Compare commits

..

18 Commits

Author SHA1 Message Date
pablodanswer
bc4a5b6496 Revert "various improvements"
This reverts commit 53225d0a43.
2025-02-03 14:42:18 -08:00
pablodanswer
e1956dc42f pop 2025-02-03 14:41:50 -08:00
pablodanswer
53225d0a43 various improvements 2025-02-03 13:31:43 -08:00
trial-danswer
715359c120 Helm chart refactoring (#3797)
* initial commit for helm chart refactoring

* Continue refactoring helm. I was able to use helm to deploy all of the apps to a cluster in aws. The bottleneck was setting up PVC dynamic provisioning.

* use default storage class

* Fix linter errors

* Fix broken helm test

---------

Co-authored-by: jpb80 <jordan.buttkevitz@gmail.com>
2025-02-03 10:56:07 -08:00
Weves
e061ba2b93 another airtable fix 2025-02-02 20:58:24 -08:00
Weves
87bccc13cc Handle expiring attachments 2025-02-02 12:02:44 -08:00
Weves
569639eb90 Improved attachment handling 2025-02-01 23:07:01 -08:00
pablodanswer
68cb1f3409 ensure tests don't run temporarily 2025-02-01 17:31:44 -08:00
pablonyx
11da0d9889 Add user specific chat session temperature (#3867)
* add user specific chat session temperature

* kbetter typing

* update
2025-02-01 17:29:58 -08:00
pablodanswer
6a7e2a8036 temporarily disable chat tests 2025-02-01 14:15:16 -08:00
pablodanswer
035f83c464 ensure tests pass (temporary dragging disabled) 2025-02-01 12:58:03 -08:00
pablonyx
3c34ddcc4f E2e assistant tests (#3869)
* adding llm override logic

* update

* general cleanup

* fix various tests

* rm

* update

* update

* better comments

* k

* k

* update to pass tests

* clarify content

* improve timeout
2025-02-01 20:05:53 +00:00
pablonyx
a82cac5361 Ensure anonymous users can give feedback
Ensure anonymous users can give feedback
2025-02-01 10:36:14 -08:00
pablodanswer
83e5cb2d2f tested 2025-01-31 16:40:37 -08:00
Chris Weaver
a5d2f0d9ac Fix airtable connector w/ mt cloud + move telem logic to match new st… (#3868)
* Fix airtable connector w/ mt cloud + move telem logic to match new standard

* Address Greptile comment

* Small fixes/improvements

* Revert back monitoring frequency

* Small monitoring fix
2025-01-31 16:29:04 -08:00
Weves
7bc8554e01 Airtable fix 2025-01-31 10:42:27 -08:00
rkuo-danswer
261150e81a Validate permission locks (#3799)
* WIP for external group sync lock fixes

* prototyping permissions validation

* validate permission sync tasks in celery

* mypy

* cleanup and wire off external group sync checks for now

* add active key to reset

* improve logging

* reset on payload format change

* return False on exception

* missed a return

* add count of tasks scanned

* add comment

* better logging

* add return

* more return

* catch payload exceptions

* code review fixes

* push to restart test

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-01-31 17:33:07 +00:00
pablonyx
3e0d24a3f6 Update foreign key migration
Update foreign key migration
2025-01-31 08:45:19 -08:00
82 changed files with 1877 additions and 340 deletions

View File

@@ -8,6 +8,8 @@ on: push
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
GEN_AI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
MOCK_LLM_RESPONSE: true
jobs:
playwright-tests:

View File

@@ -21,10 +21,10 @@ jobs:
- name: Set up Helm
uses: azure/setup-helm@v4.2.0
with:
version: v3.14.4
version: v3.17.0
- name: Set up chart-testing
uses: helm/chart-testing-action@v2.6.1
uses: helm/chart-testing-action@v2.7.0
# even though we specify chart-dirs in ct.yaml, it isn't used by ct for the list-changed command...
- name: Run chart-testing (list-changed)
@@ -37,22 +37,6 @@ jobs:
echo "changed=true" >> "$GITHUB_OUTPUT"
fi
# rkuo: I don't think we need python?
# - name: Set up Python
# uses: actions/setup-python@v5
# with:
# python-version: '3.11'
# cache: 'pip'
# cache-dependency-path: |
# backend/requirements/default.txt
# backend/requirements/dev.txt
# backend/requirements/model_server.txt
# - run: |
# python -m pip install --upgrade pip
# pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
# pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
# pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
# lint all charts if any changes were detected
- name: Run chart-testing (lint)
if: steps.list-changed.outputs.changed == 'true'
@@ -62,7 +46,7 @@ jobs:
- name: Create kind cluster
if: steps.list-changed.outputs.changed == 'true'
uses: helm/kind-action@v1.10.0
uses: helm/kind-action@v1.12.0
- name: Run chart-testing (install)
if: steps.list-changed.outputs.changed == 'true'

View File

@@ -0,0 +1,36 @@
"""add chat session specific temperature override
Revision ID: 2f80c6a2550f
Revises: 33ea50e88f24
Create Date: 2025-01-31 10:30:27.289646
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "2f80c6a2550f"
down_revision = "33ea50e88f24"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"chat_session", sa.Column("temperature_override", sa.Float(), nullable=True)
)
op.add_column(
"user",
sa.Column(
"temperature_override_enabled",
sa.Boolean(),
nullable=False,
server_default=sa.false(),
),
)
def downgrade() -> None:
op.drop_column("chat_session", "temperature_override")
op.drop_column("user", "temperature_override_enabled")

View File

@@ -13,6 +13,7 @@ from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
from onyx.connectors.confluence.utils import get_user_email_from_username__server
from onyx.connectors.models import SlimDocument
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -257,6 +258,7 @@ def _fetch_all_page_restrictions(
slim_docs: list[SlimDocument],
space_permissions_by_space_key: dict[str, ExternalAccess],
is_cloud: bool,
callback: IndexingHeartbeatInterface | None,
) -> list[DocExternalAccess]:
"""
For all pages, if a page has restrictions, then use those restrictions.
@@ -265,6 +267,12 @@ def _fetch_all_page_restrictions(
document_restrictions: list[DocExternalAccess] = []
for slim_doc in slim_docs:
if callback:
if callback.should_stop():
raise RuntimeError("confluence_doc_sync: Stop signal detected")
callback.progress("confluence_doc_sync:fetch_all_page_restrictions", 1)
if slim_doc.perm_sync_data is None:
raise ValueError(
f"No permission sync data found for document {slim_doc.id}"
@@ -334,7 +342,7 @@ def _fetch_all_page_restrictions(
def confluence_doc_sync(
cc_pair: ConnectorCredentialPair,
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
) -> list[DocExternalAccess]:
"""
Adds the external permissions to the documents in postgres
@@ -359,6 +367,12 @@ def confluence_doc_sync(
logger.debug("Fetching all slim documents from confluence")
for doc_batch in confluence_connector.retrieve_all_slim_documents():
logger.debug(f"Got {len(doc_batch)} slim documents from confluence")
if callback:
if callback.should_stop():
raise RuntimeError("confluence_doc_sync: Stop signal detected")
callback.progress("confluence_doc_sync", 1)
slim_docs.extend(doc_batch)
logger.debug("Fetching all page restrictions for space")
@@ -367,4 +381,5 @@ def confluence_doc_sync(
slim_docs=slim_docs,
space_permissions_by_space_key=space_permissions_by_space_key,
is_cloud=is_cloud,
callback=callback,
)

View File

@@ -6,6 +6,7 @@ from onyx.access.models import ExternalAccess
from onyx.connectors.gmail.connector import GmailConnector
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -28,7 +29,7 @@ def _get_slim_doc_generator(
def gmail_doc_sync(
cc_pair: ConnectorCredentialPair,
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
) -> list[DocExternalAccess]:
"""
Adds the external permissions to the documents in postgres
@@ -44,6 +45,12 @@ def gmail_doc_sync(
document_external_access: list[DocExternalAccess] = []
for slim_doc_batch in slim_doc_generator:
for slim_doc in slim_doc_batch:
if callback:
if callback.should_stop():
raise RuntimeError("gmail_doc_sync: Stop signal detected")
callback.progress("gmail_doc_sync", 1)
if slim_doc.perm_sync_data is None:
logger.warning(f"No permissions found for document {slim_doc.id}")
continue

View File

@@ -10,6 +10,7 @@ from onyx.connectors.google_utils.resources import get_drive_service
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
from onyx.connectors.models import SlimDocument
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -128,7 +129,7 @@ def _get_permissions_from_slim_doc(
def gdrive_doc_sync(
cc_pair: ConnectorCredentialPair,
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
) -> list[DocExternalAccess]:
"""
Adds the external permissions to the documents in postgres
@@ -146,6 +147,12 @@ def gdrive_doc_sync(
document_external_accesses = []
for slim_doc_batch in slim_doc_generator:
for slim_doc in slim_doc_batch:
if callback:
if callback.should_stop():
raise RuntimeError("gdrive_doc_sync: Stop signal detected")
callback.progress("gdrive_doc_sync", 1)
ext_access = _get_permissions_from_slim_doc(
google_drive_connector=google_drive_connector,
slim_doc=slim_doc,

View File

@@ -7,6 +7,7 @@ from onyx.connectors.slack.connector import get_channels
from onyx.connectors.slack.connector import make_paginated_slack_api_call_w_retries
from onyx.connectors.slack.connector import SlackPollConnector
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
@@ -14,7 +15,7 @@ logger = setup_logger()
def _get_slack_document_ids_and_channels(
cc_pair: ConnectorCredentialPair,
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
) -> dict[str, list[str]]:
slack_connector = SlackPollConnector(**cc_pair.connector.connector_specific_config)
slack_connector.load_credentials(cc_pair.credential.credential_json)
@@ -24,6 +25,14 @@ def _get_slack_document_ids_and_channels(
channel_doc_map: dict[str, list[str]] = {}
for doc_metadata_batch in slim_doc_generator:
for doc_metadata in doc_metadata_batch:
if callback:
if callback.should_stop():
raise RuntimeError(
"_get_slack_document_ids_and_channels: Stop signal detected"
)
callback.progress("_get_slack_document_ids_and_channels", 1)
if doc_metadata.perm_sync_data is None:
continue
channel_id = doc_metadata.perm_sync_data["channel_id"]
@@ -114,7 +123,7 @@ def _fetch_channel_permissions(
def slack_doc_sync(
cc_pair: ConnectorCredentialPair,
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
) -> list[DocExternalAccess]:
"""
Adds the external permissions to the documents in postgres
@@ -127,7 +136,7 @@ def slack_doc_sync(
)
user_id_to_email_map = fetch_user_id_to_email_map(slack_client)
channel_doc_map = _get_slack_document_ids_and_channels(
cc_pair=cc_pair,
cc_pair=cc_pair, callback=callback
)
workspace_permissions = _fetch_workspace_permissions(
user_id_to_email_map=user_id_to_email_map,

View File

@@ -15,11 +15,13 @@ from ee.onyx.external_permissions.slack.doc_sync import slack_doc_sync
from onyx.access.models import DocExternalAccess
from onyx.configs.constants import DocumentSource
from onyx.db.models import ConnectorCredentialPair
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
# Defining the input/output types for the sync functions
DocSyncFuncType = Callable[
[
ConnectorCredentialPair,
IndexingHeartbeatInterface | None,
],
list[DocExternalAccess],
]

View File

@@ -198,7 +198,8 @@ def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
def wait_for_redis(sender: Any, **kwargs: Any) -> None:
"""Waits for redis to become ready subject to a hardcoded timeout.
Will raise WorkerShutdown to kill the celery worker if the timeout is reached."""
Will raise WorkerShutdown to kill the celery worker if the timeout
is reached."""
r = get_redis_client(tenant_id=None)

View File

@@ -91,6 +91,28 @@ def celery_find_task(task_id: str, queue: str, r: Redis) -> int:
return False
def celery_get_queued_task_ids(queue: str, r: Redis) -> set[str]:
"""This is a redis specific way to build a list of tasks in a queue.
This helps us read the queue once and then efficiently look for missing tasks
in the queue.
"""
task_set: set[str] = set()
for priority in range(len(OnyxCeleryPriority)):
queue_name = f"{queue}{CELERY_SEPARATOR}{priority}" if priority > 0 else queue
tasks = cast(list[bytes], r.lrange(queue_name, 0, -1))
for task in tasks:
task_dict: dict[str, Any] = json.loads(task.decode("utf-8"))
task_id = task_dict.get("headers", {}).get("id")
if task_id:
task_set.add(task_id)
return task_set
def celery_inspect_get_workers(name_filter: str | None, app: Celery) -> list[str]:
"""Returns a list of current workers containing name_filter, or all workers if
name_filter is None.

View File

@@ -3,13 +3,16 @@ from datetime import datetime
from datetime import timedelta
from datetime import timezone
from time import sleep
from typing import cast
from uuid import uuid4
from celery import Celery
from celery import shared_task
from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from pydantic import ValidationError
from redis import Redis
from redis.exceptions import LockError
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
@@ -22,6 +25,10 @@ from ee.onyx.external_permissions.sync_params import (
)
from onyx.access.models import DocExternalAccess
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT
@@ -32,6 +39,7 @@ from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.configs.constants import OnyxRedisSignals
from onyx.db.connector import mark_cc_pair_as_permissions_synced
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.document import upsert_document_by_connector_credential_pair
@@ -44,14 +52,19 @@ from onyx.db.models import ConnectorCredentialPair
from onyx.db.sync_record import insert_sync_record
from onyx.db.sync_record import update_sync_record_status
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_doc_perm_sync import (
RedisConnectorPermissionSyncPayload,
)
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSyncPayload
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
from onyx.server.utils import make_short_id
from onyx.utils.logger import doc_permission_sync_ctx
from onyx.utils.logger import LoggerContextVars
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -105,7 +118,12 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b
bind=True,
)
def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool | None:
# TODO(rkuo): merge into check function after lookup table for fences is added
# we need to use celery's redis client to access its redis data
# (which lives on a different db number)
r = get_redis_client(tenant_id=tenant_id)
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK,
@@ -126,14 +144,32 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool
if _is_external_doc_permissions_sync_due(cc_pair):
cc_pair_ids_to_sync.append(cc_pair.id)
lock_beat.reacquire()
for cc_pair_id in cc_pair_ids_to_sync:
tasks_created = try_creating_permissions_sync_task(
payload_id = try_creating_permissions_sync_task(
self.app, cc_pair_id, r, tenant_id
)
if not tasks_created:
if not payload_id:
continue
task_logger.info(f"Doc permissions sync queued: cc_pair={cc_pair_id}")
task_logger.info(
f"Permissions sync queued: cc_pair={cc_pair_id} id={payload_id}"
)
# we want to run this less frequently than the overall task
lock_beat.reacquire()
if not r.exists(OnyxRedisSignals.VALIDATE_PERMISSION_SYNC_FENCES):
# clear any permission fences that don't have associated celery tasks in progress
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
# or be currently executing
try:
validate_permission_sync_fences(tenant_id, r, r_celery, lock_beat)
except Exception:
task_logger.exception(
"Exception while validating permission sync fences"
)
r.set(OnyxRedisSignals.VALIDATE_PERMISSION_SYNC_FENCES, 1, ex=60)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@@ -152,13 +188,15 @@ def try_creating_permissions_sync_task(
cc_pair_id: int,
r: Redis,
tenant_id: str | None,
) -> int | None:
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
) -> str | None:
"""Returns a randomized payload id on success.
Returns None if no syncing is required."""
redis_connector = RedisConnector(tenant_id, cc_pair_id)
LOCK_TIMEOUT = 30
payload_id: str | None = None
redis_connector = RedisConnector(tenant_id, cc_pair_id)
lock: RedisLock = r.lock(
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_permissions_sync_tasks",
timeout=LOCK_TIMEOUT,
@@ -193,7 +231,13 @@ def try_creating_permissions_sync_task(
)
# set a basic fence to start
payload = RedisConnectorPermissionSyncPayload(started=None, celery_task_id=None)
redis_connector.permissions.set_active()
payload = RedisConnectorPermissionSyncPayload(
id=make_short_id(),
submitted=datetime.now(timezone.utc),
started=None,
celery_task_id=None,
)
redis_connector.permissions.set_fence(payload)
result = app.send_task(
@@ -208,8 +252,11 @@ def try_creating_permissions_sync_task(
)
# fill in the celery task id
redis_connector.permissions.set_active()
payload.celery_task_id = result.id
redis_connector.permissions.set_fence(payload)
payload_id = payload.celery_task_id
except Exception:
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair_id}")
return None
@@ -217,7 +264,7 @@ def try_creating_permissions_sync_task(
if lock.owned():
lock.release()
return 1
return payload_id
@shared_task(
@@ -238,6 +285,8 @@ def connector_permission_sync_generator_task(
This task assumes that the task has already been properly fenced
"""
LoggerContextVars.reset()
doc_permission_sync_ctx_dict = doc_permission_sync_ctx.get()
doc_permission_sync_ctx_dict["cc_pair_id"] = cc_pair_id
doc_permission_sync_ctx_dict["request_id"] = self.request.id
@@ -325,12 +374,17 @@ def connector_permission_sync_generator_task(
raise ValueError(f"No fence payload found: cc_pair={cc_pair_id}")
new_payload = RedisConnectorPermissionSyncPayload(
id=payload.id,
submitted=payload.submitted,
started=datetime.now(timezone.utc),
celery_task_id=payload.celery_task_id,
)
redis_connector.permissions.set_fence(new_payload)
document_external_accesses: list[DocExternalAccess] = doc_sync_func(cc_pair)
callback = PermissionSyncCallback(redis_connector, lock, r)
document_external_accesses: list[DocExternalAccess] = doc_sync_func(
cc_pair, callback
)
task_logger.info(
f"RedisConnector.permissions.generate_tasks starting. cc_pair={cc_pair_id}"
@@ -380,6 +434,8 @@ def update_external_document_permissions_task(
connector_id: int,
credential_id: int,
) -> bool:
start = time.monotonic()
document_external_access = DocExternalAccess.from_dict(
serialized_doc_external_access
)
@@ -409,16 +465,268 @@ def update_external_document_permissions_task(
document_ids=[doc_id],
)
logger.debug(
f"Successfully synced postgres document permissions for {doc_id}"
elapsed = time.monotonic() - start
task_logger.info(
f"connector_id={connector_id} "
f"doc={doc_id} "
f"action=update_permissions "
f"elapsed={elapsed:.2f}"
)
return True
except Exception:
logger.exception(
f"Error Syncing Document Permissions: connector_id={connector_id} doc_id={doc_id}"
task_logger.exception(
f"Exception in update_external_document_permissions_task: "
f"connector_id={connector_id} "
f"doc_id={doc_id}"
)
return False
return True
def validate_permission_sync_fences(
tenant_id: str | None,
r: Redis,
r_celery: Redis,
lock_beat: RedisLock,
) -> None:
# building lookup table can be expensive, so we won't bother
# validating until the queue is small
PERMISSION_SYNC_VALIDATION_MAX_QUEUE_LEN = 1024
queue_len = celery_get_queue_length(
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery
)
if queue_len > PERMISSION_SYNC_VALIDATION_MAX_QUEUE_LEN:
return
queued_upsert_tasks = celery_get_queued_task_ids(
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery
)
reserved_generator_tasks = celery_get_unacked_task_ids(
OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, r_celery
)
# validate all existing indexing jobs
for key_bytes in r.scan_iter(
RedisConnectorPermissionSync.FENCE_PREFIX + "*",
count=SCAN_ITER_COUNT_DEFAULT,
):
lock_beat.reacquire()
validate_permission_sync_fence(
tenant_id,
key_bytes,
queued_upsert_tasks,
reserved_generator_tasks,
r,
r_celery,
)
return
def validate_permission_sync_fence(
tenant_id: str | None,
key_bytes: bytes,
queued_tasks: set[str],
reserved_tasks: set[str],
r: Redis,
r_celery: Redis,
) -> None:
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
This can happen if the indexing worker hard crashes or is terminated.
Being in this bad state means the fence will never clear without help, so this function
gives the help.
How this works:
1. This function renews the active signal with a 5 minute TTL under the following conditions
1.2. When the task is seen in the redis queue
1.3. When the task is seen in the reserved / prefetched list
2. Externally, the active signal is renewed when:
2.1. The fence is created
2.2. The indexing watchdog checks the spawned task.
3. The TTL allows us to get through the transitions on fence startup
and when the task starts executing.
More TTL clarification: it is seemingly impossible to exactly query Celery for
whether a task is in the queue or currently executing.
1. An unknown task id is always returned as state PENDING.
2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task
and the time it actually starts on the worker.
queued_tasks: the celery queue of lightweight permission sync tasks
reserved_tasks: prefetched tasks for sync task generator
"""
# if the fence doesn't exist, there's nothing to do
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(
f"validate_permission_sync_fence - could not parse id from {fence_key}"
)
return
cc_pair_id = int(cc_pair_id_str)
# parse out metadata and initialize the helper class with it
redis_connector = RedisConnector(tenant_id, int(cc_pair_id))
# check to see if the fence/payload exists
if not redis_connector.permissions.fenced:
return
# in the cloud, the payload format may have changed ...
# it's a little sloppy, but just reset the fence for now if that happens
# TODO: add intentional cleanup/abort logic
try:
payload = redis_connector.permissions.payload
except ValidationError:
task_logger.exception(
"validate_permission_sync_fence - "
"Resetting fence because fence schema is out of date: "
f"cc_pair={cc_pair_id} "
f"fence={fence_key}"
)
redis_connector.permissions.reset()
return
if not payload:
return
if not payload.celery_task_id:
return
# OK, there's actually something for us to validate
# either the generator task must be in flight or its subtasks must be
found = celery_find_task(
payload.celery_task_id,
OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC,
r_celery,
)
if found:
# the celery task exists in the redis queue
redis_connector.permissions.set_active()
return
if payload.celery_task_id in reserved_tasks:
# the celery task was prefetched and is reserved within a worker
redis_connector.permissions.set_active()
return
# look up every task in the current taskset in the celery queue
# every entry in the taskset should have an associated entry in the celery task queue
# because we get the celery tasks first, the entries in our own permissions taskset
# should be roughly a subset of the tasks in celery
# this check isn't very exact, but should be sufficient over a period of time
# A single successful check over some number of attempts is sufficient.
# TODO: if the number of tasks in celery is much lower than than the taskset length
# we might be able to shortcut the lookup since by definition some of the tasks
# must not exist in celery.
tasks_scanned = 0
tasks_not_in_celery = 0 # a non-zero number after completing our check is bad
for member in r.sscan_iter(redis_connector.permissions.taskset_key):
tasks_scanned += 1
member_bytes = cast(bytes, member)
member_str = member_bytes.decode("utf-8")
if member_str in queued_tasks:
continue
if member_str in reserved_tasks:
continue
tasks_not_in_celery += 1
task_logger.info(
"validate_permission_sync_fence task check: "
f"tasks_scanned={tasks_scanned} tasks_not_in_celery={tasks_not_in_celery}"
)
if tasks_not_in_celery == 0:
redis_connector.permissions.set_active()
return
# we may want to enable this check if using the active task list somehow isn't good enough
# if redis_connector_index.generator_locked():
# logger.info(f"{payload.celery_task_id} is currently executing.")
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
# but they still might be there due to gaps in our ability to check states during transitions
# Checking the active signal safeguards us against these transition periods
# (which has a duration that allows us to bridge those gaps)
if redis_connector.permissions.active():
return
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
task_logger.warning(
"validate_permission_sync_fence - "
"Resetting fence because no associated celery tasks were found: "
f"cc_pair={cc_pair_id} "
f"fence={fence_key}"
)
redis_connector.permissions.reset()
return
class PermissionSyncCallback(IndexingHeartbeatInterface):
PARENT_CHECK_INTERVAL = 60
def __init__(
self,
redis_connector: RedisConnector,
redis_lock: RedisLock,
redis_client: Redis,
):
super().__init__()
self.redis_connector: RedisConnector = redis_connector
self.redis_lock: RedisLock = redis_lock
self.redis_client = redis_client
self.started: datetime = datetime.now(timezone.utc)
self.redis_lock.reacquire()
self.last_tag: str = "PermissionSyncCallback.__init__"
self.last_lock_reacquire: datetime = datetime.now(timezone.utc)
self.last_lock_monotonic = time.monotonic()
def should_stop(self) -> bool:
if self.redis_connector.stop.fenced:
return True
return False
def progress(self, tag: str, amount: int) -> None:
try:
self.redis_connector.permissions.set_active()
current_time = time.monotonic()
if current_time - self.last_lock_monotonic >= (
CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4
):
self.redis_lock.reacquire()
self.last_lock_reacquire = datetime.now(timezone.utc)
self.last_lock_monotonic = time.monotonic()
self.last_tag = tag
except LockError:
logger.exception(
f"PermissionSyncCallback - lock.reacquire exceptioned: "
f"lock_timeout={self.redis_lock.timeout} "
f"start={self.started} "
f"last_tag={self.last_tag} "
f"last_reacquired={self.last_lock_reacquire} "
f"now={datetime.now(timezone.utc)}"
)
redis_lock_dump(self.redis_lock, self.redis_client)
raise
"""Monitoring CCPair permissions utils, called in monitor_vespa_sync"""
@@ -444,20 +752,36 @@ def monitor_ccpair_permissions_taskset(
if initial is None:
return
try:
payload = redis_connector.permissions.payload
except ValidationError:
task_logger.exception(
"Permissions sync payload failed to validate. "
"Schema may have been updated."
)
return
if not payload:
return
remaining = redis_connector.permissions.get_remaining()
task_logger.info(
f"Permissions sync progress: cc_pair={cc_pair_id} remaining={remaining} initial={initial}"
f"Permissions sync progress: "
f"cc_pair={cc_pair_id} "
f"id={payload.id} "
f"remaining={remaining} "
f"initial={initial}"
)
if remaining > 0:
return
payload: RedisConnectorPermissionSyncPayload | None = (
redis_connector.permissions.payload
mark_cc_pair_as_permissions_synced(db_session, int(cc_pair_id), payload.started)
task_logger.info(
f"Permissions sync finished: "
f"cc_pair={cc_pair_id} "
f"id={payload.id} "
f"num_synced={initial}"
)
start_time: datetime | None = payload.started if payload else None
mark_cc_pair_as_permissions_synced(db_session, int(cc_pair_id), start_time)
task_logger.info(f"Successfully synced permissions for cc_pair={cc_pair_id}")
update_sync_record_status(
db_session=db_session,

View File

@@ -1,3 +1,4 @@
import time
from datetime import datetime
from datetime import timedelta
from datetime import timezone
@@ -9,6 +10,7 @@ from celery import Task
from celery.exceptions import SoftTimeLimitExceeded
from redis import Redis
from redis.lock import Lock as RedisLock
from sqlalchemy.orm import Session
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
from ee.onyx.db.connector_credential_pair import get_cc_pairs_by_source
@@ -20,9 +22,12 @@ from ee.onyx.external_permissions.sync_params import (
GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC,
)
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_redis import celery_find_task
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
@@ -39,10 +44,12 @@ from onyx.db.models import ConnectorCredentialPair
from onyx.db.sync_record import insert_sync_record
from onyx.db.sync_record import update_sync_record_status
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
from onyx.redis.redis_connector_ext_group_sync import (
RedisConnectorExternalGroupSyncPayload,
)
from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -102,6 +109,10 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool | None:
r = get_redis_client(tenant_id=tenant_id)
# we need to use celery's redis client to access its redis data
# (which lives on a different db number)
# r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
lock_beat: RedisLock = r.lock(
OnyxRedisLocks.CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK,
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
@@ -136,6 +147,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool
if _is_external_group_sync_due(cc_pair):
cc_pair_ids_to_sync.append(cc_pair.id)
lock_beat.reacquire()
for cc_pair_id in cc_pair_ids_to_sync:
tasks_created = try_creating_external_group_sync_task(
self.app, cc_pair_id, r, tenant_id
@@ -144,6 +156,23 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool
continue
task_logger.info(f"External group sync queued: cc_pair={cc_pair_id}")
# we want to run this less frequently than the overall task
# lock_beat.reacquire()
# if not r.exists(OnyxRedisSignals.VALIDATE_EXTERNAL_GROUP_SYNC_FENCES):
# # clear any indexing fences that don't have associated celery tasks in progress
# # tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
# # or be currently executing
# try:
# validate_external_group_sync_fences(
# tenant_id, self.app, r, r_celery, lock_beat
# )
# except Exception:
# task_logger.exception(
# "Exception while validating external group sync fences"
# )
# r.set(OnyxRedisSignals.VALIDATE_EXTERNAL_GROUP_SYNC_FENCES, 1, ex=60)
except SoftTimeLimitExceeded:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
@@ -186,6 +215,12 @@ def try_creating_external_group_sync_task(
redis_connector.external_group_sync.generator_clear()
redis_connector.external_group_sync.taskset_clear()
payload = RedisConnectorExternalGroupSyncPayload(
submitted=datetime.now(timezone.utc),
started=None,
celery_task_id=None,
)
custom_task_id = f"{redis_connector.external_group_sync.taskset_key}_{uuid4()}"
result = app.send_task(
@@ -199,11 +234,6 @@ def try_creating_external_group_sync_task(
priority=OnyxCeleryPriority.HIGH,
)
payload = RedisConnectorExternalGroupSyncPayload(
started=datetime.now(timezone.utc),
celery_task_id=result.id,
)
# create before setting fence to avoid race condition where the monitoring
# task updates the sync record before it is created
with get_session_with_tenant(tenant_id) as db_session:
@@ -213,8 +243,8 @@ def try_creating_external_group_sync_task(
sync_type=SyncType.EXTERNAL_GROUP,
)
payload.celery_task_id = result.id
redis_connector.external_group_sync.set_fence(payload)
except Exception:
task_logger.exception(
f"Unexpected exception while trying to create external group sync task: cc_pair={cc_pair_id}"
@@ -241,7 +271,7 @@ def connector_external_group_sync_generator_task(
tenant_id: str | None,
) -> None:
"""
Permission sync task that handles external group syncing for a given connector credential pair
External group sync task for a given connector credential pair
This task assumes that the task has already been properly fenced
"""
@@ -249,19 +279,59 @@ def connector_external_group_sync_generator_task(
r = get_redis_client(tenant_id=tenant_id)
# this wait is needed to avoid a race condition where
# the primary worker sends the task and it is immediately executed
# before the primary worker can finalize the fence
start = time.monotonic()
while True:
if time.monotonic() - start > CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT:
raise ValueError(
f"connector_external_group_sync_generator_task - timed out waiting for fence to be ready: "
f"fence={redis_connector.external_group_sync.fence_key}"
)
if not redis_connector.external_group_sync.fenced: # The fence must exist
raise ValueError(
f"connector_external_group_sync_generator_task - fence not found: "
f"fence={redis_connector.external_group_sync.fence_key}"
)
payload = redis_connector.external_group_sync.payload # The payload must exist
if not payload:
raise ValueError(
"connector_external_group_sync_generator_task: payload invalid or not found"
)
if payload.celery_task_id is None:
logger.info(
f"connector_external_group_sync_generator_task - Waiting for fence: "
f"fence={redis_connector.external_group_sync.fence_key}"
)
time.sleep(1)
continue
logger.info(
f"connector_external_group_sync_generator_task - Fence found, continuing...: "
f"fence={redis_connector.external_group_sync.fence_key}"
)
break
lock: RedisLock = r.lock(
OnyxRedisLocks.CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX
+ f"_{redis_connector.id}",
timeout=CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT,
)
acquired = lock.acquire(blocking=False)
if not acquired:
task_logger.warning(
f"External group sync task already running, exiting...: cc_pair={cc_pair_id}"
)
return None
try:
acquired = lock.acquire(blocking=False)
if not acquired:
task_logger.warning(
f"External group sync task already running, exiting...: cc_pair={cc_pair_id}"
)
return None
payload.started = datetime.now(timezone.utc)
redis_connector.external_group_sync.set_fence(payload)
with get_session_with_tenant(tenant_id) as db_session:
cc_pair = get_connector_credential_pair_from_id(
@@ -330,3 +400,135 @@ def connector_external_group_sync_generator_task(
redis_connector.external_group_sync.set_fence(None)
if lock.owned():
lock.release()
def validate_external_group_sync_fences(
tenant_id: str | None,
celery_app: Celery,
r: Redis,
r_celery: Redis,
lock_beat: RedisLock,
) -> None:
reserved_sync_tasks = celery_get_unacked_task_ids(
OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC, r_celery
)
# validate all existing indexing jobs
for key_bytes in r.scan_iter(
RedisConnectorExternalGroupSync.FENCE_PREFIX + "*",
count=SCAN_ITER_COUNT_DEFAULT,
):
lock_beat.reacquire()
with get_session_with_tenant(tenant_id) as db_session:
validate_external_group_sync_fence(
tenant_id,
key_bytes,
reserved_sync_tasks,
r_celery,
db_session,
)
return
def validate_external_group_sync_fence(
tenant_id: str | None,
key_bytes: bytes,
reserved_tasks: set[str],
r_celery: Redis,
db_session: Session,
) -> None:
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
This can happen if the indexing worker hard crashes or is terminated.
Being in this bad state means the fence will never clear without help, so this function
gives the help.
How this works:
1. This function renews the active signal with a 5 minute TTL under the following conditions
1.2. When the task is seen in the redis queue
1.3. When the task is seen in the reserved / prefetched list
2. Externally, the active signal is renewed when:
2.1. The fence is created
2.2. The indexing watchdog checks the spawned task.
3. The TTL allows us to get through the transitions on fence startup
and when the task starts executing.
More TTL clarification: it is seemingly impossible to exactly query Celery for
whether a task is in the queue or currently executing.
1. An unknown task id is always returned as state PENDING.
2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task
and the time it actually starts on the worker.
"""
# if the fence doesn't exist, there's nothing to do
fence_key = key_bytes.decode("utf-8")
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
if cc_pair_id_str is None:
task_logger.warning(
f"validate_external_group_sync_fence - could not parse id from {fence_key}"
)
return
cc_pair_id = int(cc_pair_id_str)
# parse out metadata and initialize the helper class with it
redis_connector = RedisConnector(tenant_id, int(cc_pair_id))
# check to see if the fence/payload exists
if not redis_connector.external_group_sync.fenced:
return
payload = redis_connector.external_group_sync.payload
if not payload:
return
# OK, there's actually something for us to validate
if payload.celery_task_id is None:
# the fence is just barely set up.
# if redis_connector_index.active():
# return
# it would be odd to get here as there isn't that much that can go wrong during
# initial fence setup, but it's still worth making sure we can recover
logger.info(
"validate_external_group_sync_fence - "
f"Resetting fence in basic state without any activity: fence={fence_key}"
)
redis_connector.external_group_sync.reset()
return
found = celery_find_task(
payload.celery_task_id, OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC, r_celery
)
if found:
# the celery task exists in the redis queue
# redis_connector_index.set_active()
return
if payload.celery_task_id in reserved_tasks:
# the celery task was prefetched and is reserved within the indexing worker
# redis_connector_index.set_active()
return
# we may want to enable this check if using the active task list somehow isn't good enough
# if redis_connector_index.generator_locked():
# logger.info(f"{payload.celery_task_id} is currently executing.")
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
# but they still might be there due to gaps in our ability to check states during transitions
# Checking the active signal safeguards us against these transition periods
# (which has a duration that allows us to bridge those gaps)
# if redis_connector_index.active():
# return
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
logger.warning(
"validate_external_group_sync_fence - "
"Resetting fence because no associated celery tasks were found: "
f"cc_pair={cc_pair_id} "
f"fence={fence_key}"
)
redis_connector.external_group_sync.reset()
return

View File

@@ -39,6 +39,7 @@ from onyx.redis.redis_pool import get_redis_client
from onyx.redis.redis_pool import redis_lock_dump
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
_MONITORING_SOFT_TIME_LIMIT = 60 * 5 # 5 minutes
@@ -657,6 +658,9 @@ def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
- Syncing speed metrics
- Worker status and task counts
"""
if tenant_id is not None:
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
task_logger.info("Starting background monitoring")
r = get_redis_client(tenant_id=tenant_id)
@@ -688,11 +692,13 @@ def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
metrics = metric_fn()
for metric in metrics:
# double check to make sure we aren't double-emitting metrics
if metric.key is not None and not _has_metric_been_emitted(
if metric.key is None or not _has_metric_been_emitted(
redis_std, metric.key
):
metric.log()
metric.emit(tenant_id)
if metric.key is not None:
_mark_metric_as_emitted(redis_std, metric.key)
task_logger.info("Successfully collected background metrics")

View File

@@ -39,6 +39,7 @@ from onyx.db.sync_record import insert_sync_record
from onyx.db.sync_record import update_sync_record_status
from onyx.redis.redis_connector import RedisConnector
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.logger import LoggerContextVars
from onyx.utils.logger import pruning_ctx
from onyx.utils.logger import setup_logger
@@ -251,6 +252,8 @@ def connector_pruning_generator_task(
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
from the most recently pulled document ID list"""
LoggerContextVars.reset()
pruning_ctx_dict = pruning_ctx.get()
pruning_ctx_dict["cc_pair_id"] = cc_pair_id
pruning_ctx_dict["request_id"] = self.request.id
@@ -399,7 +402,7 @@ def monitor_ccpair_pruning_taskset(
mark_ccpair_as_pruned(int(cc_pair_id), db_session)
task_logger.info(
f"Successfully pruned connector credential pair. cc_pair={cc_pair_id}"
f"Connector pruning finished: cc_pair={cc_pair_id} num_pruned={initial}"
)
update_sync_record_status(

View File

@@ -75,6 +75,8 @@ def document_by_cc_pair_cleanup_task(
"""
task_logger.debug(f"Task start: doc={document_id}")
start = time.monotonic()
try:
with get_session_with_tenant(tenant_id) as db_session:
action = "skip"
@@ -154,11 +156,13 @@ def document_by_cc_pair_cleanup_task(
db_session.commit()
elapsed = time.monotonic() - start
task_logger.info(
f"doc={document_id} "
f"action={action} "
f"refcount={count} "
f"chunks={chunks_affected}"
f"chunks={chunks_affected} "
f"elapsed={elapsed:.2f}"
)
except SoftTimeLimitExceeded:
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")

View File

@@ -989,6 +989,10 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
task_logger.info(
"Soft time limit exceeded, task is being terminated gracefully."
)
return False
except Exception:
task_logger.exception("monitor_vespa_sync exceptioned.")
return False
finally:
if lock_beat.owned():
lock_beat.release()
@@ -1078,6 +1082,7 @@ def vespa_metadata_sync_task(
)
except SoftTimeLimitExceeded:
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
return False
except Exception as ex:
if isinstance(ex, RetryError):
task_logger.warning(

View File

@@ -617,3 +617,8 @@ POD_NAMESPACE = os.environ.get("POD_NAMESPACE")
DEV_MODE = os.environ.get("DEV_MODE", "").lower() == "true"
TEST_ENV = os.environ.get("TEST_ENV", "").lower() == "true"
# Set to true to mock LLM responses for testing purposes
MOCK_LLM_RESPONSE = (
os.environ.get("MOCK_LLM_RESPONSE") if os.environ.get("MOCK_LLM_RESPONSE") else None
)

View File

@@ -300,6 +300,8 @@ class OnyxRedisLocks:
class OnyxRedisSignals:
VALIDATE_INDEXING_FENCES = "signal:validate_indexing_fences"
VALIDATE_EXTERNAL_GROUP_SYNC_FENCES = "signal:validate_external_group_sync_fences"
VALIDATE_PERMISSION_SYNC_FENCES = "signal:validate_permission_sync_fences"
class OnyxCeleryPriority(int, Enum):

View File

@@ -1,4 +1,6 @@
import contextvars
from concurrent.futures import as_completed
from concurrent.futures import Future
from concurrent.futures import ThreadPoolExecutor
from io import BytesIO
from typing import Any
@@ -68,18 +70,25 @@ class AirtableConnector(LoadConnector):
self.base_id = base_id
self.table_name_or_id = table_name_or_id
self.batch_size = batch_size
self.airtable_client: AirtableApi | None = None
self._airtable_client: AirtableApi | None = None
self.treat_all_non_attachment_fields_as_metadata = (
treat_all_non_attachment_fields_as_metadata
)
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self.airtable_client = AirtableApi(credentials["airtable_access_token"])
self._airtable_client = AirtableApi(credentials["airtable_access_token"])
return None
@staticmethod
@property
def airtable_client(self) -> AirtableApi:
if not self._airtable_client:
raise AirtableClientNotSetUpError()
return self._airtable_client
def _extract_field_values(
self,
field_id: str,
field_name: str,
field_info: Any,
field_type: str,
base_id: str,
@@ -118,13 +127,33 @@ class AirtableConnector(LoadConnector):
backoff=2,
max_delay=10,
)
def get_attachment_with_retry(url: str) -> bytes | None:
attachment_response = requests.get(url)
if attachment_response.status_code == 200:
def get_attachment_with_retry(url: str, record_id: str) -> bytes | None:
try:
attachment_response = requests.get(url)
attachment_response.raise_for_status()
return attachment_response.content
return None
except requests.exceptions.HTTPError as e:
if e.response.status_code == 410:
logger.info(f"Refreshing attachment for {filename}")
# Re-fetch the record to get a fresh URL
refreshed_record = self.airtable_client.table(
base_id, table_id
).get(record_id)
for refreshed_attachment in refreshed_record["fields"][
field_name
]:
if refreshed_attachment.get("filename") == filename:
new_url = refreshed_attachment.get("url")
if new_url:
attachment_response = requests.get(new_url)
attachment_response.raise_for_status()
return attachment_response.content
attachment_content = get_attachment_with_retry(url)
logger.error(f"Failed to refresh attachment for {filename}")
raise
attachment_content = get_attachment_with_retry(url, record_id)
if attachment_content:
try:
file_ext = get_file_ext(filename)
@@ -208,6 +237,7 @@ class AirtableConnector(LoadConnector):
# Get the value(s) for the field
field_value_and_links = self._extract_field_values(
field_id=field_id,
field_name=field_name,
field_info=field_info,
field_type=field_type,
base_id=self.base_id,
@@ -337,7 +367,7 @@ class AirtableConnector(LoadConnector):
logger.info(f"Starting to process Airtable records for {table.name}.")
# Process records in parallel batches using ThreadPoolExecutor
PARALLEL_BATCH_SIZE = 16
PARALLEL_BATCH_SIZE = 8
max_workers = min(PARALLEL_BATCH_SIZE, len(records))
# Process records in batches
@@ -347,15 +377,19 @@ class AirtableConnector(LoadConnector):
with ThreadPoolExecutor(max_workers=max_workers) as executor:
# Submit batch tasks
future_to_record = {
executor.submit(
self._process_record,
record=record,
table_schema=table_schema,
primary_field_name=primary_field_name,
): record
for record in batch_records
}
future_to_record: dict[Future, RecordDict] = {}
for record in batch_records:
# Capture the current context so that the thread gets the current tenant ID
current_context = contextvars.copy_context()
future_to_record[
executor.submit(
current_context.run,
self._process_record,
record=record,
table_schema=table_schema,
primary_field_name=primary_field_name,
)
] = record
# Wait for all tasks in this batch to complete
for future in as_completed(future_to_record):
@@ -368,10 +402,8 @@ class AirtableConnector(LoadConnector):
logger.exception(f"Failed to process record {record['id']}")
raise e
# After batch is complete, yield if we've hit the batch size
if len(record_documents) >= self.batch_size:
yield record_documents
record_documents = []
yield record_documents
record_documents = []
# Yield any remaining records
if record_documents:

View File

@@ -150,6 +150,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
# if specified, controls the assistants that are shown to the user + their order
# if not specified, all assistants are shown
temperature_override_enabled: Mapped[bool] = mapped_column(Boolean, default=False)
auto_scroll: Mapped[bool] = mapped_column(Boolean, default=True)
shortcut_enabled: Mapped[bool] = mapped_column(Boolean, default=True)
chosen_assistants: Mapped[list[int] | None] = mapped_column(
@@ -1115,6 +1116,10 @@ class ChatSession(Base):
llm_override: Mapped[LLMOverride | None] = mapped_column(
PydanticType(LLMOverride), nullable=True
)
# The latest temperature override specified by the user
temperature_override: Mapped[float | None] = mapped_column(Float, nullable=True)
prompt_override: Mapped[PromptOverride | None] = mapped_column(
PydanticType(PromptOverride), nullable=True
)

View File

@@ -2,7 +2,7 @@ from onyx.key_value_store.interface import KeyValueStore
from onyx.key_value_store.store import PgRedisKVStore
def get_kv_store(tenant_id: str | None = None) -> KeyValueStore:
def get_kv_store() -> KeyValueStore:
# In the Multi Tenant case, the tenant context is picked up automatically, it does not need to be passed in
# It's read from the global thread level variable
return PgRedisKVStore(tenant_id=tenant_id)
return PgRedisKVStore()

View File

@@ -18,7 +18,7 @@ from onyx.utils.logger import setup_logger
from onyx.utils.special_types import JSON_ro
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
@@ -28,10 +28,8 @@ KV_REDIS_KEY_EXPIRATION = 60 * 60 * 24 # 1 Day
class PgRedisKVStore(KeyValueStore):
def __init__(
self, redis_client: Redis | None = None, tenant_id: str | None = None
) -> None:
self.tenant_id = tenant_id or CURRENT_TENANT_ID_CONTEXTVAR.get()
def __init__(self, redis_client: Redis | None = None) -> None:
self.tenant_id = get_current_tenant_id()
# If no redis_client is provided, fall back to the context var
if redis_client is not None:

View File

@@ -26,6 +26,7 @@ from langchain_core.messages.tool import ToolMessage
from langchain_core.prompt_values import PromptValue
from onyx.configs.app_configs import LOG_DANSWER_MODEL_INTERACTIONS
from onyx.configs.app_configs import MOCK_LLM_RESPONSE
from onyx.configs.model_configs import (
DISABLE_LITELLM_STREAMING,
)
@@ -387,6 +388,7 @@ class DefaultMultiLLM(LLM):
try:
return litellm.completion(
mock_response=MOCK_LLM_RESPONSE,
# model choice
model=f"{self.config.model_provider}/{self.config.deployment_name or self.config.model_name}",
# NOTE: have to pass in None instead of empty string for these

View File

@@ -109,7 +109,9 @@ from onyx.utils.variable_functionality import global_version
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
from shared_configs.configs import CORS_ALLOWED_ORIGIN
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import SENTRY_DSN
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
@@ -212,7 +214,8 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
if not MULTI_TENANT:
# We cache this at the beginning so there is no delay in the first telemetry
get_or_generate_uuid(tenant_id=None)
CURRENT_TENANT_ID_CONTEXTVAR.set(POSTGRES_DEFAULT_SCHEMA)
get_or_generate_uuid()
# If we are multi-tenant, we need to only set up initial public tables
with Session(engine) as db_session:

View File

@@ -175,7 +175,6 @@ class EmbeddingModel:
if self.callback.should_stop():
raise RuntimeError("_batch_encode_texts detected stop signal")
logger.debug(f"Encoding batch {batch_idx} of {len(text_batches)}")
embed_request = EmbedRequest(
model_name=self.model_name,
texts=text_batch,
@@ -191,7 +190,15 @@ class EmbeddingModel:
api_url=self.api_url,
)
start_time = time.time()
response = self._make_model_server_request(embed_request)
end_time = time.time()
processing_time = end_time - start_time
logger.info(
f"Batch {batch_idx} processing time: {processing_time:.2f} seconds"
)
return batch_idx, response.embeddings
# only multi thread if:

View File

@@ -17,6 +17,8 @@ from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
class RedisConnectorPermissionSyncPayload(BaseModel):
id: str
submitted: datetime
started: datetime | None
celery_task_id: str | None
@@ -41,6 +43,12 @@ class RedisConnectorPermissionSync:
TASKSET_PREFIX = f"{PREFIX}_taskset" # connectorpermissions_taskset
SUBTASK_PREFIX = f"{PREFIX}+sub" # connectorpermissions+sub
# used to signal the overall workflow is still active
# it's impossible to get the exact state of the system at a single point in time
# so we need a signal with a TTL to bridge gaps in our checks
ACTIVE_PREFIX = PREFIX + "_active"
ACTIVE_TTL = 3600
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
self.tenant_id: str | None = tenant_id
self.id = id
@@ -54,6 +62,7 @@ class RedisConnectorPermissionSync:
self.taskset_key = f"{self.TASKSET_PREFIX}_{id}"
self.subtask_prefix: str = f"{self.SUBTASK_PREFIX}_{id}"
self.active_key = f"{self.ACTIVE_PREFIX}_{id}"
def taskset_clear(self) -> None:
self.redis.delete(self.taskset_key)
@@ -107,6 +116,20 @@ class RedisConnectorPermissionSync:
self.redis.set(self.fence_key, payload.model_dump_json())
def set_active(self) -> None:
"""This sets a signal to keep the permissioning flow from getting cleaned up within
the expiration time.
The slack in timing is needed to avoid race conditions where simply checking
the celery queue and task status could result in race conditions."""
self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL)
def active(self) -> bool:
if self.redis.exists(self.active_key):
return True
return False
@property
def generator_complete(self) -> int | None:
"""the fence payload is an int representing the starting number of
@@ -173,6 +196,7 @@ class RedisConnectorPermissionSync:
return len(async_results)
def reset(self) -> None:
self.redis.delete(self.active_key)
self.redis.delete(self.generator_progress_key)
self.redis.delete(self.generator_complete_key)
self.redis.delete(self.taskset_key)
@@ -187,6 +211,9 @@ class RedisConnectorPermissionSync:
@staticmethod
def reset_all(r: redis.Redis) -> None:
"""Deletes all redis values for all connectors"""
for key in r.scan_iter(RedisConnectorPermissionSync.ACTIVE_PREFIX + "*"):
r.delete(key)
for key in r.scan_iter(RedisConnectorPermissionSync.TASKSET_PREFIX + "*"):
r.delete(key)

View File

@@ -11,6 +11,7 @@ from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
class RedisConnectorExternalGroupSyncPayload(BaseModel):
submitted: datetime
started: datetime | None
celery_task_id: str | None
@@ -135,6 +136,12 @@ class RedisConnectorExternalGroupSync:
) -> int | None:
pass
def reset(self) -> None:
self.redis.delete(self.generator_progress_key)
self.redis.delete(self.generator_complete_key)
self.redis.delete(self.taskset_key)
self.redis.delete(self.fence_key)
@staticmethod
def remove_from_taskset(id: int, task_id: str, r: redis.Redis) -> None:
taskset_key = f"{RedisConnectorExternalGroupSync.TASKSET_PREFIX}_{id}"

View File

@@ -33,8 +33,8 @@ class RedisConnectorIndex:
TERMINATE_TTL = 600
# used to signal the overall workflow is still active
# there are gaps in time between states where we need some slack
# to correctly transition
# it's impossible to get the exact state of the system at a single point in time
# so we need a signal with a TTL to bridge gaps in our checks
ACTIVE_PREFIX = PREFIX + "_active"
ACTIVE_TTL = 3600

View File

@@ -122,7 +122,7 @@ class TenantRedis(redis.Redis):
"ttl",
] # Regular methods that need simple prefixing
if item == "scan_iter":
if item == "scan_iter" or item == "sscan_iter":
return self._prefix_scan_iter(original_attr)
elif item in methods_to_wrap and callable(original_attr):
return self._prefix_method(original_attr)

View File

@@ -422,27 +422,29 @@ def sync_cc_pair(
if redis_connector.permissions.fenced:
raise HTTPException(
status_code=HTTPStatus.CONFLICT,
detail="Doc permissions sync task already in progress.",
detail="Permissions sync task already in progress.",
)
logger.info(
f"Doc permissions sync cc_pair={cc_pair_id} "
f"Permissions sync cc_pair={cc_pair_id} "
f"connector_id={cc_pair.connector_id} "
f"credential_id={cc_pair.credential_id} "
f"{cc_pair.connector.name} connector."
)
tasks_created = try_creating_permissions_sync_task(
payload_id = try_creating_permissions_sync_task(
primary_app, cc_pair_id, r, CURRENT_TENANT_ID_CONTEXTVAR.get()
)
if not tasks_created:
if not payload_id:
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
detail="Doc permissions sync task creation failed.",
detail="Permissions sync task creation failed.",
)
logger.info(f"Permissions sync queued: cc_pair={cc_pair_id} id={payload_id}")
return StatusResponse(
success=True,
message="Successfully created the doc permissions sync task.",
message="Successfully created the permissions sync task.",
)

View File

@@ -48,6 +48,7 @@ class UserPreferences(BaseModel):
auto_scroll: bool | None = None
pinned_assistants: list[int] | None = None
shortcut_enabled: bool | None = None
temperature_override_enabled: bool | None = None
class UserInfo(BaseModel):
@@ -91,6 +92,7 @@ class UserInfo(BaseModel):
hidden_assistants=user.hidden_assistants,
pinned_assistants=user.pinned_assistants,
visible_assistants=user.visible_assistants,
temperature_override_enabled=user.temperature_override_enabled,
)
),
organization_name=organization_name,

View File

@@ -568,6 +568,32 @@ def verify_user_logged_in(
"""APIs to adjust user preferences"""
@router.patch("/temperature-override-enabled")
def update_user_temperature_override_enabled(
temperature_override_enabled: bool,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
if user is None:
if AUTH_TYPE == AuthType.DISABLED:
store = get_kv_store()
no_auth_user = fetch_no_auth_user(store)
no_auth_user.preferences.temperature_override_enabled = (
temperature_override_enabled
)
set_no_auth_user_preferences(store, no_auth_user.preferences)
return
else:
raise RuntimeError("This should never happen")
db_session.execute(
update(User)
.where(User.id == user.id) # type: ignore
.values(temperature_override_enabled=temperature_override_enabled)
)
db_session.commit()
class ChosenDefaultModelRequest(BaseModel):
default_model: str | None = None

View File

@@ -18,7 +18,6 @@ from pydantic import BaseModel
from sqlalchemy.orm import Session
from onyx.auth.users import current_chat_accesssible_user
from onyx.auth.users import current_limited_user
from onyx.auth.users import current_user
from onyx.chat.chat_utils import create_chat_chain
from onyx.chat.chat_utils import extract_headers
@@ -78,6 +77,7 @@ from onyx.server.query_and_chat.models import LLMOverride
from onyx.server.query_and_chat.models import PromptOverride
from onyx.server.query_and_chat.models import RenameChatSessionResponse
from onyx.server.query_and_chat.models import SearchFeedbackRequest
from onyx.server.query_and_chat.models import UpdateChatSessionTemperatureRequest
from onyx.server.query_and_chat.models import UpdateChatSessionThreadRequest
from onyx.server.query_and_chat.token_limit import check_token_rate_limits
from onyx.utils.headers import get_custom_tool_additional_request_headers
@@ -115,12 +115,52 @@ def get_user_chat_sessions(
shared_status=chat.shared_status,
folder_id=chat.folder_id,
current_alternate_model=chat.current_alternate_model,
current_temperature_override=chat.temperature_override,
)
for chat in chat_sessions
]
)
@router.put("/update-chat-session-temperature")
def update_chat_session_temperature(
update_thread_req: UpdateChatSessionTemperatureRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
chat_session = get_chat_session_by_id(
chat_session_id=update_thread_req.chat_session_id,
user_id=user.id if user is not None else None,
db_session=db_session,
)
# Validate temperature_override
if update_thread_req.temperature_override is not None:
if (
update_thread_req.temperature_override < 0
or update_thread_req.temperature_override > 2
):
raise HTTPException(
status_code=400, detail="Temperature must be between 0 and 2"
)
# Additional check for Anthropic models
if (
chat_session.current_alternate_model
and "anthropic" in chat_session.current_alternate_model.lower()
):
if update_thread_req.temperature_override > 1:
raise HTTPException(
status_code=400,
detail="Temperature for Anthropic models must be between 0 and 1",
)
chat_session.temperature_override = update_thread_req.temperature_override
db_session.add(chat_session)
db_session.commit()
@router.put("/update-chat-session-model")
def update_chat_session_model(
update_thread_req: UpdateChatSessionThreadRequest,
@@ -191,6 +231,7 @@ def get_chat_session(
],
time_created=chat_session.time_created,
shared_status=chat_session.shared_status,
current_temperature_override=chat_session.temperature_override,
)
@@ -422,7 +463,7 @@ def set_message_as_latest(
@router.post("/create-chat-message-feedback")
def create_chat_feedback(
feedback: ChatFeedbackRequest,
user: User | None = Depends(current_limited_user),
user: User | None = Depends(current_chat_accesssible_user),
db_session: Session = Depends(get_session),
) -> None:
user_id = user.id if user else None

View File

@@ -42,6 +42,11 @@ class UpdateChatSessionThreadRequest(BaseModel):
new_alternate_model: str
class UpdateChatSessionTemperatureRequest(BaseModel):
chat_session_id: UUID
temperature_override: float
class ChatSessionCreationRequest(BaseModel):
# If not specified, use Onyx default persona
persona_id: int = 0
@@ -108,6 +113,10 @@ class CreateChatMessageRequest(ChunkContext):
llm_override: LLMOverride | None = None
prompt_override: PromptOverride | None = None
# Allows the caller to override the temperature for the chat session
# this does persist in the chat thread details
temperature_override: float | None = None
# allow user to specify an alternate assistnat
alternate_assistant_id: int | None = None
@@ -168,6 +177,7 @@ class ChatSessionDetails(BaseModel):
shared_status: ChatSessionSharedStatus
folder_id: int | None = None
current_alternate_model: str | None = None
current_temperature_override: float | None = None
class ChatSessionsResponse(BaseModel):
@@ -231,6 +241,7 @@ class ChatSessionDetailResponse(BaseModel):
time_created: datetime
shared_status: ChatSessionSharedStatus
current_alternate_model: str | None
current_temperature_override: float | None
# This one is not used anymore

View File

@@ -1,4 +1,6 @@
import base64
import json
import os
from datetime import datetime
from typing import Any
@@ -66,3 +68,10 @@ def mask_credential_dict(credential_dict: dict[str, Any]) -> dict[str, str]:
)
return masked_creds
def make_short_id() -> str:
"""Fast way to generate a random 8 character id ... useful for tagging data
to trace it through a flow. This is definitely not guaranteed to be unique and is
targeted at the stated use case."""
return base64.b32encode(os.urandom(5)).decode("utf-8")[:8] # 5 bytes → 8 chars

View File

@@ -37,6 +37,7 @@ from onyx.document_index.vespa.index import VespaIndex
from onyx.indexing.models import IndexingSetting
from onyx.key_value_store.factory import get_kv_store
from onyx.key_value_store.interface import KvKeyNotFoundError
from onyx.llm.llm_provider_options import OPEN_AI_MODEL_NAMES
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
from onyx.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from onyx.natural_language_processing.search_nlp_models import warm_up_cross_encoder
@@ -279,6 +280,7 @@ def setup_postgres(db_session: Session) -> None:
if GEN_AI_API_KEY and fetch_default_provider(db_session) is None:
# Only for dev flows
logger.notice("Setting up default OpenAI LLM for dev.")
llm_model = GEN_AI_MODEL_VERSION or "gpt-4o-mini"
fast_model = FAST_GEN_AI_MODEL_VERSION or "gpt-4o-mini"
model_req = LLMProviderUpsertRequest(
@@ -292,8 +294,8 @@ def setup_postgres(db_session: Session) -> None:
fast_default_model_name=fast_model,
is_public=True,
groups=[],
display_model_names=[llm_model, fast_model],
model_names=[llm_model, fast_model],
display_model_names=OPEN_AI_MODEL_NAMES,
model_names=OPEN_AI_MODEL_NAMES,
)
new_llm_provider = upsert_llm_provider(
llm_provider=model_req, db_session=db_session

View File

@@ -26,6 +26,13 @@ doc_permission_sync_ctx: contextvars.ContextVar[
] = contextvars.ContextVar("doc_permission_sync_ctx", default=dict())
class LoggerContextVars:
@staticmethod
def reset() -> None:
pruning_ctx.set(dict())
doc_permission_sync_ctx.set(dict())
class TaskAttemptSingleton:
"""Used to tell if this process is an indexing job, and if so what is the
unique identifier for this indexing attempt. For things like the API server,
@@ -70,27 +77,32 @@ class OnyxLoggingAdapter(logging.LoggerAdapter):
) -> tuple[str, MutableMapping[str, Any]]:
# If this is an indexing job, add the attempt ID to the log message
# This helps filter the logs for this specific indexing
index_attempt_id = TaskAttemptSingleton.get_index_attempt_id()
cc_pair_id = TaskAttemptSingleton.get_connector_credential_pair_id()
while True:
pruning_ctx_dict = pruning_ctx.get()
if len(pruning_ctx_dict) > 0:
if "request_id" in pruning_ctx_dict:
msg = f"[Prune: {pruning_ctx_dict['request_id']}] {msg}"
doc_permission_sync_ctx_dict = doc_permission_sync_ctx.get()
pruning_ctx_dict = pruning_ctx.get()
if len(pruning_ctx_dict) > 0:
if "request_id" in pruning_ctx_dict:
msg = f"[Prune: {pruning_ctx_dict['request_id']}] {msg}"
if "cc_pair_id" in pruning_ctx_dict:
msg = f"[CC Pair: {pruning_ctx_dict['cc_pair_id']}] {msg}"
break
doc_permission_sync_ctx_dict = doc_permission_sync_ctx.get()
if len(doc_permission_sync_ctx_dict) > 0:
if "request_id" in doc_permission_sync_ctx_dict:
msg = f"[Doc Permissions Sync: {doc_permission_sync_ctx_dict['request_id']}] {msg}"
break
index_attempt_id = TaskAttemptSingleton.get_index_attempt_id()
cc_pair_id = TaskAttemptSingleton.get_connector_credential_pair_id()
if "cc_pair_id" in pruning_ctx_dict:
msg = f"[CC Pair: {pruning_ctx_dict['cc_pair_id']}] {msg}"
elif len(doc_permission_sync_ctx_dict) > 0:
if "request_id" in doc_permission_sync_ctx_dict:
msg = f"[Doc Permissions Sync: {doc_permission_sync_ctx_dict['request_id']}] {msg}"
else:
if index_attempt_id is not None:
msg = f"[Index Attempt: {index_attempt_id}] {msg}"
if cc_pair_id is not None:
msg = f"[CC Pair: {cc_pair_id}] {msg}"
break
# Add tenant information if it differs from default
# This will always be the case for authenticated API requests
if MULTI_TENANT:

View File

@@ -1,3 +1,4 @@
import contextvars
import threading
import uuid
from enum import Enum
@@ -41,7 +42,7 @@ def _get_or_generate_customer_id_mt(tenant_id: str) -> str:
return str(uuid.uuid5(uuid.NAMESPACE_X500, tenant_id))
def get_or_generate_uuid(tenant_id: str | None) -> str:
def get_or_generate_uuid() -> str:
# TODO: split out the whole "instance UUID" generation logic into a separate
# utility function. Telemetry should not be aware at all of how the UUID is
# generated/stored.
@@ -52,7 +53,7 @@ def get_or_generate_uuid(tenant_id: str | None) -> str:
if _CACHED_UUID is not None:
return _CACHED_UUID
kv_store = get_kv_store(tenant_id=tenant_id)
kv_store = get_kv_store()
try:
_CACHED_UUID = cast(str, kv_store.load(KV_CUSTOMER_UUID_KEY))
@@ -63,18 +64,18 @@ def get_or_generate_uuid(tenant_id: str | None) -> str:
return _CACHED_UUID
def _get_or_generate_instance_domain(tenant_id: str | None = None) -> str | None: #
def _get_or_generate_instance_domain() -> str | None: #
global _CACHED_INSTANCE_DOMAIN
if _CACHED_INSTANCE_DOMAIN is not None:
return _CACHED_INSTANCE_DOMAIN
kv_store = get_kv_store(tenant_id=tenant_id)
kv_store = get_kv_store()
try:
_CACHED_INSTANCE_DOMAIN = cast(str, kv_store.load(KV_INSTANCE_DOMAIN_KEY))
except KvKeyNotFoundError:
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
with get_session_with_tenant() as db_session:
first_user = db_session.query(User).first()
if first_user:
_CACHED_INSTANCE_DOMAIN = first_user.email.split("@")[-1]
@@ -103,7 +104,7 @@ def optional_telemetry(
customer_uuid = (
_get_or_generate_customer_id_mt(tenant_id)
if MULTI_TENANT
else get_or_generate_uuid(tenant_id)
else get_or_generate_uuid()
)
payload = {
"data": data,
@@ -115,9 +116,7 @@ def optional_telemetry(
"is_cloud": MULTI_TENANT,
}
if ENTERPRISE_EDITION_ENABLED:
payload["instance_domain"] = _get_or_generate_instance_domain(
tenant_id
)
payload["instance_domain"] = _get_or_generate_instance_domain()
requests.post(
_DANSWER_TELEMETRY_ENDPOINT,
headers={"Content-Type": "application/json"},
@@ -128,8 +127,12 @@ def optional_telemetry(
# This way it silences all thread level logging as well
pass
# Run in separate thread to have minimal overhead in main flows
thread = threading.Thread(target=telemetry_logic, daemon=True)
# Run in separate thread with the same context as the current thread
# This is to ensure that the thread gets the current tenant ID
current_context = contextvars.copy_context()
thread = threading.Thread(
target=lambda: current_context.run(telemetry_logic), daemon=True
)
thread.start()
except Exception:
# Should never interfere with normal functions of Onyx

View File

@@ -6,3 +6,13 @@ from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
CURRENT_TENANT_ID_CONTEXTVAR = contextvars.ContextVar(
"current_tenant_id", default=POSTGRES_DEFAULT_SCHEMA
)
"""Utils related to contextvars"""
def get_current_tenant_id() -> str:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if tenant_id is None:
raise RuntimeError("Tenant ID is not set. This should never happen.")
return tenant_id

View File

@@ -9,6 +9,7 @@ from litellm.types.utils import ChatCompletionDeltaToolCall
from litellm.types.utils import Delta
from litellm.types.utils import Function as LiteLLMFunction
from onyx.configs.app_configs import MOCK_LLM_RESPONSE
from onyx.llm.chat_llm import DefaultMultiLLM
@@ -143,6 +144,7 @@ def test_multiple_tool_calls(default_multi_llm: DefaultMultiLLM) -> None:
temperature=0.0, # Default value from GEN_AI_TEMPERATURE
timeout=30,
parallel_tool_calls=False,
mock_response=MOCK_LLM_RESPONSE,
)
@@ -287,4 +289,5 @@ def test_multiple_tool_calls_streaming(default_multi_llm: DefaultMultiLLM) -> No
temperature=0.0, # Default value from GEN_AI_TEMPERATURE
timeout=30,
parallel_tool_calls=False,
mock_response=MOCK_LLM_RESPONSE,
)

View File

@@ -6,7 +6,7 @@ sources:
- "https://github.com/onyx-dot-app/onyx"
type: application
version: 0.2.1
appVersion: "latest"
appVersion: latest
annotations:
category: Productivity
licenses: MIT

View File

@@ -45,10 +45,10 @@ spec:
- |
alembic upgrade head &&
echo "Starting Onyx Api Server" &&
uvicorn onyx.main:app --host 0.0.0.0 --port 8080
uvicorn onyx.main:app --host 0.0.0.0 --port {{ .Values.api.containerPorts.server }}
ports:
- name: api-server-port
containerPort: {{ .Values.api.service.port }}
containerPort: {{ .Values.api.containerPorts.server }}
protocol: TCP
resources:
{{- toYaml .Values.api.resources | nindent 12 }}

View File

@@ -11,10 +11,10 @@ metadata:
spec:
type: {{ .Values.api.service.type }}
ports:
- port: {{ .Values.api.service.port }}
targetPort: api-server-port
- port: {{ .Values.api.service.servicePort }}
targetPort: {{ .Values.api.service.targetPort }}
protocol: TCP
name: api-server-port
name: {{ .Values.api.service.portName }}
selector:
{{- include "onyx-stack.selectorLabels" . | nindent 4 }}
{{- if .Values.api.deploymentLabels }}

View File

@@ -5,7 +5,7 @@ metadata:
labels:
{{- include "onyx-stack.labels" . | nindent 4 }}
spec:
replicas: 1
replicas: {{ .Values.indexCapability.replicaCount }}
selector:
matchLabels:
{{- include "onyx-stack.selectorLabels" . | nindent 6 }}
@@ -25,12 +25,14 @@ spec:
{{- end }}
spec:
containers:
- name: indexing-model-server
image: onyxdotapp/onyx-model-server:latest
imagePullPolicy: IfNotPresent
command: [ "uvicorn", "model_server.main:app", "--host", "0.0.0.0", "--port", "9000", "--limit-concurrency", "10" ]
- name: {{ .Values.indexCapability.name }}
image: "{{ .Values.indexCapability.image.repository }}:{{ .Values.indexCapability.image.tag | default .Chart.AppVersion }}"
imagePullPolicy: {{ .Values.indexCapability.image.pullPolicy }}
command: [ "uvicorn", "model_server.main:app", "--host", "0.0.0.0", "--port", "{{ .Values.indexCapability.containerPorts.server }}", "--limit-concurrency", "{{ .Values.indexCapability.limitConcurrency }}" ]
ports:
- containerPort: 9000
- name: model-server
containerPort: {{ .Values.indexCapability.containerPorts.server }}
protocol: TCP
envFrom:
- configMapRef:
name: {{ .Values.config.envConfigMapName }}

View File

@@ -3,8 +3,9 @@ kind: PersistentVolumeClaim
metadata:
name: {{ .Values.indexCapability.indexingModelPVC.name }}
spec:
storageClassName: {{ .Values.persistent.storageClassName }}
accessModes:
- {{ .Values.indexCapability.indexingModelPVC.accessMode | quote }}
resources:
requests:
storage: {{ .Values.indexCapability.indexingModelPVC.storage | quote }}
storage: {{ .Values.indexCapability.indexingModelPVC.storage | quote }}

View File

@@ -11,8 +11,8 @@ spec:
{{- toYaml .Values.indexCapability.deploymentLabels | nindent 4 }}
{{- end }}
ports:
- name: {{ .Values.indexCapability.service.name }}
- name: {{ .Values.indexCapability.service.portName }}
protocol: TCP
port: {{ .Values.indexCapability.service.port }}
targetPort: {{ .Values.indexCapability.service.port }}
type: {{ .Values.indexCapability.service.type }}
port: {{ .Values.indexCapability.service.servicePort }}
targetPort: {{ .Values.indexCapability.service.targetPort }}
type: {{ .Values.indexCapability.service.type }}

View File

@@ -3,14 +3,14 @@ kind: Deployment
metadata:
name: {{ include "onyx-stack.fullname" . }}-inference-model
labels:
{{- range .Values.inferenceCapability.deployment.labels }}
{{- range .Values.inferenceCapability.labels }}
{{ .key }}: {{ .value }}
{{- end }}
spec:
replicas: {{ .Values.inferenceCapability.deployment.replicas }}
replicas: {{ .Values.inferenceCapability.replicaCount }}
selector:
matchLabels:
{{- range .Values.inferenceCapability.deployment.labels }}
{{- range .Values.inferenceCapability.labels }}
{{ .key }}: {{ .value }}
{{- end }}
template:
@@ -21,24 +21,26 @@ spec:
{{- end }}
spec:
containers:
- name: {{ .Values.inferenceCapability.service.name }}
image: {{ .Values.inferenceCapability.deployment.image.repository }}:{{ .Values.inferenceCapability.deployment.image.tag }}
imagePullPolicy: {{ .Values.inferenceCapability.deployment.image.pullPolicy }}
command: {{ toYaml .Values.inferenceCapability.deployment.command | nindent 14 }}
- name: model-server-inference
image: "{{ .Values.inferenceCapability.image.repository }}:{{ .Values.inferenceCapability.image.tag | default .Chart.AppVersion }}"
imagePullPolicy: {{ .Values.inferenceCapability.image.pullPolicy }}
command: [ "uvicorn", "model_server.main:app", "--host", "0.0.0.0", "--port", "{{ .Values.inferenceCapability.containerPorts.server }}" ]
ports:
- containerPort: {{ .Values.inferenceCapability.service.port }}
- name: model-server
containerPort: {{ .Values.inferenceCapability.containerPorts.server }}
protocol: TCP
envFrom:
- configMapRef:
name: {{ .Values.config.envConfigMapName }}
env:
{{- include "onyx-stack.envSecrets" . | nindent 12}}
volumeMounts:
{{- range .Values.inferenceCapability.deployment.volumeMounts }}
{{- range .Values.inferenceCapability.volumeMounts }}
- name: {{ .name }}
mountPath: {{ .mountPath }}
{{- end }}
volumes:
{{- range .Values.inferenceCapability.deployment.volumes }}
{{- range .Values.inferenceCapability.volumes }}
- name: {{ .name }}
persistentVolumeClaim:
claimName: {{ .persistentVolumeClaim.claimName }}

View File

@@ -3,6 +3,7 @@ kind: PersistentVolumeClaim
metadata:
name: {{ .Values.inferenceCapability.pvc.name }}
spec:
storageClassName: {{ .Values.persistent.storageClassName }}
accessModes:
{{- toYaml .Values.inferenceCapability.pvc.accessModes | nindent 4 }}
resources:

View File

@@ -5,11 +5,11 @@ metadata:
spec:
type: {{ .Values.inferenceCapability.service.type }}
ports:
- port: {{ .Values.inferenceCapability.service.port }}
targetPort: {{ .Values.inferenceCapability.service.port }}
- port: {{ .Values.inferenceCapability.service.servicePort}}
targetPort: {{ .Values.inferenceCapability.service.targetPort }}
protocol: TCP
name: {{ .Values.inferenceCapability.service.name }}
name: {{ .Values.inferenceCapability.service.portName }}
selector:
{{- range .Values.inferenceCapability.deployment.labels }}
{{- range .Values.inferenceCapability.labels }}
{{ .key }}: {{ .value }}
{{- end }}

View File

@@ -5,11 +5,11 @@ metadata:
data:
nginx.conf: |
upstream api_server {
server {{ include "onyx-stack.fullname" . }}-api-service:{{ .Values.api.service.port }} fail_timeout=0;
server {{ include "onyx-stack.fullname" . }}-api-service:{{ .Values.api.service.servicePort }} fail_timeout=0;
}
upstream web_server {
server {{ include "onyx-stack.fullname" . }}-webserver:{{ .Values.webserver.service.port }} fail_timeout=0;
server {{ include "onyx-stack.fullname" . }}-webserver:{{ .Values.webserver.service.servicePort }} fail_timeout=0;
}
server {

View File

@@ -11,5 +11,5 @@ spec:
- name: wget
image: busybox
command: ['wget']
args: ['{{ include "onyx-stack.fullname" . }}-webserver:{{ .Values.webserver.service.port }}']
args: ['{{ include "onyx-stack.fullname" . }}-webserver:{{ .Values.webserver.service.servicePort }}']
restartPolicy: Never

View File

@@ -41,7 +41,7 @@ spec:
imagePullPolicy: {{ .Values.webserver.image.pullPolicy }}
ports:
- name: http
containerPort: {{ .Values.webserver.service.port }}
containerPort: {{ .Values.webserver.containerPorts.server }}
protocol: TCP
resources:
{{- toYaml .Values.webserver.resources | nindent 12 }}

View File

@@ -10,8 +10,8 @@ metadata:
spec:
type: {{ .Values.webserver.service.type }}
ports:
- port: {{ .Values.webserver.service.port }}
targetPort: http
- port: {{ .Values.webserver.service.servicePort }}
targetPort: {{ .Values.webserver.service.targetPort }}
protocol: TCP
name: http
selector:

View File

@@ -2,62 +2,73 @@
# This is a YAML-formatted file.
# Declare variables to be passed into your templates.
postgresql:
primary:
persistence:
size: 5Gi
enabled: true
auth:
existingSecret: onyx-secrets
secretKeys:
# overwriting as postgres typically expects 'postgres-password'
adminPasswordKey: postgres_password
imagePullSecrets: []
nameOverride: ""
fullnameOverride: ""
persistent:
storageClassName: ""
inferenceCapability:
service:
name: inference-model-server-service
portName: modelserver
type: ClusterIP
port: 9000
servicePort: 9000
targetPort: 9000
pvc:
name: inference-model-pvc
accessModes:
- ReadWriteOnce
storage: 3Gi
deployment:
name: inference-model-server-deployment
replicas: 1
labels:
- key: app
value: inference-model-server
image:
repository: onyxdotapp/onyx-model-server
tag: latest
pullPolicy: IfNotPresent
command:
[
"uvicorn",
"model_server.main:app",
"--host",
"0.0.0.0",
"--port",
"9000",
]
port: 9000
volumeMounts:
- name: inference-model-storage
mountPath: /root/.cache
volumes:
- name: inference-model-storage
persistentVolumeClaim:
claimName: inference-model-pvc
name: inference-model-server
replicaCount: 1
labels:
- key: app
value: inference-model-server
image:
repository: onyxdotapp/onyx-model-server
# Overrides the image tag whose default is the chart appVersion.
tag: ""
pullPolicy: IfNotPresent
containerPorts:
server: 9000
volumeMounts:
- name: inference-model-storage
mountPath: /root/.cache
volumes:
- name: inference-model-storage
persistentVolumeClaim:
claimName: inference-model-pvc
podLabels:
- key: app
value: inference-model-server
indexCapability:
service:
portName: modelserver
type: ClusterIP
port: 9000
name: indexing-model-server-port
servicePort: 9000
targetPort: 9000
replicaCount: 1
name: indexing-model-server
deploymentLabels:
app: indexing-model-server
podLabels:
app: indexing-model-server
indexingOnly: "True"
podAnnotations: {}
containerPorts:
server: 9000
volumeMounts:
- name: indexing-model-storage
mountPath: /root/.cache
@@ -69,7 +80,12 @@ indexCapability:
name: indexing-model-storage
accessMode: "ReadWriteOnce"
storage: "3Gi"
image:
repository: onyxdotapp/onyx-model-server
# Overrides the image tag whose default is the chart appVersion.
tag: ""
pullPolicy: IfNotPresent
limitConcurrency: 10
config:
envConfigMapName: env-configmap
@@ -84,16 +100,6 @@ serviceAccount:
# If not set and create is true, a name is generated using the fullname template
name: ""
postgresql:
primary:
persistence:
size: 5Gi
enabled: true
auth:
existingSecret: onyx-secrets
secretKeys:
adminPasswordKey: postgres_password # overwriting as postgres typically expects 'postgres-password'
nginx:
containerPorts:
http: 1024
@@ -135,9 +141,13 @@ webserver:
# runAsNonRoot: true
# runAsUser: 1000
containerPorts:
server: 3000
service:
type: ClusterIP
port: 3000
servicePort: 3000
targetPort: http
resources: {}
# We usually recommend not to specify default resources and to leave this as a conscious
@@ -156,7 +166,7 @@ webserver:
minReplicas: 1
maxReplicas: 100
targetCPUUtilizationPercentage: 80
# targetMemoryUtilizationPercentage: 80
targetMemoryUtilizationPercentage: 80
# Additional volumes on the output Deployment definition.
volumes: []
@@ -189,6 +199,9 @@ api:
scope: onyx-backend
app: api-server
containerPorts:
server: 8080
podSecurityContext:
{}
# fsGroup: 2000
@@ -204,7 +217,9 @@ api:
service:
type: ClusterIP
port: 8080
servicePort: 8080
targetPort: api-server-port
portName: api-server-port
resources: {}
# We usually recommend not to specify default resources and to leave this as a conscious
@@ -223,7 +238,7 @@ api:
minReplicas: 1
maxReplicas: 100
targetCPUUtilizationPercentage: 80
# targetMemoryUtilizationPercentage: 80
targetMemoryUtilizationPercentage: 80
# Additional volumes on the output Deployment definition.
volumes: []
@@ -247,7 +262,7 @@ background:
repository: onyxdotapp/onyx-backend
pullPolicy: IfNotPresent
# Overrides the image tag whose default is the chart appVersion.
tag: latest
tag: ""
podAnnotations: {}
podLabels:
scope: onyx-backend
@@ -284,7 +299,7 @@ background:
minReplicas: 1
maxReplicas: 100
targetCPUUtilizationPercentage: 80
# targetMemoryUtilizationPercentage: 80
targetMemoryUtilizationPercentage: 80
# Additional volumes on the output Deployment definition.
volumes: []
@@ -303,6 +318,16 @@ background:
tolerations: []
vespa:
volumeClaimTemplates:
- metadata:
name: vespa-storage
spec:
accessModes:
- ReadWriteOnce
storageClassName: ""
resources:
requests:
storage: 1Gi
enabled: true
replicaCount: 1
image:
@@ -377,19 +402,11 @@ redis:
# # hosts:
# # - chart-example.local
persistence:
vespa:
enabled: true
existingClaim: ""
storageClassName: ""
accessModes:
- ReadWriteOnce
size: 5Gi
auth:
# for storing smtp, oauth, slack, and other secrets
# existingSecret onyx-secret for storing smtp, oauth, slack, and other secrets
# keys are lowercased version of env vars (e.g. SMTP_USER -> smtp_user)
existingSecret: "" # onyx-secrets
existingSecret: ""
# optionally override the secret keys to reference in the secret
# this is used to populate the env vars in individual deployments
# the values here reference the keys in secrets below
@@ -413,14 +430,22 @@ auth:
redis_password: "password"
configMap:
AUTH_TYPE: "disabled" # Change this for production uses unless Onyx is only accessible behind VPN
SESSION_EXPIRE_TIME_SECONDS: "86400" # 1 Day Default
VALID_EMAIL_DOMAINS: "" # Can be something like onyx.app, as an extra double-check
SMTP_SERVER: "" # For sending verification emails, if unspecified then defaults to 'smtp.gmail.com'
SMTP_PORT: "" # For sending verification emails, if unspecified then defaults to '587'
SMTP_USER: "" # 'your-email@company.com'
# SMTP_PASS: "" # 'your-gmail-password'
EMAIL_FROM: "" # 'your-email@company.com' SMTP_USER missing used instead
# Change this for production uses unless Onyx is only accessible behind VPN
AUTH_TYPE: "disabled"
# 1 Day Default
SESSION_EXPIRE_TIME_SECONDS: "86400"
# Can be something like onyx.app, as an extra double-check
VALID_EMAIL_DOMAINS: ""
# For sending verification emails, if unspecified then defaults to 'smtp.gmail.com'
SMTP_SERVER: ""
# For sending verification emails, if unspecified then defaults to '587'
SMTP_PORT: ""
# 'your-email@company.com'
SMTP_USER: ""
# 'your-gmail-password'
# SMTP_PASS: ""
# 'your-email@company.com' SMTP_USER missing used instead
EMAIL_FROM: ""
# Gen AI Settings
GEN_AI_MAX_TOKENS: ""
QA_TIMEOUT: "60"
@@ -462,7 +487,7 @@ configMap:
DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER: ""
DANSWER_BOT_DISPLAY_ERROR_MSGS: ""
DANSWER_BOT_RESPOND_EVERY_CHANNEL: ""
DANSWER_BOT_DISABLE_COT: "" # Currently unused
DANSWER_BOT_DISABLE_COT: ""
NOTIFY_SLACKBOT_NO_ANSWER: ""
# Logging
# Optional Telemetry, please keep it on (nothing sensitive is collected)? <3
@@ -473,7 +498,8 @@ configMap:
LOG_DANSWER_MODEL_INTERACTIONS: ""
LOG_VESPA_TIMING_INFORMATION: ""
# Shared or Non-backend Related
WEB_DOMAIN: "http://localhost:3000" # for web server and api server
DOMAIN: "localhost" # for nginx
WEB_DOMAIN: "http://localhost:3000"
# DOMAIN used by nginx
DOMAIN: "localhost"
# Chat Configs
HARD_DELETE_CHATS: ""

137
web/package-lock.json generated
View File

@@ -25,6 +25,7 @@
"@radix-ui/react-scroll-area": "^1.2.2",
"@radix-ui/react-select": "^2.1.2",
"@radix-ui/react-separator": "^1.1.0",
"@radix-ui/react-slider": "^1.2.2",
"@radix-ui/react-slot": "^1.1.0",
"@radix-ui/react-switch": "^1.1.1",
"@radix-ui/react-tabs": "^1.1.1",
@@ -4963,6 +4964,142 @@
}
}
},
"node_modules/@radix-ui/react-slider": {
"version": "1.2.2",
"resolved": "https://registry.npmjs.org/@radix-ui/react-slider/-/react-slider-1.2.2.tgz",
"integrity": "sha512-sNlU06ii1/ZcbHf8I9En54ZPW0Vil/yPVg4vQMcFNjrIx51jsHbFl1HYHQvCIWJSr1q0ZmA+iIs/ZTv8h7HHSA==",
"license": "MIT",
"dependencies": {
"@radix-ui/number": "1.1.0",
"@radix-ui/primitive": "1.1.1",
"@radix-ui/react-collection": "1.1.1",
"@radix-ui/react-compose-refs": "1.1.1",
"@radix-ui/react-context": "1.1.1",
"@radix-ui/react-direction": "1.1.0",
"@radix-ui/react-primitive": "2.0.1",
"@radix-ui/react-use-controllable-state": "1.1.0",
"@radix-ui/react-use-layout-effect": "1.1.0",
"@radix-ui/react-use-previous": "1.1.0",
"@radix-ui/react-use-size": "1.1.0"
},
"peerDependencies": {
"@types/react": "*",
"@types/react-dom": "*",
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
"react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
},
"peerDependenciesMeta": {
"@types/react": {
"optional": true
},
"@types/react-dom": {
"optional": true
}
}
},
"node_modules/@radix-ui/react-slider/node_modules/@radix-ui/primitive": {
"version": "1.1.1",
"resolved": "https://registry.npmjs.org/@radix-ui/primitive/-/primitive-1.1.1.tgz",
"integrity": "sha512-SJ31y+Q/zAyShtXJc8x83i9TYdbAfHZ++tUZnvjJJqFjzsdUnKsxPL6IEtBlxKkU7yzer//GQtZSV4GbldL3YA==",
"license": "MIT"
},
"node_modules/@radix-ui/react-slider/node_modules/@radix-ui/react-collection": {
"version": "1.1.1",
"resolved": "https://registry.npmjs.org/@radix-ui/react-collection/-/react-collection-1.1.1.tgz",
"integrity": "sha512-LwT3pSho9Dljg+wY2KN2mrrh6y3qELfftINERIzBUO9e0N+t0oMTyn3k9iv+ZqgrwGkRnLpNJrsMv9BZlt2yuA==",
"license": "MIT",
"dependencies": {
"@radix-ui/react-compose-refs": "1.1.1",
"@radix-ui/react-context": "1.1.1",
"@radix-ui/react-primitive": "2.0.1",
"@radix-ui/react-slot": "1.1.1"
},
"peerDependencies": {
"@types/react": "*",
"@types/react-dom": "*",
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
"react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
},
"peerDependenciesMeta": {
"@types/react": {
"optional": true
},
"@types/react-dom": {
"optional": true
}
}
},
"node_modules/@radix-ui/react-slider/node_modules/@radix-ui/react-compose-refs": {
"version": "1.1.1",
"resolved": "https://registry.npmjs.org/@radix-ui/react-compose-refs/-/react-compose-refs-1.1.1.tgz",
"integrity": "sha512-Y9VzoRDSJtgFMUCoiZBDVo084VQ5hfpXxVE+NgkdNsjiDBByiImMZKKhxMwCbdHvhlENG6a833CbFkOQvTricw==",
"license": "MIT",
"peerDependencies": {
"@types/react": "*",
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
},
"peerDependenciesMeta": {
"@types/react": {
"optional": true
}
}
},
"node_modules/@radix-ui/react-slider/node_modules/@radix-ui/react-context": {
"version": "1.1.1",
"resolved": "https://registry.npmjs.org/@radix-ui/react-context/-/react-context-1.1.1.tgz",
"integrity": "sha512-UASk9zi+crv9WteK/NU4PLvOoL3OuE6BWVKNF6hPRBtYBDXQ2u5iu3O59zUlJiTVvkyuycnqrztsHVJwcK9K+Q==",
"license": "MIT",
"peerDependencies": {
"@types/react": "*",
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
},
"peerDependenciesMeta": {
"@types/react": {
"optional": true
}
}
},
"node_modules/@radix-ui/react-slider/node_modules/@radix-ui/react-primitive": {
"version": "2.0.1",
"resolved": "https://registry.npmjs.org/@radix-ui/react-primitive/-/react-primitive-2.0.1.tgz",
"integrity": "sha512-sHCWTtxwNn3L3fH8qAfnF3WbUZycW93SM1j3NFDzXBiz8D6F5UTTy8G1+WFEaiCdvCVRJWj6N2R4Xq6HdiHmDg==",
"license": "MIT",
"dependencies": {
"@radix-ui/react-slot": "1.1.1"
},
"peerDependencies": {
"@types/react": "*",
"@types/react-dom": "*",
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
"react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
},
"peerDependenciesMeta": {
"@types/react": {
"optional": true
},
"@types/react-dom": {
"optional": true
}
}
},
"node_modules/@radix-ui/react-slider/node_modules/@radix-ui/react-slot": {
"version": "1.1.1",
"resolved": "https://registry.npmjs.org/@radix-ui/react-slot/-/react-slot-1.1.1.tgz",
"integrity": "sha512-RApLLOcINYJA+dMVbOju7MYv1Mb2EBp2nH4HdDzXTSyaR5optlm6Otrz1euW3HbdOR8UmmFK06TD+A9frYWv+g==",
"license": "MIT",
"dependencies": {
"@radix-ui/react-compose-refs": "1.1.1"
},
"peerDependencies": {
"@types/react": "*",
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
},
"peerDependenciesMeta": {
"@types/react": {
"optional": true
}
}
},
"node_modules/@radix-ui/react-slot": {
"version": "1.1.0",
"resolved": "https://registry.npmjs.org/@radix-ui/react-slot/-/react-slot-1.1.0.tgz",

View File

@@ -28,6 +28,7 @@
"@radix-ui/react-scroll-area": "^1.2.2",
"@radix-ui/react-select": "^2.1.2",
"@radix-ui/react-separator": "^1.1.0",
"@radix-ui/react-slider": "^1.2.2",
"@radix-ui/react-slot": "^1.1.0",
"@radix-ui/react-switch": "^1.1.1",
"@radix-ui/react-tabs": "^1.1.1",

View File

@@ -2,12 +2,13 @@ import { defineConfig, devices } from "@playwright/test";
export default defineConfig({
globalSetup: require.resolve("./tests/e2e/global-setup"),
timeout: 30000, // 30 seconds timeout
projects: [
{
name: "admin",
use: {
...devices["Desktop Chrome"],
viewport: { width: 1280, height: 720 },
storageState: "admin_auth.json",
},
testIgnore: ["**/codeUtils.test.ts"],

View File

@@ -720,7 +720,6 @@ export function AssistantEditor({
name="description"
label="Description"
placeholder="Use this Assistant to help draft professional emails"
data-testid="assistant-description-input"
className="[&_input]:placeholder:text-text-muted/50"
/>

View File

@@ -4,7 +4,7 @@ import { OnyxIcon } from "@/components/icons/icons";
export function ChatIntro({ selectedPersona }: { selectedPersona: Persona }) {
return (
<div className="flex flex-col items-center gap-6">
<div data-testid="chat-intro" className="flex flex-col items-center gap-6">
<div className="relative flex flex-col gap-y-4 w-fit mx-auto justify-center">
<div className="absolute z-10 items-center flex -left-12 top-1/2 -translate-y-1/2">
<AssistantIcon size={36} assistant={selectedPersona} />

View File

@@ -297,6 +297,7 @@ export function ChatPage({
// 2. Selected assistant (assistnat default in this chat session)
// 3. First pinned assistants (ordered list of pinned assistants)
// 4. Available assistants (ordered list of available assistants)
// Relevant test: `live_assistant.spec.ts`
const liveAssistant: Persona | undefined = useMemo(
() =>
alternativeAssistant ||
@@ -403,9 +404,6 @@ export function ChatPage({
filterManager.setSelectedTags([]);
filterManager.setTimeRange(null);
// reset LLM overrides (based on chat session!)
llmOverrideManager.updateTemperature(null);
// remove uploaded files
setCurrentMessageFiles([]);
@@ -448,6 +446,7 @@ export function ChatPage({
);
const chatSession = (await response.json()) as BackendChatSession;
setSelectedAssistantFromId(chatSession.persona_id);
const newMessageMap = processRawChatHistory(chatSession.messages);

View File

@@ -478,6 +478,7 @@ export function ChatInputBar({
onKeyDownCapture={handleKeyDown}
onChange={handleInputChange}
ref={textAreaRef}
id="onyx-chat-input-textarea"
className={`
m-0
w-full
@@ -703,6 +704,7 @@ export function ChatInputBar({
</div>
<div className="flex my-auto">
<button
id="onyx-chat-input-send-button"
className={`cursor-pointer ${
chatState == "streaming" ||
chatState == "toolBuilding" ||

View File

@@ -1,4 +1,4 @@
import React, { useState } from "react";
import React, { useState, useEffect } from "react";
import {
Popover,
PopoverContent,
@@ -26,6 +26,9 @@ import {
} from "@/components/ui/tooltip";
import { FiAlertTriangle } from "react-icons/fi";
import { Slider } from "@/components/ui/slider";
import { useUser } from "@/components/user/UserProvider";
interface LLMPopoverProps {
llmProviders: LLMProviderDescriptor[];
llmOverrideManager: LlmOverrideManager;
@@ -40,6 +43,7 @@ export default function LLMPopover({
currentAssistant,
}: LLMPopoverProps) {
const [isOpen, setIsOpen] = useState(false);
const { user } = useUser();
const { llmOverride, updateLLMOverride } = llmOverrideManager;
const currentLlm = llmOverride.modelName;
@@ -88,10 +92,29 @@ export default function LLMPopover({
? getDisplayNameForModel(defaultModelName)
: null;
const [localTemperature, setLocalTemperature] = useState(
llmOverrideManager.temperature ?? 0.5
);
useEffect(() => {
setLocalTemperature(llmOverrideManager.temperature ?? 0.5);
}, [llmOverrideManager.temperature]);
const handleTemperatureChange = (value: number[]) => {
setLocalTemperature(value[0]);
};
const handleTemperatureChangeComplete = (value: number[]) => {
llmOverrideManager.updateTemperature(value[0]);
};
return (
<Popover open={isOpen} onOpenChange={setIsOpen}>
<PopoverTrigger asChild>
<button className="focus:outline-none">
<button
className="focus:outline-none"
data-testid="llm-popover-trigger"
>
<ChatInputOption
minimize
toggle
@@ -115,9 +138,9 @@ export default function LLMPopover({
</PopoverTrigger>
<PopoverContent
align="start"
className="w-64 p-1 bg-background border border-gray-200 rounded-md shadow-lg"
className="w-64 p-1 bg-background border border-gray-200 rounded-md shadow-lg flex flex-col"
>
<div className="max-h-[300px] overflow-y-auto">
<div className="flex-grow max-h-[300px] default-scrollbar overflow-y-auto">
{llmOptions.map(({ name, icon, value }, index) => {
if (!requiresImageGeneration || checkLLMSupportsImageInput(name)) {
return (
@@ -168,6 +191,25 @@ export default function LLMPopover({
return null;
})}
</div>
{user?.preferences?.temperature_override_enabled && (
<div className="mt-2 pt-2 border-t border-gray-200">
<div className="w-full px-3 py-2">
<Slider
value={[localTemperature]}
max={llmOverrideManager.maxTemperature}
min={0}
step={0.01}
onValueChange={handleTemperatureChange}
onValueCommit={handleTemperatureChangeComplete}
className="w-full"
/>
<div className="flex justify-between text-xs text-gray-500 mt-2">
<span>Temperature (creativity)</span>
<span>{localTemperature.toFixed(1)}</span>
</div>
</div>
</div>
)}
</PopoverContent>
</Popover>
);

View File

@@ -68,6 +68,7 @@ export interface ChatSession {
shared_status: ChatSessionSharedStatus;
folder_id: number | null;
current_alternate_model: string;
current_temperature_override: number | null;
}
export interface SearchSession {
@@ -107,6 +108,7 @@ export interface BackendChatSession {
messages: BackendMessage[];
time_created: string;
shared_status: ChatSessionSharedStatus;
current_temperature_override: number | null;
current_alternate_model?: string;
}

View File

@@ -75,6 +75,23 @@ export async function updateModelOverrideForChatSession(
return response;
}
export async function updateTemperatureOverrideForChatSession(
chatSessionId: string,
newTemperature: number
) {
const response = await fetch("/api/chat/update-chat-session-temperature", {
method: "PUT",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
chat_session_id: chatSessionId,
temperature_override: newTemperature,
}),
});
return response;
}
export async function createChatSession(
personaId: number,
description: string | null

View File

@@ -402,7 +402,7 @@ export const AIMessage = ({
return (
<div
id="onyx-ai-message"
id={isComplete ? "onyx-ai-message" : undefined}
ref={trackedElementRef}
className={`py-5 ml-4 lg:px-5 relative flex `}
>

View File

@@ -30,8 +30,13 @@ export function UserSettingsModal({
defaultModel: string | null;
}) {
const { inputPrompts, refreshInputPrompts } = useChatContext();
const { refreshUser, user, updateUserAutoScroll, updateUserShortcuts } =
useUser();
const {
refreshUser,
user,
updateUserAutoScroll,
updateUserShortcuts,
updateUserTemperatureOverrideEnabled,
} = useUser();
const containerRef = useRef<HTMLDivElement>(null);
const messageRef = useRef<HTMLDivElement>(null);
@@ -179,6 +184,16 @@ export function UserSettingsModal({
/>
<Label className="text-sm">Enable Prompt Shortcuts</Label>
</div>
<div className="flex items-center gap-x-2">
<Switch
size="sm"
checked={user?.preferences?.temperature_override_enabled}
onCheckedChange={(checked) => {
updateUserTemperatureOverrideEnabled(checked);
}}
/>
<Label className="text-sm">Enable Temperature Override</Label>
</div>
</div>
<Separator />

View File

@@ -129,6 +129,7 @@ const SortableAssistant: React.FC<SortableAssistantProps> = ({
className="w-3 ml-[2px] mr-[2px] group-hover:visible invisible flex-none cursor-grab"
/>
<button
data-testid={`assistant-[${assistant.id}]`}
onClick={(e) => {
e.preventDefault();
if (!isDragging) {

View File

@@ -103,7 +103,7 @@ export function Modal({
</button>
</div>
)}
<div className="flex-shrink-0">
<div className="items-start flex-shrink-0">
{title && (
<>
<div className="flex">

View File

@@ -133,6 +133,7 @@ export function UserDropdown({
onOpenChange={onOpenChange}
content={
<div
id="onyx-user-dropdown"
onClick={() => setUserInfoVisible(!userInfoVisible)}
className="flex relative cursor-pointer"
>

View File

@@ -59,7 +59,23 @@ export const Popup: React.FC<PopupSpec> = ({ message, type }) => (
/>
</svg>
)}
<span className="font-medium">{message}</span>
<div className="flex flex-col items-center space-x-2">
<span className="font-medium">{message}</span>
{type === "error" && (
<span className="text-xs text-red-100">
Need help?{" "}
<a
href="https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA"
target="_blank"
rel="noopener noreferrer"
className="underline hover:text-red-900"
>
Join our community
</a>{" "}
for support.
</span>
)}
</div>
</div>
);

View File

@@ -0,0 +1,28 @@
"use client";
import * as React from "react";
import * as SliderPrimitive from "@radix-ui/react-slider";
import { cn } from "@/lib/utils";
const Slider = React.forwardRef<
React.ElementRef<typeof SliderPrimitive.Root>,
React.ComponentPropsWithoutRef<typeof SliderPrimitive.Root>
>(({ className, ...props }, ref) => (
<SliderPrimitive.Root
ref={ref}
className={cn(
"relative flex w-full touch-none select-none items-center",
className
)}
{...props}
>
<SliderPrimitive.Track className="relative h-2 w-full grow overflow-hidden rounded-full bg-neutral-100 dark:bg-neutral-800">
<SliderPrimitive.Range className="absolute h-full bg-neutral-900 dark:bg-neutral-50" />
</SliderPrimitive.Track>
<SliderPrimitive.Thumb className="block h-3 w-3 rounded-full border border-neutral-900 bg-white ring-offset-white transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-neutral-950 focus-visible:ring-offset disabled:pointer-events-none disabled:opacity-50 dark:border-neutral-50 dark:bg-neutral-950 dark:ring-offset-neutral-950 dark:focus-visible:ring-neutral-300" />
</SliderPrimitive.Root>
));
Slider.displayName = SliderPrimitive.Root.displayName;
export { Slider };

View File

@@ -18,6 +18,7 @@ interface UserContextType {
assistantId: number,
isPinned: boolean
) => Promise<boolean>;
updateUserTemperatureOverrideEnabled: (enabled: boolean) => Promise<void>;
}
const UserContext = createContext<UserContextType | undefined>(undefined);
@@ -57,6 +58,41 @@ export function UserProvider({
console.error("Error fetching current user:", error);
}
};
const updateUserTemperatureOverrideEnabled = async (enabled: boolean) => {
try {
setUpToDateUser((prevUser) => {
if (prevUser) {
return {
...prevUser,
preferences: {
...prevUser.preferences,
temperature_override_enabled: enabled,
},
};
}
return prevUser;
});
const response = await fetch(
`/api/temperature-override-enabled?temperature_override_enabled=${enabled}`,
{
method: "PATCH",
headers: {
"Content-Type": "application/json",
},
}
);
if (!response.ok) {
await refreshUser();
throw new Error("Failed to update user temperature override setting");
}
} catch (error) {
console.error("Error updating user temperature override setting:", error);
throw error;
}
};
const updateUserShortcuts = async (enabled: boolean) => {
try {
setUpToDateUser((prevUser) => {
@@ -184,6 +220,7 @@ export function UserProvider({
refreshUser,
updateUserAutoScroll,
updateUserShortcuts,
updateUserTemperatureOverrideEnabled,
toggleAssistantPinnedStatus,
isAdmin: upToDateUser?.role === UserRole.ADMIN,
// Curator status applies for either global or basic curator

View File

@@ -10,7 +10,7 @@ import {
} from "@/lib/types";
import useSWR, { mutate, useSWRConfig } from "swr";
import { errorHandlingFetcher } from "./fetcher";
import { useContext, useEffect, useState } from "react";
import { useContext, useEffect, useMemo, useState } from "react";
import { DateRangePickerValue } from "@/app/ee/admin/performance/DateRangeSelector";
import { Filters, SourceMetadata } from "./search/interfaces";
import {
@@ -28,6 +28,8 @@ import { isAnthropic } from "@/app/admin/configuration/llm/interfaces";
import { getSourceMetadata } from "./sources";
import { AuthType, NEXT_PUBLIC_CLOUD_ENABLED } from "./constants";
import { useUser } from "@/components/user/UserProvider";
import { SEARCH_TOOL_ID } from "@/app/chat/tools/constants";
import { updateTemperatureOverrideForChatSession } from "@/app/chat/lib";
const CREDENTIAL_URL = "/api/manage/admin/credential";
@@ -360,14 +362,22 @@ export interface LlmOverride {
export interface LlmOverrideManager {
llmOverride: LlmOverride;
updateLLMOverride: (newOverride: LlmOverride) => void;
temperature: number | null;
updateTemperature: (temperature: number | null) => void;
temperature: number;
updateTemperature: (temperature: number) => void;
updateModelOverrideForChatSession: (chatSession?: ChatSession) => void;
imageFilesPresent: boolean;
updateImageFilesPresent: (present: boolean) => void;
liveAssistant: Persona | null;
maxTemperature: number;
}
// Things to test
// 1. User override
// 2. User preference (defaults to system wide default if no preference set)
// 3. Current assistant
// 4. Current chat session
// 5. Live assistant
/*
LLM Override is as follows (i.e. this order)
- User override (explicitly set in the chat input bar)
@@ -386,6 +396,20 @@ Changes take place as
- (uploadLLMOverride) User explicitly setting a model override (and we explicitly override and set the userSpecifiedOverride which we'll use in place of the user preferences unless overridden by an assistant)
If we have a live assistant, we should use that model override
Relevant test: `llm_ordering.spec.ts`.
Temperature override is set as follows:
- For existing chat sessions:
- If the user has previously overridden the temperature for a specific chat session,
that value is persisted and used when the user returns to that chat.
- This persistence applies even if the temperature was set before sending the first message in the chat.
- For new chat sessions:
- If the search tool is available, the default temperature is set to 0.
- If the search tool is not available, the default temperature is set to 0.5.
This approach ensures that user preferences are maintained for existing chats while
providing appropriate defaults for new conversations based on the available tools.
*/
export function useLlmOverride(
@@ -398,11 +422,6 @@ export function useLlmOverride(
const [chatSession, setChatSession] = useState<ChatSession | null>(null);
const llmOverrideUpdate = () => {
if (!chatSession && currentChatSession) {
setChatSession(currentChatSession || null);
return;
}
if (liveAssistant?.llm_model_version_override) {
setLlmOverride(
getValidLlmOverride(liveAssistant.llm_model_version_override)
@@ -490,24 +509,68 @@ export function useLlmOverride(
}
};
const [temperature, setTemperature] = useState<number | null>(0);
useEffect(() => {
const [temperature, setTemperature] = useState<number>(() => {
llmOverrideUpdate();
}, [liveAssistant, currentChatSession]);
if (currentChatSession?.current_temperature_override != null) {
return Math.min(
currentChatSession.current_temperature_override,
isAnthropic(llmOverride.provider, llmOverride.modelName) ? 1.0 : 2.0
);
} else if (
liveAssistant?.tools.some((tool) => tool.name === SEARCH_TOOL_ID)
) {
return 0;
}
return 0.5;
});
const maxTemperature = useMemo(() => {
return isAnthropic(llmOverride.provider, llmOverride.modelName) ? 1.0 : 2.0;
}, [llmOverride]);
useEffect(() => {
if (isAnthropic(llmOverride.provider, llmOverride.modelName)) {
setTemperature((prevTemp) => Math.min(prevTemp ?? 0, 1.0));
const newTemperature = Math.min(temperature, 1.0);
setTemperature(newTemperature);
if (chatSession?.id) {
updateTemperatureOverrideForChatSession(chatSession.id, newTemperature);
}
}
}, [llmOverride]);
const updateTemperature = (temperature: number | null) => {
useEffect(() => {
if (!chatSession && currentChatSession) {
setChatSession(currentChatSession || null);
if (temperature) {
updateTemperatureOverrideForChatSession(
currentChatSession.id,
temperature
);
}
return;
}
if (currentChatSession?.current_temperature_override) {
setTemperature(currentChatSession.current_temperature_override);
} else if (
liveAssistant?.tools.some((tool) => tool.name === SEARCH_TOOL_ID)
) {
setTemperature(0);
} else {
setTemperature(0.5);
}
}, [liveAssistant, currentChatSession]);
const updateTemperature = (temperature: number) => {
if (isAnthropic(llmOverride.provider, llmOverride.modelName)) {
setTemperature((prevTemp) => Math.min(temperature ?? 0, 1.0));
setTemperature((prevTemp) => Math.min(temperature, 1.0));
} else {
setTemperature(temperature);
}
if (chatSession) {
updateTemperatureOverrideForChatSession(chatSession.id, temperature);
}
};
return {
@@ -519,6 +582,7 @@ export function useLlmOverride(
imageFilesPresent,
updateImageFilesPresent,
liveAssistant: liveAssistant ?? null,
maxTemperature,
};
}

View File

@@ -12,6 +12,7 @@ interface UserPreferences {
recent_assistants: number[];
auto_scroll: boolean | null;
shortcut_enabled: boolean;
temperature_override_enabled: boolean;
}
export enum UserRole {

View File

@@ -1,54 +0,0 @@
import { test, expect } from "@playwright/test";
// Use pre-signed in "admin" storage state
test.use({
storageState: "admin_auth.json",
});
test("Chat workflow", async ({ page }) => {
// Initial setup
await page.goto("http://localhost:3000/chat", { timeout: 3000 });
// Interact with Art assistant
await page.locator("button").filter({ hasText: "Art" }).click();
await page.getByPlaceholder("Message Art assistant...").fill("Hi");
await page.keyboard.press("Enter");
await page.waitForTimeout(3000);
// Start a new chat
await page.getByRole("link", { name: "Start New Chat" }).click();
await page.waitForNavigation({ waitUntil: "networkidle" });
// Check for expected text
await expect(page.getByText("Assistant for generating")).toBeVisible();
// Interact with General assistant
await page.locator("button").filter({ hasText: "General" }).click();
// Check URL after clicking General assistant
await expect(page).toHaveURL("http://localhost:3000/chat?assistantId=-1", {
timeout: 5000,
});
// Create a new assistant
await page.getByRole("button", { name: "Explore Assistants" }).click();
await page.getByRole("button", { name: "Create" }).click();
await page.getByTestId("name").click();
await page.getByTestId("name").fill("Test Assistant");
await page.getByTestId("description").click();
await page.getByTestId("description").fill("Test Assistant Description");
await page.getByTestId("system_prompt").click();
await page.getByTestId("system_prompt").fill("Test Assistant Instructions");
await page.getByRole("button", { name: "Create" }).click();
// Verify new assistant creation
await expect(page.getByText("Test Assistant Description")).toBeVisible({
timeout: 5000,
});
// Start another new chat
await page.getByRole("link", { name: "Start New Chat" }).click();
await expect(page.getByText("Assistant with access to")).toBeVisible({
timeout: 5000,
});
});

View File

@@ -0,0 +1,56 @@
import { test, expect } from "@playwright/test";
import { dragElementAbove, dragElementBelow } from "../utils/dragUtils";
import { loginAsRandomUser } from "../utils/auth";
test("Assistant Drag and Drop", async ({ page }) => {
test.fail();
await page.context().clearCookies();
await loginAsRandomUser(page);
// Navigate to the chat page
await page.goto("http://localhost:3000/chat");
// Helper function to get the current order of assistants
const getAssistantOrder = async () => {
const assistants = await page.$$('[data-testid^="assistant-["]');
return Promise.all(
assistants.map(async (assistant) => {
const nameElement = await assistant.$("p");
return nameElement ? nameElement.textContent() : "";
})
);
};
// Get the initial order
const initialOrder = await getAssistantOrder();
// Drag second assistant above first
const secondAssistant = page.locator('[data-testid^="assistant-["]').nth(1);
const firstAssistant = page.locator('[data-testid^="assistant-["]').nth(0);
await dragElementAbove(secondAssistant, firstAssistant, page);
// Check new order
const orderAfterDragUp = await getAssistantOrder();
expect(orderAfterDragUp[0]).toBe(initialOrder[1]);
expect(orderAfterDragUp[1]).toBe(initialOrder[0]);
// Drag last assistant to second position
const assistants = page.locator('[data-testid^="assistant-["]');
const lastIndex = (await assistants.count()) - 1;
const lastAssistant = assistants.nth(lastIndex);
const secondPosition = assistants.nth(1);
await page.waitForTimeout(3000);
await dragElementBelow(lastAssistant, secondPosition, page);
// Check new order
const orderAfterDragDown = await getAssistantOrder();
expect(orderAfterDragDown[1]).toBe(initialOrder[lastIndex]);
// Refresh and verify order
await page.reload();
const orderAfterRefresh = await getAssistantOrder();
expect(orderAfterRefresh).toEqual(orderAfterDragDown);
});

View File

@@ -0,0 +1,72 @@
import { test, expect } from "@playwright/test";
import { loginAsRandomUser } from "../utils/auth";
import {
navigateToAssistantInHistorySidebar,
sendMessage,
startNewChat,
switchModel,
} from "../utils/chatActions";
test("Chat workflow", async ({ page }) => {
test.fail();
// Clear cookies and log in as a random user
await page.context().clearCookies();
await loginAsRandomUser(page);
// Navigate to the chat page
await page.goto("http://localhost:3000/chat");
// Test interaction with the Art assistant
await navigateToAssistantInHistorySidebar(
page,
"[-3]",
"Assistant for generating"
);
await sendMessage(page, "Hi");
// Start a new chat session
await startNewChat(page);
// Verify the presence of the expected text
await expect(page.getByText("Assistant for generating")).toBeVisible();
// Test interaction with the General assistant
await navigateToAssistantInHistorySidebar(
page,
"[-1]",
"Assistant with no search"
);
// Verify the URL after selecting the General assistant
await expect(page).toHaveURL("http://localhost:3000/chat?assistantId=-1");
// Test creation of a new assistant
await page.getByRole("button", { name: "Explore Assistants" }).click();
await page.getByRole("button", { name: "Create" }).click();
await page.getByTestId("name").click();
await page.getByTestId("name").fill("Test Assistant");
await page.getByTestId("description").click();
await page.getByTestId("description").fill("Test Assistant Description");
await page.getByTestId("system_prompt").click();
await page.getByTestId("system_prompt").fill("Test Assistant Instructions");
await page.getByRole("button", { name: "Create" }).click();
// Verify the successful creation of the new assistant
await expect(page.getByText("Test Assistant Description")).toBeVisible({
timeout: 5000,
});
// Start another new chat session
await startNewChat(page);
// Verify the presence of the default assistant text
try {
await expect(page.getByText("Assistant with access to")).toBeVisible({
timeout: 5000,
});
} catch (error) {
console.error("Live Assistant final page content:");
console.error(await page.content());
}
});

View File

@@ -0,0 +1,86 @@
import { test, expect } from "@playwright/test";
import { loginAsRandomUser } from "../utils/auth";
import {
navigateToAssistantInHistorySidebar,
sendMessage,
verifyCurrentModel,
switchModel,
startNewChat,
} from "../utils/chatActions";
test("LLM Ordering and Model Switching", async ({ page }) => {
test.fail();
// Setup: Clear cookies and log in as a random user
await page.context().clearCookies();
await loginAsRandomUser(page);
// Navigate to the chat page and verify URL
await page.goto("http://localhost:3000/chat");
await page.waitForSelector("#onyx-chat-input-textarea");
await expect(page.url()).toBe("http://localhost:3000/chat");
// Configure user settings: Set default model to GPT 4 Turbo
await page.locator("#onyx-user-dropdown").click();
await page.getByText("User Settings").click();
await page.getByRole("combobox").click();
await page.getByLabel("GPT 4 Turbo", { exact: true }).click();
await page.getByLabel("Close modal").click();
await verifyCurrentModel(page, "GPT 4 Turbo");
// Test Art Assistant: Should use its own model (GPT 4o)
await navigateToAssistantInHistorySidebar(
page,
"[-3]",
"Assistant for generating"
);
await sendMessage(page, "Sample message");
await verifyCurrentModel(page, "GPT 4o");
// Verify model persistence for Art Assistant
await sendMessage(page, "Sample message");
// Test new chat: Should use Art Assistant's model initially
await startNewChat(page);
await expect(page.getByText("Assistant for generating")).toBeVisible();
await verifyCurrentModel(page, "GPT 4o");
// Test another new chat: Should use user's default model (GPT 4 Turbo)
await startNewChat(page);
await verifyCurrentModel(page, "GPT 4 Turbo");
// Test model switching within a chat
await switchModel(page, "O1 Mini");
await sendMessage(page, "Sample message");
await verifyCurrentModel(page, "O1 Mini");
// Create a custom assistant with a specific model
await page.getByRole("button", { name: "Explore Assistants" }).click();
await page.getByRole("button", { name: "Create" }).click();
await page.waitForTimeout(2000);
await page.getByTestId("name").fill("Sample Name");
await page.getByTestId("description").fill("Sample Description");
await page.getByTestId("system_prompt").fill("Sample Instructions");
await page.getByRole("combobox").click();
await page
.getByLabel("GPT 4 Turbo (Preview)")
.getByText("GPT 4 Turbo (Preview)")
.click();
await page.getByRole("button", { name: "Create" }).click();
// Verify custom assistant uses its specified model
await page.locator("#onyx-chat-input-textarea").fill("");
await verifyCurrentModel(page, "GPT 4 Turbo (Preview)");
// Ensure model persistence for custom assistant
await sendMessage(page, "Sample message");
await verifyCurrentModel(page, "GPT 4 Turbo (Preview)");
// Switch back to Art Assistant and verify its model
await navigateToAssistantInHistorySidebar(
page,
"[-3]",
"Assistant for generating"
);
await verifyCurrentModel(page, "GPT 4o");
});

View File

@@ -35,3 +35,40 @@ export async function loginAs(page: Page, userType: "admin" | "user") {
}
}
}
// Function to generate a random email and password
const generateRandomCredentials = () => {
const randomString = Math.random().toString(36).substring(2, 10);
const specialChars = "!@#$%^&*()_+{}[]|:;<>,.?~";
const randomSpecialChar =
specialChars[Math.floor(Math.random() * specialChars.length)];
const randomUpperCase = String.fromCharCode(
65 + Math.floor(Math.random() * 26)
);
const randomNumber = Math.floor(Math.random() * 10);
return {
email: `test_${randomString}@example.com`,
password: `P@ssw0rd_${randomUpperCase}${randomSpecialChar}${randomNumber}${randomString}`,
};
};
// Function to sign up a new random user
export async function loginAsRandomUser(page: Page) {
const { email, password } = generateRandomCredentials();
await page.goto("http://localhost:3000/auth/signup");
await page.fill("#email", email);
await page.fill("#password", password);
// Click the signup button
await page.click('button[type="submit"]');
try {
await page.waitForURL("http://localhost:3000/chat");
} catch (error) {
console.log(`Timeout occurred. Current URL: ${page.url()}`);
throw new Error("Failed to sign up and redirect to chat page");
}
return { email, password };
}

View File

@@ -0,0 +1,48 @@
import { Page } from "@playwright/test";
import { expect } from "@playwright/test";
export async function navigateToAssistantInHistorySidebar(
page: Page,
testId: string,
description: string
) {
await page.getByTestId(`assistant-${testId}`).click();
try {
await expect(page.getByText(description)).toBeVisible();
} catch (error) {
console.error("Error in navigateToAssistantInHistorySidebar:", error);
const pageText = await page.textContent("body");
console.log("Page text:", pageText);
throw error;
}
}
export async function sendMessage(page: Page, message: string) {
await page.locator("#onyx-chat-input-textarea").click();
await page.locator("#onyx-chat-input-textarea").fill(message);
await page.locator("#onyx-chat-input-send-button").click();
await page.waitForSelector("#onyx-ai-message");
await page.waitForTimeout(2000);
}
export async function verifyCurrentModel(page: Page, modelName: string) {
await page.waitForTimeout(1000);
const chatInput = page.locator("#onyx-chat-input");
const text = await chatInput.textContent();
expect(text).toContain(modelName);
await page.waitForTimeout(1000);
}
// Start of Selection
export async function switchModel(page: Page, modelName: string) {
await page.getByTestId("llm-popover-trigger").click();
await page
.getByRole("button", { name: `Logo ${modelName}`, exact: true })
.click();
await page.waitForTimeout(1000);
}
export async function startNewChat(page: Page) {
await page.getByRole("link", { name: "Start New Chat" }).click();
await expect(page.locator('div[data-testid="chat-intro"]')).toBeVisible();
}

View File

@@ -0,0 +1,74 @@
import { Locator, Page } from "@playwright/test";
/**
* Drag "source" above (higher Y) "target" by using mouse events.
* Positions the cursor on the lower half of source, then moves to the top half of the target.
*/
export async function dragElementAbove(
sourceLocator: Locator,
targetLocator: Locator,
page: Page
) {
// Get bounding boxes
const sourceBB = await sourceLocator.boundingBox();
const targetBB = await targetLocator.boundingBox();
if (!sourceBB || !targetBB) {
throw new Error("Source/target bounding boxes not found.");
}
// Move over source, press mouse down
await page.mouse.move(
sourceBB.x + sourceBB.width / 2,
sourceBB.y + sourceBB.height * 0.75 // Move to 3/4 down the source element
);
await page.mouse.down();
// Move to a point slightly above the target's center
await page.mouse.move(
targetBB.x + targetBB.width / 2,
targetBB.y + targetBB.height * 0.1, // Move to 1/10 down the target element
{ steps: 20 } // Increase steps for smoother drag
);
await page.mouse.up();
// Increase wait time for DnD transitions
await page.waitForTimeout(200);
}
/**
* Drag "source" below (higher Y → lower Y) "target" using mouse events.
*/
export async function dragElementBelow(
sourceLocator: Locator,
targetLocator: Locator,
page: Page
) {
// Get bounding boxes
const sourceBB = await targetLocator.boundingBox();
const targetBB = await sourceLocator.boundingBox();
if (!sourceBB || !targetBB) {
throw new Error("Source/target bounding boxes not found.");
}
// Move over source, press mouse down
await page.mouse.move(
sourceBB.x + sourceBB.width / 2,
sourceBB.y + sourceBB.height * 0.25 // Move to 1/4 down the source element
);
await page.mouse.down();
// Move to a point well below the target's bottom edge
await page.mouse.move(
targetBB.x + targetBB.width / 2,
targetBB.y + targetBB.height + 50, // Move 50 pixels below the target element
{ steps: 50 } // Keep the same number of steps for smooth drag
);
// Hold for a moment to ensure the drag is registered
await page.waitForTimeout(500);
await page.mouse.up();
// Wait for DnD transitions and potential animations
await page.waitForTimeout(1000);
}

View File

@@ -1,15 +0,0 @@
{
"cookies": [
{
"name": "fastapiusersauth",
"value": "n_EMYYKHn4tQbuPTEbtN1gJ6dQTGek9omJPhO2GhHoA",
"domain": "localhost",
"path": "/",
"expires": 1738801376.508558,
"httpOnly": true,
"secure": false,
"sameSite": "Lax"
}
],
"origins": []
}