Compare commits

..

15 Commits

Author SHA1 Message Date
pablonyx
59c454debe k 2025-04-03 11:45:38 -07:00
pablonyx
93886f0e2c Assistant Prompt length + client side (#4433) 2025-04-03 11:26:53 -07:00
rkuo-danswer
8c3a953b7a add prometheus metrics endpoints via helper package (#4436)
* add prometheus metrics endpoints via helper package

* model server specific requirements

* mark as public endpoint

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-04-03 16:52:05 +00:00
evan-danswer
54b883d0ca fix large docs selected in chat pruning (#4412)
* fix large docs selected in chat pruning

* better approach to length restriction

* comments

* comments

* fix unit tests and minor pruning bug

* remove prints
2025-04-03 15:48:10 +00:00
pablonyx
91faac5447 minor fix (#4435) 2025-04-03 15:00:27 +00:00
Chris Weaver
1d8f9fc39d Fix weird re-index state (#4439)
* Fix weird re-index state

* Address rkuo's comments
2025-04-03 02:16:34 +00:00
Weves
9390de21e5 More logging on confluence space permissions 2025-04-02 20:01:38 -07:00
rkuo-danswer
3a33433fc9 unit tests for chunk censoring (#4434)
* unit tests for chunk censoring

* type hints for mypy

* pytestification

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-04-03 01:28:54 +00:00
Chris Weaver
c4865d57b1 Fix tons of users w/o drive access causing timeouts (#4437) 2025-04-03 00:01:05 +00:00
rkuo-danswer
81d04db08f Feature/request id middleware 2 (#4427)
* stubbing out request id

* passthru or create request id's in api and model server

* add onyx request id

* get request id logging into uvicorn

* no logs

* change prefixes

* fix comment

* docker image needs specific shared files

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-04-02 22:30:03 +00:00
rkuo-danswer
d50a17db21 add filter unit tests (#4421)
* add filter unit tests

* fix tests

---------

Co-authored-by: Richard Kuo (Onyx) <rkuo@onyx.app>
2025-04-02 20:26:25 +00:00
pablonyx
dc5a1e8fd0 add more flexible vision support check (#4429) 2025-04-02 18:11:33 +00:00
pablonyx
c0b3681650 update (#4428) 2025-04-02 18:09:44 +00:00
Chris Weaver
7ec04484d4 Another fix for Salesforce perm sync (#4432)
* Another fix for Salesforce perm sync

* typing
2025-04-02 11:08:40 -07:00
Weves
1cf966ecc1 Fix Salesforce perm sync 2025-04-02 10:47:26 -07:00
51 changed files with 1141 additions and 450 deletions

View File

@@ -46,6 +46,7 @@ WORKDIR /app
# Utils used by model server
COPY ./onyx/utils/logger.py /app/onyx/utils/logger.py
COPY ./onyx/utils/middleware.py /app/onyx/utils/middleware.py
# Place to fetch version information
COPY ./onyx/__init__.py /app/onyx/__init__.py

View File

@@ -0,0 +1,50 @@
"""update prompt length
Revision ID: 4794bc13e484
Revises: f7505c5b0284
Create Date: 2025-04-02 11:26:36.180328
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "4794bc13e484"
down_revision = "f7505c5b0284"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.alter_column(
"prompt",
"system_prompt",
existing_type=sa.TEXT(),
type_=sa.String(length=5000000),
existing_nullable=False,
)
op.alter_column(
"prompt",
"task_prompt",
existing_type=sa.TEXT(),
type_=sa.String(length=5000000),
existing_nullable=False,
)
def downgrade() -> None:
op.alter_column(
"prompt",
"system_prompt",
existing_type=sa.String(length=5000000),
type_=sa.TEXT(),
existing_nullable=False,
)
op.alter_column(
"prompt",
"task_prompt",
existing_type=sa.String(length=5000000),
type_=sa.TEXT(),
existing_nullable=False,
)

View File

@@ -5,8 +5,6 @@ Revises: 6a804aeb4830
Create Date: 2025-04-01 15:07:14.977435
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
@@ -17,34 +15,36 @@ depends_on = None
def upgrade() -> None:
op.alter_column(
"prompt",
"system_prompt",
existing_type=sa.TEXT(),
type_=sa.String(length=8000),
existing_nullable=False,
)
op.alter_column(
"prompt",
"task_prompt",
existing_type=sa.TEXT(),
type_=sa.String(length=8000),
existing_nullable=False,
)
# op.alter_column(
# "prompt",
# "system_prompt",
# existing_type=sa.TEXT(),
# type_=sa.String(length=8000),
# existing_nullable=False,
# )
# op.alter_column(
# "prompt",
# "task_prompt",
# existing_type=sa.TEXT(),
# type_=sa.String(length=8000),
# existing_nullable=False,
# )
pass
def downgrade() -> None:
op.alter_column(
"prompt",
"system_prompt",
existing_type=sa.String(length=8000),
type_=sa.TEXT(),
existing_nullable=False,
)
op.alter_column(
"prompt",
"task_prompt",
existing_type=sa.String(length=8000),
type_=sa.TEXT(),
existing_nullable=False,
)
# op.alter_column(
# "prompt",
# "system_prompt",
# existing_type=sa.String(length=8000),
# type_=sa.TEXT(),
# existing_nullable=False,
# )
# op.alter_column(
# "prompt",
# "task_prompt",
# existing_type=sa.String(length=8000),
# type_=sa.TEXT(),
# existing_nullable=False,
# )
pass

View File

@@ -159,6 +159,9 @@ def _get_space_permissions(
# Stores the permissions for each space
space_permissions_by_space_key[space_key] = space_permissions
logger.info(
f"Found space permissions for space '{space_key}': {space_permissions}"
)
return space_permissions_by_space_key

View File

@@ -55,7 +55,7 @@ def _post_query_chunk_censoring(
# if user is None, permissions are not enforced
return chunks
chunks_to_keep = []
final_chunk_dict: dict[str, InferenceChunk] = {}
chunks_to_process: dict[DocumentSource, list[InferenceChunk]] = {}
sources_to_censor = _get_all_censoring_enabled_sources()
@@ -64,7 +64,7 @@ def _post_query_chunk_censoring(
if chunk.source_type in sources_to_censor:
chunks_to_process.setdefault(chunk.source_type, []).append(chunk)
else:
chunks_to_keep.append(chunk)
final_chunk_dict[chunk.unique_id] = chunk
# For each source, filter out the chunks using the permission
# check function for that source
@@ -79,6 +79,16 @@ def _post_query_chunk_censoring(
f" chunks for this source and continuing: {e}"
)
continue
chunks_to_keep.extend(censored_chunks)
return chunks_to_keep
for censored_chunk in censored_chunks:
final_chunk_dict[censored_chunk.unique_id] = censored_chunk
# IMPORTANT: make sure to retain the same ordering as the original `chunks` passed in
final_chunk_list: list[InferenceChunk] = []
for chunk in chunks:
# only if the chunk is in the final censored chunks, add it to the final list
# if it is missing, that means it was intentionally left out
if chunk.unique_id in final_chunk_dict:
final_chunk_list.append(final_chunk_dict[chunk.unique_id])
return final_chunk_list

View File

@@ -42,11 +42,18 @@ def get_any_salesforce_client_for_doc_id(
def _query_salesforce_user_id(sf_client: Salesforce, user_email: str) -> str | None:
query = f"SELECT Id FROM User WHERE Email = '{user_email}'"
query = f"SELECT Id FROM User WHERE Username = '{user_email}' AND IsActive = true"
result = sf_client.query(query)
if len(result["records"]) == 0:
return None
return result["records"][0]["Id"]
if len(result["records"]) > 0:
return result["records"][0]["Id"]
# try emails
query = f"SELECT Id FROM User WHERE Email = '{user_email}' AND IsActive = true"
result = sf_client.query(query)
if len(result["records"]) > 0:
return result["records"][0]["Id"]
return None
# This contains only the user_ids that we have found in Salesforce.

View File

@@ -1,3 +1,4 @@
import logging
import os
import shutil
from collections.abc import AsyncGenerator
@@ -8,6 +9,7 @@ import sentry_sdk
import torch
import uvicorn
from fastapi import FastAPI
from prometheus_fastapi_instrumentator import Instrumentator
from sentry_sdk.integrations.fastapi import FastApiIntegration
from sentry_sdk.integrations.starlette import StarletteIntegration
from transformers import logging as transformer_logging # type:ignore
@@ -20,6 +22,8 @@ from model_server.management_endpoints import router as management_router
from model_server.utils import get_gpu_type
from onyx import __version__
from onyx.utils.logger import setup_logger
from onyx.utils.logger import setup_uvicorn_logger
from onyx.utils.middleware import add_onyx_request_id_middleware
from shared_configs.configs import INDEXING_ONLY
from shared_configs.configs import MIN_THREADS_ML_MODELS
from shared_configs.configs import MODEL_SERVER_ALLOWED_HOST
@@ -36,6 +40,12 @@ transformer_logging.set_verbosity_error()
logger = setup_logger()
file_handlers = [
h for h in logger.logger.handlers if isinstance(h, logging.FileHandler)
]
setup_uvicorn_logger(shared_file_handlers=file_handlers)
def _move_files_recursively(source: Path, dest: Path, overwrite: bool = False) -> None:
"""
@@ -112,6 +122,15 @@ def get_model_app() -> FastAPI:
application.include_router(encoders_router)
application.include_router(custom_models_router)
request_id_prefix = "INF"
if INDEXING_ONLY:
request_id_prefix = "IDX"
add_onyx_request_id_middleware(application, request_id_prefix, logger)
# Initialize and instrument the app
Instrumentator().instrument(application).expose(application)
return application

View File

@@ -15,6 +15,22 @@ class ExternalAccess:
# Whether the document is public in the external system or Onyx
is_public: bool
def __str__(self) -> str:
"""Prevent extremely long logs"""
def truncate_set(s: set[str], max_len: int = 100) -> str:
s_str = str(s)
if len(s_str) > max_len:
return f"{s_str[:max_len]}... ({len(s)} items)"
return s_str
return (
f"ExternalAccess("
f"external_user_emails={truncate_set(self.external_user_emails)}, "
f"external_user_group_ids={truncate_set(self.external_user_group_ids)}, "
f"is_public={self.is_public})"
)
@dataclass(frozen=True)
class DocExternalAccess:

View File

@@ -43,6 +43,7 @@ from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_me
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from onyx.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from onyx.configs.chat_configs import SELECTED_SECTIONS_MAX_WINDOW_PERCENTAGE
from onyx.configs.constants import AGENT_SEARCH_INITIAL_KEY
from onyx.configs.constants import BASIC_KEY
from onyx.configs.constants import MessageType
@@ -692,8 +693,13 @@ def stream_chat_message_objects(
doc_identifiers=identifier_tuples,
document_index=document_index,
)
# Add a maximum context size in the case of user-selected docs to prevent
# slight inaccuracies in context window size pruning from causing
# the entire query to fail
document_pruning_config = DocumentPruningConfig(
is_manually_selected_docs=True
is_manually_selected_docs=True,
max_window_percentage=SELECTED_SECTIONS_MAX_WINDOW_PERCENTAGE,
)
# In case the search doc is deleted, just don't include it

View File

@@ -312,11 +312,14 @@ def prune_sections(
)
def _merge_doc_chunks(chunks: list[InferenceChunk]) -> InferenceSection:
def _merge_doc_chunks(chunks: list[InferenceChunk]) -> tuple[InferenceSection, int]:
assert (
len(set([chunk.document_id for chunk in chunks])) == 1
), "One distinct document must be passed into merge_doc_chunks"
ADJACENT_CHUNK_SEP = "\n"
DISTANT_CHUNK_SEP = "\n\n...\n\n"
# Assuming there are no duplicates by this point
sorted_chunks = sorted(chunks, key=lambda x: x.chunk_id)
@@ -324,33 +327,48 @@ def _merge_doc_chunks(chunks: list[InferenceChunk]) -> InferenceSection:
chunks, key=lambda x: x.score if x.score is not None else float("-inf")
)
added_chars = 0
merged_content = []
for i, chunk in enumerate(sorted_chunks):
if i > 0:
prev_chunk_id = sorted_chunks[i - 1].chunk_id
if chunk.chunk_id == prev_chunk_id + 1:
merged_content.append("\n")
else:
merged_content.append("\n\n...\n\n")
sep = (
ADJACENT_CHUNK_SEP
if chunk.chunk_id == prev_chunk_id + 1
else DISTANT_CHUNK_SEP
)
merged_content.append(sep)
added_chars += len(sep)
merged_content.append(chunk.content)
combined_content = "".join(merged_content)
return InferenceSection(
center_chunk=center_chunk,
chunks=sorted_chunks,
combined_content=combined_content,
return (
InferenceSection(
center_chunk=center_chunk,
chunks=sorted_chunks,
combined_content=combined_content,
),
added_chars,
)
def _merge_sections(sections: list[InferenceSection]) -> list[InferenceSection]:
docs_map: dict[str, dict[int, InferenceChunk]] = defaultdict(dict)
doc_order: dict[str, int] = {}
combined_section_lengths: dict[str, int] = defaultdict(lambda: 0)
# chunk de-duping and doc ordering
for index, section in enumerate(sections):
if section.center_chunk.document_id not in doc_order:
doc_order[section.center_chunk.document_id] = index
combined_section_lengths[section.center_chunk.document_id] += len(
section.combined_content
)
chunks_map = docs_map[section.center_chunk.document_id]
for chunk in [section.center_chunk] + section.chunks:
chunks_map = docs_map[section.center_chunk.document_id]
existing_chunk = chunks_map.get(chunk.chunk_id)
if (
existing_chunk is None
@@ -361,8 +379,22 @@ def _merge_sections(sections: list[InferenceSection]) -> list[InferenceSection]:
chunks_map[chunk.chunk_id] = chunk
new_sections = []
for section_chunks in docs_map.values():
new_sections.append(_merge_doc_chunks(chunks=list(section_chunks.values())))
for doc_id, section_chunks in docs_map.items():
section_chunks_list = list(section_chunks.values())
merged_section, added_chars = _merge_doc_chunks(chunks=section_chunks_list)
previous_length = combined_section_lengths[doc_id] + added_chars
# After merging, ensure the content respects the pruning done earlier. Each
# combined section is restricted to the sum of the lengths of the sections
# from the pruning step. Technically the correct approach would be to prune based
# on tokens AGAIN, but this is a good approximation and worth not adding the
# tokenization overhead. This could also be fixed if we added a way of removing
# chunks from sections in the pruning step; at the moment this issue largely
# exists because we only trim the final section's combined_content.
merged_section.combined_content = merged_section.combined_content[
:previous_length
]
new_sections.append(merged_section)
# Sort by highest score, then by original document order
# It is now 1 large section per doc, the center chunk being the one with the highest score

View File

@@ -16,6 +16,9 @@ MAX_CHUNKS_FED_TO_CHAT = float(os.environ.get("MAX_CHUNKS_FED_TO_CHAT") or 10.0)
# ~3k input, half for docs, half for chat history + prompts
CHAT_TARGET_CHUNK_PERCENTAGE = 512 * 3 / 3072
# Maximum percentage of the context window to fill with selected sections
SELECTED_SECTIONS_MAX_WINDOW_PERCENTAGE = 0.8
# 1 / (1 + DOC_TIME_DECAY * doc-age-in-years), set to 0 to have no decay
# Capped in Vespa at 0.5
DOC_TIME_DECAY = float(

View File

@@ -445,6 +445,9 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
logger.warning(
f"User '{user_email}' does not have access to the drive APIs."
)
# mark this user as done so we don't try to retrieve anything for them
# again
curr_stage.stage = DriveRetrievalStage.DONE
return
raise
@@ -581,6 +584,25 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
drive_ids_to_retrieve, checkpoint
)
# only process emails that we haven't already completed retrieval for
non_completed_org_emails = [
user_email
for user_email, stage in checkpoint.completion_map.items()
if stage != DriveRetrievalStage.DONE
]
# don't process too many emails before returning a checkpoint. This is
# to resolve the case where there are a ton of emails that don't have access
# to the drive APIs. Without this, we could loop through these emails for
# more than 3 hours, causing a timeout and stalling progress.
email_batch_takes_us_to_completion = True
MAX_EMAILS_TO_PROCESS_BEFORE_CHECKPOINTING = 50
if len(non_completed_org_emails) > MAX_EMAILS_TO_PROCESS_BEFORE_CHECKPOINTING:
non_completed_org_emails = non_completed_org_emails[
:MAX_EMAILS_TO_PROCESS_BEFORE_CHECKPOINTING
]
email_batch_takes_us_to_completion = False
user_retrieval_gens = [
self._impersonate_user_for_retrieval(
email,
@@ -591,10 +613,14 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
start,
end,
)
for email in all_org_emails
for email in non_completed_org_emails
]
yield from parallel_yield(user_retrieval_gens, max_workers=MAX_DRIVE_WORKERS)
# if there are more emails to process, don't mark as complete
if not email_batch_takes_us_to_completion:
return
remaining_folders = (
drive_ids_to_retrieve | folder_ids_to_retrieve
) - self._retrieved_ids

View File

@@ -227,13 +227,16 @@ class SearchPipeline:
# If ee is enabled, censor the chunk sections based on user access
# Otherwise, return the retrieved chunks
censored_chunks = fetch_ee_implementation_or_noop(
"onyx.external_permissions.post_query_censoring",
"_post_query_chunk_censoring",
retrieved_chunks,
)(
chunks=retrieved_chunks,
user=self.user,
censored_chunks = cast(
list[InferenceChunk],
fetch_ee_implementation_or_noop(
"onyx.external_permissions.post_query_censoring",
"_post_query_chunk_censoring",
retrieved_chunks,
)(
chunks=retrieved_chunks,
user=self.user,
),
)
above = self.search_query.chunks_above

View File

@@ -613,8 +613,19 @@ def fetch_connector_credential_pairs(
def resync_cc_pair(
cc_pair: ConnectorCredentialPair,
search_settings_id: int,
db_session: Session,
) -> None:
"""
Updates state stored in the connector_credential_pair table based on the
latest index attempt for the given search settings.
Args:
cc_pair: ConnectorCredentialPair to resync
search_settings_id: SearchSettings to use for resync
db_session: Database session
"""
def find_latest_index_attempt(
connector_id: int,
credential_id: int,
@@ -627,11 +638,10 @@ def resync_cc_pair(
ConnectorCredentialPair,
IndexAttempt.connector_credential_pair_id == ConnectorCredentialPair.id,
)
.join(SearchSettings, IndexAttempt.search_settings_id == SearchSettings.id)
.filter(
ConnectorCredentialPair.connector_id == connector_id,
ConnectorCredentialPair.credential_id == credential_id,
SearchSettings.status == IndexModelStatus.PRESENT,
IndexAttempt.search_settings_id == search_settings_id,
)
)

View File

@@ -43,6 +43,8 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
ONE_HOUR_IN_SECONDS = 60 * 60
def check_docs_exist(db_session: Session) -> bool:
stmt = select(exists(DbDocument))
@@ -607,6 +609,46 @@ def delete_documents_complete__no_commit(
delete_documents__no_commit(db_session, document_ids)
def delete_all_documents_for_connector_credential_pair(
db_session: Session,
connector_id: int,
credential_id: int,
timeout: int = ONE_HOUR_IN_SECONDS,
) -> None:
"""Delete all documents for a given connector credential pair.
This will delete all documents and their associated data (chunks, feedback, tags, etc.)
NOTE: a bit inefficient, but it's not a big deal since this is done rarely - only during
an index swap. If we wanted to make this more efficient, we could use a single delete
statement + cascade.
"""
batch_size = 1000
start_time = time.monotonic()
while True:
# Get document IDs in batches
stmt = (
select(DocumentByConnectorCredentialPair.id)
.where(
DocumentByConnectorCredentialPair.connector_id == connector_id,
DocumentByConnectorCredentialPair.credential_id == credential_id,
)
.limit(batch_size)
)
document_ids = db_session.scalars(stmt).all()
if not document_ids:
break
delete_documents_complete__no_commit(
db_session=db_session, document_ids=list(document_ids)
)
db_session.commit()
if time.monotonic() - start_time > timeout:
raise RuntimeError("Timeout reached while deleting documents")
def acquire_document_locks(db_session: Session, document_ids: list[str]) -> bool:
"""Acquire locks for the specified documents. Ideally this shouldn't be
called with large list of document_ids (an exception could be made if the

View File

@@ -710,6 +710,25 @@ def cancel_indexing_attempts_past_model(
)
def cancel_indexing_attempts_for_search_settings(
search_settings_id: int,
db_session: Session,
) -> None:
"""Stops all indexing attempts that are in progress or not started for
the specified search settings."""
db_session.execute(
update(IndexAttempt)
.where(
IndexAttempt.status.in_(
[IndexingStatus.IN_PROGRESS, IndexingStatus.NOT_STARTED]
),
IndexAttempt.search_settings_id == search_settings_id,
)
.values(status=IndexingStatus.FAILED)
)
def count_unique_cc_pairs_with_successful_index_attempts(
search_settings_id: int | None,
db_session: Session,

View File

@@ -37,8 +37,8 @@ from onyx.db.models import UserFile
from onyx.db.models import UserFolder
from onyx.db.models import UserGroup
from onyx.db.notification import create_notification
from onyx.server.features.persona.models import FullPersonaSnapshot
from onyx.server.features.persona.models import PersonaSharedNotificationData
from onyx.server.features.persona.models import PersonaSnapshot
from onyx.server.features.persona.models import PersonaUpsertRequest
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_versioned_implementation
@@ -201,7 +201,7 @@ def create_update_persona(
create_persona_request: PersonaUpsertRequest,
user: User | None,
db_session: Session,
) -> PersonaSnapshot:
) -> FullPersonaSnapshot:
"""Higher level function than upsert_persona, although either is valid to use."""
# Permission to actually use these is checked later
@@ -271,7 +271,7 @@ def create_update_persona(
logger.exception("Failed to create persona")
raise HTTPException(status_code=400, detail=str(e))
return PersonaSnapshot.from_model(persona)
return FullPersonaSnapshot.from_model(persona)
def update_persona_shared_users(

View File

@@ -3,8 +3,9 @@ from sqlalchemy.orm import Session
from onyx.configs.constants import KV_REINDEX_KEY
from onyx.db.connector_credential_pair import get_connector_credential_pairs
from onyx.db.connector_credential_pair import resync_cc_pair
from onyx.db.document import delete_all_documents_for_connector_credential_pair
from onyx.db.enums import IndexModelStatus
from onyx.db.index_attempt import cancel_indexing_attempts_past_model
from onyx.db.index_attempt import cancel_indexing_attempts_for_search_settings
from onyx.db.index_attempt import (
count_unique_cc_pairs_with_successful_index_attempts,
)
@@ -26,31 +27,49 @@ def _perform_index_swap(
current_search_settings: SearchSettings,
secondary_search_settings: SearchSettings,
all_cc_pairs: list[ConnectorCredentialPair],
cleanup_documents: bool = False,
) -> None:
"""Swap the indices and expire the old one."""
current_search_settings = get_current_search_settings(db_session)
update_search_settings_status(
search_settings=current_search_settings,
new_status=IndexModelStatus.PAST,
db_session=db_session,
)
update_search_settings_status(
search_settings=secondary_search_settings,
new_status=IndexModelStatus.PRESENT,
db_session=db_session,
)
if len(all_cc_pairs) > 0:
kv_store = get_kv_store()
kv_store.store(KV_REINDEX_KEY, False)
# Expire jobs for the now past index/embedding model
cancel_indexing_attempts_past_model(db_session)
cancel_indexing_attempts_for_search_settings(
search_settings_id=current_search_settings.id,
db_session=db_session,
)
# Recount aggregates
for cc_pair in all_cc_pairs:
resync_cc_pair(cc_pair, db_session=db_session)
resync_cc_pair(
cc_pair=cc_pair,
# sync based on the new search settings
search_settings_id=secondary_search_settings.id,
db_session=db_session,
)
if cleanup_documents:
# clean up all DocumentByConnectorCredentialPair / Document rows, since we're
# doing an instant swap and no documents will exist in the new index.
for cc_pair in all_cc_pairs:
delete_all_documents_for_connector_credential_pair(
db_session=db_session,
connector_id=cc_pair.connector_id,
credential_id=cc_pair.credential_id,
)
# swap over search settings
update_search_settings_status(
search_settings=current_search_settings,
new_status=IndexModelStatus.PAST,
db_session=db_session,
)
update_search_settings_status(
search_settings=secondary_search_settings,
new_status=IndexModelStatus.PRESENT,
db_session=db_session,
)
# remove the old index from the vector db
document_index = get_default_document_index(secondary_search_settings, None)
@@ -88,6 +107,9 @@ def check_and_perform_index_swap(db_session: Session) -> SearchSettings | None:
current_search_settings=current_search_settings,
secondary_search_settings=secondary_search_settings,
all_cc_pairs=all_cc_pairs,
# clean up all DocumentByConnectorCredentialPair / Document rows, since we're
# doing an instant swap.
cleanup_documents=True,
)
return current_search_settings

View File

@@ -1,3 +1,4 @@
import logging
import sys
import traceback
from collections.abc import AsyncGenerator
@@ -16,6 +17,7 @@ from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from httpx_oauth.clients.google import GoogleOAuth2
from prometheus_fastapi_instrumentator import Instrumentator
from sentry_sdk.integrations.fastapi import FastApiIntegration
from sentry_sdk.integrations.starlette import StarletteIntegration
from sqlalchemy.orm import Session
@@ -102,6 +104,8 @@ from onyx.server.utils import BasicAuthenticationError
from onyx.setup import setup_multitenant_onyx
from onyx.setup import setup_onyx
from onyx.utils.logger import setup_logger
from onyx.utils.logger import setup_uvicorn_logger
from onyx.utils.middleware import add_onyx_request_id_middleware
from onyx.utils.telemetry import get_or_generate_uuid
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
@@ -116,6 +120,12 @@ from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
file_handlers = [
h for h in logger.logger.handlers if isinstance(h, logging.FileHandler)
]
setup_uvicorn_logger(shared_file_handlers=file_handlers)
def validation_exception_handler(request: Request, exc: Exception) -> JSONResponse:
if not isinstance(exc, RequestValidationError):
@@ -421,9 +431,14 @@ def get_application() -> FastAPI:
if LOG_ENDPOINT_LATENCY:
add_latency_logging_middleware(application, logger)
add_onyx_request_id_middleware(application, "API", logger)
# Ensure all routes have auth enabled or are explicitly marked as public
check_router_auth(application)
# Initialize and instrument the app
Instrumentator().instrument(application).expose(application)
return application

View File

@@ -49,6 +49,7 @@ PUBLIC_ENDPOINT_SPECS = [
("/auth/oauth/callback", {"GET"}),
# anonymous user on cloud
("/tenants/anonymous-user", {"POST"}),
("/metrics", {"GET"}), # added by prometheus_fastapi_instrumentator
]

View File

@@ -43,6 +43,7 @@ from onyx.file_store.models import ChatFileType
from onyx.secondary_llm_flows.starter_message_creation import (
generate_starter_messages,
)
from onyx.server.features.persona.models import FullPersonaSnapshot
from onyx.server.features.persona.models import GenerateStarterMessageRequest
from onyx.server.features.persona.models import ImageGenerationToolStatus
from onyx.server.features.persona.models import PersonaLabelCreate
@@ -424,8 +425,8 @@ def get_persona(
persona_id: int,
user: User | None = Depends(current_limited_user),
db_session: Session = Depends(get_session),
) -> PersonaSnapshot:
return PersonaSnapshot.from_model(
) -> FullPersonaSnapshot:
return FullPersonaSnapshot.from_model(
get_persona_by_id(
persona_id=persona_id,
user=user,

View File

@@ -91,37 +91,80 @@ class PersonaUpsertRequest(BaseModel):
class PersonaSnapshot(BaseModel):
id: int
owner: MinimalUserSnapshot | None
name: str
is_visible: bool
is_public: bool
display_priority: int | None
description: str
num_chunks: float | None
llm_relevance_filter: bool
llm_filter_extraction: bool
llm_model_provider_override: str | None
llm_model_version_override: str | None
starter_messages: list[StarterMessage] | None
builtin_persona: bool
prompts: list[PromptSnapshot]
tools: list[ToolSnapshot]
document_sets: list[DocumentSet]
users: list[MinimalUserSnapshot]
groups: list[int]
icon_color: str | None
icon_shape: int | None
is_public: bool
is_visible: bool
icon_shape: int | None = None
icon_color: str | None = None
uploaded_image_id: str | None = None
is_default_persona: bool
user_file_ids: list[int] = Field(default_factory=list)
user_folder_ids: list[int] = Field(default_factory=list)
display_priority: int | None = None
is_default_persona: bool = False
builtin_persona: bool = False
starter_messages: list[StarterMessage] | None = None
tools: list[ToolSnapshot] = Field(default_factory=list)
labels: list["PersonaLabelSnapshot"] = Field(default_factory=list)
owner: MinimalUserSnapshot | None = None
users: list[MinimalUserSnapshot] = Field(default_factory=list)
groups: list[int] = Field(default_factory=list)
document_sets: list[DocumentSet] = Field(default_factory=list)
llm_model_provider_override: str | None = None
llm_model_version_override: str | None = None
num_chunks: float | None = None
@classmethod
def from_model(cls, persona: Persona) -> "PersonaSnapshot":
return PersonaSnapshot(
id=persona.id,
name=persona.name,
description=persona.description,
is_public=persona.is_public,
is_visible=persona.is_visible,
icon_shape=persona.icon_shape,
icon_color=persona.icon_color,
uploaded_image_id=persona.uploaded_image_id,
user_file_ids=[file.id for file in persona.user_files],
user_folder_ids=[folder.id for folder in persona.user_folders],
display_priority=persona.display_priority,
is_default_persona=persona.is_default_persona,
builtin_persona=persona.builtin_persona,
starter_messages=persona.starter_messages,
tools=[ToolSnapshot.from_model(tool) for tool in persona.tools],
labels=[PersonaLabelSnapshot.from_model(label) for label in persona.labels],
owner=(
MinimalUserSnapshot(id=persona.user.id, email=persona.user.email)
if persona.user
else None
),
users=[
MinimalUserSnapshot(id=user.id, email=user.email)
for user in persona.users
],
groups=[user_group.id for user_group in persona.groups],
document_sets=[
DocumentSet.from_model(document_set_model)
for document_set_model in persona.document_sets
],
llm_model_provider_override=persona.llm_model_provider_override,
llm_model_version_override=persona.llm_model_version_override,
num_chunks=persona.num_chunks,
)
# Model with full context on perona's internal settings
# This is used for flows which need to know all settings
class FullPersonaSnapshot(PersonaSnapshot):
search_start_date: datetime | None = None
labels: list["PersonaLabelSnapshot"] = []
user_file_ids: list[int] | None = None
user_folder_ids: list[int] | None = None
prompts: list[PromptSnapshot] = Field(default_factory=list)
llm_relevance_filter: bool = False
llm_filter_extraction: bool = False
@classmethod
def from_model(
cls, persona: Persona, allow_deleted: bool = False
) -> "PersonaSnapshot":
) -> "FullPersonaSnapshot":
if persona.deleted:
error_msg = f"Persona with ID {persona.id} has been deleted"
if not allow_deleted:
@@ -129,44 +172,32 @@ class PersonaSnapshot(BaseModel):
else:
logger.warning(error_msg)
return PersonaSnapshot(
return FullPersonaSnapshot(
id=persona.id,
name=persona.name,
description=persona.description,
is_public=persona.is_public,
is_visible=persona.is_visible,
icon_shape=persona.icon_shape,
icon_color=persona.icon_color,
uploaded_image_id=persona.uploaded_image_id,
user_file_ids=[file.id for file in persona.user_files],
user_folder_ids=[folder.id for folder in persona.user_folders],
display_priority=persona.display_priority,
is_default_persona=persona.is_default_persona,
builtin_persona=persona.builtin_persona,
starter_messages=persona.starter_messages,
tools=[ToolSnapshot.from_model(tool) for tool in persona.tools],
labels=[PersonaLabelSnapshot.from_model(label) for label in persona.labels],
owner=(
MinimalUserSnapshot(id=persona.user.id, email=persona.user.email)
if persona.user
else None
),
is_visible=persona.is_visible,
is_public=persona.is_public,
display_priority=persona.display_priority,
description=persona.description,
num_chunks=persona.num_chunks,
search_start_date=persona.search_start_date,
prompts=[PromptSnapshot.from_model(prompt) for prompt in persona.prompts],
llm_relevance_filter=persona.llm_relevance_filter,
llm_filter_extraction=persona.llm_filter_extraction,
llm_model_provider_override=persona.llm_model_provider_override,
llm_model_version_override=persona.llm_model_version_override,
starter_messages=persona.starter_messages,
builtin_persona=persona.builtin_persona,
is_default_persona=persona.is_default_persona,
prompts=[PromptSnapshot.from_model(prompt) for prompt in persona.prompts],
tools=[ToolSnapshot.from_model(tool) for tool in persona.tools],
document_sets=[
DocumentSet.from_model(document_set_model)
for document_set_model in persona.document_sets
],
users=[
MinimalUserSnapshot(id=user.id, email=user.email)
for user in persona.users
],
groups=[user_group.id for user_group in persona.groups],
icon_color=persona.icon_color,
icon_shape=persona.icon_shape,
uploaded_image_id=persona.uploaded_image_id,
search_start_date=persona.search_start_date,
labels=[PersonaLabelSnapshot.from_model(label) for label in persona.labels],
user_file_ids=[file.id for file in persona.user_files],
user_folder_ids=[folder.id for folder in persona.user_folders],
)

View File

@@ -19,6 +19,7 @@ from onyx.db.models import SlackBot as SlackAppModel
from onyx.db.models import SlackChannelConfig as SlackChannelConfigModel
from onyx.db.models import User
from onyx.onyxbot.slack.config import VALID_SLACK_FILTERS
from onyx.server.features.persona.models import FullPersonaSnapshot
from onyx.server.features.persona.models import PersonaSnapshot
from onyx.server.models import FullUserSnapshot
from onyx.server.models import InvitedUserSnapshot
@@ -245,7 +246,7 @@ class SlackChannelConfig(BaseModel):
id=slack_channel_config_model.id,
slack_bot_id=slack_channel_config_model.slack_bot_id,
persona=(
PersonaSnapshot.from_model(
FullPersonaSnapshot.from_model(
slack_channel_config_model.persona, allow_deleted=True
)
if slack_channel_config_model.persona

View File

@@ -117,7 +117,11 @@ def set_new_search_settings(
search_settings_id=search_settings.id, db_session=db_session
)
for cc_pair in get_connector_credential_pairs(db_session):
resync_cc_pair(cc_pair, db_session=db_session)
resync_cc_pair(
cc_pair=cc_pair,
search_settings_id=new_search_settings.id,
db_session=db_session,
)
db_session.commit()
return IdReturn(id=new_search_settings.id)

View File

@@ -96,7 +96,11 @@ def setup_onyx(
)
for cc_pair in get_connector_credential_pairs(db_session):
resync_cc_pair(cc_pair, db_session=db_session)
resync_cc_pair(
cc_pair=cc_pair,
search_settings_id=search_settings.id,
db_session=db_session,
)
# Expire all old embedding models indexing attempts, technically redundant
cancel_indexing_attempts_past_model(db_session)

View File

@@ -13,6 +13,7 @@ from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import SLACK_CHANNEL_ID
from shared_configs.configs import TENANT_ID_PREFIX
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import ONYX_REQUEST_ID_CONTEXTVAR
logging.addLevelName(logging.INFO + 5, "NOTICE")
@@ -71,6 +72,14 @@ def get_log_level_from_str(log_level_str: str = LOG_LEVEL) -> int:
return log_level_dict.get(log_level_str.upper(), logging.getLevelName("NOTICE"))
class OnyxRequestIDFilter(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool:
from shared_configs.contextvars import ONYX_REQUEST_ID_CONTEXTVAR
record.request_id = ONYX_REQUEST_ID_CONTEXTVAR.get() or "-"
return True
class OnyxLoggingAdapter(logging.LoggerAdapter):
def process(
self, msg: str, kwargs: MutableMapping[str, Any]
@@ -103,6 +112,7 @@ class OnyxLoggingAdapter(logging.LoggerAdapter):
msg = f"[CC Pair: {cc_pair_id}] {msg}"
break
# Add tenant information if it differs from default
# This will always be the case for authenticated API requests
if MULTI_TENANT:
@@ -115,6 +125,11 @@ class OnyxLoggingAdapter(logging.LoggerAdapter):
)
msg = f"[t:{short_tenant}] {msg}"
# request id within a fastapi route
fastapi_request_id = ONYX_REQUEST_ID_CONTEXTVAR.get()
if fastapi_request_id:
msg = f"[{fastapi_request_id}] {msg}"
# For Slack Bot, logs the channel relevant to the request
channel_id = self.extra.get(SLACK_CHANNEL_ID) if self.extra else None
if channel_id:
@@ -165,6 +180,14 @@ class ColoredFormatter(logging.Formatter):
return super().format(record)
def get_uvicorn_standard_formatter() -> ColoredFormatter:
"""Returns a standard colored logging formatter."""
return ColoredFormatter(
"%(asctime)s %(filename)30s %(lineno)4s: [%(request_id)s] %(message)s",
datefmt="%m/%d/%Y %I:%M:%S %p",
)
def get_standard_formatter() -> ColoredFormatter:
"""Returns a standard colored logging formatter."""
return ColoredFormatter(
@@ -201,12 +224,6 @@ def setup_logger(
logger.addHandler(handler)
uvicorn_logger = logging.getLogger("uvicorn.access")
if uvicorn_logger:
uvicorn_logger.handlers = []
uvicorn_logger.addHandler(handler)
uvicorn_logger.setLevel(log_level)
is_containerized = is_running_in_container()
if LOG_FILE_NAME and (is_containerized or DEV_LOGGING_ENABLED):
log_levels = ["debug", "info", "notice"]
@@ -225,14 +242,37 @@ def setup_logger(
file_handler.setFormatter(formatter)
logger.addHandler(file_handler)
if uvicorn_logger:
uvicorn_logger.addHandler(file_handler)
logger.notice = lambda msg, *args, **kwargs: logger.log(logging.getLevelName("NOTICE"), msg, *args, **kwargs) # type: ignore
return OnyxLoggingAdapter(logger, extra=extra)
def setup_uvicorn_logger(
log_level: int = get_log_level_from_str(),
shared_file_handlers: list[logging.FileHandler] | None = None,
) -> None:
uvicorn_logger = logging.getLogger("uvicorn.access")
if not uvicorn_logger:
return
formatter = get_uvicorn_standard_formatter()
handler = logging.StreamHandler()
handler.setLevel(log_level)
handler.setFormatter(formatter)
uvicorn_logger.handlers = []
uvicorn_logger.addHandler(handler)
uvicorn_logger.setLevel(log_level)
uvicorn_logger.addFilter(OnyxRequestIDFilter())
if shared_file_handlers:
for fh in shared_file_handlers:
uvicorn_logger.addHandler(fh)
return
def print_loggers() -> None:
"""Print information about all loggers. Use to debug logging issues."""
root_logger = logging.getLogger()

View File

@@ -0,0 +1,62 @@
import base64
import hashlib
import logging
import uuid
from collections.abc import Awaitable
from collections.abc import Callable
from datetime import datetime
from datetime import timezone
from fastapi import FastAPI
from fastapi import Request
from fastapi import Response
from shared_configs.contextvars import ONYX_REQUEST_ID_CONTEXTVAR
def add_onyx_request_id_middleware(
app: FastAPI, prefix: str, logger: logging.LoggerAdapter
) -> None:
@app.middleware("http")
async def set_request_id(
request: Request, call_next: Callable[[Request], Awaitable[Response]]
) -> Response:
"""Generate a request hash that can be used to track the lifecycle
of a request. The hash is prefixed to help indicated where the request id
originated.
Format is f"{PREFIX}:{ID}" where PREFIX is 3 chars and ID is 8 chars.
Total length is 12 chars.
"""
onyx_request_id = request.headers.get("X-Onyx-Request-ID")
if not onyx_request_id:
onyx_request_id = make_randomized_onyx_request_id(prefix)
ONYX_REQUEST_ID_CONTEXTVAR.set(onyx_request_id)
return await call_next(request)
def make_randomized_onyx_request_id(prefix: str) -> str:
"""generates a randomized request id"""
hash_input = str(uuid.uuid4())
return _make_onyx_request_id(prefix, hash_input)
def make_structured_onyx_request_id(prefix: str, request_url: str) -> str:
"""Not used yet, but could be in the future!"""
hash_input = f"{request_url}:{datetime.now(timezone.utc)}"
return _make_onyx_request_id(prefix, hash_input)
def _make_onyx_request_id(prefix: str, hash_input: str) -> str:
"""helper function to return an id given a string input"""
hash_obj = hashlib.md5(hash_input.encode("utf-8"))
hash_bytes = hash_obj.digest()[:6] # Truncate to 6 bytes
# 6 bytes becomes 8 bytes. we shouldn't need to strip but just in case
# NOTE: possible we'll want more input bytes if id's aren't unique enough
hash_str = base64.urlsafe_b64encode(hash_bytes).decode("utf-8").rstrip("=")
onyx_request_id = f"{prefix}:{hash_str}"
return onyx_request_id

View File

@@ -332,14 +332,15 @@ def wait_on_background(task: TimeoutThread[R]) -> R:
return task.result
def _next_or_none(ind: int, g: Iterator[R]) -> tuple[int, R | None]:
return ind, next(g, None)
def _next_or_none(ind: int, gen: Iterator[R]) -> tuple[int, R | None]:
return ind, next(gen, None)
def parallel_yield(gens: list[Iterator[R]], max_workers: int = 10) -> Iterator[R]:
with ThreadPoolExecutor(max_workers=max_workers) as executor:
future_to_index: dict[Future[tuple[int, R | None]], int] = {
executor.submit(_next_or_none, i, g): i for i, g in enumerate(gens)
executor.submit(_next_or_none, ind, gen): ind
for ind, gen in enumerate(gens)
}
next_ind = len(gens)

View File

@@ -95,4 +95,5 @@ urllib3==2.2.3
mistune==0.8.4
sentry-sdk==2.14.0
prometheus_client==0.21.0
fastapi-limiter==0.1.6
fastapi-limiter==0.1.6
prometheus_fastapi_instrumentator==7.1.0

View File

@@ -15,4 +15,5 @@ uvicorn==0.21.1
voyageai==0.2.3
litellm==1.61.16
sentry-sdk[fastapi,celery,starlette]==2.14.0
aioboto3==13.4.0
aioboto3==13.4.0
prometheus_fastapi_instrumentator==7.1.0

View File

@@ -58,6 +58,7 @@ INDEXING_ONLY = os.environ.get("INDEXING_ONLY", "").lower() == "true"
# The process needs to have this for the log file to write to
# otherwise, it will not create additional log files
# This should just be the filename base without extension or path.
LOG_FILE_NAME = os.environ.get("LOG_FILE_NAME") or "onyx"
# Enable generating persistent log files for local dev environments

View File

@@ -11,6 +11,15 @@ CURRENT_TENANT_ID_CONTEXTVAR: contextvars.ContextVar[
"current_tenant_id", default=None if MULTI_TENANT else POSTGRES_DEFAULT_SCHEMA
)
# set by every route in the API server
INDEXING_REQUEST_ID_CONTEXTVAR: contextvars.ContextVar[
str | None
] = contextvars.ContextVar("indexing_request_id", default=None)
# set by every route in the API server
ONYX_REQUEST_ID_CONTEXTVAR: contextvars.ContextVar[str | None] = contextvars.ContextVar(
"onyx_request_id", default=None
)
"""Utils related to contextvars"""

View File

@@ -34,7 +34,7 @@ def confluence_connector(space: str) -> ConfluenceConnector:
return connector
@pytest.mark.parametrize("space", [os.environ["CONFLUENCE_TEST_SPACE"]])
@pytest.mark.parametrize("space", [os.getenv("CONFLUENCE_TEST_SPACE") or "DailyConne"])
@patch(
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,

View File

@@ -4,7 +4,7 @@ from uuid import uuid4
import requests
from onyx.context.search.enums import RecencyBiasSetting
from onyx.server.features.persona.models import PersonaSnapshot
from onyx.server.features.persona.models import FullPersonaSnapshot
from onyx.server.features.persona.models import PersonaUpsertRequest
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
@@ -181,7 +181,7 @@ class PersonaManager:
@staticmethod
def get_all(
user_performing_action: DATestUser | None = None,
) -> list[PersonaSnapshot]:
) -> list[FullPersonaSnapshot]:
response = requests.get(
f"{API_SERVER_URL}/admin/persona",
headers=user_performing_action.headers
@@ -189,13 +189,13 @@ class PersonaManager:
else GENERAL_HEADERS,
)
response.raise_for_status()
return [PersonaSnapshot(**persona) for persona in response.json()]
return [FullPersonaSnapshot(**persona) for persona in response.json()]
@staticmethod
def get_one(
persona_id: int,
user_performing_action: DATestUser | None = None,
) -> list[PersonaSnapshot]:
) -> list[FullPersonaSnapshot]:
response = requests.get(
f"{API_SERVER_URL}/persona/{persona_id}",
headers=user_performing_action.headers
@@ -203,7 +203,7 @@ class PersonaManager:
else GENERAL_HEADERS,
)
response.raise_for_status()
return [PersonaSnapshot(**response.json())]
return [FullPersonaSnapshot(**response.json())]
@staticmethod
def verify(

View File

@@ -4,6 +4,7 @@ from onyx.chat.prune_and_merge import _merge_sections
from onyx.configs.constants import DocumentSource
from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import InferenceSection
from onyx.context.search.utils import inference_section_from_chunks
# This large test accounts for all of the following:
@@ -111,7 +112,7 @@ Content 17
# Sections
[
# Document 1, top/middle/bot connected + disconnected section
InferenceSection(
inference_section_from_chunks(
center_chunk=DOC_1_TOP_CHUNK,
chunks=[
DOC_1_FILLER_1,
@@ -120,9 +121,8 @@ Content 17
DOC_1_MID_CHUNK,
DOC_1_FILLER_3,
],
combined_content="N/A", # Not used
),
InferenceSection(
inference_section_from_chunks(
center_chunk=DOC_1_MID_CHUNK,
chunks=[
DOC_1_FILLER_2,
@@ -131,9 +131,8 @@ Content 17
DOC_1_FILLER_3,
DOC_1_FILLER_4,
],
combined_content="N/A",
),
InferenceSection(
inference_section_from_chunks(
center_chunk=DOC_1_BOTTOM_CHUNK,
chunks=[
DOC_1_FILLER_3,
@@ -142,9 +141,8 @@ Content 17
DOC_1_FILLER_5,
DOC_1_FILLER_6,
],
combined_content="N/A",
),
InferenceSection(
inference_section_from_chunks(
center_chunk=DOC_1_DISCONNECTED,
chunks=[
DOC_1_FILLER_7,
@@ -153,9 +151,8 @@ Content 17
DOC_1_FILLER_9,
DOC_1_FILLER_10,
],
combined_content="N/A",
),
InferenceSection(
inference_section_from_chunks(
center_chunk=DOC_2_TOP_CHUNK,
chunks=[
DOC_2_FILLER_1,
@@ -164,9 +161,8 @@ Content 17
DOC_2_FILLER_3,
DOC_2_BOTTOM_CHUNK,
],
combined_content="N/A",
),
InferenceSection(
inference_section_from_chunks(
center_chunk=DOC_2_BOTTOM_CHUNK,
chunks=[
DOC_2_TOP_CHUNK,
@@ -175,7 +171,6 @@ Content 17
DOC_2_FILLER_4,
DOC_2_FILLER_5,
],
combined_content="N/A",
),
],
# Expected Content
@@ -204,15 +199,13 @@ def test_merge_sections(
(
# Sections
[
InferenceSection(
inference_section_from_chunks(
center_chunk=DOC_1_TOP_CHUNK,
chunks=[DOC_1_TOP_CHUNK],
combined_content="N/A", # Not used
),
InferenceSection(
inference_section_from_chunks(
center_chunk=DOC_1_MID_CHUNK,
chunks=[DOC_1_MID_CHUNK],
combined_content="N/A",
),
],
# Expected Content

View File

@@ -0,0 +1,208 @@
import os
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from onyx.configs.constants import DocumentSource
from onyx.context.search.models import InferenceChunk
from onyx.db.models import User
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
_post_query_chunk_censoring = fetch_ee_implementation_or_noop(
"onyx.external_permissions.post_query_censoring", "_post_query_chunk_censoring"
)
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="Permissions tests are enterprise only",
)
class TestPostQueryChunkCensoring:
@pytest.fixture(autouse=True)
def setUp(self) -> None:
self.mock_user = User(id=1, email="test@example.com")
self.mock_chunk_1 = InferenceChunk(
document_id="doc1",
chunk_id=1,
content="chunk1 content",
source_type=DocumentSource.SALESFORCE,
semantic_identifier="doc1_1",
title="doc1",
boost=1,
recency_bias=1.0,
score=0.9,
hidden=False,
metadata={},
match_highlights=[],
doc_summary="doc1 summary",
chunk_context="doc1 context",
updated_at=None,
image_file_name=None,
source_links={},
section_continuation=False,
blurb="chunk1",
)
self.mock_chunk_2 = InferenceChunk(
document_id="doc2",
chunk_id=2,
content="chunk2 content",
source_type=DocumentSource.SLACK,
semantic_identifier="doc2_2",
title="doc2",
boost=1,
recency_bias=1.0,
score=0.8,
hidden=False,
metadata={},
match_highlights=[],
doc_summary="doc2 summary",
chunk_context="doc2 context",
updated_at=None,
image_file_name=None,
source_links={},
section_continuation=False,
blurb="chunk2",
)
self.mock_chunk_3 = InferenceChunk(
document_id="doc3",
chunk_id=3,
content="chunk3 content",
source_type=DocumentSource.SALESFORCE,
semantic_identifier="doc3_3",
title="doc3",
boost=1,
recency_bias=1.0,
score=0.7,
hidden=False,
metadata={},
match_highlights=[],
doc_summary="doc3 summary",
chunk_context="doc3 context",
updated_at=None,
image_file_name=None,
source_links={},
section_continuation=False,
blurb="chunk3",
)
self.mock_chunk_4 = InferenceChunk(
document_id="doc4",
chunk_id=4,
content="chunk4 content",
source_type=DocumentSource.SALESFORCE,
semantic_identifier="doc4_4",
title="doc4",
boost=1,
recency_bias=1.0,
score=0.6,
hidden=False,
metadata={},
match_highlights=[],
doc_summary="doc4 summary",
chunk_context="doc4 context",
updated_at=None,
image_file_name=None,
source_links={},
section_continuation=False,
blurb="chunk4",
)
@patch(
"ee.onyx.external_permissions.post_query_censoring._get_all_censoring_enabled_sources"
)
def test_post_query_chunk_censoring_no_user(
self, mock_get_sources: MagicMock
) -> None:
mock_get_sources.return_value = {DocumentSource.SALESFORCE}
chunks = [self.mock_chunk_1, self.mock_chunk_2]
result = _post_query_chunk_censoring(chunks, None)
assert result == chunks
@patch(
"ee.onyx.external_permissions.post_query_censoring._get_all_censoring_enabled_sources"
)
@patch(
"ee.onyx.external_permissions.post_query_censoring.DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION"
)
def test_post_query_chunk_censoring_salesforce_censored(
self, mock_censor_func: MagicMock, mock_get_sources: MagicMock
) -> None:
mock_get_sources.return_value = {DocumentSource.SALESFORCE}
mock_censor_func_impl = MagicMock(
return_value=[self.mock_chunk_1]
) # Only return chunk 1
mock_censor_func.__getitem__.return_value = mock_censor_func_impl
chunks = [self.mock_chunk_1, self.mock_chunk_2, self.mock_chunk_3]
result = _post_query_chunk_censoring(chunks, self.mock_user)
assert len(result) == 2
assert self.mock_chunk_1 in result
assert self.mock_chunk_2 in result
assert self.mock_chunk_3 not in result
mock_censor_func_impl.assert_called_once()
@patch(
"ee.onyx.external_permissions.post_query_censoring._get_all_censoring_enabled_sources"
)
@patch(
"ee.onyx.external_permissions.post_query_censoring.DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION"
)
def test_post_query_chunk_censoring_salesforce_error(
self, mock_censor_func: MagicMock, mock_get_sources: MagicMock
) -> None:
mock_get_sources.return_value = {DocumentSource.SALESFORCE}
mock_censor_func_impl = MagicMock(side_effect=Exception("Censoring error"))
mock_censor_func.__getitem__.return_value = mock_censor_func_impl
chunks = [self.mock_chunk_1, self.mock_chunk_2, self.mock_chunk_3]
result = _post_query_chunk_censoring(chunks, self.mock_user)
assert len(result) == 1
assert self.mock_chunk_2 in result
mock_censor_func_impl.assert_called_once()
@patch(
"ee.onyx.external_permissions.post_query_censoring._get_all_censoring_enabled_sources"
)
@patch(
"ee.onyx.external_permissions.post_query_censoring.DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION"
)
def test_post_query_chunk_censoring_no_censoring(
self, mock_censor_func: MagicMock, mock_get_sources: MagicMock
) -> None:
mock_get_sources.return_value = set() # No sources to censor
mock_censor_func_impl = MagicMock()
mock_censor_func.__getitem__.return_value = mock_censor_func_impl
chunks = [self.mock_chunk_1, self.mock_chunk_2, self.mock_chunk_3]
result = _post_query_chunk_censoring(chunks, self.mock_user)
assert result == chunks
mock_censor_func_impl.assert_not_called()
@patch(
"ee.onyx.external_permissions.post_query_censoring._get_all_censoring_enabled_sources"
)
@patch(
"ee.onyx.external_permissions.post_query_censoring.DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION"
)
def test_post_query_chunk_censoring_order_maintained(
self, mock_censor_func: MagicMock, mock_get_sources: MagicMock
) -> None:
mock_get_sources.return_value = {DocumentSource.SALESFORCE}
mock_censor_func_impl = MagicMock(
return_value=[self.mock_chunk_3, self.mock_chunk_1]
) # Return chunk 3 and 1
mock_censor_func.__getitem__.return_value = mock_censor_func_impl
chunks = [
self.mock_chunk_1,
self.mock_chunk_2,
self.mock_chunk_3,
self.mock_chunk_4,
]
result = _post_query_chunk_censoring(chunks, self.mock_user)
assert len(result) == 3
assert result[0] == self.mock_chunk_1
assert result[1] == self.mock_chunk_2
assert result[2] == self.mock_chunk_3
assert self.mock_chunk_4 not in result
mock_censor_func_impl.assert_called_once()

View File

@@ -0,0 +1,270 @@
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import INDEX_SEPARATOR
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import Tag
from onyx.document_index.vespa.shared_utils.vespa_request_builders import (
build_vespa_filters,
)
from onyx.document_index.vespa_constants import DOC_UPDATED_AT
from onyx.document_index.vespa_constants import DOCUMENT_SETS
from onyx.document_index.vespa_constants import HIDDEN
from onyx.document_index.vespa_constants import METADATA_LIST
from onyx.document_index.vespa_constants import SOURCE_TYPE
from onyx.document_index.vespa_constants import TENANT_ID
from onyx.document_index.vespa_constants import USER_FILE
from onyx.document_index.vespa_constants import USER_FOLDER
from shared_configs.configs import MULTI_TENANT
# Import the function under test
class TestBuildVespaFilters:
def test_empty_filters(self) -> None:
"""Test with empty filters object."""
filters = IndexFilters(access_control_list=[])
result = build_vespa_filters(filters)
assert result == f"!({HIDDEN}=true) and "
# With trailing AND removed
result = build_vespa_filters(filters, remove_trailing_and=True)
assert result == f"!({HIDDEN}=true)"
def test_include_hidden(self) -> None:
"""Test with include_hidden flag."""
filters = IndexFilters(access_control_list=[])
result = build_vespa_filters(filters, include_hidden=True)
assert result == "" # No filters applied when including hidden
# With some other filter to ensure proper AND chaining
filters = IndexFilters(access_control_list=[], source_type=[DocumentSource.WEB])
result = build_vespa_filters(filters, include_hidden=True)
assert result == f'({SOURCE_TYPE} contains "web") and '
def test_acl(self) -> None:
"""Test with acls."""
# Single ACL
filters = IndexFilters(access_control_list=["user1"])
result = build_vespa_filters(filters)
assert (
result
== f'!({HIDDEN}=true) and (access_control_list contains "user1") and '
)
# Multiple ACL's
filters = IndexFilters(access_control_list=["user2", "group2"])
result = build_vespa_filters(filters)
assert (
result
== f'!({HIDDEN}=true) and (access_control_list contains "user2" or access_control_list contains "group2") and '
)
def test_tenant_filter(self) -> None:
"""Test tenant ID filtering."""
# With tenant ID
if MULTI_TENANT:
filters = IndexFilters(access_control_list=[], tenant_id="tenant1")
result = build_vespa_filters(filters)
assert (
f'!({HIDDEN}=true) and ({TENANT_ID} contains "tenant1") and ' == result
)
# No tenant ID
filters = IndexFilters(access_control_list=[], tenant_id=None)
result = build_vespa_filters(filters)
assert f"!({HIDDEN}=true) and " == result
def test_source_type_filter(self) -> None:
"""Test source type filtering."""
# Single source type
filters = IndexFilters(access_control_list=[], source_type=[DocumentSource.WEB])
result = build_vespa_filters(filters)
assert f'!({HIDDEN}=true) and ({SOURCE_TYPE} contains "web") and ' == result
# Multiple source types
filters = IndexFilters(
access_control_list=[],
source_type=[DocumentSource.WEB, DocumentSource.JIRA],
)
result = build_vespa_filters(filters)
assert (
f'!({HIDDEN}=true) and ({SOURCE_TYPE} contains "web" or {SOURCE_TYPE} contains "jira") and '
== result
)
# Empty source type list
filters = IndexFilters(access_control_list=[], source_type=[])
result = build_vespa_filters(filters)
assert f"!({HIDDEN}=true) and " == result
def test_tag_filters(self) -> None:
"""Test tag filtering."""
# Single tag
filters = IndexFilters(
access_control_list=[], tags=[Tag(tag_key="color", tag_value="red")]
)
result = build_vespa_filters(filters)
assert (
f'!({HIDDEN}=true) and ({METADATA_LIST} contains "color{INDEX_SEPARATOR}red") and '
== result
)
# Multiple tags
filters = IndexFilters(
access_control_list=[],
tags=[
Tag(tag_key="color", tag_value="red"),
Tag(tag_key="size", tag_value="large"),
],
)
result = build_vespa_filters(filters)
expected = (
f'!({HIDDEN}=true) and ({METADATA_LIST} contains "color{INDEX_SEPARATOR}red" '
f'or {METADATA_LIST} contains "size{INDEX_SEPARATOR}large") and '
)
assert expected == result
# Empty tags list
filters = IndexFilters(access_control_list=[], tags=[])
result = build_vespa_filters(filters)
assert f"!({HIDDEN}=true) and " == result
def test_document_sets_filter(self) -> None:
"""Test document sets filtering."""
# Single document set
filters = IndexFilters(access_control_list=[], document_set=["set1"])
result = build_vespa_filters(filters)
assert f'!({HIDDEN}=true) and ({DOCUMENT_SETS} contains "set1") and ' == result
# Multiple document sets
filters = IndexFilters(access_control_list=[], document_set=["set1", "set2"])
result = build_vespa_filters(filters)
assert (
f'!({HIDDEN}=true) and ({DOCUMENT_SETS} contains "set1" or {DOCUMENT_SETS} contains "set2") and '
== result
)
# Empty document sets
filters = IndexFilters(access_control_list=[], document_set=[])
result = build_vespa_filters(filters)
assert f"!({HIDDEN}=true) and " == result
def test_user_file_ids_filter(self) -> None:
"""Test user file IDs filtering."""
# Single user file ID
filters = IndexFilters(access_control_list=[], user_file_ids=[123])
result = build_vespa_filters(filters)
assert f"!({HIDDEN}=true) and ({USER_FILE} = 123) and " == result
# Multiple user file IDs
filters = IndexFilters(access_control_list=[], user_file_ids=[123, 456])
result = build_vespa_filters(filters)
assert (
f"!({HIDDEN}=true) and ({USER_FILE} = 123 or {USER_FILE} = 456) and "
== result
)
# Empty user file IDs
filters = IndexFilters(access_control_list=[], user_file_ids=[])
result = build_vespa_filters(filters)
assert f"!({HIDDEN}=true) and " == result
def test_user_folder_ids_filter(self) -> None:
"""Test user folder IDs filtering."""
# Single user folder ID
filters = IndexFilters(access_control_list=[], user_folder_ids=[789])
result = build_vespa_filters(filters)
assert f"!({HIDDEN}=true) and ({USER_FOLDER} = 789) and " == result
# Multiple user folder IDs
filters = IndexFilters(access_control_list=[], user_folder_ids=[789, 101])
result = build_vespa_filters(filters)
assert (
f"!({HIDDEN}=true) and ({USER_FOLDER} = 789 or {USER_FOLDER} = 101) and "
== result
)
# Empty user folder IDs
filters = IndexFilters(access_control_list=[], user_folder_ids=[])
result = build_vespa_filters(filters)
assert f"!({HIDDEN}=true) and " == result
def test_time_cutoff_filter(self) -> None:
"""Test time cutoff filtering."""
# With cutoff time
cutoff_time = datetime(2023, 1, 1, tzinfo=timezone.utc)
filters = IndexFilters(access_control_list=[], time_cutoff=cutoff_time)
result = build_vespa_filters(filters)
cutoff_secs = int(cutoff_time.timestamp())
assert (
f"!({HIDDEN}=true) and !({DOC_UPDATED_AT} < {cutoff_secs}) and " == result
)
# No cutoff time
filters = IndexFilters(access_control_list=[], time_cutoff=None)
result = build_vespa_filters(filters)
assert f"!({HIDDEN}=true) and " == result
# Test untimed logic (when cutoff is old enough)
old_cutoff = datetime.now(timezone.utc) - timedelta(days=100)
filters = IndexFilters(access_control_list=[], time_cutoff=old_cutoff)
result = build_vespa_filters(filters)
old_cutoff_secs = int(old_cutoff.timestamp())
assert (
f"!({HIDDEN}=true) and !({DOC_UPDATED_AT} < {old_cutoff_secs}) and "
== result
)
def test_combined_filters(self) -> None:
"""Test combining multiple filter types."""
filters = IndexFilters(
access_control_list=["user1", "group1"],
source_type=[DocumentSource.WEB],
tags=[Tag(tag_key="color", tag_value="red")],
document_set=["set1"],
user_file_ids=[123],
user_folder_ids=[789],
time_cutoff=datetime(2023, 1, 1, tzinfo=timezone.utc),
)
result = build_vespa_filters(filters)
# Build expected result piece by piece for readability
expected = f"!({HIDDEN}=true) and "
expected += (
'(access_control_list contains "user1" or '
'access_control_list contains "group1") and '
)
expected += f'({SOURCE_TYPE} contains "web") and '
expected += f'({METADATA_LIST} contains "color{INDEX_SEPARATOR}red") and '
expected += f'({DOCUMENT_SETS} contains "set1") and '
expected += f"({USER_FILE} = 123) and "
expected += f"({USER_FOLDER} = 789) and "
cutoff_secs = int(datetime(2023, 1, 1, tzinfo=timezone.utc).timestamp())
expected += f"!({DOC_UPDATED_AT} < {cutoff_secs}) and "
assert expected == result
# With trailing AND removed
result_no_trailing = build_vespa_filters(filters, remove_trailing_and=True)
assert expected[:-5] == result_no_trailing # Remove trailing " and "
def test_empty_or_none_values(self) -> None:
"""Test with empty or None values in filter lists."""
# Empty strings in document set
filters = IndexFilters(
access_control_list=[], document_set=["set1", "", "set2"]
)
result = build_vespa_filters(filters)
assert (
f'!({HIDDEN}=true) and ({DOCUMENT_SETS} contains "set1" or {DOCUMENT_SETS} contains "set2") and '
== result
)
# All empty strings in document set
filters = IndexFilters(access_control_list=[], document_set=["", ""])
result = build_vespa_filters(filters)
assert f"!({HIDDEN}=true) and " == result

View File

@@ -42,9 +42,7 @@ import Link from "next/link";
import { useRouter, useSearchParams } from "next/navigation";
import { useEffect, useMemo, useState } from "react";
import * as Yup from "yup";
import CollapsibleSection from "./CollapsibleSection";
import { SuccessfulPersonaUpdateRedirectType } from "./enums";
import { Persona, PersonaLabel, StarterMessage } from "./interfaces";
import { FullPersona, PersonaLabel, StarterMessage } from "./interfaces";
import {
PersonaUpsertParameters,
createPersona,
@@ -101,6 +99,7 @@ import { SEARCH_TOOL_ID } from "@/app/chat/tools/constants";
import TextView from "@/components/chat/TextView";
import { MinimalOnyxDocument } from "@/lib/search/interfaces";
import { TabToggle } from "@/components/ui/TabToggle";
import { MAX_CHARACTERS_PERSONA_DESCRIPTION } from "@/lib/constants";
function findSearchTool(tools: ToolSnapshot[]) {
return tools.find((tool) => tool.in_code_tool_id === SEARCH_TOOL_ID);
@@ -136,7 +135,7 @@ export function AssistantEditor({
shouldAddAssistantToUserPreferences,
admin,
}: {
existingPersona?: Persona | null;
existingPersona?: FullPersona | null;
ccPairs: CCPairBasicInfo[];
documentSets: DocumentSet[];
user: User | null;
@@ -184,8 +183,6 @@ export function AssistantEditor({
}
}, [defaultIconShape]);
const [isIconDropdownOpen, setIsIconDropdownOpen] = useState(false);
const [removePersonaImage, setRemovePersonaImage] = useState(false);
const autoStarterMessageEnabled = useMemo(
@@ -462,12 +459,12 @@ export function AssistantEditor({
"Must provide a description for the Assistant"
),
system_prompt: Yup.string().max(
8000,
"Instructions must be less than 8000 characters"
MAX_CHARACTERS_PERSONA_DESCRIPTION,
"Instructions must be less than 5000000 characters"
),
task_prompt: Yup.string().max(
8000,
"Reminders must be less than 8000 characters"
MAX_CHARACTERS_PERSONA_DESCRIPTION,
"Reminders must be less than 5000000 characters"
),
is_public: Yup.boolean().required(),
document_set_ids: Yup.array().of(Yup.number()),

View File

@@ -18,35 +18,37 @@ export interface Prompt {
datetime_aware: boolean;
default_prompt: boolean;
}
export interface Persona {
id: number;
name: string;
search_start_date: Date | null;
owner: MinimalUserSnapshot | null;
is_visible: boolean;
is_public: boolean;
display_priority: number | null;
description: string;
document_sets: DocumentSet[];
prompts: Prompt[];
tools: ToolSnapshot[];
num_chunks?: number;
llm_relevance_filter?: boolean;
llm_filter_extraction?: boolean;
llm_model_provider_override?: string;
llm_model_version_override?: string;
starter_messages: StarterMessage[] | null;
builtin_persona: boolean;
is_default_persona: boolean;
users: MinimalUserSnapshot[];
groups: number[];
is_public: boolean;
is_visible: boolean;
icon_shape?: number;
icon_color?: string;
uploaded_image_id?: string;
labels?: PersonaLabel[];
user_file_ids: number[];
user_folder_ids: number[];
display_priority: number | null;
is_default_persona: boolean;
builtin_persona: boolean;
starter_messages: StarterMessage[] | null;
tools: ToolSnapshot[];
labels?: PersonaLabel[];
owner: MinimalUserSnapshot | null;
users: MinimalUserSnapshot[];
groups: number[];
document_sets: DocumentSet[];
llm_model_provider_override?: string;
llm_model_version_override?: string;
num_chunks?: number;
}
export interface FullPersona extends Persona {
search_start_date: Date | null;
prompts: Prompt[];
llm_relevance_filter?: boolean;
llm_filter_extraction?: boolean;
}
export interface PersonaLabel {

View File

@@ -331,28 +331,3 @@ export function providersContainImageGeneratingSupport(
) {
return providers.some((provider) => provider.provider === "openai");
}
// Default fallback persona for when we must display a persona
// but assistant has access to none
export const defaultPersona: Persona = {
id: 0,
name: "Default Assistant",
description: "A default assistant",
is_visible: true,
is_public: true,
builtin_persona: false,
is_default_persona: true,
users: [],
groups: [],
document_sets: [],
prompts: [],
tools: [],
starter_messages: null,
display_priority: null,
search_start_date: null,
owner: null,
icon_shape: 50910,
icon_color: "#FF6F6F",
user_file_ids: [],
user_folder_ids: [],
};

View File

@@ -487,11 +487,6 @@ export default function EmbeddingForm() {
};
const handleReIndex = async () => {
console.log("handleReIndex");
console.log(selectedProvider);
console.log(advancedEmbeddingDetails);
console.log(rerankingDetails);
console.log(reindexType);
if (!selectedProvider) {
return;
}

View File

@@ -1383,7 +1383,7 @@ export function ChatPage({
regenerationRequest?.parentMessage.messageId ||
lastSuccessfulMessageId,
chatSessionId: currChatSessionId,
promptId: liveAssistant?.prompts[0]?.id || 0,
promptId: null,
filters: buildFilters(
filterManager.selectedSources,
filterManager.selectedDocumentSets,

View File

@@ -9,11 +9,6 @@ import { redirect } from "next/navigation";
import { BackendChatSession } from "../../interfaces";
import { SharedChatDisplay } from "./SharedChatDisplay";
import { Persona } from "@/app/admin/assistants/interfaces";
import {
FetchAssistantsResponse,
fetchAssistantsSS,
} from "@/lib/assistants/fetchAssistantsSS";
import { defaultPersona } from "@/app/admin/assistants/lib";
import { constructMiniFiedPersona } from "@/lib/assistantIconUtils";
async function getSharedChat(chatId: string) {

View File

@@ -2,7 +2,7 @@ import { User } from "@/lib/types";
import { FiPlus, FiX } from "react-icons/fi";
import { SearchMultiSelectDropdown } from "@/components/Dropdown";
import { UsersIcon } from "@/components/icons/icons";
import { Button } from "@/components/Button";
import { Button } from "@/components/ui/button";
interface UserEditorProps {
selectedUserIds: string[];

View File

@@ -22,7 +22,7 @@ export const AddMemberForm: React.FC<AddMemberFormProps> = ({
return (
<Modal
className="max-w-xl"
className="max-w-xl overflow-visible"
title="Add New User"
onOutsideClick={() => onClose()}
>

View File

@@ -43,13 +43,13 @@ const DropdownOption: React.FC<DropdownOptionProps> = ({
if (href) {
return (
<Link
<a
href={href}
target={openInNewTab ? "_blank" : undefined}
rel={openInNewTab ? "noopener noreferrer" : undefined}
>
{content}
</Link>
</a>
);
} else {
return <div onClick={onClick}>{content}</div>;

View File

@@ -167,9 +167,7 @@ export const constructMiniFiedPersona = (
display_priority: 0,
description: "",
document_sets: [],
prompts: [],
tools: [],
search_start_date: null,
owner: null,
starter_messages: null,
builtin_persona: false,

View File

@@ -1,4 +1,4 @@
import { Persona } from "@/app/admin/assistants/interfaces";
import { FullPersona, Persona } from "@/app/admin/assistants/interfaces";
import { CCPairBasicInfo, DocumentSet, User } from "../types";
import { getCurrentUserSS } from "../userSS";
import { fetchSS } from "../utilsSS";
@@ -18,7 +18,7 @@ export async function fetchAssistantEditorInfoSS(
documentSets: DocumentSet[];
llmProviders: LLMProviderView[];
user: User | null;
existingPersona: Persona | null;
existingPersona: FullPersona | null;
tools: ToolSnapshot[];
},
null,
@@ -94,7 +94,7 @@ export async function fetchAssistantEditorInfoSS(
}
const existingPersona = personaResponse
? ((await personaResponse.json()) as Persona)
? ((await personaResponse.json()) as FullPersona)
: null;
let error: string | null = null;

View File

@@ -105,3 +105,5 @@ export const ALLOWED_URL_PROTOCOLS = [
"spotify:",
"zoommtg:",
];
export const MAX_CHARACTERS_PERSONA_DESCRIPTION = 5000000;

View File

@@ -77,6 +77,7 @@ const MODEL_NAMES_SUPPORTING_IMAGE_INPUT = [
"claude-3-haiku-20240307",
// custom claude names
"claude-3.5-sonnet-v2@20241022",
"claude-3-7-sonnet@20250219",
// claude names with AWS Bedrock Suffix
"claude-3-opus-20240229-v1:0",
"claude-3-sonnet-20240229-v1:0",
@@ -125,12 +126,27 @@ export function checkLLMSupportsImageInput(model: string) {
const modelParts = model.split(/[/.]/);
const lastPart = modelParts[modelParts.length - 1]?.toLowerCase();
return MODEL_NAMES_SUPPORTING_IMAGE_INPUT.some((modelName) => {
// Try matching the last part
const lastPartMatch = MODEL_NAMES_SUPPORTING_IMAGE_INPUT.some((modelName) => {
const modelNameParts = modelName.split(/[/.]/);
const modelNameLastPart = modelNameParts[modelNameParts.length - 1];
// lastPart is already lowercased above for tiny performance gain
return modelNameLastPart?.toLowerCase() === lastPart;
});
if (lastPartMatch) {
return true;
}
// If no match found, try getting the text after the first slash
if (model.includes("/")) {
const afterSlash = model.split("/")[1]?.toLowerCase();
return MODEL_NAMES_SUPPORTING_IMAGE_INPUT.some((modelName) =>
modelName.toLowerCase().includes(afterSlash)
);
}
return false;
}
export const structureValue = (

View File

@@ -1,201 +0,0 @@
import {
BackendMessage,
LLMRelevanceFilterPacket,
} from "@/app/chat/interfaces";
import {
AnswerPiecePacket,
OnyxDocument,
ErrorMessagePacket,
DocumentInfoPacket,
Quote,
QuotesInfoPacket,
RelevanceChunk,
SearchRequestArgs,
} from "./interfaces";
import { processRawChunkString } from "./streamingUtils";
import { buildFilters, endsWithLetterOrNumber } from "./utils";
export const searchRequestStreamed = async ({
query,
sources,
documentSets,
timeRange,
tags,
persona,
agentic,
updateCurrentAnswer,
updateQuotes,
updateDocs,
updateSuggestedSearchType,
updateSuggestedFlowType,
updateSelectedDocIndices,
updateError,
updateMessageAndThreadId,
finishedSearching,
updateDocumentRelevance,
updateComments,
}: SearchRequestArgs) => {
let answer = "";
let quotes: Quote[] | null = null;
let relevantDocuments: OnyxDocument[] | null = null;
try {
const filters = buildFilters(sources, documentSets, timeRange, tags);
const threadMessage = {
message: query,
sender: null,
role: "user",
};
const response = await fetch("/api/query/stream-answer-with-quote", {
method: "POST",
body: JSON.stringify({
messages: [threadMessage],
persona_id: persona.id,
agentic,
prompt_id: persona.id === 0 ? null : persona.prompts[0]?.id,
retrieval_options: {
run_search: "always",
real_time: true,
filters: filters,
enable_auto_detect_filters: false,
},
evaluation_type: agentic ? "agentic" : "basic",
}),
headers: {
"Content-Type": "application/json",
},
});
const reader = response.body?.getReader();
const decoder = new TextDecoder("utf-8");
let previousPartialChunk: string | null = null;
while (true) {
const rawChunk = await reader?.read();
if (!rawChunk) {
throw new Error("Unable to process chunk");
}
const { done, value } = rawChunk;
if (done) {
break;
}
// Process each chunk as it arrives
const [completedChunks, partialChunk] = processRawChunkString<
| AnswerPiecePacket
| ErrorMessagePacket
| QuotesInfoPacket
| DocumentInfoPacket
| LLMRelevanceFilterPacket
| BackendMessage
| DocumentInfoPacket
| RelevanceChunk
>(decoder.decode(value, { stream: true }), previousPartialChunk);
if (!completedChunks.length && !partialChunk) {
break;
}
previousPartialChunk = partialChunk as string | null;
completedChunks.forEach((chunk) => {
// check for answer piece / end of answer
if (Object.hasOwn(chunk, "relevance_summaries")) {
const relevanceChunk = chunk as RelevanceChunk;
updateDocumentRelevance(relevanceChunk.relevance_summaries);
}
if (Object.hasOwn(chunk, "answer_piece")) {
const answerPiece = (chunk as AnswerPiecePacket).answer_piece;
if (answerPiece !== null) {
answer += (chunk as AnswerPiecePacket).answer_piece;
updateCurrentAnswer(answer);
} else {
// set quotes as non-null to signify that the answer is finished and
// we're now looking for quotes
updateQuotes([]);
if (
answer &&
!answer.endsWith(".") &&
!answer.endsWith("?") &&
!answer.endsWith("!") &&
endsWithLetterOrNumber(answer)
) {
answer += ".";
updateCurrentAnswer(answer);
}
}
return;
}
if (Object.hasOwn(chunk, "error")) {
updateError((chunk as ErrorMessagePacket).error);
return;
}
// These all come together
if (Object.hasOwn(chunk, "top_documents")) {
chunk = chunk as DocumentInfoPacket;
const topDocuments = chunk.top_documents as OnyxDocument[] | null;
if (topDocuments) {
relevantDocuments = topDocuments;
updateDocs(relevantDocuments);
}
if (chunk.predicted_flow) {
updateSuggestedFlowType(chunk.predicted_flow);
}
if (chunk.predicted_search) {
updateSuggestedSearchType(chunk.predicted_search);
}
return;
}
if (Object.hasOwn(chunk, "relevant_chunk_indices")) {
const relevantChunkIndices = (chunk as LLMRelevanceFilterPacket)
.relevant_chunk_indices;
if (relevantChunkIndices) {
updateSelectedDocIndices(relevantChunkIndices);
}
return;
}
// Check for quote section
if (Object.hasOwn(chunk, "quotes")) {
quotes = (chunk as QuotesInfoPacket).quotes;
updateQuotes(quotes);
return;
}
// Check for the final chunk
if (Object.hasOwn(chunk, "message_id")) {
const backendChunk = chunk as BackendMessage;
updateComments(backendChunk.comments);
updateMessageAndThreadId(
backendChunk.message_id,
backendChunk.chat_session_id
);
}
});
}
} catch (err) {
console.error("Fetch error:", err);
let errorMessage = "An error occurred while fetching the answer.";
if (err instanceof Error) {
if (err.message.includes("rate_limit_error")) {
errorMessage =
"Rate limit exceeded. Please try again later or reduce the length of your query.";
} else {
errorMessage = err.message;
}
}
updateError(errorMessage);
}
return { answer, quotes, relevantDocuments };
};