mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-16 23:35:46 +00:00
Compare commits
41 Commits
fix/chat-d
...
nightly-la
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4bb3ee03a0 | ||
|
|
1bb23d6837 | ||
|
|
f447359815 | ||
|
|
851e0b05f2 | ||
|
|
094cc940a4 | ||
|
|
51be9000bb | ||
|
|
80ecdb711d | ||
|
|
a599176bbf | ||
|
|
e0341b4c8a | ||
|
|
4c93fd448f | ||
|
|
84d916e210 | ||
|
|
f57ed2a8dd | ||
|
|
713889babf | ||
|
|
58c641d8ec | ||
|
|
94985e24c6 | ||
|
|
4c71a5f5ff | ||
|
|
b19e3a500b | ||
|
|
267fe027f5 | ||
|
|
0d4d8c0d64 | ||
|
|
6f9d8c0cff | ||
|
|
5031096a2b | ||
|
|
797e113000 | ||
|
|
edc2892785 | ||
|
|
ef4d5dcec3 | ||
|
|
0b5e3e5ee4 | ||
|
|
f5afb3621e | ||
|
|
9f72826143 | ||
|
|
ab7a4184df | ||
|
|
16a14bac89 | ||
|
|
baaf31513c | ||
|
|
0b01d7f848 | ||
|
|
23ff3476bc | ||
|
|
0c7ba8e2ac | ||
|
|
dad99cbec7 | ||
|
|
3e78c2f087 | ||
|
|
e822afdcfa | ||
|
|
b824951c89 | ||
|
|
ca20e527fc | ||
|
|
c8e65cce1e | ||
|
|
6c349687da | ||
|
|
3b64793d4b |
@@ -64,7 +64,7 @@ jobs:
|
||||
|
||||
- name: Backend Image Docker Build and Push
|
||||
id: build
|
||||
uses: docker/build-push-action@v5
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile
|
||||
|
||||
@@ -54,7 +54,7 @@ jobs:
|
||||
|
||||
- name: Build and push by digest
|
||||
id: build
|
||||
uses: docker/build-push-action@v5
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: ./web
|
||||
file: ./web/Dockerfile
|
||||
|
||||
@@ -80,7 +80,7 @@ jobs:
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and Push AMD64
|
||||
uses: docker/build-push-action@v5
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
@@ -126,7 +126,7 @@ jobs:
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build and Push ARM64
|
||||
uses: docker/build-push-action@v5
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/Dockerfile.model_server
|
||||
|
||||
@@ -70,7 +70,7 @@ jobs:
|
||||
|
||||
- name: Build and push by digest
|
||||
id: build
|
||||
uses: docker/build-push-action@v5
|
||||
uses: docker/build-push-action@v6
|
||||
with:
|
||||
context: ./web
|
||||
file: ./web/Dockerfile
|
||||
|
||||
23
.vscode/launch.template.jsonc
vendored
23
.vscode/launch.template.jsonc
vendored
@@ -428,6 +428,29 @@
|
||||
"--filename",
|
||||
"generated/openapi.json",
|
||||
]
|
||||
},
|
||||
{
|
||||
// script to debug multi tenant db issues
|
||||
"name": "Onyx DB Manager (Top Chunks)",
|
||||
"type": "debugpy",
|
||||
"request": "launch",
|
||||
"program": "scripts/debugging/onyx_db.py",
|
||||
"cwd": "${workspaceFolder}/backend",
|
||||
"envFile": "${workspaceFolder}/.env",
|
||||
"env": {
|
||||
"PYTHONUNBUFFERED": "1",
|
||||
"PYTHONPATH": "."
|
||||
},
|
||||
"args": [
|
||||
"--password",
|
||||
"your_password_here",
|
||||
"--port",
|
||||
"5433",
|
||||
"--report",
|
||||
"top-chunks",
|
||||
"--filename",
|
||||
"generated/tenants_by_num_docs.csv"
|
||||
]
|
||||
},
|
||||
{
|
||||
"name": "Debug React Web App in Chrome",
|
||||
|
||||
@@ -6,11 +6,8 @@ Create Date: 2024-04-15 01:36:02.952809
|
||||
|
||||
"""
|
||||
|
||||
import json
|
||||
from typing import cast
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "703313b75876"
|
||||
@@ -54,27 +51,10 @@ def upgrade() -> None:
|
||||
sa.PrimaryKeyConstraint("rate_limit_id", "user_group_id"),
|
||||
)
|
||||
|
||||
try:
|
||||
settings_json = cast(str, get_kv_store().load("token_budget_settings"))
|
||||
settings = json.loads(settings_json)
|
||||
|
||||
is_enabled = settings.get("enable_token_budget", False)
|
||||
token_budget = settings.get("token_budget", -1)
|
||||
period_hours = settings.get("period_hours", -1)
|
||||
|
||||
if is_enabled and token_budget > 0 and period_hours > 0:
|
||||
op.execute(
|
||||
f"INSERT INTO token_rate_limit \
|
||||
(enabled, token_budget, period_hours, scope) VALUES \
|
||||
({is_enabled}, {token_budget}, {period_hours}, 'GLOBAL')"
|
||||
)
|
||||
|
||||
# Delete the dynamic config
|
||||
get_kv_store().delete("token_budget_settings")
|
||||
|
||||
except Exception:
|
||||
# Ignore if the dynamic config is not found
|
||||
pass
|
||||
# NOTE: rate limit settings used to be stored in the "token_budget_settings" key in the
|
||||
# KeyValueStore. This will now be lost. The KV store works differently than it used to
|
||||
# so the migration is fairly complicated and likely not worth it to support (pretty much
|
||||
# nobody will have it set)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
|
||||
@@ -0,0 +1,128 @@
|
||||
"""add_cascade_deletes_to_agent_tables
|
||||
|
||||
Revision ID: ca04500b9ee8
|
||||
Revises: 238b84885828
|
||||
Create Date: 2025-05-30 16:03:51.112263
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "ca04500b9ee8"
|
||||
down_revision = "238b84885828"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Drop existing foreign key constraints
|
||||
op.drop_constraint(
|
||||
"agent__sub_question_primary_question_id_fkey",
|
||||
"agent__sub_question",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"agent__sub_query_parent_question_id_fkey",
|
||||
"agent__sub_query",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"chat_message__standard_answer_chat_message_id_fkey",
|
||||
"chat_message__standard_answer",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"agent__sub_query__search_doc_sub_query_id_fkey",
|
||||
"agent__sub_query__search_doc",
|
||||
type_="foreignkey",
|
||||
)
|
||||
|
||||
# Recreate foreign key constraints with CASCADE delete
|
||||
op.create_foreign_key(
|
||||
"agent__sub_question_primary_question_id_fkey",
|
||||
"agent__sub_question",
|
||||
"chat_message",
|
||||
["primary_question_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"agent__sub_query_parent_question_id_fkey",
|
||||
"agent__sub_query",
|
||||
"agent__sub_question",
|
||||
["parent_question_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"chat_message__standard_answer_chat_message_id_fkey",
|
||||
"chat_message__standard_answer",
|
||||
"chat_message",
|
||||
["chat_message_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"agent__sub_query__search_doc_sub_query_id_fkey",
|
||||
"agent__sub_query__search_doc",
|
||||
"agent__sub_query",
|
||||
["sub_query_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop CASCADE foreign key constraints
|
||||
op.drop_constraint(
|
||||
"agent__sub_question_primary_question_id_fkey",
|
||||
"agent__sub_question",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"agent__sub_query_parent_question_id_fkey",
|
||||
"agent__sub_query",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"chat_message__standard_answer_chat_message_id_fkey",
|
||||
"chat_message__standard_answer",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"agent__sub_query__search_doc_sub_query_id_fkey",
|
||||
"agent__sub_query__search_doc",
|
||||
type_="foreignkey",
|
||||
)
|
||||
|
||||
# Recreate foreign key constraints without CASCADE delete
|
||||
op.create_foreign_key(
|
||||
"agent__sub_question_primary_question_id_fkey",
|
||||
"agent__sub_question",
|
||||
"chat_message",
|
||||
["primary_question_id"],
|
||||
["id"],
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"agent__sub_query_parent_question_id_fkey",
|
||||
"agent__sub_query",
|
||||
"agent__sub_question",
|
||||
["parent_question_id"],
|
||||
["id"],
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"chat_message__standard_answer_chat_message_id_fkey",
|
||||
"chat_message__standard_answer",
|
||||
"chat_message",
|
||||
["chat_message_id"],
|
||||
["id"],
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"agent__sub_query__search_doc_sub_query_id_fkey",
|
||||
"agent__sub_query__search_doc",
|
||||
"agent__sub_query",
|
||||
["sub_query_id"],
|
||||
["id"],
|
||||
)
|
||||
@@ -1,8 +1,12 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
|
||||
|
||||
from ee.onyx.external_permissions.google_drive.models import GoogleDrivePermission
|
||||
from ee.onyx.external_permissions.google_drive.models import PermissionType
|
||||
from ee.onyx.external_permissions.google_drive.permission_retrieval import (
|
||||
@@ -13,6 +17,7 @@ from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from onyx.connectors.google_utils.resources import get_drive_service
|
||||
from onyx.connectors.google_utils.resources import RefreshableDriveObject
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
@@ -41,6 +46,20 @@ def _get_slim_doc_generator(
|
||||
)
|
||||
|
||||
|
||||
def _drive_connector_creds_getter(
|
||||
google_drive_connector: GoogleDriveConnector,
|
||||
) -> Callable[[], ServiceAccountCredentials | OAuthCredentials]:
|
||||
def inner() -> ServiceAccountCredentials | OAuthCredentials:
|
||||
if not google_drive_connector._creds_dict:
|
||||
raise ValueError(
|
||||
"Creds dict not found, load_credentials must be called first"
|
||||
)
|
||||
google_drive_connector.load_credentials(google_drive_connector._creds_dict)
|
||||
return google_drive_connector.creds
|
||||
|
||||
return inner
|
||||
|
||||
|
||||
def _fetch_permissions_for_permission_ids(
|
||||
google_drive_connector: GoogleDriveConnector,
|
||||
permission_info: dict[str, Any],
|
||||
@@ -54,13 +73,22 @@ def _fetch_permissions_for_permission_ids(
|
||||
if not permission_ids:
|
||||
return []
|
||||
|
||||
drive_service = get_drive_service(
|
||||
if not owner_email:
|
||||
logger.warning(
|
||||
f"No owner email found for document {doc_id}. Permission info: {permission_info}"
|
||||
)
|
||||
|
||||
refreshable_drive_service = RefreshableDriveObject(
|
||||
call_stack=lambda creds: get_drive_service(
|
||||
creds=creds,
|
||||
user_email=(owner_email or google_drive_connector.primary_admin_email),
|
||||
),
|
||||
creds=google_drive_connector.creds,
|
||||
user_email=(owner_email or google_drive_connector.primary_admin_email),
|
||||
creds_getter=_drive_connector_creds_getter(google_drive_connector),
|
||||
)
|
||||
|
||||
return get_permissions_by_ids(
|
||||
drive_service=drive_service,
|
||||
drive_service=refreshable_drive_service,
|
||||
doc_id=doc_id,
|
||||
permission_ids=permission_ids,
|
||||
)
|
||||
@@ -172,7 +200,9 @@ def gdrive_doc_sync(
|
||||
|
||||
slim_doc_generator = _get_slim_doc_generator(cc_pair, google_drive_connector)
|
||||
|
||||
total_processed = 0
|
||||
for slim_doc_batch in slim_doc_generator:
|
||||
logger.info(f"Drive perm sync: Processing {len(slim_doc_batch)} documents")
|
||||
for slim_doc in slim_doc_batch:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
@@ -188,3 +218,5 @@ def gdrive_doc_sync(
|
||||
external_access=ext_access,
|
||||
doc_id=slim_doc.id,
|
||||
)
|
||||
total_processed += len(slim_doc_batch)
|
||||
logger.info(f"Drive perm sync: Processed {total_processed} total documents")
|
||||
|
||||
@@ -1,14 +1,16 @@
|
||||
from googleapiclient.discovery import Resource # type: ignore
|
||||
from retry import retry
|
||||
|
||||
from ee.onyx.external_permissions.google_drive.models import GoogleDrivePermission
|
||||
from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval
|
||||
from onyx.connectors.google_utils.resources import RefreshableDriveObject
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@retry(tries=3, delay=2, backoff=2)
|
||||
def get_permissions_by_ids(
|
||||
drive_service: Resource,
|
||||
drive_service: RefreshableDriveObject,
|
||||
doc_id: str,
|
||||
permission_ids: list[str],
|
||||
) -> list[GoogleDrivePermission]:
|
||||
|
||||
@@ -8,7 +8,7 @@ from onyx.access.models import DocExternalAccess
|
||||
from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
|
||||
from onyx.connectors.slack.connector import get_channels
|
||||
from onyx.connectors.slack.connector import make_paginated_slack_api_call_w_retries
|
||||
from onyx.connectors.slack.connector import make_paginated_slack_api_call
|
||||
from onyx.connectors.slack.connector import SlackConnector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
@@ -64,7 +64,7 @@ def _fetch_channel_permissions(
|
||||
for channel_id in private_channel_ids:
|
||||
# Collect all member ids for the channel pagination calls
|
||||
member_ids = []
|
||||
for result in make_paginated_slack_api_call_w_retries(
|
||||
for result in make_paginated_slack_api_call(
|
||||
slack_client.conversations_members,
|
||||
channel=channel_id,
|
||||
):
|
||||
@@ -92,7 +92,7 @@ def _fetch_channel_permissions(
|
||||
external_user_emails=member_emails,
|
||||
# No group<->document mapping for slack
|
||||
external_user_group_ids=set(),
|
||||
# No way to determine if slack is invite only without enterprise liscense
|
||||
# No way to determine if slack is invite only without enterprise license
|
||||
is_public=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -10,8 +10,8 @@ from slack_sdk import WebClient
|
||||
from ee.onyx.db.external_perm import ExternalUserGroup
|
||||
from ee.onyx.external_permissions.slack.utils import fetch_user_id_to_email_map
|
||||
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
|
||||
from onyx.connectors.slack.connector import make_paginated_slack_api_call_w_retries
|
||||
from onyx.connectors.slack.connector import SlackConnector
|
||||
from onyx.connectors.slack.utils import make_paginated_slack_api_call
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -23,7 +23,7 @@ def _get_slack_group_ids(
|
||||
slack_client: WebClient,
|
||||
) -> list[str]:
|
||||
group_ids = []
|
||||
for result in make_paginated_slack_api_call_w_retries(slack_client.usergroups_list):
|
||||
for result in make_paginated_slack_api_call(slack_client.usergroups_list):
|
||||
for group in result.get("usergroups", []):
|
||||
group_ids.append(group.get("id"))
|
||||
return group_ids
|
||||
@@ -35,7 +35,7 @@ def _get_slack_group_members_email(
|
||||
user_id_to_email_map: dict[str, str],
|
||||
) -> list[str]:
|
||||
group_member_emails = []
|
||||
for result in make_paginated_slack_api_call_w_retries(
|
||||
for result in make_paginated_slack_api_call(
|
||||
slack_client.usergroups_users_list, usergroup=group_name
|
||||
):
|
||||
for member_id in result.get("users", []):
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
from slack_sdk import WebClient
|
||||
|
||||
from onyx.connectors.slack.connector import make_paginated_slack_api_call_w_retries
|
||||
from onyx.connectors.slack.utils import make_paginated_slack_api_call
|
||||
|
||||
|
||||
def fetch_user_id_to_email_map(
|
||||
slack_client: WebClient,
|
||||
) -> dict[str, str]:
|
||||
user_id_to_email_map = {}
|
||||
for user_info in make_paginated_slack_api_call_w_retries(
|
||||
for user_info in make_paginated_slack_api_call(
|
||||
slack_client.users_list,
|
||||
):
|
||||
for user in user_info.get("members", []):
|
||||
|
||||
@@ -2,6 +2,7 @@ from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Response
|
||||
from fastapi_users import exceptions
|
||||
|
||||
from ee.onyx.auth.users import current_cloud_superuser
|
||||
from ee.onyx.server.tenants.models import ImpersonateRequest
|
||||
@@ -24,14 +25,24 @@ async def impersonate_user(
|
||||
_: User = Depends(current_cloud_superuser),
|
||||
) -> Response:
|
||||
"""Allows a cloud superuser to impersonate another user by generating an impersonation JWT token"""
|
||||
tenant_id = get_tenant_id_for_email(impersonate_request.email)
|
||||
try:
|
||||
tenant_id = get_tenant_id_for_email(impersonate_request.email)
|
||||
except exceptions.UserNotExists:
|
||||
detail = f"User has no tenant mapping: {impersonate_request.email=}"
|
||||
logger.warning(detail)
|
||||
raise HTTPException(status_code=422, detail=detail)
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as tenant_session:
|
||||
user_to_impersonate = get_user_by_email(
|
||||
impersonate_request.email, tenant_session
|
||||
)
|
||||
if user_to_impersonate is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
detail = (
|
||||
f"User not found in tenant: {impersonate_request.email=} {tenant_id=}"
|
||||
)
|
||||
logger.warning(detail)
|
||||
raise HTTPException(status_code=422, detail=detail)
|
||||
|
||||
token = await get_redis_strategy().write_token(user_to_impersonate)
|
||||
|
||||
response = await auth_backend.transport.get_login_response(token)
|
||||
|
||||
@@ -47,10 +47,10 @@ def get_tenant_id_for_email(email: str) -> str:
|
||||
mapping.active = True
|
||||
db_session.commit()
|
||||
tenant_id = mapping.tenant_id
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error getting tenant id for email {email}: {e}")
|
||||
raise exceptions.UserNotExists()
|
||||
|
||||
if tenant_id is None:
|
||||
raise exceptions.UserNotExists()
|
||||
return tenant_id
|
||||
|
||||
@@ -92,6 +92,7 @@ def format_embedding_error(
|
||||
service_name: str,
|
||||
model: str | None,
|
||||
provider: EmbeddingProvider,
|
||||
sanitized_api_key: str | None = None,
|
||||
status_code: int | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
@@ -103,6 +104,7 @@ def format_embedding_error(
|
||||
f"{'HTTP error' if status_code else 'Exception'} embedding text with {service_name} - {detail}: "
|
||||
f"Model: {model} "
|
||||
f"Provider: {provider} "
|
||||
f"API Key: {sanitized_api_key} "
|
||||
f"Exception: {error}"
|
||||
)
|
||||
|
||||
@@ -133,6 +135,7 @@ class CloudEmbedding:
|
||||
self.timeout = timeout
|
||||
self.http_client = httpx.AsyncClient(timeout=timeout)
|
||||
self._closed = False
|
||||
self.sanitized_api_key = api_key[:4] + "********" + api_key[-4:]
|
||||
|
||||
async def _embed_openai(
|
||||
self, texts: list[str], model: str | None, reduced_dimension: int | None
|
||||
@@ -306,6 +309,7 @@ class CloudEmbedding:
|
||||
str(self.provider),
|
||||
model_name or deployment_name,
|
||||
self.provider,
|
||||
sanitized_api_key=self.sanitized_api_key,
|
||||
status_code=e.response.status_code,
|
||||
)
|
||||
logger.error(error_string)
|
||||
@@ -317,7 +321,11 @@ class CloudEmbedding:
|
||||
raise AuthenticationError(provider=str(self.provider))
|
||||
|
||||
error_string = format_embedding_error(
|
||||
e, str(self.provider), model_name or deployment_name, self.provider
|
||||
e,
|
||||
str(self.provider),
|
||||
model_name or deployment_name,
|
||||
self.provider,
|
||||
sanitized_api_key=self.sanitized_api_key,
|
||||
)
|
||||
logger.error(error_string)
|
||||
logger.debug(f"Exception texts: {texts}")
|
||||
|
||||
@@ -11,7 +11,7 @@ class ExternalAccess:
|
||||
|
||||
# arbitrary limit to prevent excessively large permissions sets
|
||||
# not internally enforced ... the caller can check this before using the instance
|
||||
MAX_NUM_ENTRIES = 1000
|
||||
MAX_NUM_ENTRIES = 5000
|
||||
|
||||
# Emails of external users with access to the doc externally
|
||||
external_user_emails: set[str]
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
@@ -12,6 +11,7 @@ from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
|
||||
trim_prompt_piece,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.prompts.agents.dc_prompts import DC_FORMATTING_NO_BASE_DATA_PROMPT
|
||||
@@ -113,42 +113,20 @@ def consolidate_research(
|
||||
)
|
||||
]
|
||||
|
||||
dispatch_timings: list[float] = []
|
||||
|
||||
primary_model = graph_config.tooling.primary_llm
|
||||
|
||||
def stream_initial_answer() -> list[str]:
|
||||
response: list[str] = []
|
||||
for message in primary_model.stream(msg, timeout_override=30, max_tokens=None):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
start_stream_token = datetime.now()
|
||||
|
||||
write_custom_event(
|
||||
"initial_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=0,
|
||||
level_question_num=0,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
end_stream_token = datetime.now()
|
||||
dispatch_timings.append(
|
||||
(end_stream_token - start_stream_token).microseconds
|
||||
)
|
||||
response.append(content)
|
||||
return response
|
||||
|
||||
try:
|
||||
_ = run_with_timeout(
|
||||
60,
|
||||
stream_initial_answer,
|
||||
lambda: stream_llm_answer(
|
||||
llm=graph_config.tooling.primary_llm,
|
||||
prompt=msg,
|
||||
event_name="initial_agent_answer",
|
||||
writer=writer,
|
||||
agent_answer_level=0,
|
||||
agent_answer_question_num=0,
|
||||
agent_answer_type="agent_level_answer",
|
||||
timeout_override=30,
|
||||
max_tokens=None,
|
||||
),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -30,6 +30,7 @@ from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
LLM_ANSWER_ERROR_MESSAGE,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_answer_citation_ids
|
||||
@@ -112,44 +113,23 @@ def generate_sub_answer(
|
||||
config=fast_llm.config,
|
||||
)
|
||||
|
||||
dispatch_timings: list[float] = []
|
||||
agent_error: AgentErrorLog | None = None
|
||||
response: list[str] = []
|
||||
|
||||
def stream_sub_answer() -> list[str]:
|
||||
for message in fast_llm.stream(
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION,
|
||||
max_tokens=AGENT_MAX_TOKENS_SUBANSWER_GENERATION,
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
start_stream_token = datetime.now()
|
||||
write_custom_event(
|
||||
"sub_answers",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=level,
|
||||
level_question_num=question_num,
|
||||
answer_type="agent_sub_answer",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
end_stream_token = datetime.now()
|
||||
dispatch_timings.append(
|
||||
(end_stream_token - start_stream_token).microseconds
|
||||
)
|
||||
response.append(content)
|
||||
return response
|
||||
|
||||
try:
|
||||
response = run_with_timeout(
|
||||
response, _ = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION,
|
||||
stream_sub_answer,
|
||||
lambda: stream_llm_answer(
|
||||
llm=fast_llm,
|
||||
prompt=msg,
|
||||
event_name="sub_answers",
|
||||
writer=writer,
|
||||
agent_answer_level=level,
|
||||
agent_answer_question_num=question_num,
|
||||
agent_answer_type="agent_sub_answer",
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION,
|
||||
max_tokens=AGENT_MAX_TOKENS_SUBANSWER_GENERATION,
|
||||
),
|
||||
)
|
||||
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
|
||||
@@ -37,6 +37,7 @@ from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AgentLLMErrorType,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
|
||||
@@ -275,46 +276,24 @@ def generate_initial_answer(
|
||||
|
||||
agent_error: AgentErrorLog | None = None
|
||||
|
||||
def stream_initial_answer() -> list[str]:
|
||||
response: list[str] = []
|
||||
for message in model.stream(
|
||||
msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
|
||||
max_tokens=(
|
||||
AGENT_MAX_TOKENS_ANSWER_GENERATION
|
||||
if _should_restrict_tokens(model.config)
|
||||
else None
|
||||
),
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
start_stream_token = datetime.now()
|
||||
|
||||
write_custom_event(
|
||||
"initial_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=0,
|
||||
level_question_num=0,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
end_stream_token = datetime.now()
|
||||
dispatch_timings.append(
|
||||
(end_stream_token - start_stream_token).microseconds
|
||||
)
|
||||
response.append(content)
|
||||
return response
|
||||
|
||||
try:
|
||||
streamed_tokens = run_with_timeout(
|
||||
streamed_tokens, dispatch_timings = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION,
|
||||
stream_initial_answer,
|
||||
lambda: stream_llm_answer(
|
||||
llm=model,
|
||||
prompt=msg,
|
||||
event_name="initial_agent_answer",
|
||||
writer=writer,
|
||||
agent_answer_level=0,
|
||||
agent_answer_question_num=0,
|
||||
agent_answer_type="agent_level_answer",
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
|
||||
max_tokens=(
|
||||
AGENT_MAX_TOKENS_ANSWER_GENERATION
|
||||
if _should_restrict_tokens(model.config)
|
||||
else None
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
|
||||
@@ -40,6 +40,7 @@ from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
from onyx.agents.agent_search.shared_graph_utils.constants import (
|
||||
AgentLLMErrorType,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
|
||||
from onyx.agents.agent_search.shared_graph_utils.models import RefinedAgentStats
|
||||
@@ -63,7 +64,6 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
remove_document_citations,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.configs.agent_configs import AGENT_ANSWER_GENERATION_BY_FAST_LLM
|
||||
@@ -301,45 +301,24 @@ def generate_validate_refined_answer(
|
||||
dispatch_timings: list[float] = []
|
||||
agent_error: AgentErrorLog | None = None
|
||||
|
||||
def stream_refined_answer() -> list[str]:
|
||||
for message in model.stream(
|
||||
msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION,
|
||||
max_tokens=(
|
||||
AGENT_MAX_TOKENS_ANSWER_GENERATION
|
||||
if _should_restrict_tokens(model.config)
|
||||
else None
|
||||
),
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
|
||||
start_stream_token = datetime.now()
|
||||
write_custom_event(
|
||||
"refined_agent_answer",
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=1,
|
||||
level_question_num=0,
|
||||
answer_type="agent_level_answer",
|
||||
),
|
||||
writer,
|
||||
)
|
||||
end_stream_token = datetime.now()
|
||||
dispatch_timings.append(
|
||||
(end_stream_token - start_stream_token).microseconds
|
||||
)
|
||||
streamed_tokens.append(content)
|
||||
return streamed_tokens
|
||||
|
||||
try:
|
||||
streamed_tokens = run_with_timeout(
|
||||
streamed_tokens, dispatch_timings = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION,
|
||||
stream_refined_answer,
|
||||
lambda: stream_llm_answer(
|
||||
llm=model,
|
||||
prompt=msg,
|
||||
event_name="refined_agent_answer",
|
||||
writer=writer,
|
||||
agent_answer_level=1,
|
||||
agent_answer_question_num=0,
|
||||
agent_answer_type="agent_level_answer",
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION,
|
||||
max_tokens=(
|
||||
AGENT_MAX_TOKENS_ANSWER_GENERATION
|
||||
if _should_restrict_tokens(model.config)
|
||||
else None
|
||||
),
|
||||
),
|
||||
)
|
||||
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
|
||||
68
backend/onyx/agents/agent_search/shared_graph_utils/llm.py
Normal file
68
backend/onyx/agents/agent_search/shared_graph_utils/llm.py
Normal file
@@ -0,0 +1,68 @@
|
||||
from datetime import datetime
|
||||
from typing import Literal
|
||||
|
||||
from langchain.schema.language_model import LanguageModelInput
|
||||
from langgraph.types import StreamWriter
|
||||
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.llm.interfaces import LLM
|
||||
|
||||
|
||||
def stream_llm_answer(
|
||||
llm: LLM,
|
||||
prompt: LanguageModelInput,
|
||||
event_name: str,
|
||||
writer: StreamWriter,
|
||||
agent_answer_level: int,
|
||||
agent_answer_question_num: int,
|
||||
agent_answer_type: Literal["agent_level_answer", "agent_sub_answer"],
|
||||
timeout_override: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> tuple[list[str], list[float]]:
|
||||
"""Stream the initial answer from the LLM.
|
||||
|
||||
Args:
|
||||
llm: The LLM to use.
|
||||
prompt: The prompt to use.
|
||||
event_name: The name of the event to write.
|
||||
writer: The writer to write to.
|
||||
agent_answer_level: The level of the agent answer.
|
||||
agent_answer_question_num: The question number within the level.
|
||||
agent_answer_type: The type of answer ("agent_level_answer" or "agent_sub_answer").
|
||||
timeout_override: The LLM timeout to use.
|
||||
max_tokens: The LLM max tokens to use.
|
||||
|
||||
Returns:
|
||||
A tuple of the response and the dispatch timings.
|
||||
"""
|
||||
response: list[str] = []
|
||||
dispatch_timings: list[float] = []
|
||||
|
||||
for message in llm.stream(
|
||||
prompt, timeout_override=timeout_override, max_tokens=max_tokens
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
if not isinstance(content, str):
|
||||
raise ValueError(
|
||||
f"Expected content to be a string, but got {type(content)}"
|
||||
)
|
||||
|
||||
start_stream_token = datetime.now()
|
||||
write_custom_event(
|
||||
event_name,
|
||||
AgentAnswerPiece(
|
||||
answer_piece=content,
|
||||
level=agent_answer_level,
|
||||
level_question_num=agent_answer_question_num,
|
||||
answer_type=agent_answer_type,
|
||||
),
|
||||
writer,
|
||||
)
|
||||
end_stream_token = datetime.now()
|
||||
|
||||
dispatch_timings.append((end_stream_token - start_stream_token).microseconds)
|
||||
response.append(content)
|
||||
|
||||
return response, dispatch_timings
|
||||
@@ -76,10 +76,11 @@ def hash_api_key(api_key: str) -> str:
|
||||
# and overlaps are impossible
|
||||
if api_key.startswith(_API_KEY_PREFIX):
|
||||
return hashlib.sha256(api_key.encode("utf-8")).hexdigest()
|
||||
elif api_key.startswith(_DEPRECATED_API_KEY_PREFIX):
|
||||
|
||||
if api_key.startswith(_DEPRECATED_API_KEY_PREFIX):
|
||||
return _deprecated_hash_api_key(api_key)
|
||||
else:
|
||||
raise ValueError(f"Invalid API key prefix: {api_key[:3]}")
|
||||
|
||||
raise ValueError(f"Invalid API key prefix: {api_key[:3]}")
|
||||
|
||||
|
||||
def build_displayable_api_key(api_key: str) -> str:
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import sentry_sdk
|
||||
from celery import bootsteps # type: ignore
|
||||
from celery import Task
|
||||
from celery.app import trace
|
||||
from celery.exceptions import WorkerShutdown
|
||||
@@ -22,6 +23,7 @@ from sqlalchemy.orm import Session
|
||||
from onyx.background.celery.apps.task_formatters import CeleryTaskColoredFormatter
|
||||
from onyx.background.celery.apps.task_formatters import CeleryTaskPlainFormatter
|
||||
from onyx.background.celery.celery_utils import celery_is_worker_primary
|
||||
from onyx.background.celery.celery_utils import make_probe_path
|
||||
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.engine import get_sqlalchemy_engine
|
||||
@@ -340,10 +342,23 @@ def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
task_logger.info("worker_ready signal received.")
|
||||
|
||||
# file based way to do readiness/liveness probes
|
||||
# https://medium.com/ambient-innovation/health-checks-for-celery-in-kubernetes-cf3274a3e106
|
||||
# https://github.com/celery/celery/issues/4079#issuecomment-1270085680
|
||||
|
||||
hostname: str = cast(str, sender.hostname)
|
||||
path = make_probe_path("readiness", hostname)
|
||||
path.touch()
|
||||
logger.info(f"Readiness signal touched at {path}.")
|
||||
|
||||
|
||||
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
HttpxPool.close_all()
|
||||
|
||||
hostname: str = cast(str, sender.hostname)
|
||||
path = make_probe_path("readiness", hostname)
|
||||
path.unlink(missing_ok=True)
|
||||
|
||||
if not celery_is_worker_primary(sender):
|
||||
return
|
||||
|
||||
@@ -483,3 +498,34 @@ def wait_for_vespa_or_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
msg = "Vespa: Readiness probe did not succeed within the timeout. Exiting..."
|
||||
logger.error(msg)
|
||||
raise WorkerShutdown(msg)
|
||||
|
||||
|
||||
# File for validating worker liveness
|
||||
class LivenessProbe(bootsteps.StartStopStep):
|
||||
requires = {"celery.worker.components:Timer"}
|
||||
|
||||
def __init__(self, worker: Any, **kwargs: Any) -> None:
|
||||
super().__init__(worker, **kwargs)
|
||||
self.requests: list[Any] = []
|
||||
self.task_tref = None
|
||||
self.path = make_probe_path("liveness", worker.hostname)
|
||||
|
||||
def start(self, worker: Any) -> None:
|
||||
self.task_tref = worker.timer.call_repeatedly(
|
||||
15.0,
|
||||
self.update_liveness_file,
|
||||
(worker,),
|
||||
priority=10,
|
||||
)
|
||||
|
||||
def stop(self, worker: Any) -> None:
|
||||
self.path.unlink(missing_ok=True)
|
||||
if self.task_tref:
|
||||
self.task_tref.cancel()
|
||||
|
||||
def update_liveness_file(self, worker: Any) -> None:
|
||||
self.path.touch()
|
||||
|
||||
|
||||
def get_bootsteps() -> list[type]:
|
||||
return [LivenessProbe]
|
||||
|
||||
@@ -8,6 +8,7 @@ from celery.signals import beat_init
|
||||
from celery.utils.log import get_task_logger
|
||||
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.background.celery.celery_utils import make_probe_path
|
||||
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
|
||||
from onyx.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
|
||||
from onyx.db.engine import get_all_tenant_ids
|
||||
@@ -45,6 +46,8 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
f"DynamicTenantScheduler initialized: reload_interval={self._reload_interval}"
|
||||
)
|
||||
|
||||
self._liveness_probe_path = make_probe_path("liveness", "beat@hostname")
|
||||
|
||||
# do not set the initial schedule here because we don't have db access yet.
|
||||
# do it in beat_init after the db engine is initialized
|
||||
|
||||
@@ -62,6 +65,8 @@ class DynamicTenantScheduler(PersistentScheduler):
|
||||
or (now - self._last_reload) > self._reload_interval
|
||||
):
|
||||
task_logger.debug("Reload interval reached, initiating task update")
|
||||
self._liveness_probe_path.touch()
|
||||
|
||||
try:
|
||||
self._try_updating_schedule()
|
||||
except (AttributeError, KeyError):
|
||||
@@ -241,6 +246,9 @@ def on_beat_init(sender: Any, **kwargs: Any) -> None:
|
||||
SqlEngine.init_engine(pool_size=2, max_overflow=0)
|
||||
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
path = make_probe_path("readiness", "beat@hostname")
|
||||
path.touch()
|
||||
task_logger.info(f"Readiness signal touched at {path}.")
|
||||
|
||||
# first time init of the scheduler after db has been init'ed
|
||||
scheduler: DynamicTenantScheduler = sender.scheduler
|
||||
|
||||
@@ -91,6 +91,10 @@ def on_setup_logging(
|
||||
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||
|
||||
|
||||
base_bootsteps = app_base.get_bootsteps()
|
||||
for bootstep in base_bootsteps:
|
||||
celery_app.steps["worker"].add(bootstep)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
|
||||
@@ -102,6 +102,10 @@ def on_setup_logging(
|
||||
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||
|
||||
|
||||
base_bootsteps = app_base.get_bootsteps()
|
||||
for bootstep in base_bootsteps:
|
||||
celery_app.steps["worker"].add(bootstep)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"onyx.background.celery.tasks.indexing",
|
||||
|
||||
@@ -105,6 +105,10 @@ def on_setup_logging(
|
||||
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||
|
||||
|
||||
base_bootsteps = app_base.get_bootsteps()
|
||||
for bootstep in base_bootsteps:
|
||||
celery_app.steps["worker"].add(bootstep)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"onyx.background.celery.tasks.shared",
|
||||
|
||||
@@ -89,6 +89,10 @@ def on_setup_logging(
|
||||
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
|
||||
|
||||
|
||||
base_bootsteps = app_base.get_bootsteps()
|
||||
for bootstep in base_bootsteps:
|
||||
celery_app.steps["worker"].add(bootstep)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"onyx.background.celery.tasks.monitoring",
|
||||
|
||||
@@ -284,6 +284,10 @@ class HubPeriodicTask(bootsteps.StartStopStep):
|
||||
|
||||
celery_app.steps["worker"].add(HubPeriodicTask)
|
||||
|
||||
base_bootsteps = app_base.get_bootsteps()
|
||||
for bootstep in base_bootsteps:
|
||||
celery_app.steps["worker"].add(bootstep)
|
||||
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"onyx.background.celery.tasks.connector_deletion",
|
||||
|
||||
55
backend/onyx/background/celery/celery_k8s_probe.py
Normal file
55
backend/onyx/background/celery/celery_k8s_probe.py
Normal file
@@ -0,0 +1,55 @@
|
||||
# script to use as a kubernetes readiness / liveness probe
|
||||
|
||||
import argparse
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
|
||||
def main_readiness(filename: str) -> int:
|
||||
"""Checks if the file exists."""
|
||||
path = Path(filename)
|
||||
if not path.is_file():
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def main_liveness(filename: str) -> int:
|
||||
"""Checks if the file exists AND was recently modified."""
|
||||
path = Path(filename)
|
||||
if not path.is_file():
|
||||
return 1
|
||||
|
||||
stats = path.stat()
|
||||
liveness_timestamp = stats.st_mtime
|
||||
current_timestamp = time.time()
|
||||
time_diff = current_timestamp - liveness_timestamp
|
||||
if time_diff > 60:
|
||||
return 1
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit_code: int
|
||||
|
||||
parser = argparse.ArgumentParser(description="k8s readiness/liveness probe")
|
||||
parser.add_argument(
|
||||
"--probe",
|
||||
type=str,
|
||||
choices=["readiness", "liveness"],
|
||||
help="The type of probe",
|
||||
required=True,
|
||||
)
|
||||
parser.add_argument("--filename", help="The filename to watch", required=True)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.probe == "readiness":
|
||||
exit_code = main_readiness(args.filename)
|
||||
elif args.probe == "liveness":
|
||||
exit_code = main_liveness(args.filename)
|
||||
else:
|
||||
raise ValueError(f"Unknown probe type: {args.probe}")
|
||||
|
||||
sys.exit(exit_code)
|
||||
@@ -1,5 +1,6 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
@@ -121,3 +122,20 @@ def httpx_init_vespa_pool(
|
||||
http2=False,
|
||||
limits=httpx.Limits(max_keepalive_connections=max_keepalive_connections),
|
||||
)
|
||||
|
||||
|
||||
def make_probe_path(probe: str, hostname: str) -> Path:
|
||||
"""templates the path for a k8s probe file.
|
||||
|
||||
e.g. /tmp/onyx_k8s_indexing_readiness.txt
|
||||
"""
|
||||
hostname_parts = hostname.split("@")
|
||||
if len(hostname_parts) != 2:
|
||||
raise ValueError(f"hostname could not be split! {hostname=}")
|
||||
|
||||
name = hostname_parts[0]
|
||||
if not name:
|
||||
raise ValueError(f"name cannot be empty! {name=}")
|
||||
|
||||
safe_name = "".join(c for c in name if c.isalnum()).rstrip()
|
||||
return Path(f"/tmp/onyx_k8s_{safe_name}_{probe}.txt")
|
||||
|
||||
@@ -43,6 +43,7 @@ from onyx.chat.models import UserKnowledgeFilePacket
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
|
||||
from onyx.chat.user_files.parse_user_files import parse_user_files
|
||||
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from onyx.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
|
||||
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
@@ -52,11 +53,9 @@ from onyx.configs.constants import BASIC_KEY
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.configs.constants import NO_AUTH_USER_ID
|
||||
from onyx.context.search.enums import LLMEvaluationType
|
||||
from onyx.context.search.enums import OptionalSearchSetting
|
||||
from onyx.context.search.enums import QueryFlow
|
||||
from onyx.context.search.enums import SearchType
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.context.search.retrieval.search_runner import (
|
||||
@@ -95,9 +94,7 @@ from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.file_store.utils import get_user_files
|
||||
from onyx.file_store.utils import load_all_chat_files
|
||||
from onyx.file_store.utils import load_in_memory_chat_files
|
||||
from onyx.file_store.utils import save_files
|
||||
from onyx.llm.exceptions import GenAIDisabledException
|
||||
from onyx.llm.factory import get_llms_for_persona
|
||||
@@ -312,8 +309,7 @@ def _handle_internet_search_tool_response_summary(
|
||||
def _get_force_search_settings(
|
||||
new_msg_req: CreateChatMessageRequest,
|
||||
tools: list[Tool],
|
||||
user_file_ids: list[int],
|
||||
user_folder_ids: list[int],
|
||||
search_tool_override_kwargs: SearchToolOverrideKwargs | None,
|
||||
) -> ForceUseTool:
|
||||
internet_search_available = any(
|
||||
isinstance(tool, InternetSearchTool) for tool in tools
|
||||
@@ -321,45 +317,24 @@ def _get_force_search_settings(
|
||||
search_tool_available = any(isinstance(tool, SearchTool) for tool in tools)
|
||||
|
||||
if not internet_search_available and not search_tool_available:
|
||||
if new_msg_req.force_user_file_search:
|
||||
return ForceUseTool(force_use=True, tool_name=SearchTool._NAME)
|
||||
else:
|
||||
# Does not matter much which tool is set here as force is false and neither tool is available
|
||||
return ForceUseTool(force_use=False, tool_name=SearchTool._NAME)
|
||||
|
||||
tool_name = SearchTool._NAME if search_tool_available else InternetSearchTool._NAME
|
||||
# Does not matter much which tool is set here as force is false and neither tool is available
|
||||
return ForceUseTool(force_use=False, tool_name=SearchTool._NAME)
|
||||
# Currently, the internet search tool does not support query override
|
||||
args = (
|
||||
{"query": new_msg_req.query_override}
|
||||
if new_msg_req.query_override and tool_name == SearchTool._NAME
|
||||
if new_msg_req.query_override and search_tool_available
|
||||
else None
|
||||
)
|
||||
|
||||
# Create override_kwargs for the search tool if user_file_ids are provided
|
||||
override_kwargs = None
|
||||
if (user_file_ids or user_folder_ids) and tool_name == SearchTool._NAME:
|
||||
override_kwargs = SearchToolOverrideKwargs(
|
||||
force_no_rerank=False,
|
||||
alternate_db_session=None,
|
||||
retrieved_sections_callback=None,
|
||||
skip_query_analysis=False,
|
||||
user_file_ids=user_file_ids,
|
||||
user_folder_ids=user_folder_ids,
|
||||
)
|
||||
|
||||
if new_msg_req.file_descriptors:
|
||||
# If user has uploaded files they're using, don't run any of the search tools
|
||||
return ForceUseTool(force_use=False, tool_name=tool_name)
|
||||
|
||||
should_force_search = any(
|
||||
[
|
||||
new_msg_req.force_user_file_search,
|
||||
new_msg_req.retrieval_options
|
||||
and new_msg_req.retrieval_options.run_search
|
||||
== OptionalSearchSetting.ALWAYS,
|
||||
new_msg_req.search_doc_ids,
|
||||
new_msg_req.query_override is not None,
|
||||
DISABLE_LLM_CHOOSE_SEARCH,
|
||||
search_tool_override_kwargs is not None,
|
||||
]
|
||||
)
|
||||
|
||||
@@ -369,62 +344,18 @@ def _get_force_search_settings(
|
||||
|
||||
return ForceUseTool(
|
||||
force_use=True,
|
||||
tool_name=tool_name,
|
||||
tool_name=SearchTool._NAME,
|
||||
args=args,
|
||||
override_kwargs=override_kwargs,
|
||||
override_kwargs=search_tool_override_kwargs,
|
||||
)
|
||||
|
||||
return ForceUseTool(
|
||||
force_use=False, tool_name=tool_name, args=args, override_kwargs=override_kwargs
|
||||
)
|
||||
|
||||
|
||||
def _get_user_knowledge_files(
|
||||
info: AnswerPostInfo,
|
||||
user_files: list[InMemoryChatFile],
|
||||
file_id_to_user_file: dict[str, InMemoryChatFile],
|
||||
) -> Generator[UserKnowledgeFilePacket, None, None]:
|
||||
if not info.qa_docs_response:
|
||||
return
|
||||
|
||||
logger.info(
|
||||
f"ORDERING: Processing search results for ordering {len(user_files)} user files"
|
||||
)
|
||||
|
||||
# Extract document order from search results
|
||||
doc_order = []
|
||||
for doc in info.qa_docs_response.top_documents:
|
||||
doc_id = doc.document_id
|
||||
if str(doc_id).startswith("USER_FILE_CONNECTOR__"):
|
||||
file_id = doc_id.replace("USER_FILE_CONNECTOR__", "")
|
||||
if file_id in file_id_to_user_file:
|
||||
doc_order.append(file_id)
|
||||
|
||||
logger.info(f"ORDERING: Found {len(doc_order)} files from search results")
|
||||
|
||||
# Add any files that weren't in search results at the end
|
||||
missing_files = [
|
||||
f_id for f_id in file_id_to_user_file.keys() if f_id not in doc_order
|
||||
]
|
||||
|
||||
missing_files.extend(doc_order)
|
||||
doc_order = missing_files
|
||||
|
||||
logger.info(f"ORDERING: Added {len(missing_files)} missing files to the end")
|
||||
|
||||
# Reorder user files based on search results
|
||||
ordered_user_files = [
|
||||
file_id_to_user_file[f_id] for f_id in doc_order if f_id in file_id_to_user_file
|
||||
]
|
||||
|
||||
yield UserKnowledgeFilePacket(
|
||||
user_files=[
|
||||
FileDescriptor(
|
||||
id=str(file.file_id),
|
||||
type=ChatFileType.USER_KNOWLEDGE,
|
||||
)
|
||||
for file in ordered_user_files
|
||||
]
|
||||
force_use=False,
|
||||
tool_name=(
|
||||
SearchTool._NAME if search_tool_available else InternetSearchTool._NAME
|
||||
),
|
||||
args=args,
|
||||
override_kwargs=None,
|
||||
)
|
||||
|
||||
|
||||
@@ -488,8 +419,6 @@ def _process_tool_response(
|
||||
retrieval_options: RetrievalDetails | None,
|
||||
user_file_files: list[UserFile] | None,
|
||||
user_files: list[InMemoryChatFile] | None,
|
||||
file_id_to_user_file: dict[str, InMemoryChatFile],
|
||||
search_for_ordering_only: bool,
|
||||
) -> Generator[ChatPacket, None, dict[SubQuestionKey, AnswerPostInfo]]:
|
||||
level, level_question_num = (
|
||||
(packet.level, packet.level_question_num)
|
||||
@@ -501,21 +430,8 @@ def _process_tool_response(
|
||||
assert level_question_num is not None
|
||||
info = info_by_subq[SubQuestionKey(level=level, question_num=level_question_num)]
|
||||
|
||||
# Skip LLM relevance processing entirely for ordering-only mode
|
||||
if search_for_ordering_only and packet.id == SECTION_RELEVANCE_LIST_ID:
|
||||
logger.info(
|
||||
"Fast path: Completely bypassing section relevance processing for ordering-only mode"
|
||||
)
|
||||
# Skip this packet entirely since it would trigger LLM processing
|
||||
return info_by_subq
|
||||
|
||||
# TODO: don't need to dedupe here when we do it in agent flow
|
||||
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
if search_for_ordering_only:
|
||||
logger.info(
|
||||
"Fast path: Skipping document deduplication for ordering-only mode"
|
||||
)
|
||||
|
||||
(
|
||||
info.qa_docs_response,
|
||||
info.reference_db_search_docs,
|
||||
@@ -525,34 +441,15 @@ def _process_tool_response(
|
||||
db_session=db_session,
|
||||
selected_search_docs=selected_db_search_docs,
|
||||
# Deduping happens at the last step to avoid harming quality by dropping content early on
|
||||
# Skip deduping completely for ordering-only mode to save time
|
||||
dedupe_docs=bool(
|
||||
not search_for_ordering_only
|
||||
and retrieval_options
|
||||
and retrieval_options.dedupe_docs
|
||||
),
|
||||
user_files=user_file_files if search_for_ordering_only else [],
|
||||
loaded_user_files=(user_files if search_for_ordering_only else []),
|
||||
dedupe_docs=bool(retrieval_options and retrieval_options.dedupe_docs),
|
||||
user_files=[],
|
||||
loaded_user_files=[],
|
||||
)
|
||||
|
||||
# If we're using search just for ordering user files
|
||||
if search_for_ordering_only and user_files:
|
||||
yield from _get_user_knowledge_files(
|
||||
info=info,
|
||||
user_files=user_files,
|
||||
file_id_to_user_file=file_id_to_user_file,
|
||||
)
|
||||
|
||||
yield info.qa_docs_response
|
||||
elif packet.id == SECTION_RELEVANCE_LIST_ID:
|
||||
relevance_sections = packet.response
|
||||
|
||||
if search_for_ordering_only:
|
||||
logger.info(
|
||||
"Performance: Skipping relevance filtering for ordering-only mode"
|
||||
)
|
||||
return info_by_subq
|
||||
|
||||
if info.reference_db_search_docs is None:
|
||||
logger.warning("No reference docs found for relevance filtering")
|
||||
return info_by_subq
|
||||
@@ -665,8 +562,6 @@ def stream_chat_message_objects(
|
||||
|
||||
try:
|
||||
# Move these variables inside the try block
|
||||
file_id_to_user_file = {}
|
||||
|
||||
user_id = user.id if user is not None else None
|
||||
|
||||
chat_session = get_chat_session_by_id(
|
||||
@@ -840,60 +735,23 @@ def stream_chat_message_objects(
|
||||
for folder in persona.user_folders:
|
||||
user_folder_ids.append(folder.id)
|
||||
|
||||
# Initialize flag for user file search
|
||||
use_search_for_user_files = False
|
||||
|
||||
user_files: list[InMemoryChatFile] | None = None
|
||||
search_for_ordering_only = False
|
||||
user_file_files: list[UserFile] | None = None
|
||||
if user_file_ids or user_folder_ids:
|
||||
# Load user files
|
||||
user_files = load_in_memory_chat_files(
|
||||
user_file_ids or [],
|
||||
user_folder_ids or [],
|
||||
db_session,
|
||||
)
|
||||
user_file_files = get_user_files(
|
||||
user_file_ids or [],
|
||||
user_folder_ids or [],
|
||||
db_session,
|
||||
)
|
||||
# Store mapping of file_id to file for later reordering
|
||||
if user_files:
|
||||
file_id_to_user_file = {file.file_id: file for file in user_files}
|
||||
|
||||
# Calculate token count for the files
|
||||
from onyx.db.user_documents import calculate_user_files_token_count
|
||||
from onyx.chat.prompt_builder.citations_prompt import (
|
||||
compute_max_document_tokens_for_persona,
|
||||
)
|
||||
|
||||
total_tokens = calculate_user_files_token_count(
|
||||
user_file_ids or [],
|
||||
user_folder_ids or [],
|
||||
db_session,
|
||||
)
|
||||
|
||||
# Calculate available tokens for documents based on prompt, user input, etc.
|
||||
available_tokens = compute_max_document_tokens_for_persona(
|
||||
db_session=db_session,
|
||||
persona=persona,
|
||||
actual_user_input=message_text, # Use the actual user message
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Total file tokens: {total_tokens}, Available tokens: {available_tokens}"
|
||||
)
|
||||
|
||||
# ALWAYS use search for user files, but track if we need it for context or just ordering
|
||||
use_search_for_user_files = True
|
||||
# If files are small enough for context, we'll just use search for ordering
|
||||
search_for_ordering_only = total_tokens <= available_tokens
|
||||
|
||||
if search_for_ordering_only:
|
||||
# Add original user files to context since they fit
|
||||
if user_files:
|
||||
latest_query_files.extend(user_files)
|
||||
# Load in user files into memory and create search tool override kwargs if needed
|
||||
# if we have enough tokens and no folders, we don't need to use search
|
||||
# we can just pass them into the prompt directly
|
||||
(
|
||||
in_memory_user_files,
|
||||
user_file_models,
|
||||
search_tool_override_kwargs_for_user_files,
|
||||
) = parse_user_files(
|
||||
user_file_ids=user_file_ids,
|
||||
user_folder_ids=user_folder_ids,
|
||||
db_session=db_session,
|
||||
persona=persona,
|
||||
actual_user_input=message_text,
|
||||
user_id=user_id,
|
||||
)
|
||||
if not search_tool_override_kwargs_for_user_files:
|
||||
latest_query_files.extend(in_memory_user_files)
|
||||
|
||||
if user_message:
|
||||
attach_files_to_chat_message(
|
||||
@@ -1052,10 +910,13 @@ def stream_chat_message_objects(
|
||||
prompt_config=prompt_config,
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
user_knowledge_present=bool(user_files or user_folder_ids),
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
use_file_search=new_msg_req.force_user_file_search,
|
||||
run_search_setting=(
|
||||
retrieval_options.run_search
|
||||
if retrieval_options
|
||||
else OptionalSearchSetting.AUTO
|
||||
),
|
||||
search_tool_config=SearchToolConfig(
|
||||
answer_style_config=answer_style_config,
|
||||
document_pruning_config=document_pruning_config,
|
||||
@@ -1086,128 +947,23 @@ def stream_chat_message_objects(
|
||||
tools.extend(tool_list)
|
||||
|
||||
force_use_tool = _get_force_search_settings(
|
||||
new_msg_req, tools, user_file_ids, user_folder_ids
|
||||
new_msg_req, tools, search_tool_override_kwargs_for_user_files
|
||||
)
|
||||
|
||||
# Set force_use if user files exceed token limit
|
||||
if use_search_for_user_files:
|
||||
try:
|
||||
# Check if search tool is available in the tools list
|
||||
search_tool_available = any(
|
||||
isinstance(tool, SearchTool) for tool in tools
|
||||
)
|
||||
|
||||
# If no search tool is available, add one
|
||||
if not search_tool_available:
|
||||
logger.info("No search tool available, creating one for user files")
|
||||
# Create a basic search tool config
|
||||
search_tool_config = SearchToolConfig(
|
||||
answer_style_config=answer_style_config,
|
||||
document_pruning_config=document_pruning_config,
|
||||
retrieval_options=retrieval_options or RetrievalDetails(),
|
||||
)
|
||||
|
||||
# Create and add the search tool
|
||||
search_tool = SearchTool(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
persona=persona,
|
||||
retrieval_options=search_tool_config.retrieval_options,
|
||||
prompt_config=prompt_config,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
pruning_config=search_tool_config.document_pruning_config,
|
||||
answer_style_config=search_tool_config.answer_style_config,
|
||||
evaluation_type=(
|
||||
LLMEvaluationType.BASIC
|
||||
if persona.llm_relevance_filter
|
||||
else LLMEvaluationType.SKIP
|
||||
),
|
||||
bypass_acl=bypass_acl,
|
||||
)
|
||||
|
||||
# Add the search tool to the tools list
|
||||
tools.append(search_tool)
|
||||
|
||||
logger.info(
|
||||
"Added search tool for user files that exceed token limit"
|
||||
)
|
||||
|
||||
# Now set force_use_tool.force_use to True
|
||||
force_use_tool.force_use = True
|
||||
force_use_tool.tool_name = SearchTool._NAME
|
||||
|
||||
# Set query argument if not already set
|
||||
if not force_use_tool.args:
|
||||
force_use_tool.args = {"query": final_msg.message}
|
||||
|
||||
# Pass the user file IDs to the search tool
|
||||
if user_file_ids or user_folder_ids:
|
||||
# Create a BaseFilters object with user_file_ids
|
||||
if not retrieval_options:
|
||||
retrieval_options = RetrievalDetails()
|
||||
if not retrieval_options.filters:
|
||||
retrieval_options.filters = BaseFilters()
|
||||
|
||||
# Set user file and folder IDs in the filters
|
||||
retrieval_options.filters.user_file_ids = user_file_ids
|
||||
retrieval_options.filters.user_folder_ids = user_folder_ids
|
||||
|
||||
# Create override kwargs for the search tool
|
||||
|
||||
override_kwargs = SearchToolOverrideKwargs(
|
||||
force_no_rerank=search_for_ordering_only, # Skip reranking for ordering-only
|
||||
alternate_db_session=None,
|
||||
retrieved_sections_callback=None,
|
||||
skip_query_analysis=search_for_ordering_only, # Skip query analysis for ordering-only
|
||||
user_file_ids=user_file_ids,
|
||||
user_folder_ids=user_folder_ids,
|
||||
ordering_only=search_for_ordering_only, # Set ordering_only flag for fast path
|
||||
)
|
||||
|
||||
# Set the override kwargs in the force_use_tool
|
||||
force_use_tool.override_kwargs = override_kwargs
|
||||
|
||||
if search_for_ordering_only:
|
||||
logger.info(
|
||||
"Fast path: Configured search tool with optimized settings for ordering-only"
|
||||
)
|
||||
logger.info(
|
||||
"Fast path: Skipping reranking and query analysis for ordering-only mode"
|
||||
)
|
||||
logger.info(
|
||||
f"Using {len(user_file_ids or [])} files and {len(user_folder_ids or [])} folders"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
"Configured search tool to use ",
|
||||
f"{len(user_file_ids or [])} files and {len(user_folder_ids or [])} folders",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Error configuring search tool for user files: {str(e)}"
|
||||
)
|
||||
use_search_for_user_files = False
|
||||
|
||||
# TODO: unify message history with single message history
|
||||
message_history = [
|
||||
PreviousMessage.from_chat_message(msg, files) for msg in history_msgs
|
||||
]
|
||||
if not use_search_for_user_files and user_files:
|
||||
if not search_tool_override_kwargs_for_user_files and in_memory_user_files:
|
||||
yield UserKnowledgeFilePacket(
|
||||
user_files=[
|
||||
FileDescriptor(
|
||||
id=str(file.file_id), type=ChatFileType.USER_KNOWLEDGE
|
||||
id=str(file.file_id), type=file.file_type, name=file.filename
|
||||
)
|
||||
for file in user_files
|
||||
for file in in_memory_user_files
|
||||
]
|
||||
)
|
||||
|
||||
if search_for_ordering_only:
|
||||
logger.info(
|
||||
"Performance: Forcing LLMEvaluationType.SKIP to prevent chunk evaluation for ordering-only search"
|
||||
)
|
||||
|
||||
prompt_builder = AnswerPromptBuilder(
|
||||
user_message=default_build_user_message(
|
||||
user_query=final_msg.message,
|
||||
@@ -1265,10 +1021,8 @@ def stream_chat_message_objects(
|
||||
selected_db_search_docs=selected_db_search_docs,
|
||||
info_by_subq=info_by_subq,
|
||||
retrieval_options=retrieval_options,
|
||||
user_file_files=user_file_files,
|
||||
user_files=user_files,
|
||||
file_id_to_user_file=file_id_to_user_file,
|
||||
search_for_ordering_only=search_for_ordering_only,
|
||||
user_file_files=user_file_models,
|
||||
user_files=in_memory_user_files,
|
||||
)
|
||||
|
||||
elif isinstance(packet, StreamStopInfo):
|
||||
|
||||
@@ -9,12 +9,12 @@ from onyx.context.search.models import InferenceChunk
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.prompts import get_default_prompt
|
||||
from onyx.db.search_settings import get_multilingual_expansion
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.llm.factory import get_llms_for_persona
|
||||
from onyx.llm.factory import get_main_llm_from_tuple
|
||||
from onyx.llm.interfaces import LLMConfig
|
||||
from onyx.llm.utils import build_content_with_imgs
|
||||
from onyx.llm.utils import check_number_of_tokens
|
||||
from onyx.llm.utils import message_to_prompt_and_imgs
|
||||
from onyx.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT
|
||||
from onyx.prompts.constants import DEFAULT_IGNORE_STATEMENT
|
||||
from onyx.prompts.direct_qa_prompts import CITATIONS_PROMPT
|
||||
@@ -120,7 +120,8 @@ def build_citations_system_message(
|
||||
|
||||
|
||||
def build_citations_user_message(
|
||||
message: HumanMessage,
|
||||
user_query: str,
|
||||
files: list[InMemoryChatFile],
|
||||
prompt_config: PromptConfig,
|
||||
context_docs: list[LlmDoc] | list[InferenceChunk],
|
||||
all_doc_useful: bool,
|
||||
@@ -135,7 +136,6 @@ def build_citations_user_message(
|
||||
history_block = (
|
||||
HISTORY_BLOCK.format(history_str=history_message) if history_message else ""
|
||||
)
|
||||
query, img_urls = message_to_prompt_and_imgs(message)
|
||||
|
||||
if context_docs:
|
||||
context_docs_str = build_complete_context_str(context_docs)
|
||||
@@ -146,7 +146,7 @@ def build_citations_user_message(
|
||||
optional_ignore_statement=optional_ignore,
|
||||
context_docs_str=context_docs_str,
|
||||
task_prompt=task_prompt_with_reminder,
|
||||
user_query=query,
|
||||
user_query=user_query,
|
||||
history_block=history_block,
|
||||
)
|
||||
else:
|
||||
@@ -154,16 +154,17 @@ def build_citations_user_message(
|
||||
user_prompt = CITATIONS_PROMPT_FOR_TOOL_CALLING.format(
|
||||
context_type=context_type,
|
||||
task_prompt=task_prompt_with_reminder,
|
||||
user_query=query,
|
||||
user_query=user_query,
|
||||
history_block=history_block,
|
||||
)
|
||||
|
||||
user_prompt = user_prompt.strip()
|
||||
tag_handled_prompt = handle_onyx_date_awareness(user_prompt, prompt_config)
|
||||
user_msg = HumanMessage(
|
||||
content=(
|
||||
build_content_with_imgs(user_prompt, img_urls=img_urls)
|
||||
if img_urls
|
||||
else user_prompt
|
||||
build_content_with_imgs(tag_handled_prompt, files)
|
||||
if files
|
||||
else tag_handled_prompt
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
106
backend/onyx/chat/user_files/parse_user_files.py
Normal file
106
backend/onyx/chat/user_files/parse_user_files.py
Normal file
@@ -0,0 +1,106 @@
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.file_store.utils import get_user_files_as_user
|
||||
from onyx.file_store.utils import load_in_memory_chat_files
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def parse_user_files(
|
||||
user_file_ids: list[int],
|
||||
user_folder_ids: list[int],
|
||||
db_session: Session,
|
||||
persona: Persona,
|
||||
actual_user_input: str,
|
||||
# should only be None if auth is disabled
|
||||
user_id: UUID | None,
|
||||
) -> tuple[list[InMemoryChatFile], list[UserFile], SearchToolOverrideKwargs | None]:
|
||||
"""
|
||||
Parse user files and folders into in-memory chat files and create search tool override kwargs.
|
||||
Only creates SearchToolOverrideKwargs if token overflow occurs or folders are present.
|
||||
|
||||
Args:
|
||||
user_file_ids: List of user file IDs to load
|
||||
user_folder_ids: List of user folder IDs to load
|
||||
db_session: Database session
|
||||
persona: Persona to calculate available tokens
|
||||
actual_user_input: User's input message for token calculation
|
||||
user_id: User ID to validate file ownership
|
||||
|
||||
Returns:
|
||||
Tuple of (
|
||||
loaded user files,
|
||||
user file models,
|
||||
search tool override kwargs if token
|
||||
overflow or folders present
|
||||
)
|
||||
"""
|
||||
# Return empty results if no files or folders specified
|
||||
if not user_file_ids and not user_folder_ids:
|
||||
return [], [], None
|
||||
|
||||
# Load user files from the database into memory
|
||||
user_files = load_in_memory_chat_files(
|
||||
user_file_ids or [],
|
||||
user_folder_ids or [],
|
||||
db_session,
|
||||
)
|
||||
|
||||
user_file_models = get_user_files_as_user(
|
||||
user_file_ids or [],
|
||||
user_folder_ids or [],
|
||||
user_id,
|
||||
db_session,
|
||||
)
|
||||
|
||||
# Calculate token count for the files, need to import here to avoid circular import
|
||||
# TODO: fix this
|
||||
from onyx.db.user_documents import calculate_user_files_token_count
|
||||
from onyx.chat.prompt_builder.citations_prompt import (
|
||||
compute_max_document_tokens_for_persona,
|
||||
)
|
||||
|
||||
total_tokens = calculate_user_files_token_count(
|
||||
user_file_ids or [],
|
||||
user_folder_ids or [],
|
||||
db_session,
|
||||
)
|
||||
|
||||
# Calculate available tokens for documents based on prompt, user input, etc.
|
||||
available_tokens = compute_max_document_tokens_for_persona(
|
||||
db_session=db_session,
|
||||
persona=persona,
|
||||
actual_user_input=actual_user_input,
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Total file tokens: {total_tokens}, Available tokens: {available_tokens}"
|
||||
)
|
||||
|
||||
have_enough_tokens = total_tokens <= available_tokens
|
||||
|
||||
# If we have enough tokens and no folders, we don't need search
|
||||
# we can just pass them into the prompt directly
|
||||
if have_enough_tokens and not user_folder_ids:
|
||||
# No search tool override needed - files can be passed directly
|
||||
return user_files, user_file_models, None
|
||||
|
||||
# Token overflow or folders present - need to use search tool
|
||||
override_kwargs = SearchToolOverrideKwargs(
|
||||
force_no_rerank=have_enough_tokens,
|
||||
alternate_db_session=None,
|
||||
retrieved_sections_callback=None,
|
||||
skip_query_analysis=have_enough_tokens,
|
||||
user_file_ids=user_file_ids,
|
||||
user_folder_ids=user_folder_ids,
|
||||
)
|
||||
|
||||
return user_files, user_file_models, override_kwargs
|
||||
@@ -170,169 +170,169 @@ AGENT_MAX_STATIC_HISTORY_WORD_LENGTH = int(
|
||||
) # 2000
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION = 10 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION = 15 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION = 30 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION = 45 # in seconds
|
||||
AGENT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION
|
||||
)
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION = 3 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION = 5 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_DOCUMENT_VERIFICATION = 5 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_DOCUMENT_VERIFICATION = 8 # in seconds
|
||||
AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_DOCUMENT_VERIFICATION
|
||||
)
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_GENERAL_GENERATION = 5 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_GENERAL_GENERATION = 8 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_GENERAL_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_GENERAL_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_GENERAL_GENERATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_GENERAL_GENERATION = 30 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_GENERAL_GENERATION = 45 # in seconds
|
||||
AGENT_TIMEOUT_LLM_GENERAL_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_GENERAL_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_GENERAL_GENERATION
|
||||
)
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION = 4 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION = 8 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_SUBQUESTION_GENERATION = 5 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_SUBQUESTION_GENERATION = 10 # in seconds
|
||||
AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_SUBQUESTION_GENERATION
|
||||
)
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION = 6 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION = 9 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION = 40 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION = 45 # in seconds
|
||||
AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION
|
||||
)
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION = 10 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION = 15 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION = 25 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION = 40 # in seconds
|
||||
AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION
|
||||
)
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION = 15 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION = 20 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION = 45 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION = 60 # in seconds
|
||||
AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION
|
||||
)
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK = 4 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK = 6 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_CHECK = 8 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_CHECK = 12 # in seconds
|
||||
AGENT_TIMEOUT_LLM_SUBANSWER_CHECK = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_SUBANSWER_CHECK")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_CHECK
|
||||
)
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION = 4 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION = 6 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION = 8 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION = 12 # in seconds
|
||||
AGENT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION
|
||||
)
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION = 2 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION = 4 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION = 3 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION = 6 # in seconds
|
||||
AGENT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION
|
||||
)
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION = 4 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION = 6 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION = 5 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION = 8 # in seconds
|
||||
AGENT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION
|
||||
)
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS = 4 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS = 6 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_COMPARE_ANSWERS = 8 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_COMPARE_ANSWERS = 12 # in seconds
|
||||
AGENT_TIMEOUT_LLM_COMPARE_ANSWERS = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_COMPARE_ANSWERS")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_COMPARE_ANSWERS
|
||||
)
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION = 4 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION = 6 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION = 8 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION = 12 # in seconds
|
||||
AGENT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION
|
||||
|
||||
@@ -21,6 +21,9 @@ from onyx.connectors.confluence.utils import datetime_from_string
|
||||
from onyx.connectors.confluence.utils import process_attachment
|
||||
from onyx.connectors.confluence.utils import update_param_in_path
|
||||
from onyx.connectors.confluence.utils import validate_attachment_filetype
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
is_atlassian_date_error,
|
||||
)
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
from onyx.connectors.exceptions import InsufficientPermissionsError
|
||||
@@ -76,10 +79,6 @@ ONE_DAY = ONE_HOUR * 24
|
||||
MAX_CACHED_IDS = 100
|
||||
|
||||
|
||||
def _should_propagate_error(e: Exception) -> bool:
|
||||
return "field 'updated' is invalid" in str(e)
|
||||
|
||||
|
||||
class ConfluenceCheckpoint(ConnectorCheckpoint):
|
||||
|
||||
next_page_url: str | None
|
||||
@@ -367,7 +366,7 @@ class ConfluenceConnector(
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting page {page.get('id', 'unknown')}: {e}")
|
||||
if _should_propagate_error(e):
|
||||
if is_atlassian_date_error(e): # propagate error to be caught and retried
|
||||
raise
|
||||
return ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
@@ -446,7 +445,9 @@ class ConfluenceConnector(
|
||||
f"Failed to extract/summarize attachment {attachment['title']}",
|
||||
exc_info=e,
|
||||
)
|
||||
if _should_propagate_error(e):
|
||||
if is_atlassian_date_error(
|
||||
e
|
||||
): # propagate error to be caught and retried
|
||||
raise
|
||||
return ConnectorFailure(
|
||||
failed_document=DocumentFailure(
|
||||
@@ -536,7 +537,7 @@ class ConfluenceConnector(
|
||||
try:
|
||||
return self._fetch_document_batches(checkpoint, start, end)
|
||||
except Exception as e:
|
||||
if _should_propagate_error(e) and start is not None:
|
||||
if is_atlassian_date_error(e) and start is not None:
|
||||
logger.warning(
|
||||
"Confluence says we provided an invalid 'updated' field. This may indicate"
|
||||
"a real issue, but can also appear during edge cases like daylight"
|
||||
|
||||
@@ -86,3 +86,7 @@ def get_oauth_callback_uri(base_domain: str, connector_id: str) -> str:
|
||||
# Used for development
|
||||
base_domain = CONNECTOR_LOCALHOST_OVERRIDE
|
||||
return f"{base_domain.strip('/')}/connector/oauth/callback/{connector_id}"
|
||||
|
||||
|
||||
def is_atlassian_date_error(e: Exception) -> bool:
|
||||
return "field 'updated' is invalid" in str(e)
|
||||
|
||||
@@ -14,6 +14,7 @@ from github import RateLimitExceededException
|
||||
from github import Repository
|
||||
from github.GithubException import GithubException
|
||||
from github.Issue import Issue
|
||||
from github.NamedUser import NamedUser
|
||||
from github.PaginatedList import PaginatedList
|
||||
from github.PullRequest import PullRequest
|
||||
from github.Requester import Requester
|
||||
@@ -219,6 +220,18 @@ def _get_batch_rate_limited(
|
||||
)
|
||||
|
||||
|
||||
def _get_userinfo(user: NamedUser) -> dict[str, str]:
|
||||
return {
|
||||
k: v
|
||||
for k, v in {
|
||||
"login": user.login,
|
||||
"name": user.name,
|
||||
"email": user.email,
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
|
||||
def _convert_pr_to_document(pull_request: PullRequest) -> Document:
|
||||
return Document(
|
||||
id=pull_request.html_url,
|
||||
@@ -226,7 +239,7 @@ def _convert_pr_to_document(pull_request: PullRequest) -> Document:
|
||||
TextSection(link=pull_request.html_url, text=pull_request.body or "")
|
||||
],
|
||||
source=DocumentSource.GITHUB,
|
||||
semantic_identifier=pull_request.title,
|
||||
semantic_identifier=f"{pull_request.number}: {pull_request.title}",
|
||||
# updated_at is UTC time but is timezone unaware, explicitly add UTC
|
||||
# as there is logic in indexing to prevent wrong timestamped docs
|
||||
# due to local time discrepancies with UTC
|
||||
@@ -236,8 +249,49 @@ def _convert_pr_to_document(pull_request: PullRequest) -> Document:
|
||||
else None
|
||||
),
|
||||
metadata={
|
||||
"merged": str(pull_request.merged),
|
||||
"state": pull_request.state,
|
||||
k: [str(vi) for vi in v] if isinstance(v, list) else str(v)
|
||||
for k, v in {
|
||||
"object_type": "PullRequest",
|
||||
"id": pull_request.number,
|
||||
"merged": pull_request.merged,
|
||||
"state": pull_request.state,
|
||||
"user": _get_userinfo(pull_request.user) if pull_request.user else None,
|
||||
"assignees": [
|
||||
_get_userinfo(assignee) for assignee in pull_request.assignees
|
||||
],
|
||||
"repo": (
|
||||
pull_request.base.repo.full_name if pull_request.base else None
|
||||
),
|
||||
"num_commits": str(pull_request.commits),
|
||||
"num_files_changed": str(pull_request.changed_files),
|
||||
"labels": [label.name for label in pull_request.labels],
|
||||
"created_at": (
|
||||
pull_request.created_at.replace(tzinfo=timezone.utc)
|
||||
if pull_request.created_at
|
||||
else None
|
||||
),
|
||||
"updated_at": (
|
||||
pull_request.updated_at.replace(tzinfo=timezone.utc)
|
||||
if pull_request.updated_at
|
||||
else None
|
||||
),
|
||||
"closed_at": (
|
||||
pull_request.closed_at.replace(tzinfo=timezone.utc)
|
||||
if pull_request.closed_at
|
||||
else None
|
||||
),
|
||||
"merged_at": (
|
||||
pull_request.merged_at.replace(tzinfo=timezone.utc)
|
||||
if pull_request.merged_at
|
||||
else None
|
||||
),
|
||||
"merged_by": (
|
||||
_get_userinfo(pull_request.merged_by)
|
||||
if pull_request.merged_by
|
||||
else None
|
||||
),
|
||||
}.items()
|
||||
if v is not None
|
||||
},
|
||||
)
|
||||
|
||||
@@ -252,11 +306,39 @@ def _convert_issue_to_document(issue: Issue) -> Document:
|
||||
id=issue.html_url,
|
||||
sections=[TextSection(link=issue.html_url, text=issue.body or "")],
|
||||
source=DocumentSource.GITHUB,
|
||||
semantic_identifier=issue.title,
|
||||
semantic_identifier=f"{issue.number}: {issue.title}",
|
||||
# updated_at is UTC time but is timezone unaware
|
||||
doc_updated_at=issue.updated_at.replace(tzinfo=timezone.utc),
|
||||
metadata={
|
||||
"state": issue.state,
|
||||
k: [str(vi) for vi in v] if isinstance(v, list) else str(v)
|
||||
for k, v in {
|
||||
"object_type": "Issue",
|
||||
"id": issue.number,
|
||||
"state": issue.state,
|
||||
"user": _get_userinfo(issue.user) if issue.user else None,
|
||||
"assignees": [_get_userinfo(assignee) for assignee in issue.assignees],
|
||||
"repo": issue.repository.full_name if issue.repository else None,
|
||||
"labels": [label.name for label in issue.labels],
|
||||
"created_at": (
|
||||
issue.created_at.replace(tzinfo=timezone.utc)
|
||||
if issue.created_at
|
||||
else None
|
||||
),
|
||||
"updated_at": (
|
||||
issue.updated_at.replace(tzinfo=timezone.utc)
|
||||
if issue.updated_at
|
||||
else None
|
||||
),
|
||||
"closed_at": (
|
||||
issue.closed_at.replace(tzinfo=timezone.utc)
|
||||
if issue.closed_at
|
||||
else None
|
||||
),
|
||||
"closed_by": (
|
||||
_get_userinfo(issue.closed_by) if issue.closed_by else None
|
||||
),
|
||||
}.items()
|
||||
if v is not None
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -27,6 +27,7 @@ from onyx.connectors.google_drive.doc_conversion import build_slim_document
|
||||
from onyx.connectors.google_drive.doc_conversion import (
|
||||
convert_drive_item_to_document,
|
||||
)
|
||||
from onyx.connectors.google_drive.doc_conversion import onyx_document_id_from_drive_file
|
||||
from onyx.connectors.google_drive.file_retrieval import crawl_folders_for_files
|
||||
from onyx.connectors.google_drive.file_retrieval import get_all_files_for_oauth
|
||||
from onyx.connectors.google_drive.file_retrieval import (
|
||||
@@ -220,6 +221,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
|
||||
self._primary_admin_email: str | None = None
|
||||
|
||||
self._creds: OAuthCredentials | ServiceAccountCredentials | None = None
|
||||
self._creds_dict: dict[str, Any] | None = None
|
||||
|
||||
# ids of folders and shared drives that have been traversed
|
||||
self._retrieved_folder_and_drive_ids: set[str] = set()
|
||||
@@ -273,6 +275,8 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
)
|
||||
|
||||
self._creds_dict = new_creds_dict
|
||||
|
||||
return new_creds_dict
|
||||
|
||||
def _update_traversed_parent_ids(self, folder_id: str) -> None:
|
||||
@@ -919,8 +923,9 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
|
||||
).timestamp(),
|
||||
current_folder_or_drive_id=file.parent_id,
|
||||
)
|
||||
if file.drive_file["id"] not in checkpoint.all_retrieved_file_ids:
|
||||
checkpoint.all_retrieved_file_ids.add(file.drive_file["id"])
|
||||
document_id = onyx_document_id_from_drive_file(file.drive_file)
|
||||
if document_id not in checkpoint.all_retrieved_file_ids:
|
||||
checkpoint.all_retrieved_file_ids.add(document_id)
|
||||
yield file
|
||||
|
||||
def _manage_oauth_retrieval(
|
||||
@@ -1135,6 +1140,10 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
|
||||
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
|
||||
raise e
|
||||
checkpoint.retrieved_folder_and_drive_ids = self._retrieved_folder_and_drive_ids
|
||||
|
||||
logger.info(
|
||||
f"num drive files retrieved: {len(checkpoint.all_retrieved_file_ids)}"
|
||||
)
|
||||
if checkpoint.completion_stage == DriveRetrievalStage.DONE:
|
||||
checkpoint.has_more = False
|
||||
return checkpoint
|
||||
@@ -1183,6 +1192,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
|
||||
end=end,
|
||||
callback=callback,
|
||||
)
|
||||
logger.info("Drive perm sync: Slim doc retrieval complete")
|
||||
|
||||
except Exception as e:
|
||||
if MISSING_SCOPES_ERROR_STR in str(e):
|
||||
|
||||
@@ -62,6 +62,10 @@ GOOGLE_MIME_TYPES = {
|
||||
}
|
||||
|
||||
|
||||
def onyx_document_id_from_drive_file(file: GoogleDriveFileType) -> str:
|
||||
return file[WEB_VIEW_LINK_KEY]
|
||||
|
||||
|
||||
def _summarize_drive_image(
|
||||
image_data: bytes, image_name: str, image_analysis_llm: LLM | None
|
||||
) -> str:
|
||||
@@ -380,7 +384,6 @@ def _convert_drive_item_to_document(
|
||||
"""
|
||||
Main entry point for converting a Google Drive file => Document object.
|
||||
"""
|
||||
doc_id = file.get(WEB_VIEW_LINK_KEY, "")
|
||||
sections: list[TextSection | ImageSection] = []
|
||||
# Only construct these services when needed
|
||||
drive_service = lazy_eval(
|
||||
@@ -389,6 +392,7 @@ def _convert_drive_item_to_document(
|
||||
docs_service = lazy_eval(
|
||||
lambda: get_google_docs_service(creds, user_email=retriever_email)
|
||||
)
|
||||
doc_id = "unknown"
|
||||
|
||||
try:
|
||||
# skip shortcuts or folders
|
||||
@@ -441,7 +445,7 @@ def _convert_drive_item_to_document(
|
||||
logger.warning(f"No content extracted from {file.get('name')}. Skipping.")
|
||||
return None
|
||||
|
||||
doc_id = file[WEB_VIEW_LINK_KEY]
|
||||
doc_id = onyx_document_id_from_drive_file(file)
|
||||
|
||||
# Create the document
|
||||
return Document(
|
||||
@@ -488,7 +492,7 @@ def build_slim_document(file: GoogleDriveFileType) -> SlimDocument | None:
|
||||
if file.get("mimeType") in [DRIVE_FOLDER_TYPE, DRIVE_SHORTCUT_TYPE]:
|
||||
return None
|
||||
return SlimDocument(
|
||||
id=file[WEB_VIEW_LINK_KEY],
|
||||
id=onyx_document_id_from_drive_file(file),
|
||||
perm_sync_data={
|
||||
"doc_id": file.get("id"),
|
||||
"drive_id": file.get("driveId"),
|
||||
|
||||
@@ -1,8 +1,16 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from google.auth.exceptions import RefreshError # type: ignore
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
||||
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
|
||||
from googleapiclient.discovery import build # type: ignore
|
||||
from googleapiclient.discovery import Resource # type: ignore
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class GoogleDriveService(Resource):
|
||||
pass
|
||||
@@ -20,6 +28,56 @@ class GmailService(Resource):
|
||||
pass
|
||||
|
||||
|
||||
class RefreshableDriveObject:
|
||||
"""
|
||||
Running Google drive service retrieval functions
|
||||
involves accessing methods of the service object (ie. files().list())
|
||||
which can raise a RefreshError if the access token is expired.
|
||||
This class is a wrapper that propagates the ability to refresh the access token
|
||||
and retry the final retrieval function until execute() is called.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
call_stack: Callable[[ServiceAccountCredentials | OAuthCredentials], Any],
|
||||
creds: ServiceAccountCredentials | OAuthCredentials,
|
||||
creds_getter: Callable[..., ServiceAccountCredentials | OAuthCredentials],
|
||||
):
|
||||
self.call_stack = call_stack
|
||||
self.creds = creds
|
||||
self.creds_getter = creds_getter
|
||||
|
||||
def __getattr__(self, name: str) -> Any:
|
||||
if name == "execute":
|
||||
return self.make_refreshable_execute()
|
||||
return RefreshableDriveObject(
|
||||
lambda creds: getattr(self.call_stack(creds), name),
|
||||
self.creds,
|
||||
self.creds_getter,
|
||||
)
|
||||
|
||||
def __call__(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return RefreshableDriveObject(
|
||||
lambda creds: self.call_stack(creds)(*args, **kwargs),
|
||||
self.creds,
|
||||
self.creds_getter,
|
||||
)
|
||||
|
||||
def make_refreshable_execute(self) -> Callable:
|
||||
def execute(*args: Any, **kwargs: Any) -> Any:
|
||||
try:
|
||||
return self.call_stack(self.creds).execute(*args, **kwargs)
|
||||
except RefreshError as e:
|
||||
logger.warning(
|
||||
f"RefreshError, going to attempt a creds refresh and retry: {e}"
|
||||
)
|
||||
# Refresh the access token
|
||||
self.creds = self.creds_getter()
|
||||
return self.call_stack(self.creds).execute(*args, **kwargs)
|
||||
|
||||
return execute
|
||||
|
||||
|
||||
def _get_google_service(
|
||||
service_name: str,
|
||||
service_version: str,
|
||||
|
||||
@@ -87,6 +87,9 @@ class BasicExpertInfo(BaseModel):
|
||||
|
||||
return "Unknown"
|
||||
|
||||
def get_email(self) -> str | None:
|
||||
return self.email or None
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
if not isinstance(other, BasicExpertInfo):
|
||||
return False
|
||||
|
||||
@@ -12,6 +12,9 @@ from onyx.configs.app_configs import INDEX_BATCH_SIZE
|
||||
from onyx.configs.app_configs import JIRA_CONNECTOR_LABELS_TO_SKIP
|
||||
from onyx.configs.app_configs import JIRA_CONNECTOR_MAX_TICKET_SIZE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
is_atlassian_date_error,
|
||||
)
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from onyx.connectors.exceptions import ConnectorValidationError
|
||||
from onyx.connectors.exceptions import CredentialExpiredError
|
||||
@@ -40,6 +43,8 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
ONE_HOUR = 3600
|
||||
|
||||
JIRA_API_VERSION = os.environ.get("JIRA_API_VERSION") or "2"
|
||||
_JIRA_SLIM_PAGE_SIZE = 500
|
||||
_JIRA_FULL_PAGE_SIZE = 50
|
||||
@@ -55,6 +60,14 @@ _FIELD_KEY = "key"
|
||||
_FIELD_CREATED = "created"
|
||||
_FIELD_DUEDATE = "duedate"
|
||||
_FIELD_ISSUETYPE = "issuetype"
|
||||
_FIELD_PARENT = "parent"
|
||||
_FIELD_ASSIGNEE_EMAIL = "assignee_email"
|
||||
_FIELD_REPORTER_EMAIL = "reporter_email"
|
||||
_FIELD_PROJECT = "project"
|
||||
_FIELD_PROJECT_NAME = "project_name"
|
||||
_FIELD_UPDATED = "updated"
|
||||
_FIELD_RESOLUTION_DATE = "resolutiondate"
|
||||
_FIELD_RESOLUTION_DATE_KEY = "resolution_date"
|
||||
|
||||
|
||||
def _perform_jql_search(
|
||||
@@ -126,6 +139,9 @@ def process_jira_issue(
|
||||
if basic_expert_info := best_effort_basic_expert_info(creator):
|
||||
people.add(basic_expert_info)
|
||||
metadata_dict[_FIELD_REPORTER] = basic_expert_info.get_semantic_name()
|
||||
if email := basic_expert_info.get_email():
|
||||
metadata_dict[_FIELD_REPORTER_EMAIL] = email
|
||||
|
||||
except Exception:
|
||||
# Author should exist but if not, doesn't matter
|
||||
pass
|
||||
@@ -135,6 +151,8 @@ def process_jira_issue(
|
||||
if basic_expert_info := best_effort_basic_expert_info(assignee):
|
||||
people.add(basic_expert_info)
|
||||
metadata_dict[_FIELD_ASSIGNEE] = basic_expert_info.get_semantic_name()
|
||||
if email := basic_expert_info.get_email():
|
||||
metadata_dict[_FIELD_ASSIGNEE_EMAIL] = email
|
||||
except Exception:
|
||||
# Author should exist but if not, doesn't matter
|
||||
pass
|
||||
@@ -149,10 +167,32 @@ def process_jira_issue(
|
||||
metadata_dict[_FIELD_LABELS] = labels
|
||||
if created := best_effort_get_field_from_issue(issue, _FIELD_CREATED):
|
||||
metadata_dict[_FIELD_CREATED] = created
|
||||
if updated := best_effort_get_field_from_issue(issue, _FIELD_UPDATED):
|
||||
metadata_dict[_FIELD_UPDATED] = updated
|
||||
if duedate := best_effort_get_field_from_issue(issue, _FIELD_DUEDATE):
|
||||
metadata_dict[_FIELD_DUEDATE] = duedate
|
||||
if issuetype := best_effort_get_field_from_issue(issue, _FIELD_ISSUETYPE):
|
||||
metadata_dict[_FIELD_ISSUETYPE] = issuetype.name
|
||||
if resolutiondate := best_effort_get_field_from_issue(
|
||||
issue, _FIELD_RESOLUTION_DATE
|
||||
):
|
||||
metadata_dict[_FIELD_RESOLUTION_DATE_KEY] = resolutiondate
|
||||
|
||||
try:
|
||||
parent = best_effort_get_field_from_issue(issue, _FIELD_PARENT)
|
||||
if parent:
|
||||
metadata_dict[_FIELD_PARENT] = parent.key
|
||||
except Exception:
|
||||
# Parent should exist but if not, doesn't matter
|
||||
pass
|
||||
try:
|
||||
project = best_effort_get_field_from_issue(issue, _FIELD_PROJECT)
|
||||
if project:
|
||||
metadata_dict[_FIELD_PROJECT_NAME] = project.name
|
||||
metadata_dict[_FIELD_PROJECT] = project.key
|
||||
except Exception:
|
||||
# Project should exist.
|
||||
logger.error(f"Project should exist but does not for {issue.key}")
|
||||
|
||||
return Document(
|
||||
id=page_url,
|
||||
@@ -240,7 +280,17 @@ class JiraConnector(CheckpointedConnector[JiraConnectorCheckpoint], SlimConnecto
|
||||
checkpoint: JiraConnectorCheckpoint,
|
||||
) -> CheckpointOutput[JiraConnectorCheckpoint]:
|
||||
jql = self._get_jql_query(start, end)
|
||||
try:
|
||||
return self._load_from_checkpoint(jql, checkpoint)
|
||||
except Exception as e:
|
||||
if is_atlassian_date_error(e):
|
||||
jql = self._get_jql_query(start - ONE_HOUR, end)
|
||||
return self._load_from_checkpoint(jql, checkpoint)
|
||||
raise e
|
||||
|
||||
def _load_from_checkpoint(
|
||||
self, jql: str, checkpoint: JiraConnectorCheckpoint
|
||||
) -> CheckpointOutput[JiraConnectorCheckpoint]:
|
||||
# Get the current offset from checkpoint or start at 0
|
||||
starting_offset = checkpoint.offset or 0
|
||||
current_offset = starting_offset
|
||||
|
||||
@@ -9,8 +9,11 @@ from concurrent.futures import Future
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from http.client import IncompleteRead
|
||||
from http.client import RemoteDisconnected
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from urllib.error import URLError
|
||||
|
||||
from pydantic import BaseModel
|
||||
from redis import Redis
|
||||
@@ -18,6 +21,9 @@ from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
from slack_sdk.http_retry import ConnectionErrorRetryHandler
|
||||
from slack_sdk.http_retry import RetryHandler
|
||||
from slack_sdk.http_retry.builtin_interval_calculators import (
|
||||
FixedValueRetryIntervalCalculator,
|
||||
)
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.configs.app_configs import ENABLE_EXPENSIVE_EXPERT_CALLS
|
||||
@@ -45,10 +51,10 @@ from onyx.connectors.models import EntityFailure
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.slack.onyx_retry_handler import OnyxRedisSlackRetryHandler
|
||||
from onyx.connectors.slack.onyx_slack_web_client import OnyxSlackWebClient
|
||||
from onyx.connectors.slack.utils import expert_info_from_slack_id
|
||||
from onyx.connectors.slack.utils import get_message_link
|
||||
from onyx.connectors.slack.utils import make_paginated_slack_api_call_w_retries
|
||||
from onyx.connectors.slack.utils import make_slack_api_call_w_retries
|
||||
from onyx.connectors.slack.utils import make_paginated_slack_api_call
|
||||
from onyx.connectors.slack.utils import SlackTextCleaner
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
@@ -78,7 +84,7 @@ def _collect_paginated_channels(
|
||||
channel_types: list[str],
|
||||
) -> list[ChannelType]:
|
||||
channels: list[dict[str, Any]] = []
|
||||
for result in make_paginated_slack_api_call_w_retries(
|
||||
for result in make_paginated_slack_api_call(
|
||||
client.conversations_list,
|
||||
exclude_archived=exclude_archived,
|
||||
# also get private channels the bot is added to
|
||||
@@ -135,14 +141,13 @@ def get_channel_messages(
|
||||
"""Get all messages in a channel"""
|
||||
# join so that the bot can access messages
|
||||
if not channel["is_member"]:
|
||||
make_slack_api_call_w_retries(
|
||||
client.conversations_join,
|
||||
client.conversations_join(
|
||||
channel=channel["id"],
|
||||
is_private=channel["is_private"],
|
||||
)
|
||||
logger.info(f"Successfully joined '{channel['name']}'")
|
||||
|
||||
for result in make_paginated_slack_api_call_w_retries(
|
||||
for result in make_paginated_slack_api_call(
|
||||
client.conversations_history,
|
||||
channel=channel["id"],
|
||||
oldest=oldest,
|
||||
@@ -159,7 +164,7 @@ def get_channel_messages(
|
||||
def get_thread(client: WebClient, channel_id: str, thread_id: str) -> ThreadType:
|
||||
"""Get all messages in a thread"""
|
||||
threads: list[MessageType] = []
|
||||
for result in make_paginated_slack_api_call_w_retries(
|
||||
for result in make_paginated_slack_api_call(
|
||||
client.conversations_replies, channel=channel_id, ts=thread_id
|
||||
):
|
||||
threads.extend(result["messages"])
|
||||
@@ -317,8 +322,7 @@ def _get_channel_by_id(client: WebClient, channel_id: str) -> ChannelType:
|
||||
Raises:
|
||||
SlackApiError: If the channel cannot be fetched
|
||||
"""
|
||||
response = make_slack_api_call_w_retries(
|
||||
client.conversations_info,
|
||||
response = client.conversations_info(
|
||||
channel=channel_id,
|
||||
)
|
||||
return cast(ChannelType, response["channel"])
|
||||
@@ -335,8 +339,7 @@ def _get_messages(
|
||||
# have to be in the channel in order to read messages
|
||||
if not channel["is_member"]:
|
||||
try:
|
||||
make_slack_api_call_w_retries(
|
||||
client.conversations_join,
|
||||
client.conversations_join(
|
||||
channel=channel["id"],
|
||||
is_private=channel["is_private"],
|
||||
)
|
||||
@@ -349,8 +352,7 @@ def _get_messages(
|
||||
raise
|
||||
logger.info(f"Successfully joined '{channel['name']}'")
|
||||
|
||||
response = make_slack_api_call_w_retries(
|
||||
client.conversations_history,
|
||||
response = client.conversations_history(
|
||||
channel=channel["id"],
|
||||
oldest=oldest,
|
||||
latest=latest,
|
||||
@@ -379,6 +381,9 @@ def _message_to_doc(
|
||||
filtered_thread: ThreadType | None = None
|
||||
thread_ts = message.get("thread_ts")
|
||||
if thread_ts:
|
||||
# NOTE: if thread_ts is present, there's a thread we need to process
|
||||
# ... otherwise, we can skip it
|
||||
|
||||
# skip threads we've already seen, since we've already processed all
|
||||
# messages in that thread
|
||||
if thread_ts in seen_thread_ts:
|
||||
@@ -527,6 +532,7 @@ class SlackConnector(
|
||||
channel_regex_enabled: bool = False,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
num_threads: int = SLACK_NUM_THREADS,
|
||||
use_redis: bool = True,
|
||||
) -> None:
|
||||
self.channels = channels
|
||||
self.channel_regex_enabled = channel_regex_enabled
|
||||
@@ -539,6 +545,7 @@ class SlackConnector(
|
||||
self.user_cache: dict[str, BasicExpertInfo | None] = {}
|
||||
self.credentials_provider: CredentialsProviderInterface | None = None
|
||||
self.credential_prefix: str | None = None
|
||||
self.use_redis: bool = use_redis
|
||||
# self.delay_lock: str | None = None # the redis key for the shared lock
|
||||
# self.delay_key: str | None = None # the redis key for the shared delay
|
||||
|
||||
@@ -563,10 +570,19 @@ class SlackConnector(
|
||||
|
||||
# NOTE: slack has a built in RateLimitErrorRetryHandler, but it isn't designed
|
||||
# for concurrent workers. We've extended it with OnyxRedisSlackRetryHandler.
|
||||
connection_error_retry_handler = ConnectionErrorRetryHandler()
|
||||
connection_error_retry_handler = ConnectionErrorRetryHandler(
|
||||
max_retry_count=max_retry_count,
|
||||
interval_calculator=FixedValueRetryIntervalCalculator(),
|
||||
error_types=[
|
||||
URLError,
|
||||
ConnectionResetError,
|
||||
RemoteDisconnected,
|
||||
IncompleteRead,
|
||||
],
|
||||
)
|
||||
|
||||
onyx_rate_limit_error_retry_handler = OnyxRedisSlackRetryHandler(
|
||||
max_retry_count=max_retry_count,
|
||||
delay_lock=delay_lock,
|
||||
delay_key=delay_key,
|
||||
r=r,
|
||||
)
|
||||
@@ -575,7 +591,13 @@ class SlackConnector(
|
||||
onyx_rate_limit_error_retry_handler,
|
||||
]
|
||||
|
||||
client = WebClient(token=token, retry_handlers=custom_retry_handlers)
|
||||
client = OnyxSlackWebClient(
|
||||
delay_lock=delay_lock,
|
||||
delay_key=delay_key,
|
||||
r=r,
|
||||
token=token,
|
||||
retry_handlers=custom_retry_handlers,
|
||||
)
|
||||
return client
|
||||
|
||||
@property
|
||||
@@ -599,16 +621,32 @@ class SlackConnector(
|
||||
if not tenant_id:
|
||||
raise ValueError("tenant_id cannot be None!")
|
||||
|
||||
self.redis = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
self.credential_prefix = SlackConnector.make_credential_prefix(
|
||||
credentials_provider.get_provider_key()
|
||||
)
|
||||
|
||||
bot_token = credentials["slack_bot_token"]
|
||||
self.client = SlackConnector.make_slack_web_client(
|
||||
self.credential_prefix, bot_token, self.MAX_RETRIES, self.redis
|
||||
)
|
||||
|
||||
if self.use_redis:
|
||||
self.redis = get_redis_client(tenant_id=tenant_id)
|
||||
self.credential_prefix = SlackConnector.make_credential_prefix(
|
||||
credentials_provider.get_provider_key()
|
||||
)
|
||||
|
||||
self.client = SlackConnector.make_slack_web_client(
|
||||
self.credential_prefix, bot_token, self.MAX_RETRIES, self.redis
|
||||
)
|
||||
else:
|
||||
connection_error_retry_handler = ConnectionErrorRetryHandler(
|
||||
max_retry_count=self.MAX_RETRIES,
|
||||
interval_calculator=FixedValueRetryIntervalCalculator(),
|
||||
error_types=[
|
||||
URLError,
|
||||
ConnectionResetError,
|
||||
RemoteDisconnected,
|
||||
IncompleteRead,
|
||||
],
|
||||
)
|
||||
|
||||
self.client = WebClient(
|
||||
token=bot_token, retry_handlers=[connection_error_retry_handler]
|
||||
)
|
||||
|
||||
# use for requests that must return quickly (e.g. realtime flows where user is waiting)
|
||||
self.fast_client = WebClient(
|
||||
@@ -651,6 +689,8 @@ class SlackConnector(
|
||||
Step 2.4: If there are no more messages in the channel, switch the current
|
||||
channel to the next channel.
|
||||
"""
|
||||
num_channels_remaining = 0
|
||||
|
||||
if self.client is None or self.text_cleaner is None:
|
||||
raise ConnectorMissingCredentialError("Slack")
|
||||
|
||||
@@ -664,7 +704,9 @@ class SlackConnector(
|
||||
raw_channels, self.channels, self.channel_regex_enabled
|
||||
)
|
||||
logger.info(
|
||||
f"Channels: all={len(raw_channels)} post_filtering={len(filtered_channels)}"
|
||||
f"Channels - initial checkpoint: "
|
||||
f"all={len(raw_channels)} "
|
||||
f"post_filtering={len(filtered_channels)}"
|
||||
)
|
||||
|
||||
checkpoint.channel_ids = [c["id"] for c in filtered_channels]
|
||||
@@ -677,6 +719,17 @@ class SlackConnector(
|
||||
return checkpoint
|
||||
|
||||
final_channel_ids = checkpoint.channel_ids
|
||||
for channel_id in final_channel_ids:
|
||||
if channel_id not in checkpoint.channel_completion_map:
|
||||
num_channels_remaining += 1
|
||||
|
||||
logger.info(
|
||||
f"Channels - current status: "
|
||||
f"processed={len(final_channel_ids) - num_channels_remaining} "
|
||||
f"remaining={num_channels_remaining=} "
|
||||
f"total={len(final_channel_ids)}"
|
||||
)
|
||||
|
||||
channel = checkpoint.current_channel
|
||||
if channel is None:
|
||||
raise ValueError("current_channel key not set in checkpoint")
|
||||
@@ -688,18 +741,32 @@ class SlackConnector(
|
||||
oldest = str(start) if start else None
|
||||
latest = checkpoint.channel_completion_map.get(channel_id, str(end))
|
||||
seen_thread_ts = set(checkpoint.seen_thread_ts)
|
||||
|
||||
logger.debug(
|
||||
f"Getting messages for channel {channel} within range {oldest} - {latest}"
|
||||
)
|
||||
|
||||
try:
|
||||
logger.debug(
|
||||
f"Getting messages for channel {channel} within range {oldest} - {latest}"
|
||||
)
|
||||
message_batch, has_more_in_channel = _get_messages(
|
||||
channel, self.client, oldest, latest
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Retrieved messages: "
|
||||
f"{len(message_batch)=} "
|
||||
f"{channel=} "
|
||||
f"{oldest=} "
|
||||
f"{latest=}"
|
||||
)
|
||||
|
||||
new_latest = message_batch[-1]["ts"] if message_batch else latest
|
||||
|
||||
num_threads_start = len(seen_thread_ts)
|
||||
# Process messages in parallel using ThreadPoolExecutor
|
||||
with ThreadPoolExecutor(max_workers=self.num_threads) as executor:
|
||||
# NOTE(rkuo): this seems to be assuming the slack sdk is thread safe.
|
||||
# That's a very bold assumption! Likely not correct.
|
||||
|
||||
futures: list[Future[ProcessedSlackMessage]] = []
|
||||
for message in message_batch:
|
||||
# Capture the current context so that the thread gets the current tenant ID
|
||||
@@ -736,7 +803,12 @@ class SlackConnector(
|
||||
yield failure
|
||||
|
||||
num_threads_processed = len(seen_thread_ts) - num_threads_start
|
||||
logger.info(f"Processed {num_threads_processed} threads.")
|
||||
logger.info(
|
||||
f"Message processing stats: "
|
||||
f"batch_len={len(message_batch)} "
|
||||
f"batch_yielded={num_threads_processed} "
|
||||
f"total_threads_seen={len(seen_thread_ts)}"
|
||||
)
|
||||
|
||||
checkpoint.seen_thread_ts = list(seen_thread_ts)
|
||||
checkpoint.channel_completion_map[channel["id"]] = new_latest
|
||||
@@ -751,6 +823,7 @@ class SlackConnector(
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
if new_channel_id:
|
||||
new_channel = _get_channel_by_id(self.client, new_channel_id)
|
||||
checkpoint.current_channel = new_channel
|
||||
@@ -758,8 +831,6 @@ class SlackConnector(
|
||||
checkpoint.current_channel = None
|
||||
|
||||
checkpoint.has_more = checkpoint.current_channel is not None
|
||||
return checkpoint
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error processing channel {channel['name']}")
|
||||
yield ConnectorFailure(
|
||||
@@ -773,7 +844,8 @@ class SlackConnector(
|
||||
failure_message=str(e),
|
||||
exception=e,
|
||||
)
|
||||
return checkpoint
|
||||
|
||||
return checkpoint
|
||||
|
||||
def validate_connector_settings(self) -> None:
|
||||
"""
|
||||
|
||||
@@ -1,11 +1,8 @@
|
||||
import math
|
||||
import random
|
||||
import time
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from slack_sdk.http_retry.handler import RetryHandler
|
||||
from slack_sdk.http_retry.request import HttpRequest
|
||||
from slack_sdk.http_retry.response import HttpResponse
|
||||
@@ -20,28 +17,23 @@ class OnyxRedisSlackRetryHandler(RetryHandler):
|
||||
"""
|
||||
This class uses Redis to share a rate limit among multiple threads.
|
||||
|
||||
Threads that encounter a rate limit will observe the shared delay, increment the
|
||||
shared delay with the retry value, and use the new shared value as a wait interval.
|
||||
As currently implemented, this code is already surrounded by a lock in Redis
|
||||
via an override of _perform_urllib_http_request in OnyxSlackWebClient.
|
||||
|
||||
This has the effect of serializing calls when a rate limit is hit, which is what
|
||||
needs to happens if the server punishes us with additional limiting when we make
|
||||
a call too early. We believe this is what Slack is doing based on empirical
|
||||
observation, meaning we see indefinite hangs if we're too aggressive.
|
||||
This just sets the desired retry delay with TTL in redis. In conjunction with
|
||||
a custom subclass of the client, the value is read and obeyed prior to an API call
|
||||
and also serialized.
|
||||
|
||||
Another way to do this is just to do exponential backoff. Might be easier?
|
||||
|
||||
Adapted from slack's RateLimitErrorRetryHandler.
|
||||
"""
|
||||
|
||||
LOCK_TTL = 60 # used to serialize access to the retry TTL
|
||||
LOCK_BLOCKING_TIMEOUT = 60 # how long to wait for the lock
|
||||
|
||||
"""RetryHandler that does retries for rate limited errors."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
max_retry_count: int,
|
||||
delay_lock: str,
|
||||
delay_key: str,
|
||||
r: Redis,
|
||||
):
|
||||
@@ -51,7 +43,6 @@ class OnyxRedisSlackRetryHandler(RetryHandler):
|
||||
"""
|
||||
super().__init__(max_retry_count=max_retry_count)
|
||||
self._redis: Redis = r
|
||||
self._delay_lock = delay_lock
|
||||
self._delay_key = delay_key
|
||||
|
||||
def _can_retry(
|
||||
@@ -72,8 +63,18 @@ class OnyxRedisSlackRetryHandler(RetryHandler):
|
||||
response: Optional[HttpResponse] = None,
|
||||
error: Optional[Exception] = None,
|
||||
) -> None:
|
||||
"""It seems this function is responsible for the wait to retry ... aka we
|
||||
actually sleep in this function."""
|
||||
"""As initially designed by the SDK authors, this function is responsible for
|
||||
the wait to retry ... aka we actually sleep in this function.
|
||||
|
||||
This doesn't work well with multiple clients because every thread is unaware
|
||||
of the current retry value until it actually calls the endpoint.
|
||||
|
||||
We're combining this with an actual subclass of the slack web client so
|
||||
that the delay is used BEFORE calling an API endpoint. The subclassed client
|
||||
has already taken the lock in redis when this method is called.
|
||||
"""
|
||||
ttl_ms: int | None = None
|
||||
|
||||
retry_after_value: list[str] | None = None
|
||||
retry_after_header_name: Optional[str] = None
|
||||
duration_s: float = 1.0 # seconds
|
||||
@@ -112,48 +113,22 @@ class OnyxRedisSlackRetryHandler(RetryHandler):
|
||||
retry_after_value[0]
|
||||
) # will raise ValueError if somehow we can't convert to int
|
||||
jitter = retry_after_value_int * 0.25 * random.random()
|
||||
duration_s = math.ceil(retry_after_value_int + jitter)
|
||||
duration_s = retry_after_value_int + jitter
|
||||
except ValueError:
|
||||
duration_s += random.random()
|
||||
|
||||
# lock and extend the ttl
|
||||
lock: RedisLock = self._redis.lock(
|
||||
self._delay_lock,
|
||||
timeout=OnyxRedisSlackRetryHandler.LOCK_TTL,
|
||||
thread_local=False,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(
|
||||
blocking_timeout=OnyxRedisSlackRetryHandler.LOCK_BLOCKING_TIMEOUT / 2
|
||||
)
|
||||
|
||||
ttl_ms: int | None = None
|
||||
|
||||
try:
|
||||
if acquired:
|
||||
# if we can get the lock, then read and extend the ttl
|
||||
ttl_ms = cast(int, self._redis.pttl(self._delay_key))
|
||||
if ttl_ms < 0: # negative values are error status codes ... see docs
|
||||
ttl_ms = 0
|
||||
ttl_ms_new = ttl_ms + int(duration_s * 1000.0)
|
||||
self._redis.set(self._delay_key, "1", px=ttl_ms_new)
|
||||
else:
|
||||
# if we can't get the lock, just go ahead.
|
||||
# TODO: if we know our actual parallelism, multiplying by that
|
||||
# would be a pretty good idea
|
||||
ttl_ms_new = int(duration_s * 1000.0)
|
||||
finally:
|
||||
if acquired:
|
||||
lock.release()
|
||||
# Read and extend the ttl
|
||||
ttl_ms = cast(int, self._redis.pttl(self._delay_key))
|
||||
if ttl_ms < 0: # negative values are error status codes ... see docs
|
||||
ttl_ms = 0
|
||||
ttl_ms_new = ttl_ms + int(duration_s * 1000.0)
|
||||
self._redis.set(self._delay_key, "1", px=ttl_ms_new)
|
||||
|
||||
logger.warning(
|
||||
f"OnyxRedisSlackRetryHandler.prepare_for_next_attempt wait: "
|
||||
f"OnyxRedisSlackRetryHandler.prepare_for_next_attempt setting delay: "
|
||||
f"current_attempt={state.current_attempt} "
|
||||
f"retry-after={retry_after_value} "
|
||||
f"shared_delay_ms={ttl_ms} new_shared_delay_ms={ttl_ms_new}"
|
||||
f"{ttl_ms_new=}"
|
||||
)
|
||||
|
||||
# TODO: would be good to take an event var and sleep in short increments to
|
||||
# allow for a clean exit / exception
|
||||
time.sleep(ttl_ms_new / 1000.0)
|
||||
|
||||
state.increment_current_attempt()
|
||||
|
||||
116
backend/onyx/connectors/slack/onyx_slack_web_client.py
Normal file
116
backend/onyx/connectors/slack/onyx_slack_web_client.py
Normal file
@@ -0,0 +1,116 @@
|
||||
import threading
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from urllib.request import Request
|
||||
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from slack_sdk import WebClient
|
||||
|
||||
from onyx.connectors.slack.utils import ONYX_SLACK_LOCK_BLOCKING_TIMEOUT
|
||||
from onyx.connectors.slack.utils import ONYX_SLACK_LOCK_TOTAL_BLOCKING_TIMEOUT
|
||||
from onyx.connectors.slack.utils import ONYX_SLACK_LOCK_TTL
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class OnyxSlackWebClient(WebClient):
|
||||
"""Use in combination with the Onyx Retry Handler.
|
||||
|
||||
This client wrapper enforces a proper retry delay through redis BEFORE the api call
|
||||
so that multiple clients can synchronize and rate limit properly.
|
||||
|
||||
The retry handler writes the correct delay value to redis so that it is can be used
|
||||
by this wrapper.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, delay_lock: str, delay_key: str, r: Redis, *args: Any, **kwargs: Any
|
||||
) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
self._delay_key = delay_key
|
||||
self._delay_lock = delay_lock
|
||||
self._redis: Redis = r
|
||||
self.num_requests: int = 0
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def _perform_urllib_http_request(
|
||||
self, *, url: str, args: Dict[str, Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""By locking around the base class method, we ensure that both the delay from
|
||||
Redis and parsing/writing of retry values to Redis are handled properly in
|
||||
one place"""
|
||||
# lock and extend the ttl
|
||||
lock: RedisLock = self._redis.lock(
|
||||
self._delay_lock,
|
||||
timeout=ONYX_SLACK_LOCK_TTL,
|
||||
)
|
||||
|
||||
# try to acquire the lock
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
acquired = lock.acquire(blocking_timeout=ONYX_SLACK_LOCK_BLOCKING_TIMEOUT)
|
||||
if acquired:
|
||||
break
|
||||
|
||||
# if we couldn't acquire the lock but it exists, there's at least some activity
|
||||
# so keep trying...
|
||||
if self._redis.exists(self._delay_lock):
|
||||
continue
|
||||
|
||||
if time.monotonic() - start > ONYX_SLACK_LOCK_TOTAL_BLOCKING_TIMEOUT:
|
||||
raise RuntimeError(
|
||||
f"OnyxSlackWebClient._perform_urllib_http_request - "
|
||||
f"timed out waiting for lock: {ONYX_SLACK_LOCK_TOTAL_BLOCKING_TIMEOUT=}"
|
||||
)
|
||||
|
||||
try:
|
||||
result = super()._perform_urllib_http_request(url=url, args=args)
|
||||
finally:
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
else:
|
||||
logger.warning(
|
||||
"OnyxSlackWebClient._perform_urllib_http_request lock not owned on release"
|
||||
)
|
||||
|
||||
time.monotonic() - start
|
||||
# logger.info(
|
||||
# f"OnyxSlackWebClient._perform_urllib_http_request: Releasing lock: {elapsed=}"
|
||||
# )
|
||||
|
||||
return result
|
||||
|
||||
def _perform_urllib_http_request_internal(
|
||||
self,
|
||||
url: str,
|
||||
req: Request,
|
||||
) -> Dict[str, Any]:
|
||||
"""Overrides the internal method which is mostly the direct call to
|
||||
urllib/urlopen ... so this is a good place to perform our delay."""
|
||||
|
||||
# read and execute the delay
|
||||
delay_ms = cast(int, self._redis.pttl(self._delay_key))
|
||||
if delay_ms < 0: # negative values are error status codes ... see docs
|
||||
delay_ms = 0
|
||||
|
||||
if delay_ms > 0:
|
||||
logger.warning(
|
||||
f"OnyxSlackWebClient._perform_urllib_http_request_internal delay: "
|
||||
f"{delay_ms=} "
|
||||
f"{self.num_requests=}"
|
||||
)
|
||||
|
||||
time.sleep(delay_ms / 1000.0)
|
||||
|
||||
result = super()._perform_urllib_http_request_internal(url, req)
|
||||
|
||||
with self._lock:
|
||||
self.num_requests += 1
|
||||
|
||||
# the delay key should have naturally expired by this point
|
||||
return result
|
||||
@@ -21,6 +21,11 @@ basic_retry_wrapper = retry_builder(tries=7)
|
||||
# number of messages we request per page when fetching paginated slack messages
|
||||
_SLACK_LIMIT = 900
|
||||
|
||||
# used to serialize access to the retry TTL
|
||||
ONYX_SLACK_LOCK_TTL = 1800 # how long the lock is allowed to idle before it expires
|
||||
ONYX_SLACK_LOCK_BLOCKING_TIMEOUT = 60 # how long to wait for the lock per wait attempt
|
||||
ONYX_SLACK_LOCK_TOTAL_BLOCKING_TIMEOUT = 3600 # how long to wait for the lock in total
|
||||
|
||||
|
||||
@lru_cache()
|
||||
def get_base_url(token: str) -> str:
|
||||
@@ -44,6 +49,18 @@ def get_message_link(
|
||||
return link
|
||||
|
||||
|
||||
def make_slack_api_call(
|
||||
call: Callable[..., SlackResponse], **kwargs: Any
|
||||
) -> SlackResponse:
|
||||
return call(**kwargs)
|
||||
|
||||
|
||||
def make_paginated_slack_api_call(
|
||||
call: Callable[..., SlackResponse], **kwargs: Any
|
||||
) -> Generator[dict[str, Any], None, None]:
|
||||
return _make_slack_api_call_paginated(call)(**kwargs)
|
||||
|
||||
|
||||
def _make_slack_api_call_paginated(
|
||||
call: Callable[..., SlackResponse],
|
||||
) -> Callable[..., Generator[dict[str, Any], None, None]]:
|
||||
@@ -119,17 +136,18 @@ def _make_slack_api_call_paginated(
|
||||
|
||||
# return rate_limited_call
|
||||
|
||||
|
||||
def make_slack_api_call_w_retries(
|
||||
call: Callable[..., SlackResponse], **kwargs: Any
|
||||
) -> SlackResponse:
|
||||
return basic_retry_wrapper(call)(**kwargs)
|
||||
# temporarily disabling due to using a different retry approach
|
||||
# might be permanent if everything works out
|
||||
# def make_slack_api_call_w_retries(
|
||||
# call: Callable[..., SlackResponse], **kwargs: Any
|
||||
# ) -> SlackResponse:
|
||||
# return basic_retry_wrapper(call)(**kwargs)
|
||||
|
||||
|
||||
def make_paginated_slack_api_call_w_retries(
|
||||
call: Callable[..., SlackResponse], **kwargs: Any
|
||||
) -> Generator[dict[str, Any], None, None]:
|
||||
return _make_slack_api_call_paginated(basic_retry_wrapper(call))(**kwargs)
|
||||
# def make_paginated_slack_api_call_w_retries(
|
||||
# call: Callable[..., SlackResponse], **kwargs: Any
|
||||
# ) -> Generator[dict[str, Any], None, None]:
|
||||
# return _make_slack_api_call_paginated(basic_retry_wrapper(call))(**kwargs)
|
||||
|
||||
|
||||
def expert_info_from_slack_id(
|
||||
|
||||
@@ -111,11 +111,14 @@ class BaseFilters(BaseModel):
|
||||
document_set: list[str] | None = None
|
||||
time_cutoff: datetime | None = None
|
||||
tags: list[Tag] | None = None
|
||||
|
||||
|
||||
class UserFileFilters(BaseModel):
|
||||
user_file_ids: list[int] | None = None
|
||||
user_folder_ids: list[int] | None = None
|
||||
|
||||
|
||||
class IndexFilters(BaseFilters):
|
||||
class IndexFilters(BaseFilters, UserFileFilters):
|
||||
access_control_list: list[str] | None
|
||||
tenant_id: str | None = None
|
||||
|
||||
@@ -150,6 +153,7 @@ class SearchRequest(ChunkContext):
|
||||
search_type: SearchType = SearchType.SEMANTIC
|
||||
|
||||
human_selected_filters: BaseFilters | None = None
|
||||
user_file_filters: UserFileFilters | None = None
|
||||
enable_auto_detect_filters: bool | None = None
|
||||
persona: Persona | None = None
|
||||
|
||||
|
||||
@@ -165,47 +165,6 @@ class SearchPipeline:
|
||||
|
||||
return cast(list[InferenceChunk], self._retrieved_chunks)
|
||||
|
||||
def get_ordering_only_chunks(
|
||||
self,
|
||||
query: str,
|
||||
user_file_ids: list[int] | None = None,
|
||||
user_folder_ids: list[int] | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
"""Optimized method that only retrieves chunks for ordering purposes.
|
||||
Skips all extra processing and uses minimal configuration to speed up retrieval.
|
||||
"""
|
||||
logger.info("Fast path: Using optimized chunk retrieval for ordering-only mode")
|
||||
|
||||
# Create minimal filters with just user file/folder IDs
|
||||
filters = IndexFilters(
|
||||
user_file_ids=user_file_ids or [],
|
||||
user_folder_ids=user_folder_ids or [],
|
||||
access_control_list=None,
|
||||
)
|
||||
|
||||
# Use a simplified query that skips all unnecessary processing
|
||||
minimal_query = SearchQuery(
|
||||
query=query,
|
||||
search_type=SearchType.SEMANTIC,
|
||||
filters=filters,
|
||||
# Set minimal options needed for retrieval
|
||||
evaluation_type=LLMEvaluationType.SKIP,
|
||||
recency_bias_multiplier=1.0,
|
||||
chunks_above=0, # No need for surrounding context
|
||||
chunks_below=0, # No need for surrounding context
|
||||
processed_keywords=[], # Empty list instead of None
|
||||
rerank_settings=None,
|
||||
hybrid_alpha=0.0,
|
||||
max_llm_filter_sections=0,
|
||||
)
|
||||
|
||||
# Retrieve chunks using the minimal configuration
|
||||
return retrieve_chunks(
|
||||
query=minimal_query,
|
||||
document_index=self.document_index,
|
||||
db_session=self.db_session,
|
||||
)
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def _get_sections(self) -> list[InferenceSection]:
|
||||
"""Returns an expanded section from each of the chunks.
|
||||
@@ -458,10 +417,6 @@ class SearchPipeline:
|
||||
self.search_query.evaluation_type == LLMEvaluationType.SKIP
|
||||
or DISABLE_LLM_DOC_RELEVANCE
|
||||
):
|
||||
if self.search_query.evaluation_type == LLMEvaluationType.SKIP:
|
||||
logger.info(
|
||||
"Fast path: Skipping section relevance evaluation for ordering-only mode"
|
||||
)
|
||||
return None
|
||||
|
||||
if self.search_query.evaluation_type == LLMEvaluationType.UNSPECIFIED:
|
||||
|
||||
@@ -372,7 +372,6 @@ def filter_sections(
|
||||
# Log evaluation type to help with debugging
|
||||
logger.info(f"filter_sections called with evaluation_type={query.evaluation_type}")
|
||||
|
||||
# Fast path: immediately return empty list for SKIP evaluation type (ordering-only mode)
|
||||
if query.evaluation_type == LLMEvaluationType.SKIP:
|
||||
return []
|
||||
|
||||
@@ -408,16 +407,6 @@ def search_postprocessing(
|
||||
llm: LLM,
|
||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||
) -> Iterator[list[InferenceSection] | list[SectionRelevancePiece]]:
|
||||
# Fast path for ordering-only: detect it by checking if evaluation_type is SKIP
|
||||
if search_query.evaluation_type == LLMEvaluationType.SKIP:
|
||||
logger.info(
|
||||
"Fast path: Detected ordering-only mode, bypassing all post-processing"
|
||||
)
|
||||
# Immediately yield the sections without any processing and an empty relevance list
|
||||
yield retrieved_sections
|
||||
yield cast(list[SectionRelevancePiece], [])
|
||||
return
|
||||
|
||||
post_processing_tasks: list[FunctionCall] = []
|
||||
|
||||
if not retrieved_sections:
|
||||
|
||||
@@ -164,14 +164,15 @@ def retrieval_preprocessing(
|
||||
user_acl_filters = (
|
||||
None if bypass_acl else build_access_filters_for_user(user, db_session)
|
||||
)
|
||||
user_file_ids = preset_filters.user_file_ids or []
|
||||
user_folder_ids = preset_filters.user_folder_ids or []
|
||||
user_file_filters = search_request.user_file_filters
|
||||
user_file_ids = (user_file_filters.user_file_ids or []) if user_file_filters else []
|
||||
user_folder_ids = (
|
||||
(user_file_filters.user_folder_ids or []) if user_file_filters else []
|
||||
)
|
||||
if persona and persona.user_files:
|
||||
user_file_ids = user_file_ids + [
|
||||
file.id
|
||||
for file in persona.user_files
|
||||
if file.id not in (preset_filters.user_file_ids or [])
|
||||
]
|
||||
user_file_ids = list(
|
||||
set(user_file_ids) | set([file.id for file in persona.user_files])
|
||||
)
|
||||
|
||||
final_filters = IndexFilters(
|
||||
user_file_ids=user_file_ids,
|
||||
|
||||
@@ -62,7 +62,7 @@ def download_nltk_data() -> None:
|
||||
resources = {
|
||||
"stopwords": "corpora/stopwords",
|
||||
# "wordnet": "corpora/wordnet", # Not in use
|
||||
"punkt": "tokenizers/punkt",
|
||||
"punkt_tab": "tokenizers/punkt_tab",
|
||||
}
|
||||
|
||||
for resource_name, resource_path in resources.items():
|
||||
|
||||
@@ -234,6 +234,10 @@ def delete_messages_and_files_from_chat_session(
|
||||
logger.info(f"Deleting file with name: {lobj_name}")
|
||||
delete_lobj_by_name(lobj_name, db_session)
|
||||
|
||||
# Delete ChatMessage records - CASCADE constraints will automatically handle:
|
||||
# - AgentSubQuery records (via AgentSubQuestion)
|
||||
# - AgentSubQuestion records
|
||||
# - ChatMessage__StandardAnswer relationship records
|
||||
db_session.execute(
|
||||
delete(ChatMessage).where(ChatMessage.chat_session_id == chat_session_id)
|
||||
)
|
||||
|
||||
@@ -423,12 +423,11 @@ def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]
|
||||
dbapi_connection = connection.connection
|
||||
cursor = dbapi_connection.cursor()
|
||||
try:
|
||||
# NOTE: don't use `text()` here since we're using the cursor directly
|
||||
cursor.execute(f'SET search_path = "{tenant_id}"')
|
||||
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
|
||||
cursor.execute(
|
||||
text(
|
||||
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
|
||||
)
|
||||
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
|
||||
)
|
||||
except Exception:
|
||||
raise RuntimeError(f"search_path not set for {tenant_id}")
|
||||
|
||||
@@ -354,7 +354,7 @@ class AgentSubQuery__SearchDoc(Base):
|
||||
__tablename__ = "agent__sub_query__search_doc"
|
||||
|
||||
sub_query_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("agent__sub_query.id"), primary_key=True
|
||||
ForeignKey("agent__sub_query.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
search_doc_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("search_doc.id"), primary_key=True
|
||||
@@ -405,7 +405,7 @@ class ChatMessage__StandardAnswer(Base):
|
||||
__tablename__ = "chat_message__standard_answer"
|
||||
|
||||
chat_message_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("chat_message.id"), primary_key=True
|
||||
ForeignKey("chat_message.id", ondelete="CASCADE"), primary_key=True
|
||||
)
|
||||
standard_answer_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("standard_answer.id"), primary_key=True
|
||||
@@ -1430,7 +1430,9 @@ class AgentSubQuestion(Base):
|
||||
__tablename__ = "agent__sub_question"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
primary_question_id: Mapped[int] = mapped_column(ForeignKey("chat_message.id"))
|
||||
primary_question_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("chat_message.id", ondelete="CASCADE")
|
||||
)
|
||||
chat_session_id: Mapped[UUID] = mapped_column(
|
||||
PGUUID(as_uuid=True), ForeignKey("chat_session.id")
|
||||
)
|
||||
@@ -1464,7 +1466,7 @@ class AgentSubQuery(Base):
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
parent_question_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("agent__sub_question.id")
|
||||
ForeignKey("agent__sub_question.id", ondelete="CASCADE")
|
||||
)
|
||||
chat_session_id: Mapped[UUID] = mapped_column(
|
||||
PGUUID(as_uuid=True), ForeignKey("chat_session.id")
|
||||
|
||||
@@ -36,7 +36,7 @@ MAX_OR_CONDITIONS = 10
|
||||
# up from 500ms for now, since we've seen quite a few timeouts
|
||||
# in the long term, we are looking to improve the performance of Vespa
|
||||
# so that we can bring this back to default
|
||||
VESPA_TIMEOUT = "3s"
|
||||
VESPA_TIMEOUT = "10s"
|
||||
BATCH_SIZE = 128 # Specific to Vespa
|
||||
|
||||
TENANT_ID = "tenant_id"
|
||||
|
||||
@@ -301,7 +301,7 @@ def read_pdf_file(
|
||||
|
||||
|
||||
def docx_to_text_and_images(
|
||||
file: IO[Any],
|
||||
file: IO[Any], file_name: str = ""
|
||||
) -> tuple[str, Sequence[tuple[bytes, str]]]:
|
||||
"""
|
||||
Extract text from a docx. If embed_images=True, also extract inline images.
|
||||
@@ -310,7 +310,11 @@ def docx_to_text_and_images(
|
||||
paragraphs = []
|
||||
embedded_images: list[tuple[bytes, str]] = []
|
||||
|
||||
doc = docx.Document(file)
|
||||
try:
|
||||
doc = docx.Document(file)
|
||||
except BadZipFile as e:
|
||||
logger.warning(f"Failed to extract text from {file_name or 'docx file'}: {e}")
|
||||
return "", []
|
||||
|
||||
# Grab text from paragraphs
|
||||
for paragraph in doc.paragraphs:
|
||||
@@ -360,6 +364,13 @@ def xlsx_to_text(file: IO[Any], file_name: str = "") -> str:
|
||||
else:
|
||||
logger.warning(error_str)
|
||||
return ""
|
||||
except Exception as e:
|
||||
if "File contains no valid workbook part" in str(e):
|
||||
logger.error(
|
||||
f"Failed to extract text from {file_name or 'xlsx file'}. This happens due to a bug in openpyxl. {e}"
|
||||
)
|
||||
return ""
|
||||
raise e
|
||||
|
||||
text_content = []
|
||||
for sheet in workbook.worksheets:
|
||||
|
||||
@@ -2,6 +2,7 @@ import base64
|
||||
from collections.abc import Callable
|
||||
from io import BytesIO
|
||||
from typing import cast
|
||||
from uuid import UUID
|
||||
from uuid import uuid4
|
||||
|
||||
import requests
|
||||
@@ -12,11 +13,11 @@ from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import UserFile
|
||||
from onyx.db.models import UserFolder
|
||||
from onyx.file_processing.extract_file_text import IMAGE_MEDIA_TYPES
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.file_store.models import ChatFileType
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
from onyx.file_store.models import InMemoryChatFile
|
||||
from onyx.server.query_and_chat.chat_utils import mime_type_to_chat_file_type
|
||||
from onyx.utils.b64 import get_image_type
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
@@ -119,27 +120,37 @@ def load_user_file(file_id: int, db_session: Session) -> InMemoryChatFile:
|
||||
if not user_file:
|
||||
raise ValueError(f"User file with id {file_id} not found")
|
||||
|
||||
# Try to load plaintext version first
|
||||
# Get the file record to determine the appropriate chat file type
|
||||
file_store = get_default_file_store(db_session)
|
||||
file_record = file_store.read_file_record(user_file.file_id)
|
||||
|
||||
# Determine appropriate chat file type based on the original file's MIME type
|
||||
chat_file_type = mime_type_to_chat_file_type(file_record.file_type)
|
||||
|
||||
# Try to load plaintext version first
|
||||
plaintext_file_name = user_file_id_to_plaintext_file_name(file_id)
|
||||
|
||||
# check for plain text normalized version first, then use original file otherwise
|
||||
try:
|
||||
file_io = file_store.read_file(plaintext_file_name, mode="b")
|
||||
# For plaintext versions, use PLAIN_TEXT type (unless it's an image which doesn't have plaintext)
|
||||
plaintext_chat_file_type = (
|
||||
ChatFileType.PLAIN_TEXT
|
||||
if chat_file_type != ChatFileType.IMAGE
|
||||
else chat_file_type
|
||||
)
|
||||
chat_file = InMemoryChatFile(
|
||||
file_id=str(user_file.file_id),
|
||||
content=file_io.read(),
|
||||
file_type=ChatFileType.USER_KNOWLEDGE,
|
||||
file_type=plaintext_chat_file_type,
|
||||
filename=user_file.name,
|
||||
)
|
||||
status = "plaintext"
|
||||
return chat_file
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load plaintext for user file {user_file.id}: {e}")
|
||||
# Fall back to original file if plaintext not available
|
||||
file_io = file_store.read_file(user_file.file_id, mode="b")
|
||||
file_record = file_store.read_file_record(user_file.file_id)
|
||||
if file_record.file_type in IMAGE_MEDIA_TYPES:
|
||||
chat_file_type = ChatFileType.IMAGE
|
||||
|
||||
chat_file = InMemoryChatFile(
|
||||
file_id=str(user_file.file_id),
|
||||
@@ -235,6 +246,26 @@ def get_user_files(
|
||||
return user_files
|
||||
|
||||
|
||||
def get_user_files_as_user(
|
||||
user_file_ids: list[int],
|
||||
user_folder_ids: list[int],
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
) -> list[UserFile]:
|
||||
"""
|
||||
Fetches all UserFile database records for a given user.
|
||||
"""
|
||||
user_files = get_user_files(user_file_ids, user_folder_ids, db_session)
|
||||
for user_file in user_files:
|
||||
# Note: if user_id is None, then all files should be None as well
|
||||
# (since auth must be disabled in this case)
|
||||
if user_file.user_id != user_id:
|
||||
raise ValueError(
|
||||
f"User {user_id} does not have access to file {user_file.id}"
|
||||
)
|
||||
return user_files
|
||||
|
||||
|
||||
def save_file_from_url(url: str) -> str:
|
||||
"""NOTE: using multiple sessions here, since this is often called
|
||||
using multithreading. In practice, sharing a session has resulted in
|
||||
|
||||
@@ -264,7 +264,7 @@ class DefaultMultiLLM(LLM):
|
||||
):
|
||||
self._timeout = timeout
|
||||
if timeout is None:
|
||||
if model_is_reasoning_model(model_name):
|
||||
if model_is_reasoning_model(model_name, model_provider):
|
||||
self._timeout = QA_TIMEOUT * 10 # Reasoning models are slow
|
||||
else:
|
||||
self._timeout = QA_TIMEOUT
|
||||
|
||||
@@ -108,7 +108,7 @@ VERTEXAI_DEFAULT_MODEL = "gemini-2.0-flash"
|
||||
VERTEXAI_DEFAULT_FAST_MODEL = "gemini-2.0-flash-lite"
|
||||
VERTEXAI_MODEL_NAMES = [
|
||||
# 2.5 pro models
|
||||
"gemini-2.5-pro-exp-03-25",
|
||||
"gemini-2.5-pro-preview-05-06",
|
||||
# 2.0 flash-lite models
|
||||
VERTEXAI_DEFAULT_FAST_MODEL,
|
||||
"gemini-2.0-flash-lite-001",
|
||||
|
||||
@@ -663,12 +663,34 @@ def model_supports_image_input(model_name: str, model_provider: str) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
def model_is_reasoning_model(model_name: str) -> bool:
|
||||
_REASONING_MODEL_NAMES = [
|
||||
"o1",
|
||||
"o1-mini",
|
||||
"o3-mini",
|
||||
"deepseek-reasoner",
|
||||
"deepseek-r1",
|
||||
]
|
||||
return model_name.lower() in _REASONING_MODEL_NAMES
|
||||
def model_is_reasoning_model(model_name: str, model_provider: str) -> bool:
|
||||
model_map = get_model_map()
|
||||
try:
|
||||
model_obj = find_model_obj(
|
||||
model_map,
|
||||
model_provider,
|
||||
model_name,
|
||||
)
|
||||
if model_obj and "supports_reasoning" in model_obj:
|
||||
return model_obj["supports_reasoning"]
|
||||
|
||||
# Fallback: try using litellm.supports_reasoning() for newer models
|
||||
try:
|
||||
logger.debug("Falling back to `litellm.supports_reasoning`")
|
||||
full_model_name = (
|
||||
f"{model_provider}/{model_name}"
|
||||
if model_provider not in model_name
|
||||
else model_name
|
||||
)
|
||||
return litellm.supports_reasoning(model=full_model_name)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Failed to check if {model_provider}/{model_name} supports reasoning"
|
||||
)
|
||||
return False
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Failed to get model object for {model_provider}/{model_name}"
|
||||
)
|
||||
return False
|
||||
|
||||
@@ -185,9 +185,7 @@ class EmbeddingModel:
|
||||
) -> list[Embedding]:
|
||||
text_batches = batch_list(texts, batch_size)
|
||||
|
||||
logger.debug(
|
||||
f"Encoding {len(texts)} texts in {len(text_batches)} batches for local model"
|
||||
)
|
||||
logger.debug(f"Encoding {len(texts)} texts in {len(text_batches)} batches")
|
||||
|
||||
embeddings: list[Embedding] = []
|
||||
|
||||
|
||||
@@ -64,7 +64,7 @@ TENANT_HEARTBEAT_INTERVAL = (
|
||||
15 # How often pods send heartbeats to indicate they are still processing a tenant
|
||||
)
|
||||
TENANT_HEARTBEAT_EXPIRATION = (
|
||||
30 # How long before a tenant's heartbeat expires, allowing other pods to take over
|
||||
60 # How long before a tenant's heartbeat expires, allowing other pods to take over
|
||||
)
|
||||
TENANT_ACQUISITION_INTERVAL = 60 # How often pods attempt to acquire unprocessed tenants and checks for new tokens
|
||||
|
||||
|
||||
@@ -137,7 +137,10 @@ def handle_generate_answer_button(
|
||||
raise ValueError("Missing thread_ts in the payload")
|
||||
|
||||
thread_messages = read_slack_thread(
|
||||
channel=channel_id, thread=thread_ts, client=client.web_client
|
||||
tenant_id=client._tenant_id,
|
||||
channel=channel_id,
|
||||
thread=thread_ts,
|
||||
client=client.web_client,
|
||||
)
|
||||
# remove all assistant messages till we get to the last user message
|
||||
# we want the new answer to be generated off of the last "question" in
|
||||
|
||||
@@ -419,6 +419,11 @@ def handle_regular_answer(
|
||||
skip_ai_feedback=skip_ai_feedback,
|
||||
)
|
||||
|
||||
# NOTE(rkuo): Slack has a maximum block list size of 50.
|
||||
# we should modify build_slack_response_blocks to respect the max
|
||||
# but enforcing the hard limit here is the last resort.
|
||||
all_blocks = all_blocks[:50]
|
||||
|
||||
try:
|
||||
respond_in_thread_or_channel(
|
||||
client=client,
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import os
|
||||
import signal
|
||||
import sys
|
||||
@@ -11,11 +10,12 @@ from types import FrameType
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from typing import Set
|
||||
|
||||
import psycopg2.errors
|
||||
from prometheus_client import Gauge
|
||||
from prometheus_client import start_http_server
|
||||
from redis.lock import Lock
|
||||
from redis.lock import Lock as RedisLock
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
from slack_sdk.http_retry import ConnectionErrorRetryHandler
|
||||
@@ -86,7 +86,7 @@ from onyx.onyxbot.slack.models import SlackMessageInfo
|
||||
from onyx.onyxbot.slack.utils import check_message_limit
|
||||
from onyx.onyxbot.slack.utils import decompose_action_id
|
||||
from onyx.onyxbot.slack.utils import get_channel_name_from_id
|
||||
from onyx.onyxbot.slack.utils import get_onyx_bot_slack_bot_id
|
||||
from onyx.onyxbot.slack.utils import get_onyx_bot_auth_ids
|
||||
from onyx.onyxbot.slack.utils import read_slack_thread
|
||||
from onyx.onyxbot.slack.utils import remove_onyx_bot_tag
|
||||
from onyx.onyxbot.slack.utils import rephrase_slack_message
|
||||
@@ -105,7 +105,6 @@ from shared_configs.configs import SLACK_CHANNEL_ID
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Prometheus metric for HPA
|
||||
@@ -135,7 +134,7 @@ _OFFICIAL_SLACKBOT_USER_ID = "USLACKBOT"
|
||||
class SlackbotHandler:
|
||||
def __init__(self) -> None:
|
||||
logger.info("Initializing SlackbotHandler")
|
||||
self.tenant_ids: Set[str] = set()
|
||||
self.tenant_ids: set[str] = set()
|
||||
# The keys for these dictionaries are tuples of (tenant_id, slack_bot_id)
|
||||
self.socket_clients: Dict[tuple[str, int], TenantSocketModeClient] = {}
|
||||
self.slack_bot_tokens: Dict[tuple[str, int], SlackBotTokens] = {}
|
||||
@@ -144,8 +143,11 @@ class SlackbotHandler:
|
||||
self.redis_locks: Dict[str, Lock] = {}
|
||||
|
||||
self.running = True
|
||||
self.pod_id = self.get_pod_id()
|
||||
self.pod_id = os.environ.get("HOSTNAME", "unknown_pod")
|
||||
self._shutdown_event = Event()
|
||||
|
||||
self._lock = threading.Lock()
|
||||
|
||||
logger.info(f"Pod ID: {self.pod_id}")
|
||||
|
||||
# Set up signal handlers for graceful shutdown
|
||||
@@ -169,12 +171,8 @@ class SlackbotHandler:
|
||||
|
||||
self.acquire_thread.start()
|
||||
self.heartbeat_thread.start()
|
||||
logger.info("Background threads started")
|
||||
|
||||
def get_pod_id(self) -> str:
|
||||
pod_id = os.environ.get("HOSTNAME", "unknown_pod")
|
||||
logger.info(f"Retrieved pod ID: {pod_id}")
|
||||
return pod_id
|
||||
logger.info("Background threads started")
|
||||
|
||||
def acquire_tenants_loop(self) -> None:
|
||||
while not self._shutdown_event.is_set():
|
||||
@@ -194,12 +192,18 @@ class SlackbotHandler:
|
||||
self._shutdown_event.wait(timeout=TENANT_ACQUISITION_INTERVAL)
|
||||
|
||||
def heartbeat_loop(self) -> None:
|
||||
"""This heartbeats into redis.
|
||||
|
||||
NOTE(rkuo): this is not thread-safe with acquire_tenants_loop and will
|
||||
occasionally exception. Fix it!
|
||||
"""
|
||||
while not self._shutdown_event.is_set():
|
||||
try:
|
||||
self.send_heartbeats()
|
||||
logger.debug(
|
||||
f"Sent heartbeats for {len(self.tenant_ids)} active tenants"
|
||||
)
|
||||
with self._lock:
|
||||
tenant_ids = self.tenant_ids.copy()
|
||||
|
||||
SlackbotHandler.send_heartbeats(self.pod_id, tenant_ids)
|
||||
logger.debug(f"Sent heartbeats for {len(tenant_ids)} active tenants")
|
||||
except Exception as e:
|
||||
logger.exception(f"Error in heartbeat loop: {e}")
|
||||
self._shutdown_event.wait(timeout=TENANT_HEARTBEAT_INTERVAL)
|
||||
@@ -224,7 +228,7 @@ class SlackbotHandler:
|
||||
f"No Slack bot tokens found for tenant={tenant_id}, bot {bot.id}"
|
||||
)
|
||||
if tenant_bot_pair in self.socket_clients:
|
||||
asyncio.run(self.socket_clients[tenant_bot_pair].close())
|
||||
self.socket_clients[tenant_bot_pair].close()
|
||||
del self.socket_clients[tenant_bot_pair]
|
||||
del self.slack_bot_tokens[tenant_bot_pair]
|
||||
return
|
||||
@@ -252,9 +256,20 @@ class SlackbotHandler:
|
||||
|
||||
# Close any existing connection first
|
||||
if tenant_bot_pair in self.socket_clients:
|
||||
asyncio.run(self.socket_clients[tenant_bot_pair].close())
|
||||
self.socket_clients[tenant_bot_pair].close()
|
||||
|
||||
self.start_socket_client(bot.id, tenant_id, slack_bot_tokens)
|
||||
socket_client = self.start_socket_client(
|
||||
bot.id, tenant_id, slack_bot_tokens
|
||||
)
|
||||
if socket_client:
|
||||
# Ensure tenant is tracked as active
|
||||
self.socket_clients[tenant_id, bot.id] = socket_client
|
||||
|
||||
logger.info(
|
||||
f"Started SocketModeClient: {tenant_id=} {socket_client.bot_name=} {bot.id=}"
|
||||
)
|
||||
|
||||
self.tenant_ids.add(tenant_id)
|
||||
|
||||
def acquire_tenants(self) -> None:
|
||||
"""
|
||||
@@ -264,6 +279,8 @@ class SlackbotHandler:
|
||||
- If a tenant in self.tenant_ids no longer has Slack bots, remove it (and release the lock in this scope).
|
||||
"""
|
||||
|
||||
token: Token[str | None]
|
||||
|
||||
# tenants that are disabled (e.g. their trial is over and haven't subscribed)
|
||||
# for non-cloud, this will return an empty set
|
||||
gated_tenants = fetch_ee_implementation_or_noop(
|
||||
@@ -271,16 +288,14 @@ class SlackbotHandler:
|
||||
"get_gated_tenants",
|
||||
set(),
|
||||
)()
|
||||
all_tenants = [
|
||||
all_active_tenants = [
|
||||
tenant_id
|
||||
for tenant_id in get_all_tenant_ids()
|
||||
if tenant_id not in gated_tenants
|
||||
]
|
||||
|
||||
token: Token[str | None]
|
||||
|
||||
# 1) Try to acquire locks for new tenants
|
||||
for tenant_id in all_tenants:
|
||||
for tenant_id in all_active_tenants:
|
||||
if (
|
||||
DISALLOWED_SLACK_BOT_TENANT_LIST is not None
|
||||
and tenant_id in DISALLOWED_SLACK_BOT_TENANT_LIST
|
||||
@@ -295,14 +310,18 @@ class SlackbotHandler:
|
||||
# Respect max tenant limit per pod
|
||||
if len(self.tenant_ids) >= MAX_TENANTS_PER_POD:
|
||||
logger.info(
|
||||
f"Max tenants per pod reached ({MAX_TENANTS_PER_POD}); not acquiring more."
|
||||
f"Max tenants per pod reached, not acquiring more: {MAX_TENANTS_PER_POD=}"
|
||||
)
|
||||
break
|
||||
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
# Acquire a Redis lock (non-blocking)
|
||||
rlock = redis_client.lock(
|
||||
OnyxRedisLocks.SLACK_BOT_LOCK, timeout=TENANT_LOCK_EXPIRATION
|
||||
# thread_local=False because the shutdown event is handled
|
||||
# on an arbitrary thread
|
||||
rlock: RedisLock = redis_client.lock(
|
||||
OnyxRedisLocks.SLACK_BOT_LOCK,
|
||||
timeout=TENANT_LOCK_EXPIRATION,
|
||||
thread_local=False,
|
||||
)
|
||||
lock_acquired = rlock.acquire(blocking=False)
|
||||
|
||||
@@ -333,6 +352,10 @@ class SlackbotHandler:
|
||||
except KvKeyNotFoundError:
|
||||
# No Slackbot tokens, pass
|
||||
pass
|
||||
except psycopg2.errors.UndefinedTable:
|
||||
logger.error(
|
||||
"Undefined table error in fetch_slack_bots. Tenant schema may need fixing."
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Error fetching Slack bots for tenant {tenant_id}: {e}"
|
||||
@@ -409,10 +432,11 @@ class SlackbotHandler:
|
||||
Helper to remove a tenant from `self.tenant_ids` and close any socket clients.
|
||||
(Lock release now happens in `acquire_tenants()`, not here.)
|
||||
"""
|
||||
socket_client_list = list(self.socket_clients.items())
|
||||
# Close all socket clients for this tenant
|
||||
for (t_id, slack_bot_id), client in list(self.socket_clients.items()):
|
||||
for (t_id, slack_bot_id), client in socket_client_list:
|
||||
if t_id == tenant_id:
|
||||
asyncio.run(client.close())
|
||||
client.close()
|
||||
del self.socket_clients[(t_id, slack_bot_id)]
|
||||
del self.slack_bot_tokens[(t_id, slack_bot_id)]
|
||||
logger.info(
|
||||
@@ -423,19 +447,22 @@ class SlackbotHandler:
|
||||
if tenant_id in self.tenant_ids:
|
||||
self.tenant_ids.remove(tenant_id)
|
||||
|
||||
def send_heartbeats(self) -> None:
|
||||
@staticmethod
|
||||
def send_heartbeats(pod_id: str, tenant_ids: set[str]) -> None:
|
||||
current_time = int(time.time())
|
||||
logger.debug(f"Sending heartbeats for {len(self.tenant_ids)} active tenants")
|
||||
for tenant_id in self.tenant_ids:
|
||||
logger.debug(f"Sending heartbeats for {len(tenant_ids)} active tenants")
|
||||
for tenant_id in tenant_ids:
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
heartbeat_key = f"{OnyxRedisLocks.SLACK_BOT_HEARTBEAT_PREFIX}:{self.pod_id}"
|
||||
heartbeat_key = f"{OnyxRedisLocks.SLACK_BOT_HEARTBEAT_PREFIX}:{pod_id}"
|
||||
redis_client.set(
|
||||
heartbeat_key, current_time, ex=TENANT_HEARTBEAT_EXPIRATION
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def start_socket_client(
|
||||
self, slack_bot_id: int, tenant_id: str, slack_bot_tokens: SlackBotTokens
|
||||
) -> None:
|
||||
slack_bot_id: int, tenant_id: str, slack_bot_tokens: SlackBotTokens
|
||||
) -> TenantSocketModeClient | None:
|
||||
"""Returns the socket client if this succeeds"""
|
||||
socket_client: TenantSocketModeClient = _get_socket_client(
|
||||
slack_bot_tokens, tenant_id, slack_bot_id
|
||||
)
|
||||
@@ -450,18 +477,21 @@ class SlackbotHandler:
|
||||
bot_name = (
|
||||
user_info["user"]["real_name"] or user_info["user"]["name"]
|
||||
)
|
||||
logger.info(
|
||||
f"Started socket client for Slackbot with name '{bot_name}' (tenant: {tenant_id}, app: {slack_bot_id})"
|
||||
)
|
||||
socket_client.bot_name = bot_name
|
||||
# logger.info(
|
||||
# f"Started socket client for Slackbot with name '{bot_name}' (tenant: {tenant_id}, app: {slack_bot_id})"
|
||||
# )
|
||||
except SlackApiError as e:
|
||||
# Only error out if we get a not_authed error
|
||||
if "not_authed" in str(e):
|
||||
self.tenant_ids.add(tenant_id)
|
||||
# for some reason we want to add the tenant to the list when this happens?
|
||||
logger.error(
|
||||
f"Authentication error: Invalid or expired credentials for tenant: {tenant_id}, app: {slack_bot_id}. "
|
||||
"Error: {e}"
|
||||
f"Authentication error - Invalid or expired credentials: "
|
||||
f"{tenant_id=} {slack_bot_id=}. "
|
||||
f"Error: {e}"
|
||||
)
|
||||
return
|
||||
return None
|
||||
|
||||
# Log other Slack API errors but continue
|
||||
logger.error(
|
||||
f"Slack API error fetching bot info: {e} for tenant: {tenant_id}, app: {slack_bot_id}"
|
||||
@@ -477,23 +507,30 @@ class SlackbotHandler:
|
||||
socket_client.socket_mode_request_listeners.append(process_slack_event) # type: ignore
|
||||
|
||||
# Establish a WebSocket connection to the Socket Mode servers
|
||||
logger.info(
|
||||
f"Connecting socket client for tenant: {tenant_id}, app: {slack_bot_id}"
|
||||
)
|
||||
# logger.debug(
|
||||
# f"Connecting socket client for tenant: {tenant_id}, app: {slack_bot_id}"
|
||||
# )
|
||||
socket_client.connect()
|
||||
self.socket_clients[tenant_id, slack_bot_id] = socket_client
|
||||
# Ensure tenant is tracked as active
|
||||
self.tenant_ids.add(tenant_id)
|
||||
logger.info(
|
||||
f"Started SocketModeClient for tenant: {tenant_id}, app: {slack_bot_id}"
|
||||
)
|
||||
# logger.info(
|
||||
# f"Started SocketModeClient for tenant: {tenant_id}, app: {slack_bot_id}"
|
||||
# )
|
||||
|
||||
def stop_socket_clients(self) -> None:
|
||||
logger.info(f"Stopping {len(self.socket_clients)} socket clients")
|
||||
for (tenant_id, slack_bot_id), client in list(self.socket_clients.items()):
|
||||
asyncio.run(client.close())
|
||||
return socket_client
|
||||
|
||||
@staticmethod
|
||||
def stop_socket_clients(
|
||||
pod_id: str, socket_clients: Dict[tuple[str, int], TenantSocketModeClient]
|
||||
) -> None:
|
||||
socket_client_list = list(socket_clients.items())
|
||||
length = len(socket_client_list)
|
||||
|
||||
x = 0
|
||||
for (tenant_id, slack_bot_id), client in socket_client_list:
|
||||
x += 1
|
||||
client.close()
|
||||
logger.info(
|
||||
f"Stopped SocketModeClient for tenant: {tenant_id}, app: {slack_bot_id}"
|
||||
f"Stopped SocketModeClient {x}/{length}: "
|
||||
f"{pod_id=} {tenant_id=} {slack_bot_id=}"
|
||||
)
|
||||
|
||||
def shutdown(self, signum: int | None, frame: FrameType | None) -> None:
|
||||
@@ -502,11 +539,15 @@ class SlackbotHandler:
|
||||
|
||||
logger.info("Shutting down gracefully")
|
||||
self.running = False
|
||||
self._shutdown_event.set()
|
||||
self._shutdown_event.set() # set the shutdown event
|
||||
|
||||
# wait for threads to detect the event and exit
|
||||
self.acquire_thread.join(timeout=60.0)
|
||||
self.heartbeat_thread.join(timeout=60.0)
|
||||
|
||||
# Stop all socket clients
|
||||
logger.info(f"Stopping {len(self.socket_clients)} socket clients")
|
||||
self.stop_socket_clients()
|
||||
SlackbotHandler.stop_socket_clients(self.pod_id, self.socket_clients)
|
||||
|
||||
# Release locks for all tenants we currently hold
|
||||
logger.info(f"Releasing locks for {len(self.tenant_ids)} tenants")
|
||||
@@ -533,7 +574,13 @@ def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -
|
||||
"""True to keep going, False to ignore this Slack request"""
|
||||
|
||||
# skip cases where the bot is disabled in the web UI
|
||||
bot_tag_id = get_onyx_bot_slack_bot_id(client.web_client)
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
bot_token_user_id, bot_token_bot_id = get_onyx_bot_auth_ids(
|
||||
tenant_id, client.web_client
|
||||
)
|
||||
logger.info(f"prefilter_requests: {bot_token_user_id=} {bot_token_bot_id=}")
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
slack_bot = fetch_slack_bot(
|
||||
db_session=db_session, slack_bot_id=client.slack_bot_id
|
||||
@@ -580,7 +627,7 @@ def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -
|
||||
|
||||
if (
|
||||
msg in _SLACK_GREETINGS_TO_IGNORE
|
||||
or remove_onyx_bot_tag(msg, client=client.web_client)
|
||||
or remove_onyx_bot_tag(tenant_id, msg, client=client.web_client)
|
||||
in _SLACK_GREETINGS_TO_IGNORE
|
||||
):
|
||||
channel_specific_logger.error(
|
||||
@@ -599,15 +646,38 @@ def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -
|
||||
)
|
||||
return False
|
||||
|
||||
bot_tag_id = get_onyx_bot_slack_bot_id(client.web_client)
|
||||
bot_token_user_id, bot_token_bot_id = get_onyx_bot_auth_ids(
|
||||
tenant_id, client.web_client
|
||||
)
|
||||
if event_type == "message":
|
||||
is_onyx_bot_msg = False
|
||||
is_tagged = False
|
||||
|
||||
event_user = event.get("user", "")
|
||||
event_bot_id = event.get("bot_id", "")
|
||||
|
||||
# temporary debugging
|
||||
if tenant_id == "tenant_i-04224818da13bf695":
|
||||
logger.warning(
|
||||
f"{tenant_id=} "
|
||||
f"{bot_token_user_id=} "
|
||||
f"{bot_token_bot_id=} "
|
||||
f"{event=}"
|
||||
)
|
||||
|
||||
is_dm = event.get("channel_type") == "im"
|
||||
is_tagged = bot_tag_id and f"<@{bot_tag_id}>" in msg
|
||||
is_onyx_bot_msg = bot_tag_id and bot_tag_id in event.get("user", "")
|
||||
if bot_token_user_id and f"<@{bot_token_user_id}>" in msg:
|
||||
is_tagged = True
|
||||
|
||||
if bot_token_user_id and bot_token_user_id in event_user:
|
||||
is_onyx_bot_msg = True
|
||||
|
||||
if bot_token_bot_id and bot_token_bot_id in event_bot_id:
|
||||
is_onyx_bot_msg = True
|
||||
|
||||
# OnyxBot should never respond to itself
|
||||
if is_onyx_bot_msg:
|
||||
logger.info("Ignoring message from OnyxBot")
|
||||
logger.info("Ignoring message from OnyxBot (self-message)")
|
||||
return False
|
||||
|
||||
# DMs with the bot don't pick up the @OnyxBot so we have to keep the
|
||||
@@ -632,7 +702,7 @@ def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -
|
||||
)
|
||||
|
||||
# If OnyxBot is not specifically tagged and the channel is not set to respond to bots, ignore the message
|
||||
if (not bot_tag_id or bot_tag_id not in msg) and (
|
||||
if (not bot_token_user_id or bot_token_user_id not in msg) and (
|
||||
not slack_channel_config
|
||||
or not slack_channel_config.channel_config.get("respond_to_bots")
|
||||
):
|
||||
@@ -692,7 +762,7 @@ def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -
|
||||
if not check_message_limit():
|
||||
return False
|
||||
|
||||
logger.debug(f"Handling Slack request with Payload: '{req.payload}'")
|
||||
logger.debug(f"Handling Slack request: {client.bot_name=} '{req.payload=}'")
|
||||
return True
|
||||
|
||||
|
||||
@@ -731,15 +801,16 @@ def process_feedback(req: SocketModeRequest, client: TenantSocketModeClient) ->
|
||||
def build_request_details(
|
||||
req: SocketModeRequest, client: TenantSocketModeClient
|
||||
) -> SlackMessageInfo:
|
||||
tagged: bool = False
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
if req.type == "events_api":
|
||||
event = cast(dict[str, Any], req.payload["event"])
|
||||
msg = cast(str, event["text"])
|
||||
channel = cast(str, event["channel"])
|
||||
|
||||
# Check for both app_mention events and messages containing bot tag
|
||||
bot_tag_id = get_onyx_bot_slack_bot_id(client.web_client)
|
||||
tagged = (event.get("type") == "app_mention") or (
|
||||
event.get("type") == "message" and bot_tag_id and f"<@{bot_tag_id}>" in msg
|
||||
)
|
||||
bot_token_user_id, _ = get_onyx_bot_auth_ids(tenant_id, client.web_client)
|
||||
message_ts = event.get("ts")
|
||||
thread_ts = event.get("thread_ts")
|
||||
sender_id = event.get("user") or None
|
||||
@@ -748,7 +819,7 @@ def build_request_details(
|
||||
)
|
||||
email = expert_info.email if expert_info else None
|
||||
|
||||
msg = remove_onyx_bot_tag(msg, client=client.web_client)
|
||||
msg = remove_onyx_bot_tag(tenant_id, msg, client=client.web_client)
|
||||
|
||||
if DANSWER_BOT_REPHRASE_MESSAGE:
|
||||
logger.info(f"Rephrasing Slack message. Original message: {msg}")
|
||||
@@ -760,12 +831,24 @@ def build_request_details(
|
||||
else:
|
||||
logger.info(f"Received Slack message: {msg}")
|
||||
|
||||
event_type = event.get("type")
|
||||
if event_type == "app_mention":
|
||||
tagged = True
|
||||
|
||||
if event_type == "message":
|
||||
if bot_token_user_id:
|
||||
if f"<@{bot_token_user_id}>" in msg:
|
||||
tagged = True
|
||||
|
||||
if tagged:
|
||||
logger.debug("User tagged OnyxBot")
|
||||
|
||||
if thread_ts != message_ts and thread_ts is not None:
|
||||
thread_messages = read_slack_thread(
|
||||
channel=channel, thread=thread_ts, client=client.web_client
|
||||
tenant_id=tenant_id,
|
||||
channel=channel,
|
||||
thread=thread_ts,
|
||||
client=client.web_client,
|
||||
)
|
||||
else:
|
||||
sender_display_name = None
|
||||
@@ -842,12 +925,24 @@ def process_message(
|
||||
notify_no_answer: bool = NOTIFY_SLACKBOT_NO_ANSWER,
|
||||
) -> None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
logger.debug(
|
||||
f"Received Slack request of type: '{req.type}' for tenant, {tenant_id}"
|
||||
)
|
||||
if req.type == "events_api":
|
||||
event = cast(dict[str, Any], req.payload["event"])
|
||||
event_type = event.get("type")
|
||||
msg = cast(str, event.get("text", ""))
|
||||
logger.info(
|
||||
f"process_message start: {tenant_id=} {req.type=} {req.envelope_id=} "
|
||||
f"{event_type=} {msg=}"
|
||||
)
|
||||
else:
|
||||
logger.info(
|
||||
f"process_message start: {tenant_id=} {req.type=} {req.envelope_id=}"
|
||||
)
|
||||
|
||||
# Throw out requests that can't or shouldn't be handled
|
||||
if not prefilter_requests(req, client):
|
||||
logger.info(
|
||||
f"process_message prefiltered: {tenant_id=} {req.type=} {req.envelope_id=}"
|
||||
)
|
||||
return
|
||||
|
||||
details = build_request_details(req, client)
|
||||
@@ -890,6 +985,10 @@ def process_message(
|
||||
if notify_no_answer:
|
||||
apologize_for_fail(details, client)
|
||||
|
||||
logger.info(
|
||||
f"process_message finished: success={not failed} {tenant_id=} {req.type=} {req.envelope_id=}"
|
||||
)
|
||||
|
||||
|
||||
def acknowledge_message(req: SocketModeRequest, client: TenantSocketModeClient) -> None:
|
||||
response = SocketModeResponse(envelope_id=req.envelope_id)
|
||||
|
||||
@@ -2,6 +2,7 @@ import logging
|
||||
import random
|
||||
import re
|
||||
import string
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
@@ -48,17 +49,38 @@ from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
slack_token_user_ids: dict[str, str | None] = {}
|
||||
slack_token_bot_ids: dict[str, str | None] = {}
|
||||
slack_token_lock = threading.Lock()
|
||||
|
||||
_DANSWER_BOT_SLACK_BOT_ID: str | None = None
|
||||
_DANSWER_BOT_MESSAGE_COUNT: int = 0
|
||||
_DANSWER_BOT_COUNT_START_TIME: float = time.time()
|
||||
|
||||
|
||||
def get_onyx_bot_slack_bot_id(web_client: WebClient) -> Any:
|
||||
global _DANSWER_BOT_SLACK_BOT_ID
|
||||
if _DANSWER_BOT_SLACK_BOT_ID is None:
|
||||
_DANSWER_BOT_SLACK_BOT_ID = web_client.auth_test().get("user_id")
|
||||
return _DANSWER_BOT_SLACK_BOT_ID
|
||||
def get_onyx_bot_auth_ids(
|
||||
tenant_id: str, web_client: WebClient
|
||||
) -> tuple[str | None, str | None]:
|
||||
"""Returns a tuple of user_id and bot_id."""
|
||||
|
||||
user_id: str | None
|
||||
bot_id: str | None
|
||||
|
||||
global slack_token_user_ids
|
||||
global slack_token_bot_ids
|
||||
|
||||
with slack_token_lock:
|
||||
user_id = slack_token_user_ids.get(tenant_id)
|
||||
bot_id = slack_token_bot_ids.get(tenant_id)
|
||||
|
||||
if user_id is None or bot_id is None:
|
||||
response = web_client.auth_test()
|
||||
user_id = response.get("user_id")
|
||||
bot_id = response.get("bot_id")
|
||||
with slack_token_lock:
|
||||
slack_token_user_ids[tenant_id] = user_id
|
||||
slack_token_bot_ids[tenant_id] = bot_id
|
||||
|
||||
return user_id, bot_id
|
||||
|
||||
|
||||
def check_message_limit() -> bool:
|
||||
@@ -117,35 +139,38 @@ def update_emote_react(
|
||||
remove: bool,
|
||||
client: WebClient,
|
||||
) -> None:
|
||||
try:
|
||||
if not message_ts:
|
||||
logger.error(
|
||||
f"Tried to remove a react in {channel} but no message specified"
|
||||
)
|
||||
return
|
||||
if not message_ts:
|
||||
action = "remove" if remove else "add"
|
||||
logger.error(f"update_emote_react - no message specified: {channel=} {action=}")
|
||||
return
|
||||
|
||||
if remove:
|
||||
if remove:
|
||||
try:
|
||||
client.reactions_remove(
|
||||
name=emoji,
|
||||
channel=channel,
|
||||
timestamp=message_ts,
|
||||
)
|
||||
else:
|
||||
client.reactions_add(
|
||||
name=emoji,
|
||||
channel=channel,
|
||||
timestamp=message_ts,
|
||||
)
|
||||
except SlackApiError as e:
|
||||
if remove:
|
||||
except SlackApiError as e:
|
||||
logger.error(f"Failed to remove Reaction due to: {e}")
|
||||
else:
|
||||
logger.error(f"Was not able to react to user message due to: {e}")
|
||||
|
||||
return
|
||||
|
||||
try:
|
||||
client.reactions_add(
|
||||
name=emoji,
|
||||
channel=channel,
|
||||
timestamp=message_ts,
|
||||
)
|
||||
except SlackApiError as e:
|
||||
logger.error(f"Was not able to react to user message due to: {e}")
|
||||
|
||||
return
|
||||
|
||||
|
||||
def remove_onyx_bot_tag(message_str: str, client: WebClient) -> str:
|
||||
bot_tag_id = get_onyx_bot_slack_bot_id(web_client=client)
|
||||
return re.sub(rf"<@{bot_tag_id}>\s*", "", message_str)
|
||||
def remove_onyx_bot_tag(tenant_id: str, message_str: str, client: WebClient) -> str:
|
||||
bot_token_user_id, _ = get_onyx_bot_auth_ids(tenant_id, web_client=client)
|
||||
return re.sub(rf"<@{bot_token_user_id}>\s*", "", message_str)
|
||||
|
||||
|
||||
def _check_for_url_in_block(block: Block) -> bool:
|
||||
@@ -215,7 +240,8 @@ def respond_in_thread_or_channel(
|
||||
unfurl_media=unfurl,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to post message: {e} \n blocks: {blocks}")
|
||||
blocks_str = str(blocks)[:1024] # truncate block logging
|
||||
logger.warning(f"Failed to post message: {e} \n blocks: {blocks_str}")
|
||||
logger.warning("Trying again without blocks that have urls")
|
||||
|
||||
if not blocks:
|
||||
@@ -252,7 +278,8 @@ def respond_in_thread_or_channel(
|
||||
unfurl_media=unfurl,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to post message: {e} \n blocks: {blocks}")
|
||||
blocks_str = str(blocks)[:1024] # truncate block logging
|
||||
logger.warning(f"Failed to post message: {e} \n blocks: {blocks_str}")
|
||||
logger.warning("Trying again without blocks that have urls")
|
||||
|
||||
if not blocks:
|
||||
@@ -515,7 +542,7 @@ def fetch_user_semantic_id_from_id(
|
||||
|
||||
|
||||
def read_slack_thread(
|
||||
channel: str, thread: str, client: WebClient
|
||||
tenant_id: str, channel: str, thread: str, client: WebClient
|
||||
) -> list[ThreadMessage]:
|
||||
thread_messages: list[ThreadMessage] = []
|
||||
response = client.conversations_replies(channel=channel, ts=thread)
|
||||
@@ -529,9 +556,22 @@ def read_slack_thread(
|
||||
)
|
||||
message_type = MessageType.USER
|
||||
else:
|
||||
self_slack_bot_id = get_onyx_bot_slack_bot_id(client)
|
||||
blocks: Any
|
||||
if reply.get("user") == self_slack_bot_id:
|
||||
is_onyx_bot_response = False
|
||||
|
||||
reply_user = reply.get("user")
|
||||
reply_bot_id = reply.get("bot_id")
|
||||
|
||||
self_slack_bot_user_id, self_slack_bot_bot_id = get_onyx_bot_auth_ids(
|
||||
tenant_id, client
|
||||
)
|
||||
if reply_user is not None and reply_user == self_slack_bot_user_id:
|
||||
is_onyx_bot_response = True
|
||||
|
||||
if reply_bot_id is not None and reply_bot_id == self_slack_bot_bot_id:
|
||||
is_onyx_bot_response = True
|
||||
|
||||
if is_onyx_bot_response:
|
||||
# OnyxBot response
|
||||
message_type = MessageType.ASSISTANT
|
||||
user_sem_id = "Assistant"
|
||||
@@ -573,7 +613,7 @@ def read_slack_thread(
|
||||
logger.warning("Skipping Slack thread message, no text found")
|
||||
continue
|
||||
|
||||
message = remove_onyx_bot_tag(message, client=client)
|
||||
message = remove_onyx_bot_tag(tenant_id, message, client=client)
|
||||
thread_messages.append(
|
||||
ThreadMessage(message=message, sender=user_sem_id, role=message_type)
|
||||
)
|
||||
@@ -676,6 +716,7 @@ class TenantSocketModeClient(SocketModeClient):
|
||||
super().__init__(*args, **kwargs)
|
||||
self._tenant_id = tenant_id
|
||||
self.slack_bot_id = slack_bot_id
|
||||
self.bot_name: str = "Unnamed"
|
||||
|
||||
@contextmanager
|
||||
def _set_tenant_context(self) -> Generator[None, None, None]:
|
||||
|
||||
@@ -51,7 +51,10 @@ def llm_eval_section(
|
||||
messages = _get_usefulness_messages()
|
||||
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
|
||||
model_output = message_to_string(llm.invoke(filled_llm_prompt))
|
||||
logger.debug(model_output)
|
||||
|
||||
# NOTE(rkuo): all this does is print "Yes useful" or "Not useful"
|
||||
# disabling becuase it's spammy, restore and give more context if this is needed
|
||||
# logger.debug(model_output)
|
||||
|
||||
return _extract_usefulness(model_output)
|
||||
|
||||
@@ -64,6 +67,8 @@ def llm_batch_eval_sections(
|
||||
metadata_list: list[dict[str, str | list[str]]],
|
||||
use_threads: bool = True,
|
||||
) -> list[bool]:
|
||||
answer: list[bool]
|
||||
|
||||
if DISABLE_LLM_DOC_RELEVANCE:
|
||||
raise RuntimeError(
|
||||
"LLM Doc Relevance is globally disabled, "
|
||||
@@ -86,12 +91,13 @@ def llm_batch_eval_sections(
|
||||
)
|
||||
|
||||
# In case of failure/timeout, don't throw out the section
|
||||
return [True if item is None else item for item in parallel_results]
|
||||
answer = [True if item is None else item for item in parallel_results]
|
||||
return answer
|
||||
|
||||
else:
|
||||
return [
|
||||
llm_eval_section(query, section_content, llm, title, metadata)
|
||||
for section_content, title, metadata in zip(
|
||||
section_contents, titles, metadata_list
|
||||
)
|
||||
]
|
||||
answer = [
|
||||
llm_eval_section(query, section_content, llm, title, metadata)
|
||||
for section_content, title, metadata in zip(
|
||||
section_contents, titles, metadata_list
|
||||
)
|
||||
]
|
||||
return answer
|
||||
|
||||
@@ -403,7 +403,7 @@ def get_docs_sync_status(
|
||||
def get_cc_pair_indexing_errors(
|
||||
cc_pair_id: int,
|
||||
include_resolved: bool = Query(False),
|
||||
page: int = Query(0, ge=0),
|
||||
page_num: int = Query(0, ge=0),
|
||||
page_size: int = Query(10, ge=1, le=100),
|
||||
_: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
@@ -413,7 +413,7 @@ def get_cc_pair_indexing_errors(
|
||||
Args:
|
||||
cc_pair_id: ID of the connector-credential pair to get errors for
|
||||
include_resolved: Whether to include resolved errors in the results
|
||||
page: Page number for pagination, starting at 0
|
||||
page_num: Page number for pagination, starting at 0
|
||||
page_size: Number of errors to return per page
|
||||
_: Current user, must be curator or admin
|
||||
db_session: Database session
|
||||
@@ -431,7 +431,7 @@ def get_cc_pair_indexing_errors(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
unresolved_only=not include_resolved,
|
||||
page=page,
|
||||
page=page_num,
|
||||
page_size=page_size,
|
||||
)
|
||||
return PaginatedReturn(
|
||||
|
||||
@@ -30,6 +30,7 @@ from onyx.chat.prompt_builder.citations_prompt import (
|
||||
compute_max_document_tokens_for_persona,
|
||||
)
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.chat_configs import HARD_DELETE_CHATS
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import MessageType
|
||||
@@ -203,6 +204,7 @@ def update_chat_session_model(
|
||||
def get_chat_session(
|
||||
session_id: UUID,
|
||||
is_shared: bool = False,
|
||||
include_deleted: bool = False,
|
||||
user: User | None = Depends(current_chat_accessible_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatSessionDetailResponse:
|
||||
@@ -213,6 +215,7 @@ def get_chat_session(
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
is_shared=is_shared,
|
||||
include_deleted=include_deleted,
|
||||
)
|
||||
except ValueError:
|
||||
raise ValueError("Chat session does not exist or has been deleted")
|
||||
@@ -253,6 +256,7 @@ def get_chat_session(
|
||||
time_created=chat_session.time_created,
|
||||
shared_status=chat_session.shared_status,
|
||||
current_temperature_override=chat_session.temperature_override,
|
||||
deleted=chat_session.deleted,
|
||||
)
|
||||
|
||||
|
||||
@@ -357,12 +361,19 @@ def delete_all_chat_sessions(
|
||||
@router.delete("/delete-chat-session/{session_id}")
|
||||
def delete_chat_session_by_id(
|
||||
session_id: UUID,
|
||||
hard_delete: bool | None = None,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
user_id = user.id if user is not None else None
|
||||
try:
|
||||
delete_chat_session(user_id, session_id, db_session)
|
||||
# Use the provided hard_delete parameter if specified, otherwise use the default config
|
||||
actual_hard_delete = (
|
||||
hard_delete if hard_delete is not None else HARD_DELETE_CHATS
|
||||
)
|
||||
delete_chat_session(
|
||||
user_id, session_id, db_session, hard_delete=actual_hard_delete
|
||||
)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
|
||||
@@ -137,8 +137,6 @@ class CreateChatMessageRequest(ChunkContext):
|
||||
# https://platform.openai.com/docs/guides/structured-outputs/introduction
|
||||
structured_response_format: dict | None = None
|
||||
|
||||
force_user_file_search: bool = False
|
||||
|
||||
# If true, ignores most of the search options and uses pro search instead.
|
||||
# TODO: decide how many of the above options we want to pass through to pro search
|
||||
use_agentic_search: bool = False
|
||||
@@ -274,6 +272,7 @@ class ChatSessionDetailResponse(BaseModel):
|
||||
shared_status: ChatSessionSharedStatus
|
||||
current_alternate_model: str | None
|
||||
current_temperature_override: float | None
|
||||
deleted: bool = False
|
||||
|
||||
|
||||
# This one is not used anymore
|
||||
|
||||
@@ -75,8 +75,6 @@ class SearchToolOverrideKwargs(BaseModel):
|
||||
precomputed_keywords: list[str] | None = None
|
||||
user_file_ids: list[int] | None = None
|
||||
user_folder_ids: list[int] | None = None
|
||||
# Flag for fast path when search is only needed for ordering
|
||||
ordering_only: bool | None = None
|
||||
document_sources: list[DocumentSource] | None = None
|
||||
time_cutoff: datetime | None = None
|
||||
expanded_queries: QueryExpansions | None = None
|
||||
|
||||
@@ -16,6 +16,7 @@ from onyx.configs.app_configs import AZURE_DALLE_DEPLOYMENT_NAME
|
||||
from onyx.configs.chat_configs import BING_API_KEY
|
||||
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from onyx.context.search.enums import LLMEvaluationType
|
||||
from onyx.context.search.enums import OptionalSearchSetting
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import RerankingDetails
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
@@ -141,12 +142,11 @@ def construct_tools(
|
||||
user: User | None,
|
||||
llm: LLM,
|
||||
fast_llm: LLM,
|
||||
use_file_search: bool,
|
||||
run_search_setting: OptionalSearchSetting,
|
||||
search_tool_config: SearchToolConfig | None = None,
|
||||
internet_search_tool_config: InternetSearchToolConfig | None = None,
|
||||
image_generation_tool_config: ImageGenerationToolConfig | None = None,
|
||||
custom_tool_config: CustomToolConfig | None = None,
|
||||
user_knowledge_present: bool = False,
|
||||
) -> dict[int, list[Tool]]:
|
||||
"""Constructs tools based on persona configuration and available APIs"""
|
||||
tool_dict: dict[int, list[Tool]] = {}
|
||||
@@ -163,7 +163,10 @@ def construct_tools(
|
||||
)
|
||||
|
||||
# Handle Search Tool
|
||||
if tool_cls.__name__ == SearchTool.__name__ and not user_knowledge_present:
|
||||
if (
|
||||
tool_cls.__name__ == SearchTool.__name__
|
||||
and run_search_setting != OptionalSearchSetting.NEVER
|
||||
):
|
||||
if not search_tool_config:
|
||||
search_tool_config = SearchToolConfig()
|
||||
|
||||
@@ -256,33 +259,6 @@ def construct_tools(
|
||||
for tool_list in tool_dict.values():
|
||||
tools.extend(tool_list)
|
||||
|
||||
if use_file_search:
|
||||
search_tool_config = SearchToolConfig()
|
||||
|
||||
search_tool = SearchTool(
|
||||
db_session=db_session,
|
||||
user=user,
|
||||
persona=persona,
|
||||
retrieval_options=search_tool_config.retrieval_options,
|
||||
prompt_config=prompt_config,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
pruning_config=search_tool_config.document_pruning_config,
|
||||
answer_style_config=search_tool_config.answer_style_config,
|
||||
selected_sections=search_tool_config.selected_sections,
|
||||
chunks_above=search_tool_config.chunks_above,
|
||||
chunks_below=search_tool_config.chunks_below,
|
||||
full_doc=search_tool_config.full_doc,
|
||||
evaluation_type=(
|
||||
LLMEvaluationType.BASIC
|
||||
if persona.llm_relevance_filter
|
||||
else LLMEvaluationType.SKIP
|
||||
),
|
||||
rerank_settings=search_tool_config.rerank_settings,
|
||||
bypass_acl=search_tool_config.bypass_acl,
|
||||
)
|
||||
tool_dict[1] = [search_tool]
|
||||
|
||||
# factor in tool definition size when pruning
|
||||
if search_tool_config:
|
||||
search_tool_config.document_pruning_config.tool_num_tokens = (
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
import copy
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
@@ -25,13 +23,13 @@ from onyx.configs.chat_configs import CONTEXT_CHUNKS_BELOW
|
||||
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||
from onyx.context.search.enums import LLMEvaluationType
|
||||
from onyx.context.search.enums import QueryFlow
|
||||
from onyx.context.search.enums import SearchType
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import RerankingDetails
|
||||
from onyx.context.search.models import RetrievalDetails
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.context.search.models import UserFileFilters
|
||||
from onyx.context.search.pipeline import SearchPipeline
|
||||
from onyx.context.search.pipeline import section_relevance_list_impl
|
||||
from onyx.db.models import Persona
|
||||
@@ -295,7 +293,6 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
skip_query_analysis = False
|
||||
user_file_ids = None
|
||||
user_folder_ids = None
|
||||
ordering_only = False
|
||||
document_sources = None
|
||||
time_cutoff = None
|
||||
expanded_queries = None
|
||||
@@ -308,46 +305,19 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
)
|
||||
user_file_ids = override_kwargs.user_file_ids
|
||||
user_folder_ids = override_kwargs.user_folder_ids
|
||||
ordering_only = use_alt_not_None(override_kwargs.ordering_only, False)
|
||||
document_sources = override_kwargs.document_sources
|
||||
time_cutoff = override_kwargs.time_cutoff
|
||||
expanded_queries = override_kwargs.expanded_queries
|
||||
|
||||
# Fast path for ordering-only search
|
||||
if ordering_only:
|
||||
yield from self._run_ordering_only_search(
|
||||
query, user_file_ids, user_folder_ids
|
||||
)
|
||||
return
|
||||
|
||||
if self.selected_sections:
|
||||
yield from self._build_response_for_specified_sections(query)
|
||||
return
|
||||
|
||||
# Create a copy of the retrieval options with user_file_ids if provided
|
||||
retrieval_options = copy.deepcopy(self.retrieval_options)
|
||||
if (user_file_ids or user_folder_ids) and retrieval_options:
|
||||
# Create a copy to avoid modifying the original
|
||||
filters = (
|
||||
retrieval_options.filters.model_copy()
|
||||
if retrieval_options.filters
|
||||
else BaseFilters()
|
||||
)
|
||||
filters.user_file_ids = user_file_ids
|
||||
retrieval_options = retrieval_options.model_copy(
|
||||
update={"filters": filters}
|
||||
)
|
||||
elif user_file_ids or user_folder_ids:
|
||||
# Create new retrieval options with user_file_ids
|
||||
filters = BaseFilters(
|
||||
user_file_ids=user_file_ids, user_folder_ids=user_folder_ids
|
||||
)
|
||||
retrieval_options = RetrievalDetails(filters=filters)
|
||||
|
||||
retrieval_options = self.retrieval_options or RetrievalDetails()
|
||||
if document_sources or time_cutoff:
|
||||
# Get retrieval_options and filters, or create if they don't exist
|
||||
retrieval_options = retrieval_options or RetrievalDetails()
|
||||
retrieval_options.filters = retrieval_options.filters or BaseFilters()
|
||||
# if empty, just start with an empty filters object
|
||||
if not retrieval_options.filters:
|
||||
retrieval_options.filters = BaseFilters()
|
||||
|
||||
# Handle document sources
|
||||
if document_sources:
|
||||
@@ -370,6 +340,9 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
human_selected_filters=(
|
||||
retrieval_options.filters if retrieval_options else None
|
||||
),
|
||||
user_file_filters=UserFileFilters(
|
||||
user_file_ids=user_file_ids, user_folder_ids=user_folder_ids
|
||||
),
|
||||
persona=self.persona,
|
||||
offset=(retrieval_options.offset if retrieval_options else None),
|
||||
limit=retrieval_options.limit if retrieval_options else None,
|
||||
@@ -451,105 +424,6 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
prompt_config=self.prompt_config,
|
||||
)
|
||||
|
||||
def _run_ordering_only_search(
|
||||
self,
|
||||
query: str,
|
||||
user_file_ids: list[int] | None,
|
||||
user_folder_ids: list[int] | None,
|
||||
) -> Generator[ToolResponse, None, None]:
|
||||
"""Optimized search that only retrieves document order with minimal processing."""
|
||||
start_time = time.time()
|
||||
|
||||
logger.info("Fast path: Starting optimized ordering-only search")
|
||||
|
||||
# Create temporary search pipeline for optimized retrieval
|
||||
search_pipeline = SearchPipeline(
|
||||
search_request=SearchRequest(
|
||||
query=query,
|
||||
evaluation_type=LLMEvaluationType.SKIP, # Force skip evaluation
|
||||
persona=self.persona,
|
||||
# Minimal configuration needed
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
),
|
||||
user=self.user,
|
||||
llm=self.llm,
|
||||
fast_llm=self.fast_llm,
|
||||
skip_query_analysis=True, # Skip unnecessary analysis
|
||||
db_session=self.db_session,
|
||||
bypass_acl=self.bypass_acl,
|
||||
prompt_config=self.prompt_config,
|
||||
contextual_pruning_config=self.contextual_pruning_config,
|
||||
)
|
||||
|
||||
# Log what we're doing
|
||||
logger.info(
|
||||
f"Fast path: Using {len(user_file_ids or [])} files and {len(user_folder_ids or [])} folders"
|
||||
)
|
||||
|
||||
# Get chunks using the optimized method in SearchPipeline
|
||||
retrieval_start = time.time()
|
||||
retrieved_chunks = search_pipeline.get_ordering_only_chunks(
|
||||
query=query, user_file_ids=user_file_ids, user_folder_ids=user_folder_ids
|
||||
)
|
||||
retrieval_time = time.time() - retrieval_start
|
||||
|
||||
logger.info(
|
||||
f"Fast path: Retrieved {len(retrieved_chunks)} chunks in {retrieval_time:.2f}s"
|
||||
)
|
||||
|
||||
# Convert chunks to minimal sections (we don't need full content)
|
||||
minimal_sections = []
|
||||
for chunk in retrieved_chunks:
|
||||
# Create a minimal section with just center_chunk
|
||||
minimal_section = InferenceSection(
|
||||
center_chunk=chunk,
|
||||
chunks=[chunk],
|
||||
combined_content=chunk.content, # Use the chunk content as combined content
|
||||
)
|
||||
minimal_sections.append(minimal_section)
|
||||
|
||||
# Log document IDs found for debugging
|
||||
doc_ids = [chunk.document_id for chunk in retrieved_chunks]
|
||||
logger.info(
|
||||
f"Fast path: Document IDs in order: {doc_ids[:5]}{'...' if len(doc_ids) > 5 else ''}"
|
||||
)
|
||||
|
||||
# Yield just the required responses for document ordering
|
||||
yield ToolResponse(
|
||||
id=SEARCH_RESPONSE_SUMMARY_ID,
|
||||
response=SearchResponseSummary(
|
||||
rephrased_query=query,
|
||||
top_sections=minimal_sections,
|
||||
predicted_flow=QueryFlow.QUESTION_ANSWER,
|
||||
predicted_search=SearchType.SEMANTIC,
|
||||
final_filters=IndexFilters(
|
||||
user_file_ids=user_file_ids or [],
|
||||
user_folder_ids=user_folder_ids or [],
|
||||
access_control_list=None,
|
||||
),
|
||||
recency_bias_multiplier=1.0,
|
||||
),
|
||||
)
|
||||
|
||||
# For fast path, don't trigger any LLM evaluation for relevance
|
||||
logger.info(
|
||||
"Fast path: Skipping section relevance evaluation to optimize performance"
|
||||
)
|
||||
yield ToolResponse(
|
||||
id=SECTION_RELEVANCE_LIST_ID,
|
||||
response=None,
|
||||
)
|
||||
|
||||
# We need to yield this for the caller to extract document order
|
||||
minimal_docs = [
|
||||
llm_doc_from_inference_section(section) for section in minimal_sections
|
||||
]
|
||||
yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=minimal_docs)
|
||||
|
||||
total_time = time.time() - start_time
|
||||
logger.info(f"Fast path: Completed ordering-only search in {total_time:.2f}s")
|
||||
|
||||
|
||||
# Allows yielding the same responses as a SearchTool without being a SearchTool.
|
||||
# SearchTool passed in to allow for access to SearchTool properties.
|
||||
@@ -568,10 +442,6 @@ def yield_search_responses(
|
||||
get_section_relevance: Callable[[], list[SectionRelevancePiece] | None],
|
||||
search_tool: SearchTool,
|
||||
) -> Generator[ToolResponse, None, None]:
|
||||
# Get the search query to check if we're in ordering-only mode
|
||||
# We can infer this from the reranked_sections not containing any relevance scoring
|
||||
is_ordering_only = search_tool.evaluation_type == LLMEvaluationType.SKIP
|
||||
|
||||
yield ToolResponse(
|
||||
id=SEARCH_RESPONSE_SUMMARY_ID,
|
||||
response=SearchResponseSummary(
|
||||
@@ -584,48 +454,26 @@ def yield_search_responses(
|
||||
),
|
||||
)
|
||||
|
||||
section_relevance: list[SectionRelevancePiece] | None = None
|
||||
|
||||
# Skip section relevance in ordering-only mode
|
||||
if is_ordering_only:
|
||||
logger.info(
|
||||
"Fast path: Skipping section relevance evaluation in yield_search_responses"
|
||||
)
|
||||
yield ToolResponse(
|
||||
id=SECTION_RELEVANCE_LIST_ID,
|
||||
response=None,
|
||||
)
|
||||
else:
|
||||
section_relevance = get_section_relevance()
|
||||
yield ToolResponse(
|
||||
id=SECTION_RELEVANCE_LIST_ID,
|
||||
response=section_relevance,
|
||||
)
|
||||
section_relevance = get_section_relevance()
|
||||
yield ToolResponse(
|
||||
id=SECTION_RELEVANCE_LIST_ID,
|
||||
response=section_relevance,
|
||||
)
|
||||
|
||||
final_context_sections = get_final_context_sections()
|
||||
|
||||
# Skip pruning sections in ordering-only mode
|
||||
if is_ordering_only:
|
||||
logger.info("Fast path: Skipping section pruning in ordering-only mode")
|
||||
llm_docs = [
|
||||
llm_doc_from_inference_section(section)
|
||||
for section in final_context_sections
|
||||
]
|
||||
else:
|
||||
# Use the section_relevance we already computed above
|
||||
pruned_sections = prune_sections(
|
||||
sections=final_context_sections,
|
||||
section_relevance_list=section_relevance_list_impl(
|
||||
section_relevance, final_context_sections
|
||||
),
|
||||
prompt_config=search_tool.prompt_config,
|
||||
llm_config=search_tool.llm.config,
|
||||
question=query,
|
||||
contextual_pruning_config=search_tool.contextual_pruning_config,
|
||||
)
|
||||
llm_docs = [
|
||||
llm_doc_from_inference_section(section) for section in pruned_sections
|
||||
]
|
||||
# Use the section_relevance we already computed above
|
||||
pruned_sections = prune_sections(
|
||||
sections=final_context_sections,
|
||||
section_relevance_list=section_relevance_list_impl(
|
||||
section_relevance, final_context_sections
|
||||
),
|
||||
prompt_config=search_tool.prompt_config,
|
||||
llm_config=search_tool.llm.config,
|
||||
question=query,
|
||||
contextual_pruning_config=search_tool.contextual_pruning_config,
|
||||
)
|
||||
llm_docs = [llm_doc_from_inference_section(section) for section in pruned_sections]
|
||||
|
||||
yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs)
|
||||
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
|
||||
from onyx.chat.models import AnswerStyleConfig
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import PromptConfig
|
||||
@@ -10,7 +8,6 @@ from onyx.chat.prompt_builder.citations_prompt import (
|
||||
build_citations_system_message,
|
||||
)
|
||||
from onyx.chat.prompt_builder.citations_prompt import build_citations_user_message
|
||||
from onyx.llm.utils import build_content_with_imgs
|
||||
from onyx.tools.message import ToolCallSummary
|
||||
from onyx.tools.models import ToolResponse
|
||||
|
||||
@@ -45,12 +42,8 @@ def build_next_prompt_for_search_like_tool(
|
||||
build_citations_user_message(
|
||||
# make sure to use the original user query here in order to avoid duplication
|
||||
# of the task prompt
|
||||
message=HumanMessage(
|
||||
content=build_content_with_imgs(
|
||||
prompt_builder.raw_user_query,
|
||||
prompt_builder.raw_user_uploaded_files,
|
||||
)
|
||||
),
|
||||
user_query=prompt_builder.raw_user_query,
|
||||
files=prompt_builder.raw_user_uploaded_files,
|
||||
prompt_config=prompt_config,
|
||||
context_docs=final_context_documents,
|
||||
all_doc_useful=(
|
||||
|
||||
@@ -8,6 +8,7 @@ from onyx.db.connector import check_connectors_exist
|
||||
from onyx.db.document import check_docs_exist
|
||||
from onyx.db.models import LLMProvider
|
||||
from onyx.llm.llm_provider_options import ANTHROPIC_PROVIDER_NAME
|
||||
from onyx.llm.llm_provider_options import BEDROCK_PROVIDER_NAME
|
||||
from onyx.llm.utils import find_model_obj
|
||||
from onyx.llm.utils import get_model_map
|
||||
from onyx.natural_language_processing.utils import BaseTokenizer
|
||||
@@ -35,6 +36,10 @@ def explicit_tool_calling_supported(model_provider: str, model_name: str) -> boo
|
||||
model_supports
|
||||
and model_provider != ANTHROPIC_PROVIDER_NAME
|
||||
and model_name not in litellm.anthropic_models
|
||||
and (
|
||||
model_provider != BEDROCK_PROVIDER_NAME
|
||||
or not any(name in model_name for name in litellm.anthropic_models)
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -8,3 +8,8 @@ filterwarnings =
|
||||
ignore::DeprecationWarning
|
||||
ignore::cryptography.utils.CryptographyDeprecationWarning
|
||||
ignore::PendingDeprecationWarning:ddtrace.internal.module
|
||||
# .test.env is gitignored.
|
||||
# After installing pytest-dotenv,
|
||||
# you can use it to test credentials locally.
|
||||
env_files =
|
||||
.test.env
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
aioboto3==14.0.0
|
||||
aiohttp==3.11.16
|
||||
alembic==1.10.4
|
||||
asyncpg==0.27.0
|
||||
asyncpg==0.30.0
|
||||
atlassian-python-api==3.41.16
|
||||
beautifulsoup4==4.12.3
|
||||
boto3==1.36.23
|
||||
|
||||
@@ -12,6 +12,7 @@ pandas==2.2.3
|
||||
posthog==3.7.4
|
||||
pre-commit==3.2.2
|
||||
pytest-asyncio==0.22.0
|
||||
pytest-dotenv==0.5.2
|
||||
pytest-xdist==3.6.1
|
||||
pytest==8.3.5
|
||||
reorder-python-imports-black==3.14.0
|
||||
|
||||
@@ -21,6 +21,7 @@ if True: # noqa: E402
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import SqlEngine
|
||||
from onyx.db.models import Document
|
||||
from onyx.db.models import User
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
@@ -30,6 +31,8 @@ if True: # noqa: E402
|
||||
|
||||
|
||||
class TenantMetadata(BaseModel):
|
||||
first_email: str | None
|
||||
user_count: int
|
||||
num_docs: int
|
||||
num_chunks: int
|
||||
|
||||
@@ -39,7 +42,7 @@ class SQLAlchemyDebugging:
|
||||
def __init__(self) -> None:
|
||||
pass
|
||||
|
||||
def top_chunks(self, k: int = 10) -> None:
|
||||
def top_chunks(self, filename: str, k: int = 10) -> None:
|
||||
tenants_to_total_chunks: dict[str, TenantMetadata] = {}
|
||||
|
||||
logger.info("Fetching all tenant id's.")
|
||||
@@ -56,6 +59,14 @@ class SQLAlchemyDebugging:
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
first_email = None
|
||||
|
||||
first_user = db_session.query(User).first()
|
||||
if first_user:
|
||||
first_email = first_user.email
|
||||
|
||||
user_count = db_session.query(User).count()
|
||||
|
||||
# Calculate the total number of document rows for the current tenant
|
||||
total_documents = db_session.query(Document).count()
|
||||
# marginally useful to skip some tenants ... maybe we can improve on this
|
||||
@@ -69,15 +80,20 @@ class SQLAlchemyDebugging:
|
||||
total_chunks = db_session.query(
|
||||
func.sum(Document.chunk_count)
|
||||
).scalar()
|
||||
|
||||
total_chunks = total_chunks or 0
|
||||
|
||||
logger.info(
|
||||
f"{num_processed} of {num_tenant_ids}: Tenant '{tenant_id}': "
|
||||
f"first_email={first_email} user_count={user_count} "
|
||||
f"docs={total_documents} chunks={total_chunks}"
|
||||
)
|
||||
|
||||
tenants_to_total_chunks[tenant_id] = TenantMetadata(
|
||||
num_docs=total_documents, num_chunks=total_chunks
|
||||
first_email=first_email,
|
||||
user_count=user_count,
|
||||
num_docs=total_documents,
|
||||
num_chunks=total_chunks,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing tenant '{tenant_id}': {e}")
|
||||
@@ -91,14 +107,23 @@ class SQLAlchemyDebugging:
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
csv_filename = "tenants_by_num_docs.csv"
|
||||
with open(csv_filename, "w") as csvfile:
|
||||
with open(filename, "w") as csvfile:
|
||||
writer = csv.writer(csvfile)
|
||||
writer.writerow(["tenant_id", "num_docs", "num_chunks"]) # Write header
|
||||
writer.writerow(
|
||||
["tenant_id", "first_user_email", "num_user", "num_docs", "num_chunks"]
|
||||
) # Write header
|
||||
# Write data rows (using the sorted list)
|
||||
for tenant_id, metadata in sorted_tenants:
|
||||
writer.writerow([tenant_id, metadata.num_docs, metadata.num_chunks])
|
||||
logger.info(f"Successfully wrote statistics to {csv_filename}")
|
||||
writer.writerow(
|
||||
[
|
||||
tenant_id,
|
||||
metadata.first_email,
|
||||
metadata.user_count,
|
||||
metadata.num_docs,
|
||||
metadata.num_chunks,
|
||||
]
|
||||
)
|
||||
logger.info(f"Successfully wrote statistics to {filename}")
|
||||
|
||||
# output top k by chunks
|
||||
top_k_tenants = heapq.nlargest(
|
||||
@@ -118,6 +143,14 @@ def main() -> None:
|
||||
|
||||
parser.add_argument("--report", help="Generate the given report")
|
||||
|
||||
parser.add_argument(
|
||||
"--filename",
|
||||
type=str,
|
||||
default="tenants_by_num_docs.csv",
|
||||
help="Generate the given report",
|
||||
required=False,
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
logger.info(f"{args}")
|
||||
@@ -140,7 +173,7 @@ def main() -> None:
|
||||
debugger = SQLAlchemyDebugging()
|
||||
|
||||
if args.report == "top-chunks":
|
||||
debugger.top_chunks(10)
|
||||
debugger.top_chunks(args.filename, 10)
|
||||
else:
|
||||
logger.info("No action.")
|
||||
|
||||
|
||||
77
backend/scripts/resume_paused_connectors.py
Normal file
77
backend/scripts/resume_paused_connectors.py
Normal file
@@ -0,0 +1,77 @@
|
||||
import argparse
|
||||
|
||||
import requests
|
||||
|
||||
API_SERVER_URL = "http://localhost:3000"
|
||||
API_KEY = "onyx-api-key" # API key here, if auth is enabled
|
||||
|
||||
|
||||
def resume_paused_connectors(
|
||||
api_server_url: str,
|
||||
api_key: str | None,
|
||||
specific_connector_sources: list[str] | None = None,
|
||||
) -> None:
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
# Get all paused connectors
|
||||
response = requests.get(
|
||||
f"{api_server_url}/api/manage/admin/connector/indexing-status",
|
||||
headers=headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
# Convert the response to a list of ConnectorIndexingStatus objects
|
||||
connectors = [cc_pair for cc_pair in response.json()]
|
||||
|
||||
# If a specific connector is provided, filter the connectors to only include that one
|
||||
if specific_connector_sources:
|
||||
connectors = [
|
||||
connector
|
||||
for connector in connectors
|
||||
if connector["connector"]["source"] in specific_connector_sources
|
||||
]
|
||||
|
||||
for connector in connectors:
|
||||
if connector["cc_pair_status"] == "PAUSED":
|
||||
print(f"Resuming connector: {connector['name']}")
|
||||
response = requests.put(
|
||||
f"{api_server_url}/api/manage/admin/cc-pair/{connector['cc_pair_id']}/status",
|
||||
json={"status": "ACTIVE"},
|
||||
headers=headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
print(f"Resumed connector: {connector['name']}")
|
||||
|
||||
else:
|
||||
print(f"Connector {connector['name']} is not paused")
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = argparse.ArgumentParser(description="Resume paused connectors")
|
||||
parser.add_argument(
|
||||
"--api_server_url",
|
||||
type=str,
|
||||
default=API_SERVER_URL,
|
||||
help="The URL of the API server to use. If not provided, will use the default.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--api_key",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The API key to use for authentication. If not provided, no authentication will be used.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--connector_sources",
|
||||
type=str.lower,
|
||||
nargs="+",
|
||||
help="The sources of the connectors to resume. If not provided, will resume all paused connectors.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
resume_paused_connectors(args.api_server_url, args.api_key, args.connector_sources)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -30,25 +30,48 @@ def test_github_connector_basic(github_connector: GithubConnector) -> None:
|
||||
start=0,
|
||||
end=time.time(),
|
||||
)
|
||||
assert len(docs) > 0 # We expect at least one PR to exist
|
||||
assert len(docs) > 1 # We expect at least one PR and one Issue to exist
|
||||
|
||||
# Test the first document's structure
|
||||
doc = docs[0]
|
||||
pr_doc = docs[0]
|
||||
issue_doc = docs[-1]
|
||||
|
||||
# Verify basic document properties
|
||||
assert doc.source == DocumentSource.GITHUB
|
||||
assert doc.secondary_owners is None
|
||||
assert doc.from_ingestion_api is False
|
||||
assert doc.additional_info is None
|
||||
assert pr_doc.source == DocumentSource.GITHUB
|
||||
assert pr_doc.secondary_owners is None
|
||||
assert pr_doc.from_ingestion_api is False
|
||||
assert pr_doc.additional_info is None
|
||||
|
||||
# Verify GitHub-specific properties
|
||||
assert "github.com" in doc.id # Should be a GitHub URL
|
||||
assert doc.metadata is not None
|
||||
assert "state" in doc.metadata
|
||||
assert "merged" in doc.metadata
|
||||
assert "github.com" in pr_doc.id # Should be a GitHub URL
|
||||
|
||||
# Verify PR-specific properties
|
||||
assert pr_doc.metadata is not None
|
||||
assert pr_doc.metadata.get("object_type") == "PullRequest"
|
||||
assert "id" in pr_doc.metadata
|
||||
assert "merged" in pr_doc.metadata
|
||||
assert "state" in pr_doc.metadata
|
||||
assert "user" in pr_doc.metadata
|
||||
assert "assignees" in pr_doc.metadata
|
||||
assert pr_doc.metadata.get("repo") == "onyx-dot-app/documentation"
|
||||
assert "num_commits" in pr_doc.metadata
|
||||
assert "num_files_changed" in pr_doc.metadata
|
||||
assert "labels" in pr_doc.metadata
|
||||
assert "created_at" in pr_doc.metadata
|
||||
|
||||
# Verify Issue-specific properties
|
||||
assert issue_doc.metadata is not None
|
||||
assert issue_doc.metadata.get("object_type") == "Issue"
|
||||
assert "id" in issue_doc.metadata
|
||||
assert "state" in issue_doc.metadata
|
||||
assert "user" in issue_doc.metadata
|
||||
assert "assignees" in issue_doc.metadata
|
||||
assert issue_doc.metadata.get("repo") == "onyx-dot-app/documentation"
|
||||
assert "labels" in issue_doc.metadata
|
||||
assert "created_at" in issue_doc.metadata
|
||||
|
||||
# Verify sections
|
||||
assert len(doc.sections) == 1
|
||||
section = doc.sections[0]
|
||||
assert section.link == doc.id # Section link should match document ID
|
||||
assert len(pr_doc.sections) == 1
|
||||
section = pr_doc.sections[0]
|
||||
assert section.link == pr_doc.id # Section link should match document ID
|
||||
assert isinstance(section.text, str) # Should have some text content
|
||||
|
||||
@@ -59,11 +59,19 @@ def test_jira_connector_basic(
|
||||
assert story.source == DocumentSource.JIRA
|
||||
assert story.metadata == {
|
||||
"priority": "Medium",
|
||||
"status": "Backlog",
|
||||
"status": "Done",
|
||||
"resolution": "Done",
|
||||
"resolution_date": "2025-05-29T15:33:31.031-0700",
|
||||
"reporter": "Chris Weaver",
|
||||
"assignee": "Chris Weaver",
|
||||
"issuetype": "Story",
|
||||
"created": "2025-04-16T16:44:06.716-0700",
|
||||
"reporter_email": "chris@onyx.app",
|
||||
"assignee_email": "chris@onyx.app",
|
||||
"project_name": "DailyConnectorTestProject",
|
||||
"project": "AS",
|
||||
"parent": "AS-4",
|
||||
"updated": "2025-05-29T15:33:31.085-0700",
|
||||
}
|
||||
assert story.secondary_owners is None
|
||||
assert story.title == "AS-3 test123small"
|
||||
@@ -86,6 +94,11 @@ def test_jira_connector_basic(
|
||||
"assignee": "Chris Weaver",
|
||||
"issuetype": "Epic",
|
||||
"created": "2025-04-16T16:55:53.068-0700",
|
||||
"reporter_email": "founders@onyx.app",
|
||||
"assignee_email": "chris@onyx.app",
|
||||
"project_name": "DailyConnectorTestProject",
|
||||
"project": "AS",
|
||||
"updated": "2025-05-29T14:43:05.312-0700",
|
||||
}
|
||||
assert epic.secondary_owners is None
|
||||
assert epic.title == "AS-4 EPIC"
|
||||
|
||||
@@ -31,6 +31,7 @@ def slack_connector(
|
||||
connector = SlackConnector(
|
||||
channels=[channel] if channel else None,
|
||||
channel_regex_enabled=False,
|
||||
use_redis=False,
|
||||
)
|
||||
connector.client = mock_slack_client
|
||||
connector.set_credentials_provider(credentials_provider=slack_credentials_provider)
|
||||
|
||||
@@ -108,7 +108,7 @@ def azure_embedding_model() -> EmbeddingModel:
|
||||
return EmbeddingModel(
|
||||
server_host="localhost",
|
||||
server_port=9000,
|
||||
model_name="text-embedding-3-large",
|
||||
model_name="text-embedding-3-small",
|
||||
normalize=True,
|
||||
query_prefix=None,
|
||||
passage_prefix=None,
|
||||
|
||||
@@ -60,6 +60,7 @@ class ChatSessionManager:
|
||||
prompt_override: PromptOverride | None = None,
|
||||
alternate_assistant_id: int | None = None,
|
||||
use_existing_user_message: bool = False,
|
||||
use_agentic_search: bool = False,
|
||||
) -> StreamedResponse:
|
||||
chat_message_req = CreateChatMessageRequest(
|
||||
chat_session_id=chat_session_id,
|
||||
@@ -76,6 +77,7 @@ class ChatSessionManager:
|
||||
prompt_override=prompt_override,
|
||||
alternate_assistant_id=alternate_assistant_id,
|
||||
use_existing_user_message=use_existing_user_message,
|
||||
use_agentic_search=use_agentic_search,
|
||||
)
|
||||
|
||||
headers = (
|
||||
@@ -175,3 +177,136 @@ class ChatSessionManager:
|
||||
),
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def delete(
|
||||
chat_session: DATestChatSession,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Delete a chat session and all its related records (messages, agent data, etc.)
|
||||
Uses the default deletion method configured on the server.
|
||||
|
||||
Returns True if deletion was successful, False otherwise.
|
||||
"""
|
||||
response = requests.delete(
|
||||
f"{API_SERVER_URL}/chat/delete-chat-session/{chat_session.id}",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
)
|
||||
return response.ok
|
||||
|
||||
@staticmethod
|
||||
def soft_delete(
|
||||
chat_session: DATestChatSession,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Soft delete a chat session (marks as deleted but keeps in database).
|
||||
|
||||
Returns True if deletion was successful, False otherwise.
|
||||
"""
|
||||
# Since there's no direct API for soft delete, we'll use a query parameter approach
|
||||
# or make a direct call with hard_delete=False parameter via a new endpoint
|
||||
response = requests.delete(
|
||||
f"{API_SERVER_URL}/chat/delete-chat-session/{chat_session.id}?hard_delete=false",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
)
|
||||
return response.ok
|
||||
|
||||
@staticmethod
|
||||
def hard_delete(
|
||||
chat_session: DATestChatSession,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Hard delete a chat session (completely removes from database).
|
||||
|
||||
Returns True if deletion was successful, False otherwise.
|
||||
"""
|
||||
response = requests.delete(
|
||||
f"{API_SERVER_URL}/chat/delete-chat-session/{chat_session.id}?hard_delete=true",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
)
|
||||
return response.ok
|
||||
|
||||
@staticmethod
|
||||
def verify_deleted(
|
||||
chat_session: DATestChatSession,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Verify that a chat session has been deleted by attempting to retrieve it.
|
||||
|
||||
Returns True if the chat session is confirmed deleted, False if it still exists.
|
||||
"""
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/chat/get-chat-session/{chat_session.id}",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
)
|
||||
# Chat session should return 400 if it doesn't exist
|
||||
return response.status_code == 400
|
||||
|
||||
@staticmethod
|
||||
def verify_soft_deleted(
|
||||
chat_session: DATestChatSession,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Verify that a chat session has been soft deleted (marked as deleted but still in DB).
|
||||
|
||||
Returns True if the chat session is soft deleted, False otherwise.
|
||||
"""
|
||||
# Try to get the chat session with include_deleted=true
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/chat/get-chat-session/{chat_session.id}?include_deleted=true",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
# Chat exists, check if it's marked as deleted
|
||||
chat_data = response.json()
|
||||
return chat_data.get("deleted", False) is True
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def verify_hard_deleted(
|
||||
chat_session: DATestChatSession,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> bool:
|
||||
"""
|
||||
Verify that a chat session has been hard deleted (completely removed from DB).
|
||||
|
||||
Returns True if the chat session is hard deleted, False otherwise.
|
||||
"""
|
||||
# Try to get the chat session with include_deleted=true
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/chat/get-chat-session/{chat_session.id}?include_deleted=true",
|
||||
headers=(
|
||||
user_performing_action.headers
|
||||
if user_performing_action
|
||||
else GENERAL_HEADERS
|
||||
),
|
||||
)
|
||||
|
||||
# For hard delete, even with include_deleted=true, the record should not exist
|
||||
return response.status_code != 200
|
||||
|
||||
@@ -29,7 +29,6 @@ class UserManager:
|
||||
def create(
|
||||
name: str | None = None,
|
||||
email: str | None = None,
|
||||
is_first_user: bool = False,
|
||||
) -> DATestUser:
|
||||
if name is None:
|
||||
name = f"test{str(uuid4())}"
|
||||
@@ -51,14 +50,14 @@ class UserManager:
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
role = UserRole.ADMIN if is_first_user else UserRole.BASIC
|
||||
|
||||
test_user = DATestUser(
|
||||
id=response.json()["id"],
|
||||
email=email,
|
||||
password=password,
|
||||
headers=deepcopy(GENERAL_HEADERS),
|
||||
role=role,
|
||||
# fill as basic for now, the `login_as_user` call will
|
||||
# fill it in correctly
|
||||
role=UserRole.BASIC,
|
||||
is_active=True,
|
||||
)
|
||||
print(f"Created user {test_user.email}")
|
||||
@@ -93,6 +92,17 @@ class UserManager:
|
||||
# Set cookies in the headers
|
||||
test_user.headers["Cookie"] = f"fastapiusersauth={session_cookie}; "
|
||||
test_user.cookies = {"fastapiusersauth": session_cookie}
|
||||
|
||||
# Get user role from /me endpoint
|
||||
me_response = requests.get(
|
||||
url=f"{API_SERVER_URL}/me",
|
||||
headers=test_user.headers,
|
||||
cookies=test_user.cookies,
|
||||
)
|
||||
me_response.raise_for_status()
|
||||
role = UserRole(me_response.json()["role"])
|
||||
test_user.role = role
|
||||
|
||||
return test_user
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -16,6 +16,8 @@ from tests.integration.common_utils.reset import reset_all_multitenant
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
from tests.integration.common_utils.vespa import vespa_fixture
|
||||
|
||||
BASIC_USER_NAME = "basic_user"
|
||||
|
||||
|
||||
def load_env_vars(env_file: str = ".env") -> None:
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
@@ -81,7 +83,7 @@ def new_admin_user(reset: None) -> DATestUser | None:
|
||||
@pytest.fixture
|
||||
def admin_user() -> DATestUser:
|
||||
try:
|
||||
user = UserManager.create(name=ADMIN_USER_NAME, is_first_user=True)
|
||||
user = UserManager.create(name=ADMIN_USER_NAME)
|
||||
|
||||
# if there are other users for some reason, reset and try again
|
||||
if not UserManager.is_role(user, UserRole.ADMIN):
|
||||
@@ -115,6 +117,44 @@ def admin_user() -> DATestUser:
|
||||
raise RuntimeError("Failed to create or login as admin user")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def basic_user(
|
||||
# make sure the admin user exists first to ensure this new user
|
||||
# gets the BASIC role
|
||||
admin_user: DATestUser,
|
||||
) -> DATestUser:
|
||||
try:
|
||||
user = UserManager.create(name=BASIC_USER_NAME)
|
||||
|
||||
# Validate that the user has the BASIC role
|
||||
if user.role != UserRole.BASIC:
|
||||
raise RuntimeError(
|
||||
f"Created user {BASIC_USER_NAME} does not have BASIC role"
|
||||
)
|
||||
|
||||
return user
|
||||
except Exception as e:
|
||||
print(f"Failed to create basic user, trying to login as existing user: {e}")
|
||||
|
||||
# Try to login as existing basic user
|
||||
user = UserManager.login_as_user(
|
||||
DATestUser(
|
||||
id="",
|
||||
email=build_email(BASIC_USER_NAME),
|
||||
password=DEFAULT_PASSWORD,
|
||||
headers=GENERAL_HEADERS,
|
||||
role=UserRole.BASIC,
|
||||
is_active=True,
|
||||
)
|
||||
)
|
||||
|
||||
# Validate that the logged-in user has the BASIC role
|
||||
if not UserManager.is_role(user, UserRole.BASIC):
|
||||
raise RuntimeError(f"User {BASIC_USER_NAME} does not have BASIC role")
|
||||
|
||||
return user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def reset_multitenant() -> None:
|
||||
reset_all_multitenant()
|
||||
|
||||
@@ -17,8 +17,7 @@ from slack_sdk.errors import SlackApiError
|
||||
|
||||
from onyx.connectors.slack.connector import default_msg_filter
|
||||
from onyx.connectors.slack.connector import get_channel_messages
|
||||
from onyx.connectors.slack.utils import make_paginated_slack_api_call_w_retries
|
||||
from onyx.connectors.slack.utils import make_slack_api_call_w_retries
|
||||
from onyx.connectors.slack.utils import make_paginated_slack_api_call
|
||||
|
||||
|
||||
def _get_slack_channel_id(channel: dict[str, Any]) -> str:
|
||||
@@ -40,7 +39,7 @@ def _get_non_general_channels(
|
||||
channel_types.append("public_channel")
|
||||
|
||||
conversations: list[dict[str, Any]] = []
|
||||
for result in make_paginated_slack_api_call_w_retries(
|
||||
for result in make_paginated_slack_api_call(
|
||||
slack_client.conversations_list,
|
||||
exclude_archived=False,
|
||||
types=channel_types,
|
||||
@@ -64,7 +63,7 @@ def _clear_slack_conversation_members(
|
||||
) -> None:
|
||||
channel_id = _get_slack_channel_id(channel)
|
||||
member_ids: list[str] = []
|
||||
for result in make_paginated_slack_api_call_w_retries(
|
||||
for result in make_paginated_slack_api_call(
|
||||
slack_client.conversations_members,
|
||||
channel=channel_id,
|
||||
):
|
||||
@@ -140,15 +139,13 @@ def _build_slack_channel_from_name(
|
||||
if channel:
|
||||
# If channel is provided, we rename it
|
||||
channel_id = _get_slack_channel_id(channel)
|
||||
channel_response = make_slack_api_call_w_retries(
|
||||
slack_client.conversations_rename,
|
||||
channel_response = slack_client.conversations_rename(
|
||||
channel=channel_id,
|
||||
name=channel_name,
|
||||
)
|
||||
else:
|
||||
# Otherwise, we create a new channel
|
||||
channel_response = make_slack_api_call_w_retries(
|
||||
slack_client.conversations_create,
|
||||
channel_response = slack_client.conversations_create(
|
||||
name=channel_name,
|
||||
is_private=is_private,
|
||||
)
|
||||
@@ -219,10 +216,13 @@ class SlackManager:
|
||||
|
||||
@staticmethod
|
||||
def build_slack_user_email_id_map(slack_client: WebClient) -> dict[str, str]:
|
||||
users_results = make_slack_api_call_w_retries(
|
||||
users: list[dict[str, Any]] = []
|
||||
|
||||
for users_results in make_paginated_slack_api_call(
|
||||
slack_client.users_list,
|
||||
)
|
||||
users: list[dict[str, Any]] = users_results.get("members", [])
|
||||
):
|
||||
users.extend(users_results.get("members", []))
|
||||
|
||||
user_email_id_map = {}
|
||||
for user in users:
|
||||
if not (email := user.get("profile", {}).get("email")):
|
||||
@@ -253,8 +253,7 @@ class SlackManager:
|
||||
slack_client: WebClient, channel: dict[str, Any], message: str
|
||||
) -> None:
|
||||
channel_id = _get_slack_channel_id(channel)
|
||||
make_slack_api_call_w_retries(
|
||||
slack_client.chat_postMessage,
|
||||
slack_client.chat_postMessage(
|
||||
channel=channel_id,
|
||||
text=message,
|
||||
)
|
||||
@@ -274,7 +273,7 @@ class SlackManager:
|
||||
) -> None:
|
||||
channel_types = ["private_channel", "public_channel"]
|
||||
channels: list[dict[str, Any]] = []
|
||||
for result in make_paginated_slack_api_call_w_retries(
|
||||
for result in make_paginated_slack_api_call(
|
||||
slack_client.conversations_list,
|
||||
exclude_archived=False,
|
||||
types=channel_types,
|
||||
|
||||
429
backend/tests/integration/tests/chat/test_chat_deletion.py
Normal file
429
backend/tests/integration/tests/chat/test_chat_deletion.py
Normal file
@@ -0,0 +1,429 @@
|
||||
import pytest
|
||||
|
||||
from tests.integration.common_utils.managers.chat import ChatSessionManager
|
||||
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
|
||||
from tests.integration.common_utils.reset import reset_all
|
||||
from tests.integration.common_utils.test_models import DATestLLMProvider
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
@pytest.fixture(scope="module", autouse=True)
|
||||
def reset_for_module() -> None:
|
||||
"""Reset all data once before running any tests in this module."""
|
||||
reset_all()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def llm_provider(admin_user: DATestUser) -> DATestLLMProvider:
|
||||
return LLMProviderManager.create(user_performing_action=admin_user)
|
||||
|
||||
|
||||
def test_soft_delete_chat_session(
|
||||
basic_user: DATestUser, llm_provider: DATestLLMProvider
|
||||
) -> None:
|
||||
"""
|
||||
Test soft deletion of a chat session.
|
||||
Soft delete should mark the chat as deleted but keep it in the database.
|
||||
"""
|
||||
# Create a chat session
|
||||
test_chat_session = ChatSessionManager.create(
|
||||
persona_id=0, # Use default persona
|
||||
description="Test chat session for soft deletion",
|
||||
user_performing_action=basic_user,
|
||||
)
|
||||
|
||||
# Send a message to create some data
|
||||
response = ChatSessionManager.send_message(
|
||||
chat_session_id=test_chat_session.id,
|
||||
message="Explain the concept of machine learning in detail",
|
||||
user_performing_action=basic_user,
|
||||
)
|
||||
|
||||
# Verify that the message was processed successfully
|
||||
assert len(response.full_message) > 0, "Chat response should not be empty"
|
||||
|
||||
# Verify that the chat session can be retrieved before deletion
|
||||
chat_history = ChatSessionManager.get_chat_history(
|
||||
chat_session=test_chat_session,
|
||||
user_performing_action=basic_user,
|
||||
)
|
||||
assert len(chat_history) > 0, "Chat session should have messages"
|
||||
|
||||
# Test soft deletion of the chat session
|
||||
deletion_success = ChatSessionManager.soft_delete(
|
||||
chat_session=test_chat_session,
|
||||
user_performing_action=basic_user,
|
||||
)
|
||||
|
||||
# Verify that the deletion was successful
|
||||
assert deletion_success, "Chat session soft deletion should succeed"
|
||||
|
||||
# Verify that the chat session is soft deleted (marked as deleted but still in DB)
|
||||
assert ChatSessionManager.verify_soft_deleted(
|
||||
chat_session=test_chat_session,
|
||||
user_performing_action=basic_user,
|
||||
), "Chat session should be soft deleted"
|
||||
|
||||
# Verify that normal access is blocked
|
||||
assert ChatSessionManager.verify_deleted(
|
||||
chat_session=test_chat_session,
|
||||
user_performing_action=basic_user,
|
||||
), "Chat session should not be accessible normally after soft delete"
|
||||
|
||||
|
||||
def test_hard_delete_chat_session(
|
||||
basic_user: DATestUser, llm_provider: DATestLLMProvider
|
||||
) -> None:
|
||||
"""
|
||||
Test hard deletion of a chat session.
|
||||
Hard delete should completely remove the chat from the database.
|
||||
"""
|
||||
# Create a chat session
|
||||
test_chat_session = ChatSessionManager.create(
|
||||
persona_id=0, # Use default persona
|
||||
description="Test chat session for hard deletion",
|
||||
user_performing_action=basic_user,
|
||||
)
|
||||
|
||||
# Send a message to create some data
|
||||
response = ChatSessionManager.send_message(
|
||||
chat_session_id=test_chat_session.id,
|
||||
message="Explain the concept of machine learning in detail",
|
||||
user_performing_action=basic_user,
|
||||
)
|
||||
|
||||
# Verify that the message was processed successfully
|
||||
assert len(response.full_message) > 0, "Chat response should not be empty"
|
||||
|
||||
# Verify that the chat session can be retrieved before deletion
|
||||
chat_history = ChatSessionManager.get_chat_history(
|
||||
chat_session=test_chat_session,
|
||||
user_performing_action=basic_user,
|
||||
)
|
||||
assert len(chat_history) > 0, "Chat session should have messages"
|
||||
|
||||
# Test hard deletion of the chat session
|
||||
deletion_success = ChatSessionManager.hard_delete(
|
||||
chat_session=test_chat_session,
|
||||
user_performing_action=basic_user,
|
||||
)
|
||||
|
||||
# Verify that the deletion was successful
|
||||
assert deletion_success, "Chat session hard deletion should succeed"
|
||||
|
||||
# Verify that the chat session is hard deleted (completely removed from DB)
|
||||
assert ChatSessionManager.verify_hard_deleted(
|
||||
chat_session=test_chat_session,
|
||||
user_performing_action=basic_user,
|
||||
), "Chat session should be hard deleted"
|
||||
|
||||
# Verify that the chat session is not accessible at all
|
||||
assert ChatSessionManager.verify_deleted(
|
||||
chat_session=test_chat_session,
|
||||
user_performing_action=basic_user,
|
||||
), "Chat session should not be accessible after hard delete"
|
||||
|
||||
# Verify it's not soft deleted (since it doesn't exist at all)
|
||||
assert not ChatSessionManager.verify_soft_deleted(
|
||||
chat_session=test_chat_session,
|
||||
user_performing_action=basic_user,
|
||||
), "Hard deleted chat should not be found as soft deleted"
|
||||
|
||||
|
||||
def test_soft_delete_with_agentic_search(
|
||||
basic_user: DATestUser, llm_provider: DATestLLMProvider
|
||||
) -> None:
|
||||
"""
|
||||
Test soft deletion of a chat session with agent behavior (sub-questions and sub-queries).
|
||||
Verifies that soft delete preserves all related agent records in the database.
|
||||
"""
|
||||
# Create a chat session
|
||||
test_chat_session = ChatSessionManager.create(
|
||||
persona_id=0,
|
||||
description="Test agentic search soft deletion",
|
||||
user_performing_action=basic_user,
|
||||
)
|
||||
|
||||
# Send a message using ChatSessionManager with agentic search enabled
|
||||
# This will create AgentSubQuestion and AgentSubQuery records
|
||||
response = ChatSessionManager.send_message(
|
||||
chat_session_id=test_chat_session.id,
|
||||
message="What are the key principles of software engineering?",
|
||||
user_performing_action=basic_user,
|
||||
use_agentic_search=True,
|
||||
)
|
||||
|
||||
# Verify that the message was processed successfully
|
||||
assert len(response.full_message) > 0, "Chat response should not be empty"
|
||||
|
||||
# Test soft deletion
|
||||
deletion_success = ChatSessionManager.soft_delete(
|
||||
chat_session=test_chat_session,
|
||||
user_performing_action=basic_user,
|
||||
)
|
||||
|
||||
# Verify successful soft deletion
|
||||
assert deletion_success, "Chat soft deletion should succeed"
|
||||
|
||||
# Verify chat session is soft deleted
|
||||
assert ChatSessionManager.verify_soft_deleted(
|
||||
chat_session=test_chat_session,
|
||||
user_performing_action=basic_user,
|
||||
), "Soft deleted chat session should be marked as deleted in DB"
|
||||
|
||||
# Verify chat session is not accessible normally
|
||||
assert ChatSessionManager.verify_deleted(
|
||||
chat_session=test_chat_session,
|
||||
user_performing_action=basic_user,
|
||||
), "Soft deleted chat session should not be accessible"
|
||||
|
||||
|
||||
def test_hard_delete_with_agentic_search(
|
||||
basic_user: DATestUser, llm_provider: DATestLLMProvider
|
||||
) -> None:
|
||||
"""
|
||||
Test hard deletion of a chat session with agent behavior (sub-questions and sub-queries).
|
||||
Verifies that hard delete removes all related agent records from the database via CASCADE.
|
||||
"""
|
||||
# Create a chat session
|
||||
test_chat_session = ChatSessionManager.create(
|
||||
persona_id=0,
|
||||
description="Test agentic search hard deletion",
|
||||
user_performing_action=basic_user,
|
||||
)
|
||||
|
||||
# Send a message using ChatSessionManager with agentic search enabled
|
||||
# This will create AgentSubQuestion and AgentSubQuery records
|
||||
response = ChatSessionManager.send_message(
|
||||
chat_session_id=test_chat_session.id,
|
||||
message="What are the key principles of software engineering?",
|
||||
user_performing_action=basic_user,
|
||||
use_agentic_search=True,
|
||||
)
|
||||
|
||||
# Verify that the message was processed successfully
|
||||
assert len(response.full_message) > 0, "Chat response should not be empty"
|
||||
|
||||
# Test hard deletion
|
||||
deletion_success = ChatSessionManager.hard_delete(
|
||||
chat_session=test_chat_session,
|
||||
user_performing_action=basic_user,
|
||||
)
|
||||
|
||||
# Verify successful hard deletion
|
||||
assert deletion_success, "Chat hard deletion should succeed"
|
||||
|
||||
# Verify chat session is hard deleted (completely removed)
|
||||
assert ChatSessionManager.verify_hard_deleted(
|
||||
chat_session=test_chat_session,
|
||||
user_performing_action=basic_user,
|
||||
), "Hard deleted chat session should be completely removed from DB"
|
||||
|
||||
# Verify chat session is not accessible
|
||||
assert ChatSessionManager.verify_deleted(
|
||||
chat_session=test_chat_session,
|
||||
user_performing_action=basic_user,
|
||||
), "Hard deleted chat session should not be accessible"
|
||||
|
||||
|
||||
def test_multiple_soft_deletions(
|
||||
basic_user: DATestUser, llm_provider: DATestLLMProvider
|
||||
) -> None:
|
||||
"""
|
||||
Test multiple chat session soft deletions to ensure proper handling
|
||||
when there are multiple related records.
|
||||
"""
|
||||
chat_sessions = []
|
||||
|
||||
# Create multiple chat sessions with potential agent behavior
|
||||
for i in range(3):
|
||||
chat_session = ChatSessionManager.create(
|
||||
persona_id=0,
|
||||
description=f"Test chat session {i} for multi-soft-deletion",
|
||||
user_performing_action=basic_user,
|
||||
)
|
||||
|
||||
# Send a message to create some data
|
||||
ChatSessionManager.send_message(
|
||||
chat_session_id=chat_session.id,
|
||||
message=f"Tell me about topic {i} with detailed analysis",
|
||||
user_performing_action=basic_user,
|
||||
)
|
||||
|
||||
chat_sessions.append(chat_session)
|
||||
|
||||
# Soft delete all chat sessions
|
||||
for chat_session in chat_sessions:
|
||||
deletion_success = ChatSessionManager.soft_delete(
|
||||
chat_session=chat_session,
|
||||
user_performing_action=basic_user,
|
||||
)
|
||||
assert deletion_success, f"Failed to soft delete chat {chat_session.id}"
|
||||
|
||||
# Verify all chat sessions are soft deleted
|
||||
for chat_session in chat_sessions:
|
||||
assert ChatSessionManager.verify_soft_deleted(
|
||||
chat_session=chat_session,
|
||||
user_performing_action=basic_user,
|
||||
), f"Chat {chat_session.id} should be soft deleted"
|
||||
|
||||
assert ChatSessionManager.verify_deleted(
|
||||
chat_session=chat_session,
|
||||
user_performing_action=basic_user,
|
||||
), f"Chat {chat_session.id} should not be accessible normally"
|
||||
|
||||
|
||||
def test_multiple_hard_deletions_with_agent_data(
|
||||
basic_user: DATestUser, llm_provider: DATestLLMProvider
|
||||
) -> None:
|
||||
"""
|
||||
Test multiple chat session hard deletions to ensure CASCADE deletes work correctly
|
||||
when there are multiple related records.
|
||||
"""
|
||||
chat_sessions = []
|
||||
|
||||
# Create multiple chat sessions with potential agent behavior
|
||||
for i in range(3):
|
||||
chat_session = ChatSessionManager.create(
|
||||
persona_id=0,
|
||||
description=f"Test chat session {i} for multi-hard-deletion",
|
||||
user_performing_action=basic_user,
|
||||
)
|
||||
|
||||
# Send a message to create some data
|
||||
ChatSessionManager.send_message(
|
||||
chat_session_id=chat_session.id,
|
||||
message=f"Tell me about topic {i} with detailed analysis",
|
||||
user_performing_action=basic_user,
|
||||
)
|
||||
|
||||
chat_sessions.append(chat_session)
|
||||
|
||||
# Hard delete all chat sessions
|
||||
for chat_session in chat_sessions:
|
||||
deletion_success = ChatSessionManager.hard_delete(
|
||||
chat_session=chat_session,
|
||||
user_performing_action=basic_user,
|
||||
)
|
||||
assert deletion_success, f"Failed to hard delete chat {chat_session.id}"
|
||||
|
||||
# Verify all chat sessions are hard deleted
|
||||
for chat_session in chat_sessions:
|
||||
assert ChatSessionManager.verify_hard_deleted(
|
||||
chat_session=chat_session,
|
||||
user_performing_action=basic_user,
|
||||
), f"Chat {chat_session.id} should be hard deleted"
|
||||
|
||||
assert ChatSessionManager.verify_deleted(
|
||||
chat_session=chat_session,
|
||||
user_performing_action=basic_user,
|
||||
), f"Chat {chat_session.id} should not be accessible"
|
||||
|
||||
|
||||
def test_soft_vs_hard_delete_edge_cases(
|
||||
basic_user: DATestUser, llm_provider: DATestLLMProvider
|
||||
) -> None:
|
||||
"""
|
||||
Test edge cases for both soft and hard deletion to ensure robustness.
|
||||
"""
|
||||
# Test 1: Soft delete a chat session with no messages
|
||||
empty_chat_session_soft = ChatSessionManager.create(
|
||||
persona_id=0,
|
||||
description="Empty chat session for soft delete",
|
||||
user_performing_action=basic_user,
|
||||
)
|
||||
|
||||
# Soft delete without sending any messages
|
||||
deletion_success = ChatSessionManager.soft_delete(
|
||||
chat_session=empty_chat_session_soft,
|
||||
user_performing_action=basic_user,
|
||||
)
|
||||
assert deletion_success, "Empty chat session should be soft deletable"
|
||||
assert ChatSessionManager.verify_soft_deleted(
|
||||
chat_session=empty_chat_session_soft,
|
||||
user_performing_action=basic_user,
|
||||
), "Empty chat session should be confirmed as soft deleted"
|
||||
|
||||
# Test 2: Hard delete a chat session with no messages
|
||||
empty_chat_session_hard = ChatSessionManager.create(
|
||||
persona_id=0,
|
||||
description="Empty chat session for hard delete",
|
||||
user_performing_action=basic_user,
|
||||
)
|
||||
|
||||
# Hard delete without sending any messages
|
||||
deletion_success = ChatSessionManager.hard_delete(
|
||||
chat_session=empty_chat_session_hard,
|
||||
user_performing_action=basic_user,
|
||||
)
|
||||
assert deletion_success, "Empty chat session should be hard deletable"
|
||||
assert ChatSessionManager.verify_hard_deleted(
|
||||
chat_session=empty_chat_session_hard,
|
||||
user_performing_action=basic_user,
|
||||
), "Empty chat session should be confirmed as hard deleted"
|
||||
|
||||
# Test 3: Soft delete a chat session with multiple messages
|
||||
multi_message_chat_soft = ChatSessionManager.create(
|
||||
persona_id=0,
|
||||
description="Multi-message chat session for soft delete",
|
||||
user_performing_action=basic_user,
|
||||
)
|
||||
|
||||
# Send multiple messages to create more complex data
|
||||
for i in range(3):
|
||||
ChatSessionManager.send_message(
|
||||
chat_session_id=multi_message_chat_soft.id,
|
||||
message=f"Message {i}: Tell me about different aspects of this topic",
|
||||
user_performing_action=basic_user,
|
||||
)
|
||||
|
||||
# Verify messages exist
|
||||
chat_history = ChatSessionManager.get_chat_history(
|
||||
chat_session=multi_message_chat_soft,
|
||||
user_performing_action=basic_user,
|
||||
)
|
||||
assert len(chat_history) >= 3, "Chat should have multiple messages"
|
||||
|
||||
# Soft delete the chat with multiple messages
|
||||
deletion_success = ChatSessionManager.soft_delete(
|
||||
chat_session=multi_message_chat_soft,
|
||||
user_performing_action=basic_user,
|
||||
)
|
||||
assert deletion_success, "Multi-message chat session should be soft deletable"
|
||||
assert ChatSessionManager.verify_soft_deleted(
|
||||
chat_session=multi_message_chat_soft,
|
||||
user_performing_action=basic_user,
|
||||
), "Multi-message chat session should be confirmed as soft deleted"
|
||||
|
||||
# Test 4: Hard delete a chat session with multiple messages
|
||||
multi_message_chat_hard = ChatSessionManager.create(
|
||||
persona_id=0,
|
||||
description="Multi-message chat session for hard delete",
|
||||
user_performing_action=basic_user,
|
||||
)
|
||||
|
||||
# Send multiple messages to create more complex data
|
||||
for i in range(3):
|
||||
ChatSessionManager.send_message(
|
||||
chat_session_id=multi_message_chat_hard.id,
|
||||
message=f"Message {i}: Tell me about different aspects of this topic",
|
||||
user_performing_action=basic_user,
|
||||
)
|
||||
|
||||
# Verify messages exist
|
||||
chat_history = ChatSessionManager.get_chat_history(
|
||||
chat_session=multi_message_chat_hard,
|
||||
user_performing_action=basic_user,
|
||||
)
|
||||
assert len(chat_history) >= 3, "Chat should have multiple messages"
|
||||
|
||||
# Hard delete the chat with multiple messages
|
||||
deletion_success = ChatSessionManager.hard_delete(
|
||||
chat_session=multi_message_chat_hard,
|
||||
user_performing_action=basic_user,
|
||||
)
|
||||
assert deletion_success, "Multi-message chat session should be hard deletable"
|
||||
assert ChatSessionManager.verify_hard_deleted(
|
||||
chat_session=multi_message_chat_hard,
|
||||
user_performing_action=basic_user,
|
||||
), "Multi-message chat session should be confirmed as hard deleted"
|
||||
@@ -58,7 +58,6 @@ def test_index_attempt_pagination(reset: None) -> None:
|
||||
# Create an admin user to perform actions
|
||||
user_performing_action: DATestUser = UserManager.create(
|
||||
name="admin_performing_action",
|
||||
is_first_user=True,
|
||||
)
|
||||
|
||||
# Create a CC pair to attach index attempts to
|
||||
|
||||
@@ -46,8 +46,7 @@ def _verify_user_pagination(
|
||||
def test_user_pagination(reset: None) -> None:
|
||||
# Create an admin user to perform actions
|
||||
user_performing_action: DATestUser = UserManager.create(
|
||||
name="admin_performing_action",
|
||||
is_first_user=True,
|
||||
name="admin_performing_action"
|
||||
)
|
||||
|
||||
# Create 9 admin users
|
||||
|
||||
@@ -842,7 +842,7 @@ def test_load_from_checkpoint_cursor_pagination_completion(
|
||||
assert all(isinstance(item, Document) for item in outputs[1].items)
|
||||
assert {
|
||||
item.semantic_identifier for item in cast(list[Document], outputs[1].items)
|
||||
} == {"PR 3 Repo 2", "PR 4 Repo 2"}
|
||||
} == {"3: PR 3 Repo 2", "4: PR 4 Repo 2"}
|
||||
cp1 = outputs[1].next_checkpoint
|
||||
assert (
|
||||
cp1.has_more
|
||||
@@ -869,7 +869,7 @@ def test_load_from_checkpoint_cursor_pagination_completion(
|
||||
assert all(isinstance(item, Document) for item in outputs[3].items)
|
||||
assert {
|
||||
item.semantic_identifier for item in cast(list[Document], outputs[3].items)
|
||||
} == {"PR 1 Repo 1", "PR 2 Repo 1"}
|
||||
} == {"1: PR 1 Repo 1", "2: PR 2 Repo 1"}
|
||||
cp3 = outputs[3].next_checkpoint
|
||||
# This checkpoint is returned early because offset had items. has_more reflects state then.
|
||||
assert cp3.has_more # still need to do issues
|
||||
|
||||
33
backend/tests/unit/onyx/llm/test_model_is_reasoning.py
Normal file
33
backend/tests/unit/onyx/llm/test_model_is_reasoning.py
Normal file
@@ -0,0 +1,33 @@
|
||||
from onyx.llm.utils import model_is_reasoning_model
|
||||
|
||||
|
||||
def test_model_is_reasoning_model() -> None:
|
||||
"""Test that reasoning models are correctly identified and non-reasoning models are not"""
|
||||
|
||||
# Models that should be identified as reasoning models
|
||||
reasoning_models = [
|
||||
("o3", "openai"),
|
||||
("o3-mini", "openai"),
|
||||
("o4-mini", "openai"),
|
||||
("deepseek-reasoner", "deepseek"),
|
||||
("deepseek-r1", "openrouter/deepseek"),
|
||||
("claude-sonnet-4-20250514", "anthropic"),
|
||||
]
|
||||
|
||||
# Models that should NOT be identified as reasoning models
|
||||
non_reasoning_models = [
|
||||
("gpt-4o", "openai"),
|
||||
("claude-3-5-sonnet-20240620", "anthropic"),
|
||||
]
|
||||
|
||||
# Test reasoning models
|
||||
for model_name, provider in reasoning_models:
|
||||
assert (
|
||||
model_is_reasoning_model(model_name, provider) is True
|
||||
), f"Expected {provider}/{model_name} to be identified as a reasoning model"
|
||||
|
||||
# Test non-reasoning models
|
||||
for model_name, provider in non_reasoning_models:
|
||||
assert (
|
||||
model_is_reasoning_model(model_name, provider) is False
|
||||
), f"Expected {provider}/{model_name} to NOT be identified as a reasoning model"
|
||||
@@ -4,6 +4,7 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
|
||||
from onyx.llm.llm_provider_options import ANTHROPIC_PROVIDER_NAME
|
||||
from onyx.llm.llm_provider_options import BEDROCK_PROVIDER_NAME
|
||||
from onyx.tools.utils import explicit_tool_calling_supported
|
||||
|
||||
|
||||
@@ -32,6 +33,40 @@ from onyx.tools.utils import explicit_tool_calling_supported
|
||||
# === Anthropic Scenarios (expected False due to base support being False) ===
|
||||
# Provider is Anthropic, base model does NOT claim FC support
|
||||
(ANTHROPIC_PROVIDER_NAME, "claude-2.1", False, [], False),
|
||||
# === Bedrock Scenarios ===
|
||||
# Bedrock provider with model name containing anthropic model name as substring -> False
|
||||
(
|
||||
BEDROCK_PROVIDER_NAME,
|
||||
"anthropic.claude-3-opus-20240229-v1:0",
|
||||
True,
|
||||
["claude-3-opus-20240229"],
|
||||
False,
|
||||
),
|
||||
# Bedrock provider with model name containing different anthropic model name as substring -> False
|
||||
(
|
||||
BEDROCK_PROVIDER_NAME,
|
||||
"aws-anthropic-claude-3-haiku-20240307",
|
||||
True,
|
||||
["claude-3-haiku-20240307"],
|
||||
False,
|
||||
),
|
||||
# Bedrock provider with model name NOT containing any anthropic model name as substring -> True
|
||||
(
|
||||
BEDROCK_PROVIDER_NAME,
|
||||
"amazon.titan-text-express-v1",
|
||||
True,
|
||||
["claude-3-opus-20240229", "claude-3-haiku-20240307"],
|
||||
True,
|
||||
),
|
||||
# Bedrock provider with model name NOT containing any anthropic model
|
||||
# name as substring, but base model doesn't support FC -> False
|
||||
(
|
||||
BEDROCK_PROVIDER_NAME,
|
||||
"amazon.titan-text-express-v1",
|
||||
False,
|
||||
["claude-3-opus-20240229", "claude-3-haiku-20240307"],
|
||||
False,
|
||||
),
|
||||
# === Non-Anthropic Scenarios ===
|
||||
# Non-Anthropic provider, base model claims FC support -> True
|
||||
("openai", "gpt-4o", True, [], True),
|
||||
@@ -73,6 +108,9 @@ def test_explicit_tool_calling_supported(
|
||||
We don't want to provide that list of tools because our UI doesn't support sequential
|
||||
tool calling yet for (a) and just looks bad for (b), so for now we just treat anthropic
|
||||
models as non-tool-calling.
|
||||
|
||||
Additionally, for Bedrock provider, any model containing an anthropic model name as a
|
||||
substring should also return False for the same reasons.
|
||||
"""
|
||||
mock_find_model_obj.return_value = {
|
||||
"supports_function_calling": mock_model_supports_fc
|
||||
|
||||
3
ct.yaml
3
ct.yaml
@@ -9,7 +9,8 @@ chart-repos:
|
||||
- vespa=https://onyx-dot-app.github.io/vespa-helm-charts
|
||||
- postgresql=https://charts.bitnami.com/bitnami
|
||||
|
||||
helm-extra-args: --debug --timeout 600s
|
||||
# have seen postgres take 10 min to pull ... so 15 min seems like a good timeout?
|
||||
helm-extra-args: --debug --timeout 900s
|
||||
|
||||
# nginx appears to not work on kind, likely due to lack of loadbalancer support
|
||||
# helm-extra-set-args also only works on the command line, not in this yaml
|
||||
|
||||
@@ -131,7 +131,7 @@ Resources:
|
||||
OperatingSystemFamily: LINUX
|
||||
ContainerDefinitions:
|
||||
- Name: vespaengine
|
||||
Image: vespaengine/vespa:8.277.17
|
||||
Image: vespaengine/vespa:8.526.15
|
||||
Cpu: 0
|
||||
Essential: true
|
||||
PortMappings:
|
||||
@@ -162,7 +162,9 @@ Resources:
|
||||
awslogs-region: !Ref AWS::Region
|
||||
awslogs-stream-prefix: ecs
|
||||
User: "1000"
|
||||
Environment: []
|
||||
Environment:
|
||||
- Name: VESPA_SKIP_UPGRADE_CHECK
|
||||
Value: "true"
|
||||
VolumesFrom: []
|
||||
SystemControls: []
|
||||
Volumes:
|
||||
|
||||
@@ -378,6 +378,7 @@ services:
|
||||
|
||||
relational_db:
|
||||
image: postgres:15.2-alpine
|
||||
shm_size: 1g
|
||||
command: -c 'max_connections=250'
|
||||
restart: always
|
||||
environment:
|
||||
@@ -390,8 +391,10 @@ services:
|
||||
|
||||
# This container name cannot have an underscore in it due to Vespa expectations of the URL
|
||||
index:
|
||||
image: vespaengine/vespa:8.277.17
|
||||
image: vespaengine/vespa:8.526.15
|
||||
restart: always
|
||||
environment:
|
||||
- VESPA_SKIP_UPGRADE_CHECK=true
|
||||
ports:
|
||||
- "19071:19071"
|
||||
- "8081:8081"
|
||||
|
||||
@@ -324,6 +324,7 @@ services:
|
||||
|
||||
relational_db:
|
||||
image: postgres:15.2-alpine
|
||||
shm_size: 1g
|
||||
command: -c 'max_connections=250'
|
||||
restart: always
|
||||
environment:
|
||||
@@ -336,8 +337,10 @@ services:
|
||||
|
||||
# This container name cannot have an underscore in it due to Vespa expectations of the URL
|
||||
index:
|
||||
image: vespaengine/vespa:8.277.17
|
||||
image: vespaengine/vespa:8.526.15
|
||||
restart: always
|
||||
environment:
|
||||
- VESPA_SKIP_UPGRADE_CHECK=true
|
||||
ports:
|
||||
- "19071:19071"
|
||||
- "8081:8081"
|
||||
|
||||
@@ -351,6 +351,7 @@ services:
|
||||
|
||||
relational_db:
|
||||
image: postgres:15.2-alpine
|
||||
shm_size: 1g
|
||||
command: -c 'max_connections=250'
|
||||
restart: always
|
||||
environment:
|
||||
@@ -363,8 +364,10 @@ services:
|
||||
|
||||
# This container name cannot have an underscore in it due to Vespa expectations of the URL
|
||||
index:
|
||||
image: vespaengine/vespa:8.277.17
|
||||
image: vespaengine/vespa:8.526.15
|
||||
restart: always
|
||||
environment:
|
||||
- VESPA_SKIP_UPGRADE_CHECK=true
|
||||
ports:
|
||||
- "19071:19071"
|
||||
- "8081:8081"
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user