Compare commits

..

14 Commits

Author SHA1 Message Date
Evan Lohn
d972fb08fc feat: configurable sharepoint endpoints 2026-02-19 15:19:27 -08:00
Evan Lohn
4433d130db test 2026-02-19 14:32:29 -08:00
Evan Lohn
f108232415 feat: azure ad group pagination 2026-02-19 14:32:29 -08:00
Evan Lohn
6cd7c59a1c feat: sharepoint scalability 3 2026-02-19 14:32:26 -08:00
Evan Lohn
d31d8092ce nit 2026-02-19 14:32:23 -08:00
Evan Lohn
2cdeecc844 feat: delta sync sharepoint
Co-authored-by: CE11-Kishan <CE11-Kishan@users.noreply.github.com>
2026-02-19 14:32:23 -08:00
Evan Lohn
8f910a187b shouldve trusted claude 2026-02-19 14:32:19 -08:00
Evan Lohn
cfe5b95cc4 pr comments and fixes 2026-02-19 14:32:19 -08:00
Evan Lohn
760c4fae6a more test fixes 2026-02-19 14:32:19 -08:00
Evan Lohn
b769ee530f fix test 2026-02-19 14:32:19 -08:00
Evan Lohn
a42114a932 more comments 2026-02-19 14:32:18 -08:00
Evan Lohn
e709a9dd0e address pr comments 2026-02-19 14:32:18 -08:00
Evan Lohn
e1c16c2391 pr comments 2026-02-19 14:32:18 -08:00
Evan Lohn
9ff29b8879 feat: sharepoint scalability 1 2026-02-19 14:32:18 -08:00
315 changed files with 5073 additions and 11738 deletions

View File

@@ -33,7 +33,7 @@ jobs:
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
helm repo add minio https://charts.min.io/
helm repo add code-interpreter https://onyx-dot-app.github.io/python-sandbox/
helm repo add code-interpreter https://onyx-dot-app.github.io/code-interpreter/
helm repo update
- name: Build chart dependencies

View File

@@ -45,6 +45,9 @@ env:
# TODO: debug why this is failing and enable
CODE_INTERPRETER_BASE_URL: http://localhost:8000
# OpenSearch
OPENSEARCH_ADMIN_PASSWORD: "StrongPassword123!"
jobs:
discover-test-dirs:
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
@@ -115,10 +118,9 @@ jobs:
- name: Create .env file for Docker Compose
run: |
cat <<EOF > deployment/docker_compose/.env
COMPOSE_PROFILES=s3-filestore,opensearch-enabled
COMPOSE_PROFILES=s3-filestore
CODE_INTERPRETER_BETA_ENABLED=true
DISABLE_TELEMETRY=true
OPENSEARCH_FOR_ONYX_ENABLED=true
EOF
- name: Set up Standard Dependencies
@@ -127,6 +129,7 @@ jobs:
docker compose \
-f docker-compose.yml \
-f docker-compose.dev.yml \
-f docker-compose.opensearch.yml \
up -d \
minio \
relational_db \

View File

@@ -91,7 +91,7 @@ jobs:
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
helm repo add minio https://charts.min.io/
helm repo add code-interpreter https://onyx-dot-app.github.io/python-sandbox/
helm repo add code-interpreter https://onyx-dot-app.github.io/code-interpreter/
helm repo update
- name: Install Redis operator

View File

@@ -1,28 +0,0 @@
"""add scim_username to scim_user_mapping
Revision ID: 0bb4558f35df
Revises: 631fd2504136
Create Date: 2026-02-20 10:45:30.340188
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "0bb4558f35df"
down_revision = "631fd2504136"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"scim_user_mapping",
sa.Column("scim_username", sa.String(), nullable=True),
)
def downgrade() -> None:
op.drop_column("scim_user_mapping", "scim_username")

View File

@@ -1,32 +0,0 @@
"""add approx_chunk_count_in_vespa to opensearch tenant migration
Revision ID: 631fd2504136
Revises: c7f2e1b4a9d3
Create Date: 2026-02-18 21:07:52.831215
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "631fd2504136"
down_revision = "c7f2e1b4a9d3"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"opensearch_tenant_migration_record",
sa.Column(
"approx_chunk_count_in_vespa",
sa.Integer(),
nullable=True,
),
)
def downgrade() -> None:
op.drop_column("opensearch_tenant_migration_record", "approx_chunk_count_in_vespa")

View File

@@ -1,31 +0,0 @@
"""add sharing_scope to build_session
Revision ID: c7f2e1b4a9d3
Revises: 19c0ccb01687
Create Date: 2026-02-17 12:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
revision = "c7f2e1b4a9d3"
down_revision = "19c0ccb01687"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"build_session",
sa.Column(
"sharing_scope",
sa.String(),
nullable=False,
server_default="private",
),
)
def downgrade() -> None:
op.drop_column("build_session", "sharing_scope")

View File

@@ -9,7 +9,6 @@ from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from ee.onyx.server.user_group.models import SetCuratorRequest
@@ -19,15 +18,11 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair_from
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Credential
from onyx.db.models import Credential__UserGroup
from onyx.db.models import Document
from onyx.db.models import DocumentByConnectorCredentialPair
from onyx.db.models import DocumentSet
from onyx.db.models import DocumentSet__UserGroup
from onyx.db.models import FederatedConnector__DocumentSet
from onyx.db.models import LLMProvider__UserGroup
from onyx.db.models import Persona
from onyx.db.models import Persona__UserGroup
from onyx.db.models import TokenRateLimit__UserGroup
from onyx.db.models import User
@@ -200,60 +195,8 @@ def fetch_user_group(db_session: Session, user_group_id: int) -> UserGroup | Non
return db_session.scalar(stmt)
def _add_user_group_snapshot_eager_loads(
stmt: Select,
) -> Select:
"""Add eager loading options needed by UserGroup.from_model snapshot creation."""
return stmt.options(
selectinload(UserGroup.users),
selectinload(UserGroup.user_group_relationships),
selectinload(UserGroup.cc_pair_relationships)
.selectinload(UserGroup__ConnectorCredentialPair.cc_pair)
.options(
selectinload(ConnectorCredentialPair.connector),
selectinload(ConnectorCredentialPair.credential).selectinload(
Credential.user
),
),
selectinload(UserGroup.document_sets).options(
selectinload(DocumentSet.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSet.users),
selectinload(DocumentSet.groups),
selectinload(DocumentSet.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
),
selectinload(UserGroup.personas).options(
selectinload(Persona.tools),
selectinload(Persona.hierarchy_nodes),
selectinload(Persona.attached_documents).selectinload(
Document.parent_hierarchy_node
),
selectinload(Persona.labels),
selectinload(Persona.document_sets).options(
selectinload(DocumentSet.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSet.users),
selectinload(DocumentSet.groups),
selectinload(DocumentSet.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
),
selectinload(Persona.user),
selectinload(Persona.user_files),
selectinload(Persona.users),
selectinload(Persona.groups),
),
)
def fetch_user_groups(
db_session: Session,
only_up_to_date: bool = True,
eager_load_for_snapshot: bool = False,
db_session: Session, only_up_to_date: bool = True
) -> Sequence[UserGroup]:
"""
Fetches user groups from the database.
@@ -266,8 +209,6 @@ def fetch_user_groups(
db_session (Session): The SQLAlchemy session used to query the database.
only_up_to_date (bool, optional): Flag to determine whether to filter the results
to include only up to date user groups. Defaults to `True`.
eager_load_for_snapshot: If True, adds eager loading for all relationships
needed by UserGroup.from_model snapshot creation.
Returns:
Sequence[UserGroup]: A sequence of `UserGroup` objects matching the query criteria.
@@ -275,16 +216,11 @@ def fetch_user_groups(
stmt = select(UserGroup)
if only_up_to_date:
stmt = stmt.where(UserGroup.is_up_to_date == True) # noqa: E712
if eager_load_for_snapshot:
stmt = _add_user_group_snapshot_eager_loads(stmt)
return db_session.scalars(stmt).unique().all()
return db_session.scalars(stmt).all()
def fetch_user_groups_for_user(
db_session: Session,
user_id: UUID,
only_curator_groups: bool = False,
eager_load_for_snapshot: bool = False,
db_session: Session, user_id: UUID, only_curator_groups: bool = False
) -> Sequence[UserGroup]:
stmt = (
select(UserGroup)
@@ -294,9 +230,7 @@ def fetch_user_groups_for_user(
)
if only_curator_groups:
stmt = stmt.where(User__UserGroup.is_curator == True) # noqa: E712
if eager_load_for_snapshot:
stmt = _add_user_group_snapshot_eager_loads(stmt)
return db_session.scalars(stmt).unique().all()
return db_session.scalars(stmt).all()
def construct_document_id_select_by_usergroup(

View File

@@ -69,7 +69,7 @@ def _graph_api_get(
continue
resp.raise_for_status()
return resp.json()
except (_requests.ConnectionError, _requests.Timeout, _requests.HTTPError):
except (_requests.ConnectionError, _requests.Timeout):
if attempt < GRAPH_API_MAX_RETRIES:
wait = min(2**attempt, 60)
logger.warning(

View File

@@ -37,15 +37,12 @@ def list_user_groups(
db_session: Session = Depends(get_session),
) -> list[UserGroup]:
if user.role == UserRole.ADMIN:
user_groups = fetch_user_groups(
db_session, only_up_to_date=False, eager_load_for_snapshot=True
)
user_groups = fetch_user_groups(db_session, only_up_to_date=False)
else:
user_groups = fetch_user_groups_for_user(
db_session=db_session,
user_id=user.id,
only_curator_groups=user.role == UserRole.CURATOR,
eager_load_for_snapshot=True,
)
return [UserGroup.from_model(user_group) for user_group in user_groups]

View File

@@ -53,8 +53,7 @@ class UserGroup(BaseModel):
id=cc_pair_relationship.cc_pair.id,
name=cc_pair_relationship.cc_pair.name,
connector=ConnectorSnapshot.from_connector_db_model(
cc_pair_relationship.cc_pair.connector,
credential_ids=[cc_pair_relationship.cc_pair.credential_id],
cc_pair_relationship.cc_pair.connector
),
credential=CredentialSnapshot.from_credential_db_model(
cc_pair_relationship.cc_pair.credential

View File

@@ -121,7 +121,6 @@ from onyx.db.pat import fetch_user_for_pat
from onyx.db.users import get_user_by_email
from onyx.redis.redis_pool import get_async_redis_connection
from onyx.redis.redis_pool import get_redis_client
from onyx.server.settings.store import load_settings
from onyx.server.utils import BasicAuthenticationError
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import mt_cloud_telemetry
@@ -138,8 +137,6 @@ from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
REGISTER_INVITE_ONLY_CODE = "REGISTER_INVITE_ONLY"
def is_user_admin(user: User) -> bool:
return user.role == UserRole.ADMIN
@@ -211,34 +208,22 @@ def anonymous_user_enabled(*, tenant_id: str | None = None) -> bool:
return int(value.decode("utf-8")) == 1
def workspace_invite_only_enabled() -> bool:
settings = load_settings()
return settings.invite_only_enabled
def verify_email_is_invited(email: str) -> None:
if AUTH_TYPE in {AuthType.SAML, AuthType.OIDC}:
# SSO providers manage membership; allow JIT provisioning regardless of invites
return
if not workspace_invite_only_enabled():
whitelist = get_invited_users()
if not whitelist:
return
whitelist = get_invited_users()
if not email:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"reason": "Email must be specified"},
)
raise PermissionError("Email must be specified")
try:
email_info = validate_email(email, check_deliverability=False)
except EmailUndeliverableError:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail={"reason": "Email is not valid"},
)
raise PermissionError("Email is not valid")
for email_whitelist in whitelist:
try:
@@ -255,13 +240,7 @@ def verify_email_is_invited(email: str) -> None:
if email_info.normalized.lower() == email_info_whitelist.normalized.lower():
return
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail={
"code": REGISTER_INVITE_ONLY_CODE,
"reason": "This workspace is invite-only. Please ask your admin to invite you.",
},
)
raise PermissionError("User not on allowed user whitelist")
def verify_email_in_whitelist(email: str, tenant_id: str) -> None:
@@ -1671,10 +1650,7 @@ def get_oauth_router(
if redirect_url is not None:
authorize_redirect_url = redirect_url
else:
# Use WEB_DOMAIN instead of request.url_for() to prevent host
# header poisoning — request.url_for() trusts the Host header.
callback_path = request.app.url_path_for(callback_route_name)
authorize_redirect_url = f"{WEB_DOMAIN}{callback_path}"
authorize_redirect_url = str(request.url_for(callback_route_name))
next_url = request.query_params.get("next", "/")

View File

@@ -0,0 +1,10 @@
"""Celery tasks for hierarchy fetching."""
from onyx.background.celery.tasks.hierarchyfetching.tasks import ( # noqa: F401
check_for_hierarchy_fetching,
)
from onyx.background.celery.tasks.hierarchyfetching.tasks import ( # noqa: F401
connector_hierarchy_fetching_task,
)
__all__ = ["check_for_hierarchy_fetching", "connector_hierarchy_fetching_task"]

View File

@@ -41,14 +41,3 @@ assert (
CHECK_FOR_DOCUMENTS_TASK_LOCK_BLOCKING_TIMEOUT_S = 30 # 30 seconds.
TOTAL_ALLOWABLE_DOC_MIGRATION_ATTEMPTS_BEFORE_PERMANENT_FAILURE = 15
# WARNING: Do not change these values without knowing what changes also need to
# be made to OpenSearchTenantMigrationRecord.
GET_VESPA_CHUNKS_PAGE_SIZE = 500
GET_VESPA_CHUNKS_SLICE_COUNT = 4
# String used to indicate in the vespa_visit_continuation_token mapping that the
# slice has finished and there is nothing left to visit.
FINISHED_VISITING_SLICE_CONTINUATION_TOKEN = (
"FINISHED_VISITING_SLICE_CONTINUATION_TOKEN"
)

View File

@@ -8,12 +8,6 @@ from celery import Task
from redis.lock import Lock as RedisLock
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.tasks.opensearch_migration.constants import (
FINISHED_VISITING_SLICE_CONTINUATION_TOKEN,
)
from onyx.background.celery.tasks.opensearch_migration.constants import (
GET_VESPA_CHUNKS_PAGE_SIZE,
)
from onyx.background.celery.tasks.opensearch_migration.constants import (
MIGRATION_TASK_LOCK_BLOCKING_TIMEOUT_S,
)
@@ -53,13 +47,7 @@ from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import get_current_tenant_id
def is_continuation_token_done_for_all_slices(
continuation_token_map: dict[int, str | None],
) -> bool:
return all(
continuation_token == FINISHED_VISITING_SLICE_CONTINUATION_TOKEN
for continuation_token in continuation_token_map.values()
)
GET_VESPA_CHUNKS_PAGE_SIZE = 1000
# shared_task allows this task to be shared across celery app instances.
@@ -88,15 +76,11 @@ def migrate_chunks_from_vespa_to_opensearch_task(
Uses Vespa's Visit API to iterate through ALL chunks in bulk (not
per-document), transform them, and index them into OpenSearch. Progress is
tracked via a continuation token map stored in the
tracked via a continuation token stored in the
OpenSearchTenantMigrationRecord.
The first time we see no continuation token map and non-zero chunks
migrated, we consider the migration complete and all subsequent invocations
are no-ops.
We divide the index into GET_VESPA_CHUNKS_SLICE_COUNT independent slices
where progress is tracked for each slice.
The first time we see no continuation token and non-zero chunks migrated, we
consider the migration complete and all subsequent invocations are no-ops.
Returns:
None if OpenSearch migration is not enabled, or if the lock could not be
@@ -169,28 +153,15 @@ def migrate_chunks_from_vespa_to_opensearch_task(
f"in {time.monotonic() - sanitized_doc_start_time:.3f} seconds."
)
approx_chunk_count_in_vespa: int | None = None
get_chunk_count_start_time = time.monotonic()
try:
approx_chunk_count_in_vespa = vespa_document_index.get_chunk_count()
except Exception:
task_logger.exception(
"Error getting approximate chunk count in Vespa. Moving on..."
)
task_logger.debug(
f"Took {time.monotonic() - get_chunk_count_start_time:.3f} seconds to attempt to get "
f"approximate chunk count in Vespa. Got {approx_chunk_count_in_vespa}."
)
while (
time.monotonic() - task_start_time < MIGRATION_TASK_SOFT_TIME_LIMIT_S
and lock.owned()
):
(
continuation_token_map,
continuation_token,
total_chunks_migrated,
) = get_vespa_visit_state(db_session)
if is_continuation_token_done_for_all_slices(continuation_token_map):
if continuation_token is None and total_chunks_migrated > 0:
task_logger.info(
f"OpenSearch migration COMPLETED for tenant {tenant_id}. "
f"Total chunks migrated: {total_chunks_migrated}."
@@ -199,19 +170,19 @@ def migrate_chunks_from_vespa_to_opensearch_task(
break
task_logger.debug(
f"Read the tenant migration record. Total chunks migrated: {total_chunks_migrated}. "
f"Continuation token map: {continuation_token_map}"
f"Continuation token: {continuation_token}"
)
get_vespa_chunks_start_time = time.monotonic()
raw_vespa_chunks, next_continuation_token_map = (
raw_vespa_chunks, next_continuation_token = (
vespa_document_index.get_all_raw_document_chunks_paginated(
continuation_token_map=continuation_token_map,
continuation_token=continuation_token,
page_size=GET_VESPA_CHUNKS_PAGE_SIZE,
)
)
task_logger.debug(
f"Read {len(raw_vespa_chunks)} chunks from Vespa in {time.monotonic() - get_vespa_chunks_start_time:.3f} "
f"seconds. Next continuation token map: {next_continuation_token_map}"
f"seconds. Next continuation token: {next_continuation_token}"
)
opensearch_document_chunks, errored_chunks = (
@@ -241,11 +212,14 @@ def migrate_chunks_from_vespa_to_opensearch_task(
total_chunks_errored_this_task += len(errored_chunks)
update_vespa_visit_progress_with_commit(
db_session,
continuation_token_map=next_continuation_token_map,
continuation_token=next_continuation_token,
chunks_processed=len(opensearch_document_chunks),
chunks_errored=len(errored_chunks),
approx_chunk_count_in_vespa=approx_chunk_count_in_vespa,
)
if next_continuation_token is None and len(raw_vespa_chunks) == 0:
task_logger.info("Vespa reported no more chunks to migrate.")
break
except Exception:
traceback.print_exc()
task_logger.exception("Error in the OpenSearch migration task.")

View File

@@ -37,35 +37,6 @@ from shared_configs.configs import MULTI_TENANT
logger = setup_logger(__name__)
FIELDS_NEEDED_FOR_TRANSFORMATION: list[str] = [
DOCUMENT_ID,
CHUNK_ID,
TITLE,
TITLE_EMBEDDING,
CONTENT,
EMBEDDINGS,
SOURCE_TYPE,
METADATA_LIST,
DOC_UPDATED_AT,
HIDDEN,
BOOST,
SEMANTIC_IDENTIFIER,
IMAGE_FILE_NAME,
SOURCE_LINKS,
BLURB,
DOC_SUMMARY,
CHUNK_CONTEXT,
METADATA_SUFFIX,
DOCUMENT_SETS,
USER_PROJECT,
PRIMARY_OWNERS,
SECONDARY_OWNERS,
ACCESS_CONTROL_LIST,
]
if MULTI_TENANT:
FIELDS_NEEDED_FOR_TRANSFORMATION.append(TENANT_ID)
def _extract_content_vector(embeddings: Any) -> list[float]:
"""Extracts the full chunk embedding vector from Vespa's embeddings tensor.

View File

@@ -0,0 +1,8 @@
"""Celery tasks for connector pruning."""
from onyx.background.celery.tasks.pruning.tasks import check_for_pruning # noqa: F401
from onyx.background.celery.tasks.pruning.tasks import ( # noqa: F401
connector_pruning_generator_task,
)
__all__ = ["check_for_pruning", "connector_pruning_generator_task"]

View File

@@ -13,7 +13,6 @@ from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
from onyx.configs.app_configs import DISABLE_VECTOR_DB
@@ -22,14 +21,12 @@ from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH
from onyx.connectors.file.connector import LocalFileConnector
from onyx.connectors.models import Document
from onyx.connectors.models import HierarchyNode
@@ -60,17 +57,6 @@ def _user_file_lock_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_PROCESSING_LOCK_PREFIX}:{user_file_id}"
def _user_file_queued_key(user_file_id: str | UUID) -> str:
"""Key that exists while a process_single_user_file task is sitting in the queue.
The beat generator sets this with a TTL equal to CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
before enqueuing and the worker deletes it as its first action. This prevents
the beat from adding duplicate tasks for files that already have a live task
in flight.
"""
return f"{OnyxRedisLocks.USER_FILE_QUEUED_PREFIX}:{user_file_id}"
def _user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_LOCK_PREFIX}:{user_file_id}"
@@ -134,24 +120,7 @@ def _get_document_chunk_count(
def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
"""Scan for user files with PROCESSING status and enqueue per-file tasks.
Three mechanisms prevent queue runaway:
1. **Queue depth backpressure** if the broker queue already has more than
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH items we skip this beat cycle
entirely. Workers are clearly behind; adding more tasks would only make
the backlog worse.
2. **Per-file queued guard** before enqueuing a task we set a short-lived
Redis key (TTL = CELERY_USER_FILE_PROCESSING_TASK_EXPIRES). If that key
already exists the file already has a live task in the queue, so we skip
it. The worker deletes the key the moment it picks up the task so the
next beat cycle can re-enqueue if the file is still PROCESSING.
3. **Task expiry** every enqueued task carries an `expires` value equal to
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES. If a task is still sitting in
the queue after that deadline, Celery discards it without touching the DB.
This is a belt-and-suspenders defence: even if the guard key is lost (e.g.
Redis restart), stale tasks evict themselves rather than piling up forever.
Uses direct Redis locks to avoid overlapping runs.
"""
task_logger.info("check_user_file_processing - Starting")
@@ -166,21 +135,7 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
return None
enqueued = 0
skipped_guard = 0
try:
# --- Protection 1: queue depth backpressure ---
r_celery = self.app.broker_connection().channel().client # type: ignore
queue_len = celery_get_queue_length(
OnyxCeleryQueues.USER_FILE_PROCESSING, r_celery
)
if queue_len > USER_FILE_PROCESSING_MAX_QUEUE_DEPTH:
task_logger.warning(
f"check_user_file_processing - Queue depth {queue_len} exceeds "
f"{USER_FILE_PROCESSING_MAX_QUEUE_DEPTH}, skipping enqueue for "
f"tenant={tenant_id}"
)
return None
with get_session_with_current_tenant() as db_session:
user_file_ids = (
db_session.execute(
@@ -193,35 +148,12 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
)
for user_file_id in user_file_ids:
# --- Protection 2: per-file queued guard ---
queued_key = _user_file_queued_key(user_file_id)
guard_set = redis_client.set(
queued_key,
1,
ex=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
nx=True,
self.app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
priority=OnyxCeleryPriority.HIGH,
)
if not guard_set:
skipped_guard += 1
continue
# --- Protection 3: task expiry ---
# If task submission fails, clear the guard immediately so the
# next beat cycle can retry enqueuing this file.
try:
self.app.send_task(
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
kwargs={
"user_file_id": str(user_file_id),
"tenant_id": tenant_id,
},
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
priority=OnyxCeleryPriority.HIGH,
expires=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
)
except Exception:
redis_client.delete(queued_key)
raise
enqueued += 1
finally:
@@ -229,8 +161,7 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
lock.release()
task_logger.info(
f"check_user_file_processing - Enqueued {enqueued} skipped_guard={skipped_guard} "
f"tasks for tenant={tenant_id}"
f"check_user_file_processing - Enqueued {enqueued} tasks for tenant={tenant_id}"
)
return None
@@ -373,12 +304,6 @@ def process_single_user_file(
start = time.monotonic()
redis_client = get_redis_client(tenant_id=tenant_id)
# Clear the "queued" guard set by the beat generator so that the next beat
# cycle can re-enqueue this file if it is still in PROCESSING state after
# this task completes or fails.
redis_client.delete(_user_file_queued_key(user_file_id))
file_lock: RedisLock = redis_client.lock(
_user_file_lock_key(user_file_id),
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,

View File

@@ -1,4 +1,3 @@
import json
import re
from collections.abc import Callable
from typing import cast
@@ -46,7 +45,6 @@ from onyx.utils.timing import log_function_time
logger = setup_logger()
IMAGE_GENERATION_TOOL_NAME = "generate_image"
def create_chat_session_from_request(
@@ -424,40 +422,6 @@ def convert_chat_history_basic(
return list(reversed(trimmed_reversed))
def _build_tool_call_response_history_message(
tool_name: str,
generated_images: list[dict] | None,
tool_call_response: str | None,
) -> str:
if tool_name != IMAGE_GENERATION_TOOL_NAME:
return TOOL_CALL_RESPONSE_CROSS_MESSAGE
if generated_images:
llm_image_context: list[dict[str, str]] = []
for image in generated_images:
file_id = image.get("file_id")
revised_prompt = image.get("revised_prompt")
if not isinstance(file_id, str):
continue
llm_image_context.append(
{
"file_id": file_id,
"revised_prompt": (
revised_prompt if isinstance(revised_prompt, str) else ""
),
}
)
if llm_image_context:
return json.dumps(llm_image_context)
if tool_call_response:
return tool_call_response
return TOOL_CALL_RESPONSE_CROSS_MESSAGE
def convert_chat_history(
chat_history: list[ChatMessage],
files: list[ChatLoadedFile],
@@ -618,24 +582,10 @@ def convert_chat_history(
# Add TOOL_CALL_RESPONSE messages for each tool call in this turn
for tool_call in turn_tool_calls:
tool_name = tool_id_to_name_map.get(
tool_call.tool_id, "unknown"
)
tool_response_message = (
_build_tool_call_response_history_message(
tool_name=tool_name,
generated_images=tool_call.generated_images,
tool_call_response=tool_call.tool_call_response,
)
)
simple_messages.append(
ChatMessageSimple(
message=tool_response_message,
token_count=(
token_counter(tool_response_message)
if tool_name == IMAGE_GENERATION_TOOL_NAME
else 20
),
message=TOOL_CALL_RESPONSE_CROSS_MESSAGE,
token_count=20, # Tiny overestimate
message_type=MessageType.TOOL_CALL_RESPONSE,
tool_call_id=tool_call.tool_call_id,
image_files=None,

View File

@@ -57,7 +57,6 @@ from onyx.tools.tool_implementations.images.models import (
FinalImageGenerationResponse,
)
from onyx.tools.tool_implementations.memory.models import MemoryToolResponse
from onyx.tools.tool_implementations.python.python_tool import PythonTool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.tool_implementations.web_search.utils import extract_url_snippet_map
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
@@ -652,7 +651,6 @@ def run_llm_loop(
ran_image_gen: bool = False
just_ran_web_search: bool = False
has_called_search_tool: bool = False
code_interpreter_file_generated: bool = False
fallback_extraction_attempted: bool = False
citation_mapping: dict[int, str] = {} # Maps citation_num -> document_id/URL
@@ -763,7 +761,6 @@ def run_llm_loop(
),
include_citation_reminder=should_cite_documents
or always_cite_documents,
include_file_reminder=code_interpreter_file_generated,
is_last_cycle=out_of_cycles,
)
@@ -903,18 +900,6 @@ def run_llm_loop(
if tool_call.tool_name == SearchTool.NAME:
has_called_search_tool = True
# Track if code interpreter generated files with download links
if (
tool_call.tool_name == PythonTool.NAME
and not code_interpreter_file_generated
):
try:
parsed = json.loads(tool_response.llm_facing_response)
if parsed.get("generated_files"):
code_interpreter_file_generated = True
except (json.JSONDecodeError, AttributeError):
pass
# Build a mapping of tool names to tool objects for getting tool_id
tools_by_name = {tool.name: tool for tool in final_tools}

View File

@@ -10,7 +10,6 @@ from onyx.db.user_file import calculate_user_files_token_count
from onyx.file_store.models import FileDescriptor
from onyx.prompts.chat_prompts import CITATION_REMINDER
from onyx.prompts.chat_prompts import DEFAULT_SYSTEM_PROMPT
from onyx.prompts.chat_prompts import FILE_REMINDER
from onyx.prompts.chat_prompts import LAST_CYCLE_CITATION_REMINDER
from onyx.prompts.chat_prompts import REQUIRE_CITATION_GUIDANCE
from onyx.prompts.prompt_utils import get_company_context
@@ -126,7 +125,6 @@ def calculate_reserved_tokens(
def build_reminder_message(
reminder_text: str | None,
include_citation_reminder: bool,
include_file_reminder: bool,
is_last_cycle: bool,
) -> str | None:
reminder = reminder_text.strip() if reminder_text else ""
@@ -134,8 +132,6 @@ def build_reminder_message(
reminder += "\n\n" + LAST_CYCLE_CITATION_REMINDER
if include_citation_reminder:
reminder += "\n\n" + CITATION_REMINDER
if include_file_reminder:
reminder += "\n\n" + FILE_REMINDER
reminder = reminder.strip()
return reminder if reminder else None
@@ -190,7 +186,7 @@ def _build_user_information_section(
if not sections:
return ""
return USER_INFORMATION_HEADER + "\n".join(sections)
return USER_INFORMATION_HEADER + "".join(sections)
def build_system_prompt(
@@ -228,21 +224,23 @@ def build_system_prompt(
system_prompt += REQUIRE_CITATION_GUIDANCE
if include_all_guidance:
tool_sections = [
TOOL_DESCRIPTION_SEARCH_GUIDANCE,
INTERNAL_SEARCH_GUIDANCE,
WEB_SEARCH_GUIDANCE.format(
system_prompt += (
TOOL_SECTION_HEADER
+ TOOL_DESCRIPTION_SEARCH_GUIDANCE
+ INTERNAL_SEARCH_GUIDANCE
+ WEB_SEARCH_GUIDANCE.format(
site_colon_disabled=WEB_SEARCH_SITE_DISABLED_GUIDANCE
),
OPEN_URLS_GUIDANCE,
PYTHON_TOOL_GUIDANCE,
GENERATE_IMAGE_GUIDANCE,
MEMORY_GUIDANCE,
]
system_prompt += TOOL_SECTION_HEADER + "\n".join(tool_sections)
)
+ OPEN_URLS_GUIDANCE
+ PYTHON_TOOL_GUIDANCE
+ GENERATE_IMAGE_GUIDANCE
+ MEMORY_GUIDANCE
)
return system_prompt
if tools:
system_prompt += TOOL_SECTION_HEADER
has_web_search = any(isinstance(tool, WebSearchTool) for tool in tools)
has_internal_search = any(isinstance(tool, SearchTool) for tool in tools)
has_open_urls = any(isinstance(tool, OpenURLTool) for tool in tools)
@@ -252,14 +250,12 @@ def build_system_prompt(
)
has_memory = any(isinstance(tool, MemoryTool) for tool in tools)
tool_guidance_sections: list[str] = []
if has_web_search or has_internal_search or include_all_guidance:
tool_guidance_sections.append(TOOL_DESCRIPTION_SEARCH_GUIDANCE)
system_prompt += TOOL_DESCRIPTION_SEARCH_GUIDANCE
# These are not included at the Tool level because the ordering may matter.
if has_internal_search or include_all_guidance:
tool_guidance_sections.append(INTERNAL_SEARCH_GUIDANCE)
system_prompt += INTERNAL_SEARCH_GUIDANCE
if has_web_search or include_all_guidance:
site_disabled_guidance = ""
@@ -269,23 +265,20 @@ def build_system_prompt(
)
if web_search_tool and not web_search_tool.supports_site_filter:
site_disabled_guidance = WEB_SEARCH_SITE_DISABLED_GUIDANCE
tool_guidance_sections.append(
WEB_SEARCH_GUIDANCE.format(site_colon_disabled=site_disabled_guidance)
system_prompt += WEB_SEARCH_GUIDANCE.format(
site_colon_disabled=site_disabled_guidance
)
if has_open_urls or include_all_guidance:
tool_guidance_sections.append(OPEN_URLS_GUIDANCE)
system_prompt += OPEN_URLS_GUIDANCE
if has_python or include_all_guidance:
tool_guidance_sections.append(PYTHON_TOOL_GUIDANCE)
system_prompt += PYTHON_TOOL_GUIDANCE
if has_generate_image or include_all_guidance:
tool_guidance_sections.append(GENERATE_IMAGE_GUIDANCE)
system_prompt += GENERATE_IMAGE_GUIDANCE
if has_memory or include_all_guidance:
tool_guidance_sections.append(MEMORY_GUIDANCE)
if tool_guidance_sections:
system_prompt += TOOL_SECTION_HEADER + "\n".join(tool_guidance_sections)
system_prompt += MEMORY_GUIDANCE
return system_prompt

View File

@@ -251,9 +251,7 @@ DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S = int(
os.environ.get("DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S") or 50
)
OPENSEARCH_ADMIN_USERNAME = os.environ.get("OPENSEARCH_ADMIN_USERNAME", "admin")
OPENSEARCH_ADMIN_PASSWORD = os.environ.get(
"OPENSEARCH_ADMIN_PASSWORD", "StrongPassword123!"
)
OPENSEARCH_ADMIN_PASSWORD = os.environ.get("OPENSEARCH_ADMIN_PASSWORD", "")
USING_AWS_MANAGED_OPENSEARCH = (
os.environ.get("USING_AWS_MANAGED_OPENSEARCH", "").lower() == "true"
)
@@ -265,18 +263,6 @@ OPENSEARCH_PROFILING_DISABLED = (
os.environ.get("OPENSEARCH_PROFILING_DISABLED", "").lower() == "true"
)
# When enabled, OpenSearch returns detailed score breakdowns for each hit.
# Useful for debugging and tuning search relevance. Has ~10-30% performance overhead according to documentation.
# Seems for Hybrid Search in practice, the impact is actually more like 1000x slower.
OPENSEARCH_EXPLAIN_ENABLED = (
os.environ.get("OPENSEARCH_EXPLAIN_ENABLED", "").lower() == "true"
)
# Analyzer used for full-text fields (title, content). Use OpenSearch built-in analyzer
# names (e.g. "english", "standard", "german"). Affects stemming and tokenization;
# existing indices need reindexing after a change.
OPENSEARCH_TEXT_ANALYZER = os.environ.get("OPENSEARCH_TEXT_ANALYZER") or "english"
# This is the "base" config for now, the idea is that at least for our dev
# environments we always want to be dual indexing into both OpenSearch and Vespa
# to stress test the new codepaths. Only enable this if there is some instance
@@ -284,9 +270,6 @@ OPENSEARCH_TEXT_ANALYZER = os.environ.get("OPENSEARCH_TEXT_ANALYZER") or "englis
ENABLE_OPENSEARCH_INDEXING_FOR_ONYX = (
os.environ.get("ENABLE_OPENSEARCH_INDEXING_FOR_ONYX", "").lower() == "true"
)
# NOTE: This effectively does nothing anymore, admins can now toggle whether
# retrieval is through OpenSearch. This value is only used as a final fallback
# in case that doesn't work for whatever reason.
# Given that the "base" config above is true, this enables whether we want to
# retrieve from OpenSearch or Vespa. We want to be able to quickly toggle this
# in the event we see issues with OpenSearch retrieval in our dev environments.

View File

@@ -157,17 +157,6 @@ CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT = 300 # 5 min
CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT = 30 * 60 # 30 minutes (in seconds)
# How long a queued user-file task is valid before workers discard it.
# Should be longer than the beat interval (20 s) but short enough to prevent
# indefinite queue growth. Workers drop tasks older than this without touching
# the DB, so a shorter value = faster drain of stale duplicates.
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES = 60 # 1 minute (in seconds)
# Maximum number of tasks allowed in the user-file-processing queue before the
# beat generator stops adding more. Prevents unbounded queue growth when workers
# fall behind.
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH = 500
CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
CELERY_SANDBOX_FILE_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
@@ -454,9 +443,6 @@ class OnyxRedisLocks:
# User file processing
USER_FILE_PROCESSING_BEAT_LOCK = "da_lock:check_user_file_processing_beat"
USER_FILE_PROCESSING_LOCK_PREFIX = "da_lock:user_file_processing"
# Short-lived key set when a task is enqueued; cleared when the worker picks it up.
# Prevents the beat from re-enqueuing the same file while a task is already queued.
USER_FILE_QUEUED_PREFIX = "da_lock:user_file_queued"
USER_FILE_PROJECT_SYNC_BEAT_LOCK = "da_lock:check_user_file_project_sync_beat"
USER_FILE_PROJECT_SYNC_LOCK_PREFIX = "da_lock:user_file_project_sync"
USER_FILE_DELETE_BEAT_LOCK = "da_lock:check_user_file_delete_beat"

View File

@@ -71,7 +71,6 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
SLIM_BATCH_SIZE = 1000
_EPOCH = datetime.fromtimestamp(0, tz=timezone.utc)
SHARED_DOCUMENTS_MAP = {
@@ -244,12 +243,6 @@ class SharepointConnectorCheckpoint(ConnectorCheckpoint):
current_drive_name: str | None = None
# Drive's web_url from the API - used as raw_node_id for DRIVE hierarchy nodes
current_drive_web_url: str | None = None
# Resolved drive ID — avoids re-resolving on checkpoint resume
current_drive_id: str | None = None
# Next delta API page URL for per-page checkpointing within a drive.
# When set, Phase 3b fetches one page at a time so progress is persisted
# between pages. None means BFS path or no active delta traversal.
current_drive_delta_next_link: str | None = None
process_site_pages: bool = False
@@ -1273,8 +1266,7 @@ class SharepointConnector(
"""
base = f"{self.graph_api_base}/drives/{drive_id}"
if folder_path:
encoded_path = quote(folder_path, safe="/")
start_url = f"{base}/root:/{encoded_path}:/children"
start_url = f"{base}/root:/{folder_path}:/children"
else:
start_url = f"{base}/root/children"
@@ -1330,12 +1322,13 @@ class SharepointConnector(
Falls back to full enumeration if the API returns 410 Gone (expired token).
"""
use_timestamp_token = start is not None and start > _EPOCH
EPOCH = datetime(1970, 1, 1, tzinfo=timezone.utc)
use_timestamp_token = start is not None and start > EPOCH
initial_url = f"{self.graph_api_base}/drives/{drive_id}/root/delta"
if use_timestamp_token:
assert start is not None # mypy
token = quote(start.isoformat(timespec="seconds"))
assert start is not None # purely for mypy
token = quote(start.strftime("%Y-%m-%dT%H:%M:%SZ"))
initial_url += f"?token={token}"
yield from self._iter_delta_pages(
@@ -1368,7 +1361,6 @@ class SharepointConnector(
try:
data = self._graph_api_get_json(page_url, params)
except requests.HTTPError as e:
# 410 means the delta token expired, so we need to fall back to full enumeration
if e.response is not None and e.response.status_code == 410:
if not allow_full_resync:
raise
@@ -1405,91 +1397,10 @@ class SharepointConnector(
yield DriveItemData.from_graph_json(item)
page_url = data.get("@odata.nextLink")
if not page_url:
page_url = data.get("@odata.nextLink") or data.get("@odata.deltaLink")
if "@odata.deltaLink" in data and "@odata.nextLink" not in data:
break
def _build_delta_start_url(
self,
drive_id: str,
start: datetime | None = None,
page_size: int = 200,
) -> str:
"""Build the initial delta API URL with query parameters embedded.
Embeds ``$top`` (and optionally a timestamp ``token``) directly in the
URL so that the returned string is fully self-contained and can be
stored in a checkpoint without needing a separate params dict.
"""
base_url = f"{self.graph_api_base}/drives/{drive_id}/root/delta"
params = [f"$top={page_size}"]
if start is not None and start > _EPOCH:
token = quote(start.isoformat(timespec="seconds"))
params.append(f"token={token}")
return f"{base_url}?{'&'.join(params)}"
def _fetch_one_delta_page(
self,
page_url: str,
drive_id: str,
start: datetime | None = None,
end: datetime | None = None,
page_size: int = 200,
) -> tuple[list[DriveItemData], str | None]:
"""Fetch a single page of delta API results.
Returns ``(items, next_page_url)``. *next_page_url* is ``None`` when
the delta enumeration is complete (deltaLink with no nextLink).
On 410 Gone (expired token) returns ``([], full_resync_url)`` so
the caller can store the resync URL in the checkpoint and retry on
the next cycle.
"""
try:
data = self._graph_api_get_json(page_url)
except requests.HTTPError as e:
if e.response is not None and e.response.status_code == 410:
logger.warning(
"Delta token expired (410 Gone) for drive '%s'. "
"Will restart with full delta enumeration.",
drive_id,
)
full_url = (
f"{self.graph_api_base}/drives/{drive_id}/root/delta"
f"?$top={page_size}"
)
return [], full_url
raise
items: list[DriveItemData] = []
for item in data.get("value", []):
if "folder" in item or "deleted" in item:
continue
if start is not None or end is not None:
raw_ts = item.get("lastModifiedDateTime")
if raw_ts:
mod_dt = datetime.fromisoformat(raw_ts.replace("Z", "+00:00"))
if start is not None and mod_dt < start:
continue
if end is not None and mod_dt > end:
continue
items.append(DriveItemData.from_graph_json(item))
next_url = data.get("@odata.nextLink")
if next_url:
return items, next_url
return items, None
@staticmethod
def _clear_drive_checkpoint_state(
checkpoint: "SharepointConnectorCheckpoint",
) -> None:
"""Reset all drive-level fields in the checkpoint."""
checkpoint.current_drive_name = None
checkpoint.current_drive_id = None
checkpoint.current_drive_web_url = None
checkpoint.current_drive_delta_next_link = None
def _fetch_slim_documents_from_sharepoint(self) -> GenerateSlimDocumentOutput:
site_descriptors = self.site_descriptors or self.fetch_sites()
@@ -1931,13 +1842,14 @@ class SharepointConnector(
# Return checkpoint to allow persistence after drive initialization
return checkpoint
# Phase 3a: Initialize the next drive for processing
# Phase 3: Process documents from current drive
if (
checkpoint.current_site_descriptor
and checkpoint.cached_drive_names
and len(checkpoint.cached_drive_names) > 0
and checkpoint.current_drive_name is None
):
checkpoint.current_drive_name = checkpoint.cached_drive_names.popleft()
start_dt = datetime.fromtimestamp(start, tz=timezone.utc)
@@ -1945,8 +1857,7 @@ class SharepointConnector(
site_descriptor = checkpoint.current_site_descriptor
logger.info(
f"Processing drive '{checkpoint.current_drive_name}' "
f"in site: {site_descriptor.url}"
f"Processing drive '{checkpoint.current_drive_name}' in site: {site_descriptor.url}"
)
logger.debug(f"Time range: {start_dt} to {end_dt}")
@@ -1955,35 +1866,35 @@ class SharepointConnector(
logger.warning("Current drive name is None, skipping")
return checkpoint
driveitems: Iterable[DriveItemData] = iter(())
drive_web_url: str | None = None
try:
logger.info(
f"Fetching drive items for drive name: {current_drive_name}"
)
result = self._resolve_drive(site_descriptor, current_drive_name)
if result is None:
logger.warning(f"Drive '{current_drive_name}' not found, skipping")
self._clear_drive_checkpoint_state(checkpoint)
return checkpoint
drive_id, drive_web_url = result
checkpoint.current_drive_id = drive_id
checkpoint.current_drive_web_url = drive_web_url
if result is not None:
drive_id, drive_web_url = result
driveitems = self._get_drive_items_for_drive_id(
site_descriptor, drive_id, start_dt, end_dt
)
checkpoint.current_drive_web_url = drive_web_url
except Exception as e:
logger.error(
f"Failed to retrieve items from drive '{current_drive_name}' "
f"in site: {site_descriptor.url}: {e}"
f"Failed to retrieve items from drive '{current_drive_name}' in site: {site_descriptor.url}: {e}"
)
yield _create_entity_failure(
f"{site_descriptor.url}|{current_drive_name}",
f"Failed to access drive '{current_drive_name}' "
f"in site '{site_descriptor.url}': {str(e)}",
f"Failed to access drive '{current_drive_name}' in site '{site_descriptor.url}': {str(e)}",
(start_dt, end_dt),
e,
)
self._clear_drive_checkpoint_state(checkpoint)
checkpoint.current_drive_name = None
checkpoint.current_drive_web_url = None
return checkpoint
display_drive_name = SHARED_DOCUMENTS_MAP.get(
# Normalize drive name (e.g., "Documents" -> "Shared Documents")
current_drive_name = SHARED_DOCUMENTS_MAP.get(
current_drive_name, current_drive_name
)
@@ -1991,74 +1902,10 @@ class SharepointConnector(
yield from self._yield_drive_hierarchy_node(
site_descriptor.url,
drive_web_url,
display_drive_name,
current_drive_name,
checkpoint,
)
# For non-folder-scoped drives, use delta API with per-page
# checkpointing. Build the initial URL and fall through to 3b.
if not site_descriptor.folder_path:
checkpoint.current_drive_delta_next_link = self._build_delta_start_url(
drive_id, start_dt
)
# else: BFS path — delta_next_link stays None;
# Phase 3b will use _iter_drive_items_paged.
# Phase 3b: Process items from the current drive
if (
checkpoint.current_site_descriptor
and checkpoint.current_drive_name is not None
and checkpoint.current_drive_id is not None
):
site_descriptor = checkpoint.current_site_descriptor
start_dt = datetime.fromtimestamp(start, tz=timezone.utc)
end_dt = datetime.fromtimestamp(end, tz=timezone.utc)
current_drive_name = SHARED_DOCUMENTS_MAP.get(
checkpoint.current_drive_name, checkpoint.current_drive_name
)
drive_web_url = checkpoint.current_drive_web_url
# --- determine item source ---
driveitems: Iterable[DriveItemData]
has_more_delta_pages = False
if checkpoint.current_drive_delta_next_link:
# Delta path: fetch one page at a time for checkpointing
try:
page_items, next_url = self._fetch_one_delta_page(
page_url=checkpoint.current_drive_delta_next_link,
drive_id=checkpoint.current_drive_id,
start=start_dt,
end=end_dt,
)
except Exception as e:
logger.error(
f"Failed to fetch delta page for drive "
f"'{current_drive_name}': {e}"
)
yield _create_entity_failure(
f"{site_descriptor.url}|{current_drive_name}",
f"Failed to fetch delta page for drive "
f"'{current_drive_name}': {str(e)}",
(start_dt, end_dt),
e,
)
self._clear_drive_checkpoint_state(checkpoint)
return checkpoint
driveitems = page_items
has_more_delta_pages = next_url is not None
if next_url:
checkpoint.current_drive_delta_next_link = next_url
else:
# BFS path (folder-scoped): process all items at once
driveitems = self._iter_drive_items_paged(
drive_id=checkpoint.current_drive_id,
folder_path=site_descriptor.folder_path,
start=start_dt,
end=end_dt,
)
item_count = 0
for driveitem in driveitems:
item_count += 1
@@ -2100,6 +1947,8 @@ class SharepointConnector(
if include_permissions:
ctx = self._create_rest_client_context(site_descriptor.url)
# Re-acquire token in case it expired during a long traversal
# MSAL has a cache that returns the same token while still valid.
access_token = self._get_graph_access_token()
doc_or_failure = _convert_driveitem_to_document_with_permissions(
driveitem,
@@ -2135,11 +1984,8 @@ class SharepointConnector(
)
logger.info(f"Processed {item_count} items in drive '{current_drive_name}'")
if has_more_delta_pages:
return checkpoint
self._clear_drive_checkpoint_state(checkpoint)
checkpoint.current_drive_name = None
checkpoint.current_drive_web_url = None
# Phase 4: Progression logic - determine next step
# If we have more drives in current site, continue with current site

View File

@@ -32,7 +32,6 @@ from onyx.context.search.federated.slack_search_utils import should_include_mess
from onyx.context.search.models import ChunkIndexRequest
from onyx.context.search.models import InferenceChunk
from onyx.db.document import DocumentSource
from onyx.db.models import SearchSettings
from onyx.db.search_settings import get_current_search_settings
from onyx.document_index.document_index_utils import (
get_multipass_config,
@@ -906,15 +905,13 @@ def convert_slack_score(slack_score: float) -> float:
def slack_retrieval(
query: ChunkIndexRequest,
access_token: str,
db_session: Session | None = None,
db_session: Session,
connector: FederatedConnectorDetail | None = None, # noqa: ARG001
entities: dict[str, Any] | None = None,
limit: int | None = None,
slack_event_context: SlackContext | None = None,
bot_token: str | None = None, # Add bot token parameter
team_id: str | None = None,
# Pre-fetched data — when provided, avoids DB query (no session needed)
search_settings: SearchSettings | None = None,
) -> list[InferenceChunk]:
"""
Main entry point for Slack federated search with entity filtering.
@@ -928,7 +925,7 @@ def slack_retrieval(
Args:
query: Search query object
access_token: User OAuth access token
db_session: Database session (optional if search_settings provided)
db_session: Database session
connector: Federated connector detail (unused, kept for backwards compat)
entities: Connector-level config (entity filtering configuration)
limit: Maximum number of results
@@ -1156,10 +1153,7 @@ def slack_retrieval(
# chunk index docs into doc aware chunks
# a single index doc can get split into multiple chunks
if search_settings is None:
if db_session is None:
raise ValueError("Either db_session or search_settings must be provided")
search_settings = get_current_search_settings(db_session)
search_settings = get_current_search_settings(db_session)
embedder = DefaultIndexingEmbedder.from_db_search_settings(
search_settings=search_settings
)

View File

@@ -18,10 +18,8 @@ from onyx.context.search.utils import inference_section_from_chunks
from onyx.db.models import Persona
from onyx.db.models import User
from onyx.document_index.interfaces import DocumentIndex
from onyx.federated_connectors.federated_retrieval import FederatedRetrievalInfo
from onyx.llm.interfaces import LLM
from onyx.natural_language_processing.english_stopwords import strip_stopwords
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
from onyx.secondary_llm_flows.source_filter import extract_source_filter
from onyx.secondary_llm_flows.time_filter import extract_time_filter
from onyx.utils.logger import setup_logger
@@ -43,7 +41,7 @@ def _build_index_filters(
user_file_ids: list[UUID] | None,
persona_document_sets: list[str] | None,
persona_time_cutoff: datetime | None,
db_session: Session | None = None,
db_session: Session,
auto_detect_filters: bool = False,
query: str | None = None,
llm: LLM | None = None,
@@ -51,8 +49,6 @@ def _build_index_filters(
# Assistant knowledge filters
attached_document_ids: list[str] | None = None,
hierarchy_node_ids: list[int] | None = None,
# Pre-fetched ACL filters (skips DB query when provided)
acl_filters: list[str] | None = None,
) -> IndexFilters:
if auto_detect_filters and (llm is None or query is None):
raise RuntimeError("LLM and query are required for auto detect filters")
@@ -107,14 +103,9 @@ def _build_index_filters(
source_filter = list(source_filter) + [DocumentSource.USER_FILE]
logger.debug("Added USER_FILE to source_filter for user knowledge search")
if bypass_acl:
user_acl_filters = None
elif acl_filters is not None:
user_acl_filters = acl_filters
else:
if db_session is None:
raise ValueError("Either db_session or acl_filters must be provided")
user_acl_filters = build_access_filters_for_user(user, db_session)
user_acl_filters = (
None if bypass_acl else build_access_filters_for_user(user, db_session)
)
final_filters = IndexFilters(
user_file_ids=user_file_ids,
@@ -261,15 +252,11 @@ def search_pipeline(
user: User,
# Used for default filters and settings
persona: Persona | None,
db_session: Session | None = None,
db_session: Session,
auto_detect_filters: bool = False,
llm: LLM | None = None,
# If a project ID is provided, it will be exclusively scoped to that project
project_id: int | None = None,
# Pre-fetched data — when provided, avoids DB queries (no session needed)
acl_filters: list[str] | None = None,
embedding_model: EmbeddingModel | None = None,
prefetched_federated_retrieval_infos: list[FederatedRetrievalInfo] | None = None,
) -> list[InferenceChunk]:
user_uploaded_persona_files: list[UUID] | None = (
[user_file.id for user_file in persona.user_files] if persona else None
@@ -310,7 +297,6 @@ def search_pipeline(
bypass_acl=chunk_search_request.bypass_acl,
attached_document_ids=attached_document_ids,
hierarchy_node_ids=hierarchy_node_ids,
acl_filters=acl_filters,
)
query_keywords = strip_stopwords(chunk_search_request.query)
@@ -329,8 +315,6 @@ def search_pipeline(
user_id=user.id if user else None,
document_index=document_index,
db_session=db_session,
embedding_model=embedding_model,
prefetched_federated_retrieval_infos=prefetched_federated_retrieval_infos,
)
# For some specific connectors like Salesforce, a user that has access to an object doesn't mean

View File

@@ -14,11 +14,9 @@ from onyx.context.search.utils import get_query_embedding
from onyx.context.search.utils import inference_section_from_chunks
from onyx.document_index.interfaces import DocumentIndex
from onyx.document_index.interfaces import VespaChunkRequest
from onyx.federated_connectors.federated_retrieval import FederatedRetrievalInfo
from onyx.federated_connectors.federated_retrieval import (
get_federated_retrieval_functions,
)
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
@@ -52,14 +50,9 @@ def combine_retrieval_results(
def _embed_and_search(
query_request: ChunkIndexRequest,
document_index: DocumentIndex,
db_session: Session | None = None,
embedding_model: EmbeddingModel | None = None,
db_session: Session,
) -> list[InferenceChunk]:
query_embedding = get_query_embedding(
query_request.query,
db_session=db_session,
embedding_model=embedding_model,
)
query_embedding = get_query_embedding(query_request.query, db_session)
hybrid_alpha = query_request.hybrid_alpha or HYBRID_ALPHA
@@ -85,9 +78,7 @@ def search_chunks(
query_request: ChunkIndexRequest,
user_id: UUID | None,
document_index: DocumentIndex,
db_session: Session | None = None,
embedding_model: EmbeddingModel | None = None,
prefetched_federated_retrieval_infos: list[FederatedRetrievalInfo] | None = None,
db_session: Session,
) -> list[InferenceChunk]:
run_queries: list[tuple[Callable, tuple]] = []
@@ -97,22 +88,14 @@ def search_chunks(
else None
)
# Federated retrieval — use pre-fetched if available, otherwise query DB
if prefetched_federated_retrieval_infos is not None:
federated_retrieval_infos = prefetched_federated_retrieval_infos
else:
if db_session is None:
raise ValueError(
"Either db_session or prefetched_federated_retrieval_infos "
"must be provided"
)
federated_retrieval_infos = get_federated_retrieval_functions(
db_session=db_session,
user_id=user_id,
source_types=list(source_filters) if source_filters else None,
document_set_names=query_request.filters.document_set,
user_file_ids=query_request.filters.user_file_ids,
)
# Federated retrieval
federated_retrieval_infos = get_federated_retrieval_functions(
db_session=db_session,
user_id=user_id,
source_types=list(source_filters) if source_filters else None,
document_set_names=query_request.filters.document_set,
user_file_ids=query_request.filters.user_file_ids,
)
federated_sources = set(
federated_retrieval_info.source.to_non_federated_source()
@@ -131,10 +114,7 @@ def search_chunks(
if normal_search_enabled:
run_queries.append(
(
_embed_and_search,
(query_request, document_index, db_session, embedding_model),
)
(_embed_and_search, (query_request, document_index, db_session))
)
parallel_search_results = run_functions_tuples_in_parallel(run_queries)

View File

@@ -64,34 +64,23 @@ def inference_section_from_single_chunk(
)
def get_query_embeddings(
queries: list[str],
db_session: Session | None = None,
embedding_model: EmbeddingModel | None = None,
) -> list[Embedding]:
if embedding_model is None:
if db_session is None:
raise ValueError("Either db_session or embedding_model must be provided")
search_settings = get_current_search_settings(db_session)
embedding_model = EmbeddingModel.from_db_model(
search_settings=search_settings,
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
def get_query_embeddings(queries: list[str], db_session: Session) -> list[Embedding]:
search_settings = get_current_search_settings(db_session)
query_embedding = embedding_model.encode(queries, text_type=EmbedTextType.QUERY)
model = EmbeddingModel.from_db_model(
search_settings=search_settings,
# The below are globally set, this flow always uses the indexing one
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
query_embedding = model.encode(queries, text_type=EmbedTextType.QUERY)
return query_embedding
@log_function_time(print_only=True, debug_only=True)
def get_query_embedding(
query: str,
db_session: Session | None = None,
embedding_model: EmbeddingModel | None = None,
) -> Embedding:
return get_query_embeddings(
[query], db_session=db_session, embedding_model=embedding_model
)[0]
def get_query_embedding(query: str, db_session: Session) -> Embedding:
return get_query_embeddings([query], db_session)[0]
def convert_inference_sections_to_search_docs(

View File

@@ -4,7 +4,6 @@ from fastapi_users.password import PasswordHelper
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.auth.api_key import ApiKeyDescriptor
@@ -55,7 +54,6 @@ async def fetch_user_for_api_key(
select(User)
.join(ApiKey, ApiKey.user_id == User.id)
.where(ApiKey.hashed_api_key == hashed_api_key)
.options(selectinload(User.memories))
)

View File

@@ -13,7 +13,6 @@ from sqlalchemy import func
from sqlalchemy import Select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.future import select
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.auth.schemas import UserRole
@@ -98,11 +97,6 @@ async def get_user_count(only_admin_users: bool = False) -> int:
# Need to override this because FastAPI Users doesn't give flexibility for backend field creation logic in OAuth flow
class SQLAlchemyUserAdminDB(SQLAlchemyUserDatabase[UP, ID]):
async def _get_user(self, statement: Select) -> UP | None:
statement = statement.options(selectinload(User.memories))
results = await self.session.execute(statement)
return results.unique().scalar_one_or_none()
async def create(
self,
create_dict: Dict[str, Any],

View File

@@ -116,15 +116,12 @@ def get_connector_credential_pairs_for_user(
order_by_desc: bool = False,
source: DocumentSource | None = None,
processing_mode: ProcessingMode | None = ProcessingMode.REGULAR,
defer_connector_config: bool = False,
) -> list[ConnectorCredentialPair]:
"""Get connector credential pairs for a user.
Args:
processing_mode: Filter by processing mode. Defaults to REGULAR to hide
FILE_SYSTEM connectors from standard admin UI. Pass None to get all.
defer_connector_config: If True, skips loading Connector.connector_specific_config
to avoid fetching large JSONB blobs when they aren't needed.
"""
if eager_load_user:
assert (
@@ -133,10 +130,7 @@ def get_connector_credential_pairs_for_user(
stmt = select(ConnectorCredentialPair).distinct()
if eager_load_connector:
connector_load = selectinload(ConnectorCredentialPair.connector)
if defer_connector_config:
connector_load = connector_load.defer(Connector.connector_specific_config)
stmt = stmt.options(connector_load)
stmt = stmt.options(selectinload(ConnectorCredentialPair.connector))
if eager_load_credential:
load_opts = selectinload(ConnectorCredentialPair.credential)
@@ -176,7 +170,6 @@ def get_connector_credential_pairs_for_user_parallel(
order_by_desc: bool = False,
source: DocumentSource | None = None,
processing_mode: ProcessingMode | None = ProcessingMode.REGULAR,
defer_connector_config: bool = False,
) -> list[ConnectorCredentialPair]:
with get_session_with_current_tenant() as db_session:
return get_connector_credential_pairs_for_user(
@@ -190,7 +183,6 @@ def get_connector_credential_pairs_for_user_parallel(
order_by_desc=order_by_desc,
source=source,
processing_mode=processing_mode,
defer_connector_config=defer_connector_config,
)

View File

@@ -554,19 +554,10 @@ def fetch_all_document_sets_for_user(
stmt = (
select(DocumentSetDBModel)
.distinct()
.options(
selectinload(DocumentSetDBModel.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSetDBModel.users),
selectinload(DocumentSetDBModel.groups),
selectinload(DocumentSetDBModel.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
)
.options(selectinload(DocumentSetDBModel.federated_connectors))
)
stmt = _add_user_filters(stmt, user, get_editable=get_editable)
return db_session.scalars(stmt).unique().all()
return db_session.scalars(stmt).all()
def fetch_documents_for_document_set_paginated(

View File

@@ -232,12 +232,6 @@ class BuildSessionStatus(str, PyEnum):
IDLE = "idle"
class SharingScope(str, PyEnum):
PRIVATE = "private"
PUBLIC_ORG = "public_org"
PUBLIC_GLOBAL = "public_global"
class SandboxStatus(str, PyEnum):
PROVISIONING = "provisioning"
RUNNING = "running"

View File

@@ -77,7 +77,6 @@ from onyx.db.enums import (
ThemePreference,
DefaultAppMode,
SwitchoverType,
SharingScope,
)
from onyx.configs.constants import NotificationType
from onyx.configs.constants import SearchFeedbackType
@@ -287,7 +286,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
# relationships
credentials: Mapped[list["Credential"]] = relationship(
"Credential", back_populates="user"
"Credential", back_populates="user", lazy="joined"
)
chat_sessions: Mapped[list["ChatSession"]] = relationship(
"ChatSession", back_populates="user"
@@ -321,6 +320,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
"Memory",
back_populates="user",
cascade="all, delete-orphan",
lazy="selectin",
order_by="desc(Memory.id)",
)
oauth_user_tokens: Mapped[list["OAuthUserToken"]] = relationship(
@@ -1040,9 +1040,7 @@ class OpenSearchTenantMigrationRecord(Base):
nullable=False,
)
# Opaque continuation token from Vespa's Visit API.
# NULL means "not started".
# Otherwise contains a serialized mapping between slice ID and continuation
# token for that slice.
# NULL means "not started" or "visit completed".
vespa_visit_continuation_token: Mapped[str | None] = mapped_column(
Text, nullable=True
)
@@ -1066,9 +1064,6 @@ class OpenSearchTenantMigrationRecord(Base):
enable_opensearch_retrieval: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=False
)
approx_chunk_count_in_vespa: Mapped[int | None] = mapped_column(
Integer, nullable=True
)
class KGEntityType(Base):
@@ -4717,12 +4712,6 @@ class BuildSession(Base):
demo_data_enabled: Mapped[bool] = mapped_column(
Boolean, nullable=False, server_default=text("true")
)
sharing_scope: Mapped[SharingScope] = mapped_column(
String,
nullable=False,
default=SharingScope.PRIVATE,
server_default="private",
)
# Relationships
user: Mapped[User | None] = relationship("User", foreign_keys=[user_id])
@@ -4939,7 +4928,6 @@ class ScimUserMapping(Base):
user_id: Mapped[UUID] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), unique=True, nullable=False
)
scim_username: Mapped[str | None] = mapped_column(String, nullable=True)
created_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False

View File

@@ -4,7 +4,6 @@ This module provides functions to track the progress of migrating documents
from Vespa to OpenSearch.
"""
import json
from datetime import datetime
from datetime import timezone
@@ -13,9 +12,6 @@ from sqlalchemy import text
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.orm import Session
from onyx.background.celery.tasks.opensearch_migration.constants import (
GET_VESPA_CHUNKS_SLICE_COUNT,
)
from onyx.background.celery.tasks.opensearch_migration.constants import (
TOTAL_ALLOWABLE_DOC_MIGRATION_ATTEMPTS_BEFORE_PERMANENT_FAILURE,
)
@@ -247,37 +243,29 @@ def should_document_migration_be_permanently_failed(
def get_vespa_visit_state(
db_session: Session,
) -> tuple[dict[int, str | None], int]:
) -> tuple[str | None, int]:
"""Gets the current Vespa migration state from the tenant migration record.
Requires the OpenSearchTenantMigrationRecord to exist.
Returns:
Tuple of (continuation_token_map, total_chunks_migrated).
Tuple of (continuation_token, total_chunks_migrated). continuation_token
is None if not started or completed.
"""
record = db_session.query(OpenSearchTenantMigrationRecord).first()
if record is None:
raise RuntimeError("OpenSearchTenantMigrationRecord not found.")
if record.vespa_visit_continuation_token is None:
continuation_token_map: dict[int, str | None] = {
slice_id: None for slice_id in range(GET_VESPA_CHUNKS_SLICE_COUNT)
}
else:
json_loaded_continuation_token_map = json.loads(
record.vespa_visit_continuation_token
)
continuation_token_map = {
int(key): value for key, value in json_loaded_continuation_token_map.items()
}
return continuation_token_map, record.total_chunks_migrated
return (
record.vespa_visit_continuation_token,
record.total_chunks_migrated,
)
def update_vespa_visit_progress_with_commit(
db_session: Session,
continuation_token_map: dict[int, str | None],
continuation_token: str | None,
chunks_processed: int,
chunks_errored: int,
approx_chunk_count_in_vespa: int | None,
) -> None:
"""Updates the Vespa migration progress and commits.
@@ -285,26 +273,19 @@ def update_vespa_visit_progress_with_commit(
Args:
db_session: SQLAlchemy session.
continuation_token_map: The new continuation token map. None entry means
the visit is complete for that slice.
continuation_token: The new continuation token. None means the visit
is complete.
chunks_processed: Number of chunks processed in this batch (added to
the running total).
chunks_errored: Number of chunks errored in this batch (added to the
running errored total).
approx_chunk_count_in_vespa: Approximate number of chunks in Vespa. If
None, the existing value is used.
"""
record = db_session.query(OpenSearchTenantMigrationRecord).first()
if record is None:
raise RuntimeError("OpenSearchTenantMigrationRecord not found.")
record.vespa_visit_continuation_token = json.dumps(continuation_token_map)
record.vespa_visit_continuation_token = continuation_token
record.total_chunks_migrated += chunks_processed
record.total_chunks_errored += chunks_errored
record.approx_chunk_count_in_vespa = (
approx_chunk_count_in_vespa
if approx_chunk_count_in_vespa is not None
else record.approx_chunk_count_in_vespa
)
db_session.commit()
@@ -372,27 +353,25 @@ def build_sanitized_to_original_doc_id_mapping(
def get_opensearch_migration_state(
db_session: Session,
) -> tuple[int, datetime | None, datetime | None, int | None]:
) -> tuple[int, datetime | None, datetime | None]:
"""Returns the state of the Vespa to OpenSearch migration.
If the tenant migration record is not found, returns defaults of 0, None,
None, None.
None.
Args:
db_session: SQLAlchemy session.
Returns:
Tuple of (total_chunks_migrated, created_at, migration_completed_at,
approx_chunk_count_in_vespa).
Tuple of (total_chunks_migrated, created_at, migration_completed_at).
"""
record = db_session.query(OpenSearchTenantMigrationRecord).first()
if record is None:
return 0, None, None, None
return 0, None, None
return (
record.total_chunks_migrated,
record.created_at,
record.migration_completed_at,
record.approx_chunk_count_in_vespa,
)

View File

@@ -8,7 +8,6 @@ from uuid import UUID
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload
from sqlalchemy.orm import Session
from onyx.auth.pat import build_displayable_pat
@@ -32,59 +31,53 @@ async def fetch_user_for_pat(
NOTE: This is async since it's used during auth (which is necessarily async due to FastAPI Users).
NOTE: Expired includes both naturally expired and user-revoked tokens (revocation sets expires_at=NOW()).
Uses select(User) as primary entity so that joined-eager relationships (e.g. oauth_accounts)
are loaded correctly — matching the pattern in fetch_user_for_api_key.
"""
# Single joined query with all filters pushed to database
now = datetime.now(timezone.utc)
user = await async_db_session.scalar(
select(User)
.join(PersonalAccessToken, PersonalAccessToken.user_id == User.id)
result = await async_db_session.execute(
select(PersonalAccessToken, User)
.join(User, PersonalAccessToken.user_id == User.id)
.where(PersonalAccessToken.hashed_token == hashed_token)
.where(User.is_active) # type: ignore
.where(
(PersonalAccessToken.expires_at.is_(None))
| (PersonalAccessToken.expires_at > now)
)
.options(selectinload(User.memories))
.limit(1)
)
if not user:
row = result.first()
if not row:
return None
_schedule_pat_last_used_update(hashed_token, now)
return user
pat, user = row
# Throttle last_used_at updates to reduce DB load (5-minute granularity sufficient for auditing)
# For request-level auditing, use application logs or a dedicated audit table
should_update = (
pat.last_used_at is None or (now - pat.last_used_at).total_seconds() > 300
)
def _schedule_pat_last_used_update(hashed_token: str, now: datetime) -> None:
"""Fire-and-forget update of last_used_at, throttled to 5-minute granularity."""
async def _update() -> None:
try:
tenant_id = get_current_tenant_id()
async with get_async_session_context_manager(tenant_id) as session:
pat = await session.scalar(
select(PersonalAccessToken).where(
PersonalAccessToken.hashed_token == hashed_token
if should_update:
# Update in separate session to avoid transaction coupling (fire-and-forget)
async def _update_last_used() -> None:
try:
tenant_id = get_current_tenant_id()
async with get_async_session_context_manager(
tenant_id
) as separate_session:
await separate_session.execute(
update(PersonalAccessToken)
.where(PersonalAccessToken.hashed_token == hashed_token)
.values(last_used_at=now)
)
)
if not pat:
return
if (
pat.last_used_at is not None
and (now - pat.last_used_at).total_seconds() <= 300
):
return
await session.execute(
update(PersonalAccessToken)
.where(PersonalAccessToken.hashed_token == hashed_token)
.values(last_used_at=now)
)
await session.commit()
except Exception as e:
logger.warning(f"Failed to update last_used_at for PAT: {e}")
await separate_session.commit()
except Exception as e:
logger.warning(f"Failed to update last_used_at for PAT: {e}")
asyncio.create_task(_update())
asyncio.create_task(_update_last_used())
return user
def create_pat(

View File

@@ -28,7 +28,6 @@ from onyx.db.document_access import get_accessible_documents_by_ids
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Document
from onyx.db.models import DocumentSet
from onyx.db.models import FederatedConnector__DocumentSet
from onyx.db.models import HierarchyNode
from onyx.db.models import Persona
from onyx.db.models import Persona__User
@@ -421,16 +420,9 @@ def get_minimal_persona_snapshots_for_user(
stmt = stmt.options(
selectinload(Persona.tools),
selectinload(Persona.labels),
selectinload(Persona.document_sets).options(
selectinload(DocumentSet.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSet.users),
selectinload(DocumentSet.groups),
selectinload(DocumentSet.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
),
selectinload(Persona.document_sets)
.selectinload(DocumentSet.connector_credential_pairs)
.selectinload(ConnectorCredentialPair.connector),
selectinload(Persona.hierarchy_nodes),
selectinload(Persona.attached_documents).selectinload(
Document.parent_hierarchy_node
@@ -461,16 +453,7 @@ def get_persona_snapshots_for_user(
Document.parent_hierarchy_node
),
selectinload(Persona.labels),
selectinload(Persona.document_sets).options(
selectinload(DocumentSet.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSet.users),
selectinload(DocumentSet.groups),
selectinload(DocumentSet.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
),
selectinload(Persona.document_sets),
selectinload(Persona.user),
selectinload(Persona.user_files),
selectinload(Persona.users),
@@ -567,16 +550,9 @@ def get_minimal_persona_snapshots_paginated(
Document.parent_hierarchy_node
),
selectinload(Persona.labels),
selectinload(Persona.document_sets).options(
selectinload(DocumentSet.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSet.users),
selectinload(DocumentSet.groups),
selectinload(DocumentSet.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
),
selectinload(Persona.document_sets)
.selectinload(DocumentSet.connector_credential_pairs)
.selectinload(ConnectorCredentialPair.connector),
selectinload(Persona.user),
)
@@ -635,16 +611,7 @@ def get_persona_snapshots_paginated(
Document.parent_hierarchy_node
),
selectinload(Persona.labels),
selectinload(Persona.document_sets).options(
selectinload(DocumentSet.connector_credential_pairs).selectinload(
ConnectorCredentialPair.connector
),
selectinload(DocumentSet.users),
selectinload(DocumentSet.groups),
selectinload(DocumentSet.federated_connectors).selectinload(
FederatedConnector__DocumentSet.federated_connector
),
),
selectinload(Persona.document_sets),
selectinload(Persona.user),
selectinload(Persona.user_files),
selectinload(Persona.users),

View File

@@ -54,9 +54,6 @@ class SearchHit(BaseModel, Generic[SchemaDocumentModel]):
# Maps schema property name to a list of highlighted snippets with match
# terms wrapped in tags (e.g. "something <hi>keyword</hi> other thing").
match_highlights: dict[str, list[str]] = {}
# Score explanation from OpenSearch when "explain": true is set in the query.
# Contains detailed breakdown of how the score was calculated.
explanation: dict[str, Any] | None = None
def get_new_body_without_vectors(body: dict[str, Any]) -> dict[str, Any]:
@@ -709,12 +706,10 @@ class OpenSearchClient:
)
document_chunk_score = hit.get("_score", None)
match_highlights: dict[str, list[str]] = hit.get("highlight", {})
explanation: dict[str, Any] | None = hit.get("_explanation", None)
search_hit = SearchHit[DocumentChunk](
document_chunk=DocumentChunk.model_validate(document_chunk_source),
score=document_chunk_score,
match_highlights=match_highlights,
explanation=explanation,
)
search_hits.append(search_hit)
logger.debug(

View File

@@ -10,31 +10,31 @@ EF_CONSTRUCTION = 256
# quality but increase memory footprint. Values typically range between 12 - 48.
M = 32 # Set relatively high for better accuracy.
# When performing hybrid search, we need to consider more candidates than the number of results to be returned.
# This is because the scoring is hybrid and the results are reordered due to the hybrid scoring.
# Higher = more candidates for hybrid fusion = better retrieval accuracy, but results in more computation per query.
# Imagine a simple case with a single keyword query and a single vector query and we want 10 final docs.
# If we only fetch 10 candidates from each of keyword and vector, they would have to have perfect overlap to get a good hybrid
# ranking for the 10 results. If we fetch 1000 candidates from each, we have a much higher chance of all 10 of the final desired
# docs showing up and getting scored. In worse situations, the final 10 docs don't even show up as the final 10 (worse than just
# a miss at the reranking step).
DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES = 750
# Number of vectors to examine for top k neighbors for the HNSW method.
EF_SEARCH = DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES
# Should be >= DEFAULT_K_NUM_CANDIDATES for good recall; higher = better accuracy, slower search.
# Bumped this to 1000, for dataset of low 10,000 docs, did not see improvement in recall.
EF_SEARCH = 256
# The default number of neighbors to consider for knn vector similarity search.
# We need this higher than the number of results because the scoring is hybrid.
# If there is only 1 query, setting k equal to the number of results is enough,
# but since there is heavy reordering due to hybrid scoring, we need to set k higher.
# Higher = more candidates for hybrid fusion = better retrieval accuracy, more query cost.
DEFAULT_K_NUM_CANDIDATES = 50 # TODO likely need to bump this way higher
# Since the titles are included in the contents, they are heavily downweighted as they act as a boost
# rather than an independent scoring component.
SEARCH_TITLE_VECTOR_WEIGHT = 0.1
SEARCH_CONTENT_VECTOR_WEIGHT = 0.45
# Single keyword weight for both title and content (merged from former title keyword + content keyword).
SEARCH_KEYWORD_WEIGHT = 0.45
SEARCH_TITLE_KEYWORD_WEIGHT = 0.1
SEARCH_CONTENT_VECTOR_WEIGHT = 0.4
SEARCH_CONTENT_KEYWORD_WEIGHT = 0.4
# NOTE: it is critical that the order of these weights matches the order of the sub-queries in the hybrid search.
HYBRID_SEARCH_NORMALIZATION_WEIGHTS = [
SEARCH_TITLE_VECTOR_WEIGHT,
SEARCH_TITLE_KEYWORD_WEIGHT,
SEARCH_CONTENT_VECTOR_WEIGHT,
SEARCH_KEYWORD_WEIGHT,
SEARCH_CONTENT_KEYWORD_WEIGHT,
]
assert sum(HYBRID_SEARCH_NORMALIZATION_WEIGHTS) == 1.0

View File

@@ -842,8 +842,6 @@ class OpenSearchDocumentIndex(DocumentIndex):
body=query_body,
search_pipeline_id=ZSCORE_NORMALIZATION_PIPELINE_NAME,
)
# Good place for a breakpoint to inspect the search hits if you have "explain" enabled.
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
_convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
search_hit.document_chunk, search_hit.score, search_hit.match_highlights

View File

@@ -11,7 +11,6 @@ from pydantic import model_serializer
from pydantic import model_validator
from pydantic import SerializerFunctionWrapHandler
from onyx.configs.app_configs import OPENSEARCH_TEXT_ANALYZER
from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE
from onyx.document_index.opensearch.constants import EF_CONSTRUCTION
@@ -55,11 +54,6 @@ SECONDARY_OWNERS_FIELD_NAME = "secondary_owners"
ANCESTOR_HIERARCHY_NODE_IDS_FIELD_NAME = "ancestor_hierarchy_node_ids"
# Faiss was also tried but it didn't have any benefits
# NMSLIB is deprecated, not recommended
OPENSEARCH_KNN_ENGINE = "lucene"
def get_opensearch_doc_chunk_id(
tenant_state: TenantState,
document_id: str,
@@ -349,9 +343,6 @@ class DocumentSchema:
"properties": {
TITLE_FIELD_NAME: {
"type": "text",
# Language analyzer (e.g. english) stems at index and search time for variant matching.
# Configure via OPENSEARCH_TEXT_ANALYZER. Existing indices need reindexing after a change.
"analyzer": OPENSEARCH_TEXT_ANALYZER,
"fields": {
# Subfield accessed as title.keyword. Not indexed for
# values longer than 256 chars.
@@ -366,7 +357,9 @@ class DocumentSchema:
CONTENT_FIELD_NAME: {
"type": "text",
"store": True,
"analyzer": OPENSEARCH_TEXT_ANALYZER,
# This makes highlighting text during queries more efficient
# at the cost of disk space. See
# https://docs.opensearch.org/latest/search-plugins/searching-data/highlight/#methods-of-obtaining-offsets
"index_options": "offsets",
},
TITLE_VECTOR_FIELD_NAME: {
@@ -375,7 +368,7 @@ class DocumentSchema:
"method": {
"name": "hnsw",
"space_type": "cosinesimil",
"engine": OPENSEARCH_KNN_ENGINE,
"engine": "lucene",
"parameters": {"ef_construction": EF_CONSTRUCTION, "m": M},
},
},
@@ -387,7 +380,7 @@ class DocumentSchema:
"method": {
"name": "hnsw",
"space_type": "cosinesimil",
"engine": OPENSEARCH_KNN_ENGINE,
"engine": "lucene",
"parameters": {"ef_construction": EF_CONSTRUCTION, "m": M},
},
},

View File

@@ -6,16 +6,13 @@ from typing import Any
from uuid import UUID
from onyx.configs.app_configs import DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S
from onyx.configs.app_configs import OPENSEARCH_EXPLAIN_ENABLED
from onyx.configs.app_configs import OPENSEARCH_PROFILING_DISABLED
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import INDEX_SEPARATOR
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import Tag
from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.opensearch.constants import (
DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES,
)
from onyx.document_index.opensearch.constants import DEFAULT_K_NUM_CANDIDATES
from onyx.document_index.opensearch.constants import HYBRID_SEARCH_NORMALIZATION_WEIGHTS
from onyx.document_index.opensearch.schema import ACCESS_CONTROL_LIST_FIELD_NAME
from onyx.document_index.opensearch.schema import ANCESTOR_HIERARCHY_NODE_IDS_FIELD_NAME
@@ -243,9 +240,6 @@ class DocumentQuery:
Returns:
A dictionary representing the final hybrid search query.
"""
# WARNING: Profiling does not work with hybrid search; do not add it at
# this level. See https://github.com/opensearch-project/neural-search/issues/1255
if num_hits > DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW:
raise ValueError(
f"Bug: num_hits ({num_hits}) is greater than the current maximum allowed "
@@ -253,7 +247,7 @@ class DocumentQuery:
)
hybrid_search_subqueries = DocumentQuery._get_hybrid_search_subqueries(
query_text, query_vector
query_text, query_vector, num_candidates=DEFAULT_K_NUM_CANDIDATES
)
hybrid_search_filters = DocumentQuery._get_search_filters(
tenant_state=tenant_state,
@@ -281,31 +275,25 @@ class DocumentQuery:
hybrid_search_query: dict[str, Any] = {
"hybrid": {
"queries": hybrid_search_subqueries,
# Max results per subquery per shard before aggregation. Ensures keyword and vector
# subqueries contribute equally to the candidate pool for hybrid fusion.
# Sources:
# https://docs.opensearch.org/latest/vector-search/ai-search/hybrid-search/pagination/
# https://opensearch.org/blog/navigating-pagination-in-hybrid-queries-with-the-pagination_depth-parameter/
"pagination_depth": DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES,
# Applied to all the sub-queries independently (this avoids having subqueries having a lot of results thrown out).
# Sources:
# Applied to all the sub-queries. Source:
# https://docs.opensearch.org/latest/query-dsl/compound/hybrid/
# https://opensearch.org/blog/introducing-common-filter-support-for-hybrid-search-queries
# Does AND for each filter in the list.
"filter": {"bool": {"filter": hybrid_search_filters}},
}
}
# NOTE: By default, hybrid search retrieves "size"-many results from
# each OpenSearch shard before aggregation. Source:
# https://docs.opensearch.org/latest/vector-search/ai-search/hybrid-search/pagination/
final_hybrid_search_body: dict[str, Any] = {
"query": hybrid_search_query,
"size": num_hits,
"highlight": match_highlights_configuration,
"timeout": f"{DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S}s",
}
# Explain is for scoring breakdowns.
if OPENSEARCH_EXPLAIN_ENABLED:
final_hybrid_search_body["explain"] = True
# WARNING: Profiling does not work with hybrid search; do not add it at
# this level. See https://github.com/opensearch-project/neural-search/issues/1255
return final_hybrid_search_body
@@ -367,12 +355,7 @@ class DocumentQuery:
@staticmethod
def _get_hybrid_search_subqueries(
query_text: str,
query_vector: list[float],
# The default number of neighbors to consider for knn vector similarity search.
# This is higher than the number of results because the scoring is hybrid.
# for a detailed breakdown, see where the default value is set.
vector_candidates: int = DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES,
query_text: str, query_vector: list[float], num_candidates: int
) -> list[dict[str, Any]]:
"""Returns subqueries for hybrid search.
@@ -384,8 +367,9 @@ class DocumentQuery:
Matches:
- Title vector
- Title keyword
- Content vector
- Keyword (title + content, match and phrase)
- Content keyword + phrase
Normalization is not performed here.
The weights of each of these subqueries should be configured in a search
@@ -406,9 +390,9 @@ class DocumentQuery:
NOTE: Options considered and rejected:
- minimum_should_match: Since it's hybrid search and users often provide semantic queries, there is often a lot of terms,
and very low number of meaningful keywords (and a low ratio of keywords).
- fuzziness AUTO: typo tolerance (0/1/2 edit distance by term length). It's mostly for typos as the analyzer ("english by
default") already does some stemming and tokenization. In testing datasets, this makes recall slightly worse. It also is
less performant so not really any reason to do it.
- fuzziness AUTO: typo tolerance (0/1/2 edit distance by term length). This is reasonable but in reality seeing the
user usage patterns, this is not very common and people tend to not be confused when a miss happens for this reason.
In testing datasets, this makes recall slightly worse.
Args:
query_text: The text of the query to search for.
@@ -417,27 +401,19 @@ class DocumentQuery:
similarity search.
"""
# Build sub-queries for hybrid search. Order must match normalization
# pipeline weights: title vector, content vector, keyword (title + content).
# pipeline weights: title vector, title keyword, content vector,
# content keyword.
hybrid_search_queries: list[dict[str, Any]] = [
# 1. Title vector search
{
"knn": {
TITLE_VECTOR_FIELD_NAME: {
"vector": query_vector,
"k": vector_candidates,
"k": num_candidates,
}
}
},
# 2. Content vector search
{
"knn": {
CONTENT_VECTOR_FIELD_NAME: {
"vector": query_vector,
"k": vector_candidates,
}
}
},
# 3. Keyword (title + content) match and phrase search.
# 2. Title keyword + phrase search.
{
"bool": {
"should": [
@@ -445,10 +421,8 @@ class DocumentQuery:
"match": {
TITLE_FIELD_NAME: {
"query": query_text,
# operator "or" = match doc if any query term matches (default, explicit for clarity).
"operator": "or",
# The title fields are strongly discounted as they are included in the content.
# It just acts as a minor boost
"boost": 0.1,
}
}
},
@@ -456,17 +430,35 @@ class DocumentQuery:
"match_phrase": {
TITLE_FIELD_NAME: {
"query": query_text,
# Slop = 1 allows one extra word or transposition in phrase match.
"slop": 1,
"boost": 0.2,
# Boost phrase over bag-of-words; exact phrase is a stronger signal.
"boost": 1.5,
}
}
},
]
}
},
# 3. Content vector search
{
"knn": {
CONTENT_VECTOR_FIELD_NAME: {
"vector": query_vector,
"k": num_candidates,
}
}
},
# 4. Content keyword + phrase search.
{
"bool": {
"should": [
{
"match": {
CONTENT_FIELD_NAME: {
"query": query_text,
# operator "or" = match doc if any query term matches (default, explicit for clarity).
"operator": "or",
"boost": 1.0,
}
}
},
@@ -474,7 +466,9 @@ class DocumentQuery:
"match_phrase": {
CONTENT_FIELD_NAME: {
"query": query_text,
# Slop = 1 allows one extra word or transposition in phrase match.
"slop": 1,
# Boost phrase over bag-of-words; exact phrase is a stronger signal.
"boost": 1.5,
}
}

View File

@@ -10,12 +10,6 @@ from typing import cast
import httpx
from retry import retry
from onyx.background.celery.tasks.opensearch_migration.constants import (
FINISHED_VISITING_SLICE_CONTINUATION_TOKEN,
)
from onyx.background.celery.tasks.opensearch_migration.transformer import (
FIELDS_NEEDED_FOR_TRANSFORMATION,
)
from onyx.configs.app_configs import LOG_VESPA_TIMING_INFORMATION
from onyx.configs.app_configs import VESPA_LANGUAGE_OVERRIDE
from onyx.context.search.models import IndexFilters
@@ -283,139 +277,54 @@ def get_chunks_via_visit_api(
def get_all_chunks_paginated(
index_name: str,
tenant_state: TenantState,
continuation_token_map: dict[int, str | None],
page_size: int,
) -> tuple[list[dict], dict[int, str | None]]:
continuation_token: str | None = None,
page_size: int = 1_000,
) -> tuple[list[dict], str | None]:
"""Gets all chunks in Vespa matching the filters, paginated.
Uses the Visit API with slicing. Each continuation token map entry is for a
different slice. The number of entries determines the number of slices.
Args:
index_name: The name of the Vespa index to visit.
tenant_state: The tenant state to filter by.
continuation_token_map: Map of slice ID to a token returned by Vespa
representing a page offset. None to start from the beginning of the
slice.
continuation_token: Token returned by Vespa representing a page offset.
None to start from the beginning. Defaults to None.
page_size: Best-effort batch size for the visit. Defaults to 1,000.
Returns:
Tuple of (list of chunk dicts, next continuation token or None). The
continuation token is None when the visit is complete.
"""
url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name)
def _get_all_chunks_paginated_for_slice(
index_name: str,
tenant_state: TenantState,
slice_id: int,
total_slices: int,
continuation_token: str | None,
page_size: int,
) -> tuple[list[dict], str | None]:
if continuation_token == FINISHED_VISITING_SLICE_CONTINUATION_TOKEN:
logger.debug(
f"Slice {slice_id} has finished visiting. Returning empty list and {FINISHED_VISITING_SLICE_CONTINUATION_TOKEN}."
)
return [], FINISHED_VISITING_SLICE_CONTINUATION_TOKEN
selection: str = f"{index_name}.large_chunk_reference_ids == null"
if MULTI_TENANT:
selection += f" and {index_name}.tenant_id=='{tenant_state.tenant_id}'"
url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name)
selection: str = f"{index_name}.large_chunk_reference_ids == null"
if MULTI_TENANT:
selection += f" and {index_name}.tenant_id=='{tenant_state.tenant_id}'"
field_set = f"{index_name}:" + ",".join(FIELDS_NEEDED_FOR_TRANSFORMATION)
params: dict[str, str | int | None] = {
"selection": selection,
"fieldSet": field_set,
"wantedDocumentCount": page_size,
"format.tensors": "short-value",
"slices": total_slices,
"sliceId": slice_id,
}
if continuation_token is not None:
params["continuation"] = continuation_token
response: httpx.Response | None = None
try:
with get_vespa_http_client() as http_client:
response = http_client.get(url, params=params)
response.raise_for_status()
except httpx.HTTPError as e:
error_base = f"Failed to get chunks from Vespa slice {slice_id} with continuation token {continuation_token}."
logger.exception(
f"Request URL: {e.request.url}\n"
f"Request Headers: {e.request.headers}\n"
f"Request Payload: {params}\n"
)
error_message = (
response.json().get("message") if response else "No response"
)
logger.error("Error message from response: %s", error_message)
raise httpx.HTTPError(error_base) from e
response_data = response.json()
# NOTE: If we see a falsey value for "continuation" in the response we
# assume we are done and return
# FINISHED_VISITING_SLICE_CONTINUATION_TOKEN instead.
next_continuation_token = (
response_data.get("continuation")
or FINISHED_VISITING_SLICE_CONTINUATION_TOKEN
)
chunks = [chunk["fields"] for chunk in response_data.get("documents", [])]
if next_continuation_token == FINISHED_VISITING_SLICE_CONTINUATION_TOKEN:
logger.debug(
f"Slice {slice_id} has finished visiting. Returning {len(chunks)} chunks and {next_continuation_token}."
)
return chunks, next_continuation_token
total_slices = len(continuation_token_map)
if total_slices < 1:
raise ValueError("continuation_token_map must have at least one entry.")
# We want to guarantee that these invocations are ordered by slice_id,
# because we read in the same order below when parsing parallel_results.
functions_with_args: list[tuple[Callable, tuple]] = [
(
_get_all_chunks_paginated_for_slice,
(
index_name,
tenant_state,
slice_id,
total_slices,
continuation_token,
page_size,
),
)
for slice_id, continuation_token in sorted(continuation_token_map.items())
]
parallel_results = run_functions_tuples_in_parallel(
functions_with_args, allow_failures=True
)
if len(parallel_results) != total_slices:
raise RuntimeError(
f"Expected {total_slices} parallel results, but got {len(parallel_results)}."
)
chunks: list[dict] = []
next_continuation_token_map: dict[int, str | None] = {
key: value for key, value in continuation_token_map.items()
params: dict[str, str | int | None] = {
"selection": selection,
"wantedDocumentCount": page_size,
"format.tensors": "short-value",
}
for i, parallel_result in enumerate(parallel_results):
if i not in next_continuation_token_map:
raise RuntimeError(f"Slice {i} is not in the continuation token map.")
if parallel_result is None:
logger.error(
f"Failed to get chunks for slice {i} of {total_slices}. "
"The continuation token for this slice will not be updated."
)
continue
chunks.extend(parallel_result[0])
next_continuation_token_map[i] = parallel_result[1]
if continuation_token is not None:
params["continuation"] = continuation_token
return chunks, next_continuation_token_map
try:
with get_vespa_http_client() as http_client:
response = http_client.get(url, params=params)
response.raise_for_status()
except httpx.HTTPError as e:
error_base = "Failed to get chunks in Vespa."
logger.exception(
f"Request URL: {e.request.url}\n"
f"Request Headers: {e.request.headers}\n"
f"Request Payload: {params}\n"
)
raise httpx.HTTPError(error_base) from e
response_data = response.json()
return [
chunk["fields"] for chunk in response_data.get("documents", [])
], response_data.get("continuation") or None
# TODO(rkuo): candidate for removal if not being used

View File

@@ -56,7 +56,6 @@ from onyx.document_index.vespa_constants import CONTENT_SUMMARY
from onyx.document_index.vespa_constants import DOCUMENT_ID
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
from onyx.document_index.vespa_constants import NUM_THREADS
from onyx.document_index.vespa_constants import SEARCH_ENDPOINT
from onyx.document_index.vespa_constants import VESPA_TIMEOUT
from onyx.document_index.vespa_constants import YQL_BASE
from onyx.indexing.models import DocMetadataAwareIndexChunk
@@ -653,9 +652,9 @@ class VespaDocumentIndex(DocumentIndex):
def get_all_raw_document_chunks_paginated(
self,
continuation_token_map: dict[int, str | None],
continuation_token: str | None,
page_size: int,
) -> tuple[list[dict[str, Any]], dict[int, str | None]]:
) -> tuple[list[dict[str, Any]], str | None]:
"""Gets all the chunks in Vespa, paginated.
Used in the chunk-level Vespa-to-OpenSearch migration task.
@@ -663,21 +662,21 @@ class VespaDocumentIndex(DocumentIndex):
Args:
continuation_token: Token returned by Vespa representing a page
offset. None to start from the beginning. Defaults to None.
page_size: Best-effort batch size for the visit.
page_size: Best-effort batch size for the visit. Defaults to 1,000.
Returns:
Tuple of (list of chunk dicts, next continuation token or None). The
continuation token is None when the visit is complete.
"""
raw_chunks, next_continuation_token_map = get_all_chunks_paginated(
raw_chunks, next_continuation_token = get_all_chunks_paginated(
index_name=self._index_name,
tenant_state=TenantState(
tenant_id=self._tenant_id, multitenant=MULTI_TENANT
),
continuation_token_map=continuation_token_map,
continuation_token=continuation_token,
page_size=page_size,
)
return raw_chunks, next_continuation_token_map
return raw_chunks, next_continuation_token
def index_raw_chunks(self, chunks: list[dict[str, Any]]) -> None:
"""Indexes raw document chunks into Vespa.
@@ -703,32 +702,3 @@ class VespaDocumentIndex(DocumentIndex):
json={"fields": chunk},
)
response.raise_for_status()
def get_chunk_count(self) -> int:
"""Returns the exact number of document chunks in Vespa for this tenant.
Uses the Vespa Search API with `limit 0` and `ranking.profile=unranked`
to get an exact count without fetching any document data.
Includes large chunks. There is no way to filter these out using the
Search API.
"""
where_clause = (
f'tenant_id contains "{self._tenant_id}"' if self._multitenant else "true"
)
yql = (
f"select documentid from {self._index_name} "
f"where {where_clause} "
f"limit 0"
)
params: dict[str, str | int] = {
"yql": yql,
"ranking.profile": "unranked",
"timeout": VESPA_TIMEOUT,
}
with get_vespa_http_client() as http_client:
response = http_client.post(SEARCH_ENDPOINT, json=params)
response.raise_for_status()
response_data = response.json()
return response_data["root"]["fields"]["totalCount"]

View File

@@ -20,20 +20,7 @@ class ImageGenerationProviderCredentials(BaseModel):
custom_config: dict[str, str] | None = None
class ReferenceImage(BaseModel):
data: bytes
mime_type: str
class ImageGenerationProvider(abc.ABC):
@property
def supports_reference_images(self) -> bool:
return False
@property
def max_reference_images(self) -> int:
return 0
@classmethod
@abc.abstractmethod
def validate_credentials(
@@ -76,7 +63,6 @@ class ImageGenerationProvider(abc.ABC):
size: str,
n: int,
quality: str | None = None,
reference_images: list[ReferenceImage] | None = None,
**kwargs: Any,
) -> ImageGenerationResponse:
"""Generates an image based on a prompt."""

View File

@@ -5,7 +5,6 @@ from typing import TYPE_CHECKING
from onyx.image_gen.interfaces import ImageGenerationProvider
from onyx.image_gen.interfaces import ImageGenerationProviderCredentials
from onyx.image_gen.interfaces import ReferenceImage
if TYPE_CHECKING:
from onyx.image_gen.interfaces import ImageGenerationResponse
@@ -60,7 +59,6 @@ class AzureImageGenerationProvider(ImageGenerationProvider):
size: str,
n: int,
quality: str | None = None,
reference_images: list[ReferenceImage] | None = None, # noqa: ARG002
**kwargs: Any,
) -> ImageGenerationResponse:
from litellm import image_generation

View File

@@ -5,7 +5,6 @@ from typing import TYPE_CHECKING
from onyx.image_gen.interfaces import ImageGenerationProvider
from onyx.image_gen.interfaces import ImageGenerationProviderCredentials
from onyx.image_gen.interfaces import ReferenceImage
if TYPE_CHECKING:
from onyx.image_gen.interfaces import ImageGenerationResponse
@@ -46,7 +45,6 @@ class OpenAIImageGenerationProvider(ImageGenerationProvider):
size: str,
n: int,
quality: str | None = None,
reference_images: list[ReferenceImage] | None = None, # noqa: ARG002
**kwargs: Any,
) -> ImageGenerationResponse:
from litellm import image_generation

View File

@@ -1,8 +1,6 @@
from __future__ import annotations
import base64
import json
from datetime import datetime
from typing import Any
from typing import TYPE_CHECKING
@@ -11,7 +9,6 @@ from pydantic import BaseModel
from onyx.image_gen.exceptions import ImageProviderCredentialsError
from onyx.image_gen.interfaces import ImageGenerationProvider
from onyx.image_gen.interfaces import ImageGenerationProviderCredentials
from onyx.image_gen.interfaces import ReferenceImage
if TYPE_CHECKING:
from onyx.image_gen.interfaces import ImageGenerationResponse
@@ -54,15 +51,6 @@ class VertexImageGenerationProvider(ImageGenerationProvider):
vertex_credentials=vertex_credentials,
)
@property
def supports_reference_images(self) -> bool:
return True
@property
def max_reference_images(self) -> int:
# Gemini image editing supports up to 14 input images.
return 14
def generate_image(
self,
prompt: str,
@@ -70,18 +58,8 @@ class VertexImageGenerationProvider(ImageGenerationProvider):
size: str,
n: int,
quality: str | None = None,
reference_images: list[ReferenceImage] | None = None,
**kwargs: Any,
) -> ImageGenerationResponse:
if reference_images:
return self._generate_image_with_reference_images(
prompt=prompt,
model=model,
size=size,
n=n,
reference_images=reference_images,
)
from litellm import image_generation
return image_generation(
@@ -96,99 +74,6 @@ class VertexImageGenerationProvider(ImageGenerationProvider):
**kwargs,
)
def _generate_image_with_reference_images(
self,
prompt: str,
model: str,
size: str,
n: int,
reference_images: list[ReferenceImage],
) -> ImageGenerationResponse:
from google import genai
from google.genai import types as genai_types
from google.oauth2 import service_account
from litellm.types.utils import ImageObject
from litellm.types.utils import ImageResponse
service_account_info = json.loads(self._vertex_credentials)
credentials = service_account.Credentials.from_service_account_info(
service_account_info,
scopes=["https://www.googleapis.com/auth/cloud-platform"],
)
client = genai.Client(
vertexai=True,
project=self._vertex_project,
location=self._vertex_location,
credentials=credentials,
)
parts: list[genai_types.Part] = [
genai_types.Part.from_bytes(data=image.data, mime_type=image.mime_type)
for image in reference_images
]
parts.append(genai_types.Part.from_text(text=prompt))
config = genai_types.GenerateContentConfig(
response_modalities=["TEXT", "IMAGE"],
candidate_count=max(1, n),
image_config=genai_types.ImageConfig(
aspect_ratio=_map_size_to_aspect_ratio(size)
),
)
model_name = model.replace("vertex_ai/", "")
response = client.models.generate_content(
model=model_name,
contents=genai_types.Content(
role="user",
parts=parts,
),
config=config,
)
generated_data: list[ImageObject] = []
for candidate in response.candidates or []:
candidate_content = candidate.content
if not candidate_content:
continue
for part in candidate_content.parts or []:
inline_data = part.inline_data
if not inline_data or inline_data.data is None:
continue
if isinstance(inline_data.data, bytes):
b64_json = base64.b64encode(inline_data.data).decode("utf-8")
elif isinstance(inline_data.data, str):
b64_json = inline_data.data
else:
continue
generated_data.append(
ImageObject(
b64_json=b64_json,
revised_prompt=prompt,
)
)
if not generated_data:
raise RuntimeError("No image data returned from Vertex AI.")
return ImageResponse(
created=int(datetime.now().timestamp()),
data=generated_data,
)
def _map_size_to_aspect_ratio(size: str) -> str:
return {
"1024x1024": "1:1",
"1792x1024": "16:9",
"1024x1792": "9:16",
"1536x1024": "3:2",
"1024x1536": "2:3",
}.get(size, "1:1")
def _parse_to_vertex_credentials(
credentials: ImageGenerationProviderCredentials,

View File

@@ -64,6 +64,21 @@
"model_vendor": "anthropic",
"model_version": "20241022-v2:0"
},
"anthropic.claude-3-7-sonnet-20240620-v1:0": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20240620-v1:0"
},
"anthropic.claude-3-7-sonnet-20250219-v1:0": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20250219-v1:0"
},
"anthropic.claude-3-haiku-20240307-v1:0": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307-v1:0"
},
"anthropic.claude-3-sonnet-20240229-v1:0": {
"display_name": "Claude Sonnet 3",
"model_vendor": "anthropic",
@@ -144,6 +159,11 @@
"model_vendor": "anthropic",
"model_version": "20241022-v2:0"
},
"apac.anthropic.claude-3-haiku-20240307-v1:0": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307-v1:0"
},
"apac.anthropic.claude-3-sonnet-20240229-v1:0": {
"display_name": "Claude Sonnet 3",
"model_vendor": "anthropic",
@@ -1300,6 +1320,11 @@
"model_vendor": "anthropic",
"model_version": "20240620-v1:0"
},
"bedrock/us-gov-east-1/anthropic.claude-3-haiku-20240307-v1:0": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307-v1:0"
},
"bedrock/us-gov-east-1/claude-sonnet-4-5-20250929-v1:0": {
"display_name": "Claude Sonnet 4.5",
"model_vendor": "anthropic",
@@ -1340,6 +1365,16 @@
"model_vendor": "anthropic",
"model_version": "20240620-v1:0"
},
"bedrock/us-gov-west-1/anthropic.claude-3-7-sonnet-20250219-v1:0": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20250219-v1:0"
},
"bedrock/us-gov-west-1/anthropic.claude-3-haiku-20240307-v1:0": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307-v1:0"
},
"bedrock/us-gov-west-1/claude-sonnet-4-5-20250929-v1:0": {
"display_name": "Claude Sonnet 4.5",
"model_vendor": "anthropic",
@@ -1470,6 +1505,26 @@
"model_vendor": "anthropic",
"model_version": "latest"
},
"claude-3-7-sonnet-20250219": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20250219"
},
"claude-3-7-sonnet-latest": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "latest"
},
"claude-3-7-sonnet@20250219": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20250219"
},
"claude-3-haiku-20240307": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307"
},
"claude-4-opus-20250514": {
"display_name": "Claude Opus 4",
"model_vendor": "anthropic",
@@ -1650,6 +1705,16 @@
"model_vendor": "anthropic",
"model_version": "20241022-v2:0"
},
"eu.anthropic.claude-3-7-sonnet-20250219-v1:0": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20250219-v1:0"
},
"eu.anthropic.claude-3-haiku-20240307-v1:0": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307-v1:0"
},
"eu.anthropic.claude-3-sonnet-20240229-v1:0": {
"display_name": "Claude Sonnet 3",
"model_vendor": "anthropic",
@@ -3161,6 +3226,15 @@
"model_vendor": "anthropic",
"model_version": "latest"
},
"openrouter/anthropic/claude-3-haiku": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic"
},
"openrouter/anthropic/claude-3-haiku-20240307": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307"
},
"openrouter/anthropic/claude-3-sonnet": {
"display_name": "Claude Sonnet 3",
"model_vendor": "anthropic"
@@ -3175,6 +3249,16 @@
"model_vendor": "anthropic",
"model_version": "latest"
},
"openrouter/anthropic/claude-3.7-sonnet": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "latest"
},
"openrouter/anthropic/claude-3.7-sonnet:beta": {
"display_name": "Claude Sonnet 3.7:beta",
"model_vendor": "anthropic",
"model_version": "latest"
},
"openrouter/anthropic/claude-haiku-4.5": {
"display_name": "Claude Haiku 4.5",
"model_vendor": "anthropic",
@@ -3666,6 +3750,16 @@
"model_vendor": "anthropic",
"model_version": "20241022"
},
"us.anthropic.claude-3-7-sonnet-20250219-v1:0": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20250219"
},
"us.anthropic.claude-3-haiku-20240307-v1:0": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307"
},
"us.anthropic.claude-3-sonnet-20240229-v1:0": {
"display_name": "Claude Sonnet 3",
"model_vendor": "anthropic",
@@ -3785,6 +3879,20 @@
"model_vendor": "anthropic",
"model_version": "20240620"
},
"vertex_ai/claude-3-7-sonnet@20250219": {
"display_name": "Claude Sonnet 3.7",
"model_vendor": "anthropic",
"model_version": "20250219"
},
"vertex_ai/claude-3-haiku": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic"
},
"vertex_ai/claude-3-haiku@20240307": {
"display_name": "Claude Haiku 3",
"model_vendor": "anthropic",
"model_version": "20240307"
},
"vertex_ai/claude-3-sonnet": {
"display_name": "Claude Sonnet 3",
"model_vendor": "anthropic"

View File

@@ -1,7 +1,5 @@
import json
import pathlib
import threading
import time
from onyx.llm.constants import LlmProviderNames
from onyx.llm.constants import PROVIDER_DISPLAY_NAMES
@@ -25,11 +23,6 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
_RECOMMENDATIONS_CACHE_TTL_SECONDS = 300
_recommendations_cache_lock = threading.Lock()
_cached_recommendations: LLMRecommendations | None = None
_cached_recommendations_time: float = 0.0
def _get_provider_to_models_map() -> dict[str, list[str]]:
"""Lazy-load provider model mappings to avoid importing litellm at module level.
@@ -48,40 +41,19 @@ def _get_provider_to_models_map() -> dict[str, list[str]]:
}
def _load_bundled_recommendations() -> LLMRecommendations:
def get_recommendations() -> LLMRecommendations:
"""Get the recommendations from the GitHub config."""
recommendations_from_github = fetch_llm_recommendations_from_github()
if recommendations_from_github:
return recommendations_from_github
# Fall back to json bundled with code
json_path = pathlib.Path(__file__).parent / "recommended-models.json"
with open(json_path, "r") as f:
json_config = json.load(f)
return LLMRecommendations.model_validate(json_config)
def get_recommendations() -> LLMRecommendations:
"""Get the recommendations, with an in-memory cache to avoid
hitting GitHub on every API request."""
global _cached_recommendations, _cached_recommendations_time
now = time.monotonic()
if (
_cached_recommendations is not None
and (now - _cached_recommendations_time) < _RECOMMENDATIONS_CACHE_TTL_SECONDS
):
return _cached_recommendations
with _recommendations_cache_lock:
# Double-check after acquiring lock
if (
_cached_recommendations is not None
and (time.monotonic() - _cached_recommendations_time)
< _RECOMMENDATIONS_CACHE_TTL_SECONDS
):
return _cached_recommendations
recommendations_from_github = fetch_llm_recommendations_from_github()
result = recommendations_from_github or _load_bundled_recommendations()
_cached_recommendations = result
_cached_recommendations_time = time.monotonic()
return result
recommendations_from_json = LLMRecommendations.model_validate(json_config)
return recommendations_from_json
def is_obsolete_model(model_name: str, provider: str) -> bool:
@@ -243,23 +215,6 @@ def model_configurations_for_provider(
) -> list[ModelConfigurationView]:
recommended_visible_models = llm_recommendations.get_visible_models(provider_name)
recommended_visible_models_names = [m.name for m in recommended_visible_models]
# Preserve provider-defined ordering while de-duplicating.
model_names: list[str] = []
seen_model_names: set[str] = set()
for model_name in (
fetch_models_for_provider(provider_name) + recommended_visible_models_names
):
if model_name in seen_model_names:
continue
seen_model_names.add(model_name)
model_names.append(model_name)
# Vertex model list can be large and mixed-vendor; alphabetical ordering
# makes model discovery easier in admin selection UIs.
if provider_name == VERTEXAI_PROVIDER_NAME:
model_names = sorted(model_names, key=str.lower)
return [
ModelConfigurationView(
name=model_name,
@@ -267,7 +222,8 @@ def model_configurations_for_provider(
max_input_tokens=get_max_input_tokens(model_name, provider_name),
supports_image_input=model_supports_image_input(model_name, provider_name),
)
for model_name in model_names
for model_name in set(fetch_models_for_provider(provider_name))
| set(recommended_visible_models_names)
]

View File

@@ -52,7 +52,6 @@ from onyx.configs.app_configs import USER_AUTH_SECRET
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import AuthType
from onyx.configs.constants import POSTGRES_WEB_APP_NAME
from onyx.db.engine.async_sql_engine import get_sqlalchemy_async_engine
from onyx.db.engine.connection_warmup import warm_up_connections
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import SqlEngine
@@ -64,7 +63,7 @@ from onyx.server.documents.connector import router as connector_router
from onyx.server.documents.credential import router as credential_router
from onyx.server.documents.document import router as document_router
from onyx.server.documents.standard_oauth import router as standard_oauth_router
from onyx.server.features.build.api.api import public_build_router
from onyx.server.features.build.api.api import nextjs_assets_router
from onyx.server.features.build.api.api import router as build_router
from onyx.server.features.default_assistant.api import (
router as default_assistant_router,
@@ -115,16 +114,13 @@ from onyx.server.manage.users import router as user_router
from onyx.server.manage.web_search.api import (
admin_router as web_search_admin_router,
)
from onyx.server.metrics.postgres_connection_pool import (
setup_postgres_connection_pool_metrics,
)
from onyx.server.metrics.prometheus_setup import setup_prometheus_metrics
from onyx.server.middleware.latency_logging import add_latency_logging_middleware
from onyx.server.middleware.rate_limiting import close_auth_limiter
from onyx.server.middleware.rate_limiting import get_auth_rate_limiters
from onyx.server.middleware.rate_limiting import setup_auth_limiter
from onyx.server.onyx_api.ingestion import router as onyx_api_router
from onyx.server.pat.api import router as pat_router
from onyx.server.prometheus_instrumentation import setup_prometheus_metrics
from onyx.server.query_and_chat.chat_backend import router as chat_router
from onyx.server.query_and_chat.query_backend import (
admin_router as admin_query_router,
@@ -142,7 +138,6 @@ from onyx.setup import setup_onyx
from onyx.tracing.setup import setup_tracing
from onyx.utils.logger import setup_logger
from onyx.utils.logger import setup_uvicorn_logger
from onyx.utils.middleware import add_endpoint_context_middleware
from onyx.utils.middleware import add_onyx_request_id_middleware
from onyx.utils.telemetry import get_or_generate_uuid
from onyx.utils.telemetry import optional_telemetry
@@ -271,17 +266,6 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
max_overflow=POSTGRES_API_SERVER_READ_ONLY_POOL_OVERFLOW,
)
# Register pool metrics now that engines are created.
# HTTP instrumentation is set up earlier in get_application() since it
# adds middleware (which Starlette forbids after the app has started).
setup_postgres_connection_pool_metrics(
engines={
"sync": SqlEngine.get_engine(),
"async": get_sqlalchemy_async_engine(),
"readonly": SqlEngine.get_readonly_engine(),
},
)
verify_auth = fetch_versioned_implementation(
"onyx.auth.users", "verify_auth_setting"
)
@@ -394,8 +378,8 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
include_router_with_global_prefix_prepended(application, admin_input_prompt_router)
include_router_with_global_prefix_prepended(application, cc_pair_router)
include_router_with_global_prefix_prepended(application, projects_router)
include_router_with_global_prefix_prepended(application, public_build_router)
include_router_with_global_prefix_prepended(application, build_router)
include_router_with_global_prefix_prepended(application, nextjs_assets_router)
include_router_with_global_prefix_prepended(application, document_set_router)
include_router_with_global_prefix_prepended(application, hierarchy_router)
include_router_with_global_prefix_prepended(application, search_settings_router)
@@ -576,18 +560,12 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
add_onyx_request_id_middleware(application, "API", logger)
# Set endpoint context for per-endpoint DB pool attribution metrics.
# Must be registered after all routes are added.
add_endpoint_context_middleware(application)
# HTTP request metrics (latency histograms, in-progress gauge, slow request
# counter). Must be called here — before the app starts — because the
# instrumentator adds middleware via app.add_middleware().
setup_prometheus_metrics(application)
# Ensure all routes have auth enabled or are explicitly marked as public
check_router_auth(application)
# Initialize and instrument the app with production Prometheus config
setup_prometheus_metrics(application)
use_route_function_names_as_operation_ids(application)
return application

View File

@@ -69,12 +69,6 @@ Very briefly describe the image(s) generated. Do not include any links or attach
""".strip()
FILE_REMINDER = """
Your code execution generated file(s) with download links.
If you reference or share these files, use the exact markdown format [filename](file_link) with the file_link from the execution result.
""".strip()
# Specifically for OpenAI models, this prefix needs to be in place for the model to output markdown and correct styling
CODE_BLOCK_MARKDOWN = "Formatting re-enabled. "

View File

@@ -1,6 +1,6 @@
# ruff: noqa: E501, W605 start
# If there are any tools, this section is included, the sections below are for the available tools
TOOL_SECTION_HEADER = "\n# Tools\n\n"
TOOL_SECTION_HEADER = "\n\n# Tools\n"
# This section is included if there are search type tools, currently internal_search and web_search
@@ -16,10 +16,11 @@ When searching for information, if the initial results cannot fully answer the u
Do not repeat the same or very similar queries if it already has been run in the chat history.
If it is unclear which tool to use, consider using multiple in parallel to be efficient with time.
""".lstrip()
"""
INTERNAL_SEARCH_GUIDANCE = """
## internal_search
Use the `internal_search` tool to search connected applications for information. Some examples of when to use `internal_search` include:
- Internal information: any time where there may be some information stored in internal applications that could help better answer the query.
@@ -27,31 +28,34 @@ Use the `internal_search` tool to search connected applications for information.
- Keyword Queries: queries that are heavily keyword based are often internal document search queries.
- Ambiguity: questions about something that is not widely known or understood.
Never provide more than 3 queries at once to `internal_search`.
""".lstrip()
"""
WEB_SEARCH_GUIDANCE = """
## web_search
Use the `web_search` tool to access up-to-date information from the web. Some examples of when to use `web_search` include:
- Freshness: when the answer might be enhanced by up-to-date information on a topic. Very important for topics that are changing or evolving.
- Accuracy: if the cost of outdated/inaccurate information is high.
- Niche Information: when detailed info is not widely known or understood (but is likely found on the internet).{site_colon_disabled}
""".lstrip()
"""
WEB_SEARCH_SITE_DISABLED_GUIDANCE = """
Do not use the "site:" operator in your web search queries.
""".lstrip()
""".rstrip()
OPEN_URLS_GUIDANCE = """
## open_url
Use the `open_url` tool to read the content of one or more URLs. Use this tool to access the contents of the most promising web pages from your web searches or user specified URLs. \
You can open many URLs at once by passing multiple URLs in the array if multiple pages seem promising. Prioritize the most promising pages and reputable sources. \
Do not open URLs that are image files like .png, .jpg, etc.
You should almost always use open_url after a web_search call. Use this tool when a user asks about a specific provided URL.
""".lstrip()
"""
PYTHON_TOOL_GUIDANCE = """
## python
Use the `python` tool to execute Python code in an isolated sandbox. The tool will respond with the output of the execution or time out after 60.0 seconds.
Any files uploaded to the chat will be automatically be available in the execution environment's current directory. \
@@ -60,21 +64,21 @@ Use this to give the user a way to download the file OR to display generated ima
Internet access for this session is disabled. Do not make external web requests or API calls as they will fail.
Use `openpyxl` to read and write Excel files. You have access to libraries like numpy, pandas, scipy, matplotlib, and PIL.
IMPORTANT: each call to this tool is independent. Variables from previous calls will NOT be available in the current call.
""".lstrip()
"""
GENERATE_IMAGE_GUIDANCE = """
## generate_image
NEVER use generate_image unless the user specifically requests an image.
For edits/variations of a previously generated image, pass `reference_image_file_ids` with
the `file_id` values returned by earlier `generate_image` tool results.
""".lstrip()
"""
MEMORY_GUIDANCE = """
## add_memory
Use the `add_memory` tool for facts shared by the user that should be remembered for future conversations. \
Only add memories that are specific, likely to remain true, and likely to be useful later. \
Focus on enduring preferences, long-term goals, stable constraints, and explicit "remember this" type requests.
""".lstrip()
"""
TOOL_CALL_FAILURE_PROMPT = """
LLM attempted to call a tool but failed. Most likely the tool name or arguments were misspelled.

View File

@@ -1,36 +1,40 @@
# ruff: noqa: E501, W605 start
USER_INFORMATION_HEADER = "\n# User Information\n\n"
USER_INFORMATION_HEADER = "\n\n# User Information\n"
BASIC_INFORMATION_PROMPT = """
## Basic Information
User name: {user_name}
User email: {user_email}{user_role}
""".lstrip()
"""
# This line only shows up if the user has configured their role.
USER_ROLE_PROMPT = """
User role: {user_role}
""".lstrip()
"""
# Team information should be a paragraph style description of the user's team.
TEAM_INFORMATION_PROMPT = """
## Team Information
{team_information}
""".lstrip()
"""
# User preferences should be a paragraph style description of the user's preferences.
USER_PREFERENCES_PROMPT = """
## User Preferences
{user_preferences}
""".lstrip()
"""
# User memories should look something like:
# - Memory 1
# - Memory 2
# - Memory 3
USER_MEMORIES_PROMPT = """
## User Memories
{user_memories}
""".lstrip()
"""
# ruff: noqa: E501, W605 end

View File

@@ -109,7 +109,6 @@ class TenantRedis(redis.Redis):
"unlock",
"get",
"set",
"setex",
"delete",
"exists",
"incrby",

View File

@@ -59,9 +59,6 @@ PUBLIC_ENDPOINT_SPECS = [
# anonymous user on cloud
("/tenants/anonymous-user", {"POST"}),
("/metrics", {"GET"}), # added by prometheus_fastapi_instrumentator
# craft webapp proxy — access enforced per-session via sharing_scope in handler
("/build/sessions/{session_id}/webapp", {"GET"}),
("/build/sessions/{session_id}/webapp/{path:path}", {"GET"}),
]

View File

@@ -103,7 +103,6 @@ from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import IndexingMode
from onyx.db.enums import ProcessingMode
from onyx.db.federated import fetch_all_federated_connectors_parallel
from onyx.db.index_attempt import get_index_attempts_for_cc_pair
from onyx.db.index_attempt import get_latest_index_attempts_by_status
@@ -988,7 +987,6 @@ def get_connector_status(
user=user,
eager_load_connector=True,
eager_load_credential=True,
eager_load_user=True,
get_editable=False,
)
@@ -1002,23 +1000,11 @@ def get_connector_status(
relationship.user_group_id
)
# Pre-compute credential_ids per connector to avoid N+1 lazy loads
connector_to_credential_ids: dict[int, list[int]] = {}
for cc_pair in cc_pairs:
connector_to_credential_ids.setdefault(cc_pair.connector_id, []).append(
cc_pair.credential_id
)
return [
ConnectorStatus(
cc_pair_id=cc_pair.id,
name=cc_pair.name,
connector=ConnectorSnapshot.from_connector_db_model(
cc_pair.connector,
credential_ids=connector_to_credential_ids.get(
cc_pair.connector_id, []
),
),
connector=ConnectorSnapshot.from_connector_db_model(cc_pair.connector),
credential=CredentialSnapshot.from_credential_db_model(cc_pair.credential),
access_type=cc_pair.access_type,
groups=group_cc_pair_relationships_dict.get(cc_pair.id, []),
@@ -1073,27 +1059,15 @@ def get_connector_indexing_status(
parallel_functions: list[tuple[CallableProtocol, tuple[Any, ...]]] = [
# Get editable connector/credential pairs
(
lambda: get_connector_credential_pairs_for_user_parallel(
user, True, None, True, True, False, True, request.source
),
(),
get_connector_credential_pairs_for_user_parallel,
(user, True, None, True, True, True, True, request.source),
),
# Get federated connectors
(fetch_all_federated_connectors_parallel, ()),
# Get most recent index attempts
(
lambda: get_latest_index_attempts_parallel(
request.secondary_index, True, False
),
(),
),
(get_latest_index_attempts_parallel, (request.secondary_index, True, False)),
# Get most recent finished index attempts
(
lambda: get_latest_index_attempts_parallel(
request.secondary_index, True, True
),
(),
),
(get_latest_index_attempts_parallel, (request.secondary_index, True, True)),
]
if user and user.role == UserRole.ADMIN:
@@ -1110,10 +1084,8 @@ def get_connector_indexing_status(
parallel_functions.append(
# Get non-editable connector/credential pairs
(
lambda: get_connector_credential_pairs_for_user_parallel(
user, False, None, True, True, False, True, request.source
),
(),
get_connector_credential_pairs_for_user_parallel,
(user, False, None, True, True, True, True, request.source),
),
)
@@ -1939,7 +1911,6 @@ Tenant ID: {tenant_id}
class BasicCCPairInfo(BaseModel):
has_successful_run: bool
source: DocumentSource
status: ConnectorCredentialPairStatus
@router.get("/connector-status", tags=PUBLIC_API_TAGS)
@@ -1953,17 +1924,13 @@ def get_basic_connector_indexing_status(
get_editable=False,
user=user,
)
# NOTE: This endpoint excludes Craft connectors
return [
BasicCCPairInfo(
has_successful_run=cc_pair.last_successful_index_time is not None,
source=cc_pair.connector.source,
status=cc_pair.status,
)
for cc_pair in cc_pairs
if cc_pair.connector.source != DocumentSource.INGESTION_API
and cc_pair.processing_mode == ProcessingMode.REGULAR
]

View File

@@ -365,8 +365,7 @@ class CCPairFullInfo(BaseModel):
in_repeated_error_state=cc_pair_model.in_repeated_error_state,
num_docs_indexed=num_docs_indexed,
connector=ConnectorSnapshot.from_connector_db_model(
cc_pair_model.connector,
credential_ids=[cc_pair_model.credential_id],
cc_pair_model.connector
),
credential=CredentialSnapshot.from_credential_db_model(
cc_pair_model.credential

View File

@@ -1,5 +1,4 @@
from collections.abc import Iterator
from pathlib import Path
from uuid import UUID
import httpx
@@ -8,19 +7,16 @@ from fastapi import Depends
from fastapi import HTTPException
from fastapi import Request
from fastapi import Response
from fastapi.responses import RedirectResponse
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from onyx.auth.users import current_user
from onyx.auth.users import optional_user
from onyx.configs.constants import DocumentSource
from onyx.db.connector_credential_pair import get_connector_credential_pairs_for_user
from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import IndexingStatus
from onyx.db.enums import ProcessingMode
from onyx.db.enums import SharingScope
from onyx.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
from onyx.db.models import BuildSession
from onyx.db.models import User
@@ -221,15 +217,12 @@ def get_build_connectors(
return BuildConnectorListResponse(connectors=connectors)
# Headers to skip when proxying.
# Hop-by-hop headers must not be forwarded, and set-cookie is stripped to
# prevent LLM-generated apps from setting cookies on the parent Onyx domain.
# Headers to skip when proxying (hop-by-hop headers)
EXCLUDED_HEADERS = {
"content-encoding",
"content-length",
"transfer-encoding",
"connection",
"set-cookie",
}
@@ -287,7 +280,7 @@ def _get_sandbox_url(session_id: UUID, db_session: Session) -> str:
db_session: Database session
Returns:
Internal URL to proxy requests to
The internal URL to proxy requests to
Raises:
HTTPException: If session not found, port not allocated, or sandbox not found
@@ -301,10 +294,12 @@ def _get_sandbox_url(session_id: UUID, db_session: Session) -> str:
if session.user_id is None:
raise HTTPException(status_code=404, detail="User not found")
# Get the user's sandbox to get the sandbox_id
sandbox = get_sandbox_by_user_id(db_session, session.user_id)
if sandbox is None:
raise HTTPException(status_code=404, detail="Sandbox not found")
# Use sandbox manager to get the correct internal URL
sandbox_manager = get_sandbox_manager()
return sandbox_manager.get_webapp_url(sandbox.id, session.nextjs_port)
@@ -370,73 +365,71 @@ def _proxy_request(
raise HTTPException(status_code=502, detail="Bad gateway")
def _check_webapp_access(
session_id: UUID, user: User | None, db_session: Session
) -> BuildSession:
"""Check if user can access a session's webapp.
- public_global: accessible by anyone (no auth required)
- public_org: accessible by any authenticated user
- private: only accessible by the session owner
"""
session = db_session.get(BuildSession, session_id)
if not session:
raise HTTPException(status_code=404, detail="Session not found")
if session.sharing_scope == SharingScope.PUBLIC_GLOBAL:
return session
if user is None:
raise HTTPException(status_code=401, detail="Authentication required")
if session.sharing_scope == SharingScope.PRIVATE and session.user_id != user.id:
raise HTTPException(status_code=404, detail="Session not found")
return session
_OFFLINE_HTML_PATH = Path(__file__).parent / "templates" / "webapp_offline.html"
def _offline_html_response() -> Response:
"""Return a branded Craft HTML page when the sandbox is not reachable.
Design mirrors the default Craft web template (outputs/web/app/page.tsx):
terminal window aesthetic with Minecraft-themed typing animation.
"""
html = _OFFLINE_HTML_PATH.read_text()
return Response(content=html, status_code=503, media_type="text/html")
# Public router for webapp proxy — no authentication required
# (access controlled per-session via sharing_scope)
public_build_router = APIRouter(prefix="/build")
@public_build_router.get("/sessions/{session_id}/webapp", response_model=None)
@public_build_router.get(
"/sessions/{session_id}/webapp/{path:path}", response_model=None
)
def get_webapp(
@router.get("/sessions/{session_id}/webapp", response_model=None)
def get_webapp_root(
session_id: UUID,
request: Request,
path: str = "",
user: User | None = Depends(optional_user),
_: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> StreamingResponse | Response:
"""Proxy the webapp for a specific session (root and subpaths).
"""Proxy the root path of the webapp for a specific session."""
return _proxy_request("", request, session_id, db_session)
Accessible without authentication when sharing_scope is public_global.
Returns a friendly offline page when the sandbox is not running.
@router.get("/sessions/{session_id}/webapp/{path:path}", response_model=None)
def get_webapp_path(
session_id: UUID,
path: str,
request: Request,
_: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> StreamingResponse | Response:
"""Proxy any subpath of the webapp (static assets, etc.) for a specific session."""
return _proxy_request(path, request, session_id, db_session)
# Separate router for Next.js static assets at /_next/*
# This is needed because Next.js apps may reference assets with root-relative paths
# that don't get rewritten. The session_id is extracted from the Referer header.
nextjs_assets_router = APIRouter()
def _extract_session_from_referer(request: Request) -> UUID | None:
"""Extract session_id from the Referer header.
Expects Referer to contain /api/build/sessions/{session_id}/webapp
"""
try:
_check_webapp_access(session_id, user, db_session)
except HTTPException as e:
if e.status_code == 401:
return RedirectResponse(url="/auth/login", status_code=302)
raise
try:
return _proxy_request(path, request, session_id, db_session)
except HTTPException as e:
if e.status_code in (502, 503, 504):
return _offline_html_response()
raise
import re
referer = request.headers.get("referer", "")
match = re.search(r"/api/build/sessions/([a-f0-9-]+)/webapp", referer)
if match:
try:
return UUID(match.group(1))
except ValueError:
return None
return None
@nextjs_assets_router.get("/_next/{path:path}", response_model=None)
def get_nextjs_assets(
path: str,
request: Request,
_: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> StreamingResponse | Response:
"""Proxy Next.js static assets requested at root /_next/ path.
The session_id is extracted from the Referer header since these requests
come from within the iframe context.
"""
session_id = _extract_session_from_referer(request)
if not session_id:
raise HTTPException(
status_code=400,
detail="Could not determine session from request context",
)
return _proxy_request(f"_next/{path}", request, session_id, db_session)
# =============================================================================

View File

@@ -10,7 +10,6 @@ from onyx.configs.constants import MessageType
from onyx.db.enums import ArtifactType
from onyx.db.enums import BuildSessionStatus
from onyx.db.enums import SandboxStatus
from onyx.db.enums import SharingScope
from onyx.server.features.build.sandbox.models import (
FilesystemEntry as FileSystemEntry,
)
@@ -108,7 +107,6 @@ class SessionResponse(BaseModel):
nextjs_port: int | None
sandbox: SandboxResponse | None
artifacts: list[ArtifactResponse]
sharing_scope: SharingScope
@classmethod
def from_model(
@@ -131,7 +129,6 @@ class SessionResponse(BaseModel):
nextjs_port=session.nextjs_port,
sandbox=(SandboxResponse.from_model(sandbox) if sandbox else None),
artifacts=[ArtifactResponse.from_model(a) for a in session.artifacts],
sharing_scope=session.sharing_scope,
)
@@ -162,19 +159,6 @@ class SessionListResponse(BaseModel):
sessions: list[SessionResponse]
class SetSessionSharingRequest(BaseModel):
"""Request to set the sharing scope of a session."""
sharing_scope: SharingScope
class SetSessionSharingResponse(BaseModel):
"""Response after setting session sharing scope."""
session_id: str
sharing_scope: SharingScope
# ===== Message Models =====
class MessageRequest(BaseModel):
"""Request to send a message to the CLI agent."""
@@ -260,7 +244,6 @@ class WebappInfo(BaseModel):
webapp_url: str | None # URL to access the webapp (e.g., http://localhost:3015)
status: str # Sandbox status (running, terminated, etc.)
ready: bool # Whether the NextJS dev server is actually responding
sharing_scope: SharingScope
# ===== File Upload Models =====

View File

@@ -30,8 +30,6 @@ from onyx.server.features.build.api.models import SessionListResponse
from onyx.server.features.build.api.models import SessionNameGenerateResponse
from onyx.server.features.build.api.models import SessionResponse
from onyx.server.features.build.api.models import SessionUpdateRequest
from onyx.server.features.build.api.models import SetSessionSharingRequest
from onyx.server.features.build.api.models import SetSessionSharingResponse
from onyx.server.features.build.api.models import SuggestionBubble
from onyx.server.features.build.api.models import SuggestionTheme
from onyx.server.features.build.api.models import UploadResponse
@@ -40,7 +38,6 @@ from onyx.server.features.build.configs import SANDBOX_BACKEND
from onyx.server.features.build.configs import SandboxBackend
from onyx.server.features.build.db.build_session import allocate_nextjs_port
from onyx.server.features.build.db.build_session import get_build_session
from onyx.server.features.build.db.build_session import set_build_session_sharing_scope
from onyx.server.features.build.db.sandbox import get_latest_snapshot_for_session
from onyx.server.features.build.db.sandbox import get_sandbox_by_user_id
from onyx.server.features.build.db.sandbox import update_sandbox_heartbeat
@@ -297,25 +294,6 @@ def update_session_name(
return SessionResponse.from_model(session, sandbox)
@router.patch("/{session_id}/public")
def set_session_public(
session_id: UUID,
request: SetSessionSharingRequest,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> SetSessionSharingResponse:
"""Set the sharing scope of a build session's webapp."""
updated = set_build_session_sharing_scope(
session_id, user.id, request.sharing_scope, db_session
)
if not updated:
raise HTTPException(status_code=404, detail="Session not found")
return SetSessionSharingResponse(
session_id=str(session_id),
sharing_scope=updated.sharing_scope,
)
@router.delete("/{session_id}", response_model=None)
def delete_session(
session_id: UUID,

View File

@@ -1,110 +0,0 @@
<!doctype html>
<html lang="en">
<head>
<meta charset="UTF-8" />
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
<meta http-equiv="refresh" content="15" />
<title>Craft — Starting up</title>
<style>
*,
*::before,
*::after {
box-sizing: border-box;
margin: 0;
padding: 0;
}
body {
font-family: ui-monospace, SFMono-Regular, "SF Mono", Menlo, Consolas,
monospace;
background: linear-gradient(to bottom right, #030712, #111827, #030712);
min-height: 100vh;
display: flex;
flex-direction: column;
align-items: center;
justify-content: center;
gap: 1.5rem;
padding: 2rem;
}
.terminal {
width: 100%;
max-width: 580px;
border: 2px solid #374151;
border-radius: 2px;
}
.titlebar {
background: #1f2937;
padding: 0.5rem 0.75rem;
display: flex;
align-items: center;
gap: 0.5rem;
border-bottom: 1px solid #374151;
}
.btn {
width: 12px;
height: 12px;
border-radius: 2px;
flex-shrink: 0;
}
.btn-red {
background: #ef4444;
}
.btn-yellow {
background: #eab308;
}
.btn-green {
background: #22c55e;
}
.title-label {
flex: 1;
text-align: center;
font-size: 0.75rem;
color: #6b7280;
margin-right: 36px;
}
.body {
background: #111827;
padding: 1.5rem;
min-height: 200px;
font-size: 0.875rem;
color: #d1d5db;
display: flex;
align-items: flex-start;
gap: 0.375rem;
}
.prompt {
color: #10b981;
user-select: none;
}
.tagline {
font-size: 0.8125rem;
color: #4b5563;
text-align: center;
}
</style>
</head>
<body>
<div class="terminal">
<div class="titlebar">
<div class="btn btn-red"></div>
<div class="btn btn-yellow"></div>
<div class="btn btn-green"></div>
<span class="title-label">crafting_table</span>
</div>
<div class="body">
<span class="prompt">/&gt;</span>
<span>Sandbox is asleep...</span>
</div>
</div>
<p class="tagline">
Ask the owner to open their Craft session to wake it up.
</p>
</body>
</html>

View File

@@ -13,7 +13,6 @@ from sqlalchemy.orm import Session
from onyx.configs.constants import MessageType
from onyx.db.enums import BuildSessionStatus
from onyx.db.enums import SandboxStatus
from onyx.db.enums import SharingScope
from onyx.db.models import Artifact
from onyx.db.models import BuildMessage
from onyx.db.models import BuildSession
@@ -160,26 +159,6 @@ def update_session_status(
logger.info(f"Updated build session {session_id} status to {status}")
def set_build_session_sharing_scope(
session_id: UUID,
user_id: UUID,
sharing_scope: SharingScope,
db_session: Session,
) -> BuildSession | None:
"""Set the sharing scope of a build session.
Only the session owner can change this setting.
Returns the updated session, or None if not found/unauthorized.
"""
session = get_build_session(session_id, user_id, db_session)
if not session:
return None
session.sharing_scope = sharing_scope
db_session.commit()
logger.info(f"Set build session {session_id} sharing_scope={sharing_scope}")
return session
def delete_build_session__no_commit(
session_id: UUID,
user_id: UUID,

View File

@@ -474,23 +474,6 @@ class SandboxManager(ABC):
"""
...
def ensure_nextjs_running(
self,
sandbox_id: UUID,
session_id: UUID,
nextjs_port: int,
) -> None:
"""Ensure the Next.js server is running for a session.
Default is a no-op — only meaningful for local backends that manage
process lifecycles directly (e.g., LocalSandboxManager).
Args:
sandbox_id: The sandbox ID
session_id: The session ID
nextjs_port: The port the Next.js server should be listening on
"""
# Singleton instance cache for the factory
_sandbox_manager_instance: SandboxManager | None = None

View File

@@ -15,8 +15,6 @@ from collections.abc import Generator
from pathlib import Path
from uuid import UUID
import httpx
from onyx.db.enums import SandboxStatus
from onyx.file_store.file_store import get_default_file_store
from onyx.server.features.build.configs import DEMO_DATA_PATH
@@ -37,7 +35,6 @@ from onyx.server.features.build.sandbox.models import LLMProviderConfig
from onyx.server.features.build.sandbox.models import SandboxInfo
from onyx.server.features.build.sandbox.models import SnapshotResult
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import ThreadSafeSet
logger = setup_logger()
@@ -92,17 +89,9 @@ class LocalSandboxManager(SandboxManager):
self._acp_clients: dict[tuple[UUID, UUID], ACPAgentClient] = {}
# Track Next.js processes - keyed by (sandbox_id, session_id) tuple
# Used for clean shutdown when sessions are deleted.
# Mutated from background threads; all access must hold _nextjs_lock.
# Used for clean shutdown when sessions are deleted
self._nextjs_processes: dict[tuple[UUID, UUID], subprocess.Popen[bytes]] = {}
# Track sessions currently being (re)started - prevents concurrent restarts.
# ThreadSafeSet allows atomic check-and-add without holding _nextjs_lock.
self._nextjs_starting: ThreadSafeSet[tuple[UUID, UUID]] = ThreadSafeSet()
# Lock guarding _nextjs_processes (shared across sessions; hold briefly only)
self._nextjs_lock = threading.Lock()
# Validate templates exist (raises RuntimeError if missing)
self._validate_templates()
@@ -337,18 +326,16 @@ class LocalSandboxManager(SandboxManager):
RuntimeError: If termination fails
"""
# Stop all Next.js processes for this sandbox (keyed by (sandbox_id, session_id))
with self._nextjs_lock:
processes_to_stop = [
(key, process)
for key, process in self._nextjs_processes.items()
if key[0] == sandbox_id
]
processes_to_stop = [
(key, process)
for key, process in self._nextjs_processes.items()
if key[0] == sandbox_id
]
for key, process in processes_to_stop:
session_id = key[1]
try:
self._stop_nextjs_process(process, session_id)
with self._nextjs_lock:
self._nextjs_processes.pop(key, None)
del self._nextjs_processes[key]
except Exception as e:
logger.warning(
f"Failed to stop Next.js for sandbox {sandbox_id}, "
@@ -529,8 +516,7 @@ class LocalSandboxManager(SandboxManager):
web_dir, nextjs_port
)
# Store process for clean shutdown on session delete
with self._nextjs_lock:
self._nextjs_processes[(sandbox_id, session_id)] = nextjs_process
self._nextjs_processes[(sandbox_id, session_id)] = nextjs_process
logger.info("Next.js server started successfully")
# Setup venv and AGENTS.md
@@ -589,8 +575,7 @@ class LocalSandboxManager(SandboxManager):
"""
# Stop Next.js dev server - try stored process first, then fallback to port lookup
process_key = (sandbox_id, session_id)
with self._nextjs_lock:
nextjs_process = self._nextjs_processes.pop(process_key, None)
nextjs_process = self._nextjs_processes.pop(process_key, None)
if nextjs_process is not None:
self._stop_nextjs_process(nextjs_process, session_id)
elif nextjs_port is not None:
@@ -781,85 +766,6 @@ class LocalSandboxManager(SandboxManager):
outputs_path = session_path / "outputs"
return outputs_path.exists()
def ensure_nextjs_running(
self,
sandbox_id: UUID,
session_id: UUID,
nextjs_port: int,
) -> None:
"""Start Next.js server for a session if not already running.
Called when the server is detected as unreachable (e.g., after API server restart).
Returns immediately — the actual startup runs in a background daemon thread.
A per-session guard prevents concurrent restarts from racing.
Lock design: _nextjs_lock is shared across ALL sessions. Holding it during
httpx (1s) or start_nextjs_server (several seconds) would block every other
session's status checks and restarts. We only hold the lock for fast
in-memory ops (dict get, check_and_add). The slow I/O runs in the background
thread without holding any lock.
Args:
sandbox_id: The sandbox ID
session_id: The session ID
nextjs_port: The port number for the Next.js server
"""
process_key = (sandbox_id, session_id)
with self._nextjs_lock:
existing = self._nextjs_processes.get(process_key)
if existing is not None and existing.poll() is None:
return
# Atomic check-and-add: returns True if already in set (another thread is starting)
if self._nextjs_starting.check_and_add(process_key):
return
def _start_in_background() -> None:
try:
# Port check in background to avoid blocking the main thread
try:
with httpx.Client(timeout=1.0) as client:
client.get(f"http://localhost:{nextjs_port}")
logger.info(
f"Port {nextjs_port} already alive for session {session_id} "
"(orphan process) — skipping restart"
)
return
except Exception:
pass # Port is dead; proceed with restart
logger.info(
f"Starting Next.js for session {session_id} on port {nextjs_port}"
)
sandbox_path = self._get_sandbox_path(sandbox_id)
web_dir = self._directory_manager.get_web_path(
sandbox_path, str(session_id)
)
if not web_dir.exists():
logger.warning(
f"Web dir missing for session {session_id}: {web_dir}"
"cannot restart Next.js"
)
return
process = self._process_manager.start_nextjs_server(
web_dir, nextjs_port
)
with self._nextjs_lock:
self._nextjs_processes[process_key] = process
logger.info(
f"Auto-restarted Next.js for session {session_id} "
f"on port {nextjs_port}"
)
except Exception as e:
logger.error(
f"Failed to auto-restart Next.js for session {session_id}: {e}"
)
finally:
self._nextjs_starting.discard(process_key)
threading.Thread(target=_start_in_background, daemon=True).start()
def restore_snapshot(
self,
sandbox_id: UUID,

View File

@@ -0,0 +1,10 @@
"""Celery tasks for sandbox management."""
from onyx.server.features.build.sandbox.tasks.tasks import (
cleanup_idle_sandboxes_task,
) # noqa: F401
from onyx.server.features.build.sandbox.tasks.tasks import (
sync_sandbox_files,
) # noqa: F401
__all__ = ["cleanup_idle_sandboxes_task", "sync_sandbox_files"]

View File

@@ -1765,7 +1765,6 @@ class SessionManager:
"webapp_url": None,
"status": "no_sandbox",
"ready": False,
"sharing_scope": session.sharing_scope,
}
# Return the proxy URL - the proxy handles routing to the correct sandbox
@@ -1778,21 +1777,11 @@ class SessionManager:
# Quick health check: can the API server reach the NextJS dev server?
ready = self._check_nextjs_ready(sandbox.id, session.nextjs_port)
# If not ready, ask the sandbox manager to ensure Next.js is running.
# For the local backend this triggers a background restart so that the
# frontend poll loop eventually sees ready=True without the user having
# to manually recreate the session.
if not ready:
self._sandbox_manager.ensure_nextjs_running(
sandbox.id, session_id, session.nextjs_port
)
return {
"has_webapp": session.nextjs_port is not None,
"webapp_url": webapp_url,
"status": sandbox.status.value,
"ready": ready,
"sharing_scope": session.sharing_scope,
}
def _check_nextjs_ready(self, sandbox_id: UUID, port: int) -> bool:

View File

@@ -111,8 +111,7 @@ class DocumentSet(BaseModel):
id=cc_pair.id,
name=cc_pair.name,
connector=ConnectorSnapshot.from_connector_db_model(
cc_pair.connector,
credential_ids=[cc_pair.credential_id],
cc_pair.connector
),
credential=CredentialSnapshot.from_credential_db_model(
cc_pair.credential

View File

@@ -26,17 +26,13 @@ def get_opensearch_migration_status(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> OpenSearchMigrationStatusResponse:
(
total_chunks_migrated,
created_at,
migration_completed_at,
approx_chunk_count_in_vespa,
) = get_opensearch_migration_state(db_session)
total_chunks_migrated, created_at, migration_completed_at = (
get_opensearch_migration_state(db_session)
)
return OpenSearchMigrationStatusResponse(
total_chunks_migrated=total_chunks_migrated,
created_at=created_at,
migration_completed_at=migration_completed_at,
approx_chunk_count_in_vespa=approx_chunk_count_in_vespa,
)

View File

@@ -8,7 +8,6 @@ class OpenSearchMigrationStatusResponse(BaseModel):
total_chunks_migrated: int
created_at: datetime | None
migration_completed_at: datetime | None
approx_chunk_count_in_vespa: int | None
class OpenSearchRetrievalStatusRequest(BaseModel):

View File

@@ -608,8 +608,7 @@ def list_all_users_basic_info(
return [
MinimalUserSnapshot(id=user.id, email=user.email)
for user in users
if user.role != UserRole.SLACK_USER
and (include_api_keys or not is_api_key_email_address(user.email))
if include_api_keys or not is_api_key_email_address(user.email)
]

View File

@@ -1,241 +0,0 @@
"""SQLAlchemy connection pool Prometheus metrics.
Provides production-grade visibility into database connection pool state:
- Pool state gauges (checked-out, idle, overflow, configured size)
- Pool lifecycle counters (checkouts, checkins, creates, invalidations, timeouts)
- Per-endpoint connection attribution (which endpoints hold connections, for how long)
Metrics are collected via two mechanisms:
1. A custom Prometheus Collector that reads pool snapshots on each /metrics scrape
2. SQLAlchemy pool event listeners (checkout, checkin, connect, invalidate) for
counters, histograms, and attribution
"""
import time
from fastapi import Request
from fastapi.responses import JSONResponse
from prometheus_client import Counter
from prometheus_client import Gauge
from prometheus_client import Histogram
from prometheus_client.core import GaugeMetricFamily
from prometheus_client.registry import Collector
from prometheus_client.registry import REGISTRY
from sqlalchemy import event
from sqlalchemy.engine import Engine
from sqlalchemy.engine.interfaces import DBAPIConnection
from sqlalchemy.ext.asyncio import AsyncEngine
from sqlalchemy.pool import ConnectionPoolEntry
from sqlalchemy.pool import PoolProxiedConnection
from sqlalchemy.pool import QueuePool
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import CURRENT_ENDPOINT_CONTEXTVAR
logger = setup_logger()
# --- Pool lifecycle counters (event-driven) ---
_checkout_total = Counter(
"onyx_db_pool_checkout_total",
"Total connection checkouts from the pool",
["engine"],
)
_checkin_total = Counter(
"onyx_db_pool_checkin_total",
"Total connection checkins to the pool",
["engine"],
)
_connections_created_total = Counter(
"onyx_db_pool_connections_created_total",
"Total new database connections created",
["engine"],
)
_invalidations_total = Counter(
"onyx_db_pool_invalidations_total",
"Total connection invalidations",
["engine"],
)
_checkout_timeout_total = Counter(
"onyx_db_pool_checkout_timeout_total",
"Total connection checkout timeouts",
["engine"],
)
# --- Per-endpoint attribution (event-driven) ---
_connections_held = Gauge(
"onyx_db_connections_held_by_endpoint",
"Number of DB connections currently held, by endpoint and engine",
["handler", "engine"],
)
_hold_seconds = Histogram(
"onyx_db_connection_hold_seconds",
"Duration a DB connection is held by an endpoint",
["handler", "engine"],
)
def pool_timeout_handler(
request: Request, # noqa: ARG001
exc: Exception,
) -> JSONResponse:
"""Increment the checkout timeout counter and return 503."""
_checkout_timeout_total.labels(engine="unknown").inc()
return JSONResponse(
status_code=503,
content={
"detail": "Database connection pool timeout",
"error": str(exc),
},
)
class PoolStateCollector(Collector):
"""Custom Prometheus collector that reads QueuePool state on each scrape.
Uses pool.checkedout(), pool.checkedin(), pool.overflow(), and pool.size()
for an atomic snapshot of pool state. Registered engines are stored as
(label, pool) tuples to avoid holding references to the full Engine.
"""
def __init__(self) -> None:
self._pools: list[tuple[str, QueuePool]] = []
def add_pool(self, label: str, pool: QueuePool) -> None:
self._pools.append((label, pool))
def collect(self) -> list[GaugeMetricFamily]:
checked_out = GaugeMetricFamily(
"onyx_db_pool_checked_out",
"Currently checked-out connections",
labels=["engine"],
)
checked_in = GaugeMetricFamily(
"onyx_db_pool_checked_in",
"Idle connections available in the pool",
labels=["engine"],
)
overflow = GaugeMetricFamily(
"onyx_db_pool_overflow",
"Current overflow connections beyond pool_size",
labels=["engine"],
)
size = GaugeMetricFamily(
"onyx_db_pool_size",
"Configured pool size",
labels=["engine"],
)
for label, pool in self._pools:
checked_out.add_metric([label], pool.checkedout())
checked_in.add_metric([label], pool.checkedin())
overflow.add_metric([label], pool.overflow())
size.add_metric([label], pool.size())
return [checked_out, checked_in, overflow, size]
def describe(self) -> list[GaugeMetricFamily]:
# Return empty to mark this as an "unchecked" collector. Prometheus
# skips upfront descriptor validation and just calls collect() at
# scrape time. Required because our metrics are dynamic (engine
# labels depend on which engines are registered at runtime).
return []
def _register_pool_events(engine: Engine, label: str) -> None:
"""Attach pool event listeners for metrics collection.
Listens to checkout, checkin, connect, and invalidate events.
Stores per-connection metadata on connection_record.info for attribution.
"""
@event.listens_for(engine, "checkout")
def on_checkout(
dbapi_conn: DBAPIConnection, # noqa: ARG001
conn_record: ConnectionPoolEntry,
conn_proxy: PoolProxiedConnection, # noqa: ARG001
) -> None:
handler = CURRENT_ENDPOINT_CONTEXTVAR.get() or "unknown"
conn_record.info["_metrics_endpoint"] = handler
conn_record.info["_metrics_checkout_time"] = time.monotonic()
_checkout_total.labels(engine=label).inc()
_connections_held.labels(handler=handler, engine=label).inc()
@event.listens_for(engine, "checkin")
def on_checkin(
dbapi_conn: DBAPIConnection, # noqa: ARG001
conn_record: ConnectionPoolEntry,
) -> None:
handler = conn_record.info.pop("_metrics_endpoint", "unknown")
start = conn_record.info.pop("_metrics_checkout_time", None)
_checkin_total.labels(engine=label).inc()
_connections_held.labels(handler=handler, engine=label).dec()
if start is not None:
_hold_seconds.labels(handler=handler, engine=label).observe(
time.monotonic() - start
)
@event.listens_for(engine, "connect")
def on_connect(
dbapi_conn: DBAPIConnection, # noqa: ARG001
conn_record: ConnectionPoolEntry, # noqa: ARG001
) -> None:
_connections_created_total.labels(engine=label).inc()
@event.listens_for(engine, "invalidate")
def on_invalidate(
dbapi_conn: DBAPIConnection, # noqa: ARG001
conn_record: ConnectionPoolEntry,
exception: BaseException | None, # noqa: ARG001
) -> None:
_invalidations_total.labels(engine=label).inc()
# Defensively clean up the held-connections gauge in case checkin
# doesn't fire after invalidation (e.g. hard pool shutdown).
handler = conn_record.info.pop("_metrics_endpoint", None)
start = conn_record.info.pop("_metrics_checkout_time", None)
if handler:
_connections_held.labels(handler=handler, engine=label).dec()
if start is not None:
_hold_seconds.labels(handler=handler or "unknown", engine=label).observe(
time.monotonic() - start
)
def setup_postgres_connection_pool_metrics(
engines: dict[str, Engine | AsyncEngine],
) -> None:
"""Register pool metrics for all provided engines.
Args:
engines: Mapping of engine label to Engine or AsyncEngine.
Example: {"sync": sync_engine, "async": async_engine, "readonly": ro_engine}
Engines using NullPool are skipped (no pool state to monitor).
For AsyncEngine, events are registered on the underlying sync_engine.
"""
collector = PoolStateCollector()
for label, engine in engines.items():
# Resolve async engines to their underlying sync engine
sync_engine = engine.sync_engine if isinstance(engine, AsyncEngine) else engine
pool = sync_engine.pool
if not isinstance(pool, QueuePool):
logger.info(
f"Skipping pool metrics for engine '{label}' "
f"({type(pool).__name__} — no pool state)"
)
continue
collector.add_pool(label, pool)
_register_pool_events(sync_engine, label)
logger.info(f"Registered pool metrics for engine '{label}'")
REGISTRY.register(collector)

View File

@@ -1,64 +0,0 @@
"""Prometheus metrics setup for the Onyx API server.
Orchestrates HTTP request instrumentation via ``prometheus-fastapi-instrumentator``:
- Request count, latency histograms, in-progress gauges
- Pool checkout timeout exception handler
- Custom metric callbacks (e.g. slow request counting)
SQLAlchemy connection pool metrics are registered separately via
``setup_postgres_connection_pool_metrics`` during application lifespan
(after engines are created).
"""
from prometheus_fastapi_instrumentator import Instrumentator
from sqlalchemy.exc import TimeoutError as SATimeoutError
from starlette.applications import Starlette
from onyx.server.metrics.postgres_connection_pool import pool_timeout_handler
from onyx.server.metrics.slow_requests import slow_request_callback
_EXCLUDED_HANDLERS = [
"/health",
"/metrics",
"/openapi.json",
]
# Denser buckets for per-handler latency histograms. The instrumentator's
# default (0.1, 0.5, 1) is too coarse for meaningful P95/P99 computation.
_LATENCY_BUCKETS = (
0.01,
0.025,
0.05,
0.1,
0.25,
0.5,
1.0,
2.5,
5.0,
10.0,
)
def setup_prometheus_metrics(app: Starlette) -> None:
"""Initialize HTTP request metrics for the Onyx API server.
Must be called in ``get_application()`` BEFORE the app starts, because
the instrumentator adds middleware via ``app.add_middleware()``.
Args:
app: The FastAPI/Starlette application to instrument.
"""
app.add_exception_handler(SATimeoutError, pool_timeout_handler)
instrumentator = Instrumentator(
should_group_status_codes=False,
should_ignore_untemplated=False,
should_group_untemplated=True,
should_instrument_requests_inprogress=True,
inprogress_labels=True,
excluded_handlers=_EXCLUDED_HANDLERS,
)
instrumentator.add(slow_request_callback)
instrumentator.instrument(app, latency_lowr_buckets=_LATENCY_BUCKETS).expose(app)

View File

@@ -1,31 +0,0 @@
"""Slow request counter metric.
Increments a counter whenever a request exceeds a configurable duration
threshold. Useful for identifying endpoints that regularly take too long.
"""
import os
from prometheus_client import Counter
from prometheus_fastapi_instrumentator.metrics import Info
SLOW_REQUEST_THRESHOLD_SECONDS: float = max(
0.0,
float(os.environ.get("SLOW_REQUEST_THRESHOLD_SECONDS", "1.0")),
)
_slow_requests = Counter(
"onyx_api_slow_requests_total",
"Total requests exceeding the slow request threshold",
["method", "handler", "status"],
)
def slow_request_callback(info: Info) -> None:
"""Increment slow request counter when duration exceeds threshold."""
if info.modified_duration > SLOW_REQUEST_THRESHOLD_SECONDS:
_slow_requests.labels(
method=info.method,
handler=info.modified_handler,
status=info.modified_status,
).inc()

View File

@@ -0,0 +1,63 @@
"""Prometheus instrumentation for the Onyx API server.
Provides a production-grade metrics configuration with:
- Exact HTTP status codes (no grouping into 2xx/3xx)
- In-progress request gauge broken down by handler and method
- Custom latency histogram buckets tuned for API workloads
- Request/response size tracking
- Slow request counter with configurable threshold
"""
import os
from prometheus_client import Counter
from prometheus_fastapi_instrumentator import Instrumentator
from prometheus_fastapi_instrumentator.metrics import Info
from starlette.applications import Starlette
SLOW_REQUEST_THRESHOLD_SECONDS: float = float(
os.environ.get("SLOW_REQUEST_THRESHOLD_SECONDS", "1.0")
)
_EXCLUDED_HANDLERS = [
"/health",
"/metrics",
"/openapi.json",
]
_slow_requests = Counter(
"onyx_api_slow_requests_total",
"Total requests exceeding the slow request threshold",
["method", "handler", "status"],
)
def _slow_request_callback(info: Info) -> None:
"""Increment slow request counter when duration exceeds threshold."""
if info.modified_duration > SLOW_REQUEST_THRESHOLD_SECONDS:
_slow_requests.labels(
method=info.method,
handler=info.modified_handler,
status=info.modified_status,
).inc()
def setup_prometheus_metrics(app: Starlette) -> None:
"""Configure and attach Prometheus instrumentation to the FastAPI app.
Records exact status codes, tracks in-progress requests per handler,
and counts slow requests exceeding a configurable threshold.
"""
instrumentator = Instrumentator(
should_group_status_codes=False,
should_ignore_untemplated=False,
should_group_untemplated=True,
should_instrument_requests_inprogress=True,
inprogress_labels=True,
excluded_handlers=_EXCLUDED_HANDLERS,
)
instrumentator.add(_slow_request_callback)
instrumentator.instrument(app).expose(app)

View File

@@ -36,8 +36,6 @@ from onyx.server.query_and_chat.streaming_models import OpenUrlStart
from onyx.server.query_and_chat.streaming_models import OpenUrlUrls
from onyx.server.query_and_chat.streaming_models import OverallStop
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import PythonToolDelta
from onyx.server.query_and_chat.streaming_models import PythonToolStart
from onyx.server.query_and_chat.streaming_models import ReasoningDelta
from onyx.server.query_and_chat.streaming_models import ReasoningStart
from onyx.server.query_and_chat.streaming_models import ResearchAgentStart
@@ -52,7 +50,6 @@ from onyx.tools.tool_implementations.images.image_generation_tool import (
)
from onyx.tools.tool_implementations.memory.memory_tool import MemoryTool
from onyx.tools.tool_implementations.open_url.open_url_tool import OpenURLTool
from onyx.tools.tool_implementations.python.python_tool import PythonTool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
from onyx.utils.logger import setup_logger
@@ -380,37 +377,6 @@ def create_memory_packets(
return packets
def create_python_tool_packets(
code: str,
stdout: str,
stderr: str,
file_ids: list[str],
turn_index: int,
tab_index: int = 0,
) -> list[Packet]:
"""Recreate PythonToolStart + PythonToolDelta + SectionEnd from the stored
tool call data so the frontend can display both the code and its output
on page reload."""
packets: list[Packet] = []
placement = Placement(turn_index=turn_index, tab_index=tab_index)
packets.append(Packet(placement=placement, obj=PythonToolStart(code=code)))
packets.append(
Packet(
placement=placement,
obj=PythonToolDelta(
stdout=stdout,
stderr=stderr,
file_ids=file_ids,
),
)
)
packets.append(Packet(placement=placement, obj=SectionEnd()))
return packets
def create_search_packets(
search_queries: list[str],
search_docs: list[SavedSearchDoc],
@@ -620,41 +586,6 @@ def translate_assistant_message_to_packets(
)
)
elif tool.in_code_tool_id == PythonTool.__name__:
code = cast(
str,
tool_call.tool_call_arguments.get("code", ""),
)
stdout = ""
stderr = ""
file_ids: list[str] = []
if tool_call.tool_call_response:
try:
response_data = json.loads(tool_call.tool_call_response)
stdout = response_data.get("stdout", "")
stderr = response_data.get("stderr", "")
generated_files = response_data.get(
"generated_files", []
)
file_ids = [
f.get("file_link", "").split("/")[-1]
for f in generated_files
if f.get("file_link")
]
except (json.JSONDecodeError, KeyError):
# Fall back to raw response as stdout
stdout = tool_call.tool_call_response
turn_tool_packets.extend(
create_python_tool_packets(
code=code,
stdout=stdout,
stderr=stderr,
file_ids=file_ids,
turn_index=turn_num,
tab_index=tool_call.tab_index,
)
)
else:
# Custom tool or unknown tool
turn_tool_packets.extend(

View File

@@ -24,7 +24,6 @@ from onyx.auth.users import get_user_manager
from onyx.auth.users import UserManager
from onyx.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
from onyx.configs.app_configs import SAML_CONF_DIR
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.db.auth import get_user_count
from onyx.db.auth import get_user_db
from onyx.db.engine.async_sql_engine import get_async_session_context_manager
@@ -124,12 +123,9 @@ async def prepare_from_fastapi_request(request: Request) -> dict[str, Any]:
if request.client is None:
raise ValueError("Invalid request for SAML")
# Derive http_host and server_port from WEB_DOMAIN (a trusted env var)
# instead of X-Forwarded-* headers, which can be spoofed by an attacker
# to poison SAML redirect URLs (host header poisoning).
parsed_domain = urlparse(WEB_DOMAIN)
http_host = parsed_domain.hostname or request.client.host
server_port = parsed_domain.port or (443 if parsed_domain.scheme == "https" else 80)
# Use X-Forwarded headers if available
http_host = request.headers.get("X-Forwarded-Host") or request.client.host
server_port = request.headers.get("X-Forwarded-Port") or request.url.port
rv: dict[str, Any] = {
"http_host": http_host,

View File

@@ -55,9 +55,7 @@ class Settings(BaseModel):
gpu_enabled: bool | None = None
application_status: ApplicationStatus = ApplicationStatus.ACTIVE
anonymous_user_enabled: bool | None = None
invite_only_enabled: bool = False
deep_research_enabled: bool | None = None
search_ui_enabled: bool | None = None
# Enterprise features flag - set by license enforcement at runtime
# When LICENSE_ENFORCEMENT_ENABLED=true, this reflects license status

View File

@@ -199,12 +199,6 @@ class PythonToolOverrideKwargs(BaseModel):
chat_files: list[ChatFile] = []
class ImageGenerationToolOverrideKwargs(BaseModel):
"""Override kwargs for image generation tool calls."""
recent_generated_image_file_ids: list[str] = []
class SearchToolRunContext(BaseModel):
emitter: Emitter

View File

@@ -171,8 +171,10 @@ def construct_tools(
if not search_tool_config:
search_tool_config = SearchToolConfig()
# TODO concerning passing the db_session here.
search_tool = SearchTool(
tool_id=db_tool_model.id,
db_session=db_session,
emitter=emitter,
user=user,
persona=persona,
@@ -420,6 +422,7 @@ def construct_tools(
search_tool = SearchTool(
tool_id=search_tool_db_model.id,
db_session=db_session,
emitter=emitter,
user=user,
persona=persona,

View File

@@ -11,14 +11,11 @@ from onyx.chat.emitter import Emitter
from onyx.configs.app_configs import IMAGE_MODEL_NAME
from onyx.configs.app_configs import IMAGE_MODEL_PROVIDER
from onyx.db.image_generation import get_default_image_generation_config
from onyx.file_store.models import ChatFileType
from onyx.file_store.utils import build_frontend_file_url
from onyx.file_store.utils import load_chat_file_by_id
from onyx.file_store.utils import save_files
from onyx.image_gen.factory import get_image_generation_provider
from onyx.image_gen.factory import validate_credentials
from onyx.image_gen.interfaces import ImageGenerationProviderCredentials
from onyx.image_gen.interfaces import ReferenceImage
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import GeneratedImage
from onyx.server.query_and_chat.streaming_models import ImageGenerationFinal
@@ -26,7 +23,6 @@ from onyx.server.query_and_chat.streaming_models import ImageGenerationToolHeart
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.tools.interface import Tool
from onyx.tools.models import ImageGenerationToolOverrideKwargs
from onyx.tools.models import ToolCallException
from onyx.tools.models import ToolExecutionException
from onyx.tools.models import ToolResponse
@@ -35,7 +31,6 @@ from onyx.tools.tool_implementations.images.models import (
)
from onyx.tools.tool_implementations.images.models import ImageGenerationResponse
from onyx.tools.tool_implementations.images.models import ImageShape
from onyx.utils.b64 import get_image_type_from_bytes
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
@@ -45,10 +40,10 @@ logger = setup_logger()
HEARTBEAT_INTERVAL = 5.0
PROMPT_FIELD = "prompt"
REFERENCE_IMAGE_FILE_IDS_FIELD = "reference_image_file_ids"
class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
# override_kwargs is not supported for image generation tools
class ImageGenerationTool(Tool[None]):
NAME = "generate_image"
DESCRIPTION = "Generate an image based on a prompt. Do not use unless the user specifically requests an image."
DISPLAY_NAME = "Image Generation"
@@ -64,7 +59,6 @@ class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
) -> None:
super().__init__(emitter=emitter)
self.model = model
self.provider = provider
self.num_imgs = num_imgs
self.img_provider = get_image_generation_provider(
@@ -139,16 +133,6 @@ class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
),
"enum": [shape.value for shape in ImageShape],
},
REFERENCE_IMAGE_FILE_IDS_FIELD: {
"type": "array",
"description": (
"Optional image file IDs to use as reference context for edits/variations. "
"Use the file_id values returned by previous generate_image calls."
),
"items": {
"type": "string",
},
},
},
"required": [PROMPT_FIELD],
},
@@ -164,10 +148,7 @@ class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
)
def _generate_image(
self,
prompt: str,
shape: ImageShape,
reference_images: list[ReferenceImage] | None = None,
self, prompt: str, shape: ImageShape
) -> tuple[ImageGenerationResponse, Any]:
if shape == ImageShape.LANDSCAPE:
if "gpt-image-1" in self.model:
@@ -188,7 +169,6 @@ class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
model=self.model,
size=size,
n=1,
reference_images=reference_images,
# response_format parameter is not supported for gpt-image-1
response_format=None if "gpt-image-1" in self.model else "b64_json",
)
@@ -251,117 +231,10 @@ class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
emit_error_packet=True,
)
def _resolve_reference_image_file_ids(
self,
llm_kwargs: dict[str, Any],
override_kwargs: ImageGenerationToolOverrideKwargs | None,
) -> list[str]:
raw_reference_ids = llm_kwargs.get(REFERENCE_IMAGE_FILE_IDS_FIELD)
if raw_reference_ids is not None:
if not isinstance(raw_reference_ids, list) or not all(
isinstance(file_id, str) for file_id in raw_reference_ids
):
raise ToolCallException(
message=(
f"Invalid {REFERENCE_IMAGE_FILE_IDS_FIELD}: expected array of strings, "
f"got {type(raw_reference_ids)}"
),
llm_facing_message=(
f"The '{REFERENCE_IMAGE_FILE_IDS_FIELD}' field must be an array of file_id strings."
),
)
reference_image_file_ids = [
file_id.strip() for file_id in raw_reference_ids if file_id.strip()
]
elif (
override_kwargs
and override_kwargs.recent_generated_image_file_ids
and self.img_provider.supports_reference_images
):
# If no explicit reference was provided, default to the most recently generated image.
reference_image_file_ids = [
override_kwargs.recent_generated_image_file_ids[-1]
]
else:
reference_image_file_ids = []
# Deduplicate while preserving order.
deduped_reference_image_ids: list[str] = []
seen_ids: set[str] = set()
for file_id in reference_image_file_ids:
if file_id in seen_ids:
continue
seen_ids.add(file_id)
deduped_reference_image_ids.append(file_id)
if not deduped_reference_image_ids:
return []
if not self.img_provider.supports_reference_images:
raise ToolCallException(
message=(
f"Reference images requested but provider '{self.provider}' "
"does not support image-editing context."
),
llm_facing_message=(
"This image provider does not support editing from previous image context. "
"Try text-only generation, or switch to a provider/model that supports image edits."
),
)
max_reference_images = self.img_provider.max_reference_images
if max_reference_images > 0:
return deduped_reference_image_ids[-max_reference_images:]
return deduped_reference_image_ids
def _load_reference_images(
self,
reference_image_file_ids: list[str],
) -> list[ReferenceImage]:
reference_images: list[ReferenceImage] = []
for file_id in reference_image_file_ids:
try:
loaded_file = load_chat_file_by_id(file_id)
except Exception as e:
raise ToolCallException(
message=f"Could not load reference image file '{file_id}': {e}",
llm_facing_message=(
f"Reference image file '{file_id}' could not be loaded. "
"Use file_id values returned by previous generate_image calls."
),
)
if loaded_file.file_type != ChatFileType.IMAGE:
raise ToolCallException(
message=f"Reference file '{file_id}' is not an image",
llm_facing_message=f"Reference file '{file_id}' is not an image.",
)
try:
mime_type = get_image_type_from_bytes(loaded_file.content)
except Exception as e:
raise ToolCallException(
message=f"Unsupported reference image format for '{file_id}': {e}",
llm_facing_message=(
f"Reference image '{file_id}' has an unsupported format. "
"Only PNG, JPEG, GIF, and WEBP are supported."
),
)
reference_images.append(
ReferenceImage(
data=loaded_file.content,
mime_type=mime_type,
)
)
return reference_images
def run(
self,
placement: Placement,
override_kwargs: ImageGenerationToolOverrideKwargs | None = None,
override_kwargs: None = None, # noqa: ARG002
**llm_kwargs: Any,
) -> ToolResponse:
if PROMPT_FIELD not in llm_kwargs:
@@ -374,11 +247,6 @@ class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
)
prompt = cast(str, llm_kwargs[PROMPT_FIELD])
shape = ImageShape(llm_kwargs.get("shape", ImageShape.SQUARE.value))
reference_image_file_ids = self._resolve_reference_image_file_ids(
llm_kwargs=llm_kwargs,
override_kwargs=override_kwargs,
)
reference_images = self._load_reference_images(reference_image_file_ids)
# Use threading to generate images in parallel while emitting heartbeats
results: list[tuple[ImageGenerationResponse, Any] | None] = [
@@ -399,7 +267,6 @@ class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
(
prompt,
shape,
reference_images or None,
),
)
for _ in range(self.num_imgs)
@@ -480,7 +347,6 @@ class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
llm_facing_response = json.dumps(
[
{
"file_id": img.file_id,
"revised_prompt": img.revised_prompt,
}
for img in generated_images_metadata

View File

@@ -1,7 +1,6 @@
from __future__ import annotations
from collections.abc import Sequence
from concurrent.futures import ThreadPoolExecutor
from onyx.file_processing.html_utils import ParsedHTML
from onyx.file_processing.html_utils import web_html_cleanup
@@ -22,22 +21,10 @@ from onyx.utils.web_content import title_from_url
logger = setup_logger()
DEFAULT_READ_TIMEOUT_SECONDS = 15
DEFAULT_CONNECT_TIMEOUT_SECONDS = 5
DEFAULT_TIMEOUT_SECONDS = 15
DEFAULT_USER_AGENT = "OnyxWebCrawler/1.0 (+https://www.onyx.app)"
DEFAULT_MAX_PDF_SIZE_BYTES = 50 * 1024 * 1024 # 50 MB
DEFAULT_MAX_HTML_SIZE_BYTES = 20 * 1024 * 1024 # 20 MB
DEFAULT_MAX_WORKERS = 5
def _failed_result(url: str) -> WebContent:
return WebContent(
title="",
link=url,
full_content="",
published_date=None,
scrape_successful=False,
)
class OnyxWebCrawler(WebContentProvider):
@@ -50,14 +37,12 @@ class OnyxWebCrawler(WebContentProvider):
def __init__(
self,
*,
timeout_seconds: int = DEFAULT_READ_TIMEOUT_SECONDS,
connect_timeout_seconds: int = DEFAULT_CONNECT_TIMEOUT_SECONDS,
timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
user_agent: str = DEFAULT_USER_AGENT,
max_pdf_size_bytes: int | None = None,
max_html_size_bytes: int | None = None,
) -> None:
self._read_timeout_seconds = timeout_seconds
self._connect_timeout_seconds = connect_timeout_seconds
self._timeout_seconds = timeout_seconds
self._max_pdf_size_bytes = max_pdf_size_bytes
self._max_html_size_bytes = max_html_size_bytes
self._headers = {
@@ -66,68 +51,75 @@ class OnyxWebCrawler(WebContentProvider):
}
def contents(self, urls: Sequence[str]) -> list[WebContent]:
if not urls:
return []
max_workers = min(DEFAULT_MAX_WORKERS, len(urls))
with ThreadPoolExecutor(max_workers=max_workers) as executor:
return list(executor.map(self._fetch_url_safe, urls))
def _fetch_url_safe(self, url: str) -> WebContent:
"""Wrapper that catches all exceptions so one bad URL doesn't kill the batch."""
try:
return self._fetch_url(url)
except Exception as exc:
logger.warning(
"Onyx crawler unexpected error for %s (%s)",
url,
exc.__class__.__name__,
)
return _failed_result(url)
results: list[WebContent] = []
for url in urls:
results.append(self._fetch_url(url))
return results
def _fetch_url(self, url: str) -> WebContent:
try:
# Use SSRF-safe request to prevent DNS rebinding attacks
response = ssrf_safe_get(
url,
headers=self._headers,
timeout=(self._connect_timeout_seconds, self._read_timeout_seconds),
url, headers=self._headers, timeout=self._timeout_seconds
)
except SSRFException as exc:
logger.error(
"SSRF protection blocked request to %s (%s)",
"SSRF protection blocked request to %s: %s",
url,
exc.__class__.__name__,
str(exc),
)
return _failed_result(url)
except Exception as exc:
return WebContent(
title="",
link=url,
full_content="",
published_date=None,
scrape_successful=False,
)
except Exception as exc: # pragma: no cover - network failures vary
logger.warning(
"Onyx crawler failed to fetch %s (%s)",
url,
exc.__class__.__name__,
)
return _failed_result(url)
return WebContent(
title="",
link=url,
full_content="",
published_date=None,
scrape_successful=False,
)
if response.status_code >= 400:
logger.warning("Onyx crawler received %s for %s", response.status_code, url)
return _failed_result(url)
return WebContent(
title="",
link=url,
full_content="",
published_date=None,
scrape_successful=False,
)
content_type = response.headers.get("Content-Type", "")
content = response.content
content_sniff = content[:1024] if content else None
content_sniff = response.content[:1024] if response.content else None
if is_pdf_resource(url, content_type, content_sniff):
if (
self._max_pdf_size_bytes is not None
and len(content) > self._max_pdf_size_bytes
and len(response.content) > self._max_pdf_size_bytes
):
logger.warning(
"PDF content too large (%d bytes) for %s, max is %d",
len(content),
len(response.content),
url,
self._max_pdf_size_bytes,
)
return _failed_result(url)
text_content, metadata = extract_pdf_text(content)
return WebContent(
title="",
link=url,
full_content="",
published_date=None,
scrape_successful=False,
)
text_content, metadata = extract_pdf_text(response.content)
title = title_from_pdf_metadata(metadata) or title_from_url(url)
return WebContent(
title=title,
@@ -139,19 +131,25 @@ class OnyxWebCrawler(WebContentProvider):
if (
self._max_html_size_bytes is not None
and len(content) > self._max_html_size_bytes
and len(response.content) > self._max_html_size_bytes
):
logger.warning(
"HTML content too large (%d bytes) for %s, max is %d",
len(content),
len(response.content),
url,
self._max_html_size_bytes,
)
return _failed_result(url)
return WebContent(
title="",
link=url,
full_content="",
published_date=None,
scrape_successful=False,
)
try:
decoded_html = decode_html_bytes(
content,
response.content,
content_type=content_type,
fallback_encoding=response.apparent_encoding or response.encoding,
)

Some files were not shown because too many files have changed in this diff Show More