mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-02 22:25:47 +00:00
Compare commits
2 Commits
ci_unhealt
...
content-re
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8f5d7e271a | ||
|
|
bb6e20614d |
@@ -54,7 +54,6 @@ runs:
|
||||
shell: bash
|
||||
env:
|
||||
RUNS_ON_ECR_CACHE: ${{ inputs.runs-on-ecr-cache }}
|
||||
INTEGRATION_REPOSITORY: ${{ inputs.runs-on-ecr-cache }}
|
||||
TAG: nightly-llm-it-${{ inputs.run-id }}
|
||||
CACHE_SUFFIX: ${{ steps.format-branch.outputs.cache-suffix }}
|
||||
HEAD_SHA: ${{ inputs.github-sha }}
|
||||
|
||||
@@ -160,7 +160,7 @@ jobs:
|
||||
cd deployment/docker_compose
|
||||
|
||||
# Get list of running containers
|
||||
containers=$(docker compose -f docker-compose.yml -f docker-compose.dev.yml ps -q)
|
||||
containers=$(docker compose -f docker-compose.yml -f docker-compose.dev.yml -f docker-compose.opensearch.yml ps -q)
|
||||
|
||||
# Collect logs from each container
|
||||
for container in $containers; do
|
||||
|
||||
@@ -1,51 +0,0 @@
|
||||
"""Add INDEXING to UserFileStatus
|
||||
|
||||
Revision ID: 4a1e4b1c89d2
|
||||
Revises: 6b3b4083c5aa
|
||||
Create Date: 2026-02-28 00:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
|
||||
revision = "4a1e4b1c89d2"
|
||||
down_revision = "6b3b4083c5aa"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
TABLE = "user_file"
|
||||
COLUMN = "status"
|
||||
CONSTRAINT_NAME = "ck_user_file_status"
|
||||
|
||||
OLD_VALUES = ("PROCESSING", "COMPLETED", "FAILED", "CANCELED", "DELETING")
|
||||
NEW_VALUES = ("PROCESSING", "INDEXING", "COMPLETED", "FAILED", "CANCELED", "DELETING")
|
||||
|
||||
|
||||
def _drop_status_check_constraint() -> None:
|
||||
"""Drop the existing CHECK constraint on user_file.status.
|
||||
|
||||
The constraint name is auto-generated by SQLAlchemy and unknown,
|
||||
so we look it up via the inspector.
|
||||
"""
|
||||
inspector = sa.inspect(op.get_bind())
|
||||
for constraint in inspector.get_check_constraints(TABLE):
|
||||
if COLUMN in constraint.get("sqltext", ""):
|
||||
constraint_name = constraint["name"]
|
||||
if constraint_name is not None:
|
||||
op.drop_constraint(constraint_name, TABLE, type_="check")
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
_drop_status_check_constraint()
|
||||
in_clause = ", ".join(f"'{v}'" for v in NEW_VALUES)
|
||||
op.create_check_constraint(CONSTRAINT_NAME, TABLE, f"{COLUMN} IN ({in_clause})")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.execute(
|
||||
f"UPDATE {TABLE} SET {COLUMN} = 'PROCESSING' WHERE {COLUMN} = 'INDEXING'"
|
||||
)
|
||||
op.drop_constraint(CONSTRAINT_NAME, TABLE, type_="check")
|
||||
in_clause = ", ".join(f"'{v}'" for v in OLD_VALUES)
|
||||
op.create_check_constraint(CONSTRAINT_NAME, TABLE, f"{COLUMN} IN ({in_clause})")
|
||||
@@ -1,112 +0,0 @@
|
||||
"""persona cleanup and featured
|
||||
|
||||
Revision ID: 6b3b4083c5aa
|
||||
Revises: 57122d037335
|
||||
Create Date: 2026-02-26 12:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "6b3b4083c5aa"
|
||||
down_revision = "57122d037335"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add featured column with nullable=True first
|
||||
op.add_column("persona", sa.Column("featured", sa.Boolean(), nullable=True))
|
||||
|
||||
# Migrate data from is_default_persona to featured
|
||||
op.execute("UPDATE persona SET featured = is_default_persona")
|
||||
|
||||
# Make featured non-nullable with default=False
|
||||
op.alter_column(
|
||||
"persona",
|
||||
"featured",
|
||||
existing_type=sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
)
|
||||
|
||||
# Drop is_default_persona column
|
||||
op.drop_column("persona", "is_default_persona")
|
||||
|
||||
# Drop unused columns
|
||||
op.drop_column("persona", "num_chunks")
|
||||
op.drop_column("persona", "chunks_above")
|
||||
op.drop_column("persona", "chunks_below")
|
||||
op.drop_column("persona", "llm_relevance_filter")
|
||||
op.drop_column("persona", "llm_filter_extraction")
|
||||
op.drop_column("persona", "recency_bias")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Add back recency_bias column
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column(
|
||||
"recency_bias",
|
||||
sa.VARCHAR(),
|
||||
nullable=False,
|
||||
server_default="base_decay",
|
||||
),
|
||||
)
|
||||
|
||||
# Add back llm_filter_extraction column
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column(
|
||||
"llm_filter_extraction",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
),
|
||||
)
|
||||
|
||||
# Add back llm_relevance_filter column
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column(
|
||||
"llm_relevance_filter",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
),
|
||||
)
|
||||
|
||||
# Add back chunks_below column
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column("chunks_below", sa.Integer(), nullable=False, server_default="0"),
|
||||
)
|
||||
|
||||
# Add back chunks_above column
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column("chunks_above", sa.Integer(), nullable=False, server_default="0"),
|
||||
)
|
||||
|
||||
# Add back num_chunks column
|
||||
op.add_column("persona", sa.Column("num_chunks", sa.Float(), nullable=True))
|
||||
|
||||
# Add back is_default_persona column
|
||||
op.add_column(
|
||||
"persona",
|
||||
sa.Column(
|
||||
"is_default_persona",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
),
|
||||
)
|
||||
|
||||
# Migrate data from featured to is_default_persona
|
||||
op.execute("UPDATE persona SET is_default_persona = featured")
|
||||
|
||||
# Drop featured column
|
||||
op.drop_column("persona", "featured")
|
||||
@@ -18,6 +18,7 @@ from ee.onyx.server.enterprise_settings.store import (
|
||||
store_settings as store_ee_settings,
|
||||
)
|
||||
from ee.onyx.server.enterprise_settings.store import upload_logo
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import update_default_provider
|
||||
@@ -160,6 +161,12 @@ def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) ->
|
||||
user=None, # Seeding is done as admin
|
||||
name=persona.name,
|
||||
description=persona.description,
|
||||
num_chunks=(
|
||||
persona.num_chunks if persona.num_chunks is not None else 0.0
|
||||
),
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
recency_bias=RecencyBiasSetting.AUTO,
|
||||
document_set_ids=persona.document_set_ids,
|
||||
llm_model_provider_override=persona.llm_model_provider_override,
|
||||
llm_model_version_override=persona.llm_model_version_override,
|
||||
@@ -171,7 +178,6 @@ def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) ->
|
||||
system_prompt=persona.system_prompt,
|
||||
task_prompt=persona.task_prompt,
|
||||
datetime_aware=persona.datetime_aware,
|
||||
featured=persona.featured,
|
||||
commit=False,
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
@@ -543,7 +543,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
result = await db_session.execute(
|
||||
select(Persona.id)
|
||||
.where(
|
||||
Persona.featured.is_(True),
|
||||
Persona.is_default_persona.is_(True),
|
||||
Persona.is_public.is_(True),
|
||||
Persona.is_visible.is_(True),
|
||||
Persona.deleted.is_(False),
|
||||
|
||||
@@ -241,7 +241,8 @@ _VECTOR_DB_BEAT_TASK_NAMES: set[str] = {
|
||||
"check-for-index-attempt-cleanup",
|
||||
"check-for-doc-permissions-sync",
|
||||
"check-for-external-group-sync",
|
||||
"migrate-chunks-from-vespa-to-opensearch",
|
||||
"check-for-documents-for-opensearch-migration",
|
||||
"migrate-documents-from-vespa-to-opensearch",
|
||||
}
|
||||
|
||||
if DISABLE_VECTOR_DB:
|
||||
|
||||
@@ -414,31 +414,34 @@ def _process_user_file_with_indexing(
|
||||
raise RuntimeError(f"Indexing pipeline failed for user file {user_file_id}")
|
||||
|
||||
|
||||
def process_user_file_impl(
|
||||
*, user_file_id: str, tenant_id: str, redis_locking: bool
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
)
|
||||
def process_single_user_file(
|
||||
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
|
||||
) -> None:
|
||||
"""Core implementation for processing a single user file.
|
||||
|
||||
When redis_locking=True, acquires a per-file Redis lock and clears the
|
||||
queued-key guard (Celery path). When redis_locking=False, skips all Redis
|
||||
operations (BackgroundTask path).
|
||||
"""
|
||||
task_logger.info(f"process_user_file_impl - Starting id={user_file_id}")
|
||||
task_logger.info(f"process_single_user_file - Starting id={user_file_id}")
|
||||
start = time.monotonic()
|
||||
|
||||
file_lock: RedisLock | None = None
|
||||
if redis_locking:
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
redis_client.delete(_user_file_queued_key(user_file_id))
|
||||
file_lock = redis_client.lock(
|
||||
_user_file_lock_key(user_file_id),
|
||||
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# Clear the "queued" guard set by the beat generator so that the next beat
|
||||
# cycle can re-enqueue this file if it is still in PROCESSING state after
|
||||
# this task completes or fails.
|
||||
redis_client.delete(_user_file_queued_key(user_file_id))
|
||||
|
||||
file_lock: RedisLock = redis_client.lock(
|
||||
_user_file_lock_key(user_file_id),
|
||||
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
if not file_lock.acquire(blocking=False):
|
||||
task_logger.info(
|
||||
f"process_single_user_file - Lock held, skipping user_file_id={user_file_id}"
|
||||
)
|
||||
if file_lock is not None and not file_lock.acquire(blocking=False):
|
||||
task_logger.info(
|
||||
f"process_user_file_impl - Lock held, skipping user_file_id={user_file_id}"
|
||||
)
|
||||
return
|
||||
return None
|
||||
|
||||
documents: list[Document] = []
|
||||
try:
|
||||
@@ -446,18 +449,15 @@ def process_user_file_impl(
|
||||
uf = db_session.get(UserFile, _as_uuid(user_file_id))
|
||||
if not uf:
|
||||
task_logger.warning(
|
||||
f"process_user_file_impl - UserFile not found id={user_file_id}"
|
||||
f"process_single_user_file - UserFile not found id={user_file_id}"
|
||||
)
|
||||
return
|
||||
return None
|
||||
|
||||
if uf.status not in (
|
||||
UserFileStatus.PROCESSING,
|
||||
UserFileStatus.INDEXING,
|
||||
):
|
||||
if uf.status != UserFileStatus.PROCESSING:
|
||||
task_logger.info(
|
||||
f"process_user_file_impl - Skipping id={user_file_id} status={uf.status}"
|
||||
f"process_single_user_file - Skipping id={user_file_id} status={uf.status}"
|
||||
)
|
||||
return
|
||||
return None
|
||||
|
||||
connector = LocalFileConnector(
|
||||
file_locations=[uf.file_id],
|
||||
@@ -471,6 +471,7 @@ def process_user_file_impl(
|
||||
[doc for doc in batch if not isinstance(doc, HierarchyNode)]
|
||||
)
|
||||
|
||||
# update the document id to userfile id in the documents
|
||||
for document in documents:
|
||||
document.id = str(user_file_id)
|
||||
document.source = DocumentSource.USER_FILE
|
||||
@@ -492,8 +493,9 @@ def process_user_file_impl(
|
||||
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"process_user_file_impl - Error processing file id={user_file_id} - {e.__class__.__name__}"
|
||||
f"process_single_user_file - Error processing file id={user_file_id} - {e.__class__.__name__}"
|
||||
)
|
||||
# don't update the status if the user file is being deleted
|
||||
current_user_file = db_session.get(UserFile, _as_uuid(user_file_id))
|
||||
if (
|
||||
current_user_file
|
||||
@@ -502,42 +504,33 @@ def process_user_file_impl(
|
||||
uf.status = UserFileStatus.FAILED
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
return
|
||||
return None
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
task_logger.info(
|
||||
f"process_user_file_impl - Finished id={user_file_id} docs={len(documents)} elapsed={elapsed:.2f}s"
|
||||
f"process_single_user_file - Finished id={user_file_id} docs={len(documents)} elapsed={elapsed:.2f}s"
|
||||
)
|
||||
return None
|
||||
except Exception as e:
|
||||
# Attempt to mark the file as failed
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
uf = db_session.get(UserFile, _as_uuid(user_file_id))
|
||||
if uf:
|
||||
# don't update the status if the user file is being deleted
|
||||
if uf.status != UserFileStatus.DELETING:
|
||||
uf.status = UserFileStatus.FAILED
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
|
||||
task_logger.exception(
|
||||
f"process_user_file_impl - Error processing file id={user_file_id} - {e.__class__.__name__}"
|
||||
f"process_single_user_file - Error processing file id={user_file_id} - {e.__class__.__name__}"
|
||||
)
|
||||
return None
|
||||
finally:
|
||||
if file_lock is not None and file_lock.owned():
|
||||
if file_lock.owned():
|
||||
file_lock.release()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
)
|
||||
def process_single_user_file(
|
||||
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
|
||||
) -> None:
|
||||
process_user_file_impl(
|
||||
user_file_id=user_file_id, tenant_id=tenant_id, redis_locking=True
|
||||
)
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_USER_FILE_DELETE,
|
||||
soft_time_limit=300,
|
||||
@@ -588,38 +581,36 @@ def check_for_user_file_delete(self: Task, *, tenant_id: str) -> None:
|
||||
return None
|
||||
|
||||
|
||||
def delete_user_file_impl(
|
||||
*, user_file_id: str, tenant_id: str, redis_locking: bool
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.DELETE_SINGLE_USER_FILE,
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
)
|
||||
def process_single_user_file_delete(
|
||||
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
|
||||
) -> None:
|
||||
"""Core implementation for deleting a single user file.
|
||||
|
||||
When redis_locking=True, acquires a per-file Redis lock (Celery path).
|
||||
When redis_locking=False, skips Redis operations (BackgroundTask path).
|
||||
"""
|
||||
task_logger.info(f"delete_user_file_impl - Starting id={user_file_id}")
|
||||
|
||||
file_lock: RedisLock | None = None
|
||||
if redis_locking:
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
file_lock = redis_client.lock(
|
||||
_user_file_delete_lock_key(user_file_id),
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
"""Process a single user file delete."""
|
||||
task_logger.info(f"process_single_user_file_delete - Starting id={user_file_id}")
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
file_lock: RedisLock = redis_client.lock(
|
||||
_user_file_delete_lock_key(user_file_id),
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
)
|
||||
if not file_lock.acquire(blocking=False):
|
||||
task_logger.info(
|
||||
f"process_single_user_file_delete - Lock held, skipping user_file_id={user_file_id}"
|
||||
)
|
||||
if file_lock is not None and not file_lock.acquire(blocking=False):
|
||||
task_logger.info(
|
||||
f"delete_user_file_impl - Lock held, skipping user_file_id={user_file_id}"
|
||||
)
|
||||
return
|
||||
|
||||
return None
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
user_file = db_session.get(UserFile, _as_uuid(user_file_id))
|
||||
if not user_file:
|
||||
task_logger.info(
|
||||
f"delete_user_file_impl - User file not found id={user_file_id}"
|
||||
f"process_single_user_file_delete - User file not found id={user_file_id}"
|
||||
)
|
||||
return
|
||||
return None
|
||||
|
||||
# 1) Delete vector DB chunks (skip when disabled)
|
||||
if not DISABLE_VECTOR_DB:
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
@@ -657,6 +648,7 @@ def delete_user_file_impl(
|
||||
chunk_count=chunk_count,
|
||||
)
|
||||
|
||||
# 2) Delete the user-uploaded file content from filestore (blob + metadata)
|
||||
file_store = get_default_file_store()
|
||||
try:
|
||||
file_store.delete_file(user_file.file_id)
|
||||
@@ -664,33 +656,26 @@ def delete_user_file_impl(
|
||||
user_file_id_to_plaintext_file_name(user_file.id)
|
||||
)
|
||||
except Exception as e:
|
||||
# This block executed only if the file is not found in the filestore
|
||||
task_logger.exception(
|
||||
f"delete_user_file_impl - Error deleting file id={user_file.id} - {e.__class__.__name__}"
|
||||
f"process_single_user_file_delete - Error deleting file id={user_file.id} - {e.__class__.__name__}"
|
||||
)
|
||||
|
||||
# 3) Finally, delete the UserFile row
|
||||
db_session.delete(user_file)
|
||||
db_session.commit()
|
||||
task_logger.info(f"delete_user_file_impl - Completed id={user_file_id}")
|
||||
task_logger.info(
|
||||
f"process_single_user_file_delete - Completed id={user_file_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"delete_user_file_impl - Error processing file id={user_file_id} - {e.__class__.__name__}"
|
||||
f"process_single_user_file_delete - Error processing file id={user_file_id} - {e.__class__.__name__}"
|
||||
)
|
||||
return None
|
||||
finally:
|
||||
if file_lock is not None and file_lock.owned():
|
||||
if file_lock.owned():
|
||||
file_lock.release()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.DELETE_SINGLE_USER_FILE,
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
)
|
||||
def process_single_user_file_delete(
|
||||
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
|
||||
) -> None:
|
||||
delete_user_file_impl(
|
||||
user_file_id=user_file_id, tenant_id=tenant_id, redis_locking=True
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@shared_task(
|
||||
@@ -762,30 +747,32 @@ def check_for_user_file_project_sync(self: Task, *, tenant_id: str) -> None:
|
||||
return None
|
||||
|
||||
|
||||
def project_sync_user_file_impl(
|
||||
*, user_file_id: str, tenant_id: str, redis_locking: bool
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
)
|
||||
def process_single_user_file_project_sync(
|
||||
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
|
||||
) -> None:
|
||||
"""Core implementation for syncing a user file's project/persona metadata.
|
||||
"""Process a single user file project sync."""
|
||||
task_logger.info(
|
||||
f"process_single_user_file_project_sync - Starting id={user_file_id}"
|
||||
)
|
||||
|
||||
When redis_locking=True, acquires a per-file Redis lock and clears the
|
||||
queued-key guard (Celery path). When redis_locking=False, skips Redis
|
||||
operations (BackgroundTask path).
|
||||
"""
|
||||
task_logger.info(f"project_sync_user_file_impl - Starting id={user_file_id}")
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
redis_client.delete(_user_file_project_sync_queued_key(user_file_id))
|
||||
|
||||
file_lock: RedisLock | None = None
|
||||
if redis_locking:
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
redis_client.delete(_user_file_project_sync_queued_key(user_file_id))
|
||||
file_lock = redis_client.lock(
|
||||
user_file_project_sync_lock_key(user_file_id),
|
||||
timeout=CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT,
|
||||
file_lock: RedisLock = redis_client.lock(
|
||||
user_file_project_sync_lock_key(user_file_id),
|
||||
timeout=CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
if not file_lock.acquire(blocking=False):
|
||||
task_logger.info(
|
||||
f"process_single_user_file_project_sync - Lock held, skipping user_file_id={user_file_id}"
|
||||
)
|
||||
if file_lock is not None and not file_lock.acquire(blocking=False):
|
||||
task_logger.info(
|
||||
f"project_sync_user_file_impl - Lock held, skipping user_file_id={user_file_id}"
|
||||
)
|
||||
return
|
||||
return None
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
@@ -796,10 +783,11 @@ def project_sync_user_file_impl(
|
||||
).scalar_one_or_none()
|
||||
if not user_file:
|
||||
task_logger.info(
|
||||
f"project_sync_user_file_impl - User file not found id={user_file_id}"
|
||||
f"process_single_user_file_project_sync - User file not found id={user_file_id}"
|
||||
)
|
||||
return
|
||||
return None
|
||||
|
||||
# Sync project metadata to vector DB (skip when disabled)
|
||||
if not DISABLE_VECTOR_DB:
|
||||
if MANAGED_VESPA:
|
||||
httpx_init_vespa_pool(
|
||||
@@ -834,7 +822,7 @@ def project_sync_user_file_impl(
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"project_sync_user_file_impl - User file id={user_file_id}"
|
||||
f"process_single_user_file_project_sync - User file id={user_file_id}"
|
||||
)
|
||||
|
||||
user_file.needs_project_sync = False
|
||||
@@ -847,21 +835,11 @@ def project_sync_user_file_impl(
|
||||
|
||||
except Exception as e:
|
||||
task_logger.exception(
|
||||
f"project_sync_user_file_impl - Error syncing project for file id={user_file_id} - {e.__class__.__name__}"
|
||||
f"process_single_user_file_project_sync - Error syncing project for file id={user_file_id} - {e.__class__.__name__}"
|
||||
)
|
||||
return None
|
||||
finally:
|
||||
if file_lock is not None and file_lock.owned():
|
||||
if file_lock.owned():
|
||||
file_lock.release()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
|
||||
bind=True,
|
||||
ignore_result=True,
|
||||
)
|
||||
def process_single_user_file_project_sync(
|
||||
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
|
||||
) -> None:
|
||||
project_sync_user_file_impl(
|
||||
user_file_id=user_file_id, tenant_id=tenant_id, redis_locking=True
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -1,284 +0,0 @@
|
||||
"""Periodic poller for NO_VECTOR_DB deployments.
|
||||
|
||||
Replaces Celery Beat and background workers with a lightweight daemon thread
|
||||
that runs from the API server process. Two responsibilities:
|
||||
|
||||
1. Recovery polling (every 30 s): re-processes user files stuck in
|
||||
PROCESSING / DELETING / needs_sync states via the drain loops defined
|
||||
in ``task_utils.py``.
|
||||
|
||||
2. Periodic task execution (configurable intervals): runs LLM model updates
|
||||
and scheduled evals at their configured cadences, with Postgres advisory
|
||||
lock deduplication across multiple API server instances.
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import field
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
RECOVERY_INTERVAL_SECONDS = 30
|
||||
PERIODIC_TASK_LOCK_BASE = 20_000
|
||||
PERIODIC_TASK_KV_PREFIX = "periodic_poller:last_claimed:"
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Periodic task definitions
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
@dataclass
|
||||
class _PeriodicTaskDef:
|
||||
name: str
|
||||
interval_seconds: float
|
||||
lock_id: int
|
||||
run_fn: Callable[[], None]
|
||||
last_run_at: float = field(default=0.0)
|
||||
|
||||
|
||||
def _run_auto_llm_update() -> None:
|
||||
from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL
|
||||
|
||||
if not AUTO_LLM_CONFIG_URL:
|
||||
return
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.llm.well_known_providers.auto_update_service import (
|
||||
sync_llm_models_from_github,
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
sync_llm_models_from_github(db_session)
|
||||
|
||||
|
||||
def _run_scheduled_eval() -> None:
|
||||
from onyx.configs.app_configs import BRAINTRUST_API_KEY
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_PERMISSIONS_EMAIL
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_PROJECT
|
||||
|
||||
if not all(
|
||||
[
|
||||
BRAINTRUST_API_KEY,
|
||||
SCHEDULED_EVAL_PROJECT,
|
||||
SCHEDULED_EVAL_DATASET_NAMES,
|
||||
SCHEDULED_EVAL_PERMISSIONS_EMAIL,
|
||||
]
|
||||
):
|
||||
return
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from onyx.evals.eval import run_eval
|
||||
from onyx.evals.models import EvalConfigurationOptions
|
||||
|
||||
run_timestamp = datetime.now(timezone.utc).strftime("%Y-%m-%d")
|
||||
for dataset_name in SCHEDULED_EVAL_DATASET_NAMES:
|
||||
try:
|
||||
run_eval(
|
||||
configuration=EvalConfigurationOptions(
|
||||
search_permissions_email=SCHEDULED_EVAL_PERMISSIONS_EMAIL,
|
||||
dataset_name=dataset_name,
|
||||
no_send_logs=False,
|
||||
braintrust_project=SCHEDULED_EVAL_PROJECT,
|
||||
experiment_name=f"{dataset_name} - {run_timestamp}",
|
||||
),
|
||||
remote_dataset_name=dataset_name,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Periodic poller - Failed scheduled eval for dataset {dataset_name}"
|
||||
)
|
||||
|
||||
|
||||
def _build_periodic_tasks() -> list[_PeriodicTaskDef]:
|
||||
from onyx.configs.app_configs import AUTO_LLM_CONFIG_URL
|
||||
from onyx.configs.app_configs import AUTO_LLM_UPDATE_INTERVAL_SECONDS
|
||||
from onyx.configs.app_configs import SCHEDULED_EVAL_DATASET_NAMES
|
||||
|
||||
tasks: list[_PeriodicTaskDef] = []
|
||||
if AUTO_LLM_CONFIG_URL:
|
||||
tasks.append(
|
||||
_PeriodicTaskDef(
|
||||
name="auto-llm-update",
|
||||
interval_seconds=AUTO_LLM_UPDATE_INTERVAL_SECONDS,
|
||||
lock_id=PERIODIC_TASK_LOCK_BASE,
|
||||
run_fn=_run_auto_llm_update,
|
||||
)
|
||||
)
|
||||
if SCHEDULED_EVAL_DATASET_NAMES:
|
||||
tasks.append(
|
||||
_PeriodicTaskDef(
|
||||
name="scheduled-eval",
|
||||
interval_seconds=7 * 24 * 3600,
|
||||
lock_id=PERIODIC_TASK_LOCK_BASE + 1,
|
||||
run_fn=_run_scheduled_eval,
|
||||
)
|
||||
)
|
||||
return tasks
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Periodic task runner with advisory-lock-guarded claim
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def _try_claim_task(task_def: _PeriodicTaskDef) -> bool:
|
||||
"""Atomically check whether *task_def* should run and record a claim.
|
||||
|
||||
Uses a transaction-scoped advisory lock for atomicity combined with a
|
||||
``KVStore`` timestamp for cross-instance dedup. The DB session is held
|
||||
only for this brief claim transaction, not during task execution.
|
||||
"""
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from sqlalchemy import text
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import KVStore
|
||||
|
||||
kv_key = PERIODIC_TASK_KV_PREFIX + task_def.name
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
acquired = db_session.execute(
|
||||
text("SELECT pg_try_advisory_xact_lock(:id)"),
|
||||
{"id": task_def.lock_id},
|
||||
).scalar()
|
||||
if not acquired:
|
||||
return False
|
||||
|
||||
row = db_session.query(KVStore).filter_by(key=kv_key).first()
|
||||
if row and row.value is not None:
|
||||
last_claimed = datetime.fromisoformat(str(row.value))
|
||||
elapsed = (datetime.now(timezone.utc) - last_claimed).total_seconds()
|
||||
if elapsed < task_def.interval_seconds:
|
||||
return False
|
||||
|
||||
now_ts = datetime.now(timezone.utc).isoformat()
|
||||
if row:
|
||||
row.value = now_ts
|
||||
else:
|
||||
db_session.add(KVStore(key=kv_key, value=now_ts))
|
||||
db_session.commit()
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def _try_run_periodic_task(task_def: _PeriodicTaskDef) -> None:
|
||||
"""Run *task_def* if its interval has elapsed and no peer holds the lock."""
|
||||
now = time.monotonic()
|
||||
if now - task_def.last_run_at < task_def.interval_seconds:
|
||||
return
|
||||
|
||||
if not _try_claim_task(task_def):
|
||||
return
|
||||
|
||||
try:
|
||||
task_def.run_fn()
|
||||
task_def.last_run_at = now
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Periodic poller - Error running periodic task {task_def.name}"
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Recovery / drain loop runner
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def _run_drain_loops(tenant_id: str) -> None:
|
||||
from onyx.background.task_utils import drain_delete_loop
|
||||
from onyx.background.task_utils import drain_processing_loop
|
||||
from onyx.background.task_utils import drain_project_sync_loop
|
||||
|
||||
drain_processing_loop(tenant_id)
|
||||
drain_delete_loop(tenant_id)
|
||||
drain_project_sync_loop(tenant_id)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Startup recovery (10g)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def recover_stuck_user_files(tenant_id: str) -> None:
|
||||
"""Run all drain loops once to re-process files left in intermediate states.
|
||||
|
||||
Called from ``lifespan()`` on startup when ``DISABLE_VECTOR_DB`` is set.
|
||||
"""
|
||||
logger.info("recover_stuck_user_files - Checking for stuck user files")
|
||||
try:
|
||||
_run_drain_loops(tenant_id)
|
||||
except Exception:
|
||||
logger.exception("recover_stuck_user_files - Error during recovery")
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Daemon thread (10f)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
_shutdown_event = threading.Event()
|
||||
_poller_thread: threading.Thread | None = None
|
||||
|
||||
|
||||
def _poller_loop(tenant_id: str) -> None:
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
periodic_tasks = _build_periodic_tasks()
|
||||
logger.info(
|
||||
f"Periodic poller started with {len(periodic_tasks)} periodic task(s): "
|
||||
f"{[t.name for t in periodic_tasks]}"
|
||||
)
|
||||
|
||||
while not _shutdown_event.is_set():
|
||||
try:
|
||||
_run_drain_loops(tenant_id)
|
||||
except Exception:
|
||||
logger.exception("Periodic poller - Error in recovery polling")
|
||||
|
||||
for task_def in periodic_tasks:
|
||||
try:
|
||||
_try_run_periodic_task(task_def)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Periodic poller - Unhandled error checking task {task_def.name}"
|
||||
)
|
||||
|
||||
_shutdown_event.wait(RECOVERY_INTERVAL_SECONDS)
|
||||
|
||||
|
||||
def start_periodic_poller(tenant_id: str) -> None:
|
||||
"""Start the periodic poller daemon thread."""
|
||||
global _poller_thread # noqa: PLW0603
|
||||
_shutdown_event.clear()
|
||||
_poller_thread = threading.Thread(
|
||||
target=_poller_loop,
|
||||
args=(tenant_id,),
|
||||
daemon=True,
|
||||
name="no-vectordb-periodic-poller",
|
||||
)
|
||||
_poller_thread.start()
|
||||
logger.info("Periodic poller thread started")
|
||||
|
||||
|
||||
def stop_periodic_poller() -> None:
|
||||
"""Signal the periodic poller to stop and wait for it to exit."""
|
||||
global _poller_thread # noqa: PLW0603
|
||||
if _poller_thread is None:
|
||||
return
|
||||
_shutdown_event.set()
|
||||
_poller_thread.join(timeout=10)
|
||||
if _poller_thread.is_alive():
|
||||
logger.warning("Periodic poller thread did not stop within timeout")
|
||||
_poller_thread = None
|
||||
logger.info("Periodic poller thread stopped")
|
||||
@@ -1,33 +1,3 @@
|
||||
"""Background task utilities.
|
||||
|
||||
Contains query-history report helpers (used by all deployment modes) and
|
||||
in-process background task execution helpers for NO_VECTOR_DB mode:
|
||||
|
||||
- Atomic claim-and-mark helpers that prevent duplicate processing
|
||||
- Drain loops that process all pending user file work
|
||||
|
||||
Each claim function runs a short-lived transaction: SELECT ... FOR UPDATE
|
||||
SKIP LOCKED, UPDATE the row to remove it from future queries, COMMIT.
|
||||
After the commit the row lock is released, but the row is no longer
|
||||
eligible for re-claiming. No long-lived sessions or advisory locks.
|
||||
"""
|
||||
|
||||
from uuid import UUID
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Query-history report helpers (pre-existing, used by all modes)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
QUERY_REPORT_NAME_PREFIX = "query-history"
|
||||
|
||||
|
||||
@@ -39,142 +9,3 @@ def construct_query_history_report_name(
|
||||
|
||||
def extract_task_id_from_query_history_report_name(name: str) -> str:
|
||||
return name.removeprefix(f"{QUERY_REPORT_NAME_PREFIX}-").removesuffix(".csv")
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Atomic claim-and-mark helpers
|
||||
# ------------------------------------------------------------------
|
||||
# Each function runs inside a single short-lived session/transaction:
|
||||
# 1. SELECT ... FOR UPDATE SKIP LOCKED (locks one eligible row)
|
||||
# 2. UPDATE the row so it is no longer eligible
|
||||
# 3. COMMIT (releases the row lock)
|
||||
# After the commit, no other drain loop can claim the same row.
|
||||
|
||||
|
||||
def _claim_next_processing_file(db_session: Session) -> UUID | None:
|
||||
"""Claim the next PROCESSING file by transitioning it to INDEXING.
|
||||
|
||||
Returns the file id, or None when no eligible files remain.
|
||||
"""
|
||||
file_id = db_session.execute(
|
||||
select(UserFile.id)
|
||||
.where(UserFile.status == UserFileStatus.PROCESSING)
|
||||
.order_by(UserFile.created_at)
|
||||
.limit(1)
|
||||
.with_for_update(skip_locked=True)
|
||||
).scalar_one_or_none()
|
||||
if file_id is None:
|
||||
return None
|
||||
|
||||
db_session.execute(
|
||||
sa.update(UserFile)
|
||||
.where(UserFile.id == file_id)
|
||||
.values(status=UserFileStatus.INDEXING)
|
||||
)
|
||||
db_session.commit()
|
||||
return file_id
|
||||
|
||||
|
||||
def _claim_next_deleting_file(db_session: Session) -> UUID | None:
|
||||
"""Claim the next DELETING file.
|
||||
|
||||
No status transition needed — the impl deletes the row on success.
|
||||
The short-lived FOR UPDATE lock prevents concurrent claims.
|
||||
"""
|
||||
file_id = db_session.execute(
|
||||
select(UserFile.id)
|
||||
.where(UserFile.status == UserFileStatus.DELETING)
|
||||
.order_by(UserFile.created_at)
|
||||
.limit(1)
|
||||
.with_for_update(skip_locked=True)
|
||||
).scalar_one_or_none()
|
||||
# Commit to release the row lock promptly.
|
||||
db_session.commit()
|
||||
return file_id
|
||||
|
||||
|
||||
def _claim_next_sync_file(db_session: Session) -> UUID | None:
|
||||
"""Claim the next file needing project/persona sync.
|
||||
|
||||
No status transition needed — the impl clears the sync flags on
|
||||
success. The short-lived FOR UPDATE lock prevents concurrent claims.
|
||||
"""
|
||||
file_id = db_session.execute(
|
||||
select(UserFile.id)
|
||||
.where(
|
||||
sa.and_(
|
||||
sa.or_(
|
||||
UserFile.needs_project_sync.is_(True),
|
||||
UserFile.needs_persona_sync.is_(True),
|
||||
),
|
||||
UserFile.status == UserFileStatus.COMPLETED,
|
||||
)
|
||||
)
|
||||
.order_by(UserFile.created_at)
|
||||
.limit(1)
|
||||
.with_for_update(skip_locked=True)
|
||||
).scalar_one_or_none()
|
||||
db_session.commit()
|
||||
return file_id
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Drain loops — process *all* pending work of each type
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
def drain_processing_loop(tenant_id: str) -> None:
|
||||
"""Process all pending PROCESSING user files."""
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
process_user_file_impl,
|
||||
)
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
while True:
|
||||
with get_session_with_current_tenant() as session:
|
||||
file_id = _claim_next_processing_file(session)
|
||||
if file_id is None:
|
||||
break
|
||||
process_user_file_impl(
|
||||
user_file_id=str(file_id),
|
||||
tenant_id=tenant_id,
|
||||
redis_locking=False,
|
||||
)
|
||||
|
||||
|
||||
def drain_delete_loop(tenant_id: str) -> None:
|
||||
"""Delete all pending DELETING user files."""
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
delete_user_file_impl,
|
||||
)
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
while True:
|
||||
with get_session_with_current_tenant() as session:
|
||||
file_id = _claim_next_deleting_file(session)
|
||||
if file_id is None:
|
||||
break
|
||||
delete_user_file_impl(
|
||||
user_file_id=str(file_id),
|
||||
tenant_id=tenant_id,
|
||||
redis_locking=False,
|
||||
)
|
||||
|
||||
|
||||
def drain_project_sync_loop(tenant_id: str) -> None:
|
||||
"""Sync all pending project/persona metadata for user files."""
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
project_sync_user_file_impl,
|
||||
)
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
while True:
|
||||
with get_session_with_current_tenant() as session:
|
||||
file_id = _claim_next_sync_file(session)
|
||||
if file_id is None:
|
||||
break
|
||||
project_sync_user_file_impl(
|
||||
user_file_id=str(file_id),
|
||||
tenant_id=tenant_id,
|
||||
redis_locking=False,
|
||||
)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -531,13 +530,11 @@ def _create_file_tool_metadata_message(
|
||||
"""
|
||||
lines = [
|
||||
"You have access to the following files. Use the read_file tool to "
|
||||
"read sections of any file. You MUST pass the file_id UUID (not the "
|
||||
"filename) to read_file:"
|
||||
"read sections of any file:"
|
||||
]
|
||||
for meta in file_metadata:
|
||||
lines.append(
|
||||
f'- file_id="{meta.file_id}" filename="{meta.filename}" '
|
||||
f"(~{meta.approx_char_count:,} chars)"
|
||||
f'- {meta.file_id}: "{meta.filename}" (~{meta.approx_char_count:,} chars)'
|
||||
)
|
||||
|
||||
message_content = "\n".join(lines)
|
||||
@@ -561,16 +558,12 @@ def _create_context_files_message(
|
||||
# Format as documents JSON as described in README
|
||||
documents_list = []
|
||||
for idx, file_text in enumerate(context_files.file_texts, start=1):
|
||||
title = (
|
||||
context_files.file_metadata[idx - 1].filename
|
||||
if idx - 1 < len(context_files.file_metadata)
|
||||
else None
|
||||
documents_list.append(
|
||||
{
|
||||
"document": idx,
|
||||
"contents": file_text,
|
||||
}
|
||||
)
|
||||
entry: dict[str, Any] = {"document": idx}
|
||||
if title:
|
||||
entry["title"] = title
|
||||
entry["contents"] = file_text
|
||||
documents_list.append(entry)
|
||||
|
||||
documents_json = json.dumps({"documents": documents_list}, indent=2)
|
||||
message_content = f"Here are some documents provided for context, they may not all be relevant:\n{documents_json}"
|
||||
|
||||
@@ -98,7 +98,6 @@ def get_chat_sessions_by_user(
|
||||
db_session: Session,
|
||||
include_onyxbot_flows: bool = False,
|
||||
limit: int = 50,
|
||||
before: datetime | None = None,
|
||||
project_id: int | None = None,
|
||||
only_non_project_chats: bool = False,
|
||||
include_failed_chats: bool = False,
|
||||
@@ -113,9 +112,6 @@ def get_chat_sessions_by_user(
|
||||
if deleted is not None:
|
||||
stmt = stmt.where(ChatSession.deleted == deleted)
|
||||
|
||||
if before is not None:
|
||||
stmt = stmt.where(ChatSession.time_updated < before)
|
||||
|
||||
if limit:
|
||||
stmt = stmt.limit(limit)
|
||||
|
||||
|
||||
@@ -186,7 +186,6 @@ class EmbeddingPrecision(str, PyEnum):
|
||||
|
||||
class UserFileStatus(str, PyEnum):
|
||||
PROCESSING = "PROCESSING"
|
||||
INDEXING = "INDEXING"
|
||||
COMPLETED = "COMPLETED"
|
||||
FAILED = "FAILED"
|
||||
CANCELED = "CANCELED"
|
||||
|
||||
@@ -103,6 +103,7 @@ from onyx.utils.encryption import encrypt_string_to_bytes
|
||||
from onyx.utils.sensitive import SensitiveValue
|
||||
from onyx.utils.headers import HeaderItemDict
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
|
||||
# TODO: After anonymous user migration has been deployed, make user_id columns NOT NULL
|
||||
# and update Mapped[User | None] relationships to Mapped[User] where needed.
|
||||
@@ -3264,6 +3265,19 @@ class Persona(Base):
|
||||
)
|
||||
name: Mapped[str] = mapped_column(String)
|
||||
description: Mapped[str] = mapped_column(String)
|
||||
# Number of chunks to pass to the LLM for generation.
|
||||
num_chunks: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
chunks_above: Mapped[int] = mapped_column(Integer)
|
||||
chunks_below: Mapped[int] = mapped_column(Integer)
|
||||
# Pass every chunk through LLM for evaluation, fairly expensive
|
||||
# Can be turned off globally by admin, in which case, this setting is ignored
|
||||
llm_relevance_filter: Mapped[bool] = mapped_column(Boolean)
|
||||
# Enables using LLM to extract time and source type filters
|
||||
# Can also be admin disabled globally
|
||||
llm_filter_extraction: Mapped[bool] = mapped_column(Boolean)
|
||||
recency_bias: Mapped[RecencyBiasSetting] = mapped_column(
|
||||
Enum(RecencyBiasSetting, native_enum=False)
|
||||
)
|
||||
|
||||
# Allows the persona to specify a specific default LLM model
|
||||
# NOTE: only is applied on the actual response generation - is not used for things like
|
||||
@@ -3290,8 +3304,11 @@ class Persona(Base):
|
||||
# Treated specially (cannot be user edited etc.)
|
||||
builtin_persona: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
|
||||
# Featured personas are highlighted in the UI
|
||||
featured: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
# Default personas are personas created by admins and are automatically added
|
||||
# to all users' assistants list.
|
||||
is_default_persona: Mapped[bool] = mapped_column(
|
||||
Boolean, default=False, nullable=False
|
||||
)
|
||||
# controls whether the persona is available to be selected by users
|
||||
is_visible: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
# controls the ordering of personas in the UI
|
||||
|
||||
@@ -18,8 +18,11 @@ from sqlalchemy.orm import Session
|
||||
from onyx.access.hierarchy_access import get_user_external_group_ids
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.app_configs import CURATORS_CANNOT_VIEW_OR_EDIT_NON_OWNED_ASSISTANTS
|
||||
from onyx.configs.chat_configs import CONTEXT_CHUNKS_ABOVE
|
||||
from onyx.configs.chat_configs import CONTEXT_CHUNKS_BELOW
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.constants import SLACK_BOT_PERSONA_PREFIX
|
||||
from onyx.db.document_access import get_accessible_documents_by_ids
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
@@ -251,15 +254,13 @@ def create_update_persona(
|
||||
# Permission to actually use these is checked later
|
||||
|
||||
try:
|
||||
# Featured persona validation
|
||||
if create_persona_request.featured:
|
||||
|
||||
# Curators can edit featured personas, but not make them
|
||||
# TODO this will be reworked soon with RBAC permissions feature
|
||||
# Default persona validation
|
||||
if create_persona_request.is_default_persona:
|
||||
# Curators can edit default personas, but not make them
|
||||
if user.role == UserRole.CURATOR or user.role == UserRole.GLOBAL_CURATOR:
|
||||
pass
|
||||
elif user.role != UserRole.ADMIN:
|
||||
raise ValueError("Only admins can make a featured persona")
|
||||
raise ValueError("Only admins can make a default persona")
|
||||
|
||||
# Convert incoming string UUIDs to UUID objects for DB operations
|
||||
converted_user_file_ids = None
|
||||
@@ -280,6 +281,7 @@ def create_update_persona(
|
||||
document_set_ids=create_persona_request.document_set_ids,
|
||||
tool_ids=create_persona_request.tool_ids,
|
||||
is_public=create_persona_request.is_public,
|
||||
recency_bias=create_persona_request.recency_bias,
|
||||
llm_model_provider_override=create_persona_request.llm_model_provider_override,
|
||||
llm_model_version_override=create_persona_request.llm_model_version_override,
|
||||
starter_messages=create_persona_request.starter_messages,
|
||||
@@ -293,7 +295,10 @@ def create_update_persona(
|
||||
remove_image=create_persona_request.remove_image,
|
||||
search_start_date=create_persona_request.search_start_date,
|
||||
label_ids=create_persona_request.label_ids,
|
||||
featured=create_persona_request.featured,
|
||||
num_chunks=create_persona_request.num_chunks,
|
||||
llm_relevance_filter=create_persona_request.llm_relevance_filter,
|
||||
llm_filter_extraction=create_persona_request.llm_filter_extraction,
|
||||
is_default_persona=create_persona_request.is_default_persona,
|
||||
user_file_ids=converted_user_file_ids,
|
||||
commit=False,
|
||||
hierarchy_node_ids=create_persona_request.hierarchy_node_ids,
|
||||
@@ -869,6 +874,10 @@ def upsert_persona(
|
||||
user: User | None,
|
||||
name: str,
|
||||
description: str,
|
||||
num_chunks: float,
|
||||
llm_relevance_filter: bool,
|
||||
llm_filter_extraction: bool,
|
||||
recency_bias: RecencyBiasSetting,
|
||||
llm_model_provider_override: str | None,
|
||||
llm_model_version_override: str | None,
|
||||
starter_messages: list[StarterMessage] | None,
|
||||
@@ -889,11 +898,13 @@ def upsert_persona(
|
||||
remove_image: bool | None = None,
|
||||
search_start_date: datetime | None = None,
|
||||
builtin_persona: bool = False,
|
||||
featured: bool | None = None,
|
||||
is_default_persona: bool | None = None,
|
||||
label_ids: list[int] | None = None,
|
||||
user_file_ids: list[UUID] | None = None,
|
||||
hierarchy_node_ids: list[int] | None = None,
|
||||
document_ids: list[str] | None = None,
|
||||
chunks_above: int = CONTEXT_CHUNKS_ABOVE,
|
||||
chunks_below: int = CONTEXT_CHUNKS_BELOW,
|
||||
replace_base_system_prompt: bool = False,
|
||||
) -> Persona:
|
||||
"""
|
||||
@@ -1004,6 +1015,12 @@ def upsert_persona(
|
||||
# `default` and `built-in` properties can only be set when creating a persona.
|
||||
existing_persona.name = name
|
||||
existing_persona.description = description
|
||||
existing_persona.num_chunks = num_chunks
|
||||
existing_persona.chunks_above = chunks_above
|
||||
existing_persona.chunks_below = chunks_below
|
||||
existing_persona.llm_relevance_filter = llm_relevance_filter
|
||||
existing_persona.llm_filter_extraction = llm_filter_extraction
|
||||
existing_persona.recency_bias = recency_bias
|
||||
existing_persona.llm_model_provider_override = llm_model_provider_override
|
||||
existing_persona.llm_model_version_override = llm_model_version_override
|
||||
existing_persona.starter_messages = starter_messages
|
||||
@@ -1017,8 +1034,10 @@ def upsert_persona(
|
||||
if label_ids is not None:
|
||||
existing_persona.labels.clear()
|
||||
existing_persona.labels = labels or []
|
||||
existing_persona.featured = (
|
||||
featured if featured is not None else existing_persona.featured
|
||||
existing_persona.is_default_persona = (
|
||||
is_default_persona
|
||||
if is_default_persona is not None
|
||||
else existing_persona.is_default_persona
|
||||
)
|
||||
# Update embedded prompt fields if provided
|
||||
if system_prompt is not None:
|
||||
@@ -1071,6 +1090,12 @@ def upsert_persona(
|
||||
is_public=is_public,
|
||||
name=name,
|
||||
description=description,
|
||||
num_chunks=num_chunks,
|
||||
chunks_above=chunks_above,
|
||||
chunks_below=chunks_below,
|
||||
llm_relevance_filter=llm_relevance_filter,
|
||||
llm_filter_extraction=llm_filter_extraction,
|
||||
recency_bias=recency_bias,
|
||||
builtin_persona=builtin_persona,
|
||||
system_prompt=system_prompt or "",
|
||||
task_prompt=task_prompt or "",
|
||||
@@ -1086,7 +1111,9 @@ def upsert_persona(
|
||||
display_priority=display_priority,
|
||||
is_visible=is_visible,
|
||||
search_start_date=search_start_date,
|
||||
featured=(featured if featured is not None else False),
|
||||
is_default_persona=(
|
||||
is_default_persona if is_default_persona is not None else False
|
||||
),
|
||||
user_files=user_files or [],
|
||||
labels=labels or [],
|
||||
hierarchy_nodes=hierarchy_nodes or [],
|
||||
@@ -1131,9 +1158,9 @@ def delete_old_default_personas(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_persona_featured(
|
||||
def update_persona_is_default(
|
||||
persona_id: int,
|
||||
featured: bool,
|
||||
is_default: bool,
|
||||
db_session: Session,
|
||||
user: User,
|
||||
) -> None:
|
||||
@@ -1141,7 +1168,7 @@ def update_persona_featured(
|
||||
db_session=db_session, persona_id=persona_id, user=user, get_editable=True
|
||||
)
|
||||
|
||||
persona.featured = featured
|
||||
persona.is_default_persona = is_default
|
||||
db_session.commit()
|
||||
|
||||
|
||||
|
||||
@@ -9,9 +9,8 @@ from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
from starlette.background import BackgroundTasks
|
||||
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
@@ -106,8 +105,8 @@ def upload_files_to_user_files_with_indexing(
|
||||
user: User,
|
||||
temp_id_map: dict[str, str] | None,
|
||||
db_session: Session,
|
||||
background_tasks: BackgroundTasks | None = None,
|
||||
) -> CategorizedFilesResult:
|
||||
# Validate project ownership if a project_id is provided
|
||||
if project_id is not None and user is not None:
|
||||
if not check_project_ownership(project_id, user.id, db_session):
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
@@ -128,27 +127,16 @@ def upload_files_to_user_files_with_indexing(
|
||||
logger.warning(
|
||||
f"File {rejected_file.filename} rejected for {rejected_file.reason}"
|
||||
)
|
||||
|
||||
if DISABLE_VECTOR_DB and background_tasks is not None:
|
||||
from onyx.background.task_utils import drain_processing_loop
|
||||
|
||||
background_tasks.add_task(drain_processing_loop, tenant_id)
|
||||
for user_file in user_files:
|
||||
logger.info(f"Queued in-process processing for user_file_id={user_file.id}")
|
||||
else:
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
|
||||
for user_file in user_files:
|
||||
task = client_app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
logger.info(
|
||||
f"Triggered indexing for user_file_id={user_file.id} "
|
||||
f"with task_id={task.id}"
|
||||
)
|
||||
for user_file in user_files:
|
||||
task = client_app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": user_file.id, "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
logger.info(
|
||||
f"Triggered indexing for user_file_id={user_file.id} with task_id={task.id}"
|
||||
)
|
||||
|
||||
return CategorizedFilesResult(
|
||||
user_files=user_files,
|
||||
|
||||
@@ -5,6 +5,8 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.constants import DEFAULT_PERSONA_SLACK_CHANNEL_NAME
|
||||
from onyx.db.constants import SLACK_BOT_PERSONA_PREFIX
|
||||
from onyx.db.models import ChannelConfig
|
||||
@@ -43,6 +45,8 @@ def create_slack_channel_persona(
|
||||
channel_name: str | None,
|
||||
document_set_ids: list[int],
|
||||
existing_persona_id: int | None = None,
|
||||
num_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
|
||||
enable_auto_filters: bool = False,
|
||||
) -> Persona:
|
||||
"""NOTE: does not commit changes"""
|
||||
|
||||
@@ -69,13 +73,17 @@ def create_slack_channel_persona(
|
||||
system_prompt="",
|
||||
task_prompt="",
|
||||
datetime_aware=True,
|
||||
num_chunks=num_chunks,
|
||||
llm_relevance_filter=True,
|
||||
llm_filter_extraction=enable_auto_filters,
|
||||
recency_bias=RecencyBiasSetting.AUTO,
|
||||
tool_ids=[search_tool.id],
|
||||
document_set_ids=document_set_ids,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
is_public=True,
|
||||
featured=False,
|
||||
is_default_persona=False,
|
||||
db_session=db_session,
|
||||
commit=False,
|
||||
)
|
||||
|
||||
@@ -37,7 +37,6 @@ from onyx.configs.app_configs import APP_HOST
|
||||
from onyx.configs.app_configs import APP_PORT
|
||||
from onyx.configs.app_configs import AUTH_RATE_LIMITING_ENABLED
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.app_configs import LOG_ENDPOINT_LATENCY
|
||||
from onyx.configs.app_configs import OAUTH_CLIENT_ID
|
||||
from onyx.configs.app_configs import OAUTH_CLIENT_SECRET
|
||||
@@ -255,38 +254,8 @@ def include_auth_router_with_prefix(
|
||||
)
|
||||
|
||||
|
||||
def validate_no_vector_db_settings() -> None:
|
||||
"""Validate that DISABLE_VECTOR_DB is not combined with incompatible settings.
|
||||
|
||||
Raises RuntimeError if DISABLE_VECTOR_DB is set alongside MULTI_TENANT or ENABLE_CRAFT,
|
||||
since these modes require infrastructure that is removed in no-vector-DB deployments.
|
||||
"""
|
||||
if not DISABLE_VECTOR_DB:
|
||||
return
|
||||
|
||||
if MULTI_TENANT:
|
||||
raise RuntimeError(
|
||||
"DISABLE_VECTOR_DB cannot be used with MULTI_TENANT. "
|
||||
"Multi-tenant deployments require the vector database for "
|
||||
"per-tenant document indexing and search. Run in single-tenant "
|
||||
"mode when disabling the vector database."
|
||||
)
|
||||
|
||||
from onyx.server.features.build.configs import ENABLE_CRAFT
|
||||
|
||||
if ENABLE_CRAFT:
|
||||
raise RuntimeError(
|
||||
"DISABLE_VECTOR_DB cannot be used with ENABLE_CRAFT. "
|
||||
"Onyx Craft requires background workers for sandbox lifecycle "
|
||||
"management, which are removed in no-vector-DB deployments. "
|
||||
"Disable Craft (ENABLE_CRAFT=false) when disabling the vector database."
|
||||
)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
|
||||
validate_no_vector_db_settings()
|
||||
|
||||
# Set recursion limit
|
||||
if SYSTEM_RECURSION_LIMIT is not None:
|
||||
sys.setrecursionlimit(SYSTEM_RECURSION_LIMIT)
|
||||
@@ -355,20 +324,8 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
|
||||
if AUTH_RATE_LIMITING_ENABLED:
|
||||
await setup_auth_limiter()
|
||||
|
||||
if DISABLE_VECTOR_DB:
|
||||
from onyx.background.periodic_poller import recover_stuck_user_files
|
||||
from onyx.background.periodic_poller import start_periodic_poller
|
||||
|
||||
recover_stuck_user_files(POSTGRES_DEFAULT_SCHEMA)
|
||||
start_periodic_poller(POSTGRES_DEFAULT_SCHEMA)
|
||||
|
||||
yield
|
||||
|
||||
if DISABLE_VECTOR_DB:
|
||||
from onyx.background.periodic_poller import stop_periodic_poller
|
||||
|
||||
stop_periodic_poller()
|
||||
|
||||
SqlEngine.reset_engine()
|
||||
|
||||
if AUTH_RATE_LIMITING_ENABLED:
|
||||
|
||||
@@ -32,7 +32,7 @@ from onyx.db.persona import get_persona_snapshots_for_user
|
||||
from onyx.db.persona import get_persona_snapshots_paginated
|
||||
from onyx.db.persona import mark_persona_as_deleted
|
||||
from onyx.db.persona import mark_persona_as_not_deleted
|
||||
from onyx.db.persona import update_persona_featured
|
||||
from onyx.db.persona import update_persona_is_default
|
||||
from onyx.db.persona import update_persona_label
|
||||
from onyx.db.persona import update_persona_public_status
|
||||
from onyx.db.persona import update_persona_shared
|
||||
@@ -130,8 +130,8 @@ class IsPublicRequest(BaseModel):
|
||||
is_public: bool
|
||||
|
||||
|
||||
class IsFeaturedRequest(BaseModel):
|
||||
featured: bool
|
||||
class IsDefaultRequest(BaseModel):
|
||||
is_default_persona: bool
|
||||
|
||||
|
||||
@admin_router.patch("/{persona_id}/visible")
|
||||
@@ -168,22 +168,22 @@ def patch_user_persona_public_status(
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
|
||||
|
||||
@admin_router.patch("/{persona_id}/featured")
|
||||
def patch_persona_featured_status(
|
||||
@admin_router.patch("/{persona_id}/default")
|
||||
def patch_persona_default_status(
|
||||
persona_id: int,
|
||||
is_featured_request: IsFeaturedRequest,
|
||||
is_default_request: IsDefaultRequest,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
try:
|
||||
update_persona_featured(
|
||||
update_persona_is_default(
|
||||
persona_id=persona_id,
|
||||
featured=is_featured_request.featured,
|
||||
is_default=is_default_request.is_default_persona,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
)
|
||||
except ValueError as e:
|
||||
logger.exception("Failed to update persona featured status")
|
||||
logger.exception("Failed to update persona default status")
|
||||
raise HTTPException(status_code=403, detail=str(e))
|
||||
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.enums import HierarchyNodeType
|
||||
from onyx.db.models import Document
|
||||
from onyx.db.models import HierarchyNode
|
||||
@@ -107,7 +108,11 @@ class PersonaUpsertRequest(BaseModel):
|
||||
name: str
|
||||
description: str
|
||||
document_set_ids: list[int]
|
||||
num_chunks: float
|
||||
is_public: bool
|
||||
recency_bias: RecencyBiasSetting
|
||||
llm_filter_extraction: bool
|
||||
llm_relevance_filter: bool
|
||||
llm_model_provider_override: str | None = None
|
||||
llm_model_version_override: str | None = None
|
||||
starter_messages: list[StarterMessage] | None = None
|
||||
@@ -123,7 +128,7 @@ class PersonaUpsertRequest(BaseModel):
|
||||
)
|
||||
search_start_date: datetime | None = None
|
||||
label_ids: list[int] | None = None
|
||||
featured: bool = False
|
||||
is_default_persona: bool = False
|
||||
display_priority: int | None = None
|
||||
# Accept string UUIDs from frontend
|
||||
user_file_ids: list[str] | None = None
|
||||
@@ -150,6 +155,9 @@ class MinimalPersonaSnapshot(BaseModel):
|
||||
tools: list[ToolSnapshot]
|
||||
starter_messages: list[StarterMessage] | None
|
||||
|
||||
llm_relevance_filter: bool
|
||||
llm_filter_extraction: bool
|
||||
|
||||
# only show document sets in the UI that the assistant has access to
|
||||
document_sets: list[DocumentSetSummary]
|
||||
# Counts for knowledge sources (used to determine if search tool should be enabled)
|
||||
@@ -167,7 +175,7 @@ class MinimalPersonaSnapshot(BaseModel):
|
||||
is_public: bool
|
||||
is_visible: bool
|
||||
display_priority: int | None
|
||||
featured: bool
|
||||
is_default_persona: bool
|
||||
builtin_persona: bool
|
||||
|
||||
# Used for filtering
|
||||
@@ -206,6 +214,8 @@ class MinimalPersonaSnapshot(BaseModel):
|
||||
if should_expose_tool_to_fe(tool)
|
||||
],
|
||||
starter_messages=persona.starter_messages,
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
document_sets=[
|
||||
DocumentSetSummary.from_model(document_set)
|
||||
for document_set in persona.document_sets
|
||||
@@ -220,7 +230,7 @@ class MinimalPersonaSnapshot(BaseModel):
|
||||
is_public=persona.is_public,
|
||||
is_visible=persona.is_visible,
|
||||
display_priority=persona.display_priority,
|
||||
featured=persona.featured,
|
||||
is_default_persona=persona.is_default_persona,
|
||||
builtin_persona=persona.builtin_persona,
|
||||
labels=[PersonaLabelSnapshot.from_model(label) for label in persona.labels],
|
||||
owner=(
|
||||
@@ -242,9 +252,11 @@ class PersonaSnapshot(BaseModel):
|
||||
# Return string UUIDs to frontend for consistency
|
||||
user_file_ids: list[str]
|
||||
display_priority: int | None
|
||||
featured: bool
|
||||
is_default_persona: bool
|
||||
builtin_persona: bool
|
||||
starter_messages: list[StarterMessage] | None
|
||||
llm_relevance_filter: bool
|
||||
llm_filter_extraction: bool
|
||||
tools: list[ToolSnapshot]
|
||||
labels: list["PersonaLabelSnapshot"]
|
||||
owner: MinimalUserSnapshot | None
|
||||
@@ -253,6 +265,7 @@ class PersonaSnapshot(BaseModel):
|
||||
document_sets: list[DocumentSetSummary]
|
||||
llm_model_provider_override: str | None
|
||||
llm_model_version_override: str | None
|
||||
num_chunks: float | None
|
||||
# Hierarchy nodes attached for scoped search
|
||||
hierarchy_nodes: list[HierarchyNodeSnapshot] = Field(default_factory=list)
|
||||
# Individual documents attached for scoped search
|
||||
@@ -276,9 +289,11 @@ class PersonaSnapshot(BaseModel):
|
||||
icon_name=persona.icon_name,
|
||||
user_file_ids=[str(file.id) for file in persona.user_files],
|
||||
display_priority=persona.display_priority,
|
||||
featured=persona.featured,
|
||||
is_default_persona=persona.is_default_persona,
|
||||
builtin_persona=persona.builtin_persona,
|
||||
starter_messages=persona.starter_messages,
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
tools=[
|
||||
ToolSnapshot.from_model(tool)
|
||||
for tool in persona.tools
|
||||
@@ -309,6 +324,7 @@ class PersonaSnapshot(BaseModel):
|
||||
],
|
||||
llm_model_provider_override=persona.llm_model_provider_override,
|
||||
llm_model_version_override=persona.llm_model_version_override,
|
||||
num_chunks=persona.num_chunks,
|
||||
system_prompt=persona.system_prompt,
|
||||
replace_base_system_prompt=persona.replace_base_system_prompt,
|
||||
task_prompt=persona.task_prompt,
|
||||
@@ -316,10 +332,12 @@ class PersonaSnapshot(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
# Model with full context on persona's internal settings
|
||||
# Model with full context on perona's internal settings
|
||||
# This is used for flows which need to know all settings
|
||||
class FullPersonaSnapshot(PersonaSnapshot):
|
||||
search_start_date: datetime | None = None
|
||||
llm_relevance_filter: bool = False
|
||||
llm_filter_extraction: bool = False
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
@@ -342,7 +360,7 @@ class FullPersonaSnapshot(PersonaSnapshot):
|
||||
icon_name=persona.icon_name,
|
||||
user_file_ids=[str(file.id) for file in persona.user_files],
|
||||
display_priority=persona.display_priority,
|
||||
featured=persona.featured,
|
||||
is_default_persona=persona.is_default_persona,
|
||||
builtin_persona=persona.builtin_persona,
|
||||
starter_messages=persona.starter_messages,
|
||||
users=[
|
||||
@@ -373,7 +391,10 @@ class FullPersonaSnapshot(PersonaSnapshot):
|
||||
DocumentSetSummary.from_model(document_set_model)
|
||||
for document_set_model in persona.document_sets
|
||||
],
|
||||
num_chunks=persona.num_chunks,
|
||||
search_start_date=persona.search_start_date,
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
llm_model_provider_override=persona.llm_model_provider_override,
|
||||
llm_model_version_override=persona.llm_model_version_override,
|
||||
system_prompt=persona.system_prompt,
|
||||
|
||||
@@ -2,7 +2,6 @@ import json
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import BackgroundTasks
|
||||
from fastapi import Depends
|
||||
from fastapi import File
|
||||
from fastapi import Form
|
||||
@@ -13,7 +12,13 @@ from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
enqueue_user_file_project_sync_task,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
get_user_file_project_sync_queue_depth,
|
||||
)
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
@@ -29,6 +34,7 @@ from onyx.db.models import UserProject
|
||||
from onyx.db.persona import get_personas_by_ids
|
||||
from onyx.db.projects import get_project_token_count
|
||||
from onyx.db.projects import upload_files_to_user_files_with_indexing
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.features.projects.models import CategorizedFilesSnapshot
|
||||
from onyx.server.features.projects.models import ChatSessionRequest
|
||||
from onyx.server.features.projects.models import TokenCountResponse
|
||||
@@ -49,27 +55,7 @@ class UserFileDeleteResult(BaseModel):
|
||||
assistant_names: list[str] = []
|
||||
|
||||
|
||||
def _trigger_user_file_project_sync(
|
||||
user_file_id: UUID,
|
||||
tenant_id: str,
|
||||
background_tasks: BackgroundTasks | None = None,
|
||||
) -> None:
|
||||
if DISABLE_VECTOR_DB and background_tasks is not None:
|
||||
from onyx.background.task_utils import drain_project_sync_loop
|
||||
|
||||
background_tasks.add_task(drain_project_sync_loop, tenant_id)
|
||||
logger.info(f"Queued in-process project sync for user_file_id={user_file_id}")
|
||||
return
|
||||
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
enqueue_user_file_project_sync_task,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
get_user_file_project_sync_queue_depth,
|
||||
)
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
|
||||
def _trigger_user_file_project_sync(user_file_id: UUID, tenant_id: str) -> None:
|
||||
queue_depth = get_user_file_project_sync_queue_depth(client_app)
|
||||
if queue_depth > USER_FILE_PROJECT_SYNC_MAX_QUEUE_DEPTH:
|
||||
logger.warning(
|
||||
@@ -125,7 +111,6 @@ def create_project(
|
||||
|
||||
@router.post("/file/upload", tags=PUBLIC_API_TAGS)
|
||||
def upload_user_files(
|
||||
bg_tasks: BackgroundTasks,
|
||||
files: list[UploadFile] = File(...),
|
||||
project_id: int | None = Form(None),
|
||||
temp_id_map: str | None = Form(None), # JSON string mapping hashed key -> temp_id
|
||||
@@ -152,12 +137,12 @@ def upload_user_files(
|
||||
user=user,
|
||||
temp_id_map=parsed_temp_id_map,
|
||||
db_session=db_session,
|
||||
background_tasks=bg_tasks if DISABLE_VECTOR_DB else None,
|
||||
)
|
||||
|
||||
return CategorizedFilesSnapshot.from_result(categorized_files_result)
|
||||
|
||||
except Exception as e:
|
||||
# Log error with type, message, and stack for easier debugging
|
||||
logger.exception(f"Error uploading files - {type(e).__name__}: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
@@ -207,7 +192,6 @@ def get_files_in_project(
|
||||
def unlink_user_file_from_project(
|
||||
project_id: int,
|
||||
file_id: UUID,
|
||||
bg_tasks: BackgroundTasks,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> Response:
|
||||
@@ -224,6 +208,7 @@ def unlink_user_file_from_project(
|
||||
if project is None:
|
||||
raise HTTPException(status_code=404, detail="Project not found")
|
||||
|
||||
user_id = user.id
|
||||
user_file = (
|
||||
db_session.query(UserFile)
|
||||
.filter(UserFile.id == file_id, UserFile.user_id == user_id)
|
||||
@@ -239,7 +224,7 @@ def unlink_user_file_from_project(
|
||||
db_session.commit()
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
_trigger_user_file_project_sync(user_file.id, tenant_id, bg_tasks)
|
||||
_trigger_user_file_project_sync(user_file.id, tenant_id)
|
||||
|
||||
return Response(status_code=204)
|
||||
|
||||
@@ -252,7 +237,6 @@ def unlink_user_file_from_project(
|
||||
def link_user_file_to_project(
|
||||
project_id: int,
|
||||
file_id: UUID,
|
||||
bg_tasks: BackgroundTasks,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UserFileSnapshot:
|
||||
@@ -284,7 +268,7 @@ def link_user_file_to_project(
|
||||
db_session.commit()
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
_trigger_user_file_project_sync(user_file.id, tenant_id, bg_tasks)
|
||||
_trigger_user_file_project_sync(user_file.id, tenant_id)
|
||||
|
||||
return UserFileSnapshot.from_model(user_file)
|
||||
|
||||
@@ -351,7 +335,7 @@ def upsert_project_instructions(
|
||||
class ProjectPayload(BaseModel):
|
||||
project: UserProjectSnapshot
|
||||
files: list[UserFileSnapshot] | None = None
|
||||
persona_id_to_featured: dict[int, bool] | None = None
|
||||
persona_id_to_is_default: dict[int, bool] | None = None
|
||||
|
||||
|
||||
@router.get(
|
||||
@@ -370,11 +354,13 @@ def get_project_details(
|
||||
if session.persona_id is not None
|
||||
]
|
||||
personas = get_personas_by_ids(persona_ids, db_session)
|
||||
persona_id_to_featured = {persona.id: persona.featured for persona in personas}
|
||||
persona_id_to_is_default = {
|
||||
persona.id: persona.is_default_persona for persona in personas
|
||||
}
|
||||
return ProjectPayload(
|
||||
project=project,
|
||||
files=files,
|
||||
persona_id_to_featured=persona_id_to_featured,
|
||||
persona_id_to_is_default=persona_id_to_is_default,
|
||||
)
|
||||
|
||||
|
||||
@@ -440,7 +426,6 @@ def delete_project(
|
||||
@router.delete("/file/{file_id}", tags=PUBLIC_API_TAGS)
|
||||
def delete_user_file(
|
||||
file_id: UUID,
|
||||
bg_tasks: BackgroundTasks,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UserFileDeleteResult:
|
||||
@@ -473,25 +458,15 @@ def delete_user_file(
|
||||
db_session.commit()
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
if DISABLE_VECTOR_DB:
|
||||
from onyx.background.task_utils import drain_delete_loop
|
||||
|
||||
bg_tasks.add_task(drain_delete_loop, tenant_id)
|
||||
logger.info(f"Queued in-process delete for user_file_id={user_file.id}")
|
||||
else:
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
|
||||
task = client_app.send_task(
|
||||
OnyxCeleryTask.DELETE_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": str(user_file.id), "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_DELETE,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
logger.info(
|
||||
f"Triggered delete for user_file_id={user_file.id} "
|
||||
f"with task_id={task.id}"
|
||||
)
|
||||
|
||||
task = client_app.send_task(
|
||||
OnyxCeleryTask.DELETE_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": str(user_file.id), "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_DELETE,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
logger.info(
|
||||
f"Triggered delete for user_file_id={user_file.id} with task_id={task.id}"
|
||||
)
|
||||
return UserFileDeleteResult(
|
||||
has_associations=False, project_names=[], assistant_names=[]
|
||||
)
|
||||
|
||||
@@ -5,6 +5,7 @@ from sqlalchemy.orm import Session
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.configs.constants import TMP_DRALPHA_PERSONA_NAME
|
||||
from onyx.configs.kg_configs import KG_BETA_ASSISTANT_DESCRIPTION
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.entities import get_entity_stats_by_grounded_source_name
|
||||
from onyx.db.entity_type import get_configured_entity_types
|
||||
@@ -133,7 +134,11 @@ def enable_or_disable_kg(
|
||||
system_prompt=KG_BETA_ASSISTANT_SYSTEM_PROMPT,
|
||||
task_prompt=KG_BETA_ASSISTANT_TASK_PROMPT,
|
||||
datetime_aware=False,
|
||||
num_chunks=25,
|
||||
llm_relevance_filter=False,
|
||||
is_public=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
document_set_ids=[],
|
||||
tool_ids=[search_tool.id, kg_tool.id],
|
||||
llm_model_provider_override=None,
|
||||
@@ -142,7 +147,7 @@ def enable_or_disable_kg(
|
||||
users=[user.id],
|
||||
groups=[],
|
||||
label_ids=[],
|
||||
featured=False,
|
||||
is_default_persona=False,
|
||||
display_priority=0,
|
||||
user_file_ids=[],
|
||||
)
|
||||
|
||||
@@ -198,6 +198,7 @@ def patch_slack_channel_config(
|
||||
channel_name=channel_config["channel_name"],
|
||||
document_set_ids=slack_channel_config_creation_request.document_sets,
|
||||
existing_persona_id=existing_persona_id,
|
||||
enable_auto_filters=slack_channel_config_creation_request.enable_auto_filters,
|
||||
).id
|
||||
|
||||
slack_channel_config_model = update_slack_channel_config(
|
||||
|
||||
@@ -152,20 +152,10 @@ def get_user_chat_sessions(
|
||||
project_id: int | None = None,
|
||||
only_non_project_chats: bool = True,
|
||||
include_failed_chats: bool = False,
|
||||
page_size: int = Query(default=50, ge=1, le=100),
|
||||
before: str | None = Query(default=None),
|
||||
) -> ChatSessionsResponse:
|
||||
user_id = user.id
|
||||
|
||||
try:
|
||||
before_dt = (
|
||||
datetime.datetime.fromisoformat(before) if before is not None else None
|
||||
)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=422, detail="Invalid 'before' timestamp format")
|
||||
|
||||
try:
|
||||
# Fetch one extra to determine if there are more results
|
||||
chat_sessions = get_chat_sessions_by_user(
|
||||
user_id=user_id,
|
||||
deleted=False,
|
||||
@@ -173,16 +163,11 @@ def get_user_chat_sessions(
|
||||
project_id=project_id,
|
||||
only_non_project_chats=only_non_project_chats,
|
||||
include_failed_chats=include_failed_chats,
|
||||
limit=page_size + 1,
|
||||
before=before_dt,
|
||||
)
|
||||
|
||||
except ValueError:
|
||||
raise ValueError("Chat session does not exist or has been deleted")
|
||||
|
||||
has_more = len(chat_sessions) > page_size
|
||||
chat_sessions = chat_sessions[:page_size]
|
||||
|
||||
return ChatSessionsResponse(
|
||||
sessions=[
|
||||
ChatSessionDetails(
|
||||
@@ -196,8 +181,7 @@ def get_user_chat_sessions(
|
||||
current_temperature_override=chat.temperature_override,
|
||||
)
|
||||
for chat in chat_sessions
|
||||
],
|
||||
has_more=has_more,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -192,7 +192,6 @@ class ChatSessionDetails(BaseModel):
|
||||
|
||||
class ChatSessionsResponse(BaseModel):
|
||||
sessions: list[ChatSessionDetails]
|
||||
has_more: bool = False
|
||||
|
||||
|
||||
class ChatMessageDetail(BaseModel):
|
||||
|
||||
@@ -9,6 +9,7 @@ from onyx.chat.models import AnswerStreamPart
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.process_message import handle_stream_message_objects
|
||||
from onyx.db.chat import create_chat_session
|
||||
from onyx.db.models import RecencyBiasSetting
|
||||
from onyx.db.models import User
|
||||
from onyx.db.persona import upsert_persona
|
||||
from onyx.server.query_and_chat.models import MessageResponseIDInfo
|
||||
@@ -73,6 +74,10 @@ def test_stream_chat_message_objects_without_web_search(
|
||||
user=None, # System persona
|
||||
name=f"Test Persona {uuid.uuid4()}",
|
||||
description="Test persona with no tools for web search test",
|
||||
num_chunks=10.0,
|
||||
llm_relevance_filter=True,
|
||||
llm_filter_extraction=True,
|
||||
recency_bias=RecencyBiasSetting.BASE_DECAY,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
|
||||
@@ -1,257 +0,0 @@
|
||||
"""External dependency unit tests for periodic task claiming.
|
||||
|
||||
Tests ``_try_claim_task`` and ``_try_run_periodic_task`` against real
|
||||
PostgreSQL, verifying happy-path behavior and concurrent-access safety.
|
||||
|
||||
The claim mechanism uses a transaction-scoped advisory lock + a KVStore
|
||||
timestamp for cross-instance dedup. The DB session is released before
|
||||
the task runs, so long-running tasks don't hold connections.
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from concurrent.futures import as_completed
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.background.periodic_poller import _PeriodicTaskDef
|
||||
from onyx.background.periodic_poller import _try_claim_task
|
||||
from onyx.background.periodic_poller import _try_run_periodic_task
|
||||
from onyx.background.periodic_poller import PERIODIC_TASK_KV_PREFIX
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.db.models import KVStore
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
_TEST_LOCK_BASE = 90_000
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def _init_engine() -> None:
|
||||
SqlEngine.init_engine(pool_size=10, max_overflow=5)
|
||||
|
||||
|
||||
def _make_task(
|
||||
*,
|
||||
name: str | None = None,
|
||||
interval: float = 3600,
|
||||
lock_id: int | None = None,
|
||||
run_fn: MagicMock | None = None,
|
||||
) -> _PeriodicTaskDef:
|
||||
return _PeriodicTaskDef(
|
||||
name=name or f"test-{uuid4().hex[:8]}",
|
||||
interval_seconds=interval,
|
||||
lock_id=lock_id or _TEST_LOCK_BASE,
|
||||
run_fn=run_fn or MagicMock(),
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _cleanup_kv(
|
||||
tenant_context: None, # noqa: ARG001
|
||||
) -> Generator[None, None, None]:
|
||||
yield
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
db_session.query(KVStore).filter(
|
||||
KVStore.key.like(f"{PERIODIC_TASK_KV_PREFIX}test-%")
|
||||
).delete(synchronize_session=False)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Happy-path: _try_claim_task
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestClaimHappyPath:
|
||||
def test_first_claim_succeeds(self) -> None:
|
||||
assert _try_claim_task(_make_task()) is True
|
||||
|
||||
def test_first_claim_creates_kv_row(self) -> None:
|
||||
task = _make_task()
|
||||
_try_claim_task(task)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
row = (
|
||||
db_session.query(KVStore)
|
||||
.filter_by(key=PERIODIC_TASK_KV_PREFIX + task.name)
|
||||
.first()
|
||||
)
|
||||
assert row is not None
|
||||
assert row.value is not None
|
||||
|
||||
def test_second_claim_within_interval_fails(self) -> None:
|
||||
task = _make_task(interval=3600)
|
||||
assert _try_claim_task(task) is True
|
||||
assert _try_claim_task(task) is False
|
||||
|
||||
def test_claim_after_interval_succeeds(self) -> None:
|
||||
task = _make_task(interval=1)
|
||||
assert _try_claim_task(task) is True
|
||||
|
||||
kv_key = PERIODIC_TASK_KV_PREFIX + task.name
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
row = db_session.query(KVStore).filter_by(key=kv_key).first()
|
||||
assert row is not None
|
||||
row.value = (datetime.now(timezone.utc) - timedelta(seconds=10)).isoformat()
|
||||
db_session.commit()
|
||||
|
||||
assert _try_claim_task(task) is True
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Happy-path: _try_run_periodic_task
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunHappyPath:
|
||||
def test_runs_task_and_updates_last_run_at(self) -> None:
|
||||
mock_fn = MagicMock()
|
||||
task = _make_task(run_fn=mock_fn)
|
||||
|
||||
_try_run_periodic_task(task)
|
||||
|
||||
mock_fn.assert_called_once()
|
||||
assert task.last_run_at > 0
|
||||
|
||||
def test_skips_when_in_memory_interval_not_elapsed(self) -> None:
|
||||
mock_fn = MagicMock()
|
||||
task = _make_task(run_fn=mock_fn, interval=3600)
|
||||
task.last_run_at = time.monotonic()
|
||||
|
||||
_try_run_periodic_task(task)
|
||||
|
||||
mock_fn.assert_not_called()
|
||||
|
||||
def test_skips_when_db_claim_blocked(self) -> None:
|
||||
name = f"test-{uuid4().hex[:8]}"
|
||||
lock_id = _TEST_LOCK_BASE + 10
|
||||
|
||||
_try_claim_task(_make_task(name=name, lock_id=lock_id, interval=3600))
|
||||
|
||||
mock_fn = MagicMock()
|
||||
task = _make_task(name=name, lock_id=lock_id, interval=3600, run_fn=mock_fn)
|
||||
_try_run_periodic_task(task)
|
||||
|
||||
mock_fn.assert_not_called()
|
||||
|
||||
def test_task_exception_does_not_propagate(self) -> None:
|
||||
task = _make_task(run_fn=MagicMock(side_effect=RuntimeError("boom")))
|
||||
_try_run_periodic_task(task)
|
||||
|
||||
def test_claim_committed_before_task_runs(self) -> None:
|
||||
"""The KV claim must be visible in the DB when run_fn executes."""
|
||||
task_name = f"test-order-{uuid4().hex[:8]}"
|
||||
kv_key = PERIODIC_TASK_KV_PREFIX + task_name
|
||||
claim_visible: list[bool] = []
|
||||
|
||||
def check_claim() -> None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
row = db_session.query(KVStore).filter_by(key=kv_key).first()
|
||||
claim_visible.append(row is not None and row.value is not None)
|
||||
|
||||
task = _PeriodicTaskDef(
|
||||
name=task_name,
|
||||
interval_seconds=3600,
|
||||
lock_id=_TEST_LOCK_BASE + 11,
|
||||
run_fn=check_claim,
|
||||
)
|
||||
|
||||
_try_run_periodic_task(task)
|
||||
|
||||
assert claim_visible == [True]
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Concurrency: only one claimer should win
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestClaimConcurrency:
|
||||
def test_concurrent_claims_single_winner(self) -> None:
|
||||
"""Many threads claim the same task — exactly one should succeed."""
|
||||
num_threads = 20
|
||||
task_name = f"test-race-{uuid4().hex[:8]}"
|
||||
lock_id = _TEST_LOCK_BASE + 20
|
||||
|
||||
def claim() -> bool:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(TEST_TENANT_ID)
|
||||
return _try_claim_task(
|
||||
_PeriodicTaskDef(
|
||||
name=task_name,
|
||||
interval_seconds=3600,
|
||||
lock_id=lock_id,
|
||||
run_fn=lambda: None,
|
||||
)
|
||||
)
|
||||
|
||||
results: list[bool] = []
|
||||
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||
futures = [executor.submit(claim) for _ in range(num_threads)]
|
||||
for future in as_completed(futures):
|
||||
results.append(future.result())
|
||||
|
||||
winners = sum(1 for r in results if r)
|
||||
assert winners == 1, f"Expected 1 winner, got {winners}"
|
||||
|
||||
def test_concurrent_run_single_execution(self) -> None:
|
||||
"""Many threads run the same task — run_fn fires exactly once."""
|
||||
num_threads = 20
|
||||
task_name = f"test-run-race-{uuid4().hex[:8]}"
|
||||
lock_id = _TEST_LOCK_BASE + 21
|
||||
counter = MagicMock()
|
||||
|
||||
def run() -> None:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(TEST_TENANT_ID)
|
||||
_try_run_periodic_task(
|
||||
_PeriodicTaskDef(
|
||||
name=task_name,
|
||||
interval_seconds=3600,
|
||||
lock_id=lock_id,
|
||||
run_fn=counter,
|
||||
)
|
||||
)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||
futures = [executor.submit(run) for _ in range(num_threads)]
|
||||
for future in as_completed(futures):
|
||||
future.result()
|
||||
|
||||
assert (
|
||||
counter.call_count == 1
|
||||
), f"Expected run_fn called once, got {counter.call_count}"
|
||||
|
||||
def test_no_errors_under_contention(self) -> None:
|
||||
"""All threads complete without exceptions under high contention."""
|
||||
num_threads = 30
|
||||
task_name = f"test-err-{uuid4().hex[:8]}"
|
||||
lock_id = _TEST_LOCK_BASE + 22
|
||||
errors: list[Exception] = []
|
||||
|
||||
def claim() -> bool:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(TEST_TENANT_ID)
|
||||
return _try_claim_task(
|
||||
_PeriodicTaskDef(
|
||||
name=task_name,
|
||||
interval_seconds=3600,
|
||||
lock_id=lock_id,
|
||||
run_fn=lambda: None,
|
||||
)
|
||||
)
|
||||
|
||||
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
||||
futures = [executor.submit(claim) for _ in range(num_threads)]
|
||||
for future in as_completed(futures):
|
||||
try:
|
||||
future.result()
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
assert errors == [], f"Got {len(errors)} errors: {errors}"
|
||||
@@ -1,219 +0,0 @@
|
||||
"""External dependency unit tests for startup recovery (Step 10g).
|
||||
|
||||
Seeds ``UserFile`` records in stuck states (PROCESSING, DELETING,
|
||||
needs_project_sync) then calls ``recover_stuck_user_files`` and verifies
|
||||
the drain loops pick them up via ``FOR UPDATE SKIP LOCKED``.
|
||||
|
||||
Uses real PostgreSQL (via ``db_session`` / ``tenant_context`` fixtures).
|
||||
The per-file ``*_impl`` functions are mocked so no real file store or
|
||||
connector is needed — we only verify that recovery finds and dispatches
|
||||
the correct files.
|
||||
"""
|
||||
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.periodic_poller import recover_stuck_user_files
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import UserFile
|
||||
from tests.external_dependency_unit.conftest import create_test_user
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_IMPL_MODULE = "onyx.background.celery.tasks.user_file_processing.tasks"
|
||||
|
||||
|
||||
def _create_user_file(
|
||||
db_session: Session,
|
||||
user_id: object,
|
||||
*,
|
||||
status: UserFileStatus = UserFileStatus.PROCESSING,
|
||||
needs_project_sync: bool = False,
|
||||
needs_persona_sync: bool = False,
|
||||
) -> UserFile:
|
||||
uf = UserFile(
|
||||
id=uuid4(),
|
||||
user_id=user_id,
|
||||
file_id=f"test_file_{uuid4().hex[:8]}",
|
||||
name=f"test_{uuid4().hex[:8]}.txt",
|
||||
file_type="text/plain",
|
||||
status=status,
|
||||
needs_project_sync=needs_project_sync,
|
||||
needs_persona_sync=needs_persona_sync,
|
||||
)
|
||||
db_session.add(uf)
|
||||
db_session.commit()
|
||||
db_session.refresh(uf)
|
||||
return uf
|
||||
|
||||
|
||||
@pytest.fixture()
|
||||
def _cleanup_user_files(db_session: Session) -> Generator[list[UserFile], None, None]:
|
||||
"""Track created UserFile rows and delete them after each test."""
|
||||
created: list[UserFile] = []
|
||||
yield created
|
||||
for uf in created:
|
||||
existing = db_session.get(UserFile, uf.id)
|
||||
if existing:
|
||||
db_session.delete(existing)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRecoverProcessingFiles:
|
||||
"""Files in PROCESSING status are re-processed via the processing drain loop."""
|
||||
|
||||
def test_processing_files_recovered(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
_cleanup_user_files: list[UserFile],
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "recovery_proc")
|
||||
uf = _create_user_file(db_session, user.id, status=UserFileStatus.PROCESSING)
|
||||
_cleanup_user_files.append(uf)
|
||||
|
||||
mock_impl = MagicMock()
|
||||
with patch(f"{_IMPL_MODULE}.process_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
|
||||
assert (
|
||||
str(uf.id) in called_ids
|
||||
), f"Expected file {uf.id} to be recovered but got: {called_ids}"
|
||||
|
||||
def test_completed_files_not_recovered(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
_cleanup_user_files: list[UserFile],
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "recovery_comp")
|
||||
uf = _create_user_file(db_session, user.id, status=UserFileStatus.COMPLETED)
|
||||
_cleanup_user_files.append(uf)
|
||||
|
||||
mock_impl = MagicMock()
|
||||
with patch(f"{_IMPL_MODULE}.process_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
|
||||
assert (
|
||||
str(uf.id) not in called_ids
|
||||
), f"COMPLETED file {uf.id} should not have been recovered"
|
||||
|
||||
|
||||
class TestRecoverDeletingFiles:
|
||||
"""Files in DELETING status are recovered via the delete drain loop."""
|
||||
|
||||
def test_deleting_files_recovered(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
_cleanup_user_files: list[UserFile],
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "recovery_del")
|
||||
uf = _create_user_file(db_session, user.id, status=UserFileStatus.DELETING)
|
||||
_cleanup_user_files.append(uf)
|
||||
|
||||
mock_impl = MagicMock()
|
||||
with patch(f"{_IMPL_MODULE}.delete_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
|
||||
assert (
|
||||
str(uf.id) in called_ids
|
||||
), f"Expected file {uf.id} to be recovered for deletion but got: {called_ids}"
|
||||
|
||||
|
||||
class TestRecoverSyncFiles:
|
||||
"""Files needing project/persona sync are recovered via the sync drain loop."""
|
||||
|
||||
def test_needs_project_sync_recovered(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
_cleanup_user_files: list[UserFile],
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "recovery_sync")
|
||||
uf = _create_user_file(
|
||||
db_session,
|
||||
user.id,
|
||||
status=UserFileStatus.COMPLETED,
|
||||
needs_project_sync=True,
|
||||
)
|
||||
_cleanup_user_files.append(uf)
|
||||
|
||||
mock_impl = MagicMock()
|
||||
with patch(f"{_IMPL_MODULE}.project_sync_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
|
||||
assert (
|
||||
str(uf.id) in called_ids
|
||||
), f"Expected file {uf.id} to be recovered for sync but got: {called_ids}"
|
||||
|
||||
def test_needs_persona_sync_recovered(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
_cleanup_user_files: list[UserFile],
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "recovery_psync")
|
||||
uf = _create_user_file(
|
||||
db_session,
|
||||
user.id,
|
||||
status=UserFileStatus.COMPLETED,
|
||||
needs_persona_sync=True,
|
||||
)
|
||||
_cleanup_user_files.append(uf)
|
||||
|
||||
mock_impl = MagicMock()
|
||||
with patch(f"{_IMPL_MODULE}.project_sync_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
called_ids = [call.kwargs["user_file_id"] for call in mock_impl.call_args_list]
|
||||
assert (
|
||||
str(uf.id) in called_ids
|
||||
), f"Expected file {uf.id} to be recovered for persona sync but got: {called_ids}"
|
||||
|
||||
|
||||
class TestRecoveryMultipleFiles:
|
||||
"""Recovery processes all stuck files in one pass, not just the first."""
|
||||
|
||||
def test_multiple_processing_files(
|
||||
self,
|
||||
db_session: Session,
|
||||
tenant_context: None, # noqa: ARG002
|
||||
_cleanup_user_files: list[UserFile],
|
||||
) -> None:
|
||||
user = create_test_user(db_session, "recovery_multi")
|
||||
files = []
|
||||
for _ in range(3):
|
||||
uf = _create_user_file(
|
||||
db_session, user.id, status=UserFileStatus.PROCESSING
|
||||
)
|
||||
_cleanup_user_files.append(uf)
|
||||
files.append(uf)
|
||||
|
||||
mock_impl = MagicMock()
|
||||
with patch(f"{_IMPL_MODULE}.process_user_file_impl", mock_impl):
|
||||
recover_stuck_user_files(TEST_TENANT_ID)
|
||||
|
||||
called_ids = {call.kwargs["user_file_id"] for call in mock_impl.call_args_list}
|
||||
expected_ids = {str(uf.id) for uf in files}
|
||||
assert expected_ids.issubset(called_ids), (
|
||||
f"Expected all {len(files)} files to be recovered. "
|
||||
f"Missing: {expected_ids - called_ids}"
|
||||
)
|
||||
@@ -36,6 +36,7 @@ from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
user_file_project_sync_lock_key,
|
||||
)
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__UserFile
|
||||
@@ -85,6 +86,12 @@ def _create_test_persona(
|
||||
persona = Persona(
|
||||
name=f"Test Persona {uuid4().hex[:8]}",
|
||||
description="Test persona",
|
||||
num_chunks=10.0,
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
system_prompt="You are a test assistant",
|
||||
task_prompt="Answer the question",
|
||||
tools=[],
|
||||
@@ -403,6 +410,10 @@ class TestUpsertPersonaMarksSyncFlag:
|
||||
user=user,
|
||||
name=f"persona-{uuid4().hex[:8]}",
|
||||
description="test",
|
||||
num_chunks=10.0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
@@ -431,6 +442,10 @@ class TestUpsertPersonaMarksSyncFlag:
|
||||
user=user,
|
||||
name=f"persona-{uuid4().hex[:8]}",
|
||||
description="test",
|
||||
num_chunks=10.0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
@@ -446,11 +461,16 @@ class TestUpsertPersonaMarksSyncFlag:
|
||||
uf_old.needs_persona_sync = False
|
||||
db_session.commit()
|
||||
|
||||
assert persona.num_chunks is not None
|
||||
# Now update the persona to swap files
|
||||
upsert_persona(
|
||||
user=user,
|
||||
name=persona.name,
|
||||
description=persona.description,
|
||||
num_chunks=persona.num_chunks,
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
recency_bias=persona.recency_bias,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
@@ -481,6 +501,10 @@ class TestUpsertPersonaMarksSyncFlag:
|
||||
user=user,
|
||||
name=f"persona-{uuid4().hex[:8]}",
|
||||
description="test",
|
||||
num_chunks=10.0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
@@ -495,10 +519,15 @@ class TestUpsertPersonaMarksSyncFlag:
|
||||
uf.needs_persona_sync = False
|
||||
db_session.commit()
|
||||
|
||||
assert persona.num_chunks is not None
|
||||
upsert_persona(
|
||||
user=user,
|
||||
name=persona.name,
|
||||
description=persona.description,
|
||||
num_chunks=persona.num_chunks,
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
recency_bias=persona.recency_bias,
|
||||
llm_model_provider_override=None,
|
||||
llm_model_version_override=None,
|
||||
starter_messages=None,
|
||||
|
||||
@@ -18,6 +18,7 @@ from sqlalchemy.orm import Session
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__UserFile
|
||||
@@ -57,6 +58,12 @@ def _create_persona(db_session: Session, user: User) -> Persona:
|
||||
persona = Persona(
|
||||
name=f"Test Persona {uuid4().hex[:8]}",
|
||||
description="Test persona",
|
||||
num_chunks=10.0,
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
system_prompt="test",
|
||||
task_prompt="test",
|
||||
tools=[],
|
||||
|
||||
@@ -22,6 +22,7 @@ from sqlalchemy.orm import Session
|
||||
from slack_sdk.errors import SlackApiError
|
||||
|
||||
from onyx.configs.constants import FederatedConnectorSource
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.context.search.federated.slack_search import fetch_and_cache_channel_metadata
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import FederatedConnector
|
||||
@@ -54,6 +55,11 @@ def _create_test_persona_with_slack_config(db_session: Session) -> Persona | Non
|
||||
persona = Persona(
|
||||
name=f"test_slack_persona_{unique_id}",
|
||||
description="Test persona for Slack federated search",
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
llm_relevance_filter=True,
|
||||
llm_filter_extraction=True,
|
||||
recency_bias=RecencyBiasSetting.AUTO,
|
||||
system_prompt="You are a helpful assistant.",
|
||||
task_prompt="Answer the user's question based on the provided context.",
|
||||
)
|
||||
@@ -812,6 +818,11 @@ def test_slack_channel_config_eager_loads_persona(db_session: Session) -> None:
|
||||
persona = Persona(
|
||||
name=f"test_eager_load_persona_{unique_id}",
|
||||
description="Test persona for eager loading test",
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
llm_relevance_filter=True,
|
||||
llm_filter_extraction=True,
|
||||
recency_bias=RecencyBiasSetting.AUTO,
|
||||
system_prompt="You are a helpful assistant.",
|
||||
task_prompt="Answer the user's question.",
|
||||
)
|
||||
|
||||
@@ -21,6 +21,7 @@ import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.emitter import get_default_emitter
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.enums import MCPAuthenticationPerformer
|
||||
from onyx.db.enums import MCPAuthenticationType
|
||||
from onyx.db.enums import MCPTransport
|
||||
@@ -46,6 +47,12 @@ def _create_test_persona_with_mcp_tool(
|
||||
persona = Persona(
|
||||
name=f"Test MCP Persona {uuid4().hex[:8]}",
|
||||
description="Test persona with MCP tools",
|
||||
num_chunks=10.0,
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
system_prompt="You are a helpful assistant",
|
||||
task_prompt="Answer the user's question",
|
||||
tools=tools,
|
||||
|
||||
@@ -17,6 +17,7 @@ import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.emitter import get_default_emitter
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.models import OAuthAccount
|
||||
from onyx.db.models import OAuthConfig
|
||||
from onyx.db.models import Persona
|
||||
@@ -56,6 +57,12 @@ def _create_test_persona(db_session: Session, user: User, tools: list[Tool]) ->
|
||||
persona = Persona(
|
||||
name=f"Test Persona {uuid4().hex[:8]}",
|
||||
description="Test persona",
|
||||
num_chunks=10.0,
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.NO_DECAY,
|
||||
system_prompt="You are a helpful assistant",
|
||||
task_prompt="Answer the user's question",
|
||||
tools=tools,
|
||||
|
||||
@@ -933,7 +933,6 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from fastapi import UploadFile
|
||||
from fastapi.background import BackgroundTasks
|
||||
from sqlalchemy.orm import Session
|
||||
from starlette.datastructures import Headers
|
||||
|
||||
@@ -1140,7 +1139,6 @@ def test_code_interpreter_receives_chat_files(
|
||||
# Upload a test CSV
|
||||
csv_content = b"name,age,city\nAlice,30,NYC\nBob,25,SF\n"
|
||||
result = upload_user_files(
|
||||
bg_tasks=BackgroundTasks(),
|
||||
files=[
|
||||
UploadFile(
|
||||
file=io.BytesIO(csv_content),
|
||||
|
||||
@@ -3,6 +3,7 @@ from uuid import uuid4
|
||||
|
||||
import requests
|
||||
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.server.features.persona.models import FullPersonaSnapshot
|
||||
from onyx.server.features.persona.models import PersonaUpsertRequest
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
@@ -19,7 +20,11 @@ class PersonaManager:
|
||||
description: str | None = None,
|
||||
system_prompt: str | None = None,
|
||||
task_prompt: str | None = None,
|
||||
num_chunks: float = 5,
|
||||
llm_relevance_filter: bool = True,
|
||||
is_public: bool = True,
|
||||
llm_filter_extraction: bool = True,
|
||||
recency_bias: RecencyBiasSetting = RecencyBiasSetting.AUTO,
|
||||
datetime_aware: bool = False,
|
||||
document_set_ids: list[int] | None = None,
|
||||
tool_ids: list[int] | None = None,
|
||||
@@ -30,7 +35,6 @@ class PersonaManager:
|
||||
label_ids: list[int] | None = None,
|
||||
user_file_ids: list[str] | None = None,
|
||||
display_priority: int | None = None,
|
||||
featured: bool = False,
|
||||
) -> DATestPersona:
|
||||
name = name or f"test-persona-{uuid4()}"
|
||||
description = description or f"Description for {name}"
|
||||
@@ -43,7 +47,11 @@ class PersonaManager:
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
datetime_aware=datetime_aware,
|
||||
num_chunks=num_chunks,
|
||||
llm_relevance_filter=llm_relevance_filter,
|
||||
is_public=is_public,
|
||||
llm_filter_extraction=llm_filter_extraction,
|
||||
recency_bias=recency_bias,
|
||||
document_set_ids=document_set_ids or [],
|
||||
tool_ids=tool_ids or [],
|
||||
llm_model_provider_override=llm_model_provider_override,
|
||||
@@ -53,7 +61,6 @@ class PersonaManager:
|
||||
label_ids=label_ids or [],
|
||||
user_file_ids=user_file_ids or [],
|
||||
display_priority=display_priority,
|
||||
featured=featured,
|
||||
)
|
||||
|
||||
response = requests.post(
|
||||
@@ -68,7 +75,11 @@ class PersonaManager:
|
||||
id=persona_data["id"],
|
||||
name=name,
|
||||
description=description,
|
||||
num_chunks=num_chunks,
|
||||
llm_relevance_filter=llm_relevance_filter,
|
||||
is_public=is_public,
|
||||
llm_filter_extraction=llm_filter_extraction,
|
||||
recency_bias=recency_bias,
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
datetime_aware=datetime_aware,
|
||||
@@ -79,7 +90,6 @@ class PersonaManager:
|
||||
users=users or [],
|
||||
groups=groups or [],
|
||||
label_ids=label_ids or [],
|
||||
featured=featured,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -90,7 +100,11 @@ class PersonaManager:
|
||||
description: str | None = None,
|
||||
system_prompt: str | None = None,
|
||||
task_prompt: str | None = None,
|
||||
num_chunks: float | None = None,
|
||||
llm_relevance_filter: bool | None = None,
|
||||
is_public: bool | None = None,
|
||||
llm_filter_extraction: bool | None = None,
|
||||
recency_bias: RecencyBiasSetting | None = None,
|
||||
datetime_aware: bool = False,
|
||||
document_set_ids: list[int] | None = None,
|
||||
tool_ids: list[int] | None = None,
|
||||
@@ -99,7 +113,6 @@ class PersonaManager:
|
||||
users: list[str] | None = None,
|
||||
groups: list[int] | None = None,
|
||||
label_ids: list[int] | None = None,
|
||||
featured: bool | None = None,
|
||||
) -> DATestPersona:
|
||||
system_prompt = system_prompt or f"System prompt for {persona.name}"
|
||||
task_prompt = task_prompt or f"Task prompt for {persona.name}"
|
||||
@@ -110,7 +123,13 @@ class PersonaManager:
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
datetime_aware=datetime_aware,
|
||||
num_chunks=num_chunks or persona.num_chunks,
|
||||
llm_relevance_filter=llm_relevance_filter or persona.llm_relevance_filter,
|
||||
is_public=persona.is_public if is_public is None else is_public,
|
||||
llm_filter_extraction=(
|
||||
llm_filter_extraction or persona.llm_filter_extraction
|
||||
),
|
||||
recency_bias=recency_bias or persona.recency_bias,
|
||||
document_set_ids=document_set_ids or persona.document_set_ids,
|
||||
tool_ids=tool_ids or persona.tool_ids,
|
||||
llm_model_provider_override=(
|
||||
@@ -122,7 +141,6 @@ class PersonaManager:
|
||||
users=[UUID(user) for user in (users or persona.users)],
|
||||
groups=groups or persona.groups,
|
||||
label_ids=label_ids or persona.label_ids,
|
||||
featured=featured if featured is not None else persona.featured,
|
||||
)
|
||||
|
||||
response = requests.patch(
|
||||
@@ -137,12 +155,16 @@ class PersonaManager:
|
||||
id=updated_persona_data["id"],
|
||||
name=updated_persona_data["name"],
|
||||
description=updated_persona_data["description"],
|
||||
num_chunks=updated_persona_data["num_chunks"],
|
||||
llm_relevance_filter=updated_persona_data["llm_relevance_filter"],
|
||||
is_public=updated_persona_data["is_public"],
|
||||
llm_filter_extraction=updated_persona_data["llm_filter_extraction"],
|
||||
recency_bias=recency_bias or persona.recency_bias,
|
||||
system_prompt=system_prompt,
|
||||
task_prompt=task_prompt,
|
||||
datetime_aware=datetime_aware,
|
||||
document_set_ids=[ds["id"] for ds in updated_persona_data["document_sets"]],
|
||||
tool_ids=[t["id"] for t in updated_persona_data["tools"]],
|
||||
document_set_ids=updated_persona_data["document_sets"],
|
||||
tool_ids=updated_persona_data["tools"],
|
||||
llm_model_provider_override=updated_persona_data[
|
||||
"llm_model_provider_override"
|
||||
],
|
||||
@@ -151,8 +173,7 @@ class PersonaManager:
|
||||
],
|
||||
users=[user["email"] for user in updated_persona_data["users"]],
|
||||
groups=updated_persona_data["groups"],
|
||||
label_ids=[label["id"] for label in updated_persona_data["labels"]],
|
||||
featured=updated_persona_data["featured"],
|
||||
label_ids=updated_persona_data["labels"],
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -201,13 +222,32 @@ class PersonaManager:
|
||||
fetched_persona.description,
|
||||
)
|
||||
)
|
||||
if fetched_persona.num_chunks != persona.num_chunks:
|
||||
mismatches.append(
|
||||
("num_chunks", persona.num_chunks, fetched_persona.num_chunks)
|
||||
)
|
||||
if fetched_persona.llm_relevance_filter != persona.llm_relevance_filter:
|
||||
mismatches.append(
|
||||
(
|
||||
"llm_relevance_filter",
|
||||
persona.llm_relevance_filter,
|
||||
fetched_persona.llm_relevance_filter,
|
||||
)
|
||||
)
|
||||
if fetched_persona.is_public != persona.is_public:
|
||||
mismatches.append(
|
||||
("is_public", persona.is_public, fetched_persona.is_public)
|
||||
)
|
||||
if fetched_persona.featured != persona.featured:
|
||||
if (
|
||||
fetched_persona.llm_filter_extraction
|
||||
!= persona.llm_filter_extraction
|
||||
):
|
||||
mismatches.append(
|
||||
("featured", persona.featured, fetched_persona.featured)
|
||||
(
|
||||
"llm_filter_extraction",
|
||||
persona.llm_filter_extraction,
|
||||
fetched_persona.llm_filter_extraction,
|
||||
)
|
||||
)
|
||||
if (
|
||||
fetched_persona.llm_model_provider_override
|
||||
|
||||
@@ -10,6 +10,7 @@ from pydantic import Field
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import QAFeedbackType
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.context.search.models import SavedSearchDoc
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.db.enums import AccessType
|
||||
@@ -161,7 +162,11 @@ class DATestPersona(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
description: str
|
||||
num_chunks: float
|
||||
llm_relevance_filter: bool
|
||||
is_public: bool
|
||||
llm_filter_extraction: bool
|
||||
recency_bias: RecencyBiasSetting
|
||||
document_set_ids: list[int]
|
||||
tool_ids: list[int]
|
||||
llm_model_provider_override: str | None
|
||||
@@ -169,7 +174,6 @@ class DATestPersona(BaseModel):
|
||||
users: list[str]
|
||||
groups: list[int]
|
||||
label_ids: list[int]
|
||||
featured: bool = False
|
||||
|
||||
# Embedded prompt fields (no longer separate prompt_ids)
|
||||
system_prompt: str | None = None
|
||||
|
||||
@@ -8,6 +8,7 @@ from collections.abc import Generator
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.discord_bot import bulk_create_channel_configs
|
||||
from onyx.db.discord_bot import create_discord_bot_config
|
||||
from onyx.db.discord_bot import create_guild_config
|
||||
@@ -35,8 +36,14 @@ def _create_test_persona(db_session: Session, persona_id: int, name: str) -> Per
|
||||
id=persona_id,
|
||||
name=name,
|
||||
description="Test persona for Discord bot tests",
|
||||
num_chunks=5.0,
|
||||
chunks_above=1,
|
||||
chunks_below=1,
|
||||
llm_relevance_filter=False,
|
||||
llm_filter_extraction=False,
|
||||
recency_bias=RecencyBiasSetting.FAVOR_RECENT,
|
||||
is_visible=True,
|
||||
featured=False,
|
||||
is_default_persona=False,
|
||||
deleted=False,
|
||||
builtin_persona=False,
|
||||
)
|
||||
|
||||
@@ -414,24 +414,6 @@ def test_mock_connector_checkpoint_recovery(
|
||||
)
|
||||
assert finished_index_attempt.status == IndexingStatus.FAILED
|
||||
|
||||
# Pause the connector immediately to prevent check_for_indexing from
|
||||
# creating automatic retry attempts while we reset the mock server.
|
||||
# Without this, the INITIAL_INDEXING status causes immediate retries
|
||||
# that would consume (or fail against) the mock server before we can
|
||||
# set up the recovery behavior.
|
||||
CCPairManager.pause_cc_pair(cc_pair, user_performing_action=admin_user)
|
||||
|
||||
# Collect all index attempt IDs created so far (the initial one plus
|
||||
# any automatic retries that may have started before the pause took effect).
|
||||
all_prior_attempt_ids: list[int] = []
|
||||
index_attempts_page = IndexAttemptManager.get_index_attempt_page(
|
||||
cc_pair_id=cc_pair.id,
|
||||
page=0,
|
||||
page_size=100,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
all_prior_attempt_ids = [ia.id for ia in index_attempts_page.items]
|
||||
|
||||
# Verify initial state: both docs should be indexed
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
documents = DocumentManager.fetch_documents_for_cc_pair(
|
||||
@@ -483,14 +465,17 @@ def test_mock_connector_checkpoint_recovery(
|
||||
)
|
||||
assert response.status_code == 200
|
||||
|
||||
# Set the manual indexing trigger, then unpause to allow the recovery run.
|
||||
# After the failure, the connector is in repeated error state and paused.
|
||||
# Set the manual indexing trigger first (while paused), then unpause.
|
||||
# This ensures the trigger is set before CHECK_FOR_INDEXING runs, which will
|
||||
# prevent the connector from being re-paused when repeated error state is detected.
|
||||
CCPairManager.run_once(
|
||||
cc_pair, from_beginning=False, user_performing_action=admin_user
|
||||
)
|
||||
CCPairManager.unpause_cc_pair(cc_pair, user_performing_action=admin_user)
|
||||
recovery_index_attempt = IndexAttemptManager.wait_for_index_attempt_start(
|
||||
cc_pair_id=cc_pair.id,
|
||||
index_attempts_to_ignore=all_prior_attempt_ids,
|
||||
index_attempts_to_ignore=[initial_index_attempt.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
IndexAttemptManager.wait_for_index_attempt_completion(
|
||||
|
||||
@@ -130,8 +130,8 @@ def test_repeated_error_state_detection_and_recovery(
|
||||
# )
|
||||
break
|
||||
|
||||
if time.monotonic() - start_time > 90:
|
||||
assert False, "CC pair did not enter repeated error state within 90 seconds"
|
||||
if time.monotonic() - start_time > 30:
|
||||
assert False, "CC pair did not enter repeated error state within 30 seconds"
|
||||
|
||||
time.sleep(2)
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import pytest
|
||||
import requests
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.llm import can_user_access_llm_provider
|
||||
from onyx.db.llm import fetch_user_group_ids
|
||||
@@ -77,6 +78,12 @@ def _create_persona(
|
||||
persona = Persona(
|
||||
name=name,
|
||||
description=f"{name} description",
|
||||
num_chunks=5,
|
||||
chunks_above=2,
|
||||
chunks_below=2,
|
||||
llm_relevance_filter=True,
|
||||
llm_filter_extraction=True,
|
||||
recency_bias=RecencyBiasSetting.AUTO,
|
||||
llm_model_provider_override=provider_name,
|
||||
llm_model_version_override="gpt-4o-mini",
|
||||
system_prompt="System prompt",
|
||||
|
||||
@@ -25,7 +25,7 @@ def test_cold_startup_default_assistant() -> None:
|
||||
result = db_session.execute(
|
||||
text(
|
||||
"""
|
||||
SELECT id, name, builtin_persona, featured, deleted
|
||||
SELECT id, name, builtin_persona, is_default_persona, deleted
|
||||
FROM persona
|
||||
WHERE builtin_persona = true
|
||||
ORDER BY id
|
||||
@@ -40,7 +40,7 @@ def test_cold_startup_default_assistant() -> None:
|
||||
assert default[0] == 0, "Default assistant should have ID 0"
|
||||
assert default[1] == "Assistant", "Should be named 'Assistant'"
|
||||
assert default[2] is True, "Should be builtin"
|
||||
assert default[3] is True, "Should be featured"
|
||||
assert default[3] is True, "Should be default"
|
||||
assert default[4] is False, "Should not be deleted"
|
||||
|
||||
# Check tools are properly associated
|
||||
|
||||
@@ -195,7 +195,11 @@ def _base_persona_body(**overrides: object) -> dict:
|
||||
"description": "test",
|
||||
"system_prompt": "test",
|
||||
"task_prompt": "",
|
||||
"num_chunks": 5,
|
||||
"is_public": True,
|
||||
"recency_bias": "auto",
|
||||
"llm_filter_extraction": False,
|
||||
"llm_relevance_filter": False,
|
||||
"datetime_aware": False,
|
||||
"document_set_ids": [],
|
||||
"tool_ids": [],
|
||||
|
||||
@@ -40,6 +40,7 @@ def test_persona_create_update_share_delete(
|
||||
expected_persona,
|
||||
name=f"updated-{expected_persona.name}",
|
||||
description=f"updated-{expected_persona.description}",
|
||||
num_chunks=expected_persona.num_chunks + 1,
|
||||
is_public=False,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
@@ -31,7 +31,11 @@ def test_update_persona_with_null_label_ids_preserves_labels(
|
||||
task_prompt=persona.task_prompt or "",
|
||||
datetime_aware=persona.datetime_aware,
|
||||
document_set_ids=persona.document_set_ids,
|
||||
num_chunks=persona.num_chunks,
|
||||
is_public=persona.is_public,
|
||||
recency_bias=persona.recency_bias,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_model_provider_override=persona.llm_model_provider_override,
|
||||
llm_model_version_override=persona.llm_model_version_override,
|
||||
tool_ids=persona.tool_ids,
|
||||
|
||||
@@ -31,8 +31,9 @@ def test_unified_assistant(reset: None, admin_user: DATestUser) -> None: # noqa
|
||||
"search, web browsing, and image generation"
|
||||
in unified_assistant.description.lower()
|
||||
)
|
||||
assert unified_assistant.featured is True
|
||||
assert unified_assistant.is_default_persona is True
|
||||
assert unified_assistant.is_visible is True
|
||||
assert unified_assistant.num_chunks == 25
|
||||
|
||||
# Verify tools
|
||||
tools = unified_assistant.tools
|
||||
|
||||
@@ -1,291 +0,0 @@
|
||||
"""Tests for the _impl functions' redis_locking parameter.
|
||||
|
||||
Verifies that:
|
||||
- redis_locking=True acquires/releases Redis locks and clears queued keys
|
||||
- redis_locking=False skips all Redis operations entirely
|
||||
- Both paths execute the same business logic (DB lookup, status check)
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
delete_user_file_impl,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
process_user_file_impl,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
project_sync_user_file_impl,
|
||||
)
|
||||
|
||||
TASKS_MODULE = "onyx.background.celery.tasks.user_file_processing.tasks"
|
||||
|
||||
|
||||
def _mock_session_returning_none() -> MagicMock:
|
||||
"""Return a mock session whose .get() returns None (file not found)."""
|
||||
session = MagicMock()
|
||||
session.get.return_value = None
|
||||
session.execute.return_value.scalar_one_or_none.return_value = None
|
||||
return session
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# process_user_file_impl
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestProcessUserFileImpl:
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TASKS_MODULE}.get_redis_client")
|
||||
def test_redis_locking_true_acquires_and_releases_lock(
|
||||
self,
|
||||
mock_get_redis: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
redis_client = MagicMock()
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = True
|
||||
lock.owned.return_value = True
|
||||
redis_client.lock.return_value = lock
|
||||
mock_get_redis.return_value = redis_client
|
||||
|
||||
session = _mock_session_returning_none()
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
user_file_id = str(uuid4())
|
||||
process_user_file_impl(
|
||||
user_file_id=user_file_id,
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=True,
|
||||
)
|
||||
|
||||
mock_get_redis.assert_called_once_with(tenant_id="test-tenant")
|
||||
redis_client.delete.assert_called_once()
|
||||
lock.acquire.assert_called_once_with(blocking=False)
|
||||
lock.release.assert_called_once()
|
||||
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TASKS_MODULE}.get_redis_client")
|
||||
def test_redis_locking_true_skips_when_lock_held(
|
||||
self,
|
||||
mock_get_redis: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
redis_client = MagicMock()
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = False
|
||||
redis_client.lock.return_value = lock
|
||||
mock_get_redis.return_value = redis_client
|
||||
|
||||
process_user_file_impl(
|
||||
user_file_id=str(uuid4()),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=True,
|
||||
)
|
||||
|
||||
lock.acquire.assert_called_once()
|
||||
mock_get_session.assert_not_called()
|
||||
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TASKS_MODULE}.get_redis_client")
|
||||
def test_redis_locking_false_skips_redis_entirely(
|
||||
self,
|
||||
mock_get_redis: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
session = _mock_session_returning_none()
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
process_user_file_impl(
|
||||
user_file_id=str(uuid4()),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=False,
|
||||
)
|
||||
|
||||
mock_get_redis.assert_not_called()
|
||||
mock_get_session.assert_called_once()
|
||||
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TASKS_MODULE}.get_redis_client")
|
||||
def test_both_paths_call_db_get(
|
||||
self,
|
||||
mock_get_redis: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
"""Both redis_locking=True and False should call db_session.get(UserFile, ...)."""
|
||||
redis_client = MagicMock()
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = True
|
||||
lock.owned.return_value = True
|
||||
redis_client.lock.return_value = lock
|
||||
mock_get_redis.return_value = redis_client
|
||||
|
||||
session = _mock_session_returning_none()
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
uid = str(uuid4())
|
||||
|
||||
process_user_file_impl(user_file_id=uid, tenant_id="t", redis_locking=True)
|
||||
call_count_true = session.get.call_count
|
||||
|
||||
session.reset_mock()
|
||||
mock_get_session.reset_mock()
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
process_user_file_impl(user_file_id=uid, tenant_id="t", redis_locking=False)
|
||||
call_count_false = session.get.call_count
|
||||
|
||||
assert call_count_true == call_count_false == 1
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# delete_user_file_impl
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDeleteUserFileImpl:
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TASKS_MODULE}.get_redis_client")
|
||||
def test_redis_locking_true_acquires_and_releases_lock(
|
||||
self,
|
||||
mock_get_redis: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
redis_client = MagicMock()
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = True
|
||||
lock.owned.return_value = True
|
||||
redis_client.lock.return_value = lock
|
||||
mock_get_redis.return_value = redis_client
|
||||
|
||||
session = _mock_session_returning_none()
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
delete_user_file_impl(
|
||||
user_file_id=str(uuid4()),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=True,
|
||||
)
|
||||
|
||||
mock_get_redis.assert_called_once()
|
||||
lock.acquire.assert_called_once_with(blocking=False)
|
||||
lock.release.assert_called_once()
|
||||
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TASKS_MODULE}.get_redis_client")
|
||||
def test_redis_locking_true_skips_when_lock_held(
|
||||
self,
|
||||
mock_get_redis: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
redis_client = MagicMock()
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = False
|
||||
redis_client.lock.return_value = lock
|
||||
mock_get_redis.return_value = redis_client
|
||||
|
||||
delete_user_file_impl(
|
||||
user_file_id=str(uuid4()),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=True,
|
||||
)
|
||||
|
||||
lock.acquire.assert_called_once()
|
||||
mock_get_session.assert_not_called()
|
||||
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TASKS_MODULE}.get_redis_client")
|
||||
def test_redis_locking_false_skips_redis_entirely(
|
||||
self,
|
||||
mock_get_redis: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
session = _mock_session_returning_none()
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
delete_user_file_impl(
|
||||
user_file_id=str(uuid4()),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=False,
|
||||
)
|
||||
|
||||
mock_get_redis.assert_not_called()
|
||||
mock_get_session.assert_called_once()
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# project_sync_user_file_impl
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestProjectSyncUserFileImpl:
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TASKS_MODULE}.get_redis_client")
|
||||
def test_redis_locking_true_acquires_and_releases_lock(
|
||||
self,
|
||||
mock_get_redis: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
redis_client = MagicMock()
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = True
|
||||
lock.owned.return_value = True
|
||||
redis_client.lock.return_value = lock
|
||||
mock_get_redis.return_value = redis_client
|
||||
|
||||
session = _mock_session_returning_none()
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
project_sync_user_file_impl(
|
||||
user_file_id=str(uuid4()),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=True,
|
||||
)
|
||||
|
||||
mock_get_redis.assert_called_once()
|
||||
redis_client.delete.assert_called_once()
|
||||
lock.acquire.assert_called_once_with(blocking=False)
|
||||
lock.release.assert_called_once()
|
||||
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TASKS_MODULE}.get_redis_client")
|
||||
def test_redis_locking_true_skips_when_lock_held(
|
||||
self,
|
||||
mock_get_redis: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
redis_client = MagicMock()
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = False
|
||||
redis_client.lock.return_value = lock
|
||||
mock_get_redis.return_value = redis_client
|
||||
|
||||
project_sync_user_file_impl(
|
||||
user_file_id=str(uuid4()),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=True,
|
||||
)
|
||||
|
||||
lock.acquire.assert_called_once()
|
||||
mock_get_session.assert_not_called()
|
||||
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TASKS_MODULE}.get_redis_client")
|
||||
def test_redis_locking_false_skips_redis_entirely(
|
||||
self,
|
||||
mock_get_redis: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
session = _mock_session_returning_none()
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
project_sync_user_file_impl(
|
||||
user_file_id=str(uuid4()),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=False,
|
||||
)
|
||||
|
||||
mock_get_redis.assert_not_called()
|
||||
mock_get_session.assert_called_once()
|
||||
@@ -1,421 +0,0 @@
|
||||
"""Tests for no-vector-DB user file processing paths.
|
||||
|
||||
Verifies that when DISABLE_VECTOR_DB is True:
|
||||
- process_user_file_impl calls _process_user_file_without_vector_db (not indexing)
|
||||
- _process_user_file_without_vector_db extracts text, counts tokens, stores plaintext,
|
||||
sets status=COMPLETED and chunk_count=0
|
||||
- delete_user_file_impl skips vector DB chunk deletion
|
||||
- project_sync_user_file_impl skips vector DB metadata update
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
_process_user_file_without_vector_db,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
delete_user_file_impl,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
process_user_file_impl,
|
||||
)
|
||||
from onyx.background.celery.tasks.user_file_processing.tasks import (
|
||||
project_sync_user_file_impl,
|
||||
)
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.db.enums import UserFileStatus
|
||||
|
||||
TASKS_MODULE = "onyx.background.celery.tasks.user_file_processing.tasks"
|
||||
LLM_FACTORY_MODULE = "onyx.llm.factory"
|
||||
|
||||
|
||||
def _make_documents(texts: list[str]) -> list[Document]:
|
||||
"""Build a list of Document objects with the given section texts."""
|
||||
return [
|
||||
Document(
|
||||
id=str(uuid4()),
|
||||
source=DocumentSource.USER_FILE,
|
||||
sections=[TextSection(text=t)],
|
||||
semantic_identifier=f"test-doc-{i}",
|
||||
metadata={},
|
||||
)
|
||||
for i, t in enumerate(texts)
|
||||
]
|
||||
|
||||
|
||||
def _make_user_file(
|
||||
*,
|
||||
status: UserFileStatus = UserFileStatus.PROCESSING,
|
||||
file_id: str = "test-file-id",
|
||||
name: str = "test.txt",
|
||||
) -> MagicMock:
|
||||
"""Return a MagicMock mimicking a UserFile ORM instance."""
|
||||
uf = MagicMock()
|
||||
uf.id = uuid4()
|
||||
uf.file_id = file_id
|
||||
uf.name = name
|
||||
uf.status = status
|
||||
uf.token_count = None
|
||||
uf.chunk_count = None
|
||||
uf.last_project_sync_at = None
|
||||
uf.projects = []
|
||||
uf.assistants = []
|
||||
uf.needs_project_sync = True
|
||||
uf.needs_persona_sync = True
|
||||
return uf
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# _process_user_file_without_vector_db — direct tests
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestProcessUserFileWithoutVectorDb:
|
||||
@patch(f"{TASKS_MODULE}.store_user_file_plaintext")
|
||||
@patch(f"{LLM_FACTORY_MODULE}.get_llm_tokenizer_encode_func")
|
||||
@patch(f"{LLM_FACTORY_MODULE}.get_default_llm")
|
||||
def test_extracts_and_combines_text(
|
||||
self,
|
||||
mock_get_llm: MagicMock, # noqa: ARG002
|
||||
mock_get_encode: MagicMock,
|
||||
mock_store_plaintext: MagicMock,
|
||||
) -> None:
|
||||
mock_encode = MagicMock(return_value=[1, 2, 3, 4, 5])
|
||||
mock_get_encode.return_value = mock_encode
|
||||
|
||||
uf = _make_user_file()
|
||||
docs = _make_documents(["hello world", "foo bar"])
|
||||
db_session = MagicMock()
|
||||
|
||||
_process_user_file_without_vector_db(uf, docs, db_session)
|
||||
|
||||
stored_text = mock_store_plaintext.call_args.kwargs["plaintext_content"]
|
||||
assert "hello world" in stored_text
|
||||
assert "foo bar" in stored_text
|
||||
|
||||
@patch(f"{TASKS_MODULE}.store_user_file_plaintext")
|
||||
@patch(f"{LLM_FACTORY_MODULE}.get_llm_tokenizer_encode_func")
|
||||
@patch(f"{LLM_FACTORY_MODULE}.get_default_llm")
|
||||
def test_computes_token_count(
|
||||
self,
|
||||
mock_get_llm: MagicMock, # noqa: ARG002
|
||||
mock_get_encode: MagicMock,
|
||||
mock_store_plaintext: MagicMock, # noqa: ARG002
|
||||
) -> None:
|
||||
mock_encode = MagicMock(return_value=list(range(42)))
|
||||
mock_get_encode.return_value = mock_encode
|
||||
|
||||
uf = _make_user_file()
|
||||
docs = _make_documents(["some text content"])
|
||||
db_session = MagicMock()
|
||||
|
||||
_process_user_file_without_vector_db(uf, docs, db_session)
|
||||
|
||||
assert uf.token_count == 42
|
||||
|
||||
@patch(f"{TASKS_MODULE}.store_user_file_plaintext")
|
||||
@patch(f"{LLM_FACTORY_MODULE}.get_llm_tokenizer_encode_func")
|
||||
@patch(f"{LLM_FACTORY_MODULE}.get_default_llm")
|
||||
def test_token_count_falls_back_to_none_on_error(
|
||||
self,
|
||||
mock_get_llm: MagicMock,
|
||||
mock_get_encode: MagicMock, # noqa: ARG002
|
||||
mock_store_plaintext: MagicMock, # noqa: ARG002
|
||||
) -> None:
|
||||
mock_get_llm.side_effect = RuntimeError("No LLM configured")
|
||||
|
||||
uf = _make_user_file()
|
||||
docs = _make_documents(["text"])
|
||||
db_session = MagicMock()
|
||||
|
||||
_process_user_file_without_vector_db(uf, docs, db_session)
|
||||
|
||||
assert uf.token_count is None
|
||||
|
||||
@patch(f"{TASKS_MODULE}.store_user_file_plaintext")
|
||||
@patch(f"{LLM_FACTORY_MODULE}.get_llm_tokenizer_encode_func")
|
||||
@patch(f"{LLM_FACTORY_MODULE}.get_default_llm")
|
||||
def test_stores_plaintext(
|
||||
self,
|
||||
mock_get_llm: MagicMock, # noqa: ARG002
|
||||
mock_get_encode: MagicMock,
|
||||
mock_store_plaintext: MagicMock,
|
||||
) -> None:
|
||||
mock_get_encode.return_value = MagicMock(return_value=[1])
|
||||
|
||||
uf = _make_user_file()
|
||||
docs = _make_documents(["content to store"])
|
||||
db_session = MagicMock()
|
||||
|
||||
_process_user_file_without_vector_db(uf, docs, db_session)
|
||||
|
||||
mock_store_plaintext.assert_called_once_with(
|
||||
user_file_id=uf.id,
|
||||
plaintext_content="content to store",
|
||||
)
|
||||
|
||||
@patch(f"{TASKS_MODULE}.store_user_file_plaintext")
|
||||
@patch(f"{LLM_FACTORY_MODULE}.get_llm_tokenizer_encode_func")
|
||||
@patch(f"{LLM_FACTORY_MODULE}.get_default_llm")
|
||||
def test_sets_completed_status_and_zero_chunk_count(
|
||||
self,
|
||||
mock_get_llm: MagicMock, # noqa: ARG002
|
||||
mock_get_encode: MagicMock,
|
||||
mock_store_plaintext: MagicMock, # noqa: ARG002
|
||||
) -> None:
|
||||
mock_get_encode.return_value = MagicMock(return_value=[1])
|
||||
|
||||
uf = _make_user_file()
|
||||
docs = _make_documents(["text"])
|
||||
db_session = MagicMock()
|
||||
|
||||
_process_user_file_without_vector_db(uf, docs, db_session)
|
||||
|
||||
assert uf.status == UserFileStatus.COMPLETED
|
||||
assert uf.chunk_count == 0
|
||||
assert uf.last_project_sync_at is not None
|
||||
db_session.add.assert_called_once_with(uf)
|
||||
db_session.commit.assert_called_once()
|
||||
|
||||
@patch(f"{TASKS_MODULE}.store_user_file_plaintext")
|
||||
@patch(f"{LLM_FACTORY_MODULE}.get_llm_tokenizer_encode_func")
|
||||
@patch(f"{LLM_FACTORY_MODULE}.get_default_llm")
|
||||
def test_preserves_deleting_status(
|
||||
self,
|
||||
mock_get_llm: MagicMock, # noqa: ARG002
|
||||
mock_get_encode: MagicMock,
|
||||
mock_store_plaintext: MagicMock, # noqa: ARG002
|
||||
) -> None:
|
||||
mock_get_encode.return_value = MagicMock(return_value=[1])
|
||||
|
||||
uf = _make_user_file(status=UserFileStatus.DELETING)
|
||||
docs = _make_documents(["text"])
|
||||
db_session = MagicMock()
|
||||
|
||||
_process_user_file_without_vector_db(uf, docs, db_session)
|
||||
|
||||
assert uf.status == UserFileStatus.DELETING
|
||||
assert uf.chunk_count == 0
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# process_user_file_impl — branching on DISABLE_VECTOR_DB
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestProcessImplBranching:
|
||||
@patch(f"{TASKS_MODULE}._process_user_file_without_vector_db")
|
||||
@patch(f"{TASKS_MODULE}._process_user_file_with_indexing")
|
||||
@patch(f"{TASKS_MODULE}.DISABLE_VECTOR_DB", True)
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
def test_calls_without_vector_db_when_disabled(
|
||||
self,
|
||||
mock_get_session: MagicMock,
|
||||
mock_with_indexing: MagicMock,
|
||||
mock_without_vdb: MagicMock,
|
||||
) -> None:
|
||||
uf = _make_user_file()
|
||||
session = MagicMock()
|
||||
session.get.return_value = uf
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
connector_mock = MagicMock()
|
||||
connector_mock.load_from_state.return_value = [_make_documents(["hello"])]
|
||||
|
||||
with patch(f"{TASKS_MODULE}.LocalFileConnector", return_value=connector_mock):
|
||||
process_user_file_impl(
|
||||
user_file_id=str(uf.id),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=False,
|
||||
)
|
||||
|
||||
mock_without_vdb.assert_called_once()
|
||||
mock_with_indexing.assert_not_called()
|
||||
|
||||
@patch(f"{TASKS_MODULE}._process_user_file_without_vector_db")
|
||||
@patch(f"{TASKS_MODULE}._process_user_file_with_indexing")
|
||||
@patch(f"{TASKS_MODULE}.DISABLE_VECTOR_DB", False)
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
def test_calls_with_indexing_when_vector_db_enabled(
|
||||
self,
|
||||
mock_get_session: MagicMock,
|
||||
mock_with_indexing: MagicMock,
|
||||
mock_without_vdb: MagicMock,
|
||||
) -> None:
|
||||
uf = _make_user_file()
|
||||
session = MagicMock()
|
||||
session.get.return_value = uf
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
connector_mock = MagicMock()
|
||||
connector_mock.load_from_state.return_value = [_make_documents(["hello"])]
|
||||
|
||||
with patch(f"{TASKS_MODULE}.LocalFileConnector", return_value=connector_mock):
|
||||
process_user_file_impl(
|
||||
user_file_id=str(uf.id),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=False,
|
||||
)
|
||||
|
||||
mock_with_indexing.assert_called_once()
|
||||
mock_without_vdb.assert_not_called()
|
||||
|
||||
@patch(f"{TASKS_MODULE}.run_indexing_pipeline")
|
||||
@patch(f"{TASKS_MODULE}.store_user_file_plaintext")
|
||||
@patch(f"{TASKS_MODULE}.DISABLE_VECTOR_DB", True)
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
def test_indexing_pipeline_not_called_when_disabled(
|
||||
self,
|
||||
mock_get_session: MagicMock,
|
||||
mock_store_plaintext: MagicMock, # noqa: ARG002
|
||||
mock_run_pipeline: MagicMock,
|
||||
) -> None:
|
||||
"""End-to-end: verify run_indexing_pipeline is never invoked."""
|
||||
uf = _make_user_file()
|
||||
session = MagicMock()
|
||||
session.get.return_value = uf
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
connector_mock = MagicMock()
|
||||
connector_mock.load_from_state.return_value = [_make_documents(["content"])]
|
||||
|
||||
with (
|
||||
patch(f"{TASKS_MODULE}.LocalFileConnector", return_value=connector_mock),
|
||||
patch(f"{LLM_FACTORY_MODULE}.get_default_llm"),
|
||||
patch(
|
||||
f"{LLM_FACTORY_MODULE}.get_llm_tokenizer_encode_func",
|
||||
return_value=MagicMock(return_value=[1, 2, 3]),
|
||||
),
|
||||
):
|
||||
process_user_file_impl(
|
||||
user_file_id=str(uf.id),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=False,
|
||||
)
|
||||
|
||||
mock_run_pipeline.assert_not_called()
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# delete_user_file_impl — vector DB skip
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDeleteImplNoVectorDb:
|
||||
@patch(f"{TASKS_MODULE}.DISABLE_VECTOR_DB", True)
|
||||
@patch(f"{TASKS_MODULE}.get_default_file_store")
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
def test_skips_vector_db_deletion(
|
||||
self,
|
||||
mock_get_session: MagicMock,
|
||||
mock_get_file_store: MagicMock,
|
||||
) -> None:
|
||||
uf = _make_user_file(status=UserFileStatus.DELETING)
|
||||
session = MagicMock()
|
||||
session.get.return_value = uf
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
mock_get_file_store.return_value = MagicMock()
|
||||
|
||||
with (
|
||||
patch(f"{TASKS_MODULE}.get_all_document_indices") as mock_get_indices,
|
||||
patch(f"{TASKS_MODULE}.get_active_search_settings") as mock_get_ss,
|
||||
patch(f"{TASKS_MODULE}.httpx_init_vespa_pool") as mock_vespa_pool,
|
||||
):
|
||||
delete_user_file_impl(
|
||||
user_file_id=str(uf.id),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=False,
|
||||
)
|
||||
|
||||
mock_get_indices.assert_not_called()
|
||||
mock_get_ss.assert_not_called()
|
||||
mock_vespa_pool.assert_not_called()
|
||||
|
||||
session.delete.assert_called_once_with(uf)
|
||||
session.commit.assert_called_once()
|
||||
|
||||
@patch(f"{TASKS_MODULE}.DISABLE_VECTOR_DB", True)
|
||||
@patch(f"{TASKS_MODULE}.get_default_file_store")
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
def test_still_deletes_file_store_and_db_record(
|
||||
self,
|
||||
mock_get_session: MagicMock,
|
||||
mock_get_file_store: MagicMock,
|
||||
) -> None:
|
||||
uf = _make_user_file(status=UserFileStatus.DELETING)
|
||||
session = MagicMock()
|
||||
session.get.return_value = uf
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
file_store = MagicMock()
|
||||
mock_get_file_store.return_value = file_store
|
||||
|
||||
delete_user_file_impl(
|
||||
user_file_id=str(uf.id),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=False,
|
||||
)
|
||||
|
||||
assert file_store.delete_file.call_count == 2
|
||||
session.delete.assert_called_once_with(uf)
|
||||
session.commit.assert_called_once()
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# project_sync_user_file_impl — vector DB skip
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestProjectSyncImplNoVectorDb:
|
||||
@patch(f"{TASKS_MODULE}.DISABLE_VECTOR_DB", True)
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
def test_skips_vector_db_update(
|
||||
self,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
uf = _make_user_file(status=UserFileStatus.COMPLETED)
|
||||
session = MagicMock()
|
||||
session.execute.return_value.scalar_one_or_none.return_value = uf
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
with (
|
||||
patch(f"{TASKS_MODULE}.get_all_document_indices") as mock_get_indices,
|
||||
patch(f"{TASKS_MODULE}.get_active_search_settings") as mock_get_ss,
|
||||
patch(f"{TASKS_MODULE}.httpx_init_vespa_pool") as mock_vespa_pool,
|
||||
):
|
||||
project_sync_user_file_impl(
|
||||
user_file_id=str(uf.id),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=False,
|
||||
)
|
||||
|
||||
mock_get_indices.assert_not_called()
|
||||
mock_get_ss.assert_not_called()
|
||||
mock_vespa_pool.assert_not_called()
|
||||
|
||||
@patch(f"{TASKS_MODULE}.DISABLE_VECTOR_DB", True)
|
||||
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
|
||||
def test_still_clears_sync_flags(
|
||||
self,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
uf = _make_user_file(status=UserFileStatus.COMPLETED)
|
||||
session = MagicMock()
|
||||
session.execute.return_value.scalar_one_or_none.return_value = uf
|
||||
mock_get_session.return_value.__enter__.return_value = session
|
||||
|
||||
project_sync_user_file_impl(
|
||||
user_file_id=str(uf.id),
|
||||
tenant_id="test-tenant",
|
||||
redis_locking=False,
|
||||
)
|
||||
|
||||
assert uf.needs_project_sync is False
|
||||
assert uf.needs_persona_sync is False
|
||||
assert uf.last_project_sync_at is not None
|
||||
session.add.assert_called_once_with(uf)
|
||||
session.commit.assert_called_once()
|
||||
@@ -1,52 +0,0 @@
|
||||
"""Tests for startup validation in no-vector-DB mode.
|
||||
|
||||
Verifies that DISABLE_VECTOR_DB raises RuntimeError when combined with
|
||||
incompatible settings (MULTI_TENANT, ENABLE_CRAFT).
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestValidateNoVectorDbSettings:
|
||||
@patch("onyx.main.DISABLE_VECTOR_DB", False)
|
||||
def test_no_error_when_vector_db_enabled(self) -> None:
|
||||
from onyx.main import validate_no_vector_db_settings
|
||||
|
||||
validate_no_vector_db_settings()
|
||||
|
||||
@patch("onyx.main.DISABLE_VECTOR_DB", True)
|
||||
@patch("onyx.main.MULTI_TENANT", False)
|
||||
@patch("onyx.server.features.build.configs.ENABLE_CRAFT", False)
|
||||
def test_no_error_when_no_conflicts(self) -> None:
|
||||
from onyx.main import validate_no_vector_db_settings
|
||||
|
||||
validate_no_vector_db_settings()
|
||||
|
||||
@patch("onyx.main.DISABLE_VECTOR_DB", True)
|
||||
@patch("onyx.main.MULTI_TENANT", True)
|
||||
def test_raises_on_multi_tenant(self) -> None:
|
||||
from onyx.main import validate_no_vector_db_settings
|
||||
|
||||
with pytest.raises(RuntimeError, match="MULTI_TENANT"):
|
||||
validate_no_vector_db_settings()
|
||||
|
||||
@patch("onyx.main.DISABLE_VECTOR_DB", True)
|
||||
@patch("onyx.main.MULTI_TENANT", False)
|
||||
@patch("onyx.server.features.build.configs.ENABLE_CRAFT", True)
|
||||
def test_raises_on_enable_craft(self) -> None:
|
||||
from onyx.main import validate_no_vector_db_settings
|
||||
|
||||
with pytest.raises(RuntimeError, match="ENABLE_CRAFT"):
|
||||
validate_no_vector_db_settings()
|
||||
|
||||
@patch("onyx.main.DISABLE_VECTOR_DB", True)
|
||||
@patch("onyx.main.MULTI_TENANT", True)
|
||||
@patch("onyx.server.features.build.configs.ENABLE_CRAFT", True)
|
||||
def test_multi_tenant_checked_before_craft(self) -> None:
|
||||
"""MULTI_TENANT is checked first, so it should be the error raised."""
|
||||
from onyx.main import validate_no_vector_db_settings
|
||||
|
||||
with pytest.raises(RuntimeError, match="MULTI_TENANT"):
|
||||
validate_no_vector_db_settings()
|
||||
@@ -1,196 +0,0 @@
|
||||
"""Tests for tool construction when DISABLE_VECTOR_DB is True.
|
||||
|
||||
Verifies that:
|
||||
- SearchTool.is_available() returns False when vector DB is disabled
|
||||
- OpenURLTool.is_available() returns False when vector DB is disabled
|
||||
- The force-add SearchTool block is suppressed when DISABLE_VECTOR_DB
|
||||
- FileReaderTool.is_available() returns True when vector DB is disabled
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.tools.tool_implementations.file_reader.file_reader_tool import FileReaderTool
|
||||
|
||||
APP_CONFIGS_MODULE = "onyx.configs.app_configs"
|
||||
FILE_READER_MODULE = "onyx.tools.tool_implementations.file_reader.file_reader_tool"
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# SearchTool.is_available()
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSearchToolAvailability:
|
||||
@patch(f"{APP_CONFIGS_MODULE}.DISABLE_VECTOR_DB", True)
|
||||
def test_unavailable_when_vector_db_disabled(self) -> None:
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
|
||||
assert SearchTool.is_available(MagicMock()) is False
|
||||
|
||||
@patch("onyx.db.connector.check_user_files_exist", return_value=True)
|
||||
@patch(
|
||||
"onyx.tools.tool_implementations.search.search_tool.check_federated_connectors_exist",
|
||||
return_value=False,
|
||||
)
|
||||
@patch(
|
||||
"onyx.tools.tool_implementations.search.search_tool.check_connectors_exist",
|
||||
return_value=False,
|
||||
)
|
||||
@patch(f"{APP_CONFIGS_MODULE}.DISABLE_VECTOR_DB", False)
|
||||
def test_available_when_vector_db_enabled_and_files_exist(
|
||||
self,
|
||||
mock_connectors: MagicMock, # noqa: ARG002
|
||||
mock_federated: MagicMock, # noqa: ARG002
|
||||
mock_user_files: MagicMock, # noqa: ARG002
|
||||
) -> None:
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
|
||||
assert SearchTool.is_available(MagicMock()) is True
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# OpenURLTool.is_available()
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestOpenURLToolAvailability:
|
||||
@patch(f"{APP_CONFIGS_MODULE}.DISABLE_VECTOR_DB", True)
|
||||
def test_unavailable_when_vector_db_disabled(self) -> None:
|
||||
from onyx.tools.tool_implementations.open_url.open_url_tool import OpenURLTool
|
||||
|
||||
assert OpenURLTool.is_available(MagicMock()) is False
|
||||
|
||||
@patch(f"{APP_CONFIGS_MODULE}.DISABLE_VECTOR_DB", False)
|
||||
def test_available_when_vector_db_enabled(self) -> None:
|
||||
from onyx.tools.tool_implementations.open_url.open_url_tool import OpenURLTool
|
||||
|
||||
assert OpenURLTool.is_available(MagicMock()) is True
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# FileReaderTool.is_available()
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFileReaderToolAvailability:
|
||||
@patch(f"{FILE_READER_MODULE}.DISABLE_VECTOR_DB", True)
|
||||
def test_available_when_vector_db_disabled(self) -> None:
|
||||
assert FileReaderTool.is_available(MagicMock()) is True
|
||||
|
||||
@patch(f"{FILE_READER_MODULE}.DISABLE_VECTOR_DB", False)
|
||||
def test_unavailable_when_vector_db_enabled(self) -> None:
|
||||
assert FileReaderTool.is_available(MagicMock()) is False
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Force-add SearchTool suppression
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestForceAddSearchToolGuard:
|
||||
def test_force_add_block_checks_disable_vector_db(self) -> None:
|
||||
"""The force-add SearchTool block in construct_tools should include
|
||||
`not DISABLE_VECTOR_DB` so that forced search is also suppressed
|
||||
without a vector DB."""
|
||||
import inspect
|
||||
|
||||
from onyx.tools.tool_constructor import construct_tools
|
||||
|
||||
source = inspect.getsource(construct_tools)
|
||||
assert "DISABLE_VECTOR_DB" in source, (
|
||||
"construct_tools should reference DISABLE_VECTOR_DB "
|
||||
"to suppress force-adding SearchTool"
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Persona API — _validate_vector_db_knowledge
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestValidateVectorDbKnowledge:
|
||||
@patch(
|
||||
"onyx.server.features.persona.api.DISABLE_VECTOR_DB",
|
||||
True,
|
||||
)
|
||||
def test_rejects_document_set_ids(self) -> None:
|
||||
from fastapi import HTTPException
|
||||
|
||||
from onyx.server.features.persona.api import _validate_vector_db_knowledge
|
||||
|
||||
request = MagicMock()
|
||||
request.document_set_ids = [1]
|
||||
request.hierarchy_node_ids = []
|
||||
request.document_ids = []
|
||||
|
||||
with __import__("pytest").raises(HTTPException) as exc_info:
|
||||
_validate_vector_db_knowledge(request)
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "document sets" in exc_info.value.detail
|
||||
|
||||
@patch(
|
||||
"onyx.server.features.persona.api.DISABLE_VECTOR_DB",
|
||||
True,
|
||||
)
|
||||
def test_rejects_hierarchy_node_ids(self) -> None:
|
||||
from fastapi import HTTPException
|
||||
|
||||
from onyx.server.features.persona.api import _validate_vector_db_knowledge
|
||||
|
||||
request = MagicMock()
|
||||
request.document_set_ids = []
|
||||
request.hierarchy_node_ids = [1]
|
||||
request.document_ids = []
|
||||
|
||||
with __import__("pytest").raises(HTTPException) as exc_info:
|
||||
_validate_vector_db_knowledge(request)
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "hierarchy nodes" in exc_info.value.detail
|
||||
|
||||
@patch(
|
||||
"onyx.server.features.persona.api.DISABLE_VECTOR_DB",
|
||||
True,
|
||||
)
|
||||
def test_rejects_document_ids(self) -> None:
|
||||
from fastapi import HTTPException
|
||||
|
||||
from onyx.server.features.persona.api import _validate_vector_db_knowledge
|
||||
|
||||
request = MagicMock()
|
||||
request.document_set_ids = []
|
||||
request.hierarchy_node_ids = []
|
||||
request.document_ids = ["doc-abc"]
|
||||
|
||||
with __import__("pytest").raises(HTTPException) as exc_info:
|
||||
_validate_vector_db_knowledge(request)
|
||||
assert exc_info.value.status_code == 400
|
||||
assert "documents" in exc_info.value.detail
|
||||
|
||||
@patch(
|
||||
"onyx.server.features.persona.api.DISABLE_VECTOR_DB",
|
||||
True,
|
||||
)
|
||||
def test_allows_user_files_only(self) -> None:
|
||||
from onyx.server.features.persona.api import _validate_vector_db_knowledge
|
||||
|
||||
request = MagicMock()
|
||||
request.document_set_ids = []
|
||||
request.hierarchy_node_ids = []
|
||||
request.document_ids = []
|
||||
|
||||
_validate_vector_db_knowledge(request)
|
||||
|
||||
@patch(
|
||||
"onyx.server.features.persona.api.DISABLE_VECTOR_DB",
|
||||
False,
|
||||
)
|
||||
def test_allows_everything_when_vector_db_enabled(self) -> None:
|
||||
from onyx.server.features.persona.api import _validate_vector_db_knowledge
|
||||
|
||||
request = MagicMock()
|
||||
request.document_set_ids = [1, 2]
|
||||
request.hierarchy_node_ids = [3]
|
||||
request.document_ids = ["doc-x"]
|
||||
|
||||
_validate_vector_db_knowledge(request)
|
||||
@@ -1,237 +0,0 @@
|
||||
"""Tests for the FileReaderTool.
|
||||
|
||||
Verifies:
|
||||
- Tool definition schema is well-formed
|
||||
- File ID validation (allowlist, UUID format)
|
||||
- Character range extraction and clamping
|
||||
- Error handling for missing parameters and non-text files
|
||||
- is_available() reflects DISABLE_VECTOR_DB
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.tools.models import ToolCallException
|
||||
from onyx.tools.tool_implementations.file_reader.file_reader_tool import FILE_ID_FIELD
|
||||
from onyx.tools.tool_implementations.file_reader.file_reader_tool import FileReaderTool
|
||||
from onyx.tools.tool_implementations.file_reader.file_reader_tool import MAX_NUM_CHARS
|
||||
from onyx.tools.tool_implementations.file_reader.file_reader_tool import NUM_CHARS_FIELD
|
||||
from onyx.tools.tool_implementations.file_reader.file_reader_tool import (
|
||||
START_CHAR_FIELD,
|
||||
)
|
||||
|
||||
TOOL_MODULE = "onyx.tools.tool_implementations.file_reader.file_reader_tool"
|
||||
_PLACEMENT = Placement(turn_index=0)
|
||||
|
||||
|
||||
def _make_tool(
|
||||
user_file_ids: list | None = None,
|
||||
chat_file_ids: list | None = None,
|
||||
) -> FileReaderTool:
|
||||
emitter = MagicMock()
|
||||
return FileReaderTool(
|
||||
tool_id=99,
|
||||
emitter=emitter,
|
||||
user_file_ids=user_file_ids or [],
|
||||
chat_file_ids=chat_file_ids or [],
|
||||
)
|
||||
|
||||
|
||||
def _text_file(content: str, filename: str = "test.txt") -> InMemoryChatFile:
|
||||
return InMemoryChatFile(
|
||||
file_id="some-file-id",
|
||||
content=content.encode("utf-8"),
|
||||
file_type=ChatFileType.PLAIN_TEXT,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tool metadata
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestToolMetadata:
|
||||
def test_tool_name(self) -> None:
|
||||
tool = _make_tool()
|
||||
assert tool.name == "read_file"
|
||||
|
||||
def test_tool_definition_schema(self) -> None:
|
||||
tool = _make_tool()
|
||||
defn = tool.tool_definition()
|
||||
assert defn["type"] == "function"
|
||||
func = defn["function"]
|
||||
assert func["name"] == "read_file"
|
||||
props = func["parameters"]["properties"]
|
||||
assert FILE_ID_FIELD in props
|
||||
assert START_CHAR_FIELD in props
|
||||
assert NUM_CHARS_FIELD in props
|
||||
assert func["parameters"]["required"] == [FILE_ID_FIELD]
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# File ID validation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFileIdValidation:
|
||||
def test_rejects_invalid_uuid(self) -> None:
|
||||
tool = _make_tool()
|
||||
with pytest.raises(ToolCallException, match="Invalid file_id"):
|
||||
tool._validate_file_id("not-a-uuid")
|
||||
|
||||
def test_rejects_file_not_in_allowlist(self) -> None:
|
||||
tool = _make_tool(user_file_ids=[uuid4()])
|
||||
other_id = uuid4()
|
||||
with pytest.raises(ToolCallException, match="not in available files"):
|
||||
tool._validate_file_id(str(other_id))
|
||||
|
||||
def test_accepts_user_file_id(self) -> None:
|
||||
uid = uuid4()
|
||||
tool = _make_tool(user_file_ids=[uid])
|
||||
assert tool._validate_file_id(str(uid)) == uid
|
||||
|
||||
def test_accepts_chat_file_id(self) -> None:
|
||||
cid = uuid4()
|
||||
tool = _make_tool(chat_file_ids=[cid])
|
||||
assert tool._validate_file_id(str(cid)) == cid
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# run() — character range extraction
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRun:
|
||||
@patch(f"{TOOL_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TOOL_MODULE}.load_user_file")
|
||||
def test_returns_full_content_by_default(
|
||||
self,
|
||||
mock_load_user_file: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
uid = uuid4()
|
||||
content = "Hello, world!"
|
||||
mock_load_user_file.return_value = _text_file(content)
|
||||
mock_get_session.return_value.__enter__.return_value = MagicMock()
|
||||
|
||||
tool = _make_tool(user_file_ids=[uid])
|
||||
resp = tool.run(
|
||||
placement=_PLACEMENT,
|
||||
override_kwargs=MagicMock(),
|
||||
**{FILE_ID_FIELD: str(uid)},
|
||||
)
|
||||
assert content in resp.llm_facing_response
|
||||
|
||||
@patch(f"{TOOL_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TOOL_MODULE}.load_user_file")
|
||||
def test_respects_start_char_and_num_chars(
|
||||
self,
|
||||
mock_load_user_file: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
uid = uuid4()
|
||||
content = "abcdefghijklmnop"
|
||||
mock_load_user_file.return_value = _text_file(content)
|
||||
mock_get_session.return_value.__enter__.return_value = MagicMock()
|
||||
|
||||
tool = _make_tool(user_file_ids=[uid])
|
||||
resp = tool.run(
|
||||
placement=_PLACEMENT,
|
||||
override_kwargs=MagicMock(),
|
||||
**{FILE_ID_FIELD: str(uid), START_CHAR_FIELD: 4, NUM_CHARS_FIELD: 6},
|
||||
)
|
||||
assert "efghij" in resp.llm_facing_response
|
||||
|
||||
@patch(f"{TOOL_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TOOL_MODULE}.load_user_file")
|
||||
def test_clamps_num_chars_to_max(
|
||||
self,
|
||||
mock_load_user_file: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
uid = uuid4()
|
||||
content = "x" * (MAX_NUM_CHARS + 500)
|
||||
mock_load_user_file.return_value = _text_file(content)
|
||||
mock_get_session.return_value.__enter__.return_value = MagicMock()
|
||||
|
||||
tool = _make_tool(user_file_ids=[uid])
|
||||
resp = tool.run(
|
||||
placement=_PLACEMENT,
|
||||
override_kwargs=MagicMock(),
|
||||
**{FILE_ID_FIELD: str(uid), NUM_CHARS_FIELD: MAX_NUM_CHARS + 9999},
|
||||
)
|
||||
assert f"Characters 0-{MAX_NUM_CHARS}" in resp.llm_facing_response
|
||||
|
||||
@patch(f"{TOOL_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TOOL_MODULE}.load_user_file")
|
||||
def test_includes_continuation_hint(
|
||||
self,
|
||||
mock_load_user_file: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
uid = uuid4()
|
||||
content = "x" * 100
|
||||
mock_load_user_file.return_value = _text_file(content)
|
||||
mock_get_session.return_value.__enter__.return_value = MagicMock()
|
||||
|
||||
tool = _make_tool(user_file_ids=[uid])
|
||||
resp = tool.run(
|
||||
placement=_PLACEMENT,
|
||||
override_kwargs=MagicMock(),
|
||||
**{FILE_ID_FIELD: str(uid), NUM_CHARS_FIELD: 10},
|
||||
)
|
||||
assert "use start_char=10 to continue reading" in resp.llm_facing_response
|
||||
|
||||
def test_raises_on_missing_file_id(self) -> None:
|
||||
tool = _make_tool()
|
||||
with pytest.raises(ToolCallException, match="Missing required"):
|
||||
tool.run(
|
||||
placement=_PLACEMENT,
|
||||
override_kwargs=MagicMock(),
|
||||
)
|
||||
|
||||
@patch(f"{TOOL_MODULE}.get_session_with_current_tenant")
|
||||
@patch(f"{TOOL_MODULE}.load_user_file")
|
||||
def test_raises_on_non_text_file(
|
||||
self,
|
||||
mock_load_user_file: MagicMock,
|
||||
mock_get_session: MagicMock,
|
||||
) -> None:
|
||||
uid = uuid4()
|
||||
mock_load_user_file.return_value = InMemoryChatFile(
|
||||
file_id="img",
|
||||
content=b"\x89PNG",
|
||||
file_type=ChatFileType.IMAGE,
|
||||
filename="photo.png",
|
||||
)
|
||||
mock_get_session.return_value.__enter__.return_value = MagicMock()
|
||||
|
||||
tool = _make_tool(user_file_ids=[uid])
|
||||
with pytest.raises(ToolCallException, match="not a text file"):
|
||||
tool.run(
|
||||
placement=_PLACEMENT,
|
||||
override_kwargs=MagicMock(),
|
||||
**{FILE_ID_FIELD: str(uid)},
|
||||
)
|
||||
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# is_available()
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestIsAvailable:
|
||||
@patch(f"{TOOL_MODULE}.DISABLE_VECTOR_DB", True)
|
||||
def test_available_when_vector_db_disabled(self) -> None:
|
||||
assert FileReaderTool.is_available(MagicMock()) is True
|
||||
|
||||
@patch(f"{TOOL_MODULE}.DISABLE_VECTOR_DB", False)
|
||||
def test_unavailable_when_vector_db_enabled(self) -> None:
|
||||
assert FileReaderTool.is_available(MagicMock()) is False
|
||||
@@ -40,8 +40,6 @@ const TRAY_MENU_OPEN_APP_ID: &str = "tray_open_app";
|
||||
const TRAY_MENU_OPEN_CHAT_ID: &str = "tray_open_chat";
|
||||
const TRAY_MENU_SHOW_IN_BAR_ID: &str = "tray_show_in_menu_bar";
|
||||
const TRAY_MENU_QUIT_ID: &str = "tray_quit";
|
||||
const MENU_SHOW_MENU_BAR_ID: &str = "show_menu_bar";
|
||||
const MENU_HIDE_DECORATIONS_ID: &str = "hide_window_decorations";
|
||||
const CHAT_LINK_INTERCEPT_SCRIPT: &str = r##"
|
||||
(() => {
|
||||
if (window.__ONYX_CHAT_LINK_INTERCEPT_INSTALLED__) {
|
||||
@@ -173,92 +171,25 @@ const CHAT_LINK_INTERCEPT_SCRIPT: &str = r##"
|
||||
})();
|
||||
"##;
|
||||
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
const MENU_KEY_HANDLER_SCRIPT: &str = r#"
|
||||
(() => {
|
||||
if (window.__ONYX_MENU_KEY_HANDLER__) return;
|
||||
window.__ONYX_MENU_KEY_HANDLER__ = true;
|
||||
|
||||
let altHeld = false;
|
||||
|
||||
function invoke(cmd) {
|
||||
const fn_ =
|
||||
window.__TAURI__?.core?.invoke || window.__TAURI_INTERNALS__?.invoke;
|
||||
if (typeof fn_ === 'function') fn_(cmd);
|
||||
}
|
||||
|
||||
function releaseAltAndHideMenu() {
|
||||
if (!altHeld) {
|
||||
return;
|
||||
}
|
||||
altHeld = false;
|
||||
invoke('hide_menu_bar_temporary');
|
||||
}
|
||||
|
||||
document.addEventListener('keydown', (e) => {
|
||||
if (e.key === 'Alt') {
|
||||
if (!altHeld) {
|
||||
altHeld = true;
|
||||
invoke('show_menu_bar_temporarily');
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (e.altKey && e.key === 'F1') {
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
altHeld = false;
|
||||
invoke('toggle_menu_bar');
|
||||
return;
|
||||
}
|
||||
}, true);
|
||||
|
||||
document.addEventListener('keyup', (e) => {
|
||||
if (e.key === 'Alt' && altHeld) {
|
||||
releaseAltAndHideMenu();
|
||||
}
|
||||
}, true);
|
||||
|
||||
window.addEventListener('blur', () => {
|
||||
releaseAltAndHideMenu();
|
||||
});
|
||||
|
||||
document.addEventListener('visibilitychange', () => {
|
||||
if (document.hidden) {
|
||||
releaseAltAndHideMenu();
|
||||
}
|
||||
});
|
||||
})();
|
||||
"#;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AppConfig {
|
||||
/// The Onyx server URL (default: https://cloud.onyx.app)
|
||||
pub server_url: String,
|
||||
|
||||
/// Optional: Custom window title
|
||||
#[serde(default = "default_window_title")]
|
||||
pub window_title: String,
|
||||
|
||||
#[serde(default = "default_show_menu_bar")]
|
||||
pub show_menu_bar: bool,
|
||||
|
||||
#[serde(default)]
|
||||
pub hide_window_decorations: bool,
|
||||
}
|
||||
|
||||
fn default_window_title() -> String {
|
||||
"Onyx".to_string()
|
||||
}
|
||||
|
||||
fn default_show_menu_bar() -> bool {
|
||||
true
|
||||
}
|
||||
|
||||
impl Default for AppConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
server_url: DEFAULT_SERVER_URL.to_string(),
|
||||
window_title: default_window_title(),
|
||||
show_menu_bar: true,
|
||||
hide_window_decorations: false,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -316,7 +247,6 @@ struct ConfigState {
|
||||
config: RwLock<AppConfig>,
|
||||
config_initialized: RwLock<bool>,
|
||||
app_base_url: RwLock<Option<Url>>,
|
||||
menu_temporarily_visible: RwLock<bool>,
|
||||
}
|
||||
|
||||
fn focus_main_window(app: &AppHandle) {
|
||||
@@ -371,7 +301,6 @@ fn trigger_new_window(app: &AppHandle) {
|
||||
inject_titlebar(window.clone());
|
||||
}
|
||||
|
||||
apply_settings_to_window(&handle, &window);
|
||||
let _ = window.set_focus();
|
||||
}
|
||||
});
|
||||
@@ -648,15 +577,18 @@ async fn new_window(app: AppHandle, state: tauri::State<'_, ConfigState>) -> Res
|
||||
#[cfg(target_os = "linux")]
|
||||
let builder = builder.background_color(tauri::window::Color(0x1a, 0x1a, 0x2e, 0xff));
|
||||
|
||||
let window = builder.build().map_err(|e| e.to_string())?;
|
||||
|
||||
#[cfg(target_os = "macos")]
|
||||
{
|
||||
let window = builder.build().map_err(|e| e.to_string())?;
|
||||
// Apply vibrancy effect and inject titlebar
|
||||
let _ = apply_vibrancy(&window, NSVisualEffectMaterial::Sidebar, None, None);
|
||||
inject_titlebar(window.clone());
|
||||
}
|
||||
|
||||
apply_settings_to_window(&app, &window);
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
{
|
||||
let _window = builder.build().map_err(|e| e.to_string())?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -692,142 +624,6 @@ async fn start_drag_window(window: tauri::Window) -> Result<(), String> {
|
||||
window.start_dragging().map_err(|e| e.to_string())
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Window Settings
|
||||
// ============================================================================
|
||||
|
||||
fn find_check_menu_item(
|
||||
app: &AppHandle,
|
||||
id: &str,
|
||||
) -> Option<CheckMenuItem<tauri::Wry>> {
|
||||
let menu = app.menu()?;
|
||||
for item in menu.items().ok()? {
|
||||
if let Some(submenu) = item.as_submenu() {
|
||||
for sub_item in submenu.items().ok()? {
|
||||
if let Some(check) = sub_item.as_check_menuitem() {
|
||||
if check.id().as_ref() == id {
|
||||
return Some(check.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn apply_settings_to_window(app: &AppHandle, window: &tauri::WebviewWindow) {
|
||||
if cfg!(target_os = "macos") {
|
||||
return;
|
||||
}
|
||||
let state = app.state::<ConfigState>();
|
||||
let config = state.config.read().unwrap();
|
||||
let temp_visible = *state.menu_temporarily_visible.read().unwrap();
|
||||
if !config.show_menu_bar && !temp_visible {
|
||||
let _ = window.hide_menu();
|
||||
}
|
||||
if config.hide_window_decorations {
|
||||
let _ = window.set_decorations(false);
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_menu_bar_toggle(app: &AppHandle) {
|
||||
if cfg!(target_os = "macos") {
|
||||
return;
|
||||
}
|
||||
let state = app.state::<ConfigState>();
|
||||
let show = {
|
||||
let mut config = state.config.write().unwrap();
|
||||
config.show_menu_bar = !config.show_menu_bar;
|
||||
let _ = save_config(&config);
|
||||
config.show_menu_bar
|
||||
};
|
||||
|
||||
*state.menu_temporarily_visible.write().unwrap() = false;
|
||||
|
||||
for (_, window) in app.webview_windows() {
|
||||
if show {
|
||||
let _ = window.show_menu();
|
||||
} else {
|
||||
let _ = window.hide_menu();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_decorations_toggle(app: &AppHandle) {
|
||||
if cfg!(target_os = "macos") {
|
||||
return;
|
||||
}
|
||||
let state = app.state::<ConfigState>();
|
||||
let hide = {
|
||||
let mut config = state.config.write().unwrap();
|
||||
config.hide_window_decorations = !config.hide_window_decorations;
|
||||
let _ = save_config(&config);
|
||||
config.hide_window_decorations
|
||||
};
|
||||
|
||||
for (_, window) in app.webview_windows() {
|
||||
let _ = window.set_decorations(!hide);
|
||||
}
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
fn toggle_menu_bar(app: AppHandle) {
|
||||
if cfg!(target_os = "macos") {
|
||||
return;
|
||||
}
|
||||
handle_menu_bar_toggle(&app);
|
||||
|
||||
let state = app.state::<ConfigState>();
|
||||
let checked = state.config.read().unwrap().show_menu_bar;
|
||||
if let Some(check) = find_check_menu_item(&app, MENU_SHOW_MENU_BAR_ID) {
|
||||
let _ = check.set_checked(checked);
|
||||
}
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
fn show_menu_bar_temporarily(app: AppHandle) {
|
||||
if cfg!(target_os = "macos") {
|
||||
return;
|
||||
}
|
||||
let state = app.state::<ConfigState>();
|
||||
if state.config.read().unwrap().show_menu_bar {
|
||||
return;
|
||||
}
|
||||
|
||||
let mut temp = state.menu_temporarily_visible.write().unwrap();
|
||||
if *temp {
|
||||
return;
|
||||
}
|
||||
*temp = true;
|
||||
drop(temp);
|
||||
|
||||
for (_, window) in app.webview_windows() {
|
||||
let _ = window.show_menu();
|
||||
}
|
||||
}
|
||||
|
||||
#[tauri::command]
|
||||
fn hide_menu_bar_temporary(app: AppHandle) {
|
||||
if cfg!(target_os = "macos") {
|
||||
return;
|
||||
}
|
||||
let state = app.state::<ConfigState>();
|
||||
let mut temp = state.menu_temporarily_visible.write().unwrap();
|
||||
if !*temp {
|
||||
return;
|
||||
}
|
||||
*temp = false;
|
||||
drop(temp);
|
||||
|
||||
if state.config.read().unwrap().show_menu_bar {
|
||||
return;
|
||||
}
|
||||
|
||||
for (_, window) in app.webview_windows() {
|
||||
let _ = window.hide_menu();
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Menu Setup
|
||||
// ============================================================================
|
||||
@@ -871,59 +667,6 @@ fn setup_app_menu(app: &AppHandle) -> tauri::Result<()> {
|
||||
menu.prepend(&file_menu)?;
|
||||
}
|
||||
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
{
|
||||
let config = app.state::<ConfigState>();
|
||||
let config_guard = config.config.read().unwrap();
|
||||
|
||||
let show_menu_bar_item = CheckMenuItem::with_id(
|
||||
app,
|
||||
MENU_SHOW_MENU_BAR_ID,
|
||||
"Show Menu Bar",
|
||||
true,
|
||||
config_guard.show_menu_bar,
|
||||
None::<&str>,
|
||||
)?;
|
||||
|
||||
let hide_decorations_item = CheckMenuItem::with_id(
|
||||
app,
|
||||
MENU_HIDE_DECORATIONS_ID,
|
||||
"Hide Window Decorations",
|
||||
true,
|
||||
config_guard.hide_window_decorations,
|
||||
None::<&str>,
|
||||
)?;
|
||||
|
||||
drop(config_guard);
|
||||
|
||||
if let Some(window_menu) = menu
|
||||
.items()?
|
||||
.into_iter()
|
||||
.filter_map(|item| item.as_submenu().cloned())
|
||||
.find(|submenu| submenu.text().ok().as_deref() == Some("Window"))
|
||||
{
|
||||
window_menu.append(&show_menu_bar_item)?;
|
||||
window_menu.append(&hide_decorations_item)?;
|
||||
} else {
|
||||
let window_menu = SubmenuBuilder::new(app, "Window")
|
||||
.item(&show_menu_bar_item)
|
||||
.item(&hide_decorations_item)
|
||||
.build()?;
|
||||
|
||||
let items = menu.items()?;
|
||||
let help_idx = items
|
||||
.iter()
|
||||
.position(|item| {
|
||||
item.as_submenu()
|
||||
.and_then(|s| s.text().ok())
|
||||
.as_deref()
|
||||
== Some("Help")
|
||||
})
|
||||
.unwrap_or(items.len());
|
||||
menu.insert(&window_menu, help_idx)?;
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(help_menu) = menu
|
||||
.get(HELP_SUBMENU_ID)
|
||||
.and_then(|item| item.as_submenu().cloned())
|
||||
@@ -1058,7 +801,6 @@ fn main() {
|
||||
config: RwLock::new(config),
|
||||
config_initialized: RwLock::new(config_initialized),
|
||||
app_base_url: RwLock::new(None),
|
||||
menu_temporarily_visible: RwLock::new(false),
|
||||
})
|
||||
.invoke_handler(tauri::generate_handler![
|
||||
get_server_url,
|
||||
@@ -1074,18 +816,13 @@ fn main() {
|
||||
go_forward,
|
||||
new_window,
|
||||
reset_config,
|
||||
start_drag_window,
|
||||
toggle_menu_bar,
|
||||
show_menu_bar_temporarily,
|
||||
hide_menu_bar_temporary
|
||||
start_drag_window
|
||||
])
|
||||
.on_menu_event(|app, event| match event.id().as_ref() {
|
||||
"open_docs" => open_docs(),
|
||||
"new_chat" => trigger_new_chat(app),
|
||||
"new_window" => trigger_new_window(app),
|
||||
"open_settings" => open_settings(app),
|
||||
"show_menu_bar" => handle_menu_bar_toggle(app),
|
||||
"hide_window_decorations" => handle_decorations_toggle(app),
|
||||
_ => {}
|
||||
})
|
||||
.setup(move |app| {
|
||||
@@ -1118,8 +855,6 @@ fn main() {
|
||||
#[cfg(target_os = "macos")]
|
||||
inject_titlebar(window.clone());
|
||||
|
||||
apply_settings_to_window(&app_handle, &window);
|
||||
|
||||
let _ = window.set_focus();
|
||||
}
|
||||
|
||||
@@ -1128,27 +863,7 @@ fn main() {
|
||||
.on_page_load(|webview: &Webview, _payload: &PageLoadPayload| {
|
||||
inject_chat_link_intercept(webview);
|
||||
|
||||
#[cfg(not(target_os = "macos"))]
|
||||
{
|
||||
let _ = webview.eval(MENU_KEY_HANDLER_SCRIPT);
|
||||
|
||||
let app = webview.app_handle();
|
||||
let state = app.state::<ConfigState>();
|
||||
let config = state.config.read().unwrap();
|
||||
let temp_visible = *state.menu_temporarily_visible.read().unwrap();
|
||||
let label = webview.label().to_string();
|
||||
if !config.show_menu_bar && !temp_visible {
|
||||
if let Some(win) = app.get_webview_window(&label) {
|
||||
let _ = win.hide_menu();
|
||||
}
|
||||
}
|
||||
if config.hide_window_decorations {
|
||||
if let Some(win) = app.get_webview_window(&label) {
|
||||
let _ = win.set_decorations(false);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Re-inject titlebar after every navigation/page load (macOS only)
|
||||
#[cfg(target_os = "macos")]
|
||||
let _ = webview.eval(TITLEBAR_SCRIPT);
|
||||
})
|
||||
|
||||
@@ -5,7 +5,7 @@ import type { SizeVariant } from "@opal/shared";
|
||||
import SvgEdit from "@opal/icons/edit";
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
import { cn } from "@opal/utils";
|
||||
import { useState } from "react";
|
||||
import { useRef, useState } from "react";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Types
|
||||
@@ -89,6 +89,7 @@ function ContentLg({
|
||||
}: ContentLgProps) {
|
||||
const [editing, setEditing] = useState(false);
|
||||
const [editValue, setEditValue] = useState(title);
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
|
||||
const config = CONTENT_LG_PRESETS[sizePreset];
|
||||
|
||||
@@ -130,6 +131,7 @@ function ContentLg({
|
||||
{editValue || "\u00A0"}
|
||||
</span>
|
||||
<input
|
||||
ref={inputRef}
|
||||
className={cn(
|
||||
"opal-content-lg-input",
|
||||
config.titleFont,
|
||||
|
||||
@@ -5,7 +5,7 @@ import type { SizeVariant } from "@opal/shared";
|
||||
import SvgEdit from "@opal/icons/edit";
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
import { cn } from "@opal/utils";
|
||||
import { useState } from "react";
|
||||
import { useRef, useState } from "react";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Types
|
||||
@@ -109,6 +109,7 @@ function ContentXl({
|
||||
}: ContentXlProps) {
|
||||
const [editing, setEditing] = useState(false);
|
||||
const [editValue, setEditValue] = useState(title);
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
|
||||
const config = CONTENT_XL_PRESETS[sizePreset];
|
||||
|
||||
@@ -188,6 +189,7 @@ function ContentXl({
|
||||
{editValue || "\u00A0"}
|
||||
</span>
|
||||
<input
|
||||
ref={inputRef}
|
||||
className={cn(
|
||||
"opal-content-xl-input",
|
||||
config.titleFont,
|
||||
|
||||
@@ -148,8 +148,7 @@ function Content(props: ContentProps) {
|
||||
}
|
||||
}
|
||||
|
||||
// ContentMd: main-content/main-ui/secondary with section/heading variant
|
||||
// (variant defaults to "heading" when omitted on MdContentProps, so both arms are needed)
|
||||
// ContentMd: main-content/main-ui/secondary with section variant
|
||||
else if (variant === "section" || variant === "heading") {
|
||||
layout = (
|
||||
<ContentMd
|
||||
|
||||
@@ -8,7 +8,7 @@ Layout primitives for composing icon + title + description rows. These component
|
||||
|
||||
| Component | Description | Docs |
|
||||
|---|---|---|
|
||||
| [`Content`](./Content/README.md) | Icon + title + description row. Routes to an internal layout (`ContentXl`, `ContentLg`, `ContentMd`, or `ContentSm`) based on `sizePreset` and `variant`. | [Content README](./Content/README.md) |
|
||||
| [`Content`](./Content/README.md) | Icon + title + description row. Routes to an internal layout (`ContentLg`, `ContentMd`, or `ContentSm`) based on `sizePreset` and `variant`. | [Content README](./Content/README.md) |
|
||||
| [`ContentAction`](./ContentAction/README.md) | Wraps `Content` in a flex-row with an optional `rightChildren` slot for action buttons. Adds padding alignment via the shared `SizeVariant` scale. | [ContentAction README](./ContentAction/README.md) |
|
||||
|
||||
## Quick Start
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
// - Interactive.Container (height + min-width + padding)
|
||||
// - Button (icon sizing)
|
||||
// - ContentAction (padding only)
|
||||
// - Content (ContentXl / ContentLg / ContentMd) (edit-button size)
|
||||
// - Content (ContentLg / ContentMd) (edit-button size)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
|
||||
@@ -11,7 +11,7 @@ import { DraggableTable } from "@/components/table/DraggableTable";
|
||||
import {
|
||||
deletePersona,
|
||||
personaComparator,
|
||||
togglePersonaFeatured,
|
||||
togglePersonaDefault,
|
||||
togglePersonaVisibility,
|
||||
} from "./lib";
|
||||
import { FiEdit2 } from "react-icons/fi";
|
||||
@@ -27,8 +27,8 @@ function PersonaTypeDisplay({ persona }: { persona: Persona }) {
|
||||
return <Text as="p">Built-In</Text>;
|
||||
}
|
||||
|
||||
if (persona.featured) {
|
||||
return <Text as="p">Featured</Text>;
|
||||
if (persona.is_default_persona) {
|
||||
return <Text as="p">Default</Text>;
|
||||
}
|
||||
|
||||
if (persona.is_public) {
|
||||
@@ -152,9 +152,9 @@ export function PersonasTable({
|
||||
|
||||
const handleToggleDefault = async () => {
|
||||
if (personaToToggleDefault) {
|
||||
const response = await togglePersonaFeatured(
|
||||
const response = await togglePersonaDefault(
|
||||
personaToToggleDefault.id,
|
||||
personaToToggleDefault.featured
|
||||
personaToToggleDefault.is_default_persona
|
||||
);
|
||||
if (response.ok) {
|
||||
refreshPersonas();
|
||||
@@ -180,7 +180,7 @@ export function PersonasTable({
|
||||
{defaultModalOpen &&
|
||||
personaToToggleDefault &&
|
||||
(() => {
|
||||
const isDefault = personaToToggleDefault.featured;
|
||||
const isDefault = personaToToggleDefault.is_default_persona;
|
||||
|
||||
const title = isDefault
|
||||
? "Remove Featured Agent"
|
||||
@@ -252,7 +252,7 @@ export function PersonasTable({
|
||||
</p>,
|
||||
<PersonaTypeDisplay key={persona.id} persona={persona} />,
|
||||
<div
|
||||
key="featured"
|
||||
key="is_default_persona"
|
||||
onClick={() => {
|
||||
openDefaultModal(persona);
|
||||
}}
|
||||
@@ -261,13 +261,13 @@ export function PersonasTable({
|
||||
`}
|
||||
>
|
||||
<div className="my-auto flex-none w-22">
|
||||
{!persona.featured ? (
|
||||
{!persona.is_default_persona ? (
|
||||
<div className="text-error">Not Featured</div>
|
||||
) : (
|
||||
"Featured"
|
||||
)}
|
||||
</div>
|
||||
<Checkbox checked={persona.featured} />
|
||||
<Checkbox checked={persona.is_default_persona} />
|
||||
</div>,
|
||||
<div
|
||||
key="is_visible"
|
||||
|
||||
@@ -53,7 +53,7 @@ export interface MinimalPersonaSnapshot {
|
||||
is_public: boolean;
|
||||
is_visible: boolean;
|
||||
display_priority: number | null;
|
||||
featured: boolean;
|
||||
is_default_persona: boolean;
|
||||
builtin_persona: boolean;
|
||||
|
||||
labels?: PersonaLabel[];
|
||||
@@ -64,6 +64,7 @@ export interface Persona extends MinimalPersonaSnapshot {
|
||||
user_file_ids: string[];
|
||||
users: MinimalUserSnapshot[];
|
||||
groups: number[];
|
||||
num_chunks?: number;
|
||||
// Hierarchy nodes (folders, spaces, channels) attached for scoped search
|
||||
hierarchy_nodes?: HierarchyNodeSnapshot[];
|
||||
// Individual documents attached for scoped search
|
||||
@@ -78,6 +79,8 @@ export interface Persona extends MinimalPersonaSnapshot {
|
||||
|
||||
export interface FullPersona extends Persona {
|
||||
search_start_date: string | null;
|
||||
llm_relevance_filter?: boolean;
|
||||
llm_filter_extraction?: boolean;
|
||||
}
|
||||
|
||||
export interface PersonaLabel {
|
||||
|
||||
@@ -11,7 +11,11 @@ interface PersonaUpsertRequest {
|
||||
task_prompt: string;
|
||||
datetime_aware: boolean;
|
||||
document_set_ids: number[];
|
||||
num_chunks: number | null;
|
||||
is_public: boolean;
|
||||
recency_bias: string;
|
||||
llm_filter_extraction: boolean;
|
||||
llm_relevance_filter: boolean | null;
|
||||
llm_model_provider_override: string | null;
|
||||
llm_model_version_override: string | null;
|
||||
starter_messages: StarterMessage[] | null;
|
||||
@@ -22,7 +26,7 @@ interface PersonaUpsertRequest {
|
||||
uploaded_image_id: string | null;
|
||||
icon_name: string | null;
|
||||
search_start_date: Date | null;
|
||||
featured: boolean;
|
||||
is_default_persona: boolean;
|
||||
display_priority: number | null;
|
||||
label_ids: number[] | null;
|
||||
user_file_ids: string[] | null;
|
||||
@@ -41,7 +45,9 @@ export interface PersonaUpsertParameters {
|
||||
task_prompt: string;
|
||||
datetime_aware: boolean;
|
||||
document_set_ids: number[];
|
||||
num_chunks: number | null;
|
||||
is_public: boolean;
|
||||
llm_relevance_filter: boolean | null;
|
||||
llm_model_provider_override: string | null;
|
||||
llm_model_version_override: string | null;
|
||||
starter_messages: StarterMessage[] | null;
|
||||
@@ -52,7 +58,7 @@ export interface PersonaUpsertParameters {
|
||||
search_start_date: Date | null;
|
||||
uploaded_image_id: string | null;
|
||||
icon_name: string | null;
|
||||
featured: boolean;
|
||||
is_default_persona: boolean;
|
||||
label_ids: number[] | null;
|
||||
user_file_ids: string[];
|
||||
// Hierarchy nodes (folders, spaces, channels) for scoped search
|
||||
@@ -67,6 +73,7 @@ function buildPersonaUpsertRequest({
|
||||
system_prompt,
|
||||
task_prompt,
|
||||
document_set_ids,
|
||||
num_chunks,
|
||||
is_public,
|
||||
groups,
|
||||
datetime_aware,
|
||||
@@ -79,7 +86,8 @@ function buildPersonaUpsertRequest({
|
||||
document_ids,
|
||||
icon_name,
|
||||
uploaded_image_id,
|
||||
featured,
|
||||
is_default_persona,
|
||||
llm_relevance_filter,
|
||||
llm_model_provider_override,
|
||||
llm_model_version_override,
|
||||
starter_messages,
|
||||
@@ -92,6 +100,7 @@ function buildPersonaUpsertRequest({
|
||||
system_prompt,
|
||||
task_prompt,
|
||||
document_set_ids,
|
||||
num_chunks,
|
||||
is_public,
|
||||
uploaded_image_id,
|
||||
icon_name,
|
||||
@@ -101,7 +110,10 @@ function buildPersonaUpsertRequest({
|
||||
remove_image,
|
||||
search_start_date,
|
||||
datetime_aware,
|
||||
featured: featured ?? false,
|
||||
is_default_persona: is_default_persona ?? false,
|
||||
recency_bias: "base_decay",
|
||||
llm_filter_extraction: false,
|
||||
llm_relevance_filter: llm_relevance_filter ?? null,
|
||||
llm_model_provider_override: llm_model_provider_override ?? null,
|
||||
llm_model_version_override: llm_model_version_override ?? null,
|
||||
starter_messages: starter_messages ?? null,
|
||||
@@ -214,17 +226,17 @@ export function personaComparator(
|
||||
return closerToZeroNegativesFirstComparator(a.id, b.id);
|
||||
}
|
||||
|
||||
export async function togglePersonaFeatured(
|
||||
export async function togglePersonaDefault(
|
||||
personaId: number,
|
||||
featured: boolean
|
||||
isDefault: boolean
|
||||
) {
|
||||
const response = await fetch(`/api/admin/persona/${personaId}/featured`, {
|
||||
const response = await fetch(`/api/admin/persona/${personaId}/default`, {
|
||||
method: "PATCH",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
featured: !featured,
|
||||
is_default_persona: !isDefault,
|
||||
}),
|
||||
credentials: "include",
|
||||
});
|
||||
|
||||
@@ -74,10 +74,10 @@ export default function ProjectChatSessionList() {
|
||||
<div className="flex gap-3 min-w-0 w-full">
|
||||
<div className="flex h-full w-fit pt-1 pl-1">
|
||||
{(() => {
|
||||
const personaIdToFeatured =
|
||||
currentProjectDetails?.persona_id_to_featured || {};
|
||||
const isFeatured = personaIdToFeatured[chat.persona_id];
|
||||
if (isFeatured === false) {
|
||||
const personaIdToDefault =
|
||||
currentProjectDetails?.persona_id_to_is_default || {};
|
||||
const isDefault = personaIdToDefault[chat.persona_id];
|
||||
if (isDefault === false) {
|
||||
const assistant = assistants.find(
|
||||
(a) => a.id === chat.persona_id
|
||||
);
|
||||
|
||||
@@ -59,7 +59,7 @@ export enum UserFileStatus {
|
||||
export type ProjectDetails = {
|
||||
project: Project;
|
||||
files?: ProjectFile[];
|
||||
persona_id_to_featured?: Record<number, boolean>;
|
||||
persona_id_to_is_default?: Record<number, boolean>;
|
||||
};
|
||||
|
||||
export async function fetchProjects(): Promise<Project[]> {
|
||||
|
||||
@@ -20,7 +20,7 @@ export function constructMiniFiedPersona(name: string, id: number): Persona {
|
||||
owner: null,
|
||||
starter_messages: null,
|
||||
builtin_persona: false,
|
||||
featured: false,
|
||||
is_default_persona: false,
|
||||
users: [],
|
||||
groups: [],
|
||||
user_file_ids: [],
|
||||
|
||||
@@ -52,7 +52,7 @@ export function ChatSessionMorePopup({
|
||||
}: ChatSessionMorePopupProps) {
|
||||
const [popoverOpen, setPopoverOpen] = useState(false);
|
||||
const [isDeleteModalOpen, setIsDeleteModalOpen] = useState(false);
|
||||
const { refreshChatSessions, removeSession } = useChatSessions();
|
||||
const { refreshChatSessions } = useChatSessions();
|
||||
const { fetchProjects, projects } = useProjectsContext();
|
||||
|
||||
const [pendingMoveProjectId, setPendingMoveProjectId] = useState<
|
||||
@@ -79,20 +79,13 @@ export function ChatSessionMorePopup({
|
||||
async (e: React.MouseEvent<HTMLButtonElement>) => {
|
||||
e.stopPropagation();
|
||||
await deleteChatSession(chatSession.id);
|
||||
removeSession(chatSession.id);
|
||||
await refreshChatSessions();
|
||||
await fetchProjects();
|
||||
setIsDeleteModalOpen(false);
|
||||
setPopoverOpen(false);
|
||||
afterDelete?.();
|
||||
},
|
||||
[
|
||||
chatSession,
|
||||
refreshChatSessions,
|
||||
removeSession,
|
||||
fetchProjects,
|
||||
afterDelete,
|
||||
]
|
||||
[chatSession, refreshChatSessions, fetchProjects, afterDelete]
|
||||
);
|
||||
|
||||
const performMove = useCallback(
|
||||
|
||||
@@ -109,11 +109,13 @@ export function usePinnedAgents() {
|
||||
const serverPinnedAgents = useMemo(() => {
|
||||
if (agents.length === 0) return [];
|
||||
|
||||
// If pinned_assistants is null/undefined (never set), show featured personas
|
||||
// If pinned_assistants is null/undefined (never set), show default personas
|
||||
// If it's an empty array (user explicitly unpinned all), show nothing
|
||||
const pinnedIds = user?.preferences.pinned_assistants;
|
||||
if (pinnedIds === null || pinnedIds === undefined) {
|
||||
return agents.filter((agent) => agent.featured && agent.id !== 0);
|
||||
return agents.filter(
|
||||
(agent) => agent.is_default_persona && agent.id !== 0
|
||||
);
|
||||
}
|
||||
|
||||
return pinnedIds
|
||||
|
||||
@@ -66,15 +66,15 @@ import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
* };
|
||||
* ```
|
||||
*/
|
||||
export default function useCCPairs(enabled: boolean = true) {
|
||||
export default function useCCPairs() {
|
||||
const { data, error, isLoading, mutate } = useSWR<CCPairBasicInfo[]>(
|
||||
enabled ? "/api/manage/connector-status" : null,
|
||||
"/api/manage/connector-status",
|
||||
errorHandlingFetcher
|
||||
);
|
||||
|
||||
return {
|
||||
ccPairs: data ?? [],
|
||||
isLoading: enabled && isLoading,
|
||||
isLoading,
|
||||
error,
|
||||
refetch: mutate,
|
||||
};
|
||||
|
||||
@@ -1,13 +1,7 @@
|
||||
"use client";
|
||||
|
||||
import {
|
||||
useCallback,
|
||||
useEffect,
|
||||
useMemo,
|
||||
useState,
|
||||
useSyncExternalStore,
|
||||
} from "react";
|
||||
import useSWRInfinite from "swr/infinite";
|
||||
import { useCallback, useEffect, useMemo, useSyncExternalStore } from "react";
|
||||
import useSWR, { KeyedMutator } from "swr";
|
||||
import { ChatSession, ChatSessionSharedStatus } from "@/app/app/interfaces";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { MinimalPersonaSnapshot } from "@/app/admin/assistants/interfaces";
|
||||
@@ -15,12 +9,8 @@ import useAppFocus from "./useAppFocus";
|
||||
import { useAgents } from "./useAgents";
|
||||
import { DEFAULT_ASSISTANT_ID } from "@/lib/constants";
|
||||
|
||||
const PAGE_SIZE = 50;
|
||||
const MIN_LOADING_DURATION_MS = 500;
|
||||
|
||||
interface ChatSessionsResponse {
|
||||
sessions: ChatSession[];
|
||||
has_more: boolean;
|
||||
}
|
||||
|
||||
export interface PendingChatSessionParams {
|
||||
@@ -36,24 +26,17 @@ interface UseChatSessionsOutput {
|
||||
agentForCurrentChatSession: MinimalPersonaSnapshot | null;
|
||||
isLoading: boolean;
|
||||
error: any;
|
||||
refreshChatSessions: () => Promise<ChatSessionsResponse[] | undefined>;
|
||||
refreshChatSessions: KeyedMutator<ChatSessionsResponse>;
|
||||
addPendingChatSession: (params: PendingChatSessionParams) => void;
|
||||
removeSession: (sessionId: string) => void;
|
||||
hasMore: boolean;
|
||||
isLoadingMore: boolean;
|
||||
loadMore: () => void;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Shared module-level store for pending chat sessions
|
||||
// ---------------------------------------------------------------------------
|
||||
// Pending sessions are optimistic new sessions shown in the sidebar before
|
||||
// the server returns them. This must be module-level so all hook instances
|
||||
// (sidebar, ChatButton, etc.) share the same state.
|
||||
|
||||
// Module-level store for pending chat sessions
|
||||
// This persists across SWR revalidations and component re-renders
|
||||
// Pending sessions are shown in the sidebar until the server returns them
|
||||
const pendingSessionsStore = {
|
||||
sessions: new Map<string, ChatSession>(),
|
||||
listeners: new Set<() => void>(),
|
||||
// Cached snapshot to avoid creating new array references on every call
|
||||
cachedSnapshot: [] as ChatSession[],
|
||||
|
||||
add(session: ChatSession) {
|
||||
@@ -91,7 +74,7 @@ const pendingSessionsStore = {
|
||||
},
|
||||
};
|
||||
|
||||
// Stable empty array for SSR
|
||||
// Stable empty array for SSR - must be defined outside component to avoid infinite loop
|
||||
const EMPTY_SESSIONS: ChatSession[] = [];
|
||||
|
||||
function usePendingSessions(): ChatSession[] {
|
||||
@@ -102,10 +85,6 @@ function usePendingSessions(): ChatSession[] {
|
||||
);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helper hooks
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
function useFindAgentForCurrentChatSession(
|
||||
currentChatSession: ChatSession | null
|
||||
): MinimalPersonaSnapshot | null {
|
||||
@@ -132,102 +111,35 @@ function useFindAgentForCurrentChatSession(
|
||||
return agents.find((agent) => agent.id === agentIdToFind) ?? null;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Main hook
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export default function useChatSessions(): UseChatSessionsOutput {
|
||||
const getKey = (
|
||||
pageIndex: number,
|
||||
previousPageData: ChatSessionsResponse | null
|
||||
): string | null => {
|
||||
// No more pages
|
||||
if (previousPageData && !previousPageData.has_more) return null;
|
||||
|
||||
// First page — no cursor
|
||||
if (pageIndex === 0) {
|
||||
return `/api/chat/get-user-chat-sessions?page_size=${PAGE_SIZE}`;
|
||||
}
|
||||
|
||||
// Subsequent pages — cursor from the last session of the previous page
|
||||
const lastSession =
|
||||
previousPageData!.sessions[previousPageData!.sessions.length - 1];
|
||||
if (!lastSession) return null;
|
||||
|
||||
const params = new URLSearchParams({
|
||||
page_size: PAGE_SIZE.toString(),
|
||||
before: lastSession.time_updated,
|
||||
});
|
||||
return `/api/chat/get-user-chat-sessions?${params.toString()}`;
|
||||
};
|
||||
|
||||
const { data, error, setSize, mutate } = useSWRInfinite<ChatSessionsResponse>(
|
||||
getKey,
|
||||
const { data, error, mutate } = useSWR<ChatSessionsResponse>(
|
||||
"/api/chat/get-user-chat-sessions",
|
||||
errorHandlingFetcher,
|
||||
{
|
||||
revalidateOnFocus: false,
|
||||
revalidateFirstPage: true,
|
||||
revalidateAll: false,
|
||||
dedupingInterval: 30000,
|
||||
}
|
||||
);
|
||||
|
||||
const appFocus = useAppFocus();
|
||||
const pendingSessions = usePendingSessions();
|
||||
|
||||
// Flatten all pages into a single session list
|
||||
const allFetchedSessions = useMemo(
|
||||
() => (data ? data.flatMap((page) => page.sessions) : []),
|
||||
[data]
|
||||
);
|
||||
|
||||
// hasMore: check the last loaded page
|
||||
const hasMore = useMemo(() => {
|
||||
if (!data || data.length === 0) return false;
|
||||
const lastPage = data[data.length - 1];
|
||||
return lastPage ? lastPage.has_more : false;
|
||||
}, [data]);
|
||||
|
||||
const [isLoadingMore, setIsLoadingMore] = useState(false);
|
||||
|
||||
const loadMore = useCallback(async () => {
|
||||
if (isLoadingMore || !hasMore) return;
|
||||
|
||||
setIsLoadingMore(true);
|
||||
const loadStart = Date.now();
|
||||
|
||||
try {
|
||||
await setSize((s) => s + 1);
|
||||
|
||||
// Enforce minimum loading duration to avoid skeleton flash
|
||||
const elapsed = Date.now() - loadStart;
|
||||
if (elapsed < MIN_LOADING_DURATION_MS) {
|
||||
await new Promise((r) =>
|
||||
setTimeout(r, MIN_LOADING_DURATION_MS - elapsed)
|
||||
);
|
||||
}
|
||||
} catch (err) {
|
||||
console.error("Failed to load more chat sessions:", err);
|
||||
} finally {
|
||||
setIsLoadingMore(false);
|
||||
}
|
||||
}, [isLoadingMore, hasMore, setSize]);
|
||||
const fetchedSessions = data?.sessions ?? [];
|
||||
|
||||
// Clean up pending sessions that now appear in fetched data
|
||||
// (they now have messages and the server returns them)
|
||||
useEffect(() => {
|
||||
const fetchedIds = new Set(allFetchedSessions.map((s) => s.id));
|
||||
const fetchedIds = new Set(fetchedSessions.map((s) => s.id));
|
||||
pendingSessions.forEach((pending) => {
|
||||
if (fetchedIds.has(pending.id)) {
|
||||
pendingSessionsStore.remove(pending.id);
|
||||
}
|
||||
});
|
||||
}, [allFetchedSessions, pendingSessions]);
|
||||
}, [fetchedSessions, pendingSessions]);
|
||||
|
||||
// Merge fetched sessions with pending sessions.
|
||||
// This ensures pending sessions persist across SWR revalidations.
|
||||
// Merge fetched sessions with pending sessions
|
||||
// This ensures pending sessions persist across SWR revalidations
|
||||
const chatSessions = useMemo(() => {
|
||||
const fetchedIds = new Set(allFetchedSessions.map((s) => s.id));
|
||||
const fetchedIds = new Set(fetchedSessions.map((s) => s.id));
|
||||
|
||||
// Get pending sessions that are not yet in fetched data
|
||||
const remainingPending = pendingSessions.filter(
|
||||
@@ -235,8 +147,8 @@ export default function useChatSessions(): UseChatSessionsOutput {
|
||||
);
|
||||
|
||||
// Pending sessions go first (most recent), then fetched sessions
|
||||
return [...remainingPending, ...allFetchedSessions];
|
||||
}, [allFetchedSessions, pendingSessions]);
|
||||
return [...remainingPending, ...fetchedSessions];
|
||||
}, [fetchedSessions, pendingSessions]);
|
||||
|
||||
const currentChatSessionId = appFocus.isChat() ? appFocus.getId() : null;
|
||||
const currentChatSession =
|
||||
@@ -247,18 +159,27 @@ export default function useChatSessions(): UseChatSessionsOutput {
|
||||
const agentForCurrentChatSession =
|
||||
useFindAgentForCurrentChatSession(currentChatSession);
|
||||
|
||||
// Add a pending chat session that will persist across SWR revalidations.
|
||||
// The session will be automatically removed once it appears in the server response.
|
||||
// Add a pending chat session that will persist across SWR revalidations
|
||||
// The session will be automatically removed once it appears in the server response
|
||||
const addPendingChatSession = useCallback(
|
||||
({ chatSessionId, personaId, projectId }: PendingChatSessionParams) => {
|
||||
// Don't add sessions that belong to a project
|
||||
if (projectId != null) return;
|
||||
if (projectId != null) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Don't add if already in pending store (duplicates are also filtered during merge)
|
||||
if (pendingSessionsStore.has(chatSessionId)) return;
|
||||
if (pendingSessionsStore.has(chatSessionId)) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Note: This check uses stale fetchedSessions due to empty deps, but is defensive
|
||||
if (fetchedSessions.some((s) => s.id === chatSessionId)) {
|
||||
return;
|
||||
}
|
||||
|
||||
const now = new Date().toISOString();
|
||||
pendingSessionsStore.add({
|
||||
const pendingSession: ChatSession = {
|
||||
id: chatSessionId,
|
||||
name: "", // Empty name will display as "New Chat" via UNNAMED_CHAT constant
|
||||
persona_id: personaId,
|
||||
@@ -268,29 +189,13 @@ export default function useChatSessions(): UseChatSessionsOutput {
|
||||
project_id: projectId ?? null,
|
||||
current_alternate_model: "",
|
||||
current_temperature_override: null,
|
||||
});
|
||||
};
|
||||
|
||||
pendingSessionsStore.add(pendingSession);
|
||||
},
|
||||
[]
|
||||
);
|
||||
|
||||
const removeSession = useCallback(
|
||||
(sessionId: string) => {
|
||||
pendingSessionsStore.remove(sessionId);
|
||||
// Optimistically remove from all loaded pages
|
||||
mutate(
|
||||
(pages) =>
|
||||
pages?.map((page) => ({
|
||||
...page,
|
||||
sessions: page.sessions.filter((s) => s.id !== sessionId),
|
||||
})),
|
||||
{ revalidate: false }
|
||||
);
|
||||
},
|
||||
[mutate]
|
||||
);
|
||||
|
||||
const refreshChatSessions = useCallback(() => mutate(), [mutate]);
|
||||
|
||||
return {
|
||||
chatSessions,
|
||||
currentChatSessionId,
|
||||
@@ -298,11 +203,7 @@ export default function useChatSessions(): UseChatSessionsOutput {
|
||||
agentForCurrentChatSession,
|
||||
isLoading: !error && !data,
|
||||
error,
|
||||
refreshChatSessions,
|
||||
refreshChatSessions: mutate,
|
||||
addPendingChatSession,
|
||||
removeSession,
|
||||
hasMore,
|
||||
isLoadingMore,
|
||||
loadMore,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -4,14 +4,8 @@ import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
const HEALTH_ENDPOINT = "/api/admin/code-interpreter/health";
|
||||
const STATUS_ENDPOINT = "/api/admin/code-interpreter";
|
||||
|
||||
export type CodeInterpreterHealthStatus =
|
||||
| "healthy"
|
||||
| "unhealthy"
|
||||
| "connection_lost";
|
||||
|
||||
interface CodeInterpreterHealth {
|
||||
connected: boolean;
|
||||
error: string;
|
||||
healthy: boolean;
|
||||
}
|
||||
|
||||
interface CodeInterpreterStatus {
|
||||
@@ -21,7 +15,7 @@ interface CodeInterpreterStatus {
|
||||
export default function useCodeInterpreter() {
|
||||
const {
|
||||
data: healthData,
|
||||
error: healthFetchError,
|
||||
error: healthError,
|
||||
isLoading: isHealthLoading,
|
||||
mutate: refetchHealth,
|
||||
} = useSWR<CodeInterpreterHealth>(HEALTH_ENDPOINT, errorHandlingFetcher, {
|
||||
@@ -40,22 +34,11 @@ export default function useCodeInterpreter() {
|
||||
refetchStatus();
|
||||
}
|
||||
|
||||
const status: CodeInterpreterHealthStatus = healthFetchError
|
||||
? "connection_lost"
|
||||
: !healthData?.connected
|
||||
? "connection_lost"
|
||||
: healthData.error
|
||||
? "unhealthy"
|
||||
: "healthy";
|
||||
|
||||
const error = healthFetchError?.message || healthData?.error || undefined;
|
||||
|
||||
return {
|
||||
status,
|
||||
error,
|
||||
isHealthy: healthData?.healthy ?? false,
|
||||
isEnabled: statusData?.enabled ?? false,
|
||||
isLoading: isHealthLoading || isStatusLoading,
|
||||
fetchError: healthFetchError || statusError,
|
||||
error: healthError || statusError,
|
||||
refetch,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -104,8 +104,7 @@ function Header() {
|
||||
refreshCurrentProjectDetails,
|
||||
currentProjectId,
|
||||
} = useProjectsContext();
|
||||
const { currentChatSession, refreshChatSessions, removeSession } =
|
||||
useChatSessions();
|
||||
const { currentChatSession, refreshChatSessions } = useChatSessions();
|
||||
const router = useRouter();
|
||||
const appFocus = useAppFocus();
|
||||
const { classification } = useQueryController();
|
||||
@@ -187,7 +186,6 @@ function Header() {
|
||||
if (!response.ok) {
|
||||
throw new Error("Failed to delete chat session");
|
||||
}
|
||||
removeSession(currentChatSession.id);
|
||||
await Promise.all([refreshChatSessions(), fetchProjects()]);
|
||||
router.replace("/app");
|
||||
setDeleteModalOpen(false);
|
||||
@@ -195,13 +193,7 @@ function Header() {
|
||||
console.error("Failed to delete chat:", error);
|
||||
showErrorNotification("Failed to delete chat. Please try again.");
|
||||
}
|
||||
}, [
|
||||
currentChatSession,
|
||||
refreshChatSessions,
|
||||
removeSession,
|
||||
fetchProjects,
|
||||
router,
|
||||
]);
|
||||
}, [currentChatSession, refreshChatSessions, fetchProjects, router]);
|
||||
|
||||
const setDeleteConfirmationModalOpen = useCallback((open: boolean) => {
|
||||
setDeleteModalOpen(open);
|
||||
|
||||
@@ -211,10 +211,10 @@ export async function updateAgentFeaturedStatus(
|
||||
isFeatured: boolean
|
||||
): Promise<string | null> {
|
||||
try {
|
||||
const response = await fetch(`/api/admin/persona/${agentId}/featured`, {
|
||||
const response = await fetch(`/api/admin/persona/${agentId}/default`, {
|
||||
method: "PATCH",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({ featured: isFeatured }),
|
||||
body: JSON.stringify({ is_default_persona: isFeatured }),
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
|
||||
@@ -222,12 +222,9 @@ export const useConnectorStatus = (refreshInterval = 30000) => {
|
||||
};
|
||||
};
|
||||
|
||||
export const useBasicConnectorStatus = (enabled: boolean = true) => {
|
||||
export const useBasicConnectorStatus = () => {
|
||||
const url = "/api/manage/connector-status";
|
||||
const swrResponse = useSWR<CCPairBasicInfo[]>(
|
||||
enabled ? url : null,
|
||||
errorHandlingFetcher
|
||||
);
|
||||
const swrResponse = useSWR<CCPairBasicInfo[]>(url, errorHandlingFetcher);
|
||||
return {
|
||||
...swrResponse,
|
||||
refreshIndexingStatus: () => mutate(url),
|
||||
|
||||
@@ -19,8 +19,7 @@ export function SettingsProvider({
|
||||
settings: CombinedSettings;
|
||||
}) {
|
||||
const [isMobile, setIsMobile] = useState<boolean | undefined>();
|
||||
const vectorDbEnabled = settings.settings.vector_db_enabled !== false;
|
||||
const { ccPairs } = useCCPairs(vectorDbEnabled);
|
||||
const { ccPairs } = useCCPairs();
|
||||
|
||||
useEffect(() => {
|
||||
const checkMobile = () => {
|
||||
|
||||
@@ -3,9 +3,9 @@
|
||||
import React from "react";
|
||||
import * as DialogPrimitive from "@radix-ui/react-dialog";
|
||||
import { cn } from "@/lib/utils";
|
||||
import type { IconProps } from "@opal/types";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
import { Button } from "@opal/components";
|
||||
import { Content } from "@opal/layouts";
|
||||
import { SvgX } from "@opal/icons";
|
||||
import { WithoutStyles } from "@/types";
|
||||
import { Section, SectionProps } from "@/layouts/general-layouts";
|
||||
@@ -407,7 +407,7 @@ ModalContent.displayName = DialogPrimitive.Content.displayName;
|
||||
* ```
|
||||
*/
|
||||
interface ModalHeaderProps extends WithoutStyles<SectionProps> {
|
||||
icon?: React.FunctionComponent<IconProps>;
|
||||
icon?: IconFunctionComponent;
|
||||
title: string;
|
||||
description?: string;
|
||||
onClose?: () => void;
|
||||
@@ -416,8 +416,6 @@ const ModalHeader = React.forwardRef<HTMLDivElement, ModalHeaderProps>(
|
||||
({ icon: Icon, title, description, onClose, children, ...props }, ref) => {
|
||||
const { closeButtonRef, setHasDescription } = useModalContext();
|
||||
|
||||
// useLayoutEffect ensures aria-describedby is set before paint,
|
||||
// so screen readers announce the description when the dialog opens
|
||||
React.useLayoutEffect(() => {
|
||||
setHasDescription(!!description);
|
||||
}, [description, setHasDescription]);
|
||||
@@ -440,52 +438,38 @@ const ModalHeader = React.forwardRef<HTMLDivElement, ModalHeaderProps>(
|
||||
|
||||
return (
|
||||
<Section ref={ref} padding={1} alignItems="start" height="fit" {...props}>
|
||||
<Section gap={0.5}>
|
||||
{Icon && (
|
||||
<Section
|
||||
gap={0}
|
||||
padding={0}
|
||||
flexDirection="row"
|
||||
justifyContent="between"
|
||||
>
|
||||
{/*
|
||||
The `h-[1.5rem]` and `w-[1.5rem]` were added as backups here.
|
||||
However, prop-resolution technically resolves to choosing classNames over size props, so technically the `size={24}` is the backup.
|
||||
We specify both to be safe.
|
||||
|
||||
# Note
|
||||
1.5rem === 24px
|
||||
*/}
|
||||
<Icon
|
||||
className="stroke-text-04 h-[1.5rem] w-[1.5rem]"
|
||||
size={24}
|
||||
/>
|
||||
{closeButton}
|
||||
</Section>
|
||||
)}
|
||||
|
||||
<Section
|
||||
alignItems="start"
|
||||
gap={0}
|
||||
padding={0}
|
||||
flexDirection="row"
|
||||
justifyContent="between"
|
||||
>
|
||||
<Section alignItems="start" padding={0} gap={0}>
|
||||
<DialogPrimitive.Title asChild>
|
||||
<Text headingH3>{title}</Text>
|
||||
</DialogPrimitive.Title>
|
||||
{description && (
|
||||
<DialogPrimitive.Description asChild>
|
||||
<Text secondaryBody text03>
|
||||
<Section
|
||||
flexDirection="row"
|
||||
justifyContent="between"
|
||||
alignItems="start"
|
||||
gap={0}
|
||||
padding={0}
|
||||
>
|
||||
<div className="relative w-full">
|
||||
{/* Close button is absolutely positioned because:
|
||||
1. Figma mocks place it overlapping the top-right of the content area
|
||||
2. Using ContentAction with rightChildren causes the description
|
||||
to wrap to the second line early due to the button reserving space */}
|
||||
<div className="absolute top-0 right-0">{closeButton}</div>
|
||||
<DialogPrimitive.Title asChild>
|
||||
<div>
|
||||
<Content
|
||||
icon={Icon}
|
||||
title={title}
|
||||
description={description}
|
||||
sizePreset="section"
|
||||
variant="heading"
|
||||
/>
|
||||
{description && (
|
||||
<DialogPrimitive.Description className="hidden">
|
||||
{description}
|
||||
</Text>
|
||||
</DialogPrimitive.Description>
|
||||
)}
|
||||
</Section>
|
||||
{!Icon && closeButton}
|
||||
</Section>
|
||||
</DialogPrimitive.Description>
|
||||
)}
|
||||
</div>
|
||||
</DialogPrimitive.Title>
|
||||
</div>
|
||||
</Section>
|
||||
|
||||
{children}
|
||||
</Section>
|
||||
);
|
||||
|
||||
@@ -72,8 +72,6 @@ export interface LineItemProps
|
||||
description?: string;
|
||||
rightChildren?: React.ReactNode;
|
||||
href?: string;
|
||||
rel?: string;
|
||||
target?: string;
|
||||
ref?: React.Ref<HTMLDivElement>;
|
||||
children?: React.ReactNode;
|
||||
}
|
||||
@@ -143,8 +141,6 @@ export default function LineItem({
|
||||
children,
|
||||
rightChildren,
|
||||
href,
|
||||
rel,
|
||||
target,
|
||||
ref,
|
||||
...props
|
||||
}: LineItemProps) {
|
||||
@@ -245,9 +241,5 @@ export default function LineItem({
|
||||
);
|
||||
|
||||
if (!href) return content;
|
||||
return (
|
||||
<Link href={href as Route} rel={rel} target={target}>
|
||||
{content}
|
||||
</Link>
|
||||
);
|
||||
return <Link href={href as Route}>{content}</Link>;
|
||||
}
|
||||
|
||||
@@ -29,7 +29,6 @@ import { SourceMetadata } from "@/lib/search/interfaces";
|
||||
import { SourceIcon } from "@/components/SourceIcon";
|
||||
import { useAvailableTools } from "@/hooks/useAvailableTools";
|
||||
import useCCPairs from "@/hooks/useCCPairs";
|
||||
import { useSettingsContext } from "@/providers/SettingsProvider";
|
||||
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
|
||||
import { useToolOAuthStatus } from "@/lib/hooks/useToolOAuthStatus";
|
||||
import LineItem from "@/refresh-components/buttons/LineItem";
|
||||
@@ -275,11 +274,9 @@ export default function ActionsPopover({
|
||||
}, [selectedAssistant.id, setForcedToolIds]);
|
||||
|
||||
const { isAdmin, isCurator } = useUser();
|
||||
const settings = useSettingsContext();
|
||||
const vectorDbEnabled = settings?.settings.vector_db_enabled !== false;
|
||||
|
||||
const { tools: availableTools } = useAvailableTools();
|
||||
const { ccPairs } = useCCPairs(vectorDbEnabled);
|
||||
const { ccPairs } = useCCPairs();
|
||||
const { currentProjectId, allCurrentProjectFiles } = useProjectsContext();
|
||||
const availableToolIdSet = new Set(availableTools.map((tool) => tool.id));
|
||||
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
interface SidebarTabSkeletonProps {
|
||||
textWidth?: string;
|
||||
}
|
||||
|
||||
export default function SidebarTabSkeleton({
|
||||
textWidth = "w-2/3",
|
||||
}: SidebarTabSkeletonProps) {
|
||||
return (
|
||||
<div className="w-full rounded-08 p-1.5">
|
||||
<div className="h-[1.5rem] flex flex-row items-center px-1 py-0.5">
|
||||
<div
|
||||
className={cn(
|
||||
"h-3 rounded bg-background-tint-04 animate-pulse",
|
||||
textWidth
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -24,6 +24,7 @@ import {
|
||||
STARTER_MESSAGES_EXAMPLES,
|
||||
MAX_CHARACTERS_STARTER_MESSAGE,
|
||||
MAX_CHARACTERS_AGENT_DESCRIPTION,
|
||||
MAX_CHUNKS_FED_TO_CHAT,
|
||||
} from "@/lib/constants";
|
||||
import {
|
||||
IMAGE_GENERATION_TOOL_ID,
|
||||
@@ -560,12 +561,9 @@ export default function AgentEditorPage({
|
||||
(_, i) => existingAgent?.starter_messages?.[i]?.message ?? ""
|
||||
),
|
||||
|
||||
// Knowledge - enabled if agent has any knowledge sources attached
|
||||
enable_knowledge:
|
||||
(existingAgent?.document_sets?.length ?? 0) > 0 ||
|
||||
(existingAgent?.hierarchy_nodes?.length ?? 0) > 0 ||
|
||||
(existingAgent?.attached_documents?.length ?? 0) > 0 ||
|
||||
(existingAgent?.user_file_ids?.length ?? 0) > 0,
|
||||
// Knowledge - enabled if num_chunks is greater than 0
|
||||
// (num_chunks of 0 or null means knowledge is disabled)
|
||||
enable_knowledge: (existingAgent?.num_chunks ?? 0) > 0,
|
||||
document_set_ids: existingAgent?.document_sets?.map((ds) => ds.id) ?? [],
|
||||
// Individual document IDs from hierarchy browsing
|
||||
document_ids: existingAgent?.attached_documents?.map((doc) => doc.id) ?? [],
|
||||
@@ -655,7 +653,7 @@ export default function AgentEditorPage({
|
||||
shared_group_ids: existingAgent?.groups ?? [],
|
||||
is_public: existingAgent?.is_public ?? true,
|
||||
label_ids: existingAgent?.labels?.map((l) => l.id) ?? [],
|
||||
featured: existingAgent?.featured ?? false,
|
||||
is_default_persona: existingAgent?.is_default_persona ?? false,
|
||||
};
|
||||
|
||||
const validationSchema = Yup.object().shape({
|
||||
@@ -687,6 +685,19 @@ export default function AgentEditorPage({
|
||||
hierarchy_node_ids: Yup.array().of(Yup.number()),
|
||||
user_file_ids: Yup.array().of(Yup.string()),
|
||||
selected_sources: Yup.array().of(Yup.string()),
|
||||
num_chunks: Yup.number()
|
||||
.nullable()
|
||||
.transform((value, originalValue) =>
|
||||
originalValue === "" || originalValue === null ? null : value
|
||||
)
|
||||
.test(
|
||||
"is-non-negative-integer",
|
||||
"The number of chunks must be a non-negative integer (0, 1, 2, etc.)",
|
||||
(value) =>
|
||||
value === null ||
|
||||
value === undefined ||
|
||||
(Number.isInteger(value) && value >= 0)
|
||||
),
|
||||
|
||||
// Advanced
|
||||
llm_model_provider_override: Yup.string().nullable().optional(),
|
||||
@@ -726,6 +737,9 @@ export default function AgentEditorPage({
|
||||
const finalStarterMessages =
|
||||
starterMessages.length > 0 ? starterMessages : null;
|
||||
|
||||
// Determine knowledge settings
|
||||
const numChunks = values.enable_knowledge ? MAX_CHUNKS_FED_TO_CHAT : 0;
|
||||
|
||||
// Always look up tools in availableTools to ensure we can find all tools
|
||||
|
||||
const toolIds = [];
|
||||
@@ -785,7 +799,11 @@ export default function AgentEditorPage({
|
||||
document_set_ids: values.enable_knowledge
|
||||
? values.document_set_ids
|
||||
: [],
|
||||
num_chunks: numChunks,
|
||||
is_public: values.is_public,
|
||||
// recency_bias: ...,
|
||||
// llm_filter_extraction: ...,
|
||||
llm_relevance_filter: false,
|
||||
llm_model_provider_override: values.llm_model_provider_override || null,
|
||||
llm_model_version_override: values.llm_model_version_override || null,
|
||||
starter_messages: finalStarterMessages,
|
||||
@@ -798,7 +816,7 @@ export default function AgentEditorPage({
|
||||
icon_name: values.icon_name,
|
||||
search_start_date: values.knowledge_cutoff_date || null,
|
||||
label_ids: values.label_ids,
|
||||
featured: values.featured,
|
||||
is_default_persona: values.is_default_persona,
|
||||
// display_priority: ...,
|
||||
|
||||
user_file_ids: values.enable_knowledge ? values.user_file_ids : [],
|
||||
@@ -1043,7 +1061,7 @@ export default function AgentEditorPage({
|
||||
userIds={values.shared_user_ids}
|
||||
groupIds={values.shared_group_ids}
|
||||
isPublic={values.is_public}
|
||||
isFeatured={values.featured}
|
||||
isFeatured={values.is_default_persona}
|
||||
labelIds={values.label_ids}
|
||||
onShare={(
|
||||
userIds,
|
||||
@@ -1055,7 +1073,7 @@ export default function AgentEditorPage({
|
||||
setFieldValue("shared_user_ids", userIds);
|
||||
setFieldValue("shared_group_ids", groupIds);
|
||||
setFieldValue("is_public", isPublic);
|
||||
setFieldValue("featured", isFeatured);
|
||||
setFieldValue("is_default_persona", isFeatured);
|
||||
setFieldValue("label_ids", labelIds);
|
||||
shareAgentModal.toggle(false);
|
||||
}}
|
||||
@@ -1372,13 +1390,13 @@ export default function AgentEditorPage({
|
||||
{canUpdateFeaturedStatus && (
|
||||
<>
|
||||
<InputLayouts.Horizontal
|
||||
name="featured"
|
||||
name="is_default_persona"
|
||||
title="Feature This Agent"
|
||||
description="Show this agent at the top of the explore agents list and automatically pin it to the sidebar for new users with access."
|
||||
>
|
||||
<SwitchField name="featured" />
|
||||
<SwitchField name="is_default_persona" />
|
||||
</InputLayouts.Horizontal>
|
||||
{values.featured && !isShared && (
|
||||
{values.is_default_persona && !isShared && (
|
||||
<Message
|
||||
static
|
||||
close={false}
|
||||
|
||||
@@ -361,10 +361,12 @@ export default function AgentsNavigationPage() {
|
||||
]);
|
||||
|
||||
const featuredAgents = [
|
||||
...memoizedCurrentlyVisibleAgents.filter((agent) => agent.featured),
|
||||
...memoizedCurrentlyVisibleAgents.filter(
|
||||
(agent) => agent.is_default_persona
|
||||
),
|
||||
];
|
||||
const allAgents = memoizedCurrentlyVisibleAgents.filter(
|
||||
(agent) => !agent.featured
|
||||
(agent) => !agent.is_default_persona
|
||||
);
|
||||
|
||||
const agentCount = featuredAgents.length + allAgents.length;
|
||||
|
||||
@@ -138,13 +138,7 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
currentChatSessionId,
|
||||
isLoading: isLoadingChatSessions,
|
||||
} = useChatSessions();
|
||||
// handle redirect if chat page is disabled
|
||||
// NOTE: this must be done here, in a client component since
|
||||
// settings are passed in via Context and therefore aren't
|
||||
// available in server-side components
|
||||
const settings = useSettingsContext();
|
||||
const vectorDbEnabled = settings?.settings.vector_db_enabled !== false;
|
||||
const { ccPairs } = useCCPairs(vectorDbEnabled);
|
||||
const { ccPairs } = useCCPairs();
|
||||
const { tags } = useTags();
|
||||
const { documentSets } = useDocumentSets();
|
||||
const {
|
||||
@@ -162,6 +156,12 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
setForcedToolIds([]);
|
||||
}, [currentProjectId, setForcedToolIds]);
|
||||
|
||||
// handle redirect if chat page is disabled
|
||||
// NOTE: this must be done here, in a client component since
|
||||
// settings are passed in via Context and therefore aren't
|
||||
// available in server-side components
|
||||
const settings = useSettingsContext();
|
||||
|
||||
const isInitialLoad = useRef(true);
|
||||
|
||||
const { agents, isLoading: isLoadingAgents } = useAgents();
|
||||
|
||||
@@ -4,7 +4,6 @@ import React, { useState } from "react";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import { Card, type CardProps } from "@/refresh-components/cards";
|
||||
import {
|
||||
SvgAlertCircle,
|
||||
SvgArrowExchange,
|
||||
SvgCheckCircle,
|
||||
SvgRefreshCw,
|
||||
@@ -17,11 +16,9 @@ import { Button } from "@opal/components";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
|
||||
import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationModalLayout";
|
||||
import useCodeInterpreter, {
|
||||
type CodeInterpreterHealthStatus,
|
||||
} from "@/hooks/useCodeInterpreter";
|
||||
import useCodeInterpreter from "@/hooks/useCodeInterpreter";
|
||||
import { updateCodeInterpreter } from "@/lib/admin/code-interpreter/svc";
|
||||
import { Content, ContentAction } from "@opal/layouts";
|
||||
import { ContentAction } from "@opal/layouts";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
|
||||
interface CodeInterpreterCardProps {
|
||||
@@ -71,44 +68,19 @@ function CheckingStatus() {
|
||||
);
|
||||
}
|
||||
|
||||
const STATUS_CONFIG: Record<
|
||||
CodeInterpreterHealthStatus,
|
||||
{ label: string; icon: typeof SvgCheckCircle; iconColor: string }
|
||||
> = {
|
||||
healthy: {
|
||||
label: "Connected",
|
||||
icon: SvgCheckCircle,
|
||||
iconColor: "text-status-success-05",
|
||||
},
|
||||
unhealthy: {
|
||||
label: "Unhealthy",
|
||||
icon: SvgAlertCircle,
|
||||
iconColor: "text-status-warning-05",
|
||||
},
|
||||
connection_lost: {
|
||||
label: "Connection Lost",
|
||||
icon: SvgXOctagon,
|
||||
iconColor: "text-status-error-05",
|
||||
},
|
||||
};
|
||||
|
||||
interface ConnectionStatusProps {
|
||||
status: CodeInterpreterHealthStatus;
|
||||
healthy: boolean;
|
||||
isLoading: boolean;
|
||||
onIconHover: (hovered: boolean) => void;
|
||||
}
|
||||
|
||||
function ConnectionStatus({
|
||||
status,
|
||||
isLoading,
|
||||
onIconHover,
|
||||
}: ConnectionStatusProps) {
|
||||
function ConnectionStatus({ healthy, isLoading }: ConnectionStatusProps) {
|
||||
if (isLoading) {
|
||||
return <CheckingStatus />;
|
||||
}
|
||||
|
||||
const { label, icon: Icon, iconColor } = STATUS_CONFIG[status];
|
||||
const hasError = status !== "healthy";
|
||||
const label = healthy ? "Connected" : "Connection Lost";
|
||||
const Icon = healthy ? SvgCheckCircle : SvgXOctagon;
|
||||
const iconColor = healthy ? "text-status-success-05" : "text-status-error-05";
|
||||
|
||||
return (
|
||||
<Section
|
||||
@@ -121,13 +93,7 @@ function ConnectionStatus({
|
||||
<Text mainUiAction text03>
|
||||
{label}
|
||||
</Text>
|
||||
<div
|
||||
onMouseEnter={() => hasError && onIconHover(true)}
|
||||
onMouseLeave={() => onIconHover(false)}
|
||||
className={hasError ? "cursor-pointer" : undefined}
|
||||
>
|
||||
<Icon size={16} className={iconColor} />
|
||||
</div>
|
||||
<Icon size={16} className={iconColor} />
|
||||
</Section>
|
||||
);
|
||||
}
|
||||
@@ -172,11 +138,9 @@ function ActionButtons({
|
||||
}
|
||||
|
||||
export default function CodeInterpreterPage() {
|
||||
const { status, error, isEnabled, isLoading, refetch } = useCodeInterpreter();
|
||||
const isHealthy = status === "healthy";
|
||||
const { isHealthy, isEnabled, isLoading, refetch } = useCodeInterpreter();
|
||||
const [showDisconnectModal, setShowDisconnectModal] = useState(false);
|
||||
const [isReconnecting, setIsReconnecting] = useState(false);
|
||||
const [showErrorMenu, setShowErrorMenu] = useState(false);
|
||||
|
||||
async function handleToggle(enabled: boolean) {
|
||||
const action = enabled ? "reconnect" : "disconnect";
|
||||
@@ -204,76 +168,51 @@ export default function CodeInterpreterPage() {
|
||||
/>
|
||||
|
||||
<SettingsLayouts.Body>
|
||||
<Section flexDirection="column" padding={0} gap={0.2}>
|
||||
{isEnabled || isLoading ? (
|
||||
<CodeInterpreterCard
|
||||
title="Code Interpreter"
|
||||
variant={isHealthy ? "primary" : "secondary"}
|
||||
strikethrough={!isHealthy}
|
||||
rightContent={
|
||||
<Section
|
||||
flexDirection="column"
|
||||
justifyContent="center"
|
||||
alignItems="end"
|
||||
gap={0}
|
||||
padding={0}
|
||||
>
|
||||
<ConnectionStatus
|
||||
status={status}
|
||||
isLoading={isLoading}
|
||||
onIconHover={setShowErrorMenu}
|
||||
/>
|
||||
<ActionButtons
|
||||
onDisconnect={() => setShowDisconnectModal(true)}
|
||||
onRefresh={refetch}
|
||||
disabled={isLoading}
|
||||
/>
|
||||
</Section>
|
||||
}
|
||||
/>
|
||||
) : (
|
||||
<CodeInterpreterCard
|
||||
variant="secondary"
|
||||
title="Code Interpreter"
|
||||
middleText="(Disconnected)"
|
||||
strikethrough={true}
|
||||
rightContent={
|
||||
<Section flexDirection="row" alignItems="center" padding={0.5}>
|
||||
{isReconnecting ? (
|
||||
<CheckingStatus />
|
||||
) : (
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
rightIcon={SvgArrowExchange}
|
||||
onClick={() => handleToggle(true)}
|
||||
>
|
||||
Reconnect
|
||||
</Button>
|
||||
)}
|
||||
</Section>
|
||||
}
|
||||
/>
|
||||
)}
|
||||
{showErrorMenu && (
|
||||
<Section flexDirection="row" justifyContent="end">
|
||||
<Card className="w-[15rem]">
|
||||
<Content
|
||||
icon={(props) => (
|
||||
<SvgXOctagon {...props} className="text-status-error-05" />
|
||||
)}
|
||||
title={
|
||||
status === "connection_lost"
|
||||
? "Connection Lost Error"
|
||||
: "Code Interpreter Error"
|
||||
}
|
||||
description={error}
|
||||
variant="section"
|
||||
sizePreset="main-ui"
|
||||
{isEnabled || isLoading ? (
|
||||
<CodeInterpreterCard
|
||||
title="Code Interpreter"
|
||||
variant={isHealthy ? "primary" : "secondary"}
|
||||
strikethrough={!isHealthy}
|
||||
rightContent={
|
||||
<Section
|
||||
flexDirection="column"
|
||||
justifyContent="center"
|
||||
alignItems="end"
|
||||
gap={0}
|
||||
padding={0}
|
||||
>
|
||||
<ConnectionStatus healthy={isHealthy} isLoading={isLoading} />
|
||||
<ActionButtons
|
||||
onDisconnect={() => setShowDisconnectModal(true)}
|
||||
onRefresh={refetch}
|
||||
disabled={isLoading}
|
||||
/>
|
||||
</Card>
|
||||
</Section>
|
||||
)}
|
||||
</Section>
|
||||
</Section>
|
||||
}
|
||||
/>
|
||||
) : (
|
||||
<CodeInterpreterCard
|
||||
variant="secondary"
|
||||
title="Code Interpreter"
|
||||
middleText="(Disconnected)"
|
||||
strikethrough={true}
|
||||
rightContent={
|
||||
<Section flexDirection="row" alignItems="center" padding={0.5}>
|
||||
{isReconnecting ? (
|
||||
<CheckingStatus />
|
||||
) : (
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
rightIcon={SvgArrowExchange}
|
||||
onClick={() => handleToggle(true)}
|
||||
>
|
||||
Reconnect
|
||||
</Button>
|
||||
)}
|
||||
</Section>
|
||||
}
|
||||
/>
|
||||
)}
|
||||
</SettingsLayouts.Body>
|
||||
|
||||
{showDisconnectModal && (
|
||||
|
||||
@@ -117,7 +117,7 @@ export default function AgentCard({ agent }: AgentCardProps) {
|
||||
userIds={fullAgent?.users?.map((u) => u.id) ?? []}
|
||||
groupIds={fullAgent?.groups ?? []}
|
||||
isPublic={fullAgent?.is_public ?? false}
|
||||
isFeatured={fullAgent?.featured ?? false}
|
||||
isFeatured={fullAgent?.is_default_persona ?? false}
|
||||
labelIds={fullAgent?.labels?.map((l) => l.id) ?? []}
|
||||
onShare={handleShare}
|
||||
/>
|
||||
|
||||
@@ -294,9 +294,7 @@ const AppInputBar = React.memo(
|
||||
);
|
||||
|
||||
const { activePromptShortcuts } = usePromptShortcuts();
|
||||
const vectorDbEnabled =
|
||||
combinedSettings?.settings.vector_db_enabled !== false;
|
||||
const { ccPairs, isLoading: ccPairsLoading } = useCCPairs(vectorDbEnabled);
|
||||
const { ccPairs, isLoading: ccPairsLoading } = useCCPairs();
|
||||
const { data: federatedConnectorsData, isLoading: federatedLoading } =
|
||||
useFederatedConnectors();
|
||||
|
||||
|
||||
@@ -892,7 +892,7 @@ export default function AgentKnowledgePane({
|
||||
}, [enableKnowledge]);
|
||||
|
||||
// Get connected sources from CC pairs
|
||||
const { ccPairs } = useCCPairs(vectorDbEnabled);
|
||||
const { ccPairs } = useCCPairs();
|
||||
const connectedSources: ConnectedSource[] = useMemo(() => {
|
||||
if (!ccPairs || ccPairs.length === 0) return [];
|
||||
const sourceSet = new Set<ValidSources>();
|
||||
|
||||
@@ -248,7 +248,7 @@ export default function AgentViewerModal({ agent }: AgentViewerModalProps) {
|
||||
bottomSlot={<AgentChatInput agent={agent} onSubmit={handleStartChat} />}
|
||||
>
|
||||
<Modal.Header
|
||||
icon={(props) => <AgentAvatar agent={agent} {...props} size={24} />}
|
||||
icon={(props) => <AgentAvatar agent={agent} {...props} size={28} />}
|
||||
title={agent.name}
|
||||
onClose={() => agentViewerModal.toggle(false)}
|
||||
/>
|
||||
@@ -256,7 +256,7 @@ export default function AgentViewerModal({ agent }: AgentViewerModalProps) {
|
||||
<Modal.Body>
|
||||
{/* Metadata */}
|
||||
<Section flexDirection="row" justifyContent="start">
|
||||
{agent.featured && (
|
||||
{!agent.is_default_persona && (
|
||||
<Content
|
||||
icon={SvgStar}
|
||||
title="Featured"
|
||||
|
||||
@@ -67,7 +67,6 @@ import {
|
||||
SvgSearchMenu,
|
||||
SvgSettings,
|
||||
} from "@opal/icons";
|
||||
import SidebarTabSkeleton from "@/refresh-components/skeletons/SidebarTabSkeleton";
|
||||
import BuildModeIntroBackground from "@/app/craft/components/IntroBackground";
|
||||
import BuildModeIntroContent from "@/app/craft/components/IntroContent";
|
||||
import { CRAFT_PATH } from "@/app/craft/v1/constants";
|
||||
@@ -100,25 +99,11 @@ function buildVisibleAgents(
|
||||
return [visibleAgents, currentAgentIsPinned];
|
||||
}
|
||||
|
||||
const SKELETON_WIDTHS_BASE = ["w-4/5", "w-4/5", "w-3/5"];
|
||||
|
||||
function shuffleWidths(): string[] {
|
||||
return [...SKELETON_WIDTHS_BASE].sort(() => Math.random() - 0.5);
|
||||
}
|
||||
|
||||
interface RecentsSectionProps {
|
||||
chatSessions: ChatSession[];
|
||||
hasMore: boolean;
|
||||
isLoadingMore: boolean;
|
||||
onLoadMore: () => void;
|
||||
}
|
||||
|
||||
function RecentsSection({
|
||||
chatSessions,
|
||||
hasMore,
|
||||
isLoadingMore,
|
||||
onLoadMore,
|
||||
}: RecentsSectionProps) {
|
||||
function RecentsSection({ chatSessions }: RecentsSectionProps) {
|
||||
const { setNodeRef, isOver } = useDroppable({
|
||||
id: DRAG_TYPES.RECENTS,
|
||||
data: {
|
||||
@@ -126,33 +111,6 @@ function RecentsSection({
|
||||
},
|
||||
});
|
||||
|
||||
// Re-shuffle skeleton widths each time loaded session count changes
|
||||
const skeletonWidths = useMemo(shuffleWidths, [chatSessions.length]);
|
||||
|
||||
// Sentinel ref for IntersectionObserver-based infinite scroll
|
||||
const sentinelRef = useRef<HTMLDivElement | null>(null);
|
||||
const onLoadMoreRef = useRef(onLoadMore);
|
||||
onLoadMoreRef.current = onLoadMore;
|
||||
|
||||
useEffect(() => {
|
||||
if (!hasMore || isLoadingMore) return;
|
||||
|
||||
const sentinel = sentinelRef.current;
|
||||
if (!sentinel) return;
|
||||
|
||||
const observer = new IntersectionObserver(
|
||||
(entries) => {
|
||||
if (entries[0]?.isIntersecting) {
|
||||
onLoadMoreRef.current();
|
||||
}
|
||||
},
|
||||
{ threshold: 0 }
|
||||
);
|
||||
|
||||
observer.observe(sentinel);
|
||||
return () => observer.disconnect();
|
||||
}, [hasMore, isLoadingMore]);
|
||||
|
||||
return (
|
||||
<div
|
||||
ref={setNodeRef}
|
||||
@@ -167,28 +125,13 @@ function RecentsSection({
|
||||
Try sending a message! Your chat history will appear here.
|
||||
</Text>
|
||||
) : (
|
||||
<>
|
||||
{chatSessions.map((chatSession) => (
|
||||
<ChatButton
|
||||
key={chatSession.id}
|
||||
chatSession={chatSession}
|
||||
draggable
|
||||
/>
|
||||
))}
|
||||
{hasMore &&
|
||||
skeletonWidths.map((width, i) => (
|
||||
<div
|
||||
key={i}
|
||||
ref={i === 0 ? sentinelRef : undefined}
|
||||
className={cn(
|
||||
"transition-opacity duration-300",
|
||||
isLoadingMore ? "opacity-100" : "opacity-40"
|
||||
)}
|
||||
>
|
||||
<SidebarTabSkeleton textWidth={width} />
|
||||
</div>
|
||||
))}
|
||||
</>
|
||||
chatSessions.map((chatSession) => (
|
||||
<ChatButton
|
||||
key={chatSession.id}
|
||||
chatSession={chatSession}
|
||||
draggable
|
||||
/>
|
||||
))
|
||||
)}
|
||||
</SidebarSection>
|
||||
</div>
|
||||
@@ -214,9 +157,6 @@ const MemoizedAppSidebarInner = memo(
|
||||
chatSessions,
|
||||
refreshChatSessions,
|
||||
isLoading: isLoadingChatSessions,
|
||||
hasMore,
|
||||
isLoadingMore,
|
||||
loadMore,
|
||||
} = useChatSessions();
|
||||
const {
|
||||
projects,
|
||||
@@ -760,12 +700,7 @@ const MemoizedAppSidebarInner = memo(
|
||||
</SidebarSection>
|
||||
|
||||
{/* Recents */}
|
||||
<RecentsSection
|
||||
chatSessions={chatSessions}
|
||||
hasMore={hasMore}
|
||||
isLoadingMore={isLoadingMore}
|
||||
onLoadMore={loadMore}
|
||||
/>
|
||||
<RecentsSection chatSessions={chatSessions} />
|
||||
</DndContext>
|
||||
</>
|
||||
)}
|
||||
|
||||
@@ -122,7 +122,7 @@ const ChatButton = memo(
|
||||
const [showShareModal, setShowShareModal] = useState(false);
|
||||
const [searchTerm, setSearchTerm] = useState("");
|
||||
const [popoverItems, setPopoverItems] = useState<React.ReactNode[]>([]);
|
||||
const { refreshChatSessions, removeSession } = useChatSessions();
|
||||
const { refreshChatSessions } = useChatSessions();
|
||||
const {
|
||||
refreshCurrentProjectDetails,
|
||||
projects,
|
||||
@@ -302,7 +302,6 @@ const ChatButton = memo(
|
||||
async function handleChatDelete() {
|
||||
try {
|
||||
await deleteChatSession(chatSession.id);
|
||||
removeSession(chatSession.id);
|
||||
|
||||
if (project) {
|
||||
await fetchProjects();
|
||||
|
||||
@@ -25,7 +25,6 @@ import {
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import useAppFocus from "@/hooks/useAppFocus";
|
||||
import { useSettingsContext } from "@/providers/SettingsProvider";
|
||||
|
||||
function getDisplayName(email?: string, personalName?: string): string {
|
||||
// Prioritize custom personal name if set
|
||||
@@ -103,11 +102,7 @@ function SettingsPopover({
|
||||
<PopoverMenu>
|
||||
{[
|
||||
<div key="user-settings" data-testid="Settings/user-settings">
|
||||
<LineItem
|
||||
icon={SvgUser}
|
||||
href="/app/settings"
|
||||
onClick={onUserSettingsClick}
|
||||
>
|
||||
<LineItem icon={SvgUser} onClick={onUserSettingsClick}>
|
||||
User Settings
|
||||
</LineItem>
|
||||
</div>,
|
||||
@@ -123,9 +118,13 @@ function SettingsPopover({
|
||||
<LineItem
|
||||
key="help-faq"
|
||||
icon={SvgExternalLink}
|
||||
href="https://docs.onyx.app"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
onClick={() =>
|
||||
window.open(
|
||||
"https://docs.onyx.app",
|
||||
"_blank",
|
||||
"noopener,noreferrer"
|
||||
)
|
||||
}
|
||||
>
|
||||
Help & FAQ
|
||||
</LineItem>,
|
||||
@@ -166,8 +165,6 @@ export default function UserAvatarPopover({
|
||||
const { user } = useUser();
|
||||
const router = useRouter();
|
||||
const appFocus = useAppFocus();
|
||||
const settings = useSettingsContext();
|
||||
const vectorDbEnabled = settings?.settings.vector_db_enabled !== false;
|
||||
|
||||
// Fetch notifications for display
|
||||
// The GET endpoint also triggers a refresh if release notes are stale
|
||||
@@ -186,9 +183,7 @@ export default function UserAvatarPopover({
|
||||
// Prefetch user settings data when popover opens for instant modal display
|
||||
preload("/api/user/pats", errorHandlingFetcher);
|
||||
preload("/api/federated/oauth-status", errorHandlingFetcher);
|
||||
if (vectorDbEnabled) {
|
||||
preload("/api/manage/connector-status", errorHandlingFetcher);
|
||||
}
|
||||
preload("/api/manage/connector-status", errorHandlingFetcher);
|
||||
preload("/api/llm/provider", errorHandlingFetcher);
|
||||
setPopupState("Settings");
|
||||
} else {
|
||||
@@ -238,6 +233,7 @@ export default function UserAvatarPopover({
|
||||
<SettingsPopover
|
||||
onUserSettingsClick={() => {
|
||||
setPopupState(undefined);
|
||||
router.push("/app/settings");
|
||||
}}
|
||||
onOpenNotifications={() => setPopupState("Notifications")}
|
||||
/>
|
||||
|
||||
@@ -19,12 +19,7 @@ const API_HEALTH_URL = "**/api/admin/code-interpreter/health";
|
||||
*/
|
||||
async function mockCodeInterpreterApi(
|
||||
page: Page,
|
||||
opts: {
|
||||
enabled: boolean;
|
||||
connected: boolean;
|
||||
error?: string;
|
||||
putStatus?: number;
|
||||
}
|
||||
opts: { enabled: boolean; healthy: boolean; putStatus?: number }
|
||||
) {
|
||||
const putStatus = opts.putStatus ?? 200;
|
||||
|
||||
@@ -32,10 +27,7 @@ async function mockCodeInterpreterApi(
|
||||
await route.fulfill({
|
||||
status: 200,
|
||||
contentType: "application/json",
|
||||
body: JSON.stringify({
|
||||
connected: opts.connected,
|
||||
error: opts.error ?? "",
|
||||
}),
|
||||
body: JSON.stringify({ healthy: opts.healthy }),
|
||||
});
|
||||
});
|
||||
|
||||
@@ -81,7 +73,7 @@ test.describe("Code Interpreter Admin Page", () => {
|
||||
});
|
||||
|
||||
test("page loads with header and description", async ({ page }) => {
|
||||
await mockCodeInterpreterApi(page, { enabled: true, connected: true });
|
||||
await mockCodeInterpreterApi(page, { enabled: true, healthy: true });
|
||||
await page.goto(CODE_INTERPRETER_URL);
|
||||
|
||||
await expect(page.locator('[aria-label="admin-page-title"]')).toHaveText(
|
||||
@@ -93,14 +85,14 @@ test.describe("Code Interpreter Admin Page", () => {
|
||||
});
|
||||
|
||||
test("shows Connected status when enabled and healthy", async ({ page }) => {
|
||||
await mockCodeInterpreterApi(page, { enabled: true, connected: true });
|
||||
await mockCodeInterpreterApi(page, { enabled: true, healthy: true });
|
||||
await page.goto(CODE_INTERPRETER_URL);
|
||||
|
||||
await expect(page.getByText("Connected")).toBeVisible({ timeout: 10000 });
|
||||
});
|
||||
|
||||
test("shows Connection Lost when enabled but unhealthy", async ({ page }) => {
|
||||
await mockCodeInterpreterApi(page, { enabled: true, connected: false });
|
||||
await mockCodeInterpreterApi(page, { enabled: true, healthy: false });
|
||||
await page.goto(CODE_INTERPRETER_URL);
|
||||
|
||||
await expect(page.getByText("Connection Lost")).toBeVisible({
|
||||
@@ -109,7 +101,7 @@ test.describe("Code Interpreter Admin Page", () => {
|
||||
});
|
||||
|
||||
test("shows Reconnect button when disabled", async ({ page }) => {
|
||||
await mockCodeInterpreterApi(page, { enabled: false, connected: false });
|
||||
await mockCodeInterpreterApi(page, { enabled: false, healthy: false });
|
||||
await page.goto(CODE_INTERPRETER_URL);
|
||||
|
||||
await expect(page.getByRole("button", { name: "Reconnect" })).toBeVisible({
|
||||
@@ -121,7 +113,7 @@ test.describe("Code Interpreter Admin Page", () => {
|
||||
test("disconnect flow opens modal and sends PUT request", async ({
|
||||
page,
|
||||
}) => {
|
||||
await mockCodeInterpreterApi(page, { enabled: true, connected: true });
|
||||
await mockCodeInterpreterApi(page, { enabled: true, healthy: true });
|
||||
await page.goto(CODE_INTERPRETER_URL);
|
||||
|
||||
await expect(page.getByText("Connected")).toBeVisible({ timeout: 10000 });
|
||||
@@ -148,7 +140,7 @@ test.describe("Code Interpreter Admin Page", () => {
|
||||
test("disconnect modal can be closed without disconnecting", async ({
|
||||
page,
|
||||
}) => {
|
||||
await mockCodeInterpreterApi(page, { enabled: true, connected: true });
|
||||
await mockCodeInterpreterApi(page, { enabled: true, healthy: true });
|
||||
await page.goto(CODE_INTERPRETER_URL);
|
||||
|
||||
await expect(page.getByText("Connected")).toBeVisible({ timeout: 10000 });
|
||||
@@ -169,7 +161,7 @@ test.describe("Code Interpreter Admin Page", () => {
|
||||
});
|
||||
|
||||
test("reconnect flow sends PUT with enabled=true", async ({ page }) => {
|
||||
await mockCodeInterpreterApi(page, { enabled: false, connected: false });
|
||||
await mockCodeInterpreterApi(page, { enabled: false, healthy: false });
|
||||
await page.goto(CODE_INTERPRETER_URL);
|
||||
|
||||
await expect(page.getByRole("button", { name: "Reconnect" })).toBeVisible({
|
||||
@@ -195,7 +187,7 @@ test.describe("Code Interpreter Admin Page", () => {
|
||||
await route.fulfill({
|
||||
status: 200,
|
||||
contentType: "application/json",
|
||||
body: JSON.stringify({ connected: false, error: "" }),
|
||||
body: JSON.stringify({ healthy: false }),
|
||||
});
|
||||
});
|
||||
|
||||
@@ -231,7 +223,7 @@ test.describe("Code Interpreter Admin Page", () => {
|
||||
test("shows error toast when disconnect fails", async ({ page }) => {
|
||||
await mockCodeInterpreterApi(page, {
|
||||
enabled: true,
|
||||
connected: true,
|
||||
healthy: true,
|
||||
putStatus: 500,
|
||||
});
|
||||
await page.goto(CODE_INTERPRETER_URL);
|
||||
@@ -252,7 +244,7 @@ test.describe("Code Interpreter Admin Page", () => {
|
||||
test("shows error toast when reconnect fails", async ({ page }) => {
|
||||
await mockCodeInterpreterApi(page, {
|
||||
enabled: false,
|
||||
connected: false,
|
||||
healthy: false,
|
||||
putStatus: 500,
|
||||
});
|
||||
await page.goto(CODE_INTERPRETER_URL);
|
||||
@@ -273,83 +265,4 @@ test.describe("Code Interpreter Admin Page", () => {
|
||||
timeout: 5000,
|
||||
});
|
||||
});
|
||||
|
||||
test("shows error popover on hover over unhealthy status icon", async ({
|
||||
page,
|
||||
}) => {
|
||||
const errorMessage = "Sandbox runtime crashed unexpectedly";
|
||||
await mockCodeInterpreterApi(page, {
|
||||
enabled: true,
|
||||
connected: true,
|
||||
error: errorMessage,
|
||||
});
|
||||
await page.goto(CODE_INTERPRETER_URL);
|
||||
|
||||
// Wait for the Unhealthy status to render
|
||||
const statusText = page.getByText("Unhealthy");
|
||||
await expect(statusText).toBeVisible({ timeout: 10000 });
|
||||
|
||||
// Error popover should not be visible initially
|
||||
await expect(page.getByText("Code Interpreter Error")).not.toBeVisible();
|
||||
|
||||
// Hover over the status icon — sibling of the status label
|
||||
const statusIcon = statusText.locator("..").locator(".cursor-pointer");
|
||||
await statusIcon.hover();
|
||||
|
||||
// Error popover should appear with the title and error description
|
||||
await expect(page.getByText("Code Interpreter Error")).toBeVisible({
|
||||
timeout: 3000,
|
||||
});
|
||||
await expect(page.getByText(errorMessage)).toBeVisible();
|
||||
|
||||
// Move mouse away — error popover should disappear
|
||||
await page.mouse.move(0, 0);
|
||||
await expect(page.getByText("Code Interpreter Error")).not.toBeVisible({
|
||||
timeout: 3000,
|
||||
});
|
||||
});
|
||||
|
||||
test("shows error popover on hover over connection lost status icon", async ({
|
||||
page,
|
||||
}) => {
|
||||
const errorMessage = "Failed to reach code interpreter service";
|
||||
await mockCodeInterpreterApi(page, {
|
||||
enabled: true,
|
||||
connected: false,
|
||||
error: errorMessage,
|
||||
});
|
||||
await page.goto(CODE_INTERPRETER_URL);
|
||||
|
||||
// Wait for the Connection Lost status to render
|
||||
const statusText = page.getByText("Connection Lost");
|
||||
await expect(statusText).toBeVisible({ timeout: 10000 });
|
||||
|
||||
// Hover over the status icon — sibling of the status label
|
||||
const statusIcon = statusText.locator("..").locator(".cursor-pointer");
|
||||
await statusIcon.hover();
|
||||
|
||||
// Error popover should show the connection lost title and error
|
||||
await expect(page.getByText("Connection Lost Error")).toBeVisible({
|
||||
timeout: 3000,
|
||||
});
|
||||
await expect(page.getByText(errorMessage)).toBeVisible();
|
||||
});
|
||||
|
||||
test("does not show error popover on hover when healthy", async ({
|
||||
page,
|
||||
}) => {
|
||||
await mockCodeInterpreterApi(page, {
|
||||
enabled: true,
|
||||
connected: true,
|
||||
});
|
||||
await page.goto(CODE_INTERPRETER_URL);
|
||||
|
||||
const statusText = page.getByText("Connected");
|
||||
await expect(statusText).toBeVisible({ timeout: 10000 });
|
||||
|
||||
// The status icon wrapper should not have cursor-pointer when healthy
|
||||
await expect(
|
||||
statusText.locator("..").locator(".cursor-pointer")
|
||||
).toHaveCount(0);
|
||||
});
|
||||
});
|
||||
|
||||
Reference in New Issue
Block a user