mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-20 09:15:47 +00:00
Compare commits
22 Commits
foreign_ke
...
error_supp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bc4a5b6496 | ||
|
|
e1956dc42f | ||
|
|
53225d0a43 | ||
|
|
715359c120 | ||
|
|
e061ba2b93 | ||
|
|
87bccc13cc | ||
|
|
569639eb90 | ||
|
|
68cb1f3409 | ||
|
|
11da0d9889 | ||
|
|
6a7e2a8036 | ||
|
|
035f83c464 | ||
|
|
3c34ddcc4f | ||
|
|
a82cac5361 | ||
|
|
83e5cb2d2f | ||
|
|
a5d2f0d9ac | ||
|
|
7bc8554e01 | ||
|
|
261150e81a | ||
|
|
3e0d24a3f6 | ||
|
|
ffe8ac168f | ||
|
|
17b280e59e | ||
|
|
5edba4a7f3 | ||
|
|
6b31e2f622 |
2
.github/workflows/pr-chromatic-tests.yml
vendored
2
.github/workflows/pr-chromatic-tests.yml
vendored
@@ -8,6 +8,8 @@ on: push
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
GEN_AI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
MOCK_LLM_RESPONSE: true
|
||||
|
||||
jobs:
|
||||
playwright-tests:
|
||||
|
||||
22
.github/workflows/pr-helm-chart-testing.yml
vendored
22
.github/workflows/pr-helm-chart-testing.yml
vendored
@@ -21,10 +21,10 @@ jobs:
|
||||
- name: Set up Helm
|
||||
uses: azure/setup-helm@v4.2.0
|
||||
with:
|
||||
version: v3.14.4
|
||||
version: v3.17.0
|
||||
|
||||
- name: Set up chart-testing
|
||||
uses: helm/chart-testing-action@v2.6.1
|
||||
uses: helm/chart-testing-action@v2.7.0
|
||||
|
||||
# even though we specify chart-dirs in ct.yaml, it isn't used by ct for the list-changed command...
|
||||
- name: Run chart-testing (list-changed)
|
||||
@@ -37,22 +37,6 @@ jobs:
|
||||
echo "changed=true" >> "$GITHUB_OUTPUT"
|
||||
fi
|
||||
|
||||
# rkuo: I don't think we need python?
|
||||
# - name: Set up Python
|
||||
# uses: actions/setup-python@v5
|
||||
# with:
|
||||
# python-version: '3.11'
|
||||
# cache: 'pip'
|
||||
# cache-dependency-path: |
|
||||
# backend/requirements/default.txt
|
||||
# backend/requirements/dev.txt
|
||||
# backend/requirements/model_server.txt
|
||||
# - run: |
|
||||
# python -m pip install --upgrade pip
|
||||
# pip install --retries 5 --timeout 30 -r backend/requirements/default.txt
|
||||
# pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
# pip install --retries 5 --timeout 30 -r backend/requirements/model_server.txt
|
||||
|
||||
# lint all charts if any changes were detected
|
||||
- name: Run chart-testing (lint)
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
@@ -62,7 +46,7 @@ jobs:
|
||||
|
||||
- name: Create kind cluster
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
uses: helm/kind-action@v1.10.0
|
||||
uses: helm/kind-action@v1.12.0
|
||||
|
||||
- name: Run chart-testing (install)
|
||||
if: steps.list-changed.outputs.changed == 'true'
|
||||
|
||||
@@ -0,0 +1,36 @@
|
||||
"""add chat session specific temperature override
|
||||
|
||||
Revision ID: 2f80c6a2550f
|
||||
Revises: 33ea50e88f24
|
||||
Create Date: 2025-01-31 10:30:27.289646
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "2f80c6a2550f"
|
||||
down_revision = "33ea50e88f24"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"chat_session", sa.Column("temperature_override", sa.Float(), nullable=True)
|
||||
)
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"temperature_override_enabled",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("chat_session", "temperature_override")
|
||||
op.drop_column("user", "temperature_override_enabled")
|
||||
@@ -16,16 +16,18 @@ depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# First drop the existing FK constraints
|
||||
op.drop_constraint(
|
||||
"inputprompt__user_input_prompt_id_fkey",
|
||||
"inputprompt__user",
|
||||
type_="foreignkey",
|
||||
# Safely drop constraints if exists
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE inputprompt__user
|
||||
DROP CONSTRAINT IF EXISTS inputprompt__user_input_prompt_id_fkey
|
||||
"""
|
||||
)
|
||||
op.drop_constraint(
|
||||
"inputprompt__user_user_id_fkey",
|
||||
"inputprompt__user",
|
||||
type_="foreignkey",
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE inputprompt__user
|
||||
DROP CONSTRAINT IF EXISTS inputprompt__user_user_id_fkey
|
||||
"""
|
||||
)
|
||||
|
||||
# Recreate with ON DELETE CASCADE
|
||||
@@ -37,10 +39,11 @@ def upgrade() -> None:
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
)
|
||||
|
||||
op.create_foreign_key(
|
||||
"inputprompt__user_user_id_fkey",
|
||||
"inputprompt__user",
|
||||
'"user"',
|
||||
"user",
|
||||
["user_id"],
|
||||
["id"],
|
||||
ondelete="CASCADE",
|
||||
@@ -71,7 +74,7 @@ def downgrade() -> None:
|
||||
op.create_foreign_key(
|
||||
"inputprompt__user_user_id_fkey",
|
||||
"inputprompt__user",
|
||||
'"user"',
|
||||
"user",
|
||||
["user_id"],
|
||||
["id"],
|
||||
)
|
||||
|
||||
@@ -13,6 +13,7 @@ from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
|
||||
from onyx.connectors.confluence.utils import get_user_email_from_username__server
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -257,6 +258,7 @@ def _fetch_all_page_restrictions(
|
||||
slim_docs: list[SlimDocument],
|
||||
space_permissions_by_space_key: dict[str, ExternalAccess],
|
||||
is_cloud: bool,
|
||||
callback: IndexingHeartbeatInterface | None,
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
For all pages, if a page has restrictions, then use those restrictions.
|
||||
@@ -265,6 +267,12 @@ def _fetch_all_page_restrictions(
|
||||
document_restrictions: list[DocExternalAccess] = []
|
||||
|
||||
for slim_doc in slim_docs:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError("confluence_doc_sync: Stop signal detected")
|
||||
|
||||
callback.progress("confluence_doc_sync:fetch_all_page_restrictions", 1)
|
||||
|
||||
if slim_doc.perm_sync_data is None:
|
||||
raise ValueError(
|
||||
f"No permission sync data found for document {slim_doc.id}"
|
||||
@@ -334,7 +342,7 @@ def _fetch_all_page_restrictions(
|
||||
|
||||
|
||||
def confluence_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
@@ -359,6 +367,12 @@ def confluence_doc_sync(
|
||||
logger.debug("Fetching all slim documents from confluence")
|
||||
for doc_batch in confluence_connector.retrieve_all_slim_documents():
|
||||
logger.debug(f"Got {len(doc_batch)} slim documents from confluence")
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError("confluence_doc_sync: Stop signal detected")
|
||||
|
||||
callback.progress("confluence_doc_sync", 1)
|
||||
|
||||
slim_docs.extend(doc_batch)
|
||||
|
||||
logger.debug("Fetching all page restrictions for space")
|
||||
@@ -367,4 +381,5 @@ def confluence_doc_sync(
|
||||
slim_docs=slim_docs,
|
||||
space_permissions_by_space_key=space_permissions_by_space_key,
|
||||
is_cloud=is_cloud,
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
@@ -6,6 +6,7 @@ from onyx.access.models import ExternalAccess
|
||||
from onyx.connectors.gmail.connector import GmailConnector
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -28,7 +29,7 @@ def _get_slim_doc_generator(
|
||||
|
||||
|
||||
def gmail_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
@@ -44,6 +45,12 @@ def gmail_doc_sync(
|
||||
document_external_access: list[DocExternalAccess] = []
|
||||
for slim_doc_batch in slim_doc_generator:
|
||||
for slim_doc in slim_doc_batch:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError("gmail_doc_sync: Stop signal detected")
|
||||
|
||||
callback.progress("gmail_doc_sync", 1)
|
||||
|
||||
if slim_doc.perm_sync_data is None:
|
||||
logger.warning(f"No permissions found for document {slim_doc.id}")
|
||||
continue
|
||||
|
||||
@@ -10,6 +10,7 @@ from onyx.connectors.google_utils.resources import get_drive_service
|
||||
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -128,7 +129,7 @@ def _get_permissions_from_slim_doc(
|
||||
|
||||
|
||||
def gdrive_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
@@ -146,6 +147,12 @@ def gdrive_doc_sync(
|
||||
document_external_accesses = []
|
||||
for slim_doc_batch in slim_doc_generator:
|
||||
for slim_doc in slim_doc_batch:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError("gdrive_doc_sync: Stop signal detected")
|
||||
|
||||
callback.progress("gdrive_doc_sync", 1)
|
||||
|
||||
ext_access = _get_permissions_from_slim_doc(
|
||||
google_drive_connector=google_drive_connector,
|
||||
slim_doc=slim_doc,
|
||||
|
||||
@@ -7,6 +7,7 @@ from onyx.connectors.slack.connector import get_channels
|
||||
from onyx.connectors.slack.connector import make_paginated_slack_api_call_w_retries
|
||||
from onyx.connectors.slack.connector import SlackPollConnector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -14,7 +15,7 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def _get_slack_document_ids_and_channels(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
) -> dict[str, list[str]]:
|
||||
slack_connector = SlackPollConnector(**cc_pair.connector.connector_specific_config)
|
||||
slack_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
@@ -24,6 +25,14 @@ def _get_slack_document_ids_and_channels(
|
||||
channel_doc_map: dict[str, list[str]] = {}
|
||||
for doc_metadata_batch in slim_doc_generator:
|
||||
for doc_metadata in doc_metadata_batch:
|
||||
if callback:
|
||||
if callback.should_stop():
|
||||
raise RuntimeError(
|
||||
"_get_slack_document_ids_and_channels: Stop signal detected"
|
||||
)
|
||||
|
||||
callback.progress("_get_slack_document_ids_and_channels", 1)
|
||||
|
||||
if doc_metadata.perm_sync_data is None:
|
||||
continue
|
||||
channel_id = doc_metadata.perm_sync_data["channel_id"]
|
||||
@@ -114,7 +123,7 @@ def _fetch_channel_permissions(
|
||||
|
||||
|
||||
def slack_doc_sync(
|
||||
cc_pair: ConnectorCredentialPair,
|
||||
cc_pair: ConnectorCredentialPair, callback: IndexingHeartbeatInterface | None
|
||||
) -> list[DocExternalAccess]:
|
||||
"""
|
||||
Adds the external permissions to the documents in postgres
|
||||
@@ -127,7 +136,7 @@ def slack_doc_sync(
|
||||
)
|
||||
user_id_to_email_map = fetch_user_id_to_email_map(slack_client)
|
||||
channel_doc_map = _get_slack_document_ids_and_channels(
|
||||
cc_pair=cc_pair,
|
||||
cc_pair=cc_pair, callback=callback
|
||||
)
|
||||
workspace_permissions = _fetch_workspace_permissions(
|
||||
user_id_to_email_map=user_id_to_email_map,
|
||||
|
||||
@@ -15,11 +15,13 @@ from ee.onyx.external_permissions.slack.doc_sync import slack_doc_sync
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
|
||||
# Defining the input/output types for the sync functions
|
||||
DocSyncFuncType = Callable[
|
||||
[
|
||||
ConnectorCredentialPair,
|
||||
IndexingHeartbeatInterface | None,
|
||||
],
|
||||
list[DocExternalAccess],
|
||||
]
|
||||
|
||||
@@ -198,7 +198,8 @@ def on_celeryd_init(sender: str, conf: Any = None, **kwargs: Any) -> None:
|
||||
|
||||
def wait_for_redis(sender: Any, **kwargs: Any) -> None:
|
||||
"""Waits for redis to become ready subject to a hardcoded timeout.
|
||||
Will raise WorkerShutdown to kill the celery worker if the timeout is reached."""
|
||||
Will raise WorkerShutdown to kill the celery worker if the timeout
|
||||
is reached."""
|
||||
|
||||
r = get_redis_client(tenant_id=None)
|
||||
|
||||
|
||||
@@ -91,6 +91,28 @@ def celery_find_task(task_id: str, queue: str, r: Redis) -> int:
|
||||
return False
|
||||
|
||||
|
||||
def celery_get_queued_task_ids(queue: str, r: Redis) -> set[str]:
|
||||
"""This is a redis specific way to build a list of tasks in a queue.
|
||||
|
||||
This helps us read the queue once and then efficiently look for missing tasks
|
||||
in the queue.
|
||||
"""
|
||||
|
||||
task_set: set[str] = set()
|
||||
|
||||
for priority in range(len(OnyxCeleryPriority)):
|
||||
queue_name = f"{queue}{CELERY_SEPARATOR}{priority}" if priority > 0 else queue
|
||||
|
||||
tasks = cast(list[bytes], r.lrange(queue_name, 0, -1))
|
||||
for task in tasks:
|
||||
task_dict: dict[str, Any] = json.loads(task.decode("utf-8"))
|
||||
task_id = task_dict.get("headers", {}).get("id")
|
||||
if task_id:
|
||||
task_set.add(task_id)
|
||||
|
||||
return task_set
|
||||
|
||||
|
||||
def celery_inspect_get_workers(name_filter: str | None, app: Celery) -> list[str]:
|
||||
"""Returns a list of current workers containing name_filter, or all workers if
|
||||
name_filter is None.
|
||||
|
||||
@@ -3,13 +3,16 @@ from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from time import sleep
|
||||
from typing import cast
|
||||
from uuid import uuid4
|
||||
|
||||
from celery import Celery
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from pydantic import ValidationError
|
||||
from redis import Redis
|
||||
from redis.exceptions import LockError
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -22,6 +25,10 @@ from ee.onyx.external_permissions.sync_params import (
|
||||
)
|
||||
from onyx.access.models import DocExternalAccess
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_queued_task_ids
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_PERMISSIONS_SYNC_LOCK_TIMEOUT
|
||||
@@ -32,6 +39,7 @@ from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.configs.constants import OnyxRedisSignals
|
||||
from onyx.db.connector import mark_cc_pair_as_permissions_synced
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.document import upsert_document_by_connector_credential_pair
|
||||
@@ -44,14 +52,19 @@ from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
from onyx.db.sync_record import update_sync_record_status
|
||||
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_doc_perm_sync import (
|
||||
RedisConnectorPermissionSyncPayload,
|
||||
)
|
||||
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSync
|
||||
from onyx.redis.redis_connector_doc_perm_sync import RedisConnectorPermissionSyncPayload
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
|
||||
from onyx.server.utils import make_short_id
|
||||
from onyx.utils.logger import doc_permission_sync_ctx
|
||||
from onyx.utils.logger import LoggerContextVars
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@@ -105,7 +118,12 @@ def _is_external_doc_permissions_sync_due(cc_pair: ConnectorCredentialPair) -> b
|
||||
bind=True,
|
||||
)
|
||||
def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
# TODO(rkuo): merge into check function after lookup table for fences is added
|
||||
|
||||
# we need to use celery's redis client to access its redis data
|
||||
# (which lives on a different db number)
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_DOC_PERMISSIONS_SYNC_BEAT_LOCK,
|
||||
@@ -126,14 +144,32 @@ def check_for_doc_permissions_sync(self: Task, *, tenant_id: str | None) -> bool
|
||||
if _is_external_doc_permissions_sync_due(cc_pair):
|
||||
cc_pair_ids_to_sync.append(cc_pair.id)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for cc_pair_id in cc_pair_ids_to_sync:
|
||||
tasks_created = try_creating_permissions_sync_task(
|
||||
payload_id = try_creating_permissions_sync_task(
|
||||
self.app, cc_pair_id, r, tenant_id
|
||||
)
|
||||
if not tasks_created:
|
||||
if not payload_id:
|
||||
continue
|
||||
|
||||
task_logger.info(f"Doc permissions sync queued: cc_pair={cc_pair_id}")
|
||||
task_logger.info(
|
||||
f"Permissions sync queued: cc_pair={cc_pair_id} id={payload_id}"
|
||||
)
|
||||
|
||||
# we want to run this less frequently than the overall task
|
||||
lock_beat.reacquire()
|
||||
if not r.exists(OnyxRedisSignals.VALIDATE_PERMISSION_SYNC_FENCES):
|
||||
# clear any permission fences that don't have associated celery tasks in progress
|
||||
# tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
|
||||
# or be currently executing
|
||||
try:
|
||||
validate_permission_sync_fences(tenant_id, r, r_celery, lock_beat)
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
"Exception while validating permission sync fences"
|
||||
)
|
||||
|
||||
r.set(OnyxRedisSignals.VALIDATE_PERMISSION_SYNC_FENCES, 1, ex=60)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@@ -152,13 +188,15 @@ def try_creating_permissions_sync_task(
|
||||
cc_pair_id: int,
|
||||
r: Redis,
|
||||
tenant_id: str | None,
|
||||
) -> int | None:
|
||||
"""Returns an int if syncing is needed. The int represents the number of sync tasks generated.
|
||||
) -> str | None:
|
||||
"""Returns a randomized payload id on success.
|
||||
Returns None if no syncing is required."""
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
LOCK_TIMEOUT = 30
|
||||
|
||||
payload_id: str | None = None
|
||||
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
|
||||
lock: RedisLock = r.lock(
|
||||
DANSWER_REDIS_FUNCTION_LOCK_PREFIX + "try_generate_permissions_sync_tasks",
|
||||
timeout=LOCK_TIMEOUT,
|
||||
@@ -193,7 +231,13 @@ def try_creating_permissions_sync_task(
|
||||
)
|
||||
|
||||
# set a basic fence to start
|
||||
payload = RedisConnectorPermissionSyncPayload(started=None, celery_task_id=None)
|
||||
redis_connector.permissions.set_active()
|
||||
payload = RedisConnectorPermissionSyncPayload(
|
||||
id=make_short_id(),
|
||||
submitted=datetime.now(timezone.utc),
|
||||
started=None,
|
||||
celery_task_id=None,
|
||||
)
|
||||
redis_connector.permissions.set_fence(payload)
|
||||
|
||||
result = app.send_task(
|
||||
@@ -208,8 +252,11 @@ def try_creating_permissions_sync_task(
|
||||
)
|
||||
|
||||
# fill in the celery task id
|
||||
redis_connector.permissions.set_active()
|
||||
payload.celery_task_id = result.id
|
||||
redis_connector.permissions.set_fence(payload)
|
||||
|
||||
payload_id = payload.celery_task_id
|
||||
except Exception:
|
||||
task_logger.exception(f"Unexpected exception: cc_pair={cc_pair_id}")
|
||||
return None
|
||||
@@ -217,7 +264,7 @@ def try_creating_permissions_sync_task(
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
return 1
|
||||
return payload_id
|
||||
|
||||
|
||||
@shared_task(
|
||||
@@ -238,6 +285,8 @@ def connector_permission_sync_generator_task(
|
||||
This task assumes that the task has already been properly fenced
|
||||
"""
|
||||
|
||||
LoggerContextVars.reset()
|
||||
|
||||
doc_permission_sync_ctx_dict = doc_permission_sync_ctx.get()
|
||||
doc_permission_sync_ctx_dict["cc_pair_id"] = cc_pair_id
|
||||
doc_permission_sync_ctx_dict["request_id"] = self.request.id
|
||||
@@ -325,12 +374,17 @@ def connector_permission_sync_generator_task(
|
||||
raise ValueError(f"No fence payload found: cc_pair={cc_pair_id}")
|
||||
|
||||
new_payload = RedisConnectorPermissionSyncPayload(
|
||||
id=payload.id,
|
||||
submitted=payload.submitted,
|
||||
started=datetime.now(timezone.utc),
|
||||
celery_task_id=payload.celery_task_id,
|
||||
)
|
||||
redis_connector.permissions.set_fence(new_payload)
|
||||
|
||||
document_external_accesses: list[DocExternalAccess] = doc_sync_func(cc_pair)
|
||||
callback = PermissionSyncCallback(redis_connector, lock, r)
|
||||
document_external_accesses: list[DocExternalAccess] = doc_sync_func(
|
||||
cc_pair, callback
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"RedisConnector.permissions.generate_tasks starting. cc_pair={cc_pair_id}"
|
||||
@@ -380,6 +434,8 @@ def update_external_document_permissions_task(
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
) -> bool:
|
||||
start = time.monotonic()
|
||||
|
||||
document_external_access = DocExternalAccess.from_dict(
|
||||
serialized_doc_external_access
|
||||
)
|
||||
@@ -409,16 +465,268 @@ def update_external_document_permissions_task(
|
||||
document_ids=[doc_id],
|
||||
)
|
||||
|
||||
logger.debug(
|
||||
f"Successfully synced postgres document permissions for {doc_id}"
|
||||
elapsed = time.monotonic() - start
|
||||
task_logger.info(
|
||||
f"connector_id={connector_id} "
|
||||
f"doc={doc_id} "
|
||||
f"action=update_permissions "
|
||||
f"elapsed={elapsed:.2f}"
|
||||
)
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Error Syncing Document Permissions: connector_id={connector_id} doc_id={doc_id}"
|
||||
task_logger.exception(
|
||||
f"Exception in update_external_document_permissions_task: "
|
||||
f"connector_id={connector_id} "
|
||||
f"doc_id={doc_id}"
|
||||
)
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def validate_permission_sync_fences(
|
||||
tenant_id: str | None,
|
||||
r: Redis,
|
||||
r_celery: Redis,
|
||||
lock_beat: RedisLock,
|
||||
) -> None:
|
||||
# building lookup table can be expensive, so we won't bother
|
||||
# validating until the queue is small
|
||||
PERMISSION_SYNC_VALIDATION_MAX_QUEUE_LEN = 1024
|
||||
|
||||
queue_len = celery_get_queue_length(
|
||||
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery
|
||||
)
|
||||
if queue_len > PERMISSION_SYNC_VALIDATION_MAX_QUEUE_LEN:
|
||||
return
|
||||
|
||||
queued_upsert_tasks = celery_get_queued_task_ids(
|
||||
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT, r_celery
|
||||
)
|
||||
reserved_generator_tasks = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC, r_celery
|
||||
)
|
||||
|
||||
# validate all existing indexing jobs
|
||||
for key_bytes in r.scan_iter(
|
||||
RedisConnectorPermissionSync.FENCE_PREFIX + "*",
|
||||
count=SCAN_ITER_COUNT_DEFAULT,
|
||||
):
|
||||
lock_beat.reacquire()
|
||||
validate_permission_sync_fence(
|
||||
tenant_id,
|
||||
key_bytes,
|
||||
queued_upsert_tasks,
|
||||
reserved_generator_tasks,
|
||||
r,
|
||||
r_celery,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
def validate_permission_sync_fence(
|
||||
tenant_id: str | None,
|
||||
key_bytes: bytes,
|
||||
queued_tasks: set[str],
|
||||
reserved_tasks: set[str],
|
||||
r: Redis,
|
||||
r_celery: Redis,
|
||||
) -> None:
|
||||
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
|
||||
This can happen if the indexing worker hard crashes or is terminated.
|
||||
Being in this bad state means the fence will never clear without help, so this function
|
||||
gives the help.
|
||||
|
||||
How this works:
|
||||
1. This function renews the active signal with a 5 minute TTL under the following conditions
|
||||
1.2. When the task is seen in the redis queue
|
||||
1.3. When the task is seen in the reserved / prefetched list
|
||||
|
||||
2. Externally, the active signal is renewed when:
|
||||
2.1. The fence is created
|
||||
2.2. The indexing watchdog checks the spawned task.
|
||||
|
||||
3. The TTL allows us to get through the transitions on fence startup
|
||||
and when the task starts executing.
|
||||
|
||||
More TTL clarification: it is seemingly impossible to exactly query Celery for
|
||||
whether a task is in the queue or currently executing.
|
||||
1. An unknown task id is always returned as state PENDING.
|
||||
2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task
|
||||
and the time it actually starts on the worker.
|
||||
|
||||
queued_tasks: the celery queue of lightweight permission sync tasks
|
||||
reserved_tasks: prefetched tasks for sync task generator
|
||||
"""
|
||||
# if the fence doesn't exist, there's nothing to do
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id_str is None:
|
||||
task_logger.warning(
|
||||
f"validate_permission_sync_fence - could not parse id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
cc_pair_id = int(cc_pair_id_str)
|
||||
# parse out metadata and initialize the helper class with it
|
||||
redis_connector = RedisConnector(tenant_id, int(cc_pair_id))
|
||||
|
||||
# check to see if the fence/payload exists
|
||||
if not redis_connector.permissions.fenced:
|
||||
return
|
||||
|
||||
# in the cloud, the payload format may have changed ...
|
||||
# it's a little sloppy, but just reset the fence for now if that happens
|
||||
# TODO: add intentional cleanup/abort logic
|
||||
try:
|
||||
payload = redis_connector.permissions.payload
|
||||
except ValidationError:
|
||||
task_logger.exception(
|
||||
"validate_permission_sync_fence - "
|
||||
"Resetting fence because fence schema is out of date: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={fence_key}"
|
||||
)
|
||||
|
||||
redis_connector.permissions.reset()
|
||||
return
|
||||
|
||||
if not payload:
|
||||
return
|
||||
|
||||
if not payload.celery_task_id:
|
||||
return
|
||||
|
||||
# OK, there's actually something for us to validate
|
||||
|
||||
# either the generator task must be in flight or its subtasks must be
|
||||
found = celery_find_task(
|
||||
payload.celery_task_id,
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC,
|
||||
r_celery,
|
||||
)
|
||||
if found:
|
||||
# the celery task exists in the redis queue
|
||||
redis_connector.permissions.set_active()
|
||||
return
|
||||
|
||||
if payload.celery_task_id in reserved_tasks:
|
||||
# the celery task was prefetched and is reserved within a worker
|
||||
redis_connector.permissions.set_active()
|
||||
return
|
||||
|
||||
# look up every task in the current taskset in the celery queue
|
||||
# every entry in the taskset should have an associated entry in the celery task queue
|
||||
# because we get the celery tasks first, the entries in our own permissions taskset
|
||||
# should be roughly a subset of the tasks in celery
|
||||
|
||||
# this check isn't very exact, but should be sufficient over a period of time
|
||||
# A single successful check over some number of attempts is sufficient.
|
||||
|
||||
# TODO: if the number of tasks in celery is much lower than than the taskset length
|
||||
# we might be able to shortcut the lookup since by definition some of the tasks
|
||||
# must not exist in celery.
|
||||
|
||||
tasks_scanned = 0
|
||||
tasks_not_in_celery = 0 # a non-zero number after completing our check is bad
|
||||
|
||||
for member in r.sscan_iter(redis_connector.permissions.taskset_key):
|
||||
tasks_scanned += 1
|
||||
|
||||
member_bytes = cast(bytes, member)
|
||||
member_str = member_bytes.decode("utf-8")
|
||||
if member_str in queued_tasks:
|
||||
continue
|
||||
|
||||
if member_str in reserved_tasks:
|
||||
continue
|
||||
|
||||
tasks_not_in_celery += 1
|
||||
|
||||
task_logger.info(
|
||||
"validate_permission_sync_fence task check: "
|
||||
f"tasks_scanned={tasks_scanned} tasks_not_in_celery={tasks_not_in_celery}"
|
||||
)
|
||||
|
||||
if tasks_not_in_celery == 0:
|
||||
redis_connector.permissions.set_active()
|
||||
return
|
||||
|
||||
# we may want to enable this check if using the active task list somehow isn't good enough
|
||||
# if redis_connector_index.generator_locked():
|
||||
# logger.info(f"{payload.celery_task_id} is currently executing.")
|
||||
|
||||
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
|
||||
# but they still might be there due to gaps in our ability to check states during transitions
|
||||
# Checking the active signal safeguards us against these transition periods
|
||||
# (which has a duration that allows us to bridge those gaps)
|
||||
if redis_connector.permissions.active():
|
||||
return
|
||||
|
||||
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
|
||||
task_logger.warning(
|
||||
"validate_permission_sync_fence - "
|
||||
"Resetting fence because no associated celery tasks were found: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={fence_key}"
|
||||
)
|
||||
|
||||
redis_connector.permissions.reset()
|
||||
return
|
||||
|
||||
|
||||
class PermissionSyncCallback(IndexingHeartbeatInterface):
|
||||
PARENT_CHECK_INTERVAL = 60
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_connector: RedisConnector,
|
||||
redis_lock: RedisLock,
|
||||
redis_client: Redis,
|
||||
):
|
||||
super().__init__()
|
||||
self.redis_connector: RedisConnector = redis_connector
|
||||
self.redis_lock: RedisLock = redis_lock
|
||||
self.redis_client = redis_client
|
||||
|
||||
self.started: datetime = datetime.now(timezone.utc)
|
||||
self.redis_lock.reacquire()
|
||||
|
||||
self.last_tag: str = "PermissionSyncCallback.__init__"
|
||||
self.last_lock_reacquire: datetime = datetime.now(timezone.utc)
|
||||
self.last_lock_monotonic = time.monotonic()
|
||||
|
||||
def should_stop(self) -> bool:
|
||||
if self.redis_connector.stop.fenced:
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def progress(self, tag: str, amount: int) -> None:
|
||||
try:
|
||||
self.redis_connector.permissions.set_active()
|
||||
|
||||
current_time = time.monotonic()
|
||||
if current_time - self.last_lock_monotonic >= (
|
||||
CELERY_GENERIC_BEAT_LOCK_TIMEOUT / 4
|
||||
):
|
||||
self.redis_lock.reacquire()
|
||||
self.last_lock_reacquire = datetime.now(timezone.utc)
|
||||
self.last_lock_monotonic = time.monotonic()
|
||||
|
||||
self.last_tag = tag
|
||||
except LockError:
|
||||
logger.exception(
|
||||
f"PermissionSyncCallback - lock.reacquire exceptioned: "
|
||||
f"lock_timeout={self.redis_lock.timeout} "
|
||||
f"start={self.started} "
|
||||
f"last_tag={self.last_tag} "
|
||||
f"last_reacquired={self.last_lock_reacquire} "
|
||||
f"now={datetime.now(timezone.utc)}"
|
||||
)
|
||||
|
||||
redis_lock_dump(self.redis_lock, self.redis_client)
|
||||
raise
|
||||
|
||||
|
||||
"""Monitoring CCPair permissions utils, called in monitor_vespa_sync"""
|
||||
|
||||
@@ -444,20 +752,36 @@ def monitor_ccpair_permissions_taskset(
|
||||
if initial is None:
|
||||
return
|
||||
|
||||
try:
|
||||
payload = redis_connector.permissions.payload
|
||||
except ValidationError:
|
||||
task_logger.exception(
|
||||
"Permissions sync payload failed to validate. "
|
||||
"Schema may have been updated."
|
||||
)
|
||||
return
|
||||
|
||||
if not payload:
|
||||
return
|
||||
|
||||
remaining = redis_connector.permissions.get_remaining()
|
||||
task_logger.info(
|
||||
f"Permissions sync progress: cc_pair={cc_pair_id} remaining={remaining} initial={initial}"
|
||||
f"Permissions sync progress: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"id={payload.id} "
|
||||
f"remaining={remaining} "
|
||||
f"initial={initial}"
|
||||
)
|
||||
if remaining > 0:
|
||||
return
|
||||
|
||||
payload: RedisConnectorPermissionSyncPayload | None = (
|
||||
redis_connector.permissions.payload
|
||||
mark_cc_pair_as_permissions_synced(db_session, int(cc_pair_id), payload.started)
|
||||
task_logger.info(
|
||||
f"Permissions sync finished: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"id={payload.id} "
|
||||
f"num_synced={initial}"
|
||||
)
|
||||
start_time: datetime | None = payload.started if payload else None
|
||||
|
||||
mark_cc_pair_as_permissions_synced(db_session, int(cc_pair_id), start_time)
|
||||
task_logger.info(f"Successfully synced permissions for cc_pair={cc_pair_id}")
|
||||
|
||||
update_sync_record_status(
|
||||
db_session=db_session,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
@@ -9,6 +10,7 @@ from celery import Task
|
||||
from celery.exceptions import SoftTimeLimitExceeded
|
||||
from redis import Redis
|
||||
from redis.lock import Lock as RedisLock
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.db.connector_credential_pair import get_all_auto_sync_cc_pairs
|
||||
from ee.onyx.db.connector_credential_pair import get_cc_pairs_by_source
|
||||
@@ -20,9 +22,12 @@ from ee.onyx.external_permissions.sync_params import (
|
||||
GROUP_PERMISSIONS_IS_CC_PAIR_AGNOSTIC,
|
||||
)
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_redis import celery_find_task
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_GENERIC_BEAT_LOCK_TIMEOUT
|
||||
from onyx.configs.constants import CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT
|
||||
from onyx.configs.constants import DANSWER_REDIS_FUNCTION_LOCK_PREFIX
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
@@ -39,10 +44,12 @@ from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.sync_record import insert_sync_record
|
||||
from onyx.db.sync_record import update_sync_record_status
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_ext_group_sync import RedisConnectorExternalGroupSync
|
||||
from onyx.redis.redis_connector_ext_group_sync import (
|
||||
RedisConnectorExternalGroupSyncPayload,
|
||||
)
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -102,6 +109,10 @@ def _is_external_group_sync_due(cc_pair: ConnectorCredentialPair) -> bool:
|
||||
def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool | None:
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# we need to use celery's redis client to access its redis data
|
||||
# (which lives on a different db number)
|
||||
# r_celery: Redis = self.app.broker_connection().channel().client # type: ignore
|
||||
|
||||
lock_beat: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_CONNECTOR_EXTERNAL_GROUP_SYNC_BEAT_LOCK,
|
||||
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
|
||||
@@ -136,6 +147,7 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool
|
||||
if _is_external_group_sync_due(cc_pair):
|
||||
cc_pair_ids_to_sync.append(cc_pair.id)
|
||||
|
||||
lock_beat.reacquire()
|
||||
for cc_pair_id in cc_pair_ids_to_sync:
|
||||
tasks_created = try_creating_external_group_sync_task(
|
||||
self.app, cc_pair_id, r, tenant_id
|
||||
@@ -144,6 +156,23 @@ def check_for_external_group_sync(self: Task, *, tenant_id: str | None) -> bool
|
||||
continue
|
||||
|
||||
task_logger.info(f"External group sync queued: cc_pair={cc_pair_id}")
|
||||
|
||||
# we want to run this less frequently than the overall task
|
||||
# lock_beat.reacquire()
|
||||
# if not r.exists(OnyxRedisSignals.VALIDATE_EXTERNAL_GROUP_SYNC_FENCES):
|
||||
# # clear any indexing fences that don't have associated celery tasks in progress
|
||||
# # tasks can be in the queue in redis, in reserved tasks (prefetched by the worker),
|
||||
# # or be currently executing
|
||||
# try:
|
||||
# validate_external_group_sync_fences(
|
||||
# tenant_id, self.app, r, r_celery, lock_beat
|
||||
# )
|
||||
# except Exception:
|
||||
# task_logger.exception(
|
||||
# "Exception while validating external group sync fences"
|
||||
# )
|
||||
|
||||
# r.set(OnyxRedisSignals.VALIDATE_EXTERNAL_GROUP_SYNC_FENCES, 1, ex=60)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
@@ -186,6 +215,12 @@ def try_creating_external_group_sync_task(
|
||||
redis_connector.external_group_sync.generator_clear()
|
||||
redis_connector.external_group_sync.taskset_clear()
|
||||
|
||||
payload = RedisConnectorExternalGroupSyncPayload(
|
||||
submitted=datetime.now(timezone.utc),
|
||||
started=None,
|
||||
celery_task_id=None,
|
||||
)
|
||||
|
||||
custom_task_id = f"{redis_connector.external_group_sync.taskset_key}_{uuid4()}"
|
||||
|
||||
result = app.send_task(
|
||||
@@ -199,11 +234,6 @@ def try_creating_external_group_sync_task(
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
)
|
||||
|
||||
payload = RedisConnectorExternalGroupSyncPayload(
|
||||
started=datetime.now(timezone.utc),
|
||||
celery_task_id=result.id,
|
||||
)
|
||||
|
||||
# create before setting fence to avoid race condition where the monitoring
|
||||
# task updates the sync record before it is created
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
@@ -213,8 +243,8 @@ def try_creating_external_group_sync_task(
|
||||
sync_type=SyncType.EXTERNAL_GROUP,
|
||||
)
|
||||
|
||||
payload.celery_task_id = result.id
|
||||
redis_connector.external_group_sync.set_fence(payload)
|
||||
|
||||
except Exception:
|
||||
task_logger.exception(
|
||||
f"Unexpected exception while trying to create external group sync task: cc_pair={cc_pair_id}"
|
||||
@@ -241,7 +271,7 @@ def connector_external_group_sync_generator_task(
|
||||
tenant_id: str | None,
|
||||
) -> None:
|
||||
"""
|
||||
Permission sync task that handles external group syncing for a given connector credential pair
|
||||
External group sync task for a given connector credential pair
|
||||
This task assumes that the task has already been properly fenced
|
||||
"""
|
||||
|
||||
@@ -249,19 +279,59 @@ def connector_external_group_sync_generator_task(
|
||||
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
# this wait is needed to avoid a race condition where
|
||||
# the primary worker sends the task and it is immediately executed
|
||||
# before the primary worker can finalize the fence
|
||||
start = time.monotonic()
|
||||
while True:
|
||||
if time.monotonic() - start > CELERY_TASK_WAIT_FOR_FENCE_TIMEOUT:
|
||||
raise ValueError(
|
||||
f"connector_external_group_sync_generator_task - timed out waiting for fence to be ready: "
|
||||
f"fence={redis_connector.external_group_sync.fence_key}"
|
||||
)
|
||||
|
||||
if not redis_connector.external_group_sync.fenced: # The fence must exist
|
||||
raise ValueError(
|
||||
f"connector_external_group_sync_generator_task - fence not found: "
|
||||
f"fence={redis_connector.external_group_sync.fence_key}"
|
||||
)
|
||||
|
||||
payload = redis_connector.external_group_sync.payload # The payload must exist
|
||||
if not payload:
|
||||
raise ValueError(
|
||||
"connector_external_group_sync_generator_task: payload invalid or not found"
|
||||
)
|
||||
|
||||
if payload.celery_task_id is None:
|
||||
logger.info(
|
||||
f"connector_external_group_sync_generator_task - Waiting for fence: "
|
||||
f"fence={redis_connector.external_group_sync.fence_key}"
|
||||
)
|
||||
time.sleep(1)
|
||||
continue
|
||||
|
||||
logger.info(
|
||||
f"connector_external_group_sync_generator_task - Fence found, continuing...: "
|
||||
f"fence={redis_connector.external_group_sync.fence_key}"
|
||||
)
|
||||
break
|
||||
|
||||
lock: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CONNECTOR_EXTERNAL_GROUP_SYNC_LOCK_PREFIX
|
||||
+ f"_{redis_connector.id}",
|
||||
timeout=CELERY_EXTERNAL_GROUP_SYNC_LOCK_TIMEOUT,
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking=False)
|
||||
if not acquired:
|
||||
task_logger.warning(
|
||||
f"External group sync task already running, exiting...: cc_pair={cc_pair_id}"
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
acquired = lock.acquire(blocking=False)
|
||||
if not acquired:
|
||||
task_logger.warning(
|
||||
f"External group sync task already running, exiting...: cc_pair={cc_pair_id}"
|
||||
)
|
||||
return None
|
||||
payload.started = datetime.now(timezone.utc)
|
||||
redis_connector.external_group_sync.set_fence(payload)
|
||||
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
@@ -330,3 +400,135 @@ def connector_external_group_sync_generator_task(
|
||||
redis_connector.external_group_sync.set_fence(None)
|
||||
if lock.owned():
|
||||
lock.release()
|
||||
|
||||
|
||||
def validate_external_group_sync_fences(
|
||||
tenant_id: str | None,
|
||||
celery_app: Celery,
|
||||
r: Redis,
|
||||
r_celery: Redis,
|
||||
lock_beat: RedisLock,
|
||||
) -> None:
|
||||
reserved_sync_tasks = celery_get_unacked_task_ids(
|
||||
OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC, r_celery
|
||||
)
|
||||
|
||||
# validate all existing indexing jobs
|
||||
for key_bytes in r.scan_iter(
|
||||
RedisConnectorExternalGroupSync.FENCE_PREFIX + "*",
|
||||
count=SCAN_ITER_COUNT_DEFAULT,
|
||||
):
|
||||
lock_beat.reacquire()
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
validate_external_group_sync_fence(
|
||||
tenant_id,
|
||||
key_bytes,
|
||||
reserved_sync_tasks,
|
||||
r_celery,
|
||||
db_session,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
def validate_external_group_sync_fence(
|
||||
tenant_id: str | None,
|
||||
key_bytes: bytes,
|
||||
reserved_tasks: set[str],
|
||||
r_celery: Redis,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Checks for the error condition where an indexing fence is set but the associated celery tasks don't exist.
|
||||
This can happen if the indexing worker hard crashes or is terminated.
|
||||
Being in this bad state means the fence will never clear without help, so this function
|
||||
gives the help.
|
||||
|
||||
How this works:
|
||||
1. This function renews the active signal with a 5 minute TTL under the following conditions
|
||||
1.2. When the task is seen in the redis queue
|
||||
1.3. When the task is seen in the reserved / prefetched list
|
||||
|
||||
2. Externally, the active signal is renewed when:
|
||||
2.1. The fence is created
|
||||
2.2. The indexing watchdog checks the spawned task.
|
||||
|
||||
3. The TTL allows us to get through the transitions on fence startup
|
||||
and when the task starts executing.
|
||||
|
||||
More TTL clarification: it is seemingly impossible to exactly query Celery for
|
||||
whether a task is in the queue or currently executing.
|
||||
1. An unknown task id is always returned as state PENDING.
|
||||
2. Redis can be inspected for the task id, but the task id is gone between the time a worker receives the task
|
||||
and the time it actually starts on the worker.
|
||||
"""
|
||||
# if the fence doesn't exist, there's nothing to do
|
||||
fence_key = key_bytes.decode("utf-8")
|
||||
cc_pair_id_str = RedisConnector.get_id_from_fence_key(fence_key)
|
||||
if cc_pair_id_str is None:
|
||||
task_logger.warning(
|
||||
f"validate_external_group_sync_fence - could not parse id from {fence_key}"
|
||||
)
|
||||
return
|
||||
|
||||
cc_pair_id = int(cc_pair_id_str)
|
||||
|
||||
# parse out metadata and initialize the helper class with it
|
||||
redis_connector = RedisConnector(tenant_id, int(cc_pair_id))
|
||||
|
||||
# check to see if the fence/payload exists
|
||||
if not redis_connector.external_group_sync.fenced:
|
||||
return
|
||||
|
||||
payload = redis_connector.external_group_sync.payload
|
||||
if not payload:
|
||||
return
|
||||
|
||||
# OK, there's actually something for us to validate
|
||||
|
||||
if payload.celery_task_id is None:
|
||||
# the fence is just barely set up.
|
||||
# if redis_connector_index.active():
|
||||
# return
|
||||
|
||||
# it would be odd to get here as there isn't that much that can go wrong during
|
||||
# initial fence setup, but it's still worth making sure we can recover
|
||||
logger.info(
|
||||
"validate_external_group_sync_fence - "
|
||||
f"Resetting fence in basic state without any activity: fence={fence_key}"
|
||||
)
|
||||
redis_connector.external_group_sync.reset()
|
||||
return
|
||||
|
||||
found = celery_find_task(
|
||||
payload.celery_task_id, OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC, r_celery
|
||||
)
|
||||
if found:
|
||||
# the celery task exists in the redis queue
|
||||
# redis_connector_index.set_active()
|
||||
return
|
||||
|
||||
if payload.celery_task_id in reserved_tasks:
|
||||
# the celery task was prefetched and is reserved within the indexing worker
|
||||
# redis_connector_index.set_active()
|
||||
return
|
||||
|
||||
# we may want to enable this check if using the active task list somehow isn't good enough
|
||||
# if redis_connector_index.generator_locked():
|
||||
# logger.info(f"{payload.celery_task_id} is currently executing.")
|
||||
|
||||
# if we get here, we didn't find any direct indication that the associated celery tasks exist,
|
||||
# but they still might be there due to gaps in our ability to check states during transitions
|
||||
# Checking the active signal safeguards us against these transition periods
|
||||
# (which has a duration that allows us to bridge those gaps)
|
||||
# if redis_connector_index.active():
|
||||
# return
|
||||
|
||||
# celery tasks don't exist and the active signal has expired, possibly due to a crash. Clean it up.
|
||||
logger.warning(
|
||||
"validate_external_group_sync_fence - "
|
||||
"Resetting fence because no associated celery tasks were found: "
|
||||
f"cc_pair={cc_pair_id} "
|
||||
f"fence={fence_key}"
|
||||
)
|
||||
|
||||
redis_connector.external_group_sync.reset()
|
||||
return
|
||||
|
||||
@@ -39,6 +39,7 @@ from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.redis.redis_pool import redis_lock_dump
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
|
||||
_MONITORING_SOFT_TIME_LIMIT = 60 * 5 # 5 minutes
|
||||
@@ -657,6 +658,9 @@ def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
|
||||
- Syncing speed metrics
|
||||
- Worker status and task counts
|
||||
"""
|
||||
if tenant_id is not None:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
task_logger.info("Starting background monitoring")
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
|
||||
@@ -688,11 +692,13 @@ def monitor_background_processes(self: Task, *, tenant_id: str | None) -> None:
|
||||
metrics = metric_fn()
|
||||
for metric in metrics:
|
||||
# double check to make sure we aren't double-emitting metrics
|
||||
if metric.key is not None and not _has_metric_been_emitted(
|
||||
if metric.key is None or not _has_metric_been_emitted(
|
||||
redis_std, metric.key
|
||||
):
|
||||
metric.log()
|
||||
metric.emit(tenant_id)
|
||||
|
||||
if metric.key is not None:
|
||||
_mark_metric_as_emitted(redis_std, metric.key)
|
||||
|
||||
task_logger.info("Successfully collected background metrics")
|
||||
|
||||
@@ -39,6 +39,7 @@ from onyx.db.sync_record import insert_sync_record
|
||||
from onyx.db.sync_record import update_sync_record_status
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import LoggerContextVars
|
||||
from onyx.utils.logger import pruning_ctx
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -251,6 +252,8 @@ def connector_pruning_generator_task(
|
||||
and compares those IDs to locally stored documents and deletes all locally stored IDs missing
|
||||
from the most recently pulled document ID list"""
|
||||
|
||||
LoggerContextVars.reset()
|
||||
|
||||
pruning_ctx_dict = pruning_ctx.get()
|
||||
pruning_ctx_dict["cc_pair_id"] = cc_pair_id
|
||||
pruning_ctx_dict["request_id"] = self.request.id
|
||||
@@ -399,7 +402,7 @@ def monitor_ccpair_pruning_taskset(
|
||||
|
||||
mark_ccpair_as_pruned(int(cc_pair_id), db_session)
|
||||
task_logger.info(
|
||||
f"Successfully pruned connector credential pair. cc_pair={cc_pair_id}"
|
||||
f"Connector pruning finished: cc_pair={cc_pair_id} num_pruned={initial}"
|
||||
)
|
||||
|
||||
update_sync_record_status(
|
||||
|
||||
@@ -75,6 +75,8 @@ def document_by_cc_pair_cleanup_task(
|
||||
"""
|
||||
task_logger.debug(f"Task start: doc={document_id}")
|
||||
|
||||
start = time.monotonic()
|
||||
|
||||
try:
|
||||
with get_session_with_tenant(tenant_id) as db_session:
|
||||
action = "skip"
|
||||
@@ -154,11 +156,13 @@ def document_by_cc_pair_cleanup_task(
|
||||
|
||||
db_session.commit()
|
||||
|
||||
elapsed = time.monotonic() - start
|
||||
task_logger.info(
|
||||
f"doc={document_id} "
|
||||
f"action={action} "
|
||||
f"refcount={count} "
|
||||
f"chunks={chunks_affected}"
|
||||
f"chunks={chunks_affected} "
|
||||
f"elapsed={elapsed:.2f}"
|
||||
)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
|
||||
|
||||
@@ -989,6 +989,10 @@ def monitor_vespa_sync(self: Task, tenant_id: str | None) -> bool | None:
|
||||
task_logger.info(
|
||||
"Soft time limit exceeded, task is being terminated gracefully."
|
||||
)
|
||||
return False
|
||||
except Exception:
|
||||
task_logger.exception("monitor_vespa_sync exceptioned.")
|
||||
return False
|
||||
finally:
|
||||
if lock_beat.owned():
|
||||
lock_beat.release()
|
||||
@@ -1078,6 +1082,7 @@ def vespa_metadata_sync_task(
|
||||
)
|
||||
except SoftTimeLimitExceeded:
|
||||
task_logger.info(f"SoftTimeLimitExceeded exception. doc={document_id}")
|
||||
return False
|
||||
except Exception as ex:
|
||||
if isinstance(ex, RetryError):
|
||||
task_logger.warning(
|
||||
|
||||
@@ -617,3 +617,8 @@ POD_NAMESPACE = os.environ.get("POD_NAMESPACE")
|
||||
DEV_MODE = os.environ.get("DEV_MODE", "").lower() == "true"
|
||||
|
||||
TEST_ENV = os.environ.get("TEST_ENV", "").lower() == "true"
|
||||
|
||||
# Set to true to mock LLM responses for testing purposes
|
||||
MOCK_LLM_RESPONSE = (
|
||||
os.environ.get("MOCK_LLM_RESPONSE") if os.environ.get("MOCK_LLM_RESPONSE") else None
|
||||
)
|
||||
|
||||
@@ -300,6 +300,8 @@ class OnyxRedisLocks:
|
||||
|
||||
class OnyxRedisSignals:
|
||||
VALIDATE_INDEXING_FENCES = "signal:validate_indexing_fences"
|
||||
VALIDATE_EXTERNAL_GROUP_SYNC_FENCES = "signal:validate_external_group_sync_fences"
|
||||
VALIDATE_PERMISSION_SYNC_FENCES = "signal:validate_permission_sync_fences"
|
||||
|
||||
|
||||
class OnyxCeleryPriority(int, Enum):
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import contextvars
|
||||
from concurrent.futures import as_completed
|
||||
from concurrent.futures import Future
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
@@ -68,18 +70,25 @@ class AirtableConnector(LoadConnector):
|
||||
self.base_id = base_id
|
||||
self.table_name_or_id = table_name_or_id
|
||||
self.batch_size = batch_size
|
||||
self.airtable_client: AirtableApi | None = None
|
||||
self._airtable_client: AirtableApi | None = None
|
||||
self.treat_all_non_attachment_fields_as_metadata = (
|
||||
treat_all_non_attachment_fields_as_metadata
|
||||
)
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self.airtable_client = AirtableApi(credentials["airtable_access_token"])
|
||||
self._airtable_client = AirtableApi(credentials["airtable_access_token"])
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
@property
|
||||
def airtable_client(self) -> AirtableApi:
|
||||
if not self._airtable_client:
|
||||
raise AirtableClientNotSetUpError()
|
||||
return self._airtable_client
|
||||
|
||||
def _extract_field_values(
|
||||
self,
|
||||
field_id: str,
|
||||
field_name: str,
|
||||
field_info: Any,
|
||||
field_type: str,
|
||||
base_id: str,
|
||||
@@ -118,13 +127,33 @@ class AirtableConnector(LoadConnector):
|
||||
backoff=2,
|
||||
max_delay=10,
|
||||
)
|
||||
def get_attachment_with_retry(url: str) -> bytes | None:
|
||||
attachment_response = requests.get(url)
|
||||
if attachment_response.status_code == 200:
|
||||
def get_attachment_with_retry(url: str, record_id: str) -> bytes | None:
|
||||
try:
|
||||
attachment_response = requests.get(url)
|
||||
attachment_response.raise_for_status()
|
||||
return attachment_response.content
|
||||
return None
|
||||
except requests.exceptions.HTTPError as e:
|
||||
if e.response.status_code == 410:
|
||||
logger.info(f"Refreshing attachment for {filename}")
|
||||
# Re-fetch the record to get a fresh URL
|
||||
refreshed_record = self.airtable_client.table(
|
||||
base_id, table_id
|
||||
).get(record_id)
|
||||
for refreshed_attachment in refreshed_record["fields"][
|
||||
field_name
|
||||
]:
|
||||
if refreshed_attachment.get("filename") == filename:
|
||||
new_url = refreshed_attachment.get("url")
|
||||
if new_url:
|
||||
attachment_response = requests.get(new_url)
|
||||
attachment_response.raise_for_status()
|
||||
return attachment_response.content
|
||||
|
||||
attachment_content = get_attachment_with_retry(url)
|
||||
logger.error(f"Failed to refresh attachment for {filename}")
|
||||
|
||||
raise
|
||||
|
||||
attachment_content = get_attachment_with_retry(url, record_id)
|
||||
if attachment_content:
|
||||
try:
|
||||
file_ext = get_file_ext(filename)
|
||||
@@ -208,6 +237,7 @@ class AirtableConnector(LoadConnector):
|
||||
# Get the value(s) for the field
|
||||
field_value_and_links = self._extract_field_values(
|
||||
field_id=field_id,
|
||||
field_name=field_name,
|
||||
field_info=field_info,
|
||||
field_type=field_type,
|
||||
base_id=self.base_id,
|
||||
@@ -337,7 +367,7 @@ class AirtableConnector(LoadConnector):
|
||||
logger.info(f"Starting to process Airtable records for {table.name}.")
|
||||
|
||||
# Process records in parallel batches using ThreadPoolExecutor
|
||||
PARALLEL_BATCH_SIZE = 16
|
||||
PARALLEL_BATCH_SIZE = 8
|
||||
max_workers = min(PARALLEL_BATCH_SIZE, len(records))
|
||||
|
||||
# Process records in batches
|
||||
@@ -347,15 +377,19 @@ class AirtableConnector(LoadConnector):
|
||||
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
# Submit batch tasks
|
||||
future_to_record = {
|
||||
executor.submit(
|
||||
self._process_record,
|
||||
record=record,
|
||||
table_schema=table_schema,
|
||||
primary_field_name=primary_field_name,
|
||||
): record
|
||||
for record in batch_records
|
||||
}
|
||||
future_to_record: dict[Future, RecordDict] = {}
|
||||
for record in batch_records:
|
||||
# Capture the current context so that the thread gets the current tenant ID
|
||||
current_context = contextvars.copy_context()
|
||||
future_to_record[
|
||||
executor.submit(
|
||||
current_context.run,
|
||||
self._process_record,
|
||||
record=record,
|
||||
table_schema=table_schema,
|
||||
primary_field_name=primary_field_name,
|
||||
)
|
||||
] = record
|
||||
|
||||
# Wait for all tasks in this batch to complete
|
||||
for future in as_completed(future_to_record):
|
||||
@@ -368,10 +402,8 @@ class AirtableConnector(LoadConnector):
|
||||
logger.exception(f"Failed to process record {record['id']}")
|
||||
raise e
|
||||
|
||||
# After batch is complete, yield if we've hit the batch size
|
||||
if len(record_documents) >= self.batch_size:
|
||||
yield record_documents
|
||||
record_documents = []
|
||||
yield record_documents
|
||||
record_documents = []
|
||||
|
||||
# Yield any remaining records
|
||||
if record_documents:
|
||||
|
||||
@@ -150,6 +150,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
|
||||
# if specified, controls the assistants that are shown to the user + their order
|
||||
# if not specified, all assistants are shown
|
||||
temperature_override_enabled: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
auto_scroll: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
shortcut_enabled: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
chosen_assistants: Mapped[list[int] | None] = mapped_column(
|
||||
@@ -1115,6 +1116,10 @@ class ChatSession(Base):
|
||||
llm_override: Mapped[LLMOverride | None] = mapped_column(
|
||||
PydanticType(LLMOverride), nullable=True
|
||||
)
|
||||
|
||||
# The latest temperature override specified by the user
|
||||
temperature_override: Mapped[float | None] = mapped_column(Float, nullable=True)
|
||||
|
||||
prompt_override: Mapped[PromptOverride | None] = mapped_column(
|
||||
PydanticType(PromptOverride), nullable=True
|
||||
)
|
||||
|
||||
@@ -2,7 +2,7 @@ from onyx.key_value_store.interface import KeyValueStore
|
||||
from onyx.key_value_store.store import PgRedisKVStore
|
||||
|
||||
|
||||
def get_kv_store(tenant_id: str | None = None) -> KeyValueStore:
|
||||
def get_kv_store() -> KeyValueStore:
|
||||
# In the Multi Tenant case, the tenant context is picked up automatically, it does not need to be passed in
|
||||
# It's read from the global thread level variable
|
||||
return PgRedisKVStore(tenant_id=tenant_id)
|
||||
return PgRedisKVStore()
|
||||
|
||||
@@ -18,7 +18,7 @@ from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.special_types import JSON_ro
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -28,10 +28,8 @@ KV_REDIS_KEY_EXPIRATION = 60 * 60 * 24 # 1 Day
|
||||
|
||||
|
||||
class PgRedisKVStore(KeyValueStore):
|
||||
def __init__(
|
||||
self, redis_client: Redis | None = None, tenant_id: str | None = None
|
||||
) -> None:
|
||||
self.tenant_id = tenant_id or CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
def __init__(self, redis_client: Redis | None = None) -> None:
|
||||
self.tenant_id = get_current_tenant_id()
|
||||
|
||||
# If no redis_client is provided, fall back to the context var
|
||||
if redis_client is not None:
|
||||
|
||||
@@ -26,6 +26,7 @@ from langchain_core.messages.tool import ToolMessage
|
||||
from langchain_core.prompt_values import PromptValue
|
||||
|
||||
from onyx.configs.app_configs import LOG_DANSWER_MODEL_INTERACTIONS
|
||||
from onyx.configs.app_configs import MOCK_LLM_RESPONSE
|
||||
from onyx.configs.model_configs import (
|
||||
DISABLE_LITELLM_STREAMING,
|
||||
)
|
||||
@@ -387,6 +388,7 @@ class DefaultMultiLLM(LLM):
|
||||
|
||||
try:
|
||||
return litellm.completion(
|
||||
mock_response=MOCK_LLM_RESPONSE,
|
||||
# model choice
|
||||
model=f"{self.config.model_provider}/{self.config.deployment_name or self.config.model_name}",
|
||||
# NOTE: have to pass in None instead of empty string for these
|
||||
|
||||
@@ -109,7 +109,9 @@ from onyx.utils.variable_functionality import global_version
|
||||
from onyx.utils.variable_functionality import set_is_ee_based_on_env_variable
|
||||
from shared_configs.configs import CORS_ALLOWED_ORIGIN
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.configs import SENTRY_DSN
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -212,7 +214,8 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
|
||||
if not MULTI_TENANT:
|
||||
# We cache this at the beginning so there is no delay in the first telemetry
|
||||
get_or_generate_uuid(tenant_id=None)
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.set(POSTGRES_DEFAULT_SCHEMA)
|
||||
get_or_generate_uuid()
|
||||
|
||||
# If we are multi-tenant, we need to only set up initial public tables
|
||||
with Session(engine) as db_session:
|
||||
|
||||
@@ -175,7 +175,6 @@ class EmbeddingModel:
|
||||
if self.callback.should_stop():
|
||||
raise RuntimeError("_batch_encode_texts detected stop signal")
|
||||
|
||||
logger.debug(f"Encoding batch {batch_idx} of {len(text_batches)}")
|
||||
embed_request = EmbedRequest(
|
||||
model_name=self.model_name,
|
||||
texts=text_batch,
|
||||
@@ -191,7 +190,15 @@ class EmbeddingModel:
|
||||
api_url=self.api_url,
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
response = self._make_model_server_request(embed_request)
|
||||
end_time = time.time()
|
||||
|
||||
processing_time = end_time - start_time
|
||||
logger.info(
|
||||
f"Batch {batch_idx} processing time: {processing_time:.2f} seconds"
|
||||
)
|
||||
|
||||
return batch_idx, response.embeddings
|
||||
|
||||
# only multi thread if:
|
||||
|
||||
@@ -17,6 +17,8 @@ from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
|
||||
|
||||
|
||||
class RedisConnectorPermissionSyncPayload(BaseModel):
|
||||
id: str
|
||||
submitted: datetime
|
||||
started: datetime | None
|
||||
celery_task_id: str | None
|
||||
|
||||
@@ -41,6 +43,12 @@ class RedisConnectorPermissionSync:
|
||||
TASKSET_PREFIX = f"{PREFIX}_taskset" # connectorpermissions_taskset
|
||||
SUBTASK_PREFIX = f"{PREFIX}+sub" # connectorpermissions+sub
|
||||
|
||||
# used to signal the overall workflow is still active
|
||||
# it's impossible to get the exact state of the system at a single point in time
|
||||
# so we need a signal with a TTL to bridge gaps in our checks
|
||||
ACTIVE_PREFIX = PREFIX + "_active"
|
||||
ACTIVE_TTL = 3600
|
||||
|
||||
def __init__(self, tenant_id: str | None, id: int, redis: redis.Redis) -> None:
|
||||
self.tenant_id: str | None = tenant_id
|
||||
self.id = id
|
||||
@@ -54,6 +62,7 @@ class RedisConnectorPermissionSync:
|
||||
self.taskset_key = f"{self.TASKSET_PREFIX}_{id}"
|
||||
|
||||
self.subtask_prefix: str = f"{self.SUBTASK_PREFIX}_{id}"
|
||||
self.active_key = f"{self.ACTIVE_PREFIX}_{id}"
|
||||
|
||||
def taskset_clear(self) -> None:
|
||||
self.redis.delete(self.taskset_key)
|
||||
@@ -107,6 +116,20 @@ class RedisConnectorPermissionSync:
|
||||
|
||||
self.redis.set(self.fence_key, payload.model_dump_json())
|
||||
|
||||
def set_active(self) -> None:
|
||||
"""This sets a signal to keep the permissioning flow from getting cleaned up within
|
||||
the expiration time.
|
||||
|
||||
The slack in timing is needed to avoid race conditions where simply checking
|
||||
the celery queue and task status could result in race conditions."""
|
||||
self.redis.set(self.active_key, 0, ex=self.ACTIVE_TTL)
|
||||
|
||||
def active(self) -> bool:
|
||||
if self.redis.exists(self.active_key):
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
@property
|
||||
def generator_complete(self) -> int | None:
|
||||
"""the fence payload is an int representing the starting number of
|
||||
@@ -173,6 +196,7 @@ class RedisConnectorPermissionSync:
|
||||
return len(async_results)
|
||||
|
||||
def reset(self) -> None:
|
||||
self.redis.delete(self.active_key)
|
||||
self.redis.delete(self.generator_progress_key)
|
||||
self.redis.delete(self.generator_complete_key)
|
||||
self.redis.delete(self.taskset_key)
|
||||
@@ -187,6 +211,9 @@ class RedisConnectorPermissionSync:
|
||||
@staticmethod
|
||||
def reset_all(r: redis.Redis) -> None:
|
||||
"""Deletes all redis values for all connectors"""
|
||||
for key in r.scan_iter(RedisConnectorPermissionSync.ACTIVE_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
for key in r.scan_iter(RedisConnectorPermissionSync.TASKSET_PREFIX + "*"):
|
||||
r.delete(key)
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ from onyx.redis.redis_pool import SCAN_ITER_COUNT_DEFAULT
|
||||
|
||||
|
||||
class RedisConnectorExternalGroupSyncPayload(BaseModel):
|
||||
submitted: datetime
|
||||
started: datetime | None
|
||||
celery_task_id: str | None
|
||||
|
||||
@@ -135,6 +136,12 @@ class RedisConnectorExternalGroupSync:
|
||||
) -> int | None:
|
||||
pass
|
||||
|
||||
def reset(self) -> None:
|
||||
self.redis.delete(self.generator_progress_key)
|
||||
self.redis.delete(self.generator_complete_key)
|
||||
self.redis.delete(self.taskset_key)
|
||||
self.redis.delete(self.fence_key)
|
||||
|
||||
@staticmethod
|
||||
def remove_from_taskset(id: int, task_id: str, r: redis.Redis) -> None:
|
||||
taskset_key = f"{RedisConnectorExternalGroupSync.TASKSET_PREFIX}_{id}"
|
||||
|
||||
@@ -33,8 +33,8 @@ class RedisConnectorIndex:
|
||||
TERMINATE_TTL = 600
|
||||
|
||||
# used to signal the overall workflow is still active
|
||||
# there are gaps in time between states where we need some slack
|
||||
# to correctly transition
|
||||
# it's impossible to get the exact state of the system at a single point in time
|
||||
# so we need a signal with a TTL to bridge gaps in our checks
|
||||
ACTIVE_PREFIX = PREFIX + "_active"
|
||||
ACTIVE_TTL = 3600
|
||||
|
||||
|
||||
@@ -122,7 +122,7 @@ class TenantRedis(redis.Redis):
|
||||
"ttl",
|
||||
] # Regular methods that need simple prefixing
|
||||
|
||||
if item == "scan_iter":
|
||||
if item == "scan_iter" or item == "sscan_iter":
|
||||
return self._prefix_scan_iter(original_attr)
|
||||
elif item in methods_to_wrap and callable(original_attr):
|
||||
return self._prefix_method(original_attr)
|
||||
|
||||
@@ -422,27 +422,29 @@ def sync_cc_pair(
|
||||
if redis_connector.permissions.fenced:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.CONFLICT,
|
||||
detail="Doc permissions sync task already in progress.",
|
||||
detail="Permissions sync task already in progress.",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Doc permissions sync cc_pair={cc_pair_id} "
|
||||
f"Permissions sync cc_pair={cc_pair_id} "
|
||||
f"connector_id={cc_pair.connector_id} "
|
||||
f"credential_id={cc_pair.credential_id} "
|
||||
f"{cc_pair.connector.name} connector."
|
||||
)
|
||||
tasks_created = try_creating_permissions_sync_task(
|
||||
payload_id = try_creating_permissions_sync_task(
|
||||
primary_app, cc_pair_id, r, CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
)
|
||||
if not tasks_created:
|
||||
if not payload_id:
|
||||
raise HTTPException(
|
||||
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
|
||||
detail="Doc permissions sync task creation failed.",
|
||||
detail="Permissions sync task creation failed.",
|
||||
)
|
||||
|
||||
logger.info(f"Permissions sync queued: cc_pair={cc_pair_id} id={payload_id}")
|
||||
|
||||
return StatusResponse(
|
||||
success=True,
|
||||
message="Successfully created the doc permissions sync task.",
|
||||
message="Successfully created the permissions sync task.",
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -48,6 +48,7 @@ class UserPreferences(BaseModel):
|
||||
auto_scroll: bool | None = None
|
||||
pinned_assistants: list[int] | None = None
|
||||
shortcut_enabled: bool | None = None
|
||||
temperature_override_enabled: bool | None = None
|
||||
|
||||
|
||||
class UserInfo(BaseModel):
|
||||
@@ -91,6 +92,7 @@ class UserInfo(BaseModel):
|
||||
hidden_assistants=user.hidden_assistants,
|
||||
pinned_assistants=user.pinned_assistants,
|
||||
visible_assistants=user.visible_assistants,
|
||||
temperature_override_enabled=user.temperature_override_enabled,
|
||||
)
|
||||
),
|
||||
organization_name=organization_name,
|
||||
|
||||
@@ -568,6 +568,32 @@ def verify_user_logged_in(
|
||||
"""APIs to adjust user preferences"""
|
||||
|
||||
|
||||
@router.patch("/temperature-override-enabled")
|
||||
def update_user_temperature_override_enabled(
|
||||
temperature_override_enabled: bool,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
if user is None:
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
store = get_kv_store()
|
||||
no_auth_user = fetch_no_auth_user(store)
|
||||
no_auth_user.preferences.temperature_override_enabled = (
|
||||
temperature_override_enabled
|
||||
)
|
||||
set_no_auth_user_preferences(store, no_auth_user.preferences)
|
||||
return
|
||||
else:
|
||||
raise RuntimeError("This should never happen")
|
||||
|
||||
db_session.execute(
|
||||
update(User)
|
||||
.where(User.id == user.id) # type: ignore
|
||||
.values(temperature_override_enabled=temperature_override_enabled)
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
class ChosenDefaultModelRequest(BaseModel):
|
||||
default_model: str | None = None
|
||||
|
||||
|
||||
@@ -18,7 +18,6 @@ from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_chat_accesssible_user
|
||||
from onyx.auth.users import current_limited_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.chat.chat_utils import create_chat_chain
|
||||
from onyx.chat.chat_utils import extract_headers
|
||||
@@ -78,6 +77,7 @@ from onyx.server.query_and_chat.models import LLMOverride
|
||||
from onyx.server.query_and_chat.models import PromptOverride
|
||||
from onyx.server.query_and_chat.models import RenameChatSessionResponse
|
||||
from onyx.server.query_and_chat.models import SearchFeedbackRequest
|
||||
from onyx.server.query_and_chat.models import UpdateChatSessionTemperatureRequest
|
||||
from onyx.server.query_and_chat.models import UpdateChatSessionThreadRequest
|
||||
from onyx.server.query_and_chat.token_limit import check_token_rate_limits
|
||||
from onyx.utils.headers import get_custom_tool_additional_request_headers
|
||||
@@ -115,12 +115,52 @@ def get_user_chat_sessions(
|
||||
shared_status=chat.shared_status,
|
||||
folder_id=chat.folder_id,
|
||||
current_alternate_model=chat.current_alternate_model,
|
||||
current_temperature_override=chat.temperature_override,
|
||||
)
|
||||
for chat in chat_sessions
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@router.put("/update-chat-session-temperature")
|
||||
def update_chat_session_temperature(
|
||||
update_thread_req: UpdateChatSessionTemperatureRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=update_thread_req.chat_session_id,
|
||||
user_id=user.id if user is not None else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Validate temperature_override
|
||||
if update_thread_req.temperature_override is not None:
|
||||
if (
|
||||
update_thread_req.temperature_override < 0
|
||||
or update_thread_req.temperature_override > 2
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Temperature must be between 0 and 2"
|
||||
)
|
||||
|
||||
# Additional check for Anthropic models
|
||||
if (
|
||||
chat_session.current_alternate_model
|
||||
and "anthropic" in chat_session.current_alternate_model.lower()
|
||||
):
|
||||
if update_thread_req.temperature_override > 1:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Temperature for Anthropic models must be between 0 and 1",
|
||||
)
|
||||
|
||||
chat_session.temperature_override = update_thread_req.temperature_override
|
||||
|
||||
db_session.add(chat_session)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@router.put("/update-chat-session-model")
|
||||
def update_chat_session_model(
|
||||
update_thread_req: UpdateChatSessionThreadRequest,
|
||||
@@ -191,6 +231,7 @@ def get_chat_session(
|
||||
],
|
||||
time_created=chat_session.time_created,
|
||||
shared_status=chat_session.shared_status,
|
||||
current_temperature_override=chat_session.temperature_override,
|
||||
)
|
||||
|
||||
|
||||
@@ -422,7 +463,7 @@ def set_message_as_latest(
|
||||
@router.post("/create-chat-message-feedback")
|
||||
def create_chat_feedback(
|
||||
feedback: ChatFeedbackRequest,
|
||||
user: User | None = Depends(current_limited_user),
|
||||
user: User | None = Depends(current_chat_accesssible_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
user_id = user.id if user else None
|
||||
|
||||
@@ -42,6 +42,11 @@ class UpdateChatSessionThreadRequest(BaseModel):
|
||||
new_alternate_model: str
|
||||
|
||||
|
||||
class UpdateChatSessionTemperatureRequest(BaseModel):
|
||||
chat_session_id: UUID
|
||||
temperature_override: float
|
||||
|
||||
|
||||
class ChatSessionCreationRequest(BaseModel):
|
||||
# If not specified, use Onyx default persona
|
||||
persona_id: int = 0
|
||||
@@ -108,6 +113,10 @@ class CreateChatMessageRequest(ChunkContext):
|
||||
llm_override: LLMOverride | None = None
|
||||
prompt_override: PromptOverride | None = None
|
||||
|
||||
# Allows the caller to override the temperature for the chat session
|
||||
# this does persist in the chat thread details
|
||||
temperature_override: float | None = None
|
||||
|
||||
# allow user to specify an alternate assistnat
|
||||
alternate_assistant_id: int | None = None
|
||||
|
||||
@@ -168,6 +177,7 @@ class ChatSessionDetails(BaseModel):
|
||||
shared_status: ChatSessionSharedStatus
|
||||
folder_id: int | None = None
|
||||
current_alternate_model: str | None = None
|
||||
current_temperature_override: float | None = None
|
||||
|
||||
|
||||
class ChatSessionsResponse(BaseModel):
|
||||
@@ -231,6 +241,7 @@ class ChatSessionDetailResponse(BaseModel):
|
||||
time_created: datetime
|
||||
shared_status: ChatSessionSharedStatus
|
||||
current_alternate_model: str | None
|
||||
current_temperature_override: float | None
|
||||
|
||||
|
||||
# This one is not used anymore
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import base64
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
@@ -66,3 +68,10 @@ def mask_credential_dict(credential_dict: dict[str, Any]) -> dict[str, str]:
|
||||
)
|
||||
|
||||
return masked_creds
|
||||
|
||||
|
||||
def make_short_id() -> str:
|
||||
"""Fast way to generate a random 8 character id ... useful for tagging data
|
||||
to trace it through a flow. This is definitely not guaranteed to be unique and is
|
||||
targeted at the stated use case."""
|
||||
return base64.b32encode(os.urandom(5)).decode("utf-8")[:8] # 5 bytes → 8 chars
|
||||
|
||||
@@ -37,6 +37,7 @@ from onyx.document_index.vespa.index import VespaIndex
|
||||
from onyx.indexing.models import IndexingSetting
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.llm.llm_provider_options import OPEN_AI_MODEL_NAMES
|
||||
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from onyx.natural_language_processing.search_nlp_models import warm_up_bi_encoder
|
||||
from onyx.natural_language_processing.search_nlp_models import warm_up_cross_encoder
|
||||
@@ -279,6 +280,7 @@ def setup_postgres(db_session: Session) -> None:
|
||||
if GEN_AI_API_KEY and fetch_default_provider(db_session) is None:
|
||||
# Only for dev flows
|
||||
logger.notice("Setting up default OpenAI LLM for dev.")
|
||||
|
||||
llm_model = GEN_AI_MODEL_VERSION or "gpt-4o-mini"
|
||||
fast_model = FAST_GEN_AI_MODEL_VERSION or "gpt-4o-mini"
|
||||
model_req = LLMProviderUpsertRequest(
|
||||
@@ -292,8 +294,8 @@ def setup_postgres(db_session: Session) -> None:
|
||||
fast_default_model_name=fast_model,
|
||||
is_public=True,
|
||||
groups=[],
|
||||
display_model_names=[llm_model, fast_model],
|
||||
model_names=[llm_model, fast_model],
|
||||
display_model_names=OPEN_AI_MODEL_NAMES,
|
||||
model_names=OPEN_AI_MODEL_NAMES,
|
||||
)
|
||||
new_llm_provider = upsert_llm_provider(
|
||||
llm_provider=model_req, db_session=db_session
|
||||
|
||||
@@ -26,6 +26,13 @@ doc_permission_sync_ctx: contextvars.ContextVar[
|
||||
] = contextvars.ContextVar("doc_permission_sync_ctx", default=dict())
|
||||
|
||||
|
||||
class LoggerContextVars:
|
||||
@staticmethod
|
||||
def reset() -> None:
|
||||
pruning_ctx.set(dict())
|
||||
doc_permission_sync_ctx.set(dict())
|
||||
|
||||
|
||||
class TaskAttemptSingleton:
|
||||
"""Used to tell if this process is an indexing job, and if so what is the
|
||||
unique identifier for this indexing attempt. For things like the API server,
|
||||
@@ -70,27 +77,32 @@ class OnyxLoggingAdapter(logging.LoggerAdapter):
|
||||
) -> tuple[str, MutableMapping[str, Any]]:
|
||||
# If this is an indexing job, add the attempt ID to the log message
|
||||
# This helps filter the logs for this specific indexing
|
||||
index_attempt_id = TaskAttemptSingleton.get_index_attempt_id()
|
||||
cc_pair_id = TaskAttemptSingleton.get_connector_credential_pair_id()
|
||||
while True:
|
||||
pruning_ctx_dict = pruning_ctx.get()
|
||||
if len(pruning_ctx_dict) > 0:
|
||||
if "request_id" in pruning_ctx_dict:
|
||||
msg = f"[Prune: {pruning_ctx_dict['request_id']}] {msg}"
|
||||
|
||||
doc_permission_sync_ctx_dict = doc_permission_sync_ctx.get()
|
||||
pruning_ctx_dict = pruning_ctx.get()
|
||||
if len(pruning_ctx_dict) > 0:
|
||||
if "request_id" in pruning_ctx_dict:
|
||||
msg = f"[Prune: {pruning_ctx_dict['request_id']}] {msg}"
|
||||
if "cc_pair_id" in pruning_ctx_dict:
|
||||
msg = f"[CC Pair: {pruning_ctx_dict['cc_pair_id']}] {msg}"
|
||||
break
|
||||
|
||||
doc_permission_sync_ctx_dict = doc_permission_sync_ctx.get()
|
||||
if len(doc_permission_sync_ctx_dict) > 0:
|
||||
if "request_id" in doc_permission_sync_ctx_dict:
|
||||
msg = f"[Doc Permissions Sync: {doc_permission_sync_ctx_dict['request_id']}] {msg}"
|
||||
break
|
||||
|
||||
index_attempt_id = TaskAttemptSingleton.get_index_attempt_id()
|
||||
cc_pair_id = TaskAttemptSingleton.get_connector_credential_pair_id()
|
||||
|
||||
if "cc_pair_id" in pruning_ctx_dict:
|
||||
msg = f"[CC Pair: {pruning_ctx_dict['cc_pair_id']}] {msg}"
|
||||
elif len(doc_permission_sync_ctx_dict) > 0:
|
||||
if "request_id" in doc_permission_sync_ctx_dict:
|
||||
msg = f"[Doc Permissions Sync: {doc_permission_sync_ctx_dict['request_id']}] {msg}"
|
||||
else:
|
||||
if index_attempt_id is not None:
|
||||
msg = f"[Index Attempt: {index_attempt_id}] {msg}"
|
||||
|
||||
if cc_pair_id is not None:
|
||||
msg = f"[CC Pair: {cc_pair_id}] {msg}"
|
||||
|
||||
break
|
||||
# Add tenant information if it differs from default
|
||||
# This will always be the case for authenticated API requests
|
||||
if MULTI_TENANT:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import contextvars
|
||||
import threading
|
||||
import uuid
|
||||
from enum import Enum
|
||||
@@ -41,7 +42,7 @@ def _get_or_generate_customer_id_mt(tenant_id: str) -> str:
|
||||
return str(uuid.uuid5(uuid.NAMESPACE_X500, tenant_id))
|
||||
|
||||
|
||||
def get_or_generate_uuid(tenant_id: str | None) -> str:
|
||||
def get_or_generate_uuid() -> str:
|
||||
# TODO: split out the whole "instance UUID" generation logic into a separate
|
||||
# utility function. Telemetry should not be aware at all of how the UUID is
|
||||
# generated/stored.
|
||||
@@ -52,7 +53,7 @@ def get_or_generate_uuid(tenant_id: str | None) -> str:
|
||||
if _CACHED_UUID is not None:
|
||||
return _CACHED_UUID
|
||||
|
||||
kv_store = get_kv_store(tenant_id=tenant_id)
|
||||
kv_store = get_kv_store()
|
||||
|
||||
try:
|
||||
_CACHED_UUID = cast(str, kv_store.load(KV_CUSTOMER_UUID_KEY))
|
||||
@@ -63,18 +64,18 @@ def get_or_generate_uuid(tenant_id: str | None) -> str:
|
||||
return _CACHED_UUID
|
||||
|
||||
|
||||
def _get_or_generate_instance_domain(tenant_id: str | None = None) -> str | None: #
|
||||
def _get_or_generate_instance_domain() -> str | None: #
|
||||
global _CACHED_INSTANCE_DOMAIN
|
||||
|
||||
if _CACHED_INSTANCE_DOMAIN is not None:
|
||||
return _CACHED_INSTANCE_DOMAIN
|
||||
|
||||
kv_store = get_kv_store(tenant_id=tenant_id)
|
||||
kv_store = get_kv_store()
|
||||
|
||||
try:
|
||||
_CACHED_INSTANCE_DOMAIN = cast(str, kv_store.load(KV_INSTANCE_DOMAIN_KEY))
|
||||
except KvKeyNotFoundError:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
with get_session_with_tenant() as db_session:
|
||||
first_user = db_session.query(User).first()
|
||||
if first_user:
|
||||
_CACHED_INSTANCE_DOMAIN = first_user.email.split("@")[-1]
|
||||
@@ -103,7 +104,7 @@ def optional_telemetry(
|
||||
customer_uuid = (
|
||||
_get_or_generate_customer_id_mt(tenant_id)
|
||||
if MULTI_TENANT
|
||||
else get_or_generate_uuid(tenant_id)
|
||||
else get_or_generate_uuid()
|
||||
)
|
||||
payload = {
|
||||
"data": data,
|
||||
@@ -115,9 +116,7 @@ def optional_telemetry(
|
||||
"is_cloud": MULTI_TENANT,
|
||||
}
|
||||
if ENTERPRISE_EDITION_ENABLED:
|
||||
payload["instance_domain"] = _get_or_generate_instance_domain(
|
||||
tenant_id
|
||||
)
|
||||
payload["instance_domain"] = _get_or_generate_instance_domain()
|
||||
requests.post(
|
||||
_DANSWER_TELEMETRY_ENDPOINT,
|
||||
headers={"Content-Type": "application/json"},
|
||||
@@ -128,8 +127,12 @@ def optional_telemetry(
|
||||
# This way it silences all thread level logging as well
|
||||
pass
|
||||
|
||||
# Run in separate thread to have minimal overhead in main flows
|
||||
thread = threading.Thread(target=telemetry_logic, daemon=True)
|
||||
# Run in separate thread with the same context as the current thread
|
||||
# This is to ensure that the thread gets the current tenant ID
|
||||
current_context = contextvars.copy_context()
|
||||
thread = threading.Thread(
|
||||
target=lambda: current_context.run(telemetry_logic), daemon=True
|
||||
)
|
||||
thread.start()
|
||||
except Exception:
|
||||
# Should never interfere with normal functions of Onyx
|
||||
|
||||
@@ -6,3 +6,13 @@ from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
CURRENT_TENANT_ID_CONTEXTVAR = contextvars.ContextVar(
|
||||
"current_tenant_id", default=POSTGRES_DEFAULT_SCHEMA
|
||||
)
|
||||
|
||||
|
||||
"""Utils related to contextvars"""
|
||||
|
||||
|
||||
def get_current_tenant_id() -> str:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if tenant_id is None:
|
||||
raise RuntimeError("Tenant ID is not set. This should never happen.")
|
||||
return tenant_id
|
||||
|
||||
@@ -9,6 +9,7 @@ from litellm.types.utils import ChatCompletionDeltaToolCall
|
||||
from litellm.types.utils import Delta
|
||||
from litellm.types.utils import Function as LiteLLMFunction
|
||||
|
||||
from onyx.configs.app_configs import MOCK_LLM_RESPONSE
|
||||
from onyx.llm.chat_llm import DefaultMultiLLM
|
||||
|
||||
|
||||
@@ -143,6 +144,7 @@ def test_multiple_tool_calls(default_multi_llm: DefaultMultiLLM) -> None:
|
||||
temperature=0.0, # Default value from GEN_AI_TEMPERATURE
|
||||
timeout=30,
|
||||
parallel_tool_calls=False,
|
||||
mock_response=MOCK_LLM_RESPONSE,
|
||||
)
|
||||
|
||||
|
||||
@@ -287,4 +289,5 @@ def test_multiple_tool_calls_streaming(default_multi_llm: DefaultMultiLLM) -> No
|
||||
temperature=0.0, # Default value from GEN_AI_TEMPERATURE
|
||||
timeout=30,
|
||||
parallel_tool_calls=False,
|
||||
mock_response=MOCK_LLM_RESPONSE,
|
||||
)
|
||||
|
||||
@@ -1,75 +0,0 @@
|
||||
apiVersion: autoscaling/v2
|
||||
kind: HorizontalPodAutoscaler
|
||||
metadata:
|
||||
name: celery-worker-heavy-hpa
|
||||
spec:
|
||||
scaleTargetRef:
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
name: celery-worker-heavy
|
||||
minReplicas: 1
|
||||
maxReplicas: 5
|
||||
metrics:
|
||||
- type: Resource
|
||||
resource:
|
||||
name: cpu
|
||||
target:
|
||||
type: Utilization
|
||||
averageUtilization: 60
|
||||
---
|
||||
apiVersion: autoscaling/v2
|
||||
kind: HorizontalPodAutoscaler
|
||||
metadata:
|
||||
name: celery-worker-light-hpa
|
||||
spec:
|
||||
scaleTargetRef:
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
name: celery-worker-light
|
||||
minReplicas: 1
|
||||
maxReplicas: 10
|
||||
metrics:
|
||||
- type: Resource
|
||||
resource:
|
||||
name: cpu
|
||||
target:
|
||||
type: Utilization
|
||||
averageUtilization: 70
|
||||
---
|
||||
apiVersion: autoscaling/v2
|
||||
kind: HorizontalPodAutoscaler
|
||||
metadata:
|
||||
name: celery-worker-indexing-hpa
|
||||
spec:
|
||||
scaleTargetRef:
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
name: celery-worker-indexing
|
||||
minReplicas: 1
|
||||
maxReplicas: 10
|
||||
metrics:
|
||||
- type: Resource
|
||||
resource:
|
||||
name: cpu
|
||||
target:
|
||||
type: Utilization
|
||||
averageUtilization: 70
|
||||
---
|
||||
apiVersion: autoscaling/v2
|
||||
kind: HorizontalPodAutoscaler
|
||||
metadata:
|
||||
name: celery-worker-monitoring-hpa
|
||||
spec:
|
||||
scaleTargetRef:
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
name: celery-worker-indexing
|
||||
minReplicas: 1
|
||||
maxReplicas: 4
|
||||
metrics:
|
||||
- type: Resource
|
||||
resource:
|
||||
name: cpu
|
||||
target:
|
||||
type: Utilization
|
||||
averageUtilization: 70
|
||||
@@ -1,13 +0,0 @@
|
||||
apiVersion: keda.sh/v1alpha1
|
||||
kind: TriggerAuthentication
|
||||
metadata:
|
||||
name: celery-worker-auth
|
||||
namespace: onyx
|
||||
spec:
|
||||
secretTargetRef:
|
||||
- parameter: host
|
||||
name: keda-redis-secret
|
||||
key: host
|
||||
- parameter: password
|
||||
name: keda-redis-secret
|
||||
key: password
|
||||
@@ -1,53 +0,0 @@
|
||||
apiVersion: keda.sh/v1alpha1
|
||||
kind: ScaledObject
|
||||
metadata:
|
||||
name: celery-worker-indexing-scaledobject
|
||||
namespace: onyx
|
||||
labels:
|
||||
app: celery-worker-indexing
|
||||
spec:
|
||||
scaleTargetRef:
|
||||
name: celery-worker-indexing
|
||||
minReplicaCount: 1
|
||||
maxReplicaCount: 30
|
||||
triggers:
|
||||
- type: redis
|
||||
metadata:
|
||||
sslEnabled: "true"
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: connector_indexing
|
||||
listLength: "1"
|
||||
databaseIndex: "15"
|
||||
authenticationRef:
|
||||
name: celery-worker-auth
|
||||
|
||||
- type: redis
|
||||
metadata:
|
||||
sslEnabled: "true"
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: connector_indexing:2
|
||||
listLength: "1"
|
||||
databaseIndex: "15"
|
||||
authenticationRef:
|
||||
name: celery-worker-auth
|
||||
- type: redis
|
||||
metadata:
|
||||
sslEnabled: "true"
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: connector_indexing:3
|
||||
listLength: "1"
|
||||
databaseIndex: "15"
|
||||
authenticationRef:
|
||||
name: celery-worker-auth
|
||||
- type: cpu
|
||||
metadata:
|
||||
type: Utilization
|
||||
value: "70"
|
||||
|
||||
- type: memory
|
||||
metadata:
|
||||
type: Utilization
|
||||
value: "70"
|
||||
@@ -1,58 +0,0 @@
|
||||
apiVersion: keda.sh/v1alpha1
|
||||
kind: ScaledObject
|
||||
metadata:
|
||||
name: celery-worker-light-scaledobject
|
||||
namespace: onyx
|
||||
labels:
|
||||
app: celery-worker-light
|
||||
spec:
|
||||
scaleTargetRef:
|
||||
name: celery-worker-light
|
||||
minReplicaCount: 5
|
||||
maxReplicaCount: 20
|
||||
triggers:
|
||||
- type: redis
|
||||
metadata:
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: vespa_metadata_sync
|
||||
listLength: "1"
|
||||
databaseIndex: "15"
|
||||
authenticationRef:
|
||||
name: celery-worker-auth
|
||||
- type: redis
|
||||
metadata:
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: vespa_metadata_sync:2
|
||||
listLength: "1"
|
||||
databaseIndex: "15"
|
||||
authenticationRef:
|
||||
name: celery-worker-auth
|
||||
- type: redis
|
||||
metadata:
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: vespa_metadata_sync:3
|
||||
listLength: "1"
|
||||
databaseIndex: "15"
|
||||
authenticationRef:
|
||||
name: celery-worker-auth
|
||||
- type: redis
|
||||
metadata:
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: connector_deletion
|
||||
listLength: "1"
|
||||
databaseIndex: "15"
|
||||
authenticationRef:
|
||||
name: celery-worker-auth
|
||||
- type: redis
|
||||
metadata:
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: connector_deletion:2
|
||||
listLength: "1"
|
||||
databaseIndex: "15"
|
||||
authenticationRef:
|
||||
name: celery-worker-auth
|
||||
@@ -1,70 +0,0 @@
|
||||
apiVersion: keda.sh/v1alpha1
|
||||
kind: ScaledObject
|
||||
metadata:
|
||||
name: celery-worker-primary-scaledobject
|
||||
namespace: onyx
|
||||
labels:
|
||||
app: celery-worker-primary
|
||||
spec:
|
||||
scaleTargetRef:
|
||||
name: celery-worker-primary
|
||||
pollingInterval: 15 # Check every 15 seconds
|
||||
cooldownPeriod: 30 # Wait 30 seconds before scaling down
|
||||
minReplicaCount: 4
|
||||
maxReplicaCount: 4
|
||||
triggers:
|
||||
- type: redis
|
||||
metadata:
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: celery
|
||||
listLength: "1"
|
||||
databaseIndex: "15"
|
||||
authenticationRef:
|
||||
name: celery-worker-auth
|
||||
|
||||
- type: redis
|
||||
metadata:
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: celery:1
|
||||
listLength: "1"
|
||||
databaseIndex: "15"
|
||||
authenticationRef:
|
||||
name: celery-worker-auth
|
||||
- type: redis
|
||||
metadata:
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: celery:2
|
||||
listLength: "1"
|
||||
databaseIndex: "15"
|
||||
authenticationRef:
|
||||
name: celery-worker-auth
|
||||
- type: redis
|
||||
metadata:
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: celery:3
|
||||
listLength: "1"
|
||||
databaseIndex: "15"
|
||||
authenticationRef:
|
||||
name: celery-worker-auth
|
||||
- type: redis
|
||||
metadata:
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: periodic_tasks
|
||||
listLength: "1"
|
||||
databaseIndex: "15"
|
||||
authenticationRef:
|
||||
name: celery-worker-auth
|
||||
- type: redis
|
||||
metadata:
|
||||
port: "6379"
|
||||
enableTLS: "true"
|
||||
listName: periodic_tasks:2
|
||||
listLength: "1"
|
||||
databaseIndex: "15"
|
||||
authenticationRef:
|
||||
name: celery-worker-auth
|
||||
@@ -1,19 +0,0 @@
|
||||
apiVersion: keda.sh/v1alpha1
|
||||
kind: ScaledObject
|
||||
metadata:
|
||||
name: indexing-model-server-scaledobject
|
||||
namespace: onyx
|
||||
labels:
|
||||
app: indexing-model-server
|
||||
spec:
|
||||
scaleTargetRef:
|
||||
name: indexing-model-server-deployment
|
||||
pollingInterval: 15 # Check every 15 seconds
|
||||
cooldownPeriod: 30 # Wait 30 seconds before scaling down
|
||||
minReplicaCount: 10
|
||||
maxReplicaCount: 10
|
||||
triggers:
|
||||
- type: cpu
|
||||
metadata:
|
||||
type: Utilization
|
||||
value: "70"
|
||||
@@ -1,9 +0,0 @@
|
||||
apiVersion: v1
|
||||
kind: Secret
|
||||
metadata:
|
||||
name: keda-redis-secret
|
||||
namespace: onyx
|
||||
type: Opaque
|
||||
data:
|
||||
host: { base64 encoded host here }
|
||||
password: { base64 encoded password here }
|
||||
@@ -1,44 +0,0 @@
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: celery-beat
|
||||
spec:
|
||||
replicas: 1
|
||||
selector:
|
||||
matchLabels:
|
||||
app: celery-beat
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
app: celery-beat
|
||||
spec:
|
||||
containers:
|
||||
- name: celery-beat
|
||||
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.21
|
||||
imagePullPolicy: IfNotPresent
|
||||
command:
|
||||
[
|
||||
"celery",
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.beat",
|
||||
"beat",
|
||||
"--loglevel=INFO",
|
||||
]
|
||||
env:
|
||||
- name: REDIS_PASSWORD
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: onyx-secrets
|
||||
key: redis_password
|
||||
- name: ONYX_VERSION
|
||||
value: "v0.11.0-cloud.beta.8"
|
||||
envFrom:
|
||||
- configMapRef:
|
||||
name: env-configmap
|
||||
resources:
|
||||
requests:
|
||||
cpu: "250m"
|
||||
memory: "512Mi"
|
||||
limits:
|
||||
cpu: "500m"
|
||||
memory: "1Gi"
|
||||
@@ -1,60 +0,0 @@
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: celery-worker-heavy
|
||||
spec:
|
||||
replicas: 2
|
||||
selector:
|
||||
matchLabels:
|
||||
app: celery-worker-heavy
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
app: celery-worker-heavy
|
||||
spec:
|
||||
containers:
|
||||
- name: celery-worker-heavy
|
||||
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.21
|
||||
imagePullPolicy: IfNotPresent
|
||||
command:
|
||||
[
|
||||
"celery",
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.heavy",
|
||||
"worker",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=heavy@%n",
|
||||
"-Q",
|
||||
"connector_pruning,connector_doc_permissions_sync,connector_external_group_sync",
|
||||
]
|
||||
env:
|
||||
- name: REDIS_PASSWORD
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: onyx-secrets
|
||||
key: redis_password
|
||||
- name: ONYX_VERSION
|
||||
value: "v0.11.0-cloud.beta.8"
|
||||
envFrom:
|
||||
- configMapRef:
|
||||
name: env-configmap
|
||||
volumeMounts:
|
||||
- name: vespa-certificates
|
||||
mountPath: "/app/certs"
|
||||
readOnly: true
|
||||
resources:
|
||||
requests:
|
||||
cpu: "1000m"
|
||||
memory: "2Gi"
|
||||
limits:
|
||||
cpu: "2000m"
|
||||
memory: "4Gi"
|
||||
volumes:
|
||||
- name: vespa-certificates
|
||||
secret:
|
||||
secretName: vespa-certificates
|
||||
items:
|
||||
- key: cert.pem
|
||||
path: cert.pem
|
||||
- key: key.pem
|
||||
path: key.pem
|
||||
@@ -1,62 +0,0 @@
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: celery-worker-indexing
|
||||
spec:
|
||||
replicas: 1
|
||||
selector:
|
||||
matchLabels:
|
||||
app: celery-worker-indexing
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
app: celery-worker-indexing
|
||||
spec:
|
||||
containers:
|
||||
- name: celery-worker-indexing
|
||||
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.21
|
||||
imagePullPolicy: IfNotPresent
|
||||
command:
|
||||
[
|
||||
"celery",
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.indexing",
|
||||
"worker",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=indexing@%n",
|
||||
"-Q",
|
||||
"connector_indexing",
|
||||
"--prefetch-multiplier=1",
|
||||
"--concurrency=10",
|
||||
]
|
||||
env:
|
||||
- name: REDIS_PASSWORD
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: onyx-secrets
|
||||
key: redis_password
|
||||
- name: ONYX_VERSION
|
||||
value: "v0.11.0-cloud.beta.8"
|
||||
envFrom:
|
||||
- configMapRef:
|
||||
name: env-configmap
|
||||
volumeMounts:
|
||||
- name: vespa-certificates
|
||||
mountPath: "/app/certs"
|
||||
readOnly: true
|
||||
resources:
|
||||
requests:
|
||||
cpu: "500m"
|
||||
memory: "4Gi"
|
||||
limits:
|
||||
cpu: "1000m"
|
||||
memory: "8Gi"
|
||||
volumes:
|
||||
- name: vespa-certificates
|
||||
secret:
|
||||
secretName: vespa-certificates
|
||||
items:
|
||||
- key: cert.pem
|
||||
path: cert.pem
|
||||
- key: key.pem
|
||||
path: key.pem
|
||||
@@ -1,62 +0,0 @@
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: celery-worker-light
|
||||
spec:
|
||||
replicas: 1
|
||||
selector:
|
||||
matchLabels:
|
||||
app: celery-worker-light
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
app: celery-worker-light
|
||||
spec:
|
||||
containers:
|
||||
- name: celery-worker-light
|
||||
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.21
|
||||
imagePullPolicy: IfNotPresent
|
||||
command:
|
||||
[
|
||||
"celery",
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.light",
|
||||
"worker",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=light@%n",
|
||||
"-Q",
|
||||
"vespa_metadata_sync,connector_deletion,doc_permissions_upsert",
|
||||
"--prefetch-multiplier=1",
|
||||
"--concurrency=10",
|
||||
]
|
||||
env:
|
||||
- name: REDIS_PASSWORD
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: onyx-secrets
|
||||
key: redis_password
|
||||
- name: ONYX_VERSION
|
||||
value: "v0.11.0-cloud.beta.8"
|
||||
envFrom:
|
||||
- configMapRef:
|
||||
name: env-configmap
|
||||
volumeMounts:
|
||||
- name: vespa-certificates
|
||||
mountPath: "/app/certs"
|
||||
readOnly: true
|
||||
resources:
|
||||
requests:
|
||||
cpu: "500m"
|
||||
memory: "1Gi"
|
||||
limits:
|
||||
cpu: "1000m"
|
||||
memory: "2Gi"
|
||||
volumes:
|
||||
- name: vespa-certificates
|
||||
secret:
|
||||
secretName: vespa-certificates
|
||||
items:
|
||||
- key: cert.pem
|
||||
path: cert.pem
|
||||
- key: key.pem
|
||||
path: key.pem
|
||||
@@ -1,62 +0,0 @@
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: celery-worker-monitoring
|
||||
spec:
|
||||
replicas: 2
|
||||
selector:
|
||||
matchLabels:
|
||||
app: celery-worker-monitoring
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
app: celery-worker-monitoring
|
||||
spec:
|
||||
containers:
|
||||
- name: celery-worker-monitoring
|
||||
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.21
|
||||
imagePullPolicy: IfNotPresent
|
||||
command:
|
||||
[
|
||||
"celery",
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.monitoring",
|
||||
"worker",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=monitoring@%n",
|
||||
"-Q",
|
||||
"monitoring",
|
||||
"--prefetch-multiplier=8",
|
||||
"--concurrency=8",
|
||||
]
|
||||
env:
|
||||
- name: REDIS_PASSWORD
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: onyx-secrets
|
||||
key: redis_password
|
||||
- name: ONYX_VERSION
|
||||
value: "v0.11.0-cloud.beta.8"
|
||||
envFrom:
|
||||
- configMapRef:
|
||||
name: env-configmap
|
||||
volumeMounts:
|
||||
- name: vespa-certificates
|
||||
mountPath: "/app/certs"
|
||||
readOnly: true
|
||||
resources:
|
||||
requests:
|
||||
cpu: "1000m"
|
||||
memory: "1Gi"
|
||||
limits:
|
||||
cpu: "1000m"
|
||||
memory: "1Gi"
|
||||
volumes:
|
||||
- name: vespa-certificates
|
||||
secret:
|
||||
secretName: vespa-certificates
|
||||
items:
|
||||
- key: cert.pem
|
||||
path: cert.pem
|
||||
- key: key.pem
|
||||
path: key.pem
|
||||
@@ -1,62 +0,0 @@
|
||||
apiVersion: apps/v1
|
||||
kind: Deployment
|
||||
metadata:
|
||||
name: celery-worker-primary
|
||||
spec:
|
||||
replicas: 1
|
||||
selector:
|
||||
matchLabels:
|
||||
app: celery-worker-primary
|
||||
template:
|
||||
metadata:
|
||||
labels:
|
||||
app: celery-worker-primary
|
||||
spec:
|
||||
containers:
|
||||
- name: celery-worker-primary
|
||||
image: onyxdotapp/onyx-backend-cloud:v0.14.0-cloud.beta.21
|
||||
imagePullPolicy: IfNotPresent
|
||||
command:
|
||||
[
|
||||
"celery",
|
||||
"-A",
|
||||
"onyx.background.celery.versioned_apps.primary",
|
||||
"worker",
|
||||
"--loglevel=INFO",
|
||||
"--hostname=primary@%n",
|
||||
"-Q",
|
||||
"celery,periodic_tasks",
|
||||
"--prefetch-multiplier=1",
|
||||
"--concurrency=10",
|
||||
]
|
||||
env:
|
||||
- name: REDIS_PASSWORD
|
||||
valueFrom:
|
||||
secretKeyRef:
|
||||
name: onyx-secrets
|
||||
key: redis_password
|
||||
- name: ONYX_VERSION
|
||||
value: "v0.11.0-cloud.beta.8"
|
||||
envFrom:
|
||||
- configMapRef:
|
||||
name: env-configmap
|
||||
volumeMounts:
|
||||
- name: vespa-certificates
|
||||
mountPath: "/app/certs"
|
||||
readOnly: true
|
||||
resources:
|
||||
requests:
|
||||
cpu: "500m"
|
||||
memory: "1Gi"
|
||||
limits:
|
||||
cpu: "1000m"
|
||||
memory: "2Gi"
|
||||
volumes:
|
||||
- name: vespa-certificates
|
||||
secret:
|
||||
secretName: vespa-certificates
|
||||
items:
|
||||
- key: cert.pem
|
||||
path: cert.pem
|
||||
- key: key.pem
|
||||
path: key.pem
|
||||
@@ -6,7 +6,7 @@ sources:
|
||||
- "https://github.com/onyx-dot-app/onyx"
|
||||
type: application
|
||||
version: 0.2.1
|
||||
appVersion: "latest"
|
||||
appVersion: latest
|
||||
annotations:
|
||||
category: Productivity
|
||||
licenses: MIT
|
||||
|
||||
@@ -45,10 +45,10 @@ spec:
|
||||
- |
|
||||
alembic upgrade head &&
|
||||
echo "Starting Onyx Api Server" &&
|
||||
uvicorn onyx.main:app --host 0.0.0.0 --port 8080
|
||||
uvicorn onyx.main:app --host 0.0.0.0 --port {{ .Values.api.containerPorts.server }}
|
||||
ports:
|
||||
- name: api-server-port
|
||||
containerPort: {{ .Values.api.service.port }}
|
||||
containerPort: {{ .Values.api.containerPorts.server }}
|
||||
protocol: TCP
|
||||
resources:
|
||||
{{- toYaml .Values.api.resources | nindent 12 }}
|
||||
|
||||
@@ -11,10 +11,10 @@ metadata:
|
||||
spec:
|
||||
type: {{ .Values.api.service.type }}
|
||||
ports:
|
||||
- port: {{ .Values.api.service.port }}
|
||||
targetPort: api-server-port
|
||||
- port: {{ .Values.api.service.servicePort }}
|
||||
targetPort: {{ .Values.api.service.targetPort }}
|
||||
protocol: TCP
|
||||
name: api-server-port
|
||||
name: {{ .Values.api.service.portName }}
|
||||
selector:
|
||||
{{- include "onyx-stack.selectorLabels" . | nindent 4 }}
|
||||
{{- if .Values.api.deploymentLabels }}
|
||||
|
||||
@@ -5,7 +5,7 @@ metadata:
|
||||
labels:
|
||||
{{- include "onyx-stack.labels" . | nindent 4 }}
|
||||
spec:
|
||||
replicas: 1
|
||||
replicas: {{ .Values.indexCapability.replicaCount }}
|
||||
selector:
|
||||
matchLabels:
|
||||
{{- include "onyx-stack.selectorLabels" . | nindent 6 }}
|
||||
@@ -25,12 +25,14 @@ spec:
|
||||
{{- end }}
|
||||
spec:
|
||||
containers:
|
||||
- name: indexing-model-server
|
||||
image: onyxdotapp/onyx-model-server:latest
|
||||
imagePullPolicy: IfNotPresent
|
||||
command: [ "uvicorn", "model_server.main:app", "--host", "0.0.0.0", "--port", "9000", "--limit-concurrency", "10" ]
|
||||
- name: {{ .Values.indexCapability.name }}
|
||||
image: "{{ .Values.indexCapability.image.repository }}:{{ .Values.indexCapability.image.tag | default .Chart.AppVersion }}"
|
||||
imagePullPolicy: {{ .Values.indexCapability.image.pullPolicy }}
|
||||
command: [ "uvicorn", "model_server.main:app", "--host", "0.0.0.0", "--port", "{{ .Values.indexCapability.containerPorts.server }}", "--limit-concurrency", "{{ .Values.indexCapability.limitConcurrency }}" ]
|
||||
ports:
|
||||
- containerPort: 9000
|
||||
- name: model-server
|
||||
containerPort: {{ .Values.indexCapability.containerPorts.server }}
|
||||
protocol: TCP
|
||||
envFrom:
|
||||
- configMapRef:
|
||||
name: {{ .Values.config.envConfigMapName }}
|
||||
|
||||
@@ -3,8 +3,9 @@ kind: PersistentVolumeClaim
|
||||
metadata:
|
||||
name: {{ .Values.indexCapability.indexingModelPVC.name }}
|
||||
spec:
|
||||
storageClassName: {{ .Values.persistent.storageClassName }}
|
||||
accessModes:
|
||||
- {{ .Values.indexCapability.indexingModelPVC.accessMode | quote }}
|
||||
resources:
|
||||
requests:
|
||||
storage: {{ .Values.indexCapability.indexingModelPVC.storage | quote }}
|
||||
storage: {{ .Values.indexCapability.indexingModelPVC.storage | quote }}
|
||||
|
||||
@@ -11,8 +11,8 @@ spec:
|
||||
{{- toYaml .Values.indexCapability.deploymentLabels | nindent 4 }}
|
||||
{{- end }}
|
||||
ports:
|
||||
- name: {{ .Values.indexCapability.service.name }}
|
||||
- name: {{ .Values.indexCapability.service.portName }}
|
||||
protocol: TCP
|
||||
port: {{ .Values.indexCapability.service.port }}
|
||||
targetPort: {{ .Values.indexCapability.service.port }}
|
||||
type: {{ .Values.indexCapability.service.type }}
|
||||
port: {{ .Values.indexCapability.service.servicePort }}
|
||||
targetPort: {{ .Values.indexCapability.service.targetPort }}
|
||||
type: {{ .Values.indexCapability.service.type }}
|
||||
|
||||
@@ -3,14 +3,14 @@ kind: Deployment
|
||||
metadata:
|
||||
name: {{ include "onyx-stack.fullname" . }}-inference-model
|
||||
labels:
|
||||
{{- range .Values.inferenceCapability.deployment.labels }}
|
||||
{{- range .Values.inferenceCapability.labels }}
|
||||
{{ .key }}: {{ .value }}
|
||||
{{- end }}
|
||||
spec:
|
||||
replicas: {{ .Values.inferenceCapability.deployment.replicas }}
|
||||
replicas: {{ .Values.inferenceCapability.replicaCount }}
|
||||
selector:
|
||||
matchLabels:
|
||||
{{- range .Values.inferenceCapability.deployment.labels }}
|
||||
{{- range .Values.inferenceCapability.labels }}
|
||||
{{ .key }}: {{ .value }}
|
||||
{{- end }}
|
||||
template:
|
||||
@@ -21,24 +21,26 @@ spec:
|
||||
{{- end }}
|
||||
spec:
|
||||
containers:
|
||||
- name: {{ .Values.inferenceCapability.service.name }}
|
||||
image: {{ .Values.inferenceCapability.deployment.image.repository }}:{{ .Values.inferenceCapability.deployment.image.tag }}
|
||||
imagePullPolicy: {{ .Values.inferenceCapability.deployment.image.pullPolicy }}
|
||||
command: {{ toYaml .Values.inferenceCapability.deployment.command | nindent 14 }}
|
||||
- name: model-server-inference
|
||||
image: "{{ .Values.inferenceCapability.image.repository }}:{{ .Values.inferenceCapability.image.tag | default .Chart.AppVersion }}"
|
||||
imagePullPolicy: {{ .Values.inferenceCapability.image.pullPolicy }}
|
||||
command: [ "uvicorn", "model_server.main:app", "--host", "0.0.0.0", "--port", "{{ .Values.inferenceCapability.containerPorts.server }}" ]
|
||||
ports:
|
||||
- containerPort: {{ .Values.inferenceCapability.service.port }}
|
||||
- name: model-server
|
||||
containerPort: {{ .Values.inferenceCapability.containerPorts.server }}
|
||||
protocol: TCP
|
||||
envFrom:
|
||||
- configMapRef:
|
||||
name: {{ .Values.config.envConfigMapName }}
|
||||
env:
|
||||
{{- include "onyx-stack.envSecrets" . | nindent 12}}
|
||||
volumeMounts:
|
||||
{{- range .Values.inferenceCapability.deployment.volumeMounts }}
|
||||
{{- range .Values.inferenceCapability.volumeMounts }}
|
||||
- name: {{ .name }}
|
||||
mountPath: {{ .mountPath }}
|
||||
{{- end }}
|
||||
volumes:
|
||||
{{- range .Values.inferenceCapability.deployment.volumes }}
|
||||
{{- range .Values.inferenceCapability.volumes }}
|
||||
- name: {{ .name }}
|
||||
persistentVolumeClaim:
|
||||
claimName: {{ .persistentVolumeClaim.claimName }}
|
||||
|
||||
@@ -3,6 +3,7 @@ kind: PersistentVolumeClaim
|
||||
metadata:
|
||||
name: {{ .Values.inferenceCapability.pvc.name }}
|
||||
spec:
|
||||
storageClassName: {{ .Values.persistent.storageClassName }}
|
||||
accessModes:
|
||||
{{- toYaml .Values.inferenceCapability.pvc.accessModes | nindent 4 }}
|
||||
resources:
|
||||
|
||||
@@ -5,11 +5,11 @@ metadata:
|
||||
spec:
|
||||
type: {{ .Values.inferenceCapability.service.type }}
|
||||
ports:
|
||||
- port: {{ .Values.inferenceCapability.service.port }}
|
||||
targetPort: {{ .Values.inferenceCapability.service.port }}
|
||||
- port: {{ .Values.inferenceCapability.service.servicePort}}
|
||||
targetPort: {{ .Values.inferenceCapability.service.targetPort }}
|
||||
protocol: TCP
|
||||
name: {{ .Values.inferenceCapability.service.name }}
|
||||
name: {{ .Values.inferenceCapability.service.portName }}
|
||||
selector:
|
||||
{{- range .Values.inferenceCapability.deployment.labels }}
|
||||
{{- range .Values.inferenceCapability.labels }}
|
||||
{{ .key }}: {{ .value }}
|
||||
{{- end }}
|
||||
|
||||
@@ -5,11 +5,11 @@ metadata:
|
||||
data:
|
||||
nginx.conf: |
|
||||
upstream api_server {
|
||||
server {{ include "onyx-stack.fullname" . }}-api-service:{{ .Values.api.service.port }} fail_timeout=0;
|
||||
server {{ include "onyx-stack.fullname" . }}-api-service:{{ .Values.api.service.servicePort }} fail_timeout=0;
|
||||
}
|
||||
|
||||
upstream web_server {
|
||||
server {{ include "onyx-stack.fullname" . }}-webserver:{{ .Values.webserver.service.port }} fail_timeout=0;
|
||||
server {{ include "onyx-stack.fullname" . }}-webserver:{{ .Values.webserver.service.servicePort }} fail_timeout=0;
|
||||
}
|
||||
|
||||
server {
|
||||
|
||||
@@ -11,5 +11,5 @@ spec:
|
||||
- name: wget
|
||||
image: busybox
|
||||
command: ['wget']
|
||||
args: ['{{ include "onyx-stack.fullname" . }}-webserver:{{ .Values.webserver.service.port }}']
|
||||
args: ['{{ include "onyx-stack.fullname" . }}-webserver:{{ .Values.webserver.service.servicePort }}']
|
||||
restartPolicy: Never
|
||||
|
||||
@@ -41,7 +41,7 @@ spec:
|
||||
imagePullPolicy: {{ .Values.webserver.image.pullPolicy }}
|
||||
ports:
|
||||
- name: http
|
||||
containerPort: {{ .Values.webserver.service.port }}
|
||||
containerPort: {{ .Values.webserver.containerPorts.server }}
|
||||
protocol: TCP
|
||||
resources:
|
||||
{{- toYaml .Values.webserver.resources | nindent 12 }}
|
||||
|
||||
@@ -10,8 +10,8 @@ metadata:
|
||||
spec:
|
||||
type: {{ .Values.webserver.service.type }}
|
||||
ports:
|
||||
- port: {{ .Values.webserver.service.port }}
|
||||
targetPort: http
|
||||
- port: {{ .Values.webserver.service.servicePort }}
|
||||
targetPort: {{ .Values.webserver.service.targetPort }}
|
||||
protocol: TCP
|
||||
name: http
|
||||
selector:
|
||||
|
||||
@@ -2,62 +2,73 @@
|
||||
# This is a YAML-formatted file.
|
||||
# Declare variables to be passed into your templates.
|
||||
|
||||
postgresql:
|
||||
primary:
|
||||
persistence:
|
||||
size: 5Gi
|
||||
enabled: true
|
||||
auth:
|
||||
existingSecret: onyx-secrets
|
||||
secretKeys:
|
||||
# overwriting as postgres typically expects 'postgres-password'
|
||||
adminPasswordKey: postgres_password
|
||||
imagePullSecrets: []
|
||||
nameOverride: ""
|
||||
fullnameOverride: ""
|
||||
|
||||
persistent:
|
||||
storageClassName: ""
|
||||
|
||||
inferenceCapability:
|
||||
service:
|
||||
name: inference-model-server-service
|
||||
portName: modelserver
|
||||
type: ClusterIP
|
||||
port: 9000
|
||||
servicePort: 9000
|
||||
targetPort: 9000
|
||||
pvc:
|
||||
name: inference-model-pvc
|
||||
accessModes:
|
||||
- ReadWriteOnce
|
||||
storage: 3Gi
|
||||
deployment:
|
||||
name: inference-model-server-deployment
|
||||
replicas: 1
|
||||
labels:
|
||||
- key: app
|
||||
value: inference-model-server
|
||||
image:
|
||||
repository: onyxdotapp/onyx-model-server
|
||||
tag: latest
|
||||
pullPolicy: IfNotPresent
|
||||
command:
|
||||
[
|
||||
"uvicorn",
|
||||
"model_server.main:app",
|
||||
"--host",
|
||||
"0.0.0.0",
|
||||
"--port",
|
||||
"9000",
|
||||
]
|
||||
port: 9000
|
||||
volumeMounts:
|
||||
- name: inference-model-storage
|
||||
mountPath: /root/.cache
|
||||
volumes:
|
||||
- name: inference-model-storage
|
||||
persistentVolumeClaim:
|
||||
claimName: inference-model-pvc
|
||||
name: inference-model-server
|
||||
replicaCount: 1
|
||||
labels:
|
||||
- key: app
|
||||
value: inference-model-server
|
||||
image:
|
||||
repository: onyxdotapp/onyx-model-server
|
||||
# Overrides the image tag whose default is the chart appVersion.
|
||||
tag: ""
|
||||
pullPolicy: IfNotPresent
|
||||
containerPorts:
|
||||
server: 9000
|
||||
volumeMounts:
|
||||
- name: inference-model-storage
|
||||
mountPath: /root/.cache
|
||||
volumes:
|
||||
- name: inference-model-storage
|
||||
persistentVolumeClaim:
|
||||
claimName: inference-model-pvc
|
||||
podLabels:
|
||||
- key: app
|
||||
value: inference-model-server
|
||||
|
||||
indexCapability:
|
||||
service:
|
||||
portName: modelserver
|
||||
type: ClusterIP
|
||||
port: 9000
|
||||
name: indexing-model-server-port
|
||||
servicePort: 9000
|
||||
targetPort: 9000
|
||||
replicaCount: 1
|
||||
name: indexing-model-server
|
||||
deploymentLabels:
|
||||
app: indexing-model-server
|
||||
podLabels:
|
||||
app: indexing-model-server
|
||||
indexingOnly: "True"
|
||||
podAnnotations: {}
|
||||
containerPorts:
|
||||
server: 9000
|
||||
volumeMounts:
|
||||
- name: indexing-model-storage
|
||||
mountPath: /root/.cache
|
||||
@@ -69,7 +80,12 @@ indexCapability:
|
||||
name: indexing-model-storage
|
||||
accessMode: "ReadWriteOnce"
|
||||
storage: "3Gi"
|
||||
|
||||
image:
|
||||
repository: onyxdotapp/onyx-model-server
|
||||
# Overrides the image tag whose default is the chart appVersion.
|
||||
tag: ""
|
||||
pullPolicy: IfNotPresent
|
||||
limitConcurrency: 10
|
||||
config:
|
||||
envConfigMapName: env-configmap
|
||||
|
||||
@@ -84,16 +100,6 @@ serviceAccount:
|
||||
# If not set and create is true, a name is generated using the fullname template
|
||||
name: ""
|
||||
|
||||
postgresql:
|
||||
primary:
|
||||
persistence:
|
||||
size: 5Gi
|
||||
enabled: true
|
||||
auth:
|
||||
existingSecret: onyx-secrets
|
||||
secretKeys:
|
||||
adminPasswordKey: postgres_password # overwriting as postgres typically expects 'postgres-password'
|
||||
|
||||
nginx:
|
||||
containerPorts:
|
||||
http: 1024
|
||||
@@ -135,9 +141,13 @@ webserver:
|
||||
# runAsNonRoot: true
|
||||
# runAsUser: 1000
|
||||
|
||||
containerPorts:
|
||||
server: 3000
|
||||
|
||||
service:
|
||||
type: ClusterIP
|
||||
port: 3000
|
||||
servicePort: 3000
|
||||
targetPort: http
|
||||
|
||||
resources: {}
|
||||
# We usually recommend not to specify default resources and to leave this as a conscious
|
||||
@@ -156,7 +166,7 @@ webserver:
|
||||
minReplicas: 1
|
||||
maxReplicas: 100
|
||||
targetCPUUtilizationPercentage: 80
|
||||
# targetMemoryUtilizationPercentage: 80
|
||||
targetMemoryUtilizationPercentage: 80
|
||||
|
||||
# Additional volumes on the output Deployment definition.
|
||||
volumes: []
|
||||
@@ -189,6 +199,9 @@ api:
|
||||
scope: onyx-backend
|
||||
app: api-server
|
||||
|
||||
containerPorts:
|
||||
server: 8080
|
||||
|
||||
podSecurityContext:
|
||||
{}
|
||||
# fsGroup: 2000
|
||||
@@ -204,7 +217,9 @@ api:
|
||||
|
||||
service:
|
||||
type: ClusterIP
|
||||
port: 8080
|
||||
servicePort: 8080
|
||||
targetPort: api-server-port
|
||||
portName: api-server-port
|
||||
|
||||
resources: {}
|
||||
# We usually recommend not to specify default resources and to leave this as a conscious
|
||||
@@ -223,7 +238,7 @@ api:
|
||||
minReplicas: 1
|
||||
maxReplicas: 100
|
||||
targetCPUUtilizationPercentage: 80
|
||||
# targetMemoryUtilizationPercentage: 80
|
||||
targetMemoryUtilizationPercentage: 80
|
||||
|
||||
# Additional volumes on the output Deployment definition.
|
||||
volumes: []
|
||||
@@ -247,7 +262,7 @@ background:
|
||||
repository: onyxdotapp/onyx-backend
|
||||
pullPolicy: IfNotPresent
|
||||
# Overrides the image tag whose default is the chart appVersion.
|
||||
tag: latest
|
||||
tag: ""
|
||||
podAnnotations: {}
|
||||
podLabels:
|
||||
scope: onyx-backend
|
||||
@@ -284,7 +299,7 @@ background:
|
||||
minReplicas: 1
|
||||
maxReplicas: 100
|
||||
targetCPUUtilizationPercentage: 80
|
||||
# targetMemoryUtilizationPercentage: 80
|
||||
targetMemoryUtilizationPercentage: 80
|
||||
|
||||
# Additional volumes on the output Deployment definition.
|
||||
volumes: []
|
||||
@@ -303,6 +318,16 @@ background:
|
||||
tolerations: []
|
||||
|
||||
vespa:
|
||||
volumeClaimTemplates:
|
||||
- metadata:
|
||||
name: vespa-storage
|
||||
spec:
|
||||
accessModes:
|
||||
- ReadWriteOnce
|
||||
storageClassName: ""
|
||||
resources:
|
||||
requests:
|
||||
storage: 1Gi
|
||||
enabled: true
|
||||
replicaCount: 1
|
||||
image:
|
||||
@@ -377,19 +402,11 @@ redis:
|
||||
# # hosts:
|
||||
# # - chart-example.local
|
||||
|
||||
persistence:
|
||||
vespa:
|
||||
enabled: true
|
||||
existingClaim: ""
|
||||
storageClassName: ""
|
||||
accessModes:
|
||||
- ReadWriteOnce
|
||||
size: 5Gi
|
||||
|
||||
auth:
|
||||
# for storing smtp, oauth, slack, and other secrets
|
||||
# existingSecret onyx-secret for storing smtp, oauth, slack, and other secrets
|
||||
# keys are lowercased version of env vars (e.g. SMTP_USER -> smtp_user)
|
||||
existingSecret: "" # onyx-secrets
|
||||
existingSecret: ""
|
||||
# optionally override the secret keys to reference in the secret
|
||||
# this is used to populate the env vars in individual deployments
|
||||
# the values here reference the keys in secrets below
|
||||
@@ -413,14 +430,22 @@ auth:
|
||||
redis_password: "password"
|
||||
|
||||
configMap:
|
||||
AUTH_TYPE: "disabled" # Change this for production uses unless Onyx is only accessible behind VPN
|
||||
SESSION_EXPIRE_TIME_SECONDS: "86400" # 1 Day Default
|
||||
VALID_EMAIL_DOMAINS: "" # Can be something like onyx.app, as an extra double-check
|
||||
SMTP_SERVER: "" # For sending verification emails, if unspecified then defaults to 'smtp.gmail.com'
|
||||
SMTP_PORT: "" # For sending verification emails, if unspecified then defaults to '587'
|
||||
SMTP_USER: "" # 'your-email@company.com'
|
||||
# SMTP_PASS: "" # 'your-gmail-password'
|
||||
EMAIL_FROM: "" # 'your-email@company.com' SMTP_USER missing used instead
|
||||
# Change this for production uses unless Onyx is only accessible behind VPN
|
||||
AUTH_TYPE: "disabled"
|
||||
# 1 Day Default
|
||||
SESSION_EXPIRE_TIME_SECONDS: "86400"
|
||||
# Can be something like onyx.app, as an extra double-check
|
||||
VALID_EMAIL_DOMAINS: ""
|
||||
# For sending verification emails, if unspecified then defaults to 'smtp.gmail.com'
|
||||
SMTP_SERVER: ""
|
||||
# For sending verification emails, if unspecified then defaults to '587'
|
||||
SMTP_PORT: ""
|
||||
# 'your-email@company.com'
|
||||
SMTP_USER: ""
|
||||
# 'your-gmail-password'
|
||||
# SMTP_PASS: ""
|
||||
# 'your-email@company.com' SMTP_USER missing used instead
|
||||
EMAIL_FROM: ""
|
||||
# Gen AI Settings
|
||||
GEN_AI_MAX_TOKENS: ""
|
||||
QA_TIMEOUT: "60"
|
||||
@@ -462,7 +487,7 @@ configMap:
|
||||
DANSWER_BOT_DISABLE_DOCS_ONLY_ANSWER: ""
|
||||
DANSWER_BOT_DISPLAY_ERROR_MSGS: ""
|
||||
DANSWER_BOT_RESPOND_EVERY_CHANNEL: ""
|
||||
DANSWER_BOT_DISABLE_COT: "" # Currently unused
|
||||
DANSWER_BOT_DISABLE_COT: ""
|
||||
NOTIFY_SLACKBOT_NO_ANSWER: ""
|
||||
# Logging
|
||||
# Optional Telemetry, please keep it on (nothing sensitive is collected)? <3
|
||||
@@ -473,7 +498,8 @@ configMap:
|
||||
LOG_DANSWER_MODEL_INTERACTIONS: ""
|
||||
LOG_VESPA_TIMING_INFORMATION: ""
|
||||
# Shared or Non-backend Related
|
||||
WEB_DOMAIN: "http://localhost:3000" # for web server and api server
|
||||
DOMAIN: "localhost" # for nginx
|
||||
WEB_DOMAIN: "http://localhost:3000"
|
||||
# DOMAIN used by nginx
|
||||
DOMAIN: "localhost"
|
||||
# Chat Configs
|
||||
HARD_DELETE_CHATS: ""
|
||||
|
||||
137
web/package-lock.json
generated
137
web/package-lock.json
generated
@@ -25,6 +25,7 @@
|
||||
"@radix-ui/react-scroll-area": "^1.2.2",
|
||||
"@radix-ui/react-select": "^2.1.2",
|
||||
"@radix-ui/react-separator": "^1.1.0",
|
||||
"@radix-ui/react-slider": "^1.2.2",
|
||||
"@radix-ui/react-slot": "^1.1.0",
|
||||
"@radix-ui/react-switch": "^1.1.1",
|
||||
"@radix-ui/react-tabs": "^1.1.1",
|
||||
@@ -4963,6 +4964,142 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-slider": {
|
||||
"version": "1.2.2",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-slider/-/react-slider-1.2.2.tgz",
|
||||
"integrity": "sha512-sNlU06ii1/ZcbHf8I9En54ZPW0Vil/yPVg4vQMcFNjrIx51jsHbFl1HYHQvCIWJSr1q0ZmA+iIs/ZTv8h7HHSA==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@radix-ui/number": "1.1.0",
|
||||
"@radix-ui/primitive": "1.1.1",
|
||||
"@radix-ui/react-collection": "1.1.1",
|
||||
"@radix-ui/react-compose-refs": "1.1.1",
|
||||
"@radix-ui/react-context": "1.1.1",
|
||||
"@radix-ui/react-direction": "1.1.0",
|
||||
"@radix-ui/react-primitive": "2.0.1",
|
||||
"@radix-ui/react-use-controllable-state": "1.1.0",
|
||||
"@radix-ui/react-use-layout-effect": "1.1.0",
|
||||
"@radix-ui/react-use-previous": "1.1.0",
|
||||
"@radix-ui/react-use-size": "1.1.0"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"@types/react-dom": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
|
||||
"react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
},
|
||||
"@types/react-dom": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-slider/node_modules/@radix-ui/primitive": {
|
||||
"version": "1.1.1",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/primitive/-/primitive-1.1.1.tgz",
|
||||
"integrity": "sha512-SJ31y+Q/zAyShtXJc8x83i9TYdbAfHZ++tUZnvjJJqFjzsdUnKsxPL6IEtBlxKkU7yzer//GQtZSV4GbldL3YA==",
|
||||
"license": "MIT"
|
||||
},
|
||||
"node_modules/@radix-ui/react-slider/node_modules/@radix-ui/react-collection": {
|
||||
"version": "1.1.1",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-collection/-/react-collection-1.1.1.tgz",
|
||||
"integrity": "sha512-LwT3pSho9Dljg+wY2KN2mrrh6y3qELfftINERIzBUO9e0N+t0oMTyn3k9iv+ZqgrwGkRnLpNJrsMv9BZlt2yuA==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@radix-ui/react-compose-refs": "1.1.1",
|
||||
"@radix-ui/react-context": "1.1.1",
|
||||
"@radix-ui/react-primitive": "2.0.1",
|
||||
"@radix-ui/react-slot": "1.1.1"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"@types/react-dom": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
|
||||
"react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
},
|
||||
"@types/react-dom": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-slider/node_modules/@radix-ui/react-compose-refs": {
|
||||
"version": "1.1.1",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-compose-refs/-/react-compose-refs-1.1.1.tgz",
|
||||
"integrity": "sha512-Y9VzoRDSJtgFMUCoiZBDVo084VQ5hfpXxVE+NgkdNsjiDBByiImMZKKhxMwCbdHvhlENG6a833CbFkOQvTricw==",
|
||||
"license": "MIT",
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-slider/node_modules/@radix-ui/react-context": {
|
||||
"version": "1.1.1",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-context/-/react-context-1.1.1.tgz",
|
||||
"integrity": "sha512-UASk9zi+crv9WteK/NU4PLvOoL3OuE6BWVKNF6hPRBtYBDXQ2u5iu3O59zUlJiTVvkyuycnqrztsHVJwcK9K+Q==",
|
||||
"license": "MIT",
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-slider/node_modules/@radix-ui/react-primitive": {
|
||||
"version": "2.0.1",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-primitive/-/react-primitive-2.0.1.tgz",
|
||||
"integrity": "sha512-sHCWTtxwNn3L3fH8qAfnF3WbUZycW93SM1j3NFDzXBiz8D6F5UTTy8G1+WFEaiCdvCVRJWj6N2R4Xq6HdiHmDg==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@radix-ui/react-slot": "1.1.1"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"@types/react-dom": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc",
|
||||
"react-dom": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
},
|
||||
"@types/react-dom": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-slider/node_modules/@radix-ui/react-slot": {
|
||||
"version": "1.1.1",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-slot/-/react-slot-1.1.1.tgz",
|
||||
"integrity": "sha512-RApLLOcINYJA+dMVbOju7MYv1Mb2EBp2nH4HdDzXTSyaR5optlm6Otrz1euW3HbdOR8UmmFK06TD+A9frYWv+g==",
|
||||
"license": "MIT",
|
||||
"dependencies": {
|
||||
"@radix-ui/react-compose-refs": "1.1.1"
|
||||
},
|
||||
"peerDependencies": {
|
||||
"@types/react": "*",
|
||||
"react": "^16.8 || ^17.0 || ^18.0 || ^19.0 || ^19.0.0-rc"
|
||||
},
|
||||
"peerDependenciesMeta": {
|
||||
"@types/react": {
|
||||
"optional": true
|
||||
}
|
||||
}
|
||||
},
|
||||
"node_modules/@radix-ui/react-slot": {
|
||||
"version": "1.1.0",
|
||||
"resolved": "https://registry.npmjs.org/@radix-ui/react-slot/-/react-slot-1.1.0.tgz",
|
||||
|
||||
@@ -28,6 +28,7 @@
|
||||
"@radix-ui/react-scroll-area": "^1.2.2",
|
||||
"@radix-ui/react-select": "^2.1.2",
|
||||
"@radix-ui/react-separator": "^1.1.0",
|
||||
"@radix-ui/react-slider": "^1.2.2",
|
||||
"@radix-ui/react-slot": "^1.1.0",
|
||||
"@radix-ui/react-switch": "^1.1.1",
|
||||
"@radix-ui/react-tabs": "^1.1.1",
|
||||
|
||||
@@ -2,12 +2,13 @@ import { defineConfig, devices } from "@playwright/test";
|
||||
|
||||
export default defineConfig({
|
||||
globalSetup: require.resolve("./tests/e2e/global-setup"),
|
||||
|
||||
timeout: 30000, // 30 seconds timeout
|
||||
projects: [
|
||||
{
|
||||
name: "admin",
|
||||
use: {
|
||||
...devices["Desktop Chrome"],
|
||||
viewport: { width: 1280, height: 720 },
|
||||
storageState: "admin_auth.json",
|
||||
},
|
||||
testIgnore: ["**/codeUtils.test.ts"],
|
||||
|
||||
@@ -720,7 +720,6 @@ export function AssistantEditor({
|
||||
name="description"
|
||||
label="Description"
|
||||
placeholder="Use this Assistant to help draft professional emails"
|
||||
data-testid="assistant-description-input"
|
||||
className="[&_input]:placeholder:text-text-muted/50"
|
||||
/>
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import { OnyxIcon } from "@/components/icons/icons";
|
||||
|
||||
export function ChatIntro({ selectedPersona }: { selectedPersona: Persona }) {
|
||||
return (
|
||||
<div className="flex flex-col items-center gap-6">
|
||||
<div data-testid="chat-intro" className="flex flex-col items-center gap-6">
|
||||
<div className="relative flex flex-col gap-y-4 w-fit mx-auto justify-center">
|
||||
<div className="absolute z-10 items-center flex -left-12 top-1/2 -translate-y-1/2">
|
||||
<AssistantIcon size={36} assistant={selectedPersona} />
|
||||
|
||||
@@ -297,6 +297,7 @@ export function ChatPage({
|
||||
// 2. Selected assistant (assistnat default in this chat session)
|
||||
// 3. First pinned assistants (ordered list of pinned assistants)
|
||||
// 4. Available assistants (ordered list of available assistants)
|
||||
// Relevant test: `live_assistant.spec.ts`
|
||||
const liveAssistant: Persona | undefined = useMemo(
|
||||
() =>
|
||||
alternativeAssistant ||
|
||||
@@ -403,9 +404,6 @@ export function ChatPage({
|
||||
filterManager.setSelectedTags([]);
|
||||
filterManager.setTimeRange(null);
|
||||
|
||||
// reset LLM overrides (based on chat session!)
|
||||
llmOverrideManager.updateTemperature(null);
|
||||
|
||||
// remove uploaded files
|
||||
setCurrentMessageFiles([]);
|
||||
|
||||
@@ -448,6 +446,7 @@ export function ChatPage({
|
||||
);
|
||||
|
||||
const chatSession = (await response.json()) as BackendChatSession;
|
||||
|
||||
setSelectedAssistantFromId(chatSession.persona_id);
|
||||
|
||||
const newMessageMap = processRawChatHistory(chatSession.messages);
|
||||
|
||||
@@ -478,6 +478,7 @@ export function ChatInputBar({
|
||||
onKeyDownCapture={handleKeyDown}
|
||||
onChange={handleInputChange}
|
||||
ref={textAreaRef}
|
||||
id="onyx-chat-input-textarea"
|
||||
className={`
|
||||
m-0
|
||||
w-full
|
||||
@@ -703,6 +704,7 @@ export function ChatInputBar({
|
||||
</div>
|
||||
<div className="flex my-auto">
|
||||
<button
|
||||
id="onyx-chat-input-send-button"
|
||||
className={`cursor-pointer ${
|
||||
chatState == "streaming" ||
|
||||
chatState == "toolBuilding" ||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
import React, { useState } from "react";
|
||||
import React, { useState, useEffect } from "react";
|
||||
import {
|
||||
Popover,
|
||||
PopoverContent,
|
||||
@@ -26,6 +26,9 @@ import {
|
||||
} from "@/components/ui/tooltip";
|
||||
import { FiAlertTriangle } from "react-icons/fi";
|
||||
|
||||
import { Slider } from "@/components/ui/slider";
|
||||
import { useUser } from "@/components/user/UserProvider";
|
||||
|
||||
interface LLMPopoverProps {
|
||||
llmProviders: LLMProviderDescriptor[];
|
||||
llmOverrideManager: LlmOverrideManager;
|
||||
@@ -40,6 +43,7 @@ export default function LLMPopover({
|
||||
currentAssistant,
|
||||
}: LLMPopoverProps) {
|
||||
const [isOpen, setIsOpen] = useState(false);
|
||||
const { user } = useUser();
|
||||
const { llmOverride, updateLLMOverride } = llmOverrideManager;
|
||||
const currentLlm = llmOverride.modelName;
|
||||
|
||||
@@ -88,10 +92,29 @@ export default function LLMPopover({
|
||||
? getDisplayNameForModel(defaultModelName)
|
||||
: null;
|
||||
|
||||
const [localTemperature, setLocalTemperature] = useState(
|
||||
llmOverrideManager.temperature ?? 0.5
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
setLocalTemperature(llmOverrideManager.temperature ?? 0.5);
|
||||
}, [llmOverrideManager.temperature]);
|
||||
|
||||
const handleTemperatureChange = (value: number[]) => {
|
||||
setLocalTemperature(value[0]);
|
||||
};
|
||||
|
||||
const handleTemperatureChangeComplete = (value: number[]) => {
|
||||
llmOverrideManager.updateTemperature(value[0]);
|
||||
};
|
||||
|
||||
return (
|
||||
<Popover open={isOpen} onOpenChange={setIsOpen}>
|
||||
<PopoverTrigger asChild>
|
||||
<button className="focus:outline-none">
|
||||
<button
|
||||
className="focus:outline-none"
|
||||
data-testid="llm-popover-trigger"
|
||||
>
|
||||
<ChatInputOption
|
||||
minimize
|
||||
toggle
|
||||
@@ -115,9 +138,9 @@ export default function LLMPopover({
|
||||
</PopoverTrigger>
|
||||
<PopoverContent
|
||||
align="start"
|
||||
className="w-64 p-1 bg-background border border-gray-200 rounded-md shadow-lg"
|
||||
className="w-64 p-1 bg-background border border-gray-200 rounded-md shadow-lg flex flex-col"
|
||||
>
|
||||
<div className="max-h-[300px] overflow-y-auto">
|
||||
<div className="flex-grow max-h-[300px] default-scrollbar overflow-y-auto">
|
||||
{llmOptions.map(({ name, icon, value }, index) => {
|
||||
if (!requiresImageGeneration || checkLLMSupportsImageInput(name)) {
|
||||
return (
|
||||
@@ -168,6 +191,25 @@ export default function LLMPopover({
|
||||
return null;
|
||||
})}
|
||||
</div>
|
||||
{user?.preferences?.temperature_override_enabled && (
|
||||
<div className="mt-2 pt-2 border-t border-gray-200">
|
||||
<div className="w-full px-3 py-2">
|
||||
<Slider
|
||||
value={[localTemperature]}
|
||||
max={llmOverrideManager.maxTemperature}
|
||||
min={0}
|
||||
step={0.01}
|
||||
onValueChange={handleTemperatureChange}
|
||||
onValueCommit={handleTemperatureChangeComplete}
|
||||
className="w-full"
|
||||
/>
|
||||
<div className="flex justify-between text-xs text-gray-500 mt-2">
|
||||
<span>Temperature (creativity)</span>
|
||||
<span>{localTemperature.toFixed(1)}</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
);
|
||||
|
||||
@@ -68,6 +68,7 @@ export interface ChatSession {
|
||||
shared_status: ChatSessionSharedStatus;
|
||||
folder_id: number | null;
|
||||
current_alternate_model: string;
|
||||
current_temperature_override: number | null;
|
||||
}
|
||||
|
||||
export interface SearchSession {
|
||||
@@ -107,6 +108,7 @@ export interface BackendChatSession {
|
||||
messages: BackendMessage[];
|
||||
time_created: string;
|
||||
shared_status: ChatSessionSharedStatus;
|
||||
current_temperature_override: number | null;
|
||||
current_alternate_model?: string;
|
||||
}
|
||||
|
||||
|
||||
@@ -75,6 +75,23 @@ export async function updateModelOverrideForChatSession(
|
||||
return response;
|
||||
}
|
||||
|
||||
export async function updateTemperatureOverrideForChatSession(
|
||||
chatSessionId: string,
|
||||
newTemperature: number
|
||||
) {
|
||||
const response = await fetch("/api/chat/update-chat-session-temperature", {
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({
|
||||
chat_session_id: chatSessionId,
|
||||
temperature_override: newTemperature,
|
||||
}),
|
||||
});
|
||||
return response;
|
||||
}
|
||||
|
||||
export async function createChatSession(
|
||||
personaId: number,
|
||||
description: string | null
|
||||
|
||||
@@ -402,7 +402,7 @@ export const AIMessage = ({
|
||||
|
||||
return (
|
||||
<div
|
||||
id="onyx-ai-message"
|
||||
id={isComplete ? "onyx-ai-message" : undefined}
|
||||
ref={trackedElementRef}
|
||||
className={`py-5 ml-4 lg:px-5 relative flex `}
|
||||
>
|
||||
|
||||
@@ -30,8 +30,13 @@ export function UserSettingsModal({
|
||||
defaultModel: string | null;
|
||||
}) {
|
||||
const { inputPrompts, refreshInputPrompts } = useChatContext();
|
||||
const { refreshUser, user, updateUserAutoScroll, updateUserShortcuts } =
|
||||
useUser();
|
||||
const {
|
||||
refreshUser,
|
||||
user,
|
||||
updateUserAutoScroll,
|
||||
updateUserShortcuts,
|
||||
updateUserTemperatureOverrideEnabled,
|
||||
} = useUser();
|
||||
const containerRef = useRef<HTMLDivElement>(null);
|
||||
const messageRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
@@ -179,6 +184,16 @@ export function UserSettingsModal({
|
||||
/>
|
||||
<Label className="text-sm">Enable Prompt Shortcuts</Label>
|
||||
</div>
|
||||
<div className="flex items-center gap-x-2">
|
||||
<Switch
|
||||
size="sm"
|
||||
checked={user?.preferences?.temperature_override_enabled}
|
||||
onCheckedChange={(checked) => {
|
||||
updateUserTemperatureOverrideEnabled(checked);
|
||||
}}
|
||||
/>
|
||||
<Label className="text-sm">Enable Temperature Override</Label>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<Separator />
|
||||
|
||||
@@ -129,6 +129,7 @@ const SortableAssistant: React.FC<SortableAssistantProps> = ({
|
||||
className="w-3 ml-[2px] mr-[2px] group-hover:visible invisible flex-none cursor-grab"
|
||||
/>
|
||||
<button
|
||||
data-testid={`assistant-[${assistant.id}]`}
|
||||
onClick={(e) => {
|
||||
e.preventDefault();
|
||||
if (!isDragging) {
|
||||
|
||||
@@ -103,7 +103,7 @@ export function Modal({
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
<div className="flex-shrink-0">
|
||||
<div className="items-start flex-shrink-0">
|
||||
{title && (
|
||||
<>
|
||||
<div className="flex">
|
||||
|
||||
@@ -133,6 +133,7 @@ export function UserDropdown({
|
||||
onOpenChange={onOpenChange}
|
||||
content={
|
||||
<div
|
||||
id="onyx-user-dropdown"
|
||||
onClick={() => setUserInfoVisible(!userInfoVisible)}
|
||||
className="flex relative cursor-pointer"
|
||||
>
|
||||
|
||||
@@ -59,7 +59,23 @@ export const Popup: React.FC<PopupSpec> = ({ message, type }) => (
|
||||
/>
|
||||
</svg>
|
||||
)}
|
||||
<span className="font-medium">{message}</span>
|
||||
<div className="flex flex-col items-center space-x-2">
|
||||
<span className="font-medium">{message}</span>
|
||||
{type === "error" && (
|
||||
<span className="text-xs text-red-100">
|
||||
Need help?{" "}
|
||||
<a
|
||||
href="https://join.slack.com/t/onyx-dot-app/shared_invite/zt-2twesxdr6-5iQitKZQpgq~hYIZ~dv3KA"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="underline hover:text-red-900"
|
||||
>
|
||||
Join our community
|
||||
</a>{" "}
|
||||
for support.
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
|
||||
28
web/src/components/ui/slider.tsx
Normal file
28
web/src/components/ui/slider.tsx
Normal file
@@ -0,0 +1,28 @@
|
||||
"use client";
|
||||
|
||||
import * as React from "react";
|
||||
import * as SliderPrimitive from "@radix-ui/react-slider";
|
||||
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
const Slider = React.forwardRef<
|
||||
React.ElementRef<typeof SliderPrimitive.Root>,
|
||||
React.ComponentPropsWithoutRef<typeof SliderPrimitive.Root>
|
||||
>(({ className, ...props }, ref) => (
|
||||
<SliderPrimitive.Root
|
||||
ref={ref}
|
||||
className={cn(
|
||||
"relative flex w-full touch-none select-none items-center",
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
>
|
||||
<SliderPrimitive.Track className="relative h-2 w-full grow overflow-hidden rounded-full bg-neutral-100 dark:bg-neutral-800">
|
||||
<SliderPrimitive.Range className="absolute h-full bg-neutral-900 dark:bg-neutral-50" />
|
||||
</SliderPrimitive.Track>
|
||||
<SliderPrimitive.Thumb className="block h-3 w-3 rounded-full border border-neutral-900 bg-white ring-offset-white transition-colors focus-visible:outline-none focus-visible:ring-2 focus-visible:ring-neutral-950 focus-visible:ring-offset disabled:pointer-events-none disabled:opacity-50 dark:border-neutral-50 dark:bg-neutral-950 dark:ring-offset-neutral-950 dark:focus-visible:ring-neutral-300" />
|
||||
</SliderPrimitive.Root>
|
||||
));
|
||||
Slider.displayName = SliderPrimitive.Root.displayName;
|
||||
|
||||
export { Slider };
|
||||
@@ -18,6 +18,7 @@ interface UserContextType {
|
||||
assistantId: number,
|
||||
isPinned: boolean
|
||||
) => Promise<boolean>;
|
||||
updateUserTemperatureOverrideEnabled: (enabled: boolean) => Promise<void>;
|
||||
}
|
||||
|
||||
const UserContext = createContext<UserContextType | undefined>(undefined);
|
||||
@@ -57,6 +58,41 @@ export function UserProvider({
|
||||
console.error("Error fetching current user:", error);
|
||||
}
|
||||
};
|
||||
const updateUserTemperatureOverrideEnabled = async (enabled: boolean) => {
|
||||
try {
|
||||
setUpToDateUser((prevUser) => {
|
||||
if (prevUser) {
|
||||
return {
|
||||
...prevUser,
|
||||
preferences: {
|
||||
...prevUser.preferences,
|
||||
temperature_override_enabled: enabled,
|
||||
},
|
||||
};
|
||||
}
|
||||
return prevUser;
|
||||
});
|
||||
|
||||
const response = await fetch(
|
||||
`/api/temperature-override-enabled?temperature_override_enabled=${enabled}`,
|
||||
{
|
||||
method: "PATCH",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
}
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
await refreshUser();
|
||||
throw new Error("Failed to update user temperature override setting");
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error updating user temperature override setting:", error);
|
||||
throw error;
|
||||
}
|
||||
};
|
||||
|
||||
const updateUserShortcuts = async (enabled: boolean) => {
|
||||
try {
|
||||
setUpToDateUser((prevUser) => {
|
||||
@@ -184,6 +220,7 @@ export function UserProvider({
|
||||
refreshUser,
|
||||
updateUserAutoScroll,
|
||||
updateUserShortcuts,
|
||||
updateUserTemperatureOverrideEnabled,
|
||||
toggleAssistantPinnedStatus,
|
||||
isAdmin: upToDateUser?.role === UserRole.ADMIN,
|
||||
// Curator status applies for either global or basic curator
|
||||
|
||||
@@ -10,7 +10,7 @@ import {
|
||||
} from "@/lib/types";
|
||||
import useSWR, { mutate, useSWRConfig } from "swr";
|
||||
import { errorHandlingFetcher } from "./fetcher";
|
||||
import { useContext, useEffect, useState } from "react";
|
||||
import { useContext, useEffect, useMemo, useState } from "react";
|
||||
import { DateRangePickerValue } from "@/app/ee/admin/performance/DateRangeSelector";
|
||||
import { Filters, SourceMetadata } from "./search/interfaces";
|
||||
import {
|
||||
@@ -28,6 +28,8 @@ import { isAnthropic } from "@/app/admin/configuration/llm/interfaces";
|
||||
import { getSourceMetadata } from "./sources";
|
||||
import { AuthType, NEXT_PUBLIC_CLOUD_ENABLED } from "./constants";
|
||||
import { useUser } from "@/components/user/UserProvider";
|
||||
import { SEARCH_TOOL_ID } from "@/app/chat/tools/constants";
|
||||
import { updateTemperatureOverrideForChatSession } from "@/app/chat/lib";
|
||||
|
||||
const CREDENTIAL_URL = "/api/manage/admin/credential";
|
||||
|
||||
@@ -360,14 +362,22 @@ export interface LlmOverride {
|
||||
export interface LlmOverrideManager {
|
||||
llmOverride: LlmOverride;
|
||||
updateLLMOverride: (newOverride: LlmOverride) => void;
|
||||
temperature: number | null;
|
||||
updateTemperature: (temperature: number | null) => void;
|
||||
temperature: number;
|
||||
updateTemperature: (temperature: number) => void;
|
||||
updateModelOverrideForChatSession: (chatSession?: ChatSession) => void;
|
||||
imageFilesPresent: boolean;
|
||||
updateImageFilesPresent: (present: boolean) => void;
|
||||
liveAssistant: Persona | null;
|
||||
maxTemperature: number;
|
||||
}
|
||||
|
||||
// Things to test
|
||||
// 1. User override
|
||||
// 2. User preference (defaults to system wide default if no preference set)
|
||||
// 3. Current assistant
|
||||
// 4. Current chat session
|
||||
// 5. Live assistant
|
||||
|
||||
/*
|
||||
LLM Override is as follows (i.e. this order)
|
||||
- User override (explicitly set in the chat input bar)
|
||||
@@ -386,6 +396,20 @@ Changes take place as
|
||||
- (uploadLLMOverride) User explicitly setting a model override (and we explicitly override and set the userSpecifiedOverride which we'll use in place of the user preferences unless overridden by an assistant)
|
||||
|
||||
If we have a live assistant, we should use that model override
|
||||
|
||||
Relevant test: `llm_ordering.spec.ts`.
|
||||
|
||||
Temperature override is set as follows:
|
||||
- For existing chat sessions:
|
||||
- If the user has previously overridden the temperature for a specific chat session,
|
||||
that value is persisted and used when the user returns to that chat.
|
||||
- This persistence applies even if the temperature was set before sending the first message in the chat.
|
||||
- For new chat sessions:
|
||||
- If the search tool is available, the default temperature is set to 0.
|
||||
- If the search tool is not available, the default temperature is set to 0.5.
|
||||
|
||||
This approach ensures that user preferences are maintained for existing chats while
|
||||
providing appropriate defaults for new conversations based on the available tools.
|
||||
*/
|
||||
|
||||
export function useLlmOverride(
|
||||
@@ -398,11 +422,6 @@ export function useLlmOverride(
|
||||
const [chatSession, setChatSession] = useState<ChatSession | null>(null);
|
||||
|
||||
const llmOverrideUpdate = () => {
|
||||
if (!chatSession && currentChatSession) {
|
||||
setChatSession(currentChatSession || null);
|
||||
return;
|
||||
}
|
||||
|
||||
if (liveAssistant?.llm_model_version_override) {
|
||||
setLlmOverride(
|
||||
getValidLlmOverride(liveAssistant.llm_model_version_override)
|
||||
@@ -490,24 +509,68 @@ export function useLlmOverride(
|
||||
}
|
||||
};
|
||||
|
||||
const [temperature, setTemperature] = useState<number | null>(0);
|
||||
|
||||
useEffect(() => {
|
||||
const [temperature, setTemperature] = useState<number>(() => {
|
||||
llmOverrideUpdate();
|
||||
}, [liveAssistant, currentChatSession]);
|
||||
|
||||
if (currentChatSession?.current_temperature_override != null) {
|
||||
return Math.min(
|
||||
currentChatSession.current_temperature_override,
|
||||
isAnthropic(llmOverride.provider, llmOverride.modelName) ? 1.0 : 2.0
|
||||
);
|
||||
} else if (
|
||||
liveAssistant?.tools.some((tool) => tool.name === SEARCH_TOOL_ID)
|
||||
) {
|
||||
return 0;
|
||||
}
|
||||
return 0.5;
|
||||
});
|
||||
|
||||
const maxTemperature = useMemo(() => {
|
||||
return isAnthropic(llmOverride.provider, llmOverride.modelName) ? 1.0 : 2.0;
|
||||
}, [llmOverride]);
|
||||
|
||||
useEffect(() => {
|
||||
if (isAnthropic(llmOverride.provider, llmOverride.modelName)) {
|
||||
setTemperature((prevTemp) => Math.min(prevTemp ?? 0, 1.0));
|
||||
const newTemperature = Math.min(temperature, 1.0);
|
||||
setTemperature(newTemperature);
|
||||
if (chatSession?.id) {
|
||||
updateTemperatureOverrideForChatSession(chatSession.id, newTemperature);
|
||||
}
|
||||
}
|
||||
}, [llmOverride]);
|
||||
|
||||
const updateTemperature = (temperature: number | null) => {
|
||||
useEffect(() => {
|
||||
if (!chatSession && currentChatSession) {
|
||||
setChatSession(currentChatSession || null);
|
||||
if (temperature) {
|
||||
updateTemperatureOverrideForChatSession(
|
||||
currentChatSession.id,
|
||||
temperature
|
||||
);
|
||||
}
|
||||
return;
|
||||
}
|
||||
|
||||
if (currentChatSession?.current_temperature_override) {
|
||||
setTemperature(currentChatSession.current_temperature_override);
|
||||
} else if (
|
||||
liveAssistant?.tools.some((tool) => tool.name === SEARCH_TOOL_ID)
|
||||
) {
|
||||
setTemperature(0);
|
||||
} else {
|
||||
setTemperature(0.5);
|
||||
}
|
||||
}, [liveAssistant, currentChatSession]);
|
||||
|
||||
const updateTemperature = (temperature: number) => {
|
||||
if (isAnthropic(llmOverride.provider, llmOverride.modelName)) {
|
||||
setTemperature((prevTemp) => Math.min(temperature ?? 0, 1.0));
|
||||
setTemperature((prevTemp) => Math.min(temperature, 1.0));
|
||||
} else {
|
||||
setTemperature(temperature);
|
||||
}
|
||||
if (chatSession) {
|
||||
updateTemperatureOverrideForChatSession(chatSession.id, temperature);
|
||||
}
|
||||
};
|
||||
|
||||
return {
|
||||
@@ -519,6 +582,7 @@ export function useLlmOverride(
|
||||
imageFilesPresent,
|
||||
updateImageFilesPresent,
|
||||
liveAssistant: liveAssistant ?? null,
|
||||
maxTemperature,
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ interface UserPreferences {
|
||||
recent_assistants: number[];
|
||||
auto_scroll: boolean | null;
|
||||
shortcut_enabled: boolean;
|
||||
temperature_override_enabled: boolean;
|
||||
}
|
||||
|
||||
export enum UserRole {
|
||||
|
||||
@@ -1,54 +0,0 @@
|
||||
import { test, expect } from "@playwright/test";
|
||||
|
||||
// Use pre-signed in "admin" storage state
|
||||
test.use({
|
||||
storageState: "admin_auth.json",
|
||||
});
|
||||
|
||||
test("Chat workflow", async ({ page }) => {
|
||||
// Initial setup
|
||||
await page.goto("http://localhost:3000/chat", { timeout: 3000 });
|
||||
|
||||
// Interact with Art assistant
|
||||
await page.locator("button").filter({ hasText: "Art" }).click();
|
||||
await page.getByPlaceholder("Message Art assistant...").fill("Hi");
|
||||
await page.keyboard.press("Enter");
|
||||
await page.waitForTimeout(3000);
|
||||
|
||||
// Start a new chat
|
||||
await page.getByRole("link", { name: "Start New Chat" }).click();
|
||||
await page.waitForNavigation({ waitUntil: "networkidle" });
|
||||
|
||||
// Check for expected text
|
||||
await expect(page.getByText("Assistant for generating")).toBeVisible();
|
||||
|
||||
// Interact with General assistant
|
||||
await page.locator("button").filter({ hasText: "General" }).click();
|
||||
|
||||
// Check URL after clicking General assistant
|
||||
await expect(page).toHaveURL("http://localhost:3000/chat?assistantId=-1", {
|
||||
timeout: 5000,
|
||||
});
|
||||
|
||||
// Create a new assistant
|
||||
await page.getByRole("button", { name: "Explore Assistants" }).click();
|
||||
await page.getByRole("button", { name: "Create" }).click();
|
||||
await page.getByTestId("name").click();
|
||||
await page.getByTestId("name").fill("Test Assistant");
|
||||
await page.getByTestId("description").click();
|
||||
await page.getByTestId("description").fill("Test Assistant Description");
|
||||
await page.getByTestId("system_prompt").click();
|
||||
await page.getByTestId("system_prompt").fill("Test Assistant Instructions");
|
||||
await page.getByRole("button", { name: "Create" }).click();
|
||||
|
||||
// Verify new assistant creation
|
||||
await expect(page.getByText("Test Assistant Description")).toBeVisible({
|
||||
timeout: 5000,
|
||||
});
|
||||
|
||||
// Start another new chat
|
||||
await page.getByRole("link", { name: "Start New Chat" }).click();
|
||||
await expect(page.getByText("Assistant with access to")).toBeVisible({
|
||||
timeout: 5000,
|
||||
});
|
||||
});
|
||||
56
web/tests/e2e/chat/current_assistant.spec.ts
Normal file
56
web/tests/e2e/chat/current_assistant.spec.ts
Normal file
@@ -0,0 +1,56 @@
|
||||
import { test, expect } from "@playwright/test";
|
||||
import { dragElementAbove, dragElementBelow } from "../utils/dragUtils";
|
||||
import { loginAsRandomUser } from "../utils/auth";
|
||||
|
||||
test("Assistant Drag and Drop", async ({ page }) => {
|
||||
test.fail();
|
||||
|
||||
await page.context().clearCookies();
|
||||
await loginAsRandomUser(page);
|
||||
|
||||
// Navigate to the chat page
|
||||
await page.goto("http://localhost:3000/chat");
|
||||
|
||||
// Helper function to get the current order of assistants
|
||||
const getAssistantOrder = async () => {
|
||||
const assistants = await page.$$('[data-testid^="assistant-["]');
|
||||
return Promise.all(
|
||||
assistants.map(async (assistant) => {
|
||||
const nameElement = await assistant.$("p");
|
||||
return nameElement ? nameElement.textContent() : "";
|
||||
})
|
||||
);
|
||||
};
|
||||
|
||||
// Get the initial order
|
||||
const initialOrder = await getAssistantOrder();
|
||||
|
||||
// Drag second assistant above first
|
||||
const secondAssistant = page.locator('[data-testid^="assistant-["]').nth(1);
|
||||
const firstAssistant = page.locator('[data-testid^="assistant-["]').nth(0);
|
||||
|
||||
await dragElementAbove(secondAssistant, firstAssistant, page);
|
||||
|
||||
// Check new order
|
||||
const orderAfterDragUp = await getAssistantOrder();
|
||||
expect(orderAfterDragUp[0]).toBe(initialOrder[1]);
|
||||
expect(orderAfterDragUp[1]).toBe(initialOrder[0]);
|
||||
|
||||
// Drag last assistant to second position
|
||||
const assistants = page.locator('[data-testid^="assistant-["]');
|
||||
const lastIndex = (await assistants.count()) - 1;
|
||||
const lastAssistant = assistants.nth(lastIndex);
|
||||
const secondPosition = assistants.nth(1);
|
||||
|
||||
await page.waitForTimeout(3000);
|
||||
await dragElementBelow(lastAssistant, secondPosition, page);
|
||||
|
||||
// Check new order
|
||||
const orderAfterDragDown = await getAssistantOrder();
|
||||
expect(orderAfterDragDown[1]).toBe(initialOrder[lastIndex]);
|
||||
|
||||
// Refresh and verify order
|
||||
await page.reload();
|
||||
const orderAfterRefresh = await getAssistantOrder();
|
||||
expect(orderAfterRefresh).toEqual(orderAfterDragDown);
|
||||
});
|
||||
72
web/tests/e2e/chat/live_assistant.spec.ts
Normal file
72
web/tests/e2e/chat/live_assistant.spec.ts
Normal file
@@ -0,0 +1,72 @@
|
||||
import { test, expect } from "@playwright/test";
|
||||
import { loginAsRandomUser } from "../utils/auth";
|
||||
import {
|
||||
navigateToAssistantInHistorySidebar,
|
||||
sendMessage,
|
||||
startNewChat,
|
||||
switchModel,
|
||||
} from "../utils/chatActions";
|
||||
|
||||
test("Chat workflow", async ({ page }) => {
|
||||
test.fail();
|
||||
|
||||
// Clear cookies and log in as a random user
|
||||
await page.context().clearCookies();
|
||||
await loginAsRandomUser(page);
|
||||
|
||||
// Navigate to the chat page
|
||||
await page.goto("http://localhost:3000/chat");
|
||||
|
||||
// Test interaction with the Art assistant
|
||||
await navigateToAssistantInHistorySidebar(
|
||||
page,
|
||||
"[-3]",
|
||||
"Assistant for generating"
|
||||
);
|
||||
await sendMessage(page, "Hi");
|
||||
|
||||
// Start a new chat session
|
||||
await startNewChat(page);
|
||||
|
||||
// Verify the presence of the expected text
|
||||
await expect(page.getByText("Assistant for generating")).toBeVisible();
|
||||
|
||||
// Test interaction with the General assistant
|
||||
await navigateToAssistantInHistorySidebar(
|
||||
page,
|
||||
"[-1]",
|
||||
"Assistant with no search"
|
||||
);
|
||||
|
||||
// Verify the URL after selecting the General assistant
|
||||
await expect(page).toHaveURL("http://localhost:3000/chat?assistantId=-1");
|
||||
|
||||
// Test creation of a new assistant
|
||||
await page.getByRole("button", { name: "Explore Assistants" }).click();
|
||||
await page.getByRole("button", { name: "Create" }).click();
|
||||
await page.getByTestId("name").click();
|
||||
await page.getByTestId("name").fill("Test Assistant");
|
||||
await page.getByTestId("description").click();
|
||||
await page.getByTestId("description").fill("Test Assistant Description");
|
||||
await page.getByTestId("system_prompt").click();
|
||||
await page.getByTestId("system_prompt").fill("Test Assistant Instructions");
|
||||
await page.getByRole("button", { name: "Create" }).click();
|
||||
|
||||
// Verify the successful creation of the new assistant
|
||||
await expect(page.getByText("Test Assistant Description")).toBeVisible({
|
||||
timeout: 5000,
|
||||
});
|
||||
|
||||
// Start another new chat session
|
||||
await startNewChat(page);
|
||||
|
||||
// Verify the presence of the default assistant text
|
||||
try {
|
||||
await expect(page.getByText("Assistant with access to")).toBeVisible({
|
||||
timeout: 5000,
|
||||
});
|
||||
} catch (error) {
|
||||
console.error("Live Assistant final page content:");
|
||||
console.error(await page.content());
|
||||
}
|
||||
});
|
||||
86
web/tests/e2e/chat/llm_ordering.spec.ts
Normal file
86
web/tests/e2e/chat/llm_ordering.spec.ts
Normal file
@@ -0,0 +1,86 @@
|
||||
import { test, expect } from "@playwright/test";
|
||||
import { loginAsRandomUser } from "../utils/auth";
|
||||
import {
|
||||
navigateToAssistantInHistorySidebar,
|
||||
sendMessage,
|
||||
verifyCurrentModel,
|
||||
switchModel,
|
||||
startNewChat,
|
||||
} from "../utils/chatActions";
|
||||
|
||||
test("LLM Ordering and Model Switching", async ({ page }) => {
|
||||
test.fail();
|
||||
|
||||
// Setup: Clear cookies and log in as a random user
|
||||
await page.context().clearCookies();
|
||||
await loginAsRandomUser(page);
|
||||
|
||||
// Navigate to the chat page and verify URL
|
||||
await page.goto("http://localhost:3000/chat");
|
||||
await page.waitForSelector("#onyx-chat-input-textarea");
|
||||
await expect(page.url()).toBe("http://localhost:3000/chat");
|
||||
|
||||
// Configure user settings: Set default model to GPT 4 Turbo
|
||||
await page.locator("#onyx-user-dropdown").click();
|
||||
await page.getByText("User Settings").click();
|
||||
await page.getByRole("combobox").click();
|
||||
await page.getByLabel("GPT 4 Turbo", { exact: true }).click();
|
||||
await page.getByLabel("Close modal").click();
|
||||
await verifyCurrentModel(page, "GPT 4 Turbo");
|
||||
|
||||
// Test Art Assistant: Should use its own model (GPT 4o)
|
||||
await navigateToAssistantInHistorySidebar(
|
||||
page,
|
||||
"[-3]",
|
||||
"Assistant for generating"
|
||||
);
|
||||
await sendMessage(page, "Sample message");
|
||||
await verifyCurrentModel(page, "GPT 4o");
|
||||
|
||||
// Verify model persistence for Art Assistant
|
||||
await sendMessage(page, "Sample message");
|
||||
|
||||
// Test new chat: Should use Art Assistant's model initially
|
||||
await startNewChat(page);
|
||||
await expect(page.getByText("Assistant for generating")).toBeVisible();
|
||||
await verifyCurrentModel(page, "GPT 4o");
|
||||
|
||||
// Test another new chat: Should use user's default model (GPT 4 Turbo)
|
||||
await startNewChat(page);
|
||||
await verifyCurrentModel(page, "GPT 4 Turbo");
|
||||
|
||||
// Test model switching within a chat
|
||||
await switchModel(page, "O1 Mini");
|
||||
await sendMessage(page, "Sample message");
|
||||
await verifyCurrentModel(page, "O1 Mini");
|
||||
|
||||
// Create a custom assistant with a specific model
|
||||
await page.getByRole("button", { name: "Explore Assistants" }).click();
|
||||
await page.getByRole("button", { name: "Create" }).click();
|
||||
await page.waitForTimeout(2000);
|
||||
await page.getByTestId("name").fill("Sample Name");
|
||||
await page.getByTestId("description").fill("Sample Description");
|
||||
await page.getByTestId("system_prompt").fill("Sample Instructions");
|
||||
await page.getByRole("combobox").click();
|
||||
await page
|
||||
.getByLabel("GPT 4 Turbo (Preview)")
|
||||
.getByText("GPT 4 Turbo (Preview)")
|
||||
.click();
|
||||
await page.getByRole("button", { name: "Create" }).click();
|
||||
|
||||
// Verify custom assistant uses its specified model
|
||||
await page.locator("#onyx-chat-input-textarea").fill("");
|
||||
await verifyCurrentModel(page, "GPT 4 Turbo (Preview)");
|
||||
|
||||
// Ensure model persistence for custom assistant
|
||||
await sendMessage(page, "Sample message");
|
||||
await verifyCurrentModel(page, "GPT 4 Turbo (Preview)");
|
||||
|
||||
// Switch back to Art Assistant and verify its model
|
||||
await navigateToAssistantInHistorySidebar(
|
||||
page,
|
||||
"[-3]",
|
||||
"Assistant for generating"
|
||||
);
|
||||
await verifyCurrentModel(page, "GPT 4o");
|
||||
});
|
||||
@@ -35,3 +35,40 @@ export async function loginAs(page: Page, userType: "admin" | "user") {
|
||||
}
|
||||
}
|
||||
}
|
||||
// Function to generate a random email and password
|
||||
const generateRandomCredentials = () => {
|
||||
const randomString = Math.random().toString(36).substring(2, 10);
|
||||
const specialChars = "!@#$%^&*()_+{}[]|:;<>,.?~";
|
||||
const randomSpecialChar =
|
||||
specialChars[Math.floor(Math.random() * specialChars.length)];
|
||||
const randomUpperCase = String.fromCharCode(
|
||||
65 + Math.floor(Math.random() * 26)
|
||||
);
|
||||
const randomNumber = Math.floor(Math.random() * 10);
|
||||
|
||||
return {
|
||||
email: `test_${randomString}@example.com`,
|
||||
password: `P@ssw0rd_${randomUpperCase}${randomSpecialChar}${randomNumber}${randomString}`,
|
||||
};
|
||||
};
|
||||
|
||||
// Function to sign up a new random user
|
||||
export async function loginAsRandomUser(page: Page) {
|
||||
const { email, password } = generateRandomCredentials();
|
||||
|
||||
await page.goto("http://localhost:3000/auth/signup");
|
||||
|
||||
await page.fill("#email", email);
|
||||
await page.fill("#password", password);
|
||||
|
||||
// Click the signup button
|
||||
await page.click('button[type="submit"]');
|
||||
try {
|
||||
await page.waitForURL("http://localhost:3000/chat");
|
||||
} catch (error) {
|
||||
console.log(`Timeout occurred. Current URL: ${page.url()}`);
|
||||
throw new Error("Failed to sign up and redirect to chat page");
|
||||
}
|
||||
|
||||
return { email, password };
|
||||
}
|
||||
|
||||
48
web/tests/e2e/utils/chatActions.ts
Normal file
48
web/tests/e2e/utils/chatActions.ts
Normal file
@@ -0,0 +1,48 @@
|
||||
import { Page } from "@playwright/test";
|
||||
import { expect } from "@playwright/test";
|
||||
|
||||
export async function navigateToAssistantInHistorySidebar(
|
||||
page: Page,
|
||||
testId: string,
|
||||
description: string
|
||||
) {
|
||||
await page.getByTestId(`assistant-${testId}`).click();
|
||||
try {
|
||||
await expect(page.getByText(description)).toBeVisible();
|
||||
} catch (error) {
|
||||
console.error("Error in navigateToAssistantInHistorySidebar:", error);
|
||||
const pageText = await page.textContent("body");
|
||||
console.log("Page text:", pageText);
|
||||
throw error;
|
||||
}
|
||||
}
|
||||
|
||||
export async function sendMessage(page: Page, message: string) {
|
||||
await page.locator("#onyx-chat-input-textarea").click();
|
||||
await page.locator("#onyx-chat-input-textarea").fill(message);
|
||||
await page.locator("#onyx-chat-input-send-button").click();
|
||||
await page.waitForSelector("#onyx-ai-message");
|
||||
await page.waitForTimeout(2000);
|
||||
}
|
||||
|
||||
export async function verifyCurrentModel(page: Page, modelName: string) {
|
||||
await page.waitForTimeout(1000);
|
||||
const chatInput = page.locator("#onyx-chat-input");
|
||||
const text = await chatInput.textContent();
|
||||
expect(text).toContain(modelName);
|
||||
await page.waitForTimeout(1000);
|
||||
}
|
||||
|
||||
// Start of Selection
|
||||
export async function switchModel(page: Page, modelName: string) {
|
||||
await page.getByTestId("llm-popover-trigger").click();
|
||||
await page
|
||||
.getByRole("button", { name: `Logo ${modelName}`, exact: true })
|
||||
.click();
|
||||
await page.waitForTimeout(1000);
|
||||
}
|
||||
|
||||
export async function startNewChat(page: Page) {
|
||||
await page.getByRole("link", { name: "Start New Chat" }).click();
|
||||
await expect(page.locator('div[data-testid="chat-intro"]')).toBeVisible();
|
||||
}
|
||||
74
web/tests/e2e/utils/dragUtils.ts
Normal file
74
web/tests/e2e/utils/dragUtils.ts
Normal file
@@ -0,0 +1,74 @@
|
||||
import { Locator, Page } from "@playwright/test";
|
||||
|
||||
/**
|
||||
* Drag "source" above (higher Y) "target" by using mouse events.
|
||||
* Positions the cursor on the lower half of source, then moves to the top half of the target.
|
||||
*/
|
||||
export async function dragElementAbove(
|
||||
sourceLocator: Locator,
|
||||
targetLocator: Locator,
|
||||
page: Page
|
||||
) {
|
||||
// Get bounding boxes
|
||||
const sourceBB = await sourceLocator.boundingBox();
|
||||
const targetBB = await targetLocator.boundingBox();
|
||||
if (!sourceBB || !targetBB) {
|
||||
throw new Error("Source/target bounding boxes not found.");
|
||||
}
|
||||
|
||||
// Move over source, press mouse down
|
||||
await page.mouse.move(
|
||||
sourceBB.x + sourceBB.width / 2,
|
||||
sourceBB.y + sourceBB.height * 0.75 // Move to 3/4 down the source element
|
||||
);
|
||||
await page.mouse.down();
|
||||
|
||||
// Move to a point slightly above the target's center
|
||||
await page.mouse.move(
|
||||
targetBB.x + targetBB.width / 2,
|
||||
targetBB.y + targetBB.height * 0.1, // Move to 1/10 down the target element
|
||||
{ steps: 20 } // Increase steps for smoother drag
|
||||
);
|
||||
await page.mouse.up();
|
||||
|
||||
// Increase wait time for DnD transitions
|
||||
await page.waitForTimeout(200);
|
||||
}
|
||||
|
||||
/**
|
||||
* Drag "source" below (higher Y → lower Y) "target" using mouse events.
|
||||
*/
|
||||
export async function dragElementBelow(
|
||||
sourceLocator: Locator,
|
||||
targetLocator: Locator,
|
||||
page: Page
|
||||
) {
|
||||
// Get bounding boxes
|
||||
const sourceBB = await targetLocator.boundingBox();
|
||||
const targetBB = await sourceLocator.boundingBox();
|
||||
if (!sourceBB || !targetBB) {
|
||||
throw new Error("Source/target bounding boxes not found.");
|
||||
}
|
||||
|
||||
// Move over source, press mouse down
|
||||
await page.mouse.move(
|
||||
sourceBB.x + sourceBB.width / 2,
|
||||
sourceBB.y + sourceBB.height * 0.25 // Move to 1/4 down the source element
|
||||
);
|
||||
await page.mouse.down();
|
||||
|
||||
// Move to a point well below the target's bottom edge
|
||||
await page.mouse.move(
|
||||
targetBB.x + targetBB.width / 2,
|
||||
targetBB.y + targetBB.height + 50, // Move 50 pixels below the target element
|
||||
{ steps: 50 } // Keep the same number of steps for smooth drag
|
||||
);
|
||||
|
||||
// Hold for a moment to ensure the drag is registered
|
||||
await page.waitForTimeout(500);
|
||||
|
||||
await page.mouse.up();
|
||||
|
||||
// Wait for DnD transitions and potential animations
|
||||
await page.waitForTimeout(1000);
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
{
|
||||
"cookies": [
|
||||
{
|
||||
"name": "fastapiusersauth",
|
||||
"value": "n_EMYYKHn4tQbuPTEbtN1gJ6dQTGek9omJPhO2GhHoA",
|
||||
"domain": "localhost",
|
||||
"path": "/",
|
||||
"expires": 1738801376.508558,
|
||||
"httpOnly": true,
|
||||
"secure": false,
|
||||
"sameSite": "Lax"
|
||||
}
|
||||
],
|
||||
"origins": []
|
||||
}
|
||||
Reference in New Issue
Block a user