mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-18 08:15:48 +00:00
Compare commits
6 Commits
testing_li
...
additional
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5c33f33f6e | ||
|
|
a8f87588ff | ||
|
|
4bae1318bb | ||
|
|
11c3f44c76 | ||
|
|
cb38ac8a97 | ||
|
|
b2120b9f39 |
@@ -46,7 +46,6 @@ WORKDIR /app
|
||||
|
||||
# Utils used by model server
|
||||
COPY ./onyx/utils/logger.py /app/onyx/utils/logger.py
|
||||
COPY ./onyx/utils/middleware.py /app/onyx/utils/middleware.py
|
||||
|
||||
# Place to fetch version information
|
||||
COPY ./onyx/__init__.py /app/onyx/__init__.py
|
||||
|
||||
@@ -1,50 +0,0 @@
|
||||
"""update prompt length
|
||||
|
||||
Revision ID: 4794bc13e484
|
||||
Revises: f7505c5b0284
|
||||
Create Date: 2025-04-02 11:26:36.180328
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "4794bc13e484"
|
||||
down_revision = "f7505c5b0284"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.alter_column(
|
||||
"prompt",
|
||||
"system_prompt",
|
||||
existing_type=sa.TEXT(),
|
||||
type_=sa.String(length=5000000),
|
||||
existing_nullable=False,
|
||||
)
|
||||
op.alter_column(
|
||||
"prompt",
|
||||
"task_prompt",
|
||||
existing_type=sa.TEXT(),
|
||||
type_=sa.String(length=5000000),
|
||||
existing_nullable=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.alter_column(
|
||||
"prompt",
|
||||
"system_prompt",
|
||||
existing_type=sa.String(length=5000000),
|
||||
type_=sa.TEXT(),
|
||||
existing_nullable=False,
|
||||
)
|
||||
op.alter_column(
|
||||
"prompt",
|
||||
"task_prompt",
|
||||
existing_type=sa.String(length=5000000),
|
||||
type_=sa.TEXT(),
|
||||
existing_nullable=False,
|
||||
)
|
||||
@@ -1,50 +0,0 @@
|
||||
"""add prompt length limit
|
||||
|
||||
Revision ID: f71470ba9274
|
||||
Revises: 6a804aeb4830
|
||||
Create Date: 2025-04-01 15:07:14.977435
|
||||
|
||||
"""
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f71470ba9274"
|
||||
down_revision = "6a804aeb4830"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# op.alter_column(
|
||||
# "prompt",
|
||||
# "system_prompt",
|
||||
# existing_type=sa.TEXT(),
|
||||
# type_=sa.String(length=8000),
|
||||
# existing_nullable=False,
|
||||
# )
|
||||
# op.alter_column(
|
||||
# "prompt",
|
||||
# "task_prompt",
|
||||
# existing_type=sa.TEXT(),
|
||||
# type_=sa.String(length=8000),
|
||||
# existing_nullable=False,
|
||||
# )
|
||||
pass
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# op.alter_column(
|
||||
# "prompt",
|
||||
# "system_prompt",
|
||||
# existing_type=sa.String(length=8000),
|
||||
# type_=sa.TEXT(),
|
||||
# existing_nullable=False,
|
||||
# )
|
||||
# op.alter_column(
|
||||
# "prompt",
|
||||
# "task_prompt",
|
||||
# existing_type=sa.String(length=8000),
|
||||
# type_=sa.TEXT(),
|
||||
# existing_nullable=False,
|
||||
# )
|
||||
pass
|
||||
@@ -1,77 +0,0 @@
|
||||
"""updated constraints for ccpairs
|
||||
|
||||
Revision ID: f7505c5b0284
|
||||
Revises: f71470ba9274
|
||||
Create Date: 2025-04-01 17:50:42.504818
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "f7505c5b0284"
|
||||
down_revision = "f71470ba9274"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# 1) Drop the old foreign-key constraints
|
||||
op.drop_constraint(
|
||||
"document_by_connector_credential_pair_connector_id_fkey",
|
||||
"document_by_connector_credential_pair",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"document_by_connector_credential_pair_credential_id_fkey",
|
||||
"document_by_connector_credential_pair",
|
||||
type_="foreignkey",
|
||||
)
|
||||
|
||||
# 2) Re-add them with ondelete='CASCADE'
|
||||
op.create_foreign_key(
|
||||
"document_by_connector_credential_pair_connector_id_fkey",
|
||||
source_table="document_by_connector_credential_pair",
|
||||
referent_table="connector",
|
||||
local_cols=["connector_id"],
|
||||
remote_cols=["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"document_by_connector_credential_pair_credential_id_fkey",
|
||||
source_table="document_by_connector_credential_pair",
|
||||
referent_table="credential",
|
||||
local_cols=["credential_id"],
|
||||
remote_cols=["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Reverse the changes for rollback
|
||||
op.drop_constraint(
|
||||
"document_by_connector_credential_pair_connector_id_fkey",
|
||||
"document_by_connector_credential_pair",
|
||||
type_="foreignkey",
|
||||
)
|
||||
op.drop_constraint(
|
||||
"document_by_connector_credential_pair_credential_id_fkey",
|
||||
"document_by_connector_credential_pair",
|
||||
type_="foreignkey",
|
||||
)
|
||||
|
||||
# Recreate without CASCADE
|
||||
op.create_foreign_key(
|
||||
"document_by_connector_credential_pair_connector_id_fkey",
|
||||
"document_by_connector_credential_pair",
|
||||
"connector",
|
||||
["connector_id"],
|
||||
["id"],
|
||||
)
|
||||
op.create_foreign_key(
|
||||
"document_by_connector_credential_pair_credential_id_fkey",
|
||||
"document_by_connector_credential_pair",
|
||||
"credential",
|
||||
["credential_id"],
|
||||
["id"],
|
||||
)
|
||||
@@ -159,9 +159,6 @@ def _get_space_permissions(
|
||||
|
||||
# Stores the permissions for each space
|
||||
space_permissions_by_space_key[space_key] = space_permissions
|
||||
logger.info(
|
||||
f"Found space permissions for space '{space_key}': {space_permissions}"
|
||||
)
|
||||
|
||||
return space_permissions_by_space_key
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ def _post_query_chunk_censoring(
|
||||
# if user is None, permissions are not enforced
|
||||
return chunks
|
||||
|
||||
final_chunk_dict: dict[str, InferenceChunk] = {}
|
||||
chunks_to_keep = []
|
||||
chunks_to_process: dict[DocumentSource, list[InferenceChunk]] = {}
|
||||
|
||||
sources_to_censor = _get_all_censoring_enabled_sources()
|
||||
@@ -64,7 +64,7 @@ def _post_query_chunk_censoring(
|
||||
if chunk.source_type in sources_to_censor:
|
||||
chunks_to_process.setdefault(chunk.source_type, []).append(chunk)
|
||||
else:
|
||||
final_chunk_dict[chunk.unique_id] = chunk
|
||||
chunks_to_keep.append(chunk)
|
||||
|
||||
# For each source, filter out the chunks using the permission
|
||||
# check function for that source
|
||||
@@ -79,16 +79,6 @@ def _post_query_chunk_censoring(
|
||||
f" chunks for this source and continuing: {e}"
|
||||
)
|
||||
continue
|
||||
chunks_to_keep.extend(censored_chunks)
|
||||
|
||||
for censored_chunk in censored_chunks:
|
||||
final_chunk_dict[censored_chunk.unique_id] = censored_chunk
|
||||
|
||||
# IMPORTANT: make sure to retain the same ordering as the original `chunks` passed in
|
||||
final_chunk_list: list[InferenceChunk] = []
|
||||
for chunk in chunks:
|
||||
# only if the chunk is in the final censored chunks, add it to the final list
|
||||
# if it is missing, that means it was intentionally left out
|
||||
if chunk.unique_id in final_chunk_dict:
|
||||
final_chunk_list.append(final_chunk_dict[chunk.unique_id])
|
||||
|
||||
return final_chunk_list
|
||||
return chunks_to_keep
|
||||
|
||||
@@ -58,7 +58,6 @@ def _get_objects_access_for_user_email_from_salesforce(
|
||||
f"Time taken to get Salesforce user ID: {end_time - start_time} seconds"
|
||||
)
|
||||
if user_id is None:
|
||||
logger.warning(f"User '{user_email}' not found in Salesforce")
|
||||
return None
|
||||
|
||||
# This is the only query that is not cached in the function
|
||||
@@ -66,7 +65,6 @@ def _get_objects_access_for_user_email_from_salesforce(
|
||||
object_id_to_access = get_objects_access_for_user_id(
|
||||
salesforce_client, user_id, list(object_ids)
|
||||
)
|
||||
logger.debug(f"Object ID to access: {object_id_to_access}")
|
||||
return object_id_to_access
|
||||
|
||||
|
||||
|
||||
@@ -42,18 +42,11 @@ def get_any_salesforce_client_for_doc_id(
|
||||
|
||||
|
||||
def _query_salesforce_user_id(sf_client: Salesforce, user_email: str) -> str | None:
|
||||
query = f"SELECT Id FROM User WHERE Username = '{user_email}' AND IsActive = true"
|
||||
query = f"SELECT Id FROM User WHERE Email = '{user_email}'"
|
||||
result = sf_client.query(query)
|
||||
if len(result["records"]) > 0:
|
||||
return result["records"][0]["Id"]
|
||||
|
||||
# try emails
|
||||
query = f"SELECT Id FROM User WHERE Email = '{user_email}' AND IsActive = true"
|
||||
result = sf_client.query(query)
|
||||
if len(result["records"]) > 0:
|
||||
return result["records"][0]["Id"]
|
||||
|
||||
return None
|
||||
if len(result["records"]) == 0:
|
||||
return None
|
||||
return result["records"][0]["Id"]
|
||||
|
||||
|
||||
# This contains only the user_ids that we have found in Salesforce.
|
||||
|
||||
@@ -36,6 +36,9 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
router = APIRouter(prefix="/auth/saml")
|
||||
|
||||
# Define non-authenticated user roles that should be re-created during SAML login
|
||||
NON_AUTHENTICATED_ROLES = {UserRole.SLACK_USER, UserRole.EXT_PERM_USER}
|
||||
|
||||
|
||||
async def upsert_saml_user(email: str) -> User:
|
||||
logger.debug(f"Attempting to upsert SAML user with email: {email}")
|
||||
@@ -51,7 +54,7 @@ async def upsert_saml_user(email: str) -> User:
|
||||
try:
|
||||
user = await user_manager.get_by_email(email)
|
||||
# If user has a non-authenticated role, treat as non-existent
|
||||
if not user.role.is_web_login():
|
||||
if user.role in NON_AUTHENTICATED_ROLES:
|
||||
raise exceptions.UserNotExists()
|
||||
return user
|
||||
except exceptions.UserNotExists:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import logging
|
||||
import os
|
||||
import shutil
|
||||
from collections.abc import AsyncGenerator
|
||||
@@ -9,7 +8,6 @@ import sentry_sdk
|
||||
import torch
|
||||
import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
from sentry_sdk.integrations.fastapi import FastApiIntegration
|
||||
from sentry_sdk.integrations.starlette import StarletteIntegration
|
||||
from transformers import logging as transformer_logging # type:ignore
|
||||
@@ -22,8 +20,6 @@ from model_server.management_endpoints import router as management_router
|
||||
from model_server.utils import get_gpu_type
|
||||
from onyx import __version__
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.logger import setup_uvicorn_logger
|
||||
from onyx.utils.middleware import add_onyx_request_id_middleware
|
||||
from shared_configs.configs import INDEXING_ONLY
|
||||
from shared_configs.configs import MIN_THREADS_ML_MODELS
|
||||
from shared_configs.configs import MODEL_SERVER_ALLOWED_HOST
|
||||
@@ -40,12 +36,6 @@ transformer_logging.set_verbosity_error()
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
file_handlers = [
|
||||
h for h in logger.logger.handlers if isinstance(h, logging.FileHandler)
|
||||
]
|
||||
|
||||
setup_uvicorn_logger(shared_file_handlers=file_handlers)
|
||||
|
||||
|
||||
def _move_files_recursively(source: Path, dest: Path, overwrite: bool = False) -> None:
|
||||
"""
|
||||
@@ -122,15 +112,6 @@ def get_model_app() -> FastAPI:
|
||||
application.include_router(encoders_router)
|
||||
application.include_router(custom_models_router)
|
||||
|
||||
request_id_prefix = "INF"
|
||||
if INDEXING_ONLY:
|
||||
request_id_prefix = "IDX"
|
||||
|
||||
add_onyx_request_id_middleware(application, request_id_prefix, logger)
|
||||
|
||||
# Initialize and instrument the app
|
||||
Instrumentator().instrument(application).expose(application)
|
||||
|
||||
return application
|
||||
|
||||
|
||||
|
||||
@@ -15,22 +15,6 @@ class ExternalAccess:
|
||||
# Whether the document is public in the external system or Onyx
|
||||
is_public: bool
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Prevent extremely long logs"""
|
||||
|
||||
def truncate_set(s: set[str], max_len: int = 100) -> str:
|
||||
s_str = str(s)
|
||||
if len(s_str) > max_len:
|
||||
return f"{s_str[:max_len]}... ({len(s)} items)"
|
||||
return s_str
|
||||
|
||||
return (
|
||||
f"ExternalAccess("
|
||||
f"external_user_emails={truncate_set(self.external_user_emails)}, "
|
||||
f"external_user_group_ids={truncate_set(self.external_user_group_ids)}, "
|
||||
f"is_public={self.is_public})"
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DocExternalAccess:
|
||||
|
||||
@@ -23,7 +23,6 @@ from onyx.utils.url import add_url_params
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
HTML_EMAIL_TEMPLATE = """\
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
|
||||
@@ -56,7 +56,6 @@ from httpx_oauth.oauth2 import OAuth2Token
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME
|
||||
from onyx.auth.api_key import get_hashed_api_key_from_request
|
||||
from onyx.auth.email_utils import send_forgot_password_email
|
||||
from onyx.auth.email_utils import send_user_verification_email
|
||||
@@ -514,25 +513,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
|
||||
return user
|
||||
|
||||
async def on_after_login(
|
||||
self,
|
||||
user: User,
|
||||
request: Optional[Request] = None,
|
||||
response: Optional[Response] = None,
|
||||
) -> None:
|
||||
try:
|
||||
if response and request and ANONYMOUS_USER_COOKIE_NAME in request.cookies:
|
||||
response.delete_cookie(
|
||||
ANONYMOUS_USER_COOKIE_NAME,
|
||||
# Ensure cookie deletion doesn't override other cookies by setting the same path/domain
|
||||
path="/",
|
||||
domain=None,
|
||||
secure=WEB_DOMAIN.startswith("https"),
|
||||
)
|
||||
logger.debug(f"Deleted anonymous user cookie for user {user.email}")
|
||||
except Exception:
|
||||
logger.exception("Error deleting anonymous user cookie")
|
||||
|
||||
async def on_after_register(
|
||||
self, user: User, request: Optional[Request] = None
|
||||
) -> None:
|
||||
@@ -1322,7 +1302,6 @@ def get_oauth_router(
|
||||
# Login user
|
||||
response = await backend.login(strategy, user)
|
||||
await user_manager.on_after_login(user, request, response)
|
||||
|
||||
# Prepare redirect response
|
||||
if tenant_id is None:
|
||||
# Use URL utility to add parameters
|
||||
@@ -1332,14 +1311,9 @@ def get_oauth_router(
|
||||
# No parameters to add
|
||||
redirect_response = RedirectResponse(next_url, status_code=302)
|
||||
|
||||
# Copy headers from auth response to redirect response, with special handling for Set-Cookie
|
||||
# Copy headers and other attributes from 'response' to 'redirect_response'
|
||||
for header_name, header_value in response.headers.items():
|
||||
# FastAPI can have multiple Set-Cookie headers as a list
|
||||
if header_name.lower() == "set-cookie" and isinstance(header_value, list):
|
||||
for cookie_value in header_value:
|
||||
redirect_response.headers.append(header_name, cookie_value)
|
||||
else:
|
||||
redirect_response.headers[header_name] = header_value
|
||||
redirect_response.headers[header_name] = header_value
|
||||
|
||||
if hasattr(response, "body"):
|
||||
redirect_response.body = response.body
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import logging
|
||||
import multiprocessing
|
||||
import os
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
@@ -306,7 +305,7 @@ def wait_for_db(sender: Any, **kwargs: Any) -> None:
|
||||
|
||||
|
||||
def on_secondary_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
logger.info(f"Running as a secondary celery worker: pid={os.getpid()}")
|
||||
logger.info("Running as a secondary celery worker.")
|
||||
|
||||
# Set up variables for waiting on primary worker
|
||||
WAIT_INTERVAL = 5
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
from celery import Celery
|
||||
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
|
||||
celery_app = Celery(__name__)
|
||||
celery_app.config_from_object("onyx.background.celery.configs.client")
|
||||
celery_app.Task = app_base.TenantAwareTask # type: ignore [misc]
|
||||
@@ -1,5 +1,4 @@
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
@@ -96,7 +95,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
app_base.wait_for_vespa_or_shutdown(sender, **kwargs)
|
||||
|
||||
logger.info(f"Running as the primary celery worker: pid={os.getpid()}")
|
||||
logger.info("Running as the primary celery worker.")
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
|
||||
@@ -1,16 +0,0 @@
|
||||
import onyx.background.celery.configs.base as shared_config
|
||||
|
||||
broker_url = shared_config.broker_url
|
||||
broker_connection_retry_on_startup = shared_config.broker_connection_retry_on_startup
|
||||
broker_pool_limit = shared_config.broker_pool_limit
|
||||
broker_transport_options = shared_config.broker_transport_options
|
||||
|
||||
redis_socket_keepalive = shared_config.redis_socket_keepalive
|
||||
redis_retry_on_timeout = shared_config.redis_retry_on_timeout
|
||||
redis_backend_health_check_interval = shared_config.redis_backend_health_check_interval
|
||||
|
||||
result_backend = shared_config.result_backend
|
||||
result_expires = shared_config.result_expires # 86400 seconds is the default
|
||||
|
||||
task_default_priority = shared_config.task_default_priority
|
||||
task_acks_late = shared_config.task_acks_late
|
||||
@@ -886,8 +886,11 @@ def monitor_ccpair_permissions_taskset(
|
||||
record_type=RecordType.PERMISSION_SYNC_PROGRESS,
|
||||
data={
|
||||
"cc_pair_id": cc_pair_id,
|
||||
"total_docs_synced": initial if initial is not None else 0,
|
||||
"remaining_docs_to_sync": remaining,
|
||||
"id": payload.id if payload else None,
|
||||
"total_docs": initial if initial is not None else 0,
|
||||
"remaining_docs": remaining,
|
||||
"synced_docs": (initial - remaining) if initial is not None else 0,
|
||||
"is_complete": remaining == 0,
|
||||
},
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
@@ -903,13 +906,6 @@ def monitor_ccpair_permissions_taskset(
|
||||
f"num_synced={initial}"
|
||||
)
|
||||
|
||||
# Add telemetry for permission syncing complete
|
||||
optional_telemetry(
|
||||
record_type=RecordType.PERMISSION_SYNC_COMPLETE,
|
||||
data={"cc_pair_id": cc_pair_id},
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
entity_id=cc_pair_id,
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
"""Factory stub for running celery worker / celery beat.
|
||||
This code is different from the primary/beat stubs because there is no EE version to
|
||||
fetch. Port over the code in those files if we add an EE version of this worker.
|
||||
|
||||
This is an app stub purely for sending tasks as a client.
|
||||
"""
|
||||
from celery import Celery
|
||||
|
||||
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
|
||||
set_is_ee_based_on_env_variable()
|
||||
|
||||
|
||||
def get_app() -> Celery:
|
||||
from onyx.background.celery.apps.client import celery_app
|
||||
|
||||
return celery_app
|
||||
|
||||
|
||||
app = get_app()
|
||||
@@ -56,6 +56,7 @@ from onyx.indexing.indexing_pipeline import build_indexing_pipeline
|
||||
from onyx.natural_language_processing.search_nlp_models import (
|
||||
InformationContentClassificationModel,
|
||||
)
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.logger import TaskAttemptSingleton
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
@@ -577,8 +578,11 @@ def _run_indexing(
|
||||
data={
|
||||
"index_attempt_id": index_attempt_id,
|
||||
"cc_pair_id": ctx.cc_pair_id,
|
||||
"current_docs_indexed": document_count,
|
||||
"current_chunks_indexed": chunk_count,
|
||||
"connector_id": ctx.connector_id,
|
||||
"credential_id": ctx.credential_id,
|
||||
"total_docs_indexed": document_count,
|
||||
"total_chunks": chunk_count,
|
||||
"batch_num": batch_num,
|
||||
"source": ctx.source.value,
|
||||
},
|
||||
tenant_id=tenant_id,
|
||||
@@ -599,15 +603,26 @@ def _run_indexing(
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
|
||||
# Add telemetry for completed indexing
|
||||
redis_connector = RedisConnector(tenant_id, ctx.cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(
|
||||
index_attempt_start.search_settings_id
|
||||
)
|
||||
final_progress = redis_connector_index.get_progress() or 0
|
||||
|
||||
optional_telemetry(
|
||||
record_type=RecordType.INDEXING_COMPLETE,
|
||||
data={
|
||||
"index_attempt_id": index_attempt_id,
|
||||
"cc_pair_id": ctx.cc_pair_id,
|
||||
"connector_id": ctx.connector_id,
|
||||
"credential_id": ctx.credential_id,
|
||||
"total_docs_indexed": document_count,
|
||||
"total_chunks": chunk_count,
|
||||
"batch_count": batch_num,
|
||||
"time_elapsed_seconds": time.monotonic() - start_time,
|
||||
"source": ctx.source.value,
|
||||
"redis_progress": final_progress,
|
||||
},
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
@@ -43,7 +43,6 @@ from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_me
|
||||
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from onyx.configs.chat_configs import DISABLE_LLM_CHOOSE_SEARCH
|
||||
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from onyx.configs.chat_configs import SELECTED_SECTIONS_MAX_WINDOW_PERCENTAGE
|
||||
from onyx.configs.constants import AGENT_SEARCH_INITIAL_KEY
|
||||
from onyx.configs.constants import BASIC_KEY
|
||||
from onyx.configs.constants import MessageType
|
||||
@@ -693,13 +692,8 @@ def stream_chat_message_objects(
|
||||
doc_identifiers=identifier_tuples,
|
||||
document_index=document_index,
|
||||
)
|
||||
|
||||
# Add a maximum context size in the case of user-selected docs to prevent
|
||||
# slight inaccuracies in context window size pruning from causing
|
||||
# the entire query to fail
|
||||
document_pruning_config = DocumentPruningConfig(
|
||||
is_manually_selected_docs=True,
|
||||
max_window_percentage=SELECTED_SECTIONS_MAX_WINDOW_PERCENTAGE,
|
||||
is_manually_selected_docs=True
|
||||
)
|
||||
|
||||
# In case the search doc is deleted, just don't include it
|
||||
|
||||
@@ -312,14 +312,11 @@ def prune_sections(
|
||||
)
|
||||
|
||||
|
||||
def _merge_doc_chunks(chunks: list[InferenceChunk]) -> tuple[InferenceSection, int]:
|
||||
def _merge_doc_chunks(chunks: list[InferenceChunk]) -> InferenceSection:
|
||||
assert (
|
||||
len(set([chunk.document_id for chunk in chunks])) == 1
|
||||
), "One distinct document must be passed into merge_doc_chunks"
|
||||
|
||||
ADJACENT_CHUNK_SEP = "\n"
|
||||
DISTANT_CHUNK_SEP = "\n\n...\n\n"
|
||||
|
||||
# Assuming there are no duplicates by this point
|
||||
sorted_chunks = sorted(chunks, key=lambda x: x.chunk_id)
|
||||
|
||||
@@ -327,48 +324,33 @@ def _merge_doc_chunks(chunks: list[InferenceChunk]) -> tuple[InferenceSection, i
|
||||
chunks, key=lambda x: x.score if x.score is not None else float("-inf")
|
||||
)
|
||||
|
||||
added_chars = 0
|
||||
merged_content = []
|
||||
for i, chunk in enumerate(sorted_chunks):
|
||||
if i > 0:
|
||||
prev_chunk_id = sorted_chunks[i - 1].chunk_id
|
||||
sep = (
|
||||
ADJACENT_CHUNK_SEP
|
||||
if chunk.chunk_id == prev_chunk_id + 1
|
||||
else DISTANT_CHUNK_SEP
|
||||
)
|
||||
merged_content.append(sep)
|
||||
added_chars += len(sep)
|
||||
if chunk.chunk_id == prev_chunk_id + 1:
|
||||
merged_content.append("\n")
|
||||
else:
|
||||
merged_content.append("\n\n...\n\n")
|
||||
merged_content.append(chunk.content)
|
||||
|
||||
combined_content = "".join(merged_content)
|
||||
|
||||
return (
|
||||
InferenceSection(
|
||||
center_chunk=center_chunk,
|
||||
chunks=sorted_chunks,
|
||||
combined_content=combined_content,
|
||||
),
|
||||
added_chars,
|
||||
return InferenceSection(
|
||||
center_chunk=center_chunk,
|
||||
chunks=sorted_chunks,
|
||||
combined_content=combined_content,
|
||||
)
|
||||
|
||||
|
||||
def _merge_sections(sections: list[InferenceSection]) -> list[InferenceSection]:
|
||||
docs_map: dict[str, dict[int, InferenceChunk]] = defaultdict(dict)
|
||||
doc_order: dict[str, int] = {}
|
||||
combined_section_lengths: dict[str, int] = defaultdict(lambda: 0)
|
||||
|
||||
# chunk de-duping and doc ordering
|
||||
for index, section in enumerate(sections):
|
||||
if section.center_chunk.document_id not in doc_order:
|
||||
doc_order[section.center_chunk.document_id] = index
|
||||
|
||||
combined_section_lengths[section.center_chunk.document_id] += len(
|
||||
section.combined_content
|
||||
)
|
||||
|
||||
chunks_map = docs_map[section.center_chunk.document_id]
|
||||
for chunk in [section.center_chunk] + section.chunks:
|
||||
chunks_map = docs_map[section.center_chunk.document_id]
|
||||
existing_chunk = chunks_map.get(chunk.chunk_id)
|
||||
if (
|
||||
existing_chunk is None
|
||||
@@ -379,22 +361,8 @@ def _merge_sections(sections: list[InferenceSection]) -> list[InferenceSection]:
|
||||
chunks_map[chunk.chunk_id] = chunk
|
||||
|
||||
new_sections = []
|
||||
for doc_id, section_chunks in docs_map.items():
|
||||
section_chunks_list = list(section_chunks.values())
|
||||
merged_section, added_chars = _merge_doc_chunks(chunks=section_chunks_list)
|
||||
|
||||
previous_length = combined_section_lengths[doc_id] + added_chars
|
||||
# After merging, ensure the content respects the pruning done earlier. Each
|
||||
# combined section is restricted to the sum of the lengths of the sections
|
||||
# from the pruning step. Technically the correct approach would be to prune based
|
||||
# on tokens AGAIN, but this is a good approximation and worth not adding the
|
||||
# tokenization overhead. This could also be fixed if we added a way of removing
|
||||
# chunks from sections in the pruning step; at the moment this issue largely
|
||||
# exists because we only trim the final section's combined_content.
|
||||
merged_section.combined_content = merged_section.combined_content[
|
||||
:previous_length
|
||||
]
|
||||
new_sections.append(merged_section)
|
||||
for section_chunks in docs_map.values():
|
||||
new_sections.append(_merge_doc_chunks(chunks=list(section_chunks.values())))
|
||||
|
||||
# Sort by highest score, then by original document order
|
||||
# It is now 1 large section per doc, the center chunk being the one with the highest score
|
||||
|
||||
@@ -16,9 +16,6 @@ MAX_CHUNKS_FED_TO_CHAT = float(os.environ.get("MAX_CHUNKS_FED_TO_CHAT") or 10.0)
|
||||
# ~3k input, half for docs, half for chat history + prompts
|
||||
CHAT_TARGET_CHUNK_PERCENTAGE = 512 * 3 / 3072
|
||||
|
||||
# Maximum percentage of the context window to fill with selected sections
|
||||
SELECTED_SECTIONS_MAX_WINDOW_PERCENTAGE = 0.8
|
||||
|
||||
# 1 / (1 + DOC_TIME_DECAY * doc-age-in-years), set to 0 to have no decay
|
||||
# Capped in Vespa at 0.5
|
||||
DOC_TIME_DECAY = float(
|
||||
|
||||
@@ -13,7 +13,6 @@ from typing import TYPE_CHECKING
|
||||
from typing import TypeVar
|
||||
from urllib.parse import parse_qs
|
||||
from urllib.parse import quote
|
||||
from urllib.parse import urljoin
|
||||
from urllib.parse import urlparse
|
||||
|
||||
import requests
|
||||
@@ -343,14 +342,9 @@ def build_confluence_document_id(
|
||||
Returns:
|
||||
str: The document id
|
||||
"""
|
||||
|
||||
# NOTE: urljoin is tricky and will drop the last segment of the base if it doesn't
|
||||
# end with "/" because it believes that makes it a file.
|
||||
final_url = base_url.rstrip("/") + "/"
|
||||
if is_cloud and not final_url.endswith("/wiki/"):
|
||||
final_url = urljoin(final_url, "wiki") + "/"
|
||||
final_url = urljoin(final_url, content_url.lstrip("/"))
|
||||
return final_url
|
||||
if is_cloud and not base_url.endswith("/wiki"):
|
||||
base_url += "/wiki"
|
||||
return f"{base_url}{content_url}"
|
||||
|
||||
|
||||
def datetime_from_string(datetime_string: str) -> datetime:
|
||||
@@ -460,19 +454,6 @@ def _handle_http_error(e: requests.HTTPError, attempt: int) -> int:
|
||||
logger.warning("HTTPError with `None` as response or as headers")
|
||||
raise e
|
||||
|
||||
# Confluence Server returns 403 when rate limited
|
||||
if e.response.status_code == 403:
|
||||
FORBIDDEN_MAX_RETRY_ATTEMPTS = 7
|
||||
FORBIDDEN_RETRY_DELAY = 10
|
||||
if attempt < FORBIDDEN_MAX_RETRY_ATTEMPTS:
|
||||
logger.warning(
|
||||
"403 error. This sometimes happens when we hit "
|
||||
f"Confluence rate limits. Retrying in {FORBIDDEN_RETRY_DELAY} seconds..."
|
||||
)
|
||||
return FORBIDDEN_RETRY_DELAY
|
||||
|
||||
raise e
|
||||
|
||||
if (
|
||||
e.response.status_code != 429
|
||||
and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower()
|
||||
|
||||
@@ -445,9 +445,6 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
|
||||
logger.warning(
|
||||
f"User '{user_email}' does not have access to the drive APIs."
|
||||
)
|
||||
# mark this user as done so we don't try to retrieve anything for them
|
||||
# again
|
||||
curr_stage.stage = DriveRetrievalStage.DONE
|
||||
return
|
||||
raise
|
||||
|
||||
@@ -584,25 +581,6 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
|
||||
drive_ids_to_retrieve, checkpoint
|
||||
)
|
||||
|
||||
# only process emails that we haven't already completed retrieval for
|
||||
non_completed_org_emails = [
|
||||
user_email
|
||||
for user_email, stage in checkpoint.completion_map.items()
|
||||
if stage != DriveRetrievalStage.DONE
|
||||
]
|
||||
|
||||
# don't process too many emails before returning a checkpoint. This is
|
||||
# to resolve the case where there are a ton of emails that don't have access
|
||||
# to the drive APIs. Without this, we could loop through these emails for
|
||||
# more than 3 hours, causing a timeout and stalling progress.
|
||||
email_batch_takes_us_to_completion = True
|
||||
MAX_EMAILS_TO_PROCESS_BEFORE_CHECKPOINTING = 50
|
||||
if len(non_completed_org_emails) > MAX_EMAILS_TO_PROCESS_BEFORE_CHECKPOINTING:
|
||||
non_completed_org_emails = non_completed_org_emails[
|
||||
:MAX_EMAILS_TO_PROCESS_BEFORE_CHECKPOINTING
|
||||
]
|
||||
email_batch_takes_us_to_completion = False
|
||||
|
||||
user_retrieval_gens = [
|
||||
self._impersonate_user_for_retrieval(
|
||||
email,
|
||||
@@ -613,14 +591,10 @@ class GoogleDriveConnector(SlimConnector, CheckpointConnector[GoogleDriveCheckpo
|
||||
start,
|
||||
end,
|
||||
)
|
||||
for email in non_completed_org_emails
|
||||
for email in all_org_emails
|
||||
]
|
||||
yield from parallel_yield(user_retrieval_gens, max_workers=MAX_DRIVE_WORKERS)
|
||||
|
||||
# if there are more emails to process, don't mark as complete
|
||||
if not email_batch_takes_us_to_completion:
|
||||
return
|
||||
|
||||
remaining_folders = (
|
||||
drive_ids_to_retrieve | folder_ids_to_retrieve
|
||||
) - self._retrieved_ids
|
||||
|
||||
@@ -5,13 +5,11 @@ from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.models import ContextualPruningConfig
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.models import SectionRelevancePiece
|
||||
from onyx.chat.prune_and_merge import _merge_sections
|
||||
from onyx.chat.prune_and_merge import ChunkRange
|
||||
from onyx.chat.prune_and_merge import merge_chunk_intervals
|
||||
from onyx.chat.prune_and_merge import prune_and_merge_sections
|
||||
from onyx.configs.chat_configs import DISABLE_LLM_DOC_RELEVANCE
|
||||
from onyx.context.search.enums import LLMEvaluationType
|
||||
from onyx.context.search.enums import QueryFlow
|
||||
@@ -63,7 +61,6 @@ class SearchPipeline:
|
||||
| None = None,
|
||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||
prompt_config: PromptConfig | None = None,
|
||||
contextual_pruning_config: ContextualPruningConfig | None = None,
|
||||
):
|
||||
# NOTE: The Search Request contains a lot of fields that are overrides, many of them can be None
|
||||
# and typically are None. The preprocessing will fetch default values to replace these empty overrides.
|
||||
@@ -80,9 +77,6 @@ class SearchPipeline:
|
||||
self.search_settings = get_current_search_settings(db_session)
|
||||
self.document_index = get_default_document_index(self.search_settings, None)
|
||||
self.prompt_config: PromptConfig | None = prompt_config
|
||||
self.contextual_pruning_config: ContextualPruningConfig | None = (
|
||||
contextual_pruning_config
|
||||
)
|
||||
|
||||
# Preprocessing steps generate this
|
||||
self._search_query: SearchQuery | None = None
|
||||
@@ -227,16 +221,13 @@ class SearchPipeline:
|
||||
|
||||
# If ee is enabled, censor the chunk sections based on user access
|
||||
# Otherwise, return the retrieved chunks
|
||||
censored_chunks = cast(
|
||||
list[InferenceChunk],
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.external_permissions.post_query_censoring",
|
||||
"_post_query_chunk_censoring",
|
||||
retrieved_chunks,
|
||||
)(
|
||||
chunks=retrieved_chunks,
|
||||
user=self.user,
|
||||
),
|
||||
censored_chunks = fetch_ee_implementation_or_noop(
|
||||
"onyx.external_permissions.post_query_censoring",
|
||||
"_post_query_chunk_censoring",
|
||||
retrieved_chunks,
|
||||
)(
|
||||
chunks=retrieved_chunks,
|
||||
user=self.user,
|
||||
)
|
||||
|
||||
above = self.search_query.chunks_above
|
||||
@@ -429,26 +420,7 @@ class SearchPipeline:
|
||||
if self._final_context_sections is not None:
|
||||
return self._final_context_sections
|
||||
|
||||
if (
|
||||
self.contextual_pruning_config is not None
|
||||
and self.prompt_config is not None
|
||||
):
|
||||
self._final_context_sections = prune_and_merge_sections(
|
||||
sections=self.reranked_sections,
|
||||
section_relevance_list=None,
|
||||
prompt_config=self.prompt_config,
|
||||
llm_config=self.llm.config,
|
||||
question=self.search_query.query,
|
||||
contextual_pruning_config=self.contextual_pruning_config,
|
||||
)
|
||||
|
||||
else:
|
||||
logger.error(
|
||||
"Contextual pruning or prompt config not set, using default merge"
|
||||
)
|
||||
self._final_context_sections = _merge_sections(
|
||||
sections=self.reranked_sections
|
||||
)
|
||||
self._final_context_sections = _merge_sections(sections=self.reranked_sections)
|
||||
return self._final_context_sections
|
||||
|
||||
@property
|
||||
|
||||
@@ -613,19 +613,8 @@ def fetch_connector_credential_pairs(
|
||||
|
||||
def resync_cc_pair(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
search_settings_id: int,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""
|
||||
Updates state stored in the connector_credential_pair table based on the
|
||||
latest index attempt for the given search settings.
|
||||
|
||||
Args:
|
||||
cc_pair: ConnectorCredentialPair to resync
|
||||
search_settings_id: SearchSettings to use for resync
|
||||
db_session: Database session
|
||||
"""
|
||||
|
||||
def find_latest_index_attempt(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
@@ -638,10 +627,11 @@ def resync_cc_pair(
|
||||
ConnectorCredentialPair,
|
||||
IndexAttempt.connector_credential_pair_id == ConnectorCredentialPair.id,
|
||||
)
|
||||
.join(SearchSettings, IndexAttempt.search_settings_id == SearchSettings.id)
|
||||
.filter(
|
||||
ConnectorCredentialPair.connector_id == connector_id,
|
||||
ConnectorCredentialPair.credential_id == credential_id,
|
||||
IndexAttempt.search_settings_id == search_settings_id,
|
||||
SearchSettings.status == IndexModelStatus.PRESENT,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -43,8 +43,6 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
ONE_HOUR_IN_SECONDS = 60 * 60
|
||||
|
||||
|
||||
def check_docs_exist(db_session: Session) -> bool:
|
||||
stmt = select(exists(DbDocument))
|
||||
@@ -609,46 +607,6 @@ def delete_documents_complete__no_commit(
|
||||
delete_documents__no_commit(db_session, document_ids)
|
||||
|
||||
|
||||
def delete_all_documents_for_connector_credential_pair(
|
||||
db_session: Session,
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
timeout: int = ONE_HOUR_IN_SECONDS,
|
||||
) -> None:
|
||||
"""Delete all documents for a given connector credential pair.
|
||||
This will delete all documents and their associated data (chunks, feedback, tags, etc.)
|
||||
|
||||
NOTE: a bit inefficient, but it's not a big deal since this is done rarely - only during
|
||||
an index swap. If we wanted to make this more efficient, we could use a single delete
|
||||
statement + cascade.
|
||||
"""
|
||||
batch_size = 1000
|
||||
start_time = time.monotonic()
|
||||
|
||||
while True:
|
||||
# Get document IDs in batches
|
||||
stmt = (
|
||||
select(DocumentByConnectorCredentialPair.id)
|
||||
.where(
|
||||
DocumentByConnectorCredentialPair.connector_id == connector_id,
|
||||
DocumentByConnectorCredentialPair.credential_id == credential_id,
|
||||
)
|
||||
.limit(batch_size)
|
||||
)
|
||||
document_ids = db_session.scalars(stmt).all()
|
||||
|
||||
if not document_ids:
|
||||
break
|
||||
|
||||
delete_documents_complete__no_commit(
|
||||
db_session=db_session, document_ids=list(document_ids)
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
if time.monotonic() - start_time > timeout:
|
||||
raise RuntimeError("Timeout reached while deleting documents")
|
||||
|
||||
|
||||
def acquire_document_locks(db_session: Session, document_ids: list[str]) -> bool:
|
||||
"""Acquire locks for the specified documents. Ideally this shouldn't be
|
||||
called with large list of document_ids (an exception could be made if the
|
||||
|
||||
@@ -217,6 +217,7 @@ def mark_attempt_in_progress(
|
||||
"index_attempt_id": index_attempt.id,
|
||||
"status": IndexingStatus.IN_PROGRESS.value,
|
||||
"cc_pair_id": index_attempt.connector_credential_pair_id,
|
||||
"search_settings_id": index_attempt.search_settings_id,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
@@ -245,6 +246,9 @@ def mark_attempt_succeeded(
|
||||
"index_attempt_id": index_attempt_id,
|
||||
"status": IndexingStatus.SUCCESS.value,
|
||||
"cc_pair_id": attempt.connector_credential_pair_id,
|
||||
"search_settings_id": attempt.search_settings_id,
|
||||
"total_docs_indexed": attempt.total_docs_indexed,
|
||||
"new_docs_indexed": attempt.new_docs_indexed,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
@@ -273,6 +277,9 @@ def mark_attempt_partially_succeeded(
|
||||
"index_attempt_id": index_attempt_id,
|
||||
"status": IndexingStatus.COMPLETED_WITH_ERRORS.value,
|
||||
"cc_pair_id": attempt.connector_credential_pair_id,
|
||||
"search_settings_id": attempt.search_settings_id,
|
||||
"total_docs_indexed": attempt.total_docs_indexed,
|
||||
"new_docs_indexed": attempt.new_docs_indexed,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
@@ -305,6 +312,10 @@ def mark_attempt_canceled(
|
||||
"index_attempt_id": index_attempt_id,
|
||||
"status": IndexingStatus.CANCELED.value,
|
||||
"cc_pair_id": attempt.connector_credential_pair_id,
|
||||
"search_settings_id": attempt.search_settings_id,
|
||||
"reason": reason,
|
||||
"total_docs_indexed": attempt.total_docs_indexed,
|
||||
"new_docs_indexed": attempt.new_docs_indexed,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
@@ -339,6 +350,10 @@ def mark_attempt_failed(
|
||||
"index_attempt_id": index_attempt_id,
|
||||
"status": IndexingStatus.FAILED.value,
|
||||
"cc_pair_id": attempt.connector_credential_pair_id,
|
||||
"search_settings_id": attempt.search_settings_id,
|
||||
"reason": failure_reason,
|
||||
"total_docs_indexed": attempt.total_docs_indexed,
|
||||
"new_docs_indexed": attempt.new_docs_indexed,
|
||||
},
|
||||
)
|
||||
except Exception:
|
||||
@@ -710,25 +725,6 @@ def cancel_indexing_attempts_past_model(
|
||||
)
|
||||
|
||||
|
||||
def cancel_indexing_attempts_for_search_settings(
|
||||
search_settings_id: int,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Stops all indexing attempts that are in progress or not started for
|
||||
the specified search settings."""
|
||||
|
||||
db_session.execute(
|
||||
update(IndexAttempt)
|
||||
.where(
|
||||
IndexAttempt.status.in_(
|
||||
[IndexingStatus.IN_PROGRESS, IndexingStatus.NOT_STARTED]
|
||||
),
|
||||
IndexAttempt.search_settings_id == search_settings_id,
|
||||
)
|
||||
.values(status=IndexingStatus.FAILED)
|
||||
)
|
||||
|
||||
|
||||
def count_unique_cc_pairs_with_successful_index_attempts(
|
||||
search_settings_id: int | None,
|
||||
db_session: Session,
|
||||
|
||||
@@ -703,11 +703,7 @@ class Connector(Base):
|
||||
)
|
||||
documents_by_connector: Mapped[
|
||||
list["DocumentByConnectorCredentialPair"]
|
||||
] = relationship(
|
||||
"DocumentByConnectorCredentialPair",
|
||||
back_populates="connector",
|
||||
passive_deletes=True,
|
||||
)
|
||||
] = relationship("DocumentByConnectorCredentialPair", back_populates="connector")
|
||||
|
||||
# synchronize this validation logic with RefreshFrequencySchema etc on front end
|
||||
# until we have a centralized validation schema
|
||||
@@ -761,11 +757,7 @@ class Credential(Base):
|
||||
)
|
||||
documents_by_credential: Mapped[
|
||||
list["DocumentByConnectorCredentialPair"]
|
||||
] = relationship(
|
||||
"DocumentByConnectorCredentialPair",
|
||||
back_populates="credential",
|
||||
passive_deletes=True,
|
||||
)
|
||||
] = relationship("DocumentByConnectorCredentialPair", back_populates="credential")
|
||||
|
||||
user: Mapped[User | None] = relationship("User", back_populates="credentials")
|
||||
|
||||
@@ -1118,10 +1110,10 @@ class DocumentByConnectorCredentialPair(Base):
|
||||
id: Mapped[str] = mapped_column(ForeignKey("document.id"), primary_key=True)
|
||||
# TODO: transition this to use the ConnectorCredentialPair id directly
|
||||
connector_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("connector.id", ondelete="CASCADE"), primary_key=True
|
||||
ForeignKey("connector.id"), primary_key=True
|
||||
)
|
||||
credential_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("credential.id", ondelete="CASCADE"), primary_key=True
|
||||
ForeignKey("credential.id"), primary_key=True
|
||||
)
|
||||
|
||||
# used to better keep track of document counts at a connector level
|
||||
@@ -1131,10 +1123,10 @@ class DocumentByConnectorCredentialPair(Base):
|
||||
has_been_indexed: Mapped[bool] = mapped_column(Boolean)
|
||||
|
||||
connector: Mapped[Connector] = relationship(
|
||||
"Connector", back_populates="documents_by_connector", passive_deletes=True
|
||||
"Connector", back_populates="documents_by_connector"
|
||||
)
|
||||
credential: Mapped[Credential] = relationship(
|
||||
"Credential", back_populates="documents_by_credential", passive_deletes=True
|
||||
"Credential", back_populates="documents_by_credential"
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
@@ -1658,8 +1650,8 @@ class Prompt(Base):
|
||||
)
|
||||
name: Mapped[str] = mapped_column(String)
|
||||
description: Mapped[str] = mapped_column(String)
|
||||
system_prompt: Mapped[str] = mapped_column(String(length=8000))
|
||||
task_prompt: Mapped[str] = mapped_column(String(length=8000))
|
||||
system_prompt: Mapped[str] = mapped_column(Text)
|
||||
task_prompt: Mapped[str] = mapped_column(Text)
|
||||
include_citations: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
datetime_aware: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
# Default prompts are configured via backend during deployment
|
||||
|
||||
@@ -37,8 +37,8 @@ from onyx.db.models import UserFile
|
||||
from onyx.db.models import UserFolder
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.notification import create_notification
|
||||
from onyx.server.features.persona.models import FullPersonaSnapshot
|
||||
from onyx.server.features.persona.models import PersonaSharedNotificationData
|
||||
from onyx.server.features.persona.models import PersonaSnapshot
|
||||
from onyx.server.features.persona.models import PersonaUpsertRequest
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
@@ -201,7 +201,7 @@ def create_update_persona(
|
||||
create_persona_request: PersonaUpsertRequest,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
) -> FullPersonaSnapshot:
|
||||
) -> PersonaSnapshot:
|
||||
"""Higher level function than upsert_persona, although either is valid to use."""
|
||||
# Permission to actually use these is checked later
|
||||
|
||||
@@ -271,7 +271,7 @@ def create_update_persona(
|
||||
logger.exception("Failed to create persona")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
return FullPersonaSnapshot.from_model(persona)
|
||||
return PersonaSnapshot.from_model(persona)
|
||||
|
||||
|
||||
def update_persona_shared_users(
|
||||
|
||||
@@ -3,9 +3,8 @@ from sqlalchemy.orm import Session
|
||||
from onyx.configs.constants import KV_REINDEX_KEY
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from onyx.db.connector_credential_pair import resync_cc_pair
|
||||
from onyx.db.document import delete_all_documents_for_connector_credential_pair
|
||||
from onyx.db.enums import IndexModelStatus
|
||||
from onyx.db.index_attempt import cancel_indexing_attempts_for_search_settings
|
||||
from onyx.db.index_attempt import cancel_indexing_attempts_past_model
|
||||
from onyx.db.index_attempt import (
|
||||
count_unique_cc_pairs_with_successful_index_attempts,
|
||||
)
|
||||
@@ -27,50 +26,32 @@ def _perform_index_swap(
|
||||
current_search_settings: SearchSettings,
|
||||
secondary_search_settings: SearchSettings,
|
||||
all_cc_pairs: list[ConnectorCredentialPair],
|
||||
cleanup_documents: bool = False,
|
||||
) -> None:
|
||||
"""Swap the indices and expire the old one."""
|
||||
if len(all_cc_pairs) > 0:
|
||||
kv_store = get_kv_store()
|
||||
kv_store.store(KV_REINDEX_KEY, False)
|
||||
|
||||
# Expire jobs for the now past index/embedding model
|
||||
cancel_indexing_attempts_for_search_settings(
|
||||
search_settings_id=current_search_settings.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Recount aggregates
|
||||
for cc_pair in all_cc_pairs:
|
||||
resync_cc_pair(
|
||||
cc_pair=cc_pair,
|
||||
# sync based on the new search settings
|
||||
search_settings_id=secondary_search_settings.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
if cleanup_documents:
|
||||
# clean up all DocumentByConnectorCredentialPair / Document rows, since we're
|
||||
# doing an instant swap and no documents will exist in the new index.
|
||||
for cc_pair in all_cc_pairs:
|
||||
delete_all_documents_for_connector_credential_pair(
|
||||
db_session=db_session,
|
||||
connector_id=cc_pair.connector_id,
|
||||
credential_id=cc_pair.credential_id,
|
||||
)
|
||||
|
||||
# swap over search settings
|
||||
current_search_settings = get_current_search_settings(db_session)
|
||||
update_search_settings_status(
|
||||
search_settings=current_search_settings,
|
||||
new_status=IndexModelStatus.PAST,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
update_search_settings_status(
|
||||
search_settings=secondary_search_settings,
|
||||
new_status=IndexModelStatus.PRESENT,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
if len(all_cc_pairs) > 0:
|
||||
kv_store = get_kv_store()
|
||||
kv_store.store(KV_REINDEX_KEY, False)
|
||||
|
||||
# Expire jobs for the now past index/embedding model
|
||||
cancel_indexing_attempts_past_model(db_session)
|
||||
|
||||
# Recount aggregates
|
||||
for cc_pair in all_cc_pairs:
|
||||
resync_cc_pair(cc_pair, db_session=db_session)
|
||||
|
||||
# remove the old index from the vector db
|
||||
document_index = get_default_document_index(secondary_search_settings, None)
|
||||
document_index.ensure_indices_exist(
|
||||
@@ -107,9 +88,6 @@ def check_and_perform_index_swap(db_session: Session) -> SearchSettings | None:
|
||||
current_search_settings=current_search_settings,
|
||||
secondary_search_settings=secondary_search_settings,
|
||||
all_cc_pairs=all_cc_pairs,
|
||||
# clean up all DocumentByConnectorCredentialPair / Document rows, since we're
|
||||
# doing an instant swap.
|
||||
cleanup_documents=True,
|
||||
)
|
||||
return current_search_settings
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ from datetime import timezone
|
||||
from onyx.configs.constants import INDEX_SEPARATOR
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.document_index.interfaces import VespaChunkRequest
|
||||
from onyx.document_index.vespa_constants import ACCESS_CONTROL_LIST
|
||||
from onyx.document_index.vespa_constants import CHUNK_ID
|
||||
from onyx.document_index.vespa_constants import DOC_UPDATED_AT
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_ID
|
||||
@@ -75,10 +74,8 @@ def build_vespa_filters(
|
||||
filter_str += f'({TENANT_ID} contains "{filters.tenant_id}") and '
|
||||
|
||||
# ACL filters
|
||||
if filters.access_control_list is not None:
|
||||
filter_str += _build_or_filters(
|
||||
ACCESS_CONTROL_LIST, filters.access_control_list
|
||||
)
|
||||
# if filters.access_control_list is not None:
|
||||
# filter_str += _build_or_filters(ACCESS_CONTROL_LIST, filters.access_control_list)
|
||||
|
||||
# Source type filters
|
||||
source_strs = (
|
||||
|
||||
@@ -602,7 +602,7 @@ def get_max_input_tokens(
|
||||
)
|
||||
|
||||
if input_toks <= 0:
|
||||
return GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||
raise RuntimeError("No tokens for input for the LLM given settings")
|
||||
|
||||
return input_toks
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import logging
|
||||
import sys
|
||||
import traceback
|
||||
from collections.abc import AsyncGenerator
|
||||
@@ -17,7 +16,6 @@ from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from httpx_oauth.clients.google import GoogleOAuth2
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
from sentry_sdk.integrations.fastapi import FastApiIntegration
|
||||
from sentry_sdk.integrations.starlette import StarletteIntegration
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -104,8 +102,6 @@ from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.setup import setup_multitenant_onyx
|
||||
from onyx.setup import setup_onyx
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.logger import setup_uvicorn_logger
|
||||
from onyx.utils.middleware import add_onyx_request_id_middleware
|
||||
from onyx.utils.telemetry import get_or_generate_uuid
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
@@ -120,12 +116,6 @@ from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
file_handlers = [
|
||||
h for h in logger.logger.handlers if isinstance(h, logging.FileHandler)
|
||||
]
|
||||
|
||||
setup_uvicorn_logger(shared_file_handlers=file_handlers)
|
||||
|
||||
|
||||
def validation_exception_handler(request: Request, exc: Exception) -> JSONResponse:
|
||||
if not isinstance(exc, RequestValidationError):
|
||||
@@ -431,14 +421,9 @@ def get_application() -> FastAPI:
|
||||
if LOG_ENDPOINT_LATENCY:
|
||||
add_latency_logging_middleware(application, logger)
|
||||
|
||||
add_onyx_request_id_middleware(application, "API", logger)
|
||||
|
||||
# Ensure all routes have auth enabled or are explicitly marked as public
|
||||
check_router_auth(application)
|
||||
|
||||
# Initialize and instrument the app
|
||||
Instrumentator().instrument(application).expose(application)
|
||||
|
||||
return application
|
||||
|
||||
|
||||
|
||||
@@ -49,7 +49,6 @@ PUBLIC_ENDPOINT_SPECS = [
|
||||
("/auth/oauth/callback", {"GET"}),
|
||||
# anonymous user on cloud
|
||||
("/tenants/anonymous-user", {"POST"}),
|
||||
("/metrics", {"GET"}), # added by prometheus_fastapi_instrumentator
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -21,7 +21,7 @@ from onyx.background.celery.tasks.external_group_syncing.tasks import (
|
||||
from onyx.background.celery.tasks.pruning.tasks import (
|
||||
try_creating_prune_generator_task,
|
||||
)
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
from onyx.background.celery.versioned_apps.primary import app as primary_app
|
||||
from onyx.background.indexing.models import IndexAttemptErrorPydantic
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
@@ -219,7 +219,7 @@ def update_cc_pair_status(
|
||||
continue
|
||||
|
||||
# Revoke the task to prevent it from running
|
||||
client_app.control.revoke(index_payload.celery_task_id)
|
||||
primary_app.control.revoke(index_payload.celery_task_id)
|
||||
|
||||
# If it is running, then signaling for termination will get the
|
||||
# watchdog thread to kill the spawned task
|
||||
@@ -238,7 +238,7 @@ def update_cc_pair_status(
|
||||
db_session.commit()
|
||||
|
||||
# this speeds up the start of indexing by firing the check immediately
|
||||
client_app.send_task(
|
||||
primary_app.send_task(
|
||||
OnyxCeleryTask.CHECK_FOR_INDEXING,
|
||||
kwargs=dict(tenant_id=tenant_id),
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
@@ -376,7 +376,7 @@ def prune_cc_pair(
|
||||
f"{cc_pair.connector.name} connector."
|
||||
)
|
||||
payload_id = try_creating_prune_generator_task(
|
||||
client_app, cc_pair, db_session, r, tenant_id
|
||||
primary_app, cc_pair, db_session, r, tenant_id
|
||||
)
|
||||
if not payload_id:
|
||||
raise HTTPException(
|
||||
@@ -450,7 +450,7 @@ def sync_cc_pair(
|
||||
f"{cc_pair.connector.name} connector."
|
||||
)
|
||||
payload_id = try_creating_permissions_sync_task(
|
||||
client_app, cc_pair_id, r, tenant_id
|
||||
primary_app, cc_pair_id, r, tenant_id
|
||||
)
|
||||
if not payload_id:
|
||||
raise HTTPException(
|
||||
@@ -524,7 +524,7 @@ def sync_cc_pair_groups(
|
||||
f"{cc_pair.connector.name} connector."
|
||||
)
|
||||
payload_id = try_creating_external_group_sync_task(
|
||||
client_app, cc_pair_id, r, tenant_id
|
||||
primary_app, cc_pair_id, r, tenant_id
|
||||
)
|
||||
if not payload_id:
|
||||
raise HTTPException(
|
||||
@@ -634,7 +634,7 @@ def associate_credential_to_connector(
|
||||
)
|
||||
|
||||
# trigger indexing immediately
|
||||
client_app.send_task(
|
||||
primary_app.send_task(
|
||||
OnyxCeleryTask.CHECK_FOR_INDEXING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
kwargs={"tenant_id": tenant_id},
|
||||
|
||||
@@ -20,7 +20,7 @@ from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_chat_accessible_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
from onyx.background.celery.versioned_apps.primary import app as primary_app
|
||||
from onyx.configs.app_configs import ENABLED_CONNECTOR_TYPES
|
||||
from onyx.configs.app_configs import MOCK_CONNECTOR_FILE_PATH
|
||||
from onyx.configs.constants import DocumentSource
|
||||
@@ -928,7 +928,7 @@ def create_connector_with_mock_credential(
|
||||
)
|
||||
|
||||
# trigger indexing immediately
|
||||
client_app.send_task(
|
||||
primary_app.send_task(
|
||||
OnyxCeleryTask.CHECK_FOR_INDEXING,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
kwargs={"tenant_id": tenant_id},
|
||||
@@ -1314,7 +1314,7 @@ def trigger_indexing_for_cc_pair(
|
||||
# run the beat task to pick up the triggers immediately
|
||||
priority = OnyxCeleryPriority.HIGHEST if is_user_file else OnyxCeleryPriority.HIGH
|
||||
logger.info(f"Sending indexing check task with priority {priority}")
|
||||
client_app.send_task(
|
||||
primary_app.send_task(
|
||||
OnyxCeleryTask.CHECK_FOR_INDEXING,
|
||||
priority=priority,
|
||||
kwargs={"tenant_id": tenant_id},
|
||||
|
||||
@@ -6,7 +6,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
from onyx.background.celery.versioned_apps.primary import app as primary_app
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.document_set import check_document_sets_are_public
|
||||
@@ -52,7 +52,7 @@ def create_document_set(
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
client_app.send_task(
|
||||
primary_app.send_task(
|
||||
OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
kwargs={"tenant_id": tenant_id},
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
@@ -85,7 +85,7 @@ def patch_document_set(
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
client_app.send_task(
|
||||
primary_app.send_task(
|
||||
OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
kwargs={"tenant_id": tenant_id},
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
@@ -108,7 +108,7 @@ def delete_document_set(
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
client_app.send_task(
|
||||
primary_app.send_task(
|
||||
OnyxCeleryTask.CHECK_FOR_VESPA_SYNC_TASK,
|
||||
kwargs={"tenant_id": tenant_id},
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
|
||||
@@ -43,7 +43,6 @@ from onyx.file_store.models import ChatFileType
|
||||
from onyx.secondary_llm_flows.starter_message_creation import (
|
||||
generate_starter_messages,
|
||||
)
|
||||
from onyx.server.features.persona.models import FullPersonaSnapshot
|
||||
from onyx.server.features.persona.models import GenerateStarterMessageRequest
|
||||
from onyx.server.features.persona.models import ImageGenerationToolStatus
|
||||
from onyx.server.features.persona.models import PersonaLabelCreate
|
||||
@@ -425,8 +424,8 @@ def get_persona(
|
||||
persona_id: int,
|
||||
user: User | None = Depends(current_limited_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> FullPersonaSnapshot:
|
||||
return FullPersonaSnapshot.from_model(
|
||||
) -> PersonaSnapshot:
|
||||
return PersonaSnapshot.from_model(
|
||||
get_persona_by_id(
|
||||
persona_id=persona_id,
|
||||
user=user,
|
||||
|
||||
@@ -91,80 +91,37 @@ class PersonaUpsertRequest(BaseModel):
|
||||
|
||||
class PersonaSnapshot(BaseModel):
|
||||
id: int
|
||||
owner: MinimalUserSnapshot | None
|
||||
name: str
|
||||
description: str
|
||||
is_public: bool
|
||||
is_visible: bool
|
||||
icon_shape: int | None = None
|
||||
icon_color: str | None = None
|
||||
is_public: bool
|
||||
display_priority: int | None
|
||||
description: str
|
||||
num_chunks: float | None
|
||||
llm_relevance_filter: bool
|
||||
llm_filter_extraction: bool
|
||||
llm_model_provider_override: str | None
|
||||
llm_model_version_override: str | None
|
||||
starter_messages: list[StarterMessage] | None
|
||||
builtin_persona: bool
|
||||
prompts: list[PromptSnapshot]
|
||||
tools: list[ToolSnapshot]
|
||||
document_sets: list[DocumentSet]
|
||||
users: list[MinimalUserSnapshot]
|
||||
groups: list[int]
|
||||
icon_color: str | None
|
||||
icon_shape: int | None
|
||||
uploaded_image_id: str | None = None
|
||||
user_file_ids: list[int] = Field(default_factory=list)
|
||||
user_folder_ids: list[int] = Field(default_factory=list)
|
||||
display_priority: int | None = None
|
||||
is_default_persona: bool = False
|
||||
builtin_persona: bool = False
|
||||
starter_messages: list[StarterMessage] | None = None
|
||||
tools: list[ToolSnapshot] = Field(default_factory=list)
|
||||
labels: list["PersonaLabelSnapshot"] = Field(default_factory=list)
|
||||
owner: MinimalUserSnapshot | None = None
|
||||
users: list[MinimalUserSnapshot] = Field(default_factory=list)
|
||||
groups: list[int] = Field(default_factory=list)
|
||||
document_sets: list[DocumentSet] = Field(default_factory=list)
|
||||
llm_model_provider_override: str | None = None
|
||||
llm_model_version_override: str | None = None
|
||||
num_chunks: float | None = None
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, persona: Persona) -> "PersonaSnapshot":
|
||||
return PersonaSnapshot(
|
||||
id=persona.id,
|
||||
name=persona.name,
|
||||
description=persona.description,
|
||||
is_public=persona.is_public,
|
||||
is_visible=persona.is_visible,
|
||||
icon_shape=persona.icon_shape,
|
||||
icon_color=persona.icon_color,
|
||||
uploaded_image_id=persona.uploaded_image_id,
|
||||
user_file_ids=[file.id for file in persona.user_files],
|
||||
user_folder_ids=[folder.id for folder in persona.user_folders],
|
||||
display_priority=persona.display_priority,
|
||||
is_default_persona=persona.is_default_persona,
|
||||
builtin_persona=persona.builtin_persona,
|
||||
starter_messages=persona.starter_messages,
|
||||
tools=[ToolSnapshot.from_model(tool) for tool in persona.tools],
|
||||
labels=[PersonaLabelSnapshot.from_model(label) for label in persona.labels],
|
||||
owner=(
|
||||
MinimalUserSnapshot(id=persona.user.id, email=persona.user.email)
|
||||
if persona.user
|
||||
else None
|
||||
),
|
||||
users=[
|
||||
MinimalUserSnapshot(id=user.id, email=user.email)
|
||||
for user in persona.users
|
||||
],
|
||||
groups=[user_group.id for user_group in persona.groups],
|
||||
document_sets=[
|
||||
DocumentSet.from_model(document_set_model)
|
||||
for document_set_model in persona.document_sets
|
||||
],
|
||||
llm_model_provider_override=persona.llm_model_provider_override,
|
||||
llm_model_version_override=persona.llm_model_version_override,
|
||||
num_chunks=persona.num_chunks,
|
||||
)
|
||||
|
||||
|
||||
# Model with full context on perona's internal settings
|
||||
# This is used for flows which need to know all settings
|
||||
class FullPersonaSnapshot(PersonaSnapshot):
|
||||
is_default_persona: bool
|
||||
search_start_date: datetime | None = None
|
||||
prompts: list[PromptSnapshot] = Field(default_factory=list)
|
||||
llm_relevance_filter: bool = False
|
||||
llm_filter_extraction: bool = False
|
||||
labels: list["PersonaLabelSnapshot"] = []
|
||||
user_file_ids: list[int] | None = None
|
||||
user_folder_ids: list[int] | None = None
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
cls, persona: Persona, allow_deleted: bool = False
|
||||
) -> "FullPersonaSnapshot":
|
||||
) -> "PersonaSnapshot":
|
||||
if persona.deleted:
|
||||
error_msg = f"Persona with ID {persona.id} has been deleted"
|
||||
if not allow_deleted:
|
||||
@@ -172,32 +129,44 @@ class FullPersonaSnapshot(PersonaSnapshot):
|
||||
else:
|
||||
logger.warning(error_msg)
|
||||
|
||||
return FullPersonaSnapshot(
|
||||
return PersonaSnapshot(
|
||||
id=persona.id,
|
||||
name=persona.name,
|
||||
description=persona.description,
|
||||
is_public=persona.is_public,
|
||||
is_visible=persona.is_visible,
|
||||
icon_shape=persona.icon_shape,
|
||||
icon_color=persona.icon_color,
|
||||
uploaded_image_id=persona.uploaded_image_id,
|
||||
user_file_ids=[file.id for file in persona.user_files],
|
||||
user_folder_ids=[folder.id for folder in persona.user_folders],
|
||||
display_priority=persona.display_priority,
|
||||
is_default_persona=persona.is_default_persona,
|
||||
builtin_persona=persona.builtin_persona,
|
||||
starter_messages=persona.starter_messages,
|
||||
tools=[ToolSnapshot.from_model(tool) for tool in persona.tools],
|
||||
labels=[PersonaLabelSnapshot.from_model(label) for label in persona.labels],
|
||||
owner=(
|
||||
MinimalUserSnapshot(id=persona.user.id, email=persona.user.email)
|
||||
if persona.user
|
||||
else None
|
||||
),
|
||||
search_start_date=persona.search_start_date,
|
||||
prompts=[PromptSnapshot.from_model(prompt) for prompt in persona.prompts],
|
||||
is_visible=persona.is_visible,
|
||||
is_public=persona.is_public,
|
||||
display_priority=persona.display_priority,
|
||||
description=persona.description,
|
||||
num_chunks=persona.num_chunks,
|
||||
llm_relevance_filter=persona.llm_relevance_filter,
|
||||
llm_filter_extraction=persona.llm_filter_extraction,
|
||||
llm_model_provider_override=persona.llm_model_provider_override,
|
||||
llm_model_version_override=persona.llm_model_version_override,
|
||||
starter_messages=persona.starter_messages,
|
||||
builtin_persona=persona.builtin_persona,
|
||||
is_default_persona=persona.is_default_persona,
|
||||
prompts=[PromptSnapshot.from_model(prompt) for prompt in persona.prompts],
|
||||
tools=[ToolSnapshot.from_model(tool) for tool in persona.tools],
|
||||
document_sets=[
|
||||
DocumentSet.from_model(document_set_model)
|
||||
for document_set_model in persona.document_sets
|
||||
],
|
||||
users=[
|
||||
MinimalUserSnapshot(id=user.id, email=user.email)
|
||||
for user in persona.users
|
||||
],
|
||||
groups=[user_group.id for user_group in persona.groups],
|
||||
icon_color=persona.icon_color,
|
||||
icon_shape=persona.icon_shape,
|
||||
uploaded_image_id=persona.uploaded_image_id,
|
||||
search_start_date=persona.search_start_date,
|
||||
labels=[PersonaLabelSnapshot.from_model(label) for label in persona.labels],
|
||||
user_file_ids=[file.id for file in persona.user_files],
|
||||
user_folder_ids=[folder.id for folder in persona.user_folders],
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -10,7 +10,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
from onyx.background.celery.versioned_apps.primary import app as primary_app
|
||||
from onyx.configs.app_configs import GENERATIVE_MODEL_ACCESS_CHECK_FREQ
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import KV_GEN_AI_KEY_CHECK_TIME
|
||||
@@ -192,7 +192,7 @@ def create_deletion_attempt_for_connector_id(
|
||||
db_session.commit()
|
||||
|
||||
# run the beat task to pick up this deletion from the db immediately
|
||||
client_app.send_task(
|
||||
primary_app.send_task(
|
||||
OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
kwargs={"tenant_id": tenant_id},
|
||||
|
||||
@@ -19,7 +19,6 @@ from onyx.db.models import SlackBot as SlackAppModel
|
||||
from onyx.db.models import SlackChannelConfig as SlackChannelConfigModel
|
||||
from onyx.db.models import User
|
||||
from onyx.onyxbot.slack.config import VALID_SLACK_FILTERS
|
||||
from onyx.server.features.persona.models import FullPersonaSnapshot
|
||||
from onyx.server.features.persona.models import PersonaSnapshot
|
||||
from onyx.server.models import FullUserSnapshot
|
||||
from onyx.server.models import InvitedUserSnapshot
|
||||
@@ -246,7 +245,7 @@ class SlackChannelConfig(BaseModel):
|
||||
id=slack_channel_config_model.id,
|
||||
slack_bot_id=slack_channel_config_model.slack_bot_id,
|
||||
persona=(
|
||||
FullPersonaSnapshot.from_model(
|
||||
PersonaSnapshot.from_model(
|
||||
slack_channel_config_model.persona, allow_deleted=True
|
||||
)
|
||||
if slack_channel_config_model.persona
|
||||
|
||||
@@ -117,11 +117,7 @@ def set_new_search_settings(
|
||||
search_settings_id=search_settings.id, db_session=db_session
|
||||
)
|
||||
for cc_pair in get_connector_credential_pairs(db_session):
|
||||
resync_cc_pair(
|
||||
cc_pair=cc_pair,
|
||||
search_settings_id=new_search_settings.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
resync_cc_pair(cc_pair, db_session=db_session)
|
||||
|
||||
db_session.commit()
|
||||
return IdReturn(id=new_search_settings.id)
|
||||
|
||||
@@ -96,11 +96,7 @@ def setup_onyx(
|
||||
)
|
||||
|
||||
for cc_pair in get_connector_credential_pairs(db_session):
|
||||
resync_cc_pair(
|
||||
cc_pair=cc_pair,
|
||||
search_settings_id=search_settings.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
resync_cc_pair(cc_pair, db_session=db_session)
|
||||
|
||||
# Expire all old embedding models indexing attempts, technically redundant
|
||||
cancel_indexing_attempts_past_model(db_session)
|
||||
|
||||
@@ -376,7 +376,6 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
db_session=alternate_db_session or self.db_session,
|
||||
prompt_config=self.prompt_config,
|
||||
retrieved_sections_callback=retrieved_sections_callback,
|
||||
contextual_pruning_config=self.contextual_pruning_config,
|
||||
)
|
||||
|
||||
search_query_info = SearchQueryInfo(
|
||||
@@ -448,7 +447,6 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
db_session=self.db_session,
|
||||
bypass_acl=self.bypass_acl,
|
||||
prompt_config=self.prompt_config,
|
||||
contextual_pruning_config=self.contextual_pruning_config,
|
||||
)
|
||||
|
||||
# Log what we're doing
|
||||
|
||||
@@ -13,7 +13,6 @@ from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.configs import SLACK_CHANNEL_ID
|
||||
from shared_configs.configs import TENANT_ID_PREFIX
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.contextvars import ONYX_REQUEST_ID_CONTEXTVAR
|
||||
|
||||
|
||||
logging.addLevelName(logging.INFO + 5, "NOTICE")
|
||||
@@ -72,14 +71,6 @@ def get_log_level_from_str(log_level_str: str = LOG_LEVEL) -> int:
|
||||
return log_level_dict.get(log_level_str.upper(), logging.getLevelName("NOTICE"))
|
||||
|
||||
|
||||
class OnyxRequestIDFilter(logging.Filter):
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
from shared_configs.contextvars import ONYX_REQUEST_ID_CONTEXTVAR
|
||||
|
||||
record.request_id = ONYX_REQUEST_ID_CONTEXTVAR.get() or "-"
|
||||
return True
|
||||
|
||||
|
||||
class OnyxLoggingAdapter(logging.LoggerAdapter):
|
||||
def process(
|
||||
self, msg: str, kwargs: MutableMapping[str, Any]
|
||||
@@ -112,7 +103,6 @@ class OnyxLoggingAdapter(logging.LoggerAdapter):
|
||||
msg = f"[CC Pair: {cc_pair_id}] {msg}"
|
||||
|
||||
break
|
||||
|
||||
# Add tenant information if it differs from default
|
||||
# This will always be the case for authenticated API requests
|
||||
if MULTI_TENANT:
|
||||
@@ -125,11 +115,6 @@ class OnyxLoggingAdapter(logging.LoggerAdapter):
|
||||
)
|
||||
msg = f"[t:{short_tenant}] {msg}"
|
||||
|
||||
# request id within a fastapi route
|
||||
fastapi_request_id = ONYX_REQUEST_ID_CONTEXTVAR.get()
|
||||
if fastapi_request_id:
|
||||
msg = f"[{fastapi_request_id}] {msg}"
|
||||
|
||||
# For Slack Bot, logs the channel relevant to the request
|
||||
channel_id = self.extra.get(SLACK_CHANNEL_ID) if self.extra else None
|
||||
if channel_id:
|
||||
@@ -180,14 +165,6 @@ class ColoredFormatter(logging.Formatter):
|
||||
return super().format(record)
|
||||
|
||||
|
||||
def get_uvicorn_standard_formatter() -> ColoredFormatter:
|
||||
"""Returns a standard colored logging formatter."""
|
||||
return ColoredFormatter(
|
||||
"%(asctime)s %(filename)30s %(lineno)4s: [%(request_id)s] %(message)s",
|
||||
datefmt="%m/%d/%Y %I:%M:%S %p",
|
||||
)
|
||||
|
||||
|
||||
def get_standard_formatter() -> ColoredFormatter:
|
||||
"""Returns a standard colored logging formatter."""
|
||||
return ColoredFormatter(
|
||||
@@ -224,6 +201,12 @@ def setup_logger(
|
||||
|
||||
logger.addHandler(handler)
|
||||
|
||||
uvicorn_logger = logging.getLogger("uvicorn.access")
|
||||
if uvicorn_logger:
|
||||
uvicorn_logger.handlers = []
|
||||
uvicorn_logger.addHandler(handler)
|
||||
uvicorn_logger.setLevel(log_level)
|
||||
|
||||
is_containerized = is_running_in_container()
|
||||
if LOG_FILE_NAME and (is_containerized or DEV_LOGGING_ENABLED):
|
||||
log_levels = ["debug", "info", "notice"]
|
||||
@@ -242,37 +225,14 @@ def setup_logger(
|
||||
file_handler.setFormatter(formatter)
|
||||
logger.addHandler(file_handler)
|
||||
|
||||
if uvicorn_logger:
|
||||
uvicorn_logger.addHandler(file_handler)
|
||||
|
||||
logger.notice = lambda msg, *args, **kwargs: logger.log(logging.getLevelName("NOTICE"), msg, *args, **kwargs) # type: ignore
|
||||
|
||||
return OnyxLoggingAdapter(logger, extra=extra)
|
||||
|
||||
|
||||
def setup_uvicorn_logger(
|
||||
log_level: int = get_log_level_from_str(),
|
||||
shared_file_handlers: list[logging.FileHandler] | None = None,
|
||||
) -> None:
|
||||
uvicorn_logger = logging.getLogger("uvicorn.access")
|
||||
if not uvicorn_logger:
|
||||
return
|
||||
|
||||
formatter = get_uvicorn_standard_formatter()
|
||||
|
||||
handler = logging.StreamHandler()
|
||||
handler.setLevel(log_level)
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
uvicorn_logger.handlers = []
|
||||
uvicorn_logger.addHandler(handler)
|
||||
uvicorn_logger.setLevel(log_level)
|
||||
uvicorn_logger.addFilter(OnyxRequestIDFilter())
|
||||
|
||||
if shared_file_handlers:
|
||||
for fh in shared_file_handlers:
|
||||
uvicorn_logger.addHandler(fh)
|
||||
|
||||
return
|
||||
|
||||
|
||||
def print_loggers() -> None:
|
||||
"""Print information about all loggers. Use to debug logging issues."""
|
||||
root_logger = logging.getLogger()
|
||||
|
||||
@@ -1,62 +0,0 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Awaitable
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi import Request
|
||||
from fastapi import Response
|
||||
|
||||
from shared_configs.contextvars import ONYX_REQUEST_ID_CONTEXTVAR
|
||||
|
||||
|
||||
def add_onyx_request_id_middleware(
|
||||
app: FastAPI, prefix: str, logger: logging.LoggerAdapter
|
||||
) -> None:
|
||||
@app.middleware("http")
|
||||
async def set_request_id(
|
||||
request: Request, call_next: Callable[[Request], Awaitable[Response]]
|
||||
) -> Response:
|
||||
"""Generate a request hash that can be used to track the lifecycle
|
||||
of a request. The hash is prefixed to help indicated where the request id
|
||||
originated.
|
||||
|
||||
Format is f"{PREFIX}:{ID}" where PREFIX is 3 chars and ID is 8 chars.
|
||||
Total length is 12 chars.
|
||||
"""
|
||||
|
||||
onyx_request_id = request.headers.get("X-Onyx-Request-ID")
|
||||
if not onyx_request_id:
|
||||
onyx_request_id = make_randomized_onyx_request_id(prefix)
|
||||
|
||||
ONYX_REQUEST_ID_CONTEXTVAR.set(onyx_request_id)
|
||||
return await call_next(request)
|
||||
|
||||
|
||||
def make_randomized_onyx_request_id(prefix: str) -> str:
|
||||
"""generates a randomized request id"""
|
||||
|
||||
hash_input = str(uuid.uuid4())
|
||||
return _make_onyx_request_id(prefix, hash_input)
|
||||
|
||||
|
||||
def make_structured_onyx_request_id(prefix: str, request_url: str) -> str:
|
||||
"""Not used yet, but could be in the future!"""
|
||||
hash_input = f"{request_url}:{datetime.now(timezone.utc)}"
|
||||
return _make_onyx_request_id(prefix, hash_input)
|
||||
|
||||
|
||||
def _make_onyx_request_id(prefix: str, hash_input: str) -> str:
|
||||
"""helper function to return an id given a string input"""
|
||||
hash_obj = hashlib.md5(hash_input.encode("utf-8"))
|
||||
hash_bytes = hash_obj.digest()[:6] # Truncate to 6 bytes
|
||||
|
||||
# 6 bytes becomes 8 bytes. we shouldn't need to strip but just in case
|
||||
# NOTE: possible we'll want more input bytes if id's aren't unique enough
|
||||
hash_str = base64.urlsafe_b64encode(hash_bytes).decode("utf-8").rstrip("=")
|
||||
onyx_request_id = f"{prefix}:{hash_str}"
|
||||
return onyx_request_id
|
||||
@@ -39,7 +39,6 @@ class RecordType(str, Enum):
|
||||
INDEXING_PROGRESS = "indexing_progress"
|
||||
INDEXING_COMPLETE = "indexing_complete"
|
||||
PERMISSION_SYNC_PROGRESS = "permission_sync_progress"
|
||||
PERMISSION_SYNC_COMPLETE = "permission_sync_complete"
|
||||
INDEX_ATTEMPT_STATUS = "index_attempt_status"
|
||||
|
||||
|
||||
|
||||
@@ -332,15 +332,14 @@ def wait_on_background(task: TimeoutThread[R]) -> R:
|
||||
return task.result
|
||||
|
||||
|
||||
def _next_or_none(ind: int, gen: Iterator[R]) -> tuple[int, R | None]:
|
||||
return ind, next(gen, None)
|
||||
def _next_or_none(ind: int, g: Iterator[R]) -> tuple[int, R | None]:
|
||||
return ind, next(g, None)
|
||||
|
||||
|
||||
def parallel_yield(gens: list[Iterator[R]], max_workers: int = 10) -> Iterator[R]:
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
future_to_index: dict[Future[tuple[int, R | None]], int] = {
|
||||
executor.submit(_next_or_none, ind, gen): ind
|
||||
for ind, gen in enumerate(gens)
|
||||
executor.submit(_next_or_none, i, g): i for i, g in enumerate(gens)
|
||||
}
|
||||
|
||||
next_ind = len(gens)
|
||||
|
||||
@@ -95,5 +95,4 @@ urllib3==2.2.3
|
||||
mistune==0.8.4
|
||||
sentry-sdk==2.14.0
|
||||
prometheus_client==0.21.0
|
||||
fastapi-limiter==0.1.6
|
||||
prometheus_fastapi_instrumentator==7.1.0
|
||||
fastapi-limiter==0.1.6
|
||||
@@ -15,5 +15,4 @@ uvicorn==0.21.1
|
||||
voyageai==0.2.3
|
||||
litellm==1.61.16
|
||||
sentry-sdk[fastapi,celery,starlette]==2.14.0
|
||||
aioboto3==13.4.0
|
||||
prometheus_fastapi_instrumentator==7.1.0
|
||||
aioboto3==13.4.0
|
||||
@@ -887,7 +887,6 @@ def main() -> None:
|
||||
type=int,
|
||||
help="Maximum number of documents to delete (for delete-all-documents)",
|
||||
)
|
||||
parser.add_argument("--link", help="Document link (for get_acls filter)")
|
||||
|
||||
args = parser.parse_args()
|
||||
vespa_debug = VespaDebugging(args.tenant_id)
|
||||
@@ -925,11 +924,7 @@ def main() -> None:
|
||||
elif args.action == "get_acls":
|
||||
if args.cc_pair_id is None:
|
||||
parser.error("--cc-pair-id is required for get_acls action")
|
||||
|
||||
if args.link is None:
|
||||
vespa_debug.acls(args.cc_pair_id, args.n)
|
||||
else:
|
||||
vespa_debug.acls_by_link(args.cc_pair_id, args.link)
|
||||
vespa_debug.acls(args.cc_pair_id, args.n)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -58,7 +58,6 @@ INDEXING_ONLY = os.environ.get("INDEXING_ONLY", "").lower() == "true"
|
||||
|
||||
# The process needs to have this for the log file to write to
|
||||
# otherwise, it will not create additional log files
|
||||
# This should just be the filename base without extension or path.
|
||||
LOG_FILE_NAME = os.environ.get("LOG_FILE_NAME") or "onyx"
|
||||
|
||||
# Enable generating persistent log files for local dev environments
|
||||
|
||||
@@ -11,15 +11,6 @@ CURRENT_TENANT_ID_CONTEXTVAR: contextvars.ContextVar[
|
||||
"current_tenant_id", default=None if MULTI_TENANT else POSTGRES_DEFAULT_SCHEMA
|
||||
)
|
||||
|
||||
# set by every route in the API server
|
||||
INDEXING_REQUEST_ID_CONTEXTVAR: contextvars.ContextVar[
|
||||
str | None
|
||||
] = contextvars.ContextVar("indexing_request_id", default=None)
|
||||
|
||||
# set by every route in the API server
|
||||
ONYX_REQUEST_ID_CONTEXTVAR: contextvars.ContextVar[str | None] = contextvars.ContextVar(
|
||||
"onyx_request_id", default=None
|
||||
)
|
||||
|
||||
"""Utils related to contextvars"""
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ def confluence_connector(space: str) -> ConfluenceConnector:
|
||||
return connector
|
||||
|
||||
|
||||
@pytest.mark.parametrize("space", [os.getenv("CONFLUENCE_TEST_SPACE") or "DailyConne"])
|
||||
@pytest.mark.parametrize("space", [os.environ["CONFLUENCE_TEST_SPACE"]])
|
||||
@patch(
|
||||
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
|
||||
return_value=None,
|
||||
|
||||
@@ -165,18 +165,17 @@ class DocumentManager:
|
||||
doc["fields"]["document_id"]: doc["fields"] for doc in retrieved_docs_dict
|
||||
}
|
||||
|
||||
# NOTE(rkuo): too much log spam
|
||||
# Left this here for debugging purposes.
|
||||
# import json
|
||||
import json
|
||||
|
||||
# print("DEBUGGING DOCUMENTS")
|
||||
# print(retrieved_docs)
|
||||
# for doc in retrieved_docs.values():
|
||||
# printable_doc = doc.copy()
|
||||
# print(printable_doc.keys())
|
||||
# printable_doc.pop("embeddings")
|
||||
# printable_doc.pop("title_embedding")
|
||||
# print(json.dumps(printable_doc, indent=2))
|
||||
print("DEBUGGING DOCUMENTS")
|
||||
print(retrieved_docs)
|
||||
for doc in retrieved_docs.values():
|
||||
printable_doc = doc.copy()
|
||||
print(printable_doc.keys())
|
||||
printable_doc.pop("embeddings")
|
||||
printable_doc.pop("title_embedding")
|
||||
print(json.dumps(printable_doc, indent=2))
|
||||
|
||||
for document in cc_pair.documents:
|
||||
retrieved_doc = retrieved_docs.get(document.id)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from urllib.parse import urlencode
|
||||
@@ -192,7 +191,7 @@ class IndexAttemptManager:
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> None:
|
||||
"""Wait for an IndexAttempt to complete"""
|
||||
start = time.monotonic()
|
||||
start = datetime.now()
|
||||
while True:
|
||||
index_attempt = IndexAttemptManager.get_index_attempt_by_id(
|
||||
index_attempt_id=index_attempt_id,
|
||||
@@ -204,7 +203,7 @@ class IndexAttemptManager:
|
||||
print(f"IndexAttempt {index_attempt_id} completed")
|
||||
return
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
elapsed = (datetime.now() - start).total_seconds()
|
||||
if elapsed > timeout:
|
||||
raise TimeoutError(
|
||||
f"IndexAttempt {index_attempt_id} did not complete within {timeout} seconds"
|
||||
|
||||
@@ -4,7 +4,7 @@ from uuid import uuid4
|
||||
import requests
|
||||
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
from onyx.server.features.persona.models import FullPersonaSnapshot
|
||||
from onyx.server.features.persona.models import PersonaSnapshot
|
||||
from onyx.server.features.persona.models import PersonaUpsertRequest
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
@@ -181,7 +181,7 @@ class PersonaManager:
|
||||
@staticmethod
|
||||
def get_all(
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> list[FullPersonaSnapshot]:
|
||||
) -> list[PersonaSnapshot]:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/admin/persona",
|
||||
headers=user_performing_action.headers
|
||||
@@ -189,13 +189,13 @@ class PersonaManager:
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [FullPersonaSnapshot(**persona) for persona in response.json()]
|
||||
return [PersonaSnapshot(**persona) for persona in response.json()]
|
||||
|
||||
@staticmethod
|
||||
def get_one(
|
||||
persona_id: int,
|
||||
user_performing_action: DATestUser | None = None,
|
||||
) -> list[FullPersonaSnapshot]:
|
||||
) -> list[PersonaSnapshot]:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/persona/{persona_id}",
|
||||
headers=user_performing_action.headers
|
||||
@@ -203,7 +203,7 @@ class PersonaManager:
|
||||
else GENERAL_HEADERS,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [FullPersonaSnapshot(**response.json())]
|
||||
return [PersonaSnapshot(**response.json())]
|
||||
|
||||
@staticmethod
|
||||
def verify(
|
||||
|
||||
@@ -4,7 +4,6 @@ from onyx.chat.prune_and_merge import _merge_sections
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.utils import inference_section_from_chunks
|
||||
|
||||
|
||||
# This large test accounts for all of the following:
|
||||
@@ -112,7 +111,7 @@ Content 17
|
||||
# Sections
|
||||
[
|
||||
# Document 1, top/middle/bot connected + disconnected section
|
||||
inference_section_from_chunks(
|
||||
InferenceSection(
|
||||
center_chunk=DOC_1_TOP_CHUNK,
|
||||
chunks=[
|
||||
DOC_1_FILLER_1,
|
||||
@@ -121,8 +120,9 @@ Content 17
|
||||
DOC_1_MID_CHUNK,
|
||||
DOC_1_FILLER_3,
|
||||
],
|
||||
combined_content="N/A", # Not used
|
||||
),
|
||||
inference_section_from_chunks(
|
||||
InferenceSection(
|
||||
center_chunk=DOC_1_MID_CHUNK,
|
||||
chunks=[
|
||||
DOC_1_FILLER_2,
|
||||
@@ -131,8 +131,9 @@ Content 17
|
||||
DOC_1_FILLER_3,
|
||||
DOC_1_FILLER_4,
|
||||
],
|
||||
combined_content="N/A",
|
||||
),
|
||||
inference_section_from_chunks(
|
||||
InferenceSection(
|
||||
center_chunk=DOC_1_BOTTOM_CHUNK,
|
||||
chunks=[
|
||||
DOC_1_FILLER_3,
|
||||
@@ -141,8 +142,9 @@ Content 17
|
||||
DOC_1_FILLER_5,
|
||||
DOC_1_FILLER_6,
|
||||
],
|
||||
combined_content="N/A",
|
||||
),
|
||||
inference_section_from_chunks(
|
||||
InferenceSection(
|
||||
center_chunk=DOC_1_DISCONNECTED,
|
||||
chunks=[
|
||||
DOC_1_FILLER_7,
|
||||
@@ -151,8 +153,9 @@ Content 17
|
||||
DOC_1_FILLER_9,
|
||||
DOC_1_FILLER_10,
|
||||
],
|
||||
combined_content="N/A",
|
||||
),
|
||||
inference_section_from_chunks(
|
||||
InferenceSection(
|
||||
center_chunk=DOC_2_TOP_CHUNK,
|
||||
chunks=[
|
||||
DOC_2_FILLER_1,
|
||||
@@ -161,8 +164,9 @@ Content 17
|
||||
DOC_2_FILLER_3,
|
||||
DOC_2_BOTTOM_CHUNK,
|
||||
],
|
||||
combined_content="N/A",
|
||||
),
|
||||
inference_section_from_chunks(
|
||||
InferenceSection(
|
||||
center_chunk=DOC_2_BOTTOM_CHUNK,
|
||||
chunks=[
|
||||
DOC_2_TOP_CHUNK,
|
||||
@@ -171,6 +175,7 @@ Content 17
|
||||
DOC_2_FILLER_4,
|
||||
DOC_2_FILLER_5,
|
||||
],
|
||||
combined_content="N/A",
|
||||
),
|
||||
],
|
||||
# Expected Content
|
||||
@@ -199,13 +204,15 @@ def test_merge_sections(
|
||||
(
|
||||
# Sections
|
||||
[
|
||||
inference_section_from_chunks(
|
||||
InferenceSection(
|
||||
center_chunk=DOC_1_TOP_CHUNK,
|
||||
chunks=[DOC_1_TOP_CHUNK],
|
||||
combined_content="N/A", # Not used
|
||||
),
|
||||
inference_section_from_chunks(
|
||||
InferenceSection(
|
||||
center_chunk=DOC_1_MID_CHUNK,
|
||||
chunks=[DOC_1_MID_CHUNK],
|
||||
combined_content="N/A",
|
||||
),
|
||||
],
|
||||
# Expected Content
|
||||
|
||||
@@ -1,208 +0,0 @@
|
||||
import os
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.db.models import User
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
_post_query_chunk_censoring = fetch_ee_implementation_or_noop(
|
||||
"onyx.external_permissions.post_query_censoring", "_post_query_chunk_censoring"
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="Permissions tests are enterprise only",
|
||||
)
|
||||
class TestPostQueryChunkCensoring:
|
||||
@pytest.fixture(autouse=True)
|
||||
def setUp(self) -> None:
|
||||
self.mock_user = User(id=1, email="test@example.com")
|
||||
self.mock_chunk_1 = InferenceChunk(
|
||||
document_id="doc1",
|
||||
chunk_id=1,
|
||||
content="chunk1 content",
|
||||
source_type=DocumentSource.SALESFORCE,
|
||||
semantic_identifier="doc1_1",
|
||||
title="doc1",
|
||||
boost=1,
|
||||
recency_bias=1.0,
|
||||
score=0.9,
|
||||
hidden=False,
|
||||
metadata={},
|
||||
match_highlights=[],
|
||||
doc_summary="doc1 summary",
|
||||
chunk_context="doc1 context",
|
||||
updated_at=None,
|
||||
image_file_name=None,
|
||||
source_links={},
|
||||
section_continuation=False,
|
||||
blurb="chunk1",
|
||||
)
|
||||
self.mock_chunk_2 = InferenceChunk(
|
||||
document_id="doc2",
|
||||
chunk_id=2,
|
||||
content="chunk2 content",
|
||||
source_type=DocumentSource.SLACK,
|
||||
semantic_identifier="doc2_2",
|
||||
title="doc2",
|
||||
boost=1,
|
||||
recency_bias=1.0,
|
||||
score=0.8,
|
||||
hidden=False,
|
||||
metadata={},
|
||||
match_highlights=[],
|
||||
doc_summary="doc2 summary",
|
||||
chunk_context="doc2 context",
|
||||
updated_at=None,
|
||||
image_file_name=None,
|
||||
source_links={},
|
||||
section_continuation=False,
|
||||
blurb="chunk2",
|
||||
)
|
||||
self.mock_chunk_3 = InferenceChunk(
|
||||
document_id="doc3",
|
||||
chunk_id=3,
|
||||
content="chunk3 content",
|
||||
source_type=DocumentSource.SALESFORCE,
|
||||
semantic_identifier="doc3_3",
|
||||
title="doc3",
|
||||
boost=1,
|
||||
recency_bias=1.0,
|
||||
score=0.7,
|
||||
hidden=False,
|
||||
metadata={},
|
||||
match_highlights=[],
|
||||
doc_summary="doc3 summary",
|
||||
chunk_context="doc3 context",
|
||||
updated_at=None,
|
||||
image_file_name=None,
|
||||
source_links={},
|
||||
section_continuation=False,
|
||||
blurb="chunk3",
|
||||
)
|
||||
self.mock_chunk_4 = InferenceChunk(
|
||||
document_id="doc4",
|
||||
chunk_id=4,
|
||||
content="chunk4 content",
|
||||
source_type=DocumentSource.SALESFORCE,
|
||||
semantic_identifier="doc4_4",
|
||||
title="doc4",
|
||||
boost=1,
|
||||
recency_bias=1.0,
|
||||
score=0.6,
|
||||
hidden=False,
|
||||
metadata={},
|
||||
match_highlights=[],
|
||||
doc_summary="doc4 summary",
|
||||
chunk_context="doc4 context",
|
||||
updated_at=None,
|
||||
image_file_name=None,
|
||||
source_links={},
|
||||
section_continuation=False,
|
||||
blurb="chunk4",
|
||||
)
|
||||
|
||||
@patch(
|
||||
"ee.onyx.external_permissions.post_query_censoring._get_all_censoring_enabled_sources"
|
||||
)
|
||||
def test_post_query_chunk_censoring_no_user(
|
||||
self, mock_get_sources: MagicMock
|
||||
) -> None:
|
||||
mock_get_sources.return_value = {DocumentSource.SALESFORCE}
|
||||
chunks = [self.mock_chunk_1, self.mock_chunk_2]
|
||||
result = _post_query_chunk_censoring(chunks, None)
|
||||
assert result == chunks
|
||||
|
||||
@patch(
|
||||
"ee.onyx.external_permissions.post_query_censoring._get_all_censoring_enabled_sources"
|
||||
)
|
||||
@patch(
|
||||
"ee.onyx.external_permissions.post_query_censoring.DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION"
|
||||
)
|
||||
def test_post_query_chunk_censoring_salesforce_censored(
|
||||
self, mock_censor_func: MagicMock, mock_get_sources: MagicMock
|
||||
) -> None:
|
||||
mock_get_sources.return_value = {DocumentSource.SALESFORCE}
|
||||
mock_censor_func_impl = MagicMock(
|
||||
return_value=[self.mock_chunk_1]
|
||||
) # Only return chunk 1
|
||||
mock_censor_func.__getitem__.return_value = mock_censor_func_impl
|
||||
|
||||
chunks = [self.mock_chunk_1, self.mock_chunk_2, self.mock_chunk_3]
|
||||
result = _post_query_chunk_censoring(chunks, self.mock_user)
|
||||
assert len(result) == 2
|
||||
assert self.mock_chunk_1 in result
|
||||
assert self.mock_chunk_2 in result
|
||||
assert self.mock_chunk_3 not in result
|
||||
mock_censor_func_impl.assert_called_once()
|
||||
|
||||
@patch(
|
||||
"ee.onyx.external_permissions.post_query_censoring._get_all_censoring_enabled_sources"
|
||||
)
|
||||
@patch(
|
||||
"ee.onyx.external_permissions.post_query_censoring.DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION"
|
||||
)
|
||||
def test_post_query_chunk_censoring_salesforce_error(
|
||||
self, mock_censor_func: MagicMock, mock_get_sources: MagicMock
|
||||
) -> None:
|
||||
mock_get_sources.return_value = {DocumentSource.SALESFORCE}
|
||||
mock_censor_func_impl = MagicMock(side_effect=Exception("Censoring error"))
|
||||
mock_censor_func.__getitem__.return_value = mock_censor_func_impl
|
||||
|
||||
chunks = [self.mock_chunk_1, self.mock_chunk_2, self.mock_chunk_3]
|
||||
result = _post_query_chunk_censoring(chunks, self.mock_user)
|
||||
assert len(result) == 1
|
||||
assert self.mock_chunk_2 in result
|
||||
mock_censor_func_impl.assert_called_once()
|
||||
|
||||
@patch(
|
||||
"ee.onyx.external_permissions.post_query_censoring._get_all_censoring_enabled_sources"
|
||||
)
|
||||
@patch(
|
||||
"ee.onyx.external_permissions.post_query_censoring.DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION"
|
||||
)
|
||||
def test_post_query_chunk_censoring_no_censoring(
|
||||
self, mock_censor_func: MagicMock, mock_get_sources: MagicMock
|
||||
) -> None:
|
||||
mock_get_sources.return_value = set() # No sources to censor
|
||||
mock_censor_func_impl = MagicMock()
|
||||
mock_censor_func.__getitem__.return_value = mock_censor_func_impl
|
||||
|
||||
chunks = [self.mock_chunk_1, self.mock_chunk_2, self.mock_chunk_3]
|
||||
result = _post_query_chunk_censoring(chunks, self.mock_user)
|
||||
assert result == chunks
|
||||
mock_censor_func_impl.assert_not_called()
|
||||
|
||||
@patch(
|
||||
"ee.onyx.external_permissions.post_query_censoring._get_all_censoring_enabled_sources"
|
||||
)
|
||||
@patch(
|
||||
"ee.onyx.external_permissions.post_query_censoring.DOC_SOURCE_TO_CHUNK_CENSORING_FUNCTION"
|
||||
)
|
||||
def test_post_query_chunk_censoring_order_maintained(
|
||||
self, mock_censor_func: MagicMock, mock_get_sources: MagicMock
|
||||
) -> None:
|
||||
mock_get_sources.return_value = {DocumentSource.SALESFORCE}
|
||||
mock_censor_func_impl = MagicMock(
|
||||
return_value=[self.mock_chunk_3, self.mock_chunk_1]
|
||||
) # Return chunk 3 and 1
|
||||
mock_censor_func.__getitem__.return_value = mock_censor_func_impl
|
||||
|
||||
chunks = [
|
||||
self.mock_chunk_1,
|
||||
self.mock_chunk_2,
|
||||
self.mock_chunk_3,
|
||||
self.mock_chunk_4,
|
||||
]
|
||||
result = _post_query_chunk_censoring(chunks, self.mock_user)
|
||||
assert len(result) == 3
|
||||
assert result[0] == self.mock_chunk_1
|
||||
assert result[1] == self.mock_chunk_2
|
||||
assert result[2] == self.mock_chunk_3
|
||||
assert self.mock_chunk_4 not in result
|
||||
mock_censor_func_impl.assert_called_once()
|
||||
@@ -1,270 +0,0 @@
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import INDEX_SEPARATOR
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import Tag
|
||||
from onyx.document_index.vespa.shared_utils.vespa_request_builders import (
|
||||
build_vespa_filters,
|
||||
)
|
||||
from onyx.document_index.vespa_constants import DOC_UPDATED_AT
|
||||
from onyx.document_index.vespa_constants import DOCUMENT_SETS
|
||||
from onyx.document_index.vespa_constants import HIDDEN
|
||||
from onyx.document_index.vespa_constants import METADATA_LIST
|
||||
from onyx.document_index.vespa_constants import SOURCE_TYPE
|
||||
from onyx.document_index.vespa_constants import TENANT_ID
|
||||
from onyx.document_index.vespa_constants import USER_FILE
|
||||
from onyx.document_index.vespa_constants import USER_FOLDER
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
# Import the function under test
|
||||
|
||||
|
||||
class TestBuildVespaFilters:
|
||||
def test_empty_filters(self) -> None:
|
||||
"""Test with empty filters object."""
|
||||
filters = IndexFilters(access_control_list=[])
|
||||
result = build_vespa_filters(filters)
|
||||
assert result == f"!({HIDDEN}=true) and "
|
||||
|
||||
# With trailing AND removed
|
||||
result = build_vespa_filters(filters, remove_trailing_and=True)
|
||||
assert result == f"!({HIDDEN}=true)"
|
||||
|
||||
def test_include_hidden(self) -> None:
|
||||
"""Test with include_hidden flag."""
|
||||
filters = IndexFilters(access_control_list=[])
|
||||
result = build_vespa_filters(filters, include_hidden=True)
|
||||
assert result == "" # No filters applied when including hidden
|
||||
|
||||
# With some other filter to ensure proper AND chaining
|
||||
filters = IndexFilters(access_control_list=[], source_type=[DocumentSource.WEB])
|
||||
result = build_vespa_filters(filters, include_hidden=True)
|
||||
assert result == f'({SOURCE_TYPE} contains "web") and '
|
||||
|
||||
def test_acl(self) -> None:
|
||||
"""Test with acls."""
|
||||
# Single ACL
|
||||
filters = IndexFilters(access_control_list=["user1"])
|
||||
result = build_vespa_filters(filters)
|
||||
assert (
|
||||
result
|
||||
== f'!({HIDDEN}=true) and (access_control_list contains "user1") and '
|
||||
)
|
||||
|
||||
# Multiple ACL's
|
||||
filters = IndexFilters(access_control_list=["user2", "group2"])
|
||||
result = build_vespa_filters(filters)
|
||||
assert (
|
||||
result
|
||||
== f'!({HIDDEN}=true) and (access_control_list contains "user2" or access_control_list contains "group2") and '
|
||||
)
|
||||
|
||||
def test_tenant_filter(self) -> None:
|
||||
"""Test tenant ID filtering."""
|
||||
# With tenant ID
|
||||
if MULTI_TENANT:
|
||||
filters = IndexFilters(access_control_list=[], tenant_id="tenant1")
|
||||
result = build_vespa_filters(filters)
|
||||
assert (
|
||||
f'!({HIDDEN}=true) and ({TENANT_ID} contains "tenant1") and ' == result
|
||||
)
|
||||
|
||||
# No tenant ID
|
||||
filters = IndexFilters(access_control_list=[], tenant_id=None)
|
||||
result = build_vespa_filters(filters)
|
||||
assert f"!({HIDDEN}=true) and " == result
|
||||
|
||||
def test_source_type_filter(self) -> None:
|
||||
"""Test source type filtering."""
|
||||
# Single source type
|
||||
filters = IndexFilters(access_control_list=[], source_type=[DocumentSource.WEB])
|
||||
result = build_vespa_filters(filters)
|
||||
assert f'!({HIDDEN}=true) and ({SOURCE_TYPE} contains "web") and ' == result
|
||||
|
||||
# Multiple source types
|
||||
filters = IndexFilters(
|
||||
access_control_list=[],
|
||||
source_type=[DocumentSource.WEB, DocumentSource.JIRA],
|
||||
)
|
||||
result = build_vespa_filters(filters)
|
||||
assert (
|
||||
f'!({HIDDEN}=true) and ({SOURCE_TYPE} contains "web" or {SOURCE_TYPE} contains "jira") and '
|
||||
== result
|
||||
)
|
||||
|
||||
# Empty source type list
|
||||
filters = IndexFilters(access_control_list=[], source_type=[])
|
||||
result = build_vespa_filters(filters)
|
||||
assert f"!({HIDDEN}=true) and " == result
|
||||
|
||||
def test_tag_filters(self) -> None:
|
||||
"""Test tag filtering."""
|
||||
# Single tag
|
||||
filters = IndexFilters(
|
||||
access_control_list=[], tags=[Tag(tag_key="color", tag_value="red")]
|
||||
)
|
||||
result = build_vespa_filters(filters)
|
||||
assert (
|
||||
f'!({HIDDEN}=true) and ({METADATA_LIST} contains "color{INDEX_SEPARATOR}red") and '
|
||||
== result
|
||||
)
|
||||
|
||||
# Multiple tags
|
||||
filters = IndexFilters(
|
||||
access_control_list=[],
|
||||
tags=[
|
||||
Tag(tag_key="color", tag_value="red"),
|
||||
Tag(tag_key="size", tag_value="large"),
|
||||
],
|
||||
)
|
||||
result = build_vespa_filters(filters)
|
||||
expected = (
|
||||
f'!({HIDDEN}=true) and ({METADATA_LIST} contains "color{INDEX_SEPARATOR}red" '
|
||||
f'or {METADATA_LIST} contains "size{INDEX_SEPARATOR}large") and '
|
||||
)
|
||||
assert expected == result
|
||||
|
||||
# Empty tags list
|
||||
filters = IndexFilters(access_control_list=[], tags=[])
|
||||
result = build_vespa_filters(filters)
|
||||
assert f"!({HIDDEN}=true) and " == result
|
||||
|
||||
def test_document_sets_filter(self) -> None:
|
||||
"""Test document sets filtering."""
|
||||
# Single document set
|
||||
filters = IndexFilters(access_control_list=[], document_set=["set1"])
|
||||
result = build_vespa_filters(filters)
|
||||
assert f'!({HIDDEN}=true) and ({DOCUMENT_SETS} contains "set1") and ' == result
|
||||
|
||||
# Multiple document sets
|
||||
filters = IndexFilters(access_control_list=[], document_set=["set1", "set2"])
|
||||
result = build_vespa_filters(filters)
|
||||
assert (
|
||||
f'!({HIDDEN}=true) and ({DOCUMENT_SETS} contains "set1" or {DOCUMENT_SETS} contains "set2") and '
|
||||
== result
|
||||
)
|
||||
|
||||
# Empty document sets
|
||||
filters = IndexFilters(access_control_list=[], document_set=[])
|
||||
result = build_vespa_filters(filters)
|
||||
assert f"!({HIDDEN}=true) and " == result
|
||||
|
||||
def test_user_file_ids_filter(self) -> None:
|
||||
"""Test user file IDs filtering."""
|
||||
# Single user file ID
|
||||
filters = IndexFilters(access_control_list=[], user_file_ids=[123])
|
||||
result = build_vespa_filters(filters)
|
||||
assert f"!({HIDDEN}=true) and ({USER_FILE} = 123) and " == result
|
||||
|
||||
# Multiple user file IDs
|
||||
filters = IndexFilters(access_control_list=[], user_file_ids=[123, 456])
|
||||
result = build_vespa_filters(filters)
|
||||
assert (
|
||||
f"!({HIDDEN}=true) and ({USER_FILE} = 123 or {USER_FILE} = 456) and "
|
||||
== result
|
||||
)
|
||||
|
||||
# Empty user file IDs
|
||||
filters = IndexFilters(access_control_list=[], user_file_ids=[])
|
||||
result = build_vespa_filters(filters)
|
||||
assert f"!({HIDDEN}=true) and " == result
|
||||
|
||||
def test_user_folder_ids_filter(self) -> None:
|
||||
"""Test user folder IDs filtering."""
|
||||
# Single user folder ID
|
||||
filters = IndexFilters(access_control_list=[], user_folder_ids=[789])
|
||||
result = build_vespa_filters(filters)
|
||||
assert f"!({HIDDEN}=true) and ({USER_FOLDER} = 789) and " == result
|
||||
|
||||
# Multiple user folder IDs
|
||||
filters = IndexFilters(access_control_list=[], user_folder_ids=[789, 101])
|
||||
result = build_vespa_filters(filters)
|
||||
assert (
|
||||
f"!({HIDDEN}=true) and ({USER_FOLDER} = 789 or {USER_FOLDER} = 101) and "
|
||||
== result
|
||||
)
|
||||
|
||||
# Empty user folder IDs
|
||||
filters = IndexFilters(access_control_list=[], user_folder_ids=[])
|
||||
result = build_vespa_filters(filters)
|
||||
assert f"!({HIDDEN}=true) and " == result
|
||||
|
||||
def test_time_cutoff_filter(self) -> None:
|
||||
"""Test time cutoff filtering."""
|
||||
# With cutoff time
|
||||
cutoff_time = datetime(2023, 1, 1, tzinfo=timezone.utc)
|
||||
filters = IndexFilters(access_control_list=[], time_cutoff=cutoff_time)
|
||||
result = build_vespa_filters(filters)
|
||||
cutoff_secs = int(cutoff_time.timestamp())
|
||||
assert (
|
||||
f"!({HIDDEN}=true) and !({DOC_UPDATED_AT} < {cutoff_secs}) and " == result
|
||||
)
|
||||
|
||||
# No cutoff time
|
||||
filters = IndexFilters(access_control_list=[], time_cutoff=None)
|
||||
result = build_vespa_filters(filters)
|
||||
assert f"!({HIDDEN}=true) and " == result
|
||||
|
||||
# Test untimed logic (when cutoff is old enough)
|
||||
old_cutoff = datetime.now(timezone.utc) - timedelta(days=100)
|
||||
filters = IndexFilters(access_control_list=[], time_cutoff=old_cutoff)
|
||||
result = build_vespa_filters(filters)
|
||||
old_cutoff_secs = int(old_cutoff.timestamp())
|
||||
assert (
|
||||
f"!({HIDDEN}=true) and !({DOC_UPDATED_AT} < {old_cutoff_secs}) and "
|
||||
== result
|
||||
)
|
||||
|
||||
def test_combined_filters(self) -> None:
|
||||
"""Test combining multiple filter types."""
|
||||
filters = IndexFilters(
|
||||
access_control_list=["user1", "group1"],
|
||||
source_type=[DocumentSource.WEB],
|
||||
tags=[Tag(tag_key="color", tag_value="red")],
|
||||
document_set=["set1"],
|
||||
user_file_ids=[123],
|
||||
user_folder_ids=[789],
|
||||
time_cutoff=datetime(2023, 1, 1, tzinfo=timezone.utc),
|
||||
)
|
||||
|
||||
result = build_vespa_filters(filters)
|
||||
|
||||
# Build expected result piece by piece for readability
|
||||
expected = f"!({HIDDEN}=true) and "
|
||||
expected += (
|
||||
'(access_control_list contains "user1" or '
|
||||
'access_control_list contains "group1") and '
|
||||
)
|
||||
expected += f'({SOURCE_TYPE} contains "web") and '
|
||||
expected += f'({METADATA_LIST} contains "color{INDEX_SEPARATOR}red") and '
|
||||
expected += f'({DOCUMENT_SETS} contains "set1") and '
|
||||
expected += f"({USER_FILE} = 123) and "
|
||||
expected += f"({USER_FOLDER} = 789) and "
|
||||
cutoff_secs = int(datetime(2023, 1, 1, tzinfo=timezone.utc).timestamp())
|
||||
expected += f"!({DOC_UPDATED_AT} < {cutoff_secs}) and "
|
||||
|
||||
assert expected == result
|
||||
|
||||
# With trailing AND removed
|
||||
result_no_trailing = build_vespa_filters(filters, remove_trailing_and=True)
|
||||
assert expected[:-5] == result_no_trailing # Remove trailing " and "
|
||||
|
||||
def test_empty_or_none_values(self) -> None:
|
||||
"""Test with empty or None values in filter lists."""
|
||||
# Empty strings in document set
|
||||
filters = IndexFilters(
|
||||
access_control_list=[], document_set=["set1", "", "set2"]
|
||||
)
|
||||
result = build_vespa_filters(filters)
|
||||
assert (
|
||||
f'!({HIDDEN}=true) and ({DOCUMENT_SETS} contains "set1" or {DOCUMENT_SETS} contains "set2") and '
|
||||
== result
|
||||
)
|
||||
|
||||
# All empty strings in document set
|
||||
filters = IndexFilters(access_control_list=[], document_set=["", ""])
|
||||
result = build_vespa_filters(filters)
|
||||
assert f"!({HIDDEN}=true) and " == result
|
||||
@@ -42,6 +42,7 @@ ENV NEXT_PUBLIC_DISABLE_STREAMING=${NEXT_PUBLIC_DISABLE_STREAMING}
|
||||
ARG NEXT_PUBLIC_NEW_CHAT_DIRECTS_TO_SAME_PERSONA
|
||||
ENV NEXT_PUBLIC_NEW_CHAT_DIRECTS_TO_SAME_PERSONA=${NEXT_PUBLIC_NEW_CHAT_DIRECTS_TO_SAME_PERSONA}
|
||||
|
||||
# allow user to specify custom feedback options
|
||||
ARG NEXT_PUBLIC_POSITIVE_PREDEFINED_FEEDBACK_OPTIONS
|
||||
ENV NEXT_PUBLIC_POSITIVE_PREDEFINED_FEEDBACK_OPTIONS=${NEXT_PUBLIC_POSITIVE_PREDEFINED_FEEDBACK_OPTIONS}
|
||||
|
||||
|
||||
12200
web/package-lock.json
generated
12200
web/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -93,12 +93,11 @@
|
||||
},
|
||||
"devDependencies": {
|
||||
"@chromatic-com/playwright": "^0.10.2",
|
||||
"@playwright/test": "^1.39.0",
|
||||
"@tailwindcss/typography": "^0.5.10",
|
||||
"@types/chrome": "^0.0.287",
|
||||
"@types/jest": "^29.5.14",
|
||||
"chromatic": "^11.25.2",
|
||||
"eslint": "^8.57.1",
|
||||
"eslint": "^8.48.0",
|
||||
"eslint-config-next": "^14.1.0",
|
||||
"jest": "^29.7.0",
|
||||
"prettier": "2.8.8",
|
||||
|
||||
@@ -17,7 +17,7 @@ export default function PostHogPageView(): null {
|
||||
// Track pageviews
|
||||
if (pathname) {
|
||||
let url = window.origin + pathname;
|
||||
if (searchParams?.toString()) {
|
||||
if (searchParams.toString()) {
|
||||
url = url + `?${searchParams.toString()}`;
|
||||
}
|
||||
posthog.capture("$pageview", {
|
||||
|
||||
@@ -42,7 +42,9 @@ import Link from "next/link";
|
||||
import { useRouter, useSearchParams } from "next/navigation";
|
||||
import { useEffect, useMemo, useState } from "react";
|
||||
import * as Yup from "yup";
|
||||
import { FullPersona, PersonaLabel, StarterMessage } from "./interfaces";
|
||||
import CollapsibleSection from "./CollapsibleSection";
|
||||
import { SuccessfulPersonaUpdateRedirectType } from "./enums";
|
||||
import { Persona, PersonaLabel, StarterMessage } from "./interfaces";
|
||||
import {
|
||||
PersonaUpsertParameters,
|
||||
createPersona,
|
||||
@@ -99,7 +101,6 @@ import { SEARCH_TOOL_ID } from "@/app/chat/tools/constants";
|
||||
import TextView from "@/components/chat/TextView";
|
||||
import { MinimalOnyxDocument } from "@/lib/search/interfaces";
|
||||
import { TabToggle } from "@/components/ui/TabToggle";
|
||||
import { MAX_CHARACTERS_PERSONA_DESCRIPTION } from "@/lib/constants";
|
||||
|
||||
function findSearchTool(tools: ToolSnapshot[]) {
|
||||
return tools.find((tool) => tool.in_code_tool_id === SEARCH_TOOL_ID);
|
||||
@@ -135,7 +136,7 @@ export function AssistantEditor({
|
||||
shouldAddAssistantToUserPreferences,
|
||||
admin,
|
||||
}: {
|
||||
existingPersona?: FullPersona | null;
|
||||
existingPersona?: Persona | null;
|
||||
ccPairs: CCPairBasicInfo[];
|
||||
documentSets: DocumentSet[];
|
||||
user: User | null;
|
||||
@@ -148,7 +149,7 @@ export function AssistantEditor({
|
||||
const { refreshAssistants, isImageGenerationAvailable } = useAssistants();
|
||||
const router = useRouter();
|
||||
const searchParams = useSearchParams();
|
||||
const isAdminPage = searchParams?.get("admin") === "true";
|
||||
const isAdminPage = searchParams.get("admin") === "true";
|
||||
|
||||
const { popup, setPopup } = usePopup();
|
||||
const { labels, refreshLabels, createLabel, updateLabel, deleteLabel } =
|
||||
@@ -183,6 +184,8 @@ export function AssistantEditor({
|
||||
}
|
||||
}, [defaultIconShape]);
|
||||
|
||||
const [isIconDropdownOpen, setIsIconDropdownOpen] = useState(false);
|
||||
|
||||
const [removePersonaImage, setRemovePersonaImage] = useState(false);
|
||||
|
||||
const autoStarterMessageEnabled = useMemo(
|
||||
@@ -239,7 +242,15 @@ export function AssistantEditor({
|
||||
enabledToolsMap[tool.id] = personaCurrentToolIds.includes(tool.id);
|
||||
});
|
||||
|
||||
const { selectedFiles, selectedFolders } = useDocumentsContext();
|
||||
const {
|
||||
selectedFiles,
|
||||
selectedFolders,
|
||||
addSelectedFile,
|
||||
removeSelectedFile,
|
||||
addSelectedFolder,
|
||||
removeSelectedFolder,
|
||||
clearSelectedItems,
|
||||
} = useDocumentsContext();
|
||||
|
||||
const [showVisibilityWarning, setShowVisibilityWarning] = useState(false);
|
||||
|
||||
@@ -458,14 +469,8 @@ export function AssistantEditor({
|
||||
description: Yup.string().required(
|
||||
"Must provide a description for the Assistant"
|
||||
),
|
||||
system_prompt: Yup.string().max(
|
||||
MAX_CHARACTERS_PERSONA_DESCRIPTION,
|
||||
"Instructions must be less than 5000000 characters"
|
||||
),
|
||||
task_prompt: Yup.string().max(
|
||||
MAX_CHARACTERS_PERSONA_DESCRIPTION,
|
||||
"Reminders must be less than 5000000 characters"
|
||||
),
|
||||
system_prompt: Yup.string(),
|
||||
task_prompt: Yup.string(),
|
||||
is_public: Yup.boolean().required(),
|
||||
document_set_ids: Yup.array().of(Yup.number()),
|
||||
num_chunks: Yup.number().nullable(),
|
||||
|
||||
@@ -18,37 +18,35 @@ export interface Prompt {
|
||||
datetime_aware: boolean;
|
||||
default_prompt: boolean;
|
||||
}
|
||||
|
||||
export interface Persona {
|
||||
id: number;
|
||||
name: string;
|
||||
description: string;
|
||||
is_public: boolean;
|
||||
search_start_date: Date | null;
|
||||
owner: MinimalUserSnapshot | null;
|
||||
is_visible: boolean;
|
||||
is_public: boolean;
|
||||
display_priority: number | null;
|
||||
description: string;
|
||||
document_sets: DocumentSet[];
|
||||
prompts: Prompt[];
|
||||
tools: ToolSnapshot[];
|
||||
num_chunks?: number;
|
||||
llm_relevance_filter?: boolean;
|
||||
llm_filter_extraction?: boolean;
|
||||
llm_model_provider_override?: string;
|
||||
llm_model_version_override?: string;
|
||||
starter_messages: StarterMessage[] | null;
|
||||
builtin_persona: boolean;
|
||||
is_default_persona: boolean;
|
||||
users: MinimalUserSnapshot[];
|
||||
groups: number[];
|
||||
icon_shape?: number;
|
||||
icon_color?: string;
|
||||
uploaded_image_id?: string;
|
||||
labels?: PersonaLabel[];
|
||||
user_file_ids: number[];
|
||||
user_folder_ids: number[];
|
||||
display_priority: number | null;
|
||||
is_default_persona: boolean;
|
||||
builtin_persona: boolean;
|
||||
starter_messages: StarterMessage[] | null;
|
||||
tools: ToolSnapshot[];
|
||||
labels?: PersonaLabel[];
|
||||
owner: MinimalUserSnapshot | null;
|
||||
users: MinimalUserSnapshot[];
|
||||
groups: number[];
|
||||
document_sets: DocumentSet[];
|
||||
llm_model_provider_override?: string;
|
||||
llm_model_version_override?: string;
|
||||
num_chunks?: number;
|
||||
}
|
||||
|
||||
export interface FullPersona extends Persona {
|
||||
search_start_date: Date | null;
|
||||
prompts: Prompt[];
|
||||
llm_relevance_filter?: boolean;
|
||||
llm_filter_extraction?: boolean;
|
||||
}
|
||||
|
||||
export interface PersonaLabel {
|
||||
|
||||
@@ -331,3 +331,28 @@ export function providersContainImageGeneratingSupport(
|
||||
) {
|
||||
return providers.some((provider) => provider.provider === "openai");
|
||||
}
|
||||
|
||||
// Default fallback persona for when we must display a persona
|
||||
// but assistant has access to none
|
||||
export const defaultPersona: Persona = {
|
||||
id: 0,
|
||||
name: "Default Assistant",
|
||||
description: "A default assistant",
|
||||
is_visible: true,
|
||||
is_public: true,
|
||||
builtin_persona: false,
|
||||
is_default_persona: true,
|
||||
users: [],
|
||||
groups: [],
|
||||
document_sets: [],
|
||||
prompts: [],
|
||||
tools: [],
|
||||
starter_messages: null,
|
||||
display_priority: null,
|
||||
search_start_date: null,
|
||||
owner: null,
|
||||
icon_shape: 50910,
|
||||
icon_color: "#FF6F6F",
|
||||
user_file_ids: [],
|
||||
user_folder_ids: [],
|
||||
};
|
||||
|
||||
@@ -302,17 +302,11 @@ export default function AddConnector({
|
||||
...connector_specific_config
|
||||
} = values;
|
||||
|
||||
// Apply special transforms according to application logic
|
||||
// Apply transforms from connectors.ts configuration
|
||||
const transformedConnectorSpecificConfig = Object.entries(
|
||||
connector_specific_config
|
||||
).reduce(
|
||||
(acc, [key, value]) => {
|
||||
// Filter out empty strings from arrays
|
||||
if (Array.isArray(value)) {
|
||||
value = (value as any[]).filter(
|
||||
(item) => typeof item !== "string" || item.trim() !== ""
|
||||
);
|
||||
}
|
||||
const matchingConfigValue = configuration.values.find(
|
||||
(configValue) => configValue.name === key
|
||||
);
|
||||
|
||||
@@ -26,8 +26,8 @@ export default function OAuthCallbackPage() {
|
||||
);
|
||||
|
||||
// Extract query parameters
|
||||
const code = searchParams?.get("code");
|
||||
const state = searchParams?.get("state");
|
||||
const code = searchParams.get("code");
|
||||
const state = searchParams.get("state");
|
||||
|
||||
const pathname = usePathname();
|
||||
const connector = pathname?.split("/")[3];
|
||||
|
||||
@@ -4,6 +4,7 @@ import { useEffect, useState } from "react";
|
||||
import { usePathname, useRouter, useSearchParams } from "next/navigation";
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import Title from "@/components/ui/title";
|
||||
import { KeyIcon } from "@/components/icons/icons";
|
||||
import { getSourceMetadata, isValidSource } from "@/lib/sources";
|
||||
import { ConfluenceAccessibleResource, ValidSources } from "@/lib/types";
|
||||
@@ -73,7 +74,7 @@ export default function OAuthFinalizePage() {
|
||||
>([]);
|
||||
|
||||
// Extract query parameters
|
||||
const credentialParam = searchParams?.get("credential");
|
||||
const credentialParam = searchParams.get("credential");
|
||||
const credential = credentialParam ? parseInt(credentialParam, 10) : NaN;
|
||||
const pathname = usePathname();
|
||||
const connector = pathname?.split("/")[3];
|
||||
@@ -84,7 +85,7 @@ export default function OAuthFinalizePage() {
|
||||
// connector (url segment)= "google-drive"
|
||||
// sourceType (for looking up metadata) = "google_drive"
|
||||
|
||||
if (isNaN(credential) || !connector) {
|
||||
if (isNaN(credential)) {
|
||||
setStatusMessage("Improperly formed OAuth finalization request.");
|
||||
setStatusDetails("Invalid or missing credential id.");
|
||||
setIsError(true);
|
||||
|
||||
@@ -487,6 +487,11 @@ export default function EmbeddingForm() {
|
||||
};
|
||||
|
||||
const handleReIndex = async () => {
|
||||
console.log("handleReIndex");
|
||||
console.log(selectedProvider);
|
||||
console.log(advancedEmbeddingDetails);
|
||||
console.log(rerankingDetails);
|
||||
console.log(reindexType);
|
||||
if (!selectedProvider) {
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -23,8 +23,8 @@ const ResetPasswordPage: React.FC = () => {
|
||||
const { popup, setPopup } = usePopup();
|
||||
const [isWorking, setIsWorking] = useState(false);
|
||||
const searchParams = useSearchParams();
|
||||
const token = searchParams?.get("token");
|
||||
const tenantId = searchParams?.get(TENANT_ID_COOKIE_NAME);
|
||||
const token = searchParams.get("token");
|
||||
const tenantId = searchParams.get(TENANT_ID_COOKIE_NAME);
|
||||
// Keep search param same name as cookie for simplicity
|
||||
|
||||
useEffect(() => {
|
||||
|
||||
@@ -15,9 +15,9 @@ export function Verify({ user }: { user: User | null }) {
|
||||
const [error, setError] = useState("");
|
||||
|
||||
const verify = useCallback(async () => {
|
||||
const token = searchParams?.get("token");
|
||||
const token = searchParams.get("token");
|
||||
const firstUser =
|
||||
searchParams?.get("first_user") && NEXT_PUBLIC_CLOUD_ENABLED;
|
||||
searchParams.get("first_user") && NEXT_PUBLIC_CLOUD_ENABLED;
|
||||
if (!token) {
|
||||
setError(
|
||||
"Missing verification token. Try requesting a new verification email."
|
||||
|
||||
@@ -196,9 +196,7 @@ export function ChatPage({
|
||||
setCurrentMessageFiles,
|
||||
} = useDocumentsContext();
|
||||
|
||||
const defaultAssistantIdRaw = searchParams?.get(
|
||||
SEARCH_PARAM_NAMES.PERSONA_ID
|
||||
);
|
||||
const defaultAssistantIdRaw = searchParams.get(SEARCH_PARAM_NAMES.PERSONA_ID);
|
||||
const defaultAssistantId = defaultAssistantIdRaw
|
||||
? parseInt(defaultAssistantIdRaw)
|
||||
: undefined;
|
||||
@@ -254,8 +252,8 @@ export function ChatPage({
|
||||
);
|
||||
|
||||
const { user, isAdmin } = useUser();
|
||||
const slackChatId = searchParams?.get("slackChatId");
|
||||
const existingChatIdRaw = searchParams?.get("chatId");
|
||||
const slackChatId = searchParams.get("slackChatId");
|
||||
const existingChatIdRaw = searchParams.get("chatId");
|
||||
|
||||
const [showHistorySidebar, setShowHistorySidebar] = useState(false);
|
||||
|
||||
@@ -277,7 +275,7 @@ export function ChatPage({
|
||||
|
||||
const processSearchParamsAndSubmitMessage = (searchParamsString: string) => {
|
||||
const newSearchParams = new URLSearchParams(searchParamsString);
|
||||
const message = newSearchParams?.get("user-prompt");
|
||||
const message = newSearchParams.get("user-prompt");
|
||||
|
||||
filterManager.buildFiltersFromQueryString(
|
||||
newSearchParams.toString(),
|
||||
@@ -286,7 +284,7 @@ export function ChatPage({
|
||||
tags
|
||||
);
|
||||
|
||||
const fileDescriptorString = newSearchParams?.get(SEARCH_PARAM_NAMES.FILES);
|
||||
const fileDescriptorString = newSearchParams.get(SEARCH_PARAM_NAMES.FILES);
|
||||
const overrideFileDescriptors: FileDescriptor[] = fileDescriptorString
|
||||
? JSON.parse(decodeURIComponent(fileDescriptorString))
|
||||
: [];
|
||||
@@ -326,7 +324,7 @@ export function ChatPage({
|
||||
: undefined
|
||||
);
|
||||
// Gather default temperature settings
|
||||
const search_param_temperature = searchParams?.get(
|
||||
const search_param_temperature = searchParams.get(
|
||||
SEARCH_PARAM_NAMES.TEMPERATURE
|
||||
);
|
||||
|
||||
@@ -553,7 +551,7 @@ export function ChatPage({
|
||||
if (
|
||||
newMessageHistory.length === 1 &&
|
||||
!submitOnLoadPerformed.current &&
|
||||
searchParams?.get(SEARCH_PARAM_NAMES.SEEDED) === "true"
|
||||
searchParams.get(SEARCH_PARAM_NAMES.SEEDED) === "true"
|
||||
) {
|
||||
submitOnLoadPerformed.current = true;
|
||||
const seededMessage = newMessageHistory[0].message;
|
||||
@@ -574,11 +572,11 @@ export function ChatPage({
|
||||
|
||||
initialSessionFetch();
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [existingChatSessionId, searchParams?.get(SEARCH_PARAM_NAMES.PERSONA_ID)]);
|
||||
}, [existingChatSessionId, searchParams.get(SEARCH_PARAM_NAMES.PERSONA_ID)]);
|
||||
|
||||
useEffect(() => {
|
||||
const userFolderId = searchParams?.get(SEARCH_PARAM_NAMES.USER_FOLDER_ID);
|
||||
const allMyDocuments = searchParams?.get(
|
||||
const userFolderId = searchParams.get(SEARCH_PARAM_NAMES.USER_FOLDER_ID);
|
||||
const allMyDocuments = searchParams.get(
|
||||
SEARCH_PARAM_NAMES.ALL_MY_DOCUMENTS
|
||||
);
|
||||
|
||||
@@ -601,14 +599,14 @@ export function ChatPage({
|
||||
}
|
||||
}, [
|
||||
userFolders,
|
||||
searchParams?.get(SEARCH_PARAM_NAMES.USER_FOLDER_ID),
|
||||
searchParams?.get(SEARCH_PARAM_NAMES.ALL_MY_DOCUMENTS),
|
||||
searchParams.get(SEARCH_PARAM_NAMES.USER_FOLDER_ID),
|
||||
searchParams.get(SEARCH_PARAM_NAMES.ALL_MY_DOCUMENTS),
|
||||
addSelectedFolder,
|
||||
clearSelectedItems,
|
||||
]);
|
||||
|
||||
const [message, setMessage] = useState(
|
||||
searchParams?.get(SEARCH_PARAM_NAMES.USER_PROMPT) || ""
|
||||
searchParams.get(SEARCH_PARAM_NAMES.USER_PROMPT) || ""
|
||||
);
|
||||
|
||||
const [completeMessageDetail, setCompleteMessageDetail] = useState<
|
||||
@@ -1050,7 +1048,7 @@ export function ChatPage({
|
||||
|
||||
// Equivalent to `loadNewPageLogic`
|
||||
useEffect(() => {
|
||||
if (searchParams?.get(SEARCH_PARAM_NAMES.SEND_ON_LOAD)) {
|
||||
if (searchParams.get(SEARCH_PARAM_NAMES.SEND_ON_LOAD)) {
|
||||
processSearchParamsAndSubmitMessage(searchParams.toString());
|
||||
}
|
||||
}, [searchParams, router]);
|
||||
@@ -1233,7 +1231,7 @@ export function ChatPage({
|
||||
const isNewSession = chatSessionIdRef.current === null;
|
||||
|
||||
const searchParamBasedChatSessionName =
|
||||
searchParams?.get(SEARCH_PARAM_NAMES.TITLE) || null;
|
||||
searchParams.get(SEARCH_PARAM_NAMES.TITLE) || null;
|
||||
|
||||
if (isNewSession) {
|
||||
currChatSessionId = await createChatSession(
|
||||
@@ -1383,7 +1381,7 @@ export function ChatPage({
|
||||
regenerationRequest?.parentMessage.messageId ||
|
||||
lastSuccessfulMessageId,
|
||||
chatSessionId: currChatSessionId,
|
||||
promptId: null,
|
||||
promptId: liveAssistant?.prompts[0]?.id || 0,
|
||||
filters: buildFilters(
|
||||
filterManager.selectedSources,
|
||||
filterManager.selectedDocumentSets,
|
||||
@@ -1411,11 +1409,11 @@ export function ChatPage({
|
||||
modelVersion:
|
||||
modelOverride?.modelName ||
|
||||
llmManager.currentLlm.modelName ||
|
||||
searchParams?.get(SEARCH_PARAM_NAMES.MODEL_VERSION) ||
|
||||
searchParams.get(SEARCH_PARAM_NAMES.MODEL_VERSION) ||
|
||||
undefined,
|
||||
temperature: llmManager.temperature || undefined,
|
||||
systemPromptOverride:
|
||||
searchParams?.get(SEARCH_PARAM_NAMES.SYSTEM_PROMPT) || undefined,
|
||||
searchParams.get(SEARCH_PARAM_NAMES.SYSTEM_PROMPT) || undefined,
|
||||
useExistingUserMessage: isSeededChat,
|
||||
useLanggraph:
|
||||
settings?.settings.pro_search_enabled &&
|
||||
|
||||
@@ -4,62 +4,177 @@ import {
|
||||
LlmDescriptor,
|
||||
useLlmManager,
|
||||
} from "@/lib/hooks";
|
||||
import { StringOrNumberOption } from "@/components/Dropdown";
|
||||
|
||||
import { Persona } from "@/app/admin/assistants/interfaces";
|
||||
import { destructureValue } from "@/lib/llm/utils";
|
||||
import { destructureValue, getFinalLLM, structureValue } from "@/lib/llm/utils";
|
||||
import { useState } from "react";
|
||||
import { Hoverable } from "@/components/Hoverable";
|
||||
import { Popover } from "@/components/popover/Popover";
|
||||
import { IconType } from "react-icons";
|
||||
import { FiRefreshCw } from "react-icons/fi";
|
||||
import LLMPopover from "./input/LLMPopover";
|
||||
import { FiRefreshCw, FiCheck } from "react-icons/fi";
|
||||
|
||||
export default function RegenerateOption({
|
||||
selectedAssistant,
|
||||
regenerate,
|
||||
overriddenModel,
|
||||
export function RegenerateDropdown({
|
||||
options,
|
||||
selected,
|
||||
onSelect,
|
||||
side,
|
||||
maxHeight,
|
||||
alternate,
|
||||
onDropdownVisibleChange,
|
||||
}: {
|
||||
selectedAssistant: Persona;
|
||||
regenerate: (modelOverRide: LlmDescriptor) => Promise<void>;
|
||||
overriddenModel?: string;
|
||||
alternate?: string;
|
||||
options: StringOrNumberOption[];
|
||||
selected: string | null;
|
||||
onSelect: (value: string | number | null) => void;
|
||||
includeDefault?: boolean;
|
||||
side?: "top" | "right" | "bottom" | "left";
|
||||
maxHeight?: string;
|
||||
onDropdownVisibleChange: (isVisible: boolean) => void;
|
||||
}) {
|
||||
const { llmProviders } = useChatContext();
|
||||
const llmManager = useLlmManager(llmProviders);
|
||||
const [isOpen, setIsOpen] = useState(false);
|
||||
|
||||
const toggleDropdownVisible = (isVisible: boolean) => {
|
||||
setIsOpen(isVisible);
|
||||
onDropdownVisibleChange(isVisible);
|
||||
};
|
||||
|
||||
const Dropdown = (
|
||||
<div className="overflow-y-auto border border-neutral-800 py-2 min-w-fit bg-neutral-50 dark:bg-neutral-900 rounded-md shadow-lg">
|
||||
<div className="mb-1 flex items-center justify-between px-4 pt-2">
|
||||
<span className="text-sm text-neutral-600 dark:text-neutral-400">
|
||||
Regenerate with
|
||||
</span>
|
||||
</div>
|
||||
{options.map((option) => (
|
||||
<div
|
||||
key={option.value}
|
||||
role="menuitem"
|
||||
className={`flex items-center m-1.5 p-1.5 text-sm cursor-pointer focus-visible:outline-0 group relative hover:bg-neutral-200 dark:hover:bg-neutral-800 rounded-md my-0 px-3 mx-2 gap-2.5 py-3 !pr-3 ${
|
||||
option.value === selected
|
||||
? "bg-neutral-200 dark:bg-neutral-800"
|
||||
: ""
|
||||
}`}
|
||||
onClick={() => onSelect(option.value)}
|
||||
>
|
||||
<div className="flex grow items-center justify-between gap-2">
|
||||
<div>
|
||||
<div className="flex items-center gap-3">
|
||||
<div>{getDisplayNameForModel(option.name)}</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{option.value === selected && (
|
||||
<FiCheck className="text-neutral-700 dark:text-neutral-300" />
|
||||
)}
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
|
||||
return (
|
||||
<LLMPopover
|
||||
llmManager={llmManager}
|
||||
llmProviders={llmProviders}
|
||||
requiresImageGeneration={false}
|
||||
currentAssistant={selectedAssistant}
|
||||
currentModelName={overriddenModel}
|
||||
trigger={
|
||||
<Popover
|
||||
open={isOpen}
|
||||
onOpenChange={toggleDropdownVisible}
|
||||
content={
|
||||
<div onClick={() => toggleDropdownVisible(!isOpen)}>
|
||||
{!overriddenModel ? (
|
||||
{!alternate ? (
|
||||
<Hoverable size={16} icon={FiRefreshCw as IconType} />
|
||||
) : (
|
||||
<Hoverable
|
||||
size={16}
|
||||
icon={FiRefreshCw as IconType}
|
||||
hoverText={getDisplayNameForModel(overriddenModel)}
|
||||
hoverText={getDisplayNameForModel(alternate)}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
}
|
||||
onSelect={(value) => {
|
||||
const { name, provider, modelName } = destructureValue(value as string);
|
||||
regenerate({
|
||||
name: name,
|
||||
provider: provider,
|
||||
modelName: modelName,
|
||||
});
|
||||
}}
|
||||
popover={Dropdown}
|
||||
align="start"
|
||||
side={side}
|
||||
sideOffset={5}
|
||||
triggerMaxWidth
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
export default function RegenerateOption({
|
||||
selectedAssistant,
|
||||
regenerate,
|
||||
overriddenModel,
|
||||
onHoverChange,
|
||||
onDropdownVisibleChange,
|
||||
}: {
|
||||
selectedAssistant: Persona;
|
||||
regenerate: (modelOverRide: LlmDescriptor) => Promise<void>;
|
||||
overriddenModel?: string;
|
||||
onHoverChange: (isHovered: boolean) => void;
|
||||
onDropdownVisibleChange: (isVisible: boolean) => void;
|
||||
}) {
|
||||
const { llmProviders } = useChatContext();
|
||||
const llmManager = useLlmManager(llmProviders);
|
||||
|
||||
const [_, llmName] = getFinalLLM(llmProviders, selectedAssistant, null);
|
||||
|
||||
const llmOptionsByProvider: {
|
||||
[provider: string]: { name: string; value: string }[];
|
||||
} = {};
|
||||
const uniqueModelNames = new Set<string>();
|
||||
|
||||
llmProviders.forEach((llmProvider) => {
|
||||
if (!llmOptionsByProvider[llmProvider.provider]) {
|
||||
llmOptionsByProvider[llmProvider.provider] = [];
|
||||
}
|
||||
|
||||
(llmProvider.display_model_names || llmProvider.model_names).forEach(
|
||||
(modelName) => {
|
||||
if (!uniqueModelNames.has(modelName)) {
|
||||
uniqueModelNames.add(modelName);
|
||||
llmOptionsByProvider[llmProvider.provider].push({
|
||||
name: modelName,
|
||||
value: structureValue(
|
||||
llmProvider.name,
|
||||
llmProvider.provider,
|
||||
modelName
|
||||
),
|
||||
});
|
||||
}
|
||||
}
|
||||
);
|
||||
});
|
||||
|
||||
const llmOptions = Object.entries(llmOptionsByProvider).flatMap(
|
||||
([provider, options]) => [...options]
|
||||
);
|
||||
|
||||
const currentModelName =
|
||||
llmManager?.currentLlm.modelName ||
|
||||
(selectedAssistant
|
||||
? selectedAssistant.llm_model_version_override || llmName
|
||||
: llmName);
|
||||
|
||||
return (
|
||||
<div
|
||||
className="group flex items-center relative"
|
||||
onMouseEnter={() => onHoverChange(true)}
|
||||
onMouseLeave={() => onHoverChange(false)}
|
||||
>
|
||||
<RegenerateDropdown
|
||||
onDropdownVisibleChange={onDropdownVisibleChange}
|
||||
alternate={overriddenModel}
|
||||
options={llmOptions}
|
||||
selected={currentModelName}
|
||||
onSelect={(value) => {
|
||||
const { name, provider, modelName } = destructureValue(
|
||||
value as string
|
||||
);
|
||||
regenerate({
|
||||
name: name,
|
||||
provider: provider,
|
||||
modelName: modelName,
|
||||
});
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -6,7 +6,7 @@ import { Persona } from "@/app/admin/assistants/interfaces";
|
||||
import LLMPopover from "./LLMPopover";
|
||||
import { InputPrompt } from "@/app/chat/interfaces";
|
||||
|
||||
import { FilterManager, getDisplayNameForModel, LlmManager } from "@/lib/hooks";
|
||||
import { FilterManager, LlmManager } from "@/lib/hooks";
|
||||
import { useChatContext } from "@/components/context/ChatContext";
|
||||
import { ChatFileType, FileDescriptor } from "../interfaces";
|
||||
import {
|
||||
@@ -38,7 +38,6 @@ import { useUser } from "@/components/user/UserProvider";
|
||||
import { useDocumentSelection } from "../useDocumentSelection";
|
||||
import { AgenticToggle } from "./AgenticToggle";
|
||||
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
||||
import { getProviderIcon } from "@/app/admin/configuration/llm/interfaces";
|
||||
import { LoadingIndicator } from "react-select/dist/declarations/src/components/indicators";
|
||||
import { FidgetSpinner } from "react-loader-spinner";
|
||||
import { LoadingAnimation } from "@/components/Loading";
|
||||
@@ -800,27 +799,6 @@ export function ChatInputBar({
|
||||
llmManager={llmManager}
|
||||
requiresImageGeneration={false}
|
||||
currentAssistant={selectedAssistant}
|
||||
trigger={
|
||||
<button
|
||||
className="dark:text-white text-black focus:outline-none"
|
||||
data-testid="llm-popover-trigger"
|
||||
>
|
||||
<ChatInputOption
|
||||
minimize
|
||||
toggle
|
||||
flexPriority="stiff"
|
||||
name={getDisplayNameForModel(
|
||||
llmManager?.currentLlm.modelName || "Models"
|
||||
)}
|
||||
Icon={getProviderIcon(
|
||||
llmManager?.currentLlm.provider || "anthropic",
|
||||
llmManager?.currentLlm.modelName ||
|
||||
"claude-3-5-sonnet-20240620"
|
||||
)}
|
||||
tooltipContent="Switch models"
|
||||
/>
|
||||
</button>
|
||||
}
|
||||
/>
|
||||
|
||||
{retrievalEnabled && (
|
||||
|
||||
@@ -1,9 +1,16 @@
|
||||
import React, { useState, useEffect, useCallback, useMemo } from "react";
|
||||
import React, {
|
||||
useState,
|
||||
useEffect,
|
||||
useCallback,
|
||||
useLayoutEffect,
|
||||
useMemo,
|
||||
} from "react";
|
||||
import {
|
||||
Popover,
|
||||
PopoverContent,
|
||||
PopoverTrigger,
|
||||
} from "@/components/ui/popover";
|
||||
import { ChatInputOption } from "./ChatInputOption";
|
||||
import { getDisplayNameForModel } from "@/lib/hooks";
|
||||
import {
|
||||
checkLLMSupportsImageInput,
|
||||
@@ -28,16 +35,12 @@ import { FiAlertTriangle } from "react-icons/fi";
|
||||
import { Slider } from "@/components/ui/slider";
|
||||
import { useUser } from "@/components/user/UserProvider";
|
||||
import { TruncatedText } from "@/components/ui/truncatedText";
|
||||
import { ChatInputOption } from "./ChatInputOption";
|
||||
|
||||
interface LLMPopoverProps {
|
||||
llmProviders: LLMProviderDescriptor[];
|
||||
llmManager: LlmManager;
|
||||
requiresImageGeneration?: boolean;
|
||||
currentAssistant?: Persona;
|
||||
trigger?: React.ReactElement;
|
||||
onSelect?: (value: string) => void;
|
||||
currentModelName?: string;
|
||||
}
|
||||
|
||||
export default function LLMPopover({
|
||||
@@ -45,69 +48,70 @@ export default function LLMPopover({
|
||||
llmManager,
|
||||
requiresImageGeneration,
|
||||
currentAssistant,
|
||||
trigger,
|
||||
onSelect,
|
||||
currentModelName,
|
||||
}: LLMPopoverProps) {
|
||||
const [isOpen, setIsOpen] = useState(false);
|
||||
const { user } = useUser();
|
||||
|
||||
// Memoize the options to prevent unnecessary recalculations
|
||||
const { llmOptions, defaultProvider, defaultModelDisplayName } =
|
||||
useMemo(() => {
|
||||
const llmOptionsByProvider: {
|
||||
[provider: string]: {
|
||||
name: string;
|
||||
value: string;
|
||||
icon: React.FC<{ size?: number; className?: string }>;
|
||||
}[];
|
||||
} = {};
|
||||
const {
|
||||
llmOptionsByProvider,
|
||||
llmOptions,
|
||||
defaultProvider,
|
||||
defaultModelDisplayName,
|
||||
} = useMemo(() => {
|
||||
const llmOptionsByProvider: {
|
||||
[provider: string]: {
|
||||
name: string;
|
||||
value: string;
|
||||
icon: React.FC<{ size?: number; className?: string }>;
|
||||
}[];
|
||||
} = {};
|
||||
|
||||
const uniqueModelNames = new Set<string>();
|
||||
const uniqueModelNames = new Set<string>();
|
||||
|
||||
llmProviders.forEach((llmProvider) => {
|
||||
if (!llmOptionsByProvider[llmProvider.provider]) {
|
||||
llmOptionsByProvider[llmProvider.provider] = [];
|
||||
}
|
||||
llmProviders.forEach((llmProvider) => {
|
||||
if (!llmOptionsByProvider[llmProvider.provider]) {
|
||||
llmOptionsByProvider[llmProvider.provider] = [];
|
||||
}
|
||||
|
||||
(llmProvider.display_model_names || llmProvider.model_names).forEach(
|
||||
(modelName) => {
|
||||
if (!uniqueModelNames.has(modelName)) {
|
||||
uniqueModelNames.add(modelName);
|
||||
llmOptionsByProvider[llmProvider.provider].push({
|
||||
name: modelName,
|
||||
value: structureValue(
|
||||
llmProvider.name,
|
||||
llmProvider.provider,
|
||||
modelName
|
||||
),
|
||||
icon: getProviderIcon(llmProvider.provider, modelName),
|
||||
});
|
||||
}
|
||||
(llmProvider.display_model_names || llmProvider.model_names).forEach(
|
||||
(modelName) => {
|
||||
if (!uniqueModelNames.has(modelName)) {
|
||||
uniqueModelNames.add(modelName);
|
||||
llmOptionsByProvider[llmProvider.provider].push({
|
||||
name: modelName,
|
||||
value: structureValue(
|
||||
llmProvider.name,
|
||||
llmProvider.provider,
|
||||
modelName
|
||||
),
|
||||
icon: getProviderIcon(llmProvider.provider, modelName),
|
||||
});
|
||||
}
|
||||
);
|
||||
});
|
||||
|
||||
const llmOptions = Object.entries(llmOptionsByProvider).flatMap(
|
||||
([provider, options]) => [...options]
|
||||
}
|
||||
);
|
||||
});
|
||||
|
||||
const defaultProvider = llmProviders.find(
|
||||
(llmProvider) => llmProvider.is_default_provider
|
||||
);
|
||||
const llmOptions = Object.entries(llmOptionsByProvider).flatMap(
|
||||
([provider, options]) => [...options]
|
||||
);
|
||||
|
||||
const defaultModelName = defaultProvider?.default_model_name;
|
||||
const defaultModelDisplayName = defaultModelName
|
||||
? getDisplayNameForModel(defaultModelName)
|
||||
: null;
|
||||
const defaultProvider = llmProviders.find(
|
||||
(llmProvider) => llmProvider.is_default_provider
|
||||
);
|
||||
|
||||
return {
|
||||
llmOptionsByProvider,
|
||||
llmOptions,
|
||||
defaultProvider,
|
||||
defaultModelDisplayName,
|
||||
};
|
||||
}, [llmProviders]);
|
||||
const defaultModelName = defaultProvider?.default_model_name;
|
||||
const defaultModelDisplayName = defaultModelName
|
||||
? getDisplayNameForModel(defaultModelName)
|
||||
: null;
|
||||
|
||||
return {
|
||||
llmOptionsByProvider,
|
||||
llmOptions,
|
||||
defaultProvider,
|
||||
defaultModelDisplayName,
|
||||
};
|
||||
}, [llmProviders]);
|
||||
|
||||
const [localTemperature, setLocalTemperature] = useState(
|
||||
llmManager.temperature ?? 0.5
|
||||
@@ -131,34 +135,32 @@ export default function LLMPopover({
|
||||
|
||||
// Memoize trigger content to prevent rerendering
|
||||
const triggerContent = useMemo(
|
||||
trigger
|
||||
? () => trigger
|
||||
: () => (
|
||||
<button
|
||||
className="dark:text-[#fff] text-[#000] focus:outline-none"
|
||||
data-testid="llm-popover-trigger"
|
||||
>
|
||||
<ChatInputOption
|
||||
minimize
|
||||
toggle
|
||||
flexPriority="stiff"
|
||||
name={getDisplayNameForModel(
|
||||
llmManager?.currentLlm.modelName ||
|
||||
defaultModelDisplayName ||
|
||||
"Models"
|
||||
)}
|
||||
Icon={getProviderIcon(
|
||||
llmManager?.currentLlm.provider ||
|
||||
defaultProvider?.provider ||
|
||||
"anthropic",
|
||||
llmManager?.currentLlm.modelName ||
|
||||
defaultProvider?.default_model_name ||
|
||||
"claude-3-5-sonnet-20240620"
|
||||
)}
|
||||
tooltipContent="Switch models"
|
||||
/>
|
||||
</button>
|
||||
),
|
||||
() => (
|
||||
<button
|
||||
className="dark:text-[#fff] text-[#000] focus:outline-none"
|
||||
data-testid="llm-popover-trigger"
|
||||
>
|
||||
<ChatInputOption
|
||||
minimize
|
||||
toggle
|
||||
flexPriority="stiff"
|
||||
name={getDisplayNameForModel(
|
||||
llmManager?.currentLlm.modelName ||
|
||||
defaultModelDisplayName ||
|
||||
"Models"
|
||||
)}
|
||||
Icon={getProviderIcon(
|
||||
llmManager?.currentLlm.provider ||
|
||||
defaultProvider?.provider ||
|
||||
"anthropic",
|
||||
llmManager?.currentLlm.modelName ||
|
||||
defaultProvider?.default_model_name ||
|
||||
"claude-3-5-sonnet-20240620"
|
||||
)}
|
||||
tooltipContent="Switch models"
|
||||
/>
|
||||
</button>
|
||||
),
|
||||
[defaultModelDisplayName, defaultProvider, llmManager?.currentLlm]
|
||||
);
|
||||
|
||||
@@ -176,14 +178,12 @@ export default function LLMPopover({
|
||||
<button
|
||||
key={index}
|
||||
className={`w-full flex items-center gap-x-2 px-3 py-2 text-sm text-left hover:bg-background-100 dark:hover:bg-neutral-800 transition-colors duration-150 ${
|
||||
(currentModelName || llmManager.currentLlm.modelName) ===
|
||||
name
|
||||
llmManager.currentLlm.modelName === name
|
||||
? "bg-background-100 dark:bg-neutral-900 text-text"
|
||||
: "text-text-darker"
|
||||
}`}
|
||||
onClick={() => {
|
||||
llmManager.updateCurrentLlm(destructureValue(value));
|
||||
onSelect?.(value);
|
||||
setIsOpen(false);
|
||||
}}
|
||||
>
|
||||
|
||||
@@ -668,7 +668,7 @@ const PARAMS_TO_SKIP = [
|
||||
];
|
||||
|
||||
export function buildChatUrl(
|
||||
existingSearchParams: ReadonlyURLSearchParams | null,
|
||||
existingSearchParams: ReadonlyURLSearchParams,
|
||||
chatSessionId: string | null,
|
||||
personaId: number | null,
|
||||
search?: boolean
|
||||
@@ -685,7 +685,7 @@ export function buildChatUrl(
|
||||
finalSearchParams.push(`${SEARCH_PARAM_NAMES.PERSONA_ID}=${personaId}`);
|
||||
}
|
||||
|
||||
existingSearchParams?.forEach((value, key) => {
|
||||
existingSearchParams.forEach((value, key) => {
|
||||
if (!PARAMS_TO_SKIP.includes(key)) {
|
||||
finalSearchParams.push(`${key}=${value}`);
|
||||
}
|
||||
@@ -719,7 +719,7 @@ export async function uploadFilesForChat(
|
||||
return [responseJson.files as FileDescriptor[], null];
|
||||
}
|
||||
|
||||
export function useScrollonStream({
|
||||
export async function useScrollonStream({
|
||||
chatState,
|
||||
scrollableDivRef,
|
||||
scrollDist,
|
||||
@@ -817,5 +817,5 @@ export function useScrollonStream({
|
||||
});
|
||||
}
|
||||
}
|
||||
}, [chatState, distance, scrollDist, scrollableDivRef, enableAutoScroll]);
|
||||
}, [chatState, distance, scrollDist, scrollableDivRef]);
|
||||
}
|
||||
|
||||
@@ -178,6 +178,7 @@ export const AgenticMessage = ({
|
||||
const [isViewingInitialAnswer, setIsViewingInitialAnswer] = useState(true);
|
||||
|
||||
const [canShowResponse, setCanShowResponse] = useState(isComplete);
|
||||
const [isRegenerateHovered, setIsRegenerateHovered] = useState(false);
|
||||
const [isRegenerateDropdownVisible, setIsRegenerateDropdownVisible] =
|
||||
useState(false);
|
||||
|
||||
@@ -596,6 +597,7 @@ export const AgenticMessage = ({
|
||||
onDropdownVisibleChange={
|
||||
setIsRegenerateDropdownVisible
|
||||
}
|
||||
onHoverChange={setIsRegenerateHovered}
|
||||
selectedAssistant={currentPersona!}
|
||||
regenerate={regenerate}
|
||||
overriddenModel={overriddenModel}
|
||||
@@ -611,10 +613,16 @@ export const AgenticMessage = ({
|
||||
absolute -bottom-5
|
||||
z-10
|
||||
invisible ${
|
||||
(isHovering || settings?.isMobile) && "!visible"
|
||||
(isHovering ||
|
||||
isRegenerateHovered ||
|
||||
settings?.isMobile) &&
|
||||
"!visible"
|
||||
}
|
||||
opacity-0 ${
|
||||
(isHovering || settings?.isMobile) && "!opacity-100"
|
||||
(isHovering ||
|
||||
isRegenerateHovered ||
|
||||
settings?.isMobile) &&
|
||||
"!opacity-100"
|
||||
}
|
||||
translate-y-2 ${
|
||||
(isHovering || settings?.isMobile) &&
|
||||
@@ -689,6 +697,7 @@ export const AgenticMessage = ({
|
||||
}
|
||||
regenerate={regenerate}
|
||||
overriddenModel={overriddenModel}
|
||||
onHoverChange={setIsRegenerateHovered}
|
||||
/>
|
||||
</CustomTooltip>
|
||||
)}
|
||||
|
||||
@@ -301,6 +301,7 @@ export const AIMessage = ({
|
||||
|
||||
const finalContent = processContent(content as string);
|
||||
|
||||
const [isRegenerateHovered, setIsRegenerateHovered] = useState(false);
|
||||
const [isRegenerateDropdownVisible, setIsRegenerateDropdownVisible] =
|
||||
useState(false);
|
||||
const { isHovering, trackedElementRef, hoverElementRef } = useMouseTracking();
|
||||
@@ -727,6 +728,7 @@ export const AIMessage = ({
|
||||
onDropdownVisibleChange={
|
||||
setIsRegenerateDropdownVisible
|
||||
}
|
||||
onHoverChange={setIsRegenerateHovered}
|
||||
selectedAssistant={currentPersona!}
|
||||
regenerate={regenerate}
|
||||
overriddenModel={overriddenModel}
|
||||
@@ -742,10 +744,16 @@ export const AIMessage = ({
|
||||
absolute -bottom-5
|
||||
z-10
|
||||
invisible ${
|
||||
(isHovering || settings?.isMobile) && "!visible"
|
||||
(isHovering ||
|
||||
isRegenerateHovered ||
|
||||
settings?.isMobile) &&
|
||||
"!visible"
|
||||
}
|
||||
opacity-0 ${
|
||||
(isHovering || settings?.isMobile) && "!opacity-100"
|
||||
(isHovering ||
|
||||
isRegenerateHovered ||
|
||||
settings?.isMobile) &&
|
||||
"!opacity-100"
|
||||
}
|
||||
flex md:flex-row gap-x-0.5 bg-background-125/40 -mx-1.5 p-1.5 rounded-lg
|
||||
`}
|
||||
@@ -810,6 +818,7 @@ export const AIMessage = ({
|
||||
}
|
||||
regenerate={regenerate}
|
||||
overriddenModel={overriddenModel}
|
||||
onHoverChange={setIsRegenerateHovered}
|
||||
/>
|
||||
</CustomTooltip>
|
||||
)}
|
||||
|
||||
@@ -23,10 +23,8 @@ export const SEARCH_PARAM_NAMES = {
|
||||
SEND_ON_LOAD: "send-on-load",
|
||||
};
|
||||
|
||||
export function shouldSubmitOnLoad(
|
||||
searchParams: ReadonlyURLSearchParams | null
|
||||
) {
|
||||
const rawSubmitOnLoad = searchParams?.get(SEARCH_PARAM_NAMES.SUBMIT_ON_LOAD);
|
||||
export function shouldSubmitOnLoad(searchParams: ReadonlyURLSearchParams) {
|
||||
const rawSubmitOnLoad = searchParams.get(SEARCH_PARAM_NAMES.SUBMIT_ON_LOAD);
|
||||
if (rawSubmitOnLoad === "true" || rawSubmitOnLoad === "1") {
|
||||
return true;
|
||||
}
|
||||
|
||||
@@ -9,6 +9,11 @@ import { redirect } from "next/navigation";
|
||||
import { BackendChatSession } from "../../interfaces";
|
||||
import { SharedChatDisplay } from "./SharedChatDisplay";
|
||||
import { Persona } from "@/app/admin/assistants/interfaces";
|
||||
import {
|
||||
FetchAssistantsResponse,
|
||||
fetchAssistantsSS,
|
||||
} from "@/lib/assistants/fetchAssistantsSS";
|
||||
import { defaultPersona } from "@/app/admin/assistants/lib";
|
||||
import { constructMiniFiedPersona } from "@/lib/assistantIconUtils";
|
||||
|
||||
async function getSharedChat(chatId: string) {
|
||||
|
||||
@@ -2,7 +2,7 @@ import { User } from "@/lib/types";
|
||||
import { FiPlus, FiX } from "react-icons/fi";
|
||||
import { SearchMultiSelectDropdown } from "@/components/Dropdown";
|
||||
import { UsersIcon } from "@/components/icons/icons";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { Button } from "@/components/Button";
|
||||
|
||||
interface UserEditorProps {
|
||||
selectedUserIds: string[];
|
||||
|
||||
@@ -22,7 +22,7 @@ export const AddMemberForm: React.FC<AddMemberFormProps> = ({
|
||||
|
||||
return (
|
||||
<Modal
|
||||
className="max-w-xl overflow-visible"
|
||||
className="max-w-xl"
|
||||
title="Add New User"
|
||||
onOutsideClick={() => onClose()}
|
||||
>
|
||||
|
||||
@@ -43,13 +43,13 @@ const DropdownOption: React.FC<DropdownOptionProps> = ({
|
||||
|
||||
if (href) {
|
||||
return (
|
||||
<a
|
||||
<Link
|
||||
href={href}
|
||||
target={openInNewTab ? "_blank" : undefined}
|
||||
rel={openInNewTab ? "noopener noreferrer" : undefined}
|
||||
>
|
||||
{content}
|
||||
</a>
|
||||
</Link>
|
||||
);
|
||||
} else {
|
||||
return <div onClick={onClick}>{content}</div>;
|
||||
@@ -104,7 +104,7 @@ export function UserDropdown({
|
||||
|
||||
// Construct the current URL
|
||||
const currentUrl = `${pathname}${
|
||||
searchParams?.toString() ? `?${searchParams.toString()}` : ""
|
||||
searchParams.toString() ? `?${searchParams.toString()}` : ""
|
||||
}`;
|
||||
|
||||
// Encode the current URL to use as a redirect parameter
|
||||
|
||||
@@ -59,8 +59,8 @@ export function ClientLayout({
|
||||
const { llmProviders } = useChatContext();
|
||||
const { popup, setPopup } = usePopup();
|
||||
if (
|
||||
(pathname && pathname.startsWith("/admin/connectors")) ||
|
||||
(pathname && pathname.startsWith("/admin/embeddings"))
|
||||
pathname.startsWith("/admin/connectors") ||
|
||||
pathname.startsWith("/admin/embeddings")
|
||||
) {
|
||||
return <>{children}</>;
|
||||
}
|
||||
|
||||
@@ -76,7 +76,7 @@ export const ChatProvider: React.FC<{
|
||||
const { sessions } = await response.json();
|
||||
setChatSessions(sessions);
|
||||
|
||||
const currentSessionId = searchParams?.get("chatId");
|
||||
const currentSessionId = searchParams.get("chatId");
|
||||
if (
|
||||
currentSessionId &&
|
||||
!sessions.some(
|
||||
|
||||
@@ -34,7 +34,7 @@ export const EmbeddingFormProvider: React.FC<{
|
||||
const pathname = usePathname();
|
||||
|
||||
// Initialize formStep based on the URL parameter
|
||||
const initialStep = parseInt(searchParams?.get("step") || "0", 10);
|
||||
const initialStep = parseInt(searchParams.get("step") || "0", 10);
|
||||
const [formStep, setFormStep] = useState(initialStep);
|
||||
const [formValues, setFormValues] = useState<Record<string, any>>({});
|
||||
|
||||
@@ -56,10 +56,8 @@ export const EmbeddingFormProvider: React.FC<{
|
||||
|
||||
useEffect(() => {
|
||||
// Update URL when formStep changes
|
||||
const updatedSearchParams = new URLSearchParams(
|
||||
searchParams?.toString() || ""
|
||||
);
|
||||
const existingStep = updatedSearchParams?.get("step");
|
||||
const updatedSearchParams = new URLSearchParams(searchParams.toString());
|
||||
const existingStep = updatedSearchParams.get("step");
|
||||
updatedSearchParams.set("step", formStep.toString());
|
||||
const newUrl = `${pathname}?${updatedSearchParams.toString()}`;
|
||||
|
||||
@@ -72,7 +70,7 @@ export const EmbeddingFormProvider: React.FC<{
|
||||
|
||||
// Update formStep when URL changes
|
||||
useEffect(() => {
|
||||
const stepFromUrl = parseInt(searchParams?.get("step") || "0", 10);
|
||||
const stepFromUrl = parseInt(searchParams.get("step") || "0", 10);
|
||||
if (stepFromUrl !== formStep) {
|
||||
setFormStep(stepFromUrl);
|
||||
}
|
||||
|
||||
@@ -34,7 +34,7 @@ export const FormProvider: React.FC<{
|
||||
const pathname = usePathname();
|
||||
|
||||
// Initialize formStep based on the URL parameter
|
||||
const initialStep = parseInt(searchParams?.get("step") || "0", 10);
|
||||
const initialStep = parseInt(searchParams.get("step") || "0", 10);
|
||||
const [formStep, setFormStep] = useState(initialStep);
|
||||
const [formValues, setFormValues] = useState<Record<string, any>>({});
|
||||
|
||||
@@ -56,10 +56,8 @@ export const FormProvider: React.FC<{
|
||||
|
||||
useEffect(() => {
|
||||
// Update URL when formStep changes
|
||||
const updatedSearchParams = new URLSearchParams(
|
||||
searchParams?.toString() || ""
|
||||
);
|
||||
const existingStep = updatedSearchParams?.get("step");
|
||||
const updatedSearchParams = new URLSearchParams(searchParams.toString());
|
||||
const existingStep = updatedSearchParams.get("step");
|
||||
updatedSearchParams.set("step", formStep.toString());
|
||||
const newUrl = `${pathname}?${updatedSearchParams.toString()}`;
|
||||
|
||||
@@ -71,7 +69,7 @@ export const FormProvider: React.FC<{
|
||||
}, [formStep, router, pathname, searchParams]);
|
||||
|
||||
useEffect(() => {
|
||||
const stepFromUrl = parseInt(searchParams?.get("step") || "0", 10);
|
||||
const stepFromUrl = parseInt(searchParams.get("step") || "0", 10);
|
||||
if (stepFromUrl !== formStep) {
|
||||
setFormStep(stepFromUrl);
|
||||
}
|
||||
|
||||
@@ -35,7 +35,7 @@ export const HealthCheckBanner = () => {
|
||||
useEffect(() => {
|
||||
if (userError && userError.status === 403) {
|
||||
logout().then(() => {
|
||||
if (!pathname?.includes("/auth")) {
|
||||
if (!pathname.includes("/auth")) {
|
||||
setShowLoggedOutModal(true);
|
||||
}
|
||||
});
|
||||
@@ -61,7 +61,7 @@ export const HealthCheckBanner = () => {
|
||||
expirationTimeoutRef.current = setTimeout(() => {
|
||||
setExpired(true);
|
||||
|
||||
if (!pathname?.includes("/auth")) {
|
||||
if (!pathname.includes("/auth")) {
|
||||
setShowLoggedOutModal(true);
|
||||
}
|
||||
}, timeUntilExpire);
|
||||
@@ -205,7 +205,7 @@ export const HealthCheckBanner = () => {
|
||||
}
|
||||
|
||||
if (error instanceof RedirectError || expired) {
|
||||
if (!pathname?.includes("/auth")) {
|
||||
if (!pathname.includes("/auth")) {
|
||||
setShowLoggedOutModal(true);
|
||||
}
|
||||
return null;
|
||||
|
||||
@@ -19,12 +19,12 @@ function setWelcomeFlowComplete() {
|
||||
Cookies.set(COMPLETED_WELCOME_FLOW_COOKIE, "true", { expires: 365 });
|
||||
}
|
||||
|
||||
export function CompletedWelcomeFlowDummyComponent() {
|
||||
export function _CompletedWelcomeFlowDummyComponent() {
|
||||
setWelcomeFlowComplete();
|
||||
return null;
|
||||
}
|
||||
|
||||
export function WelcomeModal({ user }: { user: User | null }) {
|
||||
export function _WelcomeModal({ user }: { user: User | null }) {
|
||||
const router = useRouter();
|
||||
|
||||
const [providerOptions, setProviderOptions] = useState<
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import {
|
||||
CompletedWelcomeFlowDummyComponent,
|
||||
WelcomeModal as WelcomeModalComponent,
|
||||
_CompletedWelcomeFlowDummyComponent,
|
||||
_WelcomeModal,
|
||||
} from "./WelcomeModal";
|
||||
import { COMPLETED_WELCOME_FLOW_COOKIE } from "./constants";
|
||||
import { User } from "@/lib/types";
|
||||
@@ -24,8 +24,8 @@ export function WelcomeModal({
|
||||
}) {
|
||||
const hasCompletedWelcomeFlow = hasCompletedWelcomeFlowSS(requestCookies);
|
||||
if (hasCompletedWelcomeFlow) {
|
||||
return <CompletedWelcomeFlowDummyComponent />;
|
||||
return <_CompletedWelcomeFlowDummyComponent />;
|
||||
}
|
||||
|
||||
return <WelcomeModalComponent user={user} />;
|
||||
return <_WelcomeModal user={user} />;
|
||||
}
|
||||
|
||||
@@ -31,13 +31,13 @@ export function NewTeamModal() {
|
||||
const { setPopup } = usePopup();
|
||||
|
||||
useEffect(() => {
|
||||
const hasNewTeamParam = searchParams?.has("new_team");
|
||||
const hasNewTeamParam = searchParams.has("new_team");
|
||||
if (hasNewTeamParam) {
|
||||
setShowNewTeamModal(true);
|
||||
fetchTenantInfo();
|
||||
|
||||
// Remove the new_team parameter from the URL without page reload
|
||||
const newParams = new URLSearchParams(searchParams?.toString() || "");
|
||||
const newParams = new URLSearchParams(searchParams.toString());
|
||||
newParams.delete("new_team");
|
||||
const newUrl =
|
||||
window.location.pathname +
|
||||
|
||||
@@ -16,7 +16,7 @@ export const usePopupFromQuery = (messages: PopupMessages) => {
|
||||
const searchParams = new URLSearchParams(window.location.search);
|
||||
|
||||
// Get the value for search param with key "message"
|
||||
const messageValue = searchParams?.get("message");
|
||||
const messageValue = searchParams.get("message");
|
||||
|
||||
// Check if any key from messages object is present in search params
|
||||
if (messageValue && messageValue in messages) {
|
||||
|
||||
@@ -148,7 +148,7 @@ function usePaginatedFetch<T extends PaginatedType>({
|
||||
// Updates the URL with the current page number
|
||||
const updatePageUrl = useCallback(
|
||||
(page: number) => {
|
||||
if (currentPath && searchParams) {
|
||||
if (currentPath) {
|
||||
const params = new URLSearchParams(searchParams);
|
||||
params.set("page", page.toString());
|
||||
router.replace(`${currentPath}?${params.toString()}`, {
|
||||
|
||||
@@ -167,7 +167,9 @@ export const constructMiniFiedPersona = (
|
||||
display_priority: 0,
|
||||
description: "",
|
||||
document_sets: [],
|
||||
prompts: [],
|
||||
tools: [],
|
||||
search_start_date: null,
|
||||
owner: null,
|
||||
starter_messages: null,
|
||||
builtin_persona: false,
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import { FullPersona, Persona } from "@/app/admin/assistants/interfaces";
|
||||
import { Persona } from "@/app/admin/assistants/interfaces";
|
||||
import { CCPairBasicInfo, DocumentSet, User } from "../types";
|
||||
import { getCurrentUserSS } from "../userSS";
|
||||
import { fetchSS } from "../utilsSS";
|
||||
@@ -18,7 +18,7 @@ export async function fetchAssistantEditorInfoSS(
|
||||
documentSets: DocumentSet[];
|
||||
llmProviders: LLMProviderView[];
|
||||
user: User | null;
|
||||
existingPersona: FullPersona | null;
|
||||
existingPersona: Persona | null;
|
||||
tools: ToolSnapshot[];
|
||||
},
|
||||
null,
|
||||
@@ -94,7 +94,7 @@ export async function fetchAssistantEditorInfoSS(
|
||||
}
|
||||
|
||||
const existingPersona = personaResponse
|
||||
? ((await personaResponse.json()) as FullPersona)
|
||||
? ((await personaResponse.json()) as Persona)
|
||||
: null;
|
||||
|
||||
let error: string | null = null;
|
||||
|
||||
@@ -1333,10 +1333,10 @@ export function createConnectorValidationSchema(
|
||||
): Yup.ObjectSchema<Record<string, any>> {
|
||||
const configuration = connectorConfigs[connector];
|
||||
|
||||
const object = Yup.object().shape({
|
||||
return Yup.object().shape({
|
||||
access_type: Yup.string().required("Access Type is required"),
|
||||
name: Yup.string().required("Connector Name is required"),
|
||||
...[...configuration.values, ...configuration.advanced_values].reduce(
|
||||
...configuration.values.reduce(
|
||||
(acc, field) => {
|
||||
let schema: any =
|
||||
field.type === "select"
|
||||
@@ -1363,8 +1363,6 @@ export function createConnectorValidationSchema(
|
||||
pruneFreq: Yup.number().min(0, "Prune frequency must be non-negative"),
|
||||
refreshFreq: Yup.number().min(0, "Refresh frequency must be non-negative"),
|
||||
});
|
||||
|
||||
return object;
|
||||
}
|
||||
|
||||
export const defaultPruneFreqDays = 30; // 30 days
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user