Compare commits

...

41 Commits

Author SHA1 Message Date
Rei Meguro
4bb3ee03a0 Update GitHub Connector metadata (#4769)
* feat: updated github metadata

* feat: nullity check

* feat: more metadata

* feat: userinfo

* test: connector test + more metadata

* feat: num files changed

* feat str

* feat: list of str
2025-06-04 18:33:14 +00:00
Maciej Bryński
1bb23d6837 Upgrade asyncpg for Python 3.12 (#4699) 2025-06-04 11:44:52 -07:00
joachim-danswer
f447359815 bump up agent timeouts across the board (#4821) 2025-06-04 14:36:46 +00:00
Weves
851e0b05f2 Small tweak to user invite flow 2025-06-04 08:09:33 -07:00
Chris Weaver
094cc940a4 Small embedding model cleanups (#4820)
* Small embedding model cleanups

* fix

* address greptile

* fix build
2025-06-04 00:10:44 +00:00
rkuo-danswer
51be9000bb Feature/vespa bump (#4819)
* bump cloudformation

* update kubernetes

* bump helm chart

* bump docker compose

* update chart.lock

* ai accident!

* bump vespa helm chart for fix

* increase timeout

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-06-04 00:03:01 +00:00
joachim-danswer
80ecdb711d New metadata for Jira for KG (#4785)
* new metadata components

* nits & tests
2025-06-03 20:12:56 +00:00
Chris Weaver
a599176bbf Improve reasoning detection (#4817)
* Improve reasoning detection

* Address greptile comments

* Fix mypy
2025-06-03 20:01:12 +00:00
rkuo-danswer
e0341b4c8a bumping docker push action version (#4816)
Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-06-03 12:47:01 -07:00
CaptainJeff
4c93fd448f fix: updating gemini models (#4806)
Co-authored-by: Jeffrey Drakos <jeffreydrakos@Jeffreys-MacBook-Pro-2.local>
2025-06-03 11:16:42 -07:00
Chris Weaver
84d916e210 Fix hard delete of agentic chats (#4803)
* Fix hard delete of agentic chats

* Update backend/tests/integration/tests/chat/test_chat_deletion.py

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

* Address Greptile comments

* fix tests

---------

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2025-06-03 11:14:11 -07:00
Weves
f57ed2a8dd Adjust script 2025-06-02 18:39:00 -07:00
trial-danswer
713889babf [Connectors][Script] Resume Paused Connectors (#4798)
* [Connectors][Script] Resume Paused Connectors

* Addressing comment
2025-06-02 18:34:00 -07:00
Weves
58c641d8ec Remove ordering-only flow 2025-06-02 18:29:42 -07:00
Weves
94985e24c6 Adjust user file access 2025-06-02 17:28:49 -07:00
Evan Lohn
4c71a5f5ff drive perm sync logs + misc deployment improvements (#4788)
* some logs

* give postgress more memory

* give postgress more memory

* give postgress more memory

* revert

* give postgress more memory

* bump external access limit

* vespa timeout

* deployment consistency

* bump vespa version

* skip upgrade check

* retry permission by ids

* logs

* fix temp docx file issue

* fix drive file deduping

* RK comments

* mypy

* aggregate logs
2025-06-01 23:36:57 +00:00
rkuo-danswer
b19e3a500b try fixing slack bot (#4792)
* try fixing slack bot

* add logging

* just use if

* safe msg get

* .close isn't async

* enforce block list size limit

* various fixes and notes

* don't use self

* switch to punkt_tab

* fix return condition

* synchronize waiting, use non thread local redis locks

* fix log format, make collection copy more explicit for readability

* fix some logging

* unnecessary function

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-05-31 00:39:14 +00:00
Chris Weaver
267fe027f5 Fix failed docs table (#4800)
* Fix initial LLM provider set up

* Fix IndexAttemptErrorsModal pagination
2025-05-30 22:19:52 +00:00
Evan Lohn
0d4d8c0d64 jira daylight savings handling (#4797) 2025-05-30 19:13:38 +00:00
Chris Weaver
6f9d8c0cff Simplify passing in of file IDs for filtering (#4791)
* Simplify passing in of file IDs for filtering

* Address RK comments
2025-05-30 05:08:21 +00:00
Weves
5031096a2b Fix frozen add token rate limit migration 2025-05-29 22:22:36 -07:00
rkuo-danswer
797e113000 add a comment (#4789)
Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-05-29 14:11:19 -07:00
Raunak Bhagat
edc2892785 fix: Remove "Refining Answer" popup (#4783)
* Clean up logic

* Remove dead code

* Remove "Refining Answer" prompt
2025-05-29 19:55:38 +00:00
rkuo-danswer
ef4d5dcec3 new slack rate limiting approach (#4779)
* fix slack rate limit retry handler for groups

* trying to mitigate memory usage during csv download

* Revert "trying to mitigate memory usage during csv download"

This reverts commit 48262eacf6.

* integrated approach to rate limiting

* code review

* try no redis setting

* add pytest-dotenv

* add more debugging

* added comments

* add more stats

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-05-29 19:49:32 +00:00
Evan Lohn
0b5e3e5ee4 skip excel files that openpyxl fails on (#4787) 2025-05-29 18:09:46 +00:00
SubashMohan
f5afb3621e connector filter bug fix (#4771)
* connector filter bug fix

* refactor: use ValidStatuses type for last status filter

---------

Co-authored-by: Subash <subash@onyx.app>
2025-05-29 15:17:04 +00:00
rkuo-danswer
9f72826143 Bugfix/slack bot debugging (#4782)
* adding some logging

* better var name

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-05-28 18:43:11 +00:00
rkuo-danswer
ab7a4184df Feature/helm k8s probes 2 (#4766)
* add probes

* lint fixes

* add beat probes

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-05-28 05:20:24 +00:00
rkuo-danswer
16a14bac89 Feature/tenant reporting 2 (#4750)
* add more info

* fix headers

* add filename as param (merge)

* db manager entry in launch template

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-05-27 23:24:47 +00:00
Raunak Bhagat
baaf31513c fix: Create new grouping for CRM connectors (#4776)
* Create new grouping for CRM connectors
* Edit spacing
2025-05-27 06:51:34 -07:00
Rei Meguro
0b01d7f848 refactor: stream_llm_answer (#4772)
* refactor: stream_llm_answer

* fix: lambda

* fix: mypy, docstring
2025-05-26 22:29:33 +00:00
rkuo-danswer
23ff3476bc print sanitized api key to help troubleshoot (#4764)
Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-05-24 22:27:37 +00:00
Chris Weaver
0c7ba8e2ac Fix/add back search with files (#4767)
* Allow search w/ user files

* more

* More

* Fix

* Improve prompt

* Combine user files + regular uploaded files
2025-05-24 15:44:39 -07:00
Evan Lohn
dad99cbec7 v1 refresh drive creds during perm sync (#4768) 2025-05-23 23:01:26 +00:00
Chris Weaver
3e78c2f087 Fix POSTGRES_IDLE_SESSIONS_TIMEOUT (#4765) 2025-05-23 14:55:23 -07:00
rkuo-danswer
e822afdcfa add probes (#4762)
* add probes

* lint fixes

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-05-23 02:24:54 +00:00
rkuo-danswer
b824951c89 add probe signals for beat (#4760)
Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-05-23 01:41:11 +00:00
Evan Lohn
ca20e527fc fix tool calling for bedrock claude models (#4761)
* fix tool calling for bedrock claude models

* unit test

* fix unit test
2025-05-23 01:13:18 +00:00
rkuo-danswer
c8e65cce1e add k8s probes (#4752)
* add file signals to celery workers

* improve probe script

* cancel tref

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-05-22 20:21:59 +00:00
rkuo-danswer
6c349687da improve impersonation logging slightly (#4758)
Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-05-22 11:17:27 -07:00
Raunak Bhagat
3b64793d4b Update listener passing (#4751) 2025-05-22 01:31:20 +00:00
137 changed files with 3079 additions and 1796 deletions

View File

@@ -64,7 +64,7 @@ jobs:
- name: Backend Image Docker Build and Push
id: build
uses: docker/build-push-action@v5
uses: docker/build-push-action@v6
with:
context: ./backend
file: ./backend/Dockerfile

View File

@@ -54,7 +54,7 @@ jobs:
- name: Build and push by digest
id: build
uses: docker/build-push-action@v5
uses: docker/build-push-action@v6
with:
context: ./web
file: ./web/Dockerfile

View File

@@ -80,7 +80,7 @@ jobs:
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and Push AMD64
uses: docker/build-push-action@v5
uses: docker/build-push-action@v6
with:
context: ./backend
file: ./backend/Dockerfile.model_server
@@ -126,7 +126,7 @@ jobs:
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build and Push ARM64
uses: docker/build-push-action@v5
uses: docker/build-push-action@v6
with:
context: ./backend
file: ./backend/Dockerfile.model_server

View File

@@ -70,7 +70,7 @@ jobs:
- name: Build and push by digest
id: build
uses: docker/build-push-action@v5
uses: docker/build-push-action@v6
with:
context: ./web
file: ./web/Dockerfile

View File

@@ -428,6 +428,29 @@
"--filename",
"generated/openapi.json",
]
},
{
// script to debug multi tenant db issues
"name": "Onyx DB Manager (Top Chunks)",
"type": "debugpy",
"request": "launch",
"program": "scripts/debugging/onyx_db.py",
"cwd": "${workspaceFolder}/backend",
"envFile": "${workspaceFolder}/.env",
"env": {
"PYTHONUNBUFFERED": "1",
"PYTHONPATH": "."
},
"args": [
"--password",
"your_password_here",
"--port",
"5433",
"--report",
"top-chunks",
"--filename",
"generated/tenants_by_num_docs.csv"
]
},
{
"name": "Debug React Web App in Chrome",

View File

@@ -6,11 +6,8 @@ Create Date: 2024-04-15 01:36:02.952809
"""
import json
from typing import cast
from alembic import op
import sqlalchemy as sa
from onyx.key_value_store.factory import get_kv_store
# revision identifiers, used by Alembic.
revision = "703313b75876"
@@ -54,27 +51,10 @@ def upgrade() -> None:
sa.PrimaryKeyConstraint("rate_limit_id", "user_group_id"),
)
try:
settings_json = cast(str, get_kv_store().load("token_budget_settings"))
settings = json.loads(settings_json)
is_enabled = settings.get("enable_token_budget", False)
token_budget = settings.get("token_budget", -1)
period_hours = settings.get("period_hours", -1)
if is_enabled and token_budget > 0 and period_hours > 0:
op.execute(
f"INSERT INTO token_rate_limit \
(enabled, token_budget, period_hours, scope) VALUES \
({is_enabled}, {token_budget}, {period_hours}, 'GLOBAL')"
)
# Delete the dynamic config
get_kv_store().delete("token_budget_settings")
except Exception:
# Ignore if the dynamic config is not found
pass
# NOTE: rate limit settings used to be stored in the "token_budget_settings" key in the
# KeyValueStore. This will now be lost. The KV store works differently than it used to
# so the migration is fairly complicated and likely not worth it to support (pretty much
# nobody will have it set)
def downgrade() -> None:

View File

@@ -0,0 +1,128 @@
"""add_cascade_deletes_to_agent_tables
Revision ID: ca04500b9ee8
Revises: 238b84885828
Create Date: 2025-05-30 16:03:51.112263
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "ca04500b9ee8"
down_revision = "238b84885828"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Drop existing foreign key constraints
op.drop_constraint(
"agent__sub_question_primary_question_id_fkey",
"agent__sub_question",
type_="foreignkey",
)
op.drop_constraint(
"agent__sub_query_parent_question_id_fkey",
"agent__sub_query",
type_="foreignkey",
)
op.drop_constraint(
"chat_message__standard_answer_chat_message_id_fkey",
"chat_message__standard_answer",
type_="foreignkey",
)
op.drop_constraint(
"agent__sub_query__search_doc_sub_query_id_fkey",
"agent__sub_query__search_doc",
type_="foreignkey",
)
# Recreate foreign key constraints with CASCADE delete
op.create_foreign_key(
"agent__sub_question_primary_question_id_fkey",
"agent__sub_question",
"chat_message",
["primary_question_id"],
["id"],
ondelete="CASCADE",
)
op.create_foreign_key(
"agent__sub_query_parent_question_id_fkey",
"agent__sub_query",
"agent__sub_question",
["parent_question_id"],
["id"],
ondelete="CASCADE",
)
op.create_foreign_key(
"chat_message__standard_answer_chat_message_id_fkey",
"chat_message__standard_answer",
"chat_message",
["chat_message_id"],
["id"],
ondelete="CASCADE",
)
op.create_foreign_key(
"agent__sub_query__search_doc_sub_query_id_fkey",
"agent__sub_query__search_doc",
"agent__sub_query",
["sub_query_id"],
["id"],
ondelete="CASCADE",
)
def downgrade() -> None:
# Drop CASCADE foreign key constraints
op.drop_constraint(
"agent__sub_question_primary_question_id_fkey",
"agent__sub_question",
type_="foreignkey",
)
op.drop_constraint(
"agent__sub_query_parent_question_id_fkey",
"agent__sub_query",
type_="foreignkey",
)
op.drop_constraint(
"chat_message__standard_answer_chat_message_id_fkey",
"chat_message__standard_answer",
type_="foreignkey",
)
op.drop_constraint(
"agent__sub_query__search_doc_sub_query_id_fkey",
"agent__sub_query__search_doc",
type_="foreignkey",
)
# Recreate foreign key constraints without CASCADE delete
op.create_foreign_key(
"agent__sub_question_primary_question_id_fkey",
"agent__sub_question",
"chat_message",
["primary_question_id"],
["id"],
)
op.create_foreign_key(
"agent__sub_query_parent_question_id_fkey",
"agent__sub_query",
"agent__sub_question",
["parent_question_id"],
["id"],
)
op.create_foreign_key(
"chat_message__standard_answer_chat_message_id_fkey",
"chat_message__standard_answer",
"chat_message",
["chat_message_id"],
["id"],
)
op.create_foreign_key(
"agent__sub_query__search_doc_sub_query_id_fkey",
"agent__sub_query__search_doc",
"agent__sub_query",
["sub_query_id"],
["id"],
)

View File

@@ -1,8 +1,12 @@
from collections.abc import Callable
from collections.abc import Generator
from datetime import datetime
from datetime import timezone
from typing import Any
from google.oauth2.credentials import Credentials as OAuthCredentials
from google.oauth2.service_account import Credentials as ServiceAccountCredentials
from ee.onyx.external_permissions.google_drive.models import GoogleDrivePermission
from ee.onyx.external_permissions.google_drive.models import PermissionType
from ee.onyx.external_permissions.google_drive.permission_retrieval import (
@@ -13,6 +17,7 @@ from onyx.access.models import DocExternalAccess
from onyx.access.models import ExternalAccess
from onyx.connectors.google_drive.connector import GoogleDriveConnector
from onyx.connectors.google_utils.resources import get_drive_service
from onyx.connectors.google_utils.resources import RefreshableDriveObject
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.models import SlimDocument
from onyx.db.models import ConnectorCredentialPair
@@ -41,6 +46,20 @@ def _get_slim_doc_generator(
)
def _drive_connector_creds_getter(
google_drive_connector: GoogleDriveConnector,
) -> Callable[[], ServiceAccountCredentials | OAuthCredentials]:
def inner() -> ServiceAccountCredentials | OAuthCredentials:
if not google_drive_connector._creds_dict:
raise ValueError(
"Creds dict not found, load_credentials must be called first"
)
google_drive_connector.load_credentials(google_drive_connector._creds_dict)
return google_drive_connector.creds
return inner
def _fetch_permissions_for_permission_ids(
google_drive_connector: GoogleDriveConnector,
permission_info: dict[str, Any],
@@ -54,13 +73,22 @@ def _fetch_permissions_for_permission_ids(
if not permission_ids:
return []
drive_service = get_drive_service(
if not owner_email:
logger.warning(
f"No owner email found for document {doc_id}. Permission info: {permission_info}"
)
refreshable_drive_service = RefreshableDriveObject(
call_stack=lambda creds: get_drive_service(
creds=creds,
user_email=(owner_email or google_drive_connector.primary_admin_email),
),
creds=google_drive_connector.creds,
user_email=(owner_email or google_drive_connector.primary_admin_email),
creds_getter=_drive_connector_creds_getter(google_drive_connector),
)
return get_permissions_by_ids(
drive_service=drive_service,
drive_service=refreshable_drive_service,
doc_id=doc_id,
permission_ids=permission_ids,
)
@@ -172,7 +200,9 @@ def gdrive_doc_sync(
slim_doc_generator = _get_slim_doc_generator(cc_pair, google_drive_connector)
total_processed = 0
for slim_doc_batch in slim_doc_generator:
logger.info(f"Drive perm sync: Processing {len(slim_doc_batch)} documents")
for slim_doc in slim_doc_batch:
if callback:
if callback.should_stop():
@@ -188,3 +218,5 @@ def gdrive_doc_sync(
external_access=ext_access,
doc_id=slim_doc.id,
)
total_processed += len(slim_doc_batch)
logger.info(f"Drive perm sync: Processed {total_processed} total documents")

View File

@@ -1,14 +1,16 @@
from googleapiclient.discovery import Resource # type: ignore
from retry import retry
from ee.onyx.external_permissions.google_drive.models import GoogleDrivePermission
from onyx.connectors.google_utils.google_utils import execute_paginated_retrieval
from onyx.connectors.google_utils.resources import RefreshableDriveObject
from onyx.utils.logger import setup_logger
logger = setup_logger()
@retry(tries=3, delay=2, backoff=2)
def get_permissions_by_ids(
drive_service: Resource,
drive_service: RefreshableDriveObject,
doc_id: str,
permission_ids: list[str],
) -> list[GoogleDrivePermission]:

View File

@@ -8,7 +8,7 @@ from onyx.access.models import DocExternalAccess
from onyx.access.models import ExternalAccess
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
from onyx.connectors.slack.connector import get_channels
from onyx.connectors.slack.connector import make_paginated_slack_api_call_w_retries
from onyx.connectors.slack.connector import make_paginated_slack_api_call
from onyx.connectors.slack.connector import SlackConnector
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
@@ -64,7 +64,7 @@ def _fetch_channel_permissions(
for channel_id in private_channel_ids:
# Collect all member ids for the channel pagination calls
member_ids = []
for result in make_paginated_slack_api_call_w_retries(
for result in make_paginated_slack_api_call(
slack_client.conversations_members,
channel=channel_id,
):
@@ -92,7 +92,7 @@ def _fetch_channel_permissions(
external_user_emails=member_emails,
# No group<->document mapping for slack
external_user_group_ids=set(),
# No way to determine if slack is invite only without enterprise liscense
# No way to determine if slack is invite only without enterprise license
is_public=False,
)

View File

@@ -10,8 +10,8 @@ from slack_sdk import WebClient
from ee.onyx.db.external_perm import ExternalUserGroup
from ee.onyx.external_permissions.slack.utils import fetch_user_id_to_email_map
from onyx.connectors.credentials_provider import OnyxDBCredentialsProvider
from onyx.connectors.slack.connector import make_paginated_slack_api_call_w_retries
from onyx.connectors.slack.connector import SlackConnector
from onyx.connectors.slack.utils import make_paginated_slack_api_call
from onyx.db.models import ConnectorCredentialPair
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.logger import setup_logger
@@ -23,7 +23,7 @@ def _get_slack_group_ids(
slack_client: WebClient,
) -> list[str]:
group_ids = []
for result in make_paginated_slack_api_call_w_retries(slack_client.usergroups_list):
for result in make_paginated_slack_api_call(slack_client.usergroups_list):
for group in result.get("usergroups", []):
group_ids.append(group.get("id"))
return group_ids
@@ -35,7 +35,7 @@ def _get_slack_group_members_email(
user_id_to_email_map: dict[str, str],
) -> list[str]:
group_member_emails = []
for result in make_paginated_slack_api_call_w_retries(
for result in make_paginated_slack_api_call(
slack_client.usergroups_users_list, usergroup=group_name
):
for member_id in result.get("users", []):

View File

@@ -1,13 +1,13 @@
from slack_sdk import WebClient
from onyx.connectors.slack.connector import make_paginated_slack_api_call_w_retries
from onyx.connectors.slack.utils import make_paginated_slack_api_call
def fetch_user_id_to_email_map(
slack_client: WebClient,
) -> dict[str, str]:
user_id_to_email_map = {}
for user_info in make_paginated_slack_api_call_w_retries(
for user_info in make_paginated_slack_api_call(
slack_client.users_list,
):
for user in user_info.get("members", []):

View File

@@ -2,6 +2,7 @@ from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Response
from fastapi_users import exceptions
from ee.onyx.auth.users import current_cloud_superuser
from ee.onyx.server.tenants.models import ImpersonateRequest
@@ -24,14 +25,24 @@ async def impersonate_user(
_: User = Depends(current_cloud_superuser),
) -> Response:
"""Allows a cloud superuser to impersonate another user by generating an impersonation JWT token"""
tenant_id = get_tenant_id_for_email(impersonate_request.email)
try:
tenant_id = get_tenant_id_for_email(impersonate_request.email)
except exceptions.UserNotExists:
detail = f"User has no tenant mapping: {impersonate_request.email=}"
logger.warning(detail)
raise HTTPException(status_code=422, detail=detail)
with get_session_with_tenant(tenant_id=tenant_id) as tenant_session:
user_to_impersonate = get_user_by_email(
impersonate_request.email, tenant_session
)
if user_to_impersonate is None:
raise HTTPException(status_code=404, detail="User not found")
detail = (
f"User not found in tenant: {impersonate_request.email=} {tenant_id=}"
)
logger.warning(detail)
raise HTTPException(status_code=422, detail=detail)
token = await get_redis_strategy().write_token(user_to_impersonate)
response = await auth_backend.transport.get_login_response(token)

View File

@@ -47,10 +47,10 @@ def get_tenant_id_for_email(email: str) -> str:
mapping.active = True
db_session.commit()
tenant_id = mapping.tenant_id
except Exception as e:
logger.exception(f"Error getting tenant id for email {email}: {e}")
raise exceptions.UserNotExists()
if tenant_id is None:
raise exceptions.UserNotExists()
return tenant_id

View File

@@ -92,6 +92,7 @@ def format_embedding_error(
service_name: str,
model: str | None,
provider: EmbeddingProvider,
sanitized_api_key: str | None = None,
status_code: int | None = None,
) -> str:
"""
@@ -103,6 +104,7 @@ def format_embedding_error(
f"{'HTTP error' if status_code else 'Exception'} embedding text with {service_name} - {detail}: "
f"Model: {model} "
f"Provider: {provider} "
f"API Key: {sanitized_api_key} "
f"Exception: {error}"
)
@@ -133,6 +135,7 @@ class CloudEmbedding:
self.timeout = timeout
self.http_client = httpx.AsyncClient(timeout=timeout)
self._closed = False
self.sanitized_api_key = api_key[:4] + "********" + api_key[-4:]
async def _embed_openai(
self, texts: list[str], model: str | None, reduced_dimension: int | None
@@ -306,6 +309,7 @@ class CloudEmbedding:
str(self.provider),
model_name or deployment_name,
self.provider,
sanitized_api_key=self.sanitized_api_key,
status_code=e.response.status_code,
)
logger.error(error_string)
@@ -317,7 +321,11 @@ class CloudEmbedding:
raise AuthenticationError(provider=str(self.provider))
error_string = format_embedding_error(
e, str(self.provider), model_name or deployment_name, self.provider
e,
str(self.provider),
model_name or deployment_name,
self.provider,
sanitized_api_key=self.sanitized_api_key,
)
logger.error(error_string)
logger.debug(f"Exception texts: {texts}")

View File

@@ -11,7 +11,7 @@ class ExternalAccess:
# arbitrary limit to prevent excessively large permissions sets
# not internally enforced ... the caller can check this before using the instance
MAX_NUM_ENTRIES = 1000
MAX_NUM_ENTRIES = 5000
# Emails of external users with access to the doc externally
external_user_emails: set[str]

View File

@@ -1,4 +1,3 @@
from datetime import datetime
from typing import cast
from langchain_core.messages import HumanMessage
@@ -12,6 +11,7 @@ from onyx.agents.agent_search.models import GraphConfig
from onyx.agents.agent_search.shared_graph_utils.agent_prompt_ops import (
trim_prompt_piece,
)
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.prompts.agents.dc_prompts import DC_FORMATTING_NO_BASE_DATA_PROMPT
@@ -113,42 +113,20 @@ def consolidate_research(
)
]
dispatch_timings: list[float] = []
primary_model = graph_config.tooling.primary_llm
def stream_initial_answer() -> list[str]:
response: list[str] = []
for message in primary_model.stream(msg, timeout_override=30, max_tokens=None):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
start_stream_token = datetime.now()
write_custom_event(
"initial_agent_answer",
AgentAnswerPiece(
answer_piece=content,
level=0,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
)
end_stream_token = datetime.now()
dispatch_timings.append(
(end_stream_token - start_stream_token).microseconds
)
response.append(content)
return response
try:
_ = run_with_timeout(
60,
stream_initial_answer,
lambda: stream_llm_answer(
llm=graph_config.tooling.primary_llm,
prompt=msg,
event_name="initial_agent_answer",
writer=writer,
agent_answer_level=0,
agent_answer_question_num=0,
agent_answer_type="agent_level_answer",
timeout_override=30,
max_tokens=None,
),
)
except Exception as e:

View File

@@ -30,6 +30,7 @@ from onyx.agents.agent_search.shared_graph_utils.constants import (
from onyx.agents.agent_search.shared_graph_utils.constants import (
LLM_ANSWER_ERROR_MESSAGE,
)
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.utils import get_answer_citation_ids
@@ -112,44 +113,23 @@ def generate_sub_answer(
config=fast_llm.config,
)
dispatch_timings: list[float] = []
agent_error: AgentErrorLog | None = None
response: list[str] = []
def stream_sub_answer() -> list[str]:
for message in fast_llm.stream(
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION,
max_tokens=AGENT_MAX_TOKENS_SUBANSWER_GENERATION,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
start_stream_token = datetime.now()
write_custom_event(
"sub_answers",
AgentAnswerPiece(
answer_piece=content,
level=level,
level_question_num=question_num,
answer_type="agent_sub_answer",
),
writer,
)
end_stream_token = datetime.now()
dispatch_timings.append(
(end_stream_token - start_stream_token).microseconds
)
response.append(content)
return response
try:
response = run_with_timeout(
response, _ = run_with_timeout(
AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION,
stream_sub_answer,
lambda: stream_llm_answer(
llm=fast_llm,
prompt=msg,
event_name="sub_answers",
writer=writer,
agent_answer_level=level,
agent_answer_question_num=question_num,
agent_answer_type="agent_sub_answer",
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION,
max_tokens=AGENT_MAX_TOKENS_SUBANSWER_GENERATION,
),
)
except (LLMTimeoutError, TimeoutError):

View File

@@ -37,6 +37,7 @@ from onyx.agents.agent_search.shared_graph_utils.constants import (
from onyx.agents.agent_search.shared_graph_utils.constants import (
AgentLLMErrorType,
)
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
from onyx.agents.agent_search.shared_graph_utils.models import InitialAgentResultStats
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
@@ -275,46 +276,24 @@ def generate_initial_answer(
agent_error: AgentErrorLog | None = None
def stream_initial_answer() -> list[str]:
response: list[str] = []
for message in model.stream(
msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
max_tokens=(
AGENT_MAX_TOKENS_ANSWER_GENERATION
if _should_restrict_tokens(model.config)
else None
),
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
start_stream_token = datetime.now()
write_custom_event(
"initial_agent_answer",
AgentAnswerPiece(
answer_piece=content,
level=0,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
)
end_stream_token = datetime.now()
dispatch_timings.append(
(end_stream_token - start_stream_token).microseconds
)
response.append(content)
return response
try:
streamed_tokens = run_with_timeout(
streamed_tokens, dispatch_timings = run_with_timeout(
AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION,
stream_initial_answer,
lambda: stream_llm_answer(
llm=model,
prompt=msg,
event_name="initial_agent_answer",
writer=writer,
agent_answer_level=0,
agent_answer_question_num=0,
agent_answer_type="agent_level_answer",
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
max_tokens=(
AGENT_MAX_TOKENS_ANSWER_GENERATION
if _should_restrict_tokens(model.config)
else None
),
),
)
except (LLMTimeoutError, TimeoutError):

View File

@@ -40,6 +40,7 @@ from onyx.agents.agent_search.shared_graph_utils.constants import (
from onyx.agents.agent_search.shared_graph_utils.constants import (
AgentLLMErrorType,
)
from onyx.agents.agent_search.shared_graph_utils.llm import stream_llm_answer
from onyx.agents.agent_search.shared_graph_utils.models import AgentErrorLog
from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrings
from onyx.agents.agent_search.shared_graph_utils.models import RefinedAgentStats
@@ -63,7 +64,6 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
remove_document_citations,
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import ExtendedToolResponse
from onyx.chat.models import StreamingError
from onyx.configs.agent_configs import AGENT_ANSWER_GENERATION_BY_FAST_LLM
@@ -301,45 +301,24 @@ def generate_validate_refined_answer(
dispatch_timings: list[float] = []
agent_error: AgentErrorLog | None = None
def stream_refined_answer() -> list[str]:
for message in model.stream(
msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION,
max_tokens=(
AGENT_MAX_TOKENS_ANSWER_GENERATION
if _should_restrict_tokens(model.config)
else None
),
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
start_stream_token = datetime.now()
write_custom_event(
"refined_agent_answer",
AgentAnswerPiece(
answer_piece=content,
level=1,
level_question_num=0,
answer_type="agent_level_answer",
),
writer,
)
end_stream_token = datetime.now()
dispatch_timings.append(
(end_stream_token - start_stream_token).microseconds
)
streamed_tokens.append(content)
return streamed_tokens
try:
streamed_tokens = run_with_timeout(
streamed_tokens, dispatch_timings = run_with_timeout(
AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION,
stream_refined_answer,
lambda: stream_llm_answer(
llm=model,
prompt=msg,
event_name="refined_agent_answer",
writer=writer,
agent_answer_level=1,
agent_answer_question_num=0,
agent_answer_type="agent_level_answer",
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION,
max_tokens=(
AGENT_MAX_TOKENS_ANSWER_GENERATION
if _should_restrict_tokens(model.config)
else None
),
),
)
except (LLMTimeoutError, TimeoutError):

View File

@@ -0,0 +1,68 @@
from datetime import datetime
from typing import Literal
from langchain.schema.language_model import LanguageModelInput
from langgraph.types import StreamWriter
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.llm.interfaces import LLM
def stream_llm_answer(
llm: LLM,
prompt: LanguageModelInput,
event_name: str,
writer: StreamWriter,
agent_answer_level: int,
agent_answer_question_num: int,
agent_answer_type: Literal["agent_level_answer", "agent_sub_answer"],
timeout_override: int | None = None,
max_tokens: int | None = None,
) -> tuple[list[str], list[float]]:
"""Stream the initial answer from the LLM.
Args:
llm: The LLM to use.
prompt: The prompt to use.
event_name: The name of the event to write.
writer: The writer to write to.
agent_answer_level: The level of the agent answer.
agent_answer_question_num: The question number within the level.
agent_answer_type: The type of answer ("agent_level_answer" or "agent_sub_answer").
timeout_override: The LLM timeout to use.
max_tokens: The LLM max tokens to use.
Returns:
A tuple of the response and the dispatch timings.
"""
response: list[str] = []
dispatch_timings: list[float] = []
for message in llm.stream(
prompt, timeout_override=timeout_override, max_tokens=max_tokens
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
if not isinstance(content, str):
raise ValueError(
f"Expected content to be a string, but got {type(content)}"
)
start_stream_token = datetime.now()
write_custom_event(
event_name,
AgentAnswerPiece(
answer_piece=content,
level=agent_answer_level,
level_question_num=agent_answer_question_num,
answer_type=agent_answer_type,
),
writer,
)
end_stream_token = datetime.now()
dispatch_timings.append((end_stream_token - start_stream_token).microseconds)
response.append(content)
return response, dispatch_timings

View File

@@ -76,10 +76,11 @@ def hash_api_key(api_key: str) -> str:
# and overlaps are impossible
if api_key.startswith(_API_KEY_PREFIX):
return hashlib.sha256(api_key.encode("utf-8")).hexdigest()
elif api_key.startswith(_DEPRECATED_API_KEY_PREFIX):
if api_key.startswith(_DEPRECATED_API_KEY_PREFIX):
return _deprecated_hash_api_key(api_key)
else:
raise ValueError(f"Invalid API key prefix: {api_key[:3]}")
raise ValueError(f"Invalid API key prefix: {api_key[:3]}")
def build_displayable_api_key(api_key: str) -> str:

View File

@@ -6,6 +6,7 @@ from typing import Any
from typing import cast
import sentry_sdk
from celery import bootsteps # type: ignore
from celery import Task
from celery.app import trace
from celery.exceptions import WorkerShutdown
@@ -22,6 +23,7 @@ from sqlalchemy.orm import Session
from onyx.background.celery.apps.task_formatters import CeleryTaskColoredFormatter
from onyx.background.celery.apps.task_formatters import CeleryTaskPlainFormatter
from onyx.background.celery.celery_utils import celery_is_worker_primary
from onyx.background.celery.celery_utils import make_probe_path
from onyx.configs.constants import ONYX_CLOUD_CELERY_TASK_PREFIX
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.engine import get_sqlalchemy_engine
@@ -340,10 +342,23 @@ def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
task_logger.info("worker_ready signal received.")
# file based way to do readiness/liveness probes
# https://medium.com/ambient-innovation/health-checks-for-celery-in-kubernetes-cf3274a3e106
# https://github.com/celery/celery/issues/4079#issuecomment-1270085680
hostname: str = cast(str, sender.hostname)
path = make_probe_path("readiness", hostname)
path.touch()
logger.info(f"Readiness signal touched at {path}.")
def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
HttpxPool.close_all()
hostname: str = cast(str, sender.hostname)
path = make_probe_path("readiness", hostname)
path.unlink(missing_ok=True)
if not celery_is_worker_primary(sender):
return
@@ -483,3 +498,34 @@ def wait_for_vespa_or_shutdown(sender: Any, **kwargs: Any) -> None:
msg = "Vespa: Readiness probe did not succeed within the timeout. Exiting..."
logger.error(msg)
raise WorkerShutdown(msg)
# File for validating worker liveness
class LivenessProbe(bootsteps.StartStopStep):
requires = {"celery.worker.components:Timer"}
def __init__(self, worker: Any, **kwargs: Any) -> None:
super().__init__(worker, **kwargs)
self.requests: list[Any] = []
self.task_tref = None
self.path = make_probe_path("liveness", worker.hostname)
def start(self, worker: Any) -> None:
self.task_tref = worker.timer.call_repeatedly(
15.0,
self.update_liveness_file,
(worker,),
priority=10,
)
def stop(self, worker: Any) -> None:
self.path.unlink(missing_ok=True)
if self.task_tref:
self.task_tref.cancel()
def update_liveness_file(self, worker: Any) -> None:
self.path.touch()
def get_bootsteps() -> list[type]:
return [LivenessProbe]

View File

@@ -8,6 +8,7 @@ from celery.signals import beat_init
from celery.utils.log import get_task_logger
import onyx.background.celery.apps.app_base as app_base
from onyx.background.celery.celery_utils import make_probe_path
from onyx.background.celery.tasks.beat_schedule import CLOUD_BEAT_MULTIPLIER_DEFAULT
from onyx.configs.constants import POSTGRES_CELERY_BEAT_APP_NAME
from onyx.db.engine import get_all_tenant_ids
@@ -45,6 +46,8 @@ class DynamicTenantScheduler(PersistentScheduler):
f"DynamicTenantScheduler initialized: reload_interval={self._reload_interval}"
)
self._liveness_probe_path = make_probe_path("liveness", "beat@hostname")
# do not set the initial schedule here because we don't have db access yet.
# do it in beat_init after the db engine is initialized
@@ -62,6 +65,8 @@ class DynamicTenantScheduler(PersistentScheduler):
or (now - self._last_reload) > self._reload_interval
):
task_logger.debug("Reload interval reached, initiating task update")
self._liveness_probe_path.touch()
try:
self._try_updating_schedule()
except (AttributeError, KeyError):
@@ -241,6 +246,9 @@ def on_beat_init(sender: Any, **kwargs: Any) -> None:
SqlEngine.init_engine(pool_size=2, max_overflow=0)
app_base.wait_for_redis(sender, **kwargs)
path = make_probe_path("readiness", "beat@hostname")
path.touch()
task_logger.info(f"Readiness signal touched at {path}.")
# first time init of the scheduler after db has been init'ed
scheduler: DynamicTenantScheduler = sender.scheduler

View File

@@ -91,6 +91,10 @@ def on_setup_logging(
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
base_bootsteps = app_base.get_bootsteps()
for bootstep in base_bootsteps:
celery_app.steps["worker"].add(bootstep)
celery_app.autodiscover_tasks(
[
"onyx.background.celery.tasks.pruning",

View File

@@ -102,6 +102,10 @@ def on_setup_logging(
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
base_bootsteps = app_base.get_bootsteps()
for bootstep in base_bootsteps:
celery_app.steps["worker"].add(bootstep)
celery_app.autodiscover_tasks(
[
"onyx.background.celery.tasks.indexing",

View File

@@ -105,6 +105,10 @@ def on_setup_logging(
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
base_bootsteps = app_base.get_bootsteps()
for bootstep in base_bootsteps:
celery_app.steps["worker"].add(bootstep)
celery_app.autodiscover_tasks(
[
"onyx.background.celery.tasks.shared",

View File

@@ -89,6 +89,10 @@ def on_setup_logging(
app_base.on_setup_logging(loglevel, logfile, format, colorize, **kwargs)
base_bootsteps = app_base.get_bootsteps()
for bootstep in base_bootsteps:
celery_app.steps["worker"].add(bootstep)
celery_app.autodiscover_tasks(
[
"onyx.background.celery.tasks.monitoring",

View File

@@ -284,6 +284,10 @@ class HubPeriodicTask(bootsteps.StartStopStep):
celery_app.steps["worker"].add(HubPeriodicTask)
base_bootsteps = app_base.get_bootsteps()
for bootstep in base_bootsteps:
celery_app.steps["worker"].add(bootstep)
celery_app.autodiscover_tasks(
[
"onyx.background.celery.tasks.connector_deletion",

View File

@@ -0,0 +1,55 @@
# script to use as a kubernetes readiness / liveness probe
import argparse
import sys
import time
from pathlib import Path
def main_readiness(filename: str) -> int:
"""Checks if the file exists."""
path = Path(filename)
if not path.is_file():
return 1
return 0
def main_liveness(filename: str) -> int:
"""Checks if the file exists AND was recently modified."""
path = Path(filename)
if not path.is_file():
return 1
stats = path.stat()
liveness_timestamp = stats.st_mtime
current_timestamp = time.time()
time_diff = current_timestamp - liveness_timestamp
if time_diff > 60:
return 1
return 0
if __name__ == "__main__":
exit_code: int
parser = argparse.ArgumentParser(description="k8s readiness/liveness probe")
parser.add_argument(
"--probe",
type=str,
choices=["readiness", "liveness"],
help="The type of probe",
required=True,
)
parser.add_argument("--filename", help="The filename to watch", required=True)
args = parser.parse_args()
if args.probe == "readiness":
exit_code = main_readiness(args.filename)
elif args.probe == "liveness":
exit_code = main_liveness(args.filename)
else:
raise ValueError(f"Unknown probe type: {args.probe}")
sys.exit(exit_code)

View File

@@ -1,5 +1,6 @@
from datetime import datetime
from datetime import timezone
from pathlib import Path
from typing import Any
from typing import cast
@@ -121,3 +122,20 @@ def httpx_init_vespa_pool(
http2=False,
limits=httpx.Limits(max_keepalive_connections=max_keepalive_connections),
)
def make_probe_path(probe: str, hostname: str) -> Path:
"""templates the path for a k8s probe file.
e.g. /tmp/onyx_k8s_indexing_readiness.txt
"""
hostname_parts = hostname.split("@")
if len(hostname_parts) != 2:
raise ValueError(f"hostname could not be split! {hostname=}")
name = hostname_parts[0]
if not name:
raise ValueError(f"name cannot be empty! {name=}")
safe_name = "".join(c for c in name if c.isalnum()).rstrip()
return Path(f"/tmp/onyx_k8s_{safe_name}_{probe}.txt")

View File

@@ -43,6 +43,7 @@ from onyx.chat.models import UserKnowledgeFilePacket
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
from onyx.chat.user_files.parse_user_files import parse_user_files
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from onyx.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
@@ -52,11 +53,9 @@ from onyx.configs.constants import BASIC_KEY
from onyx.configs.constants import MessageType
from onyx.configs.constants import MilestoneRecordType
from onyx.configs.constants import NO_AUTH_USER_ID
from onyx.context.search.enums import LLMEvaluationType
from onyx.context.search.enums import OptionalSearchSetting
from onyx.context.search.enums import QueryFlow
from onyx.context.search.enums import SearchType
from onyx.context.search.models import BaseFilters
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import RetrievalDetails
from onyx.context.search.retrieval.search_runner import (
@@ -95,9 +94,7 @@ from onyx.document_index.factory import get_default_document_index
from onyx.file_store.models import ChatFileType
from onyx.file_store.models import FileDescriptor
from onyx.file_store.models import InMemoryChatFile
from onyx.file_store.utils import get_user_files
from onyx.file_store.utils import load_all_chat_files
from onyx.file_store.utils import load_in_memory_chat_files
from onyx.file_store.utils import save_files
from onyx.llm.exceptions import GenAIDisabledException
from onyx.llm.factory import get_llms_for_persona
@@ -312,8 +309,7 @@ def _handle_internet_search_tool_response_summary(
def _get_force_search_settings(
new_msg_req: CreateChatMessageRequest,
tools: list[Tool],
user_file_ids: list[int],
user_folder_ids: list[int],
search_tool_override_kwargs: SearchToolOverrideKwargs | None,
) -> ForceUseTool:
internet_search_available = any(
isinstance(tool, InternetSearchTool) for tool in tools
@@ -321,45 +317,24 @@ def _get_force_search_settings(
search_tool_available = any(isinstance(tool, SearchTool) for tool in tools)
if not internet_search_available and not search_tool_available:
if new_msg_req.force_user_file_search:
return ForceUseTool(force_use=True, tool_name=SearchTool._NAME)
else:
# Does not matter much which tool is set here as force is false and neither tool is available
return ForceUseTool(force_use=False, tool_name=SearchTool._NAME)
tool_name = SearchTool._NAME if search_tool_available else InternetSearchTool._NAME
# Does not matter much which tool is set here as force is false and neither tool is available
return ForceUseTool(force_use=False, tool_name=SearchTool._NAME)
# Currently, the internet search tool does not support query override
args = (
{"query": new_msg_req.query_override}
if new_msg_req.query_override and tool_name == SearchTool._NAME
if new_msg_req.query_override and search_tool_available
else None
)
# Create override_kwargs for the search tool if user_file_ids are provided
override_kwargs = None
if (user_file_ids or user_folder_ids) and tool_name == SearchTool._NAME:
override_kwargs = SearchToolOverrideKwargs(
force_no_rerank=False,
alternate_db_session=None,
retrieved_sections_callback=None,
skip_query_analysis=False,
user_file_ids=user_file_ids,
user_folder_ids=user_folder_ids,
)
if new_msg_req.file_descriptors:
# If user has uploaded files they're using, don't run any of the search tools
return ForceUseTool(force_use=False, tool_name=tool_name)
should_force_search = any(
[
new_msg_req.force_user_file_search,
new_msg_req.retrieval_options
and new_msg_req.retrieval_options.run_search
== OptionalSearchSetting.ALWAYS,
new_msg_req.search_doc_ids,
new_msg_req.query_override is not None,
DISABLE_LLM_CHOOSE_SEARCH,
search_tool_override_kwargs is not None,
]
)
@@ -369,62 +344,18 @@ def _get_force_search_settings(
return ForceUseTool(
force_use=True,
tool_name=tool_name,
tool_name=SearchTool._NAME,
args=args,
override_kwargs=override_kwargs,
override_kwargs=search_tool_override_kwargs,
)
return ForceUseTool(
force_use=False, tool_name=tool_name, args=args, override_kwargs=override_kwargs
)
def _get_user_knowledge_files(
info: AnswerPostInfo,
user_files: list[InMemoryChatFile],
file_id_to_user_file: dict[str, InMemoryChatFile],
) -> Generator[UserKnowledgeFilePacket, None, None]:
if not info.qa_docs_response:
return
logger.info(
f"ORDERING: Processing search results for ordering {len(user_files)} user files"
)
# Extract document order from search results
doc_order = []
for doc in info.qa_docs_response.top_documents:
doc_id = doc.document_id
if str(doc_id).startswith("USER_FILE_CONNECTOR__"):
file_id = doc_id.replace("USER_FILE_CONNECTOR__", "")
if file_id in file_id_to_user_file:
doc_order.append(file_id)
logger.info(f"ORDERING: Found {len(doc_order)} files from search results")
# Add any files that weren't in search results at the end
missing_files = [
f_id for f_id in file_id_to_user_file.keys() if f_id not in doc_order
]
missing_files.extend(doc_order)
doc_order = missing_files
logger.info(f"ORDERING: Added {len(missing_files)} missing files to the end")
# Reorder user files based on search results
ordered_user_files = [
file_id_to_user_file[f_id] for f_id in doc_order if f_id in file_id_to_user_file
]
yield UserKnowledgeFilePacket(
user_files=[
FileDescriptor(
id=str(file.file_id),
type=ChatFileType.USER_KNOWLEDGE,
)
for file in ordered_user_files
]
force_use=False,
tool_name=(
SearchTool._NAME if search_tool_available else InternetSearchTool._NAME
),
args=args,
override_kwargs=None,
)
@@ -488,8 +419,6 @@ def _process_tool_response(
retrieval_options: RetrievalDetails | None,
user_file_files: list[UserFile] | None,
user_files: list[InMemoryChatFile] | None,
file_id_to_user_file: dict[str, InMemoryChatFile],
search_for_ordering_only: bool,
) -> Generator[ChatPacket, None, dict[SubQuestionKey, AnswerPostInfo]]:
level, level_question_num = (
(packet.level, packet.level_question_num)
@@ -501,21 +430,8 @@ def _process_tool_response(
assert level_question_num is not None
info = info_by_subq[SubQuestionKey(level=level, question_num=level_question_num)]
# Skip LLM relevance processing entirely for ordering-only mode
if search_for_ordering_only and packet.id == SECTION_RELEVANCE_LIST_ID:
logger.info(
"Fast path: Completely bypassing section relevance processing for ordering-only mode"
)
# Skip this packet entirely since it would trigger LLM processing
return info_by_subq
# TODO: don't need to dedupe here when we do it in agent flow
if packet.id == SEARCH_RESPONSE_SUMMARY_ID:
if search_for_ordering_only:
logger.info(
"Fast path: Skipping document deduplication for ordering-only mode"
)
(
info.qa_docs_response,
info.reference_db_search_docs,
@@ -525,34 +441,15 @@ def _process_tool_response(
db_session=db_session,
selected_search_docs=selected_db_search_docs,
# Deduping happens at the last step to avoid harming quality by dropping content early on
# Skip deduping completely for ordering-only mode to save time
dedupe_docs=bool(
not search_for_ordering_only
and retrieval_options
and retrieval_options.dedupe_docs
),
user_files=user_file_files if search_for_ordering_only else [],
loaded_user_files=(user_files if search_for_ordering_only else []),
dedupe_docs=bool(retrieval_options and retrieval_options.dedupe_docs),
user_files=[],
loaded_user_files=[],
)
# If we're using search just for ordering user files
if search_for_ordering_only and user_files:
yield from _get_user_knowledge_files(
info=info,
user_files=user_files,
file_id_to_user_file=file_id_to_user_file,
)
yield info.qa_docs_response
elif packet.id == SECTION_RELEVANCE_LIST_ID:
relevance_sections = packet.response
if search_for_ordering_only:
logger.info(
"Performance: Skipping relevance filtering for ordering-only mode"
)
return info_by_subq
if info.reference_db_search_docs is None:
logger.warning("No reference docs found for relevance filtering")
return info_by_subq
@@ -665,8 +562,6 @@ def stream_chat_message_objects(
try:
# Move these variables inside the try block
file_id_to_user_file = {}
user_id = user.id if user is not None else None
chat_session = get_chat_session_by_id(
@@ -840,60 +735,23 @@ def stream_chat_message_objects(
for folder in persona.user_folders:
user_folder_ids.append(folder.id)
# Initialize flag for user file search
use_search_for_user_files = False
user_files: list[InMemoryChatFile] | None = None
search_for_ordering_only = False
user_file_files: list[UserFile] | None = None
if user_file_ids or user_folder_ids:
# Load user files
user_files = load_in_memory_chat_files(
user_file_ids or [],
user_folder_ids or [],
db_session,
)
user_file_files = get_user_files(
user_file_ids or [],
user_folder_ids or [],
db_session,
)
# Store mapping of file_id to file for later reordering
if user_files:
file_id_to_user_file = {file.file_id: file for file in user_files}
# Calculate token count for the files
from onyx.db.user_documents import calculate_user_files_token_count
from onyx.chat.prompt_builder.citations_prompt import (
compute_max_document_tokens_for_persona,
)
total_tokens = calculate_user_files_token_count(
user_file_ids or [],
user_folder_ids or [],
db_session,
)
# Calculate available tokens for documents based on prompt, user input, etc.
available_tokens = compute_max_document_tokens_for_persona(
db_session=db_session,
persona=persona,
actual_user_input=message_text, # Use the actual user message
)
logger.debug(
f"Total file tokens: {total_tokens}, Available tokens: {available_tokens}"
)
# ALWAYS use search for user files, but track if we need it for context or just ordering
use_search_for_user_files = True
# If files are small enough for context, we'll just use search for ordering
search_for_ordering_only = total_tokens <= available_tokens
if search_for_ordering_only:
# Add original user files to context since they fit
if user_files:
latest_query_files.extend(user_files)
# Load in user files into memory and create search tool override kwargs if needed
# if we have enough tokens and no folders, we don't need to use search
# we can just pass them into the prompt directly
(
in_memory_user_files,
user_file_models,
search_tool_override_kwargs_for_user_files,
) = parse_user_files(
user_file_ids=user_file_ids,
user_folder_ids=user_folder_ids,
db_session=db_session,
persona=persona,
actual_user_input=message_text,
user_id=user_id,
)
if not search_tool_override_kwargs_for_user_files:
latest_query_files.extend(in_memory_user_files)
if user_message:
attach_files_to_chat_message(
@@ -1052,10 +910,13 @@ def stream_chat_message_objects(
prompt_config=prompt_config,
db_session=db_session,
user=user,
user_knowledge_present=bool(user_files or user_folder_ids),
llm=llm,
fast_llm=fast_llm,
use_file_search=new_msg_req.force_user_file_search,
run_search_setting=(
retrieval_options.run_search
if retrieval_options
else OptionalSearchSetting.AUTO
),
search_tool_config=SearchToolConfig(
answer_style_config=answer_style_config,
document_pruning_config=document_pruning_config,
@@ -1086,128 +947,23 @@ def stream_chat_message_objects(
tools.extend(tool_list)
force_use_tool = _get_force_search_settings(
new_msg_req, tools, user_file_ids, user_folder_ids
new_msg_req, tools, search_tool_override_kwargs_for_user_files
)
# Set force_use if user files exceed token limit
if use_search_for_user_files:
try:
# Check if search tool is available in the tools list
search_tool_available = any(
isinstance(tool, SearchTool) for tool in tools
)
# If no search tool is available, add one
if not search_tool_available:
logger.info("No search tool available, creating one for user files")
# Create a basic search tool config
search_tool_config = SearchToolConfig(
answer_style_config=answer_style_config,
document_pruning_config=document_pruning_config,
retrieval_options=retrieval_options or RetrievalDetails(),
)
# Create and add the search tool
search_tool = SearchTool(
db_session=db_session,
user=user,
persona=persona,
retrieval_options=search_tool_config.retrieval_options,
prompt_config=prompt_config,
llm=llm,
fast_llm=fast_llm,
pruning_config=search_tool_config.document_pruning_config,
answer_style_config=search_tool_config.answer_style_config,
evaluation_type=(
LLMEvaluationType.BASIC
if persona.llm_relevance_filter
else LLMEvaluationType.SKIP
),
bypass_acl=bypass_acl,
)
# Add the search tool to the tools list
tools.append(search_tool)
logger.info(
"Added search tool for user files that exceed token limit"
)
# Now set force_use_tool.force_use to True
force_use_tool.force_use = True
force_use_tool.tool_name = SearchTool._NAME
# Set query argument if not already set
if not force_use_tool.args:
force_use_tool.args = {"query": final_msg.message}
# Pass the user file IDs to the search tool
if user_file_ids or user_folder_ids:
# Create a BaseFilters object with user_file_ids
if not retrieval_options:
retrieval_options = RetrievalDetails()
if not retrieval_options.filters:
retrieval_options.filters = BaseFilters()
# Set user file and folder IDs in the filters
retrieval_options.filters.user_file_ids = user_file_ids
retrieval_options.filters.user_folder_ids = user_folder_ids
# Create override kwargs for the search tool
override_kwargs = SearchToolOverrideKwargs(
force_no_rerank=search_for_ordering_only, # Skip reranking for ordering-only
alternate_db_session=None,
retrieved_sections_callback=None,
skip_query_analysis=search_for_ordering_only, # Skip query analysis for ordering-only
user_file_ids=user_file_ids,
user_folder_ids=user_folder_ids,
ordering_only=search_for_ordering_only, # Set ordering_only flag for fast path
)
# Set the override kwargs in the force_use_tool
force_use_tool.override_kwargs = override_kwargs
if search_for_ordering_only:
logger.info(
"Fast path: Configured search tool with optimized settings for ordering-only"
)
logger.info(
"Fast path: Skipping reranking and query analysis for ordering-only mode"
)
logger.info(
f"Using {len(user_file_ids or [])} files and {len(user_folder_ids or [])} folders"
)
else:
logger.info(
"Configured search tool to use ",
f"{len(user_file_ids or [])} files and {len(user_folder_ids or [])} folders",
)
except Exception as e:
logger.exception(
f"Error configuring search tool for user files: {str(e)}"
)
use_search_for_user_files = False
# TODO: unify message history with single message history
message_history = [
PreviousMessage.from_chat_message(msg, files) for msg in history_msgs
]
if not use_search_for_user_files and user_files:
if not search_tool_override_kwargs_for_user_files and in_memory_user_files:
yield UserKnowledgeFilePacket(
user_files=[
FileDescriptor(
id=str(file.file_id), type=ChatFileType.USER_KNOWLEDGE
id=str(file.file_id), type=file.file_type, name=file.filename
)
for file in user_files
for file in in_memory_user_files
]
)
if search_for_ordering_only:
logger.info(
"Performance: Forcing LLMEvaluationType.SKIP to prevent chunk evaluation for ordering-only search"
)
prompt_builder = AnswerPromptBuilder(
user_message=default_build_user_message(
user_query=final_msg.message,
@@ -1265,10 +1021,8 @@ def stream_chat_message_objects(
selected_db_search_docs=selected_db_search_docs,
info_by_subq=info_by_subq,
retrieval_options=retrieval_options,
user_file_files=user_file_files,
user_files=user_files,
file_id_to_user_file=file_id_to_user_file,
search_for_ordering_only=search_for_ordering_only,
user_file_files=user_file_models,
user_files=in_memory_user_files,
)
elif isinstance(packet, StreamStopInfo):

View File

@@ -9,12 +9,12 @@ from onyx.context.search.models import InferenceChunk
from onyx.db.models import Persona
from onyx.db.prompts import get_default_prompt
from onyx.db.search_settings import get_multilingual_expansion
from onyx.file_store.models import InMemoryChatFile
from onyx.llm.factory import get_llms_for_persona
from onyx.llm.factory import get_main_llm_from_tuple
from onyx.llm.interfaces import LLMConfig
from onyx.llm.utils import build_content_with_imgs
from onyx.llm.utils import check_number_of_tokens
from onyx.llm.utils import message_to_prompt_and_imgs
from onyx.prompts.chat_prompts import REQUIRE_CITATION_STATEMENT
from onyx.prompts.constants import DEFAULT_IGNORE_STATEMENT
from onyx.prompts.direct_qa_prompts import CITATIONS_PROMPT
@@ -120,7 +120,8 @@ def build_citations_system_message(
def build_citations_user_message(
message: HumanMessage,
user_query: str,
files: list[InMemoryChatFile],
prompt_config: PromptConfig,
context_docs: list[LlmDoc] | list[InferenceChunk],
all_doc_useful: bool,
@@ -135,7 +136,6 @@ def build_citations_user_message(
history_block = (
HISTORY_BLOCK.format(history_str=history_message) if history_message else ""
)
query, img_urls = message_to_prompt_and_imgs(message)
if context_docs:
context_docs_str = build_complete_context_str(context_docs)
@@ -146,7 +146,7 @@ def build_citations_user_message(
optional_ignore_statement=optional_ignore,
context_docs_str=context_docs_str,
task_prompt=task_prompt_with_reminder,
user_query=query,
user_query=user_query,
history_block=history_block,
)
else:
@@ -154,16 +154,17 @@ def build_citations_user_message(
user_prompt = CITATIONS_PROMPT_FOR_TOOL_CALLING.format(
context_type=context_type,
task_prompt=task_prompt_with_reminder,
user_query=query,
user_query=user_query,
history_block=history_block,
)
user_prompt = user_prompt.strip()
tag_handled_prompt = handle_onyx_date_awareness(user_prompt, prompt_config)
user_msg = HumanMessage(
content=(
build_content_with_imgs(user_prompt, img_urls=img_urls)
if img_urls
else user_prompt
build_content_with_imgs(tag_handled_prompt, files)
if files
else tag_handled_prompt
)
)

View File

@@ -0,0 +1,106 @@
from uuid import UUID
from sqlalchemy.orm import Session
from onyx.db.models import Persona
from onyx.db.models import UserFile
from onyx.file_store.models import InMemoryChatFile
from onyx.file_store.utils import get_user_files_as_user
from onyx.file_store.utils import load_in_memory_chat_files
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.utils.logger import setup_logger
logger = setup_logger()
def parse_user_files(
user_file_ids: list[int],
user_folder_ids: list[int],
db_session: Session,
persona: Persona,
actual_user_input: str,
# should only be None if auth is disabled
user_id: UUID | None,
) -> tuple[list[InMemoryChatFile], list[UserFile], SearchToolOverrideKwargs | None]:
"""
Parse user files and folders into in-memory chat files and create search tool override kwargs.
Only creates SearchToolOverrideKwargs if token overflow occurs or folders are present.
Args:
user_file_ids: List of user file IDs to load
user_folder_ids: List of user folder IDs to load
db_session: Database session
persona: Persona to calculate available tokens
actual_user_input: User's input message for token calculation
user_id: User ID to validate file ownership
Returns:
Tuple of (
loaded user files,
user file models,
search tool override kwargs if token
overflow or folders present
)
"""
# Return empty results if no files or folders specified
if not user_file_ids and not user_folder_ids:
return [], [], None
# Load user files from the database into memory
user_files = load_in_memory_chat_files(
user_file_ids or [],
user_folder_ids or [],
db_session,
)
user_file_models = get_user_files_as_user(
user_file_ids or [],
user_folder_ids or [],
user_id,
db_session,
)
# Calculate token count for the files, need to import here to avoid circular import
# TODO: fix this
from onyx.db.user_documents import calculate_user_files_token_count
from onyx.chat.prompt_builder.citations_prompt import (
compute_max_document_tokens_for_persona,
)
total_tokens = calculate_user_files_token_count(
user_file_ids or [],
user_folder_ids or [],
db_session,
)
# Calculate available tokens for documents based on prompt, user input, etc.
available_tokens = compute_max_document_tokens_for_persona(
db_session=db_session,
persona=persona,
actual_user_input=actual_user_input,
)
logger.debug(
f"Total file tokens: {total_tokens}, Available tokens: {available_tokens}"
)
have_enough_tokens = total_tokens <= available_tokens
# If we have enough tokens and no folders, we don't need search
# we can just pass them into the prompt directly
if have_enough_tokens and not user_folder_ids:
# No search tool override needed - files can be passed directly
return user_files, user_file_models, None
# Token overflow or folders present - need to use search tool
override_kwargs = SearchToolOverrideKwargs(
force_no_rerank=have_enough_tokens,
alternate_db_session=None,
retrieved_sections_callback=None,
skip_query_analysis=have_enough_tokens,
user_file_ids=user_file_ids,
user_folder_ids=user_folder_ids,
)
return user_files, user_file_models, override_kwargs

View File

@@ -170,169 +170,169 @@ AGENT_MAX_STATIC_HISTORY_WORD_LENGTH = int(
) # 2000
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION = 10 # in seconds
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION = 15 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION
)
AGENT_DEFAULT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION = 30 # in seconds
AGENT_DEFAULT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION = 45 # in seconds
AGENT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION = int(
os.environ.get("AGENT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION")
or AGENT_DEFAULT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION
)
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION = 3 # in seconds
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION = 5 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION
)
AGENT_DEFAULT_TIMEOUT_LLM_DOCUMENT_VERIFICATION = 5 # in seconds
AGENT_DEFAULT_TIMEOUT_LLM_DOCUMENT_VERIFICATION = 8 # in seconds
AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION = int(
os.environ.get("AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION")
or AGENT_DEFAULT_TIMEOUT_LLM_DOCUMENT_VERIFICATION
)
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_GENERAL_GENERATION = 5 # in seconds
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_GENERAL_GENERATION = 8 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_GENERAL_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_GENERAL_GENERATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_GENERAL_GENERATION
)
AGENT_DEFAULT_TIMEOUT_LLM_GENERAL_GENERATION = 30 # in seconds
AGENT_DEFAULT_TIMEOUT_LLM_GENERAL_GENERATION = 45 # in seconds
AGENT_TIMEOUT_LLM_GENERAL_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_LLM_GENERAL_GENERATION")
or AGENT_DEFAULT_TIMEOUT_LLM_GENERAL_GENERATION
)
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION = 4 # in seconds
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION = 8 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION
)
AGENT_DEFAULT_TIMEOUT_LLM_SUBQUESTION_GENERATION = 5 # in seconds
AGENT_DEFAULT_TIMEOUT_LLM_SUBQUESTION_GENERATION = 10 # in seconds
AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION")
or AGENT_DEFAULT_TIMEOUT_LLM_SUBQUESTION_GENERATION
)
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION = 6 # in seconds
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION = 9 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION
)
AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION = 40 # in seconds
AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION = 45 # in seconds
AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION
)
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION = 10 # in seconds
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION = 15 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION
)
AGENT_DEFAULT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION = 25 # in seconds
AGENT_DEFAULT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION = 40 # in seconds
AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION
)
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION = 15 # in seconds
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION = 20 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION
)
AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION = 45 # in seconds
AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION = 60 # in seconds
AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION
)
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK = 4 # in seconds
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK = 6 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK
)
AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_CHECK = 8 # in seconds
AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_CHECK = 12 # in seconds
AGENT_TIMEOUT_LLM_SUBANSWER_CHECK = int(
os.environ.get("AGENT_TIMEOUT_LLM_SUBANSWER_CHECK")
or AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_CHECK
)
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION = 4 # in seconds
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION = 6 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION
)
AGENT_DEFAULT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION = 8 # in seconds
AGENT_DEFAULT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION = 12 # in seconds
AGENT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION")
or AGENT_DEFAULT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION
)
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION = 2 # in seconds
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION = 4 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION
)
AGENT_DEFAULT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION = 3 # in seconds
AGENT_DEFAULT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION = 6 # in seconds
AGENT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION")
or AGENT_DEFAULT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION
)
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION = 4 # in seconds
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION = 6 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION
)
AGENT_DEFAULT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION = 5 # in seconds
AGENT_DEFAULT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION = 8 # in seconds
AGENT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION")
or AGENT_DEFAULT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION
)
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS = 4 # in seconds
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS = 6 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS
)
AGENT_DEFAULT_TIMEOUT_LLM_COMPARE_ANSWERS = 8 # in seconds
AGENT_DEFAULT_TIMEOUT_LLM_COMPARE_ANSWERS = 12 # in seconds
AGENT_TIMEOUT_LLM_COMPARE_ANSWERS = int(
os.environ.get("AGENT_TIMEOUT_LLM_COMPARE_ANSWERS")
or AGENT_DEFAULT_TIMEOUT_LLM_COMPARE_ANSWERS
)
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION = 4 # in seconds
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION = 6 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION
)
AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION = 8 # in seconds
AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION = 12 # in seconds
AGENT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION = int(
os.environ.get("AGENT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION")
or AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION

View File

@@ -21,6 +21,9 @@ from onyx.connectors.confluence.utils import datetime_from_string
from onyx.connectors.confluence.utils import process_attachment
from onyx.connectors.confluence.utils import update_param_in_path
from onyx.connectors.confluence.utils import validate_attachment_filetype
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
is_atlassian_date_error,
)
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
@@ -76,10 +79,6 @@ ONE_DAY = ONE_HOUR * 24
MAX_CACHED_IDS = 100
def _should_propagate_error(e: Exception) -> bool:
return "field 'updated' is invalid" in str(e)
class ConfluenceCheckpoint(ConnectorCheckpoint):
next_page_url: str | None
@@ -367,7 +366,7 @@ class ConfluenceConnector(
)
except Exception as e:
logger.error(f"Error converting page {page.get('id', 'unknown')}: {e}")
if _should_propagate_error(e):
if is_atlassian_date_error(e): # propagate error to be caught and retried
raise
return ConnectorFailure(
failed_document=DocumentFailure(
@@ -446,7 +445,9 @@ class ConfluenceConnector(
f"Failed to extract/summarize attachment {attachment['title']}",
exc_info=e,
)
if _should_propagate_error(e):
if is_atlassian_date_error(
e
): # propagate error to be caught and retried
raise
return ConnectorFailure(
failed_document=DocumentFailure(
@@ -536,7 +537,7 @@ class ConfluenceConnector(
try:
return self._fetch_document_batches(checkpoint, start, end)
except Exception as e:
if _should_propagate_error(e) and start is not None:
if is_atlassian_date_error(e) and start is not None:
logger.warning(
"Confluence says we provided an invalid 'updated' field. This may indicate"
"a real issue, but can also appear during edge cases like daylight"

View File

@@ -86,3 +86,7 @@ def get_oauth_callback_uri(base_domain: str, connector_id: str) -> str:
# Used for development
base_domain = CONNECTOR_LOCALHOST_OVERRIDE
return f"{base_domain.strip('/')}/connector/oauth/callback/{connector_id}"
def is_atlassian_date_error(e: Exception) -> bool:
return "field 'updated' is invalid" in str(e)

View File

@@ -14,6 +14,7 @@ from github import RateLimitExceededException
from github import Repository
from github.GithubException import GithubException
from github.Issue import Issue
from github.NamedUser import NamedUser
from github.PaginatedList import PaginatedList
from github.PullRequest import PullRequest
from github.Requester import Requester
@@ -219,6 +220,18 @@ def _get_batch_rate_limited(
)
def _get_userinfo(user: NamedUser) -> dict[str, str]:
return {
k: v
for k, v in {
"login": user.login,
"name": user.name,
"email": user.email,
}.items()
if v is not None
}
def _convert_pr_to_document(pull_request: PullRequest) -> Document:
return Document(
id=pull_request.html_url,
@@ -226,7 +239,7 @@ def _convert_pr_to_document(pull_request: PullRequest) -> Document:
TextSection(link=pull_request.html_url, text=pull_request.body or "")
],
source=DocumentSource.GITHUB,
semantic_identifier=pull_request.title,
semantic_identifier=f"{pull_request.number}: {pull_request.title}",
# updated_at is UTC time but is timezone unaware, explicitly add UTC
# as there is logic in indexing to prevent wrong timestamped docs
# due to local time discrepancies with UTC
@@ -236,8 +249,49 @@ def _convert_pr_to_document(pull_request: PullRequest) -> Document:
else None
),
metadata={
"merged": str(pull_request.merged),
"state": pull_request.state,
k: [str(vi) for vi in v] if isinstance(v, list) else str(v)
for k, v in {
"object_type": "PullRequest",
"id": pull_request.number,
"merged": pull_request.merged,
"state": pull_request.state,
"user": _get_userinfo(pull_request.user) if pull_request.user else None,
"assignees": [
_get_userinfo(assignee) for assignee in pull_request.assignees
],
"repo": (
pull_request.base.repo.full_name if pull_request.base else None
),
"num_commits": str(pull_request.commits),
"num_files_changed": str(pull_request.changed_files),
"labels": [label.name for label in pull_request.labels],
"created_at": (
pull_request.created_at.replace(tzinfo=timezone.utc)
if pull_request.created_at
else None
),
"updated_at": (
pull_request.updated_at.replace(tzinfo=timezone.utc)
if pull_request.updated_at
else None
),
"closed_at": (
pull_request.closed_at.replace(tzinfo=timezone.utc)
if pull_request.closed_at
else None
),
"merged_at": (
pull_request.merged_at.replace(tzinfo=timezone.utc)
if pull_request.merged_at
else None
),
"merged_by": (
_get_userinfo(pull_request.merged_by)
if pull_request.merged_by
else None
),
}.items()
if v is not None
},
)
@@ -252,11 +306,39 @@ def _convert_issue_to_document(issue: Issue) -> Document:
id=issue.html_url,
sections=[TextSection(link=issue.html_url, text=issue.body or "")],
source=DocumentSource.GITHUB,
semantic_identifier=issue.title,
semantic_identifier=f"{issue.number}: {issue.title}",
# updated_at is UTC time but is timezone unaware
doc_updated_at=issue.updated_at.replace(tzinfo=timezone.utc),
metadata={
"state": issue.state,
k: [str(vi) for vi in v] if isinstance(v, list) else str(v)
for k, v in {
"object_type": "Issue",
"id": issue.number,
"state": issue.state,
"user": _get_userinfo(issue.user) if issue.user else None,
"assignees": [_get_userinfo(assignee) for assignee in issue.assignees],
"repo": issue.repository.full_name if issue.repository else None,
"labels": [label.name for label in issue.labels],
"created_at": (
issue.created_at.replace(tzinfo=timezone.utc)
if issue.created_at
else None
),
"updated_at": (
issue.updated_at.replace(tzinfo=timezone.utc)
if issue.updated_at
else None
),
"closed_at": (
issue.closed_at.replace(tzinfo=timezone.utc)
if issue.closed_at
else None
),
"closed_by": (
_get_userinfo(issue.closed_by) if issue.closed_by else None
),
}.items()
if v is not None
},
)

View File

@@ -27,6 +27,7 @@ from onyx.connectors.google_drive.doc_conversion import build_slim_document
from onyx.connectors.google_drive.doc_conversion import (
convert_drive_item_to_document,
)
from onyx.connectors.google_drive.doc_conversion import onyx_document_id_from_drive_file
from onyx.connectors.google_drive.file_retrieval import crawl_folders_for_files
from onyx.connectors.google_drive.file_retrieval import get_all_files_for_oauth
from onyx.connectors.google_drive.file_retrieval import (
@@ -220,6 +221,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
self._primary_admin_email: str | None = None
self._creds: OAuthCredentials | ServiceAccountCredentials | None = None
self._creds_dict: dict[str, Any] | None = None
# ids of folders and shared drives that have been traversed
self._retrieved_folder_and_drive_ids: set[str] = set()
@@ -273,6 +275,8 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
source=DocumentSource.GOOGLE_DRIVE,
)
self._creds_dict = new_creds_dict
return new_creds_dict
def _update_traversed_parent_ids(self, folder_id: str) -> None:
@@ -919,8 +923,9 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
).timestamp(),
current_folder_or_drive_id=file.parent_id,
)
if file.drive_file["id"] not in checkpoint.all_retrieved_file_ids:
checkpoint.all_retrieved_file_ids.add(file.drive_file["id"])
document_id = onyx_document_id_from_drive_file(file.drive_file)
if document_id not in checkpoint.all_retrieved_file_ids:
checkpoint.all_retrieved_file_ids.add(document_id)
yield file
def _manage_oauth_retrieval(
@@ -1135,6 +1140,10 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
raise PermissionError(ONYX_SCOPE_INSTRUCTIONS) from e
raise e
checkpoint.retrieved_folder_and_drive_ids = self._retrieved_folder_and_drive_ids
logger.info(
f"num drive files retrieved: {len(checkpoint.all_retrieved_file_ids)}"
)
if checkpoint.completion_stage == DriveRetrievalStage.DONE:
checkpoint.has_more = False
return checkpoint
@@ -1183,6 +1192,7 @@ class GoogleDriveConnector(SlimConnector, CheckpointedConnector[GoogleDriveCheck
end=end,
callback=callback,
)
logger.info("Drive perm sync: Slim doc retrieval complete")
except Exception as e:
if MISSING_SCOPES_ERROR_STR in str(e):

View File

@@ -62,6 +62,10 @@ GOOGLE_MIME_TYPES = {
}
def onyx_document_id_from_drive_file(file: GoogleDriveFileType) -> str:
return file[WEB_VIEW_LINK_KEY]
def _summarize_drive_image(
image_data: bytes, image_name: str, image_analysis_llm: LLM | None
) -> str:
@@ -380,7 +384,6 @@ def _convert_drive_item_to_document(
"""
Main entry point for converting a Google Drive file => Document object.
"""
doc_id = file.get(WEB_VIEW_LINK_KEY, "")
sections: list[TextSection | ImageSection] = []
# Only construct these services when needed
drive_service = lazy_eval(
@@ -389,6 +392,7 @@ def _convert_drive_item_to_document(
docs_service = lazy_eval(
lambda: get_google_docs_service(creds, user_email=retriever_email)
)
doc_id = "unknown"
try:
# skip shortcuts or folders
@@ -441,7 +445,7 @@ def _convert_drive_item_to_document(
logger.warning(f"No content extracted from {file.get('name')}. Skipping.")
return None
doc_id = file[WEB_VIEW_LINK_KEY]
doc_id = onyx_document_id_from_drive_file(file)
# Create the document
return Document(
@@ -488,7 +492,7 @@ def build_slim_document(file: GoogleDriveFileType) -> SlimDocument | None:
if file.get("mimeType") in [DRIVE_FOLDER_TYPE, DRIVE_SHORTCUT_TYPE]:
return None
return SlimDocument(
id=file[WEB_VIEW_LINK_KEY],
id=onyx_document_id_from_drive_file(file),
perm_sync_data={
"doc_id": file.get("id"),
"drive_id": file.get("driveId"),

View File

@@ -1,8 +1,16 @@
from collections.abc import Callable
from typing import Any
from google.auth.exceptions import RefreshError # type: ignore
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from googleapiclient.discovery import build # type: ignore
from googleapiclient.discovery import Resource # type: ignore
from onyx.utils.logger import setup_logger
logger = setup_logger()
class GoogleDriveService(Resource):
pass
@@ -20,6 +28,56 @@ class GmailService(Resource):
pass
class RefreshableDriveObject:
"""
Running Google drive service retrieval functions
involves accessing methods of the service object (ie. files().list())
which can raise a RefreshError if the access token is expired.
This class is a wrapper that propagates the ability to refresh the access token
and retry the final retrieval function until execute() is called.
"""
def __init__(
self,
call_stack: Callable[[ServiceAccountCredentials | OAuthCredentials], Any],
creds: ServiceAccountCredentials | OAuthCredentials,
creds_getter: Callable[..., ServiceAccountCredentials | OAuthCredentials],
):
self.call_stack = call_stack
self.creds = creds
self.creds_getter = creds_getter
def __getattr__(self, name: str) -> Any:
if name == "execute":
return self.make_refreshable_execute()
return RefreshableDriveObject(
lambda creds: getattr(self.call_stack(creds), name),
self.creds,
self.creds_getter,
)
def __call__(self, *args: Any, **kwargs: Any) -> Any:
return RefreshableDriveObject(
lambda creds: self.call_stack(creds)(*args, **kwargs),
self.creds,
self.creds_getter,
)
def make_refreshable_execute(self) -> Callable:
def execute(*args: Any, **kwargs: Any) -> Any:
try:
return self.call_stack(self.creds).execute(*args, **kwargs)
except RefreshError as e:
logger.warning(
f"RefreshError, going to attempt a creds refresh and retry: {e}"
)
# Refresh the access token
self.creds = self.creds_getter()
return self.call_stack(self.creds).execute(*args, **kwargs)
return execute
def _get_google_service(
service_name: str,
service_version: str,

View File

@@ -87,6 +87,9 @@ class BasicExpertInfo(BaseModel):
return "Unknown"
def get_email(self) -> str | None:
return self.email or None
def __eq__(self, other: Any) -> bool:
if not isinstance(other, BasicExpertInfo):
return False

View File

@@ -12,6 +12,9 @@ from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.app_configs import JIRA_CONNECTOR_LABELS_TO_SKIP
from onyx.configs.app_configs import JIRA_CONNECTOR_MAX_TICKET_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
is_atlassian_date_error,
)
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
@@ -40,6 +43,8 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
ONE_HOUR = 3600
JIRA_API_VERSION = os.environ.get("JIRA_API_VERSION") or "2"
_JIRA_SLIM_PAGE_SIZE = 500
_JIRA_FULL_PAGE_SIZE = 50
@@ -55,6 +60,14 @@ _FIELD_KEY = "key"
_FIELD_CREATED = "created"
_FIELD_DUEDATE = "duedate"
_FIELD_ISSUETYPE = "issuetype"
_FIELD_PARENT = "parent"
_FIELD_ASSIGNEE_EMAIL = "assignee_email"
_FIELD_REPORTER_EMAIL = "reporter_email"
_FIELD_PROJECT = "project"
_FIELD_PROJECT_NAME = "project_name"
_FIELD_UPDATED = "updated"
_FIELD_RESOLUTION_DATE = "resolutiondate"
_FIELD_RESOLUTION_DATE_KEY = "resolution_date"
def _perform_jql_search(
@@ -126,6 +139,9 @@ def process_jira_issue(
if basic_expert_info := best_effort_basic_expert_info(creator):
people.add(basic_expert_info)
metadata_dict[_FIELD_REPORTER] = basic_expert_info.get_semantic_name()
if email := basic_expert_info.get_email():
metadata_dict[_FIELD_REPORTER_EMAIL] = email
except Exception:
# Author should exist but if not, doesn't matter
pass
@@ -135,6 +151,8 @@ def process_jira_issue(
if basic_expert_info := best_effort_basic_expert_info(assignee):
people.add(basic_expert_info)
metadata_dict[_FIELD_ASSIGNEE] = basic_expert_info.get_semantic_name()
if email := basic_expert_info.get_email():
metadata_dict[_FIELD_ASSIGNEE_EMAIL] = email
except Exception:
# Author should exist but if not, doesn't matter
pass
@@ -149,10 +167,32 @@ def process_jira_issue(
metadata_dict[_FIELD_LABELS] = labels
if created := best_effort_get_field_from_issue(issue, _FIELD_CREATED):
metadata_dict[_FIELD_CREATED] = created
if updated := best_effort_get_field_from_issue(issue, _FIELD_UPDATED):
metadata_dict[_FIELD_UPDATED] = updated
if duedate := best_effort_get_field_from_issue(issue, _FIELD_DUEDATE):
metadata_dict[_FIELD_DUEDATE] = duedate
if issuetype := best_effort_get_field_from_issue(issue, _FIELD_ISSUETYPE):
metadata_dict[_FIELD_ISSUETYPE] = issuetype.name
if resolutiondate := best_effort_get_field_from_issue(
issue, _FIELD_RESOLUTION_DATE
):
metadata_dict[_FIELD_RESOLUTION_DATE_KEY] = resolutiondate
try:
parent = best_effort_get_field_from_issue(issue, _FIELD_PARENT)
if parent:
metadata_dict[_FIELD_PARENT] = parent.key
except Exception:
# Parent should exist but if not, doesn't matter
pass
try:
project = best_effort_get_field_from_issue(issue, _FIELD_PROJECT)
if project:
metadata_dict[_FIELD_PROJECT_NAME] = project.name
metadata_dict[_FIELD_PROJECT] = project.key
except Exception:
# Project should exist.
logger.error(f"Project should exist but does not for {issue.key}")
return Document(
id=page_url,
@@ -240,7 +280,17 @@ class JiraConnector(CheckpointedConnector[JiraConnectorCheckpoint], SlimConnecto
checkpoint: JiraConnectorCheckpoint,
) -> CheckpointOutput[JiraConnectorCheckpoint]:
jql = self._get_jql_query(start, end)
try:
return self._load_from_checkpoint(jql, checkpoint)
except Exception as e:
if is_atlassian_date_error(e):
jql = self._get_jql_query(start - ONE_HOUR, end)
return self._load_from_checkpoint(jql, checkpoint)
raise e
def _load_from_checkpoint(
self, jql: str, checkpoint: JiraConnectorCheckpoint
) -> CheckpointOutput[JiraConnectorCheckpoint]:
# Get the current offset from checkpoint or start at 0
starting_offset = checkpoint.offset or 0
current_offset = starting_offset

View File

@@ -9,8 +9,11 @@ from concurrent.futures import Future
from concurrent.futures import ThreadPoolExecutor
from datetime import datetime
from datetime import timezone
from http.client import IncompleteRead
from http.client import RemoteDisconnected
from typing import Any
from typing import cast
from urllib.error import URLError
from pydantic import BaseModel
from redis import Redis
@@ -18,6 +21,9 @@ from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from slack_sdk.http_retry import ConnectionErrorRetryHandler
from slack_sdk.http_retry import RetryHandler
from slack_sdk.http_retry.builtin_interval_calculators import (
FixedValueRetryIntervalCalculator,
)
from typing_extensions import override
from onyx.configs.app_configs import ENABLE_EXPENSIVE_EXPERT_CALLS
@@ -45,10 +51,10 @@ from onyx.connectors.models import EntityFailure
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.connectors.slack.onyx_retry_handler import OnyxRedisSlackRetryHandler
from onyx.connectors.slack.onyx_slack_web_client import OnyxSlackWebClient
from onyx.connectors.slack.utils import expert_info_from_slack_id
from onyx.connectors.slack.utils import get_message_link
from onyx.connectors.slack.utils import make_paginated_slack_api_call_w_retries
from onyx.connectors.slack.utils import make_slack_api_call_w_retries
from onyx.connectors.slack.utils import make_paginated_slack_api_call
from onyx.connectors.slack.utils import SlackTextCleaner
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.redis.redis_pool import get_redis_client
@@ -78,7 +84,7 @@ def _collect_paginated_channels(
channel_types: list[str],
) -> list[ChannelType]:
channels: list[dict[str, Any]] = []
for result in make_paginated_slack_api_call_w_retries(
for result in make_paginated_slack_api_call(
client.conversations_list,
exclude_archived=exclude_archived,
# also get private channels the bot is added to
@@ -135,14 +141,13 @@ def get_channel_messages(
"""Get all messages in a channel"""
# join so that the bot can access messages
if not channel["is_member"]:
make_slack_api_call_w_retries(
client.conversations_join,
client.conversations_join(
channel=channel["id"],
is_private=channel["is_private"],
)
logger.info(f"Successfully joined '{channel['name']}'")
for result in make_paginated_slack_api_call_w_retries(
for result in make_paginated_slack_api_call(
client.conversations_history,
channel=channel["id"],
oldest=oldest,
@@ -159,7 +164,7 @@ def get_channel_messages(
def get_thread(client: WebClient, channel_id: str, thread_id: str) -> ThreadType:
"""Get all messages in a thread"""
threads: list[MessageType] = []
for result in make_paginated_slack_api_call_w_retries(
for result in make_paginated_slack_api_call(
client.conversations_replies, channel=channel_id, ts=thread_id
):
threads.extend(result["messages"])
@@ -317,8 +322,7 @@ def _get_channel_by_id(client: WebClient, channel_id: str) -> ChannelType:
Raises:
SlackApiError: If the channel cannot be fetched
"""
response = make_slack_api_call_w_retries(
client.conversations_info,
response = client.conversations_info(
channel=channel_id,
)
return cast(ChannelType, response["channel"])
@@ -335,8 +339,7 @@ def _get_messages(
# have to be in the channel in order to read messages
if not channel["is_member"]:
try:
make_slack_api_call_w_retries(
client.conversations_join,
client.conversations_join(
channel=channel["id"],
is_private=channel["is_private"],
)
@@ -349,8 +352,7 @@ def _get_messages(
raise
logger.info(f"Successfully joined '{channel['name']}'")
response = make_slack_api_call_w_retries(
client.conversations_history,
response = client.conversations_history(
channel=channel["id"],
oldest=oldest,
latest=latest,
@@ -379,6 +381,9 @@ def _message_to_doc(
filtered_thread: ThreadType | None = None
thread_ts = message.get("thread_ts")
if thread_ts:
# NOTE: if thread_ts is present, there's a thread we need to process
# ... otherwise, we can skip it
# skip threads we've already seen, since we've already processed all
# messages in that thread
if thread_ts in seen_thread_ts:
@@ -527,6 +532,7 @@ class SlackConnector(
channel_regex_enabled: bool = False,
batch_size: int = INDEX_BATCH_SIZE,
num_threads: int = SLACK_NUM_THREADS,
use_redis: bool = True,
) -> None:
self.channels = channels
self.channel_regex_enabled = channel_regex_enabled
@@ -539,6 +545,7 @@ class SlackConnector(
self.user_cache: dict[str, BasicExpertInfo | None] = {}
self.credentials_provider: CredentialsProviderInterface | None = None
self.credential_prefix: str | None = None
self.use_redis: bool = use_redis
# self.delay_lock: str | None = None # the redis key for the shared lock
# self.delay_key: str | None = None # the redis key for the shared delay
@@ -563,10 +570,19 @@ class SlackConnector(
# NOTE: slack has a built in RateLimitErrorRetryHandler, but it isn't designed
# for concurrent workers. We've extended it with OnyxRedisSlackRetryHandler.
connection_error_retry_handler = ConnectionErrorRetryHandler()
connection_error_retry_handler = ConnectionErrorRetryHandler(
max_retry_count=max_retry_count,
interval_calculator=FixedValueRetryIntervalCalculator(),
error_types=[
URLError,
ConnectionResetError,
RemoteDisconnected,
IncompleteRead,
],
)
onyx_rate_limit_error_retry_handler = OnyxRedisSlackRetryHandler(
max_retry_count=max_retry_count,
delay_lock=delay_lock,
delay_key=delay_key,
r=r,
)
@@ -575,7 +591,13 @@ class SlackConnector(
onyx_rate_limit_error_retry_handler,
]
client = WebClient(token=token, retry_handlers=custom_retry_handlers)
client = OnyxSlackWebClient(
delay_lock=delay_lock,
delay_key=delay_key,
r=r,
token=token,
retry_handlers=custom_retry_handlers,
)
return client
@property
@@ -599,16 +621,32 @@ class SlackConnector(
if not tenant_id:
raise ValueError("tenant_id cannot be None!")
self.redis = get_redis_client(tenant_id=tenant_id)
self.credential_prefix = SlackConnector.make_credential_prefix(
credentials_provider.get_provider_key()
)
bot_token = credentials["slack_bot_token"]
self.client = SlackConnector.make_slack_web_client(
self.credential_prefix, bot_token, self.MAX_RETRIES, self.redis
)
if self.use_redis:
self.redis = get_redis_client(tenant_id=tenant_id)
self.credential_prefix = SlackConnector.make_credential_prefix(
credentials_provider.get_provider_key()
)
self.client = SlackConnector.make_slack_web_client(
self.credential_prefix, bot_token, self.MAX_RETRIES, self.redis
)
else:
connection_error_retry_handler = ConnectionErrorRetryHandler(
max_retry_count=self.MAX_RETRIES,
interval_calculator=FixedValueRetryIntervalCalculator(),
error_types=[
URLError,
ConnectionResetError,
RemoteDisconnected,
IncompleteRead,
],
)
self.client = WebClient(
token=bot_token, retry_handlers=[connection_error_retry_handler]
)
# use for requests that must return quickly (e.g. realtime flows where user is waiting)
self.fast_client = WebClient(
@@ -651,6 +689,8 @@ class SlackConnector(
Step 2.4: If there are no more messages in the channel, switch the current
channel to the next channel.
"""
num_channels_remaining = 0
if self.client is None or self.text_cleaner is None:
raise ConnectorMissingCredentialError("Slack")
@@ -664,7 +704,9 @@ class SlackConnector(
raw_channels, self.channels, self.channel_regex_enabled
)
logger.info(
f"Channels: all={len(raw_channels)} post_filtering={len(filtered_channels)}"
f"Channels - initial checkpoint: "
f"all={len(raw_channels)} "
f"post_filtering={len(filtered_channels)}"
)
checkpoint.channel_ids = [c["id"] for c in filtered_channels]
@@ -677,6 +719,17 @@ class SlackConnector(
return checkpoint
final_channel_ids = checkpoint.channel_ids
for channel_id in final_channel_ids:
if channel_id not in checkpoint.channel_completion_map:
num_channels_remaining += 1
logger.info(
f"Channels - current status: "
f"processed={len(final_channel_ids) - num_channels_remaining} "
f"remaining={num_channels_remaining=} "
f"total={len(final_channel_ids)}"
)
channel = checkpoint.current_channel
if channel is None:
raise ValueError("current_channel key not set in checkpoint")
@@ -688,18 +741,32 @@ class SlackConnector(
oldest = str(start) if start else None
latest = checkpoint.channel_completion_map.get(channel_id, str(end))
seen_thread_ts = set(checkpoint.seen_thread_ts)
logger.debug(
f"Getting messages for channel {channel} within range {oldest} - {latest}"
)
try:
logger.debug(
f"Getting messages for channel {channel} within range {oldest} - {latest}"
)
message_batch, has_more_in_channel = _get_messages(
channel, self.client, oldest, latest
)
logger.info(
f"Retrieved messages: "
f"{len(message_batch)=} "
f"{channel=} "
f"{oldest=} "
f"{latest=}"
)
new_latest = message_batch[-1]["ts"] if message_batch else latest
num_threads_start = len(seen_thread_ts)
# Process messages in parallel using ThreadPoolExecutor
with ThreadPoolExecutor(max_workers=self.num_threads) as executor:
# NOTE(rkuo): this seems to be assuming the slack sdk is thread safe.
# That's a very bold assumption! Likely not correct.
futures: list[Future[ProcessedSlackMessage]] = []
for message in message_batch:
# Capture the current context so that the thread gets the current tenant ID
@@ -736,7 +803,12 @@ class SlackConnector(
yield failure
num_threads_processed = len(seen_thread_ts) - num_threads_start
logger.info(f"Processed {num_threads_processed} threads.")
logger.info(
f"Message processing stats: "
f"batch_len={len(message_batch)} "
f"batch_yielded={num_threads_processed} "
f"total_threads_seen={len(seen_thread_ts)}"
)
checkpoint.seen_thread_ts = list(seen_thread_ts)
checkpoint.channel_completion_map[channel["id"]] = new_latest
@@ -751,6 +823,7 @@ class SlackConnector(
),
None,
)
if new_channel_id:
new_channel = _get_channel_by_id(self.client, new_channel_id)
checkpoint.current_channel = new_channel
@@ -758,8 +831,6 @@ class SlackConnector(
checkpoint.current_channel = None
checkpoint.has_more = checkpoint.current_channel is not None
return checkpoint
except Exception as e:
logger.exception(f"Error processing channel {channel['name']}")
yield ConnectorFailure(
@@ -773,7 +844,8 @@ class SlackConnector(
failure_message=str(e),
exception=e,
)
return checkpoint
return checkpoint
def validate_connector_settings(self) -> None:
"""

View File

@@ -1,11 +1,8 @@
import math
import random
import time
from typing import cast
from typing import Optional
from redis import Redis
from redis.lock import Lock as RedisLock
from slack_sdk.http_retry.handler import RetryHandler
from slack_sdk.http_retry.request import HttpRequest
from slack_sdk.http_retry.response import HttpResponse
@@ -20,28 +17,23 @@ class OnyxRedisSlackRetryHandler(RetryHandler):
"""
This class uses Redis to share a rate limit among multiple threads.
Threads that encounter a rate limit will observe the shared delay, increment the
shared delay with the retry value, and use the new shared value as a wait interval.
As currently implemented, this code is already surrounded by a lock in Redis
via an override of _perform_urllib_http_request in OnyxSlackWebClient.
This has the effect of serializing calls when a rate limit is hit, which is what
needs to happens if the server punishes us with additional limiting when we make
a call too early. We believe this is what Slack is doing based on empirical
observation, meaning we see indefinite hangs if we're too aggressive.
This just sets the desired retry delay with TTL in redis. In conjunction with
a custom subclass of the client, the value is read and obeyed prior to an API call
and also serialized.
Another way to do this is just to do exponential backoff. Might be easier?
Adapted from slack's RateLimitErrorRetryHandler.
"""
LOCK_TTL = 60 # used to serialize access to the retry TTL
LOCK_BLOCKING_TIMEOUT = 60 # how long to wait for the lock
"""RetryHandler that does retries for rate limited errors."""
def __init__(
self,
max_retry_count: int,
delay_lock: str,
delay_key: str,
r: Redis,
):
@@ -51,7 +43,6 @@ class OnyxRedisSlackRetryHandler(RetryHandler):
"""
super().__init__(max_retry_count=max_retry_count)
self._redis: Redis = r
self._delay_lock = delay_lock
self._delay_key = delay_key
def _can_retry(
@@ -72,8 +63,18 @@ class OnyxRedisSlackRetryHandler(RetryHandler):
response: Optional[HttpResponse] = None,
error: Optional[Exception] = None,
) -> None:
"""It seems this function is responsible for the wait to retry ... aka we
actually sleep in this function."""
"""As initially designed by the SDK authors, this function is responsible for
the wait to retry ... aka we actually sleep in this function.
This doesn't work well with multiple clients because every thread is unaware
of the current retry value until it actually calls the endpoint.
We're combining this with an actual subclass of the slack web client so
that the delay is used BEFORE calling an API endpoint. The subclassed client
has already taken the lock in redis when this method is called.
"""
ttl_ms: int | None = None
retry_after_value: list[str] | None = None
retry_after_header_name: Optional[str] = None
duration_s: float = 1.0 # seconds
@@ -112,48 +113,22 @@ class OnyxRedisSlackRetryHandler(RetryHandler):
retry_after_value[0]
) # will raise ValueError if somehow we can't convert to int
jitter = retry_after_value_int * 0.25 * random.random()
duration_s = math.ceil(retry_after_value_int + jitter)
duration_s = retry_after_value_int + jitter
except ValueError:
duration_s += random.random()
# lock and extend the ttl
lock: RedisLock = self._redis.lock(
self._delay_lock,
timeout=OnyxRedisSlackRetryHandler.LOCK_TTL,
thread_local=False,
)
acquired = lock.acquire(
blocking_timeout=OnyxRedisSlackRetryHandler.LOCK_BLOCKING_TIMEOUT / 2
)
ttl_ms: int | None = None
try:
if acquired:
# if we can get the lock, then read and extend the ttl
ttl_ms = cast(int, self._redis.pttl(self._delay_key))
if ttl_ms < 0: # negative values are error status codes ... see docs
ttl_ms = 0
ttl_ms_new = ttl_ms + int(duration_s * 1000.0)
self._redis.set(self._delay_key, "1", px=ttl_ms_new)
else:
# if we can't get the lock, just go ahead.
# TODO: if we know our actual parallelism, multiplying by that
# would be a pretty good idea
ttl_ms_new = int(duration_s * 1000.0)
finally:
if acquired:
lock.release()
# Read and extend the ttl
ttl_ms = cast(int, self._redis.pttl(self._delay_key))
if ttl_ms < 0: # negative values are error status codes ... see docs
ttl_ms = 0
ttl_ms_new = ttl_ms + int(duration_s * 1000.0)
self._redis.set(self._delay_key, "1", px=ttl_ms_new)
logger.warning(
f"OnyxRedisSlackRetryHandler.prepare_for_next_attempt wait: "
f"OnyxRedisSlackRetryHandler.prepare_for_next_attempt setting delay: "
f"current_attempt={state.current_attempt} "
f"retry-after={retry_after_value} "
f"shared_delay_ms={ttl_ms} new_shared_delay_ms={ttl_ms_new}"
f"{ttl_ms_new=}"
)
# TODO: would be good to take an event var and sleep in short increments to
# allow for a clean exit / exception
time.sleep(ttl_ms_new / 1000.0)
state.increment_current_attempt()

View File

@@ -0,0 +1,116 @@
import threading
import time
from typing import Any
from typing import cast
from typing import Dict
from urllib.request import Request
from redis import Redis
from redis.lock import Lock as RedisLock
from slack_sdk import WebClient
from onyx.connectors.slack.utils import ONYX_SLACK_LOCK_BLOCKING_TIMEOUT
from onyx.connectors.slack.utils import ONYX_SLACK_LOCK_TOTAL_BLOCKING_TIMEOUT
from onyx.connectors.slack.utils import ONYX_SLACK_LOCK_TTL
from onyx.utils.logger import setup_logger
logger = setup_logger()
class OnyxSlackWebClient(WebClient):
"""Use in combination with the Onyx Retry Handler.
This client wrapper enforces a proper retry delay through redis BEFORE the api call
so that multiple clients can synchronize and rate limit properly.
The retry handler writes the correct delay value to redis so that it is can be used
by this wrapper.
"""
def __init__(
self, delay_lock: str, delay_key: str, r: Redis, *args: Any, **kwargs: Any
) -> None:
super().__init__(*args, **kwargs)
self._delay_key = delay_key
self._delay_lock = delay_lock
self._redis: Redis = r
self.num_requests: int = 0
self._lock = threading.Lock()
def _perform_urllib_http_request(
self, *, url: str, args: Dict[str, Dict[str, Any]]
) -> Dict[str, Any]:
"""By locking around the base class method, we ensure that both the delay from
Redis and parsing/writing of retry values to Redis are handled properly in
one place"""
# lock and extend the ttl
lock: RedisLock = self._redis.lock(
self._delay_lock,
timeout=ONYX_SLACK_LOCK_TTL,
)
# try to acquire the lock
start = time.monotonic()
while True:
acquired = lock.acquire(blocking_timeout=ONYX_SLACK_LOCK_BLOCKING_TIMEOUT)
if acquired:
break
# if we couldn't acquire the lock but it exists, there's at least some activity
# so keep trying...
if self._redis.exists(self._delay_lock):
continue
if time.monotonic() - start > ONYX_SLACK_LOCK_TOTAL_BLOCKING_TIMEOUT:
raise RuntimeError(
f"OnyxSlackWebClient._perform_urllib_http_request - "
f"timed out waiting for lock: {ONYX_SLACK_LOCK_TOTAL_BLOCKING_TIMEOUT=}"
)
try:
result = super()._perform_urllib_http_request(url=url, args=args)
finally:
if lock.owned():
lock.release()
else:
logger.warning(
"OnyxSlackWebClient._perform_urllib_http_request lock not owned on release"
)
time.monotonic() - start
# logger.info(
# f"OnyxSlackWebClient._perform_urllib_http_request: Releasing lock: {elapsed=}"
# )
return result
def _perform_urllib_http_request_internal(
self,
url: str,
req: Request,
) -> Dict[str, Any]:
"""Overrides the internal method which is mostly the direct call to
urllib/urlopen ... so this is a good place to perform our delay."""
# read and execute the delay
delay_ms = cast(int, self._redis.pttl(self._delay_key))
if delay_ms < 0: # negative values are error status codes ... see docs
delay_ms = 0
if delay_ms > 0:
logger.warning(
f"OnyxSlackWebClient._perform_urllib_http_request_internal delay: "
f"{delay_ms=} "
f"{self.num_requests=}"
)
time.sleep(delay_ms / 1000.0)
result = super()._perform_urllib_http_request_internal(url, req)
with self._lock:
self.num_requests += 1
# the delay key should have naturally expired by this point
return result

View File

@@ -21,6 +21,11 @@ basic_retry_wrapper = retry_builder(tries=7)
# number of messages we request per page when fetching paginated slack messages
_SLACK_LIMIT = 900
# used to serialize access to the retry TTL
ONYX_SLACK_LOCK_TTL = 1800 # how long the lock is allowed to idle before it expires
ONYX_SLACK_LOCK_BLOCKING_TIMEOUT = 60 # how long to wait for the lock per wait attempt
ONYX_SLACK_LOCK_TOTAL_BLOCKING_TIMEOUT = 3600 # how long to wait for the lock in total
@lru_cache()
def get_base_url(token: str) -> str:
@@ -44,6 +49,18 @@ def get_message_link(
return link
def make_slack_api_call(
call: Callable[..., SlackResponse], **kwargs: Any
) -> SlackResponse:
return call(**kwargs)
def make_paginated_slack_api_call(
call: Callable[..., SlackResponse], **kwargs: Any
) -> Generator[dict[str, Any], None, None]:
return _make_slack_api_call_paginated(call)(**kwargs)
def _make_slack_api_call_paginated(
call: Callable[..., SlackResponse],
) -> Callable[..., Generator[dict[str, Any], None, None]]:
@@ -119,17 +136,18 @@ def _make_slack_api_call_paginated(
# return rate_limited_call
def make_slack_api_call_w_retries(
call: Callable[..., SlackResponse], **kwargs: Any
) -> SlackResponse:
return basic_retry_wrapper(call)(**kwargs)
# temporarily disabling due to using a different retry approach
# might be permanent if everything works out
# def make_slack_api_call_w_retries(
# call: Callable[..., SlackResponse], **kwargs: Any
# ) -> SlackResponse:
# return basic_retry_wrapper(call)(**kwargs)
def make_paginated_slack_api_call_w_retries(
call: Callable[..., SlackResponse], **kwargs: Any
) -> Generator[dict[str, Any], None, None]:
return _make_slack_api_call_paginated(basic_retry_wrapper(call))(**kwargs)
# def make_paginated_slack_api_call_w_retries(
# call: Callable[..., SlackResponse], **kwargs: Any
# ) -> Generator[dict[str, Any], None, None]:
# return _make_slack_api_call_paginated(basic_retry_wrapper(call))(**kwargs)
def expert_info_from_slack_id(

View File

@@ -111,11 +111,14 @@ class BaseFilters(BaseModel):
document_set: list[str] | None = None
time_cutoff: datetime | None = None
tags: list[Tag] | None = None
class UserFileFilters(BaseModel):
user_file_ids: list[int] | None = None
user_folder_ids: list[int] | None = None
class IndexFilters(BaseFilters):
class IndexFilters(BaseFilters, UserFileFilters):
access_control_list: list[str] | None
tenant_id: str | None = None
@@ -150,6 +153,7 @@ class SearchRequest(ChunkContext):
search_type: SearchType = SearchType.SEMANTIC
human_selected_filters: BaseFilters | None = None
user_file_filters: UserFileFilters | None = None
enable_auto_detect_filters: bool | None = None
persona: Persona | None = None

View File

@@ -165,47 +165,6 @@ class SearchPipeline:
return cast(list[InferenceChunk], self._retrieved_chunks)
def get_ordering_only_chunks(
self,
query: str,
user_file_ids: list[int] | None = None,
user_folder_ids: list[int] | None = None,
) -> list[InferenceChunk]:
"""Optimized method that only retrieves chunks for ordering purposes.
Skips all extra processing and uses minimal configuration to speed up retrieval.
"""
logger.info("Fast path: Using optimized chunk retrieval for ordering-only mode")
# Create minimal filters with just user file/folder IDs
filters = IndexFilters(
user_file_ids=user_file_ids or [],
user_folder_ids=user_folder_ids or [],
access_control_list=None,
)
# Use a simplified query that skips all unnecessary processing
minimal_query = SearchQuery(
query=query,
search_type=SearchType.SEMANTIC,
filters=filters,
# Set minimal options needed for retrieval
evaluation_type=LLMEvaluationType.SKIP,
recency_bias_multiplier=1.0,
chunks_above=0, # No need for surrounding context
chunks_below=0, # No need for surrounding context
processed_keywords=[], # Empty list instead of None
rerank_settings=None,
hybrid_alpha=0.0,
max_llm_filter_sections=0,
)
# Retrieve chunks using the minimal configuration
return retrieve_chunks(
query=minimal_query,
document_index=self.document_index,
db_session=self.db_session,
)
@log_function_time(print_only=True)
def _get_sections(self) -> list[InferenceSection]:
"""Returns an expanded section from each of the chunks.
@@ -458,10 +417,6 @@ class SearchPipeline:
self.search_query.evaluation_type == LLMEvaluationType.SKIP
or DISABLE_LLM_DOC_RELEVANCE
):
if self.search_query.evaluation_type == LLMEvaluationType.SKIP:
logger.info(
"Fast path: Skipping section relevance evaluation for ordering-only mode"
)
return None
if self.search_query.evaluation_type == LLMEvaluationType.UNSPECIFIED:

View File

@@ -372,7 +372,6 @@ def filter_sections(
# Log evaluation type to help with debugging
logger.info(f"filter_sections called with evaluation_type={query.evaluation_type}")
# Fast path: immediately return empty list for SKIP evaluation type (ordering-only mode)
if query.evaluation_type == LLMEvaluationType.SKIP:
return []
@@ -408,16 +407,6 @@ def search_postprocessing(
llm: LLM,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
) -> Iterator[list[InferenceSection] | list[SectionRelevancePiece]]:
# Fast path for ordering-only: detect it by checking if evaluation_type is SKIP
if search_query.evaluation_type == LLMEvaluationType.SKIP:
logger.info(
"Fast path: Detected ordering-only mode, bypassing all post-processing"
)
# Immediately yield the sections without any processing and an empty relevance list
yield retrieved_sections
yield cast(list[SectionRelevancePiece], [])
return
post_processing_tasks: list[FunctionCall] = []
if not retrieved_sections:

View File

@@ -164,14 +164,15 @@ def retrieval_preprocessing(
user_acl_filters = (
None if bypass_acl else build_access_filters_for_user(user, db_session)
)
user_file_ids = preset_filters.user_file_ids or []
user_folder_ids = preset_filters.user_folder_ids or []
user_file_filters = search_request.user_file_filters
user_file_ids = (user_file_filters.user_file_ids or []) if user_file_filters else []
user_folder_ids = (
(user_file_filters.user_folder_ids or []) if user_file_filters else []
)
if persona and persona.user_files:
user_file_ids = user_file_ids + [
file.id
for file in persona.user_files
if file.id not in (preset_filters.user_file_ids or [])
]
user_file_ids = list(
set(user_file_ids) | set([file.id for file in persona.user_files])
)
final_filters = IndexFilters(
user_file_ids=user_file_ids,

View File

@@ -62,7 +62,7 @@ def download_nltk_data() -> None:
resources = {
"stopwords": "corpora/stopwords",
# "wordnet": "corpora/wordnet", # Not in use
"punkt": "tokenizers/punkt",
"punkt_tab": "tokenizers/punkt_tab",
}
for resource_name, resource_path in resources.items():

View File

@@ -234,6 +234,10 @@ def delete_messages_and_files_from_chat_session(
logger.info(f"Deleting file with name: {lobj_name}")
delete_lobj_by_name(lobj_name, db_session)
# Delete ChatMessage records - CASCADE constraints will automatically handle:
# - AgentSubQuery records (via AgentSubQuestion)
# - AgentSubQuestion records
# - ChatMessage__StandardAnswer relationship records
db_session.execute(
delete(ChatMessage).where(ChatMessage.chat_session_id == chat_session_id)
)

View File

@@ -423,12 +423,11 @@ def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]
dbapi_connection = connection.connection
cursor = dbapi_connection.cursor()
try:
# NOTE: don't use `text()` here since we're using the cursor directly
cursor.execute(f'SET search_path = "{tenant_id}"')
if POSTGRES_IDLE_SESSIONS_TIMEOUT:
cursor.execute(
text(
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
)
f"SET SESSION idle_in_transaction_session_timeout = {POSTGRES_IDLE_SESSIONS_TIMEOUT}"
)
except Exception:
raise RuntimeError(f"search_path not set for {tenant_id}")

View File

@@ -354,7 +354,7 @@ class AgentSubQuery__SearchDoc(Base):
__tablename__ = "agent__sub_query__search_doc"
sub_query_id: Mapped[int] = mapped_column(
ForeignKey("agent__sub_query.id"), primary_key=True
ForeignKey("agent__sub_query.id", ondelete="CASCADE"), primary_key=True
)
search_doc_id: Mapped[int] = mapped_column(
ForeignKey("search_doc.id"), primary_key=True
@@ -405,7 +405,7 @@ class ChatMessage__StandardAnswer(Base):
__tablename__ = "chat_message__standard_answer"
chat_message_id: Mapped[int] = mapped_column(
ForeignKey("chat_message.id"), primary_key=True
ForeignKey("chat_message.id", ondelete="CASCADE"), primary_key=True
)
standard_answer_id: Mapped[int] = mapped_column(
ForeignKey("standard_answer.id"), primary_key=True
@@ -1430,7 +1430,9 @@ class AgentSubQuestion(Base):
__tablename__ = "agent__sub_question"
id: Mapped[int] = mapped_column(primary_key=True)
primary_question_id: Mapped[int] = mapped_column(ForeignKey("chat_message.id"))
primary_question_id: Mapped[int] = mapped_column(
ForeignKey("chat_message.id", ondelete="CASCADE")
)
chat_session_id: Mapped[UUID] = mapped_column(
PGUUID(as_uuid=True), ForeignKey("chat_session.id")
)
@@ -1464,7 +1466,7 @@ class AgentSubQuery(Base):
id: Mapped[int] = mapped_column(primary_key=True)
parent_question_id: Mapped[int] = mapped_column(
ForeignKey("agent__sub_question.id")
ForeignKey("agent__sub_question.id", ondelete="CASCADE")
)
chat_session_id: Mapped[UUID] = mapped_column(
PGUUID(as_uuid=True), ForeignKey("chat_session.id")

View File

@@ -36,7 +36,7 @@ MAX_OR_CONDITIONS = 10
# up from 500ms for now, since we've seen quite a few timeouts
# in the long term, we are looking to improve the performance of Vespa
# so that we can bring this back to default
VESPA_TIMEOUT = "3s"
VESPA_TIMEOUT = "10s"
BATCH_SIZE = 128 # Specific to Vespa
TENANT_ID = "tenant_id"

View File

@@ -301,7 +301,7 @@ def read_pdf_file(
def docx_to_text_and_images(
file: IO[Any],
file: IO[Any], file_name: str = ""
) -> tuple[str, Sequence[tuple[bytes, str]]]:
"""
Extract text from a docx. If embed_images=True, also extract inline images.
@@ -310,7 +310,11 @@ def docx_to_text_and_images(
paragraphs = []
embedded_images: list[tuple[bytes, str]] = []
doc = docx.Document(file)
try:
doc = docx.Document(file)
except BadZipFile as e:
logger.warning(f"Failed to extract text from {file_name or 'docx file'}: {e}")
return "", []
# Grab text from paragraphs
for paragraph in doc.paragraphs:
@@ -360,6 +364,13 @@ def xlsx_to_text(file: IO[Any], file_name: str = "") -> str:
else:
logger.warning(error_str)
return ""
except Exception as e:
if "File contains no valid workbook part" in str(e):
logger.error(
f"Failed to extract text from {file_name or 'xlsx file'}. This happens due to a bug in openpyxl. {e}"
)
return ""
raise e
text_content = []
for sheet in workbook.worksheets:

View File

@@ -2,6 +2,7 @@ import base64
from collections.abc import Callable
from io import BytesIO
from typing import cast
from uuid import UUID
from uuid import uuid4
import requests
@@ -12,11 +13,11 @@ from onyx.db.engine import get_session_with_current_tenant
from onyx.db.models import ChatMessage
from onyx.db.models import UserFile
from onyx.db.models import UserFolder
from onyx.file_processing.extract_file_text import IMAGE_MEDIA_TYPES
from onyx.file_store.file_store import get_default_file_store
from onyx.file_store.models import ChatFileType
from onyx.file_store.models import FileDescriptor
from onyx.file_store.models import InMemoryChatFile
from onyx.server.query_and_chat.chat_utils import mime_type_to_chat_file_type
from onyx.utils.b64 import get_image_type
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
@@ -119,27 +120,37 @@ def load_user_file(file_id: int, db_session: Session) -> InMemoryChatFile:
if not user_file:
raise ValueError(f"User file with id {file_id} not found")
# Try to load plaintext version first
# Get the file record to determine the appropriate chat file type
file_store = get_default_file_store(db_session)
file_record = file_store.read_file_record(user_file.file_id)
# Determine appropriate chat file type based on the original file's MIME type
chat_file_type = mime_type_to_chat_file_type(file_record.file_type)
# Try to load plaintext version first
plaintext_file_name = user_file_id_to_plaintext_file_name(file_id)
# check for plain text normalized version first, then use original file otherwise
try:
file_io = file_store.read_file(plaintext_file_name, mode="b")
# For plaintext versions, use PLAIN_TEXT type (unless it's an image which doesn't have plaintext)
plaintext_chat_file_type = (
ChatFileType.PLAIN_TEXT
if chat_file_type != ChatFileType.IMAGE
else chat_file_type
)
chat_file = InMemoryChatFile(
file_id=str(user_file.file_id),
content=file_io.read(),
file_type=ChatFileType.USER_KNOWLEDGE,
file_type=plaintext_chat_file_type,
filename=user_file.name,
)
status = "plaintext"
return chat_file
except Exception:
except Exception as e:
logger.warning(f"Failed to load plaintext for user file {user_file.id}: {e}")
# Fall back to original file if plaintext not available
file_io = file_store.read_file(user_file.file_id, mode="b")
file_record = file_store.read_file_record(user_file.file_id)
if file_record.file_type in IMAGE_MEDIA_TYPES:
chat_file_type = ChatFileType.IMAGE
chat_file = InMemoryChatFile(
file_id=str(user_file.file_id),
@@ -235,6 +246,26 @@ def get_user_files(
return user_files
def get_user_files_as_user(
user_file_ids: list[int],
user_folder_ids: list[int],
user_id: UUID | None,
db_session: Session,
) -> list[UserFile]:
"""
Fetches all UserFile database records for a given user.
"""
user_files = get_user_files(user_file_ids, user_folder_ids, db_session)
for user_file in user_files:
# Note: if user_id is None, then all files should be None as well
# (since auth must be disabled in this case)
if user_file.user_id != user_id:
raise ValueError(
f"User {user_id} does not have access to file {user_file.id}"
)
return user_files
def save_file_from_url(url: str) -> str:
"""NOTE: using multiple sessions here, since this is often called
using multithreading. In practice, sharing a session has resulted in

View File

@@ -264,7 +264,7 @@ class DefaultMultiLLM(LLM):
):
self._timeout = timeout
if timeout is None:
if model_is_reasoning_model(model_name):
if model_is_reasoning_model(model_name, model_provider):
self._timeout = QA_TIMEOUT * 10 # Reasoning models are slow
else:
self._timeout = QA_TIMEOUT

View File

@@ -108,7 +108,7 @@ VERTEXAI_DEFAULT_MODEL = "gemini-2.0-flash"
VERTEXAI_DEFAULT_FAST_MODEL = "gemini-2.0-flash-lite"
VERTEXAI_MODEL_NAMES = [
# 2.5 pro models
"gemini-2.5-pro-exp-03-25",
"gemini-2.5-pro-preview-05-06",
# 2.0 flash-lite models
VERTEXAI_DEFAULT_FAST_MODEL,
"gemini-2.0-flash-lite-001",

View File

@@ -663,12 +663,34 @@ def model_supports_image_input(model_name: str, model_provider: str) -> bool:
return False
def model_is_reasoning_model(model_name: str) -> bool:
_REASONING_MODEL_NAMES = [
"o1",
"o1-mini",
"o3-mini",
"deepseek-reasoner",
"deepseek-r1",
]
return model_name.lower() in _REASONING_MODEL_NAMES
def model_is_reasoning_model(model_name: str, model_provider: str) -> bool:
model_map = get_model_map()
try:
model_obj = find_model_obj(
model_map,
model_provider,
model_name,
)
if model_obj and "supports_reasoning" in model_obj:
return model_obj["supports_reasoning"]
# Fallback: try using litellm.supports_reasoning() for newer models
try:
logger.debug("Falling back to `litellm.supports_reasoning`")
full_model_name = (
f"{model_provider}/{model_name}"
if model_provider not in model_name
else model_name
)
return litellm.supports_reasoning(model=full_model_name)
except Exception:
logger.exception(
f"Failed to check if {model_provider}/{model_name} supports reasoning"
)
return False
except Exception:
logger.exception(
f"Failed to get model object for {model_provider}/{model_name}"
)
return False

View File

@@ -185,9 +185,7 @@ class EmbeddingModel:
) -> list[Embedding]:
text_batches = batch_list(texts, batch_size)
logger.debug(
f"Encoding {len(texts)} texts in {len(text_batches)} batches for local model"
)
logger.debug(f"Encoding {len(texts)} texts in {len(text_batches)} batches")
embeddings: list[Embedding] = []

View File

@@ -64,7 +64,7 @@ TENANT_HEARTBEAT_INTERVAL = (
15 # How often pods send heartbeats to indicate they are still processing a tenant
)
TENANT_HEARTBEAT_EXPIRATION = (
30 # How long before a tenant's heartbeat expires, allowing other pods to take over
60 # How long before a tenant's heartbeat expires, allowing other pods to take over
)
TENANT_ACQUISITION_INTERVAL = 60 # How often pods attempt to acquire unprocessed tenants and checks for new tokens

View File

@@ -137,7 +137,10 @@ def handle_generate_answer_button(
raise ValueError("Missing thread_ts in the payload")
thread_messages = read_slack_thread(
channel=channel_id, thread=thread_ts, client=client.web_client
tenant_id=client._tenant_id,
channel=channel_id,
thread=thread_ts,
client=client.web_client,
)
# remove all assistant messages till we get to the last user message
# we want the new answer to be generated off of the last "question" in

View File

@@ -419,6 +419,11 @@ def handle_regular_answer(
skip_ai_feedback=skip_ai_feedback,
)
# NOTE(rkuo): Slack has a maximum block list size of 50.
# we should modify build_slack_response_blocks to respect the max
# but enforcing the hard limit here is the last resort.
all_blocks = all_blocks[:50]
try:
respond_in_thread_or_channel(
client=client,

View File

@@ -1,4 +1,3 @@
import asyncio
import os
import signal
import sys
@@ -11,11 +10,12 @@ from types import FrameType
from typing import Any
from typing import cast
from typing import Dict
from typing import Set
import psycopg2.errors
from prometheus_client import Gauge
from prometheus_client import start_http_server
from redis.lock import Lock
from redis.lock import Lock as RedisLock
from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from slack_sdk.http_retry import ConnectionErrorRetryHandler
@@ -86,7 +86,7 @@ from onyx.onyxbot.slack.models import SlackMessageInfo
from onyx.onyxbot.slack.utils import check_message_limit
from onyx.onyxbot.slack.utils import decompose_action_id
from onyx.onyxbot.slack.utils import get_channel_name_from_id
from onyx.onyxbot.slack.utils import get_onyx_bot_slack_bot_id
from onyx.onyxbot.slack.utils import get_onyx_bot_auth_ids
from onyx.onyxbot.slack.utils import read_slack_thread
from onyx.onyxbot.slack.utils import remove_onyx_bot_tag
from onyx.onyxbot.slack.utils import rephrase_slack_message
@@ -105,7 +105,6 @@ from shared_configs.configs import SLACK_CHANNEL_ID
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
# Prometheus metric for HPA
@@ -135,7 +134,7 @@ _OFFICIAL_SLACKBOT_USER_ID = "USLACKBOT"
class SlackbotHandler:
def __init__(self) -> None:
logger.info("Initializing SlackbotHandler")
self.tenant_ids: Set[str] = set()
self.tenant_ids: set[str] = set()
# The keys for these dictionaries are tuples of (tenant_id, slack_bot_id)
self.socket_clients: Dict[tuple[str, int], TenantSocketModeClient] = {}
self.slack_bot_tokens: Dict[tuple[str, int], SlackBotTokens] = {}
@@ -144,8 +143,11 @@ class SlackbotHandler:
self.redis_locks: Dict[str, Lock] = {}
self.running = True
self.pod_id = self.get_pod_id()
self.pod_id = os.environ.get("HOSTNAME", "unknown_pod")
self._shutdown_event = Event()
self._lock = threading.Lock()
logger.info(f"Pod ID: {self.pod_id}")
# Set up signal handlers for graceful shutdown
@@ -169,12 +171,8 @@ class SlackbotHandler:
self.acquire_thread.start()
self.heartbeat_thread.start()
logger.info("Background threads started")
def get_pod_id(self) -> str:
pod_id = os.environ.get("HOSTNAME", "unknown_pod")
logger.info(f"Retrieved pod ID: {pod_id}")
return pod_id
logger.info("Background threads started")
def acquire_tenants_loop(self) -> None:
while not self._shutdown_event.is_set():
@@ -194,12 +192,18 @@ class SlackbotHandler:
self._shutdown_event.wait(timeout=TENANT_ACQUISITION_INTERVAL)
def heartbeat_loop(self) -> None:
"""This heartbeats into redis.
NOTE(rkuo): this is not thread-safe with acquire_tenants_loop and will
occasionally exception. Fix it!
"""
while not self._shutdown_event.is_set():
try:
self.send_heartbeats()
logger.debug(
f"Sent heartbeats for {len(self.tenant_ids)} active tenants"
)
with self._lock:
tenant_ids = self.tenant_ids.copy()
SlackbotHandler.send_heartbeats(self.pod_id, tenant_ids)
logger.debug(f"Sent heartbeats for {len(tenant_ids)} active tenants")
except Exception as e:
logger.exception(f"Error in heartbeat loop: {e}")
self._shutdown_event.wait(timeout=TENANT_HEARTBEAT_INTERVAL)
@@ -224,7 +228,7 @@ class SlackbotHandler:
f"No Slack bot tokens found for tenant={tenant_id}, bot {bot.id}"
)
if tenant_bot_pair in self.socket_clients:
asyncio.run(self.socket_clients[tenant_bot_pair].close())
self.socket_clients[tenant_bot_pair].close()
del self.socket_clients[tenant_bot_pair]
del self.slack_bot_tokens[tenant_bot_pair]
return
@@ -252,9 +256,20 @@ class SlackbotHandler:
# Close any existing connection first
if tenant_bot_pair in self.socket_clients:
asyncio.run(self.socket_clients[tenant_bot_pair].close())
self.socket_clients[tenant_bot_pair].close()
self.start_socket_client(bot.id, tenant_id, slack_bot_tokens)
socket_client = self.start_socket_client(
bot.id, tenant_id, slack_bot_tokens
)
if socket_client:
# Ensure tenant is tracked as active
self.socket_clients[tenant_id, bot.id] = socket_client
logger.info(
f"Started SocketModeClient: {tenant_id=} {socket_client.bot_name=} {bot.id=}"
)
self.tenant_ids.add(tenant_id)
def acquire_tenants(self) -> None:
"""
@@ -264,6 +279,8 @@ class SlackbotHandler:
- If a tenant in self.tenant_ids no longer has Slack bots, remove it (and release the lock in this scope).
"""
token: Token[str | None]
# tenants that are disabled (e.g. their trial is over and haven't subscribed)
# for non-cloud, this will return an empty set
gated_tenants = fetch_ee_implementation_or_noop(
@@ -271,16 +288,14 @@ class SlackbotHandler:
"get_gated_tenants",
set(),
)()
all_tenants = [
all_active_tenants = [
tenant_id
for tenant_id in get_all_tenant_ids()
if tenant_id not in gated_tenants
]
token: Token[str | None]
# 1) Try to acquire locks for new tenants
for tenant_id in all_tenants:
for tenant_id in all_active_tenants:
if (
DISALLOWED_SLACK_BOT_TENANT_LIST is not None
and tenant_id in DISALLOWED_SLACK_BOT_TENANT_LIST
@@ -295,14 +310,18 @@ class SlackbotHandler:
# Respect max tenant limit per pod
if len(self.tenant_ids) >= MAX_TENANTS_PER_POD:
logger.info(
f"Max tenants per pod reached ({MAX_TENANTS_PER_POD}); not acquiring more."
f"Max tenants per pod reached, not acquiring more: {MAX_TENANTS_PER_POD=}"
)
break
redis_client = get_redis_client(tenant_id=tenant_id)
# Acquire a Redis lock (non-blocking)
rlock = redis_client.lock(
OnyxRedisLocks.SLACK_BOT_LOCK, timeout=TENANT_LOCK_EXPIRATION
# thread_local=False because the shutdown event is handled
# on an arbitrary thread
rlock: RedisLock = redis_client.lock(
OnyxRedisLocks.SLACK_BOT_LOCK,
timeout=TENANT_LOCK_EXPIRATION,
thread_local=False,
)
lock_acquired = rlock.acquire(blocking=False)
@@ -333,6 +352,10 @@ class SlackbotHandler:
except KvKeyNotFoundError:
# No Slackbot tokens, pass
pass
except psycopg2.errors.UndefinedTable:
logger.error(
"Undefined table error in fetch_slack_bots. Tenant schema may need fixing."
)
except Exception as e:
logger.exception(
f"Error fetching Slack bots for tenant {tenant_id}: {e}"
@@ -409,10 +432,11 @@ class SlackbotHandler:
Helper to remove a tenant from `self.tenant_ids` and close any socket clients.
(Lock release now happens in `acquire_tenants()`, not here.)
"""
socket_client_list = list(self.socket_clients.items())
# Close all socket clients for this tenant
for (t_id, slack_bot_id), client in list(self.socket_clients.items()):
for (t_id, slack_bot_id), client in socket_client_list:
if t_id == tenant_id:
asyncio.run(client.close())
client.close()
del self.socket_clients[(t_id, slack_bot_id)]
del self.slack_bot_tokens[(t_id, slack_bot_id)]
logger.info(
@@ -423,19 +447,22 @@ class SlackbotHandler:
if tenant_id in self.tenant_ids:
self.tenant_ids.remove(tenant_id)
def send_heartbeats(self) -> None:
@staticmethod
def send_heartbeats(pod_id: str, tenant_ids: set[str]) -> None:
current_time = int(time.time())
logger.debug(f"Sending heartbeats for {len(self.tenant_ids)} active tenants")
for tenant_id in self.tenant_ids:
logger.debug(f"Sending heartbeats for {len(tenant_ids)} active tenants")
for tenant_id in tenant_ids:
redis_client = get_redis_client(tenant_id=tenant_id)
heartbeat_key = f"{OnyxRedisLocks.SLACK_BOT_HEARTBEAT_PREFIX}:{self.pod_id}"
heartbeat_key = f"{OnyxRedisLocks.SLACK_BOT_HEARTBEAT_PREFIX}:{pod_id}"
redis_client.set(
heartbeat_key, current_time, ex=TENANT_HEARTBEAT_EXPIRATION
)
@staticmethod
def start_socket_client(
self, slack_bot_id: int, tenant_id: str, slack_bot_tokens: SlackBotTokens
) -> None:
slack_bot_id: int, tenant_id: str, slack_bot_tokens: SlackBotTokens
) -> TenantSocketModeClient | None:
"""Returns the socket client if this succeeds"""
socket_client: TenantSocketModeClient = _get_socket_client(
slack_bot_tokens, tenant_id, slack_bot_id
)
@@ -450,18 +477,21 @@ class SlackbotHandler:
bot_name = (
user_info["user"]["real_name"] or user_info["user"]["name"]
)
logger.info(
f"Started socket client for Slackbot with name '{bot_name}' (tenant: {tenant_id}, app: {slack_bot_id})"
)
socket_client.bot_name = bot_name
# logger.info(
# f"Started socket client for Slackbot with name '{bot_name}' (tenant: {tenant_id}, app: {slack_bot_id})"
# )
except SlackApiError as e:
# Only error out if we get a not_authed error
if "not_authed" in str(e):
self.tenant_ids.add(tenant_id)
# for some reason we want to add the tenant to the list when this happens?
logger.error(
f"Authentication error: Invalid or expired credentials for tenant: {tenant_id}, app: {slack_bot_id}. "
"Error: {e}"
f"Authentication error - Invalid or expired credentials: "
f"{tenant_id=} {slack_bot_id=}. "
f"Error: {e}"
)
return
return None
# Log other Slack API errors but continue
logger.error(
f"Slack API error fetching bot info: {e} for tenant: {tenant_id}, app: {slack_bot_id}"
@@ -477,23 +507,30 @@ class SlackbotHandler:
socket_client.socket_mode_request_listeners.append(process_slack_event) # type: ignore
# Establish a WebSocket connection to the Socket Mode servers
logger.info(
f"Connecting socket client for tenant: {tenant_id}, app: {slack_bot_id}"
)
# logger.debug(
# f"Connecting socket client for tenant: {tenant_id}, app: {slack_bot_id}"
# )
socket_client.connect()
self.socket_clients[tenant_id, slack_bot_id] = socket_client
# Ensure tenant is tracked as active
self.tenant_ids.add(tenant_id)
logger.info(
f"Started SocketModeClient for tenant: {tenant_id}, app: {slack_bot_id}"
)
# logger.info(
# f"Started SocketModeClient for tenant: {tenant_id}, app: {slack_bot_id}"
# )
def stop_socket_clients(self) -> None:
logger.info(f"Stopping {len(self.socket_clients)} socket clients")
for (tenant_id, slack_bot_id), client in list(self.socket_clients.items()):
asyncio.run(client.close())
return socket_client
@staticmethod
def stop_socket_clients(
pod_id: str, socket_clients: Dict[tuple[str, int], TenantSocketModeClient]
) -> None:
socket_client_list = list(socket_clients.items())
length = len(socket_client_list)
x = 0
for (tenant_id, slack_bot_id), client in socket_client_list:
x += 1
client.close()
logger.info(
f"Stopped SocketModeClient for tenant: {tenant_id}, app: {slack_bot_id}"
f"Stopped SocketModeClient {x}/{length}: "
f"{pod_id=} {tenant_id=} {slack_bot_id=}"
)
def shutdown(self, signum: int | None, frame: FrameType | None) -> None:
@@ -502,11 +539,15 @@ class SlackbotHandler:
logger.info("Shutting down gracefully")
self.running = False
self._shutdown_event.set()
self._shutdown_event.set() # set the shutdown event
# wait for threads to detect the event and exit
self.acquire_thread.join(timeout=60.0)
self.heartbeat_thread.join(timeout=60.0)
# Stop all socket clients
logger.info(f"Stopping {len(self.socket_clients)} socket clients")
self.stop_socket_clients()
SlackbotHandler.stop_socket_clients(self.pod_id, self.socket_clients)
# Release locks for all tenants we currently hold
logger.info(f"Releasing locks for {len(self.tenant_ids)} tenants")
@@ -533,7 +574,13 @@ def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -
"""True to keep going, False to ignore this Slack request"""
# skip cases where the bot is disabled in the web UI
bot_tag_id = get_onyx_bot_slack_bot_id(client.web_client)
tenant_id = get_current_tenant_id()
bot_token_user_id, bot_token_bot_id = get_onyx_bot_auth_ids(
tenant_id, client.web_client
)
logger.info(f"prefilter_requests: {bot_token_user_id=} {bot_token_bot_id=}")
with get_session_with_current_tenant() as db_session:
slack_bot = fetch_slack_bot(
db_session=db_session, slack_bot_id=client.slack_bot_id
@@ -580,7 +627,7 @@ def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -
if (
msg in _SLACK_GREETINGS_TO_IGNORE
or remove_onyx_bot_tag(msg, client=client.web_client)
or remove_onyx_bot_tag(tenant_id, msg, client=client.web_client)
in _SLACK_GREETINGS_TO_IGNORE
):
channel_specific_logger.error(
@@ -599,15 +646,38 @@ def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -
)
return False
bot_tag_id = get_onyx_bot_slack_bot_id(client.web_client)
bot_token_user_id, bot_token_bot_id = get_onyx_bot_auth_ids(
tenant_id, client.web_client
)
if event_type == "message":
is_onyx_bot_msg = False
is_tagged = False
event_user = event.get("user", "")
event_bot_id = event.get("bot_id", "")
# temporary debugging
if tenant_id == "tenant_i-04224818da13bf695":
logger.warning(
f"{tenant_id=} "
f"{bot_token_user_id=} "
f"{bot_token_bot_id=} "
f"{event=}"
)
is_dm = event.get("channel_type") == "im"
is_tagged = bot_tag_id and f"<@{bot_tag_id}>" in msg
is_onyx_bot_msg = bot_tag_id and bot_tag_id in event.get("user", "")
if bot_token_user_id and f"<@{bot_token_user_id}>" in msg:
is_tagged = True
if bot_token_user_id and bot_token_user_id in event_user:
is_onyx_bot_msg = True
if bot_token_bot_id and bot_token_bot_id in event_bot_id:
is_onyx_bot_msg = True
# OnyxBot should never respond to itself
if is_onyx_bot_msg:
logger.info("Ignoring message from OnyxBot")
logger.info("Ignoring message from OnyxBot (self-message)")
return False
# DMs with the bot don't pick up the @OnyxBot so we have to keep the
@@ -632,7 +702,7 @@ def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -
)
# If OnyxBot is not specifically tagged and the channel is not set to respond to bots, ignore the message
if (not bot_tag_id or bot_tag_id not in msg) and (
if (not bot_token_user_id or bot_token_user_id not in msg) and (
not slack_channel_config
or not slack_channel_config.channel_config.get("respond_to_bots")
):
@@ -692,7 +762,7 @@ def prefilter_requests(req: SocketModeRequest, client: TenantSocketModeClient) -
if not check_message_limit():
return False
logger.debug(f"Handling Slack request with Payload: '{req.payload}'")
logger.debug(f"Handling Slack request: {client.bot_name=} '{req.payload=}'")
return True
@@ -731,15 +801,16 @@ def process_feedback(req: SocketModeRequest, client: TenantSocketModeClient) ->
def build_request_details(
req: SocketModeRequest, client: TenantSocketModeClient
) -> SlackMessageInfo:
tagged: bool = False
tenant_id = get_current_tenant_id()
if req.type == "events_api":
event = cast(dict[str, Any], req.payload["event"])
msg = cast(str, event["text"])
channel = cast(str, event["channel"])
# Check for both app_mention events and messages containing bot tag
bot_tag_id = get_onyx_bot_slack_bot_id(client.web_client)
tagged = (event.get("type") == "app_mention") or (
event.get("type") == "message" and bot_tag_id and f"<@{bot_tag_id}>" in msg
)
bot_token_user_id, _ = get_onyx_bot_auth_ids(tenant_id, client.web_client)
message_ts = event.get("ts")
thread_ts = event.get("thread_ts")
sender_id = event.get("user") or None
@@ -748,7 +819,7 @@ def build_request_details(
)
email = expert_info.email if expert_info else None
msg = remove_onyx_bot_tag(msg, client=client.web_client)
msg = remove_onyx_bot_tag(tenant_id, msg, client=client.web_client)
if DANSWER_BOT_REPHRASE_MESSAGE:
logger.info(f"Rephrasing Slack message. Original message: {msg}")
@@ -760,12 +831,24 @@ def build_request_details(
else:
logger.info(f"Received Slack message: {msg}")
event_type = event.get("type")
if event_type == "app_mention":
tagged = True
if event_type == "message":
if bot_token_user_id:
if f"<@{bot_token_user_id}>" in msg:
tagged = True
if tagged:
logger.debug("User tagged OnyxBot")
if thread_ts != message_ts and thread_ts is not None:
thread_messages = read_slack_thread(
channel=channel, thread=thread_ts, client=client.web_client
tenant_id=tenant_id,
channel=channel,
thread=thread_ts,
client=client.web_client,
)
else:
sender_display_name = None
@@ -842,12 +925,24 @@ def process_message(
notify_no_answer: bool = NOTIFY_SLACKBOT_NO_ANSWER,
) -> None:
tenant_id = get_current_tenant_id()
logger.debug(
f"Received Slack request of type: '{req.type}' for tenant, {tenant_id}"
)
if req.type == "events_api":
event = cast(dict[str, Any], req.payload["event"])
event_type = event.get("type")
msg = cast(str, event.get("text", ""))
logger.info(
f"process_message start: {tenant_id=} {req.type=} {req.envelope_id=} "
f"{event_type=} {msg=}"
)
else:
logger.info(
f"process_message start: {tenant_id=} {req.type=} {req.envelope_id=}"
)
# Throw out requests that can't or shouldn't be handled
if not prefilter_requests(req, client):
logger.info(
f"process_message prefiltered: {tenant_id=} {req.type=} {req.envelope_id=}"
)
return
details = build_request_details(req, client)
@@ -890,6 +985,10 @@ def process_message(
if notify_no_answer:
apologize_for_fail(details, client)
logger.info(
f"process_message finished: success={not failed} {tenant_id=} {req.type=} {req.envelope_id=}"
)
def acknowledge_message(req: SocketModeRequest, client: TenantSocketModeClient) -> None:
response = SocketModeResponse(envelope_id=req.envelope_id)

View File

@@ -2,6 +2,7 @@ import logging
import random
import re
import string
import threading
import time
import uuid
from collections.abc import Generator
@@ -48,17 +49,38 @@ from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
slack_token_user_ids: dict[str, str | None] = {}
slack_token_bot_ids: dict[str, str | None] = {}
slack_token_lock = threading.Lock()
_DANSWER_BOT_SLACK_BOT_ID: str | None = None
_DANSWER_BOT_MESSAGE_COUNT: int = 0
_DANSWER_BOT_COUNT_START_TIME: float = time.time()
def get_onyx_bot_slack_bot_id(web_client: WebClient) -> Any:
global _DANSWER_BOT_SLACK_BOT_ID
if _DANSWER_BOT_SLACK_BOT_ID is None:
_DANSWER_BOT_SLACK_BOT_ID = web_client.auth_test().get("user_id")
return _DANSWER_BOT_SLACK_BOT_ID
def get_onyx_bot_auth_ids(
tenant_id: str, web_client: WebClient
) -> tuple[str | None, str | None]:
"""Returns a tuple of user_id and bot_id."""
user_id: str | None
bot_id: str | None
global slack_token_user_ids
global slack_token_bot_ids
with slack_token_lock:
user_id = slack_token_user_ids.get(tenant_id)
bot_id = slack_token_bot_ids.get(tenant_id)
if user_id is None or bot_id is None:
response = web_client.auth_test()
user_id = response.get("user_id")
bot_id = response.get("bot_id")
with slack_token_lock:
slack_token_user_ids[tenant_id] = user_id
slack_token_bot_ids[tenant_id] = bot_id
return user_id, bot_id
def check_message_limit() -> bool:
@@ -117,35 +139,38 @@ def update_emote_react(
remove: bool,
client: WebClient,
) -> None:
try:
if not message_ts:
logger.error(
f"Tried to remove a react in {channel} but no message specified"
)
return
if not message_ts:
action = "remove" if remove else "add"
logger.error(f"update_emote_react - no message specified: {channel=} {action=}")
return
if remove:
if remove:
try:
client.reactions_remove(
name=emoji,
channel=channel,
timestamp=message_ts,
)
else:
client.reactions_add(
name=emoji,
channel=channel,
timestamp=message_ts,
)
except SlackApiError as e:
if remove:
except SlackApiError as e:
logger.error(f"Failed to remove Reaction due to: {e}")
else:
logger.error(f"Was not able to react to user message due to: {e}")
return
try:
client.reactions_add(
name=emoji,
channel=channel,
timestamp=message_ts,
)
except SlackApiError as e:
logger.error(f"Was not able to react to user message due to: {e}")
return
def remove_onyx_bot_tag(message_str: str, client: WebClient) -> str:
bot_tag_id = get_onyx_bot_slack_bot_id(web_client=client)
return re.sub(rf"<@{bot_tag_id}>\s*", "", message_str)
def remove_onyx_bot_tag(tenant_id: str, message_str: str, client: WebClient) -> str:
bot_token_user_id, _ = get_onyx_bot_auth_ids(tenant_id, web_client=client)
return re.sub(rf"<@{bot_token_user_id}>\s*", "", message_str)
def _check_for_url_in_block(block: Block) -> bool:
@@ -215,7 +240,8 @@ def respond_in_thread_or_channel(
unfurl_media=unfurl,
)
except Exception as e:
logger.warning(f"Failed to post message: {e} \n blocks: {blocks}")
blocks_str = str(blocks)[:1024] # truncate block logging
logger.warning(f"Failed to post message: {e} \n blocks: {blocks_str}")
logger.warning("Trying again without blocks that have urls")
if not blocks:
@@ -252,7 +278,8 @@ def respond_in_thread_or_channel(
unfurl_media=unfurl,
)
except Exception as e:
logger.warning(f"Failed to post message: {e} \n blocks: {blocks}")
blocks_str = str(blocks)[:1024] # truncate block logging
logger.warning(f"Failed to post message: {e} \n blocks: {blocks_str}")
logger.warning("Trying again without blocks that have urls")
if not blocks:
@@ -515,7 +542,7 @@ def fetch_user_semantic_id_from_id(
def read_slack_thread(
channel: str, thread: str, client: WebClient
tenant_id: str, channel: str, thread: str, client: WebClient
) -> list[ThreadMessage]:
thread_messages: list[ThreadMessage] = []
response = client.conversations_replies(channel=channel, ts=thread)
@@ -529,9 +556,22 @@ def read_slack_thread(
)
message_type = MessageType.USER
else:
self_slack_bot_id = get_onyx_bot_slack_bot_id(client)
blocks: Any
if reply.get("user") == self_slack_bot_id:
is_onyx_bot_response = False
reply_user = reply.get("user")
reply_bot_id = reply.get("bot_id")
self_slack_bot_user_id, self_slack_bot_bot_id = get_onyx_bot_auth_ids(
tenant_id, client
)
if reply_user is not None and reply_user == self_slack_bot_user_id:
is_onyx_bot_response = True
if reply_bot_id is not None and reply_bot_id == self_slack_bot_bot_id:
is_onyx_bot_response = True
if is_onyx_bot_response:
# OnyxBot response
message_type = MessageType.ASSISTANT
user_sem_id = "Assistant"
@@ -573,7 +613,7 @@ def read_slack_thread(
logger.warning("Skipping Slack thread message, no text found")
continue
message = remove_onyx_bot_tag(message, client=client)
message = remove_onyx_bot_tag(tenant_id, message, client=client)
thread_messages.append(
ThreadMessage(message=message, sender=user_sem_id, role=message_type)
)
@@ -676,6 +716,7 @@ class TenantSocketModeClient(SocketModeClient):
super().__init__(*args, **kwargs)
self._tenant_id = tenant_id
self.slack_bot_id = slack_bot_id
self.bot_name: str = "Unnamed"
@contextmanager
def _set_tenant_context(self) -> Generator[None, None, None]:

View File

@@ -51,7 +51,10 @@ def llm_eval_section(
messages = _get_usefulness_messages()
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
model_output = message_to_string(llm.invoke(filled_llm_prompt))
logger.debug(model_output)
# NOTE(rkuo): all this does is print "Yes useful" or "Not useful"
# disabling becuase it's spammy, restore and give more context if this is needed
# logger.debug(model_output)
return _extract_usefulness(model_output)
@@ -64,6 +67,8 @@ def llm_batch_eval_sections(
metadata_list: list[dict[str, str | list[str]]],
use_threads: bool = True,
) -> list[bool]:
answer: list[bool]
if DISABLE_LLM_DOC_RELEVANCE:
raise RuntimeError(
"LLM Doc Relevance is globally disabled, "
@@ -86,12 +91,13 @@ def llm_batch_eval_sections(
)
# In case of failure/timeout, don't throw out the section
return [True if item is None else item for item in parallel_results]
answer = [True if item is None else item for item in parallel_results]
return answer
else:
return [
llm_eval_section(query, section_content, llm, title, metadata)
for section_content, title, metadata in zip(
section_contents, titles, metadata_list
)
]
answer = [
llm_eval_section(query, section_content, llm, title, metadata)
for section_content, title, metadata in zip(
section_contents, titles, metadata_list
)
]
return answer

View File

@@ -403,7 +403,7 @@ def get_docs_sync_status(
def get_cc_pair_indexing_errors(
cc_pair_id: int,
include_resolved: bool = Query(False),
page: int = Query(0, ge=0),
page_num: int = Query(0, ge=0),
page_size: int = Query(10, ge=1, le=100),
_: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
@@ -413,7 +413,7 @@ def get_cc_pair_indexing_errors(
Args:
cc_pair_id: ID of the connector-credential pair to get errors for
include_resolved: Whether to include resolved errors in the results
page: Page number for pagination, starting at 0
page_num: Page number for pagination, starting at 0
page_size: Number of errors to return per page
_: Current user, must be curator or admin
db_session: Database session
@@ -431,7 +431,7 @@ def get_cc_pair_indexing_errors(
db_session=db_session,
cc_pair_id=cc_pair_id,
unresolved_only=not include_resolved,
page=page,
page=page_num,
page_size=page_size,
)
return PaginatedReturn(

View File

@@ -30,6 +30,7 @@ from onyx.chat.prompt_builder.citations_prompt import (
compute_max_document_tokens_for_persona,
)
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.chat_configs import HARD_DELETE_CHATS
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import FileOrigin
from onyx.configs.constants import MessageType
@@ -203,6 +204,7 @@ def update_chat_session_model(
def get_chat_session(
session_id: UUID,
is_shared: bool = False,
include_deleted: bool = False,
user: User | None = Depends(current_chat_accessible_user),
db_session: Session = Depends(get_session),
) -> ChatSessionDetailResponse:
@@ -213,6 +215,7 @@ def get_chat_session(
user_id=user_id,
db_session=db_session,
is_shared=is_shared,
include_deleted=include_deleted,
)
except ValueError:
raise ValueError("Chat session does not exist or has been deleted")
@@ -253,6 +256,7 @@ def get_chat_session(
time_created=chat_session.time_created,
shared_status=chat_session.shared_status,
current_temperature_override=chat_session.temperature_override,
deleted=chat_session.deleted,
)
@@ -357,12 +361,19 @@ def delete_all_chat_sessions(
@router.delete("/delete-chat-session/{session_id}")
def delete_chat_session_by_id(
session_id: UUID,
hard_delete: bool | None = None,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
user_id = user.id if user is not None else None
try:
delete_chat_session(user_id, session_id, db_session)
# Use the provided hard_delete parameter if specified, otherwise use the default config
actual_hard_delete = (
hard_delete if hard_delete is not None else HARD_DELETE_CHATS
)
delete_chat_session(
user_id, session_id, db_session, hard_delete=actual_hard_delete
)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))

View File

@@ -137,8 +137,6 @@ class CreateChatMessageRequest(ChunkContext):
# https://platform.openai.com/docs/guides/structured-outputs/introduction
structured_response_format: dict | None = None
force_user_file_search: bool = False
# If true, ignores most of the search options and uses pro search instead.
# TODO: decide how many of the above options we want to pass through to pro search
use_agentic_search: bool = False
@@ -274,6 +272,7 @@ class ChatSessionDetailResponse(BaseModel):
shared_status: ChatSessionSharedStatus
current_alternate_model: str | None
current_temperature_override: float | None
deleted: bool = False
# This one is not used anymore

View File

@@ -75,8 +75,6 @@ class SearchToolOverrideKwargs(BaseModel):
precomputed_keywords: list[str] | None = None
user_file_ids: list[int] | None = None
user_folder_ids: list[int] | None = None
# Flag for fast path when search is only needed for ordering
ordering_only: bool | None = None
document_sources: list[DocumentSource] | None = None
time_cutoff: datetime | None = None
expanded_queries: QueryExpansions | None = None

View File

@@ -16,6 +16,7 @@ from onyx.configs.app_configs import AZURE_DALLE_DEPLOYMENT_NAME
from onyx.configs.chat_configs import BING_API_KEY
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
from onyx.context.search.enums import LLMEvaluationType
from onyx.context.search.enums import OptionalSearchSetting
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import RerankingDetails
from onyx.context.search.models import RetrievalDetails
@@ -141,12 +142,11 @@ def construct_tools(
user: User | None,
llm: LLM,
fast_llm: LLM,
use_file_search: bool,
run_search_setting: OptionalSearchSetting,
search_tool_config: SearchToolConfig | None = None,
internet_search_tool_config: InternetSearchToolConfig | None = None,
image_generation_tool_config: ImageGenerationToolConfig | None = None,
custom_tool_config: CustomToolConfig | None = None,
user_knowledge_present: bool = False,
) -> dict[int, list[Tool]]:
"""Constructs tools based on persona configuration and available APIs"""
tool_dict: dict[int, list[Tool]] = {}
@@ -163,7 +163,10 @@ def construct_tools(
)
# Handle Search Tool
if tool_cls.__name__ == SearchTool.__name__ and not user_knowledge_present:
if (
tool_cls.__name__ == SearchTool.__name__
and run_search_setting != OptionalSearchSetting.NEVER
):
if not search_tool_config:
search_tool_config = SearchToolConfig()
@@ -256,33 +259,6 @@ def construct_tools(
for tool_list in tool_dict.values():
tools.extend(tool_list)
if use_file_search:
search_tool_config = SearchToolConfig()
search_tool = SearchTool(
db_session=db_session,
user=user,
persona=persona,
retrieval_options=search_tool_config.retrieval_options,
prompt_config=prompt_config,
llm=llm,
fast_llm=fast_llm,
pruning_config=search_tool_config.document_pruning_config,
answer_style_config=search_tool_config.answer_style_config,
selected_sections=search_tool_config.selected_sections,
chunks_above=search_tool_config.chunks_above,
chunks_below=search_tool_config.chunks_below,
full_doc=search_tool_config.full_doc,
evaluation_type=(
LLMEvaluationType.BASIC
if persona.llm_relevance_filter
else LLMEvaluationType.SKIP
),
rerank_settings=search_tool_config.rerank_settings,
bypass_acl=search_tool_config.bypass_acl,
)
tool_dict[1] = [search_tool]
# factor in tool definition size when pruning
if search_tool_config:
search_tool_config.document_pruning_config.tool_num_tokens = (

View File

@@ -1,6 +1,4 @@
import copy
import json
import time
from collections.abc import Callable
from collections.abc import Generator
from typing import Any
@@ -25,13 +23,13 @@ from onyx.configs.chat_configs import CONTEXT_CHUNKS_BELOW
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
from onyx.context.search.enums import LLMEvaluationType
from onyx.context.search.enums import QueryFlow
from onyx.context.search.enums import SearchType
from onyx.context.search.models import BaseFilters
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import RerankingDetails
from onyx.context.search.models import RetrievalDetails
from onyx.context.search.models import SearchRequest
from onyx.context.search.models import UserFileFilters
from onyx.context.search.pipeline import SearchPipeline
from onyx.context.search.pipeline import section_relevance_list_impl
from onyx.db.models import Persona
@@ -295,7 +293,6 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
skip_query_analysis = False
user_file_ids = None
user_folder_ids = None
ordering_only = False
document_sources = None
time_cutoff = None
expanded_queries = None
@@ -308,46 +305,19 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
)
user_file_ids = override_kwargs.user_file_ids
user_folder_ids = override_kwargs.user_folder_ids
ordering_only = use_alt_not_None(override_kwargs.ordering_only, False)
document_sources = override_kwargs.document_sources
time_cutoff = override_kwargs.time_cutoff
expanded_queries = override_kwargs.expanded_queries
# Fast path for ordering-only search
if ordering_only:
yield from self._run_ordering_only_search(
query, user_file_ids, user_folder_ids
)
return
if self.selected_sections:
yield from self._build_response_for_specified_sections(query)
return
# Create a copy of the retrieval options with user_file_ids if provided
retrieval_options = copy.deepcopy(self.retrieval_options)
if (user_file_ids or user_folder_ids) and retrieval_options:
# Create a copy to avoid modifying the original
filters = (
retrieval_options.filters.model_copy()
if retrieval_options.filters
else BaseFilters()
)
filters.user_file_ids = user_file_ids
retrieval_options = retrieval_options.model_copy(
update={"filters": filters}
)
elif user_file_ids or user_folder_ids:
# Create new retrieval options with user_file_ids
filters = BaseFilters(
user_file_ids=user_file_ids, user_folder_ids=user_folder_ids
)
retrieval_options = RetrievalDetails(filters=filters)
retrieval_options = self.retrieval_options or RetrievalDetails()
if document_sources or time_cutoff:
# Get retrieval_options and filters, or create if they don't exist
retrieval_options = retrieval_options or RetrievalDetails()
retrieval_options.filters = retrieval_options.filters or BaseFilters()
# if empty, just start with an empty filters object
if not retrieval_options.filters:
retrieval_options.filters = BaseFilters()
# Handle document sources
if document_sources:
@@ -370,6 +340,9 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
human_selected_filters=(
retrieval_options.filters if retrieval_options else None
),
user_file_filters=UserFileFilters(
user_file_ids=user_file_ids, user_folder_ids=user_folder_ids
),
persona=self.persona,
offset=(retrieval_options.offset if retrieval_options else None),
limit=retrieval_options.limit if retrieval_options else None,
@@ -451,105 +424,6 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
prompt_config=self.prompt_config,
)
def _run_ordering_only_search(
self,
query: str,
user_file_ids: list[int] | None,
user_folder_ids: list[int] | None,
) -> Generator[ToolResponse, None, None]:
"""Optimized search that only retrieves document order with minimal processing."""
start_time = time.time()
logger.info("Fast path: Starting optimized ordering-only search")
# Create temporary search pipeline for optimized retrieval
search_pipeline = SearchPipeline(
search_request=SearchRequest(
query=query,
evaluation_type=LLMEvaluationType.SKIP, # Force skip evaluation
persona=self.persona,
# Minimal configuration needed
chunks_above=0,
chunks_below=0,
),
user=self.user,
llm=self.llm,
fast_llm=self.fast_llm,
skip_query_analysis=True, # Skip unnecessary analysis
db_session=self.db_session,
bypass_acl=self.bypass_acl,
prompt_config=self.prompt_config,
contextual_pruning_config=self.contextual_pruning_config,
)
# Log what we're doing
logger.info(
f"Fast path: Using {len(user_file_ids or [])} files and {len(user_folder_ids or [])} folders"
)
# Get chunks using the optimized method in SearchPipeline
retrieval_start = time.time()
retrieved_chunks = search_pipeline.get_ordering_only_chunks(
query=query, user_file_ids=user_file_ids, user_folder_ids=user_folder_ids
)
retrieval_time = time.time() - retrieval_start
logger.info(
f"Fast path: Retrieved {len(retrieved_chunks)} chunks in {retrieval_time:.2f}s"
)
# Convert chunks to minimal sections (we don't need full content)
minimal_sections = []
for chunk in retrieved_chunks:
# Create a minimal section with just center_chunk
minimal_section = InferenceSection(
center_chunk=chunk,
chunks=[chunk],
combined_content=chunk.content, # Use the chunk content as combined content
)
minimal_sections.append(minimal_section)
# Log document IDs found for debugging
doc_ids = [chunk.document_id for chunk in retrieved_chunks]
logger.info(
f"Fast path: Document IDs in order: {doc_ids[:5]}{'...' if len(doc_ids) > 5 else ''}"
)
# Yield just the required responses for document ordering
yield ToolResponse(
id=SEARCH_RESPONSE_SUMMARY_ID,
response=SearchResponseSummary(
rephrased_query=query,
top_sections=minimal_sections,
predicted_flow=QueryFlow.QUESTION_ANSWER,
predicted_search=SearchType.SEMANTIC,
final_filters=IndexFilters(
user_file_ids=user_file_ids or [],
user_folder_ids=user_folder_ids or [],
access_control_list=None,
),
recency_bias_multiplier=1.0,
),
)
# For fast path, don't trigger any LLM evaluation for relevance
logger.info(
"Fast path: Skipping section relevance evaluation to optimize performance"
)
yield ToolResponse(
id=SECTION_RELEVANCE_LIST_ID,
response=None,
)
# We need to yield this for the caller to extract document order
minimal_docs = [
llm_doc_from_inference_section(section) for section in minimal_sections
]
yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=minimal_docs)
total_time = time.time() - start_time
logger.info(f"Fast path: Completed ordering-only search in {total_time:.2f}s")
# Allows yielding the same responses as a SearchTool without being a SearchTool.
# SearchTool passed in to allow for access to SearchTool properties.
@@ -568,10 +442,6 @@ def yield_search_responses(
get_section_relevance: Callable[[], list[SectionRelevancePiece] | None],
search_tool: SearchTool,
) -> Generator[ToolResponse, None, None]:
# Get the search query to check if we're in ordering-only mode
# We can infer this from the reranked_sections not containing any relevance scoring
is_ordering_only = search_tool.evaluation_type == LLMEvaluationType.SKIP
yield ToolResponse(
id=SEARCH_RESPONSE_SUMMARY_ID,
response=SearchResponseSummary(
@@ -584,48 +454,26 @@ def yield_search_responses(
),
)
section_relevance: list[SectionRelevancePiece] | None = None
# Skip section relevance in ordering-only mode
if is_ordering_only:
logger.info(
"Fast path: Skipping section relevance evaluation in yield_search_responses"
)
yield ToolResponse(
id=SECTION_RELEVANCE_LIST_ID,
response=None,
)
else:
section_relevance = get_section_relevance()
yield ToolResponse(
id=SECTION_RELEVANCE_LIST_ID,
response=section_relevance,
)
section_relevance = get_section_relevance()
yield ToolResponse(
id=SECTION_RELEVANCE_LIST_ID,
response=section_relevance,
)
final_context_sections = get_final_context_sections()
# Skip pruning sections in ordering-only mode
if is_ordering_only:
logger.info("Fast path: Skipping section pruning in ordering-only mode")
llm_docs = [
llm_doc_from_inference_section(section)
for section in final_context_sections
]
else:
# Use the section_relevance we already computed above
pruned_sections = prune_sections(
sections=final_context_sections,
section_relevance_list=section_relevance_list_impl(
section_relevance, final_context_sections
),
prompt_config=search_tool.prompt_config,
llm_config=search_tool.llm.config,
question=query,
contextual_pruning_config=search_tool.contextual_pruning_config,
)
llm_docs = [
llm_doc_from_inference_section(section) for section in pruned_sections
]
# Use the section_relevance we already computed above
pruned_sections = prune_sections(
sections=final_context_sections,
section_relevance_list=section_relevance_list_impl(
section_relevance, final_context_sections
),
prompt_config=search_tool.prompt_config,
llm_config=search_tool.llm.config,
question=query,
contextual_pruning_config=search_tool.contextual_pruning_config,
)
llm_docs = [llm_doc_from_inference_section(section) for section in pruned_sections]
yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs)

View File

@@ -1,7 +1,5 @@
from typing import cast
from langchain_core.messages import HumanMessage
from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import LlmDoc
from onyx.chat.models import PromptConfig
@@ -10,7 +8,6 @@ from onyx.chat.prompt_builder.citations_prompt import (
build_citations_system_message,
)
from onyx.chat.prompt_builder.citations_prompt import build_citations_user_message
from onyx.llm.utils import build_content_with_imgs
from onyx.tools.message import ToolCallSummary
from onyx.tools.models import ToolResponse
@@ -45,12 +42,8 @@ def build_next_prompt_for_search_like_tool(
build_citations_user_message(
# make sure to use the original user query here in order to avoid duplication
# of the task prompt
message=HumanMessage(
content=build_content_with_imgs(
prompt_builder.raw_user_query,
prompt_builder.raw_user_uploaded_files,
)
),
user_query=prompt_builder.raw_user_query,
files=prompt_builder.raw_user_uploaded_files,
prompt_config=prompt_config,
context_docs=final_context_documents,
all_doc_useful=(

View File

@@ -8,6 +8,7 @@ from onyx.db.connector import check_connectors_exist
from onyx.db.document import check_docs_exist
from onyx.db.models import LLMProvider
from onyx.llm.llm_provider_options import ANTHROPIC_PROVIDER_NAME
from onyx.llm.llm_provider_options import BEDROCK_PROVIDER_NAME
from onyx.llm.utils import find_model_obj
from onyx.llm.utils import get_model_map
from onyx.natural_language_processing.utils import BaseTokenizer
@@ -35,6 +36,10 @@ def explicit_tool_calling_supported(model_provider: str, model_name: str) -> boo
model_supports
and model_provider != ANTHROPIC_PROVIDER_NAME
and model_name not in litellm.anthropic_models
and (
model_provider != BEDROCK_PROVIDER_NAME
or not any(name in model_name for name in litellm.anthropic_models)
)
)

View File

@@ -8,3 +8,8 @@ filterwarnings =
ignore::DeprecationWarning
ignore::cryptography.utils.CryptographyDeprecationWarning
ignore::PendingDeprecationWarning:ddtrace.internal.module
# .test.env is gitignored.
# After installing pytest-dotenv,
# you can use it to test credentials locally.
env_files =
.test.env

View File

@@ -1,7 +1,7 @@
aioboto3==14.0.0
aiohttp==3.11.16
alembic==1.10.4
asyncpg==0.27.0
asyncpg==0.30.0
atlassian-python-api==3.41.16
beautifulsoup4==4.12.3
boto3==1.36.23

View File

@@ -12,6 +12,7 @@ pandas==2.2.3
posthog==3.7.4
pre-commit==3.2.2
pytest-asyncio==0.22.0
pytest-dotenv==0.5.2
pytest-xdist==3.6.1
pytest==8.3.5
reorder-python-imports-black==3.14.0

View File

@@ -21,6 +21,7 @@ if True: # noqa: E402
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import SqlEngine
from onyx.db.models import Document
from onyx.db.models import User
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
@@ -30,6 +31,8 @@ if True: # noqa: E402
class TenantMetadata(BaseModel):
first_email: str | None
user_count: int
num_docs: int
num_chunks: int
@@ -39,7 +42,7 @@ class SQLAlchemyDebugging:
def __init__(self) -> None:
pass
def top_chunks(self, k: int = 10) -> None:
def top_chunks(self, filename: str, k: int = 10) -> None:
tenants_to_total_chunks: dict[str, TenantMetadata] = {}
logger.info("Fetching all tenant id's.")
@@ -56,6 +59,14 @@ class SQLAlchemyDebugging:
try:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
first_email = None
first_user = db_session.query(User).first()
if first_user:
first_email = first_user.email
user_count = db_session.query(User).count()
# Calculate the total number of document rows for the current tenant
total_documents = db_session.query(Document).count()
# marginally useful to skip some tenants ... maybe we can improve on this
@@ -69,15 +80,20 @@ class SQLAlchemyDebugging:
total_chunks = db_session.query(
func.sum(Document.chunk_count)
).scalar()
total_chunks = total_chunks or 0
logger.info(
f"{num_processed} of {num_tenant_ids}: Tenant '{tenant_id}': "
f"first_email={first_email} user_count={user_count} "
f"docs={total_documents} chunks={total_chunks}"
)
tenants_to_total_chunks[tenant_id] = TenantMetadata(
num_docs=total_documents, num_chunks=total_chunks
first_email=first_email,
user_count=user_count,
num_docs=total_documents,
num_chunks=total_chunks,
)
except Exception as e:
logger.error(f"Error processing tenant '{tenant_id}': {e}")
@@ -91,14 +107,23 @@ class SQLAlchemyDebugging:
reverse=True,
)
csv_filename = "tenants_by_num_docs.csv"
with open(csv_filename, "w") as csvfile:
with open(filename, "w") as csvfile:
writer = csv.writer(csvfile)
writer.writerow(["tenant_id", "num_docs", "num_chunks"]) # Write header
writer.writerow(
["tenant_id", "first_user_email", "num_user", "num_docs", "num_chunks"]
) # Write header
# Write data rows (using the sorted list)
for tenant_id, metadata in sorted_tenants:
writer.writerow([tenant_id, metadata.num_docs, metadata.num_chunks])
logger.info(f"Successfully wrote statistics to {csv_filename}")
writer.writerow(
[
tenant_id,
metadata.first_email,
metadata.user_count,
metadata.num_docs,
metadata.num_chunks,
]
)
logger.info(f"Successfully wrote statistics to {filename}")
# output top k by chunks
top_k_tenants = heapq.nlargest(
@@ -118,6 +143,14 @@ def main() -> None:
parser.add_argument("--report", help="Generate the given report")
parser.add_argument(
"--filename",
type=str,
default="tenants_by_num_docs.csv",
help="Generate the given report",
required=False,
)
args = parser.parse_args()
logger.info(f"{args}")
@@ -140,7 +173,7 @@ def main() -> None:
debugger = SQLAlchemyDebugging()
if args.report == "top-chunks":
debugger.top_chunks(10)
debugger.top_chunks(args.filename, 10)
else:
logger.info("No action.")

View File

@@ -0,0 +1,77 @@
import argparse
import requests
API_SERVER_URL = "http://localhost:3000"
API_KEY = "onyx-api-key" # API key here, if auth is enabled
def resume_paused_connectors(
api_server_url: str,
api_key: str | None,
specific_connector_sources: list[str] | None = None,
) -> None:
headers = {"Content-Type": "application/json"}
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
# Get all paused connectors
response = requests.get(
f"{api_server_url}/api/manage/admin/connector/indexing-status",
headers=headers,
)
response.raise_for_status()
# Convert the response to a list of ConnectorIndexingStatus objects
connectors = [cc_pair for cc_pair in response.json()]
# If a specific connector is provided, filter the connectors to only include that one
if specific_connector_sources:
connectors = [
connector
for connector in connectors
if connector["connector"]["source"] in specific_connector_sources
]
for connector in connectors:
if connector["cc_pair_status"] == "PAUSED":
print(f"Resuming connector: {connector['name']}")
response = requests.put(
f"{api_server_url}/api/manage/admin/cc-pair/{connector['cc_pair_id']}/status",
json={"status": "ACTIVE"},
headers=headers,
)
response.raise_for_status()
print(f"Resumed connector: {connector['name']}")
else:
print(f"Connector {connector['name']} is not paused")
def main() -> None:
parser = argparse.ArgumentParser(description="Resume paused connectors")
parser.add_argument(
"--api_server_url",
type=str,
default=API_SERVER_URL,
help="The URL of the API server to use. If not provided, will use the default.",
)
parser.add_argument(
"--api_key",
type=str,
default=None,
help="The API key to use for authentication. If not provided, no authentication will be used.",
)
parser.add_argument(
"--connector_sources",
type=str.lower,
nargs="+",
help="The sources of the connectors to resume. If not provided, will resume all paused connectors.",
)
args = parser.parse_args()
resume_paused_connectors(args.api_server_url, args.api_key, args.connector_sources)
if __name__ == "__main__":
main()

View File

@@ -30,25 +30,48 @@ def test_github_connector_basic(github_connector: GithubConnector) -> None:
start=0,
end=time.time(),
)
assert len(docs) > 0 # We expect at least one PR to exist
assert len(docs) > 1 # We expect at least one PR and one Issue to exist
# Test the first document's structure
doc = docs[0]
pr_doc = docs[0]
issue_doc = docs[-1]
# Verify basic document properties
assert doc.source == DocumentSource.GITHUB
assert doc.secondary_owners is None
assert doc.from_ingestion_api is False
assert doc.additional_info is None
assert pr_doc.source == DocumentSource.GITHUB
assert pr_doc.secondary_owners is None
assert pr_doc.from_ingestion_api is False
assert pr_doc.additional_info is None
# Verify GitHub-specific properties
assert "github.com" in doc.id # Should be a GitHub URL
assert doc.metadata is not None
assert "state" in doc.metadata
assert "merged" in doc.metadata
assert "github.com" in pr_doc.id # Should be a GitHub URL
# Verify PR-specific properties
assert pr_doc.metadata is not None
assert pr_doc.metadata.get("object_type") == "PullRequest"
assert "id" in pr_doc.metadata
assert "merged" in pr_doc.metadata
assert "state" in pr_doc.metadata
assert "user" in pr_doc.metadata
assert "assignees" in pr_doc.metadata
assert pr_doc.metadata.get("repo") == "onyx-dot-app/documentation"
assert "num_commits" in pr_doc.metadata
assert "num_files_changed" in pr_doc.metadata
assert "labels" in pr_doc.metadata
assert "created_at" in pr_doc.metadata
# Verify Issue-specific properties
assert issue_doc.metadata is not None
assert issue_doc.metadata.get("object_type") == "Issue"
assert "id" in issue_doc.metadata
assert "state" in issue_doc.metadata
assert "user" in issue_doc.metadata
assert "assignees" in issue_doc.metadata
assert issue_doc.metadata.get("repo") == "onyx-dot-app/documentation"
assert "labels" in issue_doc.metadata
assert "created_at" in issue_doc.metadata
# Verify sections
assert len(doc.sections) == 1
section = doc.sections[0]
assert section.link == doc.id # Section link should match document ID
assert len(pr_doc.sections) == 1
section = pr_doc.sections[0]
assert section.link == pr_doc.id # Section link should match document ID
assert isinstance(section.text, str) # Should have some text content

View File

@@ -59,11 +59,19 @@ def test_jira_connector_basic(
assert story.source == DocumentSource.JIRA
assert story.metadata == {
"priority": "Medium",
"status": "Backlog",
"status": "Done",
"resolution": "Done",
"resolution_date": "2025-05-29T15:33:31.031-0700",
"reporter": "Chris Weaver",
"assignee": "Chris Weaver",
"issuetype": "Story",
"created": "2025-04-16T16:44:06.716-0700",
"reporter_email": "chris@onyx.app",
"assignee_email": "chris@onyx.app",
"project_name": "DailyConnectorTestProject",
"project": "AS",
"parent": "AS-4",
"updated": "2025-05-29T15:33:31.085-0700",
}
assert story.secondary_owners is None
assert story.title == "AS-3 test123small"
@@ -86,6 +94,11 @@ def test_jira_connector_basic(
"assignee": "Chris Weaver",
"issuetype": "Epic",
"created": "2025-04-16T16:55:53.068-0700",
"reporter_email": "founders@onyx.app",
"assignee_email": "chris@onyx.app",
"project_name": "DailyConnectorTestProject",
"project": "AS",
"updated": "2025-05-29T14:43:05.312-0700",
}
assert epic.secondary_owners is None
assert epic.title == "AS-4 EPIC"

View File

@@ -31,6 +31,7 @@ def slack_connector(
connector = SlackConnector(
channels=[channel] if channel else None,
channel_regex_enabled=False,
use_redis=False,
)
connector.client = mock_slack_client
connector.set_credentials_provider(credentials_provider=slack_credentials_provider)

View File

@@ -108,7 +108,7 @@ def azure_embedding_model() -> EmbeddingModel:
return EmbeddingModel(
server_host="localhost",
server_port=9000,
model_name="text-embedding-3-large",
model_name="text-embedding-3-small",
normalize=True,
query_prefix=None,
passage_prefix=None,

View File

@@ -60,6 +60,7 @@ class ChatSessionManager:
prompt_override: PromptOverride | None = None,
alternate_assistant_id: int | None = None,
use_existing_user_message: bool = False,
use_agentic_search: bool = False,
) -> StreamedResponse:
chat_message_req = CreateChatMessageRequest(
chat_session_id=chat_session_id,
@@ -76,6 +77,7 @@ class ChatSessionManager:
prompt_override=prompt_override,
alternate_assistant_id=alternate_assistant_id,
use_existing_user_message=use_existing_user_message,
use_agentic_search=use_agentic_search,
)
headers = (
@@ -175,3 +177,136 @@ class ChatSessionManager:
),
)
response.raise_for_status()
@staticmethod
def delete(
chat_session: DATestChatSession,
user_performing_action: DATestUser | None = None,
) -> bool:
"""
Delete a chat session and all its related records (messages, agent data, etc.)
Uses the default deletion method configured on the server.
Returns True if deletion was successful, False otherwise.
"""
response = requests.delete(
f"{API_SERVER_URL}/chat/delete-chat-session/{chat_session.id}",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
return response.ok
@staticmethod
def soft_delete(
chat_session: DATestChatSession,
user_performing_action: DATestUser | None = None,
) -> bool:
"""
Soft delete a chat session (marks as deleted but keeps in database).
Returns True if deletion was successful, False otherwise.
"""
# Since there's no direct API for soft delete, we'll use a query parameter approach
# or make a direct call with hard_delete=False parameter via a new endpoint
response = requests.delete(
f"{API_SERVER_URL}/chat/delete-chat-session/{chat_session.id}?hard_delete=false",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
return response.ok
@staticmethod
def hard_delete(
chat_session: DATestChatSession,
user_performing_action: DATestUser | None = None,
) -> bool:
"""
Hard delete a chat session (completely removes from database).
Returns True if deletion was successful, False otherwise.
"""
response = requests.delete(
f"{API_SERVER_URL}/chat/delete-chat-session/{chat_session.id}?hard_delete=true",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
return response.ok
@staticmethod
def verify_deleted(
chat_session: DATestChatSession,
user_performing_action: DATestUser | None = None,
) -> bool:
"""
Verify that a chat session has been deleted by attempting to retrieve it.
Returns True if the chat session is confirmed deleted, False if it still exists.
"""
response = requests.get(
f"{API_SERVER_URL}/chat/get-chat-session/{chat_session.id}",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
# Chat session should return 400 if it doesn't exist
return response.status_code == 400
@staticmethod
def verify_soft_deleted(
chat_session: DATestChatSession,
user_performing_action: DATestUser | None = None,
) -> bool:
"""
Verify that a chat session has been soft deleted (marked as deleted but still in DB).
Returns True if the chat session is soft deleted, False otherwise.
"""
# Try to get the chat session with include_deleted=true
response = requests.get(
f"{API_SERVER_URL}/chat/get-chat-session/{chat_session.id}?include_deleted=true",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
if response.status_code == 200:
# Chat exists, check if it's marked as deleted
chat_data = response.json()
return chat_data.get("deleted", False) is True
return False
@staticmethod
def verify_hard_deleted(
chat_session: DATestChatSession,
user_performing_action: DATestUser | None = None,
) -> bool:
"""
Verify that a chat session has been hard deleted (completely removed from DB).
Returns True if the chat session is hard deleted, False otherwise.
"""
# Try to get the chat session with include_deleted=true
response = requests.get(
f"{API_SERVER_URL}/chat/get-chat-session/{chat_session.id}?include_deleted=true",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
# For hard delete, even with include_deleted=true, the record should not exist
return response.status_code != 200

View File

@@ -29,7 +29,6 @@ class UserManager:
def create(
name: str | None = None,
email: str | None = None,
is_first_user: bool = False,
) -> DATestUser:
if name is None:
name = f"test{str(uuid4())}"
@@ -51,14 +50,14 @@ class UserManager:
)
response.raise_for_status()
role = UserRole.ADMIN if is_first_user else UserRole.BASIC
test_user = DATestUser(
id=response.json()["id"],
email=email,
password=password,
headers=deepcopy(GENERAL_HEADERS),
role=role,
# fill as basic for now, the `login_as_user` call will
# fill it in correctly
role=UserRole.BASIC,
is_active=True,
)
print(f"Created user {test_user.email}")
@@ -93,6 +92,17 @@ class UserManager:
# Set cookies in the headers
test_user.headers["Cookie"] = f"fastapiusersauth={session_cookie}; "
test_user.cookies = {"fastapiusersauth": session_cookie}
# Get user role from /me endpoint
me_response = requests.get(
url=f"{API_SERVER_URL}/me",
headers=test_user.headers,
cookies=test_user.cookies,
)
me_response.raise_for_status()
role = UserRole(me_response.json()["role"])
test_user.role = role
return test_user
@staticmethod

View File

@@ -16,6 +16,8 @@ from tests.integration.common_utils.reset import reset_all_multitenant
from tests.integration.common_utils.test_models import DATestUser
from tests.integration.common_utils.vespa import vespa_fixture
BASIC_USER_NAME = "basic_user"
def load_env_vars(env_file: str = ".env") -> None:
current_dir = os.path.dirname(os.path.abspath(__file__))
@@ -81,7 +83,7 @@ def new_admin_user(reset: None) -> DATestUser | None:
@pytest.fixture
def admin_user() -> DATestUser:
try:
user = UserManager.create(name=ADMIN_USER_NAME, is_first_user=True)
user = UserManager.create(name=ADMIN_USER_NAME)
# if there are other users for some reason, reset and try again
if not UserManager.is_role(user, UserRole.ADMIN):
@@ -115,6 +117,44 @@ def admin_user() -> DATestUser:
raise RuntimeError("Failed to create or login as admin user")
@pytest.fixture
def basic_user(
# make sure the admin user exists first to ensure this new user
# gets the BASIC role
admin_user: DATestUser,
) -> DATestUser:
try:
user = UserManager.create(name=BASIC_USER_NAME)
# Validate that the user has the BASIC role
if user.role != UserRole.BASIC:
raise RuntimeError(
f"Created user {BASIC_USER_NAME} does not have BASIC role"
)
return user
except Exception as e:
print(f"Failed to create basic user, trying to login as existing user: {e}")
# Try to login as existing basic user
user = UserManager.login_as_user(
DATestUser(
id="",
email=build_email(BASIC_USER_NAME),
password=DEFAULT_PASSWORD,
headers=GENERAL_HEADERS,
role=UserRole.BASIC,
is_active=True,
)
)
# Validate that the logged-in user has the BASIC role
if not UserManager.is_role(user, UserRole.BASIC):
raise RuntimeError(f"User {BASIC_USER_NAME} does not have BASIC role")
return user
@pytest.fixture
def reset_multitenant() -> None:
reset_all_multitenant()

View File

@@ -17,8 +17,7 @@ from slack_sdk.errors import SlackApiError
from onyx.connectors.slack.connector import default_msg_filter
from onyx.connectors.slack.connector import get_channel_messages
from onyx.connectors.slack.utils import make_paginated_slack_api_call_w_retries
from onyx.connectors.slack.utils import make_slack_api_call_w_retries
from onyx.connectors.slack.utils import make_paginated_slack_api_call
def _get_slack_channel_id(channel: dict[str, Any]) -> str:
@@ -40,7 +39,7 @@ def _get_non_general_channels(
channel_types.append("public_channel")
conversations: list[dict[str, Any]] = []
for result in make_paginated_slack_api_call_w_retries(
for result in make_paginated_slack_api_call(
slack_client.conversations_list,
exclude_archived=False,
types=channel_types,
@@ -64,7 +63,7 @@ def _clear_slack_conversation_members(
) -> None:
channel_id = _get_slack_channel_id(channel)
member_ids: list[str] = []
for result in make_paginated_slack_api_call_w_retries(
for result in make_paginated_slack_api_call(
slack_client.conversations_members,
channel=channel_id,
):
@@ -140,15 +139,13 @@ def _build_slack_channel_from_name(
if channel:
# If channel is provided, we rename it
channel_id = _get_slack_channel_id(channel)
channel_response = make_slack_api_call_w_retries(
slack_client.conversations_rename,
channel_response = slack_client.conversations_rename(
channel=channel_id,
name=channel_name,
)
else:
# Otherwise, we create a new channel
channel_response = make_slack_api_call_w_retries(
slack_client.conversations_create,
channel_response = slack_client.conversations_create(
name=channel_name,
is_private=is_private,
)
@@ -219,10 +216,13 @@ class SlackManager:
@staticmethod
def build_slack_user_email_id_map(slack_client: WebClient) -> dict[str, str]:
users_results = make_slack_api_call_w_retries(
users: list[dict[str, Any]] = []
for users_results in make_paginated_slack_api_call(
slack_client.users_list,
)
users: list[dict[str, Any]] = users_results.get("members", [])
):
users.extend(users_results.get("members", []))
user_email_id_map = {}
for user in users:
if not (email := user.get("profile", {}).get("email")):
@@ -253,8 +253,7 @@ class SlackManager:
slack_client: WebClient, channel: dict[str, Any], message: str
) -> None:
channel_id = _get_slack_channel_id(channel)
make_slack_api_call_w_retries(
slack_client.chat_postMessage,
slack_client.chat_postMessage(
channel=channel_id,
text=message,
)
@@ -274,7 +273,7 @@ class SlackManager:
) -> None:
channel_types = ["private_channel", "public_channel"]
channels: list[dict[str, Any]] = []
for result in make_paginated_slack_api_call_w_retries(
for result in make_paginated_slack_api_call(
slack_client.conversations_list,
exclude_archived=False,
types=channel_types,

View File

@@ -0,0 +1,429 @@
import pytest
from tests.integration.common_utils.managers.chat import ChatSessionManager
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
from tests.integration.common_utils.reset import reset_all
from tests.integration.common_utils.test_models import DATestLLMProvider
from tests.integration.common_utils.test_models import DATestUser
@pytest.fixture(scope="module", autouse=True)
def reset_for_module() -> None:
"""Reset all data once before running any tests in this module."""
reset_all()
@pytest.fixture
def llm_provider(admin_user: DATestUser) -> DATestLLMProvider:
return LLMProviderManager.create(user_performing_action=admin_user)
def test_soft_delete_chat_session(
basic_user: DATestUser, llm_provider: DATestLLMProvider
) -> None:
"""
Test soft deletion of a chat session.
Soft delete should mark the chat as deleted but keep it in the database.
"""
# Create a chat session
test_chat_session = ChatSessionManager.create(
persona_id=0, # Use default persona
description="Test chat session for soft deletion",
user_performing_action=basic_user,
)
# Send a message to create some data
response = ChatSessionManager.send_message(
chat_session_id=test_chat_session.id,
message="Explain the concept of machine learning in detail",
user_performing_action=basic_user,
)
# Verify that the message was processed successfully
assert len(response.full_message) > 0, "Chat response should not be empty"
# Verify that the chat session can be retrieved before deletion
chat_history = ChatSessionManager.get_chat_history(
chat_session=test_chat_session,
user_performing_action=basic_user,
)
assert len(chat_history) > 0, "Chat session should have messages"
# Test soft deletion of the chat session
deletion_success = ChatSessionManager.soft_delete(
chat_session=test_chat_session,
user_performing_action=basic_user,
)
# Verify that the deletion was successful
assert deletion_success, "Chat session soft deletion should succeed"
# Verify that the chat session is soft deleted (marked as deleted but still in DB)
assert ChatSessionManager.verify_soft_deleted(
chat_session=test_chat_session,
user_performing_action=basic_user,
), "Chat session should be soft deleted"
# Verify that normal access is blocked
assert ChatSessionManager.verify_deleted(
chat_session=test_chat_session,
user_performing_action=basic_user,
), "Chat session should not be accessible normally after soft delete"
def test_hard_delete_chat_session(
basic_user: DATestUser, llm_provider: DATestLLMProvider
) -> None:
"""
Test hard deletion of a chat session.
Hard delete should completely remove the chat from the database.
"""
# Create a chat session
test_chat_session = ChatSessionManager.create(
persona_id=0, # Use default persona
description="Test chat session for hard deletion",
user_performing_action=basic_user,
)
# Send a message to create some data
response = ChatSessionManager.send_message(
chat_session_id=test_chat_session.id,
message="Explain the concept of machine learning in detail",
user_performing_action=basic_user,
)
# Verify that the message was processed successfully
assert len(response.full_message) > 0, "Chat response should not be empty"
# Verify that the chat session can be retrieved before deletion
chat_history = ChatSessionManager.get_chat_history(
chat_session=test_chat_session,
user_performing_action=basic_user,
)
assert len(chat_history) > 0, "Chat session should have messages"
# Test hard deletion of the chat session
deletion_success = ChatSessionManager.hard_delete(
chat_session=test_chat_session,
user_performing_action=basic_user,
)
# Verify that the deletion was successful
assert deletion_success, "Chat session hard deletion should succeed"
# Verify that the chat session is hard deleted (completely removed from DB)
assert ChatSessionManager.verify_hard_deleted(
chat_session=test_chat_session,
user_performing_action=basic_user,
), "Chat session should be hard deleted"
# Verify that the chat session is not accessible at all
assert ChatSessionManager.verify_deleted(
chat_session=test_chat_session,
user_performing_action=basic_user,
), "Chat session should not be accessible after hard delete"
# Verify it's not soft deleted (since it doesn't exist at all)
assert not ChatSessionManager.verify_soft_deleted(
chat_session=test_chat_session,
user_performing_action=basic_user,
), "Hard deleted chat should not be found as soft deleted"
def test_soft_delete_with_agentic_search(
basic_user: DATestUser, llm_provider: DATestLLMProvider
) -> None:
"""
Test soft deletion of a chat session with agent behavior (sub-questions and sub-queries).
Verifies that soft delete preserves all related agent records in the database.
"""
# Create a chat session
test_chat_session = ChatSessionManager.create(
persona_id=0,
description="Test agentic search soft deletion",
user_performing_action=basic_user,
)
# Send a message using ChatSessionManager with agentic search enabled
# This will create AgentSubQuestion and AgentSubQuery records
response = ChatSessionManager.send_message(
chat_session_id=test_chat_session.id,
message="What are the key principles of software engineering?",
user_performing_action=basic_user,
use_agentic_search=True,
)
# Verify that the message was processed successfully
assert len(response.full_message) > 0, "Chat response should not be empty"
# Test soft deletion
deletion_success = ChatSessionManager.soft_delete(
chat_session=test_chat_session,
user_performing_action=basic_user,
)
# Verify successful soft deletion
assert deletion_success, "Chat soft deletion should succeed"
# Verify chat session is soft deleted
assert ChatSessionManager.verify_soft_deleted(
chat_session=test_chat_session,
user_performing_action=basic_user,
), "Soft deleted chat session should be marked as deleted in DB"
# Verify chat session is not accessible normally
assert ChatSessionManager.verify_deleted(
chat_session=test_chat_session,
user_performing_action=basic_user,
), "Soft deleted chat session should not be accessible"
def test_hard_delete_with_agentic_search(
basic_user: DATestUser, llm_provider: DATestLLMProvider
) -> None:
"""
Test hard deletion of a chat session with agent behavior (sub-questions and sub-queries).
Verifies that hard delete removes all related agent records from the database via CASCADE.
"""
# Create a chat session
test_chat_session = ChatSessionManager.create(
persona_id=0,
description="Test agentic search hard deletion",
user_performing_action=basic_user,
)
# Send a message using ChatSessionManager with agentic search enabled
# This will create AgentSubQuestion and AgentSubQuery records
response = ChatSessionManager.send_message(
chat_session_id=test_chat_session.id,
message="What are the key principles of software engineering?",
user_performing_action=basic_user,
use_agentic_search=True,
)
# Verify that the message was processed successfully
assert len(response.full_message) > 0, "Chat response should not be empty"
# Test hard deletion
deletion_success = ChatSessionManager.hard_delete(
chat_session=test_chat_session,
user_performing_action=basic_user,
)
# Verify successful hard deletion
assert deletion_success, "Chat hard deletion should succeed"
# Verify chat session is hard deleted (completely removed)
assert ChatSessionManager.verify_hard_deleted(
chat_session=test_chat_session,
user_performing_action=basic_user,
), "Hard deleted chat session should be completely removed from DB"
# Verify chat session is not accessible
assert ChatSessionManager.verify_deleted(
chat_session=test_chat_session,
user_performing_action=basic_user,
), "Hard deleted chat session should not be accessible"
def test_multiple_soft_deletions(
basic_user: DATestUser, llm_provider: DATestLLMProvider
) -> None:
"""
Test multiple chat session soft deletions to ensure proper handling
when there are multiple related records.
"""
chat_sessions = []
# Create multiple chat sessions with potential agent behavior
for i in range(3):
chat_session = ChatSessionManager.create(
persona_id=0,
description=f"Test chat session {i} for multi-soft-deletion",
user_performing_action=basic_user,
)
# Send a message to create some data
ChatSessionManager.send_message(
chat_session_id=chat_session.id,
message=f"Tell me about topic {i} with detailed analysis",
user_performing_action=basic_user,
)
chat_sessions.append(chat_session)
# Soft delete all chat sessions
for chat_session in chat_sessions:
deletion_success = ChatSessionManager.soft_delete(
chat_session=chat_session,
user_performing_action=basic_user,
)
assert deletion_success, f"Failed to soft delete chat {chat_session.id}"
# Verify all chat sessions are soft deleted
for chat_session in chat_sessions:
assert ChatSessionManager.verify_soft_deleted(
chat_session=chat_session,
user_performing_action=basic_user,
), f"Chat {chat_session.id} should be soft deleted"
assert ChatSessionManager.verify_deleted(
chat_session=chat_session,
user_performing_action=basic_user,
), f"Chat {chat_session.id} should not be accessible normally"
def test_multiple_hard_deletions_with_agent_data(
basic_user: DATestUser, llm_provider: DATestLLMProvider
) -> None:
"""
Test multiple chat session hard deletions to ensure CASCADE deletes work correctly
when there are multiple related records.
"""
chat_sessions = []
# Create multiple chat sessions with potential agent behavior
for i in range(3):
chat_session = ChatSessionManager.create(
persona_id=0,
description=f"Test chat session {i} for multi-hard-deletion",
user_performing_action=basic_user,
)
# Send a message to create some data
ChatSessionManager.send_message(
chat_session_id=chat_session.id,
message=f"Tell me about topic {i} with detailed analysis",
user_performing_action=basic_user,
)
chat_sessions.append(chat_session)
# Hard delete all chat sessions
for chat_session in chat_sessions:
deletion_success = ChatSessionManager.hard_delete(
chat_session=chat_session,
user_performing_action=basic_user,
)
assert deletion_success, f"Failed to hard delete chat {chat_session.id}"
# Verify all chat sessions are hard deleted
for chat_session in chat_sessions:
assert ChatSessionManager.verify_hard_deleted(
chat_session=chat_session,
user_performing_action=basic_user,
), f"Chat {chat_session.id} should be hard deleted"
assert ChatSessionManager.verify_deleted(
chat_session=chat_session,
user_performing_action=basic_user,
), f"Chat {chat_session.id} should not be accessible"
def test_soft_vs_hard_delete_edge_cases(
basic_user: DATestUser, llm_provider: DATestLLMProvider
) -> None:
"""
Test edge cases for both soft and hard deletion to ensure robustness.
"""
# Test 1: Soft delete a chat session with no messages
empty_chat_session_soft = ChatSessionManager.create(
persona_id=0,
description="Empty chat session for soft delete",
user_performing_action=basic_user,
)
# Soft delete without sending any messages
deletion_success = ChatSessionManager.soft_delete(
chat_session=empty_chat_session_soft,
user_performing_action=basic_user,
)
assert deletion_success, "Empty chat session should be soft deletable"
assert ChatSessionManager.verify_soft_deleted(
chat_session=empty_chat_session_soft,
user_performing_action=basic_user,
), "Empty chat session should be confirmed as soft deleted"
# Test 2: Hard delete a chat session with no messages
empty_chat_session_hard = ChatSessionManager.create(
persona_id=0,
description="Empty chat session for hard delete",
user_performing_action=basic_user,
)
# Hard delete without sending any messages
deletion_success = ChatSessionManager.hard_delete(
chat_session=empty_chat_session_hard,
user_performing_action=basic_user,
)
assert deletion_success, "Empty chat session should be hard deletable"
assert ChatSessionManager.verify_hard_deleted(
chat_session=empty_chat_session_hard,
user_performing_action=basic_user,
), "Empty chat session should be confirmed as hard deleted"
# Test 3: Soft delete a chat session with multiple messages
multi_message_chat_soft = ChatSessionManager.create(
persona_id=0,
description="Multi-message chat session for soft delete",
user_performing_action=basic_user,
)
# Send multiple messages to create more complex data
for i in range(3):
ChatSessionManager.send_message(
chat_session_id=multi_message_chat_soft.id,
message=f"Message {i}: Tell me about different aspects of this topic",
user_performing_action=basic_user,
)
# Verify messages exist
chat_history = ChatSessionManager.get_chat_history(
chat_session=multi_message_chat_soft,
user_performing_action=basic_user,
)
assert len(chat_history) >= 3, "Chat should have multiple messages"
# Soft delete the chat with multiple messages
deletion_success = ChatSessionManager.soft_delete(
chat_session=multi_message_chat_soft,
user_performing_action=basic_user,
)
assert deletion_success, "Multi-message chat session should be soft deletable"
assert ChatSessionManager.verify_soft_deleted(
chat_session=multi_message_chat_soft,
user_performing_action=basic_user,
), "Multi-message chat session should be confirmed as soft deleted"
# Test 4: Hard delete a chat session with multiple messages
multi_message_chat_hard = ChatSessionManager.create(
persona_id=0,
description="Multi-message chat session for hard delete",
user_performing_action=basic_user,
)
# Send multiple messages to create more complex data
for i in range(3):
ChatSessionManager.send_message(
chat_session_id=multi_message_chat_hard.id,
message=f"Message {i}: Tell me about different aspects of this topic",
user_performing_action=basic_user,
)
# Verify messages exist
chat_history = ChatSessionManager.get_chat_history(
chat_session=multi_message_chat_hard,
user_performing_action=basic_user,
)
assert len(chat_history) >= 3, "Chat should have multiple messages"
# Hard delete the chat with multiple messages
deletion_success = ChatSessionManager.hard_delete(
chat_session=multi_message_chat_hard,
user_performing_action=basic_user,
)
assert deletion_success, "Multi-message chat session should be hard deletable"
assert ChatSessionManager.verify_hard_deleted(
chat_session=multi_message_chat_hard,
user_performing_action=basic_user,
), "Multi-message chat session should be confirmed as hard deleted"

View File

@@ -58,7 +58,6 @@ def test_index_attempt_pagination(reset: None) -> None:
# Create an admin user to perform actions
user_performing_action: DATestUser = UserManager.create(
name="admin_performing_action",
is_first_user=True,
)
# Create a CC pair to attach index attempts to

View File

@@ -46,8 +46,7 @@ def _verify_user_pagination(
def test_user_pagination(reset: None) -> None:
# Create an admin user to perform actions
user_performing_action: DATestUser = UserManager.create(
name="admin_performing_action",
is_first_user=True,
name="admin_performing_action"
)
# Create 9 admin users

View File

@@ -842,7 +842,7 @@ def test_load_from_checkpoint_cursor_pagination_completion(
assert all(isinstance(item, Document) for item in outputs[1].items)
assert {
item.semantic_identifier for item in cast(list[Document], outputs[1].items)
} == {"PR 3 Repo 2", "PR 4 Repo 2"}
} == {"3: PR 3 Repo 2", "4: PR 4 Repo 2"}
cp1 = outputs[1].next_checkpoint
assert (
cp1.has_more
@@ -869,7 +869,7 @@ def test_load_from_checkpoint_cursor_pagination_completion(
assert all(isinstance(item, Document) for item in outputs[3].items)
assert {
item.semantic_identifier for item in cast(list[Document], outputs[3].items)
} == {"PR 1 Repo 1", "PR 2 Repo 1"}
} == {"1: PR 1 Repo 1", "2: PR 2 Repo 1"}
cp3 = outputs[3].next_checkpoint
# This checkpoint is returned early because offset had items. has_more reflects state then.
assert cp3.has_more # still need to do issues

View File

@@ -0,0 +1,33 @@
from onyx.llm.utils import model_is_reasoning_model
def test_model_is_reasoning_model() -> None:
"""Test that reasoning models are correctly identified and non-reasoning models are not"""
# Models that should be identified as reasoning models
reasoning_models = [
("o3", "openai"),
("o3-mini", "openai"),
("o4-mini", "openai"),
("deepseek-reasoner", "deepseek"),
("deepseek-r1", "openrouter/deepseek"),
("claude-sonnet-4-20250514", "anthropic"),
]
# Models that should NOT be identified as reasoning models
non_reasoning_models = [
("gpt-4o", "openai"),
("claude-3-5-sonnet-20240620", "anthropic"),
]
# Test reasoning models
for model_name, provider in reasoning_models:
assert (
model_is_reasoning_model(model_name, provider) is True
), f"Expected {provider}/{model_name} to be identified as a reasoning model"
# Test non-reasoning models
for model_name, provider in non_reasoning_models:
assert (
model_is_reasoning_model(model_name, provider) is False
), f"Expected {provider}/{model_name} to NOT be identified as a reasoning model"

View File

@@ -4,6 +4,7 @@ from unittest.mock import patch
import pytest
from onyx.llm.llm_provider_options import ANTHROPIC_PROVIDER_NAME
from onyx.llm.llm_provider_options import BEDROCK_PROVIDER_NAME
from onyx.tools.utils import explicit_tool_calling_supported
@@ -32,6 +33,40 @@ from onyx.tools.utils import explicit_tool_calling_supported
# === Anthropic Scenarios (expected False due to base support being False) ===
# Provider is Anthropic, base model does NOT claim FC support
(ANTHROPIC_PROVIDER_NAME, "claude-2.1", False, [], False),
# === Bedrock Scenarios ===
# Bedrock provider with model name containing anthropic model name as substring -> False
(
BEDROCK_PROVIDER_NAME,
"anthropic.claude-3-opus-20240229-v1:0",
True,
["claude-3-opus-20240229"],
False,
),
# Bedrock provider with model name containing different anthropic model name as substring -> False
(
BEDROCK_PROVIDER_NAME,
"aws-anthropic-claude-3-haiku-20240307",
True,
["claude-3-haiku-20240307"],
False,
),
# Bedrock provider with model name NOT containing any anthropic model name as substring -> True
(
BEDROCK_PROVIDER_NAME,
"amazon.titan-text-express-v1",
True,
["claude-3-opus-20240229", "claude-3-haiku-20240307"],
True,
),
# Bedrock provider with model name NOT containing any anthropic model
# name as substring, but base model doesn't support FC -> False
(
BEDROCK_PROVIDER_NAME,
"amazon.titan-text-express-v1",
False,
["claude-3-opus-20240229", "claude-3-haiku-20240307"],
False,
),
# === Non-Anthropic Scenarios ===
# Non-Anthropic provider, base model claims FC support -> True
("openai", "gpt-4o", True, [], True),
@@ -73,6 +108,9 @@ def test_explicit_tool_calling_supported(
We don't want to provide that list of tools because our UI doesn't support sequential
tool calling yet for (a) and just looks bad for (b), so for now we just treat anthropic
models as non-tool-calling.
Additionally, for Bedrock provider, any model containing an anthropic model name as a
substring should also return False for the same reasons.
"""
mock_find_model_obj.return_value = {
"supports_function_calling": mock_model_supports_fc

View File

@@ -9,7 +9,8 @@ chart-repos:
- vespa=https://onyx-dot-app.github.io/vespa-helm-charts
- postgresql=https://charts.bitnami.com/bitnami
helm-extra-args: --debug --timeout 600s
# have seen postgres take 10 min to pull ... so 15 min seems like a good timeout?
helm-extra-args: --debug --timeout 900s
# nginx appears to not work on kind, likely due to lack of loadbalancer support
# helm-extra-set-args also only works on the command line, not in this yaml

View File

@@ -131,7 +131,7 @@ Resources:
OperatingSystemFamily: LINUX
ContainerDefinitions:
- Name: vespaengine
Image: vespaengine/vespa:8.277.17
Image: vespaengine/vespa:8.526.15
Cpu: 0
Essential: true
PortMappings:
@@ -162,7 +162,9 @@ Resources:
awslogs-region: !Ref AWS::Region
awslogs-stream-prefix: ecs
User: "1000"
Environment: []
Environment:
- Name: VESPA_SKIP_UPGRADE_CHECK
Value: "true"
VolumesFrom: []
SystemControls: []
Volumes:

View File

@@ -378,6 +378,7 @@ services:
relational_db:
image: postgres:15.2-alpine
shm_size: 1g
command: -c 'max_connections=250'
restart: always
environment:
@@ -390,8 +391,10 @@ services:
# This container name cannot have an underscore in it due to Vespa expectations of the URL
index:
image: vespaengine/vespa:8.277.17
image: vespaengine/vespa:8.526.15
restart: always
environment:
- VESPA_SKIP_UPGRADE_CHECK=true
ports:
- "19071:19071"
- "8081:8081"

View File

@@ -324,6 +324,7 @@ services:
relational_db:
image: postgres:15.2-alpine
shm_size: 1g
command: -c 'max_connections=250'
restart: always
environment:
@@ -336,8 +337,10 @@ services:
# This container name cannot have an underscore in it due to Vespa expectations of the URL
index:
image: vespaengine/vespa:8.277.17
image: vespaengine/vespa:8.526.15
restart: always
environment:
- VESPA_SKIP_UPGRADE_CHECK=true
ports:
- "19071:19071"
- "8081:8081"

View File

@@ -351,6 +351,7 @@ services:
relational_db:
image: postgres:15.2-alpine
shm_size: 1g
command: -c 'max_connections=250'
restart: always
environment:
@@ -363,8 +364,10 @@ services:
# This container name cannot have an underscore in it due to Vespa expectations of the URL
index:
image: vespaengine/vespa:8.277.17
image: vespaengine/vespa:8.526.15
restart: always
environment:
- VESPA_SKIP_UPGRADE_CHECK=true
ports:
- "19071:19071"
- "8081:8081"

Some files were not shown because too many files have changed in this diff Show More