mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-23 02:35:45 +00:00
Compare commits
14 Commits
csv_render
...
sharepoint
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d972fb08fc | ||
|
|
4433d130db | ||
|
|
f108232415 | ||
|
|
6cd7c59a1c | ||
|
|
d31d8092ce | ||
|
|
2cdeecc844 | ||
|
|
8f910a187b | ||
|
|
cfe5b95cc4 | ||
|
|
760c4fae6a | ||
|
|
b769ee530f | ||
|
|
a42114a932 | ||
|
|
e709a9dd0e | ||
|
|
e1c16c2391 | ||
|
|
9ff29b8879 |
2
.github/workflows/helm-chart-releases.yml
vendored
2
.github/workflows/helm-chart-releases.yml
vendored
@@ -33,7 +33,7 @@ jobs:
|
||||
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
|
||||
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
|
||||
helm repo add minio https://charts.min.io/
|
||||
helm repo add code-interpreter https://onyx-dot-app.github.io/python-sandbox/
|
||||
helm repo add code-interpreter https://onyx-dot-app.github.io/code-interpreter/
|
||||
helm repo update
|
||||
|
||||
- name: Build chart dependencies
|
||||
|
||||
@@ -45,6 +45,9 @@ env:
|
||||
# TODO: debug why this is failing and enable
|
||||
CODE_INTERPRETER_BASE_URL: http://localhost:8000
|
||||
|
||||
# OpenSearch
|
||||
OPENSEARCH_ADMIN_PASSWORD: "StrongPassword123!"
|
||||
|
||||
jobs:
|
||||
discover-test-dirs:
|
||||
# NOTE: Github-hosted runners have about 20s faster queue times and are preferred here.
|
||||
@@ -115,10 +118,9 @@ jobs:
|
||||
- name: Create .env file for Docker Compose
|
||||
run: |
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
COMPOSE_PROFILES=s3-filestore,opensearch-enabled
|
||||
COMPOSE_PROFILES=s3-filestore
|
||||
CODE_INTERPRETER_BETA_ENABLED=true
|
||||
DISABLE_TELEMETRY=true
|
||||
OPENSEARCH_FOR_ONYX_ENABLED=true
|
||||
EOF
|
||||
|
||||
- name: Set up Standard Dependencies
|
||||
@@ -127,6 +129,7 @@ jobs:
|
||||
docker compose \
|
||||
-f docker-compose.yml \
|
||||
-f docker-compose.dev.yml \
|
||||
-f docker-compose.opensearch.yml \
|
||||
up -d \
|
||||
minio \
|
||||
relational_db \
|
||||
|
||||
2
.github/workflows/pr-helm-chart-testing.yml
vendored
2
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -91,7 +91,7 @@ jobs:
|
||||
helm repo add cloudnative-pg https://cloudnative-pg.github.io/charts
|
||||
helm repo add ot-container-kit https://ot-container-kit.github.io/helm-charts
|
||||
helm repo add minio https://charts.min.io/
|
||||
helm repo add code-interpreter https://onyx-dot-app.github.io/python-sandbox/
|
||||
helm repo add code-interpreter https://onyx-dot-app.github.io/code-interpreter/
|
||||
helm repo update
|
||||
|
||||
- name: Install Redis operator
|
||||
|
||||
@@ -1,28 +0,0 @@
|
||||
"""add scim_username to scim_user_mapping
|
||||
|
||||
Revision ID: 0bb4558f35df
|
||||
Revises: 631fd2504136
|
||||
Create Date: 2026-02-20 10:45:30.340188
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "0bb4558f35df"
|
||||
down_revision = "631fd2504136"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"scim_user_mapping",
|
||||
sa.Column("scim_username", sa.String(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("scim_user_mapping", "scim_username")
|
||||
@@ -1,32 +0,0 @@
|
||||
"""add approx_chunk_count_in_vespa to opensearch tenant migration
|
||||
|
||||
Revision ID: 631fd2504136
|
||||
Revises: c7f2e1b4a9d3
|
||||
Create Date: 2026-02-18 21:07:52.831215
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "631fd2504136"
|
||||
down_revision = "c7f2e1b4a9d3"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"opensearch_tenant_migration_record",
|
||||
sa.Column(
|
||||
"approx_chunk_count_in_vespa",
|
||||
sa.Integer(),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("opensearch_tenant_migration_record", "approx_chunk_count_in_vespa")
|
||||
@@ -1,31 +0,0 @@
|
||||
"""add sharing_scope to build_session
|
||||
|
||||
Revision ID: c7f2e1b4a9d3
|
||||
Revises: 19c0ccb01687
|
||||
Create Date: 2026-02-17 12:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "c7f2e1b4a9d3"
|
||||
down_revision = "19c0ccb01687"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"build_session",
|
||||
sa.Column(
|
||||
"sharing_scope",
|
||||
sa.String(),
|
||||
nullable=False,
|
||||
server_default="private",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("build_session", "sharing_scope")
|
||||
@@ -9,7 +9,6 @@ from sqlalchemy import Select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.server.user_group.models import SetCuratorRequest
|
||||
@@ -19,15 +18,11 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair_from
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Credential
|
||||
from onyx.db.models import Credential__UserGroup
|
||||
from onyx.db.models import Document
|
||||
from onyx.db.models import DocumentByConnectorCredentialPair
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import DocumentSet__UserGroup
|
||||
from onyx.db.models import FederatedConnector__DocumentSet
|
||||
from onyx.db.models import LLMProvider__UserGroup
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__UserGroup
|
||||
from onyx.db.models import TokenRateLimit__UserGroup
|
||||
from onyx.db.models import User
|
||||
@@ -200,60 +195,8 @@ def fetch_user_group(db_session: Session, user_group_id: int) -> UserGroup | Non
|
||||
return db_session.scalar(stmt)
|
||||
|
||||
|
||||
def _add_user_group_snapshot_eager_loads(
|
||||
stmt: Select,
|
||||
) -> Select:
|
||||
"""Add eager loading options needed by UserGroup.from_model snapshot creation."""
|
||||
return stmt.options(
|
||||
selectinload(UserGroup.users),
|
||||
selectinload(UserGroup.user_group_relationships),
|
||||
selectinload(UserGroup.cc_pair_relationships)
|
||||
.selectinload(UserGroup__ConnectorCredentialPair.cc_pair)
|
||||
.options(
|
||||
selectinload(ConnectorCredentialPair.connector),
|
||||
selectinload(ConnectorCredentialPair.credential).selectinload(
|
||||
Credential.user
|
||||
),
|
||||
),
|
||||
selectinload(UserGroup.document_sets).options(
|
||||
selectinload(DocumentSet.connector_credential_pairs).selectinload(
|
||||
ConnectorCredentialPair.connector
|
||||
),
|
||||
selectinload(DocumentSet.users),
|
||||
selectinload(DocumentSet.groups),
|
||||
selectinload(DocumentSet.federated_connectors).selectinload(
|
||||
FederatedConnector__DocumentSet.federated_connector
|
||||
),
|
||||
),
|
||||
selectinload(UserGroup.personas).options(
|
||||
selectinload(Persona.tools),
|
||||
selectinload(Persona.hierarchy_nodes),
|
||||
selectinload(Persona.attached_documents).selectinload(
|
||||
Document.parent_hierarchy_node
|
||||
),
|
||||
selectinload(Persona.labels),
|
||||
selectinload(Persona.document_sets).options(
|
||||
selectinload(DocumentSet.connector_credential_pairs).selectinload(
|
||||
ConnectorCredentialPair.connector
|
||||
),
|
||||
selectinload(DocumentSet.users),
|
||||
selectinload(DocumentSet.groups),
|
||||
selectinload(DocumentSet.federated_connectors).selectinload(
|
||||
FederatedConnector__DocumentSet.federated_connector
|
||||
),
|
||||
),
|
||||
selectinload(Persona.user),
|
||||
selectinload(Persona.user_files),
|
||||
selectinload(Persona.users),
|
||||
selectinload(Persona.groups),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def fetch_user_groups(
|
||||
db_session: Session,
|
||||
only_up_to_date: bool = True,
|
||||
eager_load_for_snapshot: bool = False,
|
||||
db_session: Session, only_up_to_date: bool = True
|
||||
) -> Sequence[UserGroup]:
|
||||
"""
|
||||
Fetches user groups from the database.
|
||||
@@ -266,8 +209,6 @@ def fetch_user_groups(
|
||||
db_session (Session): The SQLAlchemy session used to query the database.
|
||||
only_up_to_date (bool, optional): Flag to determine whether to filter the results
|
||||
to include only up to date user groups. Defaults to `True`.
|
||||
eager_load_for_snapshot: If True, adds eager loading for all relationships
|
||||
needed by UserGroup.from_model snapshot creation.
|
||||
|
||||
Returns:
|
||||
Sequence[UserGroup]: A sequence of `UserGroup` objects matching the query criteria.
|
||||
@@ -275,16 +216,11 @@ def fetch_user_groups(
|
||||
stmt = select(UserGroup)
|
||||
if only_up_to_date:
|
||||
stmt = stmt.where(UserGroup.is_up_to_date == True) # noqa: E712
|
||||
if eager_load_for_snapshot:
|
||||
stmt = _add_user_group_snapshot_eager_loads(stmt)
|
||||
return db_session.scalars(stmt).unique().all()
|
||||
return db_session.scalars(stmt).all()
|
||||
|
||||
|
||||
def fetch_user_groups_for_user(
|
||||
db_session: Session,
|
||||
user_id: UUID,
|
||||
only_curator_groups: bool = False,
|
||||
eager_load_for_snapshot: bool = False,
|
||||
db_session: Session, user_id: UUID, only_curator_groups: bool = False
|
||||
) -> Sequence[UserGroup]:
|
||||
stmt = (
|
||||
select(UserGroup)
|
||||
@@ -294,9 +230,7 @@ def fetch_user_groups_for_user(
|
||||
)
|
||||
if only_curator_groups:
|
||||
stmt = stmt.where(User__UserGroup.is_curator == True) # noqa: E712
|
||||
if eager_load_for_snapshot:
|
||||
stmt = _add_user_group_snapshot_eager_loads(stmt)
|
||||
return db_session.scalars(stmt).unique().all()
|
||||
return db_session.scalars(stmt).all()
|
||||
|
||||
|
||||
def construct_document_id_select_by_usergroup(
|
||||
|
||||
@@ -69,7 +69,7 @@ def _graph_api_get(
|
||||
continue
|
||||
resp.raise_for_status()
|
||||
return resp.json()
|
||||
except (_requests.ConnectionError, _requests.Timeout, _requests.HTTPError):
|
||||
except (_requests.ConnectionError, _requests.Timeout):
|
||||
if attempt < GRAPH_API_MAX_RETRIES:
|
||||
wait = min(2**attempt, 60)
|
||||
logger.warning(
|
||||
|
||||
@@ -37,15 +37,12 @@ def list_user_groups(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[UserGroup]:
|
||||
if user.role == UserRole.ADMIN:
|
||||
user_groups = fetch_user_groups(
|
||||
db_session, only_up_to_date=False, eager_load_for_snapshot=True
|
||||
)
|
||||
user_groups = fetch_user_groups(db_session, only_up_to_date=False)
|
||||
else:
|
||||
user_groups = fetch_user_groups_for_user(
|
||||
db_session=db_session,
|
||||
user_id=user.id,
|
||||
only_curator_groups=user.role == UserRole.CURATOR,
|
||||
eager_load_for_snapshot=True,
|
||||
)
|
||||
return [UserGroup.from_model(user_group) for user_group in user_groups]
|
||||
|
||||
|
||||
@@ -53,8 +53,7 @@ class UserGroup(BaseModel):
|
||||
id=cc_pair_relationship.cc_pair.id,
|
||||
name=cc_pair_relationship.cc_pair.name,
|
||||
connector=ConnectorSnapshot.from_connector_db_model(
|
||||
cc_pair_relationship.cc_pair.connector,
|
||||
credential_ids=[cc_pair_relationship.cc_pair.credential_id],
|
||||
cc_pair_relationship.cc_pair.connector
|
||||
),
|
||||
credential=CredentialSnapshot.from_credential_db_model(
|
||||
cc_pair_relationship.cc_pair.credential
|
||||
|
||||
@@ -121,7 +121,6 @@ from onyx.db.pat import fetch_user_for_pat
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.redis.redis_pool import get_async_redis_connection
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
@@ -138,8 +137,6 @@ from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
REGISTER_INVITE_ONLY_CODE = "REGISTER_INVITE_ONLY"
|
||||
|
||||
|
||||
def is_user_admin(user: User) -> bool:
|
||||
return user.role == UserRole.ADMIN
|
||||
@@ -211,34 +208,22 @@ def anonymous_user_enabled(*, tenant_id: str | None = None) -> bool:
|
||||
return int(value.decode("utf-8")) == 1
|
||||
|
||||
|
||||
def workspace_invite_only_enabled() -> bool:
|
||||
settings = load_settings()
|
||||
return settings.invite_only_enabled
|
||||
|
||||
|
||||
def verify_email_is_invited(email: str) -> None:
|
||||
if AUTH_TYPE in {AuthType.SAML, AuthType.OIDC}:
|
||||
# SSO providers manage membership; allow JIT provisioning regardless of invites
|
||||
return
|
||||
|
||||
if not workspace_invite_only_enabled():
|
||||
whitelist = get_invited_users()
|
||||
if not whitelist:
|
||||
return
|
||||
|
||||
whitelist = get_invited_users()
|
||||
|
||||
if not email:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"reason": "Email must be specified"},
|
||||
)
|
||||
raise PermissionError("Email must be specified")
|
||||
|
||||
try:
|
||||
email_info = validate_email(email, check_deliverability=False)
|
||||
except EmailUndeliverableError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"reason": "Email is not valid"},
|
||||
)
|
||||
raise PermissionError("Email is not valid")
|
||||
|
||||
for email_whitelist in whitelist:
|
||||
try:
|
||||
@@ -255,13 +240,7 @@ def verify_email_is_invited(email: str) -> None:
|
||||
if email_info.normalized.lower() == email_info_whitelist.normalized.lower():
|
||||
return
|
||||
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail={
|
||||
"code": REGISTER_INVITE_ONLY_CODE,
|
||||
"reason": "This workspace is invite-only. Please ask your admin to invite you.",
|
||||
},
|
||||
)
|
||||
raise PermissionError("User not on allowed user whitelist")
|
||||
|
||||
|
||||
def verify_email_in_whitelist(email: str, tenant_id: str) -> None:
|
||||
@@ -1671,10 +1650,7 @@ def get_oauth_router(
|
||||
if redirect_url is not None:
|
||||
authorize_redirect_url = redirect_url
|
||||
else:
|
||||
# Use WEB_DOMAIN instead of request.url_for() to prevent host
|
||||
# header poisoning — request.url_for() trusts the Host header.
|
||||
callback_path = request.app.url_path_for(callback_route_name)
|
||||
authorize_redirect_url = f"{WEB_DOMAIN}{callback_path}"
|
||||
authorize_redirect_url = str(request.url_for(callback_route_name))
|
||||
|
||||
next_url = request.query_params.get("next", "/")
|
||||
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
"""Celery tasks for hierarchy fetching."""
|
||||
|
||||
from onyx.background.celery.tasks.hierarchyfetching.tasks import ( # noqa: F401
|
||||
check_for_hierarchy_fetching,
|
||||
)
|
||||
from onyx.background.celery.tasks.hierarchyfetching.tasks import ( # noqa: F401
|
||||
connector_hierarchy_fetching_task,
|
||||
)
|
||||
|
||||
__all__ = ["check_for_hierarchy_fetching", "connector_hierarchy_fetching_task"]
|
||||
|
||||
@@ -41,14 +41,3 @@ assert (
|
||||
CHECK_FOR_DOCUMENTS_TASK_LOCK_BLOCKING_TIMEOUT_S = 30 # 30 seconds.
|
||||
|
||||
TOTAL_ALLOWABLE_DOC_MIGRATION_ATTEMPTS_BEFORE_PERMANENT_FAILURE = 15
|
||||
|
||||
# WARNING: Do not change these values without knowing what changes also need to
|
||||
# be made to OpenSearchTenantMigrationRecord.
|
||||
GET_VESPA_CHUNKS_PAGE_SIZE = 500
|
||||
GET_VESPA_CHUNKS_SLICE_COUNT = 4
|
||||
|
||||
# String used to indicate in the vespa_visit_continuation_token mapping that the
|
||||
# slice has finished and there is nothing left to visit.
|
||||
FINISHED_VISITING_SLICE_CONTINUATION_TOKEN = (
|
||||
"FINISHED_VISITING_SLICE_CONTINUATION_TOKEN"
|
||||
)
|
||||
|
||||
@@ -8,12 +8,6 @@ from celery import Task
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.tasks.opensearch_migration.constants import (
|
||||
FINISHED_VISITING_SLICE_CONTINUATION_TOKEN,
|
||||
)
|
||||
from onyx.background.celery.tasks.opensearch_migration.constants import (
|
||||
GET_VESPA_CHUNKS_PAGE_SIZE,
|
||||
)
|
||||
from onyx.background.celery.tasks.opensearch_migration.constants import (
|
||||
MIGRATION_TASK_LOCK_BLOCKING_TIMEOUT_S,
|
||||
)
|
||||
@@ -53,13 +47,7 @@ from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
def is_continuation_token_done_for_all_slices(
|
||||
continuation_token_map: dict[int, str | None],
|
||||
) -> bool:
|
||||
return all(
|
||||
continuation_token == FINISHED_VISITING_SLICE_CONTINUATION_TOKEN
|
||||
for continuation_token in continuation_token_map.values()
|
||||
)
|
||||
GET_VESPA_CHUNKS_PAGE_SIZE = 1000
|
||||
|
||||
|
||||
# shared_task allows this task to be shared across celery app instances.
|
||||
@@ -88,15 +76,11 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
|
||||
Uses Vespa's Visit API to iterate through ALL chunks in bulk (not
|
||||
per-document), transform them, and index them into OpenSearch. Progress is
|
||||
tracked via a continuation token map stored in the
|
||||
tracked via a continuation token stored in the
|
||||
OpenSearchTenantMigrationRecord.
|
||||
|
||||
The first time we see no continuation token map and non-zero chunks
|
||||
migrated, we consider the migration complete and all subsequent invocations
|
||||
are no-ops.
|
||||
|
||||
We divide the index into GET_VESPA_CHUNKS_SLICE_COUNT independent slices
|
||||
where progress is tracked for each slice.
|
||||
The first time we see no continuation token and non-zero chunks migrated, we
|
||||
consider the migration complete and all subsequent invocations are no-ops.
|
||||
|
||||
Returns:
|
||||
None if OpenSearch migration is not enabled, or if the lock could not be
|
||||
@@ -169,28 +153,15 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
f"in {time.monotonic() - sanitized_doc_start_time:.3f} seconds."
|
||||
)
|
||||
|
||||
approx_chunk_count_in_vespa: int | None = None
|
||||
get_chunk_count_start_time = time.monotonic()
|
||||
try:
|
||||
approx_chunk_count_in_vespa = vespa_document_index.get_chunk_count()
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
"Error getting approximate chunk count in Vespa. Moving on..."
|
||||
)
|
||||
task_logger.debug(
|
||||
f"Took {time.monotonic() - get_chunk_count_start_time:.3f} seconds to attempt to get "
|
||||
f"approximate chunk count in Vespa. Got {approx_chunk_count_in_vespa}."
|
||||
)
|
||||
|
||||
while (
|
||||
time.monotonic() - task_start_time < MIGRATION_TASK_SOFT_TIME_LIMIT_S
|
||||
and lock.owned()
|
||||
):
|
||||
(
|
||||
continuation_token_map,
|
||||
continuation_token,
|
||||
total_chunks_migrated,
|
||||
) = get_vespa_visit_state(db_session)
|
||||
if is_continuation_token_done_for_all_slices(continuation_token_map):
|
||||
if continuation_token is None and total_chunks_migrated > 0:
|
||||
task_logger.info(
|
||||
f"OpenSearch migration COMPLETED for tenant {tenant_id}. "
|
||||
f"Total chunks migrated: {total_chunks_migrated}."
|
||||
@@ -199,19 +170,19 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
break
|
||||
task_logger.debug(
|
||||
f"Read the tenant migration record. Total chunks migrated: {total_chunks_migrated}. "
|
||||
f"Continuation token map: {continuation_token_map}"
|
||||
f"Continuation token: {continuation_token}"
|
||||
)
|
||||
|
||||
get_vespa_chunks_start_time = time.monotonic()
|
||||
raw_vespa_chunks, next_continuation_token_map = (
|
||||
raw_vespa_chunks, next_continuation_token = (
|
||||
vespa_document_index.get_all_raw_document_chunks_paginated(
|
||||
continuation_token_map=continuation_token_map,
|
||||
continuation_token=continuation_token,
|
||||
page_size=GET_VESPA_CHUNKS_PAGE_SIZE,
|
||||
)
|
||||
)
|
||||
task_logger.debug(
|
||||
f"Read {len(raw_vespa_chunks)} chunks from Vespa in {time.monotonic() - get_vespa_chunks_start_time:.3f} "
|
||||
f"seconds. Next continuation token map: {next_continuation_token_map}"
|
||||
f"seconds. Next continuation token: {next_continuation_token}"
|
||||
)
|
||||
|
||||
opensearch_document_chunks, errored_chunks = (
|
||||
@@ -241,11 +212,14 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
total_chunks_errored_this_task += len(errored_chunks)
|
||||
update_vespa_visit_progress_with_commit(
|
||||
db_session,
|
||||
continuation_token_map=next_continuation_token_map,
|
||||
continuation_token=next_continuation_token,
|
||||
chunks_processed=len(opensearch_document_chunks),
|
||||
chunks_errored=len(errored_chunks),
|
||||
approx_chunk_count_in_vespa=approx_chunk_count_in_vespa,
|
||||
)
|
||||
|
||||
if next_continuation_token is None and len(raw_vespa_chunks) == 0:
|
||||
task_logger.info("Vespa reported no more chunks to migrate.")
|
||||
break
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
task_logger.exception("Error in the OpenSearch migration task.")
|
||||
|
||||
@@ -37,35 +37,6 @@ from shared_configs.configs import MULTI_TENANT
|
||||
logger = setup_logger(__name__)
|
||||
|
||||
|
||||
FIELDS_NEEDED_FOR_TRANSFORMATION: list[str] = [
|
||||
DOCUMENT_ID,
|
||||
CHUNK_ID,
|
||||
TITLE,
|
||||
TITLE_EMBEDDING,
|
||||
CONTENT,
|
||||
EMBEDDINGS,
|
||||
SOURCE_TYPE,
|
||||
METADATA_LIST,
|
||||
DOC_UPDATED_AT,
|
||||
HIDDEN,
|
||||
BOOST,
|
||||
SEMANTIC_IDENTIFIER,
|
||||
IMAGE_FILE_NAME,
|
||||
SOURCE_LINKS,
|
||||
BLURB,
|
||||
DOC_SUMMARY,
|
||||
CHUNK_CONTEXT,
|
||||
METADATA_SUFFIX,
|
||||
DOCUMENT_SETS,
|
||||
USER_PROJECT,
|
||||
PRIMARY_OWNERS,
|
||||
SECONDARY_OWNERS,
|
||||
ACCESS_CONTROL_LIST,
|
||||
]
|
||||
if MULTI_TENANT:
|
||||
FIELDS_NEEDED_FOR_TRANSFORMATION.append(TENANT_ID)
|
||||
|
||||
|
||||
def _extract_content_vector(embeddings: Any) -> list[float]:
|
||||
"""Extracts the full chunk embedding vector from Vespa's embeddings tensor.
|
||||
|
||||
|
||||
@@ -0,0 +1,8 @@
|
||||
"""Celery tasks for connector pruning."""
|
||||
|
||||
from onyx.background.celery.tasks.pruning.tasks import check_for_pruning # noqa: F401
|
||||
from onyx.background.celery.tasks.pruning.tasks import ( # noqa: F401
|
||||
connector_pruning_generator_task,
|
||||
)
|
||||
|
||||
__all__ = ["check_for_pruning", "connector_pruning_generator_task"]
|
||||
|
||||
@@ -13,7 +13,6 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.background.celery.tasks.shared.RetryDocumentIndex import RetryDocumentIndex
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
@@ -22,14 +21,12 @@ from onyx.configs.app_configs import VESPA_CLOUD_CERT_PATH
|
||||
from onyx.configs.app_configs import VESPA_CLOUD_KEY_PATH
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
from onyx.configs.constants import CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import USER_FILE_PROCESSING_MAX_QUEUE_DEPTH
|
||||
from onyx.connectors.file.connector import LocalFileConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import HierarchyNode
|
||||
@@ -60,17 +57,6 @@ def _user_file_lock_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_PROCESSING_LOCK_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
def _user_file_queued_key(user_file_id: str | UUID) -> str:
|
||||
"""Key that exists while a process_single_user_file task is sitting in the queue.
|
||||
|
||||
The beat generator sets this with a TTL equal to CELERY_USER_FILE_PROCESSING_TASK_EXPIRES
|
||||
before enqueuing and the worker deletes it as its first action. This prevents
|
||||
the beat from adding duplicate tasks for files that already have a live task
|
||||
in flight.
|
||||
"""
|
||||
return f"{OnyxRedisLocks.USER_FILE_QUEUED_PREFIX}:{user_file_id}"
|
||||
|
||||
|
||||
def _user_file_project_sync_lock_key(user_file_id: str | UUID) -> str:
|
||||
return f"{OnyxRedisLocks.USER_FILE_PROJECT_SYNC_LOCK_PREFIX}:{user_file_id}"
|
||||
|
||||
@@ -134,24 +120,7 @@ def _get_document_chunk_count(
|
||||
def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
"""Scan for user files with PROCESSING status and enqueue per-file tasks.
|
||||
|
||||
Three mechanisms prevent queue runaway:
|
||||
|
||||
1. **Queue depth backpressure** – if the broker queue already has more than
|
||||
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH items we skip this beat cycle
|
||||
entirely. Workers are clearly behind; adding more tasks would only make
|
||||
the backlog worse.
|
||||
|
||||
2. **Per-file queued guard** – before enqueuing a task we set a short-lived
|
||||
Redis key (TTL = CELERY_USER_FILE_PROCESSING_TASK_EXPIRES). If that key
|
||||
already exists the file already has a live task in the queue, so we skip
|
||||
it. The worker deletes the key the moment it picks up the task so the
|
||||
next beat cycle can re-enqueue if the file is still PROCESSING.
|
||||
|
||||
3. **Task expiry** – every enqueued task carries an `expires` value equal to
|
||||
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES. If a task is still sitting in
|
||||
the queue after that deadline, Celery discards it without touching the DB.
|
||||
This is a belt-and-suspenders defence: even if the guard key is lost (e.g.
|
||||
Redis restart), stale tasks evict themselves rather than piling up forever.
|
||||
Uses direct Redis locks to avoid overlapping runs.
|
||||
"""
|
||||
task_logger.info("check_user_file_processing - Starting")
|
||||
|
||||
@@ -166,21 +135,7 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
return None
|
||||
|
||||
enqueued = 0
|
||||
skipped_guard = 0
|
||||
try:
|
||||
# --- Protection 1: queue depth backpressure ---
|
||||
r_celery = self.app.broker_connection().channel().client # type: ignore
|
||||
queue_len = celery_get_queue_length(
|
||||
OnyxCeleryQueues.USER_FILE_PROCESSING, r_celery
|
||||
)
|
||||
if queue_len > USER_FILE_PROCESSING_MAX_QUEUE_DEPTH:
|
||||
task_logger.warning(
|
||||
f"check_user_file_processing - Queue depth {queue_len} exceeds "
|
||||
f"{USER_FILE_PROCESSING_MAX_QUEUE_DEPTH}, skipping enqueue for "
|
||||
f"tenant={tenant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
user_file_ids = (
|
||||
db_session.execute(
|
||||
@@ -193,35 +148,12 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
)
|
||||
|
||||
for user_file_id in user_file_ids:
|
||||
# --- Protection 2: per-file queued guard ---
|
||||
queued_key = _user_file_queued_key(user_file_id)
|
||||
guard_set = redis_client.set(
|
||||
queued_key,
|
||||
1,
|
||||
ex=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
|
||||
nx=True,
|
||||
self.app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
kwargs={"user_file_id": str(user_file_id), "tenant_id": tenant_id},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
if not guard_set:
|
||||
skipped_guard += 1
|
||||
continue
|
||||
|
||||
# --- Protection 3: task expiry ---
|
||||
# If task submission fails, clear the guard immediately so the
|
||||
# next beat cycle can retry enqueuing this file.
|
||||
try:
|
||||
self.app.send_task(
|
||||
OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
|
||||
kwargs={
|
||||
"user_file_id": str(user_file_id),
|
||||
"tenant_id": tenant_id,
|
||||
},
|
||||
queue=OnyxCeleryQueues.USER_FILE_PROCESSING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
expires=CELERY_USER_FILE_PROCESSING_TASK_EXPIRES,
|
||||
)
|
||||
except Exception:
|
||||
redis_client.delete(queued_key)
|
||||
raise
|
||||
enqueued += 1
|
||||
|
||||
finally:
|
||||
@@ -229,8 +161,7 @@ def check_user_file_processing(self: Task, *, tenant_id: str) -> None:
|
||||
lock.release()
|
||||
|
||||
task_logger.info(
|
||||
f"check_user_file_processing - Enqueued {enqueued} skipped_guard={skipped_guard} "
|
||||
f"tasks for tenant={tenant_id}"
|
||||
f"check_user_file_processing - Enqueued {enqueued} tasks for tenant={tenant_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -373,12 +304,6 @@ def process_single_user_file(
|
||||
start = time.monotonic()
|
||||
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# Clear the "queued" guard set by the beat generator so that the next beat
|
||||
# cycle can re-enqueue this file if it is still in PROCESSING state after
|
||||
# this task completes or fails.
|
||||
redis_client.delete(_user_file_queued_key(user_file_id))
|
||||
|
||||
file_lock: RedisLock = redis_client.lock(
|
||||
_user_file_lock_key(user_file_id),
|
||||
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
@@ -46,7 +45,6 @@ from onyx.utils.timing import log_function_time
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
IMAGE_GENERATION_TOOL_NAME = "generate_image"
|
||||
|
||||
|
||||
def create_chat_session_from_request(
|
||||
@@ -424,40 +422,6 @@ def convert_chat_history_basic(
|
||||
return list(reversed(trimmed_reversed))
|
||||
|
||||
|
||||
def _build_tool_call_response_history_message(
|
||||
tool_name: str,
|
||||
generated_images: list[dict] | None,
|
||||
tool_call_response: str | None,
|
||||
) -> str:
|
||||
if tool_name != IMAGE_GENERATION_TOOL_NAME:
|
||||
return TOOL_CALL_RESPONSE_CROSS_MESSAGE
|
||||
|
||||
if generated_images:
|
||||
llm_image_context: list[dict[str, str]] = []
|
||||
for image in generated_images:
|
||||
file_id = image.get("file_id")
|
||||
revised_prompt = image.get("revised_prompt")
|
||||
if not isinstance(file_id, str):
|
||||
continue
|
||||
|
||||
llm_image_context.append(
|
||||
{
|
||||
"file_id": file_id,
|
||||
"revised_prompt": (
|
||||
revised_prompt if isinstance(revised_prompt, str) else ""
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
if llm_image_context:
|
||||
return json.dumps(llm_image_context)
|
||||
|
||||
if tool_call_response:
|
||||
return tool_call_response
|
||||
|
||||
return TOOL_CALL_RESPONSE_CROSS_MESSAGE
|
||||
|
||||
|
||||
def convert_chat_history(
|
||||
chat_history: list[ChatMessage],
|
||||
files: list[ChatLoadedFile],
|
||||
@@ -618,24 +582,10 @@ def convert_chat_history(
|
||||
|
||||
# Add TOOL_CALL_RESPONSE messages for each tool call in this turn
|
||||
for tool_call in turn_tool_calls:
|
||||
tool_name = tool_id_to_name_map.get(
|
||||
tool_call.tool_id, "unknown"
|
||||
)
|
||||
tool_response_message = (
|
||||
_build_tool_call_response_history_message(
|
||||
tool_name=tool_name,
|
||||
generated_images=tool_call.generated_images,
|
||||
tool_call_response=tool_call.tool_call_response,
|
||||
)
|
||||
)
|
||||
simple_messages.append(
|
||||
ChatMessageSimple(
|
||||
message=tool_response_message,
|
||||
token_count=(
|
||||
token_counter(tool_response_message)
|
||||
if tool_name == IMAGE_GENERATION_TOOL_NAME
|
||||
else 20
|
||||
),
|
||||
message=TOOL_CALL_RESPONSE_CROSS_MESSAGE,
|
||||
token_count=20, # Tiny overestimate
|
||||
message_type=MessageType.TOOL_CALL_RESPONSE,
|
||||
tool_call_id=tool_call.tool_call_id,
|
||||
image_files=None,
|
||||
|
||||
@@ -57,7 +57,6 @@ from onyx.tools.tool_implementations.images.models import (
|
||||
FinalImageGenerationResponse,
|
||||
)
|
||||
from onyx.tools.tool_implementations.memory.models import MemoryToolResponse
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.tool_implementations.web_search.utils import extract_url_snippet_map
|
||||
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
|
||||
@@ -652,7 +651,6 @@ def run_llm_loop(
|
||||
ran_image_gen: bool = False
|
||||
just_ran_web_search: bool = False
|
||||
has_called_search_tool: bool = False
|
||||
code_interpreter_file_generated: bool = False
|
||||
fallback_extraction_attempted: bool = False
|
||||
citation_mapping: dict[int, str] = {} # Maps citation_num -> document_id/URL
|
||||
|
||||
@@ -763,7 +761,6 @@ def run_llm_loop(
|
||||
),
|
||||
include_citation_reminder=should_cite_documents
|
||||
or always_cite_documents,
|
||||
include_file_reminder=code_interpreter_file_generated,
|
||||
is_last_cycle=out_of_cycles,
|
||||
)
|
||||
|
||||
@@ -903,18 +900,6 @@ def run_llm_loop(
|
||||
if tool_call.tool_name == SearchTool.NAME:
|
||||
has_called_search_tool = True
|
||||
|
||||
# Track if code interpreter generated files with download links
|
||||
if (
|
||||
tool_call.tool_name == PythonTool.NAME
|
||||
and not code_interpreter_file_generated
|
||||
):
|
||||
try:
|
||||
parsed = json.loads(tool_response.llm_facing_response)
|
||||
if parsed.get("generated_files"):
|
||||
code_interpreter_file_generated = True
|
||||
except (json.JSONDecodeError, AttributeError):
|
||||
pass
|
||||
|
||||
# Build a mapping of tool names to tool objects for getting tool_id
|
||||
tools_by_name = {tool.name: tool for tool in final_tools}
|
||||
|
||||
|
||||
@@ -10,7 +10,6 @@ from onyx.db.user_file import calculate_user_files_token_count
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.prompts.chat_prompts import CITATION_REMINDER
|
||||
from onyx.prompts.chat_prompts import DEFAULT_SYSTEM_PROMPT
|
||||
from onyx.prompts.chat_prompts import FILE_REMINDER
|
||||
from onyx.prompts.chat_prompts import LAST_CYCLE_CITATION_REMINDER
|
||||
from onyx.prompts.chat_prompts import REQUIRE_CITATION_GUIDANCE
|
||||
from onyx.prompts.prompt_utils import get_company_context
|
||||
@@ -126,7 +125,6 @@ def calculate_reserved_tokens(
|
||||
def build_reminder_message(
|
||||
reminder_text: str | None,
|
||||
include_citation_reminder: bool,
|
||||
include_file_reminder: bool,
|
||||
is_last_cycle: bool,
|
||||
) -> str | None:
|
||||
reminder = reminder_text.strip() if reminder_text else ""
|
||||
@@ -134,8 +132,6 @@ def build_reminder_message(
|
||||
reminder += "\n\n" + LAST_CYCLE_CITATION_REMINDER
|
||||
if include_citation_reminder:
|
||||
reminder += "\n\n" + CITATION_REMINDER
|
||||
if include_file_reminder:
|
||||
reminder += "\n\n" + FILE_REMINDER
|
||||
reminder = reminder.strip()
|
||||
return reminder if reminder else None
|
||||
|
||||
@@ -190,7 +186,7 @@ def _build_user_information_section(
|
||||
if not sections:
|
||||
return ""
|
||||
|
||||
return USER_INFORMATION_HEADER + "\n".join(sections)
|
||||
return USER_INFORMATION_HEADER + "".join(sections)
|
||||
|
||||
|
||||
def build_system_prompt(
|
||||
@@ -228,21 +224,23 @@ def build_system_prompt(
|
||||
system_prompt += REQUIRE_CITATION_GUIDANCE
|
||||
|
||||
if include_all_guidance:
|
||||
tool_sections = [
|
||||
TOOL_DESCRIPTION_SEARCH_GUIDANCE,
|
||||
INTERNAL_SEARCH_GUIDANCE,
|
||||
WEB_SEARCH_GUIDANCE.format(
|
||||
system_prompt += (
|
||||
TOOL_SECTION_HEADER
|
||||
+ TOOL_DESCRIPTION_SEARCH_GUIDANCE
|
||||
+ INTERNAL_SEARCH_GUIDANCE
|
||||
+ WEB_SEARCH_GUIDANCE.format(
|
||||
site_colon_disabled=WEB_SEARCH_SITE_DISABLED_GUIDANCE
|
||||
),
|
||||
OPEN_URLS_GUIDANCE,
|
||||
PYTHON_TOOL_GUIDANCE,
|
||||
GENERATE_IMAGE_GUIDANCE,
|
||||
MEMORY_GUIDANCE,
|
||||
]
|
||||
system_prompt += TOOL_SECTION_HEADER + "\n".join(tool_sections)
|
||||
)
|
||||
+ OPEN_URLS_GUIDANCE
|
||||
+ PYTHON_TOOL_GUIDANCE
|
||||
+ GENERATE_IMAGE_GUIDANCE
|
||||
+ MEMORY_GUIDANCE
|
||||
)
|
||||
return system_prompt
|
||||
|
||||
if tools:
|
||||
system_prompt += TOOL_SECTION_HEADER
|
||||
|
||||
has_web_search = any(isinstance(tool, WebSearchTool) for tool in tools)
|
||||
has_internal_search = any(isinstance(tool, SearchTool) for tool in tools)
|
||||
has_open_urls = any(isinstance(tool, OpenURLTool) for tool in tools)
|
||||
@@ -252,14 +250,12 @@ def build_system_prompt(
|
||||
)
|
||||
has_memory = any(isinstance(tool, MemoryTool) for tool in tools)
|
||||
|
||||
tool_guidance_sections: list[str] = []
|
||||
|
||||
if has_web_search or has_internal_search or include_all_guidance:
|
||||
tool_guidance_sections.append(TOOL_DESCRIPTION_SEARCH_GUIDANCE)
|
||||
system_prompt += TOOL_DESCRIPTION_SEARCH_GUIDANCE
|
||||
|
||||
# These are not included at the Tool level because the ordering may matter.
|
||||
if has_internal_search or include_all_guidance:
|
||||
tool_guidance_sections.append(INTERNAL_SEARCH_GUIDANCE)
|
||||
system_prompt += INTERNAL_SEARCH_GUIDANCE
|
||||
|
||||
if has_web_search or include_all_guidance:
|
||||
site_disabled_guidance = ""
|
||||
@@ -269,23 +265,20 @@ def build_system_prompt(
|
||||
)
|
||||
if web_search_tool and not web_search_tool.supports_site_filter:
|
||||
site_disabled_guidance = WEB_SEARCH_SITE_DISABLED_GUIDANCE
|
||||
tool_guidance_sections.append(
|
||||
WEB_SEARCH_GUIDANCE.format(site_colon_disabled=site_disabled_guidance)
|
||||
system_prompt += WEB_SEARCH_GUIDANCE.format(
|
||||
site_colon_disabled=site_disabled_guidance
|
||||
)
|
||||
|
||||
if has_open_urls or include_all_guidance:
|
||||
tool_guidance_sections.append(OPEN_URLS_GUIDANCE)
|
||||
system_prompt += OPEN_URLS_GUIDANCE
|
||||
|
||||
if has_python or include_all_guidance:
|
||||
tool_guidance_sections.append(PYTHON_TOOL_GUIDANCE)
|
||||
system_prompt += PYTHON_TOOL_GUIDANCE
|
||||
|
||||
if has_generate_image or include_all_guidance:
|
||||
tool_guidance_sections.append(GENERATE_IMAGE_GUIDANCE)
|
||||
system_prompt += GENERATE_IMAGE_GUIDANCE
|
||||
|
||||
if has_memory or include_all_guidance:
|
||||
tool_guidance_sections.append(MEMORY_GUIDANCE)
|
||||
|
||||
if tool_guidance_sections:
|
||||
system_prompt += TOOL_SECTION_HEADER + "\n".join(tool_guidance_sections)
|
||||
system_prompt += MEMORY_GUIDANCE
|
||||
|
||||
return system_prompt
|
||||
|
||||
@@ -251,9 +251,7 @@ DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S = int(
|
||||
os.environ.get("DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S") or 50
|
||||
)
|
||||
OPENSEARCH_ADMIN_USERNAME = os.environ.get("OPENSEARCH_ADMIN_USERNAME", "admin")
|
||||
OPENSEARCH_ADMIN_PASSWORD = os.environ.get(
|
||||
"OPENSEARCH_ADMIN_PASSWORD", "StrongPassword123!"
|
||||
)
|
||||
OPENSEARCH_ADMIN_PASSWORD = os.environ.get("OPENSEARCH_ADMIN_PASSWORD", "")
|
||||
USING_AWS_MANAGED_OPENSEARCH = (
|
||||
os.environ.get("USING_AWS_MANAGED_OPENSEARCH", "").lower() == "true"
|
||||
)
|
||||
@@ -265,18 +263,6 @@ OPENSEARCH_PROFILING_DISABLED = (
|
||||
os.environ.get("OPENSEARCH_PROFILING_DISABLED", "").lower() == "true"
|
||||
)
|
||||
|
||||
# When enabled, OpenSearch returns detailed score breakdowns for each hit.
|
||||
# Useful for debugging and tuning search relevance. Has ~10-30% performance overhead according to documentation.
|
||||
# Seems for Hybrid Search in practice, the impact is actually more like 1000x slower.
|
||||
OPENSEARCH_EXPLAIN_ENABLED = (
|
||||
os.environ.get("OPENSEARCH_EXPLAIN_ENABLED", "").lower() == "true"
|
||||
)
|
||||
|
||||
# Analyzer used for full-text fields (title, content). Use OpenSearch built-in analyzer
|
||||
# names (e.g. "english", "standard", "german"). Affects stemming and tokenization;
|
||||
# existing indices need reindexing after a change.
|
||||
OPENSEARCH_TEXT_ANALYZER = os.environ.get("OPENSEARCH_TEXT_ANALYZER") or "english"
|
||||
|
||||
# This is the "base" config for now, the idea is that at least for our dev
|
||||
# environments we always want to be dual indexing into both OpenSearch and Vespa
|
||||
# to stress test the new codepaths. Only enable this if there is some instance
|
||||
@@ -284,9 +270,6 @@ OPENSEARCH_TEXT_ANALYZER = os.environ.get("OPENSEARCH_TEXT_ANALYZER") or "englis
|
||||
ENABLE_OPENSEARCH_INDEXING_FOR_ONYX = (
|
||||
os.environ.get("ENABLE_OPENSEARCH_INDEXING_FOR_ONYX", "").lower() == "true"
|
||||
)
|
||||
# NOTE: This effectively does nothing anymore, admins can now toggle whether
|
||||
# retrieval is through OpenSearch. This value is only used as a final fallback
|
||||
# in case that doesn't work for whatever reason.
|
||||
# Given that the "base" config above is true, this enables whether we want to
|
||||
# retrieve from OpenSearch or Vespa. We want to be able to quickly toggle this
|
||||
# in the event we see issues with OpenSearch retrieval in our dev environments.
|
||||
|
||||
@@ -157,17 +157,6 @@ CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT = 300 # 5 min
|
||||
|
||||
CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT = 30 * 60 # 30 minutes (in seconds)
|
||||
|
||||
# How long a queued user-file task is valid before workers discard it.
|
||||
# Should be longer than the beat interval (20 s) but short enough to prevent
|
||||
# indefinite queue growth. Workers drop tasks older than this without touching
|
||||
# the DB, so a shorter value = faster drain of stale duplicates.
|
||||
CELERY_USER_FILE_PROCESSING_TASK_EXPIRES = 60 # 1 minute (in seconds)
|
||||
|
||||
# Maximum number of tasks allowed in the user-file-processing queue before the
|
||||
# beat generator stops adding more. Prevents unbounded queue growth when workers
|
||||
# fall behind.
|
||||
USER_FILE_PROCESSING_MAX_QUEUE_DEPTH = 500
|
||||
|
||||
CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
|
||||
|
||||
CELERY_SANDBOX_FILE_SYNC_LOCK_TIMEOUT = 5 * 60 # 5 minutes (in seconds)
|
||||
@@ -454,9 +443,6 @@ class OnyxRedisLocks:
|
||||
# User file processing
|
||||
USER_FILE_PROCESSING_BEAT_LOCK = "da_lock:check_user_file_processing_beat"
|
||||
USER_FILE_PROCESSING_LOCK_PREFIX = "da_lock:user_file_processing"
|
||||
# Short-lived key set when a task is enqueued; cleared when the worker picks it up.
|
||||
# Prevents the beat from re-enqueuing the same file while a task is already queued.
|
||||
USER_FILE_QUEUED_PREFIX = "da_lock:user_file_queued"
|
||||
USER_FILE_PROJECT_SYNC_BEAT_LOCK = "da_lock:check_user_file_project_sync_beat"
|
||||
USER_FILE_PROJECT_SYNC_LOCK_PREFIX = "da_lock:user_file_project_sync"
|
||||
USER_FILE_DELETE_BEAT_LOCK = "da_lock:check_user_file_delete_beat"
|
||||
|
||||
@@ -71,7 +71,6 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
SLIM_BATCH_SIZE = 1000
|
||||
_EPOCH = datetime.fromtimestamp(0, tz=timezone.utc)
|
||||
|
||||
|
||||
SHARED_DOCUMENTS_MAP = {
|
||||
@@ -244,12 +243,6 @@ class SharepointConnectorCheckpoint(ConnectorCheckpoint):
|
||||
current_drive_name: str | None = None
|
||||
# Drive's web_url from the API - used as raw_node_id for DRIVE hierarchy nodes
|
||||
current_drive_web_url: str | None = None
|
||||
# Resolved drive ID — avoids re-resolving on checkpoint resume
|
||||
current_drive_id: str | None = None
|
||||
# Next delta API page URL for per-page checkpointing within a drive.
|
||||
# When set, Phase 3b fetches one page at a time so progress is persisted
|
||||
# between pages. None means BFS path or no active delta traversal.
|
||||
current_drive_delta_next_link: str | None = None
|
||||
|
||||
process_site_pages: bool = False
|
||||
|
||||
@@ -1273,8 +1266,7 @@ class SharepointConnector(
|
||||
"""
|
||||
base = f"{self.graph_api_base}/drives/{drive_id}"
|
||||
if folder_path:
|
||||
encoded_path = quote(folder_path, safe="/")
|
||||
start_url = f"{base}/root:/{encoded_path}:/children"
|
||||
start_url = f"{base}/root:/{folder_path}:/children"
|
||||
else:
|
||||
start_url = f"{base}/root/children"
|
||||
|
||||
@@ -1330,12 +1322,13 @@ class SharepointConnector(
|
||||
|
||||
Falls back to full enumeration if the API returns 410 Gone (expired token).
|
||||
"""
|
||||
use_timestamp_token = start is not None and start > _EPOCH
|
||||
EPOCH = datetime(1970, 1, 1, tzinfo=timezone.utc)
|
||||
use_timestamp_token = start is not None and start > EPOCH
|
||||
|
||||
initial_url = f"{self.graph_api_base}/drives/{drive_id}/root/delta"
|
||||
if use_timestamp_token:
|
||||
assert start is not None # mypy
|
||||
token = quote(start.isoformat(timespec="seconds"))
|
||||
assert start is not None # purely for mypy
|
||||
token = quote(start.strftime("%Y-%m-%dT%H:%M:%SZ"))
|
||||
initial_url += f"?token={token}"
|
||||
|
||||
yield from self._iter_delta_pages(
|
||||
@@ -1368,7 +1361,6 @@ class SharepointConnector(
|
||||
try:
|
||||
data = self._graph_api_get_json(page_url, params)
|
||||
except requests.HTTPError as e:
|
||||
# 410 means the delta token expired, so we need to fall back to full enumeration
|
||||
if e.response is not None and e.response.status_code == 410:
|
||||
if not allow_full_resync:
|
||||
raise
|
||||
@@ -1405,91 +1397,10 @@ class SharepointConnector(
|
||||
|
||||
yield DriveItemData.from_graph_json(item)
|
||||
|
||||
page_url = data.get("@odata.nextLink")
|
||||
if not page_url:
|
||||
page_url = data.get("@odata.nextLink") or data.get("@odata.deltaLink")
|
||||
if "@odata.deltaLink" in data and "@odata.nextLink" not in data:
|
||||
break
|
||||
|
||||
def _build_delta_start_url(
|
||||
self,
|
||||
drive_id: str,
|
||||
start: datetime | None = None,
|
||||
page_size: int = 200,
|
||||
) -> str:
|
||||
"""Build the initial delta API URL with query parameters embedded.
|
||||
|
||||
Embeds ``$top`` (and optionally a timestamp ``token``) directly in the
|
||||
URL so that the returned string is fully self-contained and can be
|
||||
stored in a checkpoint without needing a separate params dict.
|
||||
"""
|
||||
base_url = f"{self.graph_api_base}/drives/{drive_id}/root/delta"
|
||||
params = [f"$top={page_size}"]
|
||||
if start is not None and start > _EPOCH:
|
||||
token = quote(start.isoformat(timespec="seconds"))
|
||||
params.append(f"token={token}")
|
||||
return f"{base_url}?{'&'.join(params)}"
|
||||
|
||||
def _fetch_one_delta_page(
|
||||
self,
|
||||
page_url: str,
|
||||
drive_id: str,
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
page_size: int = 200,
|
||||
) -> tuple[list[DriveItemData], str | None]:
|
||||
"""Fetch a single page of delta API results.
|
||||
|
||||
Returns ``(items, next_page_url)``. *next_page_url* is ``None`` when
|
||||
the delta enumeration is complete (deltaLink with no nextLink).
|
||||
|
||||
On 410 Gone (expired token) returns ``([], full_resync_url)`` so
|
||||
the caller can store the resync URL in the checkpoint and retry on
|
||||
the next cycle.
|
||||
"""
|
||||
try:
|
||||
data = self._graph_api_get_json(page_url)
|
||||
except requests.HTTPError as e:
|
||||
if e.response is not None and e.response.status_code == 410:
|
||||
logger.warning(
|
||||
"Delta token expired (410 Gone) for drive '%s'. "
|
||||
"Will restart with full delta enumeration.",
|
||||
drive_id,
|
||||
)
|
||||
full_url = (
|
||||
f"{self.graph_api_base}/drives/{drive_id}/root/delta"
|
||||
f"?$top={page_size}"
|
||||
)
|
||||
return [], full_url
|
||||
raise
|
||||
|
||||
items: list[DriveItemData] = []
|
||||
for item in data.get("value", []):
|
||||
if "folder" in item or "deleted" in item:
|
||||
continue
|
||||
if start is not None or end is not None:
|
||||
raw_ts = item.get("lastModifiedDateTime")
|
||||
if raw_ts:
|
||||
mod_dt = datetime.fromisoformat(raw_ts.replace("Z", "+00:00"))
|
||||
if start is not None and mod_dt < start:
|
||||
continue
|
||||
if end is not None and mod_dt > end:
|
||||
continue
|
||||
items.append(DriveItemData.from_graph_json(item))
|
||||
|
||||
next_url = data.get("@odata.nextLink")
|
||||
if next_url:
|
||||
return items, next_url
|
||||
return items, None
|
||||
|
||||
@staticmethod
|
||||
def _clear_drive_checkpoint_state(
|
||||
checkpoint: "SharepointConnectorCheckpoint",
|
||||
) -> None:
|
||||
"""Reset all drive-level fields in the checkpoint."""
|
||||
checkpoint.current_drive_name = None
|
||||
checkpoint.current_drive_id = None
|
||||
checkpoint.current_drive_web_url = None
|
||||
checkpoint.current_drive_delta_next_link = None
|
||||
|
||||
def _fetch_slim_documents_from_sharepoint(self) -> GenerateSlimDocumentOutput:
|
||||
site_descriptors = self.site_descriptors or self.fetch_sites()
|
||||
|
||||
@@ -1931,13 +1842,14 @@ class SharepointConnector(
|
||||
# Return checkpoint to allow persistence after drive initialization
|
||||
return checkpoint
|
||||
|
||||
# Phase 3a: Initialize the next drive for processing
|
||||
# Phase 3: Process documents from current drive
|
||||
if (
|
||||
checkpoint.current_site_descriptor
|
||||
and checkpoint.cached_drive_names
|
||||
and len(checkpoint.cached_drive_names) > 0
|
||||
and checkpoint.current_drive_name is None
|
||||
):
|
||||
|
||||
checkpoint.current_drive_name = checkpoint.cached_drive_names.popleft()
|
||||
|
||||
start_dt = datetime.fromtimestamp(start, tz=timezone.utc)
|
||||
@@ -1945,8 +1857,7 @@ class SharepointConnector(
|
||||
site_descriptor = checkpoint.current_site_descriptor
|
||||
|
||||
logger.info(
|
||||
f"Processing drive '{checkpoint.current_drive_name}' "
|
||||
f"in site: {site_descriptor.url}"
|
||||
f"Processing drive '{checkpoint.current_drive_name}' in site: {site_descriptor.url}"
|
||||
)
|
||||
logger.debug(f"Time range: {start_dt} to {end_dt}")
|
||||
|
||||
@@ -1955,35 +1866,35 @@ class SharepointConnector(
|
||||
logger.warning("Current drive name is None, skipping")
|
||||
return checkpoint
|
||||
|
||||
driveitems: Iterable[DriveItemData] = iter(())
|
||||
drive_web_url: str | None = None
|
||||
try:
|
||||
logger.info(
|
||||
f"Fetching drive items for drive name: {current_drive_name}"
|
||||
)
|
||||
result = self._resolve_drive(site_descriptor, current_drive_name)
|
||||
if result is None:
|
||||
logger.warning(f"Drive '{current_drive_name}' not found, skipping")
|
||||
self._clear_drive_checkpoint_state(checkpoint)
|
||||
return checkpoint
|
||||
|
||||
drive_id, drive_web_url = result
|
||||
checkpoint.current_drive_id = drive_id
|
||||
checkpoint.current_drive_web_url = drive_web_url
|
||||
if result is not None:
|
||||
drive_id, drive_web_url = result
|
||||
driveitems = self._get_drive_items_for_drive_id(
|
||||
site_descriptor, drive_id, start_dt, end_dt
|
||||
)
|
||||
checkpoint.current_drive_web_url = drive_web_url
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to retrieve items from drive '{current_drive_name}' "
|
||||
f"in site: {site_descriptor.url}: {e}"
|
||||
f"Failed to retrieve items from drive '{current_drive_name}' in site: {site_descriptor.url}: {e}"
|
||||
)
|
||||
yield _create_entity_failure(
|
||||
f"{site_descriptor.url}|{current_drive_name}",
|
||||
f"Failed to access drive '{current_drive_name}' "
|
||||
f"in site '{site_descriptor.url}': {str(e)}",
|
||||
f"Failed to access drive '{current_drive_name}' in site '{site_descriptor.url}': {str(e)}",
|
||||
(start_dt, end_dt),
|
||||
e,
|
||||
)
|
||||
self._clear_drive_checkpoint_state(checkpoint)
|
||||
checkpoint.current_drive_name = None
|
||||
checkpoint.current_drive_web_url = None
|
||||
return checkpoint
|
||||
|
||||
display_drive_name = SHARED_DOCUMENTS_MAP.get(
|
||||
# Normalize drive name (e.g., "Documents" -> "Shared Documents")
|
||||
current_drive_name = SHARED_DOCUMENTS_MAP.get(
|
||||
current_drive_name, current_drive_name
|
||||
)
|
||||
|
||||
@@ -1991,74 +1902,10 @@ class SharepointConnector(
|
||||
yield from self._yield_drive_hierarchy_node(
|
||||
site_descriptor.url,
|
||||
drive_web_url,
|
||||
display_drive_name,
|
||||
current_drive_name,
|
||||
checkpoint,
|
||||
)
|
||||
|
||||
# For non-folder-scoped drives, use delta API with per-page
|
||||
# checkpointing. Build the initial URL and fall through to 3b.
|
||||
if not site_descriptor.folder_path:
|
||||
checkpoint.current_drive_delta_next_link = self._build_delta_start_url(
|
||||
drive_id, start_dt
|
||||
)
|
||||
# else: BFS path — delta_next_link stays None;
|
||||
# Phase 3b will use _iter_drive_items_paged.
|
||||
|
||||
# Phase 3b: Process items from the current drive
|
||||
if (
|
||||
checkpoint.current_site_descriptor
|
||||
and checkpoint.current_drive_name is not None
|
||||
and checkpoint.current_drive_id is not None
|
||||
):
|
||||
site_descriptor = checkpoint.current_site_descriptor
|
||||
start_dt = datetime.fromtimestamp(start, tz=timezone.utc)
|
||||
end_dt = datetime.fromtimestamp(end, tz=timezone.utc)
|
||||
current_drive_name = SHARED_DOCUMENTS_MAP.get(
|
||||
checkpoint.current_drive_name, checkpoint.current_drive_name
|
||||
)
|
||||
drive_web_url = checkpoint.current_drive_web_url
|
||||
|
||||
# --- determine item source ---
|
||||
driveitems: Iterable[DriveItemData]
|
||||
has_more_delta_pages = False
|
||||
|
||||
if checkpoint.current_drive_delta_next_link:
|
||||
# Delta path: fetch one page at a time for checkpointing
|
||||
try:
|
||||
page_items, next_url = self._fetch_one_delta_page(
|
||||
page_url=checkpoint.current_drive_delta_next_link,
|
||||
drive_id=checkpoint.current_drive_id,
|
||||
start=start_dt,
|
||||
end=end_dt,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to fetch delta page for drive "
|
||||
f"'{current_drive_name}': {e}"
|
||||
)
|
||||
yield _create_entity_failure(
|
||||
f"{site_descriptor.url}|{current_drive_name}",
|
||||
f"Failed to fetch delta page for drive "
|
||||
f"'{current_drive_name}': {str(e)}",
|
||||
(start_dt, end_dt),
|
||||
e,
|
||||
)
|
||||
self._clear_drive_checkpoint_state(checkpoint)
|
||||
return checkpoint
|
||||
|
||||
driveitems = page_items
|
||||
has_more_delta_pages = next_url is not None
|
||||
if next_url:
|
||||
checkpoint.current_drive_delta_next_link = next_url
|
||||
else:
|
||||
# BFS path (folder-scoped): process all items at once
|
||||
driveitems = self._iter_drive_items_paged(
|
||||
drive_id=checkpoint.current_drive_id,
|
||||
folder_path=site_descriptor.folder_path,
|
||||
start=start_dt,
|
||||
end=end_dt,
|
||||
)
|
||||
|
||||
item_count = 0
|
||||
for driveitem in driveitems:
|
||||
item_count += 1
|
||||
@@ -2100,6 +1947,8 @@ class SharepointConnector(
|
||||
if include_permissions:
|
||||
ctx = self._create_rest_client_context(site_descriptor.url)
|
||||
|
||||
# Re-acquire token in case it expired during a long traversal
|
||||
# MSAL has a cache that returns the same token while still valid.
|
||||
access_token = self._get_graph_access_token()
|
||||
doc_or_failure = _convert_driveitem_to_document_with_permissions(
|
||||
driveitem,
|
||||
@@ -2135,11 +1984,8 @@ class SharepointConnector(
|
||||
)
|
||||
|
||||
logger.info(f"Processed {item_count} items in drive '{current_drive_name}'")
|
||||
|
||||
if has_more_delta_pages:
|
||||
return checkpoint
|
||||
|
||||
self._clear_drive_checkpoint_state(checkpoint)
|
||||
checkpoint.current_drive_name = None
|
||||
checkpoint.current_drive_web_url = None
|
||||
|
||||
# Phase 4: Progression logic - determine next step
|
||||
# If we have more drives in current site, continue with current site
|
||||
|
||||
@@ -32,7 +32,6 @@ from onyx.context.search.federated.slack_search_utils import should_include_mess
|
||||
from onyx.context.search.models import ChunkIndexRequest
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.db.document import DocumentSource
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.document_index.document_index_utils import (
|
||||
get_multipass_config,
|
||||
@@ -906,15 +905,13 @@ def convert_slack_score(slack_score: float) -> float:
|
||||
def slack_retrieval(
|
||||
query: ChunkIndexRequest,
|
||||
access_token: str,
|
||||
db_session: Session | None = None,
|
||||
db_session: Session,
|
||||
connector: FederatedConnectorDetail | None = None, # noqa: ARG001
|
||||
entities: dict[str, Any] | None = None,
|
||||
limit: int | None = None,
|
||||
slack_event_context: SlackContext | None = None,
|
||||
bot_token: str | None = None, # Add bot token parameter
|
||||
team_id: str | None = None,
|
||||
# Pre-fetched data — when provided, avoids DB query (no session needed)
|
||||
search_settings: SearchSettings | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
"""
|
||||
Main entry point for Slack federated search with entity filtering.
|
||||
@@ -928,7 +925,7 @@ def slack_retrieval(
|
||||
Args:
|
||||
query: Search query object
|
||||
access_token: User OAuth access token
|
||||
db_session: Database session (optional if search_settings provided)
|
||||
db_session: Database session
|
||||
connector: Federated connector detail (unused, kept for backwards compat)
|
||||
entities: Connector-level config (entity filtering configuration)
|
||||
limit: Maximum number of results
|
||||
@@ -1156,10 +1153,7 @@ def slack_retrieval(
|
||||
|
||||
# chunk index docs into doc aware chunks
|
||||
# a single index doc can get split into multiple chunks
|
||||
if search_settings is None:
|
||||
if db_session is None:
|
||||
raise ValueError("Either db_session or search_settings must be provided")
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
embedder = DefaultIndexingEmbedder.from_db_search_settings(
|
||||
search_settings=search_settings
|
||||
)
|
||||
|
||||
@@ -18,10 +18,8 @@ from onyx.context.search.utils import inference_section_from_chunks
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.federated_connectors.federated_retrieval import FederatedRetrievalInfo
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.natural_language_processing.english_stopwords import strip_stopwords
|
||||
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from onyx.secondary_llm_flows.source_filter import extract_source_filter
|
||||
from onyx.secondary_llm_flows.time_filter import extract_time_filter
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -43,7 +41,7 @@ def _build_index_filters(
|
||||
user_file_ids: list[UUID] | None,
|
||||
persona_document_sets: list[str] | None,
|
||||
persona_time_cutoff: datetime | None,
|
||||
db_session: Session | None = None,
|
||||
db_session: Session,
|
||||
auto_detect_filters: bool = False,
|
||||
query: str | None = None,
|
||||
llm: LLM | None = None,
|
||||
@@ -51,8 +49,6 @@ def _build_index_filters(
|
||||
# Assistant knowledge filters
|
||||
attached_document_ids: list[str] | None = None,
|
||||
hierarchy_node_ids: list[int] | None = None,
|
||||
# Pre-fetched ACL filters (skips DB query when provided)
|
||||
acl_filters: list[str] | None = None,
|
||||
) -> IndexFilters:
|
||||
if auto_detect_filters and (llm is None or query is None):
|
||||
raise RuntimeError("LLM and query are required for auto detect filters")
|
||||
@@ -107,14 +103,9 @@ def _build_index_filters(
|
||||
source_filter = list(source_filter) + [DocumentSource.USER_FILE]
|
||||
logger.debug("Added USER_FILE to source_filter for user knowledge search")
|
||||
|
||||
if bypass_acl:
|
||||
user_acl_filters = None
|
||||
elif acl_filters is not None:
|
||||
user_acl_filters = acl_filters
|
||||
else:
|
||||
if db_session is None:
|
||||
raise ValueError("Either db_session or acl_filters must be provided")
|
||||
user_acl_filters = build_access_filters_for_user(user, db_session)
|
||||
user_acl_filters = (
|
||||
None if bypass_acl else build_access_filters_for_user(user, db_session)
|
||||
)
|
||||
|
||||
final_filters = IndexFilters(
|
||||
user_file_ids=user_file_ids,
|
||||
@@ -261,15 +252,11 @@ def search_pipeline(
|
||||
user: User,
|
||||
# Used for default filters and settings
|
||||
persona: Persona | None,
|
||||
db_session: Session | None = None,
|
||||
db_session: Session,
|
||||
auto_detect_filters: bool = False,
|
||||
llm: LLM | None = None,
|
||||
# If a project ID is provided, it will be exclusively scoped to that project
|
||||
project_id: int | None = None,
|
||||
# Pre-fetched data — when provided, avoids DB queries (no session needed)
|
||||
acl_filters: list[str] | None = None,
|
||||
embedding_model: EmbeddingModel | None = None,
|
||||
prefetched_federated_retrieval_infos: list[FederatedRetrievalInfo] | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
user_uploaded_persona_files: list[UUID] | None = (
|
||||
[user_file.id for user_file in persona.user_files] if persona else None
|
||||
@@ -310,7 +297,6 @@ def search_pipeline(
|
||||
bypass_acl=chunk_search_request.bypass_acl,
|
||||
attached_document_ids=attached_document_ids,
|
||||
hierarchy_node_ids=hierarchy_node_ids,
|
||||
acl_filters=acl_filters,
|
||||
)
|
||||
|
||||
query_keywords = strip_stopwords(chunk_search_request.query)
|
||||
@@ -329,8 +315,6 @@ def search_pipeline(
|
||||
user_id=user.id if user else None,
|
||||
document_index=document_index,
|
||||
db_session=db_session,
|
||||
embedding_model=embedding_model,
|
||||
prefetched_federated_retrieval_infos=prefetched_federated_retrieval_infos,
|
||||
)
|
||||
|
||||
# For some specific connectors like Salesforce, a user that has access to an object doesn't mean
|
||||
|
||||
@@ -14,11 +14,9 @@ from onyx.context.search.utils import get_query_embedding
|
||||
from onyx.context.search.utils import inference_section_from_chunks
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.interfaces import VespaChunkRequest
|
||||
from onyx.federated_connectors.federated_retrieval import FederatedRetrievalInfo
|
||||
from onyx.federated_connectors.federated_retrieval import (
|
||||
get_federated_retrieval_functions,
|
||||
)
|
||||
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
|
||||
@@ -52,14 +50,9 @@ def combine_retrieval_results(
|
||||
def _embed_and_search(
|
||||
query_request: ChunkIndexRequest,
|
||||
document_index: DocumentIndex,
|
||||
db_session: Session | None = None,
|
||||
embedding_model: EmbeddingModel | None = None,
|
||||
db_session: Session,
|
||||
) -> list[InferenceChunk]:
|
||||
query_embedding = get_query_embedding(
|
||||
query_request.query,
|
||||
db_session=db_session,
|
||||
embedding_model=embedding_model,
|
||||
)
|
||||
query_embedding = get_query_embedding(query_request.query, db_session)
|
||||
|
||||
hybrid_alpha = query_request.hybrid_alpha or HYBRID_ALPHA
|
||||
|
||||
@@ -85,9 +78,7 @@ def search_chunks(
|
||||
query_request: ChunkIndexRequest,
|
||||
user_id: UUID | None,
|
||||
document_index: DocumentIndex,
|
||||
db_session: Session | None = None,
|
||||
embedding_model: EmbeddingModel | None = None,
|
||||
prefetched_federated_retrieval_infos: list[FederatedRetrievalInfo] | None = None,
|
||||
db_session: Session,
|
||||
) -> list[InferenceChunk]:
|
||||
run_queries: list[tuple[Callable, tuple]] = []
|
||||
|
||||
@@ -97,22 +88,14 @@ def search_chunks(
|
||||
else None
|
||||
)
|
||||
|
||||
# Federated retrieval — use pre-fetched if available, otherwise query DB
|
||||
if prefetched_federated_retrieval_infos is not None:
|
||||
federated_retrieval_infos = prefetched_federated_retrieval_infos
|
||||
else:
|
||||
if db_session is None:
|
||||
raise ValueError(
|
||||
"Either db_session or prefetched_federated_retrieval_infos "
|
||||
"must be provided"
|
||||
)
|
||||
federated_retrieval_infos = get_federated_retrieval_functions(
|
||||
db_session=db_session,
|
||||
user_id=user_id,
|
||||
source_types=list(source_filters) if source_filters else None,
|
||||
document_set_names=query_request.filters.document_set,
|
||||
user_file_ids=query_request.filters.user_file_ids,
|
||||
)
|
||||
# Federated retrieval
|
||||
federated_retrieval_infos = get_federated_retrieval_functions(
|
||||
db_session=db_session,
|
||||
user_id=user_id,
|
||||
source_types=list(source_filters) if source_filters else None,
|
||||
document_set_names=query_request.filters.document_set,
|
||||
user_file_ids=query_request.filters.user_file_ids,
|
||||
)
|
||||
|
||||
federated_sources = set(
|
||||
federated_retrieval_info.source.to_non_federated_source()
|
||||
@@ -131,10 +114,7 @@ def search_chunks(
|
||||
|
||||
if normal_search_enabled:
|
||||
run_queries.append(
|
||||
(
|
||||
_embed_and_search,
|
||||
(query_request, document_index, db_session, embedding_model),
|
||||
)
|
||||
(_embed_and_search, (query_request, document_index, db_session))
|
||||
)
|
||||
|
||||
parallel_search_results = run_functions_tuples_in_parallel(run_queries)
|
||||
|
||||
@@ -64,34 +64,23 @@ def inference_section_from_single_chunk(
|
||||
)
|
||||
|
||||
|
||||
def get_query_embeddings(
|
||||
queries: list[str],
|
||||
db_session: Session | None = None,
|
||||
embedding_model: EmbeddingModel | None = None,
|
||||
) -> list[Embedding]:
|
||||
if embedding_model is None:
|
||||
if db_session is None:
|
||||
raise ValueError("Either db_session or embedding_model must be provided")
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
embedding_model = EmbeddingModel.from_db_model(
|
||||
search_settings=search_settings,
|
||||
server_host=MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
def get_query_embeddings(queries: list[str], db_session: Session) -> list[Embedding]:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
|
||||
query_embedding = embedding_model.encode(queries, text_type=EmbedTextType.QUERY)
|
||||
model = EmbeddingModel.from_db_model(
|
||||
search_settings=search_settings,
|
||||
# The below are globally set, this flow always uses the indexing one
|
||||
server_host=MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
query_embedding = model.encode(queries, text_type=EmbedTextType.QUERY)
|
||||
return query_embedding
|
||||
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def get_query_embedding(
|
||||
query: str,
|
||||
db_session: Session | None = None,
|
||||
embedding_model: EmbeddingModel | None = None,
|
||||
) -> Embedding:
|
||||
return get_query_embeddings(
|
||||
[query], db_session=db_session, embedding_model=embedding_model
|
||||
)[0]
|
||||
def get_query_embedding(query: str, db_session: Session) -> Embedding:
|
||||
return get_query_embeddings([query], db_session)[0]
|
||||
|
||||
|
||||
def convert_inference_sections_to_search_docs(
|
||||
|
||||
@@ -4,7 +4,6 @@ from fastapi_users.password import PasswordHelper
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.api_key import ApiKeyDescriptor
|
||||
@@ -55,7 +54,6 @@ async def fetch_user_for_api_key(
|
||||
select(User)
|
||||
.join(ApiKey, ApiKey.user_id == User.id)
|
||||
.where(ApiKey.hashed_api_key == hashed_api_key)
|
||||
.options(selectinload(User.memories))
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -13,7 +13,6 @@ from sqlalchemy import func
|
||||
from sqlalchemy import Select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
@@ -98,11 +97,6 @@ async def get_user_count(only_admin_users: bool = False) -> int:
|
||||
|
||||
# Need to override this because FastAPI Users doesn't give flexibility for backend field creation logic in OAuth flow
|
||||
class SQLAlchemyUserAdminDB(SQLAlchemyUserDatabase[UP, ID]):
|
||||
async def _get_user(self, statement: Select) -> UP | None:
|
||||
statement = statement.options(selectinload(User.memories))
|
||||
results = await self.session.execute(statement)
|
||||
return results.unique().scalar_one_or_none()
|
||||
|
||||
async def create(
|
||||
self,
|
||||
create_dict: Dict[str, Any],
|
||||
|
||||
@@ -116,15 +116,12 @@ def get_connector_credential_pairs_for_user(
|
||||
order_by_desc: bool = False,
|
||||
source: DocumentSource | None = None,
|
||||
processing_mode: ProcessingMode | None = ProcessingMode.REGULAR,
|
||||
defer_connector_config: bool = False,
|
||||
) -> list[ConnectorCredentialPair]:
|
||||
"""Get connector credential pairs for a user.
|
||||
|
||||
Args:
|
||||
processing_mode: Filter by processing mode. Defaults to REGULAR to hide
|
||||
FILE_SYSTEM connectors from standard admin UI. Pass None to get all.
|
||||
defer_connector_config: If True, skips loading Connector.connector_specific_config
|
||||
to avoid fetching large JSONB blobs when they aren't needed.
|
||||
"""
|
||||
if eager_load_user:
|
||||
assert (
|
||||
@@ -133,10 +130,7 @@ def get_connector_credential_pairs_for_user(
|
||||
stmt = select(ConnectorCredentialPair).distinct()
|
||||
|
||||
if eager_load_connector:
|
||||
connector_load = selectinload(ConnectorCredentialPair.connector)
|
||||
if defer_connector_config:
|
||||
connector_load = connector_load.defer(Connector.connector_specific_config)
|
||||
stmt = stmt.options(connector_load)
|
||||
stmt = stmt.options(selectinload(ConnectorCredentialPair.connector))
|
||||
|
||||
if eager_load_credential:
|
||||
load_opts = selectinload(ConnectorCredentialPair.credential)
|
||||
@@ -176,7 +170,6 @@ def get_connector_credential_pairs_for_user_parallel(
|
||||
order_by_desc: bool = False,
|
||||
source: DocumentSource | None = None,
|
||||
processing_mode: ProcessingMode | None = ProcessingMode.REGULAR,
|
||||
defer_connector_config: bool = False,
|
||||
) -> list[ConnectorCredentialPair]:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
return get_connector_credential_pairs_for_user(
|
||||
@@ -190,7 +183,6 @@ def get_connector_credential_pairs_for_user_parallel(
|
||||
order_by_desc=order_by_desc,
|
||||
source=source,
|
||||
processing_mode=processing_mode,
|
||||
defer_connector_config=defer_connector_config,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -554,19 +554,10 @@ def fetch_all_document_sets_for_user(
|
||||
stmt = (
|
||||
select(DocumentSetDBModel)
|
||||
.distinct()
|
||||
.options(
|
||||
selectinload(DocumentSetDBModel.connector_credential_pairs).selectinload(
|
||||
ConnectorCredentialPair.connector
|
||||
),
|
||||
selectinload(DocumentSetDBModel.users),
|
||||
selectinload(DocumentSetDBModel.groups),
|
||||
selectinload(DocumentSetDBModel.federated_connectors).selectinload(
|
||||
FederatedConnector__DocumentSet.federated_connector
|
||||
),
|
||||
)
|
||||
.options(selectinload(DocumentSetDBModel.federated_connectors))
|
||||
)
|
||||
stmt = _add_user_filters(stmt, user, get_editable=get_editable)
|
||||
return db_session.scalars(stmt).unique().all()
|
||||
return db_session.scalars(stmt).all()
|
||||
|
||||
|
||||
def fetch_documents_for_document_set_paginated(
|
||||
|
||||
@@ -232,12 +232,6 @@ class BuildSessionStatus(str, PyEnum):
|
||||
IDLE = "idle"
|
||||
|
||||
|
||||
class SharingScope(str, PyEnum):
|
||||
PRIVATE = "private"
|
||||
PUBLIC_ORG = "public_org"
|
||||
PUBLIC_GLOBAL = "public_global"
|
||||
|
||||
|
||||
class SandboxStatus(str, PyEnum):
|
||||
PROVISIONING = "provisioning"
|
||||
RUNNING = "running"
|
||||
|
||||
@@ -77,7 +77,6 @@ from onyx.db.enums import (
|
||||
ThemePreference,
|
||||
DefaultAppMode,
|
||||
SwitchoverType,
|
||||
SharingScope,
|
||||
)
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.configs.constants import SearchFeedbackType
|
||||
@@ -287,7 +286,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
|
||||
# relationships
|
||||
credentials: Mapped[list["Credential"]] = relationship(
|
||||
"Credential", back_populates="user"
|
||||
"Credential", back_populates="user", lazy="joined"
|
||||
)
|
||||
chat_sessions: Mapped[list["ChatSession"]] = relationship(
|
||||
"ChatSession", back_populates="user"
|
||||
@@ -321,6 +320,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
"Memory",
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
lazy="selectin",
|
||||
order_by="desc(Memory.id)",
|
||||
)
|
||||
oauth_user_tokens: Mapped[list["OAuthUserToken"]] = relationship(
|
||||
@@ -1040,9 +1040,7 @@ class OpenSearchTenantMigrationRecord(Base):
|
||||
nullable=False,
|
||||
)
|
||||
# Opaque continuation token from Vespa's Visit API.
|
||||
# NULL means "not started".
|
||||
# Otherwise contains a serialized mapping between slice ID and continuation
|
||||
# token for that slice.
|
||||
# NULL means "not started" or "visit completed".
|
||||
vespa_visit_continuation_token: Mapped[str | None] = mapped_column(
|
||||
Text, nullable=True
|
||||
)
|
||||
@@ -1066,9 +1064,6 @@ class OpenSearchTenantMigrationRecord(Base):
|
||||
enable_opensearch_retrieval: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=False
|
||||
)
|
||||
approx_chunk_count_in_vespa: Mapped[int | None] = mapped_column(
|
||||
Integer, nullable=True
|
||||
)
|
||||
|
||||
|
||||
class KGEntityType(Base):
|
||||
@@ -4717,12 +4712,6 @@ class BuildSession(Base):
|
||||
demo_data_enabled: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, server_default=text("true")
|
||||
)
|
||||
sharing_scope: Mapped[SharingScope] = mapped_column(
|
||||
String,
|
||||
nullable=False,
|
||||
default=SharingScope.PRIVATE,
|
||||
server_default="private",
|
||||
)
|
||||
|
||||
# Relationships
|
||||
user: Mapped[User | None] = relationship("User", foreign_keys=[user_id])
|
||||
@@ -4939,7 +4928,6 @@ class ScimUserMapping(Base):
|
||||
user_id: Mapped[UUID] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), unique=True, nullable=False
|
||||
)
|
||||
scim_username: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||
|
||||
@@ -4,7 +4,6 @@ This module provides functions to track the progress of migrating documents
|
||||
from Vespa to OpenSearch.
|
||||
"""
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
@@ -13,9 +12,6 @@ from sqlalchemy import text
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.tasks.opensearch_migration.constants import (
|
||||
GET_VESPA_CHUNKS_SLICE_COUNT,
|
||||
)
|
||||
from onyx.background.celery.tasks.opensearch_migration.constants import (
|
||||
TOTAL_ALLOWABLE_DOC_MIGRATION_ATTEMPTS_BEFORE_PERMANENT_FAILURE,
|
||||
)
|
||||
@@ -247,37 +243,29 @@ def should_document_migration_be_permanently_failed(
|
||||
|
||||
def get_vespa_visit_state(
|
||||
db_session: Session,
|
||||
) -> tuple[dict[int, str | None], int]:
|
||||
) -> tuple[str | None, int]:
|
||||
"""Gets the current Vespa migration state from the tenant migration record.
|
||||
|
||||
Requires the OpenSearchTenantMigrationRecord to exist.
|
||||
|
||||
Returns:
|
||||
Tuple of (continuation_token_map, total_chunks_migrated).
|
||||
Tuple of (continuation_token, total_chunks_migrated). continuation_token
|
||||
is None if not started or completed.
|
||||
"""
|
||||
record = db_session.query(OpenSearchTenantMigrationRecord).first()
|
||||
if record is None:
|
||||
raise RuntimeError("OpenSearchTenantMigrationRecord not found.")
|
||||
if record.vespa_visit_continuation_token is None:
|
||||
continuation_token_map: dict[int, str | None] = {
|
||||
slice_id: None for slice_id in range(GET_VESPA_CHUNKS_SLICE_COUNT)
|
||||
}
|
||||
else:
|
||||
json_loaded_continuation_token_map = json.loads(
|
||||
record.vespa_visit_continuation_token
|
||||
)
|
||||
continuation_token_map = {
|
||||
int(key): value for key, value in json_loaded_continuation_token_map.items()
|
||||
}
|
||||
return continuation_token_map, record.total_chunks_migrated
|
||||
return (
|
||||
record.vespa_visit_continuation_token,
|
||||
record.total_chunks_migrated,
|
||||
)
|
||||
|
||||
|
||||
def update_vespa_visit_progress_with_commit(
|
||||
db_session: Session,
|
||||
continuation_token_map: dict[int, str | None],
|
||||
continuation_token: str | None,
|
||||
chunks_processed: int,
|
||||
chunks_errored: int,
|
||||
approx_chunk_count_in_vespa: int | None,
|
||||
) -> None:
|
||||
"""Updates the Vespa migration progress and commits.
|
||||
|
||||
@@ -285,26 +273,19 @@ def update_vespa_visit_progress_with_commit(
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session.
|
||||
continuation_token_map: The new continuation token map. None entry means
|
||||
the visit is complete for that slice.
|
||||
continuation_token: The new continuation token. None means the visit
|
||||
is complete.
|
||||
chunks_processed: Number of chunks processed in this batch (added to
|
||||
the running total).
|
||||
chunks_errored: Number of chunks errored in this batch (added to the
|
||||
running errored total).
|
||||
approx_chunk_count_in_vespa: Approximate number of chunks in Vespa. If
|
||||
None, the existing value is used.
|
||||
"""
|
||||
record = db_session.query(OpenSearchTenantMigrationRecord).first()
|
||||
if record is None:
|
||||
raise RuntimeError("OpenSearchTenantMigrationRecord not found.")
|
||||
record.vespa_visit_continuation_token = json.dumps(continuation_token_map)
|
||||
record.vespa_visit_continuation_token = continuation_token
|
||||
record.total_chunks_migrated += chunks_processed
|
||||
record.total_chunks_errored += chunks_errored
|
||||
record.approx_chunk_count_in_vespa = (
|
||||
approx_chunk_count_in_vespa
|
||||
if approx_chunk_count_in_vespa is not None
|
||||
else record.approx_chunk_count_in_vespa
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -372,27 +353,25 @@ def build_sanitized_to_original_doc_id_mapping(
|
||||
|
||||
def get_opensearch_migration_state(
|
||||
db_session: Session,
|
||||
) -> tuple[int, datetime | None, datetime | None, int | None]:
|
||||
) -> tuple[int, datetime | None, datetime | None]:
|
||||
"""Returns the state of the Vespa to OpenSearch migration.
|
||||
|
||||
If the tenant migration record is not found, returns defaults of 0, None,
|
||||
None, None.
|
||||
None.
|
||||
|
||||
Args:
|
||||
db_session: SQLAlchemy session.
|
||||
|
||||
Returns:
|
||||
Tuple of (total_chunks_migrated, created_at, migration_completed_at,
|
||||
approx_chunk_count_in_vespa).
|
||||
Tuple of (total_chunks_migrated, created_at, migration_completed_at).
|
||||
"""
|
||||
record = db_session.query(OpenSearchTenantMigrationRecord).first()
|
||||
if record is None:
|
||||
return 0, None, None, None
|
||||
return 0, None, None
|
||||
return (
|
||||
record.total_chunks_migrated,
|
||||
record.created_at,
|
||||
record.migration_completed_at,
|
||||
record.approx_chunk_count_in_vespa,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@ from uuid import UUID
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.pat import build_displayable_pat
|
||||
@@ -32,59 +31,53 @@ async def fetch_user_for_pat(
|
||||
|
||||
NOTE: This is async since it's used during auth (which is necessarily async due to FastAPI Users).
|
||||
NOTE: Expired includes both naturally expired and user-revoked tokens (revocation sets expires_at=NOW()).
|
||||
|
||||
Uses select(User) as primary entity so that joined-eager relationships (e.g. oauth_accounts)
|
||||
are loaded correctly — matching the pattern in fetch_user_for_api_key.
|
||||
"""
|
||||
# Single joined query with all filters pushed to database
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
user = await async_db_session.scalar(
|
||||
select(User)
|
||||
.join(PersonalAccessToken, PersonalAccessToken.user_id == User.id)
|
||||
result = await async_db_session.execute(
|
||||
select(PersonalAccessToken, User)
|
||||
.join(User, PersonalAccessToken.user_id == User.id)
|
||||
.where(PersonalAccessToken.hashed_token == hashed_token)
|
||||
.where(User.is_active) # type: ignore
|
||||
.where(
|
||||
(PersonalAccessToken.expires_at.is_(None))
|
||||
| (PersonalAccessToken.expires_at > now)
|
||||
)
|
||||
.options(selectinload(User.memories))
|
||||
.limit(1)
|
||||
)
|
||||
if not user:
|
||||
row = result.first()
|
||||
|
||||
if not row:
|
||||
return None
|
||||
|
||||
_schedule_pat_last_used_update(hashed_token, now)
|
||||
return user
|
||||
pat, user = row
|
||||
|
||||
# Throttle last_used_at updates to reduce DB load (5-minute granularity sufficient for auditing)
|
||||
# For request-level auditing, use application logs or a dedicated audit table
|
||||
should_update = (
|
||||
pat.last_used_at is None or (now - pat.last_used_at).total_seconds() > 300
|
||||
)
|
||||
|
||||
def _schedule_pat_last_used_update(hashed_token: str, now: datetime) -> None:
|
||||
"""Fire-and-forget update of last_used_at, throttled to 5-minute granularity."""
|
||||
|
||||
async def _update() -> None:
|
||||
try:
|
||||
tenant_id = get_current_tenant_id()
|
||||
async with get_async_session_context_manager(tenant_id) as session:
|
||||
pat = await session.scalar(
|
||||
select(PersonalAccessToken).where(
|
||||
PersonalAccessToken.hashed_token == hashed_token
|
||||
if should_update:
|
||||
# Update in separate session to avoid transaction coupling (fire-and-forget)
|
||||
async def _update_last_used() -> None:
|
||||
try:
|
||||
tenant_id = get_current_tenant_id()
|
||||
async with get_async_session_context_manager(
|
||||
tenant_id
|
||||
) as separate_session:
|
||||
await separate_session.execute(
|
||||
update(PersonalAccessToken)
|
||||
.where(PersonalAccessToken.hashed_token == hashed_token)
|
||||
.values(last_used_at=now)
|
||||
)
|
||||
)
|
||||
if not pat:
|
||||
return
|
||||
if (
|
||||
pat.last_used_at is not None
|
||||
and (now - pat.last_used_at).total_seconds() <= 300
|
||||
):
|
||||
return
|
||||
await session.execute(
|
||||
update(PersonalAccessToken)
|
||||
.where(PersonalAccessToken.hashed_token == hashed_token)
|
||||
.values(last_used_at=now)
|
||||
)
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to update last_used_at for PAT: {e}")
|
||||
await separate_session.commit()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to update last_used_at for PAT: {e}")
|
||||
|
||||
asyncio.create_task(_update())
|
||||
asyncio.create_task(_update_last_used())
|
||||
|
||||
return user
|
||||
|
||||
|
||||
def create_pat(
|
||||
|
||||
@@ -28,7 +28,6 @@ from onyx.db.document_access import get_accessible_documents_by_ids
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Document
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import FederatedConnector__DocumentSet
|
||||
from onyx.db.models import HierarchyNode
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__User
|
||||
@@ -421,16 +420,9 @@ def get_minimal_persona_snapshots_for_user(
|
||||
stmt = stmt.options(
|
||||
selectinload(Persona.tools),
|
||||
selectinload(Persona.labels),
|
||||
selectinload(Persona.document_sets).options(
|
||||
selectinload(DocumentSet.connector_credential_pairs).selectinload(
|
||||
ConnectorCredentialPair.connector
|
||||
),
|
||||
selectinload(DocumentSet.users),
|
||||
selectinload(DocumentSet.groups),
|
||||
selectinload(DocumentSet.federated_connectors).selectinload(
|
||||
FederatedConnector__DocumentSet.federated_connector
|
||||
),
|
||||
),
|
||||
selectinload(Persona.document_sets)
|
||||
.selectinload(DocumentSet.connector_credential_pairs)
|
||||
.selectinload(ConnectorCredentialPair.connector),
|
||||
selectinload(Persona.hierarchy_nodes),
|
||||
selectinload(Persona.attached_documents).selectinload(
|
||||
Document.parent_hierarchy_node
|
||||
@@ -461,16 +453,7 @@ def get_persona_snapshots_for_user(
|
||||
Document.parent_hierarchy_node
|
||||
),
|
||||
selectinload(Persona.labels),
|
||||
selectinload(Persona.document_sets).options(
|
||||
selectinload(DocumentSet.connector_credential_pairs).selectinload(
|
||||
ConnectorCredentialPair.connector
|
||||
),
|
||||
selectinload(DocumentSet.users),
|
||||
selectinload(DocumentSet.groups),
|
||||
selectinload(DocumentSet.federated_connectors).selectinload(
|
||||
FederatedConnector__DocumentSet.federated_connector
|
||||
),
|
||||
),
|
||||
selectinload(Persona.document_sets),
|
||||
selectinload(Persona.user),
|
||||
selectinload(Persona.user_files),
|
||||
selectinload(Persona.users),
|
||||
@@ -567,16 +550,9 @@ def get_minimal_persona_snapshots_paginated(
|
||||
Document.parent_hierarchy_node
|
||||
),
|
||||
selectinload(Persona.labels),
|
||||
selectinload(Persona.document_sets).options(
|
||||
selectinload(DocumentSet.connector_credential_pairs).selectinload(
|
||||
ConnectorCredentialPair.connector
|
||||
),
|
||||
selectinload(DocumentSet.users),
|
||||
selectinload(DocumentSet.groups),
|
||||
selectinload(DocumentSet.federated_connectors).selectinload(
|
||||
FederatedConnector__DocumentSet.federated_connector
|
||||
),
|
||||
),
|
||||
selectinload(Persona.document_sets)
|
||||
.selectinload(DocumentSet.connector_credential_pairs)
|
||||
.selectinload(ConnectorCredentialPair.connector),
|
||||
selectinload(Persona.user),
|
||||
)
|
||||
|
||||
@@ -635,16 +611,7 @@ def get_persona_snapshots_paginated(
|
||||
Document.parent_hierarchy_node
|
||||
),
|
||||
selectinload(Persona.labels),
|
||||
selectinload(Persona.document_sets).options(
|
||||
selectinload(DocumentSet.connector_credential_pairs).selectinload(
|
||||
ConnectorCredentialPair.connector
|
||||
),
|
||||
selectinload(DocumentSet.users),
|
||||
selectinload(DocumentSet.groups),
|
||||
selectinload(DocumentSet.federated_connectors).selectinload(
|
||||
FederatedConnector__DocumentSet.federated_connector
|
||||
),
|
||||
),
|
||||
selectinload(Persona.document_sets),
|
||||
selectinload(Persona.user),
|
||||
selectinload(Persona.user_files),
|
||||
selectinload(Persona.users),
|
||||
|
||||
@@ -54,9 +54,6 @@ class SearchHit(BaseModel, Generic[SchemaDocumentModel]):
|
||||
# Maps schema property name to a list of highlighted snippets with match
|
||||
# terms wrapped in tags (e.g. "something <hi>keyword</hi> other thing").
|
||||
match_highlights: dict[str, list[str]] = {}
|
||||
# Score explanation from OpenSearch when "explain": true is set in the query.
|
||||
# Contains detailed breakdown of how the score was calculated.
|
||||
explanation: dict[str, Any] | None = None
|
||||
|
||||
|
||||
def get_new_body_without_vectors(body: dict[str, Any]) -> dict[str, Any]:
|
||||
@@ -709,12 +706,10 @@ class OpenSearchClient:
|
||||
)
|
||||
document_chunk_score = hit.get("_score", None)
|
||||
match_highlights: dict[str, list[str]] = hit.get("highlight", {})
|
||||
explanation: dict[str, Any] | None = hit.get("_explanation", None)
|
||||
search_hit = SearchHit[DocumentChunk](
|
||||
document_chunk=DocumentChunk.model_validate(document_chunk_source),
|
||||
score=document_chunk_score,
|
||||
match_highlights=match_highlights,
|
||||
explanation=explanation,
|
||||
)
|
||||
search_hits.append(search_hit)
|
||||
logger.debug(
|
||||
|
||||
@@ -10,31 +10,31 @@ EF_CONSTRUCTION = 256
|
||||
# quality but increase memory footprint. Values typically range between 12 - 48.
|
||||
M = 32 # Set relatively high for better accuracy.
|
||||
|
||||
# When performing hybrid search, we need to consider more candidates than the number of results to be returned.
|
||||
# This is because the scoring is hybrid and the results are reordered due to the hybrid scoring.
|
||||
# Higher = more candidates for hybrid fusion = better retrieval accuracy, but results in more computation per query.
|
||||
# Imagine a simple case with a single keyword query and a single vector query and we want 10 final docs.
|
||||
# If we only fetch 10 candidates from each of keyword and vector, they would have to have perfect overlap to get a good hybrid
|
||||
# ranking for the 10 results. If we fetch 1000 candidates from each, we have a much higher chance of all 10 of the final desired
|
||||
# docs showing up and getting scored. In worse situations, the final 10 docs don't even show up as the final 10 (worse than just
|
||||
# a miss at the reranking step).
|
||||
DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES = 750
|
||||
|
||||
# Number of vectors to examine for top k neighbors for the HNSW method.
|
||||
EF_SEARCH = DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES
|
||||
# Should be >= DEFAULT_K_NUM_CANDIDATES for good recall; higher = better accuracy, slower search.
|
||||
# Bumped this to 1000, for dataset of low 10,000 docs, did not see improvement in recall.
|
||||
EF_SEARCH = 256
|
||||
|
||||
# The default number of neighbors to consider for knn vector similarity search.
|
||||
# We need this higher than the number of results because the scoring is hybrid.
|
||||
# If there is only 1 query, setting k equal to the number of results is enough,
|
||||
# but since there is heavy reordering due to hybrid scoring, we need to set k higher.
|
||||
# Higher = more candidates for hybrid fusion = better retrieval accuracy, more query cost.
|
||||
DEFAULT_K_NUM_CANDIDATES = 50 # TODO likely need to bump this way higher
|
||||
|
||||
# Since the titles are included in the contents, they are heavily downweighted as they act as a boost
|
||||
# rather than an independent scoring component.
|
||||
SEARCH_TITLE_VECTOR_WEIGHT = 0.1
|
||||
SEARCH_CONTENT_VECTOR_WEIGHT = 0.45
|
||||
# Single keyword weight for both title and content (merged from former title keyword + content keyword).
|
||||
SEARCH_KEYWORD_WEIGHT = 0.45
|
||||
SEARCH_TITLE_KEYWORD_WEIGHT = 0.1
|
||||
SEARCH_CONTENT_VECTOR_WEIGHT = 0.4
|
||||
SEARCH_CONTENT_KEYWORD_WEIGHT = 0.4
|
||||
|
||||
# NOTE: it is critical that the order of these weights matches the order of the sub-queries in the hybrid search.
|
||||
HYBRID_SEARCH_NORMALIZATION_WEIGHTS = [
|
||||
SEARCH_TITLE_VECTOR_WEIGHT,
|
||||
SEARCH_TITLE_KEYWORD_WEIGHT,
|
||||
SEARCH_CONTENT_VECTOR_WEIGHT,
|
||||
SEARCH_KEYWORD_WEIGHT,
|
||||
SEARCH_CONTENT_KEYWORD_WEIGHT,
|
||||
]
|
||||
|
||||
assert sum(HYBRID_SEARCH_NORMALIZATION_WEIGHTS) == 1.0
|
||||
|
||||
@@ -842,8 +842,6 @@ class OpenSearchDocumentIndex(DocumentIndex):
|
||||
body=query_body,
|
||||
search_pipeline_id=ZSCORE_NORMALIZATION_PIPELINE_NAME,
|
||||
)
|
||||
|
||||
# Good place for a breakpoint to inspect the search hits if you have "explain" enabled.
|
||||
inference_chunks_uncleaned: list[InferenceChunkUncleaned] = [
|
||||
_convert_retrieved_opensearch_chunk_to_inference_chunk_uncleaned(
|
||||
search_hit.document_chunk, search_hit.score, search_hit.match_highlights
|
||||
|
||||
@@ -11,7 +11,6 @@ from pydantic import model_serializer
|
||||
from pydantic import model_validator
|
||||
from pydantic import SerializerFunctionWrapHandler
|
||||
|
||||
from onyx.configs.app_configs import OPENSEARCH_TEXT_ANALYZER
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE
|
||||
from onyx.document_index.opensearch.constants import EF_CONSTRUCTION
|
||||
@@ -55,11 +54,6 @@ SECONDARY_OWNERS_FIELD_NAME = "secondary_owners"
|
||||
ANCESTOR_HIERARCHY_NODE_IDS_FIELD_NAME = "ancestor_hierarchy_node_ids"
|
||||
|
||||
|
||||
# Faiss was also tried but it didn't have any benefits
|
||||
# NMSLIB is deprecated, not recommended
|
||||
OPENSEARCH_KNN_ENGINE = "lucene"
|
||||
|
||||
|
||||
def get_opensearch_doc_chunk_id(
|
||||
tenant_state: TenantState,
|
||||
document_id: str,
|
||||
@@ -349,9 +343,6 @@ class DocumentSchema:
|
||||
"properties": {
|
||||
TITLE_FIELD_NAME: {
|
||||
"type": "text",
|
||||
# Language analyzer (e.g. english) stems at index and search time for variant matching.
|
||||
# Configure via OPENSEARCH_TEXT_ANALYZER. Existing indices need reindexing after a change.
|
||||
"analyzer": OPENSEARCH_TEXT_ANALYZER,
|
||||
"fields": {
|
||||
# Subfield accessed as title.keyword. Not indexed for
|
||||
# values longer than 256 chars.
|
||||
@@ -366,7 +357,9 @@ class DocumentSchema:
|
||||
CONTENT_FIELD_NAME: {
|
||||
"type": "text",
|
||||
"store": True,
|
||||
"analyzer": OPENSEARCH_TEXT_ANALYZER,
|
||||
# This makes highlighting text during queries more efficient
|
||||
# at the cost of disk space. See
|
||||
# https://docs.opensearch.org/latest/search-plugins/searching-data/highlight/#methods-of-obtaining-offsets
|
||||
"index_options": "offsets",
|
||||
},
|
||||
TITLE_VECTOR_FIELD_NAME: {
|
||||
@@ -375,7 +368,7 @@ class DocumentSchema:
|
||||
"method": {
|
||||
"name": "hnsw",
|
||||
"space_type": "cosinesimil",
|
||||
"engine": OPENSEARCH_KNN_ENGINE,
|
||||
"engine": "lucene",
|
||||
"parameters": {"ef_construction": EF_CONSTRUCTION, "m": M},
|
||||
},
|
||||
},
|
||||
@@ -387,7 +380,7 @@ class DocumentSchema:
|
||||
"method": {
|
||||
"name": "hnsw",
|
||||
"space_type": "cosinesimil",
|
||||
"engine": OPENSEARCH_KNN_ENGINE,
|
||||
"engine": "lucene",
|
||||
"parameters": {"ef_construction": EF_CONSTRUCTION, "m": M},
|
||||
},
|
||||
},
|
||||
|
||||
@@ -6,16 +6,13 @@ from typing import Any
|
||||
from uuid import UUID
|
||||
|
||||
from onyx.configs.app_configs import DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S
|
||||
from onyx.configs.app_configs import OPENSEARCH_EXPLAIN_ENABLED
|
||||
from onyx.configs.app_configs import OPENSEARCH_PROFILING_DISABLED
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import INDEX_SEPARATOR
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import Tag
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.constants import (
|
||||
DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES,
|
||||
)
|
||||
from onyx.document_index.opensearch.constants import DEFAULT_K_NUM_CANDIDATES
|
||||
from onyx.document_index.opensearch.constants import HYBRID_SEARCH_NORMALIZATION_WEIGHTS
|
||||
from onyx.document_index.opensearch.schema import ACCESS_CONTROL_LIST_FIELD_NAME
|
||||
from onyx.document_index.opensearch.schema import ANCESTOR_HIERARCHY_NODE_IDS_FIELD_NAME
|
||||
@@ -243,9 +240,6 @@ class DocumentQuery:
|
||||
Returns:
|
||||
A dictionary representing the final hybrid search query.
|
||||
"""
|
||||
# WARNING: Profiling does not work with hybrid search; do not add it at
|
||||
# this level. See https://github.com/opensearch-project/neural-search/issues/1255
|
||||
|
||||
if num_hits > DEFAULT_OPENSEARCH_MAX_RESULT_WINDOW:
|
||||
raise ValueError(
|
||||
f"Bug: num_hits ({num_hits}) is greater than the current maximum allowed "
|
||||
@@ -253,7 +247,7 @@ class DocumentQuery:
|
||||
)
|
||||
|
||||
hybrid_search_subqueries = DocumentQuery._get_hybrid_search_subqueries(
|
||||
query_text, query_vector
|
||||
query_text, query_vector, num_candidates=DEFAULT_K_NUM_CANDIDATES
|
||||
)
|
||||
hybrid_search_filters = DocumentQuery._get_search_filters(
|
||||
tenant_state=tenant_state,
|
||||
@@ -281,31 +275,25 @@ class DocumentQuery:
|
||||
hybrid_search_query: dict[str, Any] = {
|
||||
"hybrid": {
|
||||
"queries": hybrid_search_subqueries,
|
||||
# Max results per subquery per shard before aggregation. Ensures keyword and vector
|
||||
# subqueries contribute equally to the candidate pool for hybrid fusion.
|
||||
# Sources:
|
||||
# https://docs.opensearch.org/latest/vector-search/ai-search/hybrid-search/pagination/
|
||||
# https://opensearch.org/blog/navigating-pagination-in-hybrid-queries-with-the-pagination_depth-parameter/
|
||||
"pagination_depth": DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES,
|
||||
# Applied to all the sub-queries independently (this avoids having subqueries having a lot of results thrown out).
|
||||
# Sources:
|
||||
# Applied to all the sub-queries. Source:
|
||||
# https://docs.opensearch.org/latest/query-dsl/compound/hybrid/
|
||||
# https://opensearch.org/blog/introducing-common-filter-support-for-hybrid-search-queries
|
||||
# Does AND for each filter in the list.
|
||||
"filter": {"bool": {"filter": hybrid_search_filters}},
|
||||
}
|
||||
}
|
||||
|
||||
# NOTE: By default, hybrid search retrieves "size"-many results from
|
||||
# each OpenSearch shard before aggregation. Source:
|
||||
# https://docs.opensearch.org/latest/vector-search/ai-search/hybrid-search/pagination/
|
||||
|
||||
final_hybrid_search_body: dict[str, Any] = {
|
||||
"query": hybrid_search_query,
|
||||
"size": num_hits,
|
||||
"highlight": match_highlights_configuration,
|
||||
"timeout": f"{DEFAULT_OPENSEARCH_QUERY_TIMEOUT_S}s",
|
||||
}
|
||||
|
||||
# Explain is for scoring breakdowns.
|
||||
if OPENSEARCH_EXPLAIN_ENABLED:
|
||||
final_hybrid_search_body["explain"] = True
|
||||
# WARNING: Profiling does not work with hybrid search; do not add it at
|
||||
# this level. See https://github.com/opensearch-project/neural-search/issues/1255
|
||||
|
||||
return final_hybrid_search_body
|
||||
|
||||
@@ -367,12 +355,7 @@ class DocumentQuery:
|
||||
|
||||
@staticmethod
|
||||
def _get_hybrid_search_subqueries(
|
||||
query_text: str,
|
||||
query_vector: list[float],
|
||||
# The default number of neighbors to consider for knn vector similarity search.
|
||||
# This is higher than the number of results because the scoring is hybrid.
|
||||
# for a detailed breakdown, see where the default value is set.
|
||||
vector_candidates: int = DEFAULT_NUM_HYBRID_SEARCH_CANDIDATES,
|
||||
query_text: str, query_vector: list[float], num_candidates: int
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Returns subqueries for hybrid search.
|
||||
|
||||
@@ -384,8 +367,9 @@ class DocumentQuery:
|
||||
|
||||
Matches:
|
||||
- Title vector
|
||||
- Title keyword
|
||||
- Content vector
|
||||
- Keyword (title + content, match and phrase)
|
||||
- Content keyword + phrase
|
||||
|
||||
Normalization is not performed here.
|
||||
The weights of each of these subqueries should be configured in a search
|
||||
@@ -406,9 +390,9 @@ class DocumentQuery:
|
||||
NOTE: Options considered and rejected:
|
||||
- minimum_should_match: Since it's hybrid search and users often provide semantic queries, there is often a lot of terms,
|
||||
and very low number of meaningful keywords (and a low ratio of keywords).
|
||||
- fuzziness AUTO: typo tolerance (0/1/2 edit distance by term length). It's mostly for typos as the analyzer ("english by
|
||||
default") already does some stemming and tokenization. In testing datasets, this makes recall slightly worse. It also is
|
||||
less performant so not really any reason to do it.
|
||||
- fuzziness AUTO: typo tolerance (0/1/2 edit distance by term length). This is reasonable but in reality seeing the
|
||||
user usage patterns, this is not very common and people tend to not be confused when a miss happens for this reason.
|
||||
In testing datasets, this makes recall slightly worse.
|
||||
|
||||
Args:
|
||||
query_text: The text of the query to search for.
|
||||
@@ -417,27 +401,19 @@ class DocumentQuery:
|
||||
similarity search.
|
||||
"""
|
||||
# Build sub-queries for hybrid search. Order must match normalization
|
||||
# pipeline weights: title vector, content vector, keyword (title + content).
|
||||
# pipeline weights: title vector, title keyword, content vector,
|
||||
# content keyword.
|
||||
hybrid_search_queries: list[dict[str, Any]] = [
|
||||
# 1. Title vector search
|
||||
{
|
||||
"knn": {
|
||||
TITLE_VECTOR_FIELD_NAME: {
|
||||
"vector": query_vector,
|
||||
"k": vector_candidates,
|
||||
"k": num_candidates,
|
||||
}
|
||||
}
|
||||
},
|
||||
# 2. Content vector search
|
||||
{
|
||||
"knn": {
|
||||
CONTENT_VECTOR_FIELD_NAME: {
|
||||
"vector": query_vector,
|
||||
"k": vector_candidates,
|
||||
}
|
||||
}
|
||||
},
|
||||
# 3. Keyword (title + content) match and phrase search.
|
||||
# 2. Title keyword + phrase search.
|
||||
{
|
||||
"bool": {
|
||||
"should": [
|
||||
@@ -445,10 +421,8 @@ class DocumentQuery:
|
||||
"match": {
|
||||
TITLE_FIELD_NAME: {
|
||||
"query": query_text,
|
||||
# operator "or" = match doc if any query term matches (default, explicit for clarity).
|
||||
"operator": "or",
|
||||
# The title fields are strongly discounted as they are included in the content.
|
||||
# It just acts as a minor boost
|
||||
"boost": 0.1,
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -456,17 +430,35 @@ class DocumentQuery:
|
||||
"match_phrase": {
|
||||
TITLE_FIELD_NAME: {
|
||||
"query": query_text,
|
||||
# Slop = 1 allows one extra word or transposition in phrase match.
|
||||
"slop": 1,
|
||||
"boost": 0.2,
|
||||
# Boost phrase over bag-of-words; exact phrase is a stronger signal.
|
||||
"boost": 1.5,
|
||||
}
|
||||
}
|
||||
},
|
||||
]
|
||||
}
|
||||
},
|
||||
# 3. Content vector search
|
||||
{
|
||||
"knn": {
|
||||
CONTENT_VECTOR_FIELD_NAME: {
|
||||
"vector": query_vector,
|
||||
"k": num_candidates,
|
||||
}
|
||||
}
|
||||
},
|
||||
# 4. Content keyword + phrase search.
|
||||
{
|
||||
"bool": {
|
||||
"should": [
|
||||
{
|
||||
"match": {
|
||||
CONTENT_FIELD_NAME: {
|
||||
"query": query_text,
|
||||
# operator "or" = match doc if any query term matches (default, explicit for clarity).
|
||||
"operator": "or",
|
||||
"boost": 1.0,
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -474,7 +466,9 @@ class DocumentQuery:
|
||||
"match_phrase": {
|
||||
CONTENT_FIELD_NAME: {
|
||||
"query": query_text,
|
||||
# Slop = 1 allows one extra word or transposition in phrase match.
|
||||
"slop": 1,
|
||||
# Boost phrase over bag-of-words; exact phrase is a stronger signal.
|
||||
"boost": 1.5,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,12 +10,6 @@ from typing import cast
|
||||
import httpx
|
||||
from retry import retry
|
||||
|
||||
from onyx.background.celery.tasks.opensearch_migration.constants import (
|
||||
FINISHED_VISITING_SLICE_CONTINUATION_TOKEN,
|
||||
)
|
||||
from onyx.background.celery.tasks.opensearch_migration.transformer import (
|
||||
FIELDS_NEEDED_FOR_TRANSFORMATION,
|
||||
)
|
||||
from onyx.configs.app_configs import LOG_VESPA_TIMING_INFORMATION
|
||||
from onyx.configs.app_configs import VESPA_LANGUAGE_OVERRIDE
|
||||
from onyx.context.search.models import IndexFilters
|
||||
@@ -283,139 +277,54 @@ def get_chunks_via_visit_api(
|
||||
def get_all_chunks_paginated(
|
||||
index_name: str,
|
||||
tenant_state: TenantState,
|
||||
continuation_token_map: dict[int, str | None],
|
||||
page_size: int,
|
||||
) -> tuple[list[dict], dict[int, str | None]]:
|
||||
continuation_token: str | None = None,
|
||||
page_size: int = 1_000,
|
||||
) -> tuple[list[dict], str | None]:
|
||||
"""Gets all chunks in Vespa matching the filters, paginated.
|
||||
|
||||
Uses the Visit API with slicing. Each continuation token map entry is for a
|
||||
different slice. The number of entries determines the number of slices.
|
||||
|
||||
Args:
|
||||
index_name: The name of the Vespa index to visit.
|
||||
tenant_state: The tenant state to filter by.
|
||||
continuation_token_map: Map of slice ID to a token returned by Vespa
|
||||
representing a page offset. None to start from the beginning of the
|
||||
slice.
|
||||
continuation_token: Token returned by Vespa representing a page offset.
|
||||
None to start from the beginning. Defaults to None.
|
||||
page_size: Best-effort batch size for the visit. Defaults to 1,000.
|
||||
|
||||
Returns:
|
||||
Tuple of (list of chunk dicts, next continuation token or None). The
|
||||
continuation token is None when the visit is complete.
|
||||
"""
|
||||
url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name)
|
||||
|
||||
def _get_all_chunks_paginated_for_slice(
|
||||
index_name: str,
|
||||
tenant_state: TenantState,
|
||||
slice_id: int,
|
||||
total_slices: int,
|
||||
continuation_token: str | None,
|
||||
page_size: int,
|
||||
) -> tuple[list[dict], str | None]:
|
||||
if continuation_token == FINISHED_VISITING_SLICE_CONTINUATION_TOKEN:
|
||||
logger.debug(
|
||||
f"Slice {slice_id} has finished visiting. Returning empty list and {FINISHED_VISITING_SLICE_CONTINUATION_TOKEN}."
|
||||
)
|
||||
return [], FINISHED_VISITING_SLICE_CONTINUATION_TOKEN
|
||||
selection: str = f"{index_name}.large_chunk_reference_ids == null"
|
||||
if MULTI_TENANT:
|
||||
selection += f" and {index_name}.tenant_id=='{tenant_state.tenant_id}'"
|
||||
|
||||
url = DOCUMENT_ID_ENDPOINT.format(index_name=index_name)
|
||||
|
||||
selection: str = f"{index_name}.large_chunk_reference_ids == null"
|
||||
if MULTI_TENANT:
|
||||
selection += f" and {index_name}.tenant_id=='{tenant_state.tenant_id}'"
|
||||
|
||||
field_set = f"{index_name}:" + ",".join(FIELDS_NEEDED_FOR_TRANSFORMATION)
|
||||
|
||||
params: dict[str, str | int | None] = {
|
||||
"selection": selection,
|
||||
"fieldSet": field_set,
|
||||
"wantedDocumentCount": page_size,
|
||||
"format.tensors": "short-value",
|
||||
"slices": total_slices,
|
||||
"sliceId": slice_id,
|
||||
}
|
||||
if continuation_token is not None:
|
||||
params["continuation"] = continuation_token
|
||||
|
||||
response: httpx.Response | None = None
|
||||
try:
|
||||
with get_vespa_http_client() as http_client:
|
||||
response = http_client.get(url, params=params)
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPError as e:
|
||||
error_base = f"Failed to get chunks from Vespa slice {slice_id} with continuation token {continuation_token}."
|
||||
logger.exception(
|
||||
f"Request URL: {e.request.url}\n"
|
||||
f"Request Headers: {e.request.headers}\n"
|
||||
f"Request Payload: {params}\n"
|
||||
)
|
||||
error_message = (
|
||||
response.json().get("message") if response else "No response"
|
||||
)
|
||||
logger.error("Error message from response: %s", error_message)
|
||||
raise httpx.HTTPError(error_base) from e
|
||||
|
||||
response_data = response.json()
|
||||
|
||||
# NOTE: If we see a falsey value for "continuation" in the response we
|
||||
# assume we are done and return
|
||||
# FINISHED_VISITING_SLICE_CONTINUATION_TOKEN instead.
|
||||
next_continuation_token = (
|
||||
response_data.get("continuation")
|
||||
or FINISHED_VISITING_SLICE_CONTINUATION_TOKEN
|
||||
)
|
||||
chunks = [chunk["fields"] for chunk in response_data.get("documents", [])]
|
||||
if next_continuation_token == FINISHED_VISITING_SLICE_CONTINUATION_TOKEN:
|
||||
logger.debug(
|
||||
f"Slice {slice_id} has finished visiting. Returning {len(chunks)} chunks and {next_continuation_token}."
|
||||
)
|
||||
return chunks, next_continuation_token
|
||||
|
||||
total_slices = len(continuation_token_map)
|
||||
if total_slices < 1:
|
||||
raise ValueError("continuation_token_map must have at least one entry.")
|
||||
# We want to guarantee that these invocations are ordered by slice_id,
|
||||
# because we read in the same order below when parsing parallel_results.
|
||||
functions_with_args: list[tuple[Callable, tuple]] = [
|
||||
(
|
||||
_get_all_chunks_paginated_for_slice,
|
||||
(
|
||||
index_name,
|
||||
tenant_state,
|
||||
slice_id,
|
||||
total_slices,
|
||||
continuation_token,
|
||||
page_size,
|
||||
),
|
||||
)
|
||||
for slice_id, continuation_token in sorted(continuation_token_map.items())
|
||||
]
|
||||
|
||||
parallel_results = run_functions_tuples_in_parallel(
|
||||
functions_with_args, allow_failures=True
|
||||
)
|
||||
if len(parallel_results) != total_slices:
|
||||
raise RuntimeError(
|
||||
f"Expected {total_slices} parallel results, but got {len(parallel_results)}."
|
||||
)
|
||||
|
||||
chunks: list[dict] = []
|
||||
next_continuation_token_map: dict[int, str | None] = {
|
||||
key: value for key, value in continuation_token_map.items()
|
||||
params: dict[str, str | int | None] = {
|
||||
"selection": selection,
|
||||
"wantedDocumentCount": page_size,
|
||||
"format.tensors": "short-value",
|
||||
}
|
||||
for i, parallel_result in enumerate(parallel_results):
|
||||
if i not in next_continuation_token_map:
|
||||
raise RuntimeError(f"Slice {i} is not in the continuation token map.")
|
||||
if parallel_result is None:
|
||||
logger.error(
|
||||
f"Failed to get chunks for slice {i} of {total_slices}. "
|
||||
"The continuation token for this slice will not be updated."
|
||||
)
|
||||
continue
|
||||
chunks.extend(parallel_result[0])
|
||||
next_continuation_token_map[i] = parallel_result[1]
|
||||
if continuation_token is not None:
|
||||
params["continuation"] = continuation_token
|
||||
|
||||
return chunks, next_continuation_token_map
|
||||
try:
|
||||
with get_vespa_http_client() as http_client:
|
||||
response = http_client.get(url, params=params)
|
||||
response.raise_for_status()
|
||||
except httpx.HTTPError as e:
|
||||
error_base = "Failed to get chunks in Vespa."
|
||||
logger.exception(
|
||||
f"Request URL: {e.request.url}\n"
|
||||
f"Request Headers: {e.request.headers}\n"
|
||||
f"Request Payload: {params}\n"
|
||||
)
|
||||
raise httpx.HTTPError(error_base) from e
|
||||
|
||||
response_data = response.json()
|
||||
|
||||
return [
|
||||
chunk["fields"] for chunk in response_data.get("documents", [])
|
||||
], response_data.get("continuation") or None
|
||||
|
||||
|
||||
# TODO(rkuo): candidate for removal if not being used
|
||||
|
||||
@@ -56,7 +56,6 @@ from onyx.document_index.vespa_constants import CONTENT_SUMMARY
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_ID
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_ID_ENDPOINT
|
||||
from onyx.document_index.vespa_constants import NUM_THREADS
|
||||
from onyx.document_index.vespa_constants import SEARCH_ENDPOINT
|
||||
from onyx.document_index.vespa_constants import VESPA_TIMEOUT
|
||||
from onyx.document_index.vespa_constants import YQL_BASE
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
@@ -653,9 +652,9 @@ class VespaDocumentIndex(DocumentIndex):
|
||||
|
||||
def get_all_raw_document_chunks_paginated(
|
||||
self,
|
||||
continuation_token_map: dict[int, str | None],
|
||||
continuation_token: str | None,
|
||||
page_size: int,
|
||||
) -> tuple[list[dict[str, Any]], dict[int, str | None]]:
|
||||
) -> tuple[list[dict[str, Any]], str | None]:
|
||||
"""Gets all the chunks in Vespa, paginated.
|
||||
|
||||
Used in the chunk-level Vespa-to-OpenSearch migration task.
|
||||
@@ -663,21 +662,21 @@ class VespaDocumentIndex(DocumentIndex):
|
||||
Args:
|
||||
continuation_token: Token returned by Vespa representing a page
|
||||
offset. None to start from the beginning. Defaults to None.
|
||||
page_size: Best-effort batch size for the visit.
|
||||
page_size: Best-effort batch size for the visit. Defaults to 1,000.
|
||||
|
||||
Returns:
|
||||
Tuple of (list of chunk dicts, next continuation token or None). The
|
||||
continuation token is None when the visit is complete.
|
||||
"""
|
||||
raw_chunks, next_continuation_token_map = get_all_chunks_paginated(
|
||||
raw_chunks, next_continuation_token = get_all_chunks_paginated(
|
||||
index_name=self._index_name,
|
||||
tenant_state=TenantState(
|
||||
tenant_id=self._tenant_id, multitenant=MULTI_TENANT
|
||||
),
|
||||
continuation_token_map=continuation_token_map,
|
||||
continuation_token=continuation_token,
|
||||
page_size=page_size,
|
||||
)
|
||||
return raw_chunks, next_continuation_token_map
|
||||
return raw_chunks, next_continuation_token
|
||||
|
||||
def index_raw_chunks(self, chunks: list[dict[str, Any]]) -> None:
|
||||
"""Indexes raw document chunks into Vespa.
|
||||
@@ -703,32 +702,3 @@ class VespaDocumentIndex(DocumentIndex):
|
||||
json={"fields": chunk},
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
def get_chunk_count(self) -> int:
|
||||
"""Returns the exact number of document chunks in Vespa for this tenant.
|
||||
|
||||
Uses the Vespa Search API with `limit 0` and `ranking.profile=unranked`
|
||||
to get an exact count without fetching any document data.
|
||||
|
||||
Includes large chunks. There is no way to filter these out using the
|
||||
Search API.
|
||||
"""
|
||||
where_clause = (
|
||||
f'tenant_id contains "{self._tenant_id}"' if self._multitenant else "true"
|
||||
)
|
||||
yql = (
|
||||
f"select documentid from {self._index_name} "
|
||||
f"where {where_clause} "
|
||||
f"limit 0"
|
||||
)
|
||||
params: dict[str, str | int] = {
|
||||
"yql": yql,
|
||||
"ranking.profile": "unranked",
|
||||
"timeout": VESPA_TIMEOUT,
|
||||
}
|
||||
|
||||
with get_vespa_http_client() as http_client:
|
||||
response = http_client.post(SEARCH_ENDPOINT, json=params)
|
||||
response.raise_for_status()
|
||||
response_data = response.json()
|
||||
return response_data["root"]["fields"]["totalCount"]
|
||||
|
||||
@@ -20,20 +20,7 @@ class ImageGenerationProviderCredentials(BaseModel):
|
||||
custom_config: dict[str, str] | None = None
|
||||
|
||||
|
||||
class ReferenceImage(BaseModel):
|
||||
data: bytes
|
||||
mime_type: str
|
||||
|
||||
|
||||
class ImageGenerationProvider(abc.ABC):
|
||||
@property
|
||||
def supports_reference_images(self) -> bool:
|
||||
return False
|
||||
|
||||
@property
|
||||
def max_reference_images(self) -> int:
|
||||
return 0
|
||||
|
||||
@classmethod
|
||||
@abc.abstractmethod
|
||||
def validate_credentials(
|
||||
@@ -76,7 +63,6 @@ class ImageGenerationProvider(abc.ABC):
|
||||
size: str,
|
||||
n: int,
|
||||
quality: str | None = None,
|
||||
reference_images: list[ReferenceImage] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ImageGenerationResponse:
|
||||
"""Generates an image based on a prompt."""
|
||||
|
||||
@@ -5,7 +5,6 @@ from typing import TYPE_CHECKING
|
||||
|
||||
from onyx.image_gen.interfaces import ImageGenerationProvider
|
||||
from onyx.image_gen.interfaces import ImageGenerationProviderCredentials
|
||||
from onyx.image_gen.interfaces import ReferenceImage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from onyx.image_gen.interfaces import ImageGenerationResponse
|
||||
@@ -60,7 +59,6 @@ class AzureImageGenerationProvider(ImageGenerationProvider):
|
||||
size: str,
|
||||
n: int,
|
||||
quality: str | None = None,
|
||||
reference_images: list[ReferenceImage] | None = None, # noqa: ARG002
|
||||
**kwargs: Any,
|
||||
) -> ImageGenerationResponse:
|
||||
from litellm import image_generation
|
||||
|
||||
@@ -5,7 +5,6 @@ from typing import TYPE_CHECKING
|
||||
|
||||
from onyx.image_gen.interfaces import ImageGenerationProvider
|
||||
from onyx.image_gen.interfaces import ImageGenerationProviderCredentials
|
||||
from onyx.image_gen.interfaces import ReferenceImage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from onyx.image_gen.interfaces import ImageGenerationResponse
|
||||
@@ -46,7 +45,6 @@ class OpenAIImageGenerationProvider(ImageGenerationProvider):
|
||||
size: str,
|
||||
n: int,
|
||||
quality: str | None = None,
|
||||
reference_images: list[ReferenceImage] | None = None, # noqa: ARG002
|
||||
**kwargs: Any,
|
||||
) -> ImageGenerationResponse:
|
||||
from litellm import image_generation
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
@@ -11,7 +9,6 @@ from pydantic import BaseModel
|
||||
from onyx.image_gen.exceptions import ImageProviderCredentialsError
|
||||
from onyx.image_gen.interfaces import ImageGenerationProvider
|
||||
from onyx.image_gen.interfaces import ImageGenerationProviderCredentials
|
||||
from onyx.image_gen.interfaces import ReferenceImage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from onyx.image_gen.interfaces import ImageGenerationResponse
|
||||
@@ -54,15 +51,6 @@ class VertexImageGenerationProvider(ImageGenerationProvider):
|
||||
vertex_credentials=vertex_credentials,
|
||||
)
|
||||
|
||||
@property
|
||||
def supports_reference_images(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def max_reference_images(self) -> int:
|
||||
# Gemini image editing supports up to 14 input images.
|
||||
return 14
|
||||
|
||||
def generate_image(
|
||||
self,
|
||||
prompt: str,
|
||||
@@ -70,18 +58,8 @@ class VertexImageGenerationProvider(ImageGenerationProvider):
|
||||
size: str,
|
||||
n: int,
|
||||
quality: str | None = None,
|
||||
reference_images: list[ReferenceImage] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ImageGenerationResponse:
|
||||
if reference_images:
|
||||
return self._generate_image_with_reference_images(
|
||||
prompt=prompt,
|
||||
model=model,
|
||||
size=size,
|
||||
n=n,
|
||||
reference_images=reference_images,
|
||||
)
|
||||
|
||||
from litellm import image_generation
|
||||
|
||||
return image_generation(
|
||||
@@ -96,99 +74,6 @@ class VertexImageGenerationProvider(ImageGenerationProvider):
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def _generate_image_with_reference_images(
|
||||
self,
|
||||
prompt: str,
|
||||
model: str,
|
||||
size: str,
|
||||
n: int,
|
||||
reference_images: list[ReferenceImage],
|
||||
) -> ImageGenerationResponse:
|
||||
from google import genai
|
||||
from google.genai import types as genai_types
|
||||
from google.oauth2 import service_account
|
||||
from litellm.types.utils import ImageObject
|
||||
from litellm.types.utils import ImageResponse
|
||||
|
||||
service_account_info = json.loads(self._vertex_credentials)
|
||||
credentials = service_account.Credentials.from_service_account_info(
|
||||
service_account_info,
|
||||
scopes=["https://www.googleapis.com/auth/cloud-platform"],
|
||||
)
|
||||
|
||||
client = genai.Client(
|
||||
vertexai=True,
|
||||
project=self._vertex_project,
|
||||
location=self._vertex_location,
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
parts: list[genai_types.Part] = [
|
||||
genai_types.Part.from_bytes(data=image.data, mime_type=image.mime_type)
|
||||
for image in reference_images
|
||||
]
|
||||
parts.append(genai_types.Part.from_text(text=prompt))
|
||||
|
||||
config = genai_types.GenerateContentConfig(
|
||||
response_modalities=["TEXT", "IMAGE"],
|
||||
candidate_count=max(1, n),
|
||||
image_config=genai_types.ImageConfig(
|
||||
aspect_ratio=_map_size_to_aspect_ratio(size)
|
||||
),
|
||||
)
|
||||
model_name = model.replace("vertex_ai/", "")
|
||||
response = client.models.generate_content(
|
||||
model=model_name,
|
||||
contents=genai_types.Content(
|
||||
role="user",
|
||||
parts=parts,
|
||||
),
|
||||
config=config,
|
||||
)
|
||||
|
||||
generated_data: list[ImageObject] = []
|
||||
for candidate in response.candidates or []:
|
||||
candidate_content = candidate.content
|
||||
if not candidate_content:
|
||||
continue
|
||||
|
||||
for part in candidate_content.parts or []:
|
||||
inline_data = part.inline_data
|
||||
if not inline_data or inline_data.data is None:
|
||||
continue
|
||||
|
||||
if isinstance(inline_data.data, bytes):
|
||||
b64_json = base64.b64encode(inline_data.data).decode("utf-8")
|
||||
elif isinstance(inline_data.data, str):
|
||||
b64_json = inline_data.data
|
||||
else:
|
||||
continue
|
||||
|
||||
generated_data.append(
|
||||
ImageObject(
|
||||
b64_json=b64_json,
|
||||
revised_prompt=prompt,
|
||||
)
|
||||
)
|
||||
|
||||
if not generated_data:
|
||||
raise RuntimeError("No image data returned from Vertex AI.")
|
||||
|
||||
return ImageResponse(
|
||||
created=int(datetime.now().timestamp()),
|
||||
data=generated_data,
|
||||
)
|
||||
|
||||
|
||||
def _map_size_to_aspect_ratio(size: str) -> str:
|
||||
return {
|
||||
"1024x1024": "1:1",
|
||||
"1792x1024": "16:9",
|
||||
"1024x1792": "9:16",
|
||||
"1536x1024": "3:2",
|
||||
"1024x1536": "2:3",
|
||||
}.get(size, "1:1")
|
||||
|
||||
|
||||
def _parse_to_vertex_credentials(
|
||||
credentials: ImageGenerationProviderCredentials,
|
||||
|
||||
@@ -64,6 +64,21 @@
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20241022-v2:0"
|
||||
},
|
||||
"anthropic.claude-3-7-sonnet-20240620-v1:0": {
|
||||
"display_name": "Claude Sonnet 3.7",
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20240620-v1:0"
|
||||
},
|
||||
"anthropic.claude-3-7-sonnet-20250219-v1:0": {
|
||||
"display_name": "Claude Sonnet 3.7",
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20250219-v1:0"
|
||||
},
|
||||
"anthropic.claude-3-haiku-20240307-v1:0": {
|
||||
"display_name": "Claude Haiku 3",
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20240307-v1:0"
|
||||
},
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0": {
|
||||
"display_name": "Claude Sonnet 3",
|
||||
"model_vendor": "anthropic",
|
||||
@@ -144,6 +159,11 @@
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20241022-v2:0"
|
||||
},
|
||||
"apac.anthropic.claude-3-haiku-20240307-v1:0": {
|
||||
"display_name": "Claude Haiku 3",
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20240307-v1:0"
|
||||
},
|
||||
"apac.anthropic.claude-3-sonnet-20240229-v1:0": {
|
||||
"display_name": "Claude Sonnet 3",
|
||||
"model_vendor": "anthropic",
|
||||
@@ -1300,6 +1320,11 @@
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20240620-v1:0"
|
||||
},
|
||||
"bedrock/us-gov-east-1/anthropic.claude-3-haiku-20240307-v1:0": {
|
||||
"display_name": "Claude Haiku 3",
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20240307-v1:0"
|
||||
},
|
||||
"bedrock/us-gov-east-1/claude-sonnet-4-5-20250929-v1:0": {
|
||||
"display_name": "Claude Sonnet 4.5",
|
||||
"model_vendor": "anthropic",
|
||||
@@ -1340,6 +1365,16 @@
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20240620-v1:0"
|
||||
},
|
||||
"bedrock/us-gov-west-1/anthropic.claude-3-7-sonnet-20250219-v1:0": {
|
||||
"display_name": "Claude Sonnet 3.7",
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20250219-v1:0"
|
||||
},
|
||||
"bedrock/us-gov-west-1/anthropic.claude-3-haiku-20240307-v1:0": {
|
||||
"display_name": "Claude Haiku 3",
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20240307-v1:0"
|
||||
},
|
||||
"bedrock/us-gov-west-1/claude-sonnet-4-5-20250929-v1:0": {
|
||||
"display_name": "Claude Sonnet 4.5",
|
||||
"model_vendor": "anthropic",
|
||||
@@ -1470,6 +1505,26 @@
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "latest"
|
||||
},
|
||||
"claude-3-7-sonnet-20250219": {
|
||||
"display_name": "Claude Sonnet 3.7",
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20250219"
|
||||
},
|
||||
"claude-3-7-sonnet-latest": {
|
||||
"display_name": "Claude Sonnet 3.7",
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "latest"
|
||||
},
|
||||
"claude-3-7-sonnet@20250219": {
|
||||
"display_name": "Claude Sonnet 3.7",
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20250219"
|
||||
},
|
||||
"claude-3-haiku-20240307": {
|
||||
"display_name": "Claude Haiku 3",
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20240307"
|
||||
},
|
||||
"claude-4-opus-20250514": {
|
||||
"display_name": "Claude Opus 4",
|
||||
"model_vendor": "anthropic",
|
||||
@@ -1650,6 +1705,16 @@
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20241022-v2:0"
|
||||
},
|
||||
"eu.anthropic.claude-3-7-sonnet-20250219-v1:0": {
|
||||
"display_name": "Claude Sonnet 3.7",
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20250219-v1:0"
|
||||
},
|
||||
"eu.anthropic.claude-3-haiku-20240307-v1:0": {
|
||||
"display_name": "Claude Haiku 3",
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20240307-v1:0"
|
||||
},
|
||||
"eu.anthropic.claude-3-sonnet-20240229-v1:0": {
|
||||
"display_name": "Claude Sonnet 3",
|
||||
"model_vendor": "anthropic",
|
||||
@@ -3161,6 +3226,15 @@
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "latest"
|
||||
},
|
||||
"openrouter/anthropic/claude-3-haiku": {
|
||||
"display_name": "Claude Haiku 3",
|
||||
"model_vendor": "anthropic"
|
||||
},
|
||||
"openrouter/anthropic/claude-3-haiku-20240307": {
|
||||
"display_name": "Claude Haiku 3",
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20240307"
|
||||
},
|
||||
"openrouter/anthropic/claude-3-sonnet": {
|
||||
"display_name": "Claude Sonnet 3",
|
||||
"model_vendor": "anthropic"
|
||||
@@ -3175,6 +3249,16 @@
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "latest"
|
||||
},
|
||||
"openrouter/anthropic/claude-3.7-sonnet": {
|
||||
"display_name": "Claude Sonnet 3.7",
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "latest"
|
||||
},
|
||||
"openrouter/anthropic/claude-3.7-sonnet:beta": {
|
||||
"display_name": "Claude Sonnet 3.7:beta",
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "latest"
|
||||
},
|
||||
"openrouter/anthropic/claude-haiku-4.5": {
|
||||
"display_name": "Claude Haiku 4.5",
|
||||
"model_vendor": "anthropic",
|
||||
@@ -3666,6 +3750,16 @@
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20241022"
|
||||
},
|
||||
"us.anthropic.claude-3-7-sonnet-20250219-v1:0": {
|
||||
"display_name": "Claude Sonnet 3.7",
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20250219"
|
||||
},
|
||||
"us.anthropic.claude-3-haiku-20240307-v1:0": {
|
||||
"display_name": "Claude Haiku 3",
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20240307"
|
||||
},
|
||||
"us.anthropic.claude-3-sonnet-20240229-v1:0": {
|
||||
"display_name": "Claude Sonnet 3",
|
||||
"model_vendor": "anthropic",
|
||||
@@ -3785,6 +3879,20 @@
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20240620"
|
||||
},
|
||||
"vertex_ai/claude-3-7-sonnet@20250219": {
|
||||
"display_name": "Claude Sonnet 3.7",
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20250219"
|
||||
},
|
||||
"vertex_ai/claude-3-haiku": {
|
||||
"display_name": "Claude Haiku 3",
|
||||
"model_vendor": "anthropic"
|
||||
},
|
||||
"vertex_ai/claude-3-haiku@20240307": {
|
||||
"display_name": "Claude Haiku 3",
|
||||
"model_vendor": "anthropic",
|
||||
"model_version": "20240307"
|
||||
},
|
||||
"vertex_ai/claude-3-sonnet": {
|
||||
"display_name": "Claude Sonnet 3",
|
||||
"model_vendor": "anthropic"
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import json
|
||||
import pathlib
|
||||
import threading
|
||||
import time
|
||||
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.constants import PROVIDER_DISPLAY_NAMES
|
||||
@@ -25,11 +23,6 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_RECOMMENDATIONS_CACHE_TTL_SECONDS = 300
|
||||
_recommendations_cache_lock = threading.Lock()
|
||||
_cached_recommendations: LLMRecommendations | None = None
|
||||
_cached_recommendations_time: float = 0.0
|
||||
|
||||
|
||||
def _get_provider_to_models_map() -> dict[str, list[str]]:
|
||||
"""Lazy-load provider model mappings to avoid importing litellm at module level.
|
||||
@@ -48,40 +41,19 @@ def _get_provider_to_models_map() -> dict[str, list[str]]:
|
||||
}
|
||||
|
||||
|
||||
def _load_bundled_recommendations() -> LLMRecommendations:
|
||||
def get_recommendations() -> LLMRecommendations:
|
||||
"""Get the recommendations from the GitHub config."""
|
||||
recommendations_from_github = fetch_llm_recommendations_from_github()
|
||||
if recommendations_from_github:
|
||||
return recommendations_from_github
|
||||
|
||||
# Fall back to json bundled with code
|
||||
json_path = pathlib.Path(__file__).parent / "recommended-models.json"
|
||||
with open(json_path, "r") as f:
|
||||
json_config = json.load(f)
|
||||
return LLMRecommendations.model_validate(json_config)
|
||||
|
||||
|
||||
def get_recommendations() -> LLMRecommendations:
|
||||
"""Get the recommendations, with an in-memory cache to avoid
|
||||
hitting GitHub on every API request."""
|
||||
global _cached_recommendations, _cached_recommendations_time
|
||||
|
||||
now = time.monotonic()
|
||||
if (
|
||||
_cached_recommendations is not None
|
||||
and (now - _cached_recommendations_time) < _RECOMMENDATIONS_CACHE_TTL_SECONDS
|
||||
):
|
||||
return _cached_recommendations
|
||||
|
||||
with _recommendations_cache_lock:
|
||||
# Double-check after acquiring lock
|
||||
if (
|
||||
_cached_recommendations is not None
|
||||
and (time.monotonic() - _cached_recommendations_time)
|
||||
< _RECOMMENDATIONS_CACHE_TTL_SECONDS
|
||||
):
|
||||
return _cached_recommendations
|
||||
|
||||
recommendations_from_github = fetch_llm_recommendations_from_github()
|
||||
result = recommendations_from_github or _load_bundled_recommendations()
|
||||
|
||||
_cached_recommendations = result
|
||||
_cached_recommendations_time = time.monotonic()
|
||||
return result
|
||||
recommendations_from_json = LLMRecommendations.model_validate(json_config)
|
||||
return recommendations_from_json
|
||||
|
||||
|
||||
def is_obsolete_model(model_name: str, provider: str) -> bool:
|
||||
@@ -243,23 +215,6 @@ def model_configurations_for_provider(
|
||||
) -> list[ModelConfigurationView]:
|
||||
recommended_visible_models = llm_recommendations.get_visible_models(provider_name)
|
||||
recommended_visible_models_names = [m.name for m in recommended_visible_models]
|
||||
|
||||
# Preserve provider-defined ordering while de-duplicating.
|
||||
model_names: list[str] = []
|
||||
seen_model_names: set[str] = set()
|
||||
for model_name in (
|
||||
fetch_models_for_provider(provider_name) + recommended_visible_models_names
|
||||
):
|
||||
if model_name in seen_model_names:
|
||||
continue
|
||||
seen_model_names.add(model_name)
|
||||
model_names.append(model_name)
|
||||
|
||||
# Vertex model list can be large and mixed-vendor; alphabetical ordering
|
||||
# makes model discovery easier in admin selection UIs.
|
||||
if provider_name == VERTEXAI_PROVIDER_NAME:
|
||||
model_names = sorted(model_names, key=str.lower)
|
||||
|
||||
return [
|
||||
ModelConfigurationView(
|
||||
name=model_name,
|
||||
@@ -267,7 +222,8 @@ def model_configurations_for_provider(
|
||||
max_input_tokens=get_max_input_tokens(model_name, provider_name),
|
||||
supports_image_input=model_supports_image_input(model_name, provider_name),
|
||||
)
|
||||
for model_name in model_names
|
||||
for model_name in set(fetch_models_for_provider(provider_name))
|
||||
| set(recommended_visible_models_names)
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -52,7 +52,6 @@ from onyx.configs.app_configs import USER_AUTH_SECRET
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import POSTGRES_WEB_APP_NAME
|
||||
from onyx.db.engine.async_sql_engine import get_sqlalchemy_async_engine
|
||||
from onyx.db.engine.connection_warmup import warm_up_connections
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
@@ -64,7 +63,7 @@ from onyx.server.documents.connector import router as connector_router
|
||||
from onyx.server.documents.credential import router as credential_router
|
||||
from onyx.server.documents.document import router as document_router
|
||||
from onyx.server.documents.standard_oauth import router as standard_oauth_router
|
||||
from onyx.server.features.build.api.api import public_build_router
|
||||
from onyx.server.features.build.api.api import nextjs_assets_router
|
||||
from onyx.server.features.build.api.api import router as build_router
|
||||
from onyx.server.features.default_assistant.api import (
|
||||
router as default_assistant_router,
|
||||
@@ -115,16 +114,13 @@ from onyx.server.manage.users import router as user_router
|
||||
from onyx.server.manage.web_search.api import (
|
||||
admin_router as web_search_admin_router,
|
||||
)
|
||||
from onyx.server.metrics.postgres_connection_pool import (
|
||||
setup_postgres_connection_pool_metrics,
|
||||
)
|
||||
from onyx.server.metrics.prometheus_setup import setup_prometheus_metrics
|
||||
from onyx.server.middleware.latency_logging import add_latency_logging_middleware
|
||||
from onyx.server.middleware.rate_limiting import close_auth_limiter
|
||||
from onyx.server.middleware.rate_limiting import get_auth_rate_limiters
|
||||
from onyx.server.middleware.rate_limiting import setup_auth_limiter
|
||||
from onyx.server.onyx_api.ingestion import router as onyx_api_router
|
||||
from onyx.server.pat.api import router as pat_router
|
||||
from onyx.server.prometheus_instrumentation import setup_prometheus_metrics
|
||||
from onyx.server.query_and_chat.chat_backend import router as chat_router
|
||||
from onyx.server.query_and_chat.query_backend import (
|
||||
admin_router as admin_query_router,
|
||||
@@ -142,7 +138,6 @@ from onyx.setup import setup_onyx
|
||||
from onyx.tracing.setup import setup_tracing
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.logger import setup_uvicorn_logger
|
||||
from onyx.utils.middleware import add_endpoint_context_middleware
|
||||
from onyx.utils.middleware import add_onyx_request_id_middleware
|
||||
from onyx.utils.telemetry import get_or_generate_uuid
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
@@ -271,17 +266,6 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
|
||||
max_overflow=POSTGRES_API_SERVER_READ_ONLY_POOL_OVERFLOW,
|
||||
)
|
||||
|
||||
# Register pool metrics now that engines are created.
|
||||
# HTTP instrumentation is set up earlier in get_application() since it
|
||||
# adds middleware (which Starlette forbids after the app has started).
|
||||
setup_postgres_connection_pool_metrics(
|
||||
engines={
|
||||
"sync": SqlEngine.get_engine(),
|
||||
"async": get_sqlalchemy_async_engine(),
|
||||
"readonly": SqlEngine.get_readonly_engine(),
|
||||
},
|
||||
)
|
||||
|
||||
verify_auth = fetch_versioned_implementation(
|
||||
"onyx.auth.users", "verify_auth_setting"
|
||||
)
|
||||
@@ -394,8 +378,8 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(application, admin_input_prompt_router)
|
||||
include_router_with_global_prefix_prepended(application, cc_pair_router)
|
||||
include_router_with_global_prefix_prepended(application, projects_router)
|
||||
include_router_with_global_prefix_prepended(application, public_build_router)
|
||||
include_router_with_global_prefix_prepended(application, build_router)
|
||||
include_router_with_global_prefix_prepended(application, nextjs_assets_router)
|
||||
include_router_with_global_prefix_prepended(application, document_set_router)
|
||||
include_router_with_global_prefix_prepended(application, hierarchy_router)
|
||||
include_router_with_global_prefix_prepended(application, search_settings_router)
|
||||
@@ -576,18 +560,12 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
|
||||
|
||||
add_onyx_request_id_middleware(application, "API", logger)
|
||||
|
||||
# Set endpoint context for per-endpoint DB pool attribution metrics.
|
||||
# Must be registered after all routes are added.
|
||||
add_endpoint_context_middleware(application)
|
||||
|
||||
# HTTP request metrics (latency histograms, in-progress gauge, slow request
|
||||
# counter). Must be called here — before the app starts — because the
|
||||
# instrumentator adds middleware via app.add_middleware().
|
||||
setup_prometheus_metrics(application)
|
||||
|
||||
# Ensure all routes have auth enabled or are explicitly marked as public
|
||||
check_router_auth(application)
|
||||
|
||||
# Initialize and instrument the app with production Prometheus config
|
||||
setup_prometheus_metrics(application)
|
||||
|
||||
use_route_function_names_as_operation_ids(application)
|
||||
|
||||
return application
|
||||
|
||||
@@ -69,12 +69,6 @@ Very briefly describe the image(s) generated. Do not include any links or attach
|
||||
""".strip()
|
||||
|
||||
|
||||
FILE_REMINDER = """
|
||||
Your code execution generated file(s) with download links.
|
||||
If you reference or share these files, use the exact markdown format [filename](file_link) with the file_link from the execution result.
|
||||
""".strip()
|
||||
|
||||
|
||||
# Specifically for OpenAI models, this prefix needs to be in place for the model to output markdown and correct styling
|
||||
CODE_BLOCK_MARKDOWN = "Formatting re-enabled. "
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# ruff: noqa: E501, W605 start
|
||||
# If there are any tools, this section is included, the sections below are for the available tools
|
||||
TOOL_SECTION_HEADER = "\n# Tools\n\n"
|
||||
TOOL_SECTION_HEADER = "\n\n# Tools\n"
|
||||
|
||||
|
||||
# This section is included if there are search type tools, currently internal_search and web_search
|
||||
@@ -16,10 +16,11 @@ When searching for information, if the initial results cannot fully answer the u
|
||||
Do not repeat the same or very similar queries if it already has been run in the chat history.
|
||||
|
||||
If it is unclear which tool to use, consider using multiple in parallel to be efficient with time.
|
||||
""".lstrip()
|
||||
"""
|
||||
|
||||
|
||||
INTERNAL_SEARCH_GUIDANCE = """
|
||||
|
||||
## internal_search
|
||||
Use the `internal_search` tool to search connected applications for information. Some examples of when to use `internal_search` include:
|
||||
- Internal information: any time where there may be some information stored in internal applications that could help better answer the query.
|
||||
@@ -27,31 +28,34 @@ Use the `internal_search` tool to search connected applications for information.
|
||||
- Keyword Queries: queries that are heavily keyword based are often internal document search queries.
|
||||
- Ambiguity: questions about something that is not widely known or understood.
|
||||
Never provide more than 3 queries at once to `internal_search`.
|
||||
""".lstrip()
|
||||
"""
|
||||
|
||||
|
||||
WEB_SEARCH_GUIDANCE = """
|
||||
|
||||
## web_search
|
||||
Use the `web_search` tool to access up-to-date information from the web. Some examples of when to use `web_search` include:
|
||||
- Freshness: when the answer might be enhanced by up-to-date information on a topic. Very important for topics that are changing or evolving.
|
||||
- Accuracy: if the cost of outdated/inaccurate information is high.
|
||||
- Niche Information: when detailed info is not widely known or understood (but is likely found on the internet).{site_colon_disabled}
|
||||
""".lstrip()
|
||||
"""
|
||||
|
||||
WEB_SEARCH_SITE_DISABLED_GUIDANCE = """
|
||||
Do not use the "site:" operator in your web search queries.
|
||||
""".lstrip()
|
||||
""".rstrip()
|
||||
|
||||
|
||||
OPEN_URLS_GUIDANCE = """
|
||||
|
||||
## open_url
|
||||
Use the `open_url` tool to read the content of one or more URLs. Use this tool to access the contents of the most promising web pages from your web searches or user specified URLs. \
|
||||
You can open many URLs at once by passing multiple URLs in the array if multiple pages seem promising. Prioritize the most promising pages and reputable sources. \
|
||||
Do not open URLs that are image files like .png, .jpg, etc.
|
||||
You should almost always use open_url after a web_search call. Use this tool when a user asks about a specific provided URL.
|
||||
""".lstrip()
|
||||
"""
|
||||
|
||||
PYTHON_TOOL_GUIDANCE = """
|
||||
|
||||
## python
|
||||
Use the `python` tool to execute Python code in an isolated sandbox. The tool will respond with the output of the execution or time out after 60.0 seconds.
|
||||
Any files uploaded to the chat will be automatically be available in the execution environment's current directory. \
|
||||
@@ -60,21 +64,21 @@ Use this to give the user a way to download the file OR to display generated ima
|
||||
Internet access for this session is disabled. Do not make external web requests or API calls as they will fail.
|
||||
Use `openpyxl` to read and write Excel files. You have access to libraries like numpy, pandas, scipy, matplotlib, and PIL.
|
||||
IMPORTANT: each call to this tool is independent. Variables from previous calls will NOT be available in the current call.
|
||||
""".lstrip()
|
||||
"""
|
||||
|
||||
GENERATE_IMAGE_GUIDANCE = """
|
||||
|
||||
## generate_image
|
||||
NEVER use generate_image unless the user specifically requests an image.
|
||||
For edits/variations of a previously generated image, pass `reference_image_file_ids` with
|
||||
the `file_id` values returned by earlier `generate_image` tool results.
|
||||
""".lstrip()
|
||||
"""
|
||||
|
||||
MEMORY_GUIDANCE = """
|
||||
|
||||
## add_memory
|
||||
Use the `add_memory` tool for facts shared by the user that should be remembered for future conversations. \
|
||||
Only add memories that are specific, likely to remain true, and likely to be useful later. \
|
||||
Focus on enduring preferences, long-term goals, stable constraints, and explicit "remember this" type requests.
|
||||
""".lstrip()
|
||||
"""
|
||||
|
||||
TOOL_CALL_FAILURE_PROMPT = """
|
||||
LLM attempted to call a tool but failed. Most likely the tool name or arguments were misspelled.
|
||||
|
||||
@@ -1,36 +1,40 @@
|
||||
# ruff: noqa: E501, W605 start
|
||||
USER_INFORMATION_HEADER = "\n# User Information\n\n"
|
||||
USER_INFORMATION_HEADER = "\n\n# User Information\n"
|
||||
|
||||
BASIC_INFORMATION_PROMPT = """
|
||||
|
||||
## Basic Information
|
||||
User name: {user_name}
|
||||
User email: {user_email}{user_role}
|
||||
""".lstrip()
|
||||
"""
|
||||
|
||||
# This line only shows up if the user has configured their role.
|
||||
USER_ROLE_PROMPT = """
|
||||
User role: {user_role}
|
||||
""".lstrip()
|
||||
"""
|
||||
|
||||
# Team information should be a paragraph style description of the user's team.
|
||||
TEAM_INFORMATION_PROMPT = """
|
||||
|
||||
## Team Information
|
||||
{team_information}
|
||||
""".lstrip()
|
||||
"""
|
||||
|
||||
# User preferences should be a paragraph style description of the user's preferences.
|
||||
USER_PREFERENCES_PROMPT = """
|
||||
|
||||
## User Preferences
|
||||
{user_preferences}
|
||||
""".lstrip()
|
||||
"""
|
||||
|
||||
# User memories should look something like:
|
||||
# - Memory 1
|
||||
# - Memory 2
|
||||
# - Memory 3
|
||||
USER_MEMORIES_PROMPT = """
|
||||
|
||||
## User Memories
|
||||
{user_memories}
|
||||
""".lstrip()
|
||||
"""
|
||||
|
||||
# ruff: noqa: E501, W605 end
|
||||
|
||||
@@ -109,7 +109,6 @@ class TenantRedis(redis.Redis):
|
||||
"unlock",
|
||||
"get",
|
||||
"set",
|
||||
"setex",
|
||||
"delete",
|
||||
"exists",
|
||||
"incrby",
|
||||
|
||||
@@ -59,9 +59,6 @@ PUBLIC_ENDPOINT_SPECS = [
|
||||
# anonymous user on cloud
|
||||
("/tenants/anonymous-user", {"POST"}),
|
||||
("/metrics", {"GET"}), # added by prometheus_fastapi_instrumentator
|
||||
# craft webapp proxy — access enforced per-session via sharing_scope in handler
|
||||
("/build/sessions/{session_id}/webapp", {"GET"}),
|
||||
("/build/sessions/{session_id}/webapp/{path:path}", {"GET"}),
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -103,7 +103,6 @@ from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import IndexingMode
|
||||
from onyx.db.enums import ProcessingMode
|
||||
from onyx.db.federated import fetch_all_federated_connectors_parallel
|
||||
from onyx.db.index_attempt import get_index_attempts_for_cc_pair
|
||||
from onyx.db.index_attempt import get_latest_index_attempts_by_status
|
||||
@@ -988,7 +987,6 @@ def get_connector_status(
|
||||
user=user,
|
||||
eager_load_connector=True,
|
||||
eager_load_credential=True,
|
||||
eager_load_user=True,
|
||||
get_editable=False,
|
||||
)
|
||||
|
||||
@@ -1002,23 +1000,11 @@ def get_connector_status(
|
||||
relationship.user_group_id
|
||||
)
|
||||
|
||||
# Pre-compute credential_ids per connector to avoid N+1 lazy loads
|
||||
connector_to_credential_ids: dict[int, list[int]] = {}
|
||||
for cc_pair in cc_pairs:
|
||||
connector_to_credential_ids.setdefault(cc_pair.connector_id, []).append(
|
||||
cc_pair.credential_id
|
||||
)
|
||||
|
||||
return [
|
||||
ConnectorStatus(
|
||||
cc_pair_id=cc_pair.id,
|
||||
name=cc_pair.name,
|
||||
connector=ConnectorSnapshot.from_connector_db_model(
|
||||
cc_pair.connector,
|
||||
credential_ids=connector_to_credential_ids.get(
|
||||
cc_pair.connector_id, []
|
||||
),
|
||||
),
|
||||
connector=ConnectorSnapshot.from_connector_db_model(cc_pair.connector),
|
||||
credential=CredentialSnapshot.from_credential_db_model(cc_pair.credential),
|
||||
access_type=cc_pair.access_type,
|
||||
groups=group_cc_pair_relationships_dict.get(cc_pair.id, []),
|
||||
@@ -1073,27 +1059,15 @@ def get_connector_indexing_status(
|
||||
parallel_functions: list[tuple[CallableProtocol, tuple[Any, ...]]] = [
|
||||
# Get editable connector/credential pairs
|
||||
(
|
||||
lambda: get_connector_credential_pairs_for_user_parallel(
|
||||
user, True, None, True, True, False, True, request.source
|
||||
),
|
||||
(),
|
||||
get_connector_credential_pairs_for_user_parallel,
|
||||
(user, True, None, True, True, True, True, request.source),
|
||||
),
|
||||
# Get federated connectors
|
||||
(fetch_all_federated_connectors_parallel, ()),
|
||||
# Get most recent index attempts
|
||||
(
|
||||
lambda: get_latest_index_attempts_parallel(
|
||||
request.secondary_index, True, False
|
||||
),
|
||||
(),
|
||||
),
|
||||
(get_latest_index_attempts_parallel, (request.secondary_index, True, False)),
|
||||
# Get most recent finished index attempts
|
||||
(
|
||||
lambda: get_latest_index_attempts_parallel(
|
||||
request.secondary_index, True, True
|
||||
),
|
||||
(),
|
||||
),
|
||||
(get_latest_index_attempts_parallel, (request.secondary_index, True, True)),
|
||||
]
|
||||
|
||||
if user and user.role == UserRole.ADMIN:
|
||||
@@ -1110,10 +1084,8 @@ def get_connector_indexing_status(
|
||||
parallel_functions.append(
|
||||
# Get non-editable connector/credential pairs
|
||||
(
|
||||
lambda: get_connector_credential_pairs_for_user_parallel(
|
||||
user, False, None, True, True, False, True, request.source
|
||||
),
|
||||
(),
|
||||
get_connector_credential_pairs_for_user_parallel,
|
||||
(user, False, None, True, True, True, True, request.source),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1939,7 +1911,6 @@ Tenant ID: {tenant_id}
|
||||
class BasicCCPairInfo(BaseModel):
|
||||
has_successful_run: bool
|
||||
source: DocumentSource
|
||||
status: ConnectorCredentialPairStatus
|
||||
|
||||
|
||||
@router.get("/connector-status", tags=PUBLIC_API_TAGS)
|
||||
@@ -1953,17 +1924,13 @@ def get_basic_connector_indexing_status(
|
||||
get_editable=False,
|
||||
user=user,
|
||||
)
|
||||
|
||||
# NOTE: This endpoint excludes Craft connectors
|
||||
return [
|
||||
BasicCCPairInfo(
|
||||
has_successful_run=cc_pair.last_successful_index_time is not None,
|
||||
source=cc_pair.connector.source,
|
||||
status=cc_pair.status,
|
||||
)
|
||||
for cc_pair in cc_pairs
|
||||
if cc_pair.connector.source != DocumentSource.INGESTION_API
|
||||
and cc_pair.processing_mode == ProcessingMode.REGULAR
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -365,8 +365,7 @@ class CCPairFullInfo(BaseModel):
|
||||
in_repeated_error_state=cc_pair_model.in_repeated_error_state,
|
||||
num_docs_indexed=num_docs_indexed,
|
||||
connector=ConnectorSnapshot.from_connector_db_model(
|
||||
cc_pair_model.connector,
|
||||
credential_ids=[cc_pair_model.credential_id],
|
||||
cc_pair_model.connector
|
||||
),
|
||||
credential=CredentialSnapshot.from_credential_db_model(
|
||||
cc_pair_model.credential
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
from collections.abc import Iterator
|
||||
from pathlib import Path
|
||||
from uuid import UUID
|
||||
|
||||
import httpx
|
||||
@@ -8,19 +7,16 @@ from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi.responses import RedirectResponse
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.auth.users import optional_user
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs_for_user
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import IndexingStatus
|
||||
from onyx.db.enums import ProcessingMode
|
||||
from onyx.db.enums import SharingScope
|
||||
from onyx.db.index_attempt import get_latest_index_attempt_for_cc_pair_id
|
||||
from onyx.db.models import BuildSession
|
||||
from onyx.db.models import User
|
||||
@@ -221,15 +217,12 @@ def get_build_connectors(
|
||||
return BuildConnectorListResponse(connectors=connectors)
|
||||
|
||||
|
||||
# Headers to skip when proxying.
|
||||
# Hop-by-hop headers must not be forwarded, and set-cookie is stripped to
|
||||
# prevent LLM-generated apps from setting cookies on the parent Onyx domain.
|
||||
# Headers to skip when proxying (hop-by-hop headers)
|
||||
EXCLUDED_HEADERS = {
|
||||
"content-encoding",
|
||||
"content-length",
|
||||
"transfer-encoding",
|
||||
"connection",
|
||||
"set-cookie",
|
||||
}
|
||||
|
||||
|
||||
@@ -287,7 +280,7 @@ def _get_sandbox_url(session_id: UUID, db_session: Session) -> str:
|
||||
db_session: Database session
|
||||
|
||||
Returns:
|
||||
Internal URL to proxy requests to
|
||||
The internal URL to proxy requests to
|
||||
|
||||
Raises:
|
||||
HTTPException: If session not found, port not allocated, or sandbox not found
|
||||
@@ -301,10 +294,12 @@ def _get_sandbox_url(session_id: UUID, db_session: Session) -> str:
|
||||
if session.user_id is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
# Get the user's sandbox to get the sandbox_id
|
||||
sandbox = get_sandbox_by_user_id(db_session, session.user_id)
|
||||
if sandbox is None:
|
||||
raise HTTPException(status_code=404, detail="Sandbox not found")
|
||||
|
||||
# Use sandbox manager to get the correct internal URL
|
||||
sandbox_manager = get_sandbox_manager()
|
||||
return sandbox_manager.get_webapp_url(sandbox.id, session.nextjs_port)
|
||||
|
||||
@@ -370,73 +365,71 @@ def _proxy_request(
|
||||
raise HTTPException(status_code=502, detail="Bad gateway")
|
||||
|
||||
|
||||
def _check_webapp_access(
|
||||
session_id: UUID, user: User | None, db_session: Session
|
||||
) -> BuildSession:
|
||||
"""Check if user can access a session's webapp.
|
||||
|
||||
- public_global: accessible by anyone (no auth required)
|
||||
- public_org: accessible by any authenticated user
|
||||
- private: only accessible by the session owner
|
||||
"""
|
||||
session = db_session.get(BuildSession, session_id)
|
||||
if not session:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
if session.sharing_scope == SharingScope.PUBLIC_GLOBAL:
|
||||
return session
|
||||
if user is None:
|
||||
raise HTTPException(status_code=401, detail="Authentication required")
|
||||
if session.sharing_scope == SharingScope.PRIVATE and session.user_id != user.id:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
return session
|
||||
|
||||
|
||||
_OFFLINE_HTML_PATH = Path(__file__).parent / "templates" / "webapp_offline.html"
|
||||
|
||||
|
||||
def _offline_html_response() -> Response:
|
||||
"""Return a branded Craft HTML page when the sandbox is not reachable.
|
||||
|
||||
Design mirrors the default Craft web template (outputs/web/app/page.tsx):
|
||||
terminal window aesthetic with Minecraft-themed typing animation.
|
||||
"""
|
||||
html = _OFFLINE_HTML_PATH.read_text()
|
||||
return Response(content=html, status_code=503, media_type="text/html")
|
||||
|
||||
|
||||
# Public router for webapp proxy — no authentication required
|
||||
# (access controlled per-session via sharing_scope)
|
||||
public_build_router = APIRouter(prefix="/build")
|
||||
|
||||
|
||||
@public_build_router.get("/sessions/{session_id}/webapp", response_model=None)
|
||||
@public_build_router.get(
|
||||
"/sessions/{session_id}/webapp/{path:path}", response_model=None
|
||||
)
|
||||
def get_webapp(
|
||||
@router.get("/sessions/{session_id}/webapp", response_model=None)
|
||||
def get_webapp_root(
|
||||
session_id: UUID,
|
||||
request: Request,
|
||||
path: str = "",
|
||||
user: User | None = Depends(optional_user),
|
||||
_: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StreamingResponse | Response:
|
||||
"""Proxy the webapp for a specific session (root and subpaths).
|
||||
"""Proxy the root path of the webapp for a specific session."""
|
||||
return _proxy_request("", request, session_id, db_session)
|
||||
|
||||
Accessible without authentication when sharing_scope is public_global.
|
||||
Returns a friendly offline page when the sandbox is not running.
|
||||
|
||||
@router.get("/sessions/{session_id}/webapp/{path:path}", response_model=None)
|
||||
def get_webapp_path(
|
||||
session_id: UUID,
|
||||
path: str,
|
||||
request: Request,
|
||||
_: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StreamingResponse | Response:
|
||||
"""Proxy any subpath of the webapp (static assets, etc.) for a specific session."""
|
||||
return _proxy_request(path, request, session_id, db_session)
|
||||
|
||||
|
||||
# Separate router for Next.js static assets at /_next/*
|
||||
# This is needed because Next.js apps may reference assets with root-relative paths
|
||||
# that don't get rewritten. The session_id is extracted from the Referer header.
|
||||
nextjs_assets_router = APIRouter()
|
||||
|
||||
|
||||
def _extract_session_from_referer(request: Request) -> UUID | None:
|
||||
"""Extract session_id from the Referer header.
|
||||
|
||||
Expects Referer to contain /api/build/sessions/{session_id}/webapp
|
||||
"""
|
||||
try:
|
||||
_check_webapp_access(session_id, user, db_session)
|
||||
except HTTPException as e:
|
||||
if e.status_code == 401:
|
||||
return RedirectResponse(url="/auth/login", status_code=302)
|
||||
raise
|
||||
try:
|
||||
return _proxy_request(path, request, session_id, db_session)
|
||||
except HTTPException as e:
|
||||
if e.status_code in (502, 503, 504):
|
||||
return _offline_html_response()
|
||||
raise
|
||||
import re
|
||||
|
||||
referer = request.headers.get("referer", "")
|
||||
match = re.search(r"/api/build/sessions/([a-f0-9-]+)/webapp", referer)
|
||||
if match:
|
||||
try:
|
||||
return UUID(match.group(1))
|
||||
except ValueError:
|
||||
return None
|
||||
return None
|
||||
|
||||
|
||||
@nextjs_assets_router.get("/_next/{path:path}", response_model=None)
|
||||
def get_nextjs_assets(
|
||||
path: str,
|
||||
request: Request,
|
||||
_: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StreamingResponse | Response:
|
||||
"""Proxy Next.js static assets requested at root /_next/ path.
|
||||
|
||||
The session_id is extracted from the Referer header since these requests
|
||||
come from within the iframe context.
|
||||
"""
|
||||
session_id = _extract_session_from_referer(request)
|
||||
if not session_id:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Could not determine session from request context",
|
||||
)
|
||||
return _proxy_request(f"_next/{path}", request, session_id, db_session)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
|
||||
@@ -10,7 +10,6 @@ from onyx.configs.constants import MessageType
|
||||
from onyx.db.enums import ArtifactType
|
||||
from onyx.db.enums import BuildSessionStatus
|
||||
from onyx.db.enums import SandboxStatus
|
||||
from onyx.db.enums import SharingScope
|
||||
from onyx.server.features.build.sandbox.models import (
|
||||
FilesystemEntry as FileSystemEntry,
|
||||
)
|
||||
@@ -108,7 +107,6 @@ class SessionResponse(BaseModel):
|
||||
nextjs_port: int | None
|
||||
sandbox: SandboxResponse | None
|
||||
artifacts: list[ArtifactResponse]
|
||||
sharing_scope: SharingScope
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
@@ -131,7 +129,6 @@ class SessionResponse(BaseModel):
|
||||
nextjs_port=session.nextjs_port,
|
||||
sandbox=(SandboxResponse.from_model(sandbox) if sandbox else None),
|
||||
artifacts=[ArtifactResponse.from_model(a) for a in session.artifacts],
|
||||
sharing_scope=session.sharing_scope,
|
||||
)
|
||||
|
||||
|
||||
@@ -162,19 +159,6 @@ class SessionListResponse(BaseModel):
|
||||
sessions: list[SessionResponse]
|
||||
|
||||
|
||||
class SetSessionSharingRequest(BaseModel):
|
||||
"""Request to set the sharing scope of a session."""
|
||||
|
||||
sharing_scope: SharingScope
|
||||
|
||||
|
||||
class SetSessionSharingResponse(BaseModel):
|
||||
"""Response after setting session sharing scope."""
|
||||
|
||||
session_id: str
|
||||
sharing_scope: SharingScope
|
||||
|
||||
|
||||
# ===== Message Models =====
|
||||
class MessageRequest(BaseModel):
|
||||
"""Request to send a message to the CLI agent."""
|
||||
@@ -260,7 +244,6 @@ class WebappInfo(BaseModel):
|
||||
webapp_url: str | None # URL to access the webapp (e.g., http://localhost:3015)
|
||||
status: str # Sandbox status (running, terminated, etc.)
|
||||
ready: bool # Whether the NextJS dev server is actually responding
|
||||
sharing_scope: SharingScope
|
||||
|
||||
|
||||
# ===== File Upload Models =====
|
||||
|
||||
@@ -30,8 +30,6 @@ from onyx.server.features.build.api.models import SessionListResponse
|
||||
from onyx.server.features.build.api.models import SessionNameGenerateResponse
|
||||
from onyx.server.features.build.api.models import SessionResponse
|
||||
from onyx.server.features.build.api.models import SessionUpdateRequest
|
||||
from onyx.server.features.build.api.models import SetSessionSharingRequest
|
||||
from onyx.server.features.build.api.models import SetSessionSharingResponse
|
||||
from onyx.server.features.build.api.models import SuggestionBubble
|
||||
from onyx.server.features.build.api.models import SuggestionTheme
|
||||
from onyx.server.features.build.api.models import UploadResponse
|
||||
@@ -40,7 +38,6 @@ from onyx.server.features.build.configs import SANDBOX_BACKEND
|
||||
from onyx.server.features.build.configs import SandboxBackend
|
||||
from onyx.server.features.build.db.build_session import allocate_nextjs_port
|
||||
from onyx.server.features.build.db.build_session import get_build_session
|
||||
from onyx.server.features.build.db.build_session import set_build_session_sharing_scope
|
||||
from onyx.server.features.build.db.sandbox import get_latest_snapshot_for_session
|
||||
from onyx.server.features.build.db.sandbox import get_sandbox_by_user_id
|
||||
from onyx.server.features.build.db.sandbox import update_sandbox_heartbeat
|
||||
@@ -297,25 +294,6 @@ def update_session_name(
|
||||
return SessionResponse.from_model(session, sandbox)
|
||||
|
||||
|
||||
@router.patch("/{session_id}/public")
|
||||
def set_session_public(
|
||||
session_id: UUID,
|
||||
request: SetSessionSharingRequest,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> SetSessionSharingResponse:
|
||||
"""Set the sharing scope of a build session's webapp."""
|
||||
updated = set_build_session_sharing_scope(
|
||||
session_id, user.id, request.sharing_scope, db_session
|
||||
)
|
||||
if not updated:
|
||||
raise HTTPException(status_code=404, detail="Session not found")
|
||||
return SetSessionSharingResponse(
|
||||
session_id=str(session_id),
|
||||
sharing_scope=updated.sharing_scope,
|
||||
)
|
||||
|
||||
|
||||
@router.delete("/{session_id}", response_model=None)
|
||||
def delete_session(
|
||||
session_id: UUID,
|
||||
|
||||
@@ -1,110 +0,0 @@
|
||||
<!doctype html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="UTF-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0" />
|
||||
<meta http-equiv="refresh" content="15" />
|
||||
<title>Craft — Starting up</title>
|
||||
<style>
|
||||
*,
|
||||
*::before,
|
||||
*::after {
|
||||
box-sizing: border-box;
|
||||
margin: 0;
|
||||
padding: 0;
|
||||
}
|
||||
|
||||
body {
|
||||
font-family: ui-monospace, SFMono-Regular, "SF Mono", Menlo, Consolas,
|
||||
monospace;
|
||||
background: linear-gradient(to bottom right, #030712, #111827, #030712);
|
||||
min-height: 100vh;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
gap: 1.5rem;
|
||||
padding: 2rem;
|
||||
}
|
||||
|
||||
.terminal {
|
||||
width: 100%;
|
||||
max-width: 580px;
|
||||
border: 2px solid #374151;
|
||||
border-radius: 2px;
|
||||
}
|
||||
|
||||
.titlebar {
|
||||
background: #1f2937;
|
||||
padding: 0.5rem 0.75rem;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 0.5rem;
|
||||
border-bottom: 1px solid #374151;
|
||||
}
|
||||
|
||||
.btn {
|
||||
width: 12px;
|
||||
height: 12px;
|
||||
border-radius: 2px;
|
||||
flex-shrink: 0;
|
||||
}
|
||||
.btn-red {
|
||||
background: #ef4444;
|
||||
}
|
||||
.btn-yellow {
|
||||
background: #eab308;
|
||||
}
|
||||
.btn-green {
|
||||
background: #22c55e;
|
||||
}
|
||||
|
||||
.title-label {
|
||||
flex: 1;
|
||||
text-align: center;
|
||||
font-size: 0.75rem;
|
||||
color: #6b7280;
|
||||
margin-right: 36px;
|
||||
}
|
||||
|
||||
.body {
|
||||
background: #111827;
|
||||
padding: 1.5rem;
|
||||
min-height: 200px;
|
||||
font-size: 0.875rem;
|
||||
color: #d1d5db;
|
||||
display: flex;
|
||||
align-items: flex-start;
|
||||
gap: 0.375rem;
|
||||
}
|
||||
|
||||
.prompt {
|
||||
color: #10b981;
|
||||
user-select: none;
|
||||
}
|
||||
|
||||
.tagline {
|
||||
font-size: 0.8125rem;
|
||||
color: #4b5563;
|
||||
text-align: center;
|
||||
}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="terminal">
|
||||
<div class="titlebar">
|
||||
<div class="btn btn-red"></div>
|
||||
<div class="btn btn-yellow"></div>
|
||||
<div class="btn btn-green"></div>
|
||||
<span class="title-label">crafting_table</span>
|
||||
</div>
|
||||
<div class="body">
|
||||
<span class="prompt">/></span>
|
||||
<span>Sandbox is asleep...</span>
|
||||
</div>
|
||||
</div>
|
||||
<p class="tagline">
|
||||
Ask the owner to open their Craft session to wake it up.
|
||||
</p>
|
||||
</body>
|
||||
</html>
|
||||
@@ -13,7 +13,6 @@ from sqlalchemy.orm import Session
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.db.enums import BuildSessionStatus
|
||||
from onyx.db.enums import SandboxStatus
|
||||
from onyx.db.enums import SharingScope
|
||||
from onyx.db.models import Artifact
|
||||
from onyx.db.models import BuildMessage
|
||||
from onyx.db.models import BuildSession
|
||||
@@ -160,26 +159,6 @@ def update_session_status(
|
||||
logger.info(f"Updated build session {session_id} status to {status}")
|
||||
|
||||
|
||||
def set_build_session_sharing_scope(
|
||||
session_id: UUID,
|
||||
user_id: UUID,
|
||||
sharing_scope: SharingScope,
|
||||
db_session: Session,
|
||||
) -> BuildSession | None:
|
||||
"""Set the sharing scope of a build session.
|
||||
|
||||
Only the session owner can change this setting.
|
||||
Returns the updated session, or None if not found/unauthorized.
|
||||
"""
|
||||
session = get_build_session(session_id, user_id, db_session)
|
||||
if not session:
|
||||
return None
|
||||
session.sharing_scope = sharing_scope
|
||||
db_session.commit()
|
||||
logger.info(f"Set build session {session_id} sharing_scope={sharing_scope}")
|
||||
return session
|
||||
|
||||
|
||||
def delete_build_session__no_commit(
|
||||
session_id: UUID,
|
||||
user_id: UUID,
|
||||
|
||||
@@ -474,23 +474,6 @@ class SandboxManager(ABC):
|
||||
"""
|
||||
...
|
||||
|
||||
def ensure_nextjs_running(
|
||||
self,
|
||||
sandbox_id: UUID,
|
||||
session_id: UUID,
|
||||
nextjs_port: int,
|
||||
) -> None:
|
||||
"""Ensure the Next.js server is running for a session.
|
||||
|
||||
Default is a no-op — only meaningful for local backends that manage
|
||||
process lifecycles directly (e.g., LocalSandboxManager).
|
||||
|
||||
Args:
|
||||
sandbox_id: The sandbox ID
|
||||
session_id: The session ID
|
||||
nextjs_port: The port the Next.js server should be listening on
|
||||
"""
|
||||
|
||||
|
||||
# Singleton instance cache for the factory
|
||||
_sandbox_manager_instance: SandboxManager | None = None
|
||||
|
||||
@@ -15,8 +15,6 @@ from collections.abc import Generator
|
||||
from pathlib import Path
|
||||
from uuid import UUID
|
||||
|
||||
import httpx
|
||||
|
||||
from onyx.db.enums import SandboxStatus
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.server.features.build.configs import DEMO_DATA_PATH
|
||||
@@ -37,7 +35,6 @@ from onyx.server.features.build.sandbox.models import LLMProviderConfig
|
||||
from onyx.server.features.build.sandbox.models import SandboxInfo
|
||||
from onyx.server.features.build.sandbox.models import SnapshotResult
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import ThreadSafeSet
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -92,17 +89,9 @@ class LocalSandboxManager(SandboxManager):
|
||||
self._acp_clients: dict[tuple[UUID, UUID], ACPAgentClient] = {}
|
||||
|
||||
# Track Next.js processes - keyed by (sandbox_id, session_id) tuple
|
||||
# Used for clean shutdown when sessions are deleted.
|
||||
# Mutated from background threads; all access must hold _nextjs_lock.
|
||||
# Used for clean shutdown when sessions are deleted
|
||||
self._nextjs_processes: dict[tuple[UUID, UUID], subprocess.Popen[bytes]] = {}
|
||||
|
||||
# Track sessions currently being (re)started - prevents concurrent restarts.
|
||||
# ThreadSafeSet allows atomic check-and-add without holding _nextjs_lock.
|
||||
self._nextjs_starting: ThreadSafeSet[tuple[UUID, UUID]] = ThreadSafeSet()
|
||||
|
||||
# Lock guarding _nextjs_processes (shared across sessions; hold briefly only)
|
||||
self._nextjs_lock = threading.Lock()
|
||||
|
||||
# Validate templates exist (raises RuntimeError if missing)
|
||||
self._validate_templates()
|
||||
|
||||
@@ -337,18 +326,16 @@ class LocalSandboxManager(SandboxManager):
|
||||
RuntimeError: If termination fails
|
||||
"""
|
||||
# Stop all Next.js processes for this sandbox (keyed by (sandbox_id, session_id))
|
||||
with self._nextjs_lock:
|
||||
processes_to_stop = [
|
||||
(key, process)
|
||||
for key, process in self._nextjs_processes.items()
|
||||
if key[0] == sandbox_id
|
||||
]
|
||||
processes_to_stop = [
|
||||
(key, process)
|
||||
for key, process in self._nextjs_processes.items()
|
||||
if key[0] == sandbox_id
|
||||
]
|
||||
for key, process in processes_to_stop:
|
||||
session_id = key[1]
|
||||
try:
|
||||
self._stop_nextjs_process(process, session_id)
|
||||
with self._nextjs_lock:
|
||||
self._nextjs_processes.pop(key, None)
|
||||
del self._nextjs_processes[key]
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to stop Next.js for sandbox {sandbox_id}, "
|
||||
@@ -529,8 +516,7 @@ class LocalSandboxManager(SandboxManager):
|
||||
web_dir, nextjs_port
|
||||
)
|
||||
# Store process for clean shutdown on session delete
|
||||
with self._nextjs_lock:
|
||||
self._nextjs_processes[(sandbox_id, session_id)] = nextjs_process
|
||||
self._nextjs_processes[(sandbox_id, session_id)] = nextjs_process
|
||||
logger.info("Next.js server started successfully")
|
||||
|
||||
# Setup venv and AGENTS.md
|
||||
@@ -589,8 +575,7 @@ class LocalSandboxManager(SandboxManager):
|
||||
"""
|
||||
# Stop Next.js dev server - try stored process first, then fallback to port lookup
|
||||
process_key = (sandbox_id, session_id)
|
||||
with self._nextjs_lock:
|
||||
nextjs_process = self._nextjs_processes.pop(process_key, None)
|
||||
nextjs_process = self._nextjs_processes.pop(process_key, None)
|
||||
if nextjs_process is not None:
|
||||
self._stop_nextjs_process(nextjs_process, session_id)
|
||||
elif nextjs_port is not None:
|
||||
@@ -781,85 +766,6 @@ class LocalSandboxManager(SandboxManager):
|
||||
outputs_path = session_path / "outputs"
|
||||
return outputs_path.exists()
|
||||
|
||||
def ensure_nextjs_running(
|
||||
self,
|
||||
sandbox_id: UUID,
|
||||
session_id: UUID,
|
||||
nextjs_port: int,
|
||||
) -> None:
|
||||
"""Start Next.js server for a session if not already running.
|
||||
|
||||
Called when the server is detected as unreachable (e.g., after API server restart).
|
||||
Returns immediately — the actual startup runs in a background daemon thread.
|
||||
A per-session guard prevents concurrent restarts from racing.
|
||||
|
||||
Lock design: _nextjs_lock is shared across ALL sessions. Holding it during
|
||||
httpx (1s) or start_nextjs_server (several seconds) would block every other
|
||||
session's status checks and restarts. We only hold the lock for fast
|
||||
in-memory ops (dict get, check_and_add). The slow I/O runs in the background
|
||||
thread without holding any lock.
|
||||
|
||||
Args:
|
||||
sandbox_id: The sandbox ID
|
||||
session_id: The session ID
|
||||
nextjs_port: The port number for the Next.js server
|
||||
"""
|
||||
process_key = (sandbox_id, session_id)
|
||||
|
||||
with self._nextjs_lock:
|
||||
existing = self._nextjs_processes.get(process_key)
|
||||
if existing is not None and existing.poll() is None:
|
||||
return
|
||||
|
||||
# Atomic check-and-add: returns True if already in set (another thread is starting)
|
||||
if self._nextjs_starting.check_and_add(process_key):
|
||||
return
|
||||
|
||||
def _start_in_background() -> None:
|
||||
try:
|
||||
# Port check in background to avoid blocking the main thread
|
||||
try:
|
||||
with httpx.Client(timeout=1.0) as client:
|
||||
client.get(f"http://localhost:{nextjs_port}")
|
||||
logger.info(
|
||||
f"Port {nextjs_port} already alive for session {session_id} "
|
||||
"(orphan process) — skipping restart"
|
||||
)
|
||||
return
|
||||
except Exception:
|
||||
pass # Port is dead; proceed with restart
|
||||
|
||||
logger.info(
|
||||
f"Starting Next.js for session {session_id} on port {nextjs_port}"
|
||||
)
|
||||
sandbox_path = self._get_sandbox_path(sandbox_id)
|
||||
web_dir = self._directory_manager.get_web_path(
|
||||
sandbox_path, str(session_id)
|
||||
)
|
||||
if not web_dir.exists():
|
||||
logger.warning(
|
||||
f"Web dir missing for session {session_id}: {web_dir} — "
|
||||
"cannot restart Next.js"
|
||||
)
|
||||
return
|
||||
process = self._process_manager.start_nextjs_server(
|
||||
web_dir, nextjs_port
|
||||
)
|
||||
with self._nextjs_lock:
|
||||
self._nextjs_processes[process_key] = process
|
||||
logger.info(
|
||||
f"Auto-restarted Next.js for session {session_id} "
|
||||
f"on port {nextjs_port}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to auto-restart Next.js for session {session_id}: {e}"
|
||||
)
|
||||
finally:
|
||||
self._nextjs_starting.discard(process_key)
|
||||
|
||||
threading.Thread(target=_start_in_background, daemon=True).start()
|
||||
|
||||
def restore_snapshot(
|
||||
self,
|
||||
sandbox_id: UUID,
|
||||
|
||||
@@ -0,0 +1,10 @@
|
||||
"""Celery tasks for sandbox management."""
|
||||
|
||||
from onyx.server.features.build.sandbox.tasks.tasks import (
|
||||
cleanup_idle_sandboxes_task,
|
||||
) # noqa: F401
|
||||
from onyx.server.features.build.sandbox.tasks.tasks import (
|
||||
sync_sandbox_files,
|
||||
) # noqa: F401
|
||||
|
||||
__all__ = ["cleanup_idle_sandboxes_task", "sync_sandbox_files"]
|
||||
|
||||
@@ -1765,7 +1765,6 @@ class SessionManager:
|
||||
"webapp_url": None,
|
||||
"status": "no_sandbox",
|
||||
"ready": False,
|
||||
"sharing_scope": session.sharing_scope,
|
||||
}
|
||||
|
||||
# Return the proxy URL - the proxy handles routing to the correct sandbox
|
||||
@@ -1778,21 +1777,11 @@ class SessionManager:
|
||||
# Quick health check: can the API server reach the NextJS dev server?
|
||||
ready = self._check_nextjs_ready(sandbox.id, session.nextjs_port)
|
||||
|
||||
# If not ready, ask the sandbox manager to ensure Next.js is running.
|
||||
# For the local backend this triggers a background restart so that the
|
||||
# frontend poll loop eventually sees ready=True without the user having
|
||||
# to manually recreate the session.
|
||||
if not ready:
|
||||
self._sandbox_manager.ensure_nextjs_running(
|
||||
sandbox.id, session_id, session.nextjs_port
|
||||
)
|
||||
|
||||
return {
|
||||
"has_webapp": session.nextjs_port is not None,
|
||||
"webapp_url": webapp_url,
|
||||
"status": sandbox.status.value,
|
||||
"ready": ready,
|
||||
"sharing_scope": session.sharing_scope,
|
||||
}
|
||||
|
||||
def _check_nextjs_ready(self, sandbox_id: UUID, port: int) -> bool:
|
||||
|
||||
@@ -111,8 +111,7 @@ class DocumentSet(BaseModel):
|
||||
id=cc_pair.id,
|
||||
name=cc_pair.name,
|
||||
connector=ConnectorSnapshot.from_connector_db_model(
|
||||
cc_pair.connector,
|
||||
credential_ids=[cc_pair.credential_id],
|
||||
cc_pair.connector
|
||||
),
|
||||
credential=CredentialSnapshot.from_credential_db_model(
|
||||
cc_pair.credential
|
||||
|
||||
@@ -26,17 +26,13 @@ def get_opensearch_migration_status(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> OpenSearchMigrationStatusResponse:
|
||||
(
|
||||
total_chunks_migrated,
|
||||
created_at,
|
||||
migration_completed_at,
|
||||
approx_chunk_count_in_vespa,
|
||||
) = get_opensearch_migration_state(db_session)
|
||||
total_chunks_migrated, created_at, migration_completed_at = (
|
||||
get_opensearch_migration_state(db_session)
|
||||
)
|
||||
return OpenSearchMigrationStatusResponse(
|
||||
total_chunks_migrated=total_chunks_migrated,
|
||||
created_at=created_at,
|
||||
migration_completed_at=migration_completed_at,
|
||||
approx_chunk_count_in_vespa=approx_chunk_count_in_vespa,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@ class OpenSearchMigrationStatusResponse(BaseModel):
|
||||
total_chunks_migrated: int
|
||||
created_at: datetime | None
|
||||
migration_completed_at: datetime | None
|
||||
approx_chunk_count_in_vespa: int | None
|
||||
|
||||
|
||||
class OpenSearchRetrievalStatusRequest(BaseModel):
|
||||
|
||||
@@ -608,8 +608,7 @@ def list_all_users_basic_info(
|
||||
return [
|
||||
MinimalUserSnapshot(id=user.id, email=user.email)
|
||||
for user in users
|
||||
if user.role != UserRole.SLACK_USER
|
||||
and (include_api_keys or not is_api_key_email_address(user.email))
|
||||
if include_api_keys or not is_api_key_email_address(user.email)
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -1,241 +0,0 @@
|
||||
"""SQLAlchemy connection pool Prometheus metrics.
|
||||
|
||||
Provides production-grade visibility into database connection pool state:
|
||||
|
||||
- Pool state gauges (checked-out, idle, overflow, configured size)
|
||||
- Pool lifecycle counters (checkouts, checkins, creates, invalidations, timeouts)
|
||||
- Per-endpoint connection attribution (which endpoints hold connections, for how long)
|
||||
|
||||
Metrics are collected via two mechanisms:
|
||||
1. A custom Prometheus Collector that reads pool snapshots on each /metrics scrape
|
||||
2. SQLAlchemy pool event listeners (checkout, checkin, connect, invalidate) for
|
||||
counters, histograms, and attribution
|
||||
"""
|
||||
|
||||
import time
|
||||
|
||||
from fastapi import Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from prometheus_client import Counter
|
||||
from prometheus_client import Gauge
|
||||
from prometheus_client import Histogram
|
||||
from prometheus_client.core import GaugeMetricFamily
|
||||
from prometheus_client.registry import Collector
|
||||
from prometheus_client.registry import REGISTRY
|
||||
from sqlalchemy import event
|
||||
from sqlalchemy.engine import Engine
|
||||
from sqlalchemy.engine.interfaces import DBAPIConnection
|
||||
from sqlalchemy.ext.asyncio import AsyncEngine
|
||||
from sqlalchemy.pool import ConnectionPoolEntry
|
||||
from sqlalchemy.pool import PoolProxiedConnection
|
||||
from sqlalchemy.pool import QueuePool
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import CURRENT_ENDPOINT_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# --- Pool lifecycle counters (event-driven) ---
|
||||
|
||||
_checkout_total = Counter(
|
||||
"onyx_db_pool_checkout_total",
|
||||
"Total connection checkouts from the pool",
|
||||
["engine"],
|
||||
)
|
||||
|
||||
_checkin_total = Counter(
|
||||
"onyx_db_pool_checkin_total",
|
||||
"Total connection checkins to the pool",
|
||||
["engine"],
|
||||
)
|
||||
|
||||
_connections_created_total = Counter(
|
||||
"onyx_db_pool_connections_created_total",
|
||||
"Total new database connections created",
|
||||
["engine"],
|
||||
)
|
||||
|
||||
_invalidations_total = Counter(
|
||||
"onyx_db_pool_invalidations_total",
|
||||
"Total connection invalidations",
|
||||
["engine"],
|
||||
)
|
||||
|
||||
_checkout_timeout_total = Counter(
|
||||
"onyx_db_pool_checkout_timeout_total",
|
||||
"Total connection checkout timeouts",
|
||||
["engine"],
|
||||
)
|
||||
|
||||
# --- Per-endpoint attribution (event-driven) ---
|
||||
|
||||
_connections_held = Gauge(
|
||||
"onyx_db_connections_held_by_endpoint",
|
||||
"Number of DB connections currently held, by endpoint and engine",
|
||||
["handler", "engine"],
|
||||
)
|
||||
|
||||
_hold_seconds = Histogram(
|
||||
"onyx_db_connection_hold_seconds",
|
||||
"Duration a DB connection is held by an endpoint",
|
||||
["handler", "engine"],
|
||||
)
|
||||
|
||||
|
||||
def pool_timeout_handler(
|
||||
request: Request, # noqa: ARG001
|
||||
exc: Exception,
|
||||
) -> JSONResponse:
|
||||
"""Increment the checkout timeout counter and return 503."""
|
||||
_checkout_timeout_total.labels(engine="unknown").inc()
|
||||
return JSONResponse(
|
||||
status_code=503,
|
||||
content={
|
||||
"detail": "Database connection pool timeout",
|
||||
"error": str(exc),
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
class PoolStateCollector(Collector):
|
||||
"""Custom Prometheus collector that reads QueuePool state on each scrape.
|
||||
|
||||
Uses pool.checkedout(), pool.checkedin(), pool.overflow(), and pool.size()
|
||||
for an atomic snapshot of pool state. Registered engines are stored as
|
||||
(label, pool) tuples to avoid holding references to the full Engine.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._pools: list[tuple[str, QueuePool]] = []
|
||||
|
||||
def add_pool(self, label: str, pool: QueuePool) -> None:
|
||||
self._pools.append((label, pool))
|
||||
|
||||
def collect(self) -> list[GaugeMetricFamily]:
|
||||
checked_out = GaugeMetricFamily(
|
||||
"onyx_db_pool_checked_out",
|
||||
"Currently checked-out connections",
|
||||
labels=["engine"],
|
||||
)
|
||||
checked_in = GaugeMetricFamily(
|
||||
"onyx_db_pool_checked_in",
|
||||
"Idle connections available in the pool",
|
||||
labels=["engine"],
|
||||
)
|
||||
overflow = GaugeMetricFamily(
|
||||
"onyx_db_pool_overflow",
|
||||
"Current overflow connections beyond pool_size",
|
||||
labels=["engine"],
|
||||
)
|
||||
size = GaugeMetricFamily(
|
||||
"onyx_db_pool_size",
|
||||
"Configured pool size",
|
||||
labels=["engine"],
|
||||
)
|
||||
|
||||
for label, pool in self._pools:
|
||||
checked_out.add_metric([label], pool.checkedout())
|
||||
checked_in.add_metric([label], pool.checkedin())
|
||||
overflow.add_metric([label], pool.overflow())
|
||||
size.add_metric([label], pool.size())
|
||||
|
||||
return [checked_out, checked_in, overflow, size]
|
||||
|
||||
def describe(self) -> list[GaugeMetricFamily]:
|
||||
# Return empty to mark this as an "unchecked" collector. Prometheus
|
||||
# skips upfront descriptor validation and just calls collect() at
|
||||
# scrape time. Required because our metrics are dynamic (engine
|
||||
# labels depend on which engines are registered at runtime).
|
||||
return []
|
||||
|
||||
|
||||
def _register_pool_events(engine: Engine, label: str) -> None:
|
||||
"""Attach pool event listeners for metrics collection.
|
||||
|
||||
Listens to checkout, checkin, connect, and invalidate events.
|
||||
Stores per-connection metadata on connection_record.info for attribution.
|
||||
"""
|
||||
|
||||
@event.listens_for(engine, "checkout")
|
||||
def on_checkout(
|
||||
dbapi_conn: DBAPIConnection, # noqa: ARG001
|
||||
conn_record: ConnectionPoolEntry,
|
||||
conn_proxy: PoolProxiedConnection, # noqa: ARG001
|
||||
) -> None:
|
||||
handler = CURRENT_ENDPOINT_CONTEXTVAR.get() or "unknown"
|
||||
conn_record.info["_metrics_endpoint"] = handler
|
||||
conn_record.info["_metrics_checkout_time"] = time.monotonic()
|
||||
_checkout_total.labels(engine=label).inc()
|
||||
_connections_held.labels(handler=handler, engine=label).inc()
|
||||
|
||||
@event.listens_for(engine, "checkin")
|
||||
def on_checkin(
|
||||
dbapi_conn: DBAPIConnection, # noqa: ARG001
|
||||
conn_record: ConnectionPoolEntry,
|
||||
) -> None:
|
||||
handler = conn_record.info.pop("_metrics_endpoint", "unknown")
|
||||
start = conn_record.info.pop("_metrics_checkout_time", None)
|
||||
_checkin_total.labels(engine=label).inc()
|
||||
_connections_held.labels(handler=handler, engine=label).dec()
|
||||
if start is not None:
|
||||
_hold_seconds.labels(handler=handler, engine=label).observe(
|
||||
time.monotonic() - start
|
||||
)
|
||||
|
||||
@event.listens_for(engine, "connect")
|
||||
def on_connect(
|
||||
dbapi_conn: DBAPIConnection, # noqa: ARG001
|
||||
conn_record: ConnectionPoolEntry, # noqa: ARG001
|
||||
) -> None:
|
||||
_connections_created_total.labels(engine=label).inc()
|
||||
|
||||
@event.listens_for(engine, "invalidate")
|
||||
def on_invalidate(
|
||||
dbapi_conn: DBAPIConnection, # noqa: ARG001
|
||||
conn_record: ConnectionPoolEntry,
|
||||
exception: BaseException | None, # noqa: ARG001
|
||||
) -> None:
|
||||
_invalidations_total.labels(engine=label).inc()
|
||||
# Defensively clean up the held-connections gauge in case checkin
|
||||
# doesn't fire after invalidation (e.g. hard pool shutdown).
|
||||
handler = conn_record.info.pop("_metrics_endpoint", None)
|
||||
start = conn_record.info.pop("_metrics_checkout_time", None)
|
||||
if handler:
|
||||
_connections_held.labels(handler=handler, engine=label).dec()
|
||||
if start is not None:
|
||||
_hold_seconds.labels(handler=handler or "unknown", engine=label).observe(
|
||||
time.monotonic() - start
|
||||
)
|
||||
|
||||
|
||||
def setup_postgres_connection_pool_metrics(
|
||||
engines: dict[str, Engine | AsyncEngine],
|
||||
) -> None:
|
||||
"""Register pool metrics for all provided engines.
|
||||
|
||||
Args:
|
||||
engines: Mapping of engine label to Engine or AsyncEngine.
|
||||
Example: {"sync": sync_engine, "async": async_engine, "readonly": ro_engine}
|
||||
|
||||
Engines using NullPool are skipped (no pool state to monitor).
|
||||
For AsyncEngine, events are registered on the underlying sync_engine.
|
||||
"""
|
||||
collector = PoolStateCollector()
|
||||
|
||||
for label, engine in engines.items():
|
||||
# Resolve async engines to their underlying sync engine
|
||||
sync_engine = engine.sync_engine if isinstance(engine, AsyncEngine) else engine
|
||||
|
||||
pool = sync_engine.pool
|
||||
if not isinstance(pool, QueuePool):
|
||||
logger.info(
|
||||
f"Skipping pool metrics for engine '{label}' "
|
||||
f"({type(pool).__name__} — no pool state)"
|
||||
)
|
||||
continue
|
||||
|
||||
collector.add_pool(label, pool)
|
||||
_register_pool_events(sync_engine, label)
|
||||
logger.info(f"Registered pool metrics for engine '{label}'")
|
||||
|
||||
REGISTRY.register(collector)
|
||||
@@ -1,64 +0,0 @@
|
||||
"""Prometheus metrics setup for the Onyx API server.
|
||||
|
||||
Orchestrates HTTP request instrumentation via ``prometheus-fastapi-instrumentator``:
|
||||
- Request count, latency histograms, in-progress gauges
|
||||
- Pool checkout timeout exception handler
|
||||
- Custom metric callbacks (e.g. slow request counting)
|
||||
|
||||
SQLAlchemy connection pool metrics are registered separately via
|
||||
``setup_postgres_connection_pool_metrics`` during application lifespan
|
||||
(after engines are created).
|
||||
"""
|
||||
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
from sqlalchemy.exc import TimeoutError as SATimeoutError
|
||||
from starlette.applications import Starlette
|
||||
|
||||
from onyx.server.metrics.postgres_connection_pool import pool_timeout_handler
|
||||
from onyx.server.metrics.slow_requests import slow_request_callback
|
||||
|
||||
_EXCLUDED_HANDLERS = [
|
||||
"/health",
|
||||
"/metrics",
|
||||
"/openapi.json",
|
||||
]
|
||||
|
||||
# Denser buckets for per-handler latency histograms. The instrumentator's
|
||||
# default (0.1, 0.5, 1) is too coarse for meaningful P95/P99 computation.
|
||||
_LATENCY_BUCKETS = (
|
||||
0.01,
|
||||
0.025,
|
||||
0.05,
|
||||
0.1,
|
||||
0.25,
|
||||
0.5,
|
||||
1.0,
|
||||
2.5,
|
||||
5.0,
|
||||
10.0,
|
||||
)
|
||||
|
||||
|
||||
def setup_prometheus_metrics(app: Starlette) -> None:
|
||||
"""Initialize HTTP request metrics for the Onyx API server.
|
||||
|
||||
Must be called in ``get_application()`` BEFORE the app starts, because
|
||||
the instrumentator adds middleware via ``app.add_middleware()``.
|
||||
|
||||
Args:
|
||||
app: The FastAPI/Starlette application to instrument.
|
||||
"""
|
||||
app.add_exception_handler(SATimeoutError, pool_timeout_handler)
|
||||
|
||||
instrumentator = Instrumentator(
|
||||
should_group_status_codes=False,
|
||||
should_ignore_untemplated=False,
|
||||
should_group_untemplated=True,
|
||||
should_instrument_requests_inprogress=True,
|
||||
inprogress_labels=True,
|
||||
excluded_handlers=_EXCLUDED_HANDLERS,
|
||||
)
|
||||
|
||||
instrumentator.add(slow_request_callback)
|
||||
|
||||
instrumentator.instrument(app, latency_lowr_buckets=_LATENCY_BUCKETS).expose(app)
|
||||
@@ -1,31 +0,0 @@
|
||||
"""Slow request counter metric.
|
||||
|
||||
Increments a counter whenever a request exceeds a configurable duration
|
||||
threshold. Useful for identifying endpoints that regularly take too long.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from prometheus_client import Counter
|
||||
from prometheus_fastapi_instrumentator.metrics import Info
|
||||
|
||||
SLOW_REQUEST_THRESHOLD_SECONDS: float = max(
|
||||
0.0,
|
||||
float(os.environ.get("SLOW_REQUEST_THRESHOLD_SECONDS", "1.0")),
|
||||
)
|
||||
|
||||
_slow_requests = Counter(
|
||||
"onyx_api_slow_requests_total",
|
||||
"Total requests exceeding the slow request threshold",
|
||||
["method", "handler", "status"],
|
||||
)
|
||||
|
||||
|
||||
def slow_request_callback(info: Info) -> None:
|
||||
"""Increment slow request counter when duration exceeds threshold."""
|
||||
if info.modified_duration > SLOW_REQUEST_THRESHOLD_SECONDS:
|
||||
_slow_requests.labels(
|
||||
method=info.method,
|
||||
handler=info.modified_handler,
|
||||
status=info.modified_status,
|
||||
).inc()
|
||||
63
backend/onyx/server/prometheus_instrumentation.py
Normal file
63
backend/onyx/server/prometheus_instrumentation.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""Prometheus instrumentation for the Onyx API server.
|
||||
|
||||
Provides a production-grade metrics configuration with:
|
||||
|
||||
- Exact HTTP status codes (no grouping into 2xx/3xx)
|
||||
- In-progress request gauge broken down by handler and method
|
||||
- Custom latency histogram buckets tuned for API workloads
|
||||
- Request/response size tracking
|
||||
- Slow request counter with configurable threshold
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from prometheus_client import Counter
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
from prometheus_fastapi_instrumentator.metrics import Info
|
||||
from starlette.applications import Starlette
|
||||
|
||||
SLOW_REQUEST_THRESHOLD_SECONDS: float = float(
|
||||
os.environ.get("SLOW_REQUEST_THRESHOLD_SECONDS", "1.0")
|
||||
)
|
||||
|
||||
_EXCLUDED_HANDLERS = [
|
||||
"/health",
|
||||
"/metrics",
|
||||
"/openapi.json",
|
||||
]
|
||||
|
||||
_slow_requests = Counter(
|
||||
"onyx_api_slow_requests_total",
|
||||
"Total requests exceeding the slow request threshold",
|
||||
["method", "handler", "status"],
|
||||
)
|
||||
|
||||
|
||||
def _slow_request_callback(info: Info) -> None:
|
||||
"""Increment slow request counter when duration exceeds threshold."""
|
||||
if info.modified_duration > SLOW_REQUEST_THRESHOLD_SECONDS:
|
||||
_slow_requests.labels(
|
||||
method=info.method,
|
||||
handler=info.modified_handler,
|
||||
status=info.modified_status,
|
||||
).inc()
|
||||
|
||||
|
||||
def setup_prometheus_metrics(app: Starlette) -> None:
|
||||
"""Configure and attach Prometheus instrumentation to the FastAPI app.
|
||||
|
||||
Records exact status codes, tracks in-progress requests per handler,
|
||||
and counts slow requests exceeding a configurable threshold.
|
||||
"""
|
||||
instrumentator = Instrumentator(
|
||||
should_group_status_codes=False,
|
||||
should_ignore_untemplated=False,
|
||||
should_group_untemplated=True,
|
||||
should_instrument_requests_inprogress=True,
|
||||
inprogress_labels=True,
|
||||
excluded_handlers=_EXCLUDED_HANDLERS,
|
||||
)
|
||||
|
||||
instrumentator.add(_slow_request_callback)
|
||||
|
||||
instrumentator.instrument(app).expose(app)
|
||||
@@ -36,8 +36,6 @@ from onyx.server.query_and_chat.streaming_models import OpenUrlStart
|
||||
from onyx.server.query_and_chat.streaming_models import OpenUrlUrls
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import PythonToolDelta
|
||||
from onyx.server.query_and_chat.streaming_models import PythonToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningDelta
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningStart
|
||||
from onyx.server.query_and_chat.streaming_models import ResearchAgentStart
|
||||
@@ -52,7 +50,6 @@ from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
)
|
||||
from onyx.tools.tool_implementations.memory.memory_tool import MemoryTool
|
||||
from onyx.tools.tool_implementations.open_url.open_url_tool import OpenURLTool
|
||||
from onyx.tools.tool_implementations.python.python_tool import PythonTool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.tool_implementations.web_search.web_search_tool import WebSearchTool
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -380,37 +377,6 @@ def create_memory_packets(
|
||||
return packets
|
||||
|
||||
|
||||
def create_python_tool_packets(
|
||||
code: str,
|
||||
stdout: str,
|
||||
stderr: str,
|
||||
file_ids: list[str],
|
||||
turn_index: int,
|
||||
tab_index: int = 0,
|
||||
) -> list[Packet]:
|
||||
"""Recreate PythonToolStart + PythonToolDelta + SectionEnd from the stored
|
||||
tool call data so the frontend can display both the code and its output
|
||||
on page reload."""
|
||||
packets: list[Packet] = []
|
||||
placement = Placement(turn_index=turn_index, tab_index=tab_index)
|
||||
|
||||
packets.append(Packet(placement=placement, obj=PythonToolStart(code=code)))
|
||||
|
||||
packets.append(
|
||||
Packet(
|
||||
placement=placement,
|
||||
obj=PythonToolDelta(
|
||||
stdout=stdout,
|
||||
stderr=stderr,
|
||||
file_ids=file_ids,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
packets.append(Packet(placement=placement, obj=SectionEnd()))
|
||||
return packets
|
||||
|
||||
|
||||
def create_search_packets(
|
||||
search_queries: list[str],
|
||||
search_docs: list[SavedSearchDoc],
|
||||
@@ -620,41 +586,6 @@ def translate_assistant_message_to_packets(
|
||||
)
|
||||
)
|
||||
|
||||
elif tool.in_code_tool_id == PythonTool.__name__:
|
||||
code = cast(
|
||||
str,
|
||||
tool_call.tool_call_arguments.get("code", ""),
|
||||
)
|
||||
stdout = ""
|
||||
stderr = ""
|
||||
file_ids: list[str] = []
|
||||
if tool_call.tool_call_response:
|
||||
try:
|
||||
response_data = json.loads(tool_call.tool_call_response)
|
||||
stdout = response_data.get("stdout", "")
|
||||
stderr = response_data.get("stderr", "")
|
||||
generated_files = response_data.get(
|
||||
"generated_files", []
|
||||
)
|
||||
file_ids = [
|
||||
f.get("file_link", "").split("/")[-1]
|
||||
for f in generated_files
|
||||
if f.get("file_link")
|
||||
]
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
# Fall back to raw response as stdout
|
||||
stdout = tool_call.tool_call_response
|
||||
turn_tool_packets.extend(
|
||||
create_python_tool_packets(
|
||||
code=code,
|
||||
stdout=stdout,
|
||||
stderr=stderr,
|
||||
file_ids=file_ids,
|
||||
turn_index=turn_num,
|
||||
tab_index=tool_call.tab_index,
|
||||
)
|
||||
)
|
||||
|
||||
else:
|
||||
# Custom tool or unknown tool
|
||||
turn_tool_packets.extend(
|
||||
|
||||
@@ -24,7 +24,6 @@ from onyx.auth.users import get_user_manager
|
||||
from onyx.auth.users import UserManager
|
||||
from onyx.configs.app_configs import REQUIRE_EMAIL_VERIFICATION
|
||||
from onyx.configs.app_configs import SAML_CONF_DIR
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.db.auth import get_user_count
|
||||
from onyx.db.auth import get_user_db
|
||||
from onyx.db.engine.async_sql_engine import get_async_session_context_manager
|
||||
@@ -124,12 +123,9 @@ async def prepare_from_fastapi_request(request: Request) -> dict[str, Any]:
|
||||
if request.client is None:
|
||||
raise ValueError("Invalid request for SAML")
|
||||
|
||||
# Derive http_host and server_port from WEB_DOMAIN (a trusted env var)
|
||||
# instead of X-Forwarded-* headers, which can be spoofed by an attacker
|
||||
# to poison SAML redirect URLs (host header poisoning).
|
||||
parsed_domain = urlparse(WEB_DOMAIN)
|
||||
http_host = parsed_domain.hostname or request.client.host
|
||||
server_port = parsed_domain.port or (443 if parsed_domain.scheme == "https" else 80)
|
||||
# Use X-Forwarded headers if available
|
||||
http_host = request.headers.get("X-Forwarded-Host") or request.client.host
|
||||
server_port = request.headers.get("X-Forwarded-Port") or request.url.port
|
||||
|
||||
rv: dict[str, Any] = {
|
||||
"http_host": http_host,
|
||||
|
||||
@@ -55,9 +55,7 @@ class Settings(BaseModel):
|
||||
gpu_enabled: bool | None = None
|
||||
application_status: ApplicationStatus = ApplicationStatus.ACTIVE
|
||||
anonymous_user_enabled: bool | None = None
|
||||
invite_only_enabled: bool = False
|
||||
deep_research_enabled: bool | None = None
|
||||
search_ui_enabled: bool | None = None
|
||||
|
||||
# Enterprise features flag - set by license enforcement at runtime
|
||||
# When LICENSE_ENFORCEMENT_ENABLED=true, this reflects license status
|
||||
|
||||
@@ -199,12 +199,6 @@ class PythonToolOverrideKwargs(BaseModel):
|
||||
chat_files: list[ChatFile] = []
|
||||
|
||||
|
||||
class ImageGenerationToolOverrideKwargs(BaseModel):
|
||||
"""Override kwargs for image generation tool calls."""
|
||||
|
||||
recent_generated_image_file_ids: list[str] = []
|
||||
|
||||
|
||||
class SearchToolRunContext(BaseModel):
|
||||
emitter: Emitter
|
||||
|
||||
|
||||
@@ -171,8 +171,10 @@ def construct_tools(
|
||||
if not search_tool_config:
|
||||
search_tool_config = SearchToolConfig()
|
||||
|
||||
# TODO concerning passing the db_session here.
|
||||
search_tool = SearchTool(
|
||||
tool_id=db_tool_model.id,
|
||||
db_session=db_session,
|
||||
emitter=emitter,
|
||||
user=user,
|
||||
persona=persona,
|
||||
@@ -420,6 +422,7 @@ def construct_tools(
|
||||
|
||||
search_tool = SearchTool(
|
||||
tool_id=search_tool_db_model.id,
|
||||
db_session=db_session,
|
||||
emitter=emitter,
|
||||
user=user,
|
||||
persona=persona,
|
||||
|
||||
@@ -11,14 +11,11 @@ from onyx.chat.emitter import Emitter
|
||||
from onyx.configs.app_configs import IMAGE_MODEL_NAME
|
||||
from onyx.configs.app_configs import IMAGE_MODEL_PROVIDER
|
||||
from onyx.db.image_generation import get_default_image_generation_config
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.utils import build_frontend_file_url
|
||||
from onyx.file_store.utils import load_chat_file_by_id
|
||||
from onyx.file_store.utils import save_files
|
||||
from onyx.image_gen.factory import get_image_generation_provider
|
||||
from onyx.image_gen.factory import validate_credentials
|
||||
from onyx.image_gen.interfaces import ImageGenerationProviderCredentials
|
||||
from onyx.image_gen.interfaces import ReferenceImage
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import GeneratedImage
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationFinal
|
||||
@@ -26,7 +23,6 @@ from onyx.server.query_and_chat.streaming_models import ImageGenerationToolHeart
|
||||
from onyx.server.query_and_chat.streaming_models import ImageGenerationToolStart
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.models import ImageGenerationToolOverrideKwargs
|
||||
from onyx.tools.models import ToolCallException
|
||||
from onyx.tools.models import ToolExecutionException
|
||||
from onyx.tools.models import ToolResponse
|
||||
@@ -35,7 +31,6 @@ from onyx.tools.tool_implementations.images.models import (
|
||||
)
|
||||
from onyx.tools.tool_implementations.images.models import ImageGenerationResponse
|
||||
from onyx.tools.tool_implementations.images.models import ImageShape
|
||||
from onyx.utils.b64 import get_image_type_from_bytes
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
|
||||
@@ -45,10 +40,10 @@ logger = setup_logger()
|
||||
HEARTBEAT_INTERVAL = 5.0
|
||||
|
||||
PROMPT_FIELD = "prompt"
|
||||
REFERENCE_IMAGE_FILE_IDS_FIELD = "reference_image_file_ids"
|
||||
|
||||
|
||||
class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
|
||||
# override_kwargs is not supported for image generation tools
|
||||
class ImageGenerationTool(Tool[None]):
|
||||
NAME = "generate_image"
|
||||
DESCRIPTION = "Generate an image based on a prompt. Do not use unless the user specifically requests an image."
|
||||
DISPLAY_NAME = "Image Generation"
|
||||
@@ -64,7 +59,6 @@ class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
|
||||
) -> None:
|
||||
super().__init__(emitter=emitter)
|
||||
self.model = model
|
||||
self.provider = provider
|
||||
self.num_imgs = num_imgs
|
||||
|
||||
self.img_provider = get_image_generation_provider(
|
||||
@@ -139,16 +133,6 @@ class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
|
||||
),
|
||||
"enum": [shape.value for shape in ImageShape],
|
||||
},
|
||||
REFERENCE_IMAGE_FILE_IDS_FIELD: {
|
||||
"type": "array",
|
||||
"description": (
|
||||
"Optional image file IDs to use as reference context for edits/variations. "
|
||||
"Use the file_id values returned by previous generate_image calls."
|
||||
),
|
||||
"items": {
|
||||
"type": "string",
|
||||
},
|
||||
},
|
||||
},
|
||||
"required": [PROMPT_FIELD],
|
||||
},
|
||||
@@ -164,10 +148,7 @@ class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
|
||||
)
|
||||
|
||||
def _generate_image(
|
||||
self,
|
||||
prompt: str,
|
||||
shape: ImageShape,
|
||||
reference_images: list[ReferenceImage] | None = None,
|
||||
self, prompt: str, shape: ImageShape
|
||||
) -> tuple[ImageGenerationResponse, Any]:
|
||||
if shape == ImageShape.LANDSCAPE:
|
||||
if "gpt-image-1" in self.model:
|
||||
@@ -188,7 +169,6 @@ class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
|
||||
model=self.model,
|
||||
size=size,
|
||||
n=1,
|
||||
reference_images=reference_images,
|
||||
# response_format parameter is not supported for gpt-image-1
|
||||
response_format=None if "gpt-image-1" in self.model else "b64_json",
|
||||
)
|
||||
@@ -251,117 +231,10 @@ class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
|
||||
emit_error_packet=True,
|
||||
)
|
||||
|
||||
def _resolve_reference_image_file_ids(
|
||||
self,
|
||||
llm_kwargs: dict[str, Any],
|
||||
override_kwargs: ImageGenerationToolOverrideKwargs | None,
|
||||
) -> list[str]:
|
||||
raw_reference_ids = llm_kwargs.get(REFERENCE_IMAGE_FILE_IDS_FIELD)
|
||||
if raw_reference_ids is not None:
|
||||
if not isinstance(raw_reference_ids, list) or not all(
|
||||
isinstance(file_id, str) for file_id in raw_reference_ids
|
||||
):
|
||||
raise ToolCallException(
|
||||
message=(
|
||||
f"Invalid {REFERENCE_IMAGE_FILE_IDS_FIELD}: expected array of strings, "
|
||||
f"got {type(raw_reference_ids)}"
|
||||
),
|
||||
llm_facing_message=(
|
||||
f"The '{REFERENCE_IMAGE_FILE_IDS_FIELD}' field must be an array of file_id strings."
|
||||
),
|
||||
)
|
||||
reference_image_file_ids = [
|
||||
file_id.strip() for file_id in raw_reference_ids if file_id.strip()
|
||||
]
|
||||
elif (
|
||||
override_kwargs
|
||||
and override_kwargs.recent_generated_image_file_ids
|
||||
and self.img_provider.supports_reference_images
|
||||
):
|
||||
# If no explicit reference was provided, default to the most recently generated image.
|
||||
reference_image_file_ids = [
|
||||
override_kwargs.recent_generated_image_file_ids[-1]
|
||||
]
|
||||
else:
|
||||
reference_image_file_ids = []
|
||||
|
||||
# Deduplicate while preserving order.
|
||||
deduped_reference_image_ids: list[str] = []
|
||||
seen_ids: set[str] = set()
|
||||
for file_id in reference_image_file_ids:
|
||||
if file_id in seen_ids:
|
||||
continue
|
||||
seen_ids.add(file_id)
|
||||
deduped_reference_image_ids.append(file_id)
|
||||
|
||||
if not deduped_reference_image_ids:
|
||||
return []
|
||||
|
||||
if not self.img_provider.supports_reference_images:
|
||||
raise ToolCallException(
|
||||
message=(
|
||||
f"Reference images requested but provider '{self.provider}' "
|
||||
"does not support image-editing context."
|
||||
),
|
||||
llm_facing_message=(
|
||||
"This image provider does not support editing from previous image context. "
|
||||
"Try text-only generation, or switch to a provider/model that supports image edits."
|
||||
),
|
||||
)
|
||||
|
||||
max_reference_images = self.img_provider.max_reference_images
|
||||
if max_reference_images > 0:
|
||||
return deduped_reference_image_ids[-max_reference_images:]
|
||||
return deduped_reference_image_ids
|
||||
|
||||
def _load_reference_images(
|
||||
self,
|
||||
reference_image_file_ids: list[str],
|
||||
) -> list[ReferenceImage]:
|
||||
reference_images: list[ReferenceImage] = []
|
||||
|
||||
for file_id in reference_image_file_ids:
|
||||
try:
|
||||
loaded_file = load_chat_file_by_id(file_id)
|
||||
except Exception as e:
|
||||
raise ToolCallException(
|
||||
message=f"Could not load reference image file '{file_id}': {e}",
|
||||
llm_facing_message=(
|
||||
f"Reference image file '{file_id}' could not be loaded. "
|
||||
"Use file_id values returned by previous generate_image calls."
|
||||
),
|
||||
)
|
||||
|
||||
if loaded_file.file_type != ChatFileType.IMAGE:
|
||||
raise ToolCallException(
|
||||
message=f"Reference file '{file_id}' is not an image",
|
||||
llm_facing_message=f"Reference file '{file_id}' is not an image.",
|
||||
)
|
||||
|
||||
try:
|
||||
mime_type = get_image_type_from_bytes(loaded_file.content)
|
||||
except Exception as e:
|
||||
raise ToolCallException(
|
||||
message=f"Unsupported reference image format for '{file_id}': {e}",
|
||||
llm_facing_message=(
|
||||
f"Reference image '{file_id}' has an unsupported format. "
|
||||
"Only PNG, JPEG, GIF, and WEBP are supported."
|
||||
),
|
||||
)
|
||||
|
||||
reference_images.append(
|
||||
ReferenceImage(
|
||||
data=loaded_file.content,
|
||||
mime_type=mime_type,
|
||||
)
|
||||
)
|
||||
|
||||
return reference_images
|
||||
|
||||
def run(
|
||||
self,
|
||||
placement: Placement,
|
||||
override_kwargs: ImageGenerationToolOverrideKwargs | None = None,
|
||||
override_kwargs: None = None, # noqa: ARG002
|
||||
**llm_kwargs: Any,
|
||||
) -> ToolResponse:
|
||||
if PROMPT_FIELD not in llm_kwargs:
|
||||
@@ -374,11 +247,6 @@ class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
|
||||
)
|
||||
prompt = cast(str, llm_kwargs[PROMPT_FIELD])
|
||||
shape = ImageShape(llm_kwargs.get("shape", ImageShape.SQUARE.value))
|
||||
reference_image_file_ids = self._resolve_reference_image_file_ids(
|
||||
llm_kwargs=llm_kwargs,
|
||||
override_kwargs=override_kwargs,
|
||||
)
|
||||
reference_images = self._load_reference_images(reference_image_file_ids)
|
||||
|
||||
# Use threading to generate images in parallel while emitting heartbeats
|
||||
results: list[tuple[ImageGenerationResponse, Any] | None] = [
|
||||
@@ -399,7 +267,6 @@ class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
|
||||
(
|
||||
prompt,
|
||||
shape,
|
||||
reference_images or None,
|
||||
),
|
||||
)
|
||||
for _ in range(self.num_imgs)
|
||||
@@ -480,7 +347,6 @@ class ImageGenerationTool(Tool[ImageGenerationToolOverrideKwargs | None]):
|
||||
llm_facing_response = json.dumps(
|
||||
[
|
||||
{
|
||||
"file_id": img.file_id,
|
||||
"revised_prompt": img.revised_prompt,
|
||||
}
|
||||
for img in generated_images_metadata
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from onyx.file_processing.html_utils import ParsedHTML
|
||||
from onyx.file_processing.html_utils import web_html_cleanup
|
||||
@@ -22,22 +21,10 @@ from onyx.utils.web_content import title_from_url
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
DEFAULT_READ_TIMEOUT_SECONDS = 15
|
||||
DEFAULT_CONNECT_TIMEOUT_SECONDS = 5
|
||||
DEFAULT_TIMEOUT_SECONDS = 15
|
||||
DEFAULT_USER_AGENT = "OnyxWebCrawler/1.0 (+https://www.onyx.app)"
|
||||
DEFAULT_MAX_PDF_SIZE_BYTES = 50 * 1024 * 1024 # 50 MB
|
||||
DEFAULT_MAX_HTML_SIZE_BYTES = 20 * 1024 * 1024 # 20 MB
|
||||
DEFAULT_MAX_WORKERS = 5
|
||||
|
||||
|
||||
def _failed_result(url: str) -> WebContent:
|
||||
return WebContent(
|
||||
title="",
|
||||
link=url,
|
||||
full_content="",
|
||||
published_date=None,
|
||||
scrape_successful=False,
|
||||
)
|
||||
|
||||
|
||||
class OnyxWebCrawler(WebContentProvider):
|
||||
@@ -50,14 +37,12 @@ class OnyxWebCrawler(WebContentProvider):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
timeout_seconds: int = DEFAULT_READ_TIMEOUT_SECONDS,
|
||||
connect_timeout_seconds: int = DEFAULT_CONNECT_TIMEOUT_SECONDS,
|
||||
timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
|
||||
user_agent: str = DEFAULT_USER_AGENT,
|
||||
max_pdf_size_bytes: int | None = None,
|
||||
max_html_size_bytes: int | None = None,
|
||||
) -> None:
|
||||
self._read_timeout_seconds = timeout_seconds
|
||||
self._connect_timeout_seconds = connect_timeout_seconds
|
||||
self._timeout_seconds = timeout_seconds
|
||||
self._max_pdf_size_bytes = max_pdf_size_bytes
|
||||
self._max_html_size_bytes = max_html_size_bytes
|
||||
self._headers = {
|
||||
@@ -66,68 +51,75 @@ class OnyxWebCrawler(WebContentProvider):
|
||||
}
|
||||
|
||||
def contents(self, urls: Sequence[str]) -> list[WebContent]:
|
||||
if not urls:
|
||||
return []
|
||||
|
||||
max_workers = min(DEFAULT_MAX_WORKERS, len(urls))
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
return list(executor.map(self._fetch_url_safe, urls))
|
||||
|
||||
def _fetch_url_safe(self, url: str) -> WebContent:
|
||||
"""Wrapper that catches all exceptions so one bad URL doesn't kill the batch."""
|
||||
try:
|
||||
return self._fetch_url(url)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Onyx crawler unexpected error for %s (%s)",
|
||||
url,
|
||||
exc.__class__.__name__,
|
||||
)
|
||||
return _failed_result(url)
|
||||
results: list[WebContent] = []
|
||||
for url in urls:
|
||||
results.append(self._fetch_url(url))
|
||||
return results
|
||||
|
||||
def _fetch_url(self, url: str) -> WebContent:
|
||||
try:
|
||||
# Use SSRF-safe request to prevent DNS rebinding attacks
|
||||
response = ssrf_safe_get(
|
||||
url,
|
||||
headers=self._headers,
|
||||
timeout=(self._connect_timeout_seconds, self._read_timeout_seconds),
|
||||
url, headers=self._headers, timeout=self._timeout_seconds
|
||||
)
|
||||
except SSRFException as exc:
|
||||
logger.error(
|
||||
"SSRF protection blocked request to %s (%s)",
|
||||
"SSRF protection blocked request to %s: %s",
|
||||
url,
|
||||
exc.__class__.__name__,
|
||||
str(exc),
|
||||
)
|
||||
return _failed_result(url)
|
||||
except Exception as exc:
|
||||
return WebContent(
|
||||
title="",
|
||||
link=url,
|
||||
full_content="",
|
||||
published_date=None,
|
||||
scrape_successful=False,
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - network failures vary
|
||||
logger.warning(
|
||||
"Onyx crawler failed to fetch %s (%s)",
|
||||
url,
|
||||
exc.__class__.__name__,
|
||||
)
|
||||
return _failed_result(url)
|
||||
return WebContent(
|
||||
title="",
|
||||
link=url,
|
||||
full_content="",
|
||||
published_date=None,
|
||||
scrape_successful=False,
|
||||
)
|
||||
|
||||
if response.status_code >= 400:
|
||||
logger.warning("Onyx crawler received %s for %s", response.status_code, url)
|
||||
return _failed_result(url)
|
||||
return WebContent(
|
||||
title="",
|
||||
link=url,
|
||||
full_content="",
|
||||
published_date=None,
|
||||
scrape_successful=False,
|
||||
)
|
||||
|
||||
content_type = response.headers.get("Content-Type", "")
|
||||
content = response.content
|
||||
|
||||
content_sniff = content[:1024] if content else None
|
||||
content_sniff = response.content[:1024] if response.content else None
|
||||
if is_pdf_resource(url, content_type, content_sniff):
|
||||
if (
|
||||
self._max_pdf_size_bytes is not None
|
||||
and len(content) > self._max_pdf_size_bytes
|
||||
and len(response.content) > self._max_pdf_size_bytes
|
||||
):
|
||||
logger.warning(
|
||||
"PDF content too large (%d bytes) for %s, max is %d",
|
||||
len(content),
|
||||
len(response.content),
|
||||
url,
|
||||
self._max_pdf_size_bytes,
|
||||
)
|
||||
return _failed_result(url)
|
||||
text_content, metadata = extract_pdf_text(content)
|
||||
return WebContent(
|
||||
title="",
|
||||
link=url,
|
||||
full_content="",
|
||||
published_date=None,
|
||||
scrape_successful=False,
|
||||
)
|
||||
text_content, metadata = extract_pdf_text(response.content)
|
||||
title = title_from_pdf_metadata(metadata) or title_from_url(url)
|
||||
return WebContent(
|
||||
title=title,
|
||||
@@ -139,19 +131,25 @@ class OnyxWebCrawler(WebContentProvider):
|
||||
|
||||
if (
|
||||
self._max_html_size_bytes is not None
|
||||
and len(content) > self._max_html_size_bytes
|
||||
and len(response.content) > self._max_html_size_bytes
|
||||
):
|
||||
logger.warning(
|
||||
"HTML content too large (%d bytes) for %s, max is %d",
|
||||
len(content),
|
||||
len(response.content),
|
||||
url,
|
||||
self._max_html_size_bytes,
|
||||
)
|
||||
return _failed_result(url)
|
||||
return WebContent(
|
||||
title="",
|
||||
link=url,
|
||||
full_content="",
|
||||
published_date=None,
|
||||
scrape_successful=False,
|
||||
)
|
||||
|
||||
try:
|
||||
decoded_html = decode_html_bytes(
|
||||
content,
|
||||
response.content,
|
||||
content_type=content_type,
|
||||
fallback_encoding=response.apparent_encoding or response.encoding,
|
||||
)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user