mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-09 09:42:39 +00:00
Compare commits
10 Commits
v3.0.0-enc
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9d7dc3da21 | ||
|
|
2899be4c5e | ||
|
|
64ee7fc23f | ||
|
|
e07764285d | ||
|
|
cc2e6ffa8a | ||
|
|
d3ee5c9b59 | ||
|
|
dfa0efc093 | ||
|
|
9aad4077f1 | ||
|
|
29d9ebf7b3 | ||
|
|
f1df36e306 |
177
.github/workflows/release-cli.yml
vendored
177
.github/workflows/release-cli.yml
vendored
@@ -26,7 +26,7 @@ jobs:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
- uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
- uses: astral-sh/setup-uv@5a095e7a2014a4212f075830d4f7277575a9d098 # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
@@ -37,3 +37,178 @@ jobs:
|
||||
working-directory: cli
|
||||
- run: uv publish
|
||||
working-directory: cli
|
||||
|
||||
docker-amd64:
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-cli-amd64
|
||||
- extras=ecr-cache
|
||||
environment: deploy
|
||||
permissions:
|
||||
id-token: write
|
||||
timeout-minutes: 30
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-cli
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # ratchet:aws-actions/configure-aws-credentials@v6.0.0
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 # ratchet:aws-actions/aws-secretsmanager-get-secrets@v2.0.10
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # ratchet:docker/login-action@v4
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push AMD64
|
||||
id: build
|
||||
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # ratchet:docker/build-push-action@v7
|
||||
with:
|
||||
context: ./cli
|
||||
file: ./cli/Dockerfile
|
||||
platforms: linux/amd64
|
||||
cache-from: type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: type=inline
|
||||
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
|
||||
docker-arm64:
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-arm64
|
||||
- run-id=${{ github.run_id }}-cli-arm64
|
||||
- extras=ecr-cache
|
||||
environment: deploy
|
||||
permissions:
|
||||
id-token: write
|
||||
timeout-minutes: 30
|
||||
outputs:
|
||||
digest: ${{ steps.build.outputs.digest }}
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-cli
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Checkout
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # ratchet:aws-actions/configure-aws-credentials@v6.0.0
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 # ratchet:aws-actions/aws-secretsmanager-get-secrets@v2.0.10
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # ratchet:docker/login-action@v4
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and push ARM64
|
||||
id: build
|
||||
uses: docker/build-push-action@d08e5c354a6adb9ed34480a06d141179aa583294 # ratchet:docker/build-push-action@v7
|
||||
with:
|
||||
context: ./cli
|
||||
file: ./cli/Dockerfile
|
||||
platforms: linux/arm64
|
||||
cache-from: type=registry,ref=${{ env.REGISTRY_IMAGE }}:latest
|
||||
cache-to: type=inline
|
||||
outputs: type=image,name=${{ env.REGISTRY_IMAGE }},push-by-digest=true,name-canonical=true,push=true
|
||||
|
||||
merge-docker:
|
||||
needs:
|
||||
- docker-amd64
|
||||
- docker-arm64
|
||||
runs-on:
|
||||
- runs-on
|
||||
- runner=2cpu-linux-x64
|
||||
- run-id=${{ github.run_id }}-cli-merge
|
||||
environment: deploy
|
||||
permissions:
|
||||
id-token: write
|
||||
timeout-minutes: 10
|
||||
env:
|
||||
REGISTRY_IMAGE: onyxdotapp/onyx-cli
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
- name: Configure AWS credentials
|
||||
uses: aws-actions/configure-aws-credentials@8df5847569e6427dd6c4fb1cf565c83acfa8afa7 # ratchet:aws-actions/configure-aws-credentials@v6.0.0
|
||||
with:
|
||||
role-to-assume: ${{ secrets.AWS_OIDC_ROLE_ARN }}
|
||||
aws-region: us-east-2
|
||||
|
||||
- name: Get AWS Secrets
|
||||
uses: aws-actions/aws-secretsmanager-get-secrets@a9a7eb4e2f2871d30dc5b892576fde60a2ecc802 # ratchet:aws-actions/aws-secretsmanager-get-secrets@v2.0.10
|
||||
with:
|
||||
secret-ids: |
|
||||
DOCKER_USERNAME, deploy/docker-username
|
||||
DOCKER_TOKEN, deploy/docker-token
|
||||
parse-json-secrets: true
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@4d04d5d9486b7bd6fa91e7baf45bbb4f8b9deedd # ratchet:docker/setup-buildx-action@v4
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@b45d80f862d83dbcd57f89517bcf500b2ab88fb2 # ratchet:docker/login-action@v4
|
||||
with:
|
||||
username: ${{ env.DOCKER_USERNAME }}
|
||||
password: ${{ env.DOCKER_TOKEN }}
|
||||
|
||||
- name: Create and push manifest
|
||||
env:
|
||||
AMD64_DIGEST: ${{ needs.docker-amd64.outputs.digest }}
|
||||
ARM64_DIGEST: ${{ needs.docker-arm64.outputs.digest }}
|
||||
TAG: ${{ github.ref_name }}
|
||||
run: |
|
||||
SANITIZED_TAG="${TAG#cli/}"
|
||||
IMAGES=(
|
||||
"${REGISTRY_IMAGE}@${AMD64_DIGEST}"
|
||||
"${REGISTRY_IMAGE}@${ARM64_DIGEST}"
|
||||
)
|
||||
|
||||
if [[ "$TAG" =~ ^cli/v[0-9]+\.[0-9]+\.[0-9]+$ ]]; then
|
||||
docker buildx imagetools create \
|
||||
-t "${REGISTRY_IMAGE}:${SANITIZED_TAG}" \
|
||||
-t "${REGISTRY_IMAGE}:latest" \
|
||||
"${IMAGES[@]}"
|
||||
else
|
||||
docker buildx imagetools create \
|
||||
-t "${REGISTRY_IMAGE}:${SANITIZED_TAG}" \
|
||||
"${IMAGES[@]}"
|
||||
fi
|
||||
|
||||
@@ -141,6 +141,7 @@ COPY --chown=onyx:onyx ./scripts/debugging /app/scripts/debugging
|
||||
COPY --chown=onyx:onyx ./scripts/force_delete_connector_by_id.py /app/scripts/force_delete_connector_by_id.py
|
||||
COPY --chown=onyx:onyx ./scripts/supervisord_entrypoint.sh /app/scripts/supervisord_entrypoint.sh
|
||||
COPY --chown=onyx:onyx ./scripts/setup_craft_templates.sh /app/scripts/setup_craft_templates.sh
|
||||
COPY --chown=onyx:onyx ./scripts/reencrypt_secrets.py /app/scripts/reencrypt_secrets.py
|
||||
RUN chmod +x /app/scripts/supervisord_entrypoint.sh /app/scripts/setup_craft_templates.sh
|
||||
|
||||
# Run Craft template setup at build time when ENABLE_CRAFT=true
|
||||
|
||||
@@ -11,7 +11,6 @@ from sqlalchemy import text
|
||||
from alembic import op
|
||||
from onyx.configs.app_configs import DB_READONLY_PASSWORD
|
||||
from onyx.configs.app_configs import DB_READONLY_USER
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
@@ -22,59 +21,52 @@ depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
if MULTI_TENANT:
|
||||
# Enable pg_trgm extension if not already enabled
|
||||
op.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm")
|
||||
|
||||
# Enable pg_trgm extension if not already enabled
|
||||
op.execute("CREATE EXTENSION IF NOT EXISTS pg_trgm")
|
||||
# Create the read-only db user if it does not already exist.
|
||||
if not (DB_READONLY_USER and DB_READONLY_PASSWORD):
|
||||
raise Exception("DB_READONLY_USER or DB_READONLY_PASSWORD is not set")
|
||||
|
||||
# Create read-only db user here only in multi-tenant mode. For single-tenant mode,
|
||||
# the user is created in the standard migration.
|
||||
if not (DB_READONLY_USER and DB_READONLY_PASSWORD):
|
||||
raise Exception("DB_READONLY_USER or DB_READONLY_PASSWORD is not set")
|
||||
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
-- Check if the read-only user already exists
|
||||
IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
-- Create the read-only user with the specified password
|
||||
EXECUTE format('CREATE USER %I WITH PASSWORD %L', '{DB_READONLY_USER}', '{DB_READONLY_PASSWORD}');
|
||||
-- First revoke all privileges to ensure a clean slate
|
||||
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
|
||||
-- Grant only the CONNECT privilege to allow the user to connect to the database
|
||||
-- but not perform any operations without additional specific grants
|
||||
EXECUTE format('GRANT CONNECT ON DATABASE %I TO %I', current_database(), '{DB_READONLY_USER}');
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
if MULTI_TENANT:
|
||||
# Drop read-only db user here only in single tenant mode. For multi-tenant mode,
|
||||
# the user is dropped in the alembic_tenants migration.
|
||||
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
-- First revoke all privileges from the database
|
||||
-- Check if the read-only user already exists
|
||||
IF NOT EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
-- Create the read-only user with the specified password
|
||||
EXECUTE format('CREATE USER %I WITH PASSWORD %L', '{DB_READONLY_USER}', '{DB_READONLY_PASSWORD}');
|
||||
-- First revoke all privileges to ensure a clean slate
|
||||
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
|
||||
-- Then revoke all privileges from the public schema
|
||||
EXECUTE format('REVOKE ALL ON SCHEMA public FROM %I', '{DB_READONLY_USER}');
|
||||
-- Then drop the user
|
||||
EXECUTE format('DROP USER %I', '{DB_READONLY_USER}');
|
||||
-- Grant only the CONNECT privilege to allow the user to connect to the database
|
||||
-- but not perform any operations without additional specific grants
|
||||
EXECUTE format('GRANT CONNECT ON DATABASE %I TO %I', current_database(), '{DB_READONLY_USER}');
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
"""
|
||||
)
|
||||
op.execute(text("DROP EXTENSION IF EXISTS pg_trgm"))
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
text(
|
||||
f"""
|
||||
DO $$
|
||||
BEGIN
|
||||
IF EXISTS (SELECT FROM pg_catalog.pg_roles WHERE rolname = '{DB_READONLY_USER}') THEN
|
||||
-- First revoke all privileges from the database
|
||||
EXECUTE format('REVOKE ALL ON DATABASE %I FROM %I', current_database(), '{DB_READONLY_USER}');
|
||||
-- Then revoke all privileges from the public schema
|
||||
EXECUTE format('REVOKE ALL ON SCHEMA public FROM %I', '{DB_READONLY_USER}');
|
||||
-- Then drop the user
|
||||
EXECUTE format('DROP USER %I', '{DB_READONLY_USER}');
|
||||
END IF;
|
||||
END
|
||||
$$;
|
||||
"""
|
||||
)
|
||||
)
|
||||
op.execute(text("DROP EXTENSION IF EXISTS pg_trgm"))
|
||||
|
||||
@@ -14,67 +14,91 @@ from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
@lru_cache(maxsize=2)
|
||||
def _get_trimmed_key(key: str) -> bytes:
|
||||
encoded_key = key.encode()
|
||||
key_length = len(encoded_key)
|
||||
if key_length < 16:
|
||||
raise RuntimeError("Invalid ENCRYPTION_KEY_SECRET - too short")
|
||||
elif key_length > 32:
|
||||
key = key[:32]
|
||||
elif key_length not in (16, 24, 32):
|
||||
valid_lengths = [16, 24, 32]
|
||||
key = key[: min(valid_lengths, key=lambda x: abs(x - key_length))]
|
||||
|
||||
return encoded_key
|
||||
# Trim to the largest valid AES key size that fits
|
||||
valid_lengths = [32, 24, 16]
|
||||
for size in valid_lengths:
|
||||
if key_length >= size:
|
||||
return encoded_key[:size]
|
||||
|
||||
raise AssertionError("unreachable")
|
||||
|
||||
|
||||
def _encrypt_string(input_str: str) -> bytes:
|
||||
if not ENCRYPTION_KEY_SECRET:
|
||||
def _encrypt_string(input_str: str, key: str | None = None) -> bytes:
|
||||
effective_key = key if key is not None else ENCRYPTION_KEY_SECRET
|
||||
if not effective_key:
|
||||
return input_str.encode()
|
||||
|
||||
key = _get_trimmed_key(ENCRYPTION_KEY_SECRET)
|
||||
trimmed = _get_trimmed_key(effective_key)
|
||||
iv = urandom(16)
|
||||
padder = padding.PKCS7(algorithms.AES.block_size).padder()
|
||||
padded_data = padder.update(input_str.encode()) + padder.finalize()
|
||||
|
||||
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
|
||||
cipher = Cipher(algorithms.AES(trimmed), modes.CBC(iv), backend=default_backend())
|
||||
encryptor = cipher.encryptor()
|
||||
encrypted_data = encryptor.update(padded_data) + encryptor.finalize()
|
||||
|
||||
return iv + encrypted_data
|
||||
|
||||
|
||||
def _decrypt_bytes(input_bytes: bytes) -> str:
|
||||
if not ENCRYPTION_KEY_SECRET:
|
||||
def _decrypt_bytes(input_bytes: bytes, key: str | None = None) -> str:
|
||||
effective_key = key if key is not None else ENCRYPTION_KEY_SECRET
|
||||
if not effective_key:
|
||||
return input_bytes.decode()
|
||||
|
||||
key = _get_trimmed_key(ENCRYPTION_KEY_SECRET)
|
||||
iv = input_bytes[:16]
|
||||
encrypted_data = input_bytes[16:]
|
||||
trimmed = _get_trimmed_key(effective_key)
|
||||
try:
|
||||
iv = input_bytes[:16]
|
||||
encrypted_data = input_bytes[16:]
|
||||
|
||||
cipher = Cipher(algorithms.AES(key), modes.CBC(iv), backend=default_backend())
|
||||
decryptor = cipher.decryptor()
|
||||
decrypted_padded_data = decryptor.update(encrypted_data) + decryptor.finalize()
|
||||
cipher = Cipher(
|
||||
algorithms.AES(trimmed), modes.CBC(iv), backend=default_backend()
|
||||
)
|
||||
decryptor = cipher.decryptor()
|
||||
decrypted_padded_data = decryptor.update(encrypted_data) + decryptor.finalize()
|
||||
|
||||
unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder()
|
||||
decrypted_data = unpadder.update(decrypted_padded_data) + unpadder.finalize()
|
||||
unpadder = padding.PKCS7(algorithms.AES.block_size).unpadder()
|
||||
decrypted_data = unpadder.update(decrypted_padded_data) + unpadder.finalize()
|
||||
|
||||
return decrypted_data.decode()
|
||||
return decrypted_data.decode()
|
||||
except (ValueError, UnicodeDecodeError):
|
||||
if key is not None:
|
||||
# Explicit key was provided — don't fall back silently
|
||||
raise
|
||||
# Read path: attempt raw UTF-8 decode as a fallback for legacy data.
|
||||
# Does NOT handle data encrypted with a different key — that
|
||||
# ciphertext is not valid UTF-8 and will raise below.
|
||||
logger.warning(
|
||||
"AES decryption failed — falling back to raw decode. "
|
||||
"Run the re-encrypt secrets script to rotate to the current key."
|
||||
)
|
||||
try:
|
||||
return input_bytes.decode()
|
||||
except UnicodeDecodeError:
|
||||
raise ValueError(
|
||||
"Data is not valid UTF-8 — likely encrypted with a different key. "
|
||||
"Run the re-encrypt secrets script to rotate to the current key."
|
||||
) from None
|
||||
|
||||
|
||||
def encrypt_string_to_bytes(input_str: str) -> bytes:
|
||||
def encrypt_string_to_bytes(input_str: str, key: str | None = None) -> bytes:
|
||||
versioned_encryption_fn = fetch_versioned_implementation(
|
||||
"onyx.utils.encryption", "_encrypt_string"
|
||||
)
|
||||
return versioned_encryption_fn(input_str)
|
||||
return versioned_encryption_fn(input_str, key=key)
|
||||
|
||||
|
||||
def decrypt_bytes_to_string(input_bytes: bytes) -> str:
|
||||
def decrypt_bytes_to_string(input_bytes: bytes, key: str | None = None) -> str:
|
||||
versioned_decryption_fn = fetch_versioned_implementation(
|
||||
"onyx.utils.encryption", "_decrypt_bytes"
|
||||
)
|
||||
return versioned_decryption_fn(input_bytes)
|
||||
return versioned_decryption_fn(input_bytes, key=key)
|
||||
|
||||
|
||||
def test_encryption() -> None:
|
||||
|
||||
@@ -15,6 +15,7 @@ from onyx.chat.citation_processor import DynamicCitationProcessor
|
||||
from onyx.chat.emitter import Emitter
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.chat.models import LlmStepResult
|
||||
from onyx.chat.tool_call_args_streaming import maybe_emit_argument_delta
|
||||
from onyx.configs.app_configs import LOG_ONYX_MODEL_INTERACTIONS
|
||||
from onyx.configs.app_configs import PROMPT_CACHE_CHAT_HISTORY
|
||||
from onyx.configs.constants import MessageType
|
||||
@@ -54,6 +55,7 @@ from onyx.server.query_and_chat.streaming_models import ReasoningStart
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tracing.framework.create import generation_span
|
||||
from onyx.utils.b64 import get_image_type_from_bytes
|
||||
from onyx.utils.jsonriver import Parser
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.postgres_sanitization import sanitize_string
|
||||
from onyx.utils.text_processing import find_all_json_objects
|
||||
@@ -1009,6 +1011,7 @@ def run_llm_step_pkt_generator(
|
||||
)
|
||||
|
||||
id_to_tool_call_map: dict[int, dict[str, Any]] = {}
|
||||
arg_parsers: dict[int, Parser] = {}
|
||||
reasoning_start = False
|
||||
answer_start = False
|
||||
accumulated_reasoning = ""
|
||||
@@ -1215,7 +1218,14 @@ def run_llm_step_pkt_generator(
|
||||
yield from _close_reasoning_if_active()
|
||||
|
||||
for tool_call_delta in delta.tool_calls:
|
||||
# maybe_emit depends and update being called first and attaching the delta
|
||||
_update_tool_call_with_delta(id_to_tool_call_map, tool_call_delta)
|
||||
yield from maybe_emit_argument_delta(
|
||||
tool_calls_in_progress=id_to_tool_call_map,
|
||||
tool_call_delta=tool_call_delta,
|
||||
placement=_current_placement(),
|
||||
parsers=arg_parsers,
|
||||
)
|
||||
|
||||
# Flush any tail text buffered while checking for split "<function_calls" markers.
|
||||
filtered_content_tail = xml_tool_call_content_filter.flush()
|
||||
|
||||
77
backend/onyx/chat/tool_call_args_streaming.py
Normal file
77
backend/onyx/chat/tool_call_args_streaming.py
Normal file
@@ -0,0 +1,77 @@
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
from typing import Type
|
||||
|
||||
from onyx.llm.model_response import ChatCompletionDeltaToolCall
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import ToolCallArgumentDelta
|
||||
from onyx.tools.built_in_tools import TOOL_NAME_TO_CLASS
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.utils.jsonriver import Parser
|
||||
|
||||
|
||||
def _get_tool_class(
|
||||
tool_calls_in_progress: Mapping[int, Mapping[str, Any]],
|
||||
tool_call_delta: ChatCompletionDeltaToolCall,
|
||||
) -> Type[Tool] | None:
|
||||
"""Look up the Tool subclass for a streaming tool call delta."""
|
||||
tool_name = tool_calls_in_progress.get(tool_call_delta.index, {}).get("name")
|
||||
if not tool_name:
|
||||
return None
|
||||
return TOOL_NAME_TO_CLASS.get(tool_name)
|
||||
|
||||
|
||||
def maybe_emit_argument_delta(
|
||||
tool_calls_in_progress: Mapping[int, Mapping[str, Any]],
|
||||
tool_call_delta: ChatCompletionDeltaToolCall,
|
||||
placement: Placement,
|
||||
parsers: dict[int, Parser],
|
||||
) -> Generator[Packet, None, None]:
|
||||
"""Emit decoded tool-call argument deltas to the frontend.
|
||||
|
||||
Uses a ``jsonriver.Parser`` per tool-call index to incrementally parse
|
||||
the JSON argument string and extract only the newly-appended content
|
||||
for each string-valued argument.
|
||||
|
||||
NOTE: Non-string arguments (numbers, booleans, null, arrays, objects)
|
||||
are skipped — they are available in the final tool-call kickoff packet.
|
||||
|
||||
``parsers`` is a mutable dict keyed by tool-call index. A new
|
||||
``Parser`` is created automatically for each new index.
|
||||
"""
|
||||
tool_cls = _get_tool_class(tool_calls_in_progress, tool_call_delta)
|
||||
if not tool_cls or not tool_cls.should_emit_argument_deltas():
|
||||
return
|
||||
|
||||
fn = tool_call_delta.function
|
||||
delta_fragment = fn.arguments if fn else None
|
||||
if not delta_fragment:
|
||||
return
|
||||
|
||||
idx = tool_call_delta.index
|
||||
if idx not in parsers:
|
||||
parsers[idx] = Parser()
|
||||
parser = parsers[idx]
|
||||
|
||||
deltas = parser.feed(delta_fragment)
|
||||
|
||||
argument_deltas: dict[str, str] = {}
|
||||
for delta in deltas:
|
||||
if isinstance(delta, dict):
|
||||
for key, value in delta.items():
|
||||
if isinstance(value, str):
|
||||
argument_deltas[key] = argument_deltas.get(key, "") + value
|
||||
|
||||
if not argument_deltas:
|
||||
return
|
||||
|
||||
tc_data = tool_calls_in_progress[tool_call_delta.index]
|
||||
yield Packet(
|
||||
placement=placement,
|
||||
obj=ToolCallArgumentDelta(
|
||||
tool_type=tc_data.get("name", ""),
|
||||
argument_deltas=argument_deltas,
|
||||
),
|
||||
)
|
||||
@@ -68,6 +68,10 @@ FILE_TOKEN_COUNT_THRESHOLD = int(
|
||||
os.environ.get("FILE_TOKEN_COUNT_THRESHOLD", str(_DEFAULT_FILE_TOKEN_LIMIT))
|
||||
)
|
||||
|
||||
# Maximum upload size for a single user file (chat/projects) in MB.
|
||||
USER_FILE_MAX_UPLOAD_SIZE_MB = int(os.environ.get("USER_FILE_MAX_UPLOAD_SIZE_MB") or 50)
|
||||
USER_FILE_MAX_UPLOAD_SIZE_BYTES = USER_FILE_MAX_UPLOAD_SIZE_MB * 1024 * 1024
|
||||
|
||||
# If set to true, will show extra/uncommon connectors in the "Other" category
|
||||
SHOW_EXTRA_CONNECTORS = os.environ.get("SHOW_EXTRA_CONNECTORS", "").lower() == "true"
|
||||
|
||||
|
||||
161
backend/onyx/db/rotate_encryption_key.py
Normal file
161
backend/onyx/db/rotate_encryption_key.py
Normal file
@@ -0,0 +1,161 @@
|
||||
"""Rotate encryption key for all encrypted columns.
|
||||
|
||||
Dynamically discovers all columns using EncryptedString / EncryptedJson,
|
||||
decrypts each value with the old key, and re-encrypts with the current
|
||||
ENCRYPTION_KEY_SECRET.
|
||||
|
||||
The operation is idempotent: rows already encrypted with the current key
|
||||
are skipped. Commits are made in batches so a crash mid-rotation can be
|
||||
safely resumed by re-running.
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import LargeBinary
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import ENCRYPTION_KEY_SECRET
|
||||
from onyx.db.models import Base
|
||||
from onyx.db.models import EncryptedJson
|
||||
from onyx.db.models import EncryptedString
|
||||
from onyx.utils.encryption import decrypt_bytes_to_string
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_BATCH_SIZE = 500
|
||||
|
||||
|
||||
def _can_decrypt_with_current_key(data: bytes) -> bool:
|
||||
"""Check if data is already encrypted with the current key.
|
||||
|
||||
Passes the key explicitly so the fallback-to-raw-decode path in
|
||||
_decrypt_bytes is NOT triggered — a clean success/failure signal.
|
||||
"""
|
||||
try:
|
||||
decrypt_bytes_to_string(data, key=ENCRYPTION_KEY_SECRET)
|
||||
return True
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
|
||||
def _discover_encrypted_columns() -> list[tuple[type, str, list[str], bool]]:
|
||||
"""Walk all ORM models and find columns using EncryptedString/EncryptedJson.
|
||||
|
||||
Returns list of (ModelClass, column_attr_name, [pk_attr_names], is_json).
|
||||
"""
|
||||
results: list[tuple[type, str, list[str], bool]] = []
|
||||
|
||||
for mapper in Base.registry.mappers:
|
||||
model_cls = mapper.class_
|
||||
pk_names = [col.key for col in mapper.primary_key]
|
||||
|
||||
for prop in mapper.column_attrs:
|
||||
for col in prop.columns:
|
||||
if isinstance(col.type, EncryptedJson):
|
||||
results.append((model_cls, prop.key, pk_names, True))
|
||||
elif isinstance(col.type, EncryptedString):
|
||||
results.append((model_cls, prop.key, pk_names, False))
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def rotate_encryption_key(
|
||||
db_session: Session,
|
||||
old_key: str | None,
|
||||
dry_run: bool = False,
|
||||
) -> dict[str, int]:
|
||||
"""Decrypt all encrypted columns with old_key and re-encrypt with the current key.
|
||||
|
||||
Args:
|
||||
db_session: Active database session.
|
||||
old_key: The previous encryption key. Pass None or "" if values were
|
||||
not previously encrypted with a key.
|
||||
dry_run: If True, count rows that need rotation without modifying data.
|
||||
|
||||
Returns:
|
||||
Dict of "table.column" -> number of rows re-encrypted (or would be).
|
||||
|
||||
Commits every _BATCH_SIZE rows so that locks are held briefly and progress
|
||||
is preserved on crash. Already-rotated rows are detected and skipped,
|
||||
making the operation safe to re-run.
|
||||
"""
|
||||
if not global_version.is_ee_version():
|
||||
raise RuntimeError("EE mode is not enabled — rotation requires EE encryption.")
|
||||
|
||||
if not ENCRYPTION_KEY_SECRET:
|
||||
raise RuntimeError(
|
||||
"ENCRYPTION_KEY_SECRET is not set — cannot rotate. "
|
||||
"Set the target encryption key in the environment before running."
|
||||
)
|
||||
|
||||
encrypted_columns = _discover_encrypted_columns()
|
||||
totals: dict[str, int] = {}
|
||||
|
||||
for model_cls, col_name, pk_names, is_json in encrypted_columns:
|
||||
table_name: str = model_cls.__tablename__ # type: ignore[attr-defined]
|
||||
col_attr = getattr(model_cls, col_name)
|
||||
pk_attrs = [getattr(model_cls, pk) for pk in pk_names]
|
||||
|
||||
# Read raw bytes directly, bypassing the TypeDecorator
|
||||
raw_col = col_attr.property.columns[0]
|
||||
|
||||
stmt = select(*pk_attrs, raw_col.cast(LargeBinary)).where(col_attr.is_not(None))
|
||||
rows = db_session.execute(stmt).all()
|
||||
|
||||
reencrypted = 0
|
||||
batch_pending = 0
|
||||
for row in rows:
|
||||
raw_bytes: bytes | None = row[-1]
|
||||
if raw_bytes is None:
|
||||
continue
|
||||
|
||||
if _can_decrypt_with_current_key(raw_bytes):
|
||||
continue
|
||||
|
||||
try:
|
||||
if not old_key:
|
||||
decrypted_str = raw_bytes.decode("utf-8")
|
||||
else:
|
||||
decrypted_str = decrypt_bytes_to_string(raw_bytes, key=old_key)
|
||||
|
||||
# For EncryptedJson, parse back to dict so the TypeDecorator
|
||||
# can json.dumps() it cleanly (avoids double-encoding).
|
||||
value: Any = json.loads(decrypted_str) if is_json else decrypted_str
|
||||
except (ValueError, UnicodeDecodeError) as e:
|
||||
pk_vals = [row[i] for i in range(len(pk_names))]
|
||||
logger.warning(
|
||||
f"Could not decrypt/parse {table_name}.{col_name} "
|
||||
f"row {pk_vals} — skipping: {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
if not dry_run:
|
||||
pk_filters = [pk_attr == row[i] for i, pk_attr in enumerate(pk_attrs)]
|
||||
update_stmt = (
|
||||
update(model_cls).where(*pk_filters).values({col_name: value})
|
||||
)
|
||||
db_session.execute(update_stmt)
|
||||
batch_pending += 1
|
||||
|
||||
if batch_pending >= _BATCH_SIZE:
|
||||
db_session.commit()
|
||||
batch_pending = 0
|
||||
reencrypted += 1
|
||||
|
||||
# Flush remaining rows in this column
|
||||
if batch_pending > 0:
|
||||
db_session.commit()
|
||||
|
||||
if reencrypted > 0:
|
||||
totals[f"{table_name}.{col_name}"] = reencrypted
|
||||
logger.info(
|
||||
f"{'[DRY RUN] Would re-encrypt' if dry_run else 'Re-encrypted'} "
|
||||
f"{reencrypted} value(s) in {table_name}.{col_name}"
|
||||
)
|
||||
|
||||
return totals
|
||||
@@ -10,6 +10,8 @@ from pydantic import Field
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import FILE_TOKEN_COUNT_THRESHOLD
|
||||
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_BYTES
|
||||
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
from onyx.db.llm import fetch_default_llm_model
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
@@ -35,6 +37,38 @@ def get_safe_filename(upload: UploadFile) -> str:
|
||||
return upload.filename
|
||||
|
||||
|
||||
def get_upload_size_bytes(upload: UploadFile) -> int | None:
|
||||
"""Best-effort file size in bytes without consuming the stream."""
|
||||
if upload.size is not None:
|
||||
return upload.size
|
||||
|
||||
try:
|
||||
current_pos = upload.file.tell()
|
||||
upload.file.seek(0, 2)
|
||||
size = upload.file.tell()
|
||||
upload.file.seek(current_pos)
|
||||
return size
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Could not determine upload size via stream seek "
|
||||
f"(filename='{get_safe_filename(upload)}', "
|
||||
f"error_type={type(e).__name__}, error={e})"
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
def is_upload_too_large(upload: UploadFile, max_bytes: int) -> bool:
|
||||
"""Return True when upload size is known and exceeds max_bytes."""
|
||||
size_bytes = get_upload_size_bytes(upload)
|
||||
if size_bytes is None:
|
||||
logger.warning(
|
||||
"Could not determine upload size; skipping size-limit check for "
|
||||
f"'{get_safe_filename(upload)}'"
|
||||
)
|
||||
return False
|
||||
return size_bytes > max_bytes
|
||||
|
||||
|
||||
# Guard against extremely large images
|
||||
Image.MAX_IMAGE_PIXELS = 12000 * 12000
|
||||
|
||||
@@ -159,6 +193,18 @@ def categorize_uploaded_files(
|
||||
for upload in files:
|
||||
try:
|
||||
filename = get_safe_filename(upload)
|
||||
|
||||
# Size limit is a hard safety cap and is enforced even when token
|
||||
# threshold checks are skipped via SKIP_USERFILE_THRESHOLD settings.
|
||||
if is_upload_too_large(upload, USER_FILE_MAX_UPLOAD_SIZE_BYTES):
|
||||
results.rejected.append(
|
||||
RejectedFile(
|
||||
filename=filename,
|
||||
reason=f"Exceeds {USER_FILE_MAX_UPLOAD_SIZE_MB} MB file size limit",
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
extension = get_file_ext(filename)
|
||||
|
||||
# If image, estimate tokens via dedicated method first
|
||||
|
||||
@@ -41,6 +41,7 @@ class StreamingType(Enum):
|
||||
REASONING_DONE = "reasoning_done"
|
||||
CITATION_INFO = "citation_info"
|
||||
TOOL_CALL_DEBUG = "tool_call_debug"
|
||||
TOOL_CALL_ARGUMENT_DELTA = "tool_call_argument_delta"
|
||||
|
||||
MEMORY_TOOL_START = "memory_tool_start"
|
||||
MEMORY_TOOL_DELTA = "memory_tool_delta"
|
||||
@@ -259,6 +260,15 @@ class CustomToolDelta(BaseObj):
|
||||
file_ids: list[str] | None = None
|
||||
|
||||
|
||||
class ToolCallArgumentDelta(BaseObj):
|
||||
type: Literal["tool_call_argument_delta"] = (
|
||||
StreamingType.TOOL_CALL_ARGUMENT_DELTA.value
|
||||
)
|
||||
|
||||
tool_type: str
|
||||
argument_deltas: dict[str, Any]
|
||||
|
||||
|
||||
################################################
|
||||
# File Reader Packets
|
||||
################################################
|
||||
@@ -379,6 +389,7 @@ PacketObj = Union[
|
||||
# Citation Packets
|
||||
CitationInfo,
|
||||
ToolCallDebug,
|
||||
ToolCallArgumentDelta,
|
||||
# Deep Research Packets
|
||||
DeepResearchPlanStart,
|
||||
DeepResearchPlanDelta,
|
||||
|
||||
@@ -78,6 +78,7 @@ class Settings(BaseModel):
|
||||
|
||||
# User Knowledge settings
|
||||
user_knowledge_enabled: bool | None = True
|
||||
user_file_max_upload_size_mb: int | None = None
|
||||
|
||||
# Connector settings
|
||||
show_extra_connectors: bool | None = True
|
||||
|
||||
@@ -3,6 +3,7 @@ from onyx.configs.app_configs import DISABLE_USER_KNOWLEDGE
|
||||
from onyx.configs.app_configs import ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
|
||||
from onyx.configs.app_configs import SHOW_EXTRA_CONNECTORS
|
||||
from onyx.configs.app_configs import USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
from onyx.configs.constants import KV_SETTINGS_KEY
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
@@ -50,6 +51,7 @@ def load_settings() -> Settings:
|
||||
if DISABLE_USER_KNOWLEDGE:
|
||||
settings.user_knowledge_enabled = False
|
||||
|
||||
settings.user_file_max_upload_size_mb = USER_FILE_MAX_UPLOAD_SIZE_MB
|
||||
settings.show_extra_connectors = SHOW_EXTRA_CONNECTORS
|
||||
settings.opensearch_indexing_enabled = ENABLE_OPENSEARCH_INDEXING_FOR_ONYX
|
||||
return settings
|
||||
|
||||
@@ -56,3 +56,23 @@ def get_built_in_tool_ids() -> list[str]:
|
||||
|
||||
def get_built_in_tool_by_id(in_code_tool_id: str) -> Type[BUILT_IN_TOOL_TYPES]:
|
||||
return BUILT_IN_TOOL_MAP[in_code_tool_id]
|
||||
|
||||
|
||||
def _build_tool_name_to_class() -> dict[str, Type[BUILT_IN_TOOL_TYPES]]:
|
||||
"""Build a mapping from LLM-facing tool name to tool class."""
|
||||
result: dict[str, Type[BUILT_IN_TOOL_TYPES]] = {}
|
||||
for cls in BUILT_IN_TOOL_MAP.values():
|
||||
name_attr = cls.__dict__.get("name")
|
||||
if isinstance(name_attr, property) and name_attr.fget is not None:
|
||||
tool_name = name_attr.fget(cls)
|
||||
elif isinstance(name_attr, str):
|
||||
tool_name = name_attr
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Built-in tool {cls.__name__} must define a valid LLM-facing tool name"
|
||||
)
|
||||
result[tool_name] = cls
|
||||
return result
|
||||
|
||||
|
||||
TOOL_NAME_TO_CLASS: dict[str, Type[BUILT_IN_TOOL_TYPES]] = _build_tool_name_to_class()
|
||||
|
||||
@@ -92,3 +92,7 @@ class Tool(abc.ABC, Generic[TOverride]):
|
||||
**llm_kwargs: Any,
|
||||
) -> ToolResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def should_emit_argument_deltas(cls) -> bool:
|
||||
return False
|
||||
|
||||
@@ -376,3 +376,8 @@ class PythonTool(Tool[PythonToolOverrideKwargs]):
|
||||
rich_response=None,
|
||||
llm_facing_response=llm_response,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@override
|
||||
def should_emit_argument_deltas(cls) -> bool:
|
||||
return True
|
||||
|
||||
@@ -11,16 +11,20 @@ logger = setup_logger()
|
||||
|
||||
|
||||
# IMPORTANT DO NOT DELETE, THIS IS USED BY fetch_versioned_implementation
|
||||
def _encrypt_string(input_str: str) -> bytes:
|
||||
def _encrypt_string(input_str: str, key: str | None = None) -> bytes: # noqa: ARG001
|
||||
if ENCRYPTION_KEY_SECRET:
|
||||
logger.warning("MIT version of Onyx does not support encryption of secrets.")
|
||||
elif key is not None:
|
||||
logger.debug("MIT encrypt called with explicit key — key ignored.")
|
||||
return input_str.encode()
|
||||
|
||||
|
||||
# IMPORTANT DO NOT DELETE, THIS IS USED BY fetch_versioned_implementation
|
||||
def _decrypt_bytes(input_bytes: bytes) -> str:
|
||||
# No need to double warn. If you wish to learn more about encryption features
|
||||
# refer to the Onyx EE code
|
||||
def _decrypt_bytes(input_bytes: bytes, key: str | None = None) -> str: # noqa: ARG001
|
||||
if ENCRYPTION_KEY_SECRET:
|
||||
logger.warning("MIT version of Onyx does not support decryption of secrets.")
|
||||
elif key is not None:
|
||||
logger.debug("MIT decrypt called with explicit key — key ignored.")
|
||||
return input_bytes.decode()
|
||||
|
||||
|
||||
@@ -86,15 +90,15 @@ def _mask_list(items: list[Any]) -> list[Any]:
|
||||
return masked
|
||||
|
||||
|
||||
def encrypt_string_to_bytes(intput_str: str) -> bytes:
|
||||
def encrypt_string_to_bytes(intput_str: str, key: str | None = None) -> bytes:
|
||||
versioned_encryption_fn = fetch_versioned_implementation(
|
||||
"onyx.utils.encryption", "_encrypt_string"
|
||||
)
|
||||
return versioned_encryption_fn(intput_str)
|
||||
return versioned_encryption_fn(intput_str, key=key)
|
||||
|
||||
|
||||
def decrypt_bytes_to_string(intput_bytes: bytes) -> str:
|
||||
def decrypt_bytes_to_string(intput_bytes: bytes, key: str | None = None) -> str:
|
||||
versioned_decryption_fn = fetch_versioned_implementation(
|
||||
"onyx.utils.encryption", "_decrypt_bytes"
|
||||
)
|
||||
return versioned_decryption_fn(intput_bytes)
|
||||
return versioned_decryption_fn(intput_bytes, key=key)
|
||||
|
||||
107
backend/scripts/reencrypt_secrets.py
Normal file
107
backend/scripts/reencrypt_secrets.py
Normal file
@@ -0,0 +1,107 @@
|
||||
"""Re-encrypt secrets under the current ENCRYPTION_KEY_SECRET.
|
||||
|
||||
Decrypts all encrypted columns using the old key (or raw decode if the old key
|
||||
is empty), then re-encrypts them with the current ENCRYPTION_KEY_SECRET.
|
||||
|
||||
Usage (docker):
|
||||
docker exec -it onyx-api_server-1 \
|
||||
python -m scripts.reencrypt_secrets --old-key "previous-key"
|
||||
|
||||
Usage (kubernetes):
|
||||
kubectl exec -it <pod> -- \
|
||||
python -m scripts.reencrypt_secrets --old-key "previous-key"
|
||||
|
||||
Omit --old-key (or pass "") if secrets were not previously encrypted.
|
||||
|
||||
For multi-tenant deployments, pass --tenant-id to target a specific tenant,
|
||||
or --all-tenants to iterate every tenant.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import os
|
||||
import sys
|
||||
|
||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
sys.path.append(parent_dir)
|
||||
|
||||
from onyx.db.rotate_encryption_key import rotate_encryption_key # noqa: E402
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant # noqa: E402
|
||||
from onyx.db.engine.sql_engine import SqlEngine # noqa: E402
|
||||
from onyx.db.engine.tenant_utils import get_all_tenant_ids # noqa: E402
|
||||
from onyx.utils.variable_functionality import global_version # noqa: E402
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA # noqa: E402
|
||||
|
||||
|
||||
def _run_for_tenant(tenant_id: str, old_key: str | None, dry_run: bool = False) -> None:
|
||||
print(f"Re-encrypting secrets for tenant: {tenant_id}")
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
results = rotate_encryption_key(db_session, old_key=old_key, dry_run=dry_run)
|
||||
|
||||
if results:
|
||||
for col, count in results.items():
|
||||
print(
|
||||
f" {col}: {count} row(s) {'would be ' if dry_run else ''}re-encrypted"
|
||||
)
|
||||
else:
|
||||
print("No rows needed re-encryption.")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Re-encrypt secrets under the current encryption key."
|
||||
)
|
||||
parser.add_argument(
|
||||
"--old-key",
|
||||
default=None,
|
||||
help="Previous encryption key. Omit or pass empty string if not applicable.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--dry-run",
|
||||
action="store_true",
|
||||
help="Show what would be re-encrypted without making changes.",
|
||||
)
|
||||
|
||||
tenant_group = parser.add_mutually_exclusive_group()
|
||||
tenant_group.add_argument(
|
||||
"--tenant-id",
|
||||
default=None,
|
||||
help="Target a specific tenant schema.",
|
||||
)
|
||||
tenant_group.add_argument(
|
||||
"--all-tenants",
|
||||
action="store_true",
|
||||
help="Iterate all tenants.",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
old_key = args.old_key if args.old_key else None
|
||||
|
||||
global_version.set_ee()
|
||||
SqlEngine.init_engine(pool_size=5, max_overflow=2)
|
||||
|
||||
if args.dry_run:
|
||||
print("DRY RUN — no changes will be made")
|
||||
|
||||
if args.all_tenants:
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
print(f"Found {len(tenant_ids)} tenant(s)")
|
||||
failed_tenants: list[str] = []
|
||||
for tid in tenant_ids:
|
||||
try:
|
||||
_run_for_tenant(tid, old_key, dry_run=args.dry_run)
|
||||
except Exception as e:
|
||||
print(f" ERROR for tenant {tid}: {e}")
|
||||
failed_tenants.append(tid)
|
||||
if failed_tenants:
|
||||
print(f"FAILED tenants ({len(failed_tenants)}): {failed_tenants}")
|
||||
sys.exit(1)
|
||||
else:
|
||||
tenant_id = args.tenant_id or POSTGRES_DEFAULT_SCHEMA
|
||||
_run_for_tenant(tenant_id, old_key, dry_run=args.dry_run)
|
||||
|
||||
print("Done.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,305 @@
|
||||
"""Tests for rotate_encryption_key against real Postgres.
|
||||
|
||||
Uses real ORM models (Credential, InternetSearchProvider) and the actual
|
||||
Postgres database. Discovery is mocked in rotation tests to scope mutations
|
||||
to only the test rows — the real _discover_encrypted_columns walk is tested
|
||||
separately in TestDiscoverEncryptedColumns.
|
||||
|
||||
Requires a running Postgres instance. Run with::
|
||||
|
||||
python -m dotenv -f .vscode/.env run -- pytest tests/external_dependency_unit/db/test_rotate_encryption_key.py
|
||||
"""
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import LargeBinary
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.utils.encryption import _decrypt_bytes
|
||||
from ee.onyx.utils.encryption import _encrypt_string
|
||||
from ee.onyx.utils.encryption import _get_trimmed_key
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.models import Credential
|
||||
from onyx.db.models import EncryptedJson
|
||||
from onyx.db.models import EncryptedString
|
||||
from onyx.db.models import InternetSearchProvider
|
||||
from onyx.db.rotate_encryption_key import _discover_encrypted_columns
|
||||
from onyx.db.rotate_encryption_key import rotate_encryption_key
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
|
||||
EE_MODULE = "ee.onyx.utils.encryption"
|
||||
ROTATE_MODULE = "onyx.db.rotate_encryption_key"
|
||||
|
||||
OLD_KEY = "o" * 16
|
||||
NEW_KEY = "n" * 16
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _enable_ee() -> Generator[None, None, None]:
|
||||
prev = global_version._is_ee
|
||||
global_version.set_ee()
|
||||
fetch_versioned_implementation.cache_clear()
|
||||
yield
|
||||
global_version._is_ee = prev
|
||||
fetch_versioned_implementation.cache_clear()
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_key_cache() -> None:
|
||||
_get_trimmed_key.cache_clear()
|
||||
|
||||
|
||||
def _raw_credential_bytes(db_session: Session, credential_id: int) -> bytes | None:
|
||||
"""Read raw bytes from credential_json, bypassing the TypeDecorator."""
|
||||
col = Credential.__table__.c.credential_json
|
||||
stmt = select(col.cast(LargeBinary)).where(
|
||||
Credential.__table__.c.id == credential_id
|
||||
)
|
||||
return db_session.execute(stmt).scalar()
|
||||
|
||||
|
||||
def _raw_isp_bytes(db_session: Session, isp_id: int) -> bytes | None:
|
||||
"""Read raw bytes from InternetSearchProvider.api_key."""
|
||||
col = InternetSearchProvider.__table__.c.api_key
|
||||
stmt = select(col.cast(LargeBinary)).where(
|
||||
InternetSearchProvider.__table__.c.id == isp_id
|
||||
)
|
||||
return db_session.execute(stmt).scalar()
|
||||
|
||||
|
||||
class TestDiscoverEncryptedColumns:
|
||||
"""Verify _discover_encrypted_columns finds real production models."""
|
||||
|
||||
def test_discovers_credential_json(self) -> None:
|
||||
results = _discover_encrypted_columns()
|
||||
found = {
|
||||
(model_cls.__tablename__, col_name, is_json) # type: ignore[attr-defined]
|
||||
for model_cls, col_name, _, is_json in results
|
||||
}
|
||||
assert ("credential", "credential_json", True) in found
|
||||
|
||||
def test_discovers_internet_search_provider_api_key(self) -> None:
|
||||
results = _discover_encrypted_columns()
|
||||
found = {
|
||||
(model_cls.__tablename__, col_name, is_json) # type: ignore[attr-defined]
|
||||
for model_cls, col_name, _, is_json in results
|
||||
}
|
||||
assert ("internet_search_provider", "api_key", False) in found
|
||||
|
||||
def test_all_encrypted_string_columns_are_not_json(self) -> None:
|
||||
results = _discover_encrypted_columns()
|
||||
for model_cls, col_name, _, is_json in results:
|
||||
col = getattr(model_cls, col_name).property.columns[0]
|
||||
if isinstance(col.type, EncryptedString):
|
||||
assert not is_json, (
|
||||
f"{model_cls.__tablename__}.{col_name} is EncryptedString " # type: ignore[attr-defined]
|
||||
f"but is_json={is_json}"
|
||||
)
|
||||
|
||||
def test_all_encrypted_json_columns_are_json(self) -> None:
|
||||
results = _discover_encrypted_columns()
|
||||
for model_cls, col_name, _, is_json in results:
|
||||
col = getattr(model_cls, col_name).property.columns[0]
|
||||
if isinstance(col.type, EncryptedJson):
|
||||
assert is_json, (
|
||||
f"{model_cls.__tablename__}.{col_name} is EncryptedJson " # type: ignore[attr-defined]
|
||||
f"but is_json={is_json}"
|
||||
)
|
||||
|
||||
|
||||
class TestRotateCredential:
|
||||
"""Test rotation against the real Credential table (EncryptedJson).
|
||||
|
||||
Discovery is scoped to only the Credential model to avoid mutating
|
||||
other tables in the test database.
|
||||
"""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _limit_discovery(self) -> Generator[None, None, None]:
|
||||
with patch(
|
||||
f"{ROTATE_MODULE}._discover_encrypted_columns",
|
||||
return_value=[(Credential, "credential_json", ["id"], True)],
|
||||
):
|
||||
yield
|
||||
|
||||
@pytest.fixture()
|
||||
def credential_id(
|
||||
self, db_session: Session, tenant_context: None # noqa: ARG002
|
||||
) -> Generator[int, None, None]:
|
||||
"""Insert a Credential row with raw encrypted bytes, clean up after."""
|
||||
config = {"api_key": "sk-test-1234", "endpoint": "https://example.com"}
|
||||
encrypted = _encrypt_string(json.dumps(config), key=OLD_KEY)
|
||||
|
||||
result = db_session.execute(
|
||||
text(
|
||||
"INSERT INTO credential "
|
||||
"(source, credential_json, admin_public, curator_public) "
|
||||
"VALUES (:source, :cred_json, true, false) "
|
||||
"RETURNING id"
|
||||
),
|
||||
{"source": DocumentSource.INGESTION_API.value, "cred_json": encrypted},
|
||||
)
|
||||
cred_id = result.scalar_one()
|
||||
db_session.commit()
|
||||
|
||||
yield cred_id
|
||||
|
||||
db_session.execute(
|
||||
text("DELETE FROM credential WHERE id = :id"), {"id": cred_id}
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
def test_rotates_credential_json(
|
||||
self, db_session: Session, credential_id: int
|
||||
) -> None:
|
||||
with (
|
||||
patch(f"{ROTATE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
|
||||
patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
|
||||
):
|
||||
totals = rotate_encryption_key(db_session, old_key=OLD_KEY)
|
||||
|
||||
assert totals.get("credential.credential_json", 0) >= 1
|
||||
|
||||
raw = _raw_credential_bytes(db_session, credential_id)
|
||||
assert raw is not None
|
||||
decrypted = json.loads(_decrypt_bytes(raw, key=NEW_KEY))
|
||||
assert decrypted["api_key"] == "sk-test-1234"
|
||||
assert decrypted["endpoint"] == "https://example.com"
|
||||
|
||||
def test_skips_already_rotated(
|
||||
self, db_session: Session, credential_id: int
|
||||
) -> None:
|
||||
with (
|
||||
patch(f"{ROTATE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
|
||||
patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
|
||||
):
|
||||
rotate_encryption_key(db_session, old_key=OLD_KEY)
|
||||
_ = rotate_encryption_key(db_session, old_key=OLD_KEY)
|
||||
|
||||
raw = _raw_credential_bytes(db_session, credential_id)
|
||||
assert raw is not None
|
||||
decrypted = json.loads(_decrypt_bytes(raw, key=NEW_KEY))
|
||||
assert decrypted["api_key"] == "sk-test-1234"
|
||||
|
||||
def test_dry_run_does_not_modify(
|
||||
self, db_session: Session, credential_id: int
|
||||
) -> None:
|
||||
original = _raw_credential_bytes(db_session, credential_id)
|
||||
|
||||
with (
|
||||
patch(f"{ROTATE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
|
||||
patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
|
||||
):
|
||||
totals = rotate_encryption_key(db_session, old_key=OLD_KEY, dry_run=True)
|
||||
|
||||
assert totals.get("credential.credential_json", 0) >= 1
|
||||
|
||||
raw_after = _raw_credential_bytes(db_session, credential_id)
|
||||
assert raw_after == original
|
||||
|
||||
|
||||
class TestRotateInternetSearchProvider:
|
||||
"""Test rotation against the real InternetSearchProvider table (EncryptedString).
|
||||
|
||||
Discovery is scoped to only the InternetSearchProvider model to avoid
|
||||
mutating other tables in the test database.
|
||||
"""
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _limit_discovery(self) -> Generator[None, None, None]:
|
||||
with patch(
|
||||
f"{ROTATE_MODULE}._discover_encrypted_columns",
|
||||
return_value=[
|
||||
(InternetSearchProvider, "api_key", ["id"], False),
|
||||
],
|
||||
):
|
||||
yield
|
||||
|
||||
@pytest.fixture()
|
||||
def isp_id(
|
||||
self, db_session: Session, tenant_context: None # noqa: ARG002
|
||||
) -> Generator[int, None, None]:
|
||||
"""Insert an InternetSearchProvider row with raw encrypted bytes."""
|
||||
encrypted = _encrypt_string("sk-secret-api-key", key=OLD_KEY)
|
||||
|
||||
result = db_session.execute(
|
||||
text(
|
||||
"INSERT INTO internet_search_provider "
|
||||
"(name, provider_type, api_key, is_active) "
|
||||
"VALUES (:name, :ptype, :api_key, false) "
|
||||
"RETURNING id"
|
||||
),
|
||||
{
|
||||
"name": f"test-rotation-{id(self)}",
|
||||
"ptype": "test",
|
||||
"api_key": encrypted,
|
||||
},
|
||||
)
|
||||
isp_id = result.scalar_one()
|
||||
db_session.commit()
|
||||
|
||||
yield isp_id
|
||||
|
||||
db_session.execute(
|
||||
text("DELETE FROM internet_search_provider WHERE id = :id"),
|
||||
{"id": isp_id},
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
def test_rotates_api_key(self, db_session: Session, isp_id: int) -> None:
|
||||
with (
|
||||
patch(f"{ROTATE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
|
||||
patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
|
||||
):
|
||||
totals = rotate_encryption_key(db_session, old_key=OLD_KEY)
|
||||
|
||||
assert totals.get("internet_search_provider.api_key", 0) >= 1
|
||||
|
||||
raw = _raw_isp_bytes(db_session, isp_id)
|
||||
assert raw is not None
|
||||
assert _decrypt_bytes(raw, key=NEW_KEY) == "sk-secret-api-key"
|
||||
|
||||
def test_rotates_from_unencrypted(
|
||||
self, db_session: Session, tenant_context: None # noqa: ARG002
|
||||
) -> None:
|
||||
"""Test rotating data that was stored without any encryption key."""
|
||||
result = db_session.execute(
|
||||
text(
|
||||
"INSERT INTO internet_search_provider "
|
||||
"(name, provider_type, api_key, is_active) "
|
||||
"VALUES (:name, :ptype, :api_key, false) "
|
||||
"RETURNING id"
|
||||
),
|
||||
{
|
||||
"name": f"test-raw-{id(self)}",
|
||||
"ptype": "test",
|
||||
"api_key": b"raw-api-key",
|
||||
},
|
||||
)
|
||||
isp_id = result.scalar_one()
|
||||
db_session.commit()
|
||||
|
||||
try:
|
||||
with (
|
||||
patch(f"{ROTATE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
|
||||
patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", NEW_KEY),
|
||||
):
|
||||
totals = rotate_encryption_key(db_session, old_key=None)
|
||||
|
||||
assert totals.get("internet_search_provider.api_key", 0) >= 1
|
||||
|
||||
raw = _raw_isp_bytes(db_session, isp_id)
|
||||
assert raw is not None
|
||||
assert _decrypt_bytes(raw, key=NEW_KEY) == "raw-api-key"
|
||||
finally:
|
||||
db_session.execute(
|
||||
text("DELETE FROM internet_search_provider WHERE id = :id"),
|
||||
{"id": isp_id},
|
||||
)
|
||||
db_session.commit()
|
||||
@@ -950,6 +950,7 @@ from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import PythonToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import PythonToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
from onyx.server.query_and_chat.streaming_models import ToolCallArgumentDelta
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
from tests.external_dependency_unit.answer.stream_test_builder import StreamTestBuilder
|
||||
from tests.external_dependency_unit.answer.stream_test_utils import create_chat_session
|
||||
@@ -1294,9 +1295,18 @@ def test_code_interpreter_replay_packets_include_code_and_output(
|
||||
).expect(
|
||||
Packet(
|
||||
placement=create_placement(0),
|
||||
obj=PythonToolStart(code=code),
|
||||
obj=ToolCallArgumentDelta(
|
||||
tool_type="python",
|
||||
argument_deltas={"code": code},
|
||||
),
|
||||
),
|
||||
forward=2,
|
||||
).expect(
|
||||
Packet(
|
||||
placement=create_placement(0),
|
||||
obj=PythonToolStart(code=code),
|
||||
),
|
||||
forward=False,
|
||||
).expect(
|
||||
Packet(
|
||||
placement=create_placement(0),
|
||||
|
||||
165
backend/tests/unit/ee/onyx/utils/test_encryption.py
Normal file
165
backend/tests/unit/ee/onyx/utils/test_encryption.py
Normal file
@@ -0,0 +1,165 @@
|
||||
"""Tests for EE AES-CBC encryption/decryption with explicit key support.
|
||||
|
||||
With EE mode enabled (via conftest), fetch_versioned_implementation resolves
|
||||
to the EE implementations, so no patching of the MIT layer is needed.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from ee.onyx.utils.encryption import _decrypt_bytes
|
||||
from ee.onyx.utils.encryption import _encrypt_string
|
||||
from ee.onyx.utils.encryption import _get_trimmed_key
|
||||
from ee.onyx.utils.encryption import decrypt_bytes_to_string
|
||||
from ee.onyx.utils.encryption import encrypt_string_to_bytes
|
||||
|
||||
EE_MODULE = "ee.onyx.utils.encryption"
|
||||
|
||||
# Keys must be exactly 16, 24, or 32 bytes for AES
|
||||
KEY_16 = "a" * 16
|
||||
KEY_16_ALT = "b" * 16
|
||||
KEY_24 = "d" * 24
|
||||
KEY_32 = "c" * 32
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _clear_key_cache() -> None:
|
||||
_get_trimmed_key.cache_clear()
|
||||
|
||||
|
||||
class TestEncryptDecryptRoundTrip:
|
||||
def test_roundtrip_with_env_key(self) -> None:
|
||||
with patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", KEY_16):
|
||||
encrypted = _encrypt_string("hello world")
|
||||
assert encrypted != b"hello world"
|
||||
assert _decrypt_bytes(encrypted) == "hello world"
|
||||
|
||||
def test_roundtrip_with_explicit_key(self) -> None:
|
||||
encrypted = _encrypt_string("secret data", key=KEY_32)
|
||||
assert encrypted != b"secret data"
|
||||
assert _decrypt_bytes(encrypted, key=KEY_32) == "secret data"
|
||||
|
||||
def test_roundtrip_no_key(self) -> None:
|
||||
"""Without any key, data is raw-encoded (no encryption)."""
|
||||
with patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", ""):
|
||||
encrypted = _encrypt_string("plain text")
|
||||
assert encrypted == b"plain text"
|
||||
assert _decrypt_bytes(encrypted) == "plain text"
|
||||
|
||||
def test_explicit_key_overrides_env(self) -> None:
|
||||
with patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", KEY_16):
|
||||
encrypted = _encrypt_string("data", key=KEY_16_ALT)
|
||||
with pytest.raises(ValueError):
|
||||
_decrypt_bytes(encrypted, key=KEY_16)
|
||||
assert _decrypt_bytes(encrypted, key=KEY_16_ALT) == "data"
|
||||
|
||||
def test_different_encryptions_produce_different_bytes(self) -> None:
|
||||
"""Each encryption uses a random IV, so results differ."""
|
||||
a = _encrypt_string("same", key=KEY_16)
|
||||
b = _encrypt_string("same", key=KEY_16)
|
||||
assert a != b
|
||||
|
||||
def test_roundtrip_empty_string(self) -> None:
|
||||
encrypted = _encrypt_string("", key=KEY_16)
|
||||
assert encrypted != b""
|
||||
assert _decrypt_bytes(encrypted, key=KEY_16) == ""
|
||||
|
||||
def test_roundtrip_unicode(self) -> None:
|
||||
text = "日本語テスト 🔐 émojis"
|
||||
encrypted = _encrypt_string(text, key=KEY_16)
|
||||
assert _decrypt_bytes(encrypted, key=KEY_16) == text
|
||||
|
||||
|
||||
class TestDecryptFallbackBehavior:
|
||||
def test_wrong_env_key_falls_back_to_raw_decode(self) -> None:
|
||||
"""Default key path: AES fails on non-AES data → fallback to raw decode."""
|
||||
raw = "readable text".encode()
|
||||
with patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", KEY_16):
|
||||
assert _decrypt_bytes(raw) == "readable text"
|
||||
|
||||
def test_explicit_wrong_key_raises(self) -> None:
|
||||
"""Explicit key path: AES fails → raises, no fallback."""
|
||||
encrypted = _encrypt_string("secret", key=KEY_16)
|
||||
with pytest.raises(ValueError):
|
||||
_decrypt_bytes(encrypted, key=KEY_16_ALT)
|
||||
|
||||
def test_explicit_none_key_with_no_env(self) -> None:
|
||||
"""key=None with empty env → raw decode."""
|
||||
with patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", ""):
|
||||
assert _decrypt_bytes(b"hello", key=None) == "hello"
|
||||
|
||||
def test_explicit_empty_string_key(self) -> None:
|
||||
"""key='' means no encryption."""
|
||||
encrypted = _encrypt_string("test", key="")
|
||||
assert encrypted == b"test"
|
||||
assert _decrypt_bytes(encrypted, key="") == "test"
|
||||
|
||||
|
||||
class TestKeyValidation:
|
||||
def test_key_too_short_raises(self) -> None:
|
||||
with pytest.raises(RuntimeError, match="too short"):
|
||||
_encrypt_string("data", key="short")
|
||||
|
||||
def test_16_byte_key(self) -> None:
|
||||
encrypted = _encrypt_string("data", key=KEY_16)
|
||||
assert _decrypt_bytes(encrypted, key=KEY_16) == "data"
|
||||
|
||||
def test_24_byte_key(self) -> None:
|
||||
encrypted = _encrypt_string("data", key=KEY_24)
|
||||
assert _decrypt_bytes(encrypted, key=KEY_24) == "data"
|
||||
|
||||
def test_32_byte_key(self) -> None:
|
||||
encrypted = _encrypt_string("data", key=KEY_32)
|
||||
assert _decrypt_bytes(encrypted, key=KEY_32) == "data"
|
||||
|
||||
def test_long_key_truncated_to_32(self) -> None:
|
||||
"""Keys longer than 32 bytes are truncated to 32."""
|
||||
long_key = "e" * 64
|
||||
encrypted = _encrypt_string("data", key=long_key)
|
||||
assert _decrypt_bytes(encrypted, key=long_key) == "data"
|
||||
|
||||
def test_20_byte_key_trimmed_to_16(self) -> None:
|
||||
"""A 20-byte key is trimmed to the largest valid AES size that fits (16)."""
|
||||
key_20 = "f" * 20
|
||||
encrypted = _encrypt_string("data", key=key_20)
|
||||
assert _decrypt_bytes(encrypted, key=key_20) == "data"
|
||||
|
||||
# Verify it was trimmed to 16 by checking that the first 16 bytes
|
||||
# of the key can also decrypt it
|
||||
key_16_same_prefix = "f" * 16
|
||||
assert _decrypt_bytes(encrypted, key=key_16_same_prefix) == "data"
|
||||
|
||||
def test_25_byte_key_trimmed_to_24(self) -> None:
|
||||
"""A 25-byte key is trimmed to the largest valid AES size that fits (24)."""
|
||||
key_25 = "g" * 25
|
||||
encrypted = _encrypt_string("data", key=key_25)
|
||||
assert _decrypt_bytes(encrypted, key=key_25) == "data"
|
||||
|
||||
key_24_same_prefix = "g" * 24
|
||||
assert _decrypt_bytes(encrypted, key=key_24_same_prefix) == "data"
|
||||
|
||||
def test_30_byte_key_trimmed_to_24(self) -> None:
|
||||
"""A 30-byte key is trimmed to the largest valid AES size that fits (24)."""
|
||||
key_30 = "h" * 30
|
||||
encrypted = _encrypt_string("data", key=key_30)
|
||||
assert _decrypt_bytes(encrypted, key=key_30) == "data"
|
||||
|
||||
key_24_same_prefix = "h" * 24
|
||||
assert _decrypt_bytes(encrypted, key=key_24_same_prefix) == "data"
|
||||
|
||||
|
||||
class TestWrapperFunctions:
|
||||
"""Test encrypt_string_to_bytes / decrypt_bytes_to_string pass key through.
|
||||
|
||||
With EE mode enabled, the wrappers resolve to EE implementations automatically.
|
||||
"""
|
||||
|
||||
def test_wrapper_passes_key(self) -> None:
|
||||
encrypted = encrypt_string_to_bytes("payload", key=KEY_16)
|
||||
assert decrypt_bytes_to_string(encrypted, key=KEY_16) == "payload"
|
||||
|
||||
def test_wrapper_no_key_uses_env(self) -> None:
|
||||
with patch(f"{EE_MODULE}.ENCRYPTION_KEY_SECRET", KEY_32):
|
||||
encrypted = encrypt_string_to_bytes("payload")
|
||||
assert decrypt_bytes_to_string(encrypted) == "payload"
|
||||
630
backend/tests/unit/onyx/chat/test_argument_delta_streaming.py
Normal file
630
backend/tests/unit/onyx/chat/test_argument_delta_streaming.py
Normal file
@@ -0,0 +1,630 @@
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.chat.tool_call_args_streaming import maybe_emit_argument_delta
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import ToolCallArgumentDelta
|
||||
from onyx.utils.jsonriver import Parser
|
||||
|
||||
|
||||
def _make_tool_call_delta(
|
||||
index: int = 0,
|
||||
name: str | None = None,
|
||||
arguments: str | None = None,
|
||||
function_is_none: bool = False,
|
||||
) -> MagicMock:
|
||||
"""Create a mock tool_call_delta matching the LiteLLM streaming shape."""
|
||||
delta = MagicMock()
|
||||
delta.index = index
|
||||
if function_is_none:
|
||||
delta.function = None
|
||||
else:
|
||||
delta.function = MagicMock()
|
||||
delta.function.name = name
|
||||
delta.function.arguments = arguments
|
||||
return delta
|
||||
|
||||
|
||||
def _make_placement() -> Placement:
|
||||
return Placement(turn_index=0, tab_index=0)
|
||||
|
||||
|
||||
def _mock_tool_class(emit: bool = True) -> MagicMock:
|
||||
cls = MagicMock()
|
||||
cls.should_emit_argument_deltas.return_value = emit
|
||||
return cls
|
||||
|
||||
|
||||
def _collect(
|
||||
tc_map: dict[int, dict[str, Any]],
|
||||
delta: MagicMock,
|
||||
placement: Placement | None = None,
|
||||
parsers: dict[int, Parser] | None = None,
|
||||
) -> list[Any]:
|
||||
"""Run maybe_emit_argument_delta and return the yielded packets."""
|
||||
return list(
|
||||
maybe_emit_argument_delta(
|
||||
tc_map,
|
||||
delta,
|
||||
placement or _make_placement(),
|
||||
parsers if parsers is not None else {},
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def _stream_fragments(
|
||||
fragments: list[str],
|
||||
tc_map: dict[int, dict[str, Any]],
|
||||
placement: Placement | None = None,
|
||||
) -> list[str]:
|
||||
"""Feed fragments into maybe_emit_argument_delta one by one, returning
|
||||
all emitted content values concatenated per-key as a flat list."""
|
||||
pl = placement or _make_placement()
|
||||
parsers: dict[int, Parser] = {}
|
||||
emitted: list[str] = []
|
||||
for frag in fragments:
|
||||
tc_map[0]["arguments"] += frag
|
||||
delta = _make_tool_call_delta(arguments=frag)
|
||||
for packet in maybe_emit_argument_delta(tc_map, delta, pl, parsers=parsers):
|
||||
obj = packet.obj
|
||||
assert isinstance(obj, ToolCallArgumentDelta)
|
||||
for value in obj.argument_deltas.values():
|
||||
emitted.append(value)
|
||||
return emitted
|
||||
|
||||
|
||||
class TestMaybeEmitArgumentDeltaGuards:
|
||||
"""Tests for conditions that cause no packet to be emitted."""
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_no_emission_when_tool_does_not_opt_in(
|
||||
self, mock_get_tool: MagicMock
|
||||
) -> None:
|
||||
"""Tools that return False from should_emit_argument_deltas emit nothing."""
|
||||
mock_get_tool.return_value = _mock_tool_class(emit=False)
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": '{"code": "x'}
|
||||
}
|
||||
assert _collect(tc_map, _make_tool_call_delta(arguments="x")) == []
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_no_emission_when_tool_class_unknown(
|
||||
self, mock_get_tool: MagicMock
|
||||
) -> None:
|
||||
mock_get_tool.return_value = None
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "unknown", "arguments": '{"code": "x'}
|
||||
}
|
||||
assert _collect(tc_map, _make_tool_call_delta(arguments="x")) == []
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_no_emission_when_no_argument_fragment(
|
||||
self, mock_get_tool: MagicMock
|
||||
) -> None:
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": '{"code": "x'}
|
||||
}
|
||||
assert _collect(tc_map, _make_tool_call_delta(arguments=None)) == []
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_no_emission_when_key_value_not_started(
|
||||
self, mock_get_tool: MagicMock
|
||||
) -> None:
|
||||
"""Key exists in JSON but its string value hasn't begun yet."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": '{"code":'}
|
||||
}
|
||||
assert _collect(tc_map, _make_tool_call_delta(arguments=":")) == []
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_no_emission_before_any_key(self, mock_get_tool: MagicMock) -> None:
|
||||
"""Only the opening brace has arrived — no key to stream yet."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": "{"}
|
||||
}
|
||||
assert _collect(tc_map, _make_tool_call_delta(arguments="{")) == []
|
||||
|
||||
|
||||
class TestMaybeEmitArgumentDeltaBasic:
|
||||
"""Tests for correct packet content and incremental emission."""
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_emits_packet_with_correct_fields(self, mock_get_tool: MagicMock) -> None:
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = ['{"code": "', "print(1)", '"}']
|
||||
|
||||
pl = _make_placement()
|
||||
parsers: dict[int, Parser] = {}
|
||||
all_packets = []
|
||||
for frag in fragments:
|
||||
tc_map[0]["arguments"] += frag
|
||||
packets = _collect(
|
||||
tc_map, _make_tool_call_delta(arguments=frag), pl, parsers
|
||||
)
|
||||
all_packets.extend(packets)
|
||||
|
||||
assert len(all_packets) >= 1
|
||||
# Verify packet structure
|
||||
obj = all_packets[0].obj
|
||||
assert isinstance(obj, ToolCallArgumentDelta)
|
||||
assert obj.tool_type == "python"
|
||||
# All emitted content should reconstruct the value
|
||||
full_code = ""
|
||||
for p in all_packets:
|
||||
assert isinstance(p.obj, ToolCallArgumentDelta)
|
||||
if "code" in p.obj.argument_deltas:
|
||||
full_code += p.obj.argument_deltas["code"]
|
||||
assert full_code == "print(1)"
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_emits_only_new_content_on_subsequent_call(
|
||||
self, mock_get_tool: MagicMock
|
||||
) -> None:
|
||||
"""After a first emission, subsequent calls emit only the diff."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
parsers: dict[int, Parser] = {}
|
||||
pl = _make_placement()
|
||||
|
||||
# First fragment opens the string
|
||||
tc_map[0]["arguments"] = '{"code": "abc'
|
||||
packets_1 = _collect(
|
||||
tc_map, _make_tool_call_delta(arguments='{"code": "abc'), pl, parsers
|
||||
)
|
||||
code_1 = ""
|
||||
for p in packets_1:
|
||||
assert isinstance(p.obj, ToolCallArgumentDelta)
|
||||
code_1 += p.obj.argument_deltas.get("code", "")
|
||||
assert code_1 == "abc"
|
||||
|
||||
# Second fragment appends more
|
||||
tc_map[0]["arguments"] = '{"code": "abcdef'
|
||||
packets_2 = _collect(
|
||||
tc_map, _make_tool_call_delta(arguments="def"), pl, parsers
|
||||
)
|
||||
code_2 = ""
|
||||
for p in packets_2:
|
||||
assert isinstance(p.obj, ToolCallArgumentDelta)
|
||||
code_2 += p.obj.argument_deltas.get("code", "")
|
||||
assert code_2 == "def"
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_handles_multiple_keys_sequentially(self, mock_get_tool: MagicMock) -> None:
|
||||
"""When a second key starts, emissions switch to that key."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = [
|
||||
'{"code": "x',
|
||||
'", "output": "hello',
|
||||
'"}',
|
||||
]
|
||||
|
||||
emitted = _stream_fragments(fragments, tc_map)
|
||||
full = "".join(emitted)
|
||||
assert "x" in full
|
||||
assert "hello" in full
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_delta_spans_key_boundary(self, mock_get_tool: MagicMock) -> None:
|
||||
"""A single delta contains the end of one value and the start of the next key."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = [
|
||||
'{"code": "x',
|
||||
'y", "lang": "py',
|
||||
'"}',
|
||||
]
|
||||
|
||||
emitted = _stream_fragments(fragments, tc_map)
|
||||
full = "".join(emitted)
|
||||
assert "xy" in full
|
||||
assert "py" in full
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_empty_value_emits_nothing(self, mock_get_tool: MagicMock) -> None:
|
||||
"""An empty string value has nothing to emit."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
# Opening quote just arrived, value is empty
|
||||
tc_map[0]["arguments"] = '{"code": "'
|
||||
packets = _collect(tc_map, _make_tool_call_delta(arguments='{"code": "'))
|
||||
# No string content yet, so either no packet or empty deltas
|
||||
for p in packets:
|
||||
assert isinstance(p.obj, ToolCallArgumentDelta)
|
||||
assert p.obj.argument_deltas.get("code", "") == ""
|
||||
|
||||
|
||||
class TestMaybeEmitArgumentDeltaDecoding:
|
||||
"""Tests verifying that JSON escape sequences are properly decoded."""
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_decodes_newlines(self, mock_get_tool: MagicMock) -> None:
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = ['{"code": "line1\\nline2"}']
|
||||
|
||||
emitted = _stream_fragments(fragments, tc_map)
|
||||
assert "".join(emitted) == "line1\nline2"
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_decodes_tabs(self, mock_get_tool: MagicMock) -> None:
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = ['{"code": "\\tindented"}']
|
||||
|
||||
emitted = _stream_fragments(fragments, tc_map)
|
||||
assert "".join(emitted) == "\tindented"
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_decodes_escaped_quotes(self, mock_get_tool: MagicMock) -> None:
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = ['{"code": "say \\"hi\\""}']
|
||||
|
||||
emitted = _stream_fragments(fragments, tc_map)
|
||||
assert "".join(emitted) == 'say "hi"'
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_decodes_escaped_backslashes(self, mock_get_tool: MagicMock) -> None:
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = ['{"code": "path\\\\dir"}']
|
||||
|
||||
emitted = _stream_fragments(fragments, tc_map)
|
||||
assert "".join(emitted) == "path\\dir"
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_decodes_unicode_escape(self, mock_get_tool: MagicMock) -> None:
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = ['{"code": "\\u0041"}']
|
||||
|
||||
emitted = _stream_fragments(fragments, tc_map)
|
||||
assert "".join(emitted) == "A"
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_incomplete_escape_at_end_decoded_on_next_chunk(
|
||||
self, mock_get_tool: MagicMock
|
||||
) -> None:
|
||||
"""A trailing backslash (incomplete escape) is completed in the next chunk."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = ['{"code": "hello\\', 'n"}']
|
||||
|
||||
emitted = _stream_fragments(fragments, tc_map)
|
||||
assert "".join(emitted) == "hello\n"
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_incomplete_unicode_escape_completed_on_next_chunk(
|
||||
self, mock_get_tool: MagicMock
|
||||
) -> None:
|
||||
"""A partial \\uXX sequence is completed in the next chunk."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = ['{"code": "hello\\u00', '41"}']
|
||||
|
||||
emitted = _stream_fragments(fragments, tc_map)
|
||||
assert "".join(emitted) == "helloA"
|
||||
|
||||
|
||||
class TestArgumentDeltaStreamingE2E:
|
||||
"""Simulates realistic sequences of LLM argument deltas to verify
|
||||
the full pipeline produces correct decoded output."""
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_realistic_python_code_streaming(self, mock_get_tool: MagicMock) -> None:
|
||||
"""Streams: {"code": "print('hello')\\nprint('world')"}"""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = [
|
||||
'{"',
|
||||
"code",
|
||||
'": "',
|
||||
"print(",
|
||||
"'hello')",
|
||||
"\\n",
|
||||
"print(",
|
||||
"'world')",
|
||||
'"}',
|
||||
]
|
||||
|
||||
full = "".join(_stream_fragments(fragments, tc_map))
|
||||
assert full == "print('hello')\nprint('world')"
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_streaming_with_tabs_and_newlines(self, mock_get_tool: MagicMock) -> None:
|
||||
"""Streams code with tabs and newlines."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = [
|
||||
'{"code": "',
|
||||
"if True:",
|
||||
"\\n",
|
||||
"\\t",
|
||||
"pass",
|
||||
'"}',
|
||||
]
|
||||
|
||||
full = "".join(_stream_fragments(fragments, tc_map))
|
||||
assert full == "if True:\n\tpass"
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_split_escape_sequence(self, mock_get_tool: MagicMock) -> None:
|
||||
"""An escape sequence split across two fragments (backslash in one,
|
||||
'n' in the next) should still decode correctly."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = [
|
||||
'{"code": "hello',
|
||||
"\\",
|
||||
"n",
|
||||
'world"}',
|
||||
]
|
||||
|
||||
full = "".join(_stream_fragments(fragments, tc_map))
|
||||
assert full == "hello\nworld"
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_multiple_newlines_and_indentation(self, mock_get_tool: MagicMock) -> None:
|
||||
"""Streams a multi-line function with multiple escape sequences."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = [
|
||||
'{"code": "',
|
||||
"def foo():",
|
||||
"\\n",
|
||||
"\\t",
|
||||
"x = 1",
|
||||
"\\n",
|
||||
"\\t",
|
||||
"return x",
|
||||
'"}',
|
||||
]
|
||||
|
||||
full = "".join(_stream_fragments(fragments, tc_map))
|
||||
assert full == "def foo():\n\tx = 1\n\treturn x"
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_two_keys_streamed_sequentially(self, mock_get_tool: MagicMock) -> None:
|
||||
"""Streams code first, then a second key (language) — both decoded."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = [
|
||||
'{"code": "',
|
||||
"x = 1",
|
||||
'", "language": "',
|
||||
"python",
|
||||
'"}',
|
||||
]
|
||||
|
||||
emitted = _stream_fragments(fragments, tc_map)
|
||||
# Should have emissions for both keys
|
||||
full = "".join(emitted)
|
||||
assert "x = 1" in full
|
||||
assert "python" in full
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_code_containing_dict_literal(self, mock_get_tool: MagicMock) -> None:
|
||||
"""Python code like `x = {"key": "val"}` contains JSON-like patterns.
|
||||
The escaped quotes inside the *outer* JSON value should prevent the
|
||||
inner `"key":` from being mistaken for a top-level JSON key."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
# The LLM sends: {"code": "x = {\"key\": \"val\"}"}
|
||||
# The inner quotes are escaped as \" in the JSON value.
|
||||
fragments = [
|
||||
'{"code": "',
|
||||
"x = {",
|
||||
'\\"key\\"',
|
||||
": ",
|
||||
'\\"val\\"',
|
||||
"}",
|
||||
'"}',
|
||||
]
|
||||
|
||||
full = "".join(_stream_fragments(fragments, tc_map))
|
||||
assert full == 'x = {"key": "val"}'
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_code_with_colon_in_value(self, mock_get_tool: MagicMock) -> None:
|
||||
"""Colons inside the string value should not confuse key detection."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = [
|
||||
'{"code": "',
|
||||
"url = ",
|
||||
'\\"https://example.com\\"',
|
||||
'"}',
|
||||
]
|
||||
|
||||
full = "".join(_stream_fragments(fragments, tc_map))
|
||||
assert full == 'url = "https://example.com"'
|
||||
|
||||
|
||||
class TestMaybeEmitArgumentDeltaEdgeCases:
|
||||
"""Edge cases not covered by the standard test classes."""
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_no_emission_when_function_is_none(self, mock_get_tool: MagicMock) -> None:
|
||||
"""Some delta chunks have function=None (e.g. role-only deltas)."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": '{"code": "x'}
|
||||
}
|
||||
delta = _make_tool_call_delta(arguments=None, function_is_none=True)
|
||||
assert _collect(tc_map, delta) == []
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_multiple_concurrent_tool_calls(self, mock_get_tool: MagicMock) -> None:
|
||||
"""Two tool calls streaming at different indices in parallel."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""},
|
||||
1: {"id": "tc_2", "name": "python", "arguments": ""},
|
||||
}
|
||||
|
||||
parsers: dict[int, Parser] = {}
|
||||
pl = _make_placement()
|
||||
|
||||
# Feed full JSON to index 0
|
||||
tc_map[0]["arguments"] = '{"code": "aaa"}'
|
||||
packets_0 = _collect(
|
||||
tc_map,
|
||||
_make_tool_call_delta(index=0, arguments='{"code": "aaa"}'),
|
||||
pl,
|
||||
parsers,
|
||||
)
|
||||
code_0 = ""
|
||||
for p in packets_0:
|
||||
assert isinstance(p.obj, ToolCallArgumentDelta)
|
||||
code_0 += p.obj.argument_deltas.get("code", "")
|
||||
assert code_0 == "aaa"
|
||||
|
||||
# Feed full JSON to index 1
|
||||
tc_map[1]["arguments"] = '{"code": "bbb"}'
|
||||
packets_1 = _collect(
|
||||
tc_map,
|
||||
_make_tool_call_delta(index=1, arguments='{"code": "bbb"}'),
|
||||
pl,
|
||||
parsers,
|
||||
)
|
||||
code_1 = ""
|
||||
for p in packets_1:
|
||||
assert isinstance(p.obj, ToolCallArgumentDelta)
|
||||
code_1 += p.obj.argument_deltas.get("code", "")
|
||||
assert code_1 == "bbb"
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_delta_with_four_arguments(self, mock_get_tool: MagicMock) -> None:
|
||||
"""A single delta contains four complete key-value pairs."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
full = '{"a": "one", "b": "two", "c": "three", "d": "four"}'
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
tc_map[0]["arguments"] = full
|
||||
parsers: dict[int, Parser] = {}
|
||||
packets = _collect(
|
||||
tc_map, _make_tool_call_delta(arguments=full), parsers=parsers
|
||||
)
|
||||
|
||||
# Collect all argument deltas across packets
|
||||
all_deltas: dict[str, str] = {}
|
||||
for p in packets:
|
||||
assert isinstance(p.obj, ToolCallArgumentDelta)
|
||||
for k, v in p.obj.argument_deltas.items():
|
||||
all_deltas[k] = all_deltas.get(k, "") + v
|
||||
|
||||
assert all_deltas == {
|
||||
"a": "one",
|
||||
"b": "two",
|
||||
"c": "three",
|
||||
"d": "four",
|
||||
}
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_delta_on_second_arg_after_first_complete(
|
||||
self, mock_get_tool: MagicMock
|
||||
) -> None:
|
||||
"""First argument is fully complete; delta only adds to the second."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
|
||||
fragments = [
|
||||
'{"code": "print(1)", "lang": "py',
|
||||
'"}',
|
||||
]
|
||||
|
||||
emitted = _stream_fragments(fragments, tc_map)
|
||||
full = "".join(emitted)
|
||||
assert "print(1)" in full
|
||||
assert "py" in full
|
||||
|
||||
@patch("onyx.chat.tool_call_args_streaming._get_tool_class")
|
||||
def test_non_string_values_skipped(self, mock_get_tool: MagicMock) -> None:
|
||||
"""Non-string values (numbers, booleans, null) are skipped — they are
|
||||
available in the final tool-call kickoff packet. String arguments
|
||||
following them are still emitted."""
|
||||
mock_get_tool.return_value = _mock_tool_class()
|
||||
|
||||
tc_map: dict[int, dict[str, Any]] = {
|
||||
0: {"id": "tc_1", "name": "python", "arguments": ""}
|
||||
}
|
||||
fragments = ['{"timeout": 30, "code": "hello"}']
|
||||
|
||||
emitted = _stream_fragments(fragments, tc_map)
|
||||
full = "".join(emitted)
|
||||
assert full == "hello"
|
||||
188
backend/tests/unit/onyx/server/test_projects_file_utils.py
Normal file
188
backend/tests/unit/onyx/server/test_projects_file_utils.py
Normal file
@@ -0,0 +1,188 @@
|
||||
from io import BytesIO
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from fastapi import UploadFile
|
||||
|
||||
from onyx.server.features.projects import projects_file_utils as utils
|
||||
|
||||
|
||||
class _Tokenizer:
|
||||
def encode(self, text: str) -> list[int]:
|
||||
return [1] * len(text)
|
||||
|
||||
|
||||
class _NonSeekableFile(BytesIO):
|
||||
def tell(self) -> int:
|
||||
raise OSError("tell not supported")
|
||||
|
||||
def seek(self, *_args: object, **_kwargs: object) -> int:
|
||||
raise OSError("seek not supported")
|
||||
|
||||
|
||||
def _make_upload(filename: str, size: int, content: bytes | None = None) -> UploadFile:
|
||||
payload = content if content is not None else (b"x" * size)
|
||||
return UploadFile(filename=filename, file=BytesIO(payload), size=size)
|
||||
|
||||
|
||||
def _make_upload_no_size(filename: str, content: bytes) -> UploadFile:
|
||||
return UploadFile(filename=filename, file=BytesIO(content), size=None)
|
||||
|
||||
|
||||
def _patch_common_dependencies(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
monkeypatch.setattr(utils, "fetch_default_llm_model", lambda _db: None)
|
||||
monkeypatch.setattr(utils, "get_tokenizer", lambda **_kwargs: _Tokenizer())
|
||||
monkeypatch.setattr(utils, "is_file_password_protected", lambda **_kwargs: False)
|
||||
|
||||
|
||||
def test_get_upload_size_bytes_falls_back_to_stream_size() -> None:
|
||||
upload = UploadFile(filename="example.txt", file=BytesIO(b"abcdef"), size=None)
|
||||
upload.file.seek(2)
|
||||
|
||||
size = utils.get_upload_size_bytes(upload)
|
||||
|
||||
assert size == 6
|
||||
assert upload.file.tell() == 2
|
||||
|
||||
|
||||
def test_get_upload_size_bytes_logs_warning_when_stream_size_unavailable(
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
upload = UploadFile(filename="non_seekable.txt", file=_NonSeekableFile(), size=None)
|
||||
|
||||
caplog.set_level("WARNING")
|
||||
size = utils.get_upload_size_bytes(upload)
|
||||
|
||||
assert size is None
|
||||
assert "Could not determine upload size via stream seek" in caplog.text
|
||||
assert "non_seekable.txt" in caplog.text
|
||||
|
||||
|
||||
def test_is_upload_too_large_logs_warning_when_size_unknown(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
caplog: pytest.LogCaptureFixture,
|
||||
) -> None:
|
||||
upload = _make_upload("size_unknown.txt", size=1)
|
||||
monkeypatch.setattr(utils, "get_upload_size_bytes", lambda _upload: None)
|
||||
|
||||
caplog.set_level("WARNING")
|
||||
is_too_large = utils.is_upload_too_large(upload, max_bytes=100)
|
||||
|
||||
assert is_too_large is False
|
||||
assert "Could not determine upload size; skipping size-limit check" in caplog.text
|
||||
assert "size_unknown.txt" in caplog.text
|
||||
|
||||
|
||||
def test_categorize_uploaded_files_accepts_size_under_limit(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
|
||||
|
||||
upload = _make_upload("small.png", size=99)
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.acceptable) == 1
|
||||
assert len(result.rejected) == 0
|
||||
|
||||
|
||||
def test_categorize_uploaded_files_uses_seek_fallback_when_upload_size_missing(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
|
||||
|
||||
upload = _make_upload_no_size("small.png", content=b"x" * 99)
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.acceptable) == 1
|
||||
assert len(result.rejected) == 0
|
||||
|
||||
|
||||
def test_categorize_uploaded_files_accepts_size_at_limit(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
|
||||
|
||||
upload = _make_upload("edge.png", size=100)
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.acceptable) == 1
|
||||
assert len(result.rejected) == 0
|
||||
|
||||
|
||||
def test_categorize_uploaded_files_rejects_size_over_limit_with_reason(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
|
||||
|
||||
upload = _make_upload("large.png", size=101)
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.acceptable) == 0
|
||||
assert len(result.rejected) == 1
|
||||
assert result.rejected[0].reason == "Exceeds 1 MB file size limit"
|
||||
|
||||
|
||||
def test_categorize_uploaded_files_mixed_batch_keeps_valid_and_rejects_oversized(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
monkeypatch.setattr(utils, "estimate_image_tokens_for_upload", lambda _upload: 10)
|
||||
|
||||
small = _make_upload("small.png", size=50)
|
||||
large = _make_upload("large.png", size=101)
|
||||
|
||||
result = utils.categorize_uploaded_files([small, large], MagicMock())
|
||||
|
||||
assert [file.filename for file in result.acceptable] == ["small.png"]
|
||||
assert len(result.rejected) == 1
|
||||
assert result.rejected[0].filename == "large.png"
|
||||
assert result.rejected[0].reason == "Exceeds 1 MB file size limit"
|
||||
|
||||
|
||||
def test_categorize_uploaded_files_enforces_size_limit_even_when_threshold_is_skipped(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "SKIP_USERFILE_THRESHOLD", True)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
|
||||
upload = _make_upload("oversized.pdf", size=101)
|
||||
result = utils.categorize_uploaded_files([upload], MagicMock())
|
||||
|
||||
assert len(result.acceptable) == 0
|
||||
assert len(result.rejected) == 1
|
||||
assert result.rejected[0].reason == "Exceeds 1 MB file size limit"
|
||||
|
||||
|
||||
def test_categorize_uploaded_files_checks_size_before_text_extraction(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
_patch_common_dependencies(monkeypatch)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_BYTES", 100)
|
||||
monkeypatch.setattr(utils, "USER_FILE_MAX_UPLOAD_SIZE_MB", 1)
|
||||
|
||||
extract_mock = MagicMock(return_value="this should not run")
|
||||
monkeypatch.setattr(utils, "extract_file_text", extract_mock)
|
||||
|
||||
oversized_doc = _make_upload("oversized.pdf", size=101)
|
||||
result = utils.categorize_uploaded_files([oversized_doc], MagicMock())
|
||||
|
||||
extract_mock.assert_not_called()
|
||||
assert len(result.acceptable) == 0
|
||||
assert len(result.rejected) == 1
|
||||
assert result.rejected[0].reason == "Exceeds 1 MB file size limit"
|
||||
32
backend/tests/unit/onyx/server/test_settings_store.py
Normal file
32
backend/tests/unit/onyx/server/test_settings_store.py
Normal file
@@ -0,0 +1,32 @@
|
||||
import pytest
|
||||
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.server.settings import store as settings_store
|
||||
|
||||
|
||||
class _FakeKvStore:
|
||||
def load(self, _key: str) -> dict:
|
||||
raise KvKeyNotFoundError()
|
||||
|
||||
|
||||
class _FakeCache:
|
||||
def __init__(self) -> None:
|
||||
self._vals: dict[str, bytes] = {}
|
||||
|
||||
def get(self, key: str) -> bytes | None:
|
||||
return self._vals.get(key)
|
||||
|
||||
def set(self, key: str, value: str, ex: int | None = None) -> None: # noqa: ARG002
|
||||
self._vals[key] = value.encode("utf-8")
|
||||
|
||||
|
||||
def test_load_settings_includes_user_file_max_upload_size_mb(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(settings_store, "get_kv_store", lambda: _FakeKvStore())
|
||||
monkeypatch.setattr(settings_store, "get_cache_backend", lambda: _FakeCache())
|
||||
monkeypatch.setattr(settings_store, "USER_FILE_MAX_UPLOAD_SIZE_MB", 77)
|
||||
|
||||
settings = settings_store.load_settings()
|
||||
|
||||
assert settings.user_file_max_upload_size_mb == 77
|
||||
22
cli/Dockerfile
Normal file
22
cli/Dockerfile
Normal file
@@ -0,0 +1,22 @@
|
||||
FROM golang:1.26-alpine@sha256:2389ebfa5b7f43eeafbd6be0c3700cc46690ef842ad962f6c5bd6be49ed82039 AS builder
|
||||
|
||||
WORKDIR /app
|
||||
COPY ./ .
|
||||
|
||||
ARG TARGETARCH
|
||||
RUN CGO_ENABLED=0 GOOS=linux GOARCH=${TARGETARCH} go build -ldflags="-s -w" -o onyx-cli .
|
||||
RUN mkdir -p /home/onyx/.config
|
||||
|
||||
FROM scratch
|
||||
|
||||
COPY --from=builder /etc/ssl/certs/ca-certificates.crt /etc/ssl/certs/
|
||||
COPY --from=builder --chown=65534:65534 /home/onyx /home/onyx
|
||||
|
||||
COPY --from=builder /app/onyx-cli /onyx-cli
|
||||
|
||||
ENV HOME=/home/onyx
|
||||
ENV XDG_CONFIG_HOME=/home/onyx/.config
|
||||
|
||||
USER 65534:65534
|
||||
|
||||
ENTRYPOINT ["/onyx-cli"]
|
||||
@@ -1261,3 +1261,5 @@ configMap:
|
||||
SKIP_USERFILE_THRESHOLD: ""
|
||||
# For multi-tenant: comma-separated list of tenant IDs to skip threshold
|
||||
SKIP_USERFILE_THRESHOLD_TENANT_IDS: ""
|
||||
# Maximum user upload file size in MB for chat/projects uploads
|
||||
USER_FILE_MAX_UPLOAD_SIZE_MB: ""
|
||||
|
||||
@@ -18,6 +18,10 @@ variable "INTEGRATION_REPOSITORY" {
|
||||
default = "onyxdotapp/onyx-integration"
|
||||
}
|
||||
|
||||
variable "CLI_REPOSITORY" {
|
||||
default = "onyxdotapp/onyx-cli"
|
||||
}
|
||||
|
||||
variable "TAG" {
|
||||
default = "latest"
|
||||
}
|
||||
@@ -64,3 +68,13 @@ target "integration" {
|
||||
|
||||
tags = ["${INTEGRATION_REPOSITORY}:${TAG}"]
|
||||
}
|
||||
|
||||
target "cli" {
|
||||
context = "cli"
|
||||
dockerfile = "Dockerfile"
|
||||
|
||||
cache-from = ["type=registry,ref=${CLI_REPOSITORY}:latest"]
|
||||
cache-to = ["type=inline"]
|
||||
|
||||
tags = ["${CLI_REPOSITORY}:${TAG}"]
|
||||
}
|
||||
|
||||
@@ -156,6 +156,7 @@ module.exports = {
|
||||
"**/src/app/**/*.test.tsx",
|
||||
"**/src/components/**/*.test.tsx",
|
||||
"**/src/lib/**/*.test.tsx",
|
||||
"**/src/providers/**/*.test.tsx",
|
||||
"**/src/refresh-components/**/*.test.tsx",
|
||||
"**/src/hooks/**/*.test.tsx",
|
||||
"**/src/sections/**/*.test.tsx",
|
||||
|
||||
@@ -8,7 +8,8 @@ const cspHeader = `
|
||||
base-uri 'self';
|
||||
form-action 'self';
|
||||
${
|
||||
process.env.NEXT_PUBLIC_CLOUD_ENABLED === "true"
|
||||
process.env.NEXT_PUBLIC_CLOUD_ENABLED === "true" &&
|
||||
process.env.NODE_ENV !== "development"
|
||||
? "upgrade-insecure-requests;"
|
||||
: ""
|
||||
}
|
||||
|
||||
@@ -7,7 +7,7 @@ import { SlackTokensForm } from "./SlackTokensForm";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import { SvgSlack } from "@opal/icons";
|
||||
|
||||
export const NewSlackBotForm = () => {
|
||||
export function NewSlackBotForm() {
|
||||
const [formValues] = useState({
|
||||
name: "",
|
||||
enabled: true,
|
||||
@@ -19,7 +19,12 @@ export const NewSlackBotForm = () => {
|
||||
|
||||
return (
|
||||
<SettingsLayouts.Root>
|
||||
<SettingsLayouts.Header icon={SvgSlack} title="New Slack Bot" separator />
|
||||
<SettingsLayouts.Header
|
||||
icon={SvgSlack}
|
||||
title="New Slack Bot"
|
||||
separator
|
||||
backButton
|
||||
/>
|
||||
<SettingsLayouts.Body>
|
||||
<CardSection>
|
||||
<div className="p-4">
|
||||
@@ -33,4 +38,4 @@ export const NewSlackBotForm = () => {
|
||||
</SettingsLayouts.Body>
|
||||
</SettingsLayouts.Root>
|
||||
);
|
||||
};
|
||||
}
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
"use client";
|
||||
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { SlackBot, ValidSources } from "@/lib/types";
|
||||
import { SlackBot } from "@/lib/types";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useState, useEffect, useRef } from "react";
|
||||
import { updateSlackBotField } from "@/lib/updateSlackBotField";
|
||||
import { SlackTokensForm } from "./SlackTokensForm";
|
||||
import { SourceIcon } from "@/components/SourceIcon";
|
||||
|
||||
import { EditableStringFieldDisplay } from "@/components/EditableStringFieldDisplay";
|
||||
import { deleteSlackBot } from "./new/lib";
|
||||
import GenericConfirmModal from "@/components/modals/GenericConfirmModal";
|
||||
@@ -90,10 +90,7 @@ export const ExistingSlackBotForm = ({
|
||||
<div>
|
||||
<div className="flex items-center justify-between h-14">
|
||||
<div className="flex items-center gap-2">
|
||||
<div className="my-auto">
|
||||
<SourceIcon iconSize={32} sourceType={ValidSources.Slack} />
|
||||
</div>
|
||||
<div className="ml-1">
|
||||
<div>
|
||||
<EditableStringFieldDisplay
|
||||
value={formValues.name}
|
||||
isEditable={true}
|
||||
|
||||
@@ -1,100 +1,122 @@
|
||||
import { SlackChannelConfigCreationForm } from "../SlackChannelConfigCreationForm";
|
||||
import { fetchSS } from "@/lib/utilsSS";
|
||||
"use client";
|
||||
|
||||
import { use } from "react";
|
||||
import { SlackChannelConfigCreationForm } from "@/app/admin/bots/[bot-id]/channels/SlackChannelConfigCreationForm";
|
||||
import { ErrorCallout } from "@/components/ErrorCallout";
|
||||
import { DocumentSetSummary, SlackChannelConfig } from "@/lib/types";
|
||||
import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh";
|
||||
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import { SvgSlack } from "@opal/icons";
|
||||
import { FetchAgentsResponse, fetchAgentsSS } from "@/lib/agentsSS";
|
||||
import { getStandardAnswerCategoriesIfEE } from "@/components/standardAnswers/getStandardAnswerCategoriesIfEE";
|
||||
import { useSlackChannelConfigs } from "@/app/admin/bots/[bot-id]/hooks";
|
||||
import { useDocumentSets } from "@/app/admin/documents/sets/hooks";
|
||||
import { useAgents } from "@/hooks/useAgents";
|
||||
import { useStandardAnswerCategories } from "@/app/ee/admin/standard-answer/hooks";
|
||||
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
|
||||
import type { StandardAnswerCategoryResponse } from "@/components/standardAnswers/getStandardAnswerCategoriesIfEE";
|
||||
|
||||
async function EditslackChannelConfigPage(props: {
|
||||
params: Promise<{ id: number }>;
|
||||
}) {
|
||||
const params = await props.params;
|
||||
const tasks = [
|
||||
fetchSS("/manage/admin/slack-app/channel"),
|
||||
fetchSS("/manage/document-set"),
|
||||
fetchAgentsSS(),
|
||||
];
|
||||
function EditSlackChannelConfigContent({ id }: { id: string }) {
|
||||
const isPaidEnterprise = usePaidEnterpriseFeaturesEnabled();
|
||||
|
||||
const [
|
||||
slackChannelsResponse,
|
||||
documentSetsResponse,
|
||||
[assistants, agentsFetchError],
|
||||
] = (await Promise.all(tasks)) as [Response, Response, FetchAgentsResponse];
|
||||
const {
|
||||
data: slackChannelConfigs,
|
||||
isLoading: isChannelsLoading,
|
||||
error: channelsError,
|
||||
} = useSlackChannelConfigs();
|
||||
|
||||
const eeStandardAnswerCategoryResponse =
|
||||
await getStandardAnswerCategoriesIfEE();
|
||||
const {
|
||||
data: documentSets,
|
||||
isLoading: isDocSetsLoading,
|
||||
error: docSetsError,
|
||||
} = useDocumentSets();
|
||||
|
||||
if (!slackChannelsResponse.ok) {
|
||||
return (
|
||||
<ErrorCallout
|
||||
errorTitle="Something went wrong :("
|
||||
errorMsg={`Failed to fetch Slack Channels - ${await slackChannelsResponse.text()}`}
|
||||
/>
|
||||
);
|
||||
}
|
||||
const allslackChannelConfigs =
|
||||
(await slackChannelsResponse.json()) as SlackChannelConfig[];
|
||||
const {
|
||||
agents,
|
||||
isLoading: isAgentsLoading,
|
||||
error: agentsError,
|
||||
} = useAgents();
|
||||
|
||||
const slackChannelConfig = allslackChannelConfigs.find(
|
||||
(config) => config.id === Number(params.id)
|
||||
const {
|
||||
data: standardAnswerCategories,
|
||||
isLoading: isStdAnswerLoading,
|
||||
error: stdAnswerError,
|
||||
} = useStandardAnswerCategories();
|
||||
|
||||
const isLoading =
|
||||
isChannelsLoading ||
|
||||
isDocSetsLoading ||
|
||||
isAgentsLoading ||
|
||||
(isPaidEnterprise && isStdAnswerLoading);
|
||||
|
||||
const slackChannelConfig = slackChannelConfigs?.find(
|
||||
(config) => config.id === Number(id)
|
||||
);
|
||||
|
||||
if (!slackChannelConfig) {
|
||||
return (
|
||||
<ErrorCallout
|
||||
errorTitle="Something went wrong :("
|
||||
errorMsg={`Did not find Slack Channel config with ID: ${params.id}`}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (!documentSetsResponse.ok) {
|
||||
return (
|
||||
<ErrorCallout
|
||||
errorTitle="Something went wrong :("
|
||||
errorMsg={`Failed to fetch document sets - ${await documentSetsResponse.text()}`}
|
||||
/>
|
||||
);
|
||||
}
|
||||
const response = await documentSetsResponse.json();
|
||||
const documentSets = response as DocumentSetSummary[];
|
||||
|
||||
if (agentsFetchError) {
|
||||
return (
|
||||
<ErrorCallout
|
||||
errorTitle="Something went wrong :("
|
||||
errorMsg={`Failed to fetch personas - ${agentsFetchError}`}
|
||||
/>
|
||||
);
|
||||
}
|
||||
const title = slackChannelConfig?.is_default
|
||||
? "Edit Default Slack Config"
|
||||
: "Edit Slack Channel Config";
|
||||
|
||||
return (
|
||||
<SettingsLayouts.Root>
|
||||
<InstantSSRAutoRefresh />
|
||||
<SettingsLayouts.Header
|
||||
icon={SvgSlack}
|
||||
title={
|
||||
slackChannelConfig.is_default
|
||||
? "Edit Default Slack Config"
|
||||
: "Edit Slack Channel Config"
|
||||
}
|
||||
title={title}
|
||||
separator
|
||||
backButton
|
||||
/>
|
||||
<SettingsLayouts.Body>
|
||||
<SlackChannelConfigCreationForm
|
||||
slack_bot_id={slackChannelConfig.slack_bot_id}
|
||||
documentSets={documentSets}
|
||||
personas={assistants}
|
||||
standardAnswerCategoryResponse={eeStandardAnswerCategoryResponse}
|
||||
existingSlackChannelConfig={slackChannelConfig}
|
||||
/>
|
||||
{isLoading ? (
|
||||
<SimpleLoader />
|
||||
) : channelsError || !slackChannelConfigs ? (
|
||||
<ErrorCallout
|
||||
errorTitle="Something went wrong :("
|
||||
errorMsg={`Failed to fetch Slack Channels - ${
|
||||
channelsError?.message ?? "unknown error"
|
||||
}`}
|
||||
/>
|
||||
) : !slackChannelConfig ? (
|
||||
<ErrorCallout
|
||||
errorTitle="Something went wrong :("
|
||||
errorMsg={`Did not find Slack Channel config with ID: ${id}`}
|
||||
/>
|
||||
) : docSetsError || !documentSets ? (
|
||||
<ErrorCallout
|
||||
errorTitle="Something went wrong :("
|
||||
errorMsg={`Failed to fetch document sets - ${
|
||||
docSetsError?.message ?? "unknown error"
|
||||
}`}
|
||||
/>
|
||||
) : agentsError ? (
|
||||
<ErrorCallout
|
||||
errorTitle="Something went wrong :("
|
||||
errorMsg={`Failed to fetch agents - ${
|
||||
agentsError?.message ?? "unknown error"
|
||||
}`}
|
||||
/>
|
||||
) : (
|
||||
<SlackChannelConfigCreationForm
|
||||
slack_bot_id={slackChannelConfig.slack_bot_id}
|
||||
documentSets={documentSets}
|
||||
personas={agents}
|
||||
standardAnswerCategoryResponse={
|
||||
isPaidEnterprise
|
||||
? {
|
||||
paidEnterpriseFeaturesEnabled: true,
|
||||
categories: standardAnswerCategories ?? [],
|
||||
...(stdAnswerError
|
||||
? { error: { message: String(stdAnswerError) } }
|
||||
: {}),
|
||||
}
|
||||
: { paidEnterpriseFeaturesEnabled: false }
|
||||
}
|
||||
existingSlackChannelConfig={slackChannelConfig}
|
||||
/>
|
||||
)}
|
||||
</SettingsLayouts.Body>
|
||||
</SettingsLayouts.Root>
|
||||
);
|
||||
}
|
||||
|
||||
export default EditslackChannelConfigPage;
|
||||
export default function Page(props: { params: Promise<{ id: string }> }) {
|
||||
const params = use(props.params);
|
||||
|
||||
return <EditSlackChannelConfigContent id={params.id} />;
|
||||
}
|
||||
|
||||
@@ -1,53 +1,109 @@
|
||||
import { SlackChannelConfigCreationForm } from "../SlackChannelConfigCreationForm";
|
||||
import { fetchSS } from "@/lib/utilsSS";
|
||||
"use client";
|
||||
|
||||
import { use, useEffect } from "react";
|
||||
import { SlackChannelConfigCreationForm } from "@/app/admin/bots/[bot-id]/channels/SlackChannelConfigCreationForm";
|
||||
import { ErrorCallout } from "@/components/ErrorCallout";
|
||||
import { DocumentSetSummary } from "@/lib/types";
|
||||
import { fetchAgentsSS } from "@/lib/agentsSS";
|
||||
import { getStandardAnswerCategoriesIfEE } from "@/components/standardAnswers/getStandardAnswerCategoriesIfEE";
|
||||
import { redirect } from "next/navigation";
|
||||
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import { SvgSlack } from "@opal/icons";
|
||||
import { useDocumentSets } from "@/app/admin/documents/sets/hooks";
|
||||
import { useAgents } from "@/hooks/useAgents";
|
||||
import { useStandardAnswerCategories } from "@/app/ee/admin/standard-answer/hooks";
|
||||
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
|
||||
import type { StandardAnswerCategoryResponse } from "@/components/standardAnswers/getStandardAnswerCategoriesIfEE";
|
||||
import { useRouter } from "next/navigation";
|
||||
|
||||
function NewChannelConfigContent({ slackBotId }: { slackBotId: number }) {
|
||||
const isPaidEnterprise = usePaidEnterpriseFeaturesEnabled();
|
||||
|
||||
const {
|
||||
data: documentSets,
|
||||
isLoading: isDocSetsLoading,
|
||||
error: docSetsError,
|
||||
} = useDocumentSets();
|
||||
|
||||
const {
|
||||
agents,
|
||||
isLoading: isAgentsLoading,
|
||||
error: agentsError,
|
||||
} = useAgents();
|
||||
|
||||
const {
|
||||
data: standardAnswerCategories,
|
||||
isLoading: isStdAnswerLoading,
|
||||
error: stdAnswerError,
|
||||
} = useStandardAnswerCategories();
|
||||
|
||||
if (
|
||||
isDocSetsLoading ||
|
||||
isAgentsLoading ||
|
||||
(isPaidEnterprise && isStdAnswerLoading)
|
||||
) {
|
||||
return <SimpleLoader />;
|
||||
}
|
||||
|
||||
if (docSetsError || !documentSets) {
|
||||
return (
|
||||
<ErrorCallout
|
||||
errorTitle="Something went wrong :("
|
||||
errorMsg={`Failed to fetch document sets - ${
|
||||
docSetsError?.message ?? "unknown error"
|
||||
}`}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (agentsError) {
|
||||
return (
|
||||
<ErrorCallout
|
||||
errorTitle="Something went wrong :("
|
||||
errorMsg={`Failed to fetch agents - ${
|
||||
agentsError?.message ?? "unknown error"
|
||||
}`}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
const standardAnswerCategoryResponse: StandardAnswerCategoryResponse =
|
||||
isPaidEnterprise
|
||||
? {
|
||||
paidEnterpriseFeaturesEnabled: true,
|
||||
categories: standardAnswerCategories ?? [],
|
||||
...(stdAnswerError
|
||||
? { error: { message: String(stdAnswerError) } }
|
||||
: {}),
|
||||
}
|
||||
: { paidEnterpriseFeaturesEnabled: false };
|
||||
|
||||
return (
|
||||
<SlackChannelConfigCreationForm
|
||||
slack_bot_id={slackBotId}
|
||||
documentSets={documentSets}
|
||||
personas={agents}
|
||||
standardAnswerCategoryResponse={standardAnswerCategoryResponse}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
export default function Page(props: { params: Promise<{ "bot-id": string }> }) {
|
||||
const unwrappedParams = use(props.params);
|
||||
const router = useRouter();
|
||||
|
||||
async function NewChannelConfigPage(props: {
|
||||
params: Promise<{ "bot-id": string }>;
|
||||
}) {
|
||||
const unwrappedParams = await props.params;
|
||||
const slack_bot_id_raw = unwrappedParams?.["bot-id"] || null;
|
||||
const slack_bot_id = slack_bot_id_raw
|
||||
? parseInt(slack_bot_id_raw as string, 10)
|
||||
: null;
|
||||
|
||||
useEffect(() => {
|
||||
if (!slack_bot_id || isNaN(slack_bot_id)) {
|
||||
router.replace("/admin/bots");
|
||||
}
|
||||
}, [slack_bot_id, router]);
|
||||
|
||||
if (!slack_bot_id || isNaN(slack_bot_id)) {
|
||||
redirect("/admin/bots");
|
||||
return null;
|
||||
}
|
||||
|
||||
const [documentSetsResponse, agentsResponse, standardAnswerCategoryResponse] =
|
||||
await Promise.all([
|
||||
fetchSS("/manage/document-set") as Promise<Response>,
|
||||
fetchAgentsSS(),
|
||||
getStandardAnswerCategoriesIfEE(),
|
||||
]);
|
||||
|
||||
if (!documentSetsResponse.ok) {
|
||||
return (
|
||||
<ErrorCallout
|
||||
errorTitle="Something went wrong :("
|
||||
errorMsg={`Failed to fetch document sets - ${await documentSetsResponse.text()}`}
|
||||
/>
|
||||
);
|
||||
}
|
||||
const documentSets =
|
||||
(await documentSetsResponse.json()) as DocumentSetSummary[];
|
||||
|
||||
if (agentsResponse[1]) {
|
||||
return (
|
||||
<ErrorCallout
|
||||
errorTitle="Something went wrong :("
|
||||
errorMsg={`Failed to fetch agents - ${agentsResponse[1]}`}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<SettingsLayouts.Root>
|
||||
<SettingsLayouts.Header
|
||||
@@ -57,15 +113,8 @@ async function NewChannelConfigPage(props: {
|
||||
backButton
|
||||
/>
|
||||
<SettingsLayouts.Body>
|
||||
<SlackChannelConfigCreationForm
|
||||
slack_bot_id={slack_bot_id}
|
||||
documentSets={documentSets}
|
||||
personas={agentsResponse[0]}
|
||||
standardAnswerCategoryResponse={standardAnswerCategoryResponse}
|
||||
/>
|
||||
<NewChannelConfigContent slackBotId={slack_bot_id} />
|
||||
</SettingsLayouts.Body>
|
||||
</SettingsLayouts.Root>
|
||||
);
|
||||
}
|
||||
|
||||
export default NewChannelConfigPage;
|
||||
|
||||
@@ -1,82 +1,62 @@
|
||||
"use client";
|
||||
|
||||
import { use } from "react";
|
||||
import BackButton from "@/refresh-components/buttons/BackButton";
|
||||
import { ErrorCallout } from "@/components/ErrorCallout";
|
||||
import { ThreeDotsLoader } from "@/components/Loading";
|
||||
import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh";
|
||||
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
|
||||
import SlackChannelConfigsTable from "./SlackChannelConfigsTable";
|
||||
import { useSlackBot, useSlackChannelConfigsByBot } from "./hooks";
|
||||
import { ExistingSlackBotForm } from "../SlackBotUpdateForm";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
|
||||
function SlackBotEditPage({
|
||||
params,
|
||||
}: {
|
||||
params: Promise<{ "bot-id": string }>;
|
||||
}) {
|
||||
// Unwrap the params promise
|
||||
const unwrappedParams = use(params);
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import { SvgSlack } from "@opal/icons";
|
||||
import { getErrorMsg } from "@/lib/error";
|
||||
|
||||
function SlackBotEditContent({ botId }: { botId: string }) {
|
||||
const {
|
||||
data: slackBot,
|
||||
isLoading: isSlackBotLoading,
|
||||
error: slackBotError,
|
||||
refreshSlackBot,
|
||||
} = useSlackBot(Number(unwrappedParams["bot-id"]));
|
||||
} = useSlackBot(Number(botId));
|
||||
|
||||
const {
|
||||
data: slackChannelConfigs,
|
||||
isLoading: isSlackChannelConfigsLoading,
|
||||
error: slackChannelConfigsError,
|
||||
refreshSlackChannelConfigs,
|
||||
} = useSlackChannelConfigsByBot(Number(unwrappedParams["bot-id"]));
|
||||
} = useSlackChannelConfigsByBot(Number(botId));
|
||||
|
||||
if (isSlackBotLoading || isSlackChannelConfigsLoading) {
|
||||
return (
|
||||
<div className="flex justify-center items-center h-screen">
|
||||
<ThreeDotsLoader />
|
||||
</div>
|
||||
);
|
||||
return <SimpleLoader />;
|
||||
}
|
||||
|
||||
if (slackBotError || !slackBot) {
|
||||
const errorMsg =
|
||||
slackBotError?.info?.message ||
|
||||
slackBotError?.info?.detail ||
|
||||
"An unknown error occurred";
|
||||
return (
|
||||
<ErrorCallout
|
||||
errorTitle="Something went wrong :("
|
||||
errorMsg={`Failed to fetch Slack Bot ${unwrappedParams["bot-id"]}: ${errorMsg}`}
|
||||
errorMsg={`Failed to fetch Slack Bot ${botId}: ${getErrorMsg(
|
||||
slackBotError
|
||||
)}`}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (slackChannelConfigsError || !slackChannelConfigs) {
|
||||
const errorMsg =
|
||||
slackChannelConfigsError?.info?.message ||
|
||||
slackChannelConfigsError?.info?.detail ||
|
||||
"An unknown error occurred";
|
||||
return (
|
||||
<ErrorCallout
|
||||
errorTitle="Something went wrong :("
|
||||
errorMsg={`Failed to fetch Slack Bot ${unwrappedParams["bot-id"]}: ${errorMsg}`}
|
||||
errorMsg={`Failed to fetch Slack Bot ${botId}: ${getErrorMsg(
|
||||
slackChannelConfigsError
|
||||
)}`}
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<InstantSSRAutoRefresh />
|
||||
|
||||
<BackButton routerOverride="/admin/bots" />
|
||||
|
||||
<ExistingSlackBotForm
|
||||
existingSlackBot={slackBot}
|
||||
refreshSlackBot={refreshSlackBot}
|
||||
/>
|
||||
<Separator />
|
||||
|
||||
<div className="mt-8">
|
||||
<SlackChannelConfigsTable
|
||||
@@ -94,9 +74,19 @@ export default function Page({
|
||||
}: {
|
||||
params: Promise<{ "bot-id": string }>;
|
||||
}) {
|
||||
const unwrappedParams = use(params);
|
||||
|
||||
return (
|
||||
<>
|
||||
<SlackBotEditPage params={params} />
|
||||
</>
|
||||
<SettingsLayouts.Root>
|
||||
<SettingsLayouts.Header
|
||||
icon={SvgSlack}
|
||||
title="Edit Slack Bot"
|
||||
backButton
|
||||
separator
|
||||
/>
|
||||
<SettingsLayouts.Body>
|
||||
<SlackBotEditContent botId={unwrappedParams["bot-id"]} />
|
||||
</SettingsLayouts.Body>
|
||||
</SettingsLayouts.Root>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,12 +1,7 @@
|
||||
import BackButton from "@/refresh-components/buttons/BackButton";
|
||||
"use client";
|
||||
|
||||
import { NewSlackBotForm } from "../SlackBotCreationForm";
|
||||
|
||||
export default async function NewSlackBotPage() {
|
||||
return (
|
||||
<>
|
||||
<BackButton routerOverride="/admin/bots" />
|
||||
|
||||
<NewSlackBotForm />
|
||||
</>
|
||||
);
|
||||
export default function Page() {
|
||||
return <NewSlackBotForm />;
|
||||
}
|
||||
|
||||
@@ -1,11 +1,14 @@
|
||||
import React, { JSX, memo } from "react";
|
||||
import {
|
||||
ChatPacket,
|
||||
CODE_INTERPRETER_TOOL_TYPES,
|
||||
ImageGenerationToolPacket,
|
||||
Packet,
|
||||
PacketType,
|
||||
ReasoningPacket,
|
||||
SearchToolStart,
|
||||
StopReason,
|
||||
ToolCallArgumentDelta,
|
||||
} from "../../services/streamingModels";
|
||||
import {
|
||||
FullChatState,
|
||||
@@ -26,7 +29,6 @@ import { DeepResearchPlanRenderer } from "./timeline/renderers/deepresearch/Deep
|
||||
import { ResearchAgentRenderer } from "./timeline/renderers/deepresearch/ResearchAgentRenderer";
|
||||
import { WebSearchToolRenderer } from "./timeline/renderers/search/WebSearchToolRenderer";
|
||||
import { InternalSearchToolRenderer } from "./timeline/renderers/search/InternalSearchToolRenderer";
|
||||
import { SearchToolStart } from "../../services/streamingModels";
|
||||
|
||||
// Different types of chat packets using discriminated unions
|
||||
interface GroupedPackets {
|
||||
@@ -56,7 +58,12 @@ function isImageToolPacket(packet: Packet) {
|
||||
}
|
||||
|
||||
function isPythonToolPacket(packet: Packet) {
|
||||
return packet.obj.type === PacketType.PYTHON_TOOL_START;
|
||||
return (
|
||||
packet.obj.type === PacketType.PYTHON_TOOL_START ||
|
||||
(packet.obj.type === PacketType.TOOL_CALL_ARGUMENT_DELTA &&
|
||||
(packet.obj as ToolCallArgumentDelta).tool_type ===
|
||||
CODE_INTERPRETER_TOOL_TYPES.PYTHON)
|
||||
);
|
||||
}
|
||||
|
||||
function isCustomToolPacket(packet: Packet) {
|
||||
|
||||
@@ -10,6 +10,8 @@ import {
|
||||
Stop,
|
||||
ImageGenerationToolDelta,
|
||||
MessageStart,
|
||||
ToolCallArgumentDelta,
|
||||
CODE_INTERPRETER_TOOL_TYPES,
|
||||
} from "@/app/app/services/streamingModels";
|
||||
import { CitationMap } from "@/app/app/interfaces";
|
||||
import { OnyxDocument } from "@/lib/search/interfaces";
|
||||
@@ -138,6 +140,7 @@ const CONTENT_PACKET_TYPES_SET = new Set<PacketType>([
|
||||
PacketType.SEARCH_TOOL_START,
|
||||
PacketType.IMAGE_GENERATION_TOOL_START,
|
||||
PacketType.PYTHON_TOOL_START,
|
||||
PacketType.TOOL_CALL_ARGUMENT_DELTA,
|
||||
PacketType.CUSTOM_TOOL_START,
|
||||
PacketType.FILE_READER_START,
|
||||
PacketType.FETCH_TOOL_START,
|
||||
@@ -149,9 +152,16 @@ const CONTENT_PACKET_TYPES_SET = new Set<PacketType>([
|
||||
]);
|
||||
|
||||
function hasContentPackets(packets: Packet[]): boolean {
|
||||
return packets.some((packet) =>
|
||||
CONTENT_PACKET_TYPES_SET.has(packet.obj.type as PacketType)
|
||||
);
|
||||
return packets.some((packet) => {
|
||||
const type = packet.obj.type as PacketType;
|
||||
if (type === PacketType.TOOL_CALL_ARGUMENT_DELTA) {
|
||||
return (
|
||||
(packet.obj as ToolCallArgumentDelta).tool_type ===
|
||||
CODE_INTERPRETER_TOOL_TYPES.PYTHON
|
||||
);
|
||||
}
|
||||
return CONTENT_PACKET_TYPES_SET.has(type);
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -1,6 +1,13 @@
|
||||
import { Packet, PacketType } from "@/app/app/services/streamingModels";
|
||||
import {
|
||||
CODE_INTERPRETER_TOOL_TYPES,
|
||||
Packet,
|
||||
PacketType,
|
||||
ToolCallArgumentDelta,
|
||||
} from "@/app/app/services/streamingModels";
|
||||
|
||||
// Packet types with renderers supporting collapsed streaming mode
|
||||
// Packet types with renderers supporting collapsed streaming mode.
|
||||
// TOOL_CALL_ARGUMENT_DELTA is intentionally excluded here because it requires
|
||||
// a tool_type check — it's handled separately in stepSupportsCollapsedStreaming.
|
||||
export const COLLAPSED_STREAMING_PACKET_TYPES = new Set<PacketType>([
|
||||
PacketType.SEARCH_TOOL_START,
|
||||
PacketType.FETCH_TOOL_START,
|
||||
@@ -21,7 +28,13 @@ export const isSearchToolPackets = (packets: Packet[]): boolean =>
|
||||
|
||||
// Check if packets belong to a python tool
|
||||
export const isPythonToolPackets = (packets: Packet[]): boolean =>
|
||||
packets.some((p) => p.obj.type === PacketType.PYTHON_TOOL_START);
|
||||
packets.some(
|
||||
(p) =>
|
||||
p.obj.type === PacketType.PYTHON_TOOL_START ||
|
||||
(p.obj.type === PacketType.TOOL_CALL_ARGUMENT_DELTA &&
|
||||
(p.obj as ToolCallArgumentDelta).tool_type ===
|
||||
CODE_INTERPRETER_TOOL_TYPES.PYTHON)
|
||||
);
|
||||
|
||||
// Check if packets belong to reasoning
|
||||
export const isReasoningPackets = (packets: Packet[]): boolean =>
|
||||
@@ -29,8 +42,12 @@ export const isReasoningPackets = (packets: Packet[]): boolean =>
|
||||
|
||||
// Check if step supports collapsed streaming rendering mode
|
||||
export const stepSupportsCollapsedStreaming = (packets: Packet[]): boolean =>
|
||||
packets.some((p) =>
|
||||
COLLAPSED_STREAMING_PACKET_TYPES.has(p.obj.type as PacketType)
|
||||
packets.some(
|
||||
(p) =>
|
||||
COLLAPSED_STREAMING_PACKET_TYPES.has(p.obj.type as PacketType) ||
|
||||
(p.obj.type === PacketType.TOOL_CALL_ARGUMENT_DELTA &&
|
||||
(p.obj as ToolCallArgumentDelta).tool_type ===
|
||||
CODE_INTERPRETER_TOOL_TYPES.PYTHON)
|
||||
);
|
||||
|
||||
// Check if packets have content worth rendering in collapsed streaming mode.
|
||||
@@ -67,7 +84,13 @@ export const stepHasCollapsedStreamingContent = (
|
||||
// Python tool renders code/output from the start packet onward
|
||||
if (
|
||||
packetTypes.has(PacketType.PYTHON_TOOL_START) ||
|
||||
packetTypes.has(PacketType.PYTHON_TOOL_DELTA)
|
||||
packetTypes.has(PacketType.PYTHON_TOOL_DELTA) ||
|
||||
packets.some(
|
||||
(p) =>
|
||||
p.obj.type === PacketType.TOOL_CALL_ARGUMENT_DELTA &&
|
||||
(p.obj as ToolCallArgumentDelta).tool_type ===
|
||||
CODE_INTERPRETER_TOOL_TYPES.PYTHON
|
||||
)
|
||||
) {
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -4,7 +4,9 @@ import {
|
||||
PythonToolPacket,
|
||||
PythonToolStart,
|
||||
PythonToolDelta,
|
||||
ToolCallArgumentDelta,
|
||||
SectionEnd,
|
||||
CODE_INTERPRETER_TOOL_TYPES,
|
||||
} from "@/app/app/services/streamingModels";
|
||||
import {
|
||||
MessageRenderer,
|
||||
@@ -39,6 +41,18 @@ function HighlightedPythonCode({ code }: { code: string }) {
|
||||
|
||||
// Helper function to construct current Python execution state
|
||||
function constructCurrentPythonState(packets: PythonToolPacket[]) {
|
||||
// Accumulate streaming code from argument deltas (arrives before PythonToolStart)
|
||||
const streamingCode = packets
|
||||
.filter(
|
||||
(packet) =>
|
||||
packet.obj.type === PacketType.TOOL_CALL_ARGUMENT_DELTA &&
|
||||
(packet.obj as ToolCallArgumentDelta).tool_type ===
|
||||
CODE_INTERPRETER_TOOL_TYPES.PYTHON
|
||||
)
|
||||
.map((packet) =>
|
||||
String((packet.obj as ToolCallArgumentDelta).argument_deltas.code ?? "")
|
||||
)
|
||||
.join("");
|
||||
const pythonStart = packets.find(
|
||||
(packet) => packet.obj.type === PacketType.PYTHON_TOOL_START
|
||||
)?.obj as PythonToolStart | null;
|
||||
@@ -51,7 +65,8 @@ function constructCurrentPythonState(packets: PythonToolPacket[]) {
|
||||
packet.obj.type === PacketType.ERROR
|
||||
)?.obj as SectionEnd | null;
|
||||
|
||||
const code = pythonStart?.code || "";
|
||||
// Use complete code from PythonToolStart if available, else use streamed code.
|
||||
const code = pythonStart?.code || streamingCode;
|
||||
const stdout = pythonDeltas
|
||||
.map((delta) => delta?.stdout || "")
|
||||
.filter((s) => s)
|
||||
@@ -61,6 +76,7 @@ function constructCurrentPythonState(packets: PythonToolPacket[]) {
|
||||
.filter((s) => s)
|
||||
.join("");
|
||||
const fileIds = pythonDeltas.flatMap((delta) => delta?.file_ids || []);
|
||||
const isStreaming = !pythonStart && streamingCode.length > 0;
|
||||
const isExecuting = pythonStart && !pythonEnd;
|
||||
const isComplete = pythonStart && pythonEnd;
|
||||
const hasError = stderr.length > 0;
|
||||
@@ -70,6 +86,7 @@ function constructCurrentPythonState(packets: PythonToolPacket[]) {
|
||||
stdout,
|
||||
stderr,
|
||||
fileIds,
|
||||
isStreaming,
|
||||
isExecuting,
|
||||
isComplete,
|
||||
hasError,
|
||||
@@ -82,8 +99,16 @@ export const PythonToolRenderer: MessageRenderer<PythonToolPacket, {}> = ({
|
||||
renderType,
|
||||
children,
|
||||
}) => {
|
||||
const { code, stdout, stderr, fileIds, isExecuting, isComplete, hasError } =
|
||||
constructCurrentPythonState(packets);
|
||||
const {
|
||||
code,
|
||||
stdout,
|
||||
stderr,
|
||||
fileIds,
|
||||
isStreaming,
|
||||
isExecuting,
|
||||
isComplete,
|
||||
hasError,
|
||||
} = constructCurrentPythonState(packets);
|
||||
|
||||
useEffect(() => {
|
||||
if (isComplete) {
|
||||
@@ -92,6 +117,9 @@ export const PythonToolRenderer: MessageRenderer<PythonToolPacket, {}> = ({
|
||||
}, [isComplete, onComplete]);
|
||||
|
||||
const status = useMemo(() => {
|
||||
if (isStreaming) {
|
||||
return "Writing code...";
|
||||
}
|
||||
if (isExecuting) {
|
||||
return "Executing Python code...";
|
||||
}
|
||||
@@ -102,13 +130,13 @@ export const PythonToolRenderer: MessageRenderer<PythonToolPacket, {}> = ({
|
||||
return "Python execution completed";
|
||||
}
|
||||
return "Python execution";
|
||||
}, [isComplete, isExecuting, hasError]);
|
||||
}, [isStreaming, isComplete, isExecuting, hasError]);
|
||||
|
||||
// Shared content for all states - used by both FULL and compact modes
|
||||
const content = (
|
||||
<div className="flex flex-col mb-1 space-y-2">
|
||||
{/* Loading indicator when executing */}
|
||||
{isExecuting && (
|
||||
{/* Loading indicator when streaming or executing */}
|
||||
{(isStreaming || isExecuting) && (
|
||||
<div className="flex items-center gap-2 text-sm text-muted-foreground">
|
||||
<div className="flex gap-0.5">
|
||||
<div className="w-1 h-1 bg-current rounded-full animate-pulse"></div>
|
||||
@@ -121,7 +149,7 @@ export const PythonToolRenderer: MessageRenderer<PythonToolPacket, {}> = ({
|
||||
style={{ animationDelay: "0.2s" }}
|
||||
></div>
|
||||
</div>
|
||||
<span>Running code...</span>
|
||||
<span>{isStreaming ? "Writing code..." : "Running code..."}</span>
|
||||
</div>
|
||||
)}
|
||||
|
||||
|
||||
@@ -16,6 +16,7 @@ export function isToolPacket(
|
||||
PacketType.SEARCH_TOOL_DOCUMENTS_DELTA,
|
||||
PacketType.PYTHON_TOOL_START,
|
||||
PacketType.PYTHON_TOOL_DELTA,
|
||||
PacketType.TOOL_CALL_ARGUMENT_DELTA,
|
||||
PacketType.CUSTOM_TOOL_START,
|
||||
PacketType.CUSTOM_TOOL_DELTA,
|
||||
PacketType.FILE_READER_START,
|
||||
|
||||
@@ -27,6 +27,9 @@ export enum PacketType {
|
||||
FETCH_TOOL_URLS = "open_url_urls",
|
||||
FETCH_TOOL_DOCUMENTS = "open_url_documents",
|
||||
|
||||
// Tool call argument delta (streams tool args before tool executes)
|
||||
TOOL_CALL_ARGUMENT_DELTA = "tool_call_argument_delta",
|
||||
|
||||
// Custom tool packets
|
||||
CUSTOM_TOOL_START = "custom_tool_start",
|
||||
CUSTOM_TOOL_DELTA = "custom_tool_delta",
|
||||
@@ -59,6 +62,10 @@ export enum PacketType {
|
||||
INTERMEDIATE_REPORT_CITED_DOCS = "intermediate_report_cited_docs",
|
||||
}
|
||||
|
||||
export const CODE_INTERPRETER_TOOL_TYPES = {
|
||||
PYTHON: "python",
|
||||
} as const;
|
||||
|
||||
// Basic Message Packets
|
||||
export interface MessageStart extends BaseObj {
|
||||
id: string;
|
||||
@@ -149,6 +156,13 @@ export interface PythonToolDelta extends BaseObj {
|
||||
file_ids: string[];
|
||||
}
|
||||
|
||||
export interface ToolCallArgumentDelta extends BaseObj {
|
||||
type: "tool_call_argument_delta";
|
||||
tool_type: string;
|
||||
tool_id: string;
|
||||
argument_deltas: Record<string, unknown>;
|
||||
}
|
||||
|
||||
export interface FetchToolStart extends BaseObj {
|
||||
type: "open_url_start";
|
||||
}
|
||||
@@ -294,6 +308,7 @@ export type ImageGenerationToolObj =
|
||||
export type PythonToolObj =
|
||||
| PythonToolStart
|
||||
| PythonToolDelta
|
||||
| ToolCallArgumentDelta
|
||||
| SectionEnd
|
||||
| PacketError;
|
||||
export type FetchToolObj =
|
||||
|
||||
@@ -36,6 +36,7 @@ export interface Settings {
|
||||
|
||||
// User Knowledge settings
|
||||
user_knowledge_enabled?: boolean;
|
||||
user_file_max_upload_size_mb?: number | null;
|
||||
|
||||
// Connector settings
|
||||
show_extra_connectors?: boolean;
|
||||
|
||||
10
web/src/lib/error.ts
Normal file
10
web/src/lib/error.ts
Normal file
@@ -0,0 +1,10 @@
|
||||
/**
|
||||
* Extract a human-readable error message from an SWR error object.
|
||||
* SWR errors from `errorHandlingFetcher` attach `info.message` or `info.detail`.
|
||||
*/
|
||||
export function getErrorMsg(
|
||||
error: { info?: { message?: string; detail?: string } } | null | undefined,
|
||||
fallback = "An unknown error occurred"
|
||||
): string {
|
||||
return error?.info?.message || error?.info?.detail || fallback;
|
||||
}
|
||||
@@ -43,6 +43,7 @@ import { useAppRouter } from "@/hooks/appNavigation";
|
||||
import { ChatFileType } from "@/app/app/interfaces";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { useProjects } from "@/lib/hooks/useProjects";
|
||||
import { SettingsContext } from "@/providers/SettingsProvider";
|
||||
|
||||
export type { Project, ProjectFile } from "@/app/app/projects/projectsService";
|
||||
|
||||
@@ -84,6 +85,8 @@ function buildFileKey(file: File): string {
|
||||
return `${file.size}|${namePrefix}`;
|
||||
}
|
||||
|
||||
const DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB = 50;
|
||||
|
||||
interface ProjectsContextType {
|
||||
projects: Project[];
|
||||
recentFiles: ProjectFile[];
|
||||
@@ -157,6 +160,7 @@ export function ProjectsProvider({ children }: ProjectsProviderProps) {
|
||||
new Map()
|
||||
);
|
||||
const route = useAppRouter();
|
||||
const settingsContext = useContext(SettingsContext);
|
||||
|
||||
// Use SWR's mutate to refresh projects - returns the new data
|
||||
const fetchProjects = useCallback(async (): Promise<Project[]> => {
|
||||
@@ -336,16 +340,40 @@ export function ProjectsProvider({ children }: ProjectsProviderProps) {
|
||||
onSuccess?: (uploaded: CategorizedFiles) => void,
|
||||
onFailure?: (failedTempIds: string[]) => void
|
||||
): Promise<ProjectFile[]> => {
|
||||
const optimisticFiles = files.map((f) =>
|
||||
const rawMax = settingsContext?.settings?.user_file_max_upload_size_mb;
|
||||
const maxUploadSizeMb =
|
||||
rawMax && rawMax > 0 ? rawMax : DEFAULT_USER_FILE_MAX_UPLOAD_SIZE_MB;
|
||||
const maxUploadSizeBytes = maxUploadSizeMb * 1024 * 1024;
|
||||
|
||||
const oversizedFiles = files.filter(
|
||||
(file) => file.size > maxUploadSizeBytes
|
||||
);
|
||||
const validFiles = files.filter(
|
||||
(file) => file.size <= maxUploadSizeBytes
|
||||
);
|
||||
|
||||
if (oversizedFiles.length > 0) {
|
||||
const skippedNames = oversizedFiles.map((file) => file.name).join(", ");
|
||||
toast.warning(
|
||||
`Skipped ${oversizedFiles.length} oversized file(s) (>${maxUploadSizeMb} MB): ${skippedNames}`
|
||||
);
|
||||
}
|
||||
|
||||
if (validFiles.length === 0) {
|
||||
onFailure?.([]);
|
||||
return [];
|
||||
}
|
||||
|
||||
const optimisticFiles = validFiles.map((f) =>
|
||||
createOptimisticFile(f, projectId)
|
||||
);
|
||||
const tempIdMap = getTempIdMap(files, optimisticFiles);
|
||||
const tempIdMap = getTempIdMap(validFiles, optimisticFiles);
|
||||
setAllRecentFiles((prev) => [...optimisticFiles, ...prev]);
|
||||
if (projectId) {
|
||||
setAllCurrentProjectFiles((prev) => [...optimisticFiles, ...prev]);
|
||||
projectToUploadFilesMapRef.current.set(projectId, optimisticFiles);
|
||||
}
|
||||
svcUploadFiles(files, projectId, tempIdMap)
|
||||
svcUploadFiles(validFiles, projectId, tempIdMap)
|
||||
.then((uploaded) => {
|
||||
const uploadedFiles = uploaded.user_files || [];
|
||||
const tempIdToUploadedFileMap = new Map(
|
||||
@@ -445,6 +473,7 @@ export function ProjectsProvider({ children }: ProjectsProviderProps) {
|
||||
refreshCurrentProjectDetails,
|
||||
refreshRecentFiles,
|
||||
removeOptimisticFilesByTempIds,
|
||||
settingsContext,
|
||||
]
|
||||
);
|
||||
|
||||
|
||||
166
web/src/providers/__tests__/ProjectsContext.test.tsx
Normal file
166
web/src/providers/__tests__/ProjectsContext.test.tsx
Normal file
@@ -0,0 +1,166 @@
|
||||
import React, { PropsWithChildren } from "react";
|
||||
import { act, renderHook } from "@testing-library/react";
|
||||
import {
|
||||
ProjectsProvider,
|
||||
useProjectsContext,
|
||||
} from "@/providers/ProjectsContext";
|
||||
import { SettingsContext } from "@/providers/SettingsProvider";
|
||||
import { CombinedSettings } from "@/interfaces/settings";
|
||||
import type { ProjectFile } from "@/app/app/projects/projectsService";
|
||||
|
||||
const mockUploadFiles = jest.fn();
|
||||
const mockGetRecentFiles = jest.fn();
|
||||
const mockToastWarning = jest.fn();
|
||||
|
||||
jest.mock("next/navigation", () => ({
|
||||
useSearchParams: () => ({
|
||||
get: () => null,
|
||||
}),
|
||||
}));
|
||||
|
||||
jest.mock("@/hooks/appNavigation", () => ({
|
||||
useAppRouter: () => jest.fn(),
|
||||
}));
|
||||
|
||||
jest.mock("@/lib/hooks/useProjects", () => ({
|
||||
useProjects: () => ({
|
||||
projects: [],
|
||||
refreshProjects: jest.fn().mockResolvedValue([]),
|
||||
}),
|
||||
}));
|
||||
|
||||
jest.mock("@/hooks/useToast", () => ({
|
||||
toast: {
|
||||
warning: (...args: unknown[]) => mockToastWarning(...args),
|
||||
error: jest.fn(),
|
||||
success: jest.fn(),
|
||||
},
|
||||
}));
|
||||
|
||||
jest.mock("@/app/app/projects/projectsService", () => {
|
||||
const actual = jest.requireActual("@/app/app/projects/projectsService");
|
||||
return {
|
||||
...actual,
|
||||
fetchProjects: jest.fn().mockResolvedValue([]),
|
||||
createProject: jest.fn(),
|
||||
uploadFiles: (...args: unknown[]) => mockUploadFiles(...args),
|
||||
getRecentFiles: (...args: unknown[]) => mockGetRecentFiles(...args),
|
||||
getFilesInProject: jest.fn().mockResolvedValue([]),
|
||||
getProject: jest.fn(),
|
||||
getProjectInstructions: jest.fn(),
|
||||
upsertProjectInstructions: jest.fn(),
|
||||
getProjectDetails: jest.fn(),
|
||||
renameProject: jest.fn(),
|
||||
deleteProject: jest.fn(),
|
||||
deleteUserFile: jest.fn(),
|
||||
getUserFileStatuses: jest.fn().mockResolvedValue([]),
|
||||
unlinkFileFromProject: jest.fn(),
|
||||
linkFileToProject: jest.fn(),
|
||||
};
|
||||
});
|
||||
|
||||
const settingsValue: CombinedSettings = {
|
||||
settings: {
|
||||
user_file_max_upload_size_mb: 1,
|
||||
} as CombinedSettings["settings"],
|
||||
enterpriseSettings: null,
|
||||
customAnalyticsScript: null,
|
||||
webVersion: null,
|
||||
webDomain: null,
|
||||
isSearchModeAvailable: true,
|
||||
};
|
||||
|
||||
const wrapper = ({ children }: PropsWithChildren) => (
|
||||
<SettingsContext.Provider value={settingsValue}>
|
||||
<ProjectsProvider>{children}</ProjectsProvider>
|
||||
</SettingsContext.Provider>
|
||||
);
|
||||
|
||||
describe("ProjectsContext beginUpload size precheck", () => {
|
||||
beforeEach(() => {
|
||||
mockUploadFiles.mockReset();
|
||||
mockGetRecentFiles.mockReset();
|
||||
mockToastWarning.mockReset();
|
||||
|
||||
mockUploadFiles.mockResolvedValue({
|
||||
user_files: [],
|
||||
rejected_files: [],
|
||||
});
|
||||
mockGetRecentFiles.mockResolvedValue([]);
|
||||
});
|
||||
|
||||
it("only sends valid files to the upload API when oversized files are present", async () => {
|
||||
const { result } = renderHook(() => useProjectsContext(), { wrapper });
|
||||
|
||||
const valid = new File(["small"], "small.txt", { type: "text/plain" });
|
||||
const oversized = new File([new Uint8Array(2 * 1024 * 1024)], "big.txt", {
|
||||
type: "text/plain",
|
||||
});
|
||||
|
||||
let optimisticFiles: ProjectFile[] = [];
|
||||
await act(async () => {
|
||||
optimisticFiles = await result.current.beginUpload(
|
||||
[valid, oversized],
|
||||
null
|
||||
);
|
||||
});
|
||||
|
||||
expect(mockUploadFiles).toHaveBeenCalledTimes(1);
|
||||
const [uploadedFiles] = mockUploadFiles.mock.calls[0];
|
||||
expect((uploadedFiles as File[]).map((f) => f.name)).toEqual(["small.txt"]);
|
||||
expect(optimisticFiles.map((f) => f.name)).toEqual(["small.txt"]);
|
||||
expect(mockToastWarning).toHaveBeenCalledTimes(1);
|
||||
});
|
||||
|
||||
it("uploads all files when none are oversized", async () => {
|
||||
const { result } = renderHook(() => useProjectsContext(), { wrapper });
|
||||
|
||||
const first = new File(["small"], "first.txt", { type: "text/plain" });
|
||||
const second = new File(["small"], "second.txt", { type: "text/plain" });
|
||||
|
||||
let optimisticFiles: ProjectFile[] = [];
|
||||
await act(async () => {
|
||||
optimisticFiles = await result.current.beginUpload([first, second], null);
|
||||
});
|
||||
|
||||
expect(mockUploadFiles).toHaveBeenCalledTimes(1);
|
||||
const [uploadedFiles] = mockUploadFiles.mock.calls[0];
|
||||
expect((uploadedFiles as File[]).map((f) => f.name)).toEqual([
|
||||
"first.txt",
|
||||
"second.txt",
|
||||
]);
|
||||
expect(mockToastWarning).not.toHaveBeenCalled();
|
||||
expect(optimisticFiles.map((f) => f.name)).toEqual([
|
||||
"first.txt",
|
||||
"second.txt",
|
||||
]);
|
||||
});
|
||||
|
||||
it("does not call upload API when all files are oversized", async () => {
|
||||
const { result } = renderHook(() => useProjectsContext(), { wrapper });
|
||||
|
||||
const oversized = new File(
|
||||
[new Uint8Array(2 * 1024 * 1024)],
|
||||
"too-big.txt",
|
||||
{ type: "text/plain" }
|
||||
);
|
||||
const onSuccess = jest.fn();
|
||||
const onFailure = jest.fn();
|
||||
|
||||
let optimisticFiles: ProjectFile[] = [];
|
||||
await act(async () => {
|
||||
optimisticFiles = await result.current.beginUpload(
|
||||
[oversized],
|
||||
null,
|
||||
onSuccess,
|
||||
onFailure
|
||||
);
|
||||
});
|
||||
|
||||
expect(mockUploadFiles).not.toHaveBeenCalled();
|
||||
expect(optimisticFiles).toEqual([]);
|
||||
expect(mockToastWarning).toHaveBeenCalledTimes(1);
|
||||
expect(onSuccess).not.toHaveBeenCalled();
|
||||
expect(onFailure).toHaveBeenCalledWith([]);
|
||||
});
|
||||
});
|
||||
@@ -21,10 +21,47 @@ type DefaultModelInfo = {
|
||||
model_name: string;
|
||||
} | null;
|
||||
|
||||
type ProviderModelConfig = {
|
||||
name: string;
|
||||
is_visible: boolean;
|
||||
};
|
||||
|
||||
function uniqueName(prefix: string): string {
|
||||
return `${prefix}-${Date.now()}-${Math.random().toString(36).slice(2, 8)}`;
|
||||
}
|
||||
|
||||
function normalizeAlphaNum(input: string): string {
|
||||
return input.toLowerCase().replace(/[^a-z0-9]/g, "");
|
||||
}
|
||||
|
||||
function modelTokenVariants(modelName: string): string[][] {
|
||||
return modelName
|
||||
.toLowerCase()
|
||||
.split(/[^a-z0-9]+/)
|
||||
.filter((token) => token.length > 0)
|
||||
.map((token) => {
|
||||
// Display names may shorten long numeric segments to suffixes.
|
||||
if (/^\d+$/.test(token) && token.length > 5) {
|
||||
return [token, token.slice(-5)];
|
||||
}
|
||||
return [token];
|
||||
});
|
||||
}
|
||||
|
||||
function textMatchesModel(modelName: string, candidateText: string): boolean {
|
||||
const normalizedCandidate = normalizeAlphaNum(candidateText);
|
||||
if (!normalizedCandidate) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const tokenVariants = modelTokenVariants(modelName);
|
||||
return tokenVariants.every((variants) =>
|
||||
variants.some((variant) =>
|
||||
normalizedCandidate.includes(normalizeAlphaNum(variant))
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
async function getAdminLLMProviderResponse(page: Page) {
|
||||
const response = await page.request.get(`${BASE_URL}/api/admin/llm/provider`);
|
||||
expect(response.ok()).toBeTruthy();
|
||||
@@ -50,6 +87,18 @@ async function createPublicProvider(
|
||||
providerName: string,
|
||||
modelName: string = "gpt-4o"
|
||||
): Promise<number> {
|
||||
return createPublicProviderWithModels(page, providerName, [
|
||||
{ name: modelName, is_visible: true },
|
||||
]);
|
||||
}
|
||||
|
||||
async function createPublicProviderWithModels(
|
||||
page: Page,
|
||||
providerName: string,
|
||||
modelConfigurations: ProviderModelConfig[]
|
||||
): Promise<number> {
|
||||
expect(modelConfigurations.length).toBeGreaterThan(0);
|
||||
|
||||
const response = await page.request.put(
|
||||
`${BASE_URL}/api/admin/llm/provider?is_creation=true`,
|
||||
{
|
||||
@@ -60,7 +109,7 @@ async function createPublicProvider(
|
||||
is_public: true,
|
||||
groups: [],
|
||||
personas: [],
|
||||
model_configurations: [{ name: modelName, is_visible: true }],
|
||||
model_configurations: modelConfigurations,
|
||||
},
|
||||
}
|
||||
);
|
||||
@@ -69,6 +118,86 @@ async function createPublicProvider(
|
||||
return data.id;
|
||||
}
|
||||
|
||||
async function navigateToAdminLlmPageFromChat(page: Page): Promise<void> {
|
||||
await page.goto(LLM_SETUP_URL);
|
||||
await page.waitForURL("**/admin/configuration/llm**");
|
||||
await expect(page.getByLabel("admin-page-title")).toHaveText(
|
||||
/^Language Models/
|
||||
);
|
||||
}
|
||||
|
||||
async function exitAdminToChat(page: Page): Promise<void> {
|
||||
await page.goto("/app");
|
||||
await page.waitForURL("**/app**");
|
||||
await page
|
||||
.locator("#onyx-chat-input-textarea")
|
||||
.waitFor({ state: "visible", timeout: 15000 });
|
||||
}
|
||||
|
||||
async function isModelVisibleInChatProviders(
|
||||
page: Page,
|
||||
modelName: string
|
||||
): Promise<boolean> {
|
||||
const response = await page.request.get(`${BASE_URL}/api/llm/provider`);
|
||||
expect(response.ok()).toBeTruthy();
|
||||
|
||||
const data = (await response.json()) as {
|
||||
providers: {
|
||||
model_configurations: { name: string; is_visible: boolean }[];
|
||||
}[];
|
||||
};
|
||||
|
||||
return data.providers.some((provider) =>
|
||||
provider.model_configurations.some(
|
||||
(model) => model.name === modelName && model.is_visible
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
async function expectModelVisibilityInChatProviders(
|
||||
page: Page,
|
||||
modelName: string,
|
||||
expectedVisible: boolean
|
||||
): Promise<void> {
|
||||
await expect
|
||||
.poll(() => isModelVisibleInChatProviders(page, modelName), {
|
||||
timeout: 30000,
|
||||
})
|
||||
.toBe(expectedVisible);
|
||||
}
|
||||
|
||||
async function getModelCountInChatSelector(
|
||||
page: Page,
|
||||
modelName: string
|
||||
): Promise<number> {
|
||||
const dialog = page.locator('[role="dialog"]').first();
|
||||
|
||||
// When used in expect.poll retries, a previous attempt may leave the
|
||||
// popover open. Ensure a clean state before toggling it.
|
||||
if (await dialog.isVisible()) {
|
||||
await page.keyboard.press("Escape");
|
||||
await dialog.waitFor({ state: "hidden", timeout: 5000 });
|
||||
}
|
||||
|
||||
await page.getByTestId("AppInputBar/llm-popover-trigger").click();
|
||||
await dialog.waitFor({ state: "visible", timeout: 10000 });
|
||||
|
||||
await dialog.getByPlaceholder("Search models...").fill(modelName);
|
||||
const optionButtons = dialog.getByRole("button");
|
||||
const optionTexts = await optionButtons.allTextContents();
|
||||
const uniqueOptionTexts = Array.from(
|
||||
new Set(optionTexts.map((text) => text.trim()))
|
||||
);
|
||||
const count = uniqueOptionTexts.filter((text) =>
|
||||
textMatchesModel(modelName, text)
|
||||
).length;
|
||||
|
||||
await page.keyboard.press("Escape");
|
||||
await dialog.waitFor({ state: "hidden", timeout: 10000 });
|
||||
|
||||
return count;
|
||||
}
|
||||
|
||||
async function getProviderByName(
|
||||
page: Page,
|
||||
providerName: string
|
||||
@@ -272,4 +401,132 @@ test.describe("LLM Provider Setup @exclusive", () => {
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
test("adding a hidden model on an existing provider shows it in chat after one save", async ({
|
||||
page,
|
||||
}) => {
|
||||
await page.route("**/api/admin/llm/test", async (route) => {
|
||||
await route.fulfill({
|
||||
status: 200,
|
||||
contentType: "application/json",
|
||||
body: JSON.stringify({ success: true }),
|
||||
});
|
||||
});
|
||||
|
||||
const providerName = uniqueName("PW Provider Add Model");
|
||||
const ts = Date.now();
|
||||
const alwaysVisibleModel = `pw-visible-${ts}-base`;
|
||||
const modelToEnable = `pw-hidden-${ts}-to-enable`;
|
||||
|
||||
const providerId = await createPublicProviderWithModels(
|
||||
page,
|
||||
providerName,
|
||||
[
|
||||
{ name: alwaysVisibleModel, is_visible: true },
|
||||
{ name: modelToEnable, is_visible: false },
|
||||
]
|
||||
);
|
||||
providersToCleanup.push(providerId);
|
||||
await expectModelVisibilityInChatProviders(page, modelToEnable, false);
|
||||
|
||||
await page.goto("/app");
|
||||
await page.waitForLoadState("networkidle");
|
||||
await page
|
||||
.locator("#onyx-chat-input-textarea")
|
||||
.waitFor({ state: "visible", timeout: 15000 });
|
||||
|
||||
await expect
|
||||
.poll(() => getModelCountInChatSelector(page, modelToEnable), {
|
||||
timeout: 15000,
|
||||
})
|
||||
.toBe(0);
|
||||
|
||||
await navigateToAdminLlmPageFromChat(page);
|
||||
|
||||
const editModal = await openProviderEditModal(page, providerName);
|
||||
await editModal.getByText(modelToEnable, { exact: true }).click();
|
||||
|
||||
const updateButton = editModal.getByRole("button", { name: "Update" });
|
||||
const providerUpdateResponsePromise = page.waitForResponse(
|
||||
(response) =>
|
||||
response.url().includes("/api/admin/llm/provider") &&
|
||||
response.request().method() === "PUT"
|
||||
);
|
||||
await expect(updateButton).toBeEnabled({ timeout: 10000 });
|
||||
await updateButton.click();
|
||||
await providerUpdateResponsePromise;
|
||||
await expect(editModal).not.toBeVisible({ timeout: 30000 });
|
||||
await expectModelVisibilityInChatProviders(page, modelToEnable, true);
|
||||
|
||||
await exitAdminToChat(page);
|
||||
await expect
|
||||
.poll(() => getModelCountInChatSelector(page, modelToEnable), {
|
||||
timeout: 15000,
|
||||
})
|
||||
.toBe(1);
|
||||
});
|
||||
|
||||
test("removing a visible model on an existing provider hides it in chat after one save", async ({
|
||||
page,
|
||||
}) => {
|
||||
await page.route("**/api/admin/llm/test", async (route) => {
|
||||
await route.fulfill({
|
||||
status: 200,
|
||||
contentType: "application/json",
|
||||
body: JSON.stringify({ success: true }),
|
||||
});
|
||||
});
|
||||
|
||||
const providerName = uniqueName("PW Provider Remove Model");
|
||||
const ts = Date.now();
|
||||
const alwaysVisibleModel = `pw-visible-${ts}-base`;
|
||||
const modelToDisable = `pw-visible-${ts}-to-disable`;
|
||||
|
||||
const providerId = await createPublicProviderWithModels(
|
||||
page,
|
||||
providerName,
|
||||
[
|
||||
{ name: alwaysVisibleModel, is_visible: true },
|
||||
{ name: modelToDisable, is_visible: true },
|
||||
]
|
||||
);
|
||||
providersToCleanup.push(providerId);
|
||||
await expectModelVisibilityInChatProviders(page, modelToDisable, true);
|
||||
|
||||
await page.goto("/app");
|
||||
await page.waitForLoadState("networkidle");
|
||||
await page
|
||||
.locator("#onyx-chat-input-textarea")
|
||||
.waitFor({ state: "visible", timeout: 15000 });
|
||||
|
||||
await expect
|
||||
.poll(() => getModelCountInChatSelector(page, modelToDisable), {
|
||||
timeout: 15000,
|
||||
})
|
||||
.toBe(1);
|
||||
|
||||
await navigateToAdminLlmPageFromChat(page);
|
||||
|
||||
const editModal = await openProviderEditModal(page, providerName);
|
||||
await editModal.getByText(modelToDisable, { exact: true }).click();
|
||||
|
||||
const updateButton = editModal.getByRole("button", { name: "Update" });
|
||||
const providerUpdateResponsePromise = page.waitForResponse(
|
||||
(response) =>
|
||||
response.url().includes("/api/admin/llm/provider") &&
|
||||
response.request().method() === "PUT"
|
||||
);
|
||||
await expect(updateButton).toBeEnabled({ timeout: 10000 });
|
||||
await updateButton.click();
|
||||
await providerUpdateResponsePromise;
|
||||
await expect(editModal).not.toBeVisible({ timeout: 30000 });
|
||||
await expectModelVisibilityInChatProviders(page, modelToDisable, false);
|
||||
|
||||
await exitAdminToChat(page);
|
||||
await expect
|
||||
.poll(() => getModelCountInChatSelector(page, modelToDisable), {
|
||||
timeout: 15000,
|
||||
})
|
||||
.toBe(0);
|
||||
});
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user