mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-18 08:15:48 +00:00
Compare commits
14 Commits
fix_openap
...
user_setti
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0a4ef8ff35 | ||
|
|
a3e3d83b7e | ||
|
|
4dc88ca037 | ||
|
|
11e7e1c4d6 | ||
|
|
f2d74ce540 | ||
|
|
25389c5120 | ||
|
|
ad0721ecd8 | ||
|
|
426a8842ae | ||
|
|
a98dcbc7de | ||
|
|
6f389dc100 | ||
|
|
d56177958f | ||
|
|
0e42ae9024 | ||
|
|
ce2b4de245 | ||
|
|
a515aa78d2 |
@@ -18,12 +18,13 @@ depends_on = None
|
||||
def upgrade() -> None:
|
||||
# Create a basic index on the lowercase message column for direct text matching
|
||||
# Limit to 1500 characters to stay well under the 2856 byte limit of btree version 4
|
||||
op.execute(
|
||||
"""
|
||||
CREATE INDEX idx_chat_message_message_lower
|
||||
ON chat_message (LOWER(substring(message, 1, 1500)))
|
||||
"""
|
||||
)
|
||||
# op.execute(
|
||||
# """
|
||||
# CREATE INDEX idx_chat_message_message_lower
|
||||
# ON chat_message (LOWER(substring(message, 1, 1500)))
|
||||
# """
|
||||
# )
|
||||
pass
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
|
||||
@@ -5,11 +5,9 @@ from onyx.background.celery.apps.primary import celery_app
|
||||
from onyx.background.task_utils import build_celery_task_wrapper
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.db.chat import delete_chat_sessions_older_than
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -18,10 +16,8 @@ logger = setup_logger()
|
||||
|
||||
@build_celery_task_wrapper(name_chat_ttl_task)
|
||||
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
|
||||
def perform_ttl_management_task(
|
||||
retention_limit_days: int, *, tenant_id: str | None
|
||||
) -> None:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
def perform_ttl_management_task(retention_limit_days: int, *, tenant_id: str) -> None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
delete_chat_sessions_older_than(retention_limit_days, db_session)
|
||||
|
||||
|
||||
@@ -35,24 +31,19 @@ def perform_ttl_management_task(
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def check_ttl_management_task(*, tenant_id: str | None) -> None:
|
||||
def check_ttl_management_task(*, tenant_id: str) -> None:
|
||||
"""Runs periodically to check if any ttl tasks should be run and adds them
|
||||
to the queue"""
|
||||
token = None
|
||||
if MULTI_TENANT and tenant_id is not None:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
settings = load_settings()
|
||||
retention_limit_days = settings.maximum_chat_retention_days
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
if should_perform_chat_ttl_check(retention_limit_days, db_session):
|
||||
perform_ttl_management_task.apply_async(
|
||||
kwargs=dict(
|
||||
retention_limit_days=retention_limit_days, tenant_id=tenant_id
|
||||
),
|
||||
)
|
||||
if token is not None:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
@celery_app.task(
|
||||
@@ -60,9 +51,9 @@ def check_ttl_management_task(*, tenant_id: str | None) -> None:
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
)
|
||||
def autogenerate_usage_report_task(*, tenant_id: str | None) -> None:
|
||||
def autogenerate_usage_report_task(*, tenant_id: str) -> None:
|
||||
"""This generates usage report under the /admin/generate-usage/report endpoint"""
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
create_new_usage_report(
|
||||
db_session=db_session,
|
||||
user_id=None,
|
||||
|
||||
@@ -18,7 +18,7 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def monitor_usergroup_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
"""This function is likely to move in the worker refactor happening next."""
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
|
||||
@@ -2,6 +2,7 @@ import csv
|
||||
import io
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from http import HTTPStatus
|
||||
from uuid import UUID
|
||||
|
||||
from fastapi import APIRouter
|
||||
@@ -21,8 +22,10 @@ from ee.onyx.server.query_history.models import QuestionAnswerPairSnapshot
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import get_display_email
|
||||
from onyx.chat.chat_utils import create_chat_chain
|
||||
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.configs.constants import QAFeedbackType
|
||||
from onyx.configs.constants import QueryHistoryType
|
||||
from onyx.configs.constants import SessionType
|
||||
from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_chat_sessions_by_user
|
||||
@@ -35,6 +38,8 @@ from onyx.server.query_and_chat.models import ChatSessionsResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
ONYX_ANONYMIZED_EMAIL = "anonymous@anonymous.invalid"
|
||||
|
||||
|
||||
def fetch_and_process_chat_session_history(
|
||||
db_session: Session,
|
||||
@@ -107,6 +112,17 @@ def get_user_chat_sessions(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatSessionsResponse:
|
||||
# we specifically don't allow this endpoint if "anonymized" since
|
||||
# this is a direct query on the user id
|
||||
if ONYX_QUERY_HISTORY_TYPE in [
|
||||
QueryHistoryType.DISABLED,
|
||||
QueryHistoryType.ANONYMIZED,
|
||||
]:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.FORBIDDEN,
|
||||
detail="Per user query history has been disabled by the administrator.",
|
||||
)
|
||||
|
||||
try:
|
||||
chat_sessions = get_chat_sessions_by_user(
|
||||
user_id=user_id, deleted=False, db_session=db_session, limit=0
|
||||
@@ -141,6 +157,12 @@ def get_chat_session_history(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> PaginatedReturn[ChatSessionMinimal]:
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.DISABLED:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.FORBIDDEN,
|
||||
detail="Query history has been disabled by the administrator.",
|
||||
)
|
||||
|
||||
page_of_chat_sessions = get_page_of_chat_sessions(
|
||||
page_num=page_num,
|
||||
page_size=page_size,
|
||||
@@ -157,11 +179,16 @@ def get_chat_session_history(
|
||||
feedback_filter=feedback_type,
|
||||
)
|
||||
|
||||
minimal_chat_sessions: list[ChatSessionMinimal] = []
|
||||
|
||||
for chat_session in page_of_chat_sessions:
|
||||
minimal_chat_session = ChatSessionMinimal.from_chat_session(chat_session)
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
|
||||
minimal_chat_session.user_email = ONYX_ANONYMIZED_EMAIL
|
||||
minimal_chat_sessions.append(minimal_chat_session)
|
||||
|
||||
return PaginatedReturn(
|
||||
items=[
|
||||
ChatSessionMinimal.from_chat_session(chat_session)
|
||||
for chat_session in page_of_chat_sessions
|
||||
],
|
||||
items=minimal_chat_sessions,
|
||||
total_items=total_filtered_chat_sessions_count,
|
||||
)
|
||||
|
||||
@@ -172,6 +199,12 @@ def get_chat_session_admin(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatSessionSnapshot:
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.DISABLED:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.FORBIDDEN,
|
||||
detail="Query history has been disabled by the administrator.",
|
||||
)
|
||||
|
||||
try:
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=chat_session_id,
|
||||
@@ -193,6 +226,9 @@ def get_chat_session_admin(
|
||||
f"Could not create snapshot for chat session with id '{chat_session_id}'",
|
||||
)
|
||||
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
|
||||
snapshot.user_email = ONYX_ANONYMIZED_EMAIL
|
||||
|
||||
return snapshot
|
||||
|
||||
|
||||
@@ -203,6 +239,12 @@ def get_query_history_as_csv(
|
||||
end: datetime | None = None,
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StreamingResponse:
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.DISABLED:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.FORBIDDEN,
|
||||
detail="Query history has been disabled by the administrator.",
|
||||
)
|
||||
|
||||
complete_chat_session_history = fetch_and_process_chat_session_history(
|
||||
db_session=db_session,
|
||||
start=start or datetime.fromtimestamp(0, tz=timezone.utc),
|
||||
@@ -213,6 +255,9 @@ def get_query_history_as_csv(
|
||||
|
||||
question_answer_pairs: list[QuestionAnswerPairSnapshot] = []
|
||||
for chat_session_snapshot in complete_chat_session_history:
|
||||
if ONYX_QUERY_HISTORY_TYPE == QueryHistoryType.ANONYMIZED:
|
||||
chat_session_snapshot.user_email = ONYX_ANONYMIZED_EMAIL
|
||||
|
||||
question_answer_pairs.extend(
|
||||
QuestionAnswerPairSnapshot.from_chat_session_snapshot(chat_session_snapshot)
|
||||
)
|
||||
|
||||
@@ -28,7 +28,7 @@ def get_tenant_id_for_email(email: str) -> str:
|
||||
|
||||
|
||||
def user_owns_a_tenant(email: str) -> bool:
|
||||
with get_session_with_tenant(tenant_id=None) as db_session:
|
||||
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
result = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(UserTenantMapping.email == email)
|
||||
@@ -38,7 +38,7 @@ def user_owns_a_tenant(email: str) -> bool:
|
||||
|
||||
|
||||
def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
with get_session_with_tenant(tenant_id=None) as db_session:
|
||||
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
try:
|
||||
for email in emails:
|
||||
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
|
||||
@@ -48,7 +48,7 @@ def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
|
||||
|
||||
def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
with get_session_with_tenant(tenant_id=None) as db_session:
|
||||
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
try:
|
||||
mappings_to_delete = (
|
||||
db_session.query(UserTenantMapping)
|
||||
@@ -71,7 +71,7 @@ def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
|
||||
|
||||
def remove_all_users_from_tenant(tenant_id: str) -> None:
|
||||
with get_session_with_tenant(tenant_id=None) as db_session:
|
||||
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
db_session.query(UserTenantMapping).filter(
|
||||
UserTenantMapping.tenant_id == tenant_id
|
||||
).delete()
|
||||
|
||||
@@ -10,6 +10,7 @@ from pydantic import BaseModel
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.app_configs import API_KEY_HASH_ROUNDS
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
|
||||
_API_KEY_HEADER_NAME = "Authorization"
|
||||
@@ -35,8 +36,7 @@ class ApiKeyDescriptor(BaseModel):
|
||||
|
||||
|
||||
def generate_api_key(tenant_id: str | None = None) -> str:
|
||||
# For backwards compatibility, if no tenant_id, generate old style key
|
||||
if not tenant_id:
|
||||
if not MULTI_TENANT or not tenant_id:
|
||||
return _API_KEY_PREFIX + secrets.token_urlsafe(_API_KEY_LEN)
|
||||
|
||||
encoded_tenant = quote(tenant_id) # URL encode the tenant ID
|
||||
|
||||
@@ -2,6 +2,8 @@ import smtplib
|
||||
from datetime import datetime
|
||||
from email.mime.multipart import MIMEMultipart
|
||||
from email.mime.text import MIMEText
|
||||
from email.utils import formatdate
|
||||
from email.utils import make_msgid
|
||||
|
||||
from onyx.configs.app_configs import EMAIL_CONFIGURED
|
||||
from onyx.configs.app_configs import EMAIL_FROM
|
||||
@@ -13,6 +15,7 @@ from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import TENANT_ID_COOKIE_NAME
|
||||
from onyx.db.models import User
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
HTML_EMAIL_TEMPLATE = """\
|
||||
<!DOCTYPE html>
|
||||
@@ -150,8 +153,9 @@ def send_email(
|
||||
msg = MIMEMultipart("alternative")
|
||||
msg["Subject"] = subject
|
||||
msg["To"] = user_email
|
||||
if mail_from:
|
||||
msg["From"] = mail_from
|
||||
msg["From"] = mail_from
|
||||
msg["Date"] = formatdate(localtime=True)
|
||||
msg["Message-ID"] = make_msgid(domain="onyx.app")
|
||||
|
||||
part_text = MIMEText(text_body, "plain")
|
||||
part_html = MIMEText(html_body, "html")
|
||||
@@ -173,7 +177,7 @@ def send_subscription_cancellation_email(user_email: str) -> None:
|
||||
subject = "Your Onyx Subscription Has Been Canceled"
|
||||
heading = "Subscription Canceled"
|
||||
message = (
|
||||
"<p>We’re sorry to see you go.</p>"
|
||||
"<p>We're sorry to see you go.</p>"
|
||||
"<p>Your subscription has been canceled and will end on your next billing date.</p>"
|
||||
"<p>If you change your mind, you can always come back!</p>"
|
||||
)
|
||||
@@ -239,13 +243,13 @@ def send_user_email_invite(
|
||||
def send_forgot_password_email(
|
||||
user_email: str,
|
||||
token: str,
|
||||
tenant_id: str,
|
||||
mail_from: str = EMAIL_FROM,
|
||||
tenant_id: str | None = None,
|
||||
) -> None:
|
||||
# Builds a forgot password email with or without fancy HTML
|
||||
subject = "Onyx Forgot Password"
|
||||
link = f"{WEB_DOMAIN}/auth/reset-password?token={token}"
|
||||
if tenant_id:
|
||||
if MULTI_TENANT:
|
||||
link += f"&{TENANT_ID_COOKIE_NAME}={tenant_id}"
|
||||
message = f"<p>Click the following link to reset your password:</p><p>{link}</p>"
|
||||
html_content = build_html_email("Reset Your Password", message)
|
||||
|
||||
@@ -214,7 +214,7 @@ def verify_email_is_invited(email: str) -> None:
|
||||
raise PermissionError("User not on allowed user whitelist")
|
||||
|
||||
|
||||
def verify_email_in_whitelist(email: str, tenant_id: str | None = None) -> None:
|
||||
def verify_email_in_whitelist(email: str, tenant_id: str) -> None:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
if not get_user_by_email(email, db_session):
|
||||
verify_email_is_invited(email)
|
||||
@@ -420,7 +420,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
except exceptions.UserNotExists:
|
||||
try:
|
||||
# Attempt to get user by email
|
||||
user = await self.get_by_email(account_email)
|
||||
user = cast(User, await self.user_db.get_by_email(account_email))
|
||||
if not associate_by_email:
|
||||
raise exceptions.UserAlreadyExists()
|
||||
|
||||
@@ -553,7 +553,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
async_return_default_schema,
|
||||
)(email=user.email)
|
||||
|
||||
send_forgot_password_email(user.email, token, tenant_id=tenant_id)
|
||||
send_forgot_password_email(user.email, tenant_id=tenant_id, token=token)
|
||||
|
||||
async def on_after_request_verify(
|
||||
self, user: User, token: str, request: Optional[Request] = None
|
||||
|
||||
@@ -2,6 +2,7 @@ import logging
|
||||
import multiprocessing
|
||||
import time
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import sentry_sdk
|
||||
from celery import Task
|
||||
@@ -131,9 +132,9 @@ def on_task_postrun(
|
||||
# Get tenant_id directly from kwargs- each celery task has a tenant_id kwarg
|
||||
if not kwargs:
|
||||
logger.error(f"Task {task.name} (ID: {task_id}) is missing kwargs")
|
||||
tenant_id = None
|
||||
tenant_id = POSTGRES_DEFAULT_SCHEMA
|
||||
else:
|
||||
tenant_id = kwargs.get("tenant_id")
|
||||
tenant_id = cast(str, kwargs.get("tenant_id", POSTGRES_DEFAULT_SCHEMA))
|
||||
|
||||
task_logger.debug(
|
||||
f"Task {task.name} (ID: {task_id}) completed with state: {state} "
|
||||
|
||||
@@ -34,7 +34,7 @@ def _get_deletion_status(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
db_session: Session,
|
||||
tenant_id: str | None = None,
|
||||
tenant_id: str,
|
||||
) -> TaskQueueState | None:
|
||||
"""We no longer store TaskQueueState in the DB for a deletion attempt.
|
||||
This function populates TaskQueueState by just checking redis.
|
||||
@@ -67,7 +67,7 @@ def get_deletion_attempt_snapshot(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
db_session: Session,
|
||||
tenant_id: str | None = None,
|
||||
tenant_id: str,
|
||||
) -> DeletionAttemptSnapshot | None:
|
||||
deletion_task = _get_deletion_status(
|
||||
connector_id, credential_id, db_session, tenant_id
|
||||
|
||||
@@ -109,9 +109,7 @@ def revoke_tasks_blocking_deletion(
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_connector_deletion_task(
|
||||
self: Task, *, tenant_id: str | None
|
||||
) -> bool | None:
|
||||
def check_for_connector_deletion_task(self: Task, *, tenant_id: str) -> bool | None:
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
@@ -224,7 +222,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
cc_pair_id: int,
|
||||
db_session: Session,
|
||||
lock_beat: RedisLock,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> int | None:
|
||||
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
|
||||
Note that syncing can still be required even if the number of sync tasks generated is zero.
|
||||
@@ -345,7 +343,7 @@ def try_generate_document_cc_pair_cleanup_tasks(
|
||||
|
||||
|
||||
def monitor_connector_deletion_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis
|
||||
tenant_id: str, key_bytes: bytes, r: Redis
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
@@ -500,7 +498,7 @@ def monitor_connector_deletion_taskset(
|
||||
|
||||
|
||||
def validate_connector_deletion_fences(
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
r: Redis,
|
||||
r_replica: Redis,
|
||||
r_celery: Redis,
|
||||
@@ -540,7 +538,7 @@ def validate_connector_deletion_fences(
|
||||
|
||||
|
||||
def validate_connector_deletion_fence(
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
key_bytes: bytes,
|
||||
queued_tasks: set[str],
|
||||
r: Redis,
|
||||
|
||||
@@ -221,7 +221,7 @@ def try_creating_permissions_sync_task(
|
||||
app: Celery,
|
||||
cc_pair_id: int,
|
||||
r: Redis,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> str | None:
|
||||
"""Returns a randomized payload id on success.
|
||||
Returns None if no syncing is required."""
|
||||
@@ -320,7 +320,7 @@ def try_creating_permissions_sync_task(
|
||||
def connector_permission_sync_generator_task(
|
||||
self: Task,
|
||||
cc_pair_id: int,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
Permission sync task that handles document permission syncing for a given connector credential pair
|
||||
@@ -410,7 +410,6 @@ def connector_permission_sync_generator_task(
|
||||
cc_pair.connector.id,
|
||||
cc_pair.credential.id,
|
||||
db_session,
|
||||
tenant_id,
|
||||
enforce_creation=False,
|
||||
)
|
||||
if not created:
|
||||
@@ -510,7 +509,7 @@ def connector_permission_sync_generator_task(
|
||||
)
|
||||
def update_external_document_permissions_task(
|
||||
self: Task,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
serialized_doc_external_access: dict,
|
||||
source_string: str,
|
||||
connector_id: int,
|
||||
@@ -585,7 +584,7 @@ def update_external_document_permissions_task(
|
||||
|
||||
|
||||
def validate_permission_sync_fences(
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
r: Redis,
|
||||
r_replica: Redis,
|
||||
r_celery: Redis,
|
||||
@@ -632,7 +631,7 @@ def validate_permission_sync_fences(
|
||||
|
||||
|
||||
def validate_permission_sync_fence(
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
key_bytes: bytes,
|
||||
queued_tasks: set[str],
|
||||
reserved_tasks: set[str],
|
||||
@@ -842,7 +841,7 @@ class PermissionSyncCallback(IndexingHeartbeatInterface):
|
||||
|
||||
|
||||
def monitor_ccpair_permissions_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
|
||||
@@ -123,7 +123,7 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
def check_for_external_group_sync(self: Task, *, tenant_id: str) -> bool | None:
|
||||
# we need to use celery's redis client to access its redis data
|
||||
# (which lives on a different db number)
|
||||
r = get_redis_client()
|
||||
@@ -220,7 +220,7 @@ def try_creating_external_group_sync_task(
|
||||
app: Celery,
|
||||
cc_pair_id: int,
|
||||
r: Redis,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> str | None:
|
||||
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
|
||||
Returns None if no syncing is required."""
|
||||
@@ -306,7 +306,7 @@ def try_creating_external_group_sync_task(
|
||||
def connector_external_group_sync_generator_task(
|
||||
self: Task,
|
||||
cc_pair_id: int,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
"""
|
||||
External group sync task for a given connector credential pair
|
||||
@@ -392,7 +392,6 @@ def connector_external_group_sync_generator_task(
|
||||
cc_pair.connector.id,
|
||||
cc_pair.credential.id,
|
||||
db_session,
|
||||
tenant_id,
|
||||
enforce_creation=False,
|
||||
)
|
||||
if not created:
|
||||
@@ -494,7 +493,7 @@ def connector_external_group_sync_generator_task(
|
||||
|
||||
|
||||
def validate_external_group_sync_fences(
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
celery_app: Celery,
|
||||
r: Redis,
|
||||
r_replica: Redis,
|
||||
@@ -526,7 +525,7 @@ def validate_external_group_sync_fences(
|
||||
|
||||
|
||||
def validate_external_group_sync_fence(
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
key_bytes: bytes,
|
||||
reserved_tasks: set[str],
|
||||
r_celery: Redis,
|
||||
|
||||
@@ -182,7 +182,7 @@ class SimpleJobResult:
|
||||
|
||||
|
||||
class ConnectorIndexingContext(BaseModel):
|
||||
tenant_id: str | None
|
||||
tenant_id: str
|
||||
cc_pair_id: int
|
||||
search_settings_id: int
|
||||
index_attempt_id: int
|
||||
@@ -210,7 +210,7 @@ class ConnectorIndexingLogBuilder:
|
||||
|
||||
|
||||
def monitor_ccpair_indexing_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
# if the fence doesn't exist, there's nothing to do
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
@@ -358,7 +358,7 @@ def monitor_ccpair_indexing_taskset(
|
||||
soft_time_limit=300,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_indexing(self: Task, *, tenant_id: str | None) -> int | None:
|
||||
def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
|
||||
"""a lightweight task used to kick off indexing tasks.
|
||||
Occcasionally does some validation of existing state to clear up error conditions"""
|
||||
|
||||
@@ -598,7 +598,7 @@ def connector_indexing_task(
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
is_ee: bool,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> int | None:
|
||||
"""Indexing task. For a cc pair, this task pulls all document IDs from the source
|
||||
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
|
||||
@@ -890,7 +890,7 @@ def connector_indexing_proxy_task(
|
||||
index_attempt_id: int,
|
||||
cc_pair_id: int,
|
||||
search_settings_id: int,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
"""celery out of process task execution strategy is pool=prefork, but it uses fork,
|
||||
and forking is inherently unstable.
|
||||
@@ -1170,7 +1170,7 @@ def connector_indexing_proxy_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_CHECKPOINT_CLEANUP,
|
||||
soft_time_limit=300,
|
||||
)
|
||||
def check_for_checkpoint_cleanup(*, tenant_id: str | None) -> None:
|
||||
def check_for_checkpoint_cleanup(*, tenant_id: str) -> None:
|
||||
"""Clean up old checkpoints that are older than 7 days."""
|
||||
locked = False
|
||||
redis_client = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
@@ -187,7 +187,7 @@ class IndexingCallback(IndexingCallbackBase):
|
||||
|
||||
|
||||
def validate_indexing_fence(
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
key_bytes: bytes,
|
||||
reserved_tasks: set[str],
|
||||
r_celery: Redis,
|
||||
@@ -311,7 +311,7 @@ def validate_indexing_fence(
|
||||
|
||||
|
||||
def validate_indexing_fences(
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
r_replica: Redis,
|
||||
r_celery: Redis,
|
||||
lock_beat: RedisLock,
|
||||
@@ -442,7 +442,7 @@ def try_creating_indexing_task(
|
||||
reindex: bool,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> int | None:
|
||||
"""Checks for any conditions that should block the indexing task from being
|
||||
created, then creates the task.
|
||||
|
||||
@@ -59,7 +59,7 @@ def _process_model_list_response(model_list_json: Any) -> list[str]:
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_llm_model_update(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
def check_for_llm_model_update(self: Task, *, tenant_id: str) -> bool | None:
|
||||
if not LLM_MODEL_UPDATE_API_URL:
|
||||
raise ValueError("LLM model update API URL not configured")
|
||||
|
||||
|
||||
@@ -91,7 +91,7 @@ class Metric(BaseModel):
|
||||
}
|
||||
task_logger.info(json.dumps(data))
|
||||
|
||||
def emit(self, tenant_id: str | None) -> None:
|
||||
def emit(self, tenant_id: str) -> None:
|
||||
# Convert value to appropriate type based on the input value
|
||||
bool_value = None
|
||||
float_value = None
|
||||
@@ -656,7 +656,7 @@ def build_job_id(
|
||||
queue=OnyxCeleryQueues.MONITORING,
|
||||
bind=True,
|
||||
)
|
||||
def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
|
||||
def monitor_background_processes(self: Task, *, tenant_id: str) -> None:
|
||||
"""Collect and emit metrics about background processes.
|
||||
This task runs periodically to gather metrics about:
|
||||
- Queue lengths for different Celery queues
|
||||
@@ -864,7 +864,7 @@ def cloud_monitor_celery_queues(
|
||||
|
||||
|
||||
@shared_task(name=OnyxCeleryTask.MONITOR_CELERY_QUEUES, ignore_result=True, bind=True)
|
||||
def monitor_celery_queues(self: Task, *, tenant_id: str | None) -> None:
|
||||
def monitor_celery_queues(self: Task, *, tenant_id: str) -> None:
|
||||
return monitor_celery_queues_helper(self)
|
||||
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ from onyx.db.engine import get_session_with_current_tenant
|
||||
bind=True,
|
||||
base=AbortableTask,
|
||||
)
|
||||
def kombu_message_cleanup_task(self: Any, tenant_id: str | None) -> int:
|
||||
def kombu_message_cleanup_task(self: Any, tenant_id: str) -> int:
|
||||
"""Runs periodically to clean up the kombu_message table"""
|
||||
|
||||
# we will select messages older than this amount to clean up
|
||||
|
||||
@@ -114,7 +114,7 @@ def _is_pruning_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_pruning(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
def check_for_pruning(self: Task, *, tenant_id: str) -> bool | None:
|
||||
r = get_redis_client()
|
||||
r_replica = get_redis_replica_client()
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
@@ -211,7 +211,7 @@ def try_creating_prune_generator_task(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> str | None:
|
||||
"""Checks for any conditions that should block the pruning generator task from being
|
||||
created, then creates the task.
|
||||
@@ -333,7 +333,7 @@ def connector_pruning_generator_task(
|
||||
cc_pair_id: int,
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
"""connector pruning task. For a cc pair, this task pulls all document IDs from the source
|
||||
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
|
||||
@@ -521,7 +521,7 @@ def connector_pruning_generator_task(
|
||||
|
||||
|
||||
def monitor_ccpair_pruning_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
@@ -567,7 +567,7 @@ def monitor_ccpair_pruning_taskset(
|
||||
|
||||
|
||||
def validate_pruning_fences(
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
r: Redis,
|
||||
r_replica: Redis,
|
||||
r_celery: Redis,
|
||||
@@ -615,7 +615,7 @@ def validate_pruning_fences(
|
||||
|
||||
|
||||
def validate_pruning_fence(
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
key_bytes: bytes,
|
||||
reserved_tasks: set[str],
|
||||
queued_tasks: set[str],
|
||||
|
||||
@@ -32,7 +32,7 @@ class RetryDocumentIndex:
|
||||
self,
|
||||
doc_id: str,
|
||||
*,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
chunk_count: int | None,
|
||||
) -> int:
|
||||
return self.index.delete_single(
|
||||
@@ -50,7 +50,7 @@ class RetryDocumentIndex:
|
||||
self,
|
||||
doc_id: str,
|
||||
*,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
chunk_count: int | None,
|
||||
fields: VespaDocumentFields,
|
||||
) -> int:
|
||||
|
||||
@@ -76,7 +76,7 @@ def document_by_cc_pair_cleanup_task(
|
||||
document_id: str,
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> bool:
|
||||
"""A lightweight subtask used to clean up document to cc pair relationships.
|
||||
Created by connection deletion and connector pruning parent tasks."""
|
||||
@@ -297,7 +297,8 @@ def cloud_beat_task_generator(
|
||||
return None
|
||||
|
||||
last_lock_time = time.monotonic()
|
||||
tenant_ids: list[str] | list[None] = []
|
||||
tenant_ids: list[str] = []
|
||||
num_processed_tenants = 0
|
||||
|
||||
try:
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
@@ -325,6 +326,8 @@ def cloud_beat_task_generator(
|
||||
expires=expires,
|
||||
ignore_result=True,
|
||||
)
|
||||
|
||||
num_processed_tenants += 1
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@@ -344,6 +347,7 @@ def cloud_beat_task_generator(
|
||||
task_logger.info(
|
||||
f"cloud_beat_task_generator finished: "
|
||||
f"task={task_name} "
|
||||
f"num_processed_tenants={num_processed_tenants} "
|
||||
f"num_tenants={len(tenant_ids)} "
|
||||
f"elapsed={time_elapsed:.2f}"
|
||||
)
|
||||
|
||||
@@ -76,7 +76,7 @@ logger = setup_logger()
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def check_for_vespa_sync_task(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
def check_for_vespa_sync_task(self: Task, *, tenant_id: str) -> bool | None:
|
||||
"""Runs periodically to check if any document needs syncing.
|
||||
Generates sets of tasks for Celery if syncing is needed."""
|
||||
|
||||
@@ -208,7 +208,7 @@ def try_generate_stale_document_sync_tasks(
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: RedisLock,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> int | None:
|
||||
# the fence is up, do nothing
|
||||
|
||||
@@ -284,7 +284,7 @@ def try_generate_document_set_sync_tasks(
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: RedisLock,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> int | None:
|
||||
lock_beat.reacquire()
|
||||
|
||||
@@ -361,7 +361,7 @@ def try_generate_user_group_sync_tasks(
|
||||
db_session: Session,
|
||||
r: Redis,
|
||||
lock_beat: RedisLock,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> int | None:
|
||||
lock_beat.reacquire()
|
||||
|
||||
@@ -448,7 +448,7 @@ def monitor_connector_taskset(r: Redis) -> None:
|
||||
|
||||
|
||||
def monitor_document_set_taskset(
|
||||
tenant_id: str | None, key_bytes: bytes, r: Redis, db_session: Session
|
||||
tenant_id: str, key_bytes: bytes, r: Redis, db_session: Session
|
||||
) -> None:
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
document_set_id_str = RedisDocumentSet.get_id_from_fence_key(fence_key)
|
||||
@@ -523,9 +523,7 @@ def monitor_document_set_taskset(
|
||||
time_limit=LIGHT_TIME_LIMIT,
|
||||
max_retries=3,
|
||||
)
|
||||
def vespa_metadata_sync_task(
|
||||
self: Task, document_id: str, *, tenant_id: str | None
|
||||
) -> bool:
|
||||
def vespa_metadata_sync_task(self: Task, document_id: str, *, tenant_id: str) -> bool:
|
||||
start = time.monotonic()
|
||||
|
||||
completion_status = OnyxCeleryTaskCompletionStatus.UNDEFINED
|
||||
|
||||
@@ -55,6 +55,7 @@ from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.logger import TaskAttemptSingleton
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -67,7 +68,6 @@ def _get_connector_runner(
|
||||
batch_size: int,
|
||||
start_time: datetime,
|
||||
end_time: datetime,
|
||||
tenant_id: str | None,
|
||||
leave_connector_active: bool = LEAVE_CONNECTOR_ACTIVE_ON_INITIALIZATION_FAILURE,
|
||||
) -> ConnectorRunner:
|
||||
"""
|
||||
@@ -86,7 +86,6 @@ def _get_connector_runner(
|
||||
input_type=task,
|
||||
connector_specific_config=attempt.connector_credential_pair.connector.connector_specific_config,
|
||||
credential=attempt.connector_credential_pair.credential,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
# validate the connector settings
|
||||
@@ -241,7 +240,7 @@ def _check_failure_threshold(
|
||||
def _run_indexing(
|
||||
db_session: Session,
|
||||
index_attempt_id: int,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -388,7 +387,6 @@ def _run_indexing(
|
||||
batch_size=INDEX_BATCH_SIZE,
|
||||
start_time=window_start,
|
||||
end_time=window_end,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
# don't use a checkpoint if we're explicitly indexing from
|
||||
@@ -681,7 +679,7 @@ def _run_indexing(
|
||||
|
||||
def run_indexing_entrypoint(
|
||||
index_attempt_id: int,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
connector_credential_pair_id: int,
|
||||
is_ee: bool = False,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
@@ -701,7 +699,7 @@ def run_indexing_entrypoint(
|
||||
attempt = transition_attempt_to_in_progress(index_attempt_id, db_session)
|
||||
|
||||
tenant_str = ""
|
||||
if tenant_id is not None:
|
||||
if MULTI_TENANT:
|
||||
tenant_str = f" for tenant {tenant_id}"
|
||||
|
||||
connector_name = attempt.connector_credential_pair.connector.name
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import cast
|
||||
from onyx.auth.schemas import AuthBackend
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import DocumentIndexType
|
||||
from onyx.configs.constants import QueryHistoryType
|
||||
from onyx.file_processing.enums import HtmlBasedConnectorTransformLinksStrategy
|
||||
|
||||
#####
|
||||
@@ -29,6 +30,9 @@ GENERATIVE_MODEL_ACCESS_CHECK_FREQ = int(
|
||||
) # 1 day
|
||||
DISABLE_GENERATIVE_AI = os.environ.get("DISABLE_GENERATIVE_AI", "").lower() == "true"
|
||||
|
||||
ONYX_QUERY_HISTORY_TYPE = QueryHistoryType(
|
||||
(os.environ.get("ONYX_QUERY_HISTORY_TYPE") or QueryHistoryType.NORMAL.value).lower()
|
||||
)
|
||||
|
||||
#####
|
||||
# Web Configs
|
||||
|
||||
@@ -213,6 +213,12 @@ class AuthType(str, Enum):
|
||||
CLOUD = "cloud"
|
||||
|
||||
|
||||
class QueryHistoryType(str, Enum):
|
||||
DISABLED = "disabled"
|
||||
ANONYMIZED = "anonymized"
|
||||
NORMAL = "normal"
|
||||
|
||||
|
||||
# Special characters for password validation
|
||||
PASSWORD_SPECIAL_CHARS = "!@#$%^&*()_+-=[]{}|;:,.<>?"
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import INTEGRATION_TESTS_MODE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import DocumentSourceRequiringTenantContext
|
||||
from onyx.connectors.airtable.airtable_connector import AirtableConnector
|
||||
from onyx.connectors.asana.connector import AsanaConnector
|
||||
from onyx.connectors.axero.connector import AxeroConnector
|
||||
@@ -164,13 +163,9 @@ def instantiate_connector(
|
||||
input_type: InputType,
|
||||
connector_specific_config: dict[str, Any],
|
||||
credential: Credential,
|
||||
tenant_id: str | None = None,
|
||||
) -> BaseConnector:
|
||||
connector_class = identify_connector_class(source, input_type)
|
||||
|
||||
if source in DocumentSourceRequiringTenantContext:
|
||||
connector_specific_config["tenant_id"] = tenant_id
|
||||
|
||||
connector = connector_class(**connector_specific_config)
|
||||
new_credentials = connector.load_credentials(credential.credential_json)
|
||||
|
||||
@@ -184,7 +179,6 @@ def validate_ccpair_for_user(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
db_session: Session,
|
||||
tenant_id: str | None,
|
||||
enforce_creation: bool = True,
|
||||
) -> bool:
|
||||
if INTEGRATION_TESTS_MODE:
|
||||
@@ -216,7 +210,6 @@ def validate_ccpair_for_user(
|
||||
input_type=connector.input_type,
|
||||
connector_specific_config=connector.connector_specific_config,
|
||||
credential=credential,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except ConnectorValidationError as e:
|
||||
raise e
|
||||
|
||||
@@ -16,7 +16,7 @@ from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.file_processing.extract_file_text import detect_encoding
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
@@ -27,8 +27,6 @@ from onyx.file_processing.extract_file_text import read_pdf_file
|
||||
from onyx.file_processing.extract_file_text import read_text_file
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -165,12 +163,10 @@ class LocalFileConnector(LoadConnector):
|
||||
def __init__(
|
||||
self,
|
||||
file_locations: list[Path | str],
|
||||
tenant_id: str = POSTGRES_DEFAULT_SCHEMA,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
) -> None:
|
||||
self.file_locations = [Path(file_location) for file_location in file_locations]
|
||||
self.batch_size = batch_size
|
||||
self.tenant_id = tenant_id
|
||||
self.pdf_pass: str | None = None
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
@@ -179,9 +175,8 @@ class LocalFileConnector(LoadConnector):
|
||||
|
||||
def load_from_state(self) -> GenerateDocumentsOutput:
|
||||
documents: list[Document] = []
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(self.tenant_id)
|
||||
|
||||
with get_session_with_tenant(tenant_id=self.tenant_id) as db_session:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for file_path in self.file_locations:
|
||||
current_datetime = datetime.now(timezone.utc)
|
||||
files = _read_files_and_metadata(
|
||||
@@ -203,8 +198,6 @@ class LocalFileConnector(LoadConnector):
|
||||
if documents:
|
||||
yield documents
|
||||
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
connector = LocalFileConnector(file_locations=[os.environ["TEST_FILE"]])
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
import io
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
import openpyxl # type: ignore
|
||||
from googleapiclient.discovery import build # type: ignore
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
|
||||
@@ -43,12 +45,15 @@ def _extract_sections_basic(
|
||||
) -> list[Section]:
|
||||
mime_type = file["mimeType"]
|
||||
link = file["webViewLink"]
|
||||
supported_file_types = set(item.value for item in GDriveMimeType)
|
||||
|
||||
if mime_type not in set(item.value for item in GDriveMimeType):
|
||||
if mime_type not in supported_file_types:
|
||||
# Unsupported file types can still have a title, finding this way is still useful
|
||||
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
|
||||
|
||||
try:
|
||||
# ---------------------------
|
||||
# Google Sheets extraction
|
||||
if mime_type == GDriveMimeType.SPREADSHEET.value:
|
||||
try:
|
||||
sheets_service = build(
|
||||
@@ -109,7 +114,53 @@ def _extract_sections_basic(
|
||||
f"Ran into exception '{e}' when pulling data from Google Sheet '{file['name']}'."
|
||||
" Falling back to basic extraction."
|
||||
)
|
||||
# ---------------------------
|
||||
# Microsoft Excel (.xlsx or .xls) extraction branch
|
||||
elif mime_type in [
|
||||
GDriveMimeType.SPREADSHEET_OPEN_FORMAT.value,
|
||||
GDriveMimeType.SPREADSHEET_MS_EXCEL.value,
|
||||
]:
|
||||
try:
|
||||
response = service.files().get_media(fileId=file["id"]).execute()
|
||||
|
||||
with NamedTemporaryFile(suffix=".xlsx", delete=True) as tmp:
|
||||
tmp.write(response)
|
||||
tmp_path = tmp.name
|
||||
|
||||
section_separator = "\n\n"
|
||||
workbook = openpyxl.load_workbook(tmp_path, read_only=True)
|
||||
|
||||
# Work similarly to the xlsx_to_text function used for file connector
|
||||
# but returns Sections instead of a string
|
||||
sections = [
|
||||
Section(
|
||||
link=link,
|
||||
text=(
|
||||
f"Sheet: {sheet.title}\n\n"
|
||||
+ section_separator.join(
|
||||
",".join(map(str, row))
|
||||
for row in sheet.iter_rows(
|
||||
min_row=1, values_only=True
|
||||
)
|
||||
if row
|
||||
)
|
||||
),
|
||||
)
|
||||
for sheet in workbook.worksheets
|
||||
]
|
||||
|
||||
return sections
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error extracting data from Excel file '{file['name']}': {e}"
|
||||
)
|
||||
return [
|
||||
Section(link=link, text="Error extracting data from Excel file")
|
||||
]
|
||||
|
||||
# ---------------------------
|
||||
# Export for Google Docs, PPT, and fallback for spreadsheets
|
||||
if mime_type in [
|
||||
GDriveMimeType.DOC.value,
|
||||
GDriveMimeType.PPT.value,
|
||||
@@ -128,6 +179,8 @@ def _extract_sections_basic(
|
||||
)
|
||||
return [Section(link=link, text=text)]
|
||||
|
||||
# ---------------------------
|
||||
# Plain text and Markdown files
|
||||
elif mime_type in [
|
||||
GDriveMimeType.PLAIN_TEXT.value,
|
||||
GDriveMimeType.MARKDOWN.value,
|
||||
@@ -141,6 +194,8 @@ def _extract_sections_basic(
|
||||
.decode("utf-8"),
|
||||
)
|
||||
]
|
||||
# ---------------------------
|
||||
# Word, PowerPoint, PDF files
|
||||
if mime_type in [
|
||||
GDriveMimeType.WORD_DOC.value,
|
||||
GDriveMimeType.POWERPOINT.value,
|
||||
@@ -170,7 +225,11 @@ def _extract_sections_basic(
|
||||
Section(link=link, text=pptx_to_text(file=io.BytesIO(response)))
|
||||
]
|
||||
|
||||
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
|
||||
# Catch-all case, should not happen since there should be specific handling
|
||||
# for each of the supported file types
|
||||
error_message = f"Unsupported file type: {mime_type}"
|
||||
logger.error(error_message)
|
||||
raise ValueError(error_message)
|
||||
|
||||
except Exception:
|
||||
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
|
||||
|
||||
@@ -5,6 +5,10 @@ from typing import Any
|
||||
class GDriveMimeType(str, Enum):
|
||||
DOC = "application/vnd.google-apps.document"
|
||||
SPREADSHEET = "application/vnd.google-apps.spreadsheet"
|
||||
SPREADSHEET_OPEN_FORMAT = (
|
||||
"application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
||||
)
|
||||
SPREADSHEET_MS_EXCEL = "application/vnd.ms-excel"
|
||||
PDF = "application/pdf"
|
||||
WORD_DOC = "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
PPT = "application/vnd.google-apps.presentation"
|
||||
|
||||
@@ -16,7 +16,6 @@ from onyx.configs.constants import UNNAMED_KEY_PLACEHOLDER
|
||||
from onyx.db.models import ApiKey
|
||||
from onyx.db.models import User
|
||||
from onyx.server.api_key.models import APIKeyArgs
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
|
||||
@@ -73,7 +72,7 @@ def insert_api_key(
|
||||
# Get tenant_id from context var (will be default schema for single tenant)
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
api_key = generate_api_key(tenant_id if MULTI_TENANT else None)
|
||||
api_key = generate_api_key(tenant_id)
|
||||
api_key_user_id = uuid.uuid4()
|
||||
|
||||
display_name = api_key_args.name or UNNAMED_KEY_PLACEHOLDER
|
||||
|
||||
@@ -258,11 +258,11 @@ class SqlEngine:
|
||||
cls._engine = None
|
||||
|
||||
|
||||
def get_all_tenant_ids() -> list[str] | list[None]:
|
||||
def get_all_tenant_ids() -> list[str]:
|
||||
"""Returning [None] means the only tenant is the 'public' or self hosted tenant."""
|
||||
|
||||
if not MULTI_TENANT:
|
||||
return [None]
|
||||
return [POSTGRES_DEFAULT_SCHEMA]
|
||||
|
||||
with get_session_with_shared_schema() as session:
|
||||
result = session.execute(
|
||||
@@ -417,7 +417,7 @@ def get_session_with_shared_schema() -> Generator[Session, None, None]:
|
||||
|
||||
|
||||
@contextmanager
|
||||
def get_session_with_tenant(*, tenant_id: str | None) -> Generator[Session, None, None]:
|
||||
def get_session_with_tenant(*, tenant_id: str) -> Generator[Session, None, None]:
|
||||
"""
|
||||
Generate a database session for a specific tenant.
|
||||
"""
|
||||
|
||||
@@ -100,9 +100,14 @@ def _add_user_filters(
|
||||
.correlate(Persona)
|
||||
)
|
||||
else:
|
||||
where_clause |= Persona.is_public == True # noqa: E712
|
||||
where_clause &= Persona.is_visible == True # noqa: E712
|
||||
# Group the public persona conditions
|
||||
public_condition = (Persona.is_public == True) & ( # noqa: E712
|
||||
Persona.is_visible == True # noqa: E712
|
||||
)
|
||||
|
||||
where_clause |= public_condition
|
||||
where_clause |= Persona__User.user_id == user.id
|
||||
|
||||
where_clause |= Persona.user_id == user.id
|
||||
|
||||
return stmt.where(where_clause)
|
||||
|
||||
@@ -81,7 +81,7 @@ def translate_boost_count_to_multiplier(boost: int) -> float:
|
||||
# Vespa's Document API.
|
||||
def get_document_chunk_ids(
|
||||
enriched_document_info_list: list[EnrichedDocumentIndexingInfo],
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
large_chunks_enabled: bool,
|
||||
) -> list[UUID]:
|
||||
doc_chunk_ids = []
|
||||
@@ -139,7 +139,7 @@ def get_uuid_from_chunk_info(
|
||||
*,
|
||||
document_id: str,
|
||||
chunk_id: int,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
large_chunk_id: int | None = None,
|
||||
) -> UUID:
|
||||
"""NOTE: be VERY carefuly about changing this function. If changed without a migration,
|
||||
@@ -154,7 +154,7 @@ def get_uuid_from_chunk_info(
|
||||
"large_" + str(large_chunk_id) if large_chunk_id is not None else str(chunk_id)
|
||||
)
|
||||
unique_identifier_string = "_".join([doc_str, chunk_index])
|
||||
if tenant_id and MULTI_TENANT:
|
||||
if MULTI_TENANT:
|
||||
unique_identifier_string += "_" + tenant_id
|
||||
|
||||
uuid_value = uuid.uuid5(uuid.NAMESPACE_X500, unique_identifier_string)
|
||||
|
||||
@@ -43,7 +43,7 @@ class IndexBatchParams:
|
||||
|
||||
doc_id_to_previous_chunk_cnt: dict[str, int | None]
|
||||
doc_id_to_new_chunk_cnt: dict[str, int]
|
||||
tenant_id: str | None
|
||||
tenant_id: str
|
||||
large_chunks_enabled: bool
|
||||
|
||||
|
||||
@@ -222,7 +222,7 @@ class Deletable(abc.ABC):
|
||||
self,
|
||||
doc_id: str,
|
||||
*,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
chunk_count: int | None,
|
||||
) -> int:
|
||||
"""
|
||||
@@ -249,7 +249,7 @@ class Updatable(abc.ABC):
|
||||
self,
|
||||
doc_id: str,
|
||||
*,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
chunk_count: int | None,
|
||||
fields: VespaDocumentFields,
|
||||
) -> int:
|
||||
@@ -270,9 +270,7 @@ class Updatable(abc.ABC):
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def update(
|
||||
self, update_requests: list[UpdateRequest], *, tenant_id: str | None
|
||||
) -> None:
|
||||
def update(self, update_requests: list[UpdateRequest], *, tenant_id: str) -> None:
|
||||
"""
|
||||
Updates some set of chunks. The document and fields to update are specified in the update
|
||||
requests. Each update request in the list applies its changes to a list of document ids.
|
||||
|
||||
@@ -468,9 +468,7 @@ class VespaIndex(DocumentIndex):
|
||||
failure_msg = f"Failed to update document: {future_to_document_id[future]}"
|
||||
raise requests.HTTPError(failure_msg) from e
|
||||
|
||||
def update(
|
||||
self, update_requests: list[UpdateRequest], *, tenant_id: str | None
|
||||
) -> None:
|
||||
def update(self, update_requests: list[UpdateRequest], *, tenant_id: str) -> None:
|
||||
logger.debug(f"Updating {len(update_requests)} documents in Vespa")
|
||||
|
||||
# Handle Vespa character limitations
|
||||
@@ -618,7 +616,7 @@ class VespaIndex(DocumentIndex):
|
||||
doc_id: str,
|
||||
*,
|
||||
chunk_count: int | None,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
fields: VespaDocumentFields,
|
||||
) -> int:
|
||||
"""Note: if the document id does not exist, the update will be a no-op and the
|
||||
@@ -661,7 +659,7 @@ class VespaIndex(DocumentIndex):
|
||||
self,
|
||||
doc_id: str,
|
||||
*,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
chunk_count: int | None,
|
||||
) -> int:
|
||||
total_chunks_deleted = 0
|
||||
|
||||
@@ -158,8 +158,8 @@ def index_doc_batch_with_handler(
|
||||
document_batch: list[Document],
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
db_session: Session,
|
||||
tenant_id: str,
|
||||
ignore_time_skip: bool = False,
|
||||
tenant_id: str | None = None,
|
||||
) -> IndexingPipelineResult:
|
||||
try:
|
||||
index_pipeline_result = index_doc_batch(
|
||||
@@ -317,8 +317,8 @@ def index_doc_batch(
|
||||
document_index: DocumentIndex,
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
db_session: Session,
|
||||
tenant_id: str,
|
||||
ignore_time_skip: bool = False,
|
||||
tenant_id: str | None = None,
|
||||
filter_fnc: Callable[[list[Document]], list[Document]] = filter_documents,
|
||||
) -> IndexingPipelineResult:
|
||||
"""Takes different pieces of the indexing pipeline and applies it to a batch of documents
|
||||
@@ -525,9 +525,9 @@ def build_indexing_pipeline(
|
||||
embedder: IndexingEmbedder,
|
||||
document_index: DocumentIndex,
|
||||
db_session: Session,
|
||||
tenant_id: str,
|
||||
chunker: Chunker | None = None,
|
||||
ignore_time_skip: bool = False,
|
||||
tenant_id: str | None = None,
|
||||
callback: IndexingHeartbeatInterface | None = None,
|
||||
) -> IndexingPipelineProtocol:
|
||||
"""Builds a pipeline which takes in a list (batch) of docs and indexes them."""
|
||||
|
||||
@@ -84,7 +84,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
|
||||
negative -> ranked lower.
|
||||
"""
|
||||
|
||||
tenant_id: str | None = None
|
||||
tenant_id: str
|
||||
access: "DocumentAccess"
|
||||
document_sets: set[str]
|
||||
boost: int
|
||||
@@ -96,7 +96,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
|
||||
access: "DocumentAccess",
|
||||
document_sets: set[str],
|
||||
boost: int,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> "DocMetadataAwareIndexChunk":
|
||||
index_chunk_data = index_chunk.model_dump()
|
||||
return cls(
|
||||
|
||||
@@ -219,7 +219,7 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
|
||||
# If we are multi-tenant, we need to only set up initial public tables
|
||||
with Session(engine) as db_session:
|
||||
setup_onyx(db_session, None)
|
||||
setup_onyx(db_session, POSTGRES_DEFAULT_SCHEMA)
|
||||
else:
|
||||
setup_multitenant_onyx()
|
||||
|
||||
|
||||
@@ -410,7 +410,7 @@ def _build_qa_response_blocks(
|
||||
|
||||
|
||||
def _build_continue_in_web_ui_block(
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
message_id: int | None,
|
||||
) -> Block:
|
||||
if message_id is None:
|
||||
@@ -482,7 +482,7 @@ def build_follow_up_resolved_blocks(
|
||||
|
||||
def build_slack_response_blocks(
|
||||
answer: ChatOnyxBotResponse,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
message_info: SlackMessageInfo,
|
||||
channel_conf: ChannelConfig | None,
|
||||
use_citations: bool,
|
||||
|
||||
@@ -151,7 +151,7 @@ def handle_slack_feedback(
|
||||
user_id_to_post_confirmation: str,
|
||||
channel_id_to_post_confirmation: str,
|
||||
thread_ts_to_post_confirmation: str,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> None:
|
||||
message_id, doc_id, doc_rank = decompose_action_id(feedback_id)
|
||||
|
||||
|
||||
@@ -109,7 +109,7 @@ def handle_message(
|
||||
slack_channel_config: SlackChannelConfig,
|
||||
client: WebClient,
|
||||
feedback_reminder_id: str | None,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> bool:
|
||||
"""Potentially respond to the user message depending on filters and if an answer was generated
|
||||
|
||||
|
||||
@@ -72,7 +72,7 @@ def handle_regular_answer(
|
||||
channel: str,
|
||||
logger: OnyxLoggingAdapter,
|
||||
feedback_reminder_id: str | None,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
num_retries: int = DANSWER_BOT_NUM_RETRIES,
|
||||
thread_context_percent: float = MAX_THREAD_CONTEXT_PERCENTAGE,
|
||||
should_respond_with_error_msgs: bool = DANSWER_BOT_DISPLAY_ERROR_MSGS,
|
||||
|
||||
@@ -123,13 +123,13 @@ _OFFICIAL_SLACKBOT_USER_ID = "USLACKBOT"
|
||||
class SlackbotHandler:
|
||||
def __init__(self) -> None:
|
||||
logger.info("Initializing SlackbotHandler")
|
||||
self.tenant_ids: Set[str | None] = set()
|
||||
self.tenant_ids: Set[str] = set()
|
||||
# The keys for these dictionaries are tuples of (tenant_id, slack_bot_id)
|
||||
self.socket_clients: Dict[tuple[str | None, int], TenantSocketModeClient] = {}
|
||||
self.slack_bot_tokens: Dict[tuple[str | None, int], SlackBotTokens] = {}
|
||||
self.socket_clients: Dict[tuple[str, int], TenantSocketModeClient] = {}
|
||||
self.slack_bot_tokens: Dict[tuple[str, int], SlackBotTokens] = {}
|
||||
|
||||
# Store Redis lock objects here so we can release them properly
|
||||
self.redis_locks: Dict[str | None, Lock] = {}
|
||||
self.redis_locks: Dict[str, Lock] = {}
|
||||
|
||||
self.running = True
|
||||
self.pod_id = self.get_pod_id()
|
||||
@@ -193,7 +193,7 @@ class SlackbotHandler:
|
||||
self._shutdown_event.wait(timeout=TENANT_HEARTBEAT_INTERVAL)
|
||||
|
||||
def _manage_clients_per_tenant(
|
||||
self, db_session: Session, tenant_id: str | None, bot: SlackBot
|
||||
self, db_session: Session, tenant_id: str, bot: SlackBot
|
||||
) -> None:
|
||||
"""
|
||||
- If the tokens are missing or empty, close the socket client and remove them.
|
||||
@@ -385,7 +385,7 @@ class SlackbotHandler:
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
def _remove_tenant(self, tenant_id: str | None) -> None:
|
||||
def _remove_tenant(self, tenant_id: str) -> None:
|
||||
"""
|
||||
Helper to remove a tenant from `self.tenant_ids` and close any socket clients.
|
||||
(Lock release now happens in `acquire_tenants()`, not here.)
|
||||
@@ -415,7 +415,7 @@ class SlackbotHandler:
|
||||
)
|
||||
|
||||
def start_socket_client(
|
||||
self, slack_bot_id: int, tenant_id: str | None, slack_bot_tokens: SlackBotTokens
|
||||
self, slack_bot_id: int, tenant_id: str, slack_bot_tokens: SlackBotTokens
|
||||
) -> None:
|
||||
socket_client: TenantSocketModeClient = _get_socket_client(
|
||||
slack_bot_tokens, tenant_id, slack_bot_id
|
||||
@@ -912,7 +912,7 @@ def create_process_slack_event() -> (
|
||||
|
||||
|
||||
def _get_socket_client(
|
||||
slack_bot_tokens: SlackBotTokens, tenant_id: str | None, slack_bot_id: int
|
||||
slack_bot_tokens: SlackBotTokens, tenant_id: str, slack_bot_id: int
|
||||
) -> TenantSocketModeClient:
|
||||
# For more info on how to set this up, checkout the docs:
|
||||
# https://docs.onyx.app/slack_bot_setup
|
||||
|
||||
@@ -570,7 +570,7 @@ def read_slack_thread(
|
||||
|
||||
|
||||
def slack_usage_report(
|
||||
action: str, sender_id: str | None, client: WebClient, tenant_id: str | None
|
||||
action: str, sender_id: str | None, client: WebClient, tenant_id: str
|
||||
) -> None:
|
||||
if DISABLE_TELEMETRY:
|
||||
return
|
||||
@@ -663,9 +663,7 @@ def get_feedback_visibility() -> FeedbackVisibility:
|
||||
|
||||
|
||||
class TenantSocketModeClient(SocketModeClient):
|
||||
def __init__(
|
||||
self, tenant_id: str | None, slack_bot_id: int, *args: Any, **kwargs: Any
|
||||
):
|
||||
def __init__(self, tenant_id: str, slack_bot_id: int, *args: Any, **kwargs: Any):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.tenant_id = tenant_id
|
||||
self.slack_bot_id = slack_bot_id
|
||||
|
||||
@@ -16,10 +16,10 @@ class RedisConnector:
|
||||
"""Composes several classes to simplify interacting with a connector and its
|
||||
associated background tasks / associated redis interactions."""
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: int) -> None:
|
||||
def __init__(self, tenant_id: str, id: int) -> None:
|
||||
"""id: a connector credential pair id"""
|
||||
|
||||
self.tenant_id: str | None = tenant_id
|
||||
self.tenant_id: str = tenant_id
|
||||
self.id: int = id
|
||||
self.redis: redis.Redis = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
|
||||
PREFIX = "connectorsync"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: int) -> None:
|
||||
def __init__(self, tenant_id: str, id: int) -> None:
|
||||
super().__init__(tenant_id, str(id))
|
||||
|
||||
# documents that should be skipped
|
||||
@@ -60,7 +60,7 @@ class RedisConnectorCredentialPair(RedisObjectHelper):
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: RedisLock,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> tuple[int, int] | None:
|
||||
"""We can limit the number of tasks generated here, which is useful to prevent
|
||||
one tenant from overwhelming the sync queue.
|
||||
|
||||
@@ -39,8 +39,8 @@ class RedisConnectorDelete:
|
||||
ACTIVE_PREFIX = PREFIX + "_active"
|
||||
ACTIVE_TTL = 3600
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str | None = tenant_id
|
||||
def __init__(self, tenant_id: str, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str = tenant_id
|
||||
self.id = id
|
||||
self.redis = redis
|
||||
|
||||
|
||||
@@ -52,8 +52,8 @@ class RedisConnectorPermissionSync:
|
||||
ACTIVE_PREFIX = PREFIX + "_active"
|
||||
ACTIVE_TTL = CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT * 2
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str | None = tenant_id
|
||||
def __init__(self, tenant_id: str, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str = tenant_id
|
||||
self.id = id
|
||||
self.redis = redis
|
||||
|
||||
|
||||
@@ -44,8 +44,8 @@ class RedisConnectorExternalGroupSync:
|
||||
ACTIVE_PREFIX = PREFIX + "_active"
|
||||
ACTIVE_TTL = 3600
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str | None = tenant_id
|
||||
def __init__(self, tenant_id: str, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str = tenant_id
|
||||
self.id = id
|
||||
self.redis = redis
|
||||
|
||||
|
||||
@@ -52,12 +52,12 @@ class RedisConnectorIndex:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
id: int,
|
||||
search_settings_id: int,
|
||||
redis: redis.Redis,
|
||||
) -> None:
|
||||
self.tenant_id: str | None = tenant_id
|
||||
self.tenant_id: str = tenant_id
|
||||
self.id = id
|
||||
self.search_settings_id = search_settings_id
|
||||
self.redis = redis
|
||||
|
||||
@@ -52,8 +52,8 @@ class RedisConnectorPrune:
|
||||
ACTIVE_PREFIX = PREFIX + "_active"
|
||||
ACTIVE_TTL = CELERY_PRUNING_LOCK_TIMEOUT * 2
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str | None = tenant_id
|
||||
def __init__(self, tenant_id: str, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str = tenant_id
|
||||
self.id = id
|
||||
self.redis = redis
|
||||
|
||||
|
||||
@@ -13,8 +13,8 @@ class RedisConnectorStop:
|
||||
TIMEOUT_PREFIX = f"{PREFIX}_timeout"
|
||||
TIMEOUT_TTL = 300
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str | None = tenant_id
|
||||
def __init__(self, tenant_id: str, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str = tenant_id
|
||||
self.id: int = id
|
||||
self.redis = redis
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ class RedisDocumentSet(RedisObjectHelper):
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: int) -> None:
|
||||
def __init__(self, tenant_id: str, id: int) -> None:
|
||||
super().__init__(tenant_id, str(id))
|
||||
|
||||
@property
|
||||
@@ -58,7 +58,7 @@ class RedisDocumentSet(RedisObjectHelper):
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: RedisLock,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> tuple[int, int] | None:
|
||||
"""Max tasks is ignored for now until we can build the logic to mark the
|
||||
document set up to date over multiple batches.
|
||||
|
||||
@@ -14,8 +14,8 @@ class RedisObjectHelper(ABC):
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: str):
|
||||
self._tenant_id: str | None = tenant_id
|
||||
def __init__(self, tenant_id: str, id: str):
|
||||
self._tenant_id: str = tenant_id
|
||||
self._id: str = id
|
||||
self.redis = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
@@ -87,7 +87,7 @@ class RedisObjectHelper(ABC):
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: RedisLock,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> tuple[int, int] | None:
|
||||
"""First element should be the number of actual tasks generated, second should
|
||||
be the number of docs that were candidates to be synced for the cc pair.
|
||||
|
||||
@@ -24,7 +24,7 @@ class RedisUserGroup(RedisObjectHelper):
|
||||
FENCE_PREFIX = PREFIX + "_fence"
|
||||
TASKSET_PREFIX = PREFIX + "_taskset"
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: int) -> None:
|
||||
def __init__(self, tenant_id: str, id: int) -> None:
|
||||
super().__init__(tenant_id, str(id))
|
||||
|
||||
@property
|
||||
@@ -59,7 +59,7 @@ class RedisUserGroup(RedisObjectHelper):
|
||||
db_session: Session,
|
||||
redis_client: Redis,
|
||||
lock: RedisLock,
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> tuple[int, int] | None:
|
||||
"""Max tasks is ignored for now until we can build the logic to mark the
|
||||
user group up to date over multiple batches.
|
||||
|
||||
@@ -37,13 +37,15 @@ from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.server.documents.models import ConnectorBase
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _create_indexable_chunks(
|
||||
preprocessed_docs: list[dict],
|
||||
tenant_id: str | None,
|
||||
tenant_id: str,
|
||||
) -> tuple[list[Document], list[DocMetadataAwareIndexChunk]]:
|
||||
ids_to_documents = {}
|
||||
chunks = []
|
||||
@@ -86,7 +88,7 @@ def _create_indexable_chunks(
|
||||
mini_chunk_embeddings=[],
|
||||
),
|
||||
title_embedding=preprocessed_doc["title_embedding"],
|
||||
tenant_id=tenant_id,
|
||||
tenant_id=tenant_id if MULTI_TENANT else POSTGRES_DEFAULT_SCHEMA,
|
||||
access=default_public_access,
|
||||
document_sets=set(),
|
||||
boost=DEFAULT_BOOST,
|
||||
@@ -111,7 +113,7 @@ def load_processed_docs(cohere_enabled: bool) -> list[dict]:
|
||||
|
||||
|
||||
def seed_initial_documents(
|
||||
db_session: Session, tenant_id: str | None, cohere_enabled: bool = False
|
||||
db_session: Session, tenant_id: str, cohere_enabled: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Seed initial documents so users don't have an empty index to start
|
||||
|
||||
@@ -620,7 +620,7 @@ def associate_credential_to_connector(
|
||||
)
|
||||
|
||||
try:
|
||||
validate_ccpair_for_user(connector_id, credential_id, db_session, tenant_id)
|
||||
validate_ccpair_for_user(connector_id, credential_id, db_session)
|
||||
|
||||
response = add_credential_to_connector(
|
||||
db_session=db_session,
|
||||
|
||||
@@ -902,7 +902,6 @@ def create_connector_with_mock_credential(
|
||||
connector_id=connector_id,
|
||||
credential_id=credential_id,
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
response = add_credential_to_connector(
|
||||
db_session=db_session,
|
||||
|
||||
@@ -18,7 +18,6 @@ from onyx.db.credentials import fetch_credentials_by_source_for_user
|
||||
from onyx.db.credentials import fetch_credentials_for_user
|
||||
from onyx.db.credentials import swap_credentials_connector
|
||||
from onyx.db.credentials import update_credential
|
||||
from onyx.db.engine import get_current_tenant_id
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.models import DocumentSource
|
||||
from onyx.db.models import User
|
||||
@@ -100,13 +99,11 @@ def swap_credentials_for_connector(
|
||||
credential_swap_req: CredentialSwapRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
tenant_id: str | None = Depends(get_current_tenant_id),
|
||||
) -> StatusResponse:
|
||||
validate_ccpair_for_user(
|
||||
credential_swap_req.connector_id,
|
||||
credential_swap_req.new_credential_id,
|
||||
db_session,
|
||||
tenant_id,
|
||||
)
|
||||
|
||||
connector_credential_pair = swap_credentials_connector(
|
||||
|
||||
@@ -4,7 +4,9 @@ from enum import Enum
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.configs.constants import QueryHistoryType
|
||||
from onyx.db.models import Notification as NotificationDBModel
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
|
||||
class PageType(str, Enum):
|
||||
@@ -49,9 +51,10 @@ class Settings(BaseModel):
|
||||
|
||||
temperature_override_enabled: bool | None = False
|
||||
auto_scroll: bool | None = False
|
||||
query_history_type: QueryHistoryType | None = None
|
||||
|
||||
|
||||
class UserSettings(Settings):
|
||||
notifications: list[Notification]
|
||||
needs_reindexing: bool
|
||||
tenant_id: str | None = None
|
||||
tenant_id: str = POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from onyx.configs.app_configs import ONYX_QUERY_HISTORY_TYPE
|
||||
from onyx.configs.constants import KV_SETTINGS_KEY
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
@@ -45,6 +46,7 @@ def load_settings() -> Settings:
|
||||
anonymous_user_enabled = False
|
||||
|
||||
settings.anonymous_user_enabled = anonymous_user_enabled
|
||||
settings.query_history_type = ONYX_QUERY_HISTORY_TYPE
|
||||
return settings
|
||||
|
||||
|
||||
|
||||
@@ -65,7 +65,7 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def setup_onyx(
|
||||
db_session: Session, tenant_id: str | None, cohere_enabled: bool = False
|
||||
db_session: Session, tenant_id: str, cohere_enabled: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Setup Onyx for a particular tenant. In the Single Tenant case, it will set it up for the default schema
|
||||
|
||||
@@ -260,7 +260,7 @@ def get_documents_for_tenant_connector(
|
||||
def search_for_document(
|
||||
index_name: str,
|
||||
document_id: str | None = None,
|
||||
tenant_id: str | None = None,
|
||||
tenant_id: str = POSTGRES_DEFAULT_SCHEMA,
|
||||
max_hits: int | None = 10,
|
||||
) -> List[Dict[str, Any]]:
|
||||
yql_query = f"select * from sources {index_name}"
|
||||
@@ -507,9 +507,9 @@ def get_number_of_chunks_we_think_exist(
|
||||
|
||||
class VespaDebugging:
|
||||
# Class for managing Vespa debugging actions.
|
||||
def __init__(self, tenant_id: str | None = None):
|
||||
def __init__(self, tenant_id: str = POSTGRES_DEFAULT_SCHEMA):
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
self.tenant_id = POSTGRES_DEFAULT_SCHEMA if not tenant_id else tenant_id
|
||||
self.tenant_id = tenant_id
|
||||
self.index_name = get_index_name(self.tenant_id)
|
||||
|
||||
def sample_document_counts(self) -> None:
|
||||
@@ -603,7 +603,7 @@ class VespaDebugging:
|
||||
delete_documents_for_tenant(self.index_name, self.tenant_id, count=count)
|
||||
|
||||
def search_for_document(
|
||||
self, document_id: str | None = None, tenant_id: str | None = None
|
||||
self, document_id: str | None = None, tenant_id: str = POSTGRES_DEFAULT_SCHEMA
|
||||
) -> List[Dict[str, Any]]:
|
||||
return search_for_document(self.index_name, document_id, tenant_id)
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ from sqlalchemy.orm import Session
|
||||
from onyx.db.document import delete_documents_complete__no_commit
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.search_settings import get_active_search_settings
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
# Modify sys.path
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
@@ -74,7 +75,7 @@ def _unsafe_deletion(
|
||||
for document in documents:
|
||||
document_index.delete_single(
|
||||
doc_id=document.id,
|
||||
tenant_id=None,
|
||||
tenant_id=POSTGRES_DEFAULT_SCHEMA,
|
||||
chunk_count=document.chunk_count,
|
||||
)
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.document_index.document_index_utils import get_multipass_config
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
# makes it so `PYTHONPATH=.` is not required when running this script
|
||||
parent_dir = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||||
@@ -96,7 +97,9 @@ def main() -> None:
|
||||
try:
|
||||
print(f"Deleting document {doc_id} in Vespa")
|
||||
chunks_deleted = vespa_index.delete_single(
|
||||
doc_id, tenant_id=None, chunk_count=document.chunk_count
|
||||
doc_id,
|
||||
tenant_id=POSTGRES_DEFAULT_SCHEMA,
|
||||
chunk_count=document.chunk_count,
|
||||
)
|
||||
if chunks_deleted > 0:
|
||||
print(
|
||||
|
||||
@@ -18,5 +18,7 @@ CURRENT_TENANT_ID_CONTEXTVAR: contextvars.ContextVar[
|
||||
def get_current_tenant_id() -> str:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if tenant_id is None:
|
||||
if not MULTI_TENANT:
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
raise RuntimeError("Tenant ID is not set. This should never happen.")
|
||||
return tenant_id
|
||||
|
||||
@@ -87,7 +87,7 @@ def test_confluence_connector_basic(
|
||||
assert len(txt_doc.sections) == 1
|
||||
assert txt_doc.sections[0].text == "small"
|
||||
assert txt_doc.primary_owners
|
||||
assert txt_doc.primary_owners[0].email == "chris@danswer.ai"
|
||||
assert txt_doc.primary_owners[0].email == "chris@onyx.app"
|
||||
assert (
|
||||
txt_doc.sections[0].link
|
||||
== "https://danswerai.atlassian.net/wiki/pages/viewpageattachments.action?pageId=52494430&preview=%2F52494430%2F52527123%2Fsmall-file.txt"
|
||||
|
||||
89
web/package-lock.json
generated
89
web/package-lock.json
generated
@@ -70,6 +70,8 @@
|
||||
"recharts": "^2.13.1",
|
||||
"rehype-katex": "^7.0.1",
|
||||
"rehype-prism-plus": "^2.0.0",
|
||||
"rehype-sanitize": "^6.0.0",
|
||||
"rehype-stringify": "^10.0.1",
|
||||
"remark-gfm": "^4.0.0",
|
||||
"remark-math": "^6.0.0",
|
||||
"semver": "^7.5.4",
|
||||
@@ -11741,6 +11743,54 @@
|
||||
"resolved": "https://registry.npmjs.org/@types/unist/-/unist-2.0.10.tgz",
|
||||
"integrity": "sha512-IfYcSBWE3hLpBg8+X2SEa8LVkJdJEkT2Ese2aaLs3ptGdVtABxndrMaxuFlQ1qdFf9Q5rDvDpxI3WwgvKFAsQA=="
|
||||
},
|
||||
"node_modules/hast-util-sanitize": {
|
||||
"version": "5.0.2",
|
||||
"resolved": "https://registry.npmjs.org/hast-util-sanitize/-/hast-util-sanitize-5.0.2.tgz",
|
||||
"integrity": "sha512-3yTWghByc50aGS7JlGhk61SPenfE/p1oaFeNwkOOyrscaOkMGrcW9+Cy/QAIOBpZxP1yqDIzFMR0+Np0i0+usg==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@types/hast": "^3.0.0",
|
||||
"@ungap/structured-clone": "^1.0.0",
|
||||
"unist-util-position": "^5.0.0"
|
||||
},
|
||||
"funding": {
|
||||
"type": "opencollective",
|
||||
"url": "https://opencollective.com/unified"
|
||||
}
|
||||
},
|
||||
"node_modules/hast-util-to-html": {
|
||||
"version": "9.0.5",
|
||||
"resolved": "https://registry.npmjs.org/hast-util-to-html/-/hast-util-to-html-9.0.5.tgz",
|
||||
"integrity": "sha512-OguPdidb+fbHQSU4Q4ZiLKnzWo8Wwsf5bZfbvu7//a9oTYoqD/fWpe96NuHkoS9h0ccGOTe0C4NGXdtS0iObOw==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@types/hast": "^3.0.0",
|
||||
"@types/unist": "^3.0.0",
|
||||
"ccount": "^2.0.0",
|
||||
"comma-separated-tokens": "^2.0.0",
|
||||
"hast-util-whitespace": "^3.0.0",
|
||||
"html-void-elements": "^3.0.0",
|
||||
"mdast-util-to-hast": "^13.0.0",
|
||||
"property-information": "^7.0.0",
|
||||
"space-separated-tokens": "^2.0.0",
|
||||
"stringify-entities": "^4.0.0",
|
||||
"zwitch": "^2.0.4"
|
||||
},
|
||||
"funding": {
|
||||
"type": "opencollective",
|
||||
"url": "https://opencollective.com/unified"
|
||||
}
|
||||
},
|
||||
"node_modules/hast-util-to-html/node_modules/property-information": {
|
||||
"version": "7.0.0",
|
||||
"resolved": "https://registry.npmjs.org/property-information/-/property-information-7.0.0.tgz",
|
||||
"integrity": "sha512-7D/qOz/+Y4X/rzSB6jKxKUsQnphO046ei8qxG59mtM3RG3DHgTK81HrxrmoDVINJb8NKT5ZsRbwHvQ6B68Iyhg==",
|
||||
"license": "MIT",
|
||||
"funding": {
|
||||
"type": "github",
|
||||
"url": "https://github.com/sponsors/wooorm"
|
||||
}
|
||||
},
|
||||
"node_modules/hast-util-to-jsx-runtime": {
|
||||
"version": "2.3.0",
|
||||
"resolved": "https://registry.npmjs.org/hast-util-to-jsx-runtime/-/hast-util-to-jsx-runtime-2.3.0.tgz",
|
||||
@@ -11919,6 +11969,16 @@
|
||||
"url": "https://opencollective.com/unified"
|
||||
}
|
||||
},
|
||||
"node_modules/html-void-elements": {
|
||||
"version": "3.0.0",
|
||||
"resolved": "https://registry.npmjs.org/html-void-elements/-/html-void-elements-3.0.0.tgz",
|
||||
"integrity": "sha512-bEqo66MRXsUGxWHV5IP0PUiAWwoEjba4VCzg0LjFJBpchPaTfyfCKTG6bc5F8ucKec3q5y6qOdGyYTSBEvhCrg==",
|
||||
"license": "MIT",
|
||||
"funding": {
|
||||
"type": "github",
|
||||
"url": "https://github.com/sponsors/wooorm"
|
||||
}
|
||||
},
|
||||
"node_modules/html-webpack-plugin": {
|
||||
"version": "5.6.3",
|
||||
"resolved": "https://registry.npmjs.org/html-webpack-plugin/-/html-webpack-plugin-5.6.3.tgz",
|
||||
@@ -19125,6 +19185,35 @@
|
||||
"unist-util-visit": "^5.0.0"
|
||||
}
|
||||
},
|
||||
"node_modules/rehype-sanitize": {
|
||||
"version": "6.0.0",
|
||||
"resolved": "https://registry.npmjs.org/rehype-sanitize/-/rehype-sanitize-6.0.0.tgz",
|
||||
"integrity": "sha512-CsnhKNsyI8Tub6L4sm5ZFsme4puGfc6pYylvXo1AeqaGbjOYyzNv3qZPwvs0oMJ39eryyeOdmxwUIo94IpEhqg==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@types/hast": "^3.0.0",
|
||||
"hast-util-sanitize": "^5.0.0"
|
||||
},
|
||||
"funding": {
|
||||
"type": "opencollective",
|
||||
"url": "https://opencollective.com/unified"
|
||||
}
|
||||
},
|
||||
"node_modules/rehype-stringify": {
|
||||
"version": "10.0.1",
|
||||
"resolved": "https://registry.npmjs.org/rehype-stringify/-/rehype-stringify-10.0.1.tgz",
|
||||
"integrity": "sha512-k9ecfXHmIPuFVI61B9DeLPN0qFHfawM6RsuX48hoqlaKSF61RskNjSm1lI8PhBEM0MRdLxVVm4WmTqJQccH9mA==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@types/hast": "^3.0.0",
|
||||
"hast-util-to-html": "^9.0.0",
|
||||
"unified": "^11.0.0"
|
||||
},
|
||||
"funding": {
|
||||
"type": "opencollective",
|
||||
"url": "https://opencollective.com/unified"
|
||||
}
|
||||
},
|
||||
"node_modules/relateurl": {
|
||||
"version": "0.2.7",
|
||||
"resolved": "https://registry.npmjs.org/relateurl/-/relateurl-0.2.7.tgz",
|
||||
|
||||
@@ -73,6 +73,8 @@
|
||||
"recharts": "^2.13.1",
|
||||
"rehype-katex": "^7.0.1",
|
||||
"rehype-prism-plus": "^2.0.0",
|
||||
"rehype-sanitize": "^6.0.0",
|
||||
"rehype-stringify": "^10.0.1",
|
||||
"remark-gfm": "^4.0.0",
|
||||
"remark-math": "^6.0.0",
|
||||
"semver": "^7.5.4",
|
||||
|
||||
@@ -4,6 +4,12 @@ export enum ApplicationStatus {
|
||||
ACTIVE = "active",
|
||||
}
|
||||
|
||||
export enum QueryHistoryType {
|
||||
DISABLED = "disabled",
|
||||
ANONYMIZED = "anonymized",
|
||||
NORMAL = "normal",
|
||||
}
|
||||
|
||||
export interface Settings {
|
||||
anonymous_user_enabled: boolean;
|
||||
maximum_chat_retention_days: number | null;
|
||||
@@ -14,6 +20,7 @@ export interface Settings {
|
||||
application_status: ApplicationStatus;
|
||||
auto_scroll: boolean;
|
||||
temperature_override_enabled: boolean;
|
||||
query_history_type: QueryHistoryType;
|
||||
}
|
||||
|
||||
export enum NotificationType {
|
||||
|
||||
@@ -56,6 +56,7 @@ import {
|
||||
Dispatch,
|
||||
SetStateAction,
|
||||
use,
|
||||
useCallback,
|
||||
useContext,
|
||||
useEffect,
|
||||
useLayoutEffect,
|
||||
@@ -893,24 +894,6 @@ export function ChatPage({
|
||||
);
|
||||
const scrollDist = useRef<number>(0);
|
||||
|
||||
const updateScrollTracking = () => {
|
||||
const scrollDistance =
|
||||
endDivRef?.current?.getBoundingClientRect()?.top! -
|
||||
inputRef?.current?.getBoundingClientRect()?.top!;
|
||||
scrollDist.current = scrollDistance;
|
||||
setAboveHorizon(scrollDist.current > 500);
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
const scrollableDiv = scrollableDivRef.current;
|
||||
if (scrollableDiv) {
|
||||
scrollableDiv.addEventListener("scroll", updateScrollTracking);
|
||||
return () => {
|
||||
scrollableDiv.removeEventListener("scroll", updateScrollTracking);
|
||||
};
|
||||
}
|
||||
}, []);
|
||||
|
||||
const handleInputResize = () => {
|
||||
setTimeout(() => {
|
||||
if (
|
||||
@@ -962,33 +945,12 @@ export function ChatPage({
|
||||
if (isVisible) return;
|
||||
|
||||
// Check if all messages are currently rendered
|
||||
if (currentVisibleRange.end < messageHistory.length) {
|
||||
// Update visible range to include the last messages
|
||||
updateCurrentVisibleRange({
|
||||
start: Math.max(
|
||||
0,
|
||||
messageHistory.length -
|
||||
(currentVisibleRange.end - currentVisibleRange.start)
|
||||
),
|
||||
end: messageHistory.length,
|
||||
mostVisibleMessageId: currentVisibleRange.mostVisibleMessageId,
|
||||
});
|
||||
// If all messages are already rendered, scroll immediately
|
||||
endDivRef.current.scrollIntoView({
|
||||
behavior: fast ? "auto" : "smooth",
|
||||
});
|
||||
|
||||
// Wait for the state update and re-render before scrolling
|
||||
setTimeout(() => {
|
||||
endDivRef.current?.scrollIntoView({
|
||||
behavior: fast ? "auto" : "smooth",
|
||||
});
|
||||
setHasPerformedInitialScroll(true);
|
||||
}, 100);
|
||||
} else {
|
||||
// If all messages are already rendered, scroll immediately
|
||||
endDivRef.current.scrollIntoView({
|
||||
behavior: fast ? "auto" : "smooth",
|
||||
});
|
||||
|
||||
setHasPerformedInitialScroll(true);
|
||||
}
|
||||
setHasPerformedInitialScroll(true);
|
||||
}, 50);
|
||||
|
||||
// Reset waitForScrollRef after 1.5 seconds
|
||||
@@ -1009,11 +971,6 @@ export function ChatPage({
|
||||
handleInputResize();
|
||||
}, [message]);
|
||||
|
||||
// tracks scrolling
|
||||
useEffect(() => {
|
||||
updateScrollTracking();
|
||||
}, [messageHistory]);
|
||||
|
||||
// used for resizing of the document sidebar
|
||||
const masterFlexboxRef = useRef<HTMLDivElement>(null);
|
||||
const [maxDocumentSidebarWidth, setMaxDocumentSidebarWidth] = useState<
|
||||
@@ -1977,122 +1934,6 @@ export function ChatPage({
|
||||
|
||||
// Virtualization + Scrolling related effects and functions
|
||||
const scrollInitialized = useRef(false);
|
||||
interface VisibleRange {
|
||||
start: number;
|
||||
end: number;
|
||||
mostVisibleMessageId: number | null;
|
||||
}
|
||||
|
||||
const [visibleRange, setVisibleRange] = useState<
|
||||
Map<string | null, VisibleRange>
|
||||
>(() => {
|
||||
const initialRange: VisibleRange = {
|
||||
start: 0,
|
||||
end: BUFFER_COUNT,
|
||||
mostVisibleMessageId: null,
|
||||
};
|
||||
return new Map([[chatSessionIdRef.current, initialRange]]);
|
||||
});
|
||||
|
||||
// Function used to update current visible range. Only method for updating `visibleRange` state.
|
||||
const updateCurrentVisibleRange = (
|
||||
newRange: VisibleRange,
|
||||
forceUpdate?: boolean
|
||||
) => {
|
||||
if (
|
||||
scrollInitialized.current &&
|
||||
visibleRange.get(loadedIdSessionRef.current) == undefined &&
|
||||
!forceUpdate
|
||||
) {
|
||||
return;
|
||||
}
|
||||
|
||||
setVisibleRange((prevState) => {
|
||||
const newState = new Map(prevState);
|
||||
newState.set(loadedIdSessionRef.current, newRange);
|
||||
return newState;
|
||||
});
|
||||
};
|
||||
|
||||
// Set first value for visibleRange state on page load / refresh.
|
||||
const initializeVisibleRange = () => {
|
||||
const upToDatemessageHistory = buildLatestMessageChain(
|
||||
currentMessageMap(completeMessageDetail)
|
||||
);
|
||||
|
||||
if (!scrollInitialized.current && upToDatemessageHistory.length > 0) {
|
||||
const newEnd = Math.max(upToDatemessageHistory.length, BUFFER_COUNT);
|
||||
const newStart = Math.max(0, newEnd - BUFFER_COUNT);
|
||||
const newMostVisibleMessageId =
|
||||
upToDatemessageHistory[newEnd - 1]?.messageId;
|
||||
|
||||
updateCurrentVisibleRange(
|
||||
{
|
||||
start: newStart,
|
||||
end: newEnd,
|
||||
mostVisibleMessageId: newMostVisibleMessageId,
|
||||
},
|
||||
true
|
||||
);
|
||||
scrollInitialized.current = true;
|
||||
}
|
||||
};
|
||||
|
||||
const updateVisibleRangeBasedOnScroll = () => {
|
||||
if (!scrollInitialized.current) return;
|
||||
const scrollableDiv = scrollableDivRef.current;
|
||||
if (!scrollableDiv) return;
|
||||
|
||||
const viewportHeight = scrollableDiv.clientHeight;
|
||||
let mostVisibleMessageIndex = -1;
|
||||
|
||||
messageHistory.forEach((message, index) => {
|
||||
const messageElement = document.getElementById(
|
||||
`message-${message.messageId}`
|
||||
);
|
||||
if (messageElement) {
|
||||
const rect = messageElement.getBoundingClientRect();
|
||||
const isVisible = rect.bottom <= viewportHeight && rect.bottom > 0;
|
||||
if (isVisible && index > mostVisibleMessageIndex) {
|
||||
mostVisibleMessageIndex = index;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
if (mostVisibleMessageIndex !== -1) {
|
||||
const startIndex = Math.max(0, mostVisibleMessageIndex - BUFFER_COUNT);
|
||||
const endIndex = Math.min(
|
||||
messageHistory.length,
|
||||
mostVisibleMessageIndex + BUFFER_COUNT + 1
|
||||
);
|
||||
|
||||
updateCurrentVisibleRange({
|
||||
start: startIndex,
|
||||
end: endIndex,
|
||||
mostVisibleMessageId: messageHistory[mostVisibleMessageIndex].messageId,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
initializeVisibleRange();
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [router, messageHistory]);
|
||||
|
||||
useLayoutEffect(() => {
|
||||
const scrollableDiv = scrollableDivRef.current;
|
||||
|
||||
const handleScroll = () => {
|
||||
updateVisibleRangeBasedOnScroll();
|
||||
};
|
||||
|
||||
scrollableDiv?.addEventListener("scroll", handleScroll);
|
||||
|
||||
return () => {
|
||||
scrollableDiv?.removeEventListener("scroll", handleScroll);
|
||||
};
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [messageHistory]);
|
||||
|
||||
const imageFileInMessageHistory = useMemo(() => {
|
||||
return messageHistory
|
||||
@@ -2102,11 +1943,6 @@ export function ChatPage({
|
||||
);
|
||||
}, [messageHistory]);
|
||||
|
||||
const currentVisibleRange = visibleRange.get(currentSessionId()) || {
|
||||
start: 0,
|
||||
end: 0,
|
||||
mostVisibleMessageId: null,
|
||||
};
|
||||
useSendMessageToParent();
|
||||
|
||||
useEffect(() => {
|
||||
@@ -2146,6 +1982,15 @@ export function ChatPage({
|
||||
|
||||
const currentPersona = alternativeAssistant || liveAssistant;
|
||||
|
||||
const HORIZON_DISTANCE = 800;
|
||||
const handleScroll = useCallback(() => {
|
||||
const scrollDistance =
|
||||
endDivRef?.current?.getBoundingClientRect()?.top! -
|
||||
inputRef?.current?.getBoundingClientRect()?.top!;
|
||||
scrollDist.current = scrollDistance;
|
||||
setAboveHorizon(scrollDist.current > HORIZON_DISTANCE);
|
||||
}, []);
|
||||
|
||||
useEffect(() => {
|
||||
const handleSlackChatRedirect = async () => {
|
||||
if (!slackChatId) return;
|
||||
@@ -2596,6 +2441,7 @@ export function ChatPage({
|
||||
{...getRootProps()}
|
||||
>
|
||||
<div
|
||||
onScroll={handleScroll}
|
||||
className={`w-full h-[calc(100vh-160px)] flex flex-col default-scrollbar overflow-y-auto overflow-x-hidden relative`}
|
||||
ref={scrollableDivRef}
|
||||
>
|
||||
@@ -2653,18 +2499,7 @@ export function ChatPage({
|
||||
// NOTE: temporarily removing this to fix the scroll bug
|
||||
// (hasPerformedInitialScroll ? "" : "invisible")
|
||||
>
|
||||
{(messageHistory.length < BUFFER_COUNT
|
||||
? messageHistory
|
||||
: messageHistory.slice(
|
||||
currentVisibleRange.start,
|
||||
currentVisibleRange.end
|
||||
)
|
||||
).map((message, fauxIndex) => {
|
||||
const i =
|
||||
messageHistory.length < BUFFER_COUNT
|
||||
? fauxIndex
|
||||
: fauxIndex + currentVisibleRange.start;
|
||||
|
||||
{messageHistory.map((message, i) => {
|
||||
const messageMap = currentMessageMap(
|
||||
completeMessageDetail
|
||||
);
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import React, { useState, useEffect } from "react";
|
||||
import React, { useState, useEffect, useCallback } from "react";
|
||||
import { InputPrompt } from "@/app/chat/interfaces";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { TrashIcon, PlusIcon } from "@/components/icons/icons";
|
||||
import { MoreVertical, CheckIcon, XIcon } from "lucide-react";
|
||||
import { PlusIcon } from "@/components/icons/icons";
|
||||
import { MoreVertical, XIcon } from "lucide-react";
|
||||
import { Textarea } from "@/components/ui/textarea";
|
||||
import Title from "@/components/ui/title";
|
||||
import Text from "@/components/ui/text";
|
||||
@@ -153,114 +153,6 @@ export default function InputPrompts() {
|
||||
}
|
||||
};
|
||||
|
||||
const PromptCard = ({ prompt }: { prompt: InputPrompt }) => {
|
||||
const isEditing = editingPromptId === prompt.id;
|
||||
const [localPrompt, setLocalPrompt] = useState(prompt.prompt);
|
||||
const [localContent, setLocalContent] = useState(prompt.content);
|
||||
|
||||
// Sync local edits with any prompt changes from outside
|
||||
useEffect(() => {
|
||||
setLocalPrompt(prompt.prompt);
|
||||
setLocalContent(prompt.content);
|
||||
}, [prompt, isEditing]);
|
||||
|
||||
const handleLocalEdit = (field: "prompt" | "content", value: string) => {
|
||||
if (field === "prompt") {
|
||||
setLocalPrompt(value);
|
||||
} else {
|
||||
setLocalContent(value);
|
||||
}
|
||||
};
|
||||
|
||||
const handleSaveLocal = () => {
|
||||
handleSave(prompt.id, localPrompt, localContent);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="border dark:border-none dark:bg-[#333333] rounded-lg p-4 mb-4 relative">
|
||||
{isEditing ? (
|
||||
<>
|
||||
<div className="absolute top-2 right-2">
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
onClick={() => {
|
||||
setEditingPromptId(null);
|
||||
fetchInputPrompts(); // Revert changes from server
|
||||
}}
|
||||
>
|
||||
<XIcon size={14} />
|
||||
</Button>
|
||||
</div>
|
||||
<div className="flex">
|
||||
<div className="flex-grow mr-4">
|
||||
<Textarea
|
||||
value={localPrompt}
|
||||
onChange={(e) => handleLocalEdit("prompt", e.target.value)}
|
||||
className="mb-2 resize-none"
|
||||
placeholder="Prompt"
|
||||
/>
|
||||
<Textarea
|
||||
value={localContent}
|
||||
onChange={(e) => handleLocalEdit("content", e.target.value)}
|
||||
className="resize-vertical min-h-[100px]"
|
||||
placeholder="Content"
|
||||
/>
|
||||
</div>
|
||||
<div className="flex items-end">
|
||||
<Button onClick={handleSaveLocal}>
|
||||
{prompt.id ? "Save" : "Create"}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<TooltipProvider>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<div className="mb-2 flex gap-x-2 ">
|
||||
<p className="font-semibold">{prompt.prompt}</p>
|
||||
{isPromptPublic(prompt) && <SourceChip title="Built-in" />}
|
||||
</div>
|
||||
</TooltipTrigger>
|
||||
{isPromptPublic(prompt) && (
|
||||
<TooltipContent>
|
||||
<p>This is a built-in prompt and cannot be edited</p>
|
||||
</TooltipContent>
|
||||
)}
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
<div className="whitespace-pre-wrap">{prompt.content}</div>
|
||||
<div className="absolute top-2 right-2">
|
||||
<DropdownMenu>
|
||||
<DropdownMenuTrigger className="hover:bg-transparent" asChild>
|
||||
<Button
|
||||
className="!hover:bg-transparent"
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
>
|
||||
<MoreVertical size={14} />
|
||||
</Button>
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent>
|
||||
{!isPromptPublic(prompt) && (
|
||||
<DropdownMenuItem onClick={() => handleEdit(prompt.id)}>
|
||||
Edit
|
||||
</DropdownMenuItem>
|
||||
)}
|
||||
<DropdownMenuItem onClick={() => handleDelete(prompt.id)}>
|
||||
Delete
|
||||
</DropdownMenuItem>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="mx-auto max-w-4xl">
|
||||
<div className="absolute top-4 left-4">
|
||||
@@ -272,13 +164,21 @@ export default function InputPrompts() {
|
||||
<Title>Prompt Shortcuts</Title>
|
||||
<Text>
|
||||
Manage and customize prompt shortcuts for your assistants. Use your
|
||||
prompt shortcuts by starting a new message “/” in chat.
|
||||
prompt shortcuts by starting a new message with "/" in
|
||||
chat.
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{inputPrompts.map((prompt) => (
|
||||
<PromptCard key={prompt.id} prompt={prompt} />
|
||||
<PromptCard
|
||||
key={prompt.id}
|
||||
prompt={prompt}
|
||||
onEdit={handleEdit}
|
||||
onSave={handleSave}
|
||||
onDelete={handleDelete}
|
||||
isEditing={editingPromptId === prompt.id}
|
||||
/>
|
||||
))}
|
||||
|
||||
{isCreatingNew ? (
|
||||
@@ -315,3 +215,129 @@ export default function InputPrompts() {
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
interface PromptCardProps {
|
||||
prompt: InputPrompt;
|
||||
onEdit: (id: number) => void;
|
||||
onSave: (id: number, prompt: string, content: string) => void;
|
||||
onDelete: (id: number) => void;
|
||||
isEditing: boolean;
|
||||
}
|
||||
|
||||
const PromptCard: React.FC<PromptCardProps> = ({
|
||||
prompt,
|
||||
onEdit,
|
||||
onSave,
|
||||
onDelete,
|
||||
isEditing,
|
||||
}) => {
|
||||
const [localPrompt, setLocalPrompt] = useState(prompt.prompt);
|
||||
const [localContent, setLocalContent] = useState(prompt.content);
|
||||
|
||||
useEffect(() => {
|
||||
setLocalPrompt(prompt.prompt);
|
||||
setLocalContent(prompt.content);
|
||||
}, [prompt, isEditing]);
|
||||
|
||||
const handleLocalEdit = useCallback(
|
||||
(field: "prompt" | "content", value: string) => {
|
||||
if (field === "prompt") {
|
||||
setLocalPrompt(value);
|
||||
} else {
|
||||
setLocalContent(value);
|
||||
}
|
||||
},
|
||||
[]
|
||||
);
|
||||
|
||||
const handleSaveLocal = useCallback(() => {
|
||||
onSave(prompt.id, localPrompt, localContent);
|
||||
}, [prompt.id, localPrompt, localContent, onSave]);
|
||||
|
||||
const isPromptPublic = useCallback((p: InputPrompt): boolean => {
|
||||
return p.is_public;
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<div className="border dark:border-none dark:bg-[#333333] rounded-lg p-4 mb-4 relative">
|
||||
{isEditing ? (
|
||||
<>
|
||||
<div className="absolute top-2 right-2">
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
onClick={() => {
|
||||
onEdit(0);
|
||||
}}
|
||||
>
|
||||
<XIcon size={14} />
|
||||
</Button>
|
||||
</div>
|
||||
<div className="flex">
|
||||
<div className="flex-grow mr-4">
|
||||
<Textarea
|
||||
value={localPrompt}
|
||||
onChange={(e) => handleLocalEdit("prompt", e.target.value)}
|
||||
className="mb-2 resize-none"
|
||||
placeholder="Prompt"
|
||||
/>
|
||||
<Textarea
|
||||
value={localContent}
|
||||
onChange={(e) => handleLocalEdit("content", e.target.value)}
|
||||
className="resize-vertical min-h-[100px]"
|
||||
placeholder="Content"
|
||||
/>
|
||||
</div>
|
||||
<div className="flex items-end">
|
||||
<Button onClick={handleSaveLocal}>
|
||||
{prompt.id ? "Save" : "Create"}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<TooltipProvider>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<div className="mb-2 flex gap-x-2 ">
|
||||
<p className="font-semibold">{prompt.prompt}</p>
|
||||
{isPromptPublic(prompt) && <SourceChip title="Built-in" />}
|
||||
</div>
|
||||
</TooltipTrigger>
|
||||
{isPromptPublic(prompt) && (
|
||||
<TooltipContent>
|
||||
<p>This is a built-in prompt and cannot be edited</p>
|
||||
</TooltipContent>
|
||||
)}
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
<div className="whitespace-pre-wrap">{prompt.content}</div>
|
||||
<div className="absolute top-2 right-2">
|
||||
<DropdownMenu>
|
||||
<DropdownMenuTrigger className="hover:bg-transparent" asChild>
|
||||
<Button
|
||||
className="!hover:bg-transparent"
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
>
|
||||
<MoreVertical size={14} />
|
||||
</Button>
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent>
|
||||
{!isPromptPublic(prompt) && (
|
||||
<DropdownMenuItem onClick={() => onEdit(prompt.id)}>
|
||||
Edit
|
||||
</DropdownMenuItem>
|
||||
)}
|
||||
<DropdownMenuItem onClick={() => onDelete(prompt.id)}>
|
||||
Delete
|
||||
</DropdownMenuItem>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
147
web/src/app/chat/input-prompts/PromptCard.tsx
Normal file
147
web/src/app/chat/input-prompts/PromptCard.tsx
Normal file
@@ -0,0 +1,147 @@
|
||||
import { SourceChip } from "../input/ChatInputBar";
|
||||
|
||||
import { useEffect } from "react";
|
||||
import { useState } from "react";
|
||||
import { InputPrompt } from "../interfaces";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { XIcon } from "@/components/icons/icons";
|
||||
import { Textarea } from "@/components/ui/textarea";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipProvider,
|
||||
TooltipTrigger,
|
||||
} from "@/components/ui/tooltip";
|
||||
import {
|
||||
DropdownMenu,
|
||||
DropdownMenuContent,
|
||||
DropdownMenuItem,
|
||||
DropdownMenuTrigger,
|
||||
} from "@/components/ui/dropdown-menu";
|
||||
import { MoreVertical } from "lucide-react";
|
||||
|
||||
export const PromptCard = ({
|
||||
prompt,
|
||||
editingPromptId,
|
||||
setEditingPromptId,
|
||||
handleSave,
|
||||
handleDelete,
|
||||
isPromptPublic,
|
||||
handleEdit,
|
||||
fetchInputPrompts,
|
||||
}: {
|
||||
prompt: InputPrompt;
|
||||
editingPromptId: number | null;
|
||||
setEditingPromptId: (id: number | null) => void;
|
||||
handleSave: (id: number, prompt: string, content: string) => void;
|
||||
handleDelete: (id: number) => void;
|
||||
isPromptPublic: (prompt: InputPrompt) => boolean;
|
||||
handleEdit: (id: number) => void;
|
||||
fetchInputPrompts: () => void;
|
||||
}) => {
|
||||
const isEditing = editingPromptId === prompt.id;
|
||||
const [localPrompt, setLocalPrompt] = useState(prompt.prompt);
|
||||
const [localContent, setLocalContent] = useState(prompt.content);
|
||||
|
||||
// Sync local edits with any prompt changes from outside
|
||||
useEffect(() => {
|
||||
setLocalPrompt(prompt.prompt);
|
||||
setLocalContent(prompt.content);
|
||||
}, [prompt, isEditing]);
|
||||
|
||||
const handleLocalEdit = (field: "prompt" | "content", value: string) => {
|
||||
if (field === "prompt") {
|
||||
setLocalPrompt(value);
|
||||
} else {
|
||||
setLocalContent(value);
|
||||
}
|
||||
};
|
||||
|
||||
const handleSaveLocal = () => {
|
||||
handleSave(prompt.id, localPrompt, localContent);
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="border dark:border-none dark:bg-[#333333] rounded-lg p-4 mb-4 relative">
|
||||
{isEditing ? (
|
||||
<>
|
||||
<div className="absolute top-2 right-2">
|
||||
<Button
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
onClick={() => {
|
||||
setEditingPromptId(null);
|
||||
fetchInputPrompts(); // Revert changes from server
|
||||
}}
|
||||
>
|
||||
<XIcon size={14} />
|
||||
</Button>
|
||||
</div>
|
||||
<div className="flex">
|
||||
<div className="flex-grow mr-4">
|
||||
<Textarea
|
||||
value={localPrompt}
|
||||
onChange={(e) => handleLocalEdit("prompt", e.target.value)}
|
||||
className="mb-2 resize-none"
|
||||
placeholder="Prompt"
|
||||
/>
|
||||
<Textarea
|
||||
value={localContent}
|
||||
onChange={(e) => handleLocalEdit("content", e.target.value)}
|
||||
className="resize-vertical min-h-[100px]"
|
||||
placeholder="Content"
|
||||
/>
|
||||
</div>
|
||||
<div className="flex items-end">
|
||||
<Button onClick={handleSaveLocal}>
|
||||
{prompt.id ? "Save" : "Create"}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
<TooltipProvider>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<div className="mb-2 flex gap-x-2 ">
|
||||
<p className="font-semibold">{prompt.prompt}</p>
|
||||
{isPromptPublic(prompt) && <SourceChip title="Built-in" />}
|
||||
</div>
|
||||
</TooltipTrigger>
|
||||
{isPromptPublic(prompt) && (
|
||||
<TooltipContent>
|
||||
<p>This is a built-in prompt and cannot be edited</p>
|
||||
</TooltipContent>
|
||||
)}
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
<div className="whitespace-pre-wrap">{prompt.content}</div>
|
||||
<div className="absolute top-2 right-2">
|
||||
<DropdownMenu>
|
||||
<DropdownMenuTrigger className="hover:bg-transparent" asChild>
|
||||
<Button
|
||||
className="!hover:bg-transparent"
|
||||
variant="ghost"
|
||||
size="sm"
|
||||
>
|
||||
<MoreVertical size={14} />
|
||||
</Button>
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent>
|
||||
{!isPromptPublic(prompt) && (
|
||||
<DropdownMenuItem onClick={() => handleEdit(prompt.id)}>
|
||||
Edit
|
||||
</DropdownMenuItem>
|
||||
)}
|
||||
<DropdownMenuItem onClick={() => handleDelete(prompt.id)}>
|
||||
Delete
|
||||
</DropdownMenuItem>
|
||||
</DropdownMenuContent>
|
||||
</DropdownMenu>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
@@ -154,7 +154,6 @@ export const SourceChip = ({
|
||||
gap-x-1
|
||||
h-6
|
||||
${onClick ? "cursor-pointer" : ""}
|
||||
animate-fade-in-scale
|
||||
`}
|
||||
>
|
||||
{icon}
|
||||
|
||||
@@ -7,14 +7,9 @@ import React, {
|
||||
useContext,
|
||||
useEffect,
|
||||
useMemo,
|
||||
useRef,
|
||||
useState,
|
||||
} from "react";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipProvider,
|
||||
TooltipTrigger,
|
||||
} from "@/components/ui/tooltip";
|
||||
import ReactMarkdown from "react-markdown";
|
||||
import { OnyxDocument, FilteredOnyxDocument } from "@/lib/search/interfaces";
|
||||
import remarkGfm from "remark-gfm";
|
||||
@@ -54,6 +49,7 @@ import rehypeKatex from "rehype-katex";
|
||||
import "katex/dist/katex.min.css";
|
||||
import SubQuestionsDisplay from "./SubQuestionsDisplay";
|
||||
import { StatusRefinement } from "../Refinement";
|
||||
import { copyAll, handleCopy } from "./copyingUtils";
|
||||
|
||||
export const AgenticMessage = ({
|
||||
isStreamingQuestions,
|
||||
@@ -312,6 +308,8 @@ export const AgenticMessage = ({
|
||||
[anchorCallback, paragraphCallback, streamedContent]
|
||||
);
|
||||
|
||||
const markdownRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
const renderedAlternativeMarkdown = useMemo(() => {
|
||||
return (
|
||||
<ReactMarkdown
|
||||
@@ -492,7 +490,11 @@ export const AgenticMessage = ({
|
||||
|
||||
<div className="px-4">
|
||||
{typeof content === "string" ? (
|
||||
<div className="overflow-x-visible !text-sm max-w-content-max">
|
||||
<div
|
||||
onCopy={(e) => handleCopy(e, markdownRef)}
|
||||
ref={markdownRef}
|
||||
className="overflow-x-visible !text-sm max-w-content-max"
|
||||
>
|
||||
{isViewingInitialAnswer
|
||||
? renderedMarkdown
|
||||
: renderedAlternativeMarkdown}
|
||||
@@ -558,7 +560,16 @@ export const AgenticMessage = ({
|
||||
)}
|
||||
</div>
|
||||
<CustomTooltip showTick line content="Copy">
|
||||
<CopyButton content={content.toString()} />
|
||||
<CopyButton
|
||||
copyAllFn={() =>
|
||||
copyAll(
|
||||
(isViewingInitialAnswer
|
||||
? finalContent
|
||||
: finalAlternativeContent) as string,
|
||||
markdownRef
|
||||
)
|
||||
}
|
||||
/>
|
||||
</CustomTooltip>
|
||||
<CustomTooltip showTick line content="Good response">
|
||||
<HoverableIcon
|
||||
@@ -644,7 +655,16 @@ export const AgenticMessage = ({
|
||||
)}
|
||||
</div>
|
||||
<CustomTooltip showTick line content="Copy">
|
||||
<CopyButton content={content.toString()} />
|
||||
<CopyButton
|
||||
copyAllFn={() =>
|
||||
copyAll(
|
||||
(isViewingInitialAnswer
|
||||
? finalContent
|
||||
: finalAlternativeContent) as string,
|
||||
markdownRef
|
||||
)
|
||||
}
|
||||
/>
|
||||
</CustomTooltip>
|
||||
|
||||
<CustomTooltip showTick line content="Good response">
|
||||
|
||||
@@ -16,15 +16,16 @@ import React, {
|
||||
useRef,
|
||||
useState,
|
||||
} from "react";
|
||||
import { unified } from "unified";
|
||||
import ReactMarkdown from "react-markdown";
|
||||
import { OnyxDocument, FilteredOnyxDocument } from "@/lib/search/interfaces";
|
||||
import { SearchSummary } from "./SearchSummary";
|
||||
import {
|
||||
markdownToHtml,
|
||||
getMarkdownForSelection,
|
||||
} from "@/app/chat/message/codeUtils";
|
||||
import { SkippedSearch } from "./SkippedSearch";
|
||||
import remarkGfm from "remark-gfm";
|
||||
import remarkParse from "remark-parse";
|
||||
import remarkRehype from "remark-rehype";
|
||||
import rehypeSanitize from "rehype-sanitize";
|
||||
import rehypeStringify from "rehype-stringify";
|
||||
import { CopyButton } from "@/components/CopyButton";
|
||||
import { ChatFileType, FileDescriptor, ToolCallMetadata } from "../interfaces";
|
||||
import {
|
||||
@@ -69,6 +70,7 @@ import { SourceCard } from "./SourcesDisplay";
|
||||
import remarkMath from "remark-math";
|
||||
import rehypeKatex from "rehype-katex";
|
||||
import "katex/dist/katex.min.css";
|
||||
import { copyAll, handleCopy } from "./copyingUtils";
|
||||
|
||||
const TOOLS_WITH_CUSTOM_HANDLING = [
|
||||
SEARCH_TOOL_NAME,
|
||||
@@ -364,34 +366,24 @@ export const AIMessage = ({
|
||||
}),
|
||||
[anchorCallback, paragraphCallback, finalContent]
|
||||
);
|
||||
const markdownRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
// Process selection copying with HTML formatting
|
||||
|
||||
const renderedMarkdown = useMemo(() => {
|
||||
if (typeof finalContent !== "string") {
|
||||
return finalContent;
|
||||
}
|
||||
|
||||
// Create a hidden div with the HTML content for copying
|
||||
const htmlContent = markdownToHtml(finalContent);
|
||||
|
||||
return (
|
||||
<>
|
||||
<div
|
||||
style={{
|
||||
position: "absolute",
|
||||
left: "-9999px",
|
||||
display: "none",
|
||||
}}
|
||||
dangerouslySetInnerHTML={{ __html: htmlContent }}
|
||||
/>
|
||||
<ReactMarkdown
|
||||
className="prose dark:prose-invert max-w-full text-base"
|
||||
components={markdownComponents}
|
||||
remarkPlugins={[remarkGfm, remarkMath]}
|
||||
rehypePlugins={[[rehypePrism, { ignoreMissing: true }], rehypeKatex]}
|
||||
>
|
||||
{finalContent}
|
||||
</ReactMarkdown>
|
||||
</>
|
||||
<ReactMarkdown
|
||||
className="prose dark:prose-invert max-w-full text-base"
|
||||
components={markdownComponents}
|
||||
remarkPlugins={[remarkGfm, remarkMath]}
|
||||
rehypePlugins={[[rehypePrism, { ignoreMissing: true }], rehypeKatex]}
|
||||
>
|
||||
{finalContent}
|
||||
</ReactMarkdown>
|
||||
);
|
||||
}, [finalContent, markdownComponents]);
|
||||
|
||||
@@ -535,64 +527,9 @@ export const AIMessage = ({
|
||||
{typeof content === "string" ? (
|
||||
<div className="overflow-x-visible max-w-content-max">
|
||||
<div
|
||||
contentEditable="true"
|
||||
suppressContentEditableWarning
|
||||
ref={markdownRef}
|
||||
className="focus:outline-none cursor-text select-text"
|
||||
style={{
|
||||
MozUserModify: "read-only",
|
||||
WebkitUserModify: "read-only",
|
||||
}}
|
||||
onCopy={(e) => {
|
||||
e.preventDefault();
|
||||
const selection = window.getSelection();
|
||||
const selectedPlainText =
|
||||
selection?.toString() || "";
|
||||
if (!selectedPlainText) {
|
||||
// If no text is selected, copy the full content
|
||||
const contentStr =
|
||||
typeof content === "string"
|
||||
? content
|
||||
: (
|
||||
content as JSX.Element
|
||||
).props?.children?.toString() || "";
|
||||
const clipboardItem = new ClipboardItem({
|
||||
"text/html": new Blob(
|
||||
[
|
||||
typeof content === "string"
|
||||
? markdownToHtml(content)
|
||||
: contentStr,
|
||||
],
|
||||
{ type: "text/html" }
|
||||
),
|
||||
"text/plain": new Blob([contentStr], {
|
||||
type: "text/plain",
|
||||
}),
|
||||
});
|
||||
navigator.clipboard.write([clipboardItem]);
|
||||
return;
|
||||
}
|
||||
|
||||
const contentStr =
|
||||
typeof content === "string"
|
||||
? content
|
||||
: (
|
||||
content as JSX.Element
|
||||
).props?.children?.toString() || "";
|
||||
const markdownText = getMarkdownForSelection(
|
||||
contentStr,
|
||||
selectedPlainText
|
||||
);
|
||||
const clipboardItem = new ClipboardItem({
|
||||
"text/html": new Blob(
|
||||
[markdownToHtml(markdownText)],
|
||||
{ type: "text/html" }
|
||||
),
|
||||
"text/plain": new Blob([selectedPlainText], {
|
||||
type: "text/plain",
|
||||
}),
|
||||
});
|
||||
navigator.clipboard.write([clipboardItem]);
|
||||
}}
|
||||
onCopy={(e) => handleCopy(e, markdownRef)}
|
||||
>
|
||||
{renderedMarkdown}
|
||||
</div>
|
||||
@@ -643,13 +580,8 @@ export const AIMessage = ({
|
||||
</div>
|
||||
<CustomTooltip showTick line content="Copy">
|
||||
<CopyButton
|
||||
content={
|
||||
typeof content === "string"
|
||||
? {
|
||||
html: markdownToHtml(content),
|
||||
plainText: content,
|
||||
}
|
||||
: content.toString()
|
||||
copyAllFn={() =>
|
||||
copyAll(finalContent as string, markdownRef)
|
||||
}
|
||||
/>
|
||||
</CustomTooltip>
|
||||
@@ -734,13 +666,8 @@ export const AIMessage = ({
|
||||
</div>
|
||||
<CustomTooltip showTick line content="Copy">
|
||||
<CopyButton
|
||||
content={
|
||||
typeof content === "string"
|
||||
? {
|
||||
html: markdownToHtml(content),
|
||||
plainText: content,
|
||||
}
|
||||
: content.toString()
|
||||
copyAllFn={() =>
|
||||
copyAll(finalContent as string, markdownRef)
|
||||
}
|
||||
/>
|
||||
</CustomTooltip>
|
||||
|
||||
@@ -24,6 +24,7 @@ import { CodeBlock } from "./CodeBlock";
|
||||
import { CheckIcon, ChevronDown } from "lucide-react";
|
||||
import { PHASE_MIN_MS, useStreamingMessages } from "./StreamingMessages";
|
||||
import { CirclingArrowIcon } from "@/components/icons/icons";
|
||||
import { handleCopy } from "./copyingUtils";
|
||||
|
||||
export const StatusIndicator = ({ status }: { status: ToggleState }) => {
|
||||
return (
|
||||
@@ -292,6 +293,7 @@ const SubQuestionDisplay: React.FC<{
|
||||
}
|
||||
}, [currentlyClosed]);
|
||||
|
||||
const analysisRef = useRef<HTMLDivElement>(null);
|
||||
const renderedMarkdown = useMemo(() => {
|
||||
return (
|
||||
<ReactMarkdown
|
||||
@@ -428,7 +430,11 @@ const SubQuestionDisplay: React.FC<{
|
||||
/>
|
||||
</div>
|
||||
{analysisToggled && (
|
||||
<div className="flex flex-wrap gap-2">
|
||||
<div
|
||||
ref={analysisRef}
|
||||
onCopy={(e) => handleCopy(e, analysisRef)}
|
||||
className="flex flex-wrap gap-2"
|
||||
>
|
||||
{renderedMarkdown}
|
||||
</div>
|
||||
)}
|
||||
|
||||
@@ -1,176 +0,0 @@
|
||||
import { markdownToHtml, parseMarkdownToSegments } from "../codeUtils";
|
||||
|
||||
describe("markdownToHtml", () => {
|
||||
test("converts bold text with asterisks and underscores", () => {
|
||||
expect(markdownToHtml("This is **bold** text")).toBe(
|
||||
"<p>This is <strong>bold</strong> text</p>"
|
||||
);
|
||||
expect(markdownToHtml("This is __bold__ text")).toBe(
|
||||
"<p>This is <strong>bold</strong> text</p>"
|
||||
);
|
||||
});
|
||||
|
||||
test("converts italic text with asterisks and underscores", () => {
|
||||
expect(markdownToHtml("This is *italic* text")).toBe(
|
||||
"<p>This is <em>italic</em> text</p>"
|
||||
);
|
||||
expect(markdownToHtml("This is _italic_ text")).toBe(
|
||||
"<p>This is <em>italic</em> text</p>"
|
||||
);
|
||||
});
|
||||
|
||||
test("handles mixed bold and italic", () => {
|
||||
expect(markdownToHtml("This is **bold** and *italic* text")).toBe(
|
||||
"<p>This is <strong>bold</strong> and <em>italic</em> text</p>"
|
||||
);
|
||||
expect(markdownToHtml("This is __bold__ and _italic_ text")).toBe(
|
||||
"<p>This is <strong>bold</strong> and <em>italic</em> text</p>"
|
||||
);
|
||||
});
|
||||
|
||||
test("handles text with spaces and special characters", () => {
|
||||
expect(markdownToHtml("This is *as delicious and* tasty")).toBe(
|
||||
"<p>This is <em>as delicious and</em> tasty</p>"
|
||||
);
|
||||
expect(markdownToHtml("This is _as delicious and_ tasty")).toBe(
|
||||
"<p>This is <em>as delicious and</em> tasty</p>"
|
||||
);
|
||||
});
|
||||
|
||||
test("handles multi-paragraph text with italics", () => {
|
||||
const input =
|
||||
"Sure! Here is a sentence with one italicized word:\n\nThe cake was _delicious_ and everyone enjoyed it.";
|
||||
expect(markdownToHtml(input)).toBe(
|
||||
"<p>Sure! Here is a sentence with one italicized word:</p>\n<p>The cake was <em>delicious</em> and everyone enjoyed it.</p>"
|
||||
);
|
||||
});
|
||||
|
||||
test("handles malformed markdown without crashing", () => {
|
||||
expect(markdownToHtml("This is *malformed markdown")).toBe(
|
||||
"<p>This is *malformed markdown</p>"
|
||||
);
|
||||
expect(markdownToHtml("This is _also malformed")).toBe(
|
||||
"<p>This is _also malformed</p>"
|
||||
);
|
||||
expect(markdownToHtml("This has **unclosed bold")).toBe(
|
||||
"<p>This has **unclosed bold</p>"
|
||||
);
|
||||
expect(markdownToHtml("This has __unclosed bold")).toBe(
|
||||
"<p>This has __unclosed bold</p>"
|
||||
);
|
||||
});
|
||||
|
||||
test("handles empty or null input", () => {
|
||||
expect(markdownToHtml("")).toBe("");
|
||||
expect(markdownToHtml(" ")).toBe("");
|
||||
expect(markdownToHtml("\n")).toBe("");
|
||||
});
|
||||
|
||||
test("handles extremely long input without crashing", () => {
|
||||
const longText = "This is *italic* ".repeat(1000);
|
||||
expect(() => markdownToHtml(longText)).not.toThrow();
|
||||
});
|
||||
});
|
||||
|
||||
describe("parseMarkdownToSegments", () => {
|
||||
test("parses italic text with asterisks", () => {
|
||||
const segments = parseMarkdownToSegments("This is *italic* text");
|
||||
expect(segments).toEqual([
|
||||
{ type: "text", text: "This is ", raw: "This is ", length: 8 },
|
||||
{ type: "italic", text: "italic", raw: "*italic*", length: 6 },
|
||||
{ type: "text", text: " text", raw: " text", length: 5 },
|
||||
]);
|
||||
});
|
||||
|
||||
test("parses italic text with underscores", () => {
|
||||
const segments = parseMarkdownToSegments("This is _italic_ text");
|
||||
expect(segments).toEqual([
|
||||
{ type: "text", text: "This is ", raw: "This is ", length: 8 },
|
||||
{ type: "italic", text: "italic", raw: "_italic_", length: 6 },
|
||||
{ type: "text", text: " text", raw: " text", length: 5 },
|
||||
]);
|
||||
});
|
||||
|
||||
test("parses bold text with asterisks", () => {
|
||||
const segments = parseMarkdownToSegments("This is **bold** text");
|
||||
expect(segments).toEqual([
|
||||
{ type: "text", text: "This is ", raw: "This is ", length: 8 },
|
||||
{ type: "bold", text: "bold", raw: "**bold**", length: 4 },
|
||||
{ type: "text", text: " text", raw: " text", length: 5 },
|
||||
]);
|
||||
});
|
||||
|
||||
test("parses bold text with underscores", () => {
|
||||
const segments = parseMarkdownToSegments("This is __bold__ text");
|
||||
expect(segments).toEqual([
|
||||
{ type: "text", text: "This is ", raw: "This is ", length: 8 },
|
||||
{ type: "bold", text: "bold", raw: "__bold__", length: 4 },
|
||||
{ type: "text", text: " text", raw: " text", length: 5 },
|
||||
]);
|
||||
});
|
||||
|
||||
test("parses text with spaces and special characters in italics", () => {
|
||||
const segments = parseMarkdownToSegments(
|
||||
"The cake was _delicious_ and everyone enjoyed it."
|
||||
);
|
||||
expect(segments).toEqual([
|
||||
{ type: "text", text: "The cake was ", raw: "The cake was ", length: 13 },
|
||||
{ type: "italic", text: "delicious", raw: "_delicious_", length: 9 },
|
||||
{
|
||||
type: "text",
|
||||
text: " and everyone enjoyed it.",
|
||||
raw: " and everyone enjoyed it.",
|
||||
length: 25,
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
test("parses multi-paragraph text with italics", () => {
|
||||
const segments = parseMarkdownToSegments(
|
||||
"Sure! Here is a sentence with one italicized word:\n\nThe cake was _delicious_ and everyone enjoyed it."
|
||||
);
|
||||
expect(segments).toEqual([
|
||||
{
|
||||
type: "text",
|
||||
text: "Sure! Here is a sentence with one italicized word:\n\nThe cake was ",
|
||||
raw: "Sure! Here is a sentence with one italicized word:\n\nThe cake was ",
|
||||
length: 65,
|
||||
},
|
||||
{ type: "italic", text: "delicious", raw: "_delicious_", length: 9 },
|
||||
{
|
||||
type: "text",
|
||||
text: " and everyone enjoyed it.",
|
||||
raw: " and everyone enjoyed it.",
|
||||
length: 25,
|
||||
},
|
||||
]);
|
||||
});
|
||||
|
||||
test("handles malformed markdown without crashing", () => {
|
||||
expect(() => parseMarkdownToSegments("This is *malformed")).not.toThrow();
|
||||
expect(() =>
|
||||
parseMarkdownToSegments("This is _also malformed")
|
||||
).not.toThrow();
|
||||
expect(() =>
|
||||
parseMarkdownToSegments("This has **unclosed bold")
|
||||
).not.toThrow();
|
||||
expect(() =>
|
||||
parseMarkdownToSegments("This has __unclosed bold")
|
||||
).not.toThrow();
|
||||
});
|
||||
|
||||
test("handles empty or null input", () => {
|
||||
expect(parseMarkdownToSegments("")).toEqual([]);
|
||||
expect(parseMarkdownToSegments(" ")).toEqual([
|
||||
{ type: "text", text: " ", raw: " ", length: 1 },
|
||||
]);
|
||||
expect(parseMarkdownToSegments("\n")).toEqual([
|
||||
{ type: "text", text: "\n", raw: "\n", length: 1 },
|
||||
]);
|
||||
});
|
||||
|
||||
test("handles extremely long input without crashing", () => {
|
||||
const longText = "This is *italic* ".repeat(1000);
|
||||
expect(() => parseMarkdownToSegments(longText)).not.toThrow();
|
||||
});
|
||||
});
|
||||
@@ -83,35 +83,6 @@ export const preprocessLaTeX = (content: string) => {
|
||||
return inlineProcessedContent;
|
||||
};
|
||||
|
||||
export const markdownToHtml = (content: string): string => {
|
||||
if (!content || !content.trim()) {
|
||||
return "";
|
||||
}
|
||||
|
||||
// Basic markdown to HTML conversion for common patterns
|
||||
const processedContent = content
|
||||
.replace(/(\*\*|__)((?:(?!\1).)*?)\1/g, "<strong>$2</strong>") // Bold with ** or __, non-greedy and no nesting
|
||||
.replace(/(\*|_)([^*_\n]+?)\1(?!\*|_)/g, "<em>$2</em>"); // Italic with * or _
|
||||
|
||||
// Handle code blocks and links
|
||||
const withCodeAndLinks = processedContent
|
||||
.replace(/`([^`]+)`/g, "<code>$1</code>") // Inline code
|
||||
.replace(
|
||||
/```(\w*)\n([\s\S]*?)```/g,
|
||||
(_, lang, code) =>
|
||||
`<pre><code class="language-${lang}">${code.trim()}</code></pre>`
|
||||
) // Code blocks
|
||||
.replace(/\[([^\]]+)\]\(([^)]+)\)/g, '<a href="$2">$1</a>'); // Links
|
||||
|
||||
// Handle paragraphs
|
||||
return withCodeAndLinks
|
||||
.split(/\n\n+/)
|
||||
.map((para) => para.trim())
|
||||
.filter((para) => para.length > 0)
|
||||
.map((para) => `<p>${para}</p>`)
|
||||
.join("\n");
|
||||
};
|
||||
|
||||
interface MarkdownSegment {
|
||||
type: "text" | "link" | "code" | "bold" | "italic" | "codeblock";
|
||||
text: string; // The visible/plain text
|
||||
|
||||
71
web/src/app/chat/message/copyingUtils.tsx
Normal file
71
web/src/app/chat/message/copyingUtils.tsx
Normal file
@@ -0,0 +1,71 @@
|
||||
"use client";
|
||||
import { unified } from "unified";
|
||||
import remarkParse from "remark-parse";
|
||||
import remarkGfm from "remark-gfm";
|
||||
import remarkMath from "remark-math";
|
||||
import remarkRehype from "remark-rehype";
|
||||
import rehypePrism from "rehype-prism-plus";
|
||||
import rehypeKatex from "rehype-katex";
|
||||
import rehypeSanitize from "rehype-sanitize";
|
||||
import rehypeStringify from "rehype-stringify";
|
||||
|
||||
export const handleCopy = (
|
||||
e: React.ClipboardEvent,
|
||||
markdownRef: React.RefObject<HTMLDivElement>
|
||||
) => {
|
||||
// Check if we have a selection
|
||||
const selection = window.getSelection();
|
||||
if (!selection?.rangeCount) return;
|
||||
|
||||
const range = selection.getRangeAt(0);
|
||||
|
||||
// If selection is within our markdown container
|
||||
if (
|
||||
markdownRef.current &&
|
||||
markdownRef.current.contains(range.commonAncestorContainer)
|
||||
) {
|
||||
e.preventDefault();
|
||||
|
||||
// Clone selection to get the HTML
|
||||
const fragment = range.cloneContents();
|
||||
const tempDiv = document.createElement("div");
|
||||
tempDiv.appendChild(fragment);
|
||||
|
||||
// Create clipboard data with both HTML and plain text
|
||||
e.clipboardData.setData("text/html", tempDiv.innerHTML);
|
||||
e.clipboardData.setData("text/plain", selection.toString());
|
||||
}
|
||||
};
|
||||
|
||||
// For copying the entire content
|
||||
export const copyAll = (
|
||||
content: string,
|
||||
markdownRef: React.RefObject<HTMLDivElement>
|
||||
) => {
|
||||
if (!markdownRef.current || typeof content !== "string") {
|
||||
return;
|
||||
}
|
||||
|
||||
// Convert markdown to HTML using unified ecosystem
|
||||
unified()
|
||||
.use(remarkParse)
|
||||
.use(remarkGfm)
|
||||
.use(remarkMath)
|
||||
.use(remarkRehype)
|
||||
.use(rehypePrism, { ignoreMissing: true })
|
||||
.use(rehypeKatex)
|
||||
.use(rehypeSanitize)
|
||||
.use(rehypeStringify)
|
||||
.process(content)
|
||||
.then((file: any) => {
|
||||
const htmlContent = String(file);
|
||||
|
||||
// Create clipboard data
|
||||
const clipboardItem = new ClipboardItem({
|
||||
"text/html": new Blob([htmlContent], { type: "text/html" }),
|
||||
"text/plain": new Blob([content], { type: "text/plain" }),
|
||||
});
|
||||
|
||||
navigator.clipboard.write([clipboardItem]);
|
||||
});
|
||||
};
|
||||
@@ -40,8 +40,13 @@ export function UserSettingsModal({
|
||||
onClose: () => void;
|
||||
defaultModel: string | null;
|
||||
}) {
|
||||
const { refreshUser, user, updateUserAutoScroll, updateUserShortcuts } =
|
||||
useUser();
|
||||
const {
|
||||
refreshUser,
|
||||
user,
|
||||
updateUserAutoScroll,
|
||||
updateUserShortcuts,
|
||||
updateUserTemperatureOverrideEnabled,
|
||||
} = useUser();
|
||||
const containerRef = useRef<HTMLDivElement>(null);
|
||||
const messageRef = useRef<HTMLDivElement>(null);
|
||||
const { theme, setTheme } = useTheme();
|
||||
@@ -156,11 +161,6 @@ export function UserSettingsModal({
|
||||
const settings = useContext(SettingsContext);
|
||||
const autoScroll = settings?.settings?.auto_scroll;
|
||||
|
||||
const checked =
|
||||
user?.preferences?.auto_scroll === null
|
||||
? autoScroll
|
||||
: user?.preferences?.auto_scroll;
|
||||
|
||||
const handleChangePassword = async (e: React.FormEvent) => {
|
||||
e.preventDefault();
|
||||
if (newPassword !== confirmPassword) {
|
||||
@@ -288,12 +288,26 @@ export function UserSettingsModal({
|
||||
<SubLabel>Automatically scroll to new content</SubLabel>
|
||||
</div>
|
||||
<Switch
|
||||
checked={checked}
|
||||
checked={user?.preferences.auto_scroll}
|
||||
onCheckedChange={(checked) => {
|
||||
updateUserAutoScroll(checked);
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex items-center justify-between">
|
||||
<div>
|
||||
<h3 className="text-lg font-medium">
|
||||
Temperature override
|
||||
</h3>
|
||||
<SubLabel>Set the temperature for the LLM</SubLabel>
|
||||
</div>
|
||||
<Switch
|
||||
checked={user?.preferences.temperature_override_enabled}
|
||||
onCheckedChange={(checked) => {
|
||||
updateUserTemperatureOverrideEnabled(checked);
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex items-center justify-between">
|
||||
<div>
|
||||
<h3 className="text-lg font-medium">Prompt Shortcuts</h3>
|
||||
|
||||
@@ -4,33 +4,36 @@ import { CheckmarkIcon, CopyMessageIcon } from "./icons/icons";
|
||||
|
||||
export function CopyButton({
|
||||
content,
|
||||
copyAllFn,
|
||||
onClick,
|
||||
}: {
|
||||
content?: string | { html: string; plainText: string };
|
||||
content?: string;
|
||||
copyAllFn?: () => void;
|
||||
onClick?: () => void;
|
||||
}) {
|
||||
const [copyClicked, setCopyClicked] = useState(false);
|
||||
|
||||
const copyToClipboard = async (
|
||||
content: string | { html: string; plainText: string }
|
||||
) => {
|
||||
const copyToClipboard = async () => {
|
||||
try {
|
||||
// If copyAllFn is provided, use it instead of the default behavior
|
||||
if (copyAllFn) {
|
||||
await copyAllFn();
|
||||
return;
|
||||
}
|
||||
|
||||
// Fall back to original behavior if no copyAllFn is provided
|
||||
if (!content) return;
|
||||
|
||||
const clipboardItem = new ClipboardItem({
|
||||
"text/html": new Blob(
|
||||
[typeof content === "string" ? content : content.html],
|
||||
{ type: "text/html" }
|
||||
),
|
||||
"text/plain": new Blob(
|
||||
[typeof content === "string" ? content : content.plainText],
|
||||
{ type: "text/plain" }
|
||||
),
|
||||
"text/html": new Blob([content], { type: "text/html" }),
|
||||
"text/plain": new Blob([content], { type: "text/plain" }),
|
||||
});
|
||||
await navigator.clipboard.write([clipboardItem]);
|
||||
} catch (err) {
|
||||
// Fallback to basic text copy if HTML copy fails
|
||||
await navigator.clipboard.writeText(
|
||||
typeof content === "string" ? content : content.plainText
|
||||
);
|
||||
if (content) {
|
||||
await navigator.clipboard.writeText(content);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -38,9 +41,7 @@ export function CopyButton({
|
||||
<HoverableIcon
|
||||
icon={copyClicked ? <CheckmarkIcon /> : <CopyMessageIcon />}
|
||||
onClick={() => {
|
||||
if (content) {
|
||||
copyToClipboard(content);
|
||||
}
|
||||
copyToClipboard();
|
||||
onClick && onClick();
|
||||
|
||||
setCopyClicked(true);
|
||||
|
||||
@@ -359,18 +359,25 @@ export function ClientLayout({
|
||||
),
|
||||
link: "/admin/performance/usage",
|
||||
},
|
||||
{
|
||||
name: (
|
||||
<div className="flex">
|
||||
<DatabaseIconSkeleton
|
||||
className="text-text-700"
|
||||
size={18}
|
||||
/>
|
||||
<div className="ml-1">Query History</div>
|
||||
</div>
|
||||
),
|
||||
link: "/admin/performance/query-history",
|
||||
},
|
||||
...(settings?.settings.query_history_type !==
|
||||
"disabled"
|
||||
? [
|
||||
{
|
||||
name: (
|
||||
<div className="flex">
|
||||
<DatabaseIconSkeleton
|
||||
className="text-text-700"
|
||||
size={18}
|
||||
/>
|
||||
<div className="ml-1">
|
||||
Query History
|
||||
</div>
|
||||
</div>
|
||||
),
|
||||
link: "/admin/performance/query-history",
|
||||
},
|
||||
]
|
||||
: []),
|
||||
{
|
||||
name: (
|
||||
<div className="flex">
|
||||
|
||||
@@ -3,6 +3,7 @@ import {
|
||||
EnterpriseSettings,
|
||||
ApplicationStatus,
|
||||
Settings,
|
||||
QueryHistoryType,
|
||||
} from "@/app/admin/settings/interfaces";
|
||||
import {
|
||||
CUSTOM_ANALYTICS_ENABLED,
|
||||
@@ -53,6 +54,7 @@ export async function fetchSettingsSS(): Promise<CombinedSettings | null> {
|
||||
anonymous_user_enabled: false,
|
||||
pro_search_enabled: true,
|
||||
temperature_override_enabled: true,
|
||||
query_history_type: QueryHistoryType.NORMAL,
|
||||
};
|
||||
} else {
|
||||
throw new Error(
|
||||
|
||||
@@ -13,7 +13,7 @@ interface UserContextType {
|
||||
isCurator: boolean;
|
||||
refreshUser: () => Promise<void>;
|
||||
isCloudSuperuser: boolean;
|
||||
updateUserAutoScroll: (autoScroll: boolean | null) => Promise<void>;
|
||||
updateUserAutoScroll: (autoScroll: boolean) => Promise<void>;
|
||||
updateUserShortcuts: (enabled: boolean) => Promise<void>;
|
||||
toggleAssistantPinnedStatus: (
|
||||
currentPinnedAssistantIDs: number[],
|
||||
@@ -163,7 +163,7 @@ export function UserProvider({
|
||||
}
|
||||
};
|
||||
|
||||
const updateUserAutoScroll = async (autoScroll: boolean | null) => {
|
||||
const updateUserAutoScroll = async (autoScroll: boolean) => {
|
||||
try {
|
||||
const response = await fetch("/api/auto-scroll", {
|
||||
method: "PATCH",
|
||||
|
||||
@@ -10,7 +10,7 @@ interface UserPreferences {
|
||||
pinned_assistants?: number[];
|
||||
default_model: string | null;
|
||||
recent_assistants: number[];
|
||||
auto_scroll: boolean | null;
|
||||
auto_scroll: boolean;
|
||||
shortcut_enabled: boolean;
|
||||
temperature_override_enabled: boolean;
|
||||
}
|
||||
|
||||
@@ -24,6 +24,8 @@ async function verifyAdminPageNavigation(
|
||||
console.error(
|
||||
`Failed to find h1 with text "${pageTitle}" for path "${path}"`
|
||||
);
|
||||
// NOTE: This is a temporary measure for debugging the issue
|
||||
console.error(await page.content());
|
||||
throw error;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user