Compare commits

..

25 Commits

Author SHA1 Message Date
Wenxi
959cf444f8 fix: set event hook for wrapping values into SensitiveValue (#9177) 2026-03-09 17:37:33 +00:00
Wenxi
2ebccea6d6 fix: move available context tokens to useChatController and remove arbitrary 50% cap (#9174) 2026-03-09 16:32:28 +00:00
Wenxi
5fe7a474db chore: update decryption utility (#9176) 2026-03-09 16:32:14 +00:00
Wenxi
9d7dc3da21 fix: ph ssl upgrade on redirect for local development (#9175) 2026-03-08 23:35:59 +00:00
Wenxi
2899be4c5e fix: remove unnecessary multitenant check in migration (#9172)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2026-03-08 20:53:11 +00:00
Nikolas Garza
64ee7fc23f fix(fe): fix broken slack bot admin pages (#9168) 2026-03-08 20:11:17 +00:00
Justin Tahara
e07764285d chore(llm): Adding Integration test for Model state cache 2/2 (#9142) 2026-03-08 19:07:11 +00:00
Justin Tahara
cc2e6ffa8a fix(user files): Add frontend precheck for oversized user uploads 3/3 (#9159) 2026-03-08 18:47:25 +00:00
Justin Tahara
d3ee5c9b59 fix(user files): Enforce user upload file size limit in projects/chat upload path 2/3 (#9158) 2026-03-08 17:42:44 +00:00
Justin Tahara
dfa0efc093 fix(user files): Add configurable user file max upload size setting 1/3 (#9157) 2026-03-08 17:01:55 +00:00
Danelegend
9aad4077f1 feat: Tool call arg streaming (#9095) 2026-03-07 09:02:39 +00:00
Wenxi
29d9ebf7b3 feat: rotate encryption key utility (#9162) 2026-03-07 06:17:21 +00:00
Jamison Lahman
f1df36e306 feat(cli): package as docker image (#9167)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2026-03-07 03:18:47 +00:00
Wenxi
1611604269 chore(tests): add shared enable_ee fixture and test README (#9165) 2026-03-07 01:55:38 +00:00
Danelegend
c2a71091dc feat: jsonriver implementation w/ delta (#9161) 2026-03-07 00:23:24 +00:00
Jamison Lahman
cc008699e5 fix(a11y): InputSelect supports keyboard navigation (#9160)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2026-03-07 00:06:36 +00:00
Jamison Lahman
48802618db fix(fe): fix API Key Role dropdown options (#9154) 2026-03-06 22:13:52 +00:00
Justin Tahara
6917953b86 chore(projects): Turn off DR in Projects (#9150) 2026-03-06 22:08:14 +00:00
Jamison Lahman
e7cf027f8a chore(zizmor): fix rust-toolchain commit (#9153) 2026-03-06 21:53:57 +00:00
roshan
41fb1480bb docs(cli): improve onyx-cli SKILL.md and fix README default server URL (#9152) 2026-03-06 21:47:18 +00:00
Raunak Bhagat
bdc2bfdcee fix(fe): account for wrapper padding in textarea auto-resize (#9151) 2026-03-06 21:30:25 +00:00
Evan Lohn
8816d52b27 fix: vespa filter restrictions (#9138) 2026-03-06 21:08:07 +00:00
roshan
6590f1d7ba feat(cli): add PyPI and release workflow badges to README (#9148) 2026-03-06 21:01:42 +00:00
roshan
c527f75557 fix(ci): release workflow and ods build file improvements (#9149)
Co-authored-by: Claude Sonnet 4.6 <noreply@anthropic.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2026-03-06 21:00:36 +00:00
Jamison Lahman
472d1788a7 fix(fe): add horizontal padding to chat page (#9147) 2026-03-06 20:46:56 +00:00
112 changed files with 5294 additions and 2230 deletions

View File

@@ -106,13 +106,34 @@ onyx-cli ask --json "What authentication methods do we support?"
Outputs JSON-encoded parsed stream events (one object per line). Key event objects include message deltas, stop, errors, search-start, and citation payloads.
Each line is a JSON object with this envelope:
```json
{"type": "<event_type>", "event": { ... }}
```
| Event Type | Description |
|------------|-------------|
| `message_delta` | Content token — concatenate all `content` fields for the full answer |
| `stop` | Stream complete |
| `error` | Error with `error` message field |
| `search_tool_start` | Onyx started searching documents |
| `citation_info` | Source citation with `citation_number` and `document_id` |
| `citation_info` | Source citation — see shape below |
`citation_info` event shape:
```json
{
"type": "citation_info",
"event": {
"citation_number": 1,
"document_id": "abc123def456",
"placement": {"turn_index": 0, "tab_index": 0, "sub_turn_index": null}
}
}
```
`placement` is metadata about where in the conversation the citation appeared and can be ignored for most use cases.
### Specify an agent
@@ -129,6 +150,10 @@ Uses a specific Onyx agent/persona instead of the default.
| `--agent-id` | int | Agent ID to use (overrides default) |
| `--json` | bool | Output raw NDJSON events instead of plain text |
## Statelessness
Each `onyx-cli ask` call creates an independent chat session. There is no built-in way to chain context across multiple `ask` invocations — every call starts fresh. If you need multi-turn conversation with memory, use the interactive TUI (`onyx-cli` or `onyx-cli chat`) instead.
## When to Use
Use `onyx-cli ask` when:

View File

@@ -57,7 +57,7 @@ jobs:
cache-dependency-path: ./desktop/package-lock.json
- name: Setup Rust
uses: dtolnay/rust-toolchain@4be9e76fd7c4901c61fb841f559994984270fce7
uses: dtolnay/rust-toolchain@efa25f7f19611383d5b0ccf2d1c8914531636bf9
with:
toolchain: stable
targets: ${{ matrix.target }}

View File

@@ -26,8 +26,7 @@ jobs:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
fetch-depth: 0
- uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
- uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
with:
enable-cache: false
version: "0.9.9"
@@ -38,3 +37,178 @@ jobs:
working-directory: cli
- run: uv publish
working-directory: cli
docker-amd64:
runs-on:
- runs-on
- runner=2cpu-linux-x64
- run-id=${{ github.run_id }}-cli-amd64
- extras=ecr-cache
environment: deploy
permissions:
id-token: write
timeout-minutes: 30
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
REGISTRY_IMAGE: onyxdotapp/onyx-cli
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # ratchet:aws-actions/configure-aws-credentials@v6.0.0
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 # ratchet:aws-actions/aws-secretsmanager-get-secrets@v2.0.10
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
- name: Login to Docker Hub
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # ratchet:docker/login-action@v4
with:
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
- name: Build and push AMD64
id: build
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # ratchet:docker/build-push-action@v7
with:
context: ./cli
file: ./cli/Dockerfile
platforms: linux/amd64
cache-from: type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
cache-to: type=inline
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
docker-arm64:
runs-on:
- runs-on
- runner=2cpu-linux-arm64
- run-id=${{ github.run_id }}-cli-arm64
- extras=ecr-cache
environment: deploy
permissions:
id-token: write
timeout-minutes: 30
outputs:
digest: ${{ steps.build.outputs.digest }}
env:
REGISTRY_IMAGE: onyxdotapp/onyx-cli
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Checkout
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # ratchet:aws-actions/configure-aws-credentials@v6.0.0
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 # ratchet:aws-actions/aws-secretsmanager-get-secrets@v2.0.10
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
- name: Login to Docker Hub
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # ratchet:docker/login-action@v4
with:
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
- name: Build and push ARM64
id: build
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # ratchet:docker/build-push-action@v7
with:
context: ./cli
file: ./cli/Dockerfile
platforms: linux/arm64
cache-from: type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
cache-to: type=inline
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
merge-docker:
needs:
- docker-amd64
- docker-arm64
runs-on:
- runs-on
- runner=2cpu-linux-x64
- run-id=${{ github.run_id }}-cli-merge
environment: deploy
permissions:
id-token: write
timeout-minutes: 10
env:
REGISTRY_IMAGE: onyxdotapp/onyx-cli
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
- name: Configure AWS credentials
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # ratchet:aws-actions/configure-aws-credentials@v6.0.0
with:
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
aws-region: us-east-2
- name: Get AWS Secrets
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 # ratchet:aws-actions/aws-secretsmanager-get-secrets@v2.0.10
with:
secret-ids: |
DOCKER_USERNAME, deploy/docker-username
DOCKER_TOKEN, deploy/docker-token
parse-json-secrets: true
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
- name: Login to Docker Hub
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # ratchet:docker/login-action@v4
with:
username: ${{ env.DOCKER_USERNAME }}
password: ${{ env.DOCKER_TOKEN }}
- name: Create and push manifest
env:
AMD64_DIGEST: ${{ needs.docker-amd64.outputs.digest }}
ARM64_DIGEST: ${{ needs.docker-arm64.outputs.digest }}
TAG: ${{ github.ref_name }}
run: |
SANITIZED_TAG="${TAG#cli/}"
IMAGES=(
"${REGISTRY_IMAGE}@${AMD64_DIGEST}"
"${REGISTRY_IMAGE}@${ARM64_DIGEST}"
)
if [[ "$TAG" =~ ^cli/v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
docker buildx imagetools create \
-t "${REGISTRY_IMAGE}:${SANITIZED_TAG}" \
-t "${REGISTRY_IMAGE}:latest" \
"${IMAGES[@]}"
else
docker buildx imagetools create \
-t "${REGISTRY_IMAGE}:${SANITIZED_TAG}" \
"${IMAGES[@]}"
fi

View File

@@ -22,12 +22,10 @@ jobs:
- { goos: "windows", goarch: "arm64" }
- { goos: "darwin", goarch: "amd64" }
- { goos: "darwin", goarch: "arm64" }
- { goos: "", goarch: "" }
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
with:
persist-credentials: false
fetch-depth: 0
- uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
with:
enable-cache: false

View File

@@ -544,6 +544,8 @@ To run them:
npx playwright test <TEST_NAME>
```
For shared fixtures, best practices, and detailed guidance, see `backend/tests/README.md`.
## Logs
When (1) writing integration tests or (2) doing live tests (e.g. curl / playwright) you can get access

View File

@@ -141,6 +141,7 @@ COPY --chown=onyx:onyx ./scripts/debugging /app/scripts/debugging
COPY --chown=onyx:onyx ./scripts/force_delete_connector_by_id.py /app/scripts/force_delete_connector_by_id.py
COPY --chown=onyx:onyx ./scripts/supervisord_entrypoint.sh /app/scripts/supervisord_entrypoint.sh
COPY --chown=onyx:onyx ./scripts/setup_craft_templates.sh /app/scripts/setup_craft_templates.sh
COPY --chown=onyx:onyx ./scripts/reencrypt_secrets.py /app/scripts/reencrypt_secrets.py
RUN chmod +x /app/scripts/supervisord_entrypoint.sh /app/scripts/setup_craft_templates.sh
# Run Craft template setup at build time when ENABLE_CRAFT=true

View File

@@ -1,43 +0,0 @@
"""add timestamps to user table
Revision ID: 27fb147a843f
Revises: a3b8d9e2f1c4
Create Date: 2026-03-08 17:18:40.828644
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "27fb147a843f"
down_revision = "a3b8d9e2f1c4"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"user",
sa.Column(
"created_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
)
op.add_column(
"user",
sa.Column(
"updated_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
)
def downgrade() -> None:
op.drop_column("user", "updated_at")
op.drop_column("user", "created_at")

View File

@@ -11,7 +11,6 @@ from sqlalchemy import text
from alembic import op
from onyx.configs.app_configs import DB_READONLY_PASSWORD
from onyx.configs.app_configs import DB_READONLY_USER
from shared_configs.configs import MULTI_TENANT
# revision identifiers, used by Alembic.
@@ -22,59 +21,52 @@ depends_on = None
def upgrade() -> None:
if MULTI_TENANT:
# Enable pg_trgm extension if not already enabled
op.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm")
# Enable pg_trgm extension if not already enabled
op.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm")
# Create the read-only db user if it does not already exist.
if not (DB_READONLY_USER and DB_READONLY_PASSWORD):
raise Exception("DB_READONLY_USER or DB_READONLY_PASSWORD is not set")
# Create read-only db user here only in multi-tenant mode. For single-tenant mode,
# the user is created in the standard migration.
if not (DB_READONLY_USER and DB_READONLY_PASSWORD):
raise Exception("DB_READONLY_USER or DB_READONLY_PASSWORD is not set")
op.execute(
text(
f"""
DO $$
BEGIN
-- Check if the read-only user already exists
IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
-- Create the read-only user with the specified password
EXECUTE format('CREATE USER %I WITH PASSWORD %L', '{DB_READONLY_USER}', '{DB_READONLY_PASSWORD}');
-- First revoke all privileges to ensure a clean slate
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
-- Grant only the CONNECT privilege to allow the user to connect to the database
-- but not perform any operations without additional specific grants
EXECUTE format('GRANT CONNECT ON DATABASE %I TO %I', current_database(), '{DB_READONLY_USER}');
END IF;
END
$$;
"""
)
)
def downgrade() -> None:
if MULTI_TENANT:
# Drop read-only db user here only in single tenant mode. For multi-tenant mode,
# the user is dropped in the alembic_tenants migration.
op.execute(
text(
f"""
op.execute(
text(
f"""
DO $$
BEGIN
IF EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
-- First revoke all privileges from the database
-- Check if the read-only user already exists
IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
-- Create the read-only user with the specified password
EXECUTE format('CREATE USER %I WITH PASSWORD %L', '{DB_READONLY_USER}', '{DB_READONLY_PASSWORD}');
-- First revoke all privileges to ensure a clean slate
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
-- Then revoke all privileges from the public schema
EXECUTE format('REVOKE ALL ON SCHEMA public FROM %I', '{DB_READONLY_USER}');
-- Then drop the user
EXECUTE format('DROP USER %I', '{DB_READONLY_USER}');
-- Grant only the CONNECT privilege to allow the user to connect to the database
-- but not perform any operations without additional specific grants
EXECUTE format('GRANT CONNECT ON DATABASE %I TO %I', current_database(), '{DB_READONLY_USER}');
END IF;
END
$$;
"""
)
"""
)
op.execute(text("DROP EXTENSION IF EXISTS pg_trgm"))
)
def downgrade() -> None:
op.execute(
text(
f"""
DO $$
BEGIN
IF EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
-- First revoke all privileges from the database
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
-- Then revoke all privileges from the public schema
EXECUTE format('REVOKE ALL ON SCHEMA public FROM %I', '{DB_READONLY_USER}');
-- Then drop the user
EXECUTE format('DROP USER %I', '{DB_READONLY_USER}');
END IF;
END
$$;
"""
)
)
op.execute(text("DROP EXTENSION IF EXISTS pg_trgm"))

View File

@@ -14,67 +14,91 @@ from onyx.utils.variable_functionality import fetch_versioned_implementation
logger = setup_logger()
@lru_cache(maxsize=1)
@lru_cache(maxsize=2)
def _get_trimmed_key(key: str) -> bytes:
encoded_key = key.encode()
key_length = len(encoded_key)
if key_length < 16:
raise RuntimeError("Invalid ENCRYPTION_KEY_SECRET - too short")
elif key_length > 32:
key = key[:32]
elif key_length not in (16, 24, 32):
valid_lengths = [16, 24, 32]
key = key[: min(valid_lengths, key=lambda x: abs(x - key_length))]
return encoded_key
# Trim to the largest valid AES key size that fits
valid_lengths = [32, 24, 16]
for size in valid_lengths:
if key_length >= size:
return encoded_key[:size]
raise AssertionError("unreachable")
def _encrypt_string(input_str: str) -> bytes:
if not ENCRYPTION_KEY_SECRET:
def _encrypt_string(input_str: str, key: str | None = None) -> bytes:
effective_key = key if key is not None else ENCRYPTION_KEY_SECRET
if not effective_key:
return input_str.encode()
key = _get_trimmed_key(ENCRYPTION_KEY_SECRET)
trimmed = _get_trimmed_key(effective_key)
iv = urandom(16)
padder = padding.PKCS7(algorithms.AES.block_size).padder()
padded_data = padder.update(input_str.encode()) + padder.finalize()
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
cipher = Cipher(algorithms.AES(trimmed), modes.CBC(iv), backend=default_backend())
encryptor = cipher.encryptor()
encrypted_data = encryptor.update(padded_data) + encryptor.finalize()
return iv + encrypted_data
def _decrypt_bytes(input_bytes: bytes) -> str:
if not ENCRYPTION_KEY_SECRET:
def _decrypt_bytes(input_bytes: bytes, key: str | None = None) -> str:
effective_key = key if key is not None else ENCRYPTION_KEY_SECRET
if not effective_key:
return input_bytes.decode()
key = _get_trimmed_key(ENCRYPTION_KEY_SECRET)
iv = input_bytes[:16]
encrypted_data = input_bytes[16:]
trimmed = _get_trimmed_key(effective_key)
try:
iv = input_bytes[:16]
encrypted_data = input_bytes[16:]
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
decryptor = cipher.decryptor()
decrypted_padded_data = decryptor.update(encrypted_data) + decryptor.finalize()
cipher = Cipher(
algorithms.AES(trimmed), modes.CBC(iv), backend=default_backend()
)
decryptor = cipher.decryptor()
decrypted_padded_data = decryptor.update(encrypted_data) + decryptor.finalize()
unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder()
decrypted_data = unpadder.update(decrypted_padded_data) + unpadder.finalize()
unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder()
decrypted_data = unpadder.update(decrypted_padded_data) + unpadder.finalize()
return decrypted_data.decode()
return decrypted_data.decode()
except (ValueError, UnicodeDecodeError):
if key is not None:
# Explicit key was provided — don't fall back silently
raise
# Read path: attempt raw UTF-8 decode as a fallback for legacy data.
# Does NOT handle data encrypted with a different key — that
# ciphertext is not valid UTF-8 and will raise below.
logger.warning(
"AES decryption failed — falling back to raw decode. "
"Run the re-encrypt secrets script to rotate to the current key."
)
try:
return input_bytes.decode()
except UnicodeDecodeError:
raise ValueError(
"Data is not valid UTF-8 — likely encrypted with a different key. "
"Run the re-encrypt secrets script to rotate to the current key."
) from None
def encrypt_string_to_bytes(input_str: str) -> bytes:
def encrypt_string_to_bytes(input_str: str, key: str | None = None) -> bytes:
versioned_encryption_fn = fetch_versioned_implementation(
"onyx.utils.encryption", "_encrypt_string"
)
return versioned_encryption_fn(input_str)
return versioned_encryption_fn(input_str, key=key)
def decrypt_bytes_to_string(input_bytes: bytes) -> str:
def decrypt_bytes_to_string(input_bytes: bytes, key: str | None = None) -> str:
versioned_decryption_fn = fetch_versioned_implementation(
"onyx.utils.encryption", "_decrypt_bytes"
)
return versioned_decryption_fn(input_bytes)
return versioned_decryption_fn(input_bytes, key=key)
def test_encryption() -> None:

View File

@@ -15,6 +15,7 @@ from onyx.chat.citation_processor import DynamicCitationProcessor
from onyx.chat.emitter import Emitter
from onyx.chat.models import ChatMessageSimple
from onyx.chat.models import LlmStepResult
from onyx.chat.tool_call_args_streaming import maybe_emit_argument_delta
from onyx.configs.app_configs import LOG_ONYX_MODEL_INTERACTIONS
from onyx.configs.app_configs import PROMPT_CACHE_CHAT_HISTORY
from onyx.configs.constants import MessageType
@@ -54,6 +55,7 @@ from onyx.server.query_and_chat.streaming_models import ReasoningStart
from onyx.tools.models import ToolCallKickoff
from onyx.tracing.framework.create import generation_span
from onyx.utils.b64 import get_image_type_from_bytes
from onyx.utils.jsonriver import Parser
from onyx.utils.logger import setup_logger
from onyx.utils.postgres_sanitization import sanitize_string
from onyx.utils.text_processing import find_all_json_objects
@@ -1009,6 +1011,7 @@ def run_llm_step_pkt_generator(
)
id_to_tool_call_map: dict[int, dict[str, Any]] = {}
arg_parsers: dict[int, Parser] = {}
reasoning_start = False
answer_start = False
accumulated_reasoning = ""
@@ -1215,7 +1218,14 @@ def run_llm_step_pkt_generator(
yield from _close_reasoning_if_active()
for tool_call_delta in delta.tool_calls:
# maybe_emit depends and update being called first and attaching the delta
_update_tool_call_with_delta(id_to_tool_call_map, tool_call_delta)
yield from maybe_emit_argument_delta(
tool_calls_in_progress=id_to_tool_call_map,
tool_call_delta=tool_call_delta,
placement=_current_placement(),
parsers=arg_parsers,
)
# Flush any tail text buffered while checking for split "<function_calls" markers.
filtered_content_tail = xml_tool_call_content_filter.flush()

View File

@@ -0,0 +1,77 @@
from collections.abc import Generator
from collections.abc import Mapping
from typing import Any
from typing import Type
from onyx.llm.model_response import ChatCompletionDeltaToolCall
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import ToolCallArgumentDelta
from onyx.tools.built_in_tools import TOOL_NAME_TO_CLASS
from onyx.tools.interface import Tool
from onyx.utils.jsonriver import Parser
def _get_tool_class(
tool_calls_in_progress: Mapping[int, Mapping[str, Any]],
tool_call_delta: ChatCompletionDeltaToolCall,
) -> Type[Tool] | None:
"""Look up the Tool subclass for a streaming tool call delta."""
tool_name = tool_calls_in_progress.get(tool_call_delta.index, {}).get("name")
if not tool_name:
return None
return TOOL_NAME_TO_CLASS.get(tool_name)
def maybe_emit_argument_delta(
tool_calls_in_progress: Mapping[int, Mapping[str, Any]],
tool_call_delta: ChatCompletionDeltaToolCall,
placement: Placement,
parsers: dict[int, Parser],
) -> Generator[Packet, None, None]:
"""Emit decoded tool-call argument deltas to the frontend.
Uses a ``jsonriver.Parser`` per tool-call index to incrementally parse
the JSON argument string and extract only the newly-appended content
for each string-valued argument.
NOTE: Non-string arguments (numbers, booleans, null, arrays, objects)
are skipped — they are available in the final tool-call kickoff packet.
``parsers`` is a mutable dict keyed by tool-call index. A new
``Parser`` is created automatically for each new index.
"""
tool_cls = _get_tool_class(tool_calls_in_progress, tool_call_delta)
if not tool_cls or not tool_cls.should_emit_argument_deltas():
return
fn = tool_call_delta.function
delta_fragment = fn.arguments if fn else None
if not delta_fragment:
return
idx = tool_call_delta.index
if idx not in parsers:
parsers[idx] = Parser()
parser = parsers[idx]
deltas = parser.feed(delta_fragment)
argument_deltas: dict[str, str] = {}
for delta in deltas:
if isinstance(delta, dict):
for key, value in delta.items():
if isinstance(value, str):
argument_deltas[key] = argument_deltas.get(key, "") + value
if not argument_deltas:
return
tc_data = tool_calls_in_progress[tool_call_delta.index]
yield Packet(
placement=placement,
obj=ToolCallArgumentDelta(
tool_type=tc_data.get("name", ""),
argument_deltas=argument_deltas,
),
)

View File

@@ -68,6 +68,10 @@ FILE_TOKEN_COUNT_THRESHOLD = int(
os.environ.get("FILE_TOKEN_COUNT_THRESHOLD", str(_DEFAULT_FILE_TOKEN_LIMIT))
)
# Maximum upload size for a single user file (chat/projects) in MB.
USER_FILE_MAX_UPLOAD_SIZE_MB = int(os.environ.get("USER_FILE_MAX_UPLOAD_SIZE_MB") or 50)
USER_FILE_MAX_UPLOAD_SIZE_BYTES = USER_FILE_MAX_UPLOAD_SIZE_MB * 1024 * 1024
# If set to true, will show extra/uncommon connectors in the "Other" category
SHOW_EXTRA_CONNECTORS = os.environ.get("SHOW_EXTRA_CONNECTORS", "").lower() == "true"

View File

@@ -36,9 +36,11 @@ from sqlalchemy import Text
from sqlalchemy import text
from sqlalchemy import UniqueConstraint
from sqlalchemy.dialects import postgresql
from sqlalchemy import event
from sqlalchemy.engine.interfaces import Dialect
from sqlalchemy.orm import DeclarativeBase
from sqlalchemy.orm import Mapped
from sqlalchemy.orm import Mapper
from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import relationship
from sqlalchemy.types import LargeBinary
@@ -117,10 +119,50 @@ class Base(DeclarativeBase):
__abstract__ = True
class EncryptedString(TypeDecorator):
class _EncryptedBase(TypeDecorator):
"""Base for encrypted column types that wrap values in SensitiveValue."""
impl = LargeBinary
# This type's behavior is fully deterministic and doesn't depend on any external factors.
cache_ok = True
_is_json: bool = False
def wrap_raw(self, value: Any) -> SensitiveValue:
"""Encrypt a raw value and wrap it in SensitiveValue.
Called by the attribute set event so the Python-side type is always
SensitiveValue, regardless of whether the value was loaded from the DB
or assigned in application code.
"""
if self._is_json:
if not isinstance(value, dict):
raise TypeError(
f"EncryptedJson column expected dict, got {type(value).__name__}"
)
raw_str = json.dumps(value)
else:
if not isinstance(value, str):
raise TypeError(
f"EncryptedString column expected str, got {type(value).__name__}"
)
raw_str = value
return SensitiveValue(
encrypted_bytes=encrypt_string_to_bytes(raw_str),
decrypt_fn=decrypt_bytes_to_string,
is_json=self._is_json,
)
def compare_values(self, x: Any, y: Any) -> bool:
if x is None or y is None:
return x == y
if isinstance(x, SensitiveValue):
x = x.get_value(apply_mask=False)
if isinstance(y, SensitiveValue):
y = y.get_value(apply_mask=False)
return x == y
class EncryptedString(_EncryptedBase):
_is_json: bool = False
def process_bind_param(
self, value: str | SensitiveValue[str] | None, dialect: Dialect # noqa: ARG002
@@ -144,20 +186,9 @@ class EncryptedString(TypeDecorator):
)
return None
def compare_values(self, x: Any, y: Any) -> bool:
if x is None or y is None:
return x == y
if isinstance(x, SensitiveValue):
x = x.get_value(apply_mask=False)
if isinstance(y, SensitiveValue):
y = y.get_value(apply_mask=False)
return x == y
class EncryptedJson(TypeDecorator):
impl = LargeBinary
# This type's behavior is fully deterministic and doesn't depend on any external factors.
cache_ok = True
class EncryptedJson(_EncryptedBase):
_is_json: bool = True
def process_bind_param(
self,
@@ -165,9 +196,7 @@ class EncryptedJson(TypeDecorator):
dialect: Dialect, # noqa: ARG002
) -> bytes | None:
if value is not None:
# Handle both raw dicts and SensitiveValue wrappers
if isinstance(value, SensitiveValue):
# Get raw value for storage
value = value.get_value(apply_mask=False)
json_str = json.dumps(value)
return encrypt_string_to_bytes(json_str)
@@ -184,14 +213,40 @@ class EncryptedJson(TypeDecorator):
)
return None
def compare_values(self, x: Any, y: Any) -> bool:
if x is None or y is None:
return x == y
if isinstance(x, SensitiveValue):
x = x.get_value(apply_mask=False)
if isinstance(y, SensitiveValue):
y = y.get_value(apply_mask=False)
return x == y
_REGISTERED_ATTRS: set[str] = set()
@event.listens_for(Mapper, "mapper_configured")
def _register_sensitive_value_set_events(
mapper: Mapper,
class_: type,
) -> None:
"""Auto-wrap raw values in SensitiveValue when assigned to encrypted columns."""
for prop in mapper.column_attrs:
for col in prop.columns:
if isinstance(col.type, _EncryptedBase):
col_type = col.type
attr = getattr(class_, prop.key)
# Guard against double-registration (e.g. if mapper is
# re-configured in test setups)
attr_key = f"{class_.__qualname__}.{prop.key}"
if attr_key in _REGISTERED_ATTRS:
continue
_REGISTERED_ATTRS.add(attr_key)
@event.listens_for(attr, "set", retval=True)
def _wrap_value(
target: Any, # noqa: ARG001
value: Any,
oldvalue: Any, # noqa: ARG001
initiator: Any, # noqa: ARG001
_col_type: _EncryptedBase = col_type,
) -> Any:
if value is not None and not isinstance(value, SensitiveValue):
return _col_type.wrap_raw(value)
return value
class NullFilteredString(TypeDecorator):
@@ -280,16 +335,6 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
TIMESTAMPAware(timezone=True), nullable=True
)
created_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)
updated_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
onupdate=func.now(),
nullable=False,
)
default_model: Mapped[str] = mapped_column(Text, nullable=True)
# organized in typical structured fashion
# formatted as `displayName__provider__modelName`

View File

@@ -0,0 +1,161 @@
"""Rotate encryption key for all encrypted columns.
Dynamically discovers all columns using EncryptedString / EncryptedJson,
decrypts each value with the old key, and re-encrypts with the current
ENCRYPTION_KEY_SECRET.
The operation is idempotent: rows already encrypted with the current key
are skipped. Commits are made in batches so a crash mid-rotation can be
safely resumed by re-running.
"""
import json
from typing import Any
from sqlalchemy import LargeBinary
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.orm import Session
from onyx.configs.app_configs import ENCRYPTION_KEY_SECRET
from onyx.db.models import Base
from onyx.db.models import EncryptedJson
from onyx.db.models import EncryptedString
from onyx.utils.encryption import decrypt_bytes_to_string
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import global_version
logger = setup_logger()
_BATCH_SIZE = 500
def _can_decrypt_with_current_key(data: bytes) -> bool:
"""Check if data is already encrypted with the current key.
Passes the key explicitly so the fallback-to-raw-decode path in
_decrypt_bytes is NOT triggered — a clean success/failure signal.
"""
try:
decrypt_bytes_to_string(data, key=ENCRYPTION_KEY_SECRET)
return True
except Exception:
return False
def _discover_encrypted_columns() -> list[tuple[type, str, list[str], bool]]:
"""Walk all ORM models and find columns using EncryptedString/EncryptedJson.
Returns list of (ModelClass, column_attr_name, [pk_attr_names], is_json).
"""
results: list[tuple[type, str, list[str], bool]] = []
for mapper in Base.registry.mappers:
model_cls = mapper.class_
pk_names = [col.key for col in mapper.primary_key]
for prop in mapper.column_attrs:
for col in prop.columns:
if isinstance(col.type, EncryptedJson):
results.append((model_cls, prop.key, pk_names, True))
elif isinstance(col.type, EncryptedString):
results.append((model_cls, prop.key, pk_names, False))
return results
def rotate_encryption_key(
db_session: Session,
old_key: str | None,
dry_run: bool = False,
) -> dict[str, int]:
"""Decrypt all encrypted columns with old_key and re-encrypt with the current key.
Args:
db_session: Active database session.
old_key: The previous encryption key. Pass None or "" if values were
not previously encrypted with a key.
dry_run: If True, count rows that need rotation without modifying data.
Returns:
Dict of "table.column" -> number of rows re-encrypted (or would be).
Commits every _BATCH_SIZE rows so that locks are held briefly and progress
is preserved on crash. Already-rotated rows are detected and skipped,
making the operation safe to re-run.
"""
if not global_version.is_ee_version():
raise RuntimeError("EE mode is not enabled — rotation requires EE encryption.")
if not ENCRYPTION_KEY_SECRET:
raise RuntimeError(
"ENCRYPTION_KEY_SECRET is not set — cannot rotate. "
"Set the target encryption key in the environment before running."
)
encrypted_columns = _discover_encrypted_columns()
totals: dict[str, int] = {}
for model_cls, col_name, pk_names, is_json in encrypted_columns:
table_name: str = model_cls.__tablename__ # type: ignore[attr-defined]
col_attr = getattr(model_cls, col_name)
pk_attrs = [getattr(model_cls, pk) for pk in pk_names]
# Read raw bytes directly, bypassing the TypeDecorator
raw_col = col_attr.property.columns[0]
stmt = select(*pk_attrs, raw_col.cast(LargeBinary)).where(col_attr.is_not(None))
rows = db_session.execute(stmt).all()
reencrypted = 0
batch_pending = 0
for row in rows:
raw_bytes: bytes | None = row[-1]
if raw_bytes is None:
continue
if _can_decrypt_with_current_key(raw_bytes):
continue
try:
if not old_key:
decrypted_str = raw_bytes.decode("utf-8")
else:
decrypted_str = decrypt_bytes_to_string(raw_bytes, key=old_key)
# For EncryptedJson, parse back to dict so the TypeDecorator
# can json.dumps() it cleanly (avoids double-encoding).
value: Any = json.loads(decrypted_str) if is_json else decrypted_str
except (ValueError, UnicodeDecodeError) as e:
pk_vals = [row[i] for i in range(len(pk_names))]
logger.warning(
f"Could not decrypt/parse {table_name}.{col_name} "
f"row {pk_vals} — skipping: {e}"
)
continue
if not dry_run:
pk_filters = [pk_attr == row[i] for i, pk_attr in enumerate(pk_attrs)]
update_stmt = (
update(model_cls).where(*pk_filters).values({col_name: value})
)
db_session.execute(update_stmt)
batch_pending += 1
if batch_pending >= _BATCH_SIZE:
db_session.commit()
batch_pending = 0
reencrypted += 1
# Flush remaining rows in this column
if batch_pending > 0:
db_session.commit()
if reencrypted > 0:
totals[f"{table_name}.{col_name}"] = reencrypted
logger.info(
f"{'[DRY RUN] Would re-encrypt' if dry_run else 'Re-encrypted'} "
f"{reencrypted} value(s) in {table_name}.{col_name}"
)
return totals

View File

@@ -11,7 +11,6 @@ from sqlalchemy.orm import Session
from sqlalchemy.sql import expression
from sqlalchemy.sql.elements import ColumnElement
from sqlalchemy.sql.elements import KeyedColumnElement
from sqlalchemy.sql.expression import or_
from onyx.auth.invited_users import remove_user_from_invited_users
from onyx.auth.schemas import UserRole
@@ -25,7 +24,6 @@ from onyx.db.models import Persona__User
from onyx.db.models import SamlAccount
from onyx.db.models import User
from onyx.db.models import User__UserGroup
from onyx.db.models import UserGroup
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
@@ -164,13 +162,7 @@ def _get_accepted_user_where_clause(
where_clause.append(User.role != UserRole.EXT_PERM_USER)
if email_filter_string is not None:
personal_name_col: KeyedColumnElement[Any] = User.__table__.c.personal_name
where_clause.append(
or_(
email_col.ilike(f"%{email_filter_string}%"),
personal_name_col.ilike(f"%{email_filter_string}%"),
)
)
where_clause.append(email_col.ilike(f"%{email_filter_string}%"))
if roles_filter:
where_clause.append(User.role.in_(roles_filter))
@@ -366,28 +358,3 @@ def delete_user_from_db(
# NOTE: edge case may exist with race conditions
# with this `invited user` scheme generally.
remove_user_from_invited_users(user_to_delete.email)
def batch_get_user_groups(
db_session: Session,
user_ids: list[UUID],
) -> dict[UUID, list[tuple[int, str]]]:
"""Fetch group memberships for a batch of users in a single query.
Returns a mapping of user_id -> list of (group_id, group_name) tuples."""
if not user_ids:
return {}
rows = db_session.execute(
select(
User__UserGroup.user_id,
UserGroup.id,
UserGroup.name,
)
.join(UserGroup, UserGroup.id == User__UserGroup.user_group_id)
.where(User__UserGroup.user_id.in_(user_ids))
).all()
result: dict[UUID, list[tuple[int, str]]] = {uid: [] for uid in user_ids}
for user_id, group_id, group_name in rows:
result[user_id].append((group_id, group_name))
return result

View File

@@ -0,0 +1,103 @@
# Vector DB Filter Semantics
How `IndexFilters` fields combine into the final query filter. Applies to both Vespa and OpenSearch.
## Filter categories
| Category | Fields | Join logic |
|---|---|---|
| **Visibility** | `hidden` | Always applied (unless `include_hidden`) |
| **Tenant** | `tenant_id` | AND (multi-tenant only) |
| **ACL** | `access_control_list` | OR within, AND with rest |
| **Narrowing** | `source_type`, `tags`, `time_cutoff` | Each OR within, AND with rest |
| **Knowledge scope** | `document_set`, `user_file_ids`, `attached_document_ids`, `hierarchy_node_ids` | OR within group, AND with rest |
| **Additive scope** | `project_id`, `persona_id` | OR'd into knowledge scope **only when** a knowledge scope filter already exists |
## How filters combine
All categories are AND'd together. Within the knowledge scope category, individual filters are OR'd.
```
NOT hidden
AND tenant = T -- if multi-tenant
AND (acl contains A1 OR acl contains A2)
AND (source_type = S1 OR ...) -- if set
AND (tag = T1 OR ...) -- if set
AND <knowledge scope> -- see below
AND time >= cutoff -- if set
```
## Knowledge scope rules
The knowledge scope filter controls **what knowledge an assistant can access**.
### No explicit knowledge attached
When `document_set`, `user_file_ids`, `attached_document_ids`, and `hierarchy_node_ids` are all empty/None:
- **No knowledge scope filter is applied.** The assistant can see everything (subject to ACL).
- `project_id` and `persona_id` are ignored — they never restrict on their own.
### One explicit knowledge type
```
-- Only document sets
AND (document_sets contains "Engineering" OR document_sets contains "Legal")
-- Only user files
AND (document_id = "uuid-1" OR document_id = "uuid-2")
```
### Multiple explicit knowledge types (OR'd)
```
-- Document sets + user files
AND (
document_sets contains "Engineering"
OR document_id = "uuid-1"
)
```
### Explicit knowledge + overflowing user files
When an explicit knowledge restriction is in effect **and** `project_id` or `persona_id` is set (user files overflowed the LLM context window), the additive scopes widen the filter:
```
-- Document sets + persona user files overflowed
AND (
document_sets contains "Engineering"
OR personas contains 42
)
-- User files + project files overflowed
AND (
document_id = "uuid-1"
OR user_project contains 7
)
```
### Only project_id or persona_id (no explicit knowledge)
No knowledge scope filter. The assistant searches everything.
```
-- Just ACL, no restriction
NOT hidden
AND (acl contains ...)
```
## Field reference
| Filter field | Vespa field | Vespa type | Purpose |
|---|---|---|---|
| `document_set` | `document_sets` | `weightedset<string>` | Connector doc sets attached to assistant |
| `user_file_ids` | `document_id` | `string` | User files uploaded to assistant |
| `attached_document_ids` | `document_id` | `string` | Documents explicitly attached (OpenSearch only) |
| `hierarchy_node_ids` | `ancestor_hierarchy_node_ids` | `array<int>` | Folder/space nodes (OpenSearch only) |
| `project_id` | `user_project` | `array<int>` | Project tag for overflowing user files |
| `persona_id` | `personas` | `array<int>` | Persona tag for overflowing user files |
| `access_control_list` | `access_control_list` | `weightedset<string>` | ACL entries for the requesting user |
| `source_type` | `source_type` | `string` | Connector source type (e.g. `web`, `jira`) |
| `tags` | `metadata_list` | `array<string>` | Document metadata tags |
| `time_cutoff` | `doc_updated_at` | `long` | Minimum document update timestamp |
| `tenant_id` | `tenant_id` | `string` | Tenant isolation (multi-tenant) |

View File

@@ -698,41 +698,6 @@ class DocumentQuery:
"""
return {"terms": {ANCESTOR_HIERARCHY_NODE_IDS_FIELD_NAME: node_ids}}
def _get_assistant_knowledge_filter(
attached_doc_ids: list[str] | None,
node_ids: list[int] | None,
file_ids: list[UUID] | None,
document_sets: list[str] | None,
) -> dict[str, Any]:
"""Combined filter for assistant knowledge.
When an assistant has attached knowledge, search should be scoped to:
- Documents explicitly attached (by document ID), OR
- Documents under attached hierarchy nodes (by ancestor node IDs), OR
- User-uploaded files attached to the assistant, OR
- Documents in the assistant's document sets (if any)
"""
knowledge_filter: dict[str, Any] = {
"bool": {"should": [], "minimum_should_match": 1}
}
if attached_doc_ids:
knowledge_filter["bool"]["should"].append(
_get_attached_document_id_filter(attached_doc_ids)
)
if node_ids:
knowledge_filter["bool"]["should"].append(
_get_hierarchy_node_filter(node_ids)
)
if file_ids:
knowledge_filter["bool"]["should"].append(
_get_user_file_id_filter(file_ids)
)
if document_sets:
knowledge_filter["bool"]["should"].append(
_get_document_set_filter(document_sets)
)
return knowledge_filter
filter_clauses: list[dict[str, Any]] = []
if not include_hidden:
@@ -758,41 +723,53 @@ class DocumentQuery:
# document's metadata list.
filter_clauses.append(_get_tag_filter(tags))
# Check if this is an assistant knowledge search (has any assistant-scoped knowledge)
has_assistant_knowledge = (
# Knowledge scope: explicit knowledge attachments restrict what
# an assistant can see. When none are set the assistant
# searches everything.
#
# project_id / persona_id are additive: they make overflowing
# user files findable but must NOT trigger the restriction on
# their own (an agent with no explicit knowledge should search
# everything).
has_knowledge_scope = (
attached_document_ids
or hierarchy_node_ids
or user_file_ids
or document_sets
)
if has_assistant_knowledge:
# If assistant has attached knowledge, scope search to that knowledge.
# Document sets are included in the OR filter so directly attached
# docs are always findable even if not in the document sets.
filter_clauses.append(
_get_assistant_knowledge_filter(
attached_document_ids,
hierarchy_node_ids,
user_file_ids,
document_sets,
if has_knowledge_scope:
knowledge_filter: dict[str, Any] = {
"bool": {"should": [], "minimum_should_match": 1}
}
if attached_document_ids:
knowledge_filter["bool"]["should"].append(
_get_attached_document_id_filter(attached_document_ids)
)
)
elif user_file_ids:
# Fallback for non-assistant user file searches (e.g., project searches)
# If at least one user file ID is provided, the caller will only
# retrieve documents where the document ID is in this input list of
# file IDs.
filter_clauses.append(_get_user_file_id_filter(user_file_ids))
if project_id is not None:
# If a project ID is provided, the caller will only retrieve
# documents where the project ID provided here is present in the
# document's user projects list.
filter_clauses.append(_get_user_project_filter(project_id))
if persona_id is not None:
filter_clauses.append(_get_persona_filter(persona_id))
if hierarchy_node_ids:
knowledge_filter["bool"]["should"].append(
_get_hierarchy_node_filter(hierarchy_node_ids)
)
if user_file_ids:
knowledge_filter["bool"]["should"].append(
_get_user_file_id_filter(user_file_ids)
)
if document_sets:
knowledge_filter["bool"]["should"].append(
_get_document_set_filter(document_sets)
)
# Additive: widen scope to also cover overflowing user
# files, but only when an explicit restriction is already
# in effect.
if project_id is not None:
knowledge_filter["bool"]["should"].append(
_get_user_project_filter(project_id)
)
if persona_id is not None:
knowledge_filter["bool"]["should"].append(
_get_persona_filter(persona_id)
)
filter_clauses.append(knowledge_filter)
if time_cutoff is not None:
# If a time cutoff is provided, the caller will only retrieve

View File

@@ -23,11 +23,8 @@ from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
def build_tenant_id_filter(tenant_id: str, include_trailing_and: bool = False) -> str:
filter_str = f'({TENANT_ID} contains "{tenant_id}")'
if include_trailing_and:
filter_str += " and "
return filter_str
def build_tenant_id_filter(tenant_id: str) -> str:
return f'({TENANT_ID} contains "{tenant_id}")'
def build_vespa_filters(
@@ -37,30 +34,22 @@ def build_vespa_filters(
remove_trailing_and: bool = False, # Set to True when using as a complete Vespa query
) -> str:
def _build_or_filters(key: str, vals: list[str] | None) -> str:
"""For string-based 'contains' filters, e.g. WSET fields or array<string> fields."""
"""For string-based 'contains' filters, e.g. WSET fields or array<string> fields.
Returns a bare clause like '(key contains "v1" or key contains "v2")' or ""."""
if not key or not vals:
return ""
eq_elems = [f'{key} contains "{val}"' for val in vals if val]
if not eq_elems:
return ""
or_clause = " or ".join(eq_elems)
return f"({or_clause}) and "
return f"({' or '.join(eq_elems)})"
def _build_int_or_filters(key: str, vals: list[int] | None) -> str:
"""
For an integer field filter.
If vals is not None, we want *only* docs whose key matches one of vals.
"""
# If `vals` is None => skip the filter entirely
"""For an integer field filter.
Returns a bare clause or ""."""
if vals is None or not vals:
return ""
# Otherwise build the OR filter
eq_elems = [f"{key} = {val}" for val in vals]
or_clause = " or ".join(eq_elems)
result = f"({or_clause}) and "
return result
return f"({' or '.join(eq_elems)})"
def _build_kg_filter(
kg_entities: list[str] | None,
@@ -73,16 +62,12 @@ def build_vespa_filters(
combined_filter_parts = []
def _build_kge(entity: str) -> str:
# TYPE-SUBTYPE::ID -> "TYPE-SUBTYPE::ID"
# TYPE-SUBTYPE::* -> ({prefix: true}"TYPE-SUBTYPE")
# TYPE::* -> ({prefix: true}"TYPE")
GENERAL = "::*"
if entity.endswith(GENERAL):
return f'({{prefix: true}}"{entity.split(GENERAL, 1)[0]}")'
else:
return f'"{entity}"'
# OR the entities (give new design)
if kg_entities:
filter_parts = []
for kg_entity in kg_entities:
@@ -104,8 +89,7 @@ def build_vespa_filters(
# TODO: remove kg terms entirely from prompts and codebase
# AND the combined filter parts
return f"({' and '.join(combined_filter_parts)}) and "
return f"({' and '.join(combined_filter_parts)})"
def _build_kg_source_filters(
kg_sources: list[str] | None,
@@ -114,16 +98,14 @@ def build_vespa_filters(
return ""
source_phrases = [f'{DOCUMENT_ID} contains "{source}"' for source in kg_sources]
return f"({' or '.join(source_phrases)}) and "
return f"({' or '.join(source_phrases)})"
def _build_kg_chunk_id_zero_only_filter(
kg_chunk_id_zero_only: bool,
) -> str:
if not kg_chunk_id_zero_only:
return ""
return "(chunk_id = 0 ) and "
return "(chunk_id = 0)"
def _build_time_filter(
cutoff: datetime | None,
@@ -135,8 +117,8 @@ def build_vespa_filters(
cutoff_secs = int(cutoff.timestamp())
if include_untimed:
return f"!({DOC_UPDATED_AT} < {cutoff_secs}) and "
return f"({DOC_UPDATED_AT} >= {cutoff_secs}) and "
return f"!({DOC_UPDATED_AT} < {cutoff_secs})"
return f"({DOC_UPDATED_AT} >= {cutoff_secs})"
def _build_user_project_filter(
project_id: int | None,
@@ -147,8 +129,7 @@ def build_vespa_filters(
pid = int(project_id)
except Exception:
return ""
# Vespa YQL 'contains' expects a string literal; quote the integer
return f'({USER_PROJECT} contains "{pid}") and '
return f'({USER_PROJECT} contains "{pid}")'
def _build_persona_filter(
persona_id: int | None,
@@ -160,73 +141,94 @@ def build_vespa_filters(
except Exception:
logger.warning(f"Invalid persona ID: {persona_id}")
return ""
return f'({PERSONAS} contains "{pid}") and '
return f'({PERSONAS} contains "{pid}")'
# Start building the filter string
filter_str = f"!({HIDDEN}=true) and " if not include_hidden else ""
def _append(parts: list[str], clause: str) -> None:
if clause:
parts.append(clause)
# Collect all top-level filter clauses, then join with " and " at the end.
filter_parts: list[str] = []
if not include_hidden:
filter_parts.append(f"!({HIDDEN}=true)")
# TODO: add error condition if MULTI_TENANT and no tenant_id filter is set
# If running in multi-tenant mode
if filters.tenant_id and MULTI_TENANT:
filter_str += build_tenant_id_filter(
filters.tenant_id, include_trailing_and=True
)
filter_parts.append(build_tenant_id_filter(filters.tenant_id))
# ACL filters
if filters.access_control_list is not None:
filter_str += _build_or_filters(
ACCESS_CONTROL_LIST, filters.access_control_list
_append(
filter_parts,
_build_or_filters(ACCESS_CONTROL_LIST, filters.access_control_list),
)
# Source type filters
source_strs = (
[s.value for s in filters.source_type] if filters.source_type else None
)
filter_str += _build_or_filters(SOURCE_TYPE, source_strs)
_append(filter_parts, _build_or_filters(SOURCE_TYPE, source_strs))
# Tag filters
tag_attributes = None
if filters.tags:
# build e.g. "tag_key|tag_value"
tag_attributes = [
f"{tag.tag_key}{INDEX_SEPARATOR}{tag.tag_value}" for tag in filters.tags
]
filter_str += _build_or_filters(METADATA_LIST, tag_attributes)
_append(filter_parts, _build_or_filters(METADATA_LIST, tag_attributes))
# Document sets
filter_str += _build_or_filters(DOCUMENT_SETS, filters.document_set)
# Knowledge scope: explicit knowledge attachments (document_sets,
# user_file_ids) restrict what an assistant can see. When none are
# set, the assistant can see everything.
#
# project_id / persona_id are additive: they make overflowing user
# files findable in Vespa but must NOT trigger the restriction on
# their own (an agent with no explicit knowledge should search
# everything).
knowledge_scope_parts: list[str] = []
_append(
knowledge_scope_parts, _build_or_filters(DOCUMENT_SETS, filters.document_set)
)
# Convert UUIDs to strings for user_file_ids
user_file_ids_str = (
[str(uuid) for uuid in filters.user_file_ids] if filters.user_file_ids else None
)
filter_str += _build_or_filters(DOCUMENT_ID, user_file_ids_str)
_append(knowledge_scope_parts, _build_or_filters(DOCUMENT_ID, user_file_ids_str))
# User project filter (array<int> attribute membership)
filter_str += _build_user_project_filter(filters.project_id)
# Only include project/persona scopes when an explicit knowledge
# restriction is already in effect — they widen the scope to also
# cover overflowing user files but never restrict on their own.
if knowledge_scope_parts:
_append(knowledge_scope_parts, _build_user_project_filter(filters.project_id))
_append(knowledge_scope_parts, _build_persona_filter(filters.persona_id))
# Persona filter (array<int> attribute membership)
filter_str += _build_persona_filter(filters.persona_id)
if len(knowledge_scope_parts) > 1:
filter_parts.append("(" + " or ".join(knowledge_scope_parts) + ")")
elif len(knowledge_scope_parts) == 1:
filter_parts.append(knowledge_scope_parts[0])
# Time filter
filter_str += _build_time_filter(filters.time_cutoff)
_append(filter_parts, _build_time_filter(filters.time_cutoff))
# # Knowledge Graph Filters
# filter_str += _build_kg_filter(
# _append(filter_parts, _build_kg_filter(
# kg_entities=filters.kg_entities,
# kg_relationships=filters.kg_relationships,
# kg_terms=filters.kg_terms,
# )
# ))
# filter_str += _build_kg_source_filters(filters.kg_sources)
# _append(filter_parts, _build_kg_source_filters(filters.kg_sources))
# filter_str += _build_kg_chunk_id_zero_only_filter(
# _append(filter_parts, _build_kg_chunk_id_zero_only_filter(
# filters.kg_chunk_id_zero_only or False
# )
# ))
# Trim trailing " and "
if remove_trailing_and and filter_str.endswith(" and "):
filter_str = filter_str[:-5]
filter_str = " and ".join(filter_parts)
if filter_str and not remove_trailing_and:
filter_str += " and "
return filter_str

View File

@@ -10,6 +10,8 @@ from pydantic import Field
from sqlalchemy.orm import Session
from onyx.configs.app_configs import FILE_TOKEN_COUNT_THRESHOLD
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_BYTES
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_MB
from onyx.db.llm import fetch_default_llm_model
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import get_file_ext
@@ -35,6 +37,38 @@ def get_safe_filename(upload: UploadFile) -> str:
return upload.filename
def get_upload_size_bytes(upload: UploadFile) -> int | None:
"""Best-effort file size in bytes without consuming the stream."""
if upload.size is not None:
return upload.size
try:
current_pos = upload.file.tell()
upload.file.seek(0, 2)
size = upload.file.tell()
upload.file.seek(current_pos)
return size
except Exception as e:
logger.warning(
"Could not determine upload size via stream seek "
f"(filename='{get_safe_filename(upload)}', "
f"error_type={type(e).__name__}, error={e})"
)
return None
def is_upload_too_large(upload: UploadFile, max_bytes: int) -> bool:
"""Return True when upload size is known and exceeds max_bytes."""
size_bytes = get_upload_size_bytes(upload)
if size_bytes is None:
logger.warning(
"Could not determine upload size; skipping size-limit check for "
f"'{get_safe_filename(upload)}'"
)
return False
return size_bytes > max_bytes
# Guard against extremely large images
Image.MAX_IMAGE_PIXELS = 12000 * 12000
@@ -159,6 +193,18 @@ def categorize_uploaded_files(
for upload in files:
try:
filename = get_safe_filename(upload)
# Size limit is a hard safety cap and is enforced even when token
# threshold checks are skipped via SKIP_USERFILE_THRESHOLD settings.
if is_upload_too_large(upload, USER_FILE_MAX_UPLOAD_SIZE_BYTES):
results.rejected.append(
RejectedFile(
filename=filename,
reason=f"Exceeds {USER_FILE_MAX_UPLOAD_SIZE_MB} MB file size limit",
)
)
continue
extension = get_file_ext(filename)
# If image, estimate tokens via dedicated method first

View File

@@ -5,7 +5,6 @@ from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import cast
from uuid import UUID
import jwt
from email_validator import EmailNotValidError
@@ -19,7 +18,6 @@ from fastapi import Query
from fastapi import Request
from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.auth.anonymous_user import fetch_anonymous_user_info
@@ -69,7 +67,6 @@ from onyx.db.user_preferences import update_user_role
from onyx.db.user_preferences import update_user_shortcut_enabled
from onyx.db.user_preferences import update_user_temperature_override_enabled
from onyx.db.user_preferences import update_user_theme_preference
from onyx.db.users import batch_get_user_groups
from onyx.db.users import delete_user_from_db
from onyx.db.users import get_all_users
from onyx.db.users import get_page_of_filtered_users
@@ -101,7 +98,6 @@ from onyx.server.manage.models import UserSpecificAssistantPreferences
from onyx.server.models import FullUserSnapshot
from onyx.server.models import InvitedUserSnapshot
from onyx.server.models import MinimalUserSnapshot
from onyx.server.models import UserGroupInfo
from onyx.server.usage_limits import is_tenant_on_trial_fn
from onyx.server.utils import BasicAuthenticationError
from onyx.utils.logger import setup_logger
@@ -207,32 +203,9 @@ def list_accepted_users(
total_items=0,
)
user_ids = [user.id for user in filtered_accepted_users]
groups_by_user = batch_get_user_groups(db_session, user_ids)
# Batch-fetch SCIM mappings to mark synced users
scim_synced_ids: set[UUID] = set()
try:
from onyx.db.models import ScimUserMapping
scim_mappings = db_session.scalars(
select(ScimUserMapping.user_id).where(ScimUserMapping.user_id.in_(user_ids))
).all()
scim_synced_ids = set(scim_mappings)
except Exception:
pass
return PaginatedReturn(
items=[
FullUserSnapshot.from_user_model(
user,
groups=[
UserGroupInfo(id=gid, name=gname)
for gid, gname in groups_by_user.get(user.id, [])
],
is_scim_synced=user.id in scim_synced_ids,
)
for user in filtered_accepted_users
FullUserSnapshot.from_user_model(user) for user in filtered_accepted_users
],
total_items=total_accepted_users_count,
)
@@ -296,10 +269,24 @@ def list_all_users(
if accepted_page is None or invited_page is None or slack_users_page is None:
return AllUsersResponse(
accepted=[
FullUserSnapshot.from_user_model(user) for user in accepted_users
FullUserSnapshot(
id=user.id,
email=user.email,
role=user.role,
is_active=user.is_active,
password_configured=user.password_configured,
)
for user in accepted_users
],
slack_users=[
FullUserSnapshot.from_user_model(user) for user in slack_users
FullUserSnapshot(
id=user.id,
email=user.email,
role=user.role,
is_active=user.is_active,
password_configured=user.password_configured,
)
for user in slack_users
],
invited=[InvitedUserSnapshot(email=email) for email in invited_emails],
accepted_pages=1,
@@ -309,10 +296,26 @@ def list_all_users(
# Otherwise, return paginated results
return AllUsersResponse(
accepted=[FullUserSnapshot.from_user_model(user) for user in accepted_users][
accepted_page * USERS_PAGE_SIZE : (accepted_page + 1) * USERS_PAGE_SIZE
],
slack_users=[FullUserSnapshot.from_user_model(user) for user in slack_users][
accepted=[
FullUserSnapshot(
id=user.id,
email=user.email,
role=user.role,
is_active=user.is_active,
password_configured=user.password_configured,
)
for user in accepted_users
][accepted_page * USERS_PAGE_SIZE : (accepted_page + 1) * USERS_PAGE_SIZE],
slack_users=[
FullUserSnapshot(
id=user.id,
email=user.email,
role=user.role,
is_active=user.is_active,
password_configured=user.password_configured,
)
for user in slack_users
][
slack_users_page
* USERS_PAGE_SIZE : (slack_users_page + 1)
* USERS_PAGE_SIZE

View File

@@ -1,4 +1,3 @@
import datetime
from typing import Generic
from typing import Optional
from typing import TypeVar
@@ -32,41 +31,21 @@ class MinimalUserSnapshot(BaseModel):
email: str
class UserGroupInfo(BaseModel):
id: int
name: str
class FullUserSnapshot(BaseModel):
id: UUID
email: str
role: UserRole
is_active: bool
password_configured: bool
personal_name: str | None
created_at: datetime.datetime
updated_at: datetime.datetime
groups: list[UserGroupInfo]
is_scim_synced: bool
@classmethod
def from_user_model(
cls,
user: User,
groups: list[UserGroupInfo] | None = None,
is_scim_synced: bool = False,
) -> "FullUserSnapshot":
def from_user_model(cls, user: User) -> "FullUserSnapshot":
return cls(
id=user.id,
email=user.email,
role=user.role,
is_active=user.is_active,
password_configured=user.password_configured,
personal_name=user.personal_name,
created_at=user.created_at,
updated_at=user.updated_at,
groups=groups or [],
is_scim_synced=is_scim_synced,
)

View File

@@ -41,6 +41,7 @@ class StreamingType(Enum):
REASONING_DONE = "reasoning_done"
CITATION_INFO = "citation_info"
TOOL_CALL_DEBUG = "tool_call_debug"
TOOL_CALL_ARGUMENT_DELTA = "tool_call_argument_delta"
MEMORY_TOOL_START = "memory_tool_start"
MEMORY_TOOL_DELTA = "memory_tool_delta"
@@ -259,6 +260,15 @@ class CustomToolDelta(BaseObj):
file_ids: list[str] | None = None
class ToolCallArgumentDelta(BaseObj):
type: Literal["tool_call_argument_delta"] = (
StreamingType.TOOL_CALL_ARGUMENT_DELTA.value
)
tool_type: str
argument_deltas: dict[str, Any]
################################################
# File Reader Packets
################################################
@@ -379,6 +389,7 @@ PacketObj = Union[
# Citation Packets
CitationInfo,
ToolCallDebug,
ToolCallArgumentDelta,
# Deep Research Packets
DeepResearchPlanStart,
DeepResearchPlanDelta,

View File

@@ -78,6 +78,7 @@ class Settings(BaseModel):
# User Knowledge settings
user_knowledge_enabled: bool | None = True
user_file_max_upload_size_mb: int | None = None
# Connector settings
show_extra_connectors: bool | None = True

View File

@@ -3,6 +3,7 @@ from onyx.configs.app_configs import DISABLE_USER_KNOWLEDGE
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
from onyx.configs.app_configs import SHOW_EXTRA_CONNECTORS
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_MB
from onyx.configs.constants import KV_SETTINGS_KEY
from onyx.configs.constants import OnyxRedisLocks
from onyx.key_value_store.factory import get_kv_store
@@ -50,6 +51,7 @@ def load_settings() -> Settings:
if DISABLE_USER_KNOWLEDGE:
settings.user_knowledge_enabled = False
settings.user_file_max_upload_size_mb = USER_FILE_MAX_UPLOAD_SIZE_MB
settings.show_extra_connectors = SHOW_EXTRA_CONNECTORS
settings.opensearch_indexing_enabled = ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
return settings

View File

@@ -56,3 +56,23 @@ def get_built_in_tool_ids() -> list[str]:
def get_built_in_tool_by_id(in_code_tool_id: str) -> Type[BUILT_IN_TOOL_TYPES]:
return BUILT_IN_TOOL_MAP[in_code_tool_id]
def _build_tool_name_to_class() -> dict[str, Type[BUILT_IN_TOOL_TYPES]]:
"""Build a mapping from LLM-facing tool name to tool class."""
result: dict[str, Type[BUILT_IN_TOOL_TYPES]] = {}
for cls in BUILT_IN_TOOL_MAP.values():
name_attr = cls.__dict__.get("name")
if isinstance(name_attr, property) and name_attr.fget is not None:
tool_name = name_attr.fget(cls)
elif isinstance(name_attr, str):
tool_name = name_attr
else:
raise ValueError(
f"Built-in tool {cls.__name__} must define a valid LLM-facing tool name"
)
result[tool_name] = cls
return result
TOOL_NAME_TO_CLASS: dict[str, Type[BUILT_IN_TOOL_TYPES]] = _build_tool_name_to_class()

View File

@@ -92,3 +92,7 @@ class Tool(abc.ABC, Generic[TOverride]):
**llm_kwargs: Any,
) -> ToolResponse:
raise NotImplementedError
@classmethod
def should_emit_argument_deltas(cls) -> bool:
return False

View File

@@ -376,3 +376,8 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
rich_response=None,
llm_facing_response=llm_response,
)
@classmethod
@override
def should_emit_argument_deltas(cls) -> bool:
return True

View File

@@ -11,16 +11,20 @@ logger = setup_logger()
# IMPORTANT DO NOT DELETE, THIS IS USED BY fetch_versioned_implementation
def _encrypt_string(input_str: str) -> bytes:
def _encrypt_string(input_str: str, key: str | None = None) -> bytes: # noqa: ARG001
if ENCRYPTION_KEY_SECRET:
logger.warning("MIT version of Onyx does not support encryption of secrets.")
elif key is not None:
logger.debug("MIT encrypt called with explicit key — key ignored.")
return input_str.encode()
# IMPORTANT DO NOT DELETE, THIS IS USED BY fetch_versioned_implementation
def _decrypt_bytes(input_bytes: bytes) -> str:
# No need to double warn. If you wish to learn more about encryption features
# refer to the Onyx EE code
def _decrypt_bytes(input_bytes: bytes, key: str | None = None) -> str: # noqa: ARG001
if ENCRYPTION_KEY_SECRET:
logger.warning("MIT version of Onyx does not support decryption of secrets.")
elif key is not None:
logger.debug("MIT decrypt called with explicit key — key ignored.")
return input_bytes.decode()
@@ -86,15 +90,15 @@ def _mask_list(items: list[Any]) -> list[Any]:
return masked
def encrypt_string_to_bytes(intput_str: str) -> bytes:
def encrypt_string_to_bytes(intput_str: str, key: str | None = None) -> bytes:
versioned_encryption_fn = fetch_versioned_implementation(
"onyx.utils.encryption", "_encrypt_string"
)
return versioned_encryption_fn(intput_str)
return versioned_encryption_fn(intput_str, key=key)
def decrypt_bytes_to_string(intput_bytes: bytes) -> str:
def decrypt_bytes_to_string(intput_bytes: bytes, key: str | None = None) -> str:
versioned_decryption_fn = fetch_versioned_implementation(
"onyx.utils.encryption", "_decrypt_bytes"
)
return versioned_decryption_fn(intput_bytes)
return versioned_decryption_fn(intput_bytes, key=key)

View File

@@ -0,0 +1,17 @@
"""
jsonriver - A streaming JSON parser for Python
Parse JSON incrementally as it streams in, e.g. from a network request or a language model.
Gives you a sequence of increasingly complete values.
Copyright (c) 2023 Google LLC (original TypeScript implementation)
Copyright (c) 2024 jsonriver-python contributors (Python port)
SPDX-License-Identifier: BSD-3-Clause
"""
from .parse import _Parser as Parser
from .parse import JsonObject
from .parse import JsonValue
__all__ = ["Parser", "JsonValue", "JsonObject"]
__version__ = "0.0.1"

View File

@@ -0,0 +1,427 @@
"""
JSON parser for streaming incremental parsing
Copyright (c) 2023 Google LLC (original TypeScript implementation)
Copyright (c) 2024 jsonriver-python contributors (Python port)
SPDX-License-Identifier: BSD-3-Clause
"""
from __future__ import annotations
import copy
from enum import IntEnum
from typing import cast
from typing import Union
from .tokenize import _Input
from .tokenize import json_token_type_to_string
from .tokenize import JsonTokenType
from .tokenize import Tokenizer
# Type definitions for JSON values
JsonValue = Union[None, bool, float, str, list["JsonValue"], dict[str, "JsonValue"]]
JsonObject = dict[str, JsonValue]
class _StateEnum(IntEnum):
"""Parser state machine states"""
Initial = 0
InString = 1
InArray = 2
InObjectExpectingKey = 3
InObjectExpectingValue = 4
class _State:
"""Base class for parser states"""
type: _StateEnum
value: JsonValue | tuple[str, JsonObject] | None
class _InitialState(_State):
"""Initial state before any parsing"""
def __init__(self) -> None:
self.type = _StateEnum.Initial
self.value = None
class _InStringState(_State):
"""State while parsing a string"""
def __init__(self) -> None:
self.type = _StateEnum.InString
self.value = ""
class _InArrayState(_State):
"""State while parsing an array"""
def __init__(self) -> None:
self.type = _StateEnum.InArray
self.value: list[JsonValue] = []
class _InObjectExpectingKeyState(_State):
"""State while parsing an object, expecting a key"""
def __init__(self) -> None:
self.type = _StateEnum.InObjectExpectingKey
self.value: JsonObject = {}
class _InObjectExpectingValueState(_State):
"""State while parsing an object, expecting a value"""
def __init__(self, key: str, obj: JsonObject) -> None:
self.type = _StateEnum.InObjectExpectingValue
self.value = (key, obj)
# Sentinel value to distinguish "not set" from "set to None/null"
class _Unset:
pass
_UNSET = _Unset()
class _Parser:
"""
Incremental JSON parser
Feed chunks of JSON text via feed() and get back progressively
more complete JSON values.
"""
def __init__(self) -> None:
self._state_stack: list[_State] = [_InitialState()]
self._toplevel_value: JsonValue | _Unset = _UNSET
self._input = _Input()
self.tokenizer = Tokenizer(self._input, self)
self._finished = False
self._progressed = False
self._prev_snapshot: JsonValue | _Unset = _UNSET
def feed(self, chunk: str) -> list[JsonValue]:
"""
Feed a chunk of JSON text and return deltas from the previous state.
Each element in the returned list represents what changed since the
last yielded value. For dicts, only changed/new keys are included,
with string values containing only the newly appended characters.
"""
if self._finished:
return []
self._input.feed(chunk)
return self._collect_deltas()
@staticmethod
def _compute_delta(prev: JsonValue | None, current: JsonValue) -> JsonValue | None:
if prev is None:
return current
if isinstance(current, dict) and isinstance(prev, dict):
result: JsonObject = {}
for key in current:
cur_val = current[key]
prev_val = prev.get(key)
if key not in prev:
result[key] = cur_val
elif isinstance(cur_val, str) and isinstance(prev_val, str):
if cur_val != prev_val:
result[key] = cur_val[len(prev_val) :]
elif isinstance(cur_val, list) and isinstance(prev_val, list):
if cur_val != prev_val:
new_items = cur_val[len(prev_val) :]
# check if the last existing element was updated
if (
prev_val
and len(cur_val) >= len(prev_val)
and cur_val[len(prev_val) - 1] != prev_val[-1]
):
result[key] = [cur_val[len(prev_val) - 1]] + new_items
elif new_items:
result[key] = new_items
elif cur_val != prev_val:
result[key] = cur_val
return result if result else None
if isinstance(current, str) and isinstance(prev, str):
delta = current[len(prev) :]
return delta if delta else None
if isinstance(current, list) and isinstance(prev, list):
if current != prev:
new_items = current[len(prev) :]
if (
prev
and len(current) >= len(prev)
and current[len(prev) - 1] != prev[-1]
):
return [current[len(prev) - 1]] + new_items
return new_items if new_items else None
return None
if current != prev:
return current
return None
def finish(self) -> list[JsonValue]:
"""Signal that no more chunks will be fed. Validates trailing content.
Returns any final deltas produced by flushing pending tokens (e.g.
numbers, which have no terminator and wait for more input).
"""
self._input.mark_complete()
# Pump once more so the tokenizer can emit tokens that were waiting
# for more input (e.g. numbers need buffer_complete to finalize).
results = self._collect_deltas()
self._input.expect_end_of_content()
return results
def _collect_deltas(self) -> list[JsonValue]:
"""Run one pump cycle and return any deltas produced."""
results: list[JsonValue] = []
while True:
self._progressed = False
self.tokenizer.pump()
if self._progressed:
if self._toplevel_value is _UNSET:
raise RuntimeError(
"Internal error: toplevel_value should not be unset "
"after progressing"
)
current = copy.deepcopy(cast(JsonValue, self._toplevel_value))
if isinstance(self._prev_snapshot, _Unset):
results.append(current)
else:
delta = self._compute_delta(self._prev_snapshot, current)
if delta is not None:
results.append(delta)
self._prev_snapshot = current
else:
if not self._state_stack:
self._finished = True
break
return results
# TokenHandler protocol implementation
def handle_null(self) -> None:
"""Handle null token"""
self._handle_value_token(JsonTokenType.Null, None)
def handle_boolean(self, value: bool) -> None:
"""Handle boolean token"""
self._handle_value_token(JsonTokenType.Boolean, value)
def handle_number(self, value: float) -> None:
"""Handle number token"""
self._handle_value_token(JsonTokenType.Number, value)
def handle_string_start(self) -> None:
"""Handle string start token"""
state = self._current_state()
if not self._progressed and state.type != _StateEnum.InObjectExpectingKey:
self._progressed = True
if state.type == _StateEnum.Initial:
self._state_stack.pop()
self._toplevel_value = self._progress_value(JsonTokenType.StringStart, None)
elif state.type == _StateEnum.InArray:
v = self._progress_value(JsonTokenType.StringStart, None)
arr = cast(list[JsonValue], state.value)
arr.append(v)
elif state.type == _StateEnum.InObjectExpectingKey:
self._state_stack.append(_InStringState())
elif state.type == _StateEnum.InObjectExpectingValue:
key, obj = cast(tuple[str, JsonObject], state.value)
sv = self._progress_value(JsonTokenType.StringStart, None)
obj[key] = sv
elif state.type == _StateEnum.InString:
raise ValueError(
f"Unexpected {json_token_type_to_string(JsonTokenType.StringStart)} "
f"token in the middle of string"
)
def handle_string_middle(self, value: str) -> None:
"""Handle string middle token"""
state = self._current_state()
if not self._progressed:
if len(self._state_stack) >= 2:
prev = self._state_stack[-2]
if prev.type != _StateEnum.InObjectExpectingKey:
self._progressed = True
else:
self._progressed = True
if state.type != _StateEnum.InString:
raise ValueError(
f"Unexpected {json_token_type_to_string(JsonTokenType.StringMiddle)} "
f"token when not in string"
)
assert isinstance(state.value, str)
state.value += value
parent_state = self._state_stack[-2] if len(self._state_stack) >= 2 else None
self._update_string_parent(state.value, parent_state)
def handle_string_end(self) -> None:
"""Handle string end token"""
state = self._current_state()
if state.type != _StateEnum.InString:
raise ValueError(
f"Unexpected {json_token_type_to_string(JsonTokenType.StringEnd)} "
f"token when not in string"
)
self._state_stack.pop()
parent_state = self._state_stack[-1] if self._state_stack else None
assert isinstance(state.value, str)
self._update_string_parent(state.value, parent_state)
def handle_array_start(self) -> None:
"""Handle array start token"""
self._handle_value_token(JsonTokenType.ArrayStart, None)
def handle_array_end(self) -> None:
"""Handle array end token"""
state = self._current_state()
if state.type != _StateEnum.InArray:
raise ValueError(
f"Unexpected {json_token_type_to_string(JsonTokenType.ArrayEnd)} token"
)
self._state_stack.pop()
def handle_object_start(self) -> None:
"""Handle object start token"""
self._handle_value_token(JsonTokenType.ObjectStart, None)
def handle_object_end(self) -> None:
"""Handle object end token"""
state = self._current_state()
if state.type in (
_StateEnum.InObjectExpectingKey,
_StateEnum.InObjectExpectingValue,
):
self._state_stack.pop()
else:
raise ValueError(
f"Unexpected {json_token_type_to_string(JsonTokenType.ObjectEnd)} token"
)
# Private helper methods
def _current_state(self) -> _State:
"""Get current parser state"""
if not self._state_stack:
raise ValueError("Unexpected trailing input")
return self._state_stack[-1]
def _handle_value_token(self, token_type: JsonTokenType, value: JsonValue) -> None:
"""Handle a complete value token"""
state = self._current_state()
if not self._progressed:
self._progressed = True
if state.type == _StateEnum.Initial:
self._state_stack.pop()
self._toplevel_value = self._progress_value(token_type, value)
elif state.type == _StateEnum.InArray:
v = self._progress_value(token_type, value)
arr = cast(list[JsonValue], state.value)
arr.append(v)
elif state.type == _StateEnum.InObjectExpectingValue:
key, obj = cast(tuple[str, JsonObject], state.value)
if token_type != JsonTokenType.StringStart:
self._state_stack.pop()
new_state = _InObjectExpectingKeyState()
new_state.value = obj
self._state_stack.append(new_state)
v = self._progress_value(token_type, value)
obj[key] = v
elif state.type == _StateEnum.InString:
raise ValueError(
f"Unexpected {json_token_type_to_string(token_type)} "
f"token in the middle of string"
)
elif state.type == _StateEnum.InObjectExpectingKey:
raise ValueError(
f"Unexpected {json_token_type_to_string(token_type)} "
f"token in the middle of object expecting key"
)
def _update_string_parent(self, updated: str, parent_state: _State | None) -> None:
"""Update parent container with updated string value"""
if parent_state is None:
self._toplevel_value = updated
elif parent_state.type == _StateEnum.InArray:
arr = cast(list[JsonValue], parent_state.value)
arr[-1] = updated
elif parent_state.type == _StateEnum.InObjectExpectingValue:
key, obj = cast(tuple[str, JsonObject], parent_state.value)
obj[key] = updated
if self._state_stack and self._state_stack[-1] == parent_state:
self._state_stack.pop()
new_state = _InObjectExpectingKeyState()
new_state.value = obj
self._state_stack.append(new_state)
elif parent_state.type == _StateEnum.InObjectExpectingKey:
if self._state_stack and self._state_stack[-1] == parent_state:
self._state_stack.pop()
obj = cast(JsonObject, parent_state.value)
self._state_stack.append(_InObjectExpectingValueState(updated, obj))
def _progress_value(self, token_type: JsonTokenType, value: JsonValue) -> JsonValue:
"""Create initial value for a token and push appropriate state"""
if token_type == JsonTokenType.Null:
return None
elif token_type == JsonTokenType.Boolean:
return value
elif token_type == JsonTokenType.Number:
return value
elif token_type == JsonTokenType.StringStart:
string_state = _InStringState()
self._state_stack.append(string_state)
return ""
elif token_type == JsonTokenType.ArrayStart:
array_state = _InArrayState()
self._state_stack.append(array_state)
return array_state.value
elif token_type == JsonTokenType.ObjectStart:
object_state = _InObjectExpectingKeyState()
self._state_stack.append(object_state)
return object_state.value
else:
raise ValueError(
f"Unexpected token type: {json_token_type_to_string(token_type)}"
)

View File

@@ -0,0 +1,514 @@
"""
JSON tokenizer for streaming incremental parsing
Copyright (c) 2023 Google LLC (original TypeScript implementation)
Copyright (c) 2024 jsonriver-python contributors (Python port)
SPDX-License-Identifier: BSD-3-Clause
"""
from __future__ import annotations
import re
from enum import IntEnum
from typing import Protocol
class TokenHandler(Protocol):
"""Protocol for handling JSON tokens"""
def handle_null(self) -> None: ...
def handle_boolean(self, value: bool) -> None: ...
def handle_number(self, value: float) -> None: ...
def handle_string_start(self) -> None: ...
def handle_string_middle(self, value: str) -> None: ...
def handle_string_end(self) -> None: ...
def handle_array_start(self) -> None: ...
def handle_array_end(self) -> None: ...
def handle_object_start(self) -> None: ...
def handle_object_end(self) -> None: ...
class JsonTokenType(IntEnum):
"""Types of JSON tokens"""
Null = 0
Boolean = 1
Number = 2
StringStart = 3
StringMiddle = 4
StringEnd = 5
ArrayStart = 6
ArrayEnd = 7
ObjectStart = 8
ObjectEnd = 9
def json_token_type_to_string(token_type: JsonTokenType) -> str:
"""Convert token type to readable string"""
names = {
JsonTokenType.Null: "null",
JsonTokenType.Boolean: "boolean",
JsonTokenType.Number: "number",
JsonTokenType.StringStart: "string start",
JsonTokenType.StringMiddle: "string middle",
JsonTokenType.StringEnd: "string end",
JsonTokenType.ArrayStart: "array start",
JsonTokenType.ArrayEnd: "array end",
JsonTokenType.ObjectStart: "object start",
JsonTokenType.ObjectEnd: "object end",
}
return names[token_type]
class _State(IntEnum):
"""Internal tokenizer states"""
ExpectingValue = 0
InString = 1
StartArray = 2
AfterArrayValue = 3
StartObject = 4
AfterObjectKey = 5
AfterObjectValue = 6
BeforeObjectKey = 7
# Regex for validating JSON numbers
_JSON_NUMBER_PATTERN = re.compile(r"^-?(0|[1-9]\d*)(\.\d+)?([eE][+-]?\d+)?$")
def _parse_json_number(s: str) -> float:
"""Parse a JSON number string, validating format"""
if not _JSON_NUMBER_PATTERN.match(s):
raise ValueError("Invalid number")
return float(s)
class _Input:
"""
Input buffer for chunk-based JSON parsing
Manages buffering of input chunks and provides methods for
consuming and inspecting the buffer.
"""
def __init__(self) -> None:
self._buffer = ""
self._start_index = 0
self.buffer_complete = False
def feed(self, chunk: str) -> None:
"""Add a chunk of data to the buffer"""
self._buffer += chunk
def mark_complete(self) -> None:
"""Signal that no more chunks will be fed"""
self.buffer_complete = True
@property
def length(self) -> int:
"""Number of characters remaining in buffer"""
return len(self._buffer) - self._start_index
def advance(self, length: int) -> None:
"""Advance the start position by length characters"""
self._start_index += length
def peek(self, offset: int) -> str | None:
"""Peek at character at offset, or None if not available"""
idx = self._start_index + offset
if idx < len(self._buffer):
return self._buffer[idx]
return None
def peek_char_code(self, offset: int) -> int:
"""Get character code at offset"""
return ord(self._buffer[self._start_index + offset])
def slice(self, start: int, end: int) -> str:
"""Slice buffer from start to end (relative to current position)"""
return self._buffer[self._start_index + start : self._start_index + end]
def commit(self) -> None:
"""Commit consumed content, removing it from buffer"""
if self._start_index > 0:
self._buffer = self._buffer[self._start_index :]
self._start_index = 0
def remaining(self) -> str:
"""Get all remaining content in buffer"""
return self._buffer[self._start_index :]
def expect_end_of_content(self) -> None:
"""Verify no non-whitespace content remains"""
self.commit()
self.skip_past_whitespace()
if self.length != 0:
raise ValueError(f"Unexpected trailing content {self.remaining()!r}")
def skip_past_whitespace(self) -> None:
"""Skip whitespace characters"""
i = self._start_index
while i < len(self._buffer):
c = ord(self._buffer[i])
if c in (32, 9, 10, 13): # space, tab, \n, \r
i += 1
else:
break
self._start_index = i
def try_to_take_prefix(self, prefix: str) -> bool:
"""Try to consume prefix from buffer, return True if successful"""
if self._buffer.startswith(prefix, self._start_index):
self._start_index += len(prefix)
return True
return False
def try_to_take(self, length: int) -> str | None:
"""Try to take length characters, or None if not enough available"""
if self.length < length:
return None
result = self._buffer[self._start_index : self._start_index + length]
self._start_index += length
return result
def try_to_take_char_code(self) -> int | None:
"""Try to take a single character as char code, or None if buffer empty"""
if self.length == 0:
return None
code = ord(self._buffer[self._start_index])
self._start_index += 1
return code
def take_until_quote_or_backslash(self) -> tuple[str, bool]:
"""
Consume input up to first quote or backslash
Returns tuple of (consumed_content, pattern_found)
"""
buf = self._buffer
i = self._start_index
while i < len(buf):
c = ord(buf[i])
if c <= 0x1F:
raise ValueError("Unescaped control character in string")
if c == 34 or c == 92: # " or \
result = buf[self._start_index : i]
self._start_index = i
return (result, True)
i += 1
result = buf[self._start_index :]
self._start_index = len(buf)
return (result, False)
class Tokenizer:
"""
Tokenizer for chunk-based JSON parsing
Processes chunks fed into its input buffer and calls handler methods
as JSON tokens are recognized.
"""
def __init__(self, input: _Input, handler: TokenHandler) -> None:
self.input = input
self._handler = handler
self._stack: list[_State] = [_State.ExpectingValue]
self._emitted_tokens = 0
def is_done(self) -> bool:
"""Check if tokenization is complete"""
return len(self._stack) == 0 and self.input.length == 0
def pump(self) -> None:
"""Process all available tokens in the buffer"""
while True:
before = self._emitted_tokens
self._tokenize_more()
if self._emitted_tokens == before:
self.input.commit()
return
def _tokenize_more(self) -> None:
"""Process one step of tokenization based on current state"""
if not self._stack:
return
state = self._stack[-1]
if state == _State.ExpectingValue:
self._tokenize_value()
elif state == _State.InString:
self._tokenize_string()
elif state == _State.StartArray:
self._tokenize_array_start()
elif state == _State.AfterArrayValue:
self._tokenize_after_array_value()
elif state == _State.StartObject:
self._tokenize_object_start()
elif state == _State.AfterObjectKey:
self._tokenize_after_object_key()
elif state == _State.AfterObjectValue:
self._tokenize_after_object_value()
elif state == _State.BeforeObjectKey:
self._tokenize_before_object_key()
def _tokenize_value(self) -> None:
"""Tokenize a JSON value"""
self.input.skip_past_whitespace()
if self.input.try_to_take_prefix("null"):
self._handler.handle_null()
self._emitted_tokens += 1
self._stack.pop()
return
if self.input.try_to_take_prefix("true"):
self._handler.handle_boolean(True)
self._emitted_tokens += 1
self._stack.pop()
return
if self.input.try_to_take_prefix("false"):
self._handler.handle_boolean(False)
self._emitted_tokens += 1
self._stack.pop()
return
if self.input.length > 0:
ch = self.input.peek_char_code(0)
if (48 <= ch <= 57) or ch == 45: # 0-9 or -
# Scan for end of number
i = 0
while i < self.input.length:
c = self.input.peek_char_code(i)
if (48 <= c <= 57) or c in (45, 43, 46, 101, 69): # 0-9 - + . e E
i += 1
else:
break
if i == self.input.length and not self.input.buffer_complete:
# Need more input (numbers have no terminator)
return
number_chars = self.input.slice(0, i)
self.input.advance(i)
number = _parse_json_number(number_chars)
self._handler.handle_number(number)
self._emitted_tokens += 1
self._stack.pop()
return
if self.input.try_to_take_prefix('"'):
self._stack.pop()
self._stack.append(_State.InString)
self._handler.handle_string_start()
self._emitted_tokens += 1
self._tokenize_string()
return
if self.input.try_to_take_prefix("["):
self._stack.pop()
self._stack.append(_State.StartArray)
self._handler.handle_array_start()
self._emitted_tokens += 1
self._tokenize_array_start()
return
if self.input.try_to_take_prefix("{"):
self._stack.pop()
self._stack.append(_State.StartObject)
self._handler.handle_object_start()
self._emitted_tokens += 1
self._tokenize_object_start()
return
def _tokenize_string(self) -> None:
"""Tokenize string content"""
while True:
chunk, interrupted = self.input.take_until_quote_or_backslash()
if chunk:
self._handler.handle_string_middle(chunk)
self._emitted_tokens += 1
elif not interrupted:
return
if interrupted:
if self.input.length == 0:
return
next_char = self.input.peek(0)
if next_char == '"':
self.input.advance(1)
self._handler.handle_string_end()
self._emitted_tokens += 1
self._stack.pop()
return
# Handle escape sequences
next_char2 = self.input.peek(1)
if next_char2 is None:
return
value: str
if next_char2 == "u":
# Unicode escape: need 4 hex digits
if self.input.length < 6:
return
code = 0
for j in range(2, 6):
c = self.input.peek_char_code(j)
if 48 <= c <= 57: # 0-9
digit = c - 48
elif 65 <= c <= 70: # A-F
digit = c - 55
elif 97 <= c <= 102: # a-f
digit = c - 87
else:
raise ValueError("Bad Unicode escape in JSON")
code = (code << 4) | digit
self.input.advance(6)
self._handler.handle_string_middle(chr(code))
self._emitted_tokens += 1
continue
elif next_char2 == "n":
value = "\n"
elif next_char2 == "r":
value = "\r"
elif next_char2 == "t":
value = "\t"
elif next_char2 == "b":
value = "\b"
elif next_char2 == "f":
value = "\f"
elif next_char2 == "\\":
value = "\\"
elif next_char2 == "/":
value = "/"
elif next_char2 == '"':
value = '"'
else:
raise ValueError("Bad escape in string")
self.input.advance(2)
self._handler.handle_string_middle(value)
self._emitted_tokens += 1
def _tokenize_array_start(self) -> None:
"""Tokenize start of array (check for empty or first element)"""
self.input.skip_past_whitespace()
if self.input.length == 0:
return
if self.input.try_to_take_prefix("]"):
self._handler.handle_array_end()
self._emitted_tokens += 1
self._stack.pop()
return
self._stack.pop()
self._stack.append(_State.AfterArrayValue)
self._stack.append(_State.ExpectingValue)
self._tokenize_value()
def _tokenize_after_array_value(self) -> None:
"""Tokenize after an array value (expect , or ])"""
self.input.skip_past_whitespace()
next_char = self.input.try_to_take_char_code()
if next_char is None:
return
elif next_char == 0x5D: # ]
self._handler.handle_array_end()
self._emitted_tokens += 1
self._stack.pop()
return
elif next_char == 0x2C: # ,
self._stack.append(_State.ExpectingValue)
self._tokenize_value()
return
else:
raise ValueError(f"Expected , or ], got {chr(next_char)!r}")
def _tokenize_object_start(self) -> None:
"""Tokenize start of object (check for empty or first key)"""
self.input.skip_past_whitespace()
next_char = self.input.try_to_take_char_code()
if next_char is None:
return
elif next_char == 0x7D: # }
self._handler.handle_object_end()
self._emitted_tokens += 1
self._stack.pop()
return
elif next_char == 0x22: # "
self._stack.pop()
self._stack.append(_State.AfterObjectKey)
self._stack.append(_State.InString)
self._handler.handle_string_start()
self._emitted_tokens += 1
self._tokenize_string()
return
else:
raise ValueError(f"Expected start of object key, got {chr(next_char)!r}")
def _tokenize_after_object_key(self) -> None:
"""Tokenize after object key (expect :)"""
self.input.skip_past_whitespace()
next_char = self.input.try_to_take_char_code()
if next_char is None:
return
elif next_char == 0x3A: # :
self._stack.pop()
self._stack.append(_State.AfterObjectValue)
self._stack.append(_State.ExpectingValue)
self._tokenize_value()
return
else:
raise ValueError(f"Expected colon after object key, got {chr(next_char)!r}")
def _tokenize_after_object_value(self) -> None:
"""Tokenize after object value (expect , or })"""
self.input.skip_past_whitespace()
next_char = self.input.try_to_take_char_code()
if next_char is None:
return
elif next_char == 0x7D: # }
self._handler.handle_object_end()
self._emitted_tokens += 1
self._stack.pop()
return
elif next_char == 0x2C: # ,
self._stack.pop()
self._stack.append(_State.BeforeObjectKey)
self._tokenize_before_object_key()
return
else:
raise ValueError(
f"Expected , or }} after object value, got {chr(next_char)!r}"
)
def _tokenize_before_object_key(self) -> None:
"""Tokenize before object key (after comma)"""
self.input.skip_past_whitespace()
next_char = self.input.try_to_take_char_code()
if next_char is None:
return
elif next_char == 0x22: # "
self._stack.pop()
self._stack.append(_State.AfterObjectKey)
self._stack.append(_State.InString)
self._handler.handle_string_start()
self._emitted_tokens += 1
self._tokenize_string()
return
else:
raise ValueError(f"Expected start of object key, got {chr(next_char)!r}")

View File

@@ -128,6 +128,8 @@ class SensitiveValue(Generic[T]):
value = self._decrypt()
if not apply_mask:
# Callers must not mutate the returned dict — doing so would
# desync the cache from the encrypted bytes and the DB.
return value
# Apply masking
@@ -174,18 +176,20 @@ class SensitiveValue(Generic[T]):
)
def __eq__(self, other: Any) -> bool:
"""Prevent direct comparison which might expose value."""
if isinstance(other, SensitiveValue):
# Compare encrypted bytes for equality check
return self._encrypted_bytes == other._encrypted_bytes
raise SensitiveAccessError(
"Cannot compare SensitiveValue with non-SensitiveValue. "
"Use .get_value(apply_mask=True/False) to access the value for comparison."
)
"""Compare SensitiveValues by their decrypted content."""
# NOTE: if you attempt to compare a string/dict to a SensitiveValue,
# this comparison will return NotImplemented, which then evaluates to False.
# This is the convention and required for SQLAlchemy's attribute tracking.
if not isinstance(other, SensitiveValue):
return NotImplemented
return self._decrypt() == other._decrypt()
def __hash__(self) -> int:
"""Allow hashing based on encrypted bytes."""
return hash(self._encrypted_bytes)
"""Hash based on decrypted content."""
value = self._decrypt()
if isinstance(value, dict):
return hash(json.dumps(value, sort_keys=True))
return hash(value)
# Prevent JSON serialization
def __json__(self) -> Any:

View File

@@ -24,6 +24,9 @@ class OnyxVersion:
def set_ee(self) -> None:
self._is_ee = True
def unset_ee(self) -> None:
self._is_ee = False
def is_ee_version(self) -> bool:
return self._is_ee

View File

@@ -1,48 +1,93 @@
"""Decrypt a raw hex-encoded credential value.
Usage:
python -m scripts.decrypt <hex_value>
python -m scripts.decrypt <hex_value> --key "my-encryption-key"
python -m scripts.decrypt <hex_value> --key ""
Pass --key "" to skip decryption and just decode the raw bytes as UTF-8.
Omit --key to use the current ENCRYPTION_KEY_SECRET from the environment.
"""
import argparse
import binascii
import json
import os
import sys
from onyx.utils.encryption import decrypt_bytes_to_string
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parent_dir)
from onyx.utils.encryption import decrypt_bytes_to_string # noqa: E402
from onyx.utils.variable_functionality import global_version # noqa: E402
def decrypt_raw_credential(encrypted_value: str) -> None:
"""Decrypt and display a raw encrypted credential value
def decrypt_raw_credential(encrypted_value: str, key: str | None = None) -> None:
"""Decrypt and display a raw encrypted credential value.
Args:
encrypted_value: The hex encoded encrypted credential value
encrypted_value: The hex-encoded encrypted credential value.
key: Encryption key to use. None means use ENCRYPTION_KEY_SECRET,
empty string means just decode as UTF-8.
"""
# Strip common hex prefixes
if encrypted_value.startswith("\\x"):
encrypted_value = encrypted_value[2:]
elif encrypted_value.startswith("x"):
encrypted_value = encrypted_value[1:]
print(encrypted_value)
try:
# If string starts with 'x', remove it as it's just a prefix indicating hex
if encrypted_value.startswith("x"):
encrypted_value = encrypted_value[1:]
elif encrypted_value.startswith("\\x"):
encrypted_value = encrypted_value[2:]
# Convert hex string to bytes
encrypted_bytes = binascii.unhexlify(encrypted_value)
# Decrypt the bytes
decrypted_str = decrypt_bytes_to_string(encrypted_bytes)
# Parse and pretty print the decrypted JSON
decrypted_json = json.loads(decrypted_str)
print("Decrypted credential value:")
print(json.dumps(decrypted_json, indent=2))
raw_bytes = binascii.unhexlify(encrypted_value)
except binascii.Error:
print("Error: Invalid hex encoded string")
print("Error: Invalid hex-encoded string")
sys.exit(1)
except json.JSONDecodeError as e:
print(f"Decrypted raw value (not JSON): {e}")
if key == "":
# Empty key → just decode as UTF-8, no decryption
try:
decrypted_str = raw_bytes.decode("utf-8")
except UnicodeDecodeError as e:
print(f"Error decoding bytes as UTF-8: {e}")
sys.exit(1)
else:
print(key)
try:
decrypted_str = decrypt_bytes_to_string(raw_bytes, key=key)
except Exception as e:
print(f"Error decrypting value: {e}")
sys.exit(1)
except Exception as e:
print(f"Error decrypting value: {e}")
# Try to pretty-print as JSON, otherwise print raw
try:
parsed = json.loads(decrypted_str)
print(json.dumps(parsed, indent=2))
except json.JSONDecodeError:
print(decrypted_str)
def main() -> None:
parser = argparse.ArgumentParser(
description="Decrypt a hex-encoded credential value."
)
parser.add_argument(
"value",
help="Hex-encoded encrypted value to decrypt.",
)
parser.add_argument(
"--key",
default=None,
help=(
"Encryption key. Omit to use ENCRYPTION_KEY_SECRET from env. "
'Pass "" (empty) to just decode as UTF-8 without decryption.'
),
)
args = parser.parse_args()
global_version.set_ee()
decrypt_raw_credential(args.value, key=args.key)
global_version.unset_ee()
if __name__ == "__main__":
if len(sys.argv) != 2:
print("Usage: python decrypt.py <hex_encoded_encrypted_value>")
sys.exit(1)
encrypted_value = sys.argv[1]
decrypt_raw_credential(encrypted_value)
main()

View File

@@ -0,0 +1,107 @@
"""Re-encrypt secrets under the current ENCRYPTION_KEY_SECRET.
Decrypts all encrypted columns using the old key (or raw decode if the old key
is empty), then re-encrypts them with the current ENCRYPTION_KEY_SECRET.
Usage (docker):
docker exec -it onyx-api_server-1 \
python -m scripts.reencrypt_secrets --old-key "previous-key"
Usage (kubernetes):
kubectl exec -it <pod> -- \
python -m scripts.reencrypt_secrets --old-key "previous-key"
Omit --old-key (or pass "") if secrets were not previously encrypted.
For multi-tenant deployments, pass --tenant-id to target a specific tenant,
or --all-tenants to iterate every tenant.
"""
import argparse
import os
import sys
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
sys.path.append(parent_dir)
from onyx.db.rotate_encryption_key import rotate_encryption_key # noqa: E402
from onyx.db.engine.sql_engine import get_session_with_tenant # noqa: E402
from onyx.db.engine.sql_engine import SqlEngine # noqa: E402
from onyx.db.engine.tenant_utils import get_all_tenant_ids # noqa: E402
from onyx.utils.variable_functionality import global_version # noqa: E402
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA # noqa: E402
def _run_for_tenant(tenant_id: str, old_key: str | None, dry_run: bool = False) -> None:
print(f"Re-encrypting secrets for tenant: {tenant_id}")
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
results = rotate_encryption_key(db_session, old_key=old_key, dry_run=dry_run)
if results:
for col, count in results.items():
print(
f" {col}: {count} row(s) {'would be ' if dry_run else ''}re-encrypted"
)
else:
print("No rows needed re-encryption.")
def main() -> None:
parser = argparse.ArgumentParser(
description="Re-encrypt secrets under the current encryption key."
)
parser.add_argument(
"--old-key",
default=None,
help="Previous encryption key. Omit or pass empty string if not applicable.",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Show what would be re-encrypted without making changes.",
)
tenant_group = parser.add_mutually_exclusive_group()
tenant_group.add_argument(
"--tenant-id",
default=None,
help="Target a specific tenant schema.",
)
tenant_group.add_argument(
"--all-tenants",
action="store_true",
help="Iterate all tenants.",
)
args = parser.parse_args()
old_key = args.old_key if args.old_key else None
global_version.set_ee()
SqlEngine.init_engine(pool_size=5, max_overflow=2)
if args.dry_run:
print("DRY RUN — no changes will be made")
if args.all_tenants:
tenant_ids = get_all_tenant_ids()
print(f"Found {len(tenant_ids)} tenant(s)")
failed_tenants: list[str] = []
for tid in tenant_ids:
try:
_run_for_tenant(tid, old_key, dry_run=args.dry_run)
except Exception as e:
print(f" ERROR for tenant {tid}: {e}")
failed_tenants.append(tid)
if failed_tenants:
print(f"FAILED tenants ({len(failed_tenants)}): {failed_tenants}")
sys.exit(1)
else:
tenant_id = args.tenant_id or POSTGRES_DEFAULT_SCHEMA
_run_for_tenant(tenant_id, old_key, dry_run=args.dry_run)
print("Done.")
if __name__ == "__main__":
main()

71
backend/tests/README.md Normal file
View File

@@ -0,0 +1,71 @@
# Backend Tests
## Test Types
There are four test categories, ordered by increasing scope:
### Unit Tests (`tests/unit/`)
No external services. Mock all I/O with `unittest.mock`. Use for complex, isolated
logic (e.g. citation processing, encryption).
```bash
pytest -xv backend/tests/unit
```
### External Dependency Unit Tests (`tests/external_dependency_unit/`)
External services (Postgres, Redis, Vespa, OpenAI, etc.) are running, but Onyx
application containers are not. Tests call functions directly and can mock selectively.
Use when you need a real database or real API calls but want control over setup.
```bash
python -m dotenv -f .vscode/.env run -- pytest backend/tests/external_dependency_unit
```
### Integration Tests (`tests/integration/`)
Full Onyx deployment running. No mocking. Prefer this over other test types when possible.
```bash
python -m dotenv -f .vscode/.env run -- pytest backend/tests/integration
```
### Playwright / E2E Tests (`web/tests/e2e/`)
Full stack including web server. Use for frontend-backend coordination.
```bash
npx playwright test <TEST_NAME>
```
## Shared Fixtures
Shared fixtures live in `backend/tests/conftest.py`. Test subdirectories can define
their own `conftest.py` for directory-scoped fixtures.
## Best Practices
### Use `enable_ee` fixture instead of inlining
Enables EE mode for a test, with proper teardown and cache clearing.
```python
# Whole file (in a test module, NOT in conftest.py)
pytestmark = pytest.mark.usefixtures("enable_ee")
# Whole directory — add an autouse wrapper to the directory's conftest.py
@pytest.fixture(autouse=True)
def _enable_ee_for_directory(enable_ee: None) -> None: # noqa: ARG001
"""Wraps the shared enable_ee fixture with autouse for this directory."""
# Single test
def test_something(enable_ee: None) -> None: ...
```
**Note:** `pytestmark` in a `conftest.py` does NOT apply markers to tests in that
directory — it only affects tests defined in the conftest itself (which is none).
Use the autouse fixture wrapper pattern shown above instead.
Do NOT inline `global_version.set_ee()` — always use the fixture.

24
backend/tests/conftest.py Normal file
View File

@@ -0,0 +1,24 @@
"""Root conftest — shared fixtures available to all test directories."""
from collections.abc import Generator
import pytest
from onyx.utils.variable_functionality import fetch_versioned_implementation
from onyx.utils.variable_functionality import global_version
@pytest.fixture()
def enable_ee() -> Generator[None, None, None]:
"""Temporarily enable EE mode for a single test.
Restores the previous EE state and clears the versioned-implementation
cache on teardown so state doesn't leak between tests.
"""
was_ee = global_version.is_ee_version()
global_version.set_ee()
fetch_versioned_implementation.cache_clear()
yield
if not was_ee:
global_version.unset_ee()
fetch_versioned_implementation.cache_clear()

View File

@@ -45,7 +45,7 @@ def confluence_connector() -> ConfluenceConnector:
def test_confluence_connector_permissions(
mock_get_api_key: MagicMock, # noqa: ARG001
confluence_connector: ConfluenceConnector,
set_ee_on: None, # noqa: ARG001
enable_ee: None, # noqa: ARG001
) -> None:
# Get all doc IDs from the full connector
all_full_doc_ids = set()
@@ -93,7 +93,7 @@ def test_confluence_connector_permissions(
def test_confluence_connector_restriction_handling(
mock_get_api_key: MagicMock, # noqa: ARG001
mock_db_provider_class: MagicMock,
set_ee_on: None, # noqa: ARG001
enable_ee: None, # noqa: ARG001
) -> None:
# Test space key
test_space_key = "DailyPermS"

View File

@@ -4,8 +4,6 @@ from unittest.mock import patch
import pytest
from onyx.utils.variable_functionality import global_version
@pytest.fixture
def mock_get_unstructured_api_key() -> Generator[MagicMock, None, None]:
@@ -14,14 +12,3 @@ def mock_get_unstructured_api_key() -> Generator[MagicMock, None, None]:
return_value=None,
) as mock:
yield mock
@pytest.fixture
def set_ee_on() -> Generator[None, None, None]:
"""Need EE to be enabled for these tests to work since
perm syncing is a an EE-only feature."""
global_version.set_ee()
yield
global_version._is_ee = False

View File

@@ -98,7 +98,7 @@ def _build_connector(
def test_gdrive_perm_sync_with_real_data(
google_drive_service_acct_connector_factory: Callable[..., GoogleDriveConnector],
set_ee_on: None, # noqa: ARG001
enable_ee: None, # noqa: ARG001
) -> None:
"""
Test gdrive_doc_sync and gdrive_group_sync with real data from the test drive.

View File

@@ -1,12 +1,10 @@
import time
from collections.abc import Generator
import pytest
from onyx.connectors.models import HierarchyNode
from onyx.connectors.models import SlimDocument
from onyx.connectors.slack.connector import SlackConnector
from onyx.utils.variable_functionality import global_version
from tests.daily.connectors.utils import load_all_from_connector
@@ -19,16 +17,7 @@ PRIVATE_CHANNEL_USERS = [
"test_user_2@onyx-test.com",
]
@pytest.fixture(autouse=True)
def set_ee_on() -> Generator[None, None, None]:
"""Need EE to be enabled for these tests to work since
perm syncing is a an EE-only feature."""
global_version.set_ee()
yield
global_version._is_ee = False
pytestmark = pytest.mark.usefixtures("enable_ee")
@pytest.mark.parametrize(

View File

@@ -1,13 +1,11 @@
import os
import time
from collections.abc import Generator
import pytest
from onyx.access.models import ExternalAccess
from onyx.connectors.models import HierarchyNode
from onyx.connectors.teams.connector import TeamsConnector
from onyx.utils.variable_functionality import global_version
from tests.daily.connectors.teams.models import TeamsThread
from tests.daily.connectors.utils import load_all_from_connector
@@ -168,18 +166,9 @@ def test_slim_docs_retrieval_from_teams_connector(
_assert_is_valid_external_access(external_access=slim_doc.external_access)
@pytest.fixture(autouse=False)
def set_ee_on() -> Generator[None, None, None]:
"""Need EE to be enabled for perm sync tests to work since
perm syncing is an EE-only feature."""
global_version.set_ee()
yield
global_version._is_ee = False
def test_load_from_checkpoint_with_perm_sync(
teams_connector: TeamsConnector,
set_ee_on: None, # noqa: ARG001
enable_ee: None, # noqa: ARG001
) -> None:
"""Test that load_from_checkpoint_with_perm_sync returns documents with external_access.

View File

@@ -1,5 +1,6 @@
from typing import Any
import pytest
from pydantic import BaseModel
from sqlalchemy.orm import Session
@@ -14,13 +15,14 @@ from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Credential
from onyx.db.utils import DocumentRow
from onyx.db.utils import SortOrder
from onyx.utils.variable_functionality import global_version
# In order to get these tests to run, use the credentials from Bitwarden.
# Search up "ENV vars for local and Github tests", and find the Jira relevant key-value pairs.
# Required env vars: JIRA_USER_EMAIL, JIRA_API_TOKEN
pytestmark = pytest.mark.usefixtures("enable_ee")
class DocExternalAccessSet(BaseModel):
"""A version of DocExternalAccess that uses sets for comparison."""
@@ -52,9 +54,6 @@ def test_jira_doc_sync(
This test uses the AS project which has applicationRole permission,
meaning all documents should be marked as public.
"""
# NOTE: must set EE on or else the connector will skip the perm syncing
global_version.set_ee()
try:
# Use AS project specifically for this test
connector_config = {
@@ -150,9 +149,6 @@ def test_jira_doc_sync_with_specific_permissions(
This test uses a project that has specific user permissions to verify
that specific users are correctly extracted.
"""
# NOTE: must set EE on or else the connector will skip the perm syncing
global_version.set_ee()
try:
# Use SUP project which has specific user permissions
connector_config = {

View File

@@ -1,5 +1,6 @@
from typing import Any
import pytest
from sqlalchemy.orm import Session
from ee.onyx.external_permissions.jira.group_sync import jira_group_sync
@@ -18,6 +19,8 @@ from tests.daily.connectors.confluence.models import ExternalUserGroupSet
# Search up "ENV vars for local and Github tests", and find the Jira relevant key-value pairs.
# Required env vars: JIRA_USER_EMAIL, JIRA_API_TOKEN
pytestmark = pytest.mark.usefixtures("enable_ee")
# Expected groups from the danswerai.atlassian.net Jira instance
# Note: These groups are shared with Confluence since they're both Atlassian products
# App accounts (bots, integrations) are filtered out

View File

@@ -0,0 +1,90 @@
"""Test that Credential with nested JSON round-trips through SensitiveValue correctly.
Exercises the full encrypt → store → read → decrypt → SensitiveValue path
with realistic nested OAuth credential data, and verifies SQLAlchemy dirty
tracking works with nested dict comparison.
Requires a running Postgres instance.
"""
from sqlalchemy.orm import Session
from onyx.configs.constants import DocumentSource
from onyx.db.models import Credential
from onyx.utils.sensitive import SensitiveValue
# NOTE: this is not the real shape of a Drive credential,
# but it is intended to test nested JSON credential handling
_NESTED_CRED_JSON = {
"oauth_tokens": {
"access_token": "ya29.abc123",
"refresh_token": "1//xEg-def456",
},
"scopes": ["read", "write", "admin"],
"client_config": {
"client_id": "123.apps.googleusercontent.com",
"client_secret": "GOCSPX-secret",
},
}
def test_nested_credential_json_round_trip(db_session: Session) -> None:
"""Nested OAuth credential survives encrypt → store → read → decrypt."""
credential = Credential(
source=DocumentSource.GOOGLE_DRIVE,
credential_json=_NESTED_CRED_JSON,
)
db_session.add(credential)
db_session.flush()
# Immediate read (no DB round-trip) — tests the set event wrapping
assert isinstance(credential.credential_json, SensitiveValue)
assert credential.credential_json.get_value(apply_mask=False) == _NESTED_CRED_JSON
# DB round-trip — tests process_result_value
db_session.expire(credential)
reloaded = credential.credential_json
assert isinstance(reloaded, SensitiveValue)
assert reloaded.get_value(apply_mask=False) == _NESTED_CRED_JSON
db_session.rollback()
def test_reassign_same_nested_json_not_dirty(db_session: Session) -> None:
"""Re-assigning the same nested dict should not mark the session dirty."""
credential = Credential(
source=DocumentSource.GOOGLE_DRIVE,
credential_json=_NESTED_CRED_JSON,
)
db_session.add(credential)
db_session.flush()
# Clear dirty state from the insert
db_session.expire(credential)
_ = credential.credential_json # force reload
# Re-assign identical value
credential.credential_json = _NESTED_CRED_JSON # type: ignore[assignment]
assert not db_session.is_modified(credential)
db_session.rollback()
def test_assign_different_nested_json_is_dirty(db_session: Session) -> None:
"""Assigning a different nested dict should mark the session dirty."""
credential = Credential(
source=DocumentSource.GOOGLE_DRIVE,
credential_json=_NESTED_CRED_JSON,
)
db_session.add(credential)
db_session.flush()
db_session.expire(credential)
_ = credential.credential_json # force reload
modified_cred = {**_NESTED_CRED_JSON, "scopes": ["read"]}
credential.credential_json = modified_cred # type: ignore[assignment]
assert db_session.is_modified(credential)
db_session.rollback()

View File

@@ -0,0 +1,305 @@
"""Tests for rotate_encryption_key against real Postgres.
Uses real ORM models (Credential, InternetSearchProvider) and the actual
Postgres database. Discovery is mocked in rotation tests to scope mutations
to only the test rows — the real _discover_encrypted_columns walk is tested
separately in TestDiscoverEncryptedColumns.
Requires a running Postgres instance. Run with::
python -m dotenv -f .vscode/.env run -- pytest tests/external_dependency_unit/db/test_rotate_encryption_key.py
"""
import json
from collections.abc import Generator
from unittest.mock import patch
import pytest
from sqlalchemy import LargeBinary
from sqlalchemy import select
from sqlalchemy import text
from sqlalchemy.orm import Session
from ee.onyx.utils.encryption import _decrypt_bytes
from ee.onyx.utils.encryption import _encrypt_string
from ee.onyx.utils.encryption import _get_trimmed_key
from onyx.configs.constants import DocumentSource
from onyx.db.models import Credential
from onyx.db.models import EncryptedJson
from onyx.db.models import EncryptedString
from onyx.db.models import InternetSearchProvider
from onyx.db.rotate_encryption_key import _discover_encrypted_columns
from onyx.db.rotate_encryption_key import rotate_encryption_key
from onyx.utils.variable_functionality import fetch_versioned_implementation
from onyx.utils.variable_functionality import global_version
EE_MODULE = "ee.onyx.utils.encryption"
ROTATE_MODULE = "onyx.db.rotate_encryption_key"
OLD_KEY = "o" * 16
NEW_KEY = "n" * 16
@pytest.fixture(autouse=True)
def _enable_ee() -> Generator[None, None, None]:
prev = global_version._is_ee
global_version.set_ee()
fetch_versioned_implementation.cache_clear()
yield
global_version._is_ee = prev
fetch_versioned_implementation.cache_clear()
@pytest.fixture(autouse=True)
def _clear_key_cache() -> None:
_get_trimmed_key.cache_clear()
def _raw_credential_bytes(db_session: Session, credential_id: int) -> bytes | None:
"""Read raw bytes from credential_json, bypassing the TypeDecorator."""
col = Credential.__table__.c.credential_json
stmt = select(col.cast(LargeBinary)).where(
Credential.__table__.c.id == credential_id
)
return db_session.execute(stmt).scalar()
def _raw_isp_bytes(db_session: Session, isp_id: int) -> bytes | None:
"""Read raw bytes from InternetSearchProvider.api_key."""
col = InternetSearchProvider.__table__.c.api_key
stmt = select(col.cast(LargeBinary)).where(
InternetSearchProvider.__table__.c.id == isp_id
)
return db_session.execute(stmt).scalar()
class TestDiscoverEncryptedColumns:
"""Verify _discover_encrypted_columns finds real production models."""
def test_discovers_credential_json(self) -> None:
results = _discover_encrypted_columns()
found = {
(model_cls.__tablename__, col_name, is_json) # type: ignore[attr-defined]
for model_cls, col_name, _, is_json in results
}
assert ("credential", "credential_json", True) in found
def test_discovers_internet_search_provider_api_key(self) -> None:
results = _discover_encrypted_columns()
found = {
(model_cls.__tablename__, col_name, is_json) # type: ignore[attr-defined]
for model_cls, col_name, _, is_json in results
}
assert ("internet_search_provider", "api_key", False) in found
def test_all_encrypted_string_columns_are_not_json(self) -> None:
results = _discover_encrypted_columns()
for model_cls, col_name, _, is_json in results:
col = getattr(model_cls, col_name).property.columns[0]
if isinstance(col.type, EncryptedString):
assert not is_json, (
f"{model_cls.__tablename__}.{col_name} is EncryptedString " # type: ignore[attr-defined]
f"but is_json={is_json}"
)
def test_all_encrypted_json_columns_are_json(self) -> None:
results = _discover_encrypted_columns()
for model_cls, col_name, _, is_json in results:
col = getattr(model_cls, col_name).property.columns[0]
if isinstance(col.type, EncryptedJson):
assert is_json, (
f"{model_cls.__tablename__}.{col_name} is EncryptedJson " # type: ignore[attr-defined]
f"but is_json={is_json}"
)
class TestRotateCredential:
"""Test rotation against the real Credential table (EncryptedJson).
Discovery is scoped to only the Credential model to avoid mutating
other tables in the test database.
"""
@pytest.fixture(autouse=True)
def _limit_discovery(self) -> Generator[None, None, None]:
with patch(
f"{ROTATE_MODULE}._discover_encrypted_columns",
return_value=[(Credential, "credential_json", ["id"], True)],
):
yield
@pytest.fixture()
def credential_id(
self, db_session: Session, tenant_context: None # noqa: ARG002
) -> Generator[int, None, None]:
"""Insert a Credential row with raw encrypted bytes, clean up after."""
config = {"api_key": "sk-test-1234", "endpoint": "https://example.com"}
encrypted = _encrypt_string(json.dumps(config), key=OLD_KEY)
result = db_session.execute(
text(
"INSERT INTO credential "
"(source, credential_json, admin_public, curator_public) "
"VALUES (:source, :cred_json, true, false) "
"RETURNING id"
),
{"source": DocumentSource.INGESTION_API.value, "cred_json": encrypted},
)
cred_id = result.scalar_one()
db_session.commit()
yield cred_id
db_session.execute(
text("DELETE FROM credential WHERE id = :id"), {"id": cred_id}
)
db_session.commit()
def test_rotates_credential_json(
self, db_session: Session, credential_id: int
) -> None:
with (
patch(f"{ROTATE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
):
totals = rotate_encryption_key(db_session, old_key=OLD_KEY)
assert totals.get("credential.credential_json", 0) >= 1
raw = _raw_credential_bytes(db_session, credential_id)
assert raw is not None
decrypted = json.loads(_decrypt_bytes(raw, key=NEW_KEY))
assert decrypted["api_key"] == "sk-test-1234"
assert decrypted["endpoint"] == "https://example.com"
def test_skips_already_rotated(
self, db_session: Session, credential_id: int
) -> None:
with (
patch(f"{ROTATE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
):
rotate_encryption_key(db_session, old_key=OLD_KEY)
_ = rotate_encryption_key(db_session, old_key=OLD_KEY)
raw = _raw_credential_bytes(db_session, credential_id)
assert raw is not None
decrypted = json.loads(_decrypt_bytes(raw, key=NEW_KEY))
assert decrypted["api_key"] == "sk-test-1234"
def test_dry_run_does_not_modify(
self, db_session: Session, credential_id: int
) -> None:
original = _raw_credential_bytes(db_session, credential_id)
with (
patch(f"{ROTATE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
):
totals = rotate_encryption_key(db_session, old_key=OLD_KEY, dry_run=True)
assert totals.get("credential.credential_json", 0) >= 1
raw_after = _raw_credential_bytes(db_session, credential_id)
assert raw_after == original
class TestRotateInternetSearchProvider:
"""Test rotation against the real InternetSearchProvider table (EncryptedString).
Discovery is scoped to only the InternetSearchProvider model to avoid
mutating other tables in the test database.
"""
@pytest.fixture(autouse=True)
def _limit_discovery(self) -> Generator[None, None, None]:
with patch(
f"{ROTATE_MODULE}._discover_encrypted_columns",
return_value=[
(InternetSearchProvider, "api_key", ["id"], False),
],
):
yield
@pytest.fixture()
def isp_id(
self, db_session: Session, tenant_context: None # noqa: ARG002
) -> Generator[int, None, None]:
"""Insert an InternetSearchProvider row with raw encrypted bytes."""
encrypted = _encrypt_string("sk-secret-api-key", key=OLD_KEY)
result = db_session.execute(
text(
"INSERT INTO internet_search_provider "
"(name, provider_type, api_key, is_active) "
"VALUES (:name, :ptype, :api_key, false) "
"RETURNING id"
),
{
"name": f"test-rotation-{id(self)}",
"ptype": "test",
"api_key": encrypted,
},
)
isp_id = result.scalar_one()
db_session.commit()
yield isp_id
db_session.execute(
text("DELETE FROM internet_search_provider WHERE id = :id"),
{"id": isp_id},
)
db_session.commit()
def test_rotates_api_key(self, db_session: Session, isp_id: int) -> None:
with (
patch(f"{ROTATE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
):
totals = rotate_encryption_key(db_session, old_key=OLD_KEY)
assert totals.get("internet_search_provider.api_key", 0) >= 1
raw = _raw_isp_bytes(db_session, isp_id)
assert raw is not None
assert _decrypt_bytes(raw, key=NEW_KEY) == "sk-secret-api-key"
def test_rotates_from_unencrypted(
self, db_session: Session, tenant_context: None # noqa: ARG002
) -> None:
"""Test rotating data that was stored without any encryption key."""
result = db_session.execute(
text(
"INSERT INTO internet_search_provider "
"(name, provider_type, api_key, is_active) "
"VALUES (:name, :ptype, :api_key, false) "
"RETURNING id"
),
{
"name": f"test-raw-{id(self)}",
"ptype": "test",
"api_key": b"raw-api-key",
},
)
isp_id = result.scalar_one()
db_session.commit()
try:
with (
patch(f"{ROTATE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
):
totals = rotate_encryption_key(db_session, old_key=None)
assert totals.get("internet_search_provider.api_key", 0) >= 1
raw = _raw_isp_bytes(db_session, isp_id)
assert raw is not None
assert _decrypt_bytes(raw, key=NEW_KEY) == "raw-api-key"
finally:
db_session.execute(
text("DELETE FROM internet_search_provider WHERE id = :id"),
{"id": isp_id},
)
db_session.commit()

View File

@@ -0,0 +1,85 @@
"""Tests that SlackBot CRUD operations return properly typed SensitiveValue fields.
Regression test for the bug where insert_slack_bot/update_slack_bot returned
objects with raw string tokens instead of SensitiveValue wrappers, causing
'str object has no attribute get_value' errors in SlackBot.from_model().
"""
from uuid import uuid4
from sqlalchemy.orm import Session
from onyx.db.slack_bot import insert_slack_bot
from onyx.db.slack_bot import update_slack_bot
from onyx.server.manage.models import SlackBot
from onyx.utils.sensitive import SensitiveValue
def _unique(prefix: str) -> str:
return f"{prefix}-{uuid4().hex[:8]}"
def test_insert_slack_bot_returns_sensitive_values(db_session: Session) -> None:
bot_token = _unique("xoxb-insert")
app_token = _unique("xapp-insert")
user_token = _unique("xoxp-insert")
slack_bot = insert_slack_bot(
db_session=db_session,
name=_unique("test-bot-insert"),
enabled=True,
bot_token=bot_token,
app_token=app_token,
user_token=user_token,
)
assert isinstance(slack_bot.bot_token, SensitiveValue)
assert isinstance(slack_bot.app_token, SensitiveValue)
assert isinstance(slack_bot.user_token, SensitiveValue)
assert slack_bot.bot_token.get_value(apply_mask=False) == bot_token
assert slack_bot.app_token.get_value(apply_mask=False) == app_token
assert slack_bot.user_token.get_value(apply_mask=False) == user_token
# Verify from_model works without error
pydantic_bot = SlackBot.from_model(slack_bot)
assert pydantic_bot.bot_token # masked, but not empty
assert pydantic_bot.app_token
def test_update_slack_bot_returns_sensitive_values(db_session: Session) -> None:
slack_bot = insert_slack_bot(
db_session=db_session,
name=_unique("test-bot-update"),
enabled=True,
bot_token=_unique("xoxb-update"),
app_token=_unique("xapp-update"),
)
new_bot_token = _unique("xoxb-update-new")
new_app_token = _unique("xapp-update-new")
new_user_token = _unique("xoxp-update-new")
updated = update_slack_bot(
db_session=db_session,
slack_bot_id=slack_bot.id,
name=_unique("test-bot-updated"),
enabled=False,
bot_token=new_bot_token,
app_token=new_app_token,
user_token=new_user_token,
)
assert isinstance(updated.bot_token, SensitiveValue)
assert isinstance(updated.app_token, SensitiveValue)
assert isinstance(updated.user_token, SensitiveValue)
assert updated.bot_token.get_value(apply_mask=False) == new_bot_token
assert updated.app_token.get_value(apply_mask=False) == new_app_token
assert updated.user_token.get_value(apply_mask=False) == new_user_token
# Verify from_model works without error
pydantic_bot = SlackBot.from_model(updated)
assert pydantic_bot.bot_token
assert pydantic_bot.app_token
assert pydantic_bot.user_token is not None

View File

@@ -148,8 +148,16 @@ class TestOAuthConfigCRUD:
)
# Secrets should be preserved
assert updated_config.client_id == original_client_id
assert updated_config.client_secret == original_client_secret
assert updated_config.client_id is not None
assert original_client_id is not None
assert updated_config.client_id.get_value(
apply_mask=False
) == original_client_id.get_value(apply_mask=False)
assert updated_config.client_secret is not None
assert original_client_secret is not None
assert updated_config.client_secret.get_value(
apply_mask=False
) == original_client_secret.get_value(apply_mask=False)
# But name should be updated
assert updated_config.name == new_name
@@ -173,9 +181,14 @@ class TestOAuthConfigCRUD:
)
# client_id should be cleared (empty string)
assert updated_config.client_id == ""
assert updated_config.client_id is not None
assert updated_config.client_id.get_value(apply_mask=False) == ""
# client_secret should be preserved
assert updated_config.client_secret == original_client_secret
assert updated_config.client_secret is not None
assert original_client_secret is not None
assert updated_config.client_secret.get_value(
apply_mask=False
) == original_client_secret.get_value(apply_mask=False)
def test_update_oauth_config_clear_client_secret(self, db_session: Session) -> None:
"""Test clearing client_secret while preserving client_id"""
@@ -190,9 +203,14 @@ class TestOAuthConfigCRUD:
)
# client_secret should be cleared (empty string)
assert updated_config.client_secret == ""
assert updated_config.client_secret is not None
assert updated_config.client_secret.get_value(apply_mask=False) == ""
# client_id should be preserved
assert updated_config.client_id == original_client_id
assert updated_config.client_id is not None
assert original_client_id is not None
assert updated_config.client_id.get_value(
apply_mask=False
) == original_client_id.get_value(apply_mask=False)
def test_update_oauth_config_clear_both_secrets(self, db_session: Session) -> None:
"""Test clearing both client_id and client_secret"""
@@ -207,8 +225,10 @@ class TestOAuthConfigCRUD:
)
# Both should be cleared (empty strings)
assert updated_config.client_id == ""
assert updated_config.client_secret == ""
assert updated_config.client_id is not None
assert updated_config.client_id.get_value(apply_mask=False) == ""
assert updated_config.client_secret is not None
assert updated_config.client_secret.get_value(apply_mask=False) == ""
def test_update_oauth_config_authorization_url(self, db_session: Session) -> None:
"""Test updating authorization_url"""
@@ -275,7 +295,8 @@ class TestOAuthConfigCRUD:
assert updated_config.token_url == new_token_url
assert updated_config.scopes == new_scopes
assert updated_config.additional_params == new_params
assert updated_config.client_id == new_client_id
assert updated_config.client_id is not None
assert updated_config.client_id.get_value(apply_mask=False) == new_client_id
def test_delete_oauth_config(self, db_session: Session) -> None:
"""Test deleting an OAuth configuration"""
@@ -416,7 +437,8 @@ class TestOAuthUserTokenCRUD:
assert user_token.id is not None
assert user_token.oauth_config_id == oauth_config.id
assert user_token.user_id == user.id
assert user_token.token_data == token_data
assert user_token.token_data is not None
assert user_token.token_data.get_value(apply_mask=False) == token_data
assert user_token.created_at is not None
assert user_token.updated_at is not None
@@ -446,8 +468,13 @@ class TestOAuthUserTokenCRUD:
# Should be the same token record (updated, not inserted)
assert updated_token.id == initial_token_id
assert updated_token.token_data == updated_token_data
assert updated_token.token_data != initial_token_data
assert updated_token.token_data is not None
assert (
updated_token.token_data.get_value(apply_mask=False) == updated_token_data
)
assert (
updated_token.token_data.get_value(apply_mask=False) != initial_token_data
)
def test_get_user_oauth_token(self, db_session: Session) -> None:
"""Test retrieving a user's OAuth token"""
@@ -463,7 +490,8 @@ class TestOAuthUserTokenCRUD:
assert retrieved_token is not None
assert retrieved_token.id == created_token.id
assert retrieved_token.token_data == token_data
assert retrieved_token.token_data is not None
assert retrieved_token.token_data.get_value(apply_mask=False) == token_data
def test_get_user_oauth_token_not_found(self, db_session: Session) -> None:
"""Test retrieving a non-existent user token returns None"""
@@ -519,7 +547,8 @@ class TestOAuthUserTokenCRUD:
retrieved_token = get_user_oauth_token(oauth_config.id, user.id, db_session)
assert retrieved_token is not None
assert retrieved_token.id == updated_token.id
assert retrieved_token.token_data == token_data2
assert retrieved_token.token_data is not None
assert retrieved_token.token_data.get_value(apply_mask=False) == token_data2
def test_cascade_delete_user_tokens_on_config_deletion(
self, db_session: Session

View File

@@ -374,8 +374,14 @@ class TestOAuthTokenManagerCodeExchange:
assert call_args[0][0] == oauth_config.token_url
assert call_args[1]["data"]["grant_type"] == "authorization_code"
assert call_args[1]["data"]["code"] == "auth_code_123"
assert call_args[1]["data"]["client_id"] == oauth_config.client_id
assert call_args[1]["data"]["client_secret"] == oauth_config.client_secret
assert oauth_config.client_id is not None
assert oauth_config.client_secret is not None
assert call_args[1]["data"]["client_id"] == oauth_config.client_id.get_value(
apply_mask=False
)
assert call_args[1]["data"][
"client_secret"
] == oauth_config.client_secret.get_value(apply_mask=False)
assert call_args[1]["data"]["redirect_uri"] == "https://example.com/callback"
@patch("onyx.auth.oauth_token_manager.requests.post")

View File

@@ -950,6 +950,7 @@ from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import PythonToolDelta
from onyx.server.query_and_chat.streaming_models import PythonToolStart
from onyx.server.query_and_chat.streaming_models import SectionEnd
from onyx.server.query_and_chat.streaming_models import ToolCallArgumentDelta
from onyx.tools.tool_implementations.python.python_tool import PythonTool
from tests.external_dependency_unit.answer.stream_test_builder import StreamTestBuilder
from tests.external_dependency_unit.answer.stream_test_utils import create_chat_session
@@ -1294,9 +1295,18 @@ def test_code_interpreter_replay_packets_include_code_and_output(
).expect(
Packet(
placement=create_placement(0),
obj=PythonToolStart(code=code),
obj=ToolCallArgumentDelta(
tool_type="python",
argument_deltas={"code": code},
),
),
forward=2,
).expect(
Packet(
placement=create_placement(0),
obj=PythonToolStart(code=code),
),
forward=False,
).expect(
Packet(
placement=create_placement(0),

View File

@@ -64,7 +64,8 @@ class TestBotConfigAPI:
db_session.commit()
assert config is not None
assert config.bot_token == "test_token_123"
assert config.bot_token is not None
assert config.bot_token.get_value(apply_mask=False) == "test_token_123"
# Cleanup
delete_discord_bot_config(db_session)

View File

@@ -0,0 +1,8 @@
"""Auto-enable EE mode for all tests under tests/unit/ee/."""
import pytest
@pytest.fixture(autouse=True)
def _enable_ee_for_directory(enable_ee: None) -> None: # noqa: ARG001
"""Wraps the shared enable_ee fixture with autouse for this directory."""

View File

@@ -0,0 +1,165 @@
"""Tests for EE AES-CBC encryption/decryption with explicit key support.
With EE mode enabled (via conftest), fetch_versioned_implementation resolves
to the EE implementations, so no patching of the MIT layer is needed.
"""
from unittest.mock import patch
import pytest
from ee.onyx.utils.encryption import _decrypt_bytes
from ee.onyx.utils.encryption import _encrypt_string
from ee.onyx.utils.encryption import _get_trimmed_key
from ee.onyx.utils.encryption import decrypt_bytes_to_string
from ee.onyx.utils.encryption import encrypt_string_to_bytes
EE_MODULE = "ee.onyx.utils.encryption"
# Keys must be exactly 16, 24, or 32 bytes for AES
KEY_16 = "a" * 16
KEY_16_ALT = "b" * 16
KEY_24 = "d" * 24
KEY_32 = "c" * 32
@pytest.fixture(autouse=True)
def _clear_key_cache() -> None:
_get_trimmed_key.cache_clear()
class TestEncryptDecryptRoundTrip:
def test_roundtrip_with_env_key(self) -> None:
with patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", KEY_16):
encrypted = _encrypt_string("hello world")
assert encrypted != b"hello world"
assert _decrypt_bytes(encrypted) == "hello world"
def test_roundtrip_with_explicit_key(self) -> None:
encrypted = _encrypt_string("secret data", key=KEY_32)
assert encrypted != b"secret data"
assert _decrypt_bytes(encrypted, key=KEY_32) == "secret data"
def test_roundtrip_no_key(self) -> None:
"""Without any key, data is raw-encoded (no encryption)."""
with patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", ""):
encrypted = _encrypt_string("plain text")
assert encrypted == b"plain text"
assert _decrypt_bytes(encrypted) == "plain text"
def test_explicit_key_overrides_env(self) -> None:
with patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", KEY_16):
encrypted = _encrypt_string("data", key=KEY_16_ALT)
with pytest.raises(ValueError):
_decrypt_bytes(encrypted, key=KEY_16)
assert _decrypt_bytes(encrypted, key=KEY_16_ALT) == "data"
def test_different_encryptions_produce_different_bytes(self) -> None:
"""Each encryption uses a random IV, so results differ."""
a = _encrypt_string("same", key=KEY_16)
b = _encrypt_string("same", key=KEY_16)
assert a != b
def test_roundtrip_empty_string(self) -> None:
encrypted = _encrypt_string("", key=KEY_16)
assert encrypted != b""
assert _decrypt_bytes(encrypted, key=KEY_16) == ""
def test_roundtrip_unicode(self) -> None:
text = "日本語テスト 🔐 émojis"
encrypted = _encrypt_string(text, key=KEY_16)
assert _decrypt_bytes(encrypted, key=KEY_16) == text
class TestDecryptFallbackBehavior:
def test_wrong_env_key_falls_back_to_raw_decode(self) -> None:
"""Default key path: AES fails on non-AES data → fallback to raw decode."""
raw = "readable text".encode()
with patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", KEY_16):
assert _decrypt_bytes(raw) == "readable text"
def test_explicit_wrong_key_raises(self) -> None:
"""Explicit key path: AES fails → raises, no fallback."""
encrypted = _encrypt_string("secret", key=KEY_16)
with pytest.raises(ValueError):
_decrypt_bytes(encrypted, key=KEY_16_ALT)
def test_explicit_none_key_with_no_env(self) -> None:
"""key=None with empty env → raw decode."""
with patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", ""):
assert _decrypt_bytes(b"hello", key=None) == "hello"
def test_explicit_empty_string_key(self) -> None:
"""key='' means no encryption."""
encrypted = _encrypt_string("test", key="")
assert encrypted == b"test"
assert _decrypt_bytes(encrypted, key="") == "test"
class TestKeyValidation:
def test_key_too_short_raises(self) -> None:
with pytest.raises(RuntimeError, match="too short"):
_encrypt_string("data", key="short")
def test_16_byte_key(self) -> None:
encrypted = _encrypt_string("data", key=KEY_16)
assert _decrypt_bytes(encrypted, key=KEY_16) == "data"
def test_24_byte_key(self) -> None:
encrypted = _encrypt_string("data", key=KEY_24)
assert _decrypt_bytes(encrypted, key=KEY_24) == "data"
def test_32_byte_key(self) -> None:
encrypted = _encrypt_string("data", key=KEY_32)
assert _decrypt_bytes(encrypted, key=KEY_32) == "data"
def test_long_key_truncated_to_32(self) -> None:
"""Keys longer than 32 bytes are truncated to 32."""
long_key = "e" * 64
encrypted = _encrypt_string("data", key=long_key)
assert _decrypt_bytes(encrypted, key=long_key) == "data"
def test_20_byte_key_trimmed_to_16(self) -> None:
"""A 20-byte key is trimmed to the largest valid AES size that fits (16)."""
key_20 = "f" * 20
encrypted = _encrypt_string("data", key=key_20)
assert _decrypt_bytes(encrypted, key=key_20) == "data"
# Verify it was trimmed to 16 by checking that the first 16 bytes
# of the key can also decrypt it
key_16_same_prefix = "f" * 16
assert _decrypt_bytes(encrypted, key=key_16_same_prefix) == "data"
def test_25_byte_key_trimmed_to_24(self) -> None:
"""A 25-byte key is trimmed to the largest valid AES size that fits (24)."""
key_25 = "g" * 25
encrypted = _encrypt_string("data", key=key_25)
assert _decrypt_bytes(encrypted, key=key_25) == "data"
key_24_same_prefix = "g" * 24
assert _decrypt_bytes(encrypted, key=key_24_same_prefix) == "data"
def test_30_byte_key_trimmed_to_24(self) -> None:
"""A 30-byte key is trimmed to the largest valid AES size that fits (24)."""
key_30 = "h" * 30
encrypted = _encrypt_string("data", key=key_30)
assert _decrypt_bytes(encrypted, key=key_30) == "data"
key_24_same_prefix = "h" * 24
assert _decrypt_bytes(encrypted, key=key_24_same_prefix) == "data"
class TestWrapperFunctions:
"""Test encrypt_string_to_bytes / decrypt_bytes_to_string pass key through.
With EE mode enabled, the wrappers resolve to EE implementations automatically.
"""
def test_wrapper_passes_key(self) -> None:
encrypted = encrypt_string_to_bytes("payload", key=KEY_16)
assert decrypt_bytes_to_string(encrypted, key=KEY_16) == "payload"
def test_wrapper_no_key_uses_env(self) -> None:
with patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", KEY_32):
encrypted = encrypt_string_to_bytes("payload")
assert decrypt_bytes_to_string(encrypted) == "payload"

View File

@@ -0,0 +1,630 @@
from typing import Any
from unittest.mock import MagicMock
from unittest.mock import patch
from onyx.chat.tool_call_args_streaming import maybe_emit_argument_delta
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import ToolCallArgumentDelta
from onyx.utils.jsonriver import Parser
def _make_tool_call_delta(
index: int = 0,
name: str | None = None,
arguments: str | None = None,
function_is_none: bool = False,
) -> MagicMock:
"""Create a mock tool_call_delta matching the LiteLLM streaming shape."""
delta = MagicMock()
delta.index = index
if function_is_none:
delta.function = None
else:
delta.function = MagicMock()
delta.function.name = name
delta.function.arguments = arguments
return delta
def _make_placement() -> Placement:
return Placement(turn_index=0, tab_index=0)
def _mock_tool_class(emit: bool = True) -> MagicMock:
cls = MagicMock()
cls.should_emit_argument_deltas.return_value = emit
return cls
def _collect(
tc_map: dict[int, dict[str, Any]],
delta: MagicMock,
placement: Placement | None = None,
parsers: dict[int, Parser] | None = None,
) -> list[Any]:
"""Run maybe_emit_argument_delta and return the yielded packets."""
return list(
maybe_emit_argument_delta(
tc_map,
delta,
placement or _make_placement(),
parsers if parsers is not None else {},
)
)
def _stream_fragments(
fragments: list[str],
tc_map: dict[int, dict[str, Any]],
placement: Placement | None = None,
) -> list[str]:
"""Feed fragments into maybe_emit_argument_delta one by one, returning
all emitted content values concatenated per-key as a flat list."""
pl = placement or _make_placement()
parsers: dict[int, Parser] = {}
emitted: list[str] = []
for frag in fragments:
tc_map[0]["arguments"] += frag
delta = _make_tool_call_delta(arguments=frag)
for packet in maybe_emit_argument_delta(tc_map, delta, pl, parsers=parsers):
obj = packet.obj
assert isinstance(obj, ToolCallArgumentDelta)
for value in obj.argument_deltas.values():
emitted.append(value)
return emitted
class TestMaybeEmitArgumentDeltaGuards:
"""Tests for conditions that cause no packet to be emitted."""
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_no_emission_when_tool_does_not_opt_in(
self, mock_get_tool: MagicMock
) -> None:
"""Tools that return False from should_emit_argument_deltas emit nothing."""
mock_get_tool.return_value = _mock_tool_class(emit=False)
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": '{"code": "x'}
}
assert _collect(tc_map, _make_tool_call_delta(arguments="x")) == []
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_no_emission_when_tool_class_unknown(
self, mock_get_tool: MagicMock
) -> None:
mock_get_tool.return_value = None
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "unknown", "arguments": '{"code": "x'}
}
assert _collect(tc_map, _make_tool_call_delta(arguments="x")) == []
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_no_emission_when_no_argument_fragment(
self, mock_get_tool: MagicMock
) -> None:
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": '{"code": "x'}
}
assert _collect(tc_map, _make_tool_call_delta(arguments=None)) == []
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_no_emission_when_key_value_not_started(
self, mock_get_tool: MagicMock
) -> None:
"""Key exists in JSON but its string value hasn't begun yet."""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": '{"code":'}
}
assert _collect(tc_map, _make_tool_call_delta(arguments=":")) == []
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_no_emission_before_any_key(self, mock_get_tool: MagicMock) -> None:
"""Only the opening brace has arrived — no key to stream yet."""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": "{"}
}
assert _collect(tc_map, _make_tool_call_delta(arguments="{")) == []
class TestMaybeEmitArgumentDeltaBasic:
"""Tests for correct packet content and incremental emission."""
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_emits_packet_with_correct_fields(self, mock_get_tool: MagicMock) -> None:
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": ""}
}
fragments = ['{"code": "', "print(1)", '"}']
pl = _make_placement()
parsers: dict[int, Parser] = {}
all_packets = []
for frag in fragments:
tc_map[0]["arguments"] += frag
packets = _collect(
tc_map, _make_tool_call_delta(arguments=frag), pl, parsers
)
all_packets.extend(packets)
assert len(all_packets) >= 1
# Verify packet structure
obj = all_packets[0].obj
assert isinstance(obj, ToolCallArgumentDelta)
assert obj.tool_type == "python"
# All emitted content should reconstruct the value
full_code = ""
for p in all_packets:
assert isinstance(p.obj, ToolCallArgumentDelta)
if "code" in p.obj.argument_deltas:
full_code += p.obj.argument_deltas["code"]
assert full_code == "print(1)"
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_emits_only_new_content_on_subsequent_call(
self, mock_get_tool: MagicMock
) -> None:
"""After a first emission, subsequent calls emit only the diff."""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": ""}
}
parsers: dict[int, Parser] = {}
pl = _make_placement()
# First fragment opens the string
tc_map[0]["arguments"] = '{"code": "abc'
packets_1 = _collect(
tc_map, _make_tool_call_delta(arguments='{"code": "abc'), pl, parsers
)
code_1 = ""
for p in packets_1:
assert isinstance(p.obj, ToolCallArgumentDelta)
code_1 += p.obj.argument_deltas.get("code", "")
assert code_1 == "abc"
# Second fragment appends more
tc_map[0]["arguments"] = '{"code": "abcdef'
packets_2 = _collect(
tc_map, _make_tool_call_delta(arguments="def"), pl, parsers
)
code_2 = ""
for p in packets_2:
assert isinstance(p.obj, ToolCallArgumentDelta)
code_2 += p.obj.argument_deltas.get("code", "")
assert code_2 == "def"
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_handles_multiple_keys_sequentially(self, mock_get_tool: MagicMock) -> None:
"""When a second key starts, emissions switch to that key."""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": ""}
}
fragments = [
'{"code": "x',
'", "output": "hello',
'"}',
]
emitted = _stream_fragments(fragments, tc_map)
full = "".join(emitted)
assert "x" in full
assert "hello" in full
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_delta_spans_key_boundary(self, mock_get_tool: MagicMock) -> None:
"""A single delta contains the end of one value and the start of the next key."""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": ""}
}
fragments = [
'{"code": "x',
'y", "lang": "py',
'"}',
]
emitted = _stream_fragments(fragments, tc_map)
full = "".join(emitted)
assert "xy" in full
assert "py" in full
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_empty_value_emits_nothing(self, mock_get_tool: MagicMock) -> None:
"""An empty string value has nothing to emit."""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": ""}
}
# Opening quote just arrived, value is empty
tc_map[0]["arguments"] = '{"code": "'
packets = _collect(tc_map, _make_tool_call_delta(arguments='{"code": "'))
# No string content yet, so either no packet or empty deltas
for p in packets:
assert isinstance(p.obj, ToolCallArgumentDelta)
assert p.obj.argument_deltas.get("code", "") == ""
class TestMaybeEmitArgumentDeltaDecoding:
"""Tests verifying that JSON escape sequences are properly decoded."""
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_decodes_newlines(self, mock_get_tool: MagicMock) -> None:
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": ""}
}
fragments = ['{"code": "line1\\nline2"}']
emitted = _stream_fragments(fragments, tc_map)
assert "".join(emitted) == "line1\nline2"
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_decodes_tabs(self, mock_get_tool: MagicMock) -> None:
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": ""}
}
fragments = ['{"code": "\\tindented"}']
emitted = _stream_fragments(fragments, tc_map)
assert "".join(emitted) == "\tindented"
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_decodes_escaped_quotes(self, mock_get_tool: MagicMock) -> None:
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": ""}
}
fragments = ['{"code": "say \\"hi\\""}']
emitted = _stream_fragments(fragments, tc_map)
assert "".join(emitted) == 'say "hi"'
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_decodes_escaped_backslashes(self, mock_get_tool: MagicMock) -> None:
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": ""}
}
fragments = ['{"code": "path\\\\dir"}']
emitted = _stream_fragments(fragments, tc_map)
assert "".join(emitted) == "path\\dir"
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_decodes_unicode_escape(self, mock_get_tool: MagicMock) -> None:
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": ""}
}
fragments = ['{"code": "\\u0041"}']
emitted = _stream_fragments(fragments, tc_map)
assert "".join(emitted) == "A"
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_incomplete_escape_at_end_decoded_on_next_chunk(
self, mock_get_tool: MagicMock
) -> None:
"""A trailing backslash (incomplete escape) is completed in the next chunk."""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": ""}
}
fragments = ['{"code": "hello\\', 'n"}']
emitted = _stream_fragments(fragments, tc_map)
assert "".join(emitted) == "hello\n"
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_incomplete_unicode_escape_completed_on_next_chunk(
self, mock_get_tool: MagicMock
) -> None:
"""A partial \\uXX sequence is completed in the next chunk."""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": ""}
}
fragments = ['{"code": "hello\\u00', '41"}']
emitted = _stream_fragments(fragments, tc_map)
assert "".join(emitted) == "helloA"
class TestArgumentDeltaStreamingE2E:
"""Simulates realistic sequences of LLM argument deltas to verify
the full pipeline produces correct decoded output."""
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_realistic_python_code_streaming(self, mock_get_tool: MagicMock) -> None:
"""Streams: {"code": "print('hello')\\nprint('world')"}"""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": ""}
}
fragments = [
'{"',
"code",
'": "',
"print(",
"'hello')",
"\\n",
"print(",
"'world')",
'"}',
]
full = "".join(_stream_fragments(fragments, tc_map))
assert full == "print('hello')\nprint('world')"
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_streaming_with_tabs_and_newlines(self, mock_get_tool: MagicMock) -> None:
"""Streams code with tabs and newlines."""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": ""}
}
fragments = [
'{"code": "',
"if True:",
"\\n",
"\\t",
"pass",
'"}',
]
full = "".join(_stream_fragments(fragments, tc_map))
assert full == "if True:\n\tpass"
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_split_escape_sequence(self, mock_get_tool: MagicMock) -> None:
"""An escape sequence split across two fragments (backslash in one,
'n' in the next) should still decode correctly."""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": ""}
}
fragments = [
'{"code": "hello',
"\\",
"n",
'world"}',
]
full = "".join(_stream_fragments(fragments, tc_map))
assert full == "hello\nworld"
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_multiple_newlines_and_indentation(self, mock_get_tool: MagicMock) -> None:
"""Streams a multi-line function with multiple escape sequences."""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": ""}
}
fragments = [
'{"code": "',
"def foo():",
"\\n",
"\\t",
"x = 1",
"\\n",
"\\t",
"return x",
'"}',
]
full = "".join(_stream_fragments(fragments, tc_map))
assert full == "def foo():\n\tx = 1\n\treturn x"
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_two_keys_streamed_sequentially(self, mock_get_tool: MagicMock) -> None:
"""Streams code first, then a second key (language) — both decoded."""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": ""}
}
fragments = [
'{"code": "',
"x = 1",
'", "language": "',
"python",
'"}',
]
emitted = _stream_fragments(fragments, tc_map)
# Should have emissions for both keys
full = "".join(emitted)
assert "x = 1" in full
assert "python" in full
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_code_containing_dict_literal(self, mock_get_tool: MagicMock) -> None:
"""Python code like `x = {"key": "val"}` contains JSON-like patterns.
The escaped quotes inside the *outer* JSON value should prevent the
inner `"key":` from being mistaken for a top-level JSON key."""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": ""}
}
# The LLM sends: {"code": "x = {\"key\": \"val\"}"}
# The inner quotes are escaped as \" in the JSON value.
fragments = [
'{"code": "',
"x = {",
'\\"key\\"',
": ",
'\\"val\\"',
"}",
'"}',
]
full = "".join(_stream_fragments(fragments, tc_map))
assert full == 'x = {"key": "val"}'
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_code_with_colon_in_value(self, mock_get_tool: MagicMock) -> None:
"""Colons inside the string value should not confuse key detection."""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": ""}
}
fragments = [
'{"code": "',
"url = ",
'\\"https://example.com\\"',
'"}',
]
full = "".join(_stream_fragments(fragments, tc_map))
assert full == 'url = "https://example.com"'
class TestMaybeEmitArgumentDeltaEdgeCases:
"""Edge cases not covered by the standard test classes."""
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_no_emission_when_function_is_none(self, mock_get_tool: MagicMock) -> None:
"""Some delta chunks have function=None (e.g. role-only deltas)."""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": '{"code": "x'}
}
delta = _make_tool_call_delta(arguments=None, function_is_none=True)
assert _collect(tc_map, delta) == []
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_multiple_concurrent_tool_calls(self, mock_get_tool: MagicMock) -> None:
"""Two tool calls streaming at different indices in parallel."""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": ""},
1: {"id": "tc_2", "name": "python", "arguments": ""},
}
parsers: dict[int, Parser] = {}
pl = _make_placement()
# Feed full JSON to index 0
tc_map[0]["arguments"] = '{"code": "aaa"}'
packets_0 = _collect(
tc_map,
_make_tool_call_delta(index=0, arguments='{"code": "aaa"}'),
pl,
parsers,
)
code_0 = ""
for p in packets_0:
assert isinstance(p.obj, ToolCallArgumentDelta)
code_0 += p.obj.argument_deltas.get("code", "")
assert code_0 == "aaa"
# Feed full JSON to index 1
tc_map[1]["arguments"] = '{"code": "bbb"}'
packets_1 = _collect(
tc_map,
_make_tool_call_delta(index=1, arguments='{"code": "bbb"}'),
pl,
parsers,
)
code_1 = ""
for p in packets_1:
assert isinstance(p.obj, ToolCallArgumentDelta)
code_1 += p.obj.argument_deltas.get("code", "")
assert code_1 == "bbb"
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_delta_with_four_arguments(self, mock_get_tool: MagicMock) -> None:
"""A single delta contains four complete key-value pairs."""
mock_get_tool.return_value = _mock_tool_class()
full = '{"a": "one", "b": "two", "c": "three", "d": "four"}'
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": ""}
}
tc_map[0]["arguments"] = full
parsers: dict[int, Parser] = {}
packets = _collect(
tc_map, _make_tool_call_delta(arguments=full), parsers=parsers
)
# Collect all argument deltas across packets
all_deltas: dict[str, str] = {}
for p in packets:
assert isinstance(p.obj, ToolCallArgumentDelta)
for k, v in p.obj.argument_deltas.items():
all_deltas[k] = all_deltas.get(k, "") + v
assert all_deltas == {
"a": "one",
"b": "two",
"c": "three",
"d": "four",
}
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_delta_on_second_arg_after_first_complete(
self, mock_get_tool: MagicMock
) -> None:
"""First argument is fully complete; delta only adds to the second."""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": ""}
}
fragments = [
'{"code": "print(1)", "lang": "py',
'"}',
]
emitted = _stream_fragments(fragments, tc_map)
full = "".join(emitted)
assert "print(1)" in full
assert "py" in full
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
def test_non_string_values_skipped(self, mock_get_tool: MagicMock) -> None:
"""Non-string values (numbers, booleans, null) are skipped — they are
available in the final tool-call kickoff packet. String arguments
following them are still emitted."""
mock_get_tool.return_value = _mock_tool_class()
tc_map: dict[int, dict[str, Any]] = {
0: {"id": "tc_1", "name": "python", "arguments": ""}
}
fragments = ['{"timeout": 30, "code": "hello"}']
emitted = _stream_fragments(fragments, tc_map)
full = "".join(emitted)
assert full == "hello"

View File

@@ -9,6 +9,8 @@ from onyx.connectors.jira.utils import JIRA_SERVER_API_VERSION
from onyx.db.models import ConnectorCredentialPair
from onyx.utils.sensitive import make_mock_sensitive_value
pytestmark = pytest.mark.usefixtures("enable_ee")
@pytest.fixture
def mock_jira_cc_pair(

View File

@@ -0,0 +1,188 @@
from io import BytesIO
from unittest.mock import MagicMock
import pytest
from fastapi import UploadFile
from onyx.server.features.projects import projects_file_utils as utils
class _Tokenizer:
def encode(self, text: str) -> list[int]:
return [1] * len(text)
class _NonSeekableFile(BytesIO):
def tell(self) -> int:
raise OSError("tell not supported")
def seek(self, *_args: object, **_kwargs: object) -> int:
raise OSError("seek not supported")
def _make_upload(filename: str, size: int, content: bytes | None = None) -> UploadFile:
payload = content if content is not None else (b"x" * size)
return UploadFile(filename=filename, file=BytesIO(payload), size=size)
def _make_upload_no_size(filename: str, content: bytes) -> UploadFile:
return UploadFile(filename=filename, file=BytesIO(content), size=None)
def _patch_common_dependencies(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setattr(utils, "fetch_default_llm_model", lambda _db: None)
monkeypatch.setattr(utils, "get_tokenizer", lambda **_kwargs: _Tokenizer())
monkeypatch.setattr(utils, "is_file_password_protected", lambda **_kwargs: False)
def test_get_upload_size_bytes_falls_back_to_stream_size() -> None:
upload = UploadFile(filename="example.txt", file=BytesIO(b"abcdef"), size=None)
upload.file.seek(2)
size = utils.get_upload_size_bytes(upload)
assert size == 6
assert upload.file.tell() == 2
def test_get_upload_size_bytes_logs_warning_when_stream_size_unavailable(
caplog: pytest.LogCaptureFixture,
) -> None:
upload = UploadFile(filename="non_seekable.txt", file=_NonSeekableFile(), size=None)
caplog.set_level("WARNING")
size = utils.get_upload_size_bytes(upload)
assert size is None
assert "Could not determine upload size via stream seek" in caplog.text
assert "non_seekable.txt" in caplog.text
def test_is_upload_too_large_logs_warning_when_size_unknown(
monkeypatch: pytest.MonkeyPatch,
caplog: pytest.LogCaptureFixture,
) -> None:
upload = _make_upload("size_unknown.txt", size=1)
monkeypatch.setattr(utils, "get_upload_size_bytes", lambda _upload: None)
caplog.set_level("WARNING")
is_too_large = utils.is_upload_too_large(upload, max_bytes=100)
assert is_too_large is False
assert "Could not determine upload size; skipping size-limit check" in caplog.text
assert "size_unknown.txt" in caplog.text
def test_categorize_uploaded_files_accepts_size_under_limit(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
upload = _make_upload("small.png", size=99)
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.acceptable) == 1
assert len(result.rejected) == 0
def test_categorize_uploaded_files_uses_seek_fallback_when_upload_size_missing(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
upload = _make_upload_no_size("small.png", content=b"x" * 99)
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.acceptable) == 1
assert len(result.rejected) == 0
def test_categorize_uploaded_files_accepts_size_at_limit(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
upload = _make_upload("edge.png", size=100)
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.acceptable) == 1
assert len(result.rejected) == 0
def test_categorize_uploaded_files_rejects_size_over_limit_with_reason(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
upload = _make_upload("large.png", size=101)
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.acceptable) == 0
assert len(result.rejected) == 1
assert result.rejected[0].reason == "Exceeds 1 MB file size limit"
def test_categorize_uploaded_files_mixed_batch_keeps_valid_and_rejects_oversized(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
small = _make_upload("small.png", size=50)
large = _make_upload("large.png", size=101)
result = utils.categorize_uploaded_files([small, large], MagicMock())
assert [file.filename for file in result.acceptable] == ["small.png"]
assert len(result.rejected) == 1
assert result.rejected[0].filename == "large.png"
assert result.rejected[0].reason == "Exceeds 1 MB file size limit"
def test_categorize_uploaded_files_enforces_size_limit_even_when_threshold_is_skipped(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "SKIP_USERFILE_THRESHOLD", True)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
upload = _make_upload("oversized.pdf", size=101)
result = utils.categorize_uploaded_files([upload], MagicMock())
assert len(result.acceptable) == 0
assert len(result.rejected) == 1
assert result.rejected[0].reason == "Exceeds 1 MB file size limit"
def test_categorize_uploaded_files_checks_size_before_text_extraction(
monkeypatch: pytest.MonkeyPatch,
) -> None:
_patch_common_dependencies(monkeypatch)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
extract_mock = MagicMock(return_value="this should not run")
monkeypatch.setattr(utils, "extract_file_text", extract_mock)
oversized_doc = _make_upload("oversized.pdf", size=101)
result = utils.categorize_uploaded_files([oversized_doc], MagicMock())
extract_mock.assert_not_called()
assert len(result.acceptable) == 0
assert len(result.rejected) == 1
assert result.rejected[0].reason == "Exceeds 1 MB file size limit"

View File

@@ -0,0 +1,32 @@
import pytest
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.server.settings import store as settings_store
class _FakeKvStore:
def load(self, _key: str) -> dict:
raise KvKeyNotFoundError()
class _FakeCache:
def __init__(self) -> None:
self._vals: dict[str, bytes] = {}
def get(self, key: str) -> bytes | None:
return self._vals.get(key)
def set(self, key: str, value: str, ex: int | None = None) -> None: # noqa: ARG002
self._vals[key] = value.encode("utf-8")
def test_load_settings_includes_user_file_max_upload_size_mb(
monkeypatch: pytest.MonkeyPatch,
) -> None:
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore())
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
monkeypatch.setattr(settings_store, "USER_FILE_MAX_UPLOAD_SIZE_MB", 77)
settings = settings_store.load_settings()
assert settings.user_file_max_upload_size_mb == 77

View File

@@ -0,0 +1,394 @@
"""Tests for the jsonriver incremental JSON parser."""
import json
import pytest
from onyx.utils.jsonriver import JsonValue
from onyx.utils.jsonriver import Parser
def _all_deltas(chunks: list[str]) -> list[JsonValue]:
"""Feed chunks one at a time and collect all emitted deltas."""
parser = Parser()
deltas: list[JsonValue] = []
for chunk in chunks:
deltas.extend(parser.feed(chunk))
deltas.extend(parser.finish())
return deltas
class TestParseComplete:
"""Parsing complete JSON in a single chunk."""
def test_simple_object(self) -> None:
deltas = _all_deltas(['{"a": 1}'])
assert any(r == {"a": 1.0} or r == {"a": 1} for r in deltas)
def test_simple_array(self) -> None:
deltas = _all_deltas(["[1, 2, 3]"])
assert any(isinstance(r, list) for r in deltas)
def test_simple_string(self) -> None:
deltas = _all_deltas(['"hello"'])
assert "hello" in deltas or any("hello" in str(r) for r in deltas)
def test_null(self) -> None:
deltas = _all_deltas(["null"])
assert None in deltas
def test_boolean_true(self) -> None:
deltas = _all_deltas(["true"])
assert True in deltas
def test_boolean_false(self) -> None:
deltas = _all_deltas(["false"])
assert any(r is False for r in deltas)
def test_number(self) -> None:
deltas = _all_deltas(["42"])
assert 42.0 in deltas
def test_negative_number(self) -> None:
deltas = _all_deltas(["-3.14"])
assert any(abs(r - (-3.14)) < 1e-10 for r in deltas if isinstance(r, float))
def test_empty_object(self) -> None:
deltas = _all_deltas(["{}"])
assert {} in deltas
def test_empty_array(self) -> None:
deltas = _all_deltas(["[]"])
assert [] in deltas
class TestStreamingDeltas:
"""Incremental feeding produces correct deltas."""
def test_object_string_value_streamed_char_by_char(self) -> None:
chunks = list('{"code": "abc"}')
deltas = _all_deltas(chunks)
str_parts = []
for d in deltas:
if isinstance(d, dict) and "code" in d:
val = d["code"]
if isinstance(val, str):
str_parts.append(val)
assert "".join(str_parts) == "abc"
def test_object_streamed_in_two_halves(self) -> None:
deltas = _all_deltas(['{"name": "Al', 'ice"}'])
str_parts = []
for d in deltas:
if isinstance(d, dict) and "name" in d:
val = d["name"]
if isinstance(val, str):
str_parts.append(val)
assert "".join(str_parts) == "Alice"
def test_multiple_keys_streamed(self) -> None:
deltas = _all_deltas(['{"a": "x', '", "b": "y"}'])
a_parts: list[str] = []
b_parts: list[str] = []
for d in deltas:
if isinstance(d, dict):
if "a" in d and isinstance(d["a"], str):
a_parts.append(d["a"])
if "b" in d and isinstance(d["b"], str):
b_parts.append(d["b"])
assert "".join(a_parts) == "x"
assert "".join(b_parts) == "y"
def test_deltas_only_contain_new_string_content(self) -> None:
parser = Parser()
d1 = parser.feed('{"msg": "hel')
d2 = parser.feed('lo"}')
parser.finish()
msg_parts = []
for d in d1 + d2:
if isinstance(d, dict) and "msg" in d:
val = d["msg"]
if isinstance(val, str):
msg_parts.append(val)
assert "".join(msg_parts) == "hello"
# Each delta should only contain new chars, not repeat previous ones
if len(msg_parts) == 2:
assert msg_parts[0] == "hel"
assert msg_parts[1] == "lo"
class TestEscapeSequences:
"""JSON escape sequences are decoded correctly, even across chunk boundaries."""
def test_newline_escape(self) -> None:
deltas = _all_deltas(['{"text": "line1\\nline2"}'])
text_parts = []
for d in deltas:
if isinstance(d, dict) and "text" in d and isinstance(d["text"], str):
text_parts.append(d["text"])
assert "".join(text_parts) == "line1\nline2"
def test_tab_escape(self) -> None:
deltas = _all_deltas(['{"t": "a\\tb"}'])
parts = []
for d in deltas:
if isinstance(d, dict) and "t" in d and isinstance(d["t"], str):
parts.append(d["t"])
assert "".join(parts) == "a\tb"
def test_escaped_quote(self) -> None:
deltas = _all_deltas(['{"q": "say \\"hi\\""}'])
parts = []
for d in deltas:
if isinstance(d, dict) and "q" in d and isinstance(d["q"], str):
parts.append(d["q"])
assert "".join(parts) == 'say "hi"'
def test_unicode_escape(self) -> None:
deltas = _all_deltas(['{"u": "\\u0041\\u0042"}'])
parts = []
for d in deltas:
if isinstance(d, dict) and "u" in d and isinstance(d["u"], str):
parts.append(d["u"])
assert "".join(parts) == "AB"
def test_escape_split_across_chunks(self) -> None:
deltas = _all_deltas(['{"x": "a\\', 'nb"}'])
parts = []
for d in deltas:
if isinstance(d, dict) and "x" in d and isinstance(d["x"], str):
parts.append(d["x"])
assert "".join(parts) == "a\nb"
def test_unicode_escape_split_across_chunks(self) -> None:
deltas = _all_deltas(['{"u": "\\u00', '41"}'])
parts = []
for d in deltas:
if isinstance(d, dict) and "u" in d and isinstance(d["u"], str):
parts.append(d["u"])
assert "".join(parts) == "A"
def test_backslash_escape(self) -> None:
deltas = _all_deltas(['{"p": "c:\\\\dir"}'])
parts = []
for d in deltas:
if isinstance(d, dict) and "p" in d and isinstance(d["p"], str):
parts.append(d["p"])
assert "".join(parts) == "c:\\dir"
class TestNestedStructures:
"""Nested objects and arrays produce correct deltas."""
def test_nested_object(self) -> None:
deltas = _all_deltas(['{"outer": {"inner": "val"}}'])
found = False
for d in deltas:
if isinstance(d, dict) and "outer" in d:
outer = d["outer"]
if isinstance(outer, dict) and "inner" in outer:
found = True
assert found
def test_array_of_strings(self) -> None:
deltas = _all_deltas(['["a', '", "b"]'])
all_items: list[str] = []
for d in deltas:
if isinstance(d, list):
for item in d:
if isinstance(item, str):
all_items.append(item)
elif isinstance(d, str):
all_items.append(d)
joined = "".join(all_items)
assert "a" in joined
assert "b" in joined
def test_object_with_number_and_bool(self) -> None:
deltas = _all_deltas(['{"count": 42, "active": true}'])
has_count = False
has_active = False
for d in deltas:
if isinstance(d, dict):
if "count" in d and d["count"] == 42.0:
has_count = True
if "active" in d and d["active"] is True:
has_active = True
assert has_count
assert has_active
def test_object_with_null_value(self) -> None:
deltas = _all_deltas(['{"key": null}'])
found = False
for d in deltas:
if isinstance(d, dict) and "key" in d and d["key"] is None:
found = True
assert found
class TestComputeDelta:
"""Direct tests for the _compute_delta static method."""
def test_none_prev_returns_current(self) -> None:
assert Parser._compute_delta(None, {"a": "b"}) == {"a": "b"}
def test_string_delta(self) -> None:
assert Parser._compute_delta("hel", "hello") == "lo"
def test_string_no_change(self) -> None:
assert Parser._compute_delta("same", "same") is None
def test_dict_new_key(self) -> None:
assert Parser._compute_delta({"a": "x"}, {"a": "x", "b": "y"}) == {"b": "y"}
def test_dict_string_append(self) -> None:
assert Parser._compute_delta({"code": "def"}, {"code": "def hello()"}) == {
"code": " hello()"
}
def test_dict_no_change(self) -> None:
assert Parser._compute_delta({"a": 1}, {"a": 1}) is None
def test_list_new_items(self) -> None:
assert Parser._compute_delta([1, 2], [1, 2, 3]) == [3]
def test_list_last_item_updated(self) -> None:
assert Parser._compute_delta(["a"], ["ab"]) == ["ab"]
def test_list_no_change(self) -> None:
assert Parser._compute_delta([1, 2], [1, 2]) is None
def test_primitive_change(self) -> None:
assert Parser._compute_delta(1, 2) == 2
def test_primitive_no_change(self) -> None:
assert Parser._compute_delta(42, 42) is None
class TestParserLifecycle:
"""Edge cases around parser state and lifecycle."""
def test_feed_after_finish_returns_empty(self) -> None:
parser = Parser()
parser.feed('{"a": 1}')
parser.finish()
assert parser.feed("more") == []
def test_empty_feed_returns_empty(self) -> None:
parser = Parser()
assert parser.feed("") == []
def test_whitespace_only_returns_empty(self) -> None:
parser = Parser()
assert parser.feed(" ") == []
def test_finish_with_trailing_whitespace(self) -> None:
parser = Parser()
# Trailing whitespace terminates the number, so feed() emits it
deltas = parser.feed("42 ")
assert 42.0 in deltas
parser.finish() # Should not raise
def test_finish_with_trailing_content_raises(self) -> None:
parser = Parser()
# Feed a complete JSON value followed by non-whitespace in one chunk
parser.feed('{"a": 1} extra')
with pytest.raises(ValueError, match="Unexpected trailing"):
parser.finish()
def test_finish_flushes_pending_number(self) -> None:
parser = Parser()
deltas = parser.feed("42")
# Number has no terminator, so feed() can't emit it yet
assert deltas == []
final = parser.finish()
assert 42.0 in final
class TestToolCallSimulation:
"""Simulate the LLM tool-call streaming use case."""
def test_python_tool_call_streaming(self) -> None:
full_json = json.dumps({"code": "print('hello world')"})
chunk_size = 5
chunks = [
full_json[i : i + chunk_size] for i in range(0, len(full_json), chunk_size)
]
parser = Parser()
code_parts: list[str] = []
for chunk in chunks:
for delta in parser.feed(chunk):
if isinstance(delta, dict) and "code" in delta:
val = delta["code"]
if isinstance(val, str):
code_parts.append(val)
for delta in parser.finish():
if isinstance(delta, dict) and "code" in delta:
val = delta["code"]
if isinstance(val, str):
code_parts.append(val)
assert "".join(code_parts) == "print('hello world')"
def test_multi_arg_tool_call(self) -> None:
full = '{"query": "search term", "num_results": 5}'
chunks = [full[:15], full[15:30], full[30:]]
parser = Parser()
query_parts: list[str] = []
has_num_results = False
for chunk in chunks:
for delta in parser.feed(chunk):
if isinstance(delta, dict):
if "query" in delta and isinstance(delta["query"], str):
query_parts.append(delta["query"])
if "num_results" in delta:
has_num_results = True
for delta in parser.finish():
if isinstance(delta, dict):
if "query" in delta and isinstance(delta["query"], str):
query_parts.append(delta["query"])
if "num_results" in delta:
has_num_results = True
assert "".join(query_parts) == "search term"
assert has_num_results
def test_code_with_newlines_and_escapes(self) -> None:
code = 'def greet(name):\n print(f"Hello, {name}!")\n return True'
full = json.dumps({"code": code})
chunk_size = 8
chunks = [full[i : i + chunk_size] for i in range(0, len(full), chunk_size)]
parser = Parser()
code_parts: list[str] = []
for chunk in chunks:
for delta in parser.feed(chunk):
if isinstance(delta, dict) and "code" in delta:
val = delta["code"]
if isinstance(val, str):
code_parts.append(val)
for delta in parser.finish():
if isinstance(delta, dict) and "code" in delta:
val = delta["code"]
if isinstance(val, str):
code_parts.append(val)
assert "".join(code_parts) == code
def test_single_char_streaming(self) -> None:
full = '{"key": "value"}'
parser = Parser()
key_parts: list[str] = []
for ch in full:
for delta in parser.feed(ch):
if isinstance(delta, dict) and "key" in delta:
val = delta["key"]
if isinstance(val, str):
key_parts.append(val)
for delta in parser.finish():
if isinstance(delta, dict) and "key" in delta:
val = delta["key"]
if isinstance(val, str):
key_parts.append(val)
assert "".join(key_parts) == "value"

View File

@@ -147,15 +147,18 @@ class TestSensitiveValueString:
)
assert sensitive1 != sensitive2
def test_equality_with_non_sensitive_raises(self) -> None:
"""Test that comparing with non-SensitiveValue raises error."""
def test_equality_with_non_sensitive_returns_not_equal(self) -> None:
"""Test that comparing with non-SensitiveValue is always not-equal.
Returns NotImplemented so Python falls back to identity comparison.
This is required for compatibility with SQLAlchemy's attribute tracking.
"""
sensitive = SensitiveValue(
encrypted_bytes=_encrypt_string("secret"),
decrypt_fn=_decrypt_string,
is_json=False,
)
with pytest.raises(SensitiveAccessError):
_ = sensitive == "secret"
assert not (sensitive == "secret")
class TestSensitiveValueJson:

View File

@@ -20,8 +20,6 @@ from onyx.document_index.vespa_constants import TENANT_ID
from onyx.document_index.vespa_constants import USER_PROJECT
from shared_configs.configs import MULTI_TENANT
# Import the function under test
class TestBuildVespaFilters:
def test_empty_filters(self) -> None:
@@ -179,11 +177,27 @@ class TestBuildVespaFilters:
assert f"!({HIDDEN}=true) and " == result
def test_user_project_filter(self) -> None:
"""Test user project filtering (replacement for user folder IDs)."""
# Single project id
"""Test user project filtering.
project_id alone does NOT trigger a knowledge scope restriction
(an agent with no explicit knowledge should search everything).
It only participates when explicit knowledge filters are present.
"""
# project_id alone → no restriction
filters = IndexFilters(access_control_list=[], project_id=789)
result = build_vespa_filters(filters)
assert f'!({HIDDEN}=true) and ({USER_PROJECT} contains "789") and ' == result
assert f"!({HIDDEN}=true) and " == result
# project_id with user_file_ids → both OR'd
id1 = UUID("00000000-0000-0000-0000-000000000123")
filters = IndexFilters(
access_control_list=[], project_id=789, user_file_ids=[id1]
)
result = build_vespa_filters(filters)
assert (
f'!({HIDDEN}=true) and (({DOCUMENT_ID} contains "{str(id1)}") or ({USER_PROJECT} contains "789")) and '
== result
)
# No project id
filters = IndexFilters(access_control_list=[], project_id=None)
@@ -217,7 +231,11 @@ class TestBuildVespaFilters:
)
def test_combined_filters(self) -> None:
"""Test combining multiple filter types."""
"""Test combining multiple filter types.
Knowledge-scope filters (document_set, user_file_ids, project_id,
persona_id) are OR'd together, while all other filters are AND'd.
"""
id1 = UUID("00000000-0000-0000-0000-000000000123")
filters = IndexFilters(
access_control_list=["user1", "group1"],
@@ -231,7 +249,6 @@ class TestBuildVespaFilters:
result = build_vespa_filters(filters)
# Build expected result piece by piece for readability
expected = f"!({HIDDEN}=true) and "
expected += (
'(access_control_list contains "user1" or '
@@ -239,9 +256,13 @@ class TestBuildVespaFilters:
)
expected += f'({SOURCE_TYPE} contains "web") and '
expected += f'({METADATA_LIST} contains "color{INDEX_SEPARATOR}red") and '
expected += f'({DOCUMENT_SETS} contains "set1") and '
expected += f'({DOCUMENT_ID} contains "{str(id1)}") and '
expected += f'({USER_PROJECT} contains "789") and '
# Knowledge scope filters are OR'd together
expected += (
f'(({DOCUMENT_SETS} contains "set1")'
f' or ({DOCUMENT_ID} contains "{str(id1)}")'
f' or ({USER_PROJECT} contains "789")'
f") and "
)
cutoff_secs = int(datetime(2023, 1, 1, tzinfo=timezone.utc).timestamp())
expected += f"!({DOC_UPDATED_AT} < {cutoff_secs}) and "
@@ -251,6 +272,32 @@ class TestBuildVespaFilters:
result_no_trailing = build_vespa_filters(filters, remove_trailing_and=True)
assert expected[:-5] == result_no_trailing # Remove trailing " and "
def test_knowledge_scope_single_filter_not_wrapped(self) -> None:
"""When only one knowledge-scope filter is present it should not
be wrapped in an extra OR group."""
filters = IndexFilters(access_control_list=[], document_set=["set1"])
result = build_vespa_filters(filters)
assert f'!({HIDDEN}=true) and ({DOCUMENT_SETS} contains "set1") and ' == result
def test_knowledge_scope_document_set_and_user_files_ored(self) -> None:
"""Document set filter and user file IDs must be OR'd so that
connector documents (in the set) and user files (with specific
IDs) can both be found."""
id1 = UUID("00000000-0000-0000-0000-000000000123")
filters = IndexFilters(
access_control_list=[],
document_set=["engineering"],
user_file_ids=[id1],
)
result = build_vespa_filters(filters)
expected = (
f"!({HIDDEN}=true) and "
f'(({DOCUMENT_SETS} contains "engineering")'
f' or ({DOCUMENT_ID} contains "{str(id1)}")'
f") and "
)
assert expected == result
def test_empty_or_none_values(self) -> None:
"""Test with empty or None values in filter lists."""
# Empty strings in document set

22
cli/Dockerfile Normal file
View File

@@ -0,0 +1,22 @@
FROM golang:1.26-alpine@sha256:2389ebfa5b7f43eeafbd6be0c3700cc46690ef842ad962f6c5bd6be49ed82039 AS builder
WORKDIR /app
COPY ./ .
ARG TARGETARCH
RUN CGO_ENABLED=0 GOOS=linux GOARCH=${TARGETARCH} go build -ldflags="-s -w" -o onyx-cli .
RUN mkdir -p /home/onyx/.config
FROM scratch
COPY --from=builder /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/
COPY --from=builder --chown=65534:65534 /home/onyx /home/onyx
COPY --from=builder /app/onyx-cli /onyx-cli
ENV HOME=/home/onyx
ENV XDG_CONFIG_HOME=/home/onyx/.config
USER 65534:65534
ENTRYPOINT ["/onyx-cli"]

View File

@@ -1,5 +1,8 @@
# Onyx CLI
[![Release CLI](https://github.com/onyx-dot-app/onyx/actions/workflows/release-cli.yml/badge.svg)](https://github.com/onyx-dot-app/onyx/actions/workflows/release-cli.yml)
[![PyPI](https://img.shields.io/pypi/v/onyx-cli.svg)](https://pypi.org/project/onyx-cli/)
A terminal interface for chatting with your [Onyx](https://github.com/onyx-dot-app/onyx) agent. Built with Go using [Bubble Tea](https://github.com/charmbracelet/bubbletea) for the TUI framework.
## Installation
@@ -28,7 +31,7 @@ Environment variables override config file values:
| Variable | Required | Description |
|----------|----------|-------------|
| `ONYX_SERVER_URL` | No | Server base URL (default: `http://localhost:3000`) |
| `ONYX_SERVER_URL` | No | Server base URL (default: `https://cloud.onyx.app`) |
| `ONYX_API_KEY` | Yes | API key for authentication |
| `ONYX_PERSONA_ID` | No | Default agent/persona ID |

View File

@@ -1261,3 +1261,5 @@ configMap:
SKIP_USERFILE_THRESHOLD: ""
# For multi-tenant: comma-separated list of tenant IDs to skip threshold
SKIP_USERFILE_THRESHOLD_TENANT_IDS: ""
# Maximum user upload file size in MB for chat/projects uploads
USER_FILE_MAX_UPLOAD_SIZE_MB: ""

View File

@@ -18,6 +18,10 @@ variable "INTEGRATION_REPOSITORY" {
default = "onyxdotapp/onyx-integration"
}
variable "CLI_REPOSITORY" {
default = "onyxdotapp/onyx-cli"
}
variable "TAG" {
default = "latest"
}
@@ -64,3 +68,13 @@ target "integration" {
tags = ["${INTEGRATION_REPOSITORY}:${TAG}"]
}
target "cli" {
context = "cli"
dockerfile = "Dockerfile"
cache-from = ["type=registry,ref=${CLI_REPOSITORY}:latest"]
cache-to = ["type=inline"]
tags = ["${CLI_REPOSITORY}:${TAG}"]
}

View File

@@ -26,20 +26,16 @@ class CustomBuildHook(BuildHookInterface):
# Get config and environment
binary_name = self.config["binary_name"]
tag = os.getenv("GITHUB_REF_NAME", "dev").removeprefix(f"{binary_name}/")
tag_prefix = self.config.get("tag_prefix", binary_name)
tag = os.getenv("GITHUB_REF_NAME", "dev").removeprefix(f"{tag_prefix}/")
commit = os.getenv("GITHUB_SHA", "none")
# Build the Go binary if it doesn't exist
if not os.path.exists(binary_name):
print(f"Building Go binary '{binary_name}'...")
ldflags = f"-X main.version={tag} -X main.commit={commit} -s -w"
subprocess.check_call( # noqa: S603
[
"go",
"build",
f"-ldflags=-X main.version={tag} -X main.commit={commit} -s -w",
"-o",
binary_name,
],
["go", "build", f"-ldflags={ldflags}", "-o", binary_name],
)
build_data["shared_scripts"] = {binary_name: binary_name}

View File

@@ -3,6 +3,9 @@ from __future__ import annotations
import os
import re
_tag = os.environ.get("GITHUB_REF_NAME", "v0.0.0-dev").removeprefix("ods/")
# Must match tag_prefix in pyproject.toml [tool.hatch.build.targets.wheel.hooks.custom]
TAG_PREFIX: str = "ods"
_tag = os.environ.get("GITHUB_REF_NAME", "v0.0.0-dev").removeprefix(f"{TAG_PREFIX}/")
_match = re.search(r"v?(\d+\.\d+\.\d+)", _tag)
__version__ = _match.group(1) if _match else "0.0.0"

View File

@@ -14,7 +14,9 @@ keywords = [
classifiers = [
"Programming Language :: Go",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
"Operating System :: POSIX :: Linux",
"Operating System :: MacOS",
"Operating System :: Microsoft :: Windows",
]
dynamic = ["version"]
dependencies = [
@@ -27,7 +29,7 @@ dependencies = [
Repository = "https://github.com/onyx-dot-app/onyx"
[tool.hatch.build]
include = ["go.mod", "go.sum", "main.go", "**/*.go", "**/*.py"]
include = ["go.mod", "go.sum", "main.go", "**/*.go", "**/*.py", "README.md"]
[tool.hatch.version]
source = "code"
@@ -36,6 +38,7 @@ path = "internal/_version.py"
[tool.hatch.build.targets.wheel.hooks.custom]
path = "hatch_build.py"
binary_name = "ods"
tag_prefix = "ods"
[tool.uv]
managed = false

View File

@@ -156,6 +156,7 @@ module.exports = {
"**/src/app/**/*.test.tsx",
"**/src/components/**/*.test.tsx",
"**/src/lib/**/*.test.tsx",
"**/src/providers/**/*.test.tsx",
"**/src/refresh-components/**/*.test.tsx",
"**/src/hooks/**/*.test.tsx",
"**/src/sections/**/*.test.tsx",

View File

@@ -8,7 +8,8 @@ const cspHeader = `
base-uri 'self';
form-action 'self';
${
process.env.NEXT_PUBLIC_CLOUD_ENABLED === "true"
process.env.NEXT_PUBLIC_CLOUD_ENABLED === "true" &&
process.env.NODE_ENV !== "development"
? "upgrade-insecure-requests;"
: ""
}

View File

@@ -6,7 +6,7 @@ import { Button } from "@opal/components";
import { Disabled } from "@opal/core";
import Text from "@/refresh-components/texts/Text";
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
import InputComboBox from "@/refresh-components/inputs/InputComboBox";
import InputSelect from "@/refresh-components/inputs/InputSelect";
import { FormikField } from "@/refresh-components/form/FormikField";
import { FormField } from "@/refresh-components/form/FormField";
import { USER_ROLE_LABELS, UserRole } from "@/lib/types";
@@ -107,26 +107,25 @@ export default function OnyxApiKeyForm({
<FormField name="role" state={state} className="w-full">
<FormField.Label>Role:</FormField.Label>
<FormField.Control>
<InputComboBox
<InputSelect
value={field.value}
onValueChange={(value) => helper.setValue(value)}
options={[
{
label: USER_ROLE_LABELS[UserRole.LIMITED],
value: UserRole.LIMITED.toString(),
},
{
label: USER_ROLE_LABELS[UserRole.BASIC],
value: UserRole.BASIC.toString(),
},
{
label: USER_ROLE_LABELS[UserRole.ADMIN],
value: UserRole.ADMIN.toString(),
},
]}
placeholder="Select a role"
strict
/>
>
<InputSelect.Trigger placeholder="Select a role" />
<InputSelect.Content>
<InputSelect.Item
value={UserRole.LIMITED.toString()}
>
{USER_ROLE_LABELS[UserRole.LIMITED]}
</InputSelect.Item>
<InputSelect.Item value={UserRole.BASIC.toString()}>
{USER_ROLE_LABELS[UserRole.BASIC]}
</InputSelect.Item>
<InputSelect.Item value={UserRole.ADMIN.toString()}>
{USER_ROLE_LABELS[UserRole.ADMIN]}
</InputSelect.Item>
</InputSelect.Content>
</InputSelect>
</FormField.Control>
<FormField.Description>
Select the role for this API key. Limited has access to

View File

@@ -7,7 +7,7 @@ import { SlackTokensForm } from "./SlackTokensForm";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import { SvgSlack } from "@opal/icons";
export const NewSlackBotForm = () => {
export function NewSlackBotForm() {
const [formValues] = useState({
name: "",
enabled: true,
@@ -19,7 +19,12 @@ export const NewSlackBotForm = () => {
return (
<SettingsLayouts.Root>
<SettingsLayouts.Header icon={SvgSlack} title="New Slack Bot" separator />
<SettingsLayouts.Header
icon={SvgSlack}
title="New Slack Bot"
separator
backButton
/>
<SettingsLayouts.Body>
<CardSection>
<div className="p-4">
@@ -33,4 +38,4 @@ export const NewSlackBotForm = () => {
</SettingsLayouts.Body>
</SettingsLayouts.Root>
);
};
}

View File

@@ -1,12 +1,12 @@
"use client";
import { toast } from "@/hooks/useToast";
import { SlackBot, ValidSources } from "@/lib/types";
import { SlackBot } from "@/lib/types";
import { useRouter } from "next/navigation";
import { useState, useEffect, useRef } from "react";
import { updateSlackBotField } from "@/lib/updateSlackBotField";
import { SlackTokensForm } from "./SlackTokensForm";
import { SourceIcon } from "@/components/SourceIcon";
import { EditableStringFieldDisplay } from "@/components/EditableStringFieldDisplay";
import { deleteSlackBot } from "./new/lib";
import GenericConfirmModal from "@/components/modals/GenericConfirmModal";
@@ -90,10 +90,7 @@ export const ExistingSlackBotForm = ({
<div>
<div className="flex items-center justify-between h-14">
<div className="flex items-center gap-2">
<div className="my-auto">
<SourceIcon iconSize={32} sourceType={ValidSources.Slack} />
</div>
<div className="ml-1">
<div>
<EditableStringFieldDisplay
value={formValues.name}
isEditable={true}

View File

@@ -1,100 +1,122 @@
import { SlackChannelConfigCreationForm } from "../SlackChannelConfigCreationForm";
import { fetchSS } from "@/lib/utilsSS";
"use client";
import { use } from "react";
import { SlackChannelConfigCreationForm } from "@/app/admin/bots/[bot-id]/channels/SlackChannelConfigCreationForm";
import { ErrorCallout } from "@/components/ErrorCallout";
import { DocumentSetSummary, SlackChannelConfig } from "@/lib/types";
import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh";
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import { SvgSlack } from "@opal/icons";
import { FetchAgentsResponse, fetchAgentsSS } from "@/lib/agentsSS";
import { getStandardAnswerCategoriesIfEE } from "@/components/standardAnswers/getStandardAnswerCategoriesIfEE";
import { useSlackChannelConfigs } from "@/app/admin/bots/[bot-id]/hooks";
import { useDocumentSets } from "@/app/admin/documents/sets/hooks";
import { useAgents } from "@/hooks/useAgents";
import { useStandardAnswerCategories } from "@/app/ee/admin/standard-answer/hooks";
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
import type { StandardAnswerCategoryResponse } from "@/components/standardAnswers/getStandardAnswerCategoriesIfEE";
async function EditslackChannelConfigPage(props: {
params: Promise<{ id: number }>;
}) {
const params = await props.params;
const tasks = [
fetchSS("/manage/admin/slack-app/channel"),
fetchSS("/manage/document-set"),
fetchAgentsSS(),
];
function EditSlackChannelConfigContent({ id }: { id: string }) {
const isPaidEnterprise = usePaidEnterpriseFeaturesEnabled();
const [
slackChannelsResponse,
documentSetsResponse,
[assistants, agentsFetchError],
] = (await Promise.all(tasks)) as [Response, Response, FetchAgentsResponse];
const {
data: slackChannelConfigs,
isLoading: isChannelsLoading,
error: channelsError,
} = useSlackChannelConfigs();
const eeStandardAnswerCategoryResponse =
await getStandardAnswerCategoriesIfEE();
const {
data: documentSets,
isLoading: isDocSetsLoading,
error: docSetsError,
} = useDocumentSets();
if (!slackChannelsResponse.ok) {
return (
<ErrorCallout
errorTitle="Something went wrong :("
errorMsg={`Failed to fetch Slack Channels - ${await slackChannelsResponse.text()}`}
/>
);
}
const allslackChannelConfigs =
(await slackChannelsResponse.json()) as SlackChannelConfig[];
const {
agents,
isLoading: isAgentsLoading,
error: agentsError,
} = useAgents();
const slackChannelConfig = allslackChannelConfigs.find(
(config) => config.id === Number(params.id)
const {
data: standardAnswerCategories,
isLoading: isStdAnswerLoading,
error: stdAnswerError,
} = useStandardAnswerCategories();
const isLoading =
isChannelsLoading ||
isDocSetsLoading ||
isAgentsLoading ||
(isPaidEnterprise && isStdAnswerLoading);
const slackChannelConfig = slackChannelConfigs?.find(
(config) => config.id === Number(id)
);
if (!slackChannelConfig) {
return (
<ErrorCallout
errorTitle="Something went wrong :("
errorMsg={`Did not find Slack Channel config with ID: ${params.id}`}
/>
);
}
if (!documentSetsResponse.ok) {
return (
<ErrorCallout
errorTitle="Something went wrong :("
errorMsg={`Failed to fetch document sets - ${await documentSetsResponse.text()}`}
/>
);
}
const response = await documentSetsResponse.json();
const documentSets = response as DocumentSetSummary[];
if (agentsFetchError) {
return (
<ErrorCallout
errorTitle="Something went wrong :("
errorMsg={`Failed to fetch personas - ${agentsFetchError}`}
/>
);
}
const title = slackChannelConfig?.is_default
? "Edit Default Slack Config"
: "Edit Slack Channel Config";
return (
<SettingsLayouts.Root>
<InstantSSRAutoRefresh />
<SettingsLayouts.Header
icon={SvgSlack}
title={
slackChannelConfig.is_default
? "Edit Default Slack Config"
: "Edit Slack Channel Config"
}
title={title}
separator
backButton
/>
<SettingsLayouts.Body>
<SlackChannelConfigCreationForm
slack_bot_id={slackChannelConfig.slack_bot_id}
documentSets={documentSets}
personas={assistants}
standardAnswerCategoryResponse={eeStandardAnswerCategoryResponse}
existingSlackChannelConfig={slackChannelConfig}
/>
{isLoading ? (
<SimpleLoader />
) : channelsError || !slackChannelConfigs ? (
<ErrorCallout
errorTitle="Something went wrong :("
errorMsg={`Failed to fetch Slack Channels - ${
channelsError?.message ?? "unknown error"
}`}
/>
) : !slackChannelConfig ? (
<ErrorCallout
errorTitle="Something went wrong :("
errorMsg={`Did not find Slack Channel config with ID: ${id}`}
/>
) : docSetsError || !documentSets ? (
<ErrorCallout
errorTitle="Something went wrong :("
errorMsg={`Failed to fetch document sets - ${
docSetsError?.message ?? "unknown error"
}`}
/>
) : agentsError ? (
<ErrorCallout
errorTitle="Something went wrong :("
errorMsg={`Failed to fetch agents - ${
agentsError?.message ?? "unknown error"
}`}
/>
) : (
<SlackChannelConfigCreationForm
slack_bot_id={slackChannelConfig.slack_bot_id}
documentSets={documentSets}
personas={agents}
standardAnswerCategoryResponse={
isPaidEnterprise
? {
paidEnterpriseFeaturesEnabled: true,
categories: standardAnswerCategories ?? [],
...(stdAnswerError
? { error: { message: String(stdAnswerError) } }
: {}),
}
: { paidEnterpriseFeaturesEnabled: false }
}
existingSlackChannelConfig={slackChannelConfig}
/>
)}
</SettingsLayouts.Body>
</SettingsLayouts.Root>
);
}
export default EditslackChannelConfigPage;
export default function Page(props: { params: Promise<{ id: string }> }) {
const params = use(props.params);
return <EditSlackChannelConfigContent id={params.id} />;
}

View File

@@ -1,53 +1,109 @@
import { SlackChannelConfigCreationForm } from "../SlackChannelConfigCreationForm";
import { fetchSS } from "@/lib/utilsSS";
"use client";
import { use, useEffect } from "react";
import { SlackChannelConfigCreationForm } from "@/app/admin/bots/[bot-id]/channels/SlackChannelConfigCreationForm";
import { ErrorCallout } from "@/components/ErrorCallout";
import { DocumentSetSummary } from "@/lib/types";
import { fetchAgentsSS } from "@/lib/agentsSS";
import { getStandardAnswerCategoriesIfEE } from "@/components/standardAnswers/getStandardAnswerCategoriesIfEE";
import { redirect } from "next/navigation";
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import { SvgSlack } from "@opal/icons";
import { useDocumentSets } from "@/app/admin/documents/sets/hooks";
import { useAgents } from "@/hooks/useAgents";
import { useStandardAnswerCategories } from "@/app/ee/admin/standard-answer/hooks";
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
import type { StandardAnswerCategoryResponse } from "@/components/standardAnswers/getStandardAnswerCategoriesIfEE";
import { useRouter } from "next/navigation";
function NewChannelConfigContent({ slackBotId }: { slackBotId: number }) {
const isPaidEnterprise = usePaidEnterpriseFeaturesEnabled();
const {
data: documentSets,
isLoading: isDocSetsLoading,
error: docSetsError,
} = useDocumentSets();
const {
agents,
isLoading: isAgentsLoading,
error: agentsError,
} = useAgents();
const {
data: standardAnswerCategories,
isLoading: isStdAnswerLoading,
error: stdAnswerError,
} = useStandardAnswerCategories();
if (
isDocSetsLoading ||
isAgentsLoading ||
(isPaidEnterprise && isStdAnswerLoading)
) {
return <SimpleLoader />;
}
if (docSetsError || !documentSets) {
return (
<ErrorCallout
errorTitle="Something went wrong :("
errorMsg={`Failed to fetch document sets - ${
docSetsError?.message ?? "unknown error"
}`}
/>
);
}
if (agentsError) {
return (
<ErrorCallout
errorTitle="Something went wrong :("
errorMsg={`Failed to fetch agents - ${
agentsError?.message ?? "unknown error"
}`}
/>
);
}
const standardAnswerCategoryResponse: StandardAnswerCategoryResponse =
isPaidEnterprise
? {
paidEnterpriseFeaturesEnabled: true,
categories: standardAnswerCategories ?? [],
...(stdAnswerError
? { error: { message: String(stdAnswerError) } }
: {}),
}
: { paidEnterpriseFeaturesEnabled: false };
return (
<SlackChannelConfigCreationForm
slack_bot_id={slackBotId}
documentSets={documentSets}
personas={agents}
standardAnswerCategoryResponse={standardAnswerCategoryResponse}
/>
);
}
export default function Page(props: { params: Promise<{ "bot-id": string }> }) {
const unwrappedParams = use(props.params);
const router = useRouter();
async function NewChannelConfigPage(props: {
params: Promise<{ "bot-id": string }>;
}) {
const unwrappedParams = await props.params;
const slack_bot_id_raw = unwrappedParams?.["bot-id"] || null;
const slack_bot_id = slack_bot_id_raw
? parseInt(slack_bot_id_raw as string, 10)
: null;
useEffect(() => {
if (!slack_bot_id || isNaN(slack_bot_id)) {
router.replace("/admin/bots");
}
}, [slack_bot_id, router]);
if (!slack_bot_id || isNaN(slack_bot_id)) {
redirect("/admin/bots");
return null;
}
const [documentSetsResponse, agentsResponse, standardAnswerCategoryResponse] =
await Promise.all([
fetchSS("/manage/document-set") as Promise<Response>,
fetchAgentsSS(),
getStandardAnswerCategoriesIfEE(),
]);
if (!documentSetsResponse.ok) {
return (
<ErrorCallout
errorTitle="Something went wrong :("
errorMsg={`Failed to fetch document sets - ${await documentSetsResponse.text()}`}
/>
);
}
const documentSets =
(await documentSetsResponse.json()) as DocumentSetSummary[];
if (agentsResponse[1]) {
return (
<ErrorCallout
errorTitle="Something went wrong :("
errorMsg={`Failed to fetch agents - ${agentsResponse[1]}`}
/>
);
}
return (
<SettingsLayouts.Root>
<SettingsLayouts.Header
@@ -57,15 +113,8 @@ async function NewChannelConfigPage(props: {
backButton
/>
<SettingsLayouts.Body>
<SlackChannelConfigCreationForm
slack_bot_id={slack_bot_id}
documentSets={documentSets}
personas={agentsResponse[0]}
standardAnswerCategoryResponse={standardAnswerCategoryResponse}
/>
<NewChannelConfigContent slackBotId={slack_bot_id} />
</SettingsLayouts.Body>
</SettingsLayouts.Root>
);
}
export default NewChannelConfigPage;

View File

@@ -1,82 +1,62 @@
"use client";
import { use } from "react";
import BackButton from "@/refresh-components/buttons/BackButton";
import { ErrorCallout } from "@/components/ErrorCallout";
import { ThreeDotsLoader } from "@/components/Loading";
import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh";
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
import SlackChannelConfigsTable from "./SlackChannelConfigsTable";
import { useSlackBot, useSlackChannelConfigsByBot } from "./hooks";
import { ExistingSlackBotForm } from "../SlackBotUpdateForm";
import Separator from "@/refresh-components/Separator";
function SlackBotEditPage({
params,
}: {
params: Promise<{ "bot-id": string }>;
}) {
// Unwrap the params promise
const unwrappedParams = use(params);
import * as SettingsLayouts from "@/layouts/settings-layouts";
import { SvgSlack } from "@opal/icons";
import { getErrorMsg } from "@/lib/error";
function SlackBotEditContent({ botId }: { botId: string }) {
const {
data: slackBot,
isLoading: isSlackBotLoading,
error: slackBotError,
refreshSlackBot,
} = useSlackBot(Number(unwrappedParams["bot-id"]));
} = useSlackBot(Number(botId));
const {
data: slackChannelConfigs,
isLoading: isSlackChannelConfigsLoading,
error: slackChannelConfigsError,
refreshSlackChannelConfigs,
} = useSlackChannelConfigsByBot(Number(unwrappedParams["bot-id"]));
} = useSlackChannelConfigsByBot(Number(botId));
if (isSlackBotLoading || isSlackChannelConfigsLoading) {
return (
<div className="flex justify-center items-center h-screen">
<ThreeDotsLoader />
</div>
);
return <SimpleLoader />;
}
if (slackBotError || !slackBot) {
const errorMsg =
slackBotError?.info?.message ||
slackBotError?.info?.detail ||
"An unknown error occurred";
return (
<ErrorCallout
errorTitle="Something went wrong :("
errorMsg={`Failed to fetch Slack Bot ${unwrappedParams["bot-id"]}: ${errorMsg}`}
errorMsg={`Failed to fetch Slack Bot ${botId}: ${getErrorMsg(
slackBotError
)}`}
/>
);
}
if (slackChannelConfigsError || !slackChannelConfigs) {
const errorMsg =
slackChannelConfigsError?.info?.message ||
slackChannelConfigsError?.info?.detail ||
"An unknown error occurred";
return (
<ErrorCallout
errorTitle="Something went wrong :("
errorMsg={`Failed to fetch Slack Bot ${unwrappedParams["bot-id"]}: ${errorMsg}`}
errorMsg={`Failed to fetch Slack Bot ${botId}: ${getErrorMsg(
slackChannelConfigsError
)}`}
/>
);
}
return (
<>
<InstantSSRAutoRefresh />
<BackButton routerOverride="/admin/bots" />
<ExistingSlackBotForm
existingSlackBot={slackBot}
refreshSlackBot={refreshSlackBot}
/>
<Separator />
<div className="mt-8">
<SlackChannelConfigsTable
@@ -94,9 +74,19 @@ export default function Page({
}: {
params: Promise<{ "bot-id": string }>;
}) {
const unwrappedParams = use(params);
return (
<>
<SlackBotEditPage params={params} />
</>
<SettingsLayouts.Root>
<SettingsLayouts.Header
icon={SvgSlack}
title="Edit Slack Bot"
backButton
separator
/>
<SettingsLayouts.Body>
<SlackBotEditContent botId={unwrappedParams["bot-id"]} />
</SettingsLayouts.Body>
</SettingsLayouts.Root>
);
}

View File

@@ -1,12 +1,7 @@
import BackButton from "@/refresh-components/buttons/BackButton";
"use client";
import { NewSlackBotForm } from "../SlackBotCreationForm";
export default async function NewSlackBotPage() {
return (
<>
<BackButton routerOverride="/admin/bots" />
<NewSlackBotForm />
</>
);
export default function Page() {
return <NewSlackBotForm />;
}

View File

@@ -1 +0,0 @@
export { default } from "@/refresh-pages/admin/UsersPage";

View File

@@ -1,11 +1,14 @@
import React, { JSX, memo } from "react";
import {
ChatPacket,
CODE_INTERPRETER_TOOL_TYPES,
ImageGenerationToolPacket,
Packet,
PacketType,
ReasoningPacket,
SearchToolStart,
StopReason,
ToolCallArgumentDelta,
} from "../../services/streamingModels";
import {
FullChatState,
@@ -26,7 +29,6 @@ import { DeepResearchPlanRenderer } from "./timeline/renderers/deepresearch/Deep
import { ResearchAgentRenderer } from "./timeline/renderers/deepresearch/ResearchAgentRenderer";
import { WebSearchToolRenderer } from "./timeline/renderers/search/WebSearchToolRenderer";
import { InternalSearchToolRenderer } from "./timeline/renderers/search/InternalSearchToolRenderer";
import { SearchToolStart } from "../../services/streamingModels";
// Different types of chat packets using discriminated unions
interface GroupedPackets {
@@ -56,7 +58,12 @@ function isImageToolPacket(packet: Packet) {
}
function isPythonToolPacket(packet: Packet) {
return packet.obj.type === PacketType.PYTHON_TOOL_START;
return (
packet.obj.type === PacketType.PYTHON_TOOL_START ||
(packet.obj.type === PacketType.TOOL_CALL_ARGUMENT_DELTA &&
(packet.obj as ToolCallArgumentDelta).tool_type ===
CODE_INTERPRETER_TOOL_TYPES.PYTHON)
);
}
function isCustomToolPacket(packet: Packet) {

View File

@@ -10,6 +10,8 @@ import {
Stop,
ImageGenerationToolDelta,
MessageStart,
ToolCallArgumentDelta,
CODE_INTERPRETER_TOOL_TYPES,
} from "@/app/app/services/streamingModels";
import { CitationMap } from "@/app/app/interfaces";
import { OnyxDocument } from "@/lib/search/interfaces";
@@ -138,6 +140,7 @@ const CONTENT_PACKET_TYPES_SET = new Set<PacketType>([
PacketType.SEARCH_TOOL_START,
PacketType.IMAGE_GENERATION_TOOL_START,
PacketType.PYTHON_TOOL_START,
PacketType.TOOL_CALL_ARGUMENT_DELTA,
PacketType.CUSTOM_TOOL_START,
PacketType.FILE_READER_START,
PacketType.FETCH_TOOL_START,
@@ -149,9 +152,16 @@ const CONTENT_PACKET_TYPES_SET = new Set<PacketType>([
]);
function hasContentPackets(packets: Packet[]): boolean {
return packets.some((packet) =>
CONTENT_PACKET_TYPES_SET.has(packet.obj.type as PacketType)
);
return packets.some((packet) => {
const type = packet.obj.type as PacketType;
if (type === PacketType.TOOL_CALL_ARGUMENT_DELTA) {
return (
(packet.obj as ToolCallArgumentDelta).tool_type ===
CODE_INTERPRETER_TOOL_TYPES.PYTHON
);
}
return CONTENT_PACKET_TYPES_SET.has(type);
});
}
/**

View File

@@ -1,6 +1,13 @@
import { Packet, PacketType } from "@/app/app/services/streamingModels";
import {
CODE_INTERPRETER_TOOL_TYPES,
Packet,
PacketType,
ToolCallArgumentDelta,
} from "@/app/app/services/streamingModels";
// Packet types with renderers supporting collapsed streaming mode
// Packet types with renderers supporting collapsed streaming mode.
// TOOL_CALL_ARGUMENT_DELTA is intentionally excluded here because it requires
// a tool_type check — it's handled separately in stepSupportsCollapsedStreaming.
export const COLLAPSED_STREAMING_PACKET_TYPES = new Set<PacketType>([
PacketType.SEARCH_TOOL_START,
PacketType.FETCH_TOOL_START,
@@ -21,7 +28,13 @@ export const isSearchToolPackets = (packets: Packet[]): boolean =>
// Check if packets belong to a python tool
export const isPythonToolPackets = (packets: Packet[]): boolean =>
packets.some((p) => p.obj.type === PacketType.PYTHON_TOOL_START);
packets.some(
(p) =>
p.obj.type === PacketType.PYTHON_TOOL_START ||
(p.obj.type === PacketType.TOOL_CALL_ARGUMENT_DELTA &&
(p.obj as ToolCallArgumentDelta).tool_type ===
CODE_INTERPRETER_TOOL_TYPES.PYTHON)
);
// Check if packets belong to reasoning
export const isReasoningPackets = (packets: Packet[]): boolean =>
@@ -29,8 +42,12 @@ export const isReasoningPackets = (packets: Packet[]): boolean =>
// Check if step supports collapsed streaming rendering mode
export const stepSupportsCollapsedStreaming = (packets: Packet[]): boolean =>
packets.some((p) =>
COLLAPSED_STREAMING_PACKET_TYPES.has(p.obj.type as PacketType)
packets.some(
(p) =>
COLLAPSED_STREAMING_PACKET_TYPES.has(p.obj.type as PacketType) ||
(p.obj.type === PacketType.TOOL_CALL_ARGUMENT_DELTA &&
(p.obj as ToolCallArgumentDelta).tool_type ===
CODE_INTERPRETER_TOOL_TYPES.PYTHON)
);
// Check if packets have content worth rendering in collapsed streaming mode.
@@ -67,7 +84,13 @@ export const stepHasCollapsedStreamingContent = (
// Python tool renders code/output from the start packet onward
if (
packetTypes.has(PacketType.PYTHON_TOOL_START) ||
packetTypes.has(PacketType.PYTHON_TOOL_DELTA)
packetTypes.has(PacketType.PYTHON_TOOL_DELTA) ||
packets.some(
(p) =>
p.obj.type === PacketType.TOOL_CALL_ARGUMENT_DELTA &&
(p.obj as ToolCallArgumentDelta).tool_type ===
CODE_INTERPRETER_TOOL_TYPES.PYTHON
)
) {
return true;
}

View File

@@ -4,7 +4,9 @@ import {
PythonToolPacket,
PythonToolStart,
PythonToolDelta,
ToolCallArgumentDelta,
SectionEnd,
CODE_INTERPRETER_TOOL_TYPES,
} from "@/app/app/services/streamingModels";
import {
MessageRenderer,
@@ -39,6 +41,18 @@ function HighlightedPythonCode({ code }: { code: string }) {
// Helper function to construct current Python execution state
function constructCurrentPythonState(packets: PythonToolPacket[]) {
// Accumulate streaming code from argument deltas (arrives before PythonToolStart)
const streamingCode = packets
.filter(
(packet) =>
packet.obj.type === PacketType.TOOL_CALL_ARGUMENT_DELTA &&
(packet.obj as ToolCallArgumentDelta).tool_type ===
CODE_INTERPRETER_TOOL_TYPES.PYTHON
)
.map((packet) =>
String((packet.obj as ToolCallArgumentDelta).argument_deltas.code ?? "")
)
.join("");
const pythonStart = packets.find(
(packet) => packet.obj.type === PacketType.PYTHON_TOOL_START
)?.obj as PythonToolStart | null;
@@ -51,7 +65,8 @@ function constructCurrentPythonState(packets: PythonToolPacket[]) {
packet.obj.type === PacketType.ERROR
)?.obj as SectionEnd | null;
const code = pythonStart?.code || "";
// Use complete code from PythonToolStart if available, else use streamed code.
const code = pythonStart?.code || streamingCode;
const stdout = pythonDeltas
.map((delta) => delta?.stdout || "")
.filter((s) => s)
@@ -61,6 +76,7 @@ function constructCurrentPythonState(packets: PythonToolPacket[]) {
.filter((s) => s)
.join("");
const fileIds = pythonDeltas.flatMap((delta) => delta?.file_ids || []);
const isStreaming = !pythonStart && streamingCode.length > 0;
const isExecuting = pythonStart && !pythonEnd;
const isComplete = pythonStart && pythonEnd;
const hasError = stderr.length > 0;
@@ -70,6 +86,7 @@ function constructCurrentPythonState(packets: PythonToolPacket[]) {
stdout,
stderr,
fileIds,
isStreaming,
isExecuting,
isComplete,
hasError,
@@ -82,8 +99,16 @@ export const PythonToolRenderer: MessageRenderer<PythonToolPacket, {}> = ({
renderType,
children,
}) => {
const { code, stdout, stderr, fileIds, isExecuting, isComplete, hasError } =
constructCurrentPythonState(packets);
const {
code,
stdout,
stderr,
fileIds,
isStreaming,
isExecuting,
isComplete,
hasError,
} = constructCurrentPythonState(packets);
useEffect(() => {
if (isComplete) {
@@ -92,6 +117,9 @@ export const PythonToolRenderer: MessageRenderer<PythonToolPacket, {}> = ({
}, [isComplete, onComplete]);
const status = useMemo(() => {
if (isStreaming) {
return "Writing code...";
}
if (isExecuting) {
return "Executing Python code...";
}
@@ -102,13 +130,13 @@ export const PythonToolRenderer: MessageRenderer<PythonToolPacket, {}> = ({
return "Python execution completed";
}
return "Python execution";
}, [isComplete, isExecuting, hasError]);
}, [isStreaming, isComplete, isExecuting, hasError]);
// Shared content for all states - used by both FULL and compact modes
const content = (
<div className="flex flex-col mb-1 space-y-2">
{/* Loading indicator when executing */}
{isExecuting && (
{/* Loading indicator when streaming or executing */}
{(isStreaming || isExecuting) && (
<div className="flex items-center gap-2 text-sm text-muted-foreground">
<div className="flex gap-0.5">
<div className="w-1 h-1 bg-current rounded-full animate-pulse"></div>
@@ -121,7 +149,7 @@ export const PythonToolRenderer: MessageRenderer<PythonToolPacket, {}> = ({
style={{ animationDelay: "0.2s" }}
></div>
</div>
<span>Running code...</span>
<span>{isStreaming ? "Writing code..." : "Running code..."}</span>
</div>
)}

View File

@@ -308,15 +308,15 @@ export async function getProjectTokenCount(projectId: number): Promise<number> {
export async function getMaxSelectedDocumentTokens(
personaId: number
): Promise<number> {
): Promise<number | null> {
const response = await fetch(
`/api/chat/max-selected-document-tokens?persona_id=${personaId}`
);
if (!response.ok) {
return 128_000;
return null;
}
const json = await response.json();
return (json?.max_tokens as number) ?? 128_000;
return (json?.max_tokens as number) ?? null;
}
export async function moveChatSession(

View File

@@ -288,15 +288,15 @@ export async function deleteAllChatSessions() {
export async function getAvailableContextTokens(
chatSessionId: string
): Promise<number> {
): Promise<number | null> {
const response = await fetch(
`/api/chat/available-context-tokens/${chatSessionId}`
);
if (!response.ok) {
return 0;
return null;
}
const data = (await response.json()) as { available_tokens: number };
return data?.available_tokens ?? 0;
return data?.available_tokens ?? null;
}
export function processRawChatHistory(

View File

@@ -16,6 +16,7 @@ export function isToolPacket(
PacketType.SEARCH_TOOL_DOCUMENTS_DELTA,
PacketType.PYTHON_TOOL_START,
PacketType.PYTHON_TOOL_DELTA,
PacketType.TOOL_CALL_ARGUMENT_DELTA,
PacketType.CUSTOM_TOOL_START,
PacketType.CUSTOM_TOOL_DELTA,
PacketType.FILE_READER_START,

View File

@@ -27,6 +27,9 @@ export enum PacketType {
FETCH_TOOL_URLS = "open_url_urls",
FETCH_TOOL_DOCUMENTS = "open_url_documents",
// Tool call argument delta (streams tool args before tool executes)
TOOL_CALL_ARGUMENT_DELTA = "tool_call_argument_delta",
// Custom tool packets
CUSTOM_TOOL_START = "custom_tool_start",
CUSTOM_TOOL_DELTA = "custom_tool_delta",
@@ -59,6 +62,10 @@ export enum PacketType {
INTERMEDIATE_REPORT_CITED_DOCS = "intermediate_report_cited_docs",
}
export const CODE_INTERPRETER_TOOL_TYPES = {
PYTHON: "python",
} as const;
// Basic Message Packets
export interface MessageStart extends BaseObj {
id: string;
@@ -149,6 +156,13 @@ export interface PythonToolDelta extends BaseObj {
file_ids: string[];
}
export interface ToolCallArgumentDelta extends BaseObj {
type: "tool_call_argument_delta";
tool_type: string;
tool_id: string;
argument_deltas: Record<string, unknown>;
}
export interface FetchToolStart extends BaseObj {
type: "open_url_start";
}
@@ -294,6 +308,7 @@ export type ImageGenerationToolObj =
export type PythonToolObj =
| PythonToolStart
| PythonToolDelta
| ToolCallArgumentDelta
| SectionEnd
| PacketError;
export type FetchToolObj =

View File

@@ -31,7 +31,6 @@ const SETTINGS_LAYOUT_PREFIXES = [
ADMIN_PATHS.LLM_MODELS,
ADMIN_PATHS.AGENTS,
ADMIN_PATHS.USERS,
ADMIN_PATHS.USERS_V2,
ADMIN_PATHS.TOKEN_RATE_LIMITS,
ADMIN_PATHS.SEARCH_SETTINGS,
ADMIN_PATHS.DOCUMENT_PROCESSING,

View File

@@ -1,46 +0,0 @@
"use client";
import useSWR from "swr";
import { errorHandlingFetcher } from "@/lib/fetcher";
import type { UserRole } from "@/lib/types";
import type { PaginatedUsersResponse } from "@/refresh-pages/admin/UsersPage/interfaces";
interface UseAdminUsersParams {
pageIndex: number;
pageSize: number;
searchTerm?: string;
roles?: UserRole[];
isActive?: boolean | undefined;
}
export default function useAdminUsers({
pageIndex,
pageSize,
searchTerm,
roles,
isActive,
}: UseAdminUsersParams) {
const queryParams = new URLSearchParams({
page_num: String(pageIndex),
page_size: String(pageSize),
...(searchTerm && { q: searchTerm }),
...(isActive === true && { is_active: "true" }),
...(isActive === false && { is_active: "false" }),
});
for (const role of roles ?? []) {
queryParams.append("roles", role);
}
const { data, isLoading, error, mutate } = useSWR<PaginatedUsersResponse>(
`/api/manage/users/accepted?${queryParams.toString()}`,
errorHandlingFetcher
);
return {
users: data?.items ?? [],
totalItems: data?.total_items ?? 0,
isLoading,
error,
refresh: mutate,
};
}

View File

@@ -2,9 +2,12 @@
import {
buildChatUrl,
getAvailableContextTokens,
nameChatSession,
updateLlmOverrideForChatSession,
} from "@/app/app/services/lib";
import { getMaxSelectedDocumentTokens } from "@/app/app/projects/projectsService";
import { DEFAULT_CONTEXT_TOKENS } from "@/lib/constants";
import { StreamStopInfo } from "@/lib/search/interfaces";
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import type { Route } from "next";
@@ -194,9 +197,6 @@ export default function useChatController({
const navigatingAway = useRef(false);
// Local state that doesn't need to be in the store
const [_maxTokens, setMaxTokens] = useState<number>(4096);
// Sync store state changes
useEffect(() => {
if (currentSessionId) {
@@ -1067,21 +1067,59 @@ export default function useChatController({
handleSlackChatRedirect();
}, [searchParams, router]);
// fetch # of allowed document tokens for the selected Persona
useEffect(() => {
if (!liveAgent?.id) return; // avoid calling with undefined persona id
// Available context tokens: if a chat session exists, fetch from the session
// API (dynamic per session/model). Otherwise derive from the persona's max
// document tokens. The backend already accounts for system prompt, tools,
// and user-message reservations.
const [availableContextTokens, setAvailableContextTokens] = useState<number>(
DEFAULT_CONTEXT_TOKENS
);
async function fetchMaxTokens() {
const response = await fetch(
`/api/chat/max-selected-document-tokens?persona_id=${liveAgent?.id}`
);
if (response.ok) {
const maxTokens = (await response.json()).max_tokens as number;
setMaxTokens(maxTokens);
useEffect(() => {
if (!llmManager.hasAnyProvider) return;
let cancelled = false;
const setIfActive = (tokens: number) => {
if (!cancelled) setAvailableContextTokens(tokens);
};
// Prefer the Zustand session ID, but fall back to the URL-derived prop
// so we don't incorrectly take the persona path while the store is
// still initialising on navigation to an existing chat.
const sessionId = currentSessionId || existingChatSessionId;
(async () => {
try {
if (sessionId) {
const available = await getAvailableContextTokens(sessionId);
setIfActive(available ?? DEFAULT_CONTEXT_TOKENS);
return;
}
const personaId = liveAgent?.id;
if (personaId == null) {
setIfActive(DEFAULT_CONTEXT_TOKENS);
return;
}
const maxTokens = await getMaxSelectedDocumentTokens(personaId);
setIfActive(maxTokens ?? DEFAULT_CONTEXT_TOKENS);
} catch (e) {
console.error("Failed to fetch available context tokens:", e);
setIfActive(DEFAULT_CONTEXT_TOKENS);
}
}
fetchMaxTokens();
}, [liveAgent]);
})();
return () => {
cancelled = true;
};
}, [
currentSessionId,
existingChatSessionId,
liveAgent?.id,
llmManager.hasAnyProvider,
]);
// check if there's an image file in the message history so that we know
// which LLMs are available to use
@@ -1110,5 +1148,7 @@ export default function useChatController({
onSubmit,
stopGenerating,
handleMessageSpecificFileUpload,
// data
availableContextTokens,
};
}

View File

@@ -1,34 +0,0 @@
"use client";
import useSWR from "swr";
import { errorHandlingFetcher } from "@/lib/fetcher";
import type { InvitedUserSnapshot } from "@/lib/types";
import { NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
interface PaginatedCountResponse {
total_items: number;
}
export default function useUserCounts() {
// Active user count — lightweight fetch (page_size=1 to minimize payload)
const { data: activeData } = useSWR<PaginatedCountResponse>(
"/api/manage/users/accepted?page_num=0&page_size=1",
errorHandlingFetcher
);
const { data: invitedUsers } = useSWR<InvitedUserSnapshot[]>(
"/api/manage/users/invited",
errorHandlingFetcher
);
const { data: pendingUsers } = useSWR<InvitedUserSnapshot[]>(
NEXT_PUBLIC_CLOUD_ENABLED ? "/api/tenants/users/pending" : null,
errorHandlingFetcher
);
return {
activeCount: activeData?.total_items ?? null,
invitedCount: invitedUsers?.length ?? null,
pendingCount: pendingUsers?.length ?? null,
};
}

View File

@@ -36,6 +36,7 @@ export interface Settings {
// User Knowledge settings
user_knowledge_enabled?: boolean;
user_file_max_upload_size_mb?: number | null;
// Connector settings
show_extra_connectors?: boolean;

View File

@@ -230,7 +230,7 @@ function SettingsHeader({
</div>
)}
<Spacer vertical rem={2.5} />
<Spacer vertical rem={1} />
<div className="flex flex-col gap-6 px-4">
<div className="flex w-full justify-between">

View File

@@ -58,7 +58,6 @@ export const ADMIN_PATHS = {
DOCUMENT_PROCESSING: "/admin/configuration/document-processing",
KNOWLEDGE_GRAPH: "/admin/kg",
USERS: "/admin/users",
USERS_V2: "/admin/users2",
API_KEYS: "/admin/api-key",
TOKEN_RATE_LIMITS: "/admin/token-rate-limits",
USAGE: "/admin/performance/usage",
@@ -191,11 +190,6 @@ export const ADMIN_ROUTE_CONFIG: Record<string, AdminRouteConfig> = {
title: "Manage Users",
sidebarLabel: "Users",
},
[ADMIN_PATHS.USERS_V2]: {
icon: SvgUser,
title: "Users & Requests",
sidebarLabel: "Users v2",
},
[ADMIN_PATHS.API_KEYS]: {
icon: SvgKey,
title: "API Keys",

10
web/src/lib/error.ts Normal file
View File

@@ -0,0 +1,10 @@
/**
* Extract a human-readable error message from an SWR error object.
* SWR errors from `errorHandlingFetcher` attach `info.message` or `info.detail`.
*/
export function getErrorMsg(
error: { info?: { message?: string; detail?: string } } | null | undefined,
fallback = "An unknown error occurred"
): string {
return error?.info?.message || error?.info?.detail || fallback;
}

View File

@@ -43,6 +43,7 @@ import { useAppRouter } from "@/hooks/appNavigation";
import { ChatFileType } from "@/app/app/interfaces";
import { toast } from "@/hooks/useToast";
import { useProjects } from "@/lib/hooks/useProjects";
import { SettingsContext } from "@/providers/SettingsProvider";
export type { Project, ProjectFile } from "@/app/app/projects/projectsService";
@@ -84,6 +85,8 @@ function buildFileKey(file: File): string {
return `${file.size}|${namePrefix}`;
}
const DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB = 50;
interface ProjectsContextType {
projects: Project[];
recentFiles: ProjectFile[];
@@ -157,6 +160,7 @@ export function ProjectsProvider({ children }: ProjectsProviderProps) {
new Map()
);
const route = useAppRouter();
const settingsContext = useContext(SettingsContext);
// Use SWR's mutate to refresh projects - returns the new data
const fetchProjects = useCallback(async (): Promise<Project[]> => {
@@ -336,16 +340,40 @@ export function ProjectsProvider({ children }: ProjectsProviderProps) {
onSuccess?: (uploaded: CategorizedFiles) => void,
onFailure?: (failedTempIds: string[]) => void
): Promise<ProjectFile[]> => {
const optimisticFiles = files.map((f) =>
const rawMax = settingsContext?.settings?.user_file_max_upload_size_mb;
const maxUploadSizeMb =
rawMax && rawMax > 0 ? rawMax : DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB;
const maxUploadSizeBytes = maxUploadSizeMb * 1024 * 1024;
const oversizedFiles = files.filter(
(file) => file.size > maxUploadSizeBytes
);
const validFiles = files.filter(
(file) => file.size <= maxUploadSizeBytes
);
if (oversizedFiles.length > 0) {
const skippedNames = oversizedFiles.map((file) => file.name).join(", ");
toast.warning(
`Skipped ${oversizedFiles.length} oversized file(s) (>${maxUploadSizeMb} MB): ${skippedNames}`
);
}
if (validFiles.length === 0) {
onFailure?.([]);
return [];
}
const optimisticFiles = validFiles.map((f) =>
createOptimisticFile(f, projectId)
);
const tempIdMap = getTempIdMap(files, optimisticFiles);
const tempIdMap = getTempIdMap(validFiles, optimisticFiles);
setAllRecentFiles((prev) => [...optimisticFiles, ...prev]);
if (projectId) {
setAllCurrentProjectFiles((prev) => [...optimisticFiles, ...prev]);
projectToUploadFilesMapRef.current.set(projectId, optimisticFiles);
}
svcUploadFiles(files, projectId, tempIdMap)
svcUploadFiles(validFiles, projectId, tempIdMap)
.then((uploaded) => {
const uploadedFiles = uploaded.user_files || [];
const tempIdToUploadedFileMap = new Map(
@@ -445,6 +473,7 @@ export function ProjectsProvider({ children }: ProjectsProviderProps) {
refreshCurrentProjectDetails,
refreshRecentFiles,
removeOptimisticFilesByTempIds,
settingsContext,
]
);

View File

@@ -0,0 +1,166 @@
import React, { PropsWithChildren } from "react";
import { act, renderHook } from "@testing-library/react";
import {
ProjectsProvider,
useProjectsContext,
} from "@/providers/ProjectsContext";
import { SettingsContext } from "@/providers/SettingsProvider";
import { CombinedSettings } from "@/interfaces/settings";
import type { ProjectFile } from "@/app/app/projects/projectsService";
const mockUploadFiles = jest.fn();
const mockGetRecentFiles = jest.fn();
const mockToastWarning = jest.fn();
jest.mock("next/navigation", () => ({
useSearchParams: () => ({
get: () => null,
}),
}));
jest.mock("@/hooks/appNavigation", () => ({
useAppRouter: () => jest.fn(),
}));
jest.mock("@/lib/hooks/useProjects", () => ({
useProjects: () => ({
projects: [],
refreshProjects: jest.fn().mockResolvedValue([]),
}),
}));
jest.mock("@/hooks/useToast", () => ({
toast: {
warning: (...args: unknown[]) => mockToastWarning(...args),
error: jest.fn(),
success: jest.fn(),
},
}));
jest.mock("@/app/app/projects/projectsService", () => {
const actual = jest.requireActual("@/app/app/projects/projectsService");
return {
...actual,
fetchProjects: jest.fn().mockResolvedValue([]),
createProject: jest.fn(),
uploadFiles: (...args: unknown[]) => mockUploadFiles(...args),
getRecentFiles: (...args: unknown[]) => mockGetRecentFiles(...args),
getFilesInProject: jest.fn().mockResolvedValue([]),
getProject: jest.fn(),
getProjectInstructions: jest.fn(),
upsertProjectInstructions: jest.fn(),
getProjectDetails: jest.fn(),
renameProject: jest.fn(),
deleteProject: jest.fn(),
deleteUserFile: jest.fn(),
getUserFileStatuses: jest.fn().mockResolvedValue([]),
unlinkFileFromProject: jest.fn(),
linkFileToProject: jest.fn(),
};
});
const settingsValue: CombinedSettings = {
settings: {
user_file_max_upload_size_mb: 1,
} as CombinedSettings["settings"],
enterpriseSettings: null,
customAnalyticsScript: null,
webVersion: null,
webDomain: null,
isSearchModeAvailable: true,
};
const wrapper = ({ children }: PropsWithChildren) => (
<SettingsContext.Provider value={settingsValue}>
<ProjectsProvider>{children}</ProjectsProvider>
</SettingsContext.Provider>
);
describe("ProjectsContext beginUpload size precheck", () => {
beforeEach(() => {
mockUploadFiles.mockReset();
mockGetRecentFiles.mockReset();
mockToastWarning.mockReset();
mockUploadFiles.mockResolvedValue({
user_files: [],
rejected_files: [],
});
mockGetRecentFiles.mockResolvedValue([]);
});
it("only sends valid files to the upload API when oversized files are present", async () => {
const { result } = renderHook(() => useProjectsContext(), { wrapper });
const valid = new File(["small"], "small.txt", { type: "text/plain" });
const oversized = new File([new Uint8Array(2 * 1024 * 1024)], "big.txt", {
type: "text/plain",
});
let optimisticFiles: ProjectFile[] = [];
await act(async () => {
optimisticFiles = await result.current.beginUpload(
[valid, oversized],
null
);
});
expect(mockUploadFiles).toHaveBeenCalledTimes(1);
const [uploadedFiles] = mockUploadFiles.mock.calls[0];
expect((uploadedFiles as File[]).map((f) => f.name)).toEqual(["small.txt"]);
expect(optimisticFiles.map((f) => f.name)).toEqual(["small.txt"]);
expect(mockToastWarning).toHaveBeenCalledTimes(1);
});
it("uploads all files when none are oversized", async () => {
const { result } = renderHook(() => useProjectsContext(), { wrapper });
const first = new File(["small"], "first.txt", { type: "text/plain" });
const second = new File(["small"], "second.txt", { type: "text/plain" });
let optimisticFiles: ProjectFile[] = [];
await act(async () => {
optimisticFiles = await result.current.beginUpload([first, second], null);
});
expect(mockUploadFiles).toHaveBeenCalledTimes(1);
const [uploadedFiles] = mockUploadFiles.mock.calls[0];
expect((uploadedFiles as File[]).map((f) => f.name)).toEqual([
"first.txt",
"second.txt",
]);
expect(mockToastWarning).not.toHaveBeenCalled();
expect(optimisticFiles.map((f) => f.name)).toEqual([
"first.txt",
"second.txt",
]);
});
it("does not call upload API when all files are oversized", async () => {
const { result } = renderHook(() => useProjectsContext(), { wrapper });
const oversized = new File(
[new Uint8Array(2 * 1024 * 1024)],
"too-big.txt",
{ type: "text/plain" }
);
const onSuccess = jest.fn();
const onFailure = jest.fn();
let optimisticFiles: ProjectFile[] = [];
await act(async () => {
optimisticFiles = await result.current.beginUpload(
[oversized],
null,
onSuccess,
onFailure
);
});
expect(mockUploadFiles).not.toHaveBeenCalled();
expect(optimisticFiles).toEqual([]);
expect(mockToastWarning).toHaveBeenCalledTimes(1);
expect(onSuccess).not.toHaveBeenCalled();
expect(onFailure).toHaveBeenCalledWith([]);
});
});

View File

@@ -57,6 +57,12 @@ export interface LineItemProps
WithoutStyles<React.HTMLAttributes<HTMLDivElement>>,
"children"
> {
/**
* Whether the row should behave like a standalone interactive button.
* Set to false when nested inside another interactive primitive
* (e.g. Radix Select.Item) to avoid nested focus targets.
*/
interactive?: boolean;
// line-item variants
strikethrough?: boolean;
danger?: boolean;
@@ -131,6 +137,7 @@ export interface LineItemProps
* - The component automatically adds a `data-selected="true"` attribute for custom styling
*/
export default function LineItem({
interactive = true,
selected,
strikethrough,
danger,
@@ -164,6 +171,11 @@ export default function LineItem({
const emphasisKey = emphasized ? "emphasized" : "normal";
const handleKeyDown = (e: React.KeyboardEvent<HTMLDivElement>) => {
if (!interactive) {
props.onKeyDown?.(e);
return;
}
if (e.key === "Enter") {
e.preventDefault();
(e.currentTarget as HTMLDivElement).click();
@@ -174,6 +186,11 @@ export default function LineItem({
};
const handleKeyUp = (e: React.KeyboardEvent<HTMLDivElement>) => {
if (!interactive) {
props.onKeyUp?.(e);
return;
}
if (e.key === " ") {
e.preventDefault();
(e.currentTarget as HTMLDivElement).click();
@@ -184,8 +201,8 @@ export default function LineItem({
const content = (
<div
ref={ref}
role="button"
tabIndex={0}
role={interactive ? "button" : undefined}
tabIndex={interactive ? 0 : undefined}
className={cn(
"flex flex-row w-full items-start p-2 rounded-08 group/LineItem gap-2",
!!(children && description) ? "items-start" : "items-center",

View File

@@ -369,7 +369,7 @@ function InputSelectItem({
<SelectPrimitive.Item
ref={ref}
value={value}
className="outline-none focus:outline-none"
className="outline-none focus:outline-none rounded-08 data-[highlighted]:bg-background-tint-02"
onSelect={onClick}
>
{/* Hidden ItemText for Radix to track selection */}
@@ -383,7 +383,7 @@ function InputSelectItem({
selected={isSelected}
emphasized
description={description}
onClick={noProp((event) => event.preventDefault())}
interactive={false}
>
{children}
</LineItem>

View File

@@ -1,10 +1,7 @@
"use client";
import { redirect, useRouter, useSearchParams } from "next/navigation";
import {
personaIncludesRetrieval,
getAvailableContextTokens,
} from "@/app/app/services/lib";
import { personaIncludesRetrieval } from "@/app/app/services/lib";
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import { toast } from "@/hooks/useToast";
import { SEARCH_PARAM_NAMES } from "@/app/app/services/searchParams";
@@ -56,10 +53,7 @@ import ChatScrollContainer, {
} from "@/sections/chat/ChatScrollContainer";
import ProjectContextPanel from "@/app/app/components/projects/ProjectContextPanel";
import { useProjectsContext } from "@/providers/ProjectsContext";
import {
getProjectTokenCount,
getMaxSelectedDocumentTokens,
} from "@/app/app/projects/projectsService";
import { getProjectTokenCount } from "@/app/app/projects/projectsService";
import ProjectChatSessionList from "@/app/app/components/projects/ProjectChatSessionList";
import { cn } from "@/lib/utils";
import Suggestions from "@/sections/Suggestions";
@@ -70,7 +64,6 @@ import * as AppLayouts from "@/layouts/app-layouts";
import { SvgChevronDown, SvgFileText } from "@opal/icons";
import { Button } from "@opal/components";
import Spacer from "@/refresh-components/Spacer";
import { DEFAULT_CONTEXT_TOKENS } from "@/lib/constants";
import useAppFocus from "@/hooks/useAppFocus";
import { useQueryController } from "@/providers/QueryControllerProvider";
import WelcomeMessage from "@/app/app/components/WelcomeMessage";
@@ -193,7 +186,7 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
onSubmit({
message,
currentMessageFiles,
deepResearch: deepResearchEnabled,
deepResearch: deepResearchEnabledForCurrentWorkflow,
});
}
}
@@ -218,6 +211,8 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
chatSessionId: currentChatSessionId,
agentId: selectedAgent?.id,
});
const deepResearchEnabledForCurrentWorkflow =
currentProjectId === null && deepResearchEnabled;
const [presentingDocument, setPresentingDocument] =
useState<MinimalOnyxDocument | null>(null);
@@ -365,18 +360,22 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
const autoScrollEnabled = user?.preferences?.auto_scroll !== false;
const isStreaming = currentChatState === "streaming";
const { onSubmit, stopGenerating, handleMessageSpecificFileUpload } =
useChatController({
filterManager,
llmManager,
availableAgents: agents,
liveAgent,
existingChatSessionId: currentChatSessionId,
selectedDocuments,
searchParams,
resetInputBar,
setSelectedAgentFromId,
});
const {
onSubmit,
stopGenerating,
handleMessageSpecificFileUpload,
availableContextTokens,
} = useChatController({
filterManager,
llmManager,
availableAgents: agents,
liveAgent,
existingChatSessionId: currentChatSessionId,
selectedDocuments,
searchParams,
resetInputBar,
setSelectedAgentFromId,
});
const { onMessageSelection, currentSessionFileTokenCount } =
useChatSessionController({
@@ -435,10 +434,15 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
onSubmit({
message: lastUserMsg.message,
currentMessageFiles: currentMessageFiles,
deepResearch: deepResearchEnabled,
deepResearch: deepResearchEnabledForCurrentWorkflow,
messageIdToResend: lastUserMsg.messageId,
});
}, [messageHistory, onSubmit, currentMessageFiles, deepResearchEnabled]);
}, [
messageHistory,
onSubmit,
currentMessageFiles,
deepResearchEnabledForCurrentWorkflow,
]);
const toggleDocumentSidebar = useCallback(() => {
if (!documentSidebarVisible) {
@@ -458,7 +462,7 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
onSubmit({
message,
currentMessageFiles,
deepResearch: deepResearchEnabled,
deepResearch: deepResearchEnabledForCurrentWorkflow,
});
if (showOnboarding || !onboardingDismissed) {
finishOnboarding();
@@ -468,7 +472,7 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
resetInputBar,
onSubmit,
currentMessageFiles,
deepResearchEnabled,
deepResearchEnabledForCurrentWorkflow,
showOnboarding,
onboardingDismissed,
finishOnboarding,
@@ -503,7 +507,7 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
onSubmit({
message,
currentMessageFiles,
deepResearch: deepResearchEnabled,
deepResearch: deepResearchEnabledForCurrentWorkflow,
});
if (showOnboarding || !onboardingDismissed) {
finishOnboarding();
@@ -524,7 +528,7 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
resetInputBar,
onSubmit,
currentMessageFiles,
deepResearchEnabled,
deepResearchEnabledForCurrentWorkflow,
showOnboarding,
onboardingDismissed,
finishOnboarding,
@@ -588,43 +592,6 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
};
}, [currentChatSessionId, currentProjectId, currentProjectDetails?.files]);
// Available context tokens source of truth:
// - If a chat session exists, fetch from session API (dynamic per session/model)
// - If no session, derive from the default/current persona's max document tokens
const [availableContextTokens, setAvailableContextTokens] = useState<number>(
DEFAULT_CONTEXT_TOKENS * 0.5
);
useEffect(() => {
let cancelled = false;
async function run() {
try {
if (currentChatSessionId) {
const available =
await getAvailableContextTokens(currentChatSessionId);
const capped_context_tokens =
(available ?? DEFAULT_CONTEXT_TOKENS) * 0.5;
if (!cancelled) setAvailableContextTokens(capped_context_tokens);
} else {
const personaId = (selectedAgent || liveAgent)?.id;
if (personaId !== undefined && personaId !== null) {
const maxTokens = await getMaxSelectedDocumentTokens(personaId);
const capped_context_tokens =
(maxTokens ?? DEFAULT_CONTEXT_TOKENS) * 0.5;
if (!cancelled) setAvailableContextTokens(capped_context_tokens);
} else if (!cancelled) {
setAvailableContextTokens(DEFAULT_CONTEXT_TOKENS * 0.5);
}
}
} catch (e) {
if (!cancelled) setAvailableContextTokens(DEFAULT_CONTEXT_TOKENS * 0.5);
}
}
run();
return () => {
cancelled = true;
};
}, [currentChatSessionId, selectedAgent?.id, liveAgent?.id]);
// handle error case where no assistants are available
// Only show this after agents have loaded to prevent flash during initial load
if (noAgents && !isLoadingAgents) {
@@ -709,7 +676,7 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
>
{/* Main content grid — 3 rows, animated */}
<div
className="flex-1 w-full grid min-h-0 transition-[grid-template-rows] duration-150 ease-in-out"
className="flex-1 w-full grid min-h-0 px-4 transition-[grid-template-rows] duration-150 ease-in-out"
style={gridStyle}
>
{/* ── Top row: ChatUI / WelcomeMessage / ProjectUI ── */}
@@ -732,7 +699,9 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
<ChatUI
liveAgent={liveAgent!}
llmManager={llmManager}
deepResearchEnabled={deepResearchEnabled}
deepResearchEnabled={
deepResearchEnabledForCurrentWorkflow
}
currentMessageFiles={currentMessageFiles}
setPresentingDocument={setPresentingDocument}
onSubmit={onSubmit}
@@ -828,7 +797,9 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
/>
<AppInputBar
ref={chatInputBarRef}
deepResearchEnabled={deepResearchEnabled}
deepResearchEnabled={
deepResearchEnabledForCurrentWorkflow
}
toggleDeepResearch={toggleDeepResearch}
filterManager={filterManager}
llmManager={llmManager}

View File

@@ -1,66 +0,0 @@
"use client";
import { useState } from "react";
import { SvgUser, SvgUserPlus } from "@opal/icons";
import { Button } from "@opal/components";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import { useScimToken } from "@/hooks/useScimToken";
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
import useUserCounts from "@/hooks/useUserCounts";
import UsersSummary from "./UsersPage/UsersSummary";
import UsersTable from "./UsersPage/UsersTable";
import InviteUsersModal from "./UsersPage/InviteUsersModal";
// ---------------------------------------------------------------------------
// Users page content
// ---------------------------------------------------------------------------
function UsersContent() {
const isEe = usePaidEnterpriseFeaturesEnabled();
const { data: scimToken } = useScimToken();
const showScim = isEe && !!scimToken;
const { activeCount, invitedCount, pendingCount } = useUserCounts();
return (
<>
<UsersSummary
activeUsers={activeCount}
pendingInvites={invitedCount}
requests={pendingCount}
showScim={showScim}
/>
<UsersTable />
</>
);
}
// ---------------------------------------------------------------------------
// Page
// ---------------------------------------------------------------------------
export default function UsersPage() {
const [inviteOpen, setInviteOpen] = useState(false);
return (
<SettingsLayouts.Root width="lg">
<SettingsLayouts.Header
title="Users & Requests"
icon={SvgUser}
rightChildren={
<Button icon={SvgUserPlus} onClick={() => setInviteOpen(true)}>
Invite Users
</Button>
}
/>
<SettingsLayouts.Body>
<UsersContent />
</SettingsLayouts.Body>
<InviteUsersModal open={inviteOpen} onOpenChange={setInviteOpen} />
</SettingsLayouts.Root>
);
}

Some files were not shown because too many files have changed in this diff Show More