mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-19 16:55:46 +00:00
Compare commits
16 Commits
fix_openap
...
error_ux
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8d563c4a42 | ||
|
|
5cf24c74fd | ||
|
|
52d3432056 | ||
|
|
a3e3d83b7e | ||
|
|
4dc88ca037 | ||
|
|
11e7e1c4d6 | ||
|
|
f2d74ce540 | ||
|
|
25389c5120 | ||
|
|
ad0721ecd8 | ||
|
|
426a8842ae | ||
|
|
a98dcbc7de | ||
|
|
6f389dc100 | ||
|
|
d56177958f | ||
|
|
0e42ae9024 | ||
|
|
ce2b4de245 | ||
|
|
a515aa78d2 |
@@ -18,12 +18,13 @@ depends_on = None
|
||||
def upgrade() -> None:
|
||||
# Create a basic index on the lowercase message column for direct text matching
|
||||
# Limit to 1500 characters to stay well under the 2856 byte limit of btree version 4
|
||||
op.execute(
|
||||
"""
|
||||
CREATE INDEX idx_chat_message_message_lower
|
||||
ON chat_message (LOWER(substring(message, 1, 1500)))
|
||||
"""
|
||||
)
|
||||
# op.execute(
|
||||
# """
|
||||
# CREATE INDEX idx_chat_message_message_lower
|
||||
# ON chat_message (LOWER(substring(message, 1, 1500)))
|
||||
# """
|
||||
# )
|
||||
pass
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
|
||||
@@ -5,11 +5,9 @@ from onyx.background.celery.apps.primary import celery_app
|
||||
from onyx.background.task_utils import build_celery_task_wrapper
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.db.chat import delete_chat_sessions_older_than
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -18,10 +16,8 @@ logger = setup_logger()
|
||||
|
||||
@build_celery_task_wrapper(name_chat_ttl_task)
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def perform_ttl_management_task(
|
||||
retention_limit_days: int, *, tenant_id: str | None
|
||||
) -> None:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
def perform_ttl_management_task(retention_limit_days: int, *, tenant_id: str) -> None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
delete_chat_sessions_older_than(retention_limit_days, db_session)
|
||||
|
||||
|
||||
@@ -35,24 +31,19 @@ def perform_ttl_management_task(
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_ttl_management_task(*, tenant_id: str | None) -> None:
|
||||
def check_ttl_management_task(*, tenant_id: str) -> None:
|
||||
"""Runs periodically to check if any ttl tasks should be run and adds them
|
||||
to the queue"""
|
||||
token = None
|
||||
if MULTI_TENANT and tenant_id is not None:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
settings = load_settings()
|
||||
retention_limit_days = settings.maximum_chat_retention_days
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
if should_perform_chat_ttl_check(retention_limit_days, db_session):
|
||||
perform_ttl_management_task.apply_async(
|
||||
kwargs=dict(
|
||||
retention_limit_days=retention_limit_days, tenant_id=tenant_id
|
||||
),
|
||||
)
|
||||
if token is not None:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
@@ -60,9 +51,9 @@ def check_ttl_management_task(*, tenant_id: str | None) -> None:
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def autogenerate_usage_report_task(*, tenant_id: str | None) -> None:
|
||||
def autogenerate_usage_report_task(*, tenant_id: str) -> None:
|
||||
"""This generates usage report under the /admin/generate-usage/report endpoint"""
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
create_new_usage_report(
|
||||
db_session=db_session,
|
||||
user_id=None,
|
||||
|
||||
@@ -18,7 +18,7 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def monitor_usergroup_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
"""This function is likely to move in the worker refactor happening next."""
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
|
||||
@@ -2,6 +2,7 @@ import csv
|
||||
import io
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from http import HTTPStatus
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
@@ -21,8 +22,10 @@ from ee.onyx.server.query_history.models import QuestionAnswerPairSnapshot
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import get_display_email
|
||||
from onyx.chat.chat_utils import create_chat_chain
|
||||
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import QAFeedbackType
|
||||
from onyx.configs.constants import QueryHistoryType
|
||||
from onyx.configs.constants import SessionType
|
||||
from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_chat_sessions_by_user
|
||||
@@ -35,6 +38,8 @@ from onyx.server.query_and_chat.models import ChatSessionsResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
ONYX_ANONYMIZED_EMAIL = "anonymous@anonymous.invalid"
|
||||
|
||||
|
||||
def fetch_and_process_chat_session_history(
|
||||
db_session: Session,
|
||||
@@ -107,6 +112,17 @@ def get_user_chat_sessions(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatSessionsResponse:
|
||||
# we specifically don't allow this endpoint if "anonymized" since
|
||||
# this is a direct query on the user id
|
||||
if ONYX_QUERY_HISTORY_TYPE in [
|
||||
QueryHistoryType.DISABLED,
|
||||
QueryHistoryType.ANONYMIZED,
|
||||
]:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.FORBIDDEN,
|
||||
detail="Per user query history has been disabled by the administrator.",
|
||||
)
|
||||
|
||||
try:
|
||||
chat_sessions = get_chat_sessions_by_user(
|
||||
user_id=user_id, deleted=False, db_session=db_session, limit=0
|
||||
@@ -141,6 +157,12 @@ def get_chat_session_history(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> PaginatedReturn[ChatSessionMinimal]:
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.DISABLED:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.FORBIDDEN,
|
||||
detail="Query history has been disabled by the administrator.",
|
||||
)
|
||||
|
||||
page_of_chat_sessions = get_page_of_chat_sessions(
|
||||
page_num=page_num,
|
||||
page_size=page_size,
|
||||
@@ -157,11 +179,16 @@ def get_chat_session_history(
|
||||
feedback_filter=feedback_type,
|
||||
)
|
||||
|
||||
minimal_chat_sessions: list[ChatSessionMinimal] = []
|
||||
|
||||
for chat_session in page_of_chat_sessions:
|
||||
minimal_chat_session = ChatSessionMinimal.from_chat_session(chat_session)
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
|
||||
minimal_chat_session.user_email = ONYX_ANONYMIZED_EMAIL
|
||||
minimal_chat_sessions.append(minimal_chat_session)
|
||||
|
||||
return PaginatedReturn(
|
||||
items=[
|
||||
ChatSessionMinimal.from_chat_session(chat_session)
|
||||
for chat_session in page_of_chat_sessions
|
||||
],
|
||||
items=minimal_chat_sessions,
|
||||
total_items=total_filtered_chat_sessions_count,
|
||||
)
|
||||
|
||||
@@ -172,6 +199,12 @@ def get_chat_session_admin(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatSessionSnapshot:
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.DISABLED:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.FORBIDDEN,
|
||||
detail="Query history has been disabled by the administrator.",
|
||||
)
|
||||
|
||||
try:
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=chat_session_id,
|
||||
@@ -193,6 +226,9 @@ def get_chat_session_admin(
|
||||
f"Could not create snapshot for chat session with id '{chat_session_id}'",
|
||||
)
|
||||
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
|
||||
snapshot.user_email = ONYX_ANONYMIZED_EMAIL
|
||||
|
||||
return snapshot
|
||||
|
||||
|
||||
@@ -203,6 +239,12 @@ def get_query_history_as_csv(
|
||||
end: datetime | None = None,
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StreamingResponse:
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.DISABLED:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.FORBIDDEN,
|
||||
detail="Query history has been disabled by the administrator.",
|
||||
)
|
||||
|
||||
complete_chat_session_history = fetch_and_process_chat_session_history(
|
||||
db_session=db_session,
|
||||
start=start or datetime.fromtimestamp(0, tz=timezone.utc),
|
||||
@@ -213,6 +255,9 @@ def get_query_history_as_csv(
|
||||
|
||||
question_answer_pairs: list[QuestionAnswerPairSnapshot] = []
|
||||
for chat_session_snapshot in complete_chat_session_history:
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
|
||||
chat_session_snapshot.user_email = ONYX_ANONYMIZED_EMAIL
|
||||
|
||||
question_answer_pairs.extend(
|
||||
QuestionAnswerPairSnapshot.from_chat_session_snapshot(chat_session_snapshot)
|
||||
)
|
||||
|
||||
@@ -28,7 +28,7 @@ def get_tenant_id_for_email(email: str) -> str:
|
||||
|
||||
|
||||
def user_owns_a_tenant(email: str) -> bool:
|
||||
with get_session_with_tenant(tenant_id=None) as db_session:
|
||||
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
result = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(UserTenantMapping.email == email)
|
||||
@@ -38,7 +38,7 @@ def user_owns_a_tenant(email: str) -> bool:
|
||||
|
||||
|
||||
def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
with get_session_with_tenant(tenant_id=None) as db_session:
|
||||
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
try:
|
||||
for email in emails:
|
||||
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
|
||||
@@ -48,7 +48,7 @@ def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
|
||||
|
||||
def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
with get_session_with_tenant(tenant_id=None) as db_session:
|
||||
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
try:
|
||||
mappings_to_delete = (
|
||||
db_session.query(UserTenantMapping)
|
||||
@@ -71,7 +71,7 @@ def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
|
||||
|
||||
def remove_all_users_from_tenant(tenant_id: str) -> None:
|
||||
with get_session_with_tenant(tenant_id=None) as db_session:
|
||||
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
db_session.query(UserTenantMapping).filter(
|
||||
UserTenantMapping.tenant_id == tenant_id
|
||||
).delete()
|
||||
|
||||
@@ -10,6 +10,7 @@ from pydantic import BaseModel
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.app_configs import API_KEY_HASH_ROUNDS
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
_API_KEY_HEADER_NAME = "Authorization"
|
||||
@@ -35,8 +36,7 @@ class ApiKeyDescriptor(BaseModel):
|
||||
|
||||
|
||||
def generate_api_key(tenant_id: str | None = None) -> str:
|
||||
# For backwards compatibility, if no tenant_id, generate old style key
|
||||
if not tenant_id:
|
||||
if not MULTI_TENANT or not tenant_id:
|
||||
return _API_KEY_PREFIX + secrets.token_urlsafe(_API_KEY_LEN)
|
||||
|
||||
encoded_tenant = quote(tenant_id) # URL encode the tenant ID
|
||||
|
||||
@@ -2,6 +2,8 @@ import smtplib
|
||||
from datetime import datetime
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from email.utils import formatdate
|
||||
from email.utils import make_msgid
|
||||
|
||||
from onyx.configs.app_configs import EMAIL_CONFIGURED
|
||||
from onyx.configs.app_configs import EMAIL_FROM
|
||||
@@ -13,6 +15,7 @@ from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import TENANT_ID_COOKIE_NAME
|
||||
from onyx.db.models import User
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
HTML_EMAIL_TEMPLATE = """\
|
||||
<!DOCTYPE html>
|
||||
@@ -150,8 +153,9 @@ def send_email(
|
||||
msg = MIMEMultipart("alternative")
|
||||
msg["Subject"] = subject
|
||||
msg["To"] = user_email
|
||||
if mail_from:
|
||||
msg["From"] = mail_from
|
||||
msg["From"] = mail_from
|
||||
msg["Date"] = formatdate(localtime=True)
|
||||
msg["Message-ID"] = make_msgid(domain="onyx.app")
|
||||
|
||||
part_text = MIMEText(text_body, "plain")
|
||||
part_html = MIMEText(html_body, "html")
|
||||
@@ -173,7 +177,7 @@ def send_subscription_cancellation_email(user_email: str) -> None:
|
||||
subject = "Your Onyx Subscription Has Been Canceled"
|
||||
heading = "Subscription Canceled"
|
||||
message = (
|
||||
"<p>We’re sorry to see you go.</p>"
|
||||
"<p>We're sorry to see you go.</p>"
|
||||
"<p>Your subscription has been canceled and will end on your next billing date.</p>"
|
||||
"<p>If you change your mind, you can always come back!</p>"
|
||||
)
|
||||
@@ -239,13 +243,13 @@ def send_user_email_invite(
|
||||
def send_forgot_password_email(
|
||||
user_email: str,
|
||||
token: str,
|
||||
tenant_id: str,
|
||||
mail_from: str = EMAIL_FROM,
|
||||
tenant_id: str | None = None,
|
||||
) -> None:
|
||||
# Builds a forgot password email with or without fancy HTML
|
||||
subject = "Onyx Forgot Password"
|
||||
link = f"{WEB_DOMAIN}/auth/reset-password?token={token}"
|
||||
if tenant_id:
|
||||
if MULTI_TENANT:
|
||||
link += f"&{TENANT_ID_COOKIE_NAME}={tenant_id}"
|
||||
message = f"<p>Click the following link to reset your password:</p><p>{link}</p>"
|
||||
html_content = build_html_email("Reset Your Password", message)
|
||||
|
||||
@@ -214,7 +214,7 @@ def verify_email_is_invited(email: str) -> None:
|
||||
raise PermissionError("User not on allowed user whitelist")
|
||||
|
||||
|
||||
def verify_email_in_whitelist(email: str, tenant_id: str | None = None) -> None:
|
||||
def verify_email_in_whitelist(email: str, tenant_id: str) -> None:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
if not get_user_by_email(email, db_session):
|
||||
verify_email_is_invited(email)
|
||||
@@ -420,7 +420,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
except exceptions.UserNotExists:
|
||||
try:
|
||||
# Attempt to get user by email
|
||||
user = await self.get_by_email(account_email)
|
||||
user = cast(User, await self.user_db.get_by_email(account_email))
|
||||
if not associate_by_email:
|
||||
raise exceptions.UserAlreadyExists()
|
||||
|
||||
@@ -553,7 +553,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
async_return_default_schema,
|
||||
)(email=user.email)
|
||||
|
||||
send_forgot_password_email(user.email, token, tenant_id=tenant_id)
|
||||
send_forgot_password_email(user.email, tenant_id=tenant_id, token=token)
|
||||
|
||||
async def on_after_request_verify(
|
||||
self, user: User, token: str, request: Optional[Request] = None
|
||||
|
||||
@@ -2,6 +2,7 @@ import logging
|
||||
import multiprocessing
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import sentry_sdk
|
||||
from celery import Task
|
||||
@@ -131,9 +132,9 @@ def on_task_postrun(
|
||||
# Get tenant_id directly from kwargs- each celery task has a tenant_id kwarg
|
||||
if not kwargs:
|
||||
logger.error(f"Task {task.name} (ID: {task_id}) is missing kwargs")
|
||||
tenant_id = None
|
||||
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||
else:
|
||||
tenant_id = kwargs.get("tenant_id")
|
||||
tenant_id = cast(str, kwargs.get("tenant_id", POSTGRES_DEFAULT_SCHEMA))
|
||||
|
||||
task_logger.debug(
|
||||
f"Task {task.name} (ID: {task_id}) completed with state: {state} "
|
||||
|
||||
@@ -34,7 +34,7 @@ def _get_deletion_status(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
db_session: Session,
|
||||
tenant_id: str | None = None,
|
||||
tenant_id: str,
|
||||
) -> TaskQueueState | None:
|
||||
"""We no longer store TaskQueueState in the DB for a deletion attempt.
|
||||
This function populates TaskQueueState by just checking redis.
|
||||
@@ -67,7 +67,7 @@ def get_deletion_attempt_snapshot(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
db_session: Session,
|
||||
tenant_id: str | None = None,
|
||||
tenant_id: str,
|
||||
) -> DeletionAttemptSnapshot | None:
|
||||
deletion_task = _get_deletion_status(
|
||||
connector_id, credential_id, db_session, tenant_id
|
||||
|
||||
@@ -109,9 +109,7 @@ def revoke_tasks_blocking_deletion(
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_connector_deletion_task(
|
||||
self: Task, *, tenant_id: str | None
|
||||
) -> bool | None:
|
||||
def check_for_connector_deletion_task(self: Task, *, tenant_id: str) -> bool | None:
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
@@ -224,7 +222,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
cc_pair_id: int,
|
||||
db_session: Session,
|
||||
lock_beat: RedisLock,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> int | None:
|
||||
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
|
||||
Note that syncing can still be required even if the number of sync tasks generated is zero.
|
||||
@@ -345,7 +343,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
|
||||
|
||||
def monitor_connector_deletion_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis
|
||||
tenant_id: str, key_bytes: bytes, r: Redis
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
@@ -500,7 +498,7 @@ def monitor_connector_deletion_taskset(
|
||||
|
||||
|
||||
def validate_connector_deletion_fences(
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
r: Redis,
|
||||
r_replica: Redis,
|
||||
r_celery: Redis,
|
||||
@@ -540,7 +538,7 @@ def validate_connector_deletion_fences(
|
||||
|
||||
|
||||
def validate_connector_deletion_fence(
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
key_bytes: bytes,
|
||||
queued_tasks: set[str],
|
||||
r: Redis,
|
||||
|
||||
@@ -221,7 +221,7 @@ def try_creating_permissions_sync_task(
|
||||
app: Celery,
|
||||
cc_pair_id: int,
|
||||
r: Redis,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> str | None:
|
||||
"""Returns a randomized payload id on success.
|
||||
Returns None if no syncing is required."""
|
||||
@@ -320,7 +320,7 @@ def try_creating_permissions_sync_task(
|
||||
def connector_permission_sync_generator_task(
|
||||
self: Task,
|
||||
cc_pair_id: int,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
Permission sync task that handles document permission syncing for a given connector credential pair
|
||||
@@ -410,7 +410,6 @@ def connector_permission_sync_generator_task(
|
||||
cc_pair.connector.id,
|
||||
cc_pair.credential.id,
|
||||
db_session,
|
||||
tenant_id,
|
||||
enforce_creation=False,
|
||||
)
|
||||
if not created:
|
||||
@@ -510,7 +509,7 @@ def connector_permission_sync_generator_task(
|
||||
)
|
||||
def update_external_document_permissions_task(
|
||||
self: Task,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
serialized_doc_external_access: dict,
|
||||
source_string: str,
|
||||
connector_id: int,
|
||||
@@ -585,7 +584,7 @@ def update_external_document_permissions_task(
|
||||
|
||||
|
||||
def validate_permission_sync_fences(
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
r: Redis,
|
||||
r_replica: Redis,
|
||||
r_celery: Redis,
|
||||
@@ -632,7 +631,7 @@ def validate_permission_sync_fences(
|
||||
|
||||
|
||||
def validate_permission_sync_fence(
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
key_bytes: bytes,
|
||||
queued_tasks: set[str],
|
||||
reserved_tasks: set[str],
|
||||
@@ -842,7 +841,7 @@ class PermissionSyncCallback(IndexingHeartbeatInterface):
|
||||
|
||||
|
||||
def monitor_ccpair_permissions_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
|
||||
@@ -123,7 +123,7 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
|
||||
# we need to use celery's redis client to access its redis data
|
||||
# (which lives on a different db number)
|
||||
r = get_redis_client()
|
||||
@@ -220,7 +220,7 @@ def try_creating_external_group_sync_task(
|
||||
app: Celery,
|
||||
cc_pair_id: int,
|
||||
r: Redis,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> str | None:
|
||||
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
|
||||
Returns None if no syncing is required."""
|
||||
@@ -306,7 +306,7 @@ def try_creating_external_group_sync_task(
|
||||
def connector_external_group_sync_generator_task(
|
||||
self: Task,
|
||||
cc_pair_id: int,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
External group sync task for a given connector credential pair
|
||||
@@ -392,7 +392,6 @@ def connector_external_group_sync_generator_task(
|
||||
cc_pair.connector.id,
|
||||
cc_pair.credential.id,
|
||||
db_session,
|
||||
tenant_id,
|
||||
enforce_creation=False,
|
||||
)
|
||||
if not created:
|
||||
@@ -494,7 +493,7 @@ def connector_external_group_sync_generator_task(
|
||||
|
||||
|
||||
def validate_external_group_sync_fences(
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
celery_app: Celery,
|
||||
r: Redis,
|
||||
r_replica: Redis,
|
||||
@@ -526,7 +525,7 @@ def validate_external_group_sync_fences(
|
||||
|
||||
|
||||
def validate_external_group_sync_fence(
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
key_bytes: bytes,
|
||||
reserved_tasks: set[str],
|
||||
r_celery: Redis,
|
||||
|
||||
@@ -182,7 +182,7 @@ class SimpleJobResult:
|
||||
|
||||
|
||||
class ConnectorIndexingContext(BaseModel):
|
||||
tenant_id: str | None
|
||||
tenant_id: str
|
||||
cc_pair_id: int
|
||||
search_settings_id: int
|
||||
index_attempt_id: int
|
||||
@@ -210,7 +210,7 @@ class ConnectorIndexingLogBuilder:
|
||||
|
||||
|
||||
def monitor_ccpair_indexing_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
# if the fence doesn't exist, there's nothing to do
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
@@ -358,7 +358,7 @@ def monitor_ccpair_indexing_taskset(
|
||||
soft_time_limit=300,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
"""a lightweight task used to kick off indexing tasks.
|
||||
Occcasionally does some validation of existing state to clear up error conditions"""
|
||||
|
||||
@@ -598,7 +598,7 @@ def connector_indexing_task(
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
is_ee: bool,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> int | None:
|
||||
"""Indexing task. For a cc pair, this task pulls all document IDs from the source
|
||||
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
|
||||
@@ -890,7 +890,7 @@ def connector_indexing_proxy_task(
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
"""celery out of process task execution strategy is pool=prefork, but it uses fork,
|
||||
and forking is inherently unstable.
|
||||
@@ -1170,7 +1170,7 @@ def connector_indexing_proxy_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_CHECKPOINT_CLEANUP,
|
||||
soft_time_limit=300,
|
||||
)
|
||||
def check_for_checkpoint_cleanup(*, tenant_id: str | None) -> None:
|
||||
def check_for_checkpoint_cleanup(*, tenant_id: str) -> None:
|
||||
"""Clean up old checkpoints that are older than 7 days."""
|
||||
locked = False
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
@@ -187,7 +187,7 @@ class IndexingCallback(IndexingCallbackBase):
|
||||
|
||||
|
||||
def validate_indexing_fence(
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
key_bytes: bytes,
|
||||
reserved_tasks: set[str],
|
||||
r_celery: Redis,
|
||||
@@ -311,7 +311,7 @@ def validate_indexing_fence(
|
||||
|
||||
|
||||
def validate_indexing_fences(
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
r_replica: Redis,
|
||||
r_celery: Redis,
|
||||
lock_beat: RedisLock,
|
||||
@@ -442,7 +442,7 @@ def try_creating_indexing_task(
|
||||
reindex: bool,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> int | None:
|
||||
"""Checks for any conditions that should block the indexing task from being
|
||||
created, then creates the task.
|
||||
|
||||
@@ -59,7 +59,7 @@ def _process_model_list_response(model_list_json: Any) -> list[str]:
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_llm_model_update(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
def check_for_llm_model_update(self: Task, *, tenant_id: str) -> bool | None:
|
||||
if not LLM_MODEL_UPDATE_API_URL:
|
||||
raise ValueError("LLM model update API URL not configured")
|
||||
|
||||
|
||||
@@ -91,7 +91,7 @@ class Metric(BaseModel):
|
||||
}
|
||||
task_logger.info(json.dumps(data))
|
||||
|
||||
def emit(self, tenant_id: str | None) -> None:
|
||||
def emit(self, tenant_id: str) -> None:
|
||||
# Convert value to appropriate type based on the input value
|
||||
bool_value = None
|
||||
float_value = None
|
||||
@@ -656,7 +656,7 @@ def build_job_id(
|
||||
queue=OnyxCeleryQueues.MONITORING,
|
||||
bind=True,
|
||||
)
|
||||
def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
|
||||
def monitor_background_processes(self: Task, *, tenant_id: str) -> None:
|
||||
"""Collect and emit metrics about background processes.
|
||||
This task runs periodically to gather metrics about:
|
||||
- Queue lengths for different Celery queues
|
||||
@@ -864,7 +864,7 @@ def cloud_monitor_celery_queues(
|
||||
|
||||
|
||||
@shared_task(name=OnyxCeleryTask.MONITOR_CELERY_QUEUES, ignore_result=True, bind=True)
|
||||
def monitor_celery_queues(self: Task, *, tenant_id: str | None) -> None:
|
||||
def monitor_celery_queues(self: Task, *, tenant_id: str) -> None:
|
||||
return monitor_celery_queues_helper(self)
|
||||
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ from onyx.db.engine import get_session_with_current_tenant
|
||||
bind=True,
|
||||
base=AbortableTask,
|
||||
)
|
||||
def kombu_message_cleanup_task(self: Any, tenant_id: str | None) -> int:
|
||||
def kombu_message_cleanup_task(self: Any, tenant_id: str) -> int:
|
||||
"""Runs periodically to clean up the kombu_message table"""
|
||||
|
||||
# we will select messages older than this amount to clean up
|
||||
|
||||
@@ -114,7 +114,7 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
@@ -211,7 +211,7 @@ def try_creating_prune_generator_task(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> str | None:
|
||||
"""Checks for any conditions that should block the pruning generator task from being
|
||||
created, then creates the task.
|
||||
@@ -333,7 +333,7 @@ def connector_pruning_generator_task(
|
||||
cc_pair_id: int,
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
"""connector pruning task. For a cc pair, this task pulls all document IDs from the source
|
||||
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
|
||||
@@ -521,7 +521,7 @@ def connector_pruning_generator_task(
|
||||
|
||||
|
||||
def monitor_ccpair_pruning_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
@@ -567,7 +567,7 @@ def monitor_ccpair_pruning_taskset(
|
||||
|
||||
|
||||
def validate_pruning_fences(
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
r: Redis,
|
||||
r_replica: Redis,
|
||||
r_celery: Redis,
|
||||
@@ -615,7 +615,7 @@ def validate_pruning_fences(
|
||||
|
||||
|
||||
def validate_pruning_fence(
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
key_bytes: bytes,
|
||||
reserved_tasks: set[str],
|
||||
queued_tasks: set[str],
|
||||
|
||||
@@ -32,7 +32,7 @@ class RetryDocumentIndex:
|
||||
self,
|
||||
doc_id: str,
|
||||
*,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
chunk_count: int | None,
|
||||
) -> int:
|
||||
return self.index.delete_single(
|
||||
@@ -50,7 +50,7 @@ class RetryDocumentIndex:
|
||||
self,
|
||||
doc_id: str,
|
||||
*,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
chunk_count: int | None,
|
||||
fields: VespaDocumentFields,
|
||||
) -> int:
|
||||
|
||||
@@ -76,7 +76,7 @@ def document_by_cc_pair_cleanup_task(
|
||||
document_id: str,
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> bool:
|
||||
"""A lightweight subtask used to clean up document to cc pair relationships.
|
||||
Created by connection deletion and connector pruning parent tasks."""
|
||||
@@ -297,7 +297,8 @@ def cloud_beat_task_generator(
|
||||
return None
|
||||
|
||||
last_lock_time = time.monotonic()
|
||||
tenant_ids: list[str] | list[None] = []
|
||||
tenant_ids: list[str] = []
|
||||
num_processed_tenants = 0
|
||||
|
||||
try:
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
@@ -325,6 +326,8 @@ def cloud_beat_task_generator(
|
||||
expires=expires,
|
||||
ignore_result=True,
|
||||
)
|
||||
|
||||
num_processed_tenants += 1
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@@ -344,6 +347,7 @@ def cloud_beat_task_generator(
|
||||
task_logger.info(
|
||||
f"cloud_beat_task_generator finished: "
|
||||
f"task={task_name} "
|
||||
f"num_processed_tenants={num_processed_tenants} "
|
||||
f"num_tenants={len(tenant_ids)} "
|
||||
f"elapsed={time_elapsed:.2f}"
|
||||
)
|
||||
|
||||
@@ -76,7 +76,7 @@ logger = setup_logger()
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
def check_for_vespa_sync_task(self: Task, *, tenant_id: str) -> bool | None:
|
||||
"""Runs periodically to check if any document needs syncing.
|
||||
Generates sets of tasks for Celery if syncing is needed."""
|
||||
|
||||
@@ -208,7 +208,7 @@ def try_generate_stale_document_sync_tasks(
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: RedisLock,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> int | None:
|
||||
# the fence is up, do nothing
|
||||
|
||||
@@ -284,7 +284,7 @@ def try_generate_document_set_sync_tasks(
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: RedisLock,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> int | None:
|
||||
lock_beat.reacquire()
|
||||
|
||||
@@ -361,7 +361,7 @@ def try_generate_user_group_sync_tasks(
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: RedisLock,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> int | None:
|
||||
lock_beat.reacquire()
|
||||
|
||||
@@ -448,7 +448,7 @@ def monitor_connector_taskset(r: Redis) -> None:
|
||||
|
||||
|
||||
def monitor_document_set_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
document_set_id_str = RedisDocumentSet.get_id_from_fence_key(fence_key)
|
||||
@@ -523,9 +523,7 @@ def monitor_document_set_taskset(
|
||||
time_limit=LIGHT_TIME_LIMIT,
|
||||
max_retries=3,
|
||||
)
|
||||
def vespa_metadata_sync_task(
|
||||
self: Task, document_id: str, *, tenant_id: str | None
|
||||
) -> bool:
|
||||
def vespa_metadata_sync_task(self: Task, document_id: str, *, tenant_id: str) -> bool:
|
||||
start = time.monotonic()
|
||||
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.UNDEFINED
|
||||
|
||||
@@ -55,6 +55,7 @@ from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.logger import TaskAttemptSingleton
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -67,7 +68,6 @@ def _get_connector_runner(
|
||||
batch_size: int,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
tenant_id: str | None,
|
||||
leave_connector_active: bool = LEAVE_CONNECTOR_ACTIVE_ON_INITIALIZATION_FAILURE,
|
||||
) -> ConnectorRunner:
|
||||
"""
|
||||
@@ -86,7 +86,6 @@ def _get_connector_runner(
|
||||
input_type=task,
|
||||
connector_specific_config=attempt.connector_credential_pair.connector.connector_specific_config,
|
||||
credential=attempt.connector_credential_pair.credential,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
# validate the connector settings
|
||||
@@ -241,7 +240,7 @@ def _check_failure_threshold(
|
||||
def _run_indexing(
|
||||
db_session: Session,
|
||||
index_attempt_id: int,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -388,7 +387,6 @@ def _run_indexing(
|
||||
batch_size=INDEX_BATCH_SIZE,
|
||||
start_time=window_start,
|
||||
end_time=window_end,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
# don't use a checkpoint if we're explicitly indexing from
|
||||
@@ -681,7 +679,7 @@ def _run_indexing(
|
||||
|
||||
def run_indexing_entrypoint(
|
||||
index_attempt_id: int,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
connector_credential_pair_id: int,
|
||||
is_ee: bool = False,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
@@ -701,7 +699,7 @@ def run_indexing_entrypoint(
|
||||
attempt = transition_attempt_to_in_progress(index_attempt_id, db_session)
|
||||
|
||||
tenant_str = ""
|
||||
if tenant_id is not None:
|
||||
if MULTI_TENANT:
|
||||
tenant_str = f" for tenant {tenant_id}"
|
||||
|
||||
connector_name = attempt.connector_credential_pair.connector.name
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import cast
|
||||
from onyx.auth.schemas import AuthBackend
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import DocumentIndexType
|
||||
from onyx.configs.constants import QueryHistoryType
|
||||
from onyx.file_processing.enums import HtmlBasedConnectorTransformLinksStrategy
|
||||
|
||||
#####
|
||||
@@ -29,6 +30,9 @@ GENERATIVE_MODEL_ACCESS_CHECK_FREQ = int(
|
||||
) # 1 day
|
||||
DISABLE_GENERATIVE_AI = os.environ.get("DISABLE_GENERATIVE_AI", "").lower() == "true"
|
||||
|
||||
ONYX_QUERY_HISTORY_TYPE = QueryHistoryType(
|
||||
(os.environ.get("ONYX_QUERY_HISTORY_TYPE") or QueryHistoryType.NORMAL.value).lower()
|
||||
)
|
||||
|
||||
#####
|
||||
# Web Configs
|
||||
|
||||
@@ -213,6 +213,12 @@ class AuthType(str, Enum):
|
||||
CLOUD = "cloud"
|
||||
|
||||
|
||||
class QueryHistoryType(str, Enum):
|
||||
DISABLED = "disabled"
|
||||
ANONYMIZED = "anonymized"
|
||||
NORMAL = "normal"
|
||||
|
||||
|
||||
# Special characters for password validation
|
||||
PASSWORD_SPECIAL_CHARS = "!@#$%^&*()_+-=[]{}|;:,.<>?"
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import DocumentSourceRequiringTenantContext
|
||||
from onyx.connectors.airtable.airtable_connector import AirtableConnector
|
||||
from onyx.connectors.asana.connector import AsanaConnector
|
||||
from onyx.connectors.axero.connector import AxeroConnector
|
||||
@@ -164,13 +163,9 @@ def instantiate_connector(
|
||||
input_type: InputType,
|
||||
connector_specific_config: dict[str, Any],
|
||||
credential: Credential,
|
||||
tenant_id: str | None = None,
|
||||
) -> BaseConnector:
|
||||
connector_class = identify_connector_class(source, input_type)
|
||||
|
||||
if source in DocumentSourceRequiringTenantContext:
|
||||
connector_specific_config["tenant_id"] = tenant_id
|
||||
|
||||
connector = connector_class(**connector_specific_config)
|
||||
new_credentials = connector.load_credentials(credential.credential_json)
|
||||
|
||||
@@ -184,7 +179,6 @@ def validate_ccpair_for_user(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
db_session: Session,
|
||||
tenant_id: str | None,
|
||||
enforce_creation: bool = True,
|
||||
) -> bool:
|
||||
if INTEGRATION_TESTS_MODE:
|
||||
@@ -216,7 +210,6 @@ def validate_ccpair_for_user(
|
||||
input_type=connector.input_type,
|
||||
connector_specific_config=connector.connector_specific_config,
|
||||
credential=credential,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except ConnectorValidationError as e:
|
||||
raise e
|
||||
|
||||
@@ -16,7 +16,7 @@ from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.file_processing.extract_file_text import detect_encoding
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
@@ -27,8 +27,6 @@ from onyx.file_processing.extract_file_text import read_pdf_file
|
||||
from onyx.file_processing.extract_file_text import read_text_file
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -165,12 +163,10 @@ class LocalFileConnector(LoadConnector):
|
||||
def __init__(
|
||||
self,
|
||||
file_locations: list[Path | str],
|
||||
tenant_id: str = POSTGRES_DEFAULT_SCHEMA,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
self.file_locations = [Path(file_location) for file_location in file_locations]
|
||||
self.batch_size = batch_size
|
||||
self.tenant_id = tenant_id
|
||||
self.pdf_pass: str | None = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
@@ -179,9 +175,8 @@ class LocalFileConnector(LoadConnector):
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
documents: list[Document] = []
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(self.tenant_id)
|
||||
|
||||
with get_session_with_tenant(tenant_id=self.tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for file_path in self.file_locations:
|
||||
current_datetime = datetime.now(timezone.utc)
|
||||
files = _read_files_and_metadata(
|
||||
@@ -203,8 +198,6 @@ class LocalFileConnector(LoadConnector):
|
||||
if documents:
|
||||
yield documents
|
||||
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
connector = LocalFileConnector(file_locations=[os.environ["TEST_FILE"]])
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import io
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
import openpyxl # type: ignore
|
||||
from googleapiclient.discovery import build # type: ignore
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
|
||||
@@ -43,12 +45,15 @@ def _extract_sections_basic(
|
||||
) -> list[Section]:
|
||||
mime_type = file["mimeType"]
|
||||
link = file["webViewLink"]
|
||||
supported_file_types = set(item.value for item in GDriveMimeType)
|
||||
|
||||
if mime_type not in set(item.value for item in GDriveMimeType):
|
||||
if mime_type not in supported_file_types:
|
||||
# Unsupported file types can still have a title, finding this way is still useful
|
||||
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
|
||||
|
||||
try:
|
||||
# ---------------------------
|
||||
# Google Sheets extraction
|
||||
if mime_type == GDriveMimeType.SPREADSHEET.value:
|
||||
try:
|
||||
sheets_service = build(
|
||||
@@ -109,7 +114,53 @@ def _extract_sections_basic(
|
||||
f"Ran into exception '{e}' when pulling data from Google Sheet '{file['name']}'."
|
||||
" Falling back to basic extraction."
|
||||
)
|
||||
# ---------------------------
|
||||
# Microsoft Excel (.xlsx or .xls) extraction branch
|
||||
elif mime_type in [
|
||||
GDriveMimeType.SPREADSHEET_OPEN_FORMAT.value,
|
||||
GDriveMimeType.SPREADSHEET_MS_EXCEL.value,
|
||||
]:
|
||||
try:
|
||||
response = service.files().get_media(fileId=file["id"]).execute()
|
||||
|
||||
with NamedTemporaryFile(suffix=".xlsx", delete=True) as tmp:
|
||||
tmp.write(response)
|
||||
tmp_path = tmp.name
|
||||
|
||||
section_separator = "\n\n"
|
||||
workbook = openpyxl.load_workbook(tmp_path, read_only=True)
|
||||
|
||||
# Work similarly to the xlsx_to_text function used for file connector
|
||||
# but returns Sections instead of a string
|
||||
sections = [
|
||||
Section(
|
||||
link=link,
|
||||
text=(
|
||||
f"Sheet: {sheet.title}\n\n"
|
||||
+ section_separator.join(
|
||||
",".join(map(str, row))
|
||||
for row in sheet.iter_rows(
|
||||
min_row=1, values_only=True
|
||||
)
|
||||
if row
|
||||
)
|
||||
),
|
||||
)
|
||||
for sheet in workbook.worksheets
|
||||
]
|
||||
|
||||
return sections
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error extracting data from Excel file '{file['name']}': {e}"
|
||||
)
|
||||
return [
|
||||
Section(link=link, text="Error extracting data from Excel file")
|
||||
]
|
||||
|
||||
# ---------------------------
|
||||
# Export for Google Docs, PPT, and fallback for spreadsheets
|
||||
if mime_type in [
|
||||
GDriveMimeType.DOC.value,
|
||||
GDriveMimeType.PPT.value,
|
||||
@@ -128,6 +179,8 @@ def _extract_sections_basic(
|
||||
)
|
||||
return [Section(link=link, text=text)]
|
||||
|
||||
# ---------------------------
|
||||
# Plain text and Markdown files
|
||||
elif mime_type in [
|
||||
GDriveMimeType.PLAIN_TEXT.value,
|
||||
GDriveMimeType.MARKDOWN.value,
|
||||
@@ -141,6 +194,8 @@ def _extract_sections_basic(
|
||||
.decode("utf-8"),
|
||||
)
|
||||
]
|
||||
# ---------------------------
|
||||
# Word, PowerPoint, PDF files
|
||||
if mime_type in [
|
||||
GDriveMimeType.WORD_DOC.value,
|
||||
GDriveMimeType.POWERPOINT.value,
|
||||
@@ -170,7 +225,11 @@ def _extract_sections_basic(
|
||||
Section(link=link, text=pptx_to_text(file=io.BytesIO(response)))
|
||||
]
|
||||
|
||||
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
|
||||
# Catch-all case, should not happen since there should be specific handling
|
||||
# for each of the supported file types
|
||||
error_message = f"Unsupported file type: {mime_type}"
|
||||
logger.error(error_message)
|
||||
raise ValueError(error_message)
|
||||
|
||||
except Exception:
|
||||
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
|
||||
|
||||
@@ -5,6 +5,10 @@ from typing import Any
|
||||
class GDriveMimeType(str, Enum):
|
||||
DOC = "application/vnd.google-apps.document"
|
||||
SPREADSHEET = "application/vnd.google-apps.spreadsheet"
|
||||
SPREADSHEET_OPEN_FORMAT = (
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
||||
)
|
||||
SPREADSHEET_MS_EXCEL = "application/vnd.ms-excel"
|
||||
PDF = "application/pdf"
|
||||
WORD_DOC = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
PPT = "application/vnd.google-apps.presentation"
|
||||
|
||||
@@ -16,7 +16,6 @@ from onyx.configs.constants import UNNAMED_KEY_PLACEHOLDER
|
||||
from onyx.db.models import ApiKey
|
||||
from onyx.db.models import User
|
||||
from onyx.server.api_key.models import APIKeyArgs
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
@@ -73,7 +72,7 @@ def insert_api_key(
|
||||
# Get tenant_id from context var (will be default schema for single tenant)
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
api_key = generate_api_key(tenant_id if MULTI_TENANT else None)
|
||||
api_key = generate_api_key(tenant_id)
|
||||
api_key_user_id = uuid.uuid4()
|
||||
|
||||
display_name = api_key_args.name or UNNAMED_KEY_PLACEHOLDER
|
||||
|
||||
@@ -258,11 +258,11 @@ class SqlEngine:
|
||||
cls._engine = None
|
||||
|
||||
|
||||
def get_all_tenant_ids() -> list[str] | list[None]:
|
||||
def get_all_tenant_ids() -> list[str]:
|
||||
"""Returning [None] means the only tenant is the 'public' or self hosted tenant."""
|
||||
|
||||
if not MULTI_TENANT:
|
||||
return [None]
|
||||
return [POSTGRES_DEFAULT_SCHEMA]
|
||||
|
||||
with get_session_with_shared_schema() as session:
|
||||
result = session.execute(
|
||||
@@ -417,7 +417,7 @@ def get_session_with_shared_schema() -> Generator[Session, None, None]:
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_session_with_tenant(*, tenant_id: str | None) -> Generator[Session, None, None]:
|
||||
def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]:
|
||||
"""
|
||||
Generate a database session for a specific tenant.
|
||||
"""
|
||||
|
||||
@@ -100,9 +100,14 @@ def _add_user_filters(
|
||||
.correlate(Persona)
|
||||
)
|
||||
else:
|
||||
where_clause |= Persona.is_public == True # noqa: E712
|
||||
where_clause &= Persona.is_visible == True # noqa: E712
|
||||
# Group the public persona conditions
|
||||
public_condition = (Persona.is_public == True) & ( # noqa: E712
|
||||
Persona.is_visible == True # noqa: E712
|
||||
)
|
||||
|
||||
where_clause |= public_condition
|
||||
where_clause |= Persona__User.user_id == user.id
|
||||
|
||||
where_clause |= Persona.user_id == user.id
|
||||
|
||||
return stmt.where(where_clause)
|
||||
|
||||
@@ -81,7 +81,7 @@ def translate_boost_count_to_multiplier(boost: int) -> float:
|
||||
# Vespa's Document API.
|
||||
def get_document_chunk_ids(
|
||||
enriched_document_info_list: list[EnrichedDocumentIndexingInfo],
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
large_chunks_enabled: bool,
|
||||
) -> list[UUID]:
|
||||
doc_chunk_ids = []
|
||||
@@ -139,7 +139,7 @@ def get_uuid_from_chunk_info(
|
||||
*,
|
||||
document_id: str,
|
||||
chunk_id: int,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
large_chunk_id: int | None = None,
|
||||
) -> UUID:
|
||||
"""NOTE: be VERY carefuly about changing this function. If changed without a migration,
|
||||
@@ -154,7 +154,7 @@ def get_uuid_from_chunk_info(
|
||||
"large_" + str(large_chunk_id) if large_chunk_id is not None else str(chunk_id)
|
||||
)
|
||||
unique_identifier_string = "_".join([doc_str, chunk_index])
|
||||
if tenant_id and MULTI_TENANT:
|
||||
if MULTI_TENANT:
|
||||
unique_identifier_string += "_" + tenant_id
|
||||
|
||||
uuid_value = uuid.uuid5(uuid.NAMESPACE_X500, unique_identifier_string)
|
||||
|
||||
@@ -43,7 +43,7 @@ class IndexBatchParams:
|
||||
|
||||
doc_id_to_previous_chunk_cnt: dict[str, int | None]
|
||||
doc_id_to_new_chunk_cnt: dict[str, int]
|
||||
tenant_id: str | None
|
||||
tenant_id: str
|
||||
large_chunks_enabled: bool
|
||||
|
||||
|
||||
@@ -222,7 +222,7 @@ class Deletable(abc.ABC):
|
||||
self,
|
||||
doc_id: str,
|
||||
*,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
chunk_count: int | None,
|
||||
) -> int:
|
||||
"""
|
||||
@@ -249,7 +249,7 @@ class Updatable(abc.ABC):
|
||||
self,
|
||||
doc_id: str,
|
||||
*,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
chunk_count: int | None,
|
||||
fields: VespaDocumentFields,
|
||||
) -> int:
|
||||
@@ -270,9 +270,7 @@ class Updatable(abc.ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def update(
|
||||
self, update_requests: list[UpdateRequest], *, tenant_id: str | None
|
||||
) -> None:
|
||||
def update(self, update_requests: list[UpdateRequest], *, tenant_id: str) -> None:
|
||||
"""
|
||||
Updates some set of chunks. The document and fields to update are specified in the update
|
||||
requests. Each update request in the list applies its changes to a list of document ids.
|
||||
|
||||
@@ -468,9 +468,7 @@ class VespaIndex(DocumentIndex):
|
||||
failure_msg = f"Failed to update document: {future_to_document_id[future]}"
|
||||
raise requests.HTTPError(failure_msg) from e
|
||||
|
||||
def update(
|
||||
self, update_requests: list[UpdateRequest], *, tenant_id: str | None
|
||||
) -> None:
|
||||
def update(self, update_requests: list[UpdateRequest], *, tenant_id: str) -> None:
|
||||
logger.debug(f"Updating {len(update_requests)} documents in Vespa")
|
||||
|
||||
# Handle Vespa character limitations
|
||||
@@ -618,7 +616,7 @@ class VespaIndex(DocumentIndex):
|
||||
doc_id: str,
|
||||
*,
|
||||
chunk_count: int | None,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
fields: VespaDocumentFields,
|
||||
) -> int:
|
||||
"""Note: if the document id does not exist, the update will be a no-op and the
|
||||
@@ -661,7 +659,7 @@ class VespaIndex(DocumentIndex):
|
||||
self,
|
||||
doc_id: str,
|
||||
*,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
chunk_count: int | None,
|
||||
) -> int:
|
||||
total_chunks_deleted = 0
|
||||
|
||||
@@ -158,8 +158,8 @@ def index_doc_batch_with_handler(
|
||||
document_batch: list[Document],
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
db_session: Session,
|
||||
tenant_id: str,
|
||||
ignore_time_skip: bool = False,
|
||||
tenant_id: str | None = None,
|
||||
) -> IndexingPipelineResult:
|
||||
try:
|
||||
index_pipeline_result = index_doc_batch(
|
||||
@@ -317,8 +317,8 @@ def index_doc_batch(
|
||||
document_index: DocumentIndex,
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
db_session: Session,
|
||||
tenant_id: str,
|
||||
ignore_time_skip: bool = False,
|
||||
tenant_id: str | None = None,
|
||||
filter_fnc: Callable[[list[Document]], list[Document]] = filter_documents,
|
||||
) -> IndexingPipelineResult:
|
||||
"""Takes different pieces of the indexing pipeline and applies it to a batch of documents
|
||||
@@ -525,9 +525,9 @@ def build_indexing_pipeline(
|
||||
embedder: IndexingEmbedder,
|
||||
document_index: DocumentIndex,
|
||||
db_session: Session,
|
||||
tenant_id: str,
|
||||
chunker: Chunker | None = None,
|
||||
ignore_time_skip: bool = False,
|
||||
tenant_id: str | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> IndexingPipelineProtocol:
|
||||
"""Builds a pipeline which takes in a list (batch) of docs and indexes them."""
|
||||
|
||||
@@ -84,7 +84,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
|
||||
negative -> ranked lower.
|
||||
"""
|
||||
|
||||
tenant_id: str | None = None
|
||||
tenant_id: str
|
||||
access: "DocumentAccess"
|
||||
document_sets: set[str]
|
||||
boost: int
|
||||
@@ -96,7 +96,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
|
||||
access: "DocumentAccess",
|
||||
document_sets: set[str],
|
||||
boost: int,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> "DocMetadataAwareIndexChunk":
|
||||
index_chunk_data = index_chunk.model_dump()
|
||||
return cls(
|
||||
|
||||
@@ -219,7 +219,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
|
||||
# If we are multi-tenant, we need to only set up initial public tables
|
||||
with Session(engine) as db_session:
|
||||
setup_onyx(db_session, None)
|
||||
setup_onyx(db_session, POSTGRES_DEFAULT_SCHEMA)
|
||||
else:
|
||||
setup_multitenant_onyx()
|
||||
|
||||
|
||||
@@ -410,7 +410,7 @@ def _build_qa_response_blocks(
|
||||
|
||||
|
||||
def _build_continue_in_web_ui_block(
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
message_id: int | None,
|
||||
) -> Block:
|
||||
if message_id is None:
|
||||
@@ -482,7 +482,7 @@ def build_follow_up_resolved_blocks(
|
||||
|
||||
def build_slack_response_blocks(
|
||||
answer: ChatOnyxBotResponse,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
message_info: SlackMessageInfo,
|
||||
channel_conf: ChannelConfig | None,
|
||||
use_citations: bool,
|
||||
|
||||
@@ -151,7 +151,7 @@ def handle_slack_feedback(
|
||||
user_id_to_post_confirmation: str,
|
||||
channel_id_to_post_confirmation: str,
|
||||
thread_ts_to_post_confirmation: str,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
message_id, doc_id, doc_rank = decompose_action_id(feedback_id)
|
||||
|
||||
|
||||
@@ -109,7 +109,7 @@ def handle_message(
|
||||
slack_channel_config: SlackChannelConfig,
|
||||
client: WebClient,
|
||||
feedback_reminder_id: str | None,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> bool:
|
||||
"""Potentially respond to the user message depending on filters and if an answer was generated
|
||||
|
||||
|
||||
@@ -72,7 +72,7 @@ def handle_regular_answer(
|
||||
channel: str,
|
||||
logger: OnyxLoggingAdapter,
|
||||
feedback_reminder_id: str | None,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
num_retries: int = DANSWER_BOT_NUM_RETRIES,
|
||||
thread_context_percent: float = MAX_THREAD_CONTEXT_PERCENTAGE,
|
||||
should_respond_with_error_msgs: bool = DANSWER_BOT_DISPLAY_ERROR_MSGS,
|
||||
|
||||
@@ -123,13 +123,13 @@ _OFFICIAL_SLACKBOT_USER_ID = "USLACKBOT"
|
||||
class SlackbotHandler:
|
||||
def __init__(self) -> None:
|
||||
logger.info("Initializing SlackbotHandler")
|
||||
self.tenant_ids: Set[str | None] = set()
|
||||
self.tenant_ids: Set[str] = set()
|
||||
# The keys for these dictionaries are tuples of (tenant_id, slack_bot_id)
|
||||
self.socket_clients: Dict[tuple[str | None, int], TenantSocketModeClient] = {}
|
||||
self.slack_bot_tokens: Dict[tuple[str | None, int], SlackBotTokens] = {}
|
||||
self.socket_clients: Dict[tuple[str, int], TenantSocketModeClient] = {}
|
||||
self.slack_bot_tokens: Dict[tuple[str, int], SlackBotTokens] = {}
|
||||
|
||||
# Store Redis lock objects here so we can release them properly
|
||||
self.redis_locks: Dict[str | None, Lock] = {}
|
||||
self.redis_locks: Dict[str, Lock] = {}
|
||||
|
||||
self.running = True
|
||||
self.pod_id = self.get_pod_id()
|
||||
@@ -193,7 +193,7 @@ class SlackbotHandler:
|
||||
self._shutdown_event.wait(timeout=TENANT_HEARTBEAT_INTERVAL)
|
||||
|
||||
def _manage_clients_per_tenant(
|
||||
self, db_session: Session, tenant_id: str | None, bot: SlackBot
|
||||
self, db_session: Session, tenant_id: str, bot: SlackBot
|
||||
) -> None:
|
||||
"""
|
||||
- If the tokens are missing or empty, close the socket client and remove them.
|
||||
@@ -385,7 +385,7 @@ class SlackbotHandler:
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
def _remove_tenant(self, tenant_id: str | None) -> None:
|
||||
def _remove_tenant(self, tenant_id: str) -> None:
|
||||
"""
|
||||
Helper to remove a tenant from `self.tenant_ids` and close any socket clients.
|
||||
(Lock release now happens in `acquire_tenants()`, not here.)
|
||||
@@ -415,7 +415,7 @@ class SlackbotHandler:
|
||||
)
|
||||
|
||||
def start_socket_client(
|
||||
self, slack_bot_id: int, tenant_id: str | None, slack_bot_tokens: SlackBotTokens
|
||||
self, slack_bot_id: int, tenant_id: str, slack_bot_tokens: SlackBotTokens
|
||||
) -> None:
|
||||
socket_client: TenantSocketModeClient = _get_socket_client(
|
||||
slack_bot_tokens, tenant_id, slack_bot_id
|
||||
@@ -912,7 +912,7 @@ def create_process_slack_event() -> (
|
||||
|
||||
|
||||
def _get_socket_client(
|
||||
slack_bot_tokens: SlackBotTokens, tenant_id: str | None, slack_bot_id: int
|
||||
slack_bot_tokens: SlackBotTokens, tenant_id: str, slack_bot_id: int
|
||||
) -> TenantSocketModeClient:
|
||||
# For more info on how to set this up, checkout the docs:
|
||||
# https://docs.onyx.app/slack_bot_setup
|
||||
|
||||
@@ -570,7 +570,7 @@ def read_slack_thread(
|
||||
|
||||
|
||||
def slack_usage_report(
|
||||
action: str, sender_id: str | None, client: WebClient, tenant_id: str | None
|
||||
action: str, sender_id: str | None, client: WebClient, tenant_id: str
|
||||
) -> None:
|
||||
if DISABLE_TELEMETRY:
|
||||
return
|
||||
@@ -663,9 +663,7 @@ def get_feedback_visibility() -> FeedbackVisibility:
|
||||
|
||||
|
||||
class TenantSocketModeClient(SocketModeClient):
|
||||
def __init__(
|
||||
self, tenant_id: str | None, slack_bot_id: int, *args: Any, **kwargs: Any
|
||||
):
|
||||
def __init__(self, tenant_id: str, slack_bot_id: int, *args: Any, **kwargs: Any):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.tenant_id = tenant_id
|
||||
self.slack_bot_id = slack_bot_id
|
||||
|
||||
@@ -16,10 +16,10 @@ class RedisConnector:
|
||||
"""Composes several classes to simplify interacting with a connector and its
|
||||
associated background tasks / associated redis interactions."""
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: int) -> None:
|
||||
def __init__(self, tenant_id: str, id: int) -> None:
|
||||
"""id: a connector credential pair id"""
|
||||
|
||||
self.tenant_id: str | None = tenant_id
|
||||
self.tenant_id: str = tenant_id
|
||||
self.id: int = id
|
||||
self.redis: redis.Redis = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
|
||||
PREFIX = "connectorsync"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: int) -> None:
|
||||
def __init__(self, tenant_id: str, id: int) -> None:
|
||||
super().__init__(tenant_id, str(id))
|
||||
|
||||
# documents that should be skipped
|
||||
@@ -60,7 +60,7 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: RedisLock,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> tuple[int, int] | None:
|
||||
"""We can limit the number of tasks generated here, which is useful to prevent
|
||||
one tenant from overwhelming the sync queue.
|
||||
|
||||
@@ -39,8 +39,8 @@ class RedisConnectorDelete:
|
||||
ACTIVE_PREFIX = PREFIX + "_active"
|
||||
ACTIVE_TTL = 3600
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str | None = tenant_id
|
||||
def __init__(self, tenant_id: str, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str = tenant_id
|
||||
self.id = id
|
||||
self.redis = redis
|
||||
|
||||
|
||||
@@ -52,8 +52,8 @@ class RedisConnectorPermissionSync:
|
||||
ACTIVE_PREFIX = PREFIX + "_active"
|
||||
ACTIVE_TTL = CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT * 2
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str | None = tenant_id
|
||||
def __init__(self, tenant_id: str, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str = tenant_id
|
||||
self.id = id
|
||||
self.redis = redis
|
||||
|
||||
|
||||
@@ -44,8 +44,8 @@ class RedisConnectorExternalGroupSync:
|
||||
ACTIVE_PREFIX = PREFIX + "_active"
|
||||
ACTIVE_TTL = 3600
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str | None = tenant_id
|
||||
def __init__(self, tenant_id: str, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str = tenant_id
|
||||
self.id = id
|
||||
self.redis = redis
|
||||
|
||||
|
||||
@@ -52,12 +52,12 @@ class RedisConnectorIndex:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
id: int,
|
||||
search_settings_id: int,
|
||||
redis: redis.Redis,
|
||||
) -> None:
|
||||
self.tenant_id: str | None = tenant_id
|
||||
self.tenant_id: str = tenant_id
|
||||
self.id = id
|
||||
self.search_settings_id = search_settings_id
|
||||
self.redis = redis
|
||||
|
||||
@@ -52,8 +52,8 @@ class RedisConnectorPrune:
|
||||
ACTIVE_PREFIX = PREFIX + "_active"
|
||||
ACTIVE_TTL = CELERY_PRUNING_LOCK_TIMEOUT * 2
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str | None = tenant_id
|
||||
def __init__(self, tenant_id: str, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str = tenant_id
|
||||
self.id = id
|
||||
self.redis = redis
|
||||
|
||||
|
||||
@@ -13,8 +13,8 @@ class RedisConnectorStop:
|
||||
TIMEOUT_PREFIX = f"{PREFIX}_timeout"
|
||||
TIMEOUT_TTL = 300
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str | None = tenant_id
|
||||
def __init__(self, tenant_id: str, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str = tenant_id
|
||||
self.id: int = id
|
||||
self.redis = redis
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ class RedisDocumentSet(RedisObjectHelper):
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: int) -> None:
|
||||
def __init__(self, tenant_id: str, id: int) -> None:
|
||||
super().__init__(tenant_id, str(id))
|
||||
|
||||
@property
|
||||
@@ -58,7 +58,7 @@ class RedisDocumentSet(RedisObjectHelper):
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: RedisLock,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> tuple[int, int] | None:
|
||||
"""Max tasks is ignored for now until we can build the logic to mark the
|
||||
document set up to date over multiple batches.
|
||||
|
||||
@@ -14,8 +14,8 @@ class RedisObjectHelper(ABC):
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: str):
|
||||
self._tenant_id: str | None = tenant_id
|
||||
def __init__(self, tenant_id: str, id: str):
|
||||
self._tenant_id: str = tenant_id
|
||||
self._id: str = id
|
||||
self.redis = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
@@ -87,7 +87,7 @@ class RedisObjectHelper(ABC):
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: RedisLock,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> tuple[int, int] | None:
|
||||
"""First element should be the number of actual tasks generated, second should
|
||||
be the number of docs that were candidates to be synced for the cc pair.
|
||||
|
||||
@@ -24,7 +24,7 @@ class RedisUserGroup(RedisObjectHelper):
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: int) -> None:
|
||||
def __init__(self, tenant_id: str, id: int) -> None:
|
||||
super().__init__(tenant_id, str(id))
|
||||
|
||||
@property
|
||||
@@ -59,7 +59,7 @@ class RedisUserGroup(RedisObjectHelper):
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: RedisLock,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> tuple[int, int] | None:
|
||||
"""Max tasks is ignored for now until we can build the logic to mark the
|
||||
user group up to date over multiple batches.
|
||||
|
||||
@@ -37,13 +37,15 @@ from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.server.documents.models import ConnectorBase
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _create_indexable_chunks(
|
||||
preprocessed_docs: list[dict],
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> tuple[list[Document], list[DocMetadataAwareIndexChunk]]:
|
||||
ids_to_documents = {}
|
||||
chunks = []
|
||||
@@ -86,7 +88,7 @@ def _create_indexable_chunks(
|
||||
mini_chunk_embeddings=[],
|
||||
),
|
||||
title_embedding=preprocessed_doc["title_embedding"],
|
||||
tenant_id=tenant_id,
|
||||
tenant_id=tenant_id if MULTI_TENANT else POSTGRES_DEFAULT_SCHEMA,
|
||||
access=default_public_access,
|
||||
document_sets=set(),
|
||||
boost=DEFAULT_BOOST,
|
||||
@@ -111,7 +113,7 @@ def load_processed_docs(cohere_enabled: bool) -> list[dict]:
|
||||
|
||||
|
||||
def seed_initial_documents(
|
||||
db_session: Session, tenant_id: str | None, cohere_enabled: bool = False
|
||||
db_session: Session, tenant_id: str, cohere_enabled: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Seed initial documents so users don't have an empty index to start
|
||||
|
||||
@@ -620,7 +620,7 @@ def associate_credential_to_connector(
|
||||
)
|
||||
|
||||
try:
|
||||
validate_ccpair_for_user(connector_id, credential_id, db_session, tenant_id)
|
||||
validate_ccpair_for_user(connector_id, credential_id, db_session)
|
||||
|
||||
response = add_credential_to_connector(
|
||||
db_session=db_session,
|
||||
|
||||
@@ -902,7 +902,6 @@ def create_connector_with_mock_credential(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
response = add_credential_to_connector(
|
||||
db_session=db_session,
|
||||
|
||||
@@ -18,7 +18,6 @@ from onyx.db.credentials import fetch_credentials_by_source_for_user
|
||||
from onyx.db.credentials import fetch_credentials_for_user
|
||||
from onyx.db.credentials import swap_credentials_connector
|
||||
from onyx.db.credentials import update_credential
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import DocumentSource
|
||||
from onyx.db.models import User
|
||||
@@ -100,13 +99,11 @@ def swap_credentials_for_connector(
|
||||
credential_swap_req: CredentialSwapRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> StatusResponse:
|
||||
validate_ccpair_for_user(
|
||||
credential_swap_req.connector_id,
|
||||
credential_swap_req.new_credential_id,
|
||||
db_session,
|
||||
tenant_id,
|
||||
)
|
||||
|
||||
connector_credential_pair = swap_credentials_connector(
|
||||
|
||||
@@ -4,7 +4,9 @@ from enum import Enum
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.configs.constants import QueryHistoryType
|
||||
from onyx.db.models import Notification as NotificationDBModel
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
|
||||
class PageType(str, Enum):
|
||||
@@ -49,9 +51,10 @@ class Settings(BaseModel):
|
||||
|
||||
temperature_override_enabled: bool | None = False
|
||||
auto_scroll: bool | None = False
|
||||
query_history_type: QueryHistoryType | None = None
|
||||
|
||||
|
||||
class UserSettings(Settings):
|
||||
notifications: list[Notification]
|
||||
needs_reindexing: bool
|
||||
tenant_id: str | None = None
|
||||
tenant_id: str = POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
|
||||
from onyx.configs.constants import KV_SETTINGS_KEY
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
@@ -45,6 +46,7 @@ def load_settings() -> Settings:
|
||||
anonymous_user_enabled = False
|
||||
|
||||
settings.anonymous_user_enabled = anonymous_user_enabled
|
||||
settings.query_history_type = ONYX_QUERY_HISTORY_TYPE
|
||||
return settings
|
||||
|
||||
|
||||
|
||||
@@ -65,7 +65,7 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def setup_onyx(
|
||||
db_session: Session, tenant_id: str | None, cohere_enabled: bool = False
|
||||
db_session: Session, tenant_id: str, cohere_enabled: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Setup Onyx for a particular tenant. In the Single Tenant case, it will set it up for the default schema
|
||||
|
||||
@@ -260,7 +260,7 @@ def get_documents_for_tenant_connector(
|
||||
def search_for_document(
|
||||
index_name: str,
|
||||
document_id: str | None = None,
|
||||
tenant_id: str | None = None,
|
||||
tenant_id: str = POSTGRES_DEFAULT_SCHEMA,
|
||||
max_hits: int | None = 10,
|
||||
) -> List[Dict[str, Any]]:
|
||||
yql_query = f"select * from sources {index_name}"
|
||||
@@ -507,9 +507,9 @@ def get_number_of_chunks_we_think_exist(
|
||||
|
||||
class VespaDebugging:
|
||||
# Class for managing Vespa debugging actions.
|
||||
def __init__(self, tenant_id: str | None = None):
|
||||
def __init__(self, tenant_id: str = POSTGRES_DEFAULT_SCHEMA):
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
self.tenant_id = POSTGRES_DEFAULT_SCHEMA if not tenant_id else tenant_id
|
||||
self.tenant_id = tenant_id
|
||||
self.index_name = get_index_name(self.tenant_id)
|
||||
|
||||
def sample_document_counts(self) -> None:
|
||||
@@ -603,7 +603,7 @@ class VespaDebugging:
|
||||
delete_documents_for_tenant(self.index_name, self.tenant_id, count=count)
|
||||
|
||||
def search_for_document(
|
||||
self, document_id: str | None = None, tenant_id: str | None = None
|
||||
self, document_id: str | None = None, tenant_id: str = POSTGRES_DEFAULT_SCHEMA
|
||||
) -> List[Dict[str, Any]]:
|
||||
return search_for_document(self.index_name, document_id, tenant_id)
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from sqlalchemy.orm import Session
|
||||
from onyx.db.document import delete_documents_complete__no_commit
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
# Modify sys.path
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
@@ -74,7 +75,7 @@ def _unsafe_deletion(
|
||||
for document in documents:
|
||||
document_index.delete_single(
|
||||
doc_id=document.id,
|
||||
tenant_id=None,
|
||||
tenant_id=POSTGRES_DEFAULT_SCHEMA,
|
||||
chunk_count=document.chunk_count,
|
||||
)
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.document_index.document_index_utils import get_multipass_config
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
# makes it so `PYTHONPATH=.` is not required when running this script
|
||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
@@ -96,7 +97,9 @@ def main() -> None:
|
||||
try:
|
||||
print(f"Deleting document {doc_id} in Vespa")
|
||||
chunks_deleted = vespa_index.delete_single(
|
||||
doc_id, tenant_id=None, chunk_count=document.chunk_count
|
||||
doc_id,
|
||||
tenant_id=POSTGRES_DEFAULT_SCHEMA,
|
||||
chunk_count=document.chunk_count,
|
||||
)
|
||||
if chunks_deleted > 0:
|
||||
print(
|
||||
|
||||
@@ -18,5 +18,7 @@ CURRENT_TENANT_ID_CONTEXTVAR: contextvars.ContextVar[
|
||||
def get_current_tenant_id() -> str:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if tenant_id is None:
|
||||
if not MULTI_TENANT:
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
raise RuntimeError("Tenant ID is not set. This should never happen.")
|
||||
return tenant_id
|
||||
|
||||
@@ -87,7 +87,7 @@ def test_confluence_connector_basic(
|
||||
assert len(txt_doc.sections) == 1
|
||||
assert txt_doc.sections[0].text == "small"
|
||||
assert txt_doc.primary_owners
|
||||
assert txt_doc.primary_owners[0].email == "chris@danswer.ai"
|
||||
assert txt_doc.primary_owners[0].email == "chris@onyx.app"
|
||||
assert (
|
||||
txt_doc.sections[0].link
|
||||
== "https://danswerai.atlassian.net/wiki/pages/viewpageattachments.action?pageId=52494430&preview=%2F52494430%2F52527123%2Fsmall-file.txt"
|
||||
|
||||
89
web/package-lock.json
generated
89
web/package-lock.json
generated
@@ -70,6 +70,8 @@
|
||||
"recharts": "^2.13.1",
|
||||
"rehype-katex": "^7.0.1",
|
||||
"rehype-prism-plus": "^2.0.0",
|
||||
"rehype-sanitize": "^6.0.0",
|
||||
"rehype-stringify": "^10.0.1",
|
||||
"remark-gfm": "^4.0.0",
|
||||
"remark-math": "^6.0.0",
|
||||
"semver": "^7.5.4",
|
||||
@@ -11741,6 +11743,54 @@
|
||||
"resolved": "https://registry.npmjs.org/@types/unist/-/unist-2.0.10.tgz",
|
||||
"integrity": "sha512-IfYcSBWE3hLpBg8+X2SEa8LVkJdJEkT2Ese2aaLs3ptGdVtABxndrMaxuFlQ1qdFf9Q5rDvDpxI3WwgvKFAsQA=="
|
||||
},
|
||||
"node_modules/hast-util-sanitize": {
|
||||
"version": "5.0.2",
|
||||
"resolved": "https://registry.npmjs.org/hast-util-sanitize/-/hast-util-sanitize-5.0.2.tgz",
|
||||
"integrity": "sha512-3yTWghByc50aGS7JlGhk61SPenfE/p1oaFeNwkOOyrscaOkMGrcW9+Cy/QAIOBpZxP1yqDIzFMR0+Np0i0+usg==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@types/hast": "^3.0.0",
|
||||
"@ungap/structured-clone": "^1.0.0",
|
||||
"unist-util-position": "^5.0.0"
|
||||
},
|
||||
"funding": {
|
||||
"type": "opencollective",
|
||||
"url": "https://opencollective.com/unified"
|
||||
}
|
||||
},
|
||||
"node_modules/hast-util-to-html": {
|
||||
"version": "9.0.5",
|
||||
"resolved": "https://registry.npmjs.org/hast-util-to-html/-/hast-util-to-html-9.0.5.tgz",
|
||||
"integrity": "sha512-OguPdidb+fbHQSU4Q4ZiLKnzWo8Wwsf5bZfbvu7//a9oTYoqD/fWpe96NuHkoS9h0ccGOTe0C4NGXdtS0iObOw==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@types/hast": "^3.0.0",
|
||||
"@types/unist": "^3.0.0",
|
||||
"ccount": "^2.0.0",
|
||||
"comma-separated-tokens": "^2.0.0",
|
||||
"hast-util-whitespace": "^3.0.0",
|
||||
"html-void-elements": "^3.0.0",
|
||||
"mdast-util-to-hast": "^13.0.0",
|
||||
"property-information": "^7.0.0",
|
||||
"space-separated-tokens": "^2.0.0",
|
||||
"stringify-entities": "^4.0.0",
|
||||
"zwitch": "^2.0.4"
|
||||
},
|
||||
"funding": {
|
||||
"type": "opencollective",
|
||||
"url": "https://opencollective.com/unified"
|
||||
}
|
||||
},
|
||||
"node_modules/hast-util-to-html/node_modules/property-information": {
|
||||
"version": "7.0.0",
|
||||
"resolved": "https://registry.npmjs.org/property-information/-/property-information-7.0.0.tgz",
|
||||
"integrity": "sha512-7D/qOz/+Y4X/rzSB6jKxKUsQnphO046ei8qxG59mtM3RG3DHgTK81HrxrmoDVINJb8NKT5ZsRbwHvQ6B68Iyhg==",
|
||||
"license": "MIT",
|
||||
"funding": {
|
||||
"type": "github",
|
||||
"url": "https://github.com/sponsors/wooorm"
|
||||
}
|
||||
},
|
||||
"node_modules/hast-util-to-jsx-runtime": {
|
||||
"version": "2.3.0",
|
||||
"resolved": "https://registry.npmjs.org/hast-util-to-jsx-runtime/-/hast-util-to-jsx-runtime-2.3.0.tgz",
|
||||
@@ -11919,6 +11969,16 @@
|
||||
"url": "https://opencollective.com/unified"
|
||||
}
|
||||
},
|
||||
"node_modules/html-void-elements": {
|
||||
"version": "3.0.0",
|
||||
"resolved": "https://registry.npmjs.org/html-void-elements/-/html-void-elements-3.0.0.tgz",
|
||||
"integrity": "sha512-bEqo66MRXsUGxWHV5IP0PUiAWwoEjba4VCzg0LjFJBpchPaTfyfCKTG6bc5F8ucKec3q5y6qOdGyYTSBEvhCrg==",
|
||||
"license": "MIT",
|
||||
"funding": {
|
||||
"type": "github",
|
||||
"url": "https://github.com/sponsors/wooorm"
|
||||
}
|
||||
},
|
||||
"node_modules/html-webpack-plugin": {
|
||||
"version": "5.6.3",
|
||||
"resolved": "https://registry.npmjs.org/html-webpack-plugin/-/html-webpack-plugin-5.6.3.tgz",
|
||||
@@ -19125,6 +19185,35 @@
|
||||
"unist-util-visit": "^5.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/rehype-sanitize": {
|
||||
"version": "6.0.0",
|
||||
"resolved": "https://registry.npmjs.org/rehype-sanitize/-/rehype-sanitize-6.0.0.tgz",
|
||||
"integrity": "sha512-CsnhKNsyI8Tub6L4sm5ZFsme4puGfc6pYylvXo1AeqaGbjOYyzNv3qZPwvs0oMJ39eryyeOdmxwUIo94IpEhqg==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@types/hast": "^3.0.0",
|
||||
"hast-util-sanitize": "^5.0.0"
|
||||
},
|
||||
"funding": {
|
||||
"type": "opencollective",
|
||||
"url": "https://opencollective.com/unified"
|
||||
}
|
||||
},
|
||||
"node_modules/rehype-stringify": {
|
||||
"version": "10.0.1",
|
||||
"resolved": "https://registry.npmjs.org/rehype-stringify/-/rehype-stringify-10.0.1.tgz",
|
||||
"integrity": "sha512-k9ecfXHmIPuFVI61B9DeLPN0qFHfawM6RsuX48hoqlaKSF61RskNjSm1lI8PhBEM0MRdLxVVm4WmTqJQccH9mA==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@types/hast": "^3.0.0",
|
||||
"hast-util-to-html": "^9.0.0",
|
||||
"unified": "^11.0.0"
|
||||
},
|
||||
"funding": {
|
||||
"type": "opencollective",
|
||||
"url": "https://opencollective.com/unified"
|
||||
}
|
||||
},
|
||||
"node_modules/relateurl": {
|
||||
"version": "0.2.7",
|
||||
"resolved": "https://registry.npmjs.org/relateurl/-/relateurl-0.2.7.tgz",
|
||||
|
||||
@@ -73,6 +73,8 @@
|
||||
"recharts": "^2.13.1",
|
||||
"rehype-katex": "^7.0.1",
|
||||
"rehype-prism-plus": "^2.0.0",
|
||||
"rehype-sanitize": "^6.0.0",
|
||||
"rehype-stringify": "^10.0.1",
|
||||
"remark-gfm": "^4.0.0",
|
||||
"remark-math": "^6.0.0",
|
||||
"semver": "^7.5.4",
|
||||
|
||||
@@ -4,6 +4,12 @@ export enum ApplicationStatus {
|
||||
ACTIVE = "active",
|
||||
}
|
||||
|
||||
export enum QueryHistoryType {
|
||||
DISABLED = "disabled",
|
||||
ANONYMIZED = "anonymized",
|
||||
NORMAL = "normal",
|
||||
}
|
||||
|
||||
export interface Settings {
|
||||
anonymous_user_enabled: boolean;
|
||||
maximum_chat_retention_days: number | null;
|
||||
@@ -14,6 +20,7 @@ export interface Settings {
|
||||
application_status: ApplicationStatus;
|
||||
auto_scroll: boolean;
|
||||
temperature_override_enabled: boolean;
|
||||
query_history_type: QueryHistoryType;
|
||||
}
|
||||
|
||||
export enum NotificationType {
|
||||
|
||||
@@ -56,6 +56,7 @@ import {
|
||||
Dispatch,
|
||||
SetStateAction,
|
||||
use,
|
||||
useCallback,
|
||||
useContext,
|
||||
useEffect,
|
||||
useLayoutEffect,
|
||||
@@ -131,18 +132,12 @@ import {
|
||||
|
||||
import { getSourceMetadata } from "@/lib/sources";
|
||||
import { UserSettingsModal } from "./modal/UserSettingsModal";
|
||||
import { AlignStartVertical } from "lucide-react";
|
||||
import { AgenticMessage } from "./message/AgenticMessage";
|
||||
import AssistantModal from "../assistants/mine/AssistantModal";
|
||||
import {
|
||||
OperatingSystem,
|
||||
useOperatingSystem,
|
||||
useSidebarShortcut,
|
||||
} from "@/lib/browserUtilities";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { useSidebarShortcut } from "@/lib/browserUtilities";
|
||||
import { ConfirmEntityModal } from "@/components/modals/ConfirmEntityModal";
|
||||
import { MessageChannel } from "node:worker_threads";
|
||||
import { ChatSearchModal } from "./chat_search/ChatSearchModal";
|
||||
import { ErrorBanner } from "./message/Resubmit";
|
||||
|
||||
const TEMP_USER_MESSAGE_ID = -1;
|
||||
const TEMP_ASSISTANT_MESSAGE_ID = -2;
|
||||
@@ -893,24 +888,6 @@ export function ChatPage({
|
||||
);
|
||||
const scrollDist = useRef<number>(0);
|
||||
|
||||
const updateScrollTracking = () => {
|
||||
const scrollDistance =
|
||||
endDivRef?.current?.getBoundingClientRect()?.top! -
|
||||
inputRef?.current?.getBoundingClientRect()?.top!;
|
||||
scrollDist.current = scrollDistance;
|
||||
setAboveHorizon(scrollDist.current > 500);
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
const scrollableDiv = scrollableDivRef.current;
|
||||
if (scrollableDiv) {
|
||||
scrollableDiv.addEventListener("scroll", updateScrollTracking);
|
||||
return () => {
|
||||
scrollableDiv.removeEventListener("scroll", updateScrollTracking);
|
||||
};
|
||||
}
|
||||
}, []);
|
||||
|
||||
const handleInputResize = () => {
|
||||
setTimeout(() => {
|
||||
if (
|
||||
@@ -962,33 +939,12 @@ export function ChatPage({
|
||||
if (isVisible) return;
|
||||
|
||||
// Check if all messages are currently rendered
|
||||
if (currentVisibleRange.end < messageHistory.length) {
|
||||
// Update visible range to include the last messages
|
||||
updateCurrentVisibleRange({
|
||||
start: Math.max(
|
||||
0,
|
||||
messageHistory.length -
|
||||
(currentVisibleRange.end - currentVisibleRange.start)
|
||||
),
|
||||
end: messageHistory.length,
|
||||
mostVisibleMessageId: currentVisibleRange.mostVisibleMessageId,
|
||||
});
|
||||
// If all messages are already rendered, scroll immediately
|
||||
endDivRef.current.scrollIntoView({
|
||||
behavior: fast ? "auto" : "smooth",
|
||||
});
|
||||
|
||||
// Wait for the state update and re-render before scrolling
|
||||
setTimeout(() => {
|
||||
endDivRef.current?.scrollIntoView({
|
||||
behavior: fast ? "auto" : "smooth",
|
||||
});
|
||||
setHasPerformedInitialScroll(true);
|
||||
}, 100);
|
||||
} else {
|
||||
// If all messages are already rendered, scroll immediately
|
||||
endDivRef.current.scrollIntoView({
|
||||
behavior: fast ? "auto" : "smooth",
|
||||
});
|
||||
|
||||
setHasPerformedInitialScroll(true);
|
||||
}
|
||||
setHasPerformedInitialScroll(true);
|
||||
}, 50);
|
||||
|
||||
// Reset waitForScrollRef after 1.5 seconds
|
||||
@@ -1009,11 +965,6 @@ export function ChatPage({
|
||||
handleInputResize();
|
||||
}, [message]);
|
||||
|
||||
// tracks scrolling
|
||||
useEffect(() => {
|
||||
updateScrollTracking();
|
||||
}, [messageHistory]);
|
||||
|
||||
// used for resizing of the document sidebar
|
||||
const masterFlexboxRef = useRef<HTMLDivElement>(null);
|
||||
const [maxDocumentSidebarWidth, setMaxDocumentSidebarWidth] = useState<
|
||||
@@ -1212,6 +1163,7 @@ export function ChatPage({
|
||||
navigatingAway.current = false;
|
||||
let frozenSessionId = currentSessionId();
|
||||
updateCanContinue(false, frozenSessionId);
|
||||
setUncaughtError(null);
|
||||
|
||||
// Mark that we've sent a message for this session in the current page load
|
||||
markSessionMessageSent(frozenSessionId);
|
||||
@@ -1362,6 +1314,7 @@ export function ChatPage({
|
||||
let isStreamingQuestions = true;
|
||||
let includeAgentic = false;
|
||||
let secondLevelMessageId: number | null = null;
|
||||
let isAgentic: boolean = false;
|
||||
|
||||
let initialFetchDetails: null | {
|
||||
user_message_id: number;
|
||||
@@ -1524,6 +1477,9 @@ export function ChatPage({
|
||||
second_level_generating = true;
|
||||
}
|
||||
}
|
||||
if (Object.hasOwn(packet, "is_agentic")) {
|
||||
isAgentic = (packet as any).is_agentic;
|
||||
}
|
||||
|
||||
if (Object.hasOwn(packet, "refined_answer_improvement")) {
|
||||
isImprovement = (packet as RefinedAnswerImprovement)
|
||||
@@ -1557,6 +1513,7 @@ export function ChatPage({
|
||||
);
|
||||
} else if (Object.hasOwn(packet, "sub_question")) {
|
||||
updateChatState("toolBuilding", frozenSessionId);
|
||||
isAgentic = true;
|
||||
is_generating = true;
|
||||
sub_questions = constructSubQuestions(
|
||||
sub_questions,
|
||||
@@ -1757,6 +1714,7 @@ export function ChatPage({
|
||||
sub_questions: sub_questions,
|
||||
second_level_generating: second_level_generating,
|
||||
agentic_docs: agenticDocs,
|
||||
is_agentic: isAgentic,
|
||||
},
|
||||
...(includeAgentic
|
||||
? [
|
||||
@@ -1977,122 +1935,6 @@ export function ChatPage({
|
||||
|
||||
// Virtualization + Scrolling related effects and functions
|
||||
const scrollInitialized = useRef(false);
|
||||
interface VisibleRange {
|
||||
start: number;
|
||||
end: number;
|
||||
mostVisibleMessageId: number | null;
|
||||
}
|
||||
|
||||
const [visibleRange, setVisibleRange] = useState<
|
||||
Map<string | null, VisibleRange>
|
||||
>(() => {
|
||||
const initialRange: VisibleRange = {
|
||||
start: 0,
|
||||
end: BUFFER_COUNT,
|
||||
mostVisibleMessageId: null,
|
||||
};
|
||||
return new Map([[chatSessionIdRef.current, initialRange]]);
|
||||
});
|
||||
|
||||
// Function used to update current visible range. Only method for updating `visibleRange` state.
|
||||
const updateCurrentVisibleRange = (
|
||||
newRange: VisibleRange,
|
||||
forceUpdate?: boolean
|
||||
) => {
|
||||
if (
|
||||
scrollInitialized.current &&
|
||||
visibleRange.get(loadedIdSessionRef.current) == undefined &&
|
||||
!forceUpdate
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
setVisibleRange((prevState) => {
|
||||
const newState = new Map(prevState);
|
||||
newState.set(loadedIdSessionRef.current, newRange);
|
||||
return newState;
|
||||
});
|
||||
};
|
||||
|
||||
// Set first value for visibleRange state on page load / refresh.
|
||||
const initializeVisibleRange = () => {
|
||||
const upToDatemessageHistory = buildLatestMessageChain(
|
||||
currentMessageMap(completeMessageDetail)
|
||||
);
|
||||
|
||||
if (!scrollInitialized.current && upToDatemessageHistory.length > 0) {
|
||||
const newEnd = Math.max(upToDatemessageHistory.length, BUFFER_COUNT);
|
||||
const newStart = Math.max(0, newEnd - BUFFER_COUNT);
|
||||
const newMostVisibleMessageId =
|
||||
upToDatemessageHistory[newEnd - 1]?.messageId;
|
||||
|
||||
updateCurrentVisibleRange(
|
||||
{
|
||||
start: newStart,
|
||||
end: newEnd,
|
||||
mostVisibleMessageId: newMostVisibleMessageId,
|
||||
},
|
||||
true
|
||||
);
|
||||
scrollInitialized.current = true;
|
||||
}
|
||||
};
|
||||
|
||||
const updateVisibleRangeBasedOnScroll = () => {
|
||||
if (!scrollInitialized.current) return;
|
||||
const scrollableDiv = scrollableDivRef.current;
|
||||
if (!scrollableDiv) return;
|
||||
|
||||
const viewportHeight = scrollableDiv.clientHeight;
|
||||
let mostVisibleMessageIndex = -1;
|
||||
|
||||
messageHistory.forEach((message, index) => {
|
||||
const messageElement = document.getElementById(
|
||||
`message-${message.messageId}`
|
||||
);
|
||||
if (messageElement) {
|
||||
const rect = messageElement.getBoundingClientRect();
|
||||
const isVisible = rect.bottom <= viewportHeight && rect.bottom > 0;
|
||||
if (isVisible && index > mostVisibleMessageIndex) {
|
||||
mostVisibleMessageIndex = index;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if (mostVisibleMessageIndex !== -1) {
|
||||
const startIndex = Math.max(0, mostVisibleMessageIndex - BUFFER_COUNT);
|
||||
const endIndex = Math.min(
|
||||
messageHistory.length,
|
||||
mostVisibleMessageIndex + BUFFER_COUNT + 1
|
||||
);
|
||||
|
||||
updateCurrentVisibleRange({
|
||||
start: startIndex,
|
||||
end: endIndex,
|
||||
mostVisibleMessageId: messageHistory[mostVisibleMessageIndex].messageId,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
initializeVisibleRange();
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [router, messageHistory]);
|
||||
|
||||
useLayoutEffect(() => {
|
||||
const scrollableDiv = scrollableDivRef.current;
|
||||
|
||||
const handleScroll = () => {
|
||||
updateVisibleRangeBasedOnScroll();
|
||||
};
|
||||
|
||||
scrollableDiv?.addEventListener("scroll", handleScroll);
|
||||
|
||||
return () => {
|
||||
scrollableDiv?.removeEventListener("scroll", handleScroll);
|
||||
};
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [messageHistory]);
|
||||
|
||||
const imageFileInMessageHistory = useMemo(() => {
|
||||
return messageHistory
|
||||
@@ -2102,11 +1944,6 @@ export function ChatPage({
|
||||
);
|
||||
}, [messageHistory]);
|
||||
|
||||
const currentVisibleRange = visibleRange.get(currentSessionId()) || {
|
||||
start: 0,
|
||||
end: 0,
|
||||
mostVisibleMessageId: null,
|
||||
};
|
||||
useSendMessageToParent();
|
||||
|
||||
useEffect(() => {
|
||||
@@ -2146,6 +1983,15 @@ export function ChatPage({
|
||||
|
||||
const currentPersona = alternativeAssistant || liveAssistant;
|
||||
|
||||
const HORIZON_DISTANCE = 800;
|
||||
const handleScroll = useCallback(() => {
|
||||
const scrollDistance =
|
||||
endDivRef?.current?.getBoundingClientRect()?.top! -
|
||||
inputRef?.current?.getBoundingClientRect()?.top!;
|
||||
scrollDist.current = scrollDistance;
|
||||
setAboveHorizon(scrollDist.current > HORIZON_DISTANCE);
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
const handleSlackChatRedirect = async () => {
|
||||
if (!slackChatId) return;
|
||||
@@ -2217,6 +2063,26 @@ export function ChatPage({
|
||||
const [sharedChatSession, setSharedChatSession] =
|
||||
useState<ChatSession | null>();
|
||||
|
||||
const handleResubmitLastMessage = () => {
|
||||
// Grab the last user-type message
|
||||
const lastUserMsg = messageHistory
|
||||
.slice()
|
||||
.reverse()
|
||||
.find((m) => m.type === "user");
|
||||
if (!lastUserMsg) {
|
||||
setPopup({
|
||||
message: "No previously-submitted user message found.",
|
||||
type: "error",
|
||||
});
|
||||
return;
|
||||
}
|
||||
// We call onSubmit, passing a `messageOverride`
|
||||
onSubmit({
|
||||
messageIdToResend: lastUserMsg.messageId,
|
||||
messageOverride: lastUserMsg.message,
|
||||
});
|
||||
};
|
||||
|
||||
const showShareModal = (chatSession: ChatSession) => {
|
||||
setSharedChatSession(chatSession);
|
||||
};
|
||||
@@ -2596,6 +2462,7 @@ export function ChatPage({
|
||||
{...getRootProps()}
|
||||
>
|
||||
<div
|
||||
onScroll={handleScroll}
|
||||
className={`w-full h-[calc(100vh-160px)] flex flex-col default-scrollbar overflow-y-auto overflow-x-hidden relative`}
|
||||
ref={scrollableDivRef}
|
||||
>
|
||||
@@ -2653,18 +2520,7 @@ export function ChatPage({
|
||||
// NOTE: temporarily removing this to fix the scroll bug
|
||||
// (hasPerformedInitialScroll ? "" : "invisible")
|
||||
>
|
||||
{(messageHistory.length < BUFFER_COUNT
|
||||
? messageHistory
|
||||
: messageHistory.slice(
|
||||
currentVisibleRange.start,
|
||||
currentVisibleRange.end
|
||||
)
|
||||
).map((message, fauxIndex) => {
|
||||
const i =
|
||||
messageHistory.length < BUFFER_COUNT
|
||||
? fauxIndex
|
||||
: fauxIndex + currentVisibleRange.start;
|
||||
|
||||
{messageHistory.map((message, i) => {
|
||||
const messageMap = currentMessageMap(
|
||||
completeMessageDetail
|
||||
);
|
||||
@@ -2809,9 +2665,9 @@ export function ChatPage({
|
||||
: null
|
||||
}
|
||||
>
|
||||
{message.sub_questions &&
|
||||
message.sub_questions.length > 0 ? (
|
||||
{message.is_agentic ? (
|
||||
<AgenticMessage
|
||||
resubmit={handleResubmitLastMessage}
|
||||
error={uncaughtError}
|
||||
isStreamingQuestions={
|
||||
message.isStreamingQuestions ?? false
|
||||
@@ -3159,21 +3015,18 @@ export function ChatPage({
|
||||
currentPersona={liveAssistant}
|
||||
messageId={message.messageId}
|
||||
content={
|
||||
<p className="text-red-700 text-sm my-auto">
|
||||
{message.message}
|
||||
{message.stackTrace && (
|
||||
<span
|
||||
onClick={() =>
|
||||
setStackTraceModalContent(
|
||||
message.stackTrace!
|
||||
)
|
||||
}
|
||||
className="ml-2 cursor-pointer underline"
|
||||
>
|
||||
Show stack trace.
|
||||
</span>
|
||||
)}
|
||||
</p>
|
||||
<ErrorBanner
|
||||
resubmit={handleResubmitLastMessage}
|
||||
error={message.message}
|
||||
showStackTrace={
|
||||
message.stackTrace
|
||||
? () =>
|
||||
setStackTraceModalContent(
|
||||
message.stackTrace!
|
||||
)
|
||||
: undefined
|
||||
}
|
||||
/>
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
|
||||
@@ -103,6 +103,7 @@ export interface Message {
|
||||
overridden_model?: string;
|
||||
stopReason?: StreamStopReason | null;
|
||||
sub_questions?: SubQuestionDetail[] | null;
|
||||
is_agentic?: boolean | null;
|
||||
|
||||
// Streaming only
|
||||
second_level_generating?: boolean;
|
||||
@@ -148,6 +149,7 @@ export interface BackendMessage {
|
||||
comments: any;
|
||||
parentMessageId: number | null;
|
||||
refined_answer_improvement: boolean | null;
|
||||
is_agentic: boolean | null;
|
||||
}
|
||||
|
||||
export interface MessageResponseIDInfo {
|
||||
|
||||
@@ -7,14 +7,9 @@ import React, {
|
||||
useContext,
|
||||
useEffect,
|
||||
useMemo,
|
||||
useRef,
|
||||
useState,
|
||||
} from "react";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipProvider,
|
||||
TooltipTrigger,
|
||||
} from "@/components/ui/tooltip";
|
||||
import ReactMarkdown from "react-markdown";
|
||||
import { OnyxDocument, FilteredOnyxDocument } from "@/lib/search/interfaces";
|
||||
import remarkGfm from "remark-gfm";
|
||||
@@ -54,6 +49,10 @@ import rehypeKatex from "rehype-katex";
|
||||
import "katex/dist/katex.min.css";
|
||||
import SubQuestionsDisplay from "./SubQuestionsDisplay";
|
||||
import { StatusRefinement } from "../Refinement";
|
||||
import { copyAll, handleCopy } from "./copyingUtils";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { RefreshCw } from "lucide-react";
|
||||
import { ErrorBanner, Resubmit } from "./Resubmit";
|
||||
|
||||
export const AgenticMessage = ({
|
||||
isStreamingQuestions,
|
||||
@@ -88,7 +87,9 @@ export const AgenticMessage = ({
|
||||
secondLevelSubquestions,
|
||||
toggleDocDisplay,
|
||||
error,
|
||||
resubmit,
|
||||
}: {
|
||||
resubmit?: () => void;
|
||||
isStreamingQuestions: boolean;
|
||||
isGenerating: boolean;
|
||||
docSidebarToggled?: boolean;
|
||||
@@ -312,6 +313,8 @@ export const AgenticMessage = ({
|
||||
[anchorCallback, paragraphCallback, streamedContent]
|
||||
);
|
||||
|
||||
const markdownRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
const renderedAlternativeMarkdown = useMemo(() => {
|
||||
return (
|
||||
<ReactMarkdown
|
||||
@@ -457,7 +460,6 @@ export const AgenticMessage = ({
|
||||
finalContent.length > 8) ||
|
||||
(files && files.length > 0) ? (
|
||||
<>
|
||||
{/* <FileDisplay files={files || []} /> */}
|
||||
<div className="w-full py-4 flex flex-col gap-4">
|
||||
<div className="flex items-center gap-x-2 px-4">
|
||||
<div className="text-black text-lg font-medium">
|
||||
@@ -492,7 +494,11 @@ export const AgenticMessage = ({
|
||||
|
||||
<div className="px-4">
|
||||
{typeof content === "string" ? (
|
||||
<div className="overflow-x-visible !text-sm max-w-content-max">
|
||||
<div
|
||||
onCopy={(e) => handleCopy(e, markdownRef)}
|
||||
ref={markdownRef}
|
||||
className="overflow-x-visible !text-sm max-w-content-max"
|
||||
>
|
||||
{isViewingInitialAnswer
|
||||
? renderedMarkdown
|
||||
: renderedAlternativeMarkdown}
|
||||
@@ -501,9 +507,7 @@ export const AgenticMessage = ({
|
||||
content
|
||||
)}
|
||||
{error && (
|
||||
<p className="mt-2 text-red-700 text-sm my-auto">
|
||||
{error}
|
||||
</p>
|
||||
<ErrorBanner error={error} resubmit={resubmit} />
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
@@ -511,15 +515,13 @@ export const AgenticMessage = ({
|
||||
) : isComplete ? (
|
||||
error && (
|
||||
<p className="mt-2 mx-4 text-red-700 text-sm my-auto">
|
||||
{error}
|
||||
<ErrorBanner error={error} resubmit={resubmit} />
|
||||
</p>
|
||||
)
|
||||
) : (
|
||||
<>
|
||||
{error && (
|
||||
<p className="mt-2 mx-4 text-red-700 text-sm my-auto">
|
||||
{error}
|
||||
</p>
|
||||
<ErrorBanner error={error} resubmit={resubmit} />
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
@@ -558,7 +560,16 @@ export const AgenticMessage = ({
|
||||
)}
|
||||
</div>
|
||||
<CustomTooltip showTick line content="Copy">
|
||||
<CopyButton content={content.toString()} />
|
||||
<CopyButton
|
||||
copyAllFn={() =>
|
||||
copyAll(
|
||||
(isViewingInitialAnswer
|
||||
? finalContent
|
||||
: finalAlternativeContent) as string,
|
||||
markdownRef
|
||||
)
|
||||
}
|
||||
/>
|
||||
</CustomTooltip>
|
||||
<CustomTooltip showTick line content="Good response">
|
||||
<HoverableIcon
|
||||
@@ -644,7 +655,16 @@ export const AgenticMessage = ({
|
||||
)}
|
||||
</div>
|
||||
<CustomTooltip showTick line content="Copy">
|
||||
<CopyButton content={content.toString()} />
|
||||
<CopyButton
|
||||
copyAllFn={() =>
|
||||
copyAll(
|
||||
(isViewingInitialAnswer
|
||||
? finalContent
|
||||
: finalAlternativeContent) as string,
|
||||
markdownRef
|
||||
)
|
||||
}
|
||||
/>
|
||||
</CustomTooltip>
|
||||
|
||||
<CustomTooltip showTick line content="Good response">
|
||||
|
||||
@@ -16,15 +16,16 @@ import React, {
|
||||
useRef,
|
||||
useState,
|
||||
} from "react";
|
||||
import { unified } from "unified";
|
||||
import ReactMarkdown from "react-markdown";
|
||||
import { OnyxDocument, FilteredOnyxDocument } from "@/lib/search/interfaces";
|
||||
import { SearchSummary } from "./SearchSummary";
|
||||
import {
|
||||
markdownToHtml,
|
||||
getMarkdownForSelection,
|
||||
} from "@/app/chat/message/codeUtils";
|
||||
import { SkippedSearch } from "./SkippedSearch";
|
||||
import remarkGfm from "remark-gfm";
|
||||
import remarkParse from "remark-parse";
|
||||
import remarkRehype from "remark-rehype";
|
||||
import rehypeSanitize from "rehype-sanitize";
|
||||
import rehypeStringify from "rehype-stringify";
|
||||
import { CopyButton } from "@/components/CopyButton";
|
||||
import { ChatFileType, FileDescriptor, ToolCallMetadata } from "../interfaces";
|
||||
import {
|
||||
@@ -69,6 +70,7 @@ import { SourceCard } from "./SourcesDisplay";
|
||||
import remarkMath from "remark-math";
|
||||
import rehypeKatex from "rehype-katex";
|
||||
import "katex/dist/katex.min.css";
|
||||
import { copyAll, handleCopy } from "./copyingUtils";
|
||||
|
||||
const TOOLS_WITH_CUSTOM_HANDLING = [
|
||||
SEARCH_TOOL_NAME,
|
||||
@@ -364,34 +366,24 @@ export const AIMessage = ({
|
||||
}),
|
||||
[anchorCallback, paragraphCallback, finalContent]
|
||||
);
|
||||
const markdownRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
// Process selection copying with HTML formatting
|
||||
|
||||
const renderedMarkdown = useMemo(() => {
|
||||
if (typeof finalContent !== "string") {
|
||||
return finalContent;
|
||||
}
|
||||
|
||||
// Create a hidden div with the HTML content for copying
|
||||
const htmlContent = markdownToHtml(finalContent);
|
||||
|
||||
return (
|
||||
<>
|
||||
<div
|
||||
style={{
|
||||
position: "absolute",
|
||||
left: "-9999px",
|
||||
display: "none",
|
||||
}}
|
||||
dangerouslySetInnerHTML={{ __html: htmlContent }}
|
||||
/>
|
||||
<ReactMarkdown
|
||||
className="prose dark:prose-invert max-w-full text-base"
|
||||
components={markdownComponents}
|
||||
remarkPlugins={[remarkGfm, remarkMath]}
|
||||
rehypePlugins={[[rehypePrism, { ignoreMissing: true }], rehypeKatex]}
|
||||
>
|
||||
{finalContent}
|
||||
</ReactMarkdown>
|
||||
</>
|
||||
<ReactMarkdown
|
||||
className="prose dark:prose-invert max-w-full text-base"
|
||||
components={markdownComponents}
|
||||
remarkPlugins={[remarkGfm, remarkMath]}
|
||||
rehypePlugins={[[rehypePrism, { ignoreMissing: true }], rehypeKatex]}
|
||||
>
|
||||
{finalContent}
|
||||
</ReactMarkdown>
|
||||
);
|
||||
}, [finalContent, markdownComponents]);
|
||||
|
||||
@@ -535,64 +527,9 @@ export const AIMessage = ({
|
||||
{typeof content === "string" ? (
|
||||
<div className="overflow-x-visible max-w-content-max">
|
||||
<div
|
||||
contentEditable="true"
|
||||
suppressContentEditableWarning
|
||||
ref={markdownRef}
|
||||
className="focus:outline-none cursor-text select-text"
|
||||
style={{
|
||||
MozUserModify: "read-only",
|
||||
WebkitUserModify: "read-only",
|
||||
}}
|
||||
onCopy={(e) => {
|
||||
e.preventDefault();
|
||||
const selection = window.getSelection();
|
||||
const selectedPlainText =
|
||||
selection?.toString() || "";
|
||||
if (!selectedPlainText) {
|
||||
// If no text is selected, copy the full content
|
||||
const contentStr =
|
||||
typeof content === "string"
|
||||
? content
|
||||
: (
|
||||
content as JSX.Element
|
||||
).props?.children?.toString() || "";
|
||||
const clipboardItem = new ClipboardItem({
|
||||
"text/html": new Blob(
|
||||
[
|
||||
typeof content === "string"
|
||||
? markdownToHtml(content)
|
||||
: contentStr,
|
||||
],
|
||||
{ type: "text/html" }
|
||||
),
|
||||
"text/plain": new Blob([contentStr], {
|
||||
type: "text/plain",
|
||||
}),
|
||||
});
|
||||
navigator.clipboard.write([clipboardItem]);
|
||||
return;
|
||||
}
|
||||
|
||||
const contentStr =
|
||||
typeof content === "string"
|
||||
? content
|
||||
: (
|
||||
content as JSX.Element
|
||||
).props?.children?.toString() || "";
|
||||
const markdownText = getMarkdownForSelection(
|
||||
contentStr,
|
||||
selectedPlainText
|
||||
);
|
||||
const clipboardItem = new ClipboardItem({
|
||||
"text/html": new Blob(
|
||||
[markdownToHtml(markdownText)],
|
||||
{ type: "text/html" }
|
||||
),
|
||||
"text/plain": new Blob([selectedPlainText], {
|
||||
type: "text/plain",
|
||||
}),
|
||||
});
|
||||
navigator.clipboard.write([clipboardItem]);
|
||||
}}
|
||||
onCopy={(e) => handleCopy(e, markdownRef)}
|
||||
>
|
||||
{renderedMarkdown}
|
||||
</div>
|
||||
@@ -643,13 +580,8 @@ export const AIMessage = ({
|
||||
</div>
|
||||
<CustomTooltip showTick line content="Copy">
|
||||
<CopyButton
|
||||
content={
|
||||
typeof content === "string"
|
||||
? {
|
||||
html: markdownToHtml(content),
|
||||
plainText: content,
|
||||
}
|
||||
: content.toString()
|
||||
copyAllFn={() =>
|
||||
copyAll(finalContent as string, markdownRef)
|
||||
}
|
||||
/>
|
||||
</CustomTooltip>
|
||||
@@ -734,13 +666,8 @@ export const AIMessage = ({
|
||||
</div>
|
||||
<CustomTooltip showTick line content="Copy">
|
||||
<CopyButton
|
||||
content={
|
||||
typeof content === "string"
|
||||
? {
|
||||
html: markdownToHtml(content),
|
||||
plainText: content,
|
||||
}
|
||||
: content.toString()
|
||||
copyAllFn={() =>
|
||||
copyAll(finalContent as string, markdownRef)
|
||||
}
|
||||
/>
|
||||
</CustomTooltip>
|
||||
|
||||
58
web/src/app/chat/message/Resubmit.tsx
Normal file
58
web/src/app/chat/message/Resubmit.tsx
Normal file
@@ -0,0 +1,58 @@
|
||||
import { Alert, AlertDescription, AlertTitle } from "@/components/ui/alert";
|
||||
import { AlertCircle } from "lucide-react";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { RefreshCw } from "lucide-react";
|
||||
|
||||
interface ResubmitProps {
|
||||
resubmit: () => void;
|
||||
}
|
||||
|
||||
export const Resubmit: React.FC<ResubmitProps> = ({ resubmit }) => {
|
||||
return (
|
||||
<div className="flex flex-col items-center justify-center gap-y-2 mt-4">
|
||||
<p className="text-sm text-neutral-700 dark:text-neutral-300">
|
||||
There was an error with the response.
|
||||
</p>
|
||||
<Button
|
||||
onClick={resubmit}
|
||||
variant="agent"
|
||||
size="sm"
|
||||
className="flex items-center gap-2 text-white font-medium py-2 px-4 rounded"
|
||||
>
|
||||
<RefreshCw className="w-4 h-4" />
|
||||
Regenerate
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export const ErrorBanner = ({
|
||||
error,
|
||||
showStackTrace,
|
||||
resubmit,
|
||||
}: {
|
||||
error: string;
|
||||
showStackTrace?: () => void;
|
||||
resubmit?: () => void;
|
||||
}) => {
|
||||
return (
|
||||
<div className="text-red-700 mt-4 text-sm my-auto">
|
||||
<Alert variant="broken">
|
||||
<AlertCircle className="h-4 w-4" />
|
||||
<AlertTitle>Error</AlertTitle>
|
||||
<AlertDescription className="flex gap-x-2">
|
||||
{error}
|
||||
{showStackTrace && (
|
||||
<span
|
||||
className="text-red-600 hover:text-red-800 cursor-pointer underline"
|
||||
onClick={showStackTrace}
|
||||
>
|
||||
Show stack trace
|
||||
</span>
|
||||
)}
|
||||
</AlertDescription>
|
||||
</Alert>
|
||||
{resubmit && <Resubmit resubmit={resubmit} />}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
@@ -24,6 +24,7 @@ import { CodeBlock } from "./CodeBlock";
|
||||
import { CheckIcon, ChevronDown } from "lucide-react";
|
||||
import { PHASE_MIN_MS, useStreamingMessages } from "./StreamingMessages";
|
||||
import { CirclingArrowIcon } from "@/components/icons/icons";
|
||||
import { handleCopy } from "./copyingUtils";
|
||||
|
||||
export const StatusIndicator = ({ status }: { status: ToggleState }) => {
|
||||
return (
|
||||
@@ -292,6 +293,7 @@ const SubQuestionDisplay: React.FC<{
|
||||
}
|
||||
}, [currentlyClosed]);
|
||||
|
||||
const analysisRef = useRef<HTMLDivElement>(null);
|
||||
const renderedMarkdown = useMemo(() => {
|
||||
return (
|
||||
<ReactMarkdown
|
||||
@@ -428,7 +430,11 @@ const SubQuestionDisplay: React.FC<{
|
||||
/>
|
||||
</div>
|
||||
{analysisToggled && (
|
||||
<div className="flex flex-wrap gap-2">
|
||||
<div
|
||||
ref={analysisRef}
|
||||
onCopy={(e) => handleCopy(e, analysisRef)}
|
||||
className="flex flex-wrap gap-2"
|
||||
>
|
||||
{renderedMarkdown}
|
||||
</div>
|
||||
)}
|
||||
|
||||
@@ -1,176 +0,0 @@
|
||||
import { markdownToHtml, parseMarkdownToSegments } from "../codeUtils";
|
||||
|
||||
describe("markdownToHtml", () => {
|
||||
test("converts bold text with asterisks and underscores", () => {
|
||||
expect(markdownToHtml("This is **bold** text")).toBe(
|
||||
"<p>This is <strong>bold</strong> text</p>"
|
||||
);
|
||||
expect(markdownToHtml("This is __bold__ text")).toBe(
|
||||
"<p>This is <strong>bold</strong> text</p>"
|
||||
);
|
||||
});
|
||||
|
||||
test("converts italic text with asterisks and underscores", () => {
|
||||
expect(markdownToHtml("This is *italic* text")).toBe(
|
||||
"<p>This is <em>italic</em> text</p>"
|
||||
);
|
||||
expect(markdownToHtml("This is _italic_ text")).toBe(
|
||||
"<p>This is <em>italic</em> text</p>"
|
||||
);
|
||||
});
|
||||
|
||||
test("handles mixed bold and italic", () => {
|
||||
expect(markdownToHtml("This is **bold** and *italic* text")).toBe(
|
||||
"<p>This is <strong>bold</strong> and <em>italic</em> text</p>"
|
||||
);
|
||||
expect(markdownToHtml("This is __bold__ and _italic_ text")).toBe(
|
||||
"<p>This is <strong>bold</strong> and <em>italic</em> text</p>"
|
||||
);
|
||||
});
|
||||
|
||||
test("handles text with spaces and special characters", () => {
|
||||
expect(markdownToHtml("This is *as delicious and* tasty")).toBe(
|
||||
"<p>This is <em>as delicious and</em> tasty</p>"
|
||||
);
|
||||
expect(markdownToHtml("This is _as delicious and_ tasty")).toBe(
|
||||
"<p>This is <em>as delicious and</em> tasty</p>"
|
||||
);
|
||||
});
|
||||
|
||||
test("handles multi-paragraph text with italics", () => {
|
||||
const input =
|
||||
"Sure! Here is a sentence with one italicized word:\n\nThe cake was _delicious_ and everyone enjoyed it.";
|
||||
expect(markdownToHtml(input)).toBe(
|
||||
"<p>Sure! Here is a sentence with one italicized word:</p>\n<p>The cake was <em>delicious</em> and everyone enjoyed it.</p>"
|
||||
);
|
||||
});
|
||||
|
||||
test("handles malformed markdown without crashing", () => {
|
||||
expect(markdownToHtml("This is *malformed markdown")).toBe(
|
||||
"<p>This is *malformed markdown</p>"
|
||||
);
|
||||
expect(markdownToHtml("This is _also malformed")).toBe(
|
||||
"<p>This is _also malformed</p>"
|
||||
);
|
||||
expect(markdownToHtml("This has **unclosed bold")).toBe(
|
||||
"<p>This has **unclosed bold</p>"
|
||||
);
|
||||
expect(markdownToHtml("This has __unclosed bold")).toBe(
|
||||
"<p>This has __unclosed bold</p>"
|
||||
);
|
||||
});
|
||||
|
||||
test("handles empty or null input", () => {
|
||||
expect(markdownToHtml("")).toBe("");
|
||||
expect(markdownToHtml(" ")).toBe("");
|
||||
expect(markdownToHtml("\n")).toBe("");
|
||||
});
|
||||
|
||||
test("handles extremely long input without crashing", () => {
|
||||
const longText = "This is *italic* ".repeat(1000);
|
||||
expect(() => markdownToHtml(longText)).not.toThrow();
|
||||
});
|
||||
});
|
||||
|
||||
describe("parseMarkdownToSegments", () => {
|
||||
test("parses italic text with asterisks", () => {
|
||||
const segments = parseMarkdownToSegments("This is *italic* text");
|
||||
expect(segments).toEqual([
|
||||
{ type: "text", text: "This is ", raw: "This is ", length: 8 },
|
||||
{ type: "italic", text: "italic", raw: "*italic*", length: 6 },
|
||||
{ type: "text", text: " text", raw: " text", length: 5 },
|
||||
]);
|
||||
});
|
||||
|
||||
test("parses italic text with underscores", () => {
|
||||
const segments = parseMarkdownToSegments("This is _italic_ text");
|
||||
expect(segments).toEqual([
|
||||
{ type: "text", text: "This is ", raw: "This is ", length: 8 },
|
||||
{ type: "italic", text: "italic", raw: "_italic_", length: 6 },
|
||||
{ type: "text", text: " text", raw: " text", length: 5 },
|
||||
]);
|
||||
});
|
||||
|
||||
test("parses bold text with asterisks", () => {
|
||||
const segments = parseMarkdownToSegments("This is **bold** text");
|
||||
expect(segments).toEqual([
|
||||
{ type: "text", text: "This is ", raw: "This is ", length: 8 },
|
||||
{ type: "bold", text: "bold", raw: "**bold**", length: 4 },
|
||||
{ type: "text", text: " text", raw: " text", length: 5 },
|
||||
]);
|
||||
});
|
||||
|
||||
test("parses bold text with underscores", () => {
|
||||
const segments = parseMarkdownToSegments("This is __bold__ text");
|
||||
expect(segments).toEqual([
|
||||
{ type: "text", text: "This is ", raw: "This is ", length: 8 },
|
||||
{ type: "bold", text: "bold", raw: "__bold__", length: 4 },
|
||||
{ type: "text", text: " text", raw: " text", length: 5 },
|
||||
]);
|
||||
});
|
||||
|
||||
test("parses text with spaces and special characters in italics", () => {
|
||||
const segments = parseMarkdownToSegments(
|
||||
"The cake was _delicious_ and everyone enjoyed it."
|
||||
);
|
||||
expect(segments).toEqual([
|
||||
{ type: "text", text: "The cake was ", raw: "The cake was ", length: 13 },
|
||||
{ type: "italic", text: "delicious", raw: "_delicious_", length: 9 },
|
||||
{
|
||||
type: "text",
|
||||
text: " and everyone enjoyed it.",
|
||||
raw: " and everyone enjoyed it.",
|
||||
length: 25,
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
test("parses multi-paragraph text with italics", () => {
|
||||
const segments = parseMarkdownToSegments(
|
||||
"Sure! Here is a sentence with one italicized word:\n\nThe cake was _delicious_ and everyone enjoyed it."
|
||||
);
|
||||
expect(segments).toEqual([
|
||||
{
|
||||
type: "text",
|
||||
text: "Sure! Here is a sentence with one italicized word:\n\nThe cake was ",
|
||||
raw: "Sure! Here is a sentence with one italicized word:\n\nThe cake was ",
|
||||
length: 65,
|
||||
},
|
||||
{ type: "italic", text: "delicious", raw: "_delicious_", length: 9 },
|
||||
{
|
||||
type: "text",
|
||||
text: " and everyone enjoyed it.",
|
||||
raw: " and everyone enjoyed it.",
|
||||
length: 25,
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
test("handles malformed markdown without crashing", () => {
|
||||
expect(() => parseMarkdownToSegments("This is *malformed")).not.toThrow();
|
||||
expect(() =>
|
||||
parseMarkdownToSegments("This is _also malformed")
|
||||
).not.toThrow();
|
||||
expect(() =>
|
||||
parseMarkdownToSegments("This has **unclosed bold")
|
||||
).not.toThrow();
|
||||
expect(() =>
|
||||
parseMarkdownToSegments("This has __unclosed bold")
|
||||
).not.toThrow();
|
||||
});
|
||||
|
||||
test("handles empty or null input", () => {
|
||||
expect(parseMarkdownToSegments("")).toEqual([]);
|
||||
expect(parseMarkdownToSegments(" ")).toEqual([
|
||||
{ type: "text", text: " ", raw: " ", length: 1 },
|
||||
]);
|
||||
expect(parseMarkdownToSegments("\n")).toEqual([
|
||||
{ type: "text", text: "\n", raw: "\n", length: 1 },
|
||||
]);
|
||||
});
|
||||
|
||||
test("handles extremely long input without crashing", () => {
|
||||
const longText = "This is *italic* ".repeat(1000);
|
||||
expect(() => parseMarkdownToSegments(longText)).not.toThrow();
|
||||
});
|
||||
});
|
||||
@@ -83,35 +83,6 @@ export const preprocessLaTeX = (content: string) => {
|
||||
return inlineProcessedContent;
|
||||
};
|
||||
|
||||
export const markdownToHtml = (content: string): string => {
|
||||
if (!content || !content.trim()) {
|
||||
return "";
|
||||
}
|
||||
|
||||
// Basic markdown to HTML conversion for common patterns
|
||||
const processedContent = content
|
||||
.replace(/(\*\*|__)((?:(?!\1).)*?)\1/g, "<strong>$2</strong>") // Bold with ** or __, non-greedy and no nesting
|
||||
.replace(/(\*|_)([^*_\n]+?)\1(?!\*|_)/g, "<em>$2</em>"); // Italic with * or _
|
||||
|
||||
// Handle code blocks and links
|
||||
const withCodeAndLinks = processedContent
|
||||
.replace(/`([^`]+)`/g, "<code>$1</code>") // Inline code
|
||||
.replace(
|
||||
/```(\w*)\n([\s\S]*?)```/g,
|
||||
(_, lang, code) =>
|
||||
`<pre><code class="language-${lang}">${code.trim()}</code></pre>`
|
||||
) // Code blocks
|
||||
.replace(/\[([^\]]+)\]\(([^)]+)\)/g, '<a href="$2">$1</a>'); // Links
|
||||
|
||||
// Handle paragraphs
|
||||
return withCodeAndLinks
|
||||
.split(/\n\n+/)
|
||||
.map((para) => para.trim())
|
||||
.filter((para) => para.length > 0)
|
||||
.map((para) => `<p>${para}</p>`)
|
||||
.join("\n");
|
||||
};
|
||||
|
||||
interface MarkdownSegment {
|
||||
type: "text" | "link" | "code" | "bold" | "italic" | "codeblock";
|
||||
text: string; // The visible/plain text
|
||||
|
||||
71
web/src/app/chat/message/copyingUtils.tsx
Normal file
71
web/src/app/chat/message/copyingUtils.tsx
Normal file
@@ -0,0 +1,71 @@
|
||||
"use client";
|
||||
import { unified } from "unified";
|
||||
import remarkParse from "remark-parse";
|
||||
import remarkGfm from "remark-gfm";
|
||||
import remarkMath from "remark-math";
|
||||
import remarkRehype from "remark-rehype";
|
||||
import rehypePrism from "rehype-prism-plus";
|
||||
import rehypeKatex from "rehype-katex";
|
||||
import rehypeSanitize from "rehype-sanitize";
|
||||
import rehypeStringify from "rehype-stringify";
|
||||
|
||||
export const handleCopy = (
|
||||
e: React.ClipboardEvent,
|
||||
markdownRef: React.RefObject<HTMLDivElement>
|
||||
) => {
|
||||
// Check if we have a selection
|
||||
const selection = window.getSelection();
|
||||
if (!selection?.rangeCount) return;
|
||||
|
||||
const range = selection.getRangeAt(0);
|
||||
|
||||
// If selection is within our markdown container
|
||||
if (
|
||||
markdownRef.current &&
|
||||
markdownRef.current.contains(range.commonAncestorContainer)
|
||||
) {
|
||||
e.preventDefault();
|
||||
|
||||
// Clone selection to get the HTML
|
||||
const fragment = range.cloneContents();
|
||||
const tempDiv = document.createElement("div");
|
||||
tempDiv.appendChild(fragment);
|
||||
|
||||
// Create clipboard data with both HTML and plain text
|
||||
e.clipboardData.setData("text/html", tempDiv.innerHTML);
|
||||
e.clipboardData.setData("text/plain", selection.toString());
|
||||
}
|
||||
};
|
||||
|
||||
// For copying the entire content
|
||||
export const copyAll = (
|
||||
content: string,
|
||||
markdownRef: React.RefObject<HTMLDivElement>
|
||||
) => {
|
||||
if (!markdownRef.current || typeof content !== "string") {
|
||||
return;
|
||||
}
|
||||
|
||||
// Convert markdown to HTML using unified ecosystem
|
||||
unified()
|
||||
.use(remarkParse)
|
||||
.use(remarkGfm)
|
||||
.use(remarkMath)
|
||||
.use(remarkRehype)
|
||||
.use(rehypePrism, { ignoreMissing: true })
|
||||
.use(rehypeKatex)
|
||||
.use(rehypeSanitize)
|
||||
.use(rehypeStringify)
|
||||
.process(content)
|
||||
.then((file: any) => {
|
||||
const htmlContent = String(file);
|
||||
|
||||
// Create clipboard data
|
||||
const clipboardItem = new ClipboardItem({
|
||||
"text/html": new Blob([htmlContent], { type: "text/html" }),
|
||||
"text/plain": new Blob([content], { type: "text/plain" }),
|
||||
});
|
||||
|
||||
navigator.clipboard.write([clipboardItem]);
|
||||
});
|
||||
};
|
||||
@@ -4,33 +4,36 @@ import { CheckmarkIcon, CopyMessageIcon } from "./icons/icons";
|
||||
|
||||
export function CopyButton({
|
||||
content,
|
||||
copyAllFn,
|
||||
onClick,
|
||||
}: {
|
||||
content?: string | { html: string; plainText: string };
|
||||
content?: string;
|
||||
copyAllFn?: () => void;
|
||||
onClick?: () => void;
|
||||
}) {
|
||||
const [copyClicked, setCopyClicked] = useState(false);
|
||||
|
||||
const copyToClipboard = async (
|
||||
content: string | { html: string; plainText: string }
|
||||
) => {
|
||||
const copyToClipboard = async () => {
|
||||
try {
|
||||
// If copyAllFn is provided, use it instead of the default behavior
|
||||
if (copyAllFn) {
|
||||
await copyAllFn();
|
||||
return;
|
||||
}
|
||||
|
||||
// Fall back to original behavior if no copyAllFn is provided
|
||||
if (!content) return;
|
||||
|
||||
const clipboardItem = new ClipboardItem({
|
||||
"text/html": new Blob(
|
||||
[typeof content === "string" ? content : content.html],
|
||||
{ type: "text/html" }
|
||||
),
|
||||
"text/plain": new Blob(
|
||||
[typeof content === "string" ? content : content.plainText],
|
||||
{ type: "text/plain" }
|
||||
),
|
||||
"text/html": new Blob([content], { type: "text/html" }),
|
||||
"text/plain": new Blob([content], { type: "text/plain" }),
|
||||
});
|
||||
await navigator.clipboard.write([clipboardItem]);
|
||||
} catch (err) {
|
||||
// Fallback to basic text copy if HTML copy fails
|
||||
await navigator.clipboard.writeText(
|
||||
typeof content === "string" ? content : content.plainText
|
||||
);
|
||||
if (content) {
|
||||
await navigator.clipboard.writeText(content);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -38,9 +41,7 @@ export function CopyButton({
|
||||
<HoverableIcon
|
||||
icon={copyClicked ? <CheckmarkIcon /> : <CopyMessageIcon />}
|
||||
onClick={() => {
|
||||
if (content) {
|
||||
copyToClipboard(content);
|
||||
}
|
||||
copyToClipboard();
|
||||
onClick && onClick();
|
||||
|
||||
setCopyClicked(true);
|
||||
|
||||
@@ -359,18 +359,25 @@ export function ClientLayout({
|
||||
),
|
||||
link: "/admin/performance/usage",
|
||||
},
|
||||
{
|
||||
name: (
|
||||
<div className="flex">
|
||||
<DatabaseIconSkeleton
|
||||
className="text-text-700"
|
||||
size={18}
|
||||
/>
|
||||
<div className="ml-1">Query History</div>
|
||||
</div>
|
||||
),
|
||||
link: "/admin/performance/query-history",
|
||||
},
|
||||
...(settings?.settings.query_history_type !==
|
||||
"disabled"
|
||||
? [
|
||||
{
|
||||
name: (
|
||||
<div className="flex">
|
||||
<DatabaseIconSkeleton
|
||||
className="text-text-700"
|
||||
size={18}
|
||||
/>
|
||||
<div className="ml-1">
|
||||
Query History
|
||||
</div>
|
||||
</div>
|
||||
),
|
||||
link: "/admin/performance/query-history",
|
||||
},
|
||||
]
|
||||
: []),
|
||||
{
|
||||
name: (
|
||||
<div className="flex">
|
||||
|
||||
@@ -3,6 +3,7 @@ import {
|
||||
EnterpriseSettings,
|
||||
ApplicationStatus,
|
||||
Settings,
|
||||
QueryHistoryType,
|
||||
} from "@/app/admin/settings/interfaces";
|
||||
import {
|
||||
CUSTOM_ANALYTICS_ENABLED,
|
||||
@@ -53,6 +54,7 @@ export async function fetchSettingsSS(): Promise<CombinedSettings | null> {
|
||||
anonymous_user_enabled: false,
|
||||
pro_search_enabled: true,
|
||||
temperature_override_enabled: true,
|
||||
query_history_type: QueryHistoryType.NORMAL,
|
||||
};
|
||||
} else {
|
||||
throw new Error(
|
||||
|
||||
@@ -8,8 +8,10 @@ const alertVariants = cva(
|
||||
{
|
||||
variants: {
|
||||
variant: {
|
||||
broken:
|
||||
"border-red-500/50 text-red-500 dark:border-red-500 [&>svg]:text-red-500 dark:border-red-900/50 dark:text-red-100 dark:dark:border-red-900 dark:[&>svg]:text-red-700 bg-red-50 dark:bg-red-950",
|
||||
ark: "border-amber-500/50 text-amber-500 dark:border-amber-500 [&>svg]:text-amber-500 dark:border-amber-900/50 dark:text-amber-900 dark:dark:border-amber-900 dark:[&>svg]:text-amber-900 bg-amber-50 dark:bg-amber-950",
|
||||
info: "border-black/50 dark:border-black dark:border-black/50 dark:dark:border-black",
|
||||
|
||||
default:
|
||||
"bg-neutral-50 text-neutral-darker dark:bg-neutral-950 dark:text-text",
|
||||
destructive:
|
||||
|
||||
@@ -9,6 +9,8 @@ const buttonVariants = cva(
|
||||
{
|
||||
variants: {
|
||||
variant: {
|
||||
agent:
|
||||
"bg-agent text-white hover:bg-agent-hovered dark:bg-agent dark:text-white dark:hover:bg-agent/90",
|
||||
success:
|
||||
"bg-green-100 text-green-600 hover:bg-green-500/90 dark:bg-green-700 dark:text-green-100 dark:hover:bg-green-600/90",
|
||||
"success-reverse":
|
||||
|
||||
@@ -24,6 +24,8 @@ async function verifyAdminPageNavigation(
|
||||
console.error(
|
||||
`Failed to find h1 with text "${pageTitle}" for path "${path}"`
|
||||
);
|
||||
// NOTE: This is a temporary measure for debugging the issue
|
||||
console.error(await page.content());
|
||||
throw error;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user