Compare commits

...

22 Commits

Author SHA1 Message Date
pablonyx
9523df353d update 2025-04-03 12:45:25 -07:00
pablonyx
93886f0e2c Assistant Prompt length + client side (#4433) 2025-04-03 11:26:53 -07:00
rkuo-danswer
8c3a953b7a add prometheus metrics endpoints via helper package (#4436)
* add prometheus metrics endpoints via helper package

* model server specific requirements

* mark as public endpoint

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-04-03 16:52:05 +00:00
evan-danswer
54b883d0ca fix large docs selected in chat pruning (#4412)
* fix large docs selected in chat pruning

* better approach to length restriction

* comments

* comments

* fix unit tests and minor pruning bug

* remove prints
2025-04-03 15:48:10 +00:00
pablonyx
91faac5447 minor fix (#4435) 2025-04-03 15:00:27 +00:00
Chris Weaver
1d8f9fc39d Fix weird re-index state (#4439)
* Fix weird re-index state

* Address rkuo's comments
2025-04-03 02:16:34 +00:00
Weves
9390de21e5 More logging on confluence space permissions 2025-04-02 20:01:38 -07:00
rkuo-danswer
3a33433fc9 unit tests for chunk censoring (#4434)
* unit tests for chunk censoring

* type hints for mypy

* pytestification

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-04-03 01:28:54 +00:00
Chris Weaver
c4865d57b1 Fix tons of users w/o drive access causing timeouts (#4437) 2025-04-03 00:01:05 +00:00
rkuo-danswer
81d04db08f Feature/request id middleware 2 (#4427)
* stubbing out request id

* passthru or create request id's in api and model server

* add onyx request id

* get request id logging into uvicorn

* no logs

* change prefixes

* fix comment

* docker image needs specific shared files

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-04-02 22:30:03 +00:00
rkuo-danswer
d50a17db21 add filter unit tests (#4421)
* add filter unit tests

* fix tests

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-04-02 20:26:25 +00:00
pablonyx
dc5a1e8fd0 add more flexible vision support check (#4429) 2025-04-02 18:11:33 +00:00
pablonyx
c0b3681650 update (#4428) 2025-04-02 18:09:44 +00:00
Chris Weaver
7ec04484d4 Another fix for Salesforce perm sync (#4432)
* Another fix for Salesforce perm sync

* typing
2025-04-02 11:08:40 -07:00
Weves
1cf966ecc1 Fix Salesforce perm sync 2025-04-02 10:47:26 -07:00
rkuo-danswer
8a8526dbbb harden join function (#4424)
* harden join function

* remove log spam

* use time.monotonic

* add pid logging

* client only celery app

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-04-02 01:04:00 -07:00
Weves
be20586ba1 Add retries for confluence calls 2025-04-01 23:00:37 -07:00
Weves
a314462d1e Fix migrations 2025-04-01 21:48:32 -07:00
rkuo-danswer
155f53c3d7 Revert "Add user invitation test (#4161)" (#4422)
This reverts commit 806de92feb.

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-04-01 19:55:04 -07:00
pablonyx
7c027df186 Fix cc pair doc deletion (#4420) 2025-04-01 18:44:15 -07:00
pablonyx
0a5db96026 update (#4415) 2025-04-02 00:42:42 +00:00
joachim-danswer
daef985b02 Simpler approach (#4414) 2025-04-01 16:52:59 -07:00
70 changed files with 1384 additions and 544 deletions

View File

@@ -46,6 +46,7 @@ WORKDIR /app
# Utils used by model server
COPY ./onyx/utils/logger.py /app/onyx/utils/logger.py
COPY ./onyx/utils/middleware.py /app/onyx/utils/middleware.py
# Place to fetch version information
COPY ./onyx/__init__.py /app/onyx/__init__.py

View File

@@ -0,0 +1,50 @@
"""update prompt length
Revision ID: 4794bc13e484
Revises: f7505c5b0284
Create Date: 2025-04-02 11:26:36.180328
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "4794bc13e484"
down_revision = "f7505c5b0284"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.alter_column(
"prompt",
"system_prompt",
existing_type=sa.TEXT(),
type_=sa.String(length=5000000),
existing_nullable=False,
)
op.alter_column(
"prompt",
"task_prompt",
existing_type=sa.TEXT(),
type_=sa.String(length=5000000),
existing_nullable=False,
)
def downgrade() -> None:
op.alter_column(
"prompt",
"system_prompt",
existing_type=sa.String(length=5000000),
type_=sa.TEXT(),
existing_nullable=False,
)
op.alter_column(
"prompt",
"task_prompt",
existing_type=sa.String(length=5000000),
type_=sa.TEXT(),
existing_nullable=False,
)

View File

@@ -0,0 +1,50 @@
"""add prompt length limit
Revision ID: f71470ba9274
Revises: 6a804aeb4830
Create Date: 2025-04-01 15:07:14.977435
"""
# revision identifiers, used by Alembic.
revision = "f71470ba9274"
down_revision = "6a804aeb4830"
branch_labels = None
depends_on = None
def upgrade() -> None:
# op.alter_column(
# "prompt",
# "system_prompt",
# existing_type=sa.TEXT(),
# type_=sa.String(length=8000),
# existing_nullable=False,
# )
# op.alter_column(
# "prompt",
# "task_prompt",
# existing_type=sa.TEXT(),
# type_=sa.String(length=8000),
# existing_nullable=False,
# )
pass
def downgrade() -> None:
# op.alter_column(
# "prompt",
# "system_prompt",
# existing_type=sa.String(length=8000),
# type_=sa.TEXT(),
# existing_nullable=False,
# )
# op.alter_column(
# "prompt",
# "task_prompt",
# existing_type=sa.String(length=8000),
# type_=sa.TEXT(),
# existing_nullable=False,
# )
pass

View File

@@ -0,0 +1,77 @@
"""updated constraints for ccpairs
Revision ID: f7505c5b0284
Revises: f71470ba9274
Create Date: 2025-04-01 17:50:42.504818
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "f7505c5b0284"
down_revision = "f71470ba9274"
branch_labels = None
depends_on = None
def upgrade() -> None:
# 1) Drop the old foreign-key constraints
op.drop_constraint(
"document_by_connector_credential_pair_connector_id_fkey",
"document_by_connector_credential_pair",
type_="foreignkey",
)
op.drop_constraint(
"document_by_connector_credential_pair_credential_id_fkey",
"document_by_connector_credential_pair",
type_="foreignkey",
)
# 2) Re-add them with ondelete='CASCADE'
op.create_foreign_key(
"document_by_connector_credential_pair_connector_id_fkey",
source_table="document_by_connector_credential_pair",
referent_table="connector",
local_cols=["connector_id"],
remote_cols=["id"],
ondelete="CASCADE",
)
op.create_foreign_key(
"document_by_connector_credential_pair_credential_id_fkey",
source_table="document_by_connector_credential_pair",
referent_table="credential",
local_cols=["credential_id"],
remote_cols=["id"],
ondelete="CASCADE",
)
def downgrade() -> None:
# Reverse the changes for rollback
op.drop_constraint(
"document_by_connector_credential_pair_connector_id_fkey",
"document_by_connector_credential_pair",
type_="foreignkey",
)
op.drop_constraint(
"document_by_connector_credential_pair_credential_id_fkey",
"document_by_connector_credential_pair",
type_="foreignkey",
)
# Recreate without CASCADE
op.create_foreign_key(
"document_by_connector_credential_pair_connector_id_fkey",
"document_by_connector_credential_pair",
"connector",
["connector_id"],
["id"],
)
op.create_foreign_key(
"document_by_connector_credential_pair_credential_id_fkey",
"document_by_connector_credential_pair",
"credential",
["credential_id"],
["id"],
)

View File

@@ -159,6 +159,9 @@ def _get_space_permissions(
# Stores the permissions for each space
space_permissions_by_space_key[space_key] = space_permissions
logger.info(
f"Found space permissions for space '{space_key}': {space_permissions}"
)
return space_permissions_by_space_key

View File

@@ -55,7 +55,7 @@ def _post_query_chunk_censoring(
# if user is None, permissions are not enforced
return chunks
chunks_to_keep = []
final_chunk_dict: dict[str, InferenceChunk] = {}
chunks_to_process: dict[DocumentSource, list[InferenceChunk]] = {}
sources_to_censor = _get_all_censoring_enabled_sources()
@@ -64,7 +64,7 @@ def _post_query_chunk_censoring(
if chunk.source_type in sources_to_censor:
chunks_to_process.setdefault(chunk.source_type, []).append(chunk)
else:
chunks_to_keep.append(chunk)
final_chunk_dict[chunk.unique_id] = chunk
# For each source, filter out the chunks using the permission
# check function for that source
@@ -79,6 +79,16 @@ def _post_query_chunk_censoring(
f" chunks for this source and continuing: {e}"
)
continue
chunks_to_keep.extend(censored_chunks)
return chunks_to_keep
for censored_chunk in censored_chunks:
final_chunk_dict[censored_chunk.unique_id] = censored_chunk
# IMPORTANT: make sure to retain the same ordering as the original `chunks` passed in
final_chunk_list: list[InferenceChunk] = []
for chunk in chunks:
# only if the chunk is in the final censored chunks, add it to the final list
# if it is missing, that means it was intentionally left out
if chunk.unique_id in final_chunk_dict:
final_chunk_list.append(final_chunk_dict[chunk.unique_id])
return final_chunk_list

View File

@@ -42,11 +42,18 @@ def get_any_salesforce_client_for_doc_id(
def _query_salesforce_user_id(sf_client: Salesforce, user_email: str) -> str | None:
query = f"SELECT Id FROM User WHERE Email = '{user_email}'"
query = f"SELECT Id FROM User WHERE Username = '{user_email}' AND IsActive = true"
result = sf_client.query(query)
if len(result["records"]) == 0:
return None
return result["records"][0]["Id"]
if len(result["records"]) > 0:
return result["records"][0]["Id"]
# try emails
query = f"SELECT Id FROM User WHERE Email = '{user_email}' AND IsActive = true"
result = sf_client.query(query)
if len(result["records"]) > 0:
return result["records"][0]["Id"]
return None
# This contains only the user_ids that we have found in Salesforce.

View File

@@ -1,3 +1,4 @@
import logging
import os
import shutil
from collections.abc import AsyncGenerator
@@ -8,6 +9,7 @@ import sentry_sdk
import torch
import uvicorn
from fastapi import FastAPI
from prometheus_fastapi_instrumentator import Instrumentator
from sentry_sdk.integrations.fastapi import FastApiIntegration
from sentry_sdk.integrations.starlette import StarletteIntegration
from transformers import logging as transformer_logging # type:ignore
@@ -20,6 +22,8 @@ from model_server.management_endpoints import router as management_router
from model_server.utils import get_gpu_type
from onyx import __version__
from onyx.utils.logger import setup_logger
from onyx.utils.logger import setup_uvicorn_logger
from onyx.utils.middleware import add_onyx_request_id_middleware
from shared_configs.configs import INDEXING_ONLY
from shared_configs.configs import MIN_THREADS_ML_MODELS
from shared_configs.configs import MODEL_SERVER_ALLOWED_HOST
@@ -36,6 +40,12 @@ transformer_logging.set_verbosity_error()
logger = setup_logger()
file_handlers = [
h for h in logger.logger.handlers if isinstance(h, logging.FileHandler)
]
setup_uvicorn_logger(shared_file_handlers=file_handlers)
def _move_files_recursively(source: Path, dest: Path, overwrite: bool = False) -> None:
"""
@@ -112,6 +122,15 @@ def get_model_app() -> FastAPI:
application.include_router(encoders_router)
application.include_router(custom_models_router)
request_id_prefix = "INF"
if INDEXING_ONLY:
request_id_prefix = "IDX"
add_onyx_request_id_middleware(application, request_id_prefix, logger)
# Initialize and instrument the app
Instrumentator().instrument(application).expose(application)
return application

View File

@@ -15,6 +15,22 @@ class ExternalAccess:
# Whether the document is public in the external system or Onyx
is_public: bool
def __str__(self) -> str:
"""Prevent extremely long logs"""
def truncate_set(s: set[str], max_len: int = 100) -> str:
s_str = str(s)
if len(s_str) > max_len:
return f"{s_str[:max_len]}... ({len(s)} items)"
return s_str
return (
f"ExternalAccess("
f"external_user_emails={truncate_set(self.external_user_emails)}, "
f"external_user_group_ids={truncate_set(self.external_user_group_ids)}, "
f"is_public={self.is_public})"
)
@dataclass(frozen=True)
class DocExternalAccess:

View File

@@ -1,5 +1,6 @@
import logging
import multiprocessing
import os
import time
from typing import Any
from typing import cast
@@ -305,7 +306,7 @@ def wait_for_db(sender: Any, **kwargs: Any) -> None:
def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
logger.info("Running as a secondary celery worker.")
logger.info(f"Running as a secondary celery worker: pid={os.getpid()}")
# Set up variables for waiting on primary worker
WAIT_INTERVAL = 5

View File

@@ -0,0 +1,7 @@
from celery import Celery
import onyx.background.celery.apps.app_base as app_base
celery_app = Celery(__name__)
celery_app.config_from_object("onyx.background.celery.configs.client")
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]

View File

@@ -1,4 +1,5 @@
import logging
import os
from typing import Any
from typing import cast
@@ -95,7 +96,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
app_base.wait_for_db(sender, **kwargs)
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
logger.info("Running as the primary celery worker.")
logger.info(f"Running as the primary celery worker: pid={os.getpid()}")
# Less startup checks in multi-tenant case
if MULTI_TENANT:

View File

@@ -0,0 +1,16 @@
import onyx.background.celery.configs.base as shared_config
broker_url = shared_config.broker_url
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
broker_pool_limit = shared_config.broker_pool_limit
broker_transport_options = shared_config.broker_transport_options
redis_socket_keepalive = shared_config.redis_socket_keepalive
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
result_backend = shared_config.result_backend
result_expires = shared_config.result_expires # 86400 seconds is the default
task_default_priority = shared_config.task_default_priority
task_acks_late = shared_config.task_acks_late

View File

@@ -0,0 +1,20 @@
"""Factory stub for running celery worker / celery beat.
This code is different from the primary/beat stubs because there is no EE version to
fetch. Port over the code in those files if we add an EE version of this worker.
This is an app stub purely for sending tasks as a client.
"""
from celery import Celery
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
set_is_ee_based_on_env_variable()
def get_app() -> Celery:
from onyx.background.celery.apps.client import celery_app
return celery_app
app = get_app()

View File

@@ -43,6 +43,7 @@ from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_me
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from onyx.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from onyx.configs.chat_configs import SELECTED_SECTIONS_MAX_WINDOW_PERCENTAGE
from onyx.configs.constants import AGENT_SEARCH_INITIAL_KEY
from onyx.configs.constants import BASIC_KEY
from onyx.configs.constants import MessageType
@@ -692,8 +693,13 @@ def stream_chat_message_objects(
doc_identifiers=identifier_tuples,
document_index=document_index,
)
# Add a maximum context size in the case of user-selected docs to prevent
# slight inaccuracies in context window size pruning from causing
# the entire query to fail
document_pruning_config = DocumentPruningConfig(
is_manually_selected_docs=True
is_manually_selected_docs=True,
max_window_percentage=SELECTED_SECTIONS_MAX_WINDOW_PERCENTAGE,
)
# In case the search doc is deleted, just don't include it

View File

@@ -312,11 +312,14 @@ def prune_sections(
)
def _merge_doc_chunks(chunks: list[InferenceChunk]) -> InferenceSection:
def _merge_doc_chunks(chunks: list[InferenceChunk]) -> tuple[InferenceSection, int]:
assert (
len(set([chunk.document_id for chunk in chunks])) == 1
), "One distinct document must be passed into merge_doc_chunks"
ADJACENT_CHUNK_SEP = "\n"
DISTANT_CHUNK_SEP = "\n\n...\n\n"
# Assuming there are no duplicates by this point
sorted_chunks = sorted(chunks, key=lambda x: x.chunk_id)
@@ -324,33 +327,48 @@ def _merge_doc_chunks(chunks: list[InferenceChunk]) -> InferenceSection:
chunks, key=lambda x: x.score if x.score is not None else float("-inf")
)
added_chars = 0
merged_content = []
for i, chunk in enumerate(sorted_chunks):
if i > 0:
prev_chunk_id = sorted_chunks[i - 1].chunk_id
if chunk.chunk_id == prev_chunk_id + 1:
merged_content.append("\n")
else:
merged_content.append("\n\n...\n\n")
sep = (
ADJACENT_CHUNK_SEP
if chunk.chunk_id == prev_chunk_id + 1
else DISTANT_CHUNK_SEP
)
merged_content.append(sep)
added_chars += len(sep)
merged_content.append(chunk.content)
combined_content = "".join(merged_content)
return InferenceSection(
center_chunk=center_chunk,
chunks=sorted_chunks,
combined_content=combined_content,
return (
InferenceSection(
center_chunk=center_chunk,
chunks=sorted_chunks,
combined_content=combined_content,
),
added_chars,
)
def _merge_sections(sections: list[InferenceSection]) -> list[InferenceSection]:
docs_map: dict[str, dict[int, InferenceChunk]] = defaultdict(dict)
doc_order: dict[str, int] = {}
combined_section_lengths: dict[str, int] = defaultdict(lambda: 0)
# chunk de-duping and doc ordering
for index, section in enumerate(sections):
if section.center_chunk.document_id not in doc_order:
doc_order[section.center_chunk.document_id] = index
combined_section_lengths[section.center_chunk.document_id] += len(
section.combined_content
)
chunks_map = docs_map[section.center_chunk.document_id]
for chunk in [section.center_chunk] + section.chunks:
chunks_map = docs_map[section.center_chunk.document_id]
existing_chunk = chunks_map.get(chunk.chunk_id)
if (
existing_chunk is None
@@ -361,8 +379,22 @@ def _merge_sections(sections: list[InferenceSection]) -> list[InferenceSection]:
chunks_map[chunk.chunk_id] = chunk
new_sections = []
for section_chunks in docs_map.values():
new_sections.append(_merge_doc_chunks(chunks=list(section_chunks.values())))
for doc_id, section_chunks in docs_map.items():
section_chunks_list = list(section_chunks.values())
merged_section, added_chars = _merge_doc_chunks(chunks=section_chunks_list)
previous_length = combined_section_lengths[doc_id] + added_chars
# After merging, ensure the content respects the pruning done earlier. Each
# combined section is restricted to the sum of the lengths of the sections
# from the pruning step. Technically the correct approach would be to prune based
# on tokens AGAIN, but this is a good approximation and worth not adding the
# tokenization overhead. This could also be fixed if we added a way of removing
# chunks from sections in the pruning step; at the moment this issue largely
# exists because we only trim the final section's combined_content.
merged_section.combined_content = merged_section.combined_content[
:previous_length
]
new_sections.append(merged_section)
# Sort by highest score, then by original document order
# It is now 1 large section per doc, the center chunk being the one with the highest score

View File

@@ -16,6 +16,9 @@ MAX_CHUNKS_FED_TO_CHAT = float(os.environ.get("MAX_CHUNKS_FED_TO_CHAT") or 10.0)
# ~3k input, half for docs, half for chat history + prompts
CHAT_TARGET_CHUNK_PERCENTAGE = 512 * 3 / 3072
# Maximum percentage of the context window to fill with selected sections
SELECTED_SECTIONS_MAX_WINDOW_PERCENTAGE = 0.8
# 1 / (1 + DOC_TIME_DECAY * doc-age-in-years), set to 0 to have no decay
# Capped in Vespa at 0.5
DOC_TIME_DECAY = float(

View File

@@ -13,6 +13,7 @@ from typing import TYPE_CHECKING
from typing import TypeVar
from urllib.parse import parse_qs
from urllib.parse import quote
from urllib.parse import urljoin
from urllib.parse import urlparse
import requests
@@ -342,9 +343,14 @@ def build_confluence_document_id(
Returns:
str: The document id
"""
if is_cloud and not base_url.endswith("/wiki"):
base_url += "/wiki"
return f"{base_url}{content_url}"
# NOTE: urljoin is tricky and will drop the last segment of the base if it doesn't
# end with "/" because it believes that makes it a file.
final_url = base_url.rstrip("/") + "/"
if is_cloud and not final_url.endswith("/wiki/"):
final_url = urljoin(final_url, "wiki") + "/"
final_url = urljoin(final_url, content_url.lstrip("/"))
return final_url
def datetime_from_string(datetime_string: str) -> datetime:
@@ -454,6 +460,19 @@ def _handle_http_error(e: requests.HTTPError, attempt: int) -> int:
logger.warning("HTTPError with `None` as response or as headers")
raise e
# Confluence Server returns 403 when rate limited
if e.response.status_code == 403:
FORBIDDEN_MAX_RETRY_ATTEMPTS = 7
FORBIDDEN_RETRY_DELAY = 10
if attempt < FORBIDDEN_MAX_RETRY_ATTEMPTS:
logger.warning(
"403 error. This sometimes happens when we hit "
f"Confluence rate limits. Retrying in {FORBIDDEN_RETRY_DELAY} seconds..."
)
return FORBIDDEN_RETRY_DELAY
raise e
if (
e.response.status_code != 429
and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower()

View File

@@ -445,6 +445,9 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
logger.warning(
f"User '{user_email}' does not have access to the drive APIs."
)
# mark this user as done so we don't try to retrieve anything for them
# again
curr_stage.stage = DriveRetrievalStage.DONE
return
raise
@@ -581,6 +584,25 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
drive_ids_to_retrieve, checkpoint
)
# only process emails that we haven't already completed retrieval for
non_completed_org_emails = [
user_email
for user_email, stage in checkpoint.completion_map.items()
if stage != DriveRetrievalStage.DONE
]
# don't process too many emails before returning a checkpoint. This is
# to resolve the case where there are a ton of emails that don't have access
# to the drive APIs. Without this, we could loop through these emails for
# more than 3 hours, causing a timeout and stalling progress.
email_batch_takes_us_to_completion = True
MAX_EMAILS_TO_PROCESS_BEFORE_CHECKPOINTING = 50
if len(non_completed_org_emails) > MAX_EMAILS_TO_PROCESS_BEFORE_CHECKPOINTING:
non_completed_org_emails = non_completed_org_emails[
:MAX_EMAILS_TO_PROCESS_BEFORE_CHECKPOINTING
]
email_batch_takes_us_to_completion = False
user_retrieval_gens = [
self._impersonate_user_for_retrieval(
email,
@@ -591,10 +613,14 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
start,
end,
)
for email in all_org_emails
for email in non_completed_org_emails
]
yield from parallel_yield(user_retrieval_gens, max_workers=MAX_DRIVE_WORKERS)
# if there are more emails to process, don't mark as complete
if not email_batch_takes_us_to_completion:
return
remaining_folders = (
drive_ids_to_retrieve | folder_ids_to_retrieve
) - self._retrieved_ids

View File

@@ -5,11 +5,13 @@ from typing import cast
from sqlalchemy.orm import Session
from onyx.chat.models import ContextualPruningConfig
from onyx.chat.models import PromptConfig
from onyx.chat.models import SectionRelevancePiece
from onyx.chat.prune_and_merge import _merge_sections
from onyx.chat.prune_and_merge import ChunkRange
from onyx.chat.prune_and_merge import merge_chunk_intervals
from onyx.chat.prune_and_merge import prune_and_merge_sections
from onyx.configs.chat_configs import DISABLE_LLM_DOC_RELEVANCE
from onyx.context.search.enums import LLMEvaluationType
from onyx.context.search.enums import QueryFlow
@@ -61,6 +63,7 @@ class SearchPipeline:
| None = None,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
prompt_config: PromptConfig | None = None,
contextual_pruning_config: ContextualPruningConfig | None = None,
):
# NOTE: The Search Request contains a lot of fields that are overrides, many of them can be None
# and typically are None. The preprocessing will fetch default values to replace these empty overrides.
@@ -77,6 +80,9 @@ class SearchPipeline:
self.search_settings = get_current_search_settings(db_session)
self.document_index = get_default_document_index(self.search_settings, None)
self.prompt_config: PromptConfig | None = prompt_config
self.contextual_pruning_config: ContextualPruningConfig | None = (
contextual_pruning_config
)
# Preprocessing steps generate this
self._search_query: SearchQuery | None = None
@@ -221,13 +227,16 @@ class SearchPipeline:
# If ee is enabled, censor the chunk sections based on user access
# Otherwise, return the retrieved chunks
censored_chunks = fetch_ee_implementation_or_noop(
"onyx.external_permissions.post_query_censoring",
"_post_query_chunk_censoring",
retrieved_chunks,
)(
chunks=retrieved_chunks,
user=self.user,
censored_chunks = cast(
list[InferenceChunk],
fetch_ee_implementation_or_noop(
"onyx.external_permissions.post_query_censoring",
"_post_query_chunk_censoring",
retrieved_chunks,
)(
chunks=retrieved_chunks,
user=self.user,
),
)
above = self.search_query.chunks_above
@@ -420,7 +429,26 @@ class SearchPipeline:
if self._final_context_sections is not None:
return self._final_context_sections
self._final_context_sections = _merge_sections(sections=self.reranked_sections)
if (
self.contextual_pruning_config is not None
and self.prompt_config is not None
):
self._final_context_sections = prune_and_merge_sections(
sections=self.reranked_sections,
section_relevance_list=None,
prompt_config=self.prompt_config,
llm_config=self.llm.config,
question=self.search_query.query,
contextual_pruning_config=self.contextual_pruning_config,
)
else:
logger.error(
"Contextual pruning or prompt config not set, using default merge"
)
self._final_context_sections = _merge_sections(
sections=self.reranked_sections
)
return self._final_context_sections
@property

View File

@@ -613,8 +613,19 @@ def fetch_connector_credential_pairs(
def resync_cc_pair(
cc_pair: ConnectorCredentialPair,
search_settings_id: int,
db_session: Session,
) -> None:
"""
Updates state stored in the connector_credential_pair table based on the
latest index attempt for the given search settings.
Args:
cc_pair: ConnectorCredentialPair to resync
search_settings_id: SearchSettings to use for resync
db_session: Database session
"""
def find_latest_index_attempt(
connector_id: int,
credential_id: int,
@@ -627,11 +638,10 @@ def resync_cc_pair(
ConnectorCredentialPair,
IndexAttempt.connector_credential_pair_id == ConnectorCredentialPair.id,
)
.join(SearchSettings, IndexAttempt.search_settings_id == SearchSettings.id)
.filter(
ConnectorCredentialPair.connector_id == connector_id,
ConnectorCredentialPair.credential_id == credential_id,
SearchSettings.status == IndexModelStatus.PRESENT,
IndexAttempt.search_settings_id == search_settings_id,
)
)

View File

@@ -43,6 +43,8 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
ONE_HOUR_IN_SECONDS = 60 * 60
def check_docs_exist(db_session: Session) -> bool:
stmt = select(exists(DbDocument))
@@ -607,6 +609,46 @@ def delete_documents_complete__no_commit(
delete_documents__no_commit(db_session, document_ids)
def delete_all_documents_for_connector_credential_pair(
db_session: Session,
connector_id: int,
credential_id: int,
timeout: int = ONE_HOUR_IN_SECONDS,
) -> None:
"""Delete all documents for a given connector credential pair.
This will delete all documents and their associated data (chunks, feedback, tags, etc.)
NOTE: a bit inefficient, but it's not a big deal since this is done rarely - only during
an index swap. If we wanted to make this more efficient, we could use a single delete
statement + cascade.
"""
batch_size = 1000
start_time = time.monotonic()
while True:
# Get document IDs in batches
stmt = (
select(DocumentByConnectorCredentialPair.id)
.where(
DocumentByConnectorCredentialPair.connector_id == connector_id,
DocumentByConnectorCredentialPair.credential_id == credential_id,
)
.limit(batch_size)
)
document_ids = db_session.scalars(stmt).all()
if not document_ids:
break
delete_documents_complete__no_commit(
db_session=db_session, document_ids=list(document_ids)
)
db_session.commit()
if time.monotonic() - start_time > timeout:
raise RuntimeError("Timeout reached while deleting documents")
def acquire_document_locks(db_session: Session, document_ids: list[str]) -> bool:
"""Acquire locks for the specified documents. Ideally this shouldn't be
called with large list of document_ids (an exception could be made if the

View File

@@ -710,6 +710,25 @@ def cancel_indexing_attempts_past_model(
)
def cancel_indexing_attempts_for_search_settings(
search_settings_id: int,
db_session: Session,
) -> None:
"""Stops all indexing attempts that are in progress or not started for
the specified search settings."""
db_session.execute(
update(IndexAttempt)
.where(
IndexAttempt.status.in_(
[IndexingStatus.IN_PROGRESS, IndexingStatus.NOT_STARTED]
),
IndexAttempt.search_settings_id == search_settings_id,
)
.values(status=IndexingStatus.FAILED)
)
def count_unique_cc_pairs_with_successful_index_attempts(
search_settings_id: int | None,
db_session: Session,

View File

@@ -703,7 +703,11 @@ class Connector(Base):
)
documents_by_connector: Mapped[
list["DocumentByConnectorCredentialPair"]
] = relationship("DocumentByConnectorCredentialPair", back_populates="connector")
] = relationship(
"DocumentByConnectorCredentialPair",
back_populates="connector",
passive_deletes=True,
)
# synchronize this validation logic with RefreshFrequencySchema etc on front end
# until we have a centralized validation schema
@@ -757,7 +761,11 @@ class Credential(Base):
)
documents_by_credential: Mapped[
list["DocumentByConnectorCredentialPair"]
] = relationship("DocumentByConnectorCredentialPair", back_populates="credential")
] = relationship(
"DocumentByConnectorCredentialPair",
back_populates="credential",
passive_deletes=True,
)
user: Mapped[User | None] = relationship("User", back_populates="credentials")
@@ -1110,10 +1118,10 @@ class DocumentByConnectorCredentialPair(Base):
id: Mapped[str] = mapped_column(ForeignKey("document.id"), primary_key=True)
# TODO: transition this to use the ConnectorCredentialPair id directly
connector_id: Mapped[int] = mapped_column(
ForeignKey("connector.id"), primary_key=True
ForeignKey("connector.id", ondelete="CASCADE"), primary_key=True
)
credential_id: Mapped[int] = mapped_column(
ForeignKey("credential.id"), primary_key=True
ForeignKey("credential.id", ondelete="CASCADE"), primary_key=True
)
# used to better keep track of document counts at a connector level
@@ -1123,10 +1131,10 @@ class DocumentByConnectorCredentialPair(Base):
has_been_indexed: Mapped[bool] = mapped_column(Boolean)
connector: Mapped[Connector] = relationship(
"Connector", back_populates="documents_by_connector"
"Connector", back_populates="documents_by_connector", passive_deletes=True
)
credential: Mapped[Credential] = relationship(
"Credential", back_populates="documents_by_credential"
"Credential", back_populates="documents_by_credential", passive_deletes=True
)
__table_args__ = (
@@ -1650,8 +1658,8 @@ class Prompt(Base):
)
name: Mapped[str] = mapped_column(String)
description: Mapped[str] = mapped_column(String)
system_prompt: Mapped[str] = mapped_column(Text)
task_prompt: Mapped[str] = mapped_column(Text)
system_prompt: Mapped[str] = mapped_column(String(length=8000))
task_prompt: Mapped[str] = mapped_column(String(length=8000))
include_citations: Mapped[bool] = mapped_column(Boolean, default=True)
datetime_aware: Mapped[bool] = mapped_column(Boolean, default=True)
# Default prompts are configured via backend during deployment

View File

@@ -37,8 +37,8 @@ from onyx.db.models import UserFile
from onyx.db.models import UserFolder
from onyx.db.models import UserGroup
from onyx.db.notification import create_notification
from onyx.server.features.persona.models import FullPersonaSnapshot
from onyx.server.features.persona.models import PersonaSharedNotificationData
from onyx.server.features.persona.models import PersonaSnapshot
from onyx.server.features.persona.models import PersonaUpsertRequest
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_versioned_implementation
@@ -201,7 +201,7 @@ def create_update_persona(
create_persona_request: PersonaUpsertRequest,
user: User | None,
db_session: Session,
) -> PersonaSnapshot:
) -> FullPersonaSnapshot:
"""Higher level function than upsert_persona, although either is valid to use."""
# Permission to actually use these is checked later
@@ -271,7 +271,7 @@ def create_update_persona(
logger.exception("Failed to create persona")
raise HTTPException(status_code=400, detail=str(e))
return PersonaSnapshot.from_model(persona)
return FullPersonaSnapshot.from_model(persona)
def update_persona_shared_users(

View File

@@ -3,8 +3,9 @@ from sqlalchemy.orm import Session
from onyx.configs.constants import KV_REINDEX_KEY
from onyx.db.connector_credential_pair import get_connector_credential_pairs
from onyx.db.connector_credential_pair import resync_cc_pair
from onyx.db.document import delete_all_documents_for_connector_credential_pair
from onyx.db.enums import IndexModelStatus
from onyx.db.index_attempt import cancel_indexing_attempts_past_model
from onyx.db.index_attempt import cancel_indexing_attempts_for_search_settings
from onyx.db.index_attempt import (
count_unique_cc_pairs_with_successful_index_attempts,
)
@@ -26,31 +27,49 @@ def _perform_index_swap(
current_search_settings: SearchSettings,
secondary_search_settings: SearchSettings,
all_cc_pairs: list[ConnectorCredentialPair],
cleanup_documents: bool = False,
) -> None:
"""Swap the indices and expire the old one."""
current_search_settings = get_current_search_settings(db_session)
update_search_settings_status(
search_settings=current_search_settings,
new_status=IndexModelStatus.PAST,
db_session=db_session,
)
update_search_settings_status(
search_settings=secondary_search_settings,
new_status=IndexModelStatus.PRESENT,
db_session=db_session,
)
if len(all_cc_pairs) > 0:
kv_store = get_kv_store()
kv_store.store(KV_REINDEX_KEY, False)
# Expire jobs for the now past index/embedding model
cancel_indexing_attempts_past_model(db_session)
cancel_indexing_attempts_for_search_settings(
search_settings_id=current_search_settings.id,
db_session=db_session,
)
# Recount aggregates
for cc_pair in all_cc_pairs:
resync_cc_pair(cc_pair, db_session=db_session)
resync_cc_pair(
cc_pair=cc_pair,
# sync based on the new search settings
search_settings_id=secondary_search_settings.id,
db_session=db_session,
)
if cleanup_documents:
# clean up all DocumentByConnectorCredentialPair / Document rows, since we're
# doing an instant swap and no documents will exist in the new index.
for cc_pair in all_cc_pairs:
delete_all_documents_for_connector_credential_pair(
db_session=db_session,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
)
# swap over search settings
update_search_settings_status(
search_settings=current_search_settings,
new_status=IndexModelStatus.PAST,
db_session=db_session,
)
update_search_settings_status(
search_settings=secondary_search_settings,
new_status=IndexModelStatus.PRESENT,
db_session=db_session,
)
# remove the old index from the vector db
document_index = get_default_document_index(secondary_search_settings, None)
@@ -88,6 +107,9 @@ def check_and_perform_index_swap(db_session: Session) -> SearchSettings | None:
current_search_settings=current_search_settings,
secondary_search_settings=secondary_search_settings,
all_cc_pairs=all_cc_pairs,
# clean up all DocumentByConnectorCredentialPair / Document rows, since we're
# doing an instant swap.
cleanup_documents=True,
)
return current_search_settings

View File

@@ -602,7 +602,7 @@ def get_max_input_tokens(
)
if input_toks <= 0:
raise RuntimeError("No tokens for input for the LLM given settings")
return GEN_AI_MODEL_FALLBACK_MAX_TOKENS
return input_toks

View File

@@ -1,3 +1,4 @@
import logging
import sys
import traceback
from collections.abc import AsyncGenerator
@@ -16,6 +17,7 @@ from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from httpx_oauth.clients.google import GoogleOAuth2
from prometheus_fastapi_instrumentator import Instrumentator
from sentry_sdk.integrations.fastapi import FastApiIntegration
from sentry_sdk.integrations.starlette import StarletteIntegration
from sqlalchemy.orm import Session
@@ -102,6 +104,8 @@ from onyx.server.utils import BasicAuthenticationError
from onyx.setup import setup_multitenant_onyx
from onyx.setup import setup_onyx
from onyx.utils.logger import setup_logger
from onyx.utils.logger import setup_uvicorn_logger
from onyx.utils.middleware import add_onyx_request_id_middleware
from onyx.utils.telemetry import get_or_generate_uuid
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
@@ -116,6 +120,12 @@ from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
file_handlers = [
h for h in logger.logger.handlers if isinstance(h, logging.FileHandler)
]
setup_uvicorn_logger(shared_file_handlers=file_handlers)
def validation_exception_handler(request: Request, exc: Exception) -> JSONResponse:
if not isinstance(exc, RequestValidationError):
@@ -421,9 +431,14 @@ def get_application() -> FastAPI:
if LOG_ENDPOINT_LATENCY:
add_latency_logging_middleware(application, logger)
add_onyx_request_id_middleware(application, "API", logger)
# Ensure all routes have auth enabled or are explicitly marked as public
check_router_auth(application)
# Initialize and instrument the app
Instrumentator().instrument(application).expose(application)
return application

View File

@@ -49,6 +49,7 @@ PUBLIC_ENDPOINT_SPECS = [
("/auth/oauth/callback", {"GET"}),
# anonymous user on cloud
("/tenants/anonymous-user", {"POST"}),
("/metrics", {"GET"}), # added by prometheus_fastapi_instrumentator
]

View File

@@ -21,7 +21,7 @@ from onyx.background.celery.tasks.external_group_syncing.tasks import (
from onyx.background.celery.tasks.pruning.tasks import (
try_creating_prune_generator_task,
)
from onyx.background.celery.versioned_apps.primary import app as primary_app
from onyx.background.celery.versioned_apps.client import app as client_app
from onyx.background.indexing.models import IndexAttemptErrorPydantic
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryTask
@@ -219,7 +219,7 @@ def update_cc_pair_status(
continue
# Revoke the task to prevent it from running
primary_app.control.revoke(index_payload.celery_task_id)
client_app.control.revoke(index_payload.celery_task_id)
# If it is running, then signaling for termination will get the
# watchdog thread to kill the spawned task
@@ -238,7 +238,7 @@ def update_cc_pair_status(
db_session.commit()
# this speeds up the start of indexing by firing the check immediately
primary_app.send_task(
client_app.send_task(
OnyxCeleryTask.CHECK_FOR_INDEXING,
kwargs=dict(tenant_id=tenant_id),
priority=OnyxCeleryPriority.HIGH,
@@ -376,7 +376,7 @@ def prune_cc_pair(
f"{cc_pair.connector.name} connector."
)
payload_id = try_creating_prune_generator_task(
primary_app, cc_pair, db_session, r, tenant_id
client_app, cc_pair, db_session, r, tenant_id
)
if not payload_id:
raise HTTPException(
@@ -450,7 +450,7 @@ def sync_cc_pair(
f"{cc_pair.connector.name} connector."
)
payload_id = try_creating_permissions_sync_task(
primary_app, cc_pair_id, r, tenant_id
client_app, cc_pair_id, r, tenant_id
)
if not payload_id:
raise HTTPException(
@@ -524,7 +524,7 @@ def sync_cc_pair_groups(
f"{cc_pair.connector.name} connector."
)
payload_id = try_creating_external_group_sync_task(
primary_app, cc_pair_id, r, tenant_id
client_app, cc_pair_id, r, tenant_id
)
if not payload_id:
raise HTTPException(
@@ -634,7 +634,7 @@ def associate_credential_to_connector(
)
# trigger indexing immediately
primary_app.send_task(
client_app.send_task(
OnyxCeleryTask.CHECK_FOR_INDEXING,
priority=OnyxCeleryPriority.HIGH,
kwargs={"tenant_id": tenant_id},

View File

@@ -20,7 +20,7 @@ from onyx.auth.users import current_admin_user
from onyx.auth.users import current_chat_accessible_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_user
from onyx.background.celery.versioned_apps.primary import app as primary_app
from onyx.background.celery.versioned_apps.client import app as client_app
from onyx.configs.app_configs import ENABLED_CONNECTOR_TYPES
from onyx.configs.app_configs import MOCK_CONNECTOR_FILE_PATH
from onyx.configs.constants import DocumentSource
@@ -928,7 +928,7 @@ def create_connector_with_mock_credential(
)
# trigger indexing immediately
primary_app.send_task(
client_app.send_task(
OnyxCeleryTask.CHECK_FOR_INDEXING,
priority=OnyxCeleryPriority.HIGH,
kwargs={"tenant_id": tenant_id},
@@ -1314,7 +1314,7 @@ def trigger_indexing_for_cc_pair(
# run the beat task to pick up the triggers immediately
priority = OnyxCeleryPriority.HIGHEST if is_user_file else OnyxCeleryPriority.HIGH
logger.info(f"Sending indexing check task with priority {priority}")
primary_app.send_task(
client_app.send_task(
OnyxCeleryTask.CHECK_FOR_INDEXING,
priority=priority,
kwargs={"tenant_id": tenant_id},

View File

@@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_user
from onyx.background.celery.versioned_apps.primary import app as primary_app
from onyx.background.celery.versioned_apps.client import app as client_app
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.document_set import check_document_sets_are_public
@@ -52,7 +52,7 @@ def create_document_set(
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
primary_app.send_task(
client_app.send_task(
OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
kwargs={"tenant_id": tenant_id},
priority=OnyxCeleryPriority.HIGH,
@@ -85,7 +85,7 @@ def patch_document_set(
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
primary_app.send_task(
client_app.send_task(
OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
kwargs={"tenant_id": tenant_id},
priority=OnyxCeleryPriority.HIGH,
@@ -108,7 +108,7 @@ def delete_document_set(
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
primary_app.send_task(
client_app.send_task(
OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
kwargs={"tenant_id": tenant_id},
priority=OnyxCeleryPriority.HIGH,

View File

@@ -43,6 +43,7 @@ from onyx.file_store.models import ChatFileType
from onyx.secondary_llm_flows.starter_message_creation import (
generate_starter_messages,
)
from onyx.server.features.persona.models import FullPersonaSnapshot
from onyx.server.features.persona.models import GenerateStarterMessageRequest
from onyx.server.features.persona.models import ImageGenerationToolStatus
from onyx.server.features.persona.models import PersonaLabelCreate
@@ -424,8 +425,8 @@ def get_persona(
persona_id: int,
user: User | None = Depends(current_limited_user),
db_session: Session = Depends(get_session),
) -> PersonaSnapshot:
return PersonaSnapshot.from_model(
) -> FullPersonaSnapshot:
return FullPersonaSnapshot.from_model(
get_persona_by_id(
persona_id=persona_id,
user=user,

View File

@@ -91,37 +91,80 @@ class PersonaUpsertRequest(BaseModel):
class PersonaSnapshot(BaseModel):
id: int
owner: MinimalUserSnapshot | None
name: str
is_visible: bool
is_public: bool
display_priority: int | None
description: str
num_chunks: float | None
llm_relevance_filter: bool
llm_filter_extraction: bool
llm_model_provider_override: str | None
llm_model_version_override: str | None
starter_messages: list[StarterMessage] | None
builtin_persona: bool
prompts: list[PromptSnapshot]
tools: list[ToolSnapshot]
document_sets: list[DocumentSet]
users: list[MinimalUserSnapshot]
groups: list[int]
icon_color: str | None
icon_shape: int | None
is_public: bool
is_visible: bool
icon_shape: int | None = None
icon_color: str | None = None
uploaded_image_id: str | None = None
is_default_persona: bool
user_file_ids: list[int] = Field(default_factory=list)
user_folder_ids: list[int] = Field(default_factory=list)
display_priority: int | None = None
is_default_persona: bool = False
builtin_persona: bool = False
starter_messages: list[StarterMessage] | None = None
tools: list[ToolSnapshot] = Field(default_factory=list)
labels: list["PersonaLabelSnapshot"] = Field(default_factory=list)
owner: MinimalUserSnapshot | None = None
users: list[MinimalUserSnapshot] = Field(default_factory=list)
groups: list[int] = Field(default_factory=list)
document_sets: list[DocumentSet] = Field(default_factory=list)
llm_model_provider_override: str | None = None
llm_model_version_override: str | None = None
num_chunks: float | None = None
@classmethod
def from_model(cls, persona: Persona) -> "PersonaSnapshot":
return PersonaSnapshot(
id=persona.id,
name=persona.name,
description=persona.description,
is_public=persona.is_public,
is_visible=persona.is_visible,
icon_shape=persona.icon_shape,
icon_color=persona.icon_color,
uploaded_image_id=persona.uploaded_image_id,
user_file_ids=[file.id for file in persona.user_files],
user_folder_ids=[folder.id for folder in persona.user_folders],
display_priority=persona.display_priority,
is_default_persona=persona.is_default_persona,
builtin_persona=persona.builtin_persona,
starter_messages=persona.starter_messages,
tools=[ToolSnapshot.from_model(tool) for tool in persona.tools],
labels=[PersonaLabelSnapshot.from_model(label) for label in persona.labels],
owner=(
MinimalUserSnapshot(id=persona.user.id, email=persona.user.email)
if persona.user
else None
),
users=[
MinimalUserSnapshot(id=user.id, email=user.email)
for user in persona.users
],
groups=[user_group.id for user_group in persona.groups],
document_sets=[
DocumentSet.from_model(document_set_model)
for document_set_model in persona.document_sets
],
llm_model_provider_override=persona.llm_model_provider_override,
llm_model_version_override=persona.llm_model_version_override,
num_chunks=persona.num_chunks,
)
# Model with full context on perona's internal settings
# This is used for flows which need to know all settings
class FullPersonaSnapshot(PersonaSnapshot):
search_start_date: datetime | None = None
labels: list["PersonaLabelSnapshot"] = []
user_file_ids: list[int] | None = None
user_folder_ids: list[int] | None = None
prompts: list[PromptSnapshot] = Field(default_factory=list)
llm_relevance_filter: bool = False
llm_filter_extraction: bool = False
@classmethod
def from_model(
cls, persona: Persona, allow_deleted: bool = False
) -> "PersonaSnapshot":
) -> "FullPersonaSnapshot":
if persona.deleted:
error_msg = f"Persona with ID {persona.id} has been deleted"
if not allow_deleted:
@@ -129,44 +172,32 @@ class PersonaSnapshot(BaseModel):
else:
logger.warning(error_msg)
return PersonaSnapshot(
return FullPersonaSnapshot(
id=persona.id,
name=persona.name,
description=persona.description,
is_public=persona.is_public,
is_visible=persona.is_visible,
icon_shape=persona.icon_shape,
icon_color=persona.icon_color,
uploaded_image_id=persona.uploaded_image_id,
user_file_ids=[file.id for file in persona.user_files],
user_folder_ids=[folder.id for folder in persona.user_folders],
display_priority=persona.display_priority,
is_default_persona=persona.is_default_persona,
builtin_persona=persona.builtin_persona,
starter_messages=persona.starter_messages,
tools=[ToolSnapshot.from_model(tool) for tool in persona.tools],
labels=[PersonaLabelSnapshot.from_model(label) for label in persona.labels],
owner=(
MinimalUserSnapshot(id=persona.user.id, email=persona.user.email)
if persona.user
else None
),
is_visible=persona.is_visible,
is_public=persona.is_public,
display_priority=persona.display_priority,
description=persona.description,
num_chunks=persona.num_chunks,
search_start_date=persona.search_start_date,
prompts=[PromptSnapshot.from_model(prompt) for prompt in persona.prompts],
llm_relevance_filter=persona.llm_relevance_filter,
llm_filter_extraction=persona.llm_filter_extraction,
llm_model_provider_override=persona.llm_model_provider_override,
llm_model_version_override=persona.llm_model_version_override,
starter_messages=persona.starter_messages,
builtin_persona=persona.builtin_persona,
is_default_persona=persona.is_default_persona,
prompts=[PromptSnapshot.from_model(prompt) for prompt in persona.prompts],
tools=[ToolSnapshot.from_model(tool) for tool in persona.tools],
document_sets=[
DocumentSet.from_model(document_set_model)
for document_set_model in persona.document_sets
],
users=[
MinimalUserSnapshot(id=user.id, email=user.email)
for user in persona.users
],
groups=[user_group.id for user_group in persona.groups],
icon_color=persona.icon_color,
icon_shape=persona.icon_shape,
uploaded_image_id=persona.uploaded_image_id,
search_start_date=persona.search_start_date,
labels=[PersonaLabelSnapshot.from_model(label) for label in persona.labels],
user_file_ids=[file.id for file in persona.user_files],
user_folder_ids=[folder.id for folder in persona.user_folders],
)

View File

@@ -10,7 +10,7 @@ from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.background.celery.versioned_apps.primary import app as primary_app
from onyx.background.celery.versioned_apps.client import app as client_app
from onyx.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import KV_GEN_AI_KEY_CHECK_TIME
@@ -192,7 +192,7 @@ def create_deletion_attempt_for_connector_id(
db_session.commit()
# run the beat task to pick up this deletion from the db immediately
primary_app.send_task(
client_app.send_task(
OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
priority=OnyxCeleryPriority.HIGH,
kwargs={"tenant_id": tenant_id},

View File

@@ -19,6 +19,7 @@ from onyx.db.models import SlackBot as SlackAppModel
from onyx.db.models import SlackChannelConfig as SlackChannelConfigModel
from onyx.db.models import User
from onyx.onyxbot.slack.config import VALID_SLACK_FILTERS
from onyx.server.features.persona.models import FullPersonaSnapshot
from onyx.server.features.persona.models import PersonaSnapshot
from onyx.server.models import FullUserSnapshot
from onyx.server.models import InvitedUserSnapshot
@@ -245,7 +246,7 @@ class SlackChannelConfig(BaseModel):
id=slack_channel_config_model.id,
slack_bot_id=slack_channel_config_model.slack_bot_id,
persona=(
PersonaSnapshot.from_model(
FullPersonaSnapshot.from_model(
slack_channel_config_model.persona, allow_deleted=True
)
if slack_channel_config_model.persona

View File

@@ -117,7 +117,11 @@ def set_new_search_settings(
search_settings_id=search_settings.id, db_session=db_session
)
for cc_pair in get_connector_credential_pairs(db_session):
resync_cc_pair(cc_pair, db_session=db_session)
resync_cc_pair(
cc_pair=cc_pair,
search_settings_id=new_search_settings.id,
db_session=db_session,
)
db_session.commit()
return IdReturn(id=new_search_settings.id)

View File

@@ -96,7 +96,11 @@ def setup_onyx(
)
for cc_pair in get_connector_credential_pairs(db_session):
resync_cc_pair(cc_pair, db_session=db_session)
resync_cc_pair(
cc_pair=cc_pair,
search_settings_id=search_settings.id,
db_session=db_session,
)
# Expire all old embedding models indexing attempts, technically redundant
cancel_indexing_attempts_past_model(db_session)

View File

@@ -376,6 +376,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
db_session=alternate_db_session or self.db_session,
prompt_config=self.prompt_config,
retrieved_sections_callback=retrieved_sections_callback,
contextual_pruning_config=self.contextual_pruning_config,
)
search_query_info = SearchQueryInfo(
@@ -447,6 +448,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
db_session=self.db_session,
bypass_acl=self.bypass_acl,
prompt_config=self.prompt_config,
contextual_pruning_config=self.contextual_pruning_config,
)
# Log what we're doing

View File

@@ -13,6 +13,7 @@ from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import SLACK_CHANNEL_ID
from shared_configs.configs import TENANT_ID_PREFIX
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import ONYX_REQUEST_ID_CONTEXTVAR
logging.addLevelName(logging.INFO + 5, "NOTICE")
@@ -71,6 +72,14 @@ def get_log_level_from_str(log_level_str: str = LOG_LEVEL) -> int:
return log_level_dict.get(log_level_str.upper(), logging.getLevelName("NOTICE"))
class OnyxRequestIDFilter(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool:
from shared_configs.contextvars import ONYX_REQUEST_ID_CONTEXTVAR
record.request_id = ONYX_REQUEST_ID_CONTEXTVAR.get() or "-"
return True
class OnyxLoggingAdapter(logging.LoggerAdapter):
def process(
self, msg: str, kwargs: MutableMapping[str, Any]
@@ -103,6 +112,7 @@ class OnyxLoggingAdapter(logging.LoggerAdapter):
msg = f"[CC Pair: {cc_pair_id}] {msg}"
break
# Add tenant information if it differs from default
# This will always be the case for authenticated API requests
if MULTI_TENANT:
@@ -115,6 +125,11 @@ class OnyxLoggingAdapter(logging.LoggerAdapter):
)
msg = f"[t:{short_tenant}] {msg}"
# request id within a fastapi route
fastapi_request_id = ONYX_REQUEST_ID_CONTEXTVAR.get()
if fastapi_request_id:
msg = f"[{fastapi_request_id}] {msg}"
# For Slack Bot, logs the channel relevant to the request
channel_id = self.extra.get(SLACK_CHANNEL_ID) if self.extra else None
if channel_id:
@@ -165,6 +180,14 @@ class ColoredFormatter(logging.Formatter):
return super().format(record)
def get_uvicorn_standard_formatter() -> ColoredFormatter:
"""Returns a standard colored logging formatter."""
return ColoredFormatter(
"%(asctime)s %(filename)30s %(lineno)4s: [%(request_id)s] %(message)s",
datefmt="%m/%d/%Y %I:%M:%S %p",
)
def get_standard_formatter() -> ColoredFormatter:
"""Returns a standard colored logging formatter."""
return ColoredFormatter(
@@ -201,12 +224,6 @@ def setup_logger(
logger.addHandler(handler)
uvicorn_logger = logging.getLogger("uvicorn.access")
if uvicorn_logger:
uvicorn_logger.handlers = []
uvicorn_logger.addHandler(handler)
uvicorn_logger.setLevel(log_level)
is_containerized = is_running_in_container()
if LOG_FILE_NAME and (is_containerized or DEV_LOGGING_ENABLED):
log_levels = ["debug", "info", "notice"]
@@ -225,14 +242,37 @@ def setup_logger(
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
if uvicorn_logger:
uvicorn_logger.addHandler(file_handler)
logger.notice = lambda msg, *args, **kwargs: logger.log(logging.getLevelName("NOTICE"), msg, *args, **kwargs) # type: ignore
return OnyxLoggingAdapter(logger, extra=extra)
def setup_uvicorn_logger(
log_level: int = get_log_level_from_str(),
shared_file_handlers: list[logging.FileHandler] | None = None,
) -> None:
uvicorn_logger = logging.getLogger("uvicorn.access")
if not uvicorn_logger:
return
formatter = get_uvicorn_standard_formatter()
handler = logging.StreamHandler()
handler.setLevel(log_level)
handler.setFormatter(formatter)
uvicorn_logger.handlers = []
uvicorn_logger.addHandler(handler)
uvicorn_logger.setLevel(log_level)
uvicorn_logger.addFilter(OnyxRequestIDFilter())
if shared_file_handlers:
for fh in shared_file_handlers:
uvicorn_logger.addHandler(fh)
return
def print_loggers() -> None:
"""Print information about all loggers. Use to debug logging issues."""
root_logger = logging.getLogger()

View File

@@ -0,0 +1,62 @@
import base64
import hashlib
import logging
import uuid
from collections.abc import Awaitable
from collections.abc import Callable
from datetime import datetime
from datetime import timezone
from fastapi import FastAPI
from fastapi import Request
from fastapi import Response
from shared_configs.contextvars import ONYX_REQUEST_ID_CONTEXTVAR
def add_onyx_request_id_middleware(
app: FastAPI, prefix: str, logger: logging.LoggerAdapter
) -> None:
@app.middleware("http")
async def set_request_id(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
"""Generate a request hash that can be used to track the lifecycle
of a request. The hash is prefixed to help indicated where the request id
originated.
Format is f"{PREFIX}:{ID}" where PREFIX is 3 chars and ID is 8 chars.
Total length is 12 chars.
"""
onyx_request_id = request.headers.get("X-Onyx-Request-ID")
if not onyx_request_id:
onyx_request_id = make_randomized_onyx_request_id(prefix)
ONYX_REQUEST_ID_CONTEXTVAR.set(onyx_request_id)
return await call_next(request)
def make_randomized_onyx_request_id(prefix: str) -> str:
"""generates a randomized request id"""
hash_input = str(uuid.uuid4())
return _make_onyx_request_id(prefix, hash_input)
def make_structured_onyx_request_id(prefix: str, request_url: str) -> str:
"""Not used yet, but could be in the future!"""
hash_input = f"{request_url}:{datetime.now(timezone.utc)}"
return _make_onyx_request_id(prefix, hash_input)
def _make_onyx_request_id(prefix: str, hash_input: str) -> str:
"""helper function to return an id given a string input"""
hash_obj = hashlib.md5(hash_input.encode("utf-8"))
hash_bytes = hash_obj.digest()[:6] # Truncate to 6 bytes
# 6 bytes becomes 8 bytes. we shouldn't need to strip but just in case
# NOTE: possible we'll want more input bytes if id's aren't unique enough
hash_str = base64.urlsafe_b64encode(hash_bytes).decode("utf-8").rstrip("=")
onyx_request_id = f"{prefix}:{hash_str}"
return onyx_request_id

View File

@@ -332,14 +332,15 @@ def wait_on_background(task: TimeoutThread[R]) -> R:
return task.result
def _next_or_none(ind: int, g: Iterator[R]) -> tuple[int, R | None]:
return ind, next(g, None)
def _next_or_none(ind: int, gen: Iterator[R]) -> tuple[int, R | None]:
return ind, next(gen, None)
def parallel_yield(gens: list[Iterator[R]], max_workers: int = 10) -> Iterator[R]:
with ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_index: dict[Future[tuple[int, R | None]], int] = {
executor.submit(_next_or_none, i, g): i for i, g in enumerate(gens)
executor.submit(_next_or_none, ind, gen): ind
for ind, gen in enumerate(gens)
}
next_ind = len(gens)

View File

@@ -95,4 +95,5 @@ urllib3==2.2.3
mistune==0.8.4
sentry-sdk==2.14.0
prometheus_client==0.21.0
fastapi-limiter==0.1.6
fastapi-limiter==0.1.6
prometheus_fastapi_instrumentator==7.1.0

View File

@@ -15,4 +15,5 @@ uvicorn==0.21.1
voyageai==0.2.3
litellm==1.61.16
sentry-sdk[fastapi,celery,starlette]==2.14.0
aioboto3==13.4.0
aioboto3==13.4.0
prometheus_fastapi_instrumentator==7.1.0

View File

@@ -58,6 +58,7 @@ INDEXING_ONLY = os.environ.get("INDEXING_ONLY", "").lower() == "true"
# The process needs to have this for the log file to write to
# otherwise, it will not create additional log files
# This should just be the filename base without extension or path.
LOG_FILE_NAME = os.environ.get("LOG_FILE_NAME") or "onyx"
# Enable generating persistent log files for local dev environments

View File

@@ -11,6 +11,15 @@ CURRENT_TENANT_ID_CONTEXTVAR: contextvars.ContextVar[
"current_tenant_id", default=None if MULTI_TENANT else POSTGRES_DEFAULT_SCHEMA
)
# set by every route in the API server
INDEXING_REQUEST_ID_CONTEXTVAR: contextvars.ContextVar[
str | None
] = contextvars.ContextVar("indexing_request_id", default=None)
# set by every route in the API server
ONYX_REQUEST_ID_CONTEXTVAR: contextvars.ContextVar[str | None] = contextvars.ContextVar(
"onyx_request_id", default=None
)
"""Utils related to contextvars"""

View File

@@ -34,7 +34,7 @@ def confluence_connector(space: str) -> ConfluenceConnector:
return connector
@pytest.mark.parametrize("space", [os.environ["CONFLUENCE_TEST_SPACE"]])
@pytest.mark.parametrize("space", [os.getenv("CONFLUENCE_TEST_SPACE") or "DailyConne"])
@patch(
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,

View File

@@ -165,17 +165,18 @@ class DocumentManager:
doc["fields"]["document_id"]: doc["fields"] for doc in retrieved_docs_dict
}
# NOTE(rkuo): too much log spam
# Left this here for debugging purposes.
import json
# import json
print("DEBUGGING DOCUMENTS")
print(retrieved_docs)
for doc in retrieved_docs.values():
printable_doc = doc.copy()
print(printable_doc.keys())
printable_doc.pop("embeddings")
printable_doc.pop("title_embedding")
print(json.dumps(printable_doc, indent=2))
# print("DEBUGGING DOCUMENTS")
# print(retrieved_docs)
# for doc in retrieved_docs.values():
# printable_doc = doc.copy()
# print(printable_doc.keys())
# printable_doc.pop("embeddings")
# printable_doc.pop("title_embedding")
# print(json.dumps(printable_doc, indent=2))
for document in cc_pair.documents:
retrieved_doc = retrieved_docs.get(document.id)

View File

@@ -1,3 +1,4 @@
import time
from datetime import datetime
from datetime import timedelta
from urllib.parse import urlencode
@@ -191,7 +192,7 @@ class IndexAttemptManager:
user_performing_action: DATestUser | None = None,
) -> None:
"""Wait for an IndexAttempt to complete"""
start = datetime.now()
start = time.monotonic()
while True:
index_attempt = IndexAttemptManager.get_index_attempt_by_id(
index_attempt_id=index_attempt_id,
@@ -203,7 +204,7 @@ class IndexAttemptManager:
print(f"IndexAttempt {index_attempt_id} completed")
return
elapsed = (datetime.now() - start).total_seconds()
elapsed = time.monotonic() - start
if elapsed > timeout:
raise TimeoutError(
f"IndexAttempt {index_attempt_id} did not complete within {timeout} seconds"

View File

@@ -4,7 +4,7 @@ from uuid import uuid4
import requests
from onyx.context.search.enums import RecencyBiasSetting
from onyx.server.features.persona.models import PersonaSnapshot
from onyx.server.features.persona.models import FullPersonaSnapshot
from onyx.server.features.persona.models import PersonaUpsertRequest
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
@@ -181,7 +181,7 @@ class PersonaManager:
@staticmethod
def get_all(
user_performing_action: DATestUser | None = None,
) -> list[PersonaSnapshot]:
) -> list[FullPersonaSnapshot]:
response = requests.get(
f"{API_SERVER_URL}/admin/persona",
headers=user_performing_action.headers
@@ -189,13 +189,13 @@ class PersonaManager:
else GENERAL_HEADERS,
)
response.raise_for_status()
return [PersonaSnapshot(**persona) for persona in response.json()]
return [FullPersonaSnapshot(**persona) for persona in response.json()]
@staticmethod
def get_one(
persona_id: int,
user_performing_action: DATestUser | None = None,
) -> list[PersonaSnapshot]:
) -> list[FullPersonaSnapshot]:
response = requests.get(
f"{API_SERVER_URL}/persona/{persona_id}",
headers=user_performing_action.headers
@@ -203,7 +203,7 @@ class PersonaManager:
else GENERAL_HEADERS,
)
response.raise_for_status()
return [PersonaSnapshot(**response.json())]
return [FullPersonaSnapshot(**response.json())]
@staticmethod
def verify(

View File

@@ -313,29 +313,3 @@ class UserManager:
)
response.raise_for_status()
return UserInfo(**response.json())
@staticmethod
def invite_users(
user_performing_action: DATestUser,
emails: list[str],
) -> int:
response = requests.put(
url=f"{API_SERVER_URL}/manage/admin/users",
json={"emails": emails},
headers=user_performing_action.headers,
)
response.raise_for_status()
return response.json()
@staticmethod
def remove_invited_user(
user_performing_action: DATestUser,
user_email: str,
) -> int:
response = requests.patch(
url=f"{API_SERVER_URL}/manage/admin/remove-invited-user",
json={"user_email": user_email},
headers=user_performing_action.headers,
)
response.raise_for_status()
return response.json()

View File

@@ -22,7 +22,6 @@ from onyx.document_index.document_index_utils import get_multipass_config
from onyx.document_index.vespa.index import DOCUMENT_ID_ENDPOINT
from onyx.document_index.vespa.index import VespaIndex
from onyx.indexing.models import IndexingSetting
from onyx.redis.redis_pool import get_redis_client
from onyx.setup import setup_postgres
from onyx.setup import setup_vespa
from onyx.utils.logger import setup_logger
@@ -238,12 +237,6 @@ def reset_vespa() -> None:
time.sleep(5)
def reset_redis() -> None:
"""Reset the Redis database."""
redis_client = get_redis_client()
redis_client.flushall()
def reset_postgres_multitenant() -> None:
"""Reset the Postgres database for all tenants in a multitenant setup."""
@@ -348,8 +341,6 @@ def reset_all() -> None:
reset_postgres()
logger.info("Resetting Vespa...")
reset_vespa()
logger.info("Resetting Redis...")
reset_redis()
def reset_all_multitenant() -> None:

View File

@@ -1,38 +0,0 @@
import pytest
from requests import HTTPError
from onyx.auth.schemas import UserRole
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestUser
def test_inviting_users_flow(reset: None) -> None:
"""
Test that verifies the functionality around inviting users:
1. Creating an admin user
2. Admin inviting a new user
3. Invited user successfully signing in
4. Non-invited user attempting to sign in (should result in an error)
"""
# 1) Create an admin user (the first user created is automatically admin)
admin_user: DATestUser = UserManager.create(name="admin_user")
assert admin_user is not None
assert UserManager.is_role(admin_user, UserRole.ADMIN)
# 2) Admin invites a new user
invited_email = "invited_user@test.com"
invite_response = UserManager.invite_users(admin_user, [invited_email])
assert invite_response == 1
# 3) The invited user successfully registers/logs in
invited_user: DATestUser = UserManager.create(
name="invited_user", email=invited_email
)
assert invited_user is not None
assert invited_user.email == invited_email
assert UserManager.is_role(invited_user, UserRole.BASIC)
# 4) A non-invited user attempts to sign in/register (should fail)
with pytest.raises(HTTPError):
UserManager.create(name="uninvited_user", email="uninvited_user@test.com")

View File

@@ -4,6 +4,7 @@ from onyx.chat.prune_and_merge import _merge_sections
from onyx.configs.constants import DocumentSource
from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import InferenceSection
from onyx.context.search.utils import inference_section_from_chunks
# This large test accounts for all of the following:
@@ -111,7 +112,7 @@ Content 17
# Sections
[
# Document 1, top/middle/bot connected + disconnected section
InferenceSection(
inference_section_from_chunks(
center_chunk=DOC_1_TOP_CHUNK,
chunks=[
DOC_1_FILLER_1,
@@ -120,9 +121,8 @@ Content 17
DOC_1_MID_CHUNK,
DOC_1_FILLER_3,
],
combined_content="N/A", # Not used
),
InferenceSection(
inference_section_from_chunks(
center_chunk=DOC_1_MID_CHUNK,
chunks=[
DOC_1_FILLER_2,
@@ -131,9 +131,8 @@ Content 17
DOC_1_FILLER_3,
DOC_1_FILLER_4,
],
combined_content="N/A",
),
InferenceSection(
inference_section_from_chunks(
center_chunk=DOC_1_BOTTOM_CHUNK,
chunks=[
DOC_1_FILLER_3,
@@ -142,9 +141,8 @@ Content 17
DOC_1_FILLER_5,
DOC_1_FILLER_6,
],
combined_content="N/A",
),
InferenceSection(
inference_section_from_chunks(
center_chunk=DOC_1_DISCONNECTED,
chunks=[
DOC_1_FILLER_7,
@@ -153,9 +151,8 @@ Content 17
DOC_1_FILLER_9,
DOC_1_FILLER_10,
],
combined_content="N/A",
),
InferenceSection(
inference_section_from_chunks(
center_chunk=DOC_2_TOP_CHUNK,
chunks=[
DOC_2_FILLER_1,
@@ -164,9 +161,8 @@ Content 17
DOC_2_FILLER_3,
DOC_2_BOTTOM_CHUNK,
],
combined_content="N/A",
),
InferenceSection(
inference_section_from_chunks(
center_chunk=DOC_2_BOTTOM_CHUNK,
chunks=[
DOC_2_TOP_CHUNK,
@@ -175,7 +171,6 @@ Content 17
DOC_2_FILLER_4,
DOC_2_FILLER_5,
],
combined_content="N/A",
),
],
# Expected Content
@@ -204,15 +199,13 @@ def test_merge_sections(
(
# Sections
[
InferenceSection(
inference_section_from_chunks(
center_chunk=DOC_1_TOP_CHUNK,
chunks=[DOC_1_TOP_CHUNK],
combined_content="N/A", # Not used
),
InferenceSection(
inference_section_from_chunks(
center_chunk=DOC_1_MID_CHUNK,
chunks=[DOC_1_MID_CHUNK],
combined_content="N/A",
),
],
# Expected Content

View File

@@ -0,0 +1,208 @@
import os
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from onyx.configs.constants import DocumentSource
from onyx.context.search.models import InferenceChunk
from onyx.db.models import User
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
_post_query_chunk_censoring = fetch_ee_implementation_or_noop(
"onyx.external_permissions.post_query_censoring", "_post_query_chunk_censoring"
)
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="Permissions tests are enterprise only",
)
class TestPostQueryChunkCensoring:
@pytest.fixture(autouse=True)
def setUp(self) -> None:
self.mock_user = User(id=1, email="test@example.com")
self.mock_chunk_1 = InferenceChunk(
document_id="doc1",
chunk_id=1,
content="chunk1 content",
source_type=DocumentSource.SALESFORCE,
semantic_identifier="doc1_1",
title="doc1",
boost=1,
recency_bias=1.0,
score=0.9,
hidden=False,
metadata={},
match_highlights=[],
doc_summary="doc1 summary",
chunk_context="doc1 context",
updated_at=None,
image_file_name=None,
source_links={},
section_continuation=False,
blurb="chunk1",
)
self.mock_chunk_2 = InferenceChunk(
document_id="doc2",
chunk_id=2,
content="chunk2 content",
source_type=DocumentSource.SLACK,
semantic_identifier="doc2_2",
title="doc2",
boost=1,
recency_bias=1.0,
score=0.8,
hidden=False,
metadata={},
match_highlights=[],
doc_summary="doc2 summary",
chunk_context="doc2 context",
updated_at=None,
image_file_name=None,
source_links={},
section_continuation=False,
blurb="chunk2",
)
self.mock_chunk_3 = InferenceChunk(
document_id="doc3",
chunk_id=3,
content="chunk3 content",
source_type=DocumentSource.SALESFORCE,
semantic_identifier="doc3_3",
title="doc3",
boost=1,
recency_bias=1.0,
score=0.7,
hidden=False,
metadata={},
match_highlights=[],
doc_summary="doc3 summary",
chunk_context="doc3 context",
updated_at=None,
image_file_name=None,
source_links={},
section_continuation=False,
blurb="chunk3",
)
self.mock_chunk_4 = InferenceChunk(
document_id="doc4",
chunk_id=4,
content="chunk4 content",
source_type=DocumentSource.SALESFORCE,
semantic_identifier="doc4_4",
title="doc4",
boost=1,
recency_bias=1.0,
score=0.6,
hidden=False,
metadata={},
match_highlights=[],
doc_summary="doc4 summary",
chunk_context="doc4 context",
updated_at=None,
image_file_name=None,
source_links={},
section_continuation=False,
blurb="chunk4",
)
@patch(
"ee.onyx.external_permissions.post_query_censoring._get_all_censoring_enabled_sources"
)
def test_post_query_chunk_censoring_no_user(
self, mock_get_sources: MagicMock
) -> None:
mock_get_sources.return_value = {DocumentSource.SALESFORCE}
chunks = [self.mock_chunk_1, self.mock_chunk_2]
result = _post_query_chunk_censoring(chunks, None)
assert result == chunks
@patch(
"ee.onyx.external_permissions.post_query_censoring._get_all_censoring_enabled_sources"
)
@patch(
"ee.onyx.external_permissions.post_query_censoring.DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION"
)
def test_post_query_chunk_censoring_salesforce_censored(
self, mock_censor_func: MagicMock, mock_get_sources: MagicMock
) -> None:
mock_get_sources.return_value = {DocumentSource.SALESFORCE}
mock_censor_func_impl = MagicMock(
return_value=[self.mock_chunk_1]
) # Only return chunk 1
mock_censor_func.__getitem__.return_value = mock_censor_func_impl
chunks = [self.mock_chunk_1, self.mock_chunk_2, self.mock_chunk_3]
result = _post_query_chunk_censoring(chunks, self.mock_user)
assert len(result) == 2
assert self.mock_chunk_1 in result
assert self.mock_chunk_2 in result
assert self.mock_chunk_3 not in result
mock_censor_func_impl.assert_called_once()
@patch(
"ee.onyx.external_permissions.post_query_censoring._get_all_censoring_enabled_sources"
)
@patch(
"ee.onyx.external_permissions.post_query_censoring.DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION"
)
def test_post_query_chunk_censoring_salesforce_error(
self, mock_censor_func: MagicMock, mock_get_sources: MagicMock
) -> None:
mock_get_sources.return_value = {DocumentSource.SALESFORCE}
mock_censor_func_impl = MagicMock(side_effect=Exception("Censoring error"))
mock_censor_func.__getitem__.return_value = mock_censor_func_impl
chunks = [self.mock_chunk_1, self.mock_chunk_2, self.mock_chunk_3]
result = _post_query_chunk_censoring(chunks, self.mock_user)
assert len(result) == 1
assert self.mock_chunk_2 in result
mock_censor_func_impl.assert_called_once()
@patch(
"ee.onyx.external_permissions.post_query_censoring._get_all_censoring_enabled_sources"
)
@patch(
"ee.onyx.external_permissions.post_query_censoring.DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION"
)
def test_post_query_chunk_censoring_no_censoring(
self, mock_censor_func: MagicMock, mock_get_sources: MagicMock
) -> None:
mock_get_sources.return_value = set() # No sources to censor
mock_censor_func_impl = MagicMock()
mock_censor_func.__getitem__.return_value = mock_censor_func_impl
chunks = [self.mock_chunk_1, self.mock_chunk_2, self.mock_chunk_3]
result = _post_query_chunk_censoring(chunks, self.mock_user)
assert result == chunks
mock_censor_func_impl.assert_not_called()
@patch(
"ee.onyx.external_permissions.post_query_censoring._get_all_censoring_enabled_sources"
)
@patch(
"ee.onyx.external_permissions.post_query_censoring.DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION"
)
def test_post_query_chunk_censoring_order_maintained(
self, mock_censor_func: MagicMock, mock_get_sources: MagicMock
) -> None:
mock_get_sources.return_value = {DocumentSource.SALESFORCE}
mock_censor_func_impl = MagicMock(
return_value=[self.mock_chunk_3, self.mock_chunk_1]
) # Return chunk 3 and 1
mock_censor_func.__getitem__.return_value = mock_censor_func_impl
chunks = [
self.mock_chunk_1,
self.mock_chunk_2,
self.mock_chunk_3,
self.mock_chunk_4,
]
result = _post_query_chunk_censoring(chunks, self.mock_user)
assert len(result) == 3
assert result[0] == self.mock_chunk_1
assert result[1] == self.mock_chunk_2
assert result[2] == self.mock_chunk_3
assert self.mock_chunk_4 not in result
mock_censor_func_impl.assert_called_once()

View File

@@ -0,0 +1,270 @@
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import INDEX_SEPARATOR
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import Tag
from onyx.document_index.vespa.shared_utils.vespa_request_builders import (
build_vespa_filters,
)
from onyx.document_index.vespa_constants import DOC_UPDATED_AT
from onyx.document_index.vespa_constants import DOCUMENT_SETS
from onyx.document_index.vespa_constants import HIDDEN
from onyx.document_index.vespa_constants import METADATA_LIST
from onyx.document_index.vespa_constants import SOURCE_TYPE
from onyx.document_index.vespa_constants import TENANT_ID
from onyx.document_index.vespa_constants import USER_FILE
from onyx.document_index.vespa_constants import USER_FOLDER
from shared_configs.configs import MULTI_TENANT
# Import the function under test
class TestBuildVespaFilters:
def test_empty_filters(self) -> None:
"""Test with empty filters object."""
filters = IndexFilters(access_control_list=[])
result = build_vespa_filters(filters)
assert result == f"!({HIDDEN}=true) and "
# With trailing AND removed
result = build_vespa_filters(filters, remove_trailing_and=True)
assert result == f"!({HIDDEN}=true)"
def test_include_hidden(self) -> None:
"""Test with include_hidden flag."""
filters = IndexFilters(access_control_list=[])
result = build_vespa_filters(filters, include_hidden=True)
assert result == "" # No filters applied when including hidden
# With some other filter to ensure proper AND chaining
filters = IndexFilters(access_control_list=[], source_type=[DocumentSource.WEB])
result = build_vespa_filters(filters, include_hidden=True)
assert result == f'({SOURCE_TYPE} contains "web") and '
def test_acl(self) -> None:
"""Test with acls."""
# Single ACL
filters = IndexFilters(access_control_list=["user1"])
result = build_vespa_filters(filters)
assert (
result
== f'!({HIDDEN}=true) and (access_control_list contains "user1") and '
)
# Multiple ACL's
filters = IndexFilters(access_control_list=["user2", "group2"])
result = build_vespa_filters(filters)
assert (
result
== f'!({HIDDEN}=true) and (access_control_list contains "user2" or access_control_list contains "group2") and '
)
def test_tenant_filter(self) -> None:
"""Test tenant ID filtering."""
# With tenant ID
if MULTI_TENANT:
filters = IndexFilters(access_control_list=[], tenant_id="tenant1")
result = build_vespa_filters(filters)
assert (
f'!({HIDDEN}=true) and ({TENANT_ID} contains "tenant1") and ' == result
)
# No tenant ID
filters = IndexFilters(access_control_list=[], tenant_id=None)
result = build_vespa_filters(filters)
assert f"!({HIDDEN}=true) and " == result
def test_source_type_filter(self) -> None:
"""Test source type filtering."""
# Single source type
filters = IndexFilters(access_control_list=[], source_type=[DocumentSource.WEB])
result = build_vespa_filters(filters)
assert f'!({HIDDEN}=true) and ({SOURCE_TYPE} contains "web") and ' == result
# Multiple source types
filters = IndexFilters(
access_control_list=[],
source_type=[DocumentSource.WEB, DocumentSource.JIRA],
)
result = build_vespa_filters(filters)
assert (
f'!({HIDDEN}=true) and ({SOURCE_TYPE} contains "web" or {SOURCE_TYPE} contains "jira") and '
== result
)
# Empty source type list
filters = IndexFilters(access_control_list=[], source_type=[])
result = build_vespa_filters(filters)
assert f"!({HIDDEN}=true) and " == result
def test_tag_filters(self) -> None:
"""Test tag filtering."""
# Single tag
filters = IndexFilters(
access_control_list=[], tags=[Tag(tag_key="color", tag_value="red")]
)
result = build_vespa_filters(filters)
assert (
f'!({HIDDEN}=true) and ({METADATA_LIST} contains "color{INDEX_SEPARATOR}red") and '
== result
)
# Multiple tags
filters = IndexFilters(
access_control_list=[],
tags=[
Tag(tag_key="color", tag_value="red"),
Tag(tag_key="size", tag_value="large"),
],
)
result = build_vespa_filters(filters)
expected = (
f'!({HIDDEN}=true) and ({METADATA_LIST} contains "color{INDEX_SEPARATOR}red" '
f'or {METADATA_LIST} contains "size{INDEX_SEPARATOR}large") and '
)
assert expected == result
# Empty tags list
filters = IndexFilters(access_control_list=[], tags=[])
result = build_vespa_filters(filters)
assert f"!({HIDDEN}=true) and " == result
def test_document_sets_filter(self) -> None:
"""Test document sets filtering."""
# Single document set
filters = IndexFilters(access_control_list=[], document_set=["set1"])
result = build_vespa_filters(filters)
assert f'!({HIDDEN}=true) and ({DOCUMENT_SETS} contains "set1") and ' == result
# Multiple document sets
filters = IndexFilters(access_control_list=[], document_set=["set1", "set2"])
result = build_vespa_filters(filters)
assert (
f'!({HIDDEN}=true) and ({DOCUMENT_SETS} contains "set1" or {DOCUMENT_SETS} contains "set2") and '
== result
)
# Empty document sets
filters = IndexFilters(access_control_list=[], document_set=[])
result = build_vespa_filters(filters)
assert f"!({HIDDEN}=true) and " == result
def test_user_file_ids_filter(self) -> None:
"""Test user file IDs filtering."""
# Single user file ID
filters = IndexFilters(access_control_list=[], user_file_ids=[123])
result = build_vespa_filters(filters)
assert f"!({HIDDEN}=true) and ({USER_FILE} = 123) and " == result
# Multiple user file IDs
filters = IndexFilters(access_control_list=[], user_file_ids=[123, 456])
result = build_vespa_filters(filters)
assert (
f"!({HIDDEN}=true) and ({USER_FILE} = 123 or {USER_FILE} = 456) and "
== result
)
# Empty user file IDs
filters = IndexFilters(access_control_list=[], user_file_ids=[])
result = build_vespa_filters(filters)
assert f"!({HIDDEN}=true) and " == result
def test_user_folder_ids_filter(self) -> None:
"""Test user folder IDs filtering."""
# Single user folder ID
filters = IndexFilters(access_control_list=[], user_folder_ids=[789])
result = build_vespa_filters(filters)
assert f"!({HIDDEN}=true) and ({USER_FOLDER} = 789) and " == result
# Multiple user folder IDs
filters = IndexFilters(access_control_list=[], user_folder_ids=[789, 101])
result = build_vespa_filters(filters)
assert (
f"!({HIDDEN}=true) and ({USER_FOLDER} = 789 or {USER_FOLDER} = 101) and "
== result
)
# Empty user folder IDs
filters = IndexFilters(access_control_list=[], user_folder_ids=[])
result = build_vespa_filters(filters)
assert f"!({HIDDEN}=true) and " == result
def test_time_cutoff_filter(self) -> None:
"""Test time cutoff filtering."""
# With cutoff time
cutoff_time = datetime(2023, 1, 1, tzinfo=timezone.utc)
filters = IndexFilters(access_control_list=[], time_cutoff=cutoff_time)
result = build_vespa_filters(filters)
cutoff_secs = int(cutoff_time.timestamp())
assert (
f"!({HIDDEN}=true) and !({DOC_UPDATED_AT} < {cutoff_secs}) and " == result
)
# No cutoff time
filters = IndexFilters(access_control_list=[], time_cutoff=None)
result = build_vespa_filters(filters)
assert f"!({HIDDEN}=true) and " == result
# Test untimed logic (when cutoff is old enough)
old_cutoff = datetime.now(timezone.utc) - timedelta(days=100)
filters = IndexFilters(access_control_list=[], time_cutoff=old_cutoff)
result = build_vespa_filters(filters)
old_cutoff_secs = int(old_cutoff.timestamp())
assert (
f"!({HIDDEN}=true) and !({DOC_UPDATED_AT} < {old_cutoff_secs}) and "
== result
)
def test_combined_filters(self) -> None:
"""Test combining multiple filter types."""
filters = IndexFilters(
access_control_list=["user1", "group1"],
source_type=[DocumentSource.WEB],
tags=[Tag(tag_key="color", tag_value="red")],
document_set=["set1"],
user_file_ids=[123],
user_folder_ids=[789],
time_cutoff=datetime(2023, 1, 1, tzinfo=timezone.utc),
)
result = build_vespa_filters(filters)
# Build expected result piece by piece for readability
expected = f"!({HIDDEN}=true) and "
expected += (
'(access_control_list contains "user1" or '
'access_control_list contains "group1") and '
)
expected += f'({SOURCE_TYPE} contains "web") and '
expected += f'({METADATA_LIST} contains "color{INDEX_SEPARATOR}red") and '
expected += f'({DOCUMENT_SETS} contains "set1") and '
expected += f"({USER_FILE} = 123) and "
expected += f"({USER_FOLDER} = 789) and "
cutoff_secs = int(datetime(2023, 1, 1, tzinfo=timezone.utc).timestamp())
expected += f"!({DOC_UPDATED_AT} < {cutoff_secs}) and "
assert expected == result
# With trailing AND removed
result_no_trailing = build_vespa_filters(filters, remove_trailing_and=True)
assert expected[:-5] == result_no_trailing # Remove trailing " and "
def test_empty_or_none_values(self) -> None:
"""Test with empty or None values in filter lists."""
# Empty strings in document set
filters = IndexFilters(
access_control_list=[], document_set=["set1", "", "set2"]
)
result = build_vespa_filters(filters)
assert (
f'!({HIDDEN}=true) and ({DOCUMENT_SETS} contains "set1" or {DOCUMENT_SETS} contains "set2") and '
== result
)
# All empty strings in document set
filters = IndexFilters(access_control_list=[], document_set=["", ""])
result = build_vespa_filters(filters)
assert f"!({HIDDEN}=true) and " == result

View File

@@ -42,9 +42,7 @@ import Link from "next/link";
import { useRouter, useSearchParams } from "next/navigation";
import { useEffect, useMemo, useState } from "react";
import * as Yup from "yup";
import CollapsibleSection from "./CollapsibleSection";
import { SuccessfulPersonaUpdateRedirectType } from "./enums";
import { Persona, PersonaLabel, StarterMessage } from "./interfaces";
import { FullPersona, PersonaLabel, StarterMessage } from "./interfaces";
import {
PersonaUpsertParameters,
createPersona,
@@ -101,6 +99,7 @@ import { SEARCH_TOOL_ID } from "@/app/chat/tools/constants";
import TextView from "@/components/chat/TextView";
import { MinimalOnyxDocument } from "@/lib/search/interfaces";
import { TabToggle } from "@/components/ui/TabToggle";
import { MAX_CHARACTERS_PERSONA_DESCRIPTION } from "@/lib/constants";
function findSearchTool(tools: ToolSnapshot[]) {
return tools.find((tool) => tool.in_code_tool_id === SEARCH_TOOL_ID);
@@ -136,7 +135,7 @@ export function AssistantEditor({
shouldAddAssistantToUserPreferences,
admin,
}: {
existingPersona?: Persona | null;
existingPersona?: FullPersona | null;
ccPairs: CCPairBasicInfo[];
documentSets: DocumentSet[];
user: User | null;
@@ -184,8 +183,6 @@ export function AssistantEditor({
}
}, [defaultIconShape]);
const [isIconDropdownOpen, setIsIconDropdownOpen] = useState(false);
const [removePersonaImage, setRemovePersonaImage] = useState(false);
const autoStarterMessageEnabled = useMemo(
@@ -242,15 +239,7 @@ export function AssistantEditor({
enabledToolsMap[tool.id] = personaCurrentToolIds.includes(tool.id);
});
const {
selectedFiles,
selectedFolders,
addSelectedFile,
removeSelectedFile,
addSelectedFolder,
removeSelectedFolder,
clearSelectedItems,
} = useDocumentsContext();
const { selectedFiles, selectedFolders } = useDocumentsContext();
const [showVisibilityWarning, setShowVisibilityWarning] = useState(false);
@@ -470,12 +459,12 @@ export function AssistantEditor({
"Must provide a description for the Assistant"
),
system_prompt: Yup.string().max(
8000,
"Instructions must be less than 8000 characters"
MAX_CHARACTERS_PERSONA_DESCRIPTION,
"Instructions must be less than 5000000 characters"
),
task_prompt: Yup.string().max(
8000,
"Reminders must be less than 8000 characters"
MAX_CHARACTERS_PERSONA_DESCRIPTION,
"Reminders must be less than 5000000 characters"
),
is_public: Yup.boolean().required(),
document_set_ids: Yup.array().of(Yup.number()),

View File

@@ -18,35 +18,37 @@ export interface Prompt {
datetime_aware: boolean;
default_prompt: boolean;
}
export interface Persona {
id: number;
name: string;
search_start_date: Date | null;
owner: MinimalUserSnapshot | null;
is_visible: boolean;
is_public: boolean;
display_priority: number | null;
description: string;
document_sets: DocumentSet[];
prompts: Prompt[];
tools: ToolSnapshot[];
num_chunks?: number;
llm_relevance_filter?: boolean;
llm_filter_extraction?: boolean;
llm_model_provider_override?: string;
llm_model_version_override?: string;
starter_messages: StarterMessage[] | null;
builtin_persona: boolean;
is_default_persona: boolean;
users: MinimalUserSnapshot[];
groups: number[];
is_public: boolean;
is_visible: boolean;
icon_shape?: number;
icon_color?: string;
uploaded_image_id?: string;
labels?: PersonaLabel[];
user_file_ids: number[];
user_folder_ids: number[];
display_priority: number | null;
is_default_persona: boolean;
builtin_persona: boolean;
starter_messages: StarterMessage[] | null;
tools: ToolSnapshot[];
labels?: PersonaLabel[];
owner: MinimalUserSnapshot | null;
users: MinimalUserSnapshot[];
groups: number[];
document_sets: DocumentSet[];
llm_model_provider_override?: string;
llm_model_version_override?: string;
num_chunks?: number;
}
export interface FullPersona extends Persona {
search_start_date: Date | null;
prompts: Prompt[];
llm_relevance_filter?: boolean;
llm_filter_extraction?: boolean;
}
export interface PersonaLabel {

View File

@@ -331,28 +331,3 @@ export function providersContainImageGeneratingSupport(
) {
return providers.some((provider) => provider.provider === "openai");
}
// Default fallback persona for when we must display a persona
// but assistant has access to none
export const defaultPersona: Persona = {
id: 0,
name: "Default Assistant",
description: "A default assistant",
is_visible: true,
is_public: true,
builtin_persona: false,
is_default_persona: true,
users: [],
groups: [],
document_sets: [],
prompts: [],
tools: [],
starter_messages: null,
display_priority: null,
search_start_date: null,
owner: null,
icon_shape: 50910,
icon_color: "#FF6F6F",
user_file_ids: [],
user_folder_ids: [],
};

View File

@@ -487,11 +487,6 @@ export default function EmbeddingForm() {
};
const handleReIndex = async () => {
console.log("handleReIndex");
console.log(selectedProvider);
console.log(advancedEmbeddingDetails);
console.log(rerankingDetails);
console.log(reindexType);
if (!selectedProvider) {
return;
}

View File

@@ -1383,7 +1383,7 @@ export function ChatPage({
regenerationRequest?.parentMessage.messageId ||
lastSuccessfulMessageId,
chatSessionId: currChatSessionId,
promptId: liveAssistant?.prompts[0]?.id || 0,
promptId: null,
filters: buildFilters(
filterManager.selectedSources,
filterManager.selectedDocumentSets,

View File

@@ -9,11 +9,6 @@ import { redirect } from "next/navigation";
import { BackendChatSession } from "../../interfaces";
import { SharedChatDisplay } from "./SharedChatDisplay";
import { Persona } from "@/app/admin/assistants/interfaces";
import {
FetchAssistantsResponse,
fetchAssistantsSS,
} from "@/lib/assistants/fetchAssistantsSS";
import { defaultPersona } from "@/app/admin/assistants/lib";
import { constructMiniFiedPersona } from "@/lib/assistantIconUtils";
async function getSharedChat(chatId: string) {

View File

@@ -2,7 +2,7 @@ import { User } from "@/lib/types";
import { FiPlus, FiX } from "react-icons/fi";
import { SearchMultiSelectDropdown } from "@/components/Dropdown";
import { UsersIcon } from "@/components/icons/icons";
import { Button } from "@/components/Button";
import { Button } from "@/components/ui/button";
interface UserEditorProps {
selectedUserIds: string[];

View File

@@ -22,7 +22,7 @@ export const AddMemberForm: React.FC<AddMemberFormProps> = ({
return (
<Modal
className="max-w-xl"
className="max-w-xl overflow-visible"
title="Add New User"
onOutsideClick={() => onClose()}
>

View File

@@ -43,13 +43,15 @@ const DropdownOption: React.FC<DropdownOptionProps> = ({
if (href) {
return (
<Link
// Use a button instead of a link to avoid the default behavior of Next.js
// Which caches existing contexts and leads to wonky behavior.
<a
href={href}
target={openInNewTab ? "_blank" : undefined}
rel={openInNewTab ? "noopener noreferrer" : undefined}
>
{content}
</Link>
</a>
);
} else {
return <div onClick={onClick}>{content}</div>;

View File

@@ -167,9 +167,7 @@ export const constructMiniFiedPersona = (
display_priority: 0,
description: "",
document_sets: [],
prompts: [],
tools: [],
search_start_date: null,
owner: null,
starter_messages: null,
builtin_persona: false,

View File

@@ -1,4 +1,4 @@
import { Persona } from "@/app/admin/assistants/interfaces";
import { FullPersona, Persona } from "@/app/admin/assistants/interfaces";
import { CCPairBasicInfo, DocumentSet, User } from "../types";
import { getCurrentUserSS } from "../userSS";
import { fetchSS } from "../utilsSS";
@@ -18,7 +18,7 @@ export async function fetchAssistantEditorInfoSS(
documentSets: DocumentSet[];
llmProviders: LLMProviderView[];
user: User | null;
existingPersona: Persona | null;
existingPersona: FullPersona | null;
tools: ToolSnapshot[];
},
null,
@@ -94,7 +94,7 @@ export async function fetchAssistantEditorInfoSS(
}
const existingPersona = personaResponse
? ((await personaResponse.json()) as Persona)
? ((await personaResponse.json()) as FullPersona)
: null;
let error: string | null = null;

View File

@@ -105,3 +105,5 @@ export const ALLOWED_URL_PROTOCOLS = [
"spotify:",
"zoommtg:",
];
export const MAX_CHARACTERS_PERSONA_DESCRIPTION = 5000000;

View File

@@ -77,6 +77,7 @@ const MODEL_NAMES_SUPPORTING_IMAGE_INPUT = [
"claude-3-haiku-20240307",
// custom claude names
"claude-3.5-sonnet-v2@20241022",
"claude-3-7-sonnet@20250219",
// claude names with AWS Bedrock Suffix
"claude-3-opus-20240229-v1:0",
"claude-3-sonnet-20240229-v1:0",
@@ -125,12 +126,27 @@ export function checkLLMSupportsImageInput(model: string) {
const modelParts = model.split(/[/.]/);
const lastPart = modelParts[modelParts.length - 1]?.toLowerCase();
return MODEL_NAMES_SUPPORTING_IMAGE_INPUT.some((modelName) => {
// Try matching the last part
const lastPartMatch = MODEL_NAMES_SUPPORTING_IMAGE_INPUT.some((modelName) => {
const modelNameParts = modelName.split(/[/.]/);
const modelNameLastPart = modelNameParts[modelNameParts.length - 1];
// lastPart is already lowercased above for tiny performance gain
return modelNameLastPart?.toLowerCase() === lastPart;
});
if (lastPartMatch) {
return true;
}
// If no match found, try getting the text after the first slash
if (model.includes("/")) {
const afterSlash = model.split("/")[1]?.toLowerCase();
return MODEL_NAMES_SUPPORTING_IMAGE_INPUT.some((modelName) =>
modelName.toLowerCase().includes(afterSlash)
);
}
return false;
}
export const structureValue = (

View File

@@ -1,201 +0,0 @@
import {
BackendMessage,
LLMRelevanceFilterPacket,
} from "@/app/chat/interfaces";
import {
AnswerPiecePacket,
OnyxDocument,
ErrorMessagePacket,
DocumentInfoPacket,
Quote,
QuotesInfoPacket,
RelevanceChunk,
SearchRequestArgs,
} from "./interfaces";
import { processRawChunkString } from "./streamingUtils";
import { buildFilters, endsWithLetterOrNumber } from "./utils";
export const searchRequestStreamed = async ({
query,
sources,
documentSets,
timeRange,
tags,
persona,
agentic,
updateCurrentAnswer,
updateQuotes,
updateDocs,
updateSuggestedSearchType,
updateSuggestedFlowType,
updateSelectedDocIndices,
updateError,
updateMessageAndThreadId,
finishedSearching,
updateDocumentRelevance,
updateComments,
}: SearchRequestArgs) => {
let answer = "";
let quotes: Quote[] | null = null;
let relevantDocuments: OnyxDocument[] | null = null;
try {
const filters = buildFilters(sources, documentSets, timeRange, tags);
const threadMessage = {
message: query,
sender: null,
role: "user",
};
const response = await fetch("/api/query/stream-answer-with-quote", {
method: "POST",
body: JSON.stringify({
messages: [threadMessage],
persona_id: persona.id,
agentic,
prompt_id: persona.id === 0 ? null : persona.prompts[0]?.id,
retrieval_options: {
run_search: "always",
real_time: true,
filters: filters,
enable_auto_detect_filters: false,
},
evaluation_type: agentic ? "agentic" : "basic",
}),
headers: {
"Content-Type": "application/json",
},
});
const reader = response.body?.getReader();
const decoder = new TextDecoder("utf-8");
let previousPartialChunk: string | null = null;
while (true) {
const rawChunk = await reader?.read();
if (!rawChunk) {
throw new Error("Unable to process chunk");
}
const { done, value } = rawChunk;
if (done) {
break;
}
// Process each chunk as it arrives
const [completedChunks, partialChunk] = processRawChunkString<
| AnswerPiecePacket
| ErrorMessagePacket
| QuotesInfoPacket
| DocumentInfoPacket
| LLMRelevanceFilterPacket
| BackendMessage
| DocumentInfoPacket
| RelevanceChunk
>(decoder.decode(value, { stream: true }), previousPartialChunk);
if (!completedChunks.length && !partialChunk) {
break;
}
previousPartialChunk = partialChunk as string | null;
completedChunks.forEach((chunk) => {
// check for answer piece / end of answer
if (Object.hasOwn(chunk, "relevance_summaries")) {
const relevanceChunk = chunk as RelevanceChunk;
updateDocumentRelevance(relevanceChunk.relevance_summaries);
}
if (Object.hasOwn(chunk, "answer_piece")) {
const answerPiece = (chunk as AnswerPiecePacket).answer_piece;
if (answerPiece !== null) {
answer += (chunk as AnswerPiecePacket).answer_piece;
updateCurrentAnswer(answer);
} else {
// set quotes as non-null to signify that the answer is finished and
// we're now looking for quotes
updateQuotes([]);
if (
answer &&
!answer.endsWith(".") &&
!answer.endsWith("?") &&
!answer.endsWith("!") &&
endsWithLetterOrNumber(answer)
) {
answer += ".";
updateCurrentAnswer(answer);
}
}
return;
}
if (Object.hasOwn(chunk, "error")) {
updateError((chunk as ErrorMessagePacket).error);
return;
}
// These all come together
if (Object.hasOwn(chunk, "top_documents")) {
chunk = chunk as DocumentInfoPacket;
const topDocuments = chunk.top_documents as OnyxDocument[] | null;
if (topDocuments) {
relevantDocuments = topDocuments;
updateDocs(relevantDocuments);
}
if (chunk.predicted_flow) {
updateSuggestedFlowType(chunk.predicted_flow);
}
if (chunk.predicted_search) {
updateSuggestedSearchType(chunk.predicted_search);
}
return;
}
if (Object.hasOwn(chunk, "relevant_chunk_indices")) {
const relevantChunkIndices = (chunk as LLMRelevanceFilterPacket)
.relevant_chunk_indices;
if (relevantChunkIndices) {
updateSelectedDocIndices(relevantChunkIndices);
}
return;
}
// Check for quote section
if (Object.hasOwn(chunk, "quotes")) {
quotes = (chunk as QuotesInfoPacket).quotes;
updateQuotes(quotes);
return;
}
// Check for the final chunk
if (Object.hasOwn(chunk, "message_id")) {
const backendChunk = chunk as BackendMessage;
updateComments(backendChunk.comments);
updateMessageAndThreadId(
backendChunk.message_id,
backendChunk.chat_session_id
);
}
});
}
} catch (err) {
console.error("Fetch error:", err);
let errorMessage = "An error occurred while fetching the answer.";
if (err instanceof Error) {
if (err.message.includes("rate_limit_error")) {
errorMessage =
"Rate limit exceeded. Please try again later or reduce the length of your query.";
} else {
errorMessage = err.message;
}
}
updateError(errorMessage);
}
return { answer, quotes, relevantDocuments };
};