Compare commits

..

8 Commits

Author SHA1 Message Date
Evan Lohn
b1d42726b1 test: file reader tool (#8856) 2026-02-28 23:09:43 +00:00
Yuhong Sun
7d922bffc1 chore: Persona cleanup (#8810) 2026-02-28 21:34:45 +00:00
Evan Lohn
de7fc36fc5 test: no vector db user file processing (#8854) 2026-02-28 04:19:59 +00:00
Evan Lohn
7f9e37450d fix: non vector db tasks (#8849) 2026-02-28 03:51:57 +00:00
Evan Lohn
c7ef85b733 chore: narrow no_vector_db supported scope (#8847) 2026-02-28 02:54:15 +00:00
Danelegend
bd9319e592 feat: LLM Provider Rework (#8761)
Co-authored-by: Raunak Bhagat <r@rabh.io>
2026-02-28 01:29:49 +00:00
Nikolas Garza
db5955d6f2 fix(ee): show Access Restricted page when seat limit exceeded (#8877) 2026-02-28 01:26:00 +00:00
Raunak Bhagat
5e447440ea refactor(Suggestions): migrate to opal Interactive + Content (#8881) 2026-02-27 23:39:20 +00:00
159 changed files with 4483 additions and 2213 deletions

View File

@@ -114,8 +114,10 @@ jobs:
- name: Mark workflow as failed if cherry-pick failed
if: steps.gate.outputs.should_cherrypick == 'true' && steps.run_cherry_pick.outputs.status == 'failure'
env:
CHERRY_PICK_REASON: ${{ steps.run_cherry_pick.outputs.reason }}
run: |
echo "::error::Automated cherry-pick failed (${{ steps.run_cherry_pick.outputs.reason }})."
echo "::error::Automated cherry-pick failed (${CHERRY_PICK_REASON})."
exit 1
notify-slack-on-cherry-pick-failure:

View File

@@ -0,0 +1,112 @@
"""persona cleanup and featured
Revision ID: 6b3b4083c5aa
Revises: 57122d037335
Create Date: 2026-02-26 12:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "6b3b4083c5aa"
down_revision = "57122d037335"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add featured column with nullable=True first
op.add_column("persona", sa.Column("featured", sa.Boolean(), nullable=True))
# Migrate data from is_default_persona to featured
op.execute("UPDATE persona SET featured = is_default_persona")
# Make featured non-nullable with default=False
op.alter_column(
"persona",
"featured",
existing_type=sa.Boolean(),
nullable=False,
server_default=sa.false(),
)
# Drop is_default_persona column
op.drop_column("persona", "is_default_persona")
# Drop unused columns
op.drop_column("persona", "num_chunks")
op.drop_column("persona", "chunks_above")
op.drop_column("persona", "chunks_below")
op.drop_column("persona", "llm_relevance_filter")
op.drop_column("persona", "llm_filter_extraction")
op.drop_column("persona", "recency_bias")
def downgrade() -> None:
# Add back recency_bias column
op.add_column(
"persona",
sa.Column(
"recency_bias",
sa.VARCHAR(),
nullable=False,
server_default="base_decay",
),
)
# Add back llm_filter_extraction column
op.add_column(
"persona",
sa.Column(
"llm_filter_extraction",
sa.Boolean(),
nullable=False,
server_default=sa.false(),
),
)
# Add back llm_relevance_filter column
op.add_column(
"persona",
sa.Column(
"llm_relevance_filter",
sa.Boolean(),
nullable=False,
server_default=sa.false(),
),
)
# Add back chunks_below column
op.add_column(
"persona",
sa.Column("chunks_below", sa.Integer(), nullable=False, server_default="0"),
)
# Add back chunks_above column
op.add_column(
"persona",
sa.Column("chunks_above", sa.Integer(), nullable=False, server_default="0"),
)
# Add back num_chunks column
op.add_column("persona", sa.Column("num_chunks", sa.Float(), nullable=True))
# Add back is_default_persona column
op.add_column(
"persona",
sa.Column(
"is_default_persona",
sa.Boolean(),
nullable=False,
server_default=sa.false(),
),
)
# Migrate data from featured to is_default_persona
op.execute("UPDATE persona SET is_default_persona = featured")
# Drop featured column
op.drop_column("persona", "featured")

View File

@@ -18,8 +18,8 @@ from ee.onyx.server.enterprise_settings.store import (
store_settings as store_ee_settings,
)
from ee.onyx.server.enterprise_settings.store import upload_logo
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import update_default_provider
from onyx.db.llm import upsert_llm_provider
from onyx.db.models import Tool
@@ -117,15 +117,38 @@ def _seed_custom_tools(db_session: Session, tools: List[CustomToolSeed]) -> None
def _seed_llms(
db_session: Session, llm_upsert_requests: list[LLMProviderUpsertRequest]
) -> None:
if llm_upsert_requests:
logger.notice("Seeding LLMs")
seeded_providers = [
upsert_llm_provider(llm_upsert_request, db_session)
for llm_upsert_request in llm_upsert_requests
]
update_default_provider(
provider_id=seeded_providers[0].id, db_session=db_session
)
if not llm_upsert_requests:
return
logger.notice("Seeding LLMs")
for request in llm_upsert_requests:
existing = fetch_existing_llm_provider(name=request.name, db_session=db_session)
if existing:
request.id = existing.id
seeded_providers = [
upsert_llm_provider(llm_upsert_request, db_session)
for llm_upsert_request in llm_upsert_requests
]
default_provider = next(
(p for p in seeded_providers if p.model_configurations), None
)
if not default_provider:
return
visible_configs = [
mc for mc in default_provider.model_configurations if mc.is_visible
]
default_config = (
visible_configs[0]
if visible_configs
else default_provider.model_configurations[0]
)
update_default_provider(
provider_id=default_provider.id,
model_name=default_config.name,
db_session=db_session,
)
def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) -> None:
@@ -137,12 +160,6 @@ def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) ->
user=None, # Seeding is done as admin
name=persona.name,
description=persona.description,
num_chunks=(
persona.num_chunks if persona.num_chunks is not None else 0.0
),
llm_relevance_filter=persona.llm_relevance_filter,
llm_filter_extraction=persona.llm_filter_extraction,
recency_bias=RecencyBiasSetting.AUTO,
document_set_ids=persona.document_set_ids,
llm_model_provider_override=persona.llm_model_provider_override,
llm_model_version_override=persona.llm_model_version_override,
@@ -154,6 +171,7 @@ def _seed_personas(db_session: Session, personas: list[PersonaUpsertRequest]) ->
system_prompt=persona.system_prompt,
task_prompt=persona.task_prompt,
datetime_aware=persona.datetime_aware,
featured=persona.featured,
commit=False,
)
db_session.commit()

View File

@@ -109,6 +109,12 @@ def apply_license_status_to_settings(settings: Settings) -> Settings:
if metadata.status == _BLOCKING_STATUS:
settings.application_status = metadata.status
settings.ee_features_enabled = False
elif metadata.used_seats > metadata.seats:
# License is valid but seat limit exceeded
settings.application_status = ApplicationStatus.SEAT_LIMIT_EXCEEDED
settings.seat_count = metadata.seats
settings.used_seats = metadata.used_seats
settings.ee_features_enabled = True
else:
# Has a valid license (GRACE_PERIOD/PAYMENT_REMINDER still allow EE features)
settings.ee_features_enabled = True

View File

@@ -33,6 +33,7 @@ from onyx.configs.constants import MilestoneRecordType
from onyx.db.engine.sql_engine import get_session_with_shared_schema
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.image_generation import create_default_image_gen_config_from_api_key
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import update_default_provider
from onyx.db.llm import upsert_cloud_embedding_provider
from onyx.db.llm import upsert_llm_provider
@@ -302,12 +303,17 @@ def configure_default_api_keys(db_session: Session) -> None:
has_set_default_provider = False
def _upsert(request: LLMProviderUpsertRequest) -> None:
def _upsert(request: LLMProviderUpsertRequest, default_model: str) -> None:
nonlocal has_set_default_provider
try:
existing = fetch_existing_llm_provider(
name=request.name, db_session=db_session
)
if existing:
request.id = existing.id
provider = upsert_llm_provider(request, db_session)
if not has_set_default_provider:
update_default_provider(provider.id, db_session)
update_default_provider(provider.id, default_model, db_session)
has_set_default_provider = True
except Exception as e:
logger.error(f"Failed to configure {request.provider} provider: {e}")
@@ -325,14 +331,13 @@ def configure_default_api_keys(db_session: Session) -> None:
name="OpenAI",
provider=OPENAI_PROVIDER_NAME,
api_key=OPENAI_DEFAULT_API_KEY,
default_model_name=default_model_name,
model_configurations=_build_model_configuration_upsert_requests(
OPENAI_PROVIDER_NAME, recommendations
),
api_key_changed=True,
is_auto_mode=True,
)
_upsert(openai_provider)
_upsert(openai_provider, default_model_name)
# Create default image generation config using the OpenAI API key
try:
@@ -361,14 +366,13 @@ def configure_default_api_keys(db_session: Session) -> None:
name="Anthropic",
provider=ANTHROPIC_PROVIDER_NAME,
api_key=ANTHROPIC_DEFAULT_API_KEY,
default_model_name=default_model_name,
model_configurations=_build_model_configuration_upsert_requests(
ANTHROPIC_PROVIDER_NAME, recommendations
),
api_key_changed=True,
is_auto_mode=True,
)
_upsert(anthropic_provider)
_upsert(anthropic_provider, default_model_name)
else:
logger.info(
"ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration"
@@ -393,14 +397,13 @@ def configure_default_api_keys(db_session: Session) -> None:
name="Google Vertex AI",
provider=VERTEXAI_PROVIDER_NAME,
custom_config=custom_config,
default_model_name=default_model_name,
model_configurations=_build_model_configuration_upsert_requests(
VERTEXAI_PROVIDER_NAME, recommendations
),
api_key_changed=True,
is_auto_mode=True,
)
_upsert(vertexai_provider)
_upsert(vertexai_provider, default_model_name)
else:
logger.info(
"VERTEXAI_DEFAULT_CREDENTIALS not set, skipping Vertex AI provider configuration"
@@ -432,12 +435,11 @@ def configure_default_api_keys(db_session: Session) -> None:
name="OpenRouter",
provider=OPENROUTER_PROVIDER_NAME,
api_key=OPENROUTER_DEFAULT_API_KEY,
default_model_name=default_model_name,
model_configurations=model_configurations,
api_key_changed=True,
is_auto_mode=True,
)
_upsert(openrouter_provider)
_upsert(openrouter_provider, default_model_name)
else:
logger.info(
"OPENROUTER_DEFAULT_API_KEY not set, skipping OpenRouter provider configuration"

View File

@@ -543,7 +543,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
result = await db_session.execute(
select(Persona.id)
.where(
Persona.is_default_persona.is_(True),
Persona.featured.is_(True),
Persona.is_public.is_(True),
Persona.is_visible.is_(True),
Persona.deleted.is_(False),

View File

@@ -241,8 +241,7 @@ _VECTOR_DB_BEAT_TASK_NAMES: set[str] = {
"check-for-index-attempt-cleanup",
"check-for-doc-permissions-sync",
"check-for-external-group-sync",
"check-for-documents-for-opensearch-migration",
"migrate-documents-from-vespa-to-opensearch",
"migrate-chunks-from-vespa-to-opensearch",
}
if DISABLE_VECTOR_DB:

View File

@@ -414,34 +414,31 @@ def _process_user_file_with_indexing(
raise RuntimeError(f"Indexing pipeline failed for user file {user_file_id}")
@shared_task(
name=OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
bind=True,
ignore_result=True,
)
def process_single_user_file(
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
def _process_user_file_impl(
*, user_file_id: str, tenant_id: str, redis_locking: bool
) -> None:
task_logger.info(f"process_single_user_file - Starting id={user_file_id}")
"""Core implementation for processing a single user file.
When redis_locking=True, acquires a per-file Redis lock and clears the
queued-key guard (Celery path). When redis_locking=False, skips all Redis
operations (BackgroundTask path).
"""
task_logger.info(f"_process_user_file_impl - Starting id={user_file_id}")
start = time.monotonic()
redis_client = get_redis_client(tenant_id=tenant_id)
# Clear the "queued" guard set by the beat generator so that the next beat
# cycle can re-enqueue this file if it is still in PROCESSING state after
# this task completes or fails.
redis_client.delete(_user_file_queued_key(user_file_id))
file_lock: RedisLock = redis_client.lock(
_user_file_lock_key(user_file_id),
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,
)
if not file_lock.acquire(blocking=False):
task_logger.info(
f"process_single_user_file - Lock held, skipping user_file_id={user_file_id}"
file_lock: RedisLock | None = None
if redis_locking:
redis_client = get_redis_client(tenant_id=tenant_id)
redis_client.delete(_user_file_queued_key(user_file_id))
file_lock = redis_client.lock(
_user_file_lock_key(user_file_id),
timeout=CELERY_USER_FILE_PROCESSING_LOCK_TIMEOUT,
)
return None
if file_lock is not None and not file_lock.acquire(blocking=False):
task_logger.info(
f"_process_user_file_impl - Lock held, skipping user_file_id={user_file_id}"
)
return
documents: list[Document] = []
try:
@@ -449,15 +446,15 @@ def process_single_user_file(
uf = db_session.get(UserFile, _as_uuid(user_file_id))
if not uf:
task_logger.warning(
f"process_single_user_file - UserFile not found id={user_file_id}"
f"_process_user_file_impl - UserFile not found id={user_file_id}"
)
return None
return
if uf.status != UserFileStatus.PROCESSING:
task_logger.info(
f"process_single_user_file - Skipping id={user_file_id} status={uf.status}"
f"_process_user_file_impl - Skipping id={user_file_id} status={uf.status}"
)
return None
return
connector = LocalFileConnector(
file_locations=[uf.file_id],
@@ -471,7 +468,6 @@ def process_single_user_file(
[doc for doc in batch if not isinstance(doc, HierarchyNode)]
)
# update the document id to userfile id in the documents
for document in documents:
document.id = str(user_file_id)
document.source = DocumentSource.USER_FILE
@@ -493,9 +489,8 @@ def process_single_user_file(
except Exception as e:
task_logger.exception(
f"process_single_user_file - Error processing file id={user_file_id} - {e.__class__.__name__}"
f"_process_user_file_impl - Error processing file id={user_file_id} - {e.__class__.__name__}"
)
# don't update the status if the user file is being deleted
current_user_file = db_session.get(UserFile, _as_uuid(user_file_id))
if (
current_user_file
@@ -504,33 +499,42 @@ def process_single_user_file(
uf.status = UserFileStatus.FAILED
db_session.add(uf)
db_session.commit()
return None
return
elapsed = time.monotonic() - start
task_logger.info(
f"process_single_user_file - Finished id={user_file_id} docs={len(documents)} elapsed={elapsed:.2f}s"
f"_process_user_file_impl - Finished id={user_file_id} docs={len(documents)} elapsed={elapsed:.2f}s"
)
return None
except Exception as e:
# Attempt to mark the file as failed
with get_session_with_current_tenant() as db_session:
uf = db_session.get(UserFile, _as_uuid(user_file_id))
if uf:
# don't update the status if the user file is being deleted
if uf.status != UserFileStatus.DELETING:
uf.status = UserFileStatus.FAILED
db_session.add(uf)
db_session.commit()
task_logger.exception(
f"process_single_user_file - Error processing file id={user_file_id} - {e.__class__.__name__}"
f"_process_user_file_impl - Error processing file id={user_file_id} - {e.__class__.__name__}"
)
return None
finally:
if file_lock.owned():
if file_lock is not None and file_lock.owned():
file_lock.release()
@shared_task(
name=OnyxCeleryTask.PROCESS_SINGLE_USER_FILE,
bind=True,
ignore_result=True,
)
def process_single_user_file(
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
) -> None:
_process_user_file_impl(
user_file_id=user_file_id, tenant_id=tenant_id, redis_locking=True
)
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_USER_FILE_DELETE,
soft_time_limit=300,
@@ -581,36 +585,38 @@ def check_for_user_file_delete(self: Task, *, tenant_id: str) -> None:
return None
@shared_task(
name=OnyxCeleryTask.DELETE_SINGLE_USER_FILE,
bind=True,
ignore_result=True,
)
def process_single_user_file_delete(
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
def _delete_user_file_impl(
*, user_file_id: str, tenant_id: str, redis_locking: bool
) -> None:
"""Process a single user file delete."""
task_logger.info(f"process_single_user_file_delete - Starting id={user_file_id}")
redis_client = get_redis_client(tenant_id=tenant_id)
file_lock: RedisLock = redis_client.lock(
_user_file_delete_lock_key(user_file_id),
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
)
if not file_lock.acquire(blocking=False):
task_logger.info(
f"process_single_user_file_delete - Lock held, skipping user_file_id={user_file_id}"
"""Core implementation for deleting a single user file.
When redis_locking=True, acquires a per-file Redis lock (Celery path).
When redis_locking=False, skips Redis operations (BackgroundTask path).
"""
task_logger.info(f"_delete_user_file_impl - Starting id={user_file_id}")
file_lock: RedisLock | None = None
if redis_locking:
redis_client = get_redis_client(tenant_id=tenant_id)
file_lock = redis_client.lock(
_user_file_delete_lock_key(user_file_id),
timeout=CELERY_GENERIC_BEAT_LOCK_TIMEOUT,
)
return None
if file_lock is not None and not file_lock.acquire(blocking=False):
task_logger.info(
f"_delete_user_file_impl - Lock held, skipping user_file_id={user_file_id}"
)
return
try:
with get_session_with_current_tenant() as db_session:
user_file = db_session.get(UserFile, _as_uuid(user_file_id))
if not user_file:
task_logger.info(
f"process_single_user_file_delete - User file not found id={user_file_id}"
f"_delete_user_file_impl - User file not found id={user_file_id}"
)
return None
return
# 1) Delete vector DB chunks (skip when disabled)
if not DISABLE_VECTOR_DB:
if MANAGED_VESPA:
httpx_init_vespa_pool(
@@ -648,7 +654,6 @@ def process_single_user_file_delete(
chunk_count=chunk_count,
)
# 2) Delete the user-uploaded file content from filestore (blob + metadata)
file_store = get_default_file_store()
try:
file_store.delete_file(user_file.file_id)
@@ -656,26 +661,33 @@ def process_single_user_file_delete(
user_file_id_to_plaintext_file_name(user_file.id)
)
except Exception as e:
# This block executed only if the file is not found in the filestore
task_logger.exception(
f"process_single_user_file_delete - Error deleting file id={user_file.id} - {e.__class__.__name__}"
f"_delete_user_file_impl - Error deleting file id={user_file.id} - {e.__class__.__name__}"
)
# 3) Finally, delete the UserFile row
db_session.delete(user_file)
db_session.commit()
task_logger.info(
f"process_single_user_file_delete - Completed id={user_file_id}"
)
task_logger.info(f"_delete_user_file_impl - Completed id={user_file_id}")
except Exception as e:
task_logger.exception(
f"process_single_user_file_delete - Error processing file id={user_file_id} - {e.__class__.__name__}"
f"_delete_user_file_impl - Error processing file id={user_file_id} - {e.__class__.__name__}"
)
return None
finally:
if file_lock.owned():
if file_lock is not None and file_lock.owned():
file_lock.release()
return None
@shared_task(
name=OnyxCeleryTask.DELETE_SINGLE_USER_FILE,
bind=True,
ignore_result=True,
)
def process_single_user_file_delete(
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
) -> None:
_delete_user_file_impl(
user_file_id=user_file_id, tenant_id=tenant_id, redis_locking=True
)
@shared_task(
@@ -747,32 +759,30 @@ def check_for_user_file_project_sync(self: Task, *, tenant_id: str) -> None:
return None
@shared_task(
name=OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
bind=True,
ignore_result=True,
)
def process_single_user_file_project_sync(
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
def _project_sync_user_file_impl(
*, user_file_id: str, tenant_id: str, redis_locking: bool
) -> None:
"""Process a single user file project sync."""
task_logger.info(
f"process_single_user_file_project_sync - Starting id={user_file_id}"
)
"""Core implementation for syncing a user file's project/persona metadata.
redis_client = get_redis_client(tenant_id=tenant_id)
redis_client.delete(_user_file_project_sync_queued_key(user_file_id))
When redis_locking=True, acquires a per-file Redis lock and clears the
queued-key guard (Celery path). When redis_locking=False, skips Redis
operations (BackgroundTask path).
"""
task_logger.info(f"_project_sync_user_file_impl - Starting id={user_file_id}")
file_lock: RedisLock = redis_client.lock(
user_file_project_sync_lock_key(user_file_id),
timeout=CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT,
)
if not file_lock.acquire(blocking=False):
task_logger.info(
f"process_single_user_file_project_sync - Lock held, skipping user_file_id={user_file_id}"
file_lock: RedisLock | None = None
if redis_locking:
redis_client = get_redis_client(tenant_id=tenant_id)
redis_client.delete(_user_file_project_sync_queued_key(user_file_id))
file_lock = redis_client.lock(
user_file_project_sync_lock_key(user_file_id),
timeout=CELERY_USER_FILE_PROJECT_SYNC_LOCK_TIMEOUT,
)
return None
if file_lock is not None and not file_lock.acquire(blocking=False):
task_logger.info(
f"_project_sync_user_file_impl - Lock held, skipping user_file_id={user_file_id}"
)
return
try:
with get_session_with_current_tenant() as db_session:
@@ -783,11 +793,10 @@ def process_single_user_file_project_sync(
).scalar_one_or_none()
if not user_file:
task_logger.info(
f"process_single_user_file_project_sync - User file not found id={user_file_id}"
f"_project_sync_user_file_impl - User file not found id={user_file_id}"
)
return None
return
# Sync project metadata to vector DB (skip when disabled)
if not DISABLE_VECTOR_DB:
if MANAGED_VESPA:
httpx_init_vespa_pool(
@@ -822,7 +831,7 @@ def process_single_user_file_project_sync(
)
task_logger.info(
f"process_single_user_file_project_sync - User file id={user_file_id}"
f"_project_sync_user_file_impl - User file id={user_file_id}"
)
user_file.needs_project_sync = False
@@ -835,11 +844,21 @@ def process_single_user_file_project_sync(
except Exception as e:
task_logger.exception(
f"process_single_user_file_project_sync - Error syncing project for file id={user_file_id} - {e.__class__.__name__}"
f"_project_sync_user_file_impl - Error syncing project for file id={user_file_id} - {e.__class__.__name__}"
)
return None
finally:
if file_lock.owned():
if file_lock is not None and file_lock.owned():
file_lock.release()
return None
@shared_task(
name=OnyxCeleryTask.PROCESS_SINGLE_USER_FILE_PROJECT_SYNC,
bind=True,
ignore_result=True,
)
def process_single_user_file_project_sync(
self: Task, *, user_file_id: str, tenant_id: str # noqa: ARG001
) -> None:
_project_sync_user_file_impl(
user_file_id=user_file_id, tenant_id=tenant_id, redis_locking=True
)

View File

@@ -202,7 +202,6 @@ def create_default_image_gen_config_from_api_key(
api_key=api_key,
api_base=None,
api_version=None,
default_model_name=model_name,
deployment_name=None,
is_public=True,
)

View File

@@ -213,11 +213,29 @@ def upsert_llm_provider(
llm_provider_upsert_request: LLMProviderUpsertRequest,
db_session: Session,
) -> LLMProviderView:
existing_llm_provider = fetch_existing_llm_provider(
name=llm_provider_upsert_request.name, db_session=db_session
)
existing_llm_provider: LLMProviderModel | None = None
if llm_provider_upsert_request.id:
existing_llm_provider = fetch_existing_llm_provider_by_id(
id=llm_provider_upsert_request.id, db_session=db_session
)
if not existing_llm_provider:
raise ValueError(
f"LLM provider with id {llm_provider_upsert_request.id} not found"
)
if not existing_llm_provider:
if existing_llm_provider.name != llm_provider_upsert_request.name:
raise ValueError(
f"LLM provider with id {llm_provider_upsert_request.id} name change not allowed"
)
else:
existing_llm_provider = fetch_existing_llm_provider(
name=llm_provider_upsert_request.name, db_session=db_session
)
if existing_llm_provider:
raise ValueError(
f"LLM provider with name '{llm_provider_upsert_request.name}'"
" already exists"
)
existing_llm_provider = LLMProviderModel(name=llm_provider_upsert_request.name)
db_session.add(existing_llm_provider)
@@ -238,11 +256,7 @@ def upsert_llm_provider(
existing_llm_provider.api_base = api_base
existing_llm_provider.api_version = llm_provider_upsert_request.api_version
existing_llm_provider.custom_config = custom_config
# TODO: Remove default model name on api change
# Needed due to /provider/{id}/default endpoint not disclosing the default model name
existing_llm_provider.default_model_name = (
llm_provider_upsert_request.default_model_name
)
existing_llm_provider.is_public = llm_provider_upsert_request.is_public
existing_llm_provider.is_auto_mode = llm_provider_upsert_request.is_auto_mode
existing_llm_provider.deployment_name = llm_provider_upsert_request.deployment_name
@@ -306,15 +320,6 @@ def upsert_llm_provider(
display_name=model_config.display_name,
)
default_model = fetch_default_model(db_session, LLMModelFlowType.CHAT)
if default_model and default_model.llm_provider_id == existing_llm_provider.id:
_update_default_model(
db_session=db_session,
provider_id=existing_llm_provider.id,
model=existing_llm_provider.default_model_name,
flow_type=LLMModelFlowType.CHAT,
)
# Make sure the relationship table stays up to date
update_group_llm_provider_relationships__no_commit(
llm_provider_id=existing_llm_provider.id,
@@ -488,6 +493,22 @@ def fetch_existing_llm_provider(
return provider_model
def fetch_existing_llm_provider_by_id(
id: int, db_session: Session
) -> LLMProviderModel | None:
provider_model = db_session.scalar(
select(LLMProviderModel)
.where(LLMProviderModel.id == id)
.options(
selectinload(LLMProviderModel.model_configurations),
selectinload(LLMProviderModel.groups),
selectinload(LLMProviderModel.personas),
)
)
return provider_model
def fetch_embedding_provider(
db_session: Session, provider_type: EmbeddingProvider
) -> CloudEmbeddingProviderModel | None:
@@ -604,22 +625,13 @@ def remove_llm_provider__no_commit(db_session: Session, provider_id: int) -> Non
db_session.flush()
def update_default_provider(provider_id: int, db_session: Session) -> None:
# Attempt to get the default_model_name from the provider first
# TODO: Remove default_model_name check
provider = db_session.scalar(
select(LLMProviderModel).where(
LLMProviderModel.id == provider_id,
)
)
if provider is None:
raise ValueError(f"LLM Provider with id={provider_id} does not exist")
def update_default_provider(
provider_id: int, model_name: str, db_session: Session
) -> None:
_update_default_model(
db_session,
provider_id,
provider.default_model_name, # type: ignore[arg-type]
model_name,
LLMModelFlowType.CHAT,
)
@@ -805,12 +817,6 @@ def sync_auto_mode_models(
)
changes += 1
# In Auto mode, default model is always set from GitHub config
default_model = llm_recommendations.get_default_model(provider.provider)
if default_model and provider.default_model_name != default_model.name:
provider.default_model_name = default_model.name
changes += 1
db_session.commit()
return changes

View File

@@ -103,7 +103,6 @@ from onyx.utils.encryption import encrypt_string_to_bytes
from onyx.utils.sensitive import SensitiveValue
from onyx.utils.headers import HeaderItemDict
from shared_configs.enums import EmbeddingProvider
from onyx.context.search.enums import RecencyBiasSetting
# TODO: After anonymous user migration has been deployed, make user_id columns NOT NULL
# and update Mapped[User | None] relationships to Mapped[User] where needed.
@@ -3265,19 +3264,6 @@ class Persona(Base):
)
name: Mapped[str] = mapped_column(String)
description: Mapped[str] = mapped_column(String)
# Number of chunks to pass to the LLM for generation.
num_chunks: Mapped[float | None] = mapped_column(Float, nullable=True)
chunks_above: Mapped[int] = mapped_column(Integer)
chunks_below: Mapped[int] = mapped_column(Integer)
# Pass every chunk through LLM for evaluation, fairly expensive
# Can be turned off globally by admin, in which case, this setting is ignored
llm_relevance_filter: Mapped[bool] = mapped_column(Boolean)
# Enables using LLM to extract time and source type filters
# Can also be admin disabled globally
llm_filter_extraction: Mapped[bool] = mapped_column(Boolean)
recency_bias: Mapped[RecencyBiasSetting] = mapped_column(
Enum(RecencyBiasSetting, native_enum=False)
)
# Allows the persona to specify a specific default LLM model
# NOTE: only is applied on the actual response generation - is not used for things like
@@ -3304,11 +3290,8 @@ class Persona(Base):
# Treated specially (cannot be user edited etc.)
builtin_persona: Mapped[bool] = mapped_column(Boolean, default=False)
# Default personas are personas created by admins and are automatically added
# to all users' assistants list.
is_default_persona: Mapped[bool] = mapped_column(
Boolean, default=False, nullable=False
)
# Featured personas are highlighted in the UI
featured: Mapped[bool] = mapped_column(Boolean, default=False)
# controls whether the persona is available to be selected by users
is_visible: Mapped[bool] = mapped_column(Boolean, default=True)
# controls the ordering of personas in the UI

View File

@@ -18,11 +18,8 @@ from sqlalchemy.orm import Session
from onyx.access.hierarchy_access import get_user_external_group_ids
from onyx.auth.schemas import UserRole
from onyx.configs.app_configs import CURATORS_CANNOT_VIEW_OR_EDIT_NON_OWNED_ASSISTANTS
from onyx.configs.chat_configs import CONTEXT_CHUNKS_ABOVE
from onyx.configs.chat_configs import CONTEXT_CHUNKS_BELOW
from onyx.configs.constants import DEFAULT_PERSONA_ID
from onyx.configs.constants import NotificationType
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.constants import SLACK_BOT_PERSONA_PREFIX
from onyx.db.document_access import get_accessible_documents_by_ids
from onyx.db.models import ConnectorCredentialPair
@@ -254,13 +251,15 @@ def create_update_persona(
# Permission to actually use these is checked later
try:
# Default persona validation
if create_persona_request.is_default_persona:
# Curators can edit default personas, but not make them
# Featured persona validation
if create_persona_request.featured:
# Curators can edit featured personas, but not make them
# TODO this will be reworked soon with RBAC permissions feature
if user.role == UserRole.CURATOR or user.role == UserRole.GLOBAL_CURATOR:
pass
elif user.role != UserRole.ADMIN:
raise ValueError("Only admins can make a default persona")
raise ValueError("Only admins can make a featured persona")
# Convert incoming string UUIDs to UUID objects for DB operations
converted_user_file_ids = None
@@ -281,7 +280,6 @@ def create_update_persona(
document_set_ids=create_persona_request.document_set_ids,
tool_ids=create_persona_request.tool_ids,
is_public=create_persona_request.is_public,
recency_bias=create_persona_request.recency_bias,
llm_model_provider_override=create_persona_request.llm_model_provider_override,
llm_model_version_override=create_persona_request.llm_model_version_override,
starter_messages=create_persona_request.starter_messages,
@@ -295,10 +293,7 @@ def create_update_persona(
remove_image=create_persona_request.remove_image,
search_start_date=create_persona_request.search_start_date,
label_ids=create_persona_request.label_ids,
num_chunks=create_persona_request.num_chunks,
llm_relevance_filter=create_persona_request.llm_relevance_filter,
llm_filter_extraction=create_persona_request.llm_filter_extraction,
is_default_persona=create_persona_request.is_default_persona,
featured=create_persona_request.featured,
user_file_ids=converted_user_file_ids,
commit=False,
hierarchy_node_ids=create_persona_request.hierarchy_node_ids,
@@ -874,10 +869,6 @@ def upsert_persona(
user: User | None,
name: str,
description: str,
num_chunks: float,
llm_relevance_filter: bool,
llm_filter_extraction: bool,
recency_bias: RecencyBiasSetting,
llm_model_provider_override: str | None,
llm_model_version_override: str | None,
starter_messages: list[StarterMessage] | None,
@@ -898,13 +889,11 @@ def upsert_persona(
remove_image: bool | None = None,
search_start_date: datetime | None = None,
builtin_persona: bool = False,
is_default_persona: bool | None = None,
featured: bool | None = None,
label_ids: list[int] | None = None,
user_file_ids: list[UUID] | None = None,
hierarchy_node_ids: list[int] | None = None,
document_ids: list[str] | None = None,
chunks_above: int = CONTEXT_CHUNKS_ABOVE,
chunks_below: int = CONTEXT_CHUNKS_BELOW,
replace_base_system_prompt: bool = False,
) -> Persona:
"""
@@ -1015,12 +1004,6 @@ def upsert_persona(
# `default` and `built-in` properties can only be set when creating a persona.
existing_persona.name = name
existing_persona.description = description
existing_persona.num_chunks = num_chunks
existing_persona.chunks_above = chunks_above
existing_persona.chunks_below = chunks_below
existing_persona.llm_relevance_filter = llm_relevance_filter
existing_persona.llm_filter_extraction = llm_filter_extraction
existing_persona.recency_bias = recency_bias
existing_persona.llm_model_provider_override = llm_model_provider_override
existing_persona.llm_model_version_override = llm_model_version_override
existing_persona.starter_messages = starter_messages
@@ -1034,10 +1017,8 @@ def upsert_persona(
if label_ids is not None:
existing_persona.labels.clear()
existing_persona.labels = labels or []
existing_persona.is_default_persona = (
is_default_persona
if is_default_persona is not None
else existing_persona.is_default_persona
existing_persona.featured = (
featured if featured is not None else existing_persona.featured
)
# Update embedded prompt fields if provided
if system_prompt is not None:
@@ -1090,12 +1071,6 @@ def upsert_persona(
is_public=is_public,
name=name,
description=description,
num_chunks=num_chunks,
chunks_above=chunks_above,
chunks_below=chunks_below,
llm_relevance_filter=llm_relevance_filter,
llm_filter_extraction=llm_filter_extraction,
recency_bias=recency_bias,
builtin_persona=builtin_persona,
system_prompt=system_prompt or "",
task_prompt=task_prompt or "",
@@ -1111,9 +1086,7 @@ def upsert_persona(
display_priority=display_priority,
is_visible=is_visible,
search_start_date=search_start_date,
is_default_persona=(
is_default_persona if is_default_persona is not None else False
),
featured=(featured if featured is not None else False),
user_files=user_files or [],
labels=labels or [],
hierarchy_nodes=hierarchy_nodes or [],
@@ -1158,9 +1131,9 @@ def delete_old_default_personas(
db_session.commit()
def update_persona_is_default(
def update_persona_featured(
persona_id: int,
is_default: bool,
featured: bool,
db_session: Session,
user: User,
) -> None:
@@ -1168,7 +1141,7 @@ def update_persona_is_default(
db_session=db_session, persona_id=persona_id, user=user, get_editable=True
)
persona.is_default_persona = is_default
persona.featured = featured
db_session.commit()

View File

@@ -5,8 +5,6 @@ from sqlalchemy import select
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.constants import DEFAULT_PERSONA_SLACK_CHANNEL_NAME
from onyx.db.constants import SLACK_BOT_PERSONA_PREFIX
from onyx.db.models import ChannelConfig
@@ -45,8 +43,6 @@ def create_slack_channel_persona(
channel_name: str | None,
document_set_ids: list[int],
existing_persona_id: int | None = None,
num_chunks: float = MAX_CHUNKS_FED_TO_CHAT,
enable_auto_filters: bool = False,
) -> Persona:
"""NOTE: does not commit changes"""
@@ -73,17 +69,13 @@ def create_slack_channel_persona(
system_prompt="",
task_prompt="",
datetime_aware=True,
num_chunks=num_chunks,
llm_relevance_filter=True,
llm_filter_extraction=enable_auto_filters,
recency_bias=RecencyBiasSetting.AUTO,
tool_ids=[search_tool.id],
document_set_ids=document_set_ids,
llm_model_provider_override=None,
llm_model_version_override=None,
starter_messages=None,
is_public=True,
is_default_persona=False,
featured=False,
db_session=db_session,
commit=False,
)

View File

@@ -37,6 +37,7 @@ from onyx.configs.app_configs import APP_HOST
from onyx.configs.app_configs import APP_PORT
from onyx.configs.app_configs import AUTH_RATE_LIMITING_ENABLED
from onyx.configs.app_configs import AUTH_TYPE
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.app_configs import LOG_ENDPOINT_LATENCY
from onyx.configs.app_configs import OAUTH_CLIENT_ID
from onyx.configs.app_configs import OAUTH_CLIENT_SECRET
@@ -254,8 +255,38 @@ def include_auth_router_with_prefix(
)
def validate_no_vector_db_settings() -> None:
"""Validate that DISABLE_VECTOR_DB is not combined with incompatible settings.
Raises RuntimeError if DISABLE_VECTOR_DB is set alongside MULTI_TENANT or ENABLE_CRAFT,
since these modes require infrastructure that is removed in no-vector-DB deployments.
"""
if not DISABLE_VECTOR_DB:
return
if MULTI_TENANT:
raise RuntimeError(
"DISABLE_VECTOR_DB cannot be used with MULTI_TENANT. "
"Multi-tenant deployments require the vector database for "
"per-tenant document indexing and search. Run in single-tenant "
"mode when disabling the vector database."
)
from onyx.server.features.build.configs import ENABLE_CRAFT
if ENABLE_CRAFT:
raise RuntimeError(
"DISABLE_VECTOR_DB cannot be used with ENABLE_CRAFT. "
"Onyx Craft requires background workers for sandbox lifecycle "
"management, which are removed in no-vector-DB deployments. "
"Disable Craft (ENABLE_CRAFT=false) when disabling the vector database."
)
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]: # noqa: ARG001
validate_no_vector_db_settings()
# Set recursion limit
if SYSTEM_RECURSION_LIMIT is not None:
sys.setrecursionlimit(SYSTEM_RECURSION_LIMIT)

View File

@@ -32,7 +32,7 @@ from onyx.db.persona import get_persona_snapshots_for_user
from onyx.db.persona import get_persona_snapshots_paginated
from onyx.db.persona import mark_persona_as_deleted
from onyx.db.persona import mark_persona_as_not_deleted
from onyx.db.persona import update_persona_is_default
from onyx.db.persona import update_persona_featured
from onyx.db.persona import update_persona_label
from onyx.db.persona import update_persona_public_status
from onyx.db.persona import update_persona_shared
@@ -130,8 +130,8 @@ class IsPublicRequest(BaseModel):
is_public: bool
class IsDefaultRequest(BaseModel):
is_default_persona: bool
class IsFeaturedRequest(BaseModel):
featured: bool
@admin_router.patch("/{persona_id}/visible")
@@ -168,22 +168,22 @@ def patch_user_persona_public_status(
raise HTTPException(status_code=403, detail=str(e))
@admin_router.patch("/{persona_id}/default")
def patch_persona_default_status(
@admin_router.patch("/{persona_id}/featured")
def patch_persona_featured_status(
persona_id: int,
is_default_request: IsDefaultRequest,
is_featured_request: IsFeaturedRequest,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> None:
try:
update_persona_is_default(
update_persona_featured(
persona_id=persona_id,
is_default=is_default_request.is_default_persona,
featured=is_featured_request.featured,
db_session=db_session,
user=user,
)
except ValueError as e:
logger.exception("Failed to update persona default status")
logger.exception("Failed to update persona featured status")
raise HTTPException(status_code=403, detail=str(e))

View File

@@ -5,7 +5,6 @@ from pydantic import BaseModel
from pydantic import Field
from onyx.configs.constants import DocumentSource
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.enums import HierarchyNodeType
from onyx.db.models import Document
from onyx.db.models import HierarchyNode
@@ -108,11 +107,7 @@ class PersonaUpsertRequest(BaseModel):
name: str
description: str
document_set_ids: list[int]
num_chunks: float
is_public: bool
recency_bias: RecencyBiasSetting
llm_filter_extraction: bool
llm_relevance_filter: bool
llm_model_provider_override: str | None = None
llm_model_version_override: str | None = None
starter_messages: list[StarterMessage] | None = None
@@ -128,7 +123,7 @@ class PersonaUpsertRequest(BaseModel):
)
search_start_date: datetime | None = None
label_ids: list[int] | None = None
is_default_persona: bool = False
featured: bool = False
display_priority: int | None = None
# Accept string UUIDs from frontend
user_file_ids: list[str] | None = None
@@ -155,9 +150,6 @@ class MinimalPersonaSnapshot(BaseModel):
tools: list[ToolSnapshot]
starter_messages: list[StarterMessage] | None
llm_relevance_filter: bool
llm_filter_extraction: bool
# only show document sets in the UI that the assistant has access to
document_sets: list[DocumentSetSummary]
# Counts for knowledge sources (used to determine if search tool should be enabled)
@@ -175,7 +167,7 @@ class MinimalPersonaSnapshot(BaseModel):
is_public: bool
is_visible: bool
display_priority: int | None
is_default_persona: bool
featured: bool
builtin_persona: bool
# Used for filtering
@@ -214,8 +206,6 @@ class MinimalPersonaSnapshot(BaseModel):
if should_expose_tool_to_fe(tool)
],
starter_messages=persona.starter_messages,
llm_relevance_filter=persona.llm_relevance_filter,
llm_filter_extraction=persona.llm_filter_extraction,
document_sets=[
DocumentSetSummary.from_model(document_set)
for document_set in persona.document_sets
@@ -230,7 +220,7 @@ class MinimalPersonaSnapshot(BaseModel):
is_public=persona.is_public,
is_visible=persona.is_visible,
display_priority=persona.display_priority,
is_default_persona=persona.is_default_persona,
featured=persona.featured,
builtin_persona=persona.builtin_persona,
labels=[PersonaLabelSnapshot.from_model(label) for label in persona.labels],
owner=(
@@ -252,11 +242,9 @@ class PersonaSnapshot(BaseModel):
# Return string UUIDs to frontend for consistency
user_file_ids: list[str]
display_priority: int | None
is_default_persona: bool
featured: bool
builtin_persona: bool
starter_messages: list[StarterMessage] | None
llm_relevance_filter: bool
llm_filter_extraction: bool
tools: list[ToolSnapshot]
labels: list["PersonaLabelSnapshot"]
owner: MinimalUserSnapshot | None
@@ -265,7 +253,6 @@ class PersonaSnapshot(BaseModel):
document_sets: list[DocumentSetSummary]
llm_model_provider_override: str | None
llm_model_version_override: str | None
num_chunks: float | None
# Hierarchy nodes attached for scoped search
hierarchy_nodes: list[HierarchyNodeSnapshot] = Field(default_factory=list)
# Individual documents attached for scoped search
@@ -289,11 +276,9 @@ class PersonaSnapshot(BaseModel):
icon_name=persona.icon_name,
user_file_ids=[str(file.id) for file in persona.user_files],
display_priority=persona.display_priority,
is_default_persona=persona.is_default_persona,
featured=persona.featured,
builtin_persona=persona.builtin_persona,
starter_messages=persona.starter_messages,
llm_relevance_filter=persona.llm_relevance_filter,
llm_filter_extraction=persona.llm_filter_extraction,
tools=[
ToolSnapshot.from_model(tool)
for tool in persona.tools
@@ -324,7 +309,6 @@ class PersonaSnapshot(BaseModel):
],
llm_model_provider_override=persona.llm_model_provider_override,
llm_model_version_override=persona.llm_model_version_override,
num_chunks=persona.num_chunks,
system_prompt=persona.system_prompt,
replace_base_system_prompt=persona.replace_base_system_prompt,
task_prompt=persona.task_prompt,
@@ -332,12 +316,10 @@ class PersonaSnapshot(BaseModel):
)
# Model with full context on perona's internal settings
# Model with full context on persona's internal settings
# This is used for flows which need to know all settings
class FullPersonaSnapshot(PersonaSnapshot):
search_start_date: datetime | None = None
llm_relevance_filter: bool = False
llm_filter_extraction: bool = False
@classmethod
def from_model(
@@ -360,7 +342,7 @@ class FullPersonaSnapshot(PersonaSnapshot):
icon_name=persona.icon_name,
user_file_ids=[str(file.id) for file in persona.user_files],
display_priority=persona.display_priority,
is_default_persona=persona.is_default_persona,
featured=persona.featured,
builtin_persona=persona.builtin_persona,
starter_messages=persona.starter_messages,
users=[
@@ -391,10 +373,7 @@ class FullPersonaSnapshot(PersonaSnapshot):
DocumentSetSummary.from_model(document_set_model)
for document_set_model in persona.document_sets
],
num_chunks=persona.num_chunks,
search_start_date=persona.search_start_date,
llm_relevance_filter=persona.llm_relevance_filter,
llm_filter_extraction=persona.llm_filter_extraction,
llm_model_provider_override=persona.llm_model_provider_override,
llm_model_version_override=persona.llm_model_version_override,
system_prompt=persona.system_prompt,

View File

@@ -335,7 +335,7 @@ def upsert_project_instructions(
class ProjectPayload(BaseModel):
project: UserProjectSnapshot
files: list[UserFileSnapshot] | None = None
persona_id_to_is_default: dict[int, bool] | None = None
persona_id_to_featured: dict[int, bool] | None = None
@router.get(
@@ -354,13 +354,11 @@ def get_project_details(
if session.persona_id is not None
]
personas = get_personas_by_ids(persona_ids, db_session)
persona_id_to_is_default = {
persona.id: persona.is_default_persona for persona in personas
}
persona_id_to_featured = {persona.id: persona.featured for persona in personas}
return ProjectPayload(
project=project,
files=files,
persona_id_to_is_default=persona_id_to_is_default,
persona_id_to_featured=persona_id_to_featured,
)

View File

@@ -5,7 +5,6 @@ from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.configs.constants import TMP_DRALPHA_PERSONA_NAME
from onyx.configs.kg_configs import KG_BETA_ASSISTANT_DESCRIPTION
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.engine.sql_engine import get_session
from onyx.db.entities import get_entity_stats_by_grounded_source_name
from onyx.db.entity_type import get_configured_entity_types
@@ -134,11 +133,7 @@ def enable_or_disable_kg(
system_prompt=KG_BETA_ASSISTANT_SYSTEM_PROMPT,
task_prompt=KG_BETA_ASSISTANT_TASK_PROMPT,
datetime_aware=False,
num_chunks=25,
llm_relevance_filter=False,
is_public=False,
llm_filter_extraction=False,
recency_bias=RecencyBiasSetting.NO_DECAY,
document_set_ids=[],
tool_ids=[search_tool.id, kg_tool.id],
llm_model_provider_override=None,
@@ -147,7 +142,7 @@ def enable_or_disable_kg(
users=[user.id],
groups=[],
label_ids=[],
is_default_persona=False,
featured=False,
display_priority=0,
user_file_ids=[],
)

View File

@@ -97,7 +97,6 @@ def _build_llm_provider_request(
), # Only this from source
api_base=api_base, # From request
api_version=api_version, # From request
default_model_name=model_name,
deployment_name=deployment_name, # From request
is_public=True,
groups=[],
@@ -136,7 +135,6 @@ def _build_llm_provider_request(
api_key=api_key,
api_base=api_base,
api_version=api_version,
default_model_name=model_name,
deployment_name=deployment_name,
is_public=True,
groups=[],
@@ -168,7 +166,6 @@ def _create_image_gen_llm_provider__no_commit(
api_key=provider_request.api_key,
api_base=provider_request.api_base,
api_version=provider_request.api_version,
default_model_name=provider_request.default_model_name,
deployment_name=provider_request.deployment_name,
is_public=provider_request.is_public,
custom_config=provider_request.custom_config,

View File

@@ -22,7 +22,10 @@ from onyx.auth.users import current_chat_accessible_user
from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import LLMModelFlowType
from onyx.db.llm import can_user_access_llm_provider
from onyx.db.llm import fetch_default_llm_model
from onyx.db.llm import fetch_default_vision_model
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import fetch_existing_llm_provider_by_id
from onyx.db.llm import fetch_existing_llm_providers
from onyx.db.llm import fetch_existing_models
from onyx.db.llm import fetch_persona_with_groups
@@ -52,11 +55,12 @@ from onyx.llm.well_known_providers.llm_provider_options import (
)
from onyx.server.manage.llm.models import BedrockFinalModelResponse
from onyx.server.manage.llm.models import BedrockModelsRequest
from onyx.server.manage.llm.models import DefaultModel
from onyx.server.manage.llm.models import LLMCost
from onyx.server.manage.llm.models import LLMProviderDescriptor
from onyx.server.manage.llm.models import LLMProviderResponse
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
from onyx.server.manage.llm.models import OllamaFinalModelResponse
from onyx.server.manage.llm.models import OllamaModelDetails
from onyx.server.manage.llm.models import OllamaModelsRequest
@@ -233,12 +237,9 @@ def test_llm_configuration(
test_api_key = test_llm_request.api_key
test_custom_config = test_llm_request.custom_config
if test_llm_request.name:
# NOTE: we are querying by name. we probably should be querying by an invariant id, but
# as it turns out the name is not editable in the UI and other code also keys off name,
# so we won't rock the boat just yet.
existing_provider = fetch_existing_llm_provider(
name=test_llm_request.name, db_session=db_session
if test_llm_request.id:
existing_provider = fetch_existing_llm_provider_by_id(
id=test_llm_request.id, db_session=db_session
)
if existing_provider:
test_custom_config = _restore_masked_custom_config_values(
@@ -268,7 +269,7 @@ def test_llm_configuration(
llm = get_llm(
provider=test_llm_request.provider,
model=test_llm_request.default_model_name,
model=test_llm_request.model,
api_key=test_api_key,
api_base=test_llm_request.api_base,
api_version=test_llm_request.api_version,
@@ -303,7 +304,7 @@ def list_llm_providers(
include_image_gen: bool = Query(False),
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[LLMProviderView]:
) -> LLMProviderResponse[LLMProviderView]:
start_time = datetime.now(timezone.utc)
logger.debug("Starting to fetch LLM providers")
@@ -328,7 +329,15 @@ def list_llm_providers(
duration = (end_time - start_time).total_seconds()
logger.debug(f"Completed fetching LLM providers in {duration:.2f} seconds")
return llm_provider_list
return LLMProviderResponse[LLMProviderView].from_models(
providers=llm_provider_list,
default_text=DefaultModel.from_model_config(
fetch_default_llm_model(db_session)
),
default_vision=DefaultModel.from_model_config(
fetch_default_vision_model(db_session)
),
)
@admin_router.put("/provider")
@@ -344,18 +353,44 @@ def put_llm_provider(
# validate request (e.g. if we're intending to create but the name already exists we should throw an error)
# NOTE: may involve duplicate fetching to Postgres, but we're assuming SQLAlchemy is smart enough to cache
# the result
existing_provider = fetch_existing_llm_provider(
existing_provider = None
if llm_provider_upsert_request.id:
existing_provider = fetch_existing_llm_provider_by_id(
id=llm_provider_upsert_request.id, db_session=db_session
)
# Check name constraints
# TODO: Once port from name to id is complete, unique name will no longer be required
if existing_provider and llm_provider_upsert_request.name != existing_provider.name:
raise HTTPException(
status_code=400,
detail="Renaming providers is not currently supported",
)
found_provider = fetch_existing_llm_provider(
name=llm_provider_upsert_request.name, db_session=db_session
)
if found_provider is not None and found_provider is not existing_provider:
raise HTTPException(
status_code=400,
detail=f"Provider with name={llm_provider_upsert_request.name} already exists",
)
if existing_provider and is_creation:
raise HTTPException(
status_code=400,
detail=f"LLM Provider with name {llm_provider_upsert_request.name} already exists",
detail=(
f"LLM Provider with name {llm_provider_upsert_request.name} and "
f"id={llm_provider_upsert_request.id} already exists"
),
)
elif not existing_provider and not is_creation:
raise HTTPException(
status_code=400,
detail=f"LLM Provider with name {llm_provider_upsert_request.name} does not exist",
detail=(
f"LLM Provider with name {llm_provider_upsert_request.name} and "
f"id={llm_provider_upsert_request.id} does not exist"
),
)
# SSRF Protection: Validate api_base and custom_config match stored values
@@ -393,22 +428,6 @@ def put_llm_provider(
deduplicated_personas.append(persona_id)
llm_provider_upsert_request.personas = deduplicated_personas
default_model_found = False
for model_configuration in llm_provider_upsert_request.model_configurations:
if model_configuration.name == llm_provider_upsert_request.default_model_name:
model_configuration.is_visible = True
default_model_found = True
# TODO: Remove this logic on api change
# Believed to be a dead pathway but we want to be safe for now
if not default_model_found:
llm_provider_upsert_request.model_configurations.append(
ModelConfigurationUpsertRequest(
name=llm_provider_upsert_request.default_model_name, is_visible=True
)
)
# the llm api key is sanitized when returned to clients, so the only time we
# should get a real key is when it is explicitly changed
if existing_provider and not llm_provider_upsert_request.api_key_changed:
@@ -438,8 +457,8 @@ def put_llm_provider(
config = fetch_llm_recommendations_from_github()
if config and llm_provider_upsert_request.provider in config.providers:
# Refetch the provider to get the updated model
updated_provider = fetch_existing_llm_provider(
name=llm_provider_upsert_request.name, db_session=db_session
updated_provider = fetch_existing_llm_provider_by_id(
id=result.id, db_session=db_session
)
if updated_provider:
sync_auto_mode_models(
@@ -469,28 +488,29 @@ def delete_llm_provider(
raise HTTPException(status_code=404, detail=str(e))
@admin_router.post("/provider/{provider_id}/default")
@admin_router.post("/default")
def set_provider_as_default(
provider_id: int,
default_model_request: DefaultModel,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
update_default_provider(provider_id=provider_id, db_session=db_session)
update_default_provider(
provider_id=default_model_request.provider_id,
model_name=default_model_request.model_name,
db_session=db_session,
)
@admin_router.post("/provider/{provider_id}/default-vision")
@admin_router.post("/default-vision")
def set_provider_as_default_vision(
provider_id: int,
vision_model: str | None = Query(
None, description="The default vision model to use"
),
default_model: DefaultModel,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
if vision_model is None:
raise HTTPException(status_code=404, detail="Vision model not provided")
update_default_vision_provider(
provider_id=provider_id, vision_model=vision_model, db_session=db_session
provider_id=default_model.provider_id,
vision_model=default_model.model_name,
db_session=db_session,
)
@@ -516,7 +536,7 @@ def get_auto_config(
def get_vision_capable_providers(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[VisionProviderResponse]:
) -> LLMProviderResponse[VisionProviderResponse]:
"""Return a list of LLM providers and their models that support image input"""
vision_models = fetch_existing_models(
db_session=db_session, flow_types=[LLMModelFlowType.VISION]
@@ -545,7 +565,13 @@ def get_vision_capable_providers(
]
logger.debug(f"Found {len(vision_provider_response)} vision-capable providers")
return vision_provider_response
return LLMProviderResponse[VisionProviderResponse].from_models(
providers=vision_provider_response,
default_vision=DefaultModel.from_model_config(
fetch_default_vision_model(db_session)
),
)
"""Endpoints for all"""
@@ -555,7 +581,7 @@ def get_vision_capable_providers(
def list_llm_provider_basics(
user: User = Depends(current_chat_accessible_user),
db_session: Session = Depends(get_session),
) -> list[LLMProviderDescriptor]:
) -> LLMProviderResponse[LLMProviderDescriptor]:
"""Get LLM providers accessible to the current user.
Returns:
@@ -592,7 +618,15 @@ def list_llm_provider_basics(
f"Completed fetching {len(accessible_providers)} user-accessible providers in {duration:.2f} seconds"
)
return accessible_providers
return LLMProviderResponse[LLMProviderDescriptor].from_models(
providers=accessible_providers,
default_text=DefaultModel.from_model_config(
fetch_default_llm_model(db_session)
),
default_vision=DefaultModel.from_model_config(
fetch_default_vision_model(db_session)
),
)
def get_valid_model_names_for_persona(
@@ -635,7 +669,7 @@ def list_llm_providers_for_persona(
persona_id: int,
user: User = Depends(current_chat_accessible_user),
db_session: Session = Depends(get_session),
) -> list[LLMProviderDescriptor]:
) -> LLMProviderResponse[LLMProviderDescriptor]:
"""Get LLM providers for a specific persona.
Returns providers that the user can access when using this persona:
@@ -682,7 +716,51 @@ def list_llm_providers_for_persona(
f"Completed fetching {len(llm_provider_list)} LLM providers for persona {persona_id} in {duration:.2f} seconds"
)
return llm_provider_list
# Get the default model and vision model for the persona
# TODO: Port persona's over to use ID
persona_default_provider = persona.llm_model_provider_override
persona_default_model = persona.llm_model_version_override
default_text_model = fetch_default_llm_model(db_session)
default_vision_model = fetch_default_vision_model(db_session)
# Build default_text and default_vision using persona overrides when available,
# falling back to the global defaults.
default_text = DefaultModel.from_model_config(default_text_model)
default_vision = DefaultModel.from_model_config(default_vision_model)
if persona_default_provider:
provider = fetch_existing_llm_provider(persona_default_provider, db_session)
if provider and can_user_access_llm_provider(
provider, user_group_ids, persona, is_admin=is_admin
):
if persona_default_model:
# Persona specifies both provider and model — use them directly
default_text = DefaultModel(
provider_id=provider.id,
model_name=persona_default_model,
)
else:
# Persona specifies only the provider — pick a visible (public) model,
# falling back to any model on this provider
visible_model = next(
(mc for mc in provider.model_configurations if mc.is_visible),
None,
)
fallback_model = visible_model or next(
iter(provider.model_configurations), None
)
if fallback_model:
default_text = DefaultModel(
provider_id=provider.id,
model_name=fallback_model.name,
)
return LLMProviderResponse[LLMProviderDescriptor].from_models(
providers=llm_provider_list,
default_text=default_text,
default_vision=default_vision,
)
@admin_router.get("/provider-contextual-cost")

View File

@@ -1,5 +1,9 @@
from __future__ import annotations
from typing import Any
from typing import Generic
from typing import TYPE_CHECKING
from typing import TypeVar
from pydantic import BaseModel
from pydantic import Field
@@ -21,50 +25,22 @@ if TYPE_CHECKING:
ModelConfiguration as ModelConfigurationModel,
)
# TODO: Clear this up on api refactor
# There is still logic that requires sending each providers default model name
# There is no logic that requires sending the providers default vision model name
# We only send for the one that is actually the default
def get_default_llm_model_name(llm_provider_model: "LLMProviderModel") -> str:
"""Find the default conversation model name for a provider.
Returns the model name if found, otherwise returns empty string.
"""
for model_config in llm_provider_model.model_configurations:
for flow in model_config.llm_model_flows:
if flow.is_default and flow.llm_model_flow_type == LLMModelFlowType.CHAT:
return model_config.name
return ""
def get_default_vision_model_name(llm_provider_model: "LLMProviderModel") -> str | None:
"""Find the default vision model name for a provider.
Returns the model name if found, otherwise returns None.
"""
for model_config in llm_provider_model.model_configurations:
for flow in model_config.llm_model_flows:
if flow.is_default and flow.llm_model_flow_type == LLMModelFlowType.VISION:
return model_config.name
return None
T = TypeVar("T", "LLMProviderDescriptor", "LLMProviderView", "VisionProviderResponse")
class TestLLMRequest(BaseModel):
# provider level
name: str | None = None
id: int | None = None
provider: str
model: str
api_key: str | None = None
api_base: str | None = None
api_version: str | None = None
custom_config: dict[str, str] | None = None
# model level
default_model_name: str
deployment_name: str | None = None
model_configurations: list["ModelConfigurationUpsertRequest"]
# if try and use the existing API/custom config key
api_key_changed: bool
custom_config_changed: bool
@@ -80,13 +56,10 @@ class LLMProviderDescriptor(BaseModel):
"""A descriptor for an LLM provider that can be safely viewed by
non-admin users. Used when giving a list of available LLMs."""
id: int
name: str
provider: str
provider_display_name: str # Human-friendly name like "Claude (Anthropic)"
default_model_name: str
is_default_provider: bool | None
is_default_vision_provider: bool | None
default_vision_model: str | None
model_configurations: list["ModelConfigurationView"]
@classmethod
@@ -99,24 +72,12 @@ class LLMProviderDescriptor(BaseModel):
)
provider = llm_provider_model.provider
default_model_name = get_default_llm_model_name(llm_provider_model)
default_vision_model = get_default_vision_model_name(llm_provider_model)
is_default_provider = bool(default_model_name)
is_default_vision_provider = default_vision_model is not None
default_model_name = (
default_model_name or llm_provider_model.default_model_name or ""
)
return cls(
id=llm_provider_model.id,
name=llm_provider_model.name,
provider=provider,
provider_display_name=get_provider_display_name(provider),
default_model_name=default_model_name,
is_default_provider=is_default_provider,
is_default_vision_provider=is_default_vision_provider,
default_vision_model=default_vision_model,
model_configurations=filter_model_configurations(
llm_provider_model.model_configurations, provider
),
@@ -130,18 +91,17 @@ class LLMProvider(BaseModel):
api_base: str | None = None
api_version: str | None = None
custom_config: dict[str, str] | None = None
default_model_name: str
is_public: bool = True
is_auto_mode: bool = False
groups: list[int] = Field(default_factory=list)
personas: list[int] = Field(default_factory=list)
deployment_name: str | None = None
default_vision_model: str | None = None
class LLMProviderUpsertRequest(LLMProvider):
# should only be used for a "custom" provider
# for default providers, the built-in model names are used
id: int | None = None
api_key_changed: bool = False
custom_config_changed: bool = False
model_configurations: list["ModelConfigurationUpsertRequest"] = []
@@ -157,8 +117,6 @@ class LLMProviderView(LLMProvider):
"""Stripped down representation of LLMProvider for display / limited access info only"""
id: int
is_default_provider: bool | None = None
is_default_vision_provider: bool | None = None
model_configurations: list["ModelConfigurationView"]
@classmethod
@@ -180,16 +138,6 @@ class LLMProviderView(LLMProvider):
provider = llm_provider_model.provider
default_model_name = get_default_llm_model_name(llm_provider_model)
default_vision_model = get_default_vision_model_name(llm_provider_model)
is_default_provider = bool(default_model_name)
is_default_vision_provider = default_vision_model is not None
default_model_name = (
default_model_name or llm_provider_model.default_model_name or ""
)
return cls(
id=llm_provider_model.id,
name=llm_provider_model.name,
@@ -202,10 +150,6 @@ class LLMProviderView(LLMProvider):
api_base=llm_provider_model.api_base,
api_version=llm_provider_model.api_version,
custom_config=llm_provider_model.custom_config,
default_model_name=default_model_name,
is_default_provider=is_default_provider,
is_default_vision_provider=is_default_vision_provider,
default_vision_model=default_vision_model,
is_public=llm_provider_model.is_public,
is_auto_mode=llm_provider_model.is_auto_mode,
groups=groups,
@@ -425,3 +369,38 @@ class OpenRouterFinalModelResponse(BaseModel):
int | None
) # From OpenRouter API context_length (may be missing for some models)
supports_image_input: bool
class DefaultModel(BaseModel):
provider_id: int
model_name: str
@classmethod
def from_model_config(
cls, model_config: ModelConfigurationModel | None
) -> DefaultModel | None:
if not model_config:
return None
return cls(
provider_id=model_config.llm_provider_id,
model_name=model_config.name,
)
class LLMProviderResponse(BaseModel, Generic[T]):
providers: list[T]
default_text: DefaultModel | None = None
default_vision: DefaultModel | None = None
@classmethod
def from_models(
cls,
providers: list[T],
default_text: DefaultModel | None = None,
default_vision: DefaultModel | None = None,
) -> LLMProviderResponse[T]:
return cls(
providers=providers,
default_text=default_text,
default_vision=default_vision,
)

View File

@@ -198,7 +198,6 @@ def patch_slack_channel_config(
channel_name=channel_config["channel_name"],
document_set_ids=slack_channel_config_creation_request.document_sets,
existing_persona_id=existing_persona_id,
enable_auto_filters=slack_channel_config_creation_request.enable_auto_filters,
).id
slack_channel_config_model = update_slack_channel_config(

View File

@@ -19,6 +19,7 @@ class ApplicationStatus(str, Enum):
PAYMENT_REMINDER = "payment_reminder"
GRACE_PERIOD = "grace_period"
GATED_ACCESS = "gated_access"
SEAT_LIMIT_EXCEEDED = "seat_limit_exceeded"
class Notification(BaseModel):
@@ -82,6 +83,10 @@ class Settings(BaseModel):
# Default Assistant settings
disable_default_assistant: bool | None = False
# Seat usage - populated by license enforcement when seat limit is exceeded
seat_count: int | None = None
used_seats: int | None = None
# OpenSearch migration
opensearch_indexing_enabled: bool = False

View File

@@ -25,6 +25,7 @@ from onyx.db.enums import EmbeddingPrecision
from onyx.db.index_attempt import cancel_indexing_attempts_past_model
from onyx.db.index_attempt import expire_index_attempts
from onyx.db.llm import fetch_default_llm_model
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import update_default_provider
from onyx.db.llm import upsert_llm_provider
from onyx.db.search_settings import get_active_search_settings
@@ -254,14 +255,18 @@ def setup_postgres(db_session: Session) -> None:
logger.notice("Setting up default OpenAI LLM for dev.")
llm_model = GEN_AI_MODEL_VERSION or "gpt-4o-mini"
provider_name = "DevEnvPresetOpenAI"
existing = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
model_req = LLMProviderUpsertRequest(
name="DevEnvPresetOpenAI",
id=existing.id if existing else None,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_key=GEN_AI_API_KEY,
api_base=None,
api_version=None,
custom_config=None,
default_model_name=llm_model,
is_public=True,
groups=[],
model_configurations=[
@@ -273,7 +278,9 @@ def setup_postgres(db_session: Session) -> None:
new_llm_provider = upsert_llm_provider(
llm_provider_upsert_request=model_req, db_session=db_session
)
update_default_provider(provider_id=new_llm_provider.id, db_session=db_session)
update_default_provider(
provider_id=new_llm_provider.id, model_name=llm_model, db_session=db_session
)
def update_default_multipass_indexing(db_session: Session) -> None:

View File

@@ -17,7 +17,7 @@ def test_bedrock_llm_configuration(client: TestClient) -> None:
# Prepare the test request payload
test_request: dict[str, Any] = {
"provider": LlmProviderNames.BEDROCK,
"default_model_name": _DEFAULT_BEDROCK_MODEL,
"model": _DEFAULT_BEDROCK_MODEL,
"api_key": None,
"api_base": None,
"api_version": None,
@@ -44,7 +44,7 @@ def test_bedrock_llm_configuration_invalid_key(client: TestClient) -> None:
# Prepare the test request payload with invalid credentials
test_request: dict[str, Any] = {
"provider": LlmProviderNames.BEDROCK,
"default_model_name": _DEFAULT_BEDROCK_MODEL,
"model": _DEFAULT_BEDROCK_MODEL,
"api_key": None,
"api_base": None,
"api_version": None,

View File

@@ -28,7 +28,6 @@ def ensure_default_llm_provider(db_session: Session) -> None:
provider=LlmProviderNames.OPENAI,
api_key=os.environ.get("OPENAI_API_KEY", "test"),
is_public=True,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini",
@@ -41,7 +40,7 @@ def ensure_default_llm_provider(db_session: Session) -> None:
llm_provider_upsert_request=llm_provider_request,
db_session=db_session,
)
update_default_provider(provider.id, db_session)
update_default_provider(provider.id, "gpt-4o-mini", db_session)
except Exception as exc: # pragma: no cover - only hits on duplicate setup issues
# Rollback to clear the pending transaction state
db_session.rollback()

View File

@@ -47,7 +47,6 @@ def test_answer_with_only_anthropic_provider(
name=provider_name,
provider=LlmProviderNames.ANTHROPIC,
api_key=anthropic_api_key,
default_model_name=anthropic_model,
is_public=True,
groups=[],
model_configurations=[
@@ -59,7 +58,7 @@ def test_answer_with_only_anthropic_provider(
)
try:
update_default_provider(anthropic_provider.id, db_session)
update_default_provider(anthropic_provider.id, anthropic_model, db_session)
test_user = create_test_user(db_session, email_prefix="anthropic_only")
chat_session = create_chat_session(

View File

@@ -9,7 +9,6 @@ from onyx.chat.models import AnswerStreamPart
from onyx.chat.models import StreamingError
from onyx.chat.process_message import handle_stream_message_objects
from onyx.db.chat import create_chat_session
from onyx.db.models import RecencyBiasSetting
from onyx.db.models import User
from onyx.db.persona import upsert_persona
from onyx.server.query_and_chat.models import MessageResponseIDInfo
@@ -74,10 +73,6 @@ def test_stream_chat_message_objects_without_web_search(
user=None, # System persona
name=f"Test Persona {uuid.uuid4()}",
description="Test persona with no tools for web search test",
num_chunks=10.0,
llm_relevance_filter=True,
llm_filter_extraction=True,
recency_bias=RecencyBiasSetting.BASE_DECAY,
llm_model_provider_override=None,
llm_model_version_override=None,
starter_messages=None,

View File

@@ -36,7 +36,6 @@ from onyx.background.celery.tasks.user_file_processing.tasks import (
from onyx.background.celery.tasks.user_file_processing.tasks import (
user_file_project_sync_lock_key,
)
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.enums import UserFileStatus
from onyx.db.models import Persona
from onyx.db.models import Persona__UserFile
@@ -86,12 +85,6 @@ def _create_test_persona(
persona = Persona(
name=f"Test Persona {uuid4().hex[:8]}",
description="Test persona",
num_chunks=10.0,
chunks_above=0,
chunks_below=0,
llm_relevance_filter=False,
llm_filter_extraction=False,
recency_bias=RecencyBiasSetting.NO_DECAY,
system_prompt="You are a test assistant",
task_prompt="Answer the question",
tools=[],
@@ -410,10 +403,6 @@ class TestUpsertPersonaMarksSyncFlag:
user=user,
name=f"persona-{uuid4().hex[:8]}",
description="test",
num_chunks=10.0,
llm_relevance_filter=False,
llm_filter_extraction=False,
recency_bias=RecencyBiasSetting.NO_DECAY,
llm_model_provider_override=None,
llm_model_version_override=None,
starter_messages=None,
@@ -442,10 +431,6 @@ class TestUpsertPersonaMarksSyncFlag:
user=user,
name=f"persona-{uuid4().hex[:8]}",
description="test",
num_chunks=10.0,
llm_relevance_filter=False,
llm_filter_extraction=False,
recency_bias=RecencyBiasSetting.NO_DECAY,
llm_model_provider_override=None,
llm_model_version_override=None,
starter_messages=None,
@@ -461,16 +446,11 @@ class TestUpsertPersonaMarksSyncFlag:
uf_old.needs_persona_sync = False
db_session.commit()
assert persona.num_chunks is not None
# Now update the persona to swap files
upsert_persona(
user=user,
name=persona.name,
description=persona.description,
num_chunks=persona.num_chunks,
llm_relevance_filter=persona.llm_relevance_filter,
llm_filter_extraction=persona.llm_filter_extraction,
recency_bias=persona.recency_bias,
llm_model_provider_override=None,
llm_model_version_override=None,
starter_messages=None,
@@ -501,10 +481,6 @@ class TestUpsertPersonaMarksSyncFlag:
user=user,
name=f"persona-{uuid4().hex[:8]}",
description="test",
num_chunks=10.0,
llm_relevance_filter=False,
llm_filter_extraction=False,
recency_bias=RecencyBiasSetting.NO_DECAY,
llm_model_provider_override=None,
llm_model_version_override=None,
starter_messages=None,
@@ -519,15 +495,10 @@ class TestUpsertPersonaMarksSyncFlag:
uf.needs_persona_sync = False
db_session.commit()
assert persona.num_chunks is not None
upsert_persona(
user=user,
name=persona.name,
description=persona.description,
num_chunks=persona.num_chunks,
llm_relevance_filter=persona.llm_relevance_filter,
llm_filter_extraction=persona.llm_filter_extraction,
recency_bias=persona.recency_bias,
llm_model_provider_override=None,
llm_model_version_override=None,
starter_messages=None,

View File

@@ -18,7 +18,6 @@ from sqlalchemy.orm import Session
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.enums import UserFileStatus
from onyx.db.models import Persona
from onyx.db.models import Persona__UserFile
@@ -58,12 +57,6 @@ def _create_persona(db_session: Session, user: User) -> Persona:
persona = Persona(
name=f"Test Persona {uuid4().hex[:8]}",
description="Test persona",
num_chunks=10.0,
chunks_above=0,
chunks_below=0,
llm_relevance_filter=False,
llm_filter_extraction=False,
recency_bias=RecencyBiasSetting.NO_DECAY,
system_prompt="test",
task_prompt="test",
tools=[],

View File

@@ -29,6 +29,7 @@ from onyx.server.manage.llm.api import (
test_llm_configuration as run_test_llm_configuration,
)
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
from onyx.server.manage.llm.models import TestLLMRequest as LLMTestRequest
@@ -44,15 +45,14 @@ def _create_test_provider(
db_session: Session,
name: str,
api_key: str = "sk-test-key-00000000000000000000000000000000000",
) -> None:
) -> LLMProviderView:
"""Helper to create a test LLM provider in the database."""
upsert_llm_provider(
return upsert_llm_provider(
LLMProviderUpsertRequest(
name=name,
provider=LlmProviderNames.OPENAI,
api_key=api_key,
api_key_changed=True,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(name="gpt-4o-mini", is_visible=True)
],
@@ -102,17 +102,11 @@ class TestLLMConfigurationEndpoint:
# This should complete without exception
run_test_llm_configuration(
test_llm_request=LLMTestRequest(
name=None, # New provider (not in DB)
provider=LlmProviderNames.OPENAI,
api_key="sk-new-test-key-0000000000000000000000000000",
api_key_changed=True,
custom_config_changed=False,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
)
],
model="gpt-4o-mini",
),
_=_create_mock_admin(),
db_session=db_session,
@@ -152,17 +146,11 @@ class TestLLMConfigurationEndpoint:
with pytest.raises(HTTPException) as exc_info:
run_test_llm_configuration(
test_llm_request=LLMTestRequest(
name=None,
provider=LlmProviderNames.OPENAI,
api_key="sk-invalid-key-00000000000000000000000000",
api_key_changed=True,
custom_config_changed=False,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
)
],
model="gpt-4o-mini",
),
_=_create_mock_admin(),
db_session=db_session,
@@ -194,7 +182,9 @@ class TestLLMConfigurationEndpoint:
try:
# First, create the provider in the database
_create_test_provider(db_session, provider_name, api_key=original_api_key)
provider = _create_test_provider(
db_session, provider_name, api_key=original_api_key
)
with patch(
"onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_capture
@@ -202,17 +192,12 @@ class TestLLMConfigurationEndpoint:
# Test with api_key_changed=False - should use stored key
run_test_llm_configuration(
test_llm_request=LLMTestRequest(
name=provider_name, # Existing provider
id=provider.id,
provider=LlmProviderNames.OPENAI,
api_key=None, # Not providing a new key
api_key_changed=False, # Using existing key
custom_config_changed=False,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
)
],
model="gpt-4o-mini",
),
_=_create_mock_admin(),
db_session=db_session,
@@ -246,7 +231,9 @@ class TestLLMConfigurationEndpoint:
try:
# First, create the provider in the database
_create_test_provider(db_session, provider_name, api_key=original_api_key)
provider = _create_test_provider(
db_session, provider_name, api_key=original_api_key
)
with patch(
"onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_capture
@@ -254,17 +241,12 @@ class TestLLMConfigurationEndpoint:
# Test with api_key_changed=True - should use new key
run_test_llm_configuration(
test_llm_request=LLMTestRequest(
name=provider_name, # Existing provider
id=provider.id,
provider=LlmProviderNames.OPENAI,
api_key=new_api_key, # Providing a new key
api_key_changed=True, # Key is being changed
custom_config_changed=False,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
)
],
model="gpt-4o-mini",
),
_=_create_mock_admin(),
db_session=db_session,
@@ -297,7 +279,7 @@ class TestLLMConfigurationEndpoint:
try:
# First, create the provider in the database with custom_config
upsert_llm_provider(
provider = upsert_llm_provider(
LLMProviderUpsertRequest(
name=provider_name,
provider=LlmProviderNames.OPENAI,
@@ -305,7 +287,6 @@ class TestLLMConfigurationEndpoint:
api_key_changed=True,
custom_config=original_custom_config,
custom_config_changed=True,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
@@ -321,18 +302,13 @@ class TestLLMConfigurationEndpoint:
# Test with custom_config_changed=False - should use stored config
run_test_llm_configuration(
test_llm_request=LLMTestRequest(
name=provider_name,
id=provider.id,
provider=LlmProviderNames.OPENAI,
api_key=None,
api_key_changed=False,
custom_config=None, # Not providing new config
custom_config_changed=False, # Using existing config
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
)
],
model="gpt-4o-mini",
),
_=_create_mock_admin(),
db_session=db_session,
@@ -368,17 +344,11 @@ class TestLLMConfigurationEndpoint:
for model_name in test_models:
run_test_llm_configuration(
test_llm_request=LLMTestRequest(
name=None,
provider=LlmProviderNames.OPENAI,
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
custom_config_changed=False,
default_model_name=model_name,
model_configurations=[
ModelConfigurationUpsertRequest(
name=model_name, is_visible=True
)
],
model=model_name,
),
_=_create_mock_admin(),
db_session=db_session,
@@ -442,7 +412,6 @@ class TestDefaultProviderEndpoint:
provider=LlmProviderNames.OPENAI,
api_key=provider_1_api_key,
api_key_changed=True,
default_model_name=provider_1_initial_model,
model_configurations=[
ModelConfigurationUpsertRequest(name="gpt-4", is_visible=True),
ModelConfigurationUpsertRequest(name="gpt-4o", is_visible=True),
@@ -452,7 +421,7 @@ class TestDefaultProviderEndpoint:
)
# Set provider 1 as the default provider explicitly
update_default_provider(provider_1.id, db_session)
update_default_provider(provider_1.id, provider_1_initial_model, db_session)
# Step 2: Call run_test_default_provider - should use provider 1's default model
with patch(
@@ -472,7 +441,6 @@ class TestDefaultProviderEndpoint:
provider=LlmProviderNames.OPENAI,
api_key=provider_2_api_key,
api_key_changed=True,
default_model_name=provider_2_default_model,
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
@@ -499,11 +467,11 @@ class TestDefaultProviderEndpoint:
# Step 5: Update provider 1's default model
upsert_llm_provider(
LLMProviderUpsertRequest(
id=provider_1.id,
name=provider_1_name,
provider=LlmProviderNames.OPENAI,
api_key=provider_1_api_key,
api_key_changed=True,
default_model_name=provider_1_updated_model, # Changed
model_configurations=[
ModelConfigurationUpsertRequest(name="gpt-4", is_visible=True),
ModelConfigurationUpsertRequest(name="gpt-4o", is_visible=True),
@@ -512,6 +480,9 @@ class TestDefaultProviderEndpoint:
db_session=db_session,
)
# Set provider 1's default model to the updated model
update_default_provider(provider_1.id, provider_1_updated_model, db_session)
# Step 6: Call run_test_default_provider - should use new model on provider 1
with patch(
"onyx.server.manage.llm.api.test_llm", side_effect=mock_test_llm_capture
@@ -524,7 +495,7 @@ class TestDefaultProviderEndpoint:
captured_llms.clear()
# Step 7: Change the default provider to provider 2
update_default_provider(provider_2.id, db_session)
update_default_provider(provider_2.id, provider_2_default_model, db_session)
# Step 8: Call run_test_default_provider - should use provider 2
with patch(
@@ -596,7 +567,6 @@ class TestDefaultProviderEndpoint:
provider=LlmProviderNames.OPENAI,
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
@@ -605,7 +575,7 @@ class TestDefaultProviderEndpoint:
),
db_session=db_session,
)
update_default_provider(provider.id, db_session)
update_default_provider(provider.id, "gpt-4o-mini", db_session)
# Test should fail
with patch(

View File

@@ -49,7 +49,6 @@ def _create_test_provider(
api_key_changed=True,
api_base=api_base,
custom_config=custom_config,
default_model_name="gpt-4o-mini",
model_configurations=[
ModelConfigurationUpsertRequest(name="gpt-4o-mini", is_visible=True)
],
@@ -91,14 +90,14 @@ class TestLLMProviderChanges:
the API key should be blocked.
"""
try:
_create_test_provider(db_session, provider_name)
provider = _create_test_provider(db_session, provider_name)
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_base="https://attacker.example.com",
default_model_name="gpt-4o-mini",
)
with pytest.raises(HTTPException) as exc_info:
@@ -125,16 +124,16 @@ class TestLLMProviderChanges:
Changing api_base IS allowed when the API key is also being changed.
"""
try:
_create_test_provider(db_session, provider_name)
provider = _create_test_provider(db_session, provider_name)
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_key="sk-new-key-00000000000000000000000000000000000",
api_key_changed=True,
api_base="https://custom-endpoint.example.com/v1",
default_model_name="gpt-4o-mini",
)
result = put_llm_provider(
@@ -159,14 +158,16 @@ class TestLLMProviderChanges:
original_api_base = "https://original.example.com/v1"
try:
_create_test_provider(db_session, provider_name, api_base=original_api_base)
provider = _create_test_provider(
db_session, provider_name, api_base=original_api_base
)
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_base=original_api_base,
default_model_name="gpt-4o-mini",
)
result = put_llm_provider(
@@ -190,14 +191,14 @@ class TestLLMProviderChanges:
changes. This allows model-only updates when provider has no custom base URL.
"""
try:
_create_test_provider(db_session, provider_name, api_base=None)
view = _create_test_provider(db_session, provider_name, api_base=None)
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
update_request = LLMProviderUpsertRequest(
id=view.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_base="",
default_model_name="gpt-4o-mini",
)
result = put_llm_provider(
@@ -223,14 +224,16 @@ class TestLLMProviderChanges:
original_api_base = "https://original.example.com/v1"
try:
_create_test_provider(db_session, provider_name, api_base=original_api_base)
provider = _create_test_provider(
db_session, provider_name, api_base=original_api_base
)
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_base=None,
default_model_name="gpt-4o-mini",
)
with pytest.raises(HTTPException) as exc_info:
@@ -259,14 +262,14 @@ class TestLLMProviderChanges:
users have full control over their deployment.
"""
try:
_create_test_provider(db_session, provider_name)
provider = _create_test_provider(db_session, provider_name)
with patch("onyx.server.manage.llm.api.MULTI_TENANT", False):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_base="https://custom.example.com/v1",
default_model_name="gpt-4o-mini",
)
result = put_llm_provider(
@@ -297,7 +300,6 @@ class TestLLMProviderChanges:
api_key="sk-new-key-00000000000000000000000000000000000",
api_key_changed=True,
api_base="https://custom.example.com/v1",
default_model_name="gpt-4o-mini",
)
result = put_llm_provider(
@@ -322,7 +324,7 @@ class TestLLMProviderChanges:
redirect LLM API requests).
"""
try:
_create_test_provider(
provider = _create_test_provider(
db_session,
provider_name,
custom_config={"SOME_CONFIG": "original_value"},
@@ -330,11 +332,11 @@ class TestLLMProviderChanges:
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
custom_config={"OPENAI_API_BASE": "https://attacker.example.com"},
custom_config_changed=True,
default_model_name="gpt-4o-mini",
)
with pytest.raises(HTTPException) as exc_info:
@@ -362,15 +364,15 @@ class TestLLMProviderChanges:
without changing the API key.
"""
try:
_create_test_provider(db_session, provider_name)
provider = _create_test_provider(db_session, provider_name)
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
custom_config={"OPENAI_API_BASE": "https://attacker.example.com"},
custom_config_changed=True,
default_model_name="gpt-4o-mini",
)
with pytest.raises(HTTPException) as exc_info:
@@ -399,7 +401,7 @@ class TestLLMProviderChanges:
new_config = {"AWS_REGION_NAME": "us-west-2"}
try:
_create_test_provider(
provider = _create_test_provider(
db_session,
provider_name,
custom_config={"AWS_REGION_NAME": "us-east-1"},
@@ -407,13 +409,13 @@ class TestLLMProviderChanges:
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_key="sk-new-key-00000000000000000000000000000000000",
api_key_changed=True,
custom_config_changed=True,
custom_config=new_config,
default_model_name="gpt-4o-mini",
)
result = put_llm_provider(
@@ -438,17 +440,17 @@ class TestLLMProviderChanges:
original_config = {"AWS_REGION_NAME": "us-east-1"}
try:
_create_test_provider(
provider = _create_test_provider(
db_session, provider_name, custom_config=original_config
)
with patch("onyx.server.manage.llm.api.MULTI_TENANT", True):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
custom_config=original_config,
custom_config_changed=True,
default_model_name="gpt-4o-mini",
)
result = put_llm_provider(
@@ -474,7 +476,7 @@ class TestLLMProviderChanges:
new_config = {"AWS_REGION_NAME": "eu-west-1"}
try:
_create_test_provider(
provider = _create_test_provider(
db_session,
provider_name,
custom_config={"AWS_REGION_NAME": "us-east-1"},
@@ -482,10 +484,10 @@ class TestLLMProviderChanges:
with patch("onyx.server.manage.llm.api.MULTI_TENANT", False):
update_request = LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
custom_config=new_config,
default_model_name="gpt-4o-mini",
custom_config_changed=True,
)
@@ -530,14 +532,8 @@ def test_upload_with_custom_config_then_change(
with patch("onyx.server.manage.llm.api.test_llm", side_effect=capture_test_llm):
run_llm_config_test(
LLMTestRequest(
name=name,
provider=provider_name,
default_model_name=default_model_name,
model_configurations=[
ModelConfigurationUpsertRequest(
name=default_model_name, is_visible=True
)
],
model=default_model_name,
api_key_changed=False,
custom_config_changed=True,
custom_config=custom_config,
@@ -546,11 +542,10 @@ def test_upload_with_custom_config_then_change(
db_session=db_session,
)
put_llm_provider(
provider = put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
name=name,
provider=provider_name,
default_model_name=default_model_name,
custom_config=custom_config,
model_configurations=[
ModelConfigurationUpsertRequest(
@@ -569,14 +564,9 @@ def test_upload_with_custom_config_then_change(
# Turn auto mode off
run_llm_config_test(
LLMTestRequest(
name=name,
id=provider.id,
provider=provider_name,
default_model_name=default_model_name,
model_configurations=[
ModelConfigurationUpsertRequest(
name=default_model_name, is_visible=True
)
],
model=default_model_name,
api_key_changed=False,
custom_config_changed=False,
),
@@ -586,9 +576,9 @@ def test_upload_with_custom_config_then_change(
put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
id=provider.id,
name=name,
provider=provider_name,
default_model_name=default_model_name,
model_configurations=[
ModelConfigurationUpsertRequest(
name=default_model_name, is_visible=True
@@ -616,13 +606,13 @@ def test_upload_with_custom_config_then_change(
)
# Check inside the database and check that custom_config is the same as the original
provider = fetch_existing_llm_provider(name=name, db_session=db_session)
if not provider:
db_provider = fetch_existing_llm_provider(name=name, db_session=db_session)
if not db_provider:
assert False, "Provider not found in the database"
assert provider.custom_config == custom_config, (
assert db_provider.custom_config == custom_config, (
f"Expected custom_config {custom_config}, "
f"but got {provider.custom_config}"
f"but got {db_provider.custom_config}"
)
finally:
db_session.rollback()
@@ -642,11 +632,10 @@ def test_preserves_masked_sensitive_custom_config_on_provider_update(
}
try:
put_llm_provider(
view = put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
name=name,
provider=provider,
default_model_name=default_model_name,
custom_config=original_custom_config,
model_configurations=[
ModelConfigurationUpsertRequest(
@@ -665,9 +654,9 @@ def test_preserves_masked_sensitive_custom_config_on_provider_update(
with patch("onyx.server.manage.llm.api.MULTI_TENANT", False):
put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
id=view.id,
name=name,
provider=provider,
default_model_name=default_model_name,
custom_config={
"vertex_credentials": _mask_string(
original_custom_config["vertex_credentials"]
@@ -706,7 +695,7 @@ def test_preserves_masked_sensitive_custom_config_on_test_request(
) -> None:
"""LLM test should restore masked sensitive custom config values before invocation."""
name = f"test-provider-vertex-test-{uuid4().hex[:8]}"
provider = LlmProviderNames.VERTEX_AI.value
provider_name = LlmProviderNames.VERTEX_AI.value
default_model_name = "gemini-2.5-pro"
original_custom_config = {
"vertex_credentials": '{"type":"service_account","private_key":"REAL_PRIVATE_KEY"}',
@@ -719,11 +708,10 @@ def test_preserves_masked_sensitive_custom_config_on_test_request(
return ""
try:
put_llm_provider(
provider = put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
name=name,
provider=provider,
default_model_name=default_model_name,
provider=provider_name,
custom_config=original_custom_config,
model_configurations=[
ModelConfigurationUpsertRequest(
@@ -742,14 +730,9 @@ def test_preserves_masked_sensitive_custom_config_on_test_request(
with patch("onyx.server.manage.llm.api.test_llm", side_effect=capture_test_llm):
run_llm_config_test(
LLMTestRequest(
name=name,
provider=provider,
default_model_name=default_model_name,
model_configurations=[
ModelConfigurationUpsertRequest(
name=default_model_name, is_visible=True
)
],
id=provider.id,
provider=provider_name,
model=default_model_name,
api_key_changed=False,
custom_config_changed=True,
custom_config={

View File

@@ -15,9 +15,11 @@ import pytest
from sqlalchemy.orm import Session
from onyx.db.enums import LLMModelFlowType
from onyx.db.llm import fetch_auto_mode_providers
from onyx.db.llm import fetch_default_llm_model
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import fetch_existing_llm_providers
from onyx.db.llm import fetch_llm_provider_view
from onyx.db.llm import remove_llm_provider
from onyx.db.llm import sync_auto_mode_models
from onyx.db.llm import update_default_provider
@@ -135,7 +137,6 @@ class TestAutoModeSyncFeature:
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=True,
default_model_name=expected_default_model,
model_configurations=[], # No model configs provided
),
is_creation=True,
@@ -163,13 +164,8 @@ class TestAutoModeSyncFeature:
if mc.name in all_expected_models:
assert mc.is_visible is True, f"Model '{mc.name}' should be visible"
# Verify the default model was set correctly
assert (
provider.default_model_name == expected_default_model
), f"Default model should be '{expected_default_model}'"
# Step 4: Set the provider as default
update_default_provider(provider.id, db_session)
update_default_provider(provider.id, expected_default_model, db_session)
# Step 5: Fetch the default provider and verify
default_model = fetch_default_llm_model(db_session)
@@ -238,7 +234,6 @@ class TestAutoModeSyncFeature:
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=True,
default_model_name="gpt-4o",
model_configurations=[],
),
is_creation=True,
@@ -317,7 +312,6 @@ class TestAutoModeSyncFeature:
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=False, # Not in auto mode initially
default_model_name="gpt-4",
model_configurations=initial_models,
),
is_creation=True,
@@ -326,13 +320,13 @@ class TestAutoModeSyncFeature:
)
# Verify initial state: all models are visible
provider = fetch_existing_llm_provider(
initial_provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
assert provider.is_auto_mode is False
assert initial_provider is not None
assert initial_provider.is_auto_mode is False
for mc in provider.model_configurations:
for mc in initial_provider.model_configurations:
assert (
mc.is_visible is True
), f"Initial model '{mc.name}' should be visible"
@@ -344,12 +338,12 @@ class TestAutoModeSyncFeature:
):
put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
id=initial_provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_key=None, # Not changing API key
api_key_changed=False,
is_auto_mode=True, # Now enabling auto mode
default_model_name=auto_mode_default,
model_configurations=[], # Auto mode will sync from config
),
is_creation=False, # This is an update
@@ -360,15 +354,15 @@ class TestAutoModeSyncFeature:
# Step 3: Verify model visibility after auto mode transition
# Expire session cache to force fresh fetch after sync_auto_mode_models committed
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
provider_view = fetch_llm_provider_view(
provider_name=provider_name, db_session=db_session
)
assert provider is not None
assert provider.is_auto_mode is True
assert provider_view is not None
assert provider_view.is_auto_mode is True
# Build a map of model name -> visibility
model_visibility = {
mc.name: mc.is_visible for mc in provider.model_configurations
mc.name: mc.is_visible for mc in provider_view.model_configurations
}
# Models in auto mode config should be visible
@@ -388,9 +382,6 @@ class TestAutoModeSyncFeature:
model_visibility[model_name] is False
), f"Model '{model_name}' not in auto config should NOT be visible"
# Verify the default model was updated
assert provider.default_model_name == auto_mode_default
finally:
db_session.rollback()
_cleanup_provider(db_session, provider_name)
@@ -432,8 +423,12 @@ class TestAutoModeSyncFeature:
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=True,
default_model_name="gpt-4o",
model_configurations=[],
model_configurations=[
ModelConfigurationUpsertRequest(
name="gpt-4o",
is_visible=True,
)
],
),
is_creation=True,
_=_create_mock_admin(),
@@ -535,7 +530,6 @@ class TestAutoModeSyncFeature:
api_key=provider_1_api_key,
api_key_changed=True,
is_auto_mode=True,
default_model_name=provider_1_default_model,
model_configurations=[],
),
is_creation=True,
@@ -549,7 +543,7 @@ class TestAutoModeSyncFeature:
name=provider_1_name, db_session=db_session
)
assert provider_1 is not None
update_default_provider(provider_1.id, db_session)
update_default_provider(provider_1.id, provider_1_default_model, db_session)
with patch(
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
@@ -563,7 +557,6 @@ class TestAutoModeSyncFeature:
api_key=provider_2_api_key,
api_key_changed=True,
is_auto_mode=True,
default_model_name=provider_2_default_model,
model_configurations=[],
),
is_creation=True,
@@ -584,7 +577,7 @@ class TestAutoModeSyncFeature:
name=provider_2_name, db_session=db_session
)
assert provider_2 is not None
update_default_provider(provider_2.id, db_session)
update_default_provider(provider_2.id, provider_2_default_model, db_session)
# Step 5: Verify provider 2 is now the default
db_session.expire_all()
@@ -644,7 +637,6 @@ class TestAutoModeMissingFlows:
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=True,
default_model_name="gpt-4o",
model_configurations=[],
),
is_creation=True,
@@ -701,3 +693,364 @@ class TestAutoModeMissingFlows:
finally:
db_session.rollback()
_cleanup_provider(db_session, provider_name)
class TestAutoModeTransitionsAndResync:
"""Tests for auto/manual transitions, config evolution, and sync idempotency."""
def test_auto_to_manual_mode_preserves_models_and_stops_syncing(
self,
db_session: Session,
provider_name: str,
) -> None:
"""Disabling auto mode should preserve the current model list and
prevent future syncs from altering visibility.
Steps:
1. Create provider in auto mode — models synced from config.
2. Update provider to manual mode (is_auto_mode=False).
3. Verify all models remain with unchanged visibility.
4. Call sync_auto_mode_models with a *different* config.
5. Verify fetch_auto_mode_providers excludes this provider, so the
periodic task would never call sync on it.
"""
initial_config = _create_mock_llm_recommendations(
provider=LlmProviderNames.OPENAI,
default_model_name="gpt-4o",
additional_models=["gpt-4o-mini"],
)
try:
# Step 1: Create in auto mode
with patch(
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
return_value=initial_config,
):
put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=True,
model_configurations=[],
),
is_creation=True,
_=_create_mock_admin(),
db_session=db_session,
)
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
visibility_before = {
mc.name: mc.is_visible for mc in provider.model_configurations
}
assert visibility_before == {"gpt-4o": True, "gpt-4o-mini": True}
# Step 2: Switch to manual mode
put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
id=provider.id,
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_key=None,
api_key_changed=False,
is_auto_mode=False,
model_configurations=[
ModelConfigurationUpsertRequest(name="gpt-4o", is_visible=True),
ModelConfigurationUpsertRequest(
name="gpt-4o-mini", is_visible=True
),
],
),
is_creation=False,
_=_create_mock_admin(),
db_session=db_session,
)
# Step 3: Models unchanged
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
assert provider.is_auto_mode is False
visibility_after = {
mc.name: mc.is_visible for mc in provider.model_configurations
}
assert visibility_after == visibility_before
# Step 4-5: Provider excluded from auto mode queries
auto_providers = fetch_auto_mode_providers(db_session)
auto_provider_ids = {p.id for p in auto_providers}
assert provider.id not in auto_provider_ids
finally:
db_session.rollback()
_cleanup_provider(db_session, provider_name)
def test_resync_adds_new_and_hides_removed_models(
self,
db_session: Session,
provider_name: str,
) -> None:
"""When the GitHub config changes between syncs, a subsequent sync
should add newly listed models and hide models that were removed.
Steps:
1. Create provider in auto mode with config v1: [gpt-4o, gpt-4o-mini].
2. Sync with config v2: [gpt-4o, gpt-4-turbo] (gpt-4o-mini removed,
gpt-4-turbo added).
3. Verify gpt-4o still visible, gpt-4o-mini hidden, gpt-4-turbo added
and visible.
"""
config_v1 = _create_mock_llm_recommendations(
provider=LlmProviderNames.OPENAI,
default_model_name="gpt-4o",
additional_models=["gpt-4o-mini"],
)
config_v2 = _create_mock_llm_recommendations(
provider=LlmProviderNames.OPENAI,
default_model_name="gpt-4o",
additional_models=["gpt-4-turbo"],
)
try:
# Step 1: Create with config v1
with patch(
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
return_value=config_v1,
):
put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=True,
model_configurations=[],
),
is_creation=True,
_=_create_mock_admin(),
db_session=db_session,
)
# Step 2: Re-sync with config v2
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
changes = sync_auto_mode_models(
db_session=db_session,
provider=provider,
llm_recommendations=config_v2,
)
assert changes > 0
# Step 3: Verify
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
visibility = {
mc.name: mc.is_visible for mc in provider.model_configurations
}
# gpt-4o: still in config -> visible
assert visibility["gpt-4o"] is True
# gpt-4o-mini: removed from config -> hidden (not deleted)
assert "gpt-4o-mini" in visibility, "Removed model should still exist in DB"
assert visibility["gpt-4o-mini"] is False
# gpt-4-turbo: newly added -> visible
assert visibility["gpt-4-turbo"] is True
finally:
db_session.rollback()
_cleanup_provider(db_session, provider_name)
def test_sync_is_idempotent(
self,
db_session: Session,
provider_name: str,
) -> None:
"""Running sync twice with the same config should produce zero
changes on the second call."""
config = _create_mock_llm_recommendations(
provider=LlmProviderNames.OPENAI,
default_model_name="gpt-4o",
additional_models=["gpt-4o-mini", "gpt-4-turbo"],
)
try:
with patch(
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
return_value=config,
):
put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=True,
model_configurations=[],
),
is_creation=True,
_=_create_mock_admin(),
db_session=db_session,
)
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
# First explicit sync (may report changes if creation already synced)
sync_auto_mode_models(
db_session=db_session,
provider=provider,
llm_recommendations=config,
)
# Snapshot state after first sync
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
snapshot = {
mc.name: (mc.is_visible, mc.display_name)
for mc in provider.model_configurations
}
# Second sync — should be a no-op
changes = sync_auto_mode_models(
db_session=db_session,
provider=provider,
llm_recommendations=config,
)
assert (
changes == 0
), f"Expected 0 changes on idempotent re-sync, got {changes}"
# State should be identical
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
current = {
mc.name: (mc.is_visible, mc.display_name)
for mc in provider.model_configurations
}
assert current == snapshot
finally:
db_session.rollback()
_cleanup_provider(db_session, provider_name)
def test_default_model_hidden_when_removed_from_config(
self,
db_session: Session,
provider_name: str,
) -> None:
"""When the current default model is removed from the config, sync
should hide it. The default model flow row should still exist (it
points at the ModelConfiguration), but the model is no longer visible.
Steps:
1. Create provider with config: default=gpt-4o, additional=[gpt-4o-mini].
2. Set gpt-4o as the global default.
3. Re-sync with config: default=gpt-4o-mini (gpt-4o removed entirely).
4. Verify gpt-4o is hidden, gpt-4o-mini is visible, and
fetch_default_llm_model still returns a result (the flow row persists).
"""
config_v1 = _create_mock_llm_recommendations(
provider=LlmProviderNames.OPENAI,
default_model_name="gpt-4o",
additional_models=["gpt-4o-mini"],
)
config_v2 = _create_mock_llm_recommendations(
provider=LlmProviderNames.OPENAI,
default_model_name="gpt-4o-mini",
additional_models=[],
)
try:
with patch(
"onyx.server.manage.llm.api.fetch_llm_recommendations_from_github",
return_value=config_v1,
):
put_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
name=provider_name,
provider=LlmProviderNames.OPENAI,
api_key="sk-test-key-00000000000000000000000000000000000",
api_key_changed=True,
is_auto_mode=True,
model_configurations=[],
),
is_creation=True,
_=_create_mock_admin(),
db_session=db_session,
)
# Step 2: Set gpt-4o as global default
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
update_default_provider(provider.id, "gpt-4o", db_session)
default_before = fetch_default_llm_model(db_session)
assert default_before is not None
assert default_before.name == "gpt-4o"
# Step 3: Re-sync with config v2 (gpt-4o removed)
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
changes = sync_auto_mode_models(
db_session=db_session,
provider=provider,
llm_recommendations=config_v2,
)
assert changes > 0
# Step 4: Verify visibility
db_session.expire_all()
provider = fetch_existing_llm_provider(
name=provider_name, db_session=db_session
)
assert provider is not None
visibility = {
mc.name: mc.is_visible for mc in provider.model_configurations
}
assert visibility["gpt-4o"] is False, "Removed default should be hidden"
assert visibility["gpt-4o-mini"] is True, "New default should be visible"
# The LLMModelFlow row for gpt-4o still exists (is_default=True),
# but the model is hidden. fetch_default_llm_model filters on
# is_visible=True, so it should NOT return gpt-4o.
db_session.expire_all()
default_after = fetch_default_llm_model(db_session)
assert (
default_after is None or default_after.name != "gpt-4o"
), "Hidden model should not be returned as the default"
finally:
db_session.rollback()
_cleanup_provider(db_session, provider_name)

View File

@@ -64,7 +64,6 @@ def _create_provider(
name=name,
provider=provider,
api_key="sk-ant-api03-...",
default_model_name="claude-3-5-sonnet-20240620",
is_public=is_public,
model_configurations=[
ModelConfigurationUpsertRequest(
@@ -154,7 +153,9 @@ def test_user_sends_message_to_private_provider(
)
_create_provider(db_session, LlmProviderNames.GOOGLE, "private-provider", False)
update_default_provider(public_provider_id, db_session)
update_default_provider(
public_provider_id, "claude-3-5-sonnet-20240620", db_session
)
try:
# Create chat session

View File

@@ -42,7 +42,6 @@ def _create_llm_provider_and_model(
name=provider_name,
provider="openai",
api_key="test-api-key",
default_model_name=model_name,
model_configurations=[
ModelConfigurationUpsertRequest(
name=model_name,

View File

@@ -22,7 +22,6 @@ from sqlalchemy.orm import Session
from slack_sdk.errors import SlackApiError
from onyx.configs.constants import FederatedConnectorSource
from onyx.context.search.enums import RecencyBiasSetting
from onyx.context.search.federated.slack_search import fetch_and_cache_channel_metadata
from onyx.db.models import DocumentSet
from onyx.db.models import FederatedConnector
@@ -55,11 +54,6 @@ def _create_test_persona_with_slack_config(db_session: Session) -> Persona | Non
persona = Persona(
name=f"test_slack_persona_{unique_id}",
description="Test persona for Slack federated search",
chunks_above=0,
chunks_below=0,
llm_relevance_filter=True,
llm_filter_extraction=True,
recency_bias=RecencyBiasSetting.AUTO,
system_prompt="You are a helpful assistant.",
task_prompt="Answer the user's question based on the provided context.",
)
@@ -434,7 +428,6 @@ class TestSlackBotFederatedSearch:
name=f"test-llm-provider-{uuid4().hex[:8]}",
provider=LlmProviderNames.OPENAI,
api_key=api_key,
default_model_name="gpt-4o",
is_public=True,
model_configurations=[
ModelConfigurationUpsertRequest(
@@ -448,7 +441,7 @@ class TestSlackBotFederatedSearch:
db_session=db_session,
)
update_default_provider(provider_view.id, db_session)
update_default_provider(provider_view.id, "gpt-4o", db_session)
def _teardown_common_mocks(self, patches: list) -> None:
"""Stop all patches"""
@@ -819,11 +812,6 @@ def test_slack_channel_config_eager_loads_persona(db_session: Session) -> None:
persona = Persona(
name=f"test_eager_load_persona_{unique_id}",
description="Test persona for eager loading test",
chunks_above=0,
chunks_below=0,
llm_relevance_filter=True,
llm_filter_extraction=True,
recency_bias=RecencyBiasSetting.AUTO,
system_prompt="You are a helpful assistant.",
task_prompt="Answer the user's question.",
)

View File

@@ -21,7 +21,6 @@ import pytest
from sqlalchemy.orm import Session
from onyx.chat.emitter import get_default_emitter
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.enums import MCPAuthenticationPerformer
from onyx.db.enums import MCPAuthenticationType
from onyx.db.enums import MCPTransport
@@ -47,12 +46,6 @@ def _create_test_persona_with_mcp_tool(
persona = Persona(
name=f"Test MCP Persona {uuid4().hex[:8]}",
description="Test persona with MCP tools",
num_chunks=10.0,
chunks_above=0,
chunks_below=0,
llm_relevance_filter=False,
llm_filter_extraction=False,
recency_bias=RecencyBiasSetting.NO_DECAY,
system_prompt="You are a helpful assistant",
task_prompt="Answer the user's question",
tools=tools,

View File

@@ -17,7 +17,6 @@ import pytest
from sqlalchemy.orm import Session
from onyx.chat.emitter import get_default_emitter
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.models import OAuthAccount
from onyx.db.models import OAuthConfig
from onyx.db.models import Persona
@@ -57,12 +56,6 @@ def _create_test_persona(db_session: Session, user: User, tools: list[Tool]) ->
persona = Persona(
name=f"Test Persona {uuid4().hex[:8]}",
description="Test persona",
num_chunks=10.0,
chunks_above=0,
chunks_below=0,
llm_relevance_filter=False,
llm_filter_extraction=False,
recency_bias=RecencyBiasSetting.NO_DECAY,
system_prompt="You are a helpful assistant",
task_prompt="Answer the user's question",
tools=tools,

View File

@@ -4,10 +4,12 @@ from uuid import uuid4
import requests
from onyx.llm.constants import LlmProviderNames
from onyx.server.manage.llm.models import DefaultModel
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import LLMProviderView
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.test_models import DATestLLMProvider
from tests.integration.common_utils.test_models import DATestUser
@@ -32,7 +34,6 @@ class LLMProviderManager:
llm_provider = LLMProviderUpsertRequest(
name=name or f"test-provider-{uuid4()}",
provider=provider or LlmProviderNames.OPENAI,
default_model_name=default_model_name or "gpt-4o-mini",
api_key=api_key or os.environ["OPENAI_API_KEY"],
api_base=api_base,
api_version=api_version,
@@ -65,7 +66,7 @@ class LLMProviderManager:
name=response_data["name"],
provider=response_data["provider"],
api_key=response_data["api_key"],
default_model_name=response_data["default_model_name"],
default_model_name=default_model_name or "gpt-4o-mini",
is_public=response_data["is_public"],
is_auto_mode=response_data.get("is_auto_mode", False),
groups=response_data["groups"],
@@ -75,9 +76,19 @@ class LLMProviderManager:
)
if set_as_default:
if default_model_name is None:
default_model_name = "gpt-4o-mini"
set_default_response = requests.post(
f"{API_SERVER_URL}/admin/llm/provider/{llm_response.json()['id']}/default",
headers=user_performing_action.headers,
f"{API_SERVER_URL}/admin/llm/default",
json={
"provider_id": response_data["id"],
"model_name": default_model_name,
},
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
set_default_response.raise_for_status()
@@ -104,7 +115,7 @@ class LLMProviderManager:
headers=user_performing_action.headers,
)
response.raise_for_status()
return [LLMProviderView(**ug) for ug in response.json()]
return [LLMProviderView(**p) for p in response.json()["providers"]]
@staticmethod
def verify(
@@ -113,7 +124,11 @@ class LLMProviderManager:
verify_deleted: bool = False,
) -> None:
all_llm_providers = LLMProviderManager.get_all(user_performing_action)
default_model = LLMProviderManager.get_default_model(user_performing_action)
for fetched_llm_provider in all_llm_providers:
model_names = [
model.name for model in fetched_llm_provider.model_configurations
]
if llm_provider.id == fetched_llm_provider.id:
if verify_deleted:
raise ValueError(
@@ -126,11 +141,30 @@ class LLMProviderManager:
if (
fetched_llm_groups == llm_provider_groups
and llm_provider.provider == fetched_llm_provider.provider
and llm_provider.default_model_name
== fetched_llm_provider.default_model_name
and (
default_model is None or default_model.model_name in model_names
)
and llm_provider.is_public == fetched_llm_provider.is_public
and set(fetched_llm_provider.personas) == set(llm_provider.personas)
):
return
if not verify_deleted:
raise ValueError(f"LLM Provider {llm_provider.id} not found")
@staticmethod
def get_default_model(
user_performing_action: DATestUser | None = None,
) -> DefaultModel | None:
response = requests.get(
f"{API_SERVER_URL}/admin/llm/provider",
headers=(
user_performing_action.headers
if user_performing_action
else GENERAL_HEADERS
),
)
response.raise_for_status()
default_text = response.json().get("default_text")
if default_text is None:
return None
return DefaultModel(**default_text)

View File

@@ -3,7 +3,6 @@ from uuid import uuid4
import requests
from onyx.context.search.enums import RecencyBiasSetting
from onyx.server.features.persona.models import FullPersonaSnapshot
from onyx.server.features.persona.models import PersonaUpsertRequest
from tests.integration.common_utils.constants import API_SERVER_URL
@@ -20,11 +19,7 @@ class PersonaManager:
description: str | None = None,
system_prompt: str | None = None,
task_prompt: str | None = None,
num_chunks: float = 5,
llm_relevance_filter: bool = True,
is_public: bool = True,
llm_filter_extraction: bool = True,
recency_bias: RecencyBiasSetting = RecencyBiasSetting.AUTO,
datetime_aware: bool = False,
document_set_ids: list[int] | None = None,
tool_ids: list[int] | None = None,
@@ -35,6 +30,7 @@ class PersonaManager:
label_ids: list[int] | None = None,
user_file_ids: list[str] | None = None,
display_priority: int | None = None,
featured: bool = False,
) -> DATestPersona:
name = name or f"test-persona-{uuid4()}"
description = description or f"Description for {name}"
@@ -47,11 +43,7 @@ class PersonaManager:
system_prompt=system_prompt,
task_prompt=task_prompt,
datetime_aware=datetime_aware,
num_chunks=num_chunks,
llm_relevance_filter=llm_relevance_filter,
is_public=is_public,
llm_filter_extraction=llm_filter_extraction,
recency_bias=recency_bias,
document_set_ids=document_set_ids or [],
tool_ids=tool_ids or [],
llm_model_provider_override=llm_model_provider_override,
@@ -61,6 +53,7 @@ class PersonaManager:
label_ids=label_ids or [],
user_file_ids=user_file_ids or [],
display_priority=display_priority,
featured=featured,
)
response = requests.post(
@@ -75,11 +68,7 @@ class PersonaManager:
id=persona_data["id"],
name=name,
description=description,
num_chunks=num_chunks,
llm_relevance_filter=llm_relevance_filter,
is_public=is_public,
llm_filter_extraction=llm_filter_extraction,
recency_bias=recency_bias,
system_prompt=system_prompt,
task_prompt=task_prompt,
datetime_aware=datetime_aware,
@@ -90,6 +79,7 @@ class PersonaManager:
users=users or [],
groups=groups or [],
label_ids=label_ids or [],
featured=featured,
)
@staticmethod
@@ -100,11 +90,7 @@ class PersonaManager:
description: str | None = None,
system_prompt: str | None = None,
task_prompt: str | None = None,
num_chunks: float | None = None,
llm_relevance_filter: bool | None = None,
is_public: bool | None = None,
llm_filter_extraction: bool | None = None,
recency_bias: RecencyBiasSetting | None = None,
datetime_aware: bool = False,
document_set_ids: list[int] | None = None,
tool_ids: list[int] | None = None,
@@ -113,6 +99,7 @@ class PersonaManager:
users: list[str] | None = None,
groups: list[int] | None = None,
label_ids: list[int] | None = None,
featured: bool | None = None,
) -> DATestPersona:
system_prompt = system_prompt or f"System prompt for {persona.name}"
task_prompt = task_prompt or f"Task prompt for {persona.name}"
@@ -123,13 +110,7 @@ class PersonaManager:
system_prompt=system_prompt,
task_prompt=task_prompt,
datetime_aware=datetime_aware,
num_chunks=num_chunks or persona.num_chunks,
llm_relevance_filter=llm_relevance_filter or persona.llm_relevance_filter,
is_public=persona.is_public if is_public is None else is_public,
llm_filter_extraction=(
llm_filter_extraction or persona.llm_filter_extraction
),
recency_bias=recency_bias or persona.recency_bias,
document_set_ids=document_set_ids or persona.document_set_ids,
tool_ids=tool_ids or persona.tool_ids,
llm_model_provider_override=(
@@ -141,6 +122,7 @@ class PersonaManager:
users=[UUID(user) for user in (users or persona.users)],
groups=groups or persona.groups,
label_ids=label_ids or persona.label_ids,
featured=featured if featured is not None else persona.featured,
)
response = requests.patch(
@@ -155,16 +137,12 @@ class PersonaManager:
id=updated_persona_data["id"],
name=updated_persona_data["name"],
description=updated_persona_data["description"],
num_chunks=updated_persona_data["num_chunks"],
llm_relevance_filter=updated_persona_data["llm_relevance_filter"],
is_public=updated_persona_data["is_public"],
llm_filter_extraction=updated_persona_data["llm_filter_extraction"],
recency_bias=recency_bias or persona.recency_bias,
system_prompt=system_prompt,
task_prompt=task_prompt,
datetime_aware=datetime_aware,
document_set_ids=updated_persona_data["document_sets"],
tool_ids=updated_persona_data["tools"],
document_set_ids=[ds["id"] for ds in updated_persona_data["document_sets"]],
tool_ids=[t["id"] for t in updated_persona_data["tools"]],
llm_model_provider_override=updated_persona_data[
"llm_model_provider_override"
],
@@ -173,7 +151,8 @@ class PersonaManager:
],
users=[user["email"] for user in updated_persona_data["users"]],
groups=updated_persona_data["groups"],
label_ids=updated_persona_data["labels"],
label_ids=[label["id"] for label in updated_persona_data["labels"]],
featured=updated_persona_data["featured"],
)
@staticmethod
@@ -222,32 +201,13 @@ class PersonaManager:
fetched_persona.description,
)
)
if fetched_persona.num_chunks != persona.num_chunks:
mismatches.append(
("num_chunks", persona.num_chunks, fetched_persona.num_chunks)
)
if fetched_persona.llm_relevance_filter != persona.llm_relevance_filter:
mismatches.append(
(
"llm_relevance_filter",
persona.llm_relevance_filter,
fetched_persona.llm_relevance_filter,
)
)
if fetched_persona.is_public != persona.is_public:
mismatches.append(
("is_public", persona.is_public, fetched_persona.is_public)
)
if (
fetched_persona.llm_filter_extraction
!= persona.llm_filter_extraction
):
if fetched_persona.featured != persona.featured:
mismatches.append(
(
"llm_filter_extraction",
persona.llm_filter_extraction,
fetched_persona.llm_filter_extraction,
)
("featured", persona.featured, fetched_persona.featured)
)
if (
fetched_persona.llm_model_provider_override

View File

@@ -10,7 +10,6 @@ from pydantic import Field
from onyx.auth.schemas import UserRole
from onyx.configs.constants import MessageType
from onyx.configs.constants import QAFeedbackType
from onyx.context.search.enums import RecencyBiasSetting
from onyx.context.search.models import SavedSearchDoc
from onyx.context.search.models import SearchDoc
from onyx.db.enums import AccessType
@@ -128,7 +127,7 @@ class DATestLLMProvider(BaseModel):
name: str
provider: str
api_key: str
default_model_name: str
default_model_name: str | None = None
is_public: bool
is_auto_mode: bool = False
groups: list[int]
@@ -162,11 +161,7 @@ class DATestPersona(BaseModel):
id: int
name: str
description: str
num_chunks: float
llm_relevance_filter: bool
is_public: bool
llm_filter_extraction: bool
recency_bias: RecencyBiasSetting
document_set_ids: list[int]
tool_ids: list[int]
llm_model_provider_override: str | None
@@ -174,6 +169,7 @@ class DATestPersona(BaseModel):
users: list[str]
groups: list[int]
label_ids: list[int]
featured: bool = False
# Embedded prompt fields (no longer separate prompt_ids)
system_prompt: str | None = None

View File

@@ -8,7 +8,6 @@ from collections.abc import Generator
import pytest
from sqlalchemy.orm import Session
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.discord_bot import bulk_create_channel_configs
from onyx.db.discord_bot import create_discord_bot_config
from onyx.db.discord_bot import create_guild_config
@@ -36,14 +35,8 @@ def _create_test_persona(db_session: Session, persona_id: int, name: str) -> Per
id=persona_id,
name=name,
description="Test persona for Discord bot tests",
num_chunks=5.0,
chunks_above=1,
chunks_below=1,
llm_relevance_filter=False,
llm_filter_extraction=False,
recency_bias=RecencyBiasSetting.FAVOR_RECENT,
is_visible=True,
is_default_persona=False,
featured=False,
deleted=False,
builtin_persona=False,
)

View File

@@ -42,12 +42,10 @@ def _create_provider_with_api(
llm_provider_data = {
"name": name,
"provider": provider_type,
"default_model_name": default_model,
"api_key": "test-api-key-for-auto-mode-testing",
"api_base": None,
"api_version": None,
"custom_config": None,
"fast_default_model_name": default_model,
"is_public": True,
"is_auto_mode": is_auto_mode,
"groups": [],
@@ -72,7 +70,7 @@ def _get_provider_by_id(admin_user: DATestUser, provider_id: int) -> dict:
headers=admin_user.headers,
)
response.raise_for_status()
for provider in response.json():
for provider in response.json()["providers"]:
if provider["id"] == provider_id:
return provider
raise ValueError(f"Provider with id {provider_id} not found")
@@ -219,15 +217,6 @@ def test_auto_mode_provider_gets_synced_from_github_config(
"is_visible"
], "Outdated model should not be visible after sync"
# Verify default model was set from GitHub config
expected_default = (
default_model["name"] if isinstance(default_model, dict) else default_model
)
assert synced_provider["default_model_name"] == expected_default, (
f"Default model should be {expected_default}, "
f"got {synced_provider['default_model_name']}"
)
def test_manual_mode_provider_not_affected_by_auto_sync(
reset: None, # noqa: ARG001
@@ -273,7 +262,3 @@ def test_manual_mode_provider_not_affected_by_auto_sync(
f"Manual mode provider models should not change. "
f"Initial: {initial_models}, Current: {current_models}"
)
assert (
updated_provider["default_model_name"] == custom_model
), f"Manual mode default model should remain {custom_model}"

View File

@@ -4,22 +4,22 @@ import pytest
import requests
from sqlalchemy.orm import Session
from onyx.context.search.enums import RecencyBiasSetting
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import LLMModelFlowType
from onyx.db.llm import can_user_access_llm_provider
from onyx.db.llm import fetch_user_group_ids
from onyx.db.models import LLMModelFlow
from onyx.db.llm import update_default_provider
from onyx.db.llm import upsert_llm_provider
from onyx.db.models import LLMProvider as LLMProviderModel
from onyx.db.models import LLMProvider__Persona
from onyx.db.models import LLMProvider__UserGroup
from onyx.db.models import ModelConfiguration
from onyx.db.models import Persona
from onyx.db.models import User
from onyx.db.models import User__UserGroup
from onyx.db.models import UserGroup
from onyx.llm.constants import LlmProviderNames
from onyx.llm.factory import get_llm_for_persona
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.managers.llm_provider import LLMProviderManager
from tests.integration.common_utils.managers.persona import PersonaManager
@@ -41,24 +41,30 @@ def _create_llm_provider(
is_public: bool,
is_default: bool,
) -> LLMProviderModel:
provider = LLMProviderModel(
name=name,
provider=LlmProviderNames.OPENAI,
api_key=None,
api_base=None,
api_version=None,
custom_config=None,
default_model_name=default_model_name,
deployment_name=None,
is_public=is_public,
# Use None instead of False to avoid unique constraint violation
# The is_default_provider column has unique=True, so only one True and one False allowed
is_default_provider=is_default if is_default else None,
is_default_vision_provider=False,
default_vision_model=None,
_provider = upsert_llm_provider(
llm_provider_upsert_request=LLMProviderUpsertRequest(
name=name,
provider=LlmProviderNames.OPENAI,
api_key=None,
api_base=None,
api_version=None,
custom_config=None,
is_public=is_public,
model_configurations=[
ModelConfigurationUpsertRequest(
name=default_model_name,
is_visible=True,
)
],
),
db_session=db_session,
)
db_session.add(provider)
db_session.flush()
if is_default:
update_default_provider(_provider.id, default_model_name, db_session)
provider = db_session.get(LLMProviderModel, _provider.id)
if not provider:
raise ValueError(f"Provider {name} not found")
return provider
@@ -71,12 +77,6 @@ def _create_persona(
persona = Persona(
name=name,
description=f"{name} description",
num_chunks=5,
chunks_above=2,
chunks_below=2,
llm_relevance_filter=True,
llm_filter_extraction=True,
recency_bias=RecencyBiasSetting.AUTO,
llm_model_provider_override=provider_name,
llm_model_version_override="gpt-4o-mini",
system_prompt="System prompt",
@@ -270,24 +270,6 @@ def test_get_llm_for_persona_falls_back_when_access_denied(
provider_name=restricted_provider.name,
)
# Set up ModelConfiguration + LLMModelFlow so get_default_llm() can
# resolve the default provider when the fallback path is triggered.
default_model_config = ModelConfiguration(
llm_provider_id=default_provider.id,
name=default_provider.default_model_name,
is_visible=True,
)
db_session.add(default_model_config)
db_session.flush()
db_session.add(
LLMModelFlow(
model_configuration_id=default_model_config.id,
llm_model_flow_type=LLMModelFlowType.CHAT,
is_default=True,
)
)
db_session.flush()
access_group = UserGroup(name="persona-group")
db_session.add(access_group)
db_session.flush()
@@ -321,13 +303,19 @@ def test_get_llm_for_persona_falls_back_when_access_denied(
persona=persona,
user=admin_model,
)
assert allowed_llm.config.model_name == restricted_provider.default_model_name
assert (
allowed_llm.config.model_name
== restricted_provider.model_configurations[0].name
)
fallback_llm = get_llm_for_persona(
persona=persona,
user=basic_model,
)
assert fallback_llm.config.model_name == default_provider.default_model_name
assert (
fallback_llm.config.model_name
== default_provider.model_configurations[0].name
)
def test_list_llm_provider_basics_excludes_non_public_unrestricted(
@@ -346,6 +334,7 @@ def test_list_llm_provider_basics_excludes_non_public_unrestricted(
name="public-provider",
is_public=True,
set_as_default=True,
default_model_name="gpt-4o",
user_performing_action=admin_user,
)
@@ -365,7 +354,7 @@ def test_list_llm_provider_basics_excludes_non_public_unrestricted(
headers=basic_user.headers,
)
assert response.status_code == 200
providers = response.json()
providers = response.json()["providers"]
provider_names = [p["name"] for p in providers]
# Public provider should be visible
@@ -380,7 +369,7 @@ def test_list_llm_provider_basics_excludes_non_public_unrestricted(
headers=admin_user.headers,
)
assert admin_response.status_code == 200
admin_providers = admin_response.json()
admin_providers = admin_response.json()["providers"]
admin_provider_names = [p["name"] for p in admin_providers]
assert public_provider.name in admin_provider_names
@@ -396,6 +385,7 @@ def test_provider_delete_clears_persona_references(reset: None) -> None: # noqa
name="default-provider",
is_public=True,
set_as_default=True,
default_model_name="gpt-4o",
user_performing_action=admin_user,
)

View File

@@ -107,7 +107,7 @@ def test_authorized_persona_access_returns_filtered_providers(
# Should succeed
assert response.status_code == 200
providers = response.json()
providers = response.json()["providers"]
# Should include the restricted provider since basic_user can access the persona
provider_names = [p["name"] for p in providers]
@@ -140,7 +140,7 @@ def test_persona_id_zero_applies_rbac(
# Should succeed (persona_id=0 refers to default persona, which is public)
assert response.status_code == 200
providers = response.json()
providers = response.json()["providers"]
# Should NOT include the restricted provider since basic_user is not in group2
provider_names = [p["name"] for p in providers]
@@ -182,7 +182,7 @@ def test_admin_can_query_any_persona(
# Should succeed - admins can access any persona
assert response.status_code == 200
providers = response.json()
providers = response.json()["providers"]
# Should include the restricted provider
provider_names = [p["name"] for p in providers]
@@ -223,7 +223,7 @@ def test_public_persona_accessible_to_all(
# Should succeed
assert response.status_code == 200
providers = response.json()
providers = response.json()["providers"]
# Should return the public provider
assert len(providers) > 0

View File

@@ -25,7 +25,7 @@ def test_cold_startup_default_assistant() -> None:
result = db_session.execute(
text(
"""
SELECT id, name, builtin_persona, is_default_persona, deleted
SELECT id, name, builtin_persona, featured, deleted
FROM persona
WHERE builtin_persona = true
ORDER BY id
@@ -40,7 +40,7 @@ def test_cold_startup_default_assistant() -> None:
assert default[0] == 0, "Default assistant should have ID 0"
assert default[1] == "Assistant", "Should be named 'Assistant'"
assert default[2] is True, "Should be builtin"
assert default[3] is True, "Should be default"
assert default[3] is True, "Should be featured"
assert default[4] is False, "Should not be deleted"
# Check tools are properly associated

View File

@@ -195,11 +195,7 @@ def _base_persona_body(**overrides: object) -> dict:
"description": "test",
"system_prompt": "test",
"task_prompt": "",
"num_chunks": 5,
"is_public": True,
"recency_bias": "auto",
"llm_filter_extraction": False,
"llm_relevance_filter": False,
"datetime_aware": False,
"document_set_ids": [],
"tool_ids": [],

View File

@@ -40,7 +40,6 @@ def test_persona_create_update_share_delete(
expected_persona,
name=f"updated-{expected_persona.name}",
description=f"updated-{expected_persona.description}",
num_chunks=expected_persona.num_chunks + 1,
is_public=False,
user_performing_action=admin_user,
)

View File

@@ -31,11 +31,7 @@ def test_update_persona_with_null_label_ids_preserves_labels(
task_prompt=persona.task_prompt or "",
datetime_aware=persona.datetime_aware,
document_set_ids=persona.document_set_ids,
num_chunks=persona.num_chunks,
is_public=persona.is_public,
recency_bias=persona.recency_bias,
llm_filter_extraction=persona.llm_filter_extraction,
llm_relevance_filter=persona.llm_relevance_filter,
llm_model_provider_override=persona.llm_model_provider_override,
llm_model_version_override=persona.llm_model_version_override,
tool_ids=persona.tool_ids,

View File

@@ -31,9 +31,8 @@ def test_unified_assistant(reset: None, admin_user: DATestUser) -> None: # noqa
"search, web browsing, and image generation"
in unified_assistant.description.lower()
)
assert unified_assistant.is_default_persona is True
assert unified_assistant.featured is True
assert unified_assistant.is_visible is True
assert unified_assistant.num_chunks == 25
# Verify tools
tools = unified_assistant.tools

View File

@@ -9,6 +9,19 @@ from redis.exceptions import RedisError
from onyx.server.settings.models import ApplicationStatus
from onyx.server.settings.models import Settings
# Fields we assert on across all tests
_ASSERT_FIELDS = {
"application_status",
"ee_features_enabled",
"seat_count",
"used_seats",
}
def _pick(settings: Settings) -> dict:
"""Extract only the fields under test from a Settings object."""
return settings.model_dump(include=_ASSERT_FIELDS)
@pytest.fixture
def base_settings() -> Settings:
@@ -27,17 +40,17 @@ class TestApplyLicenseStatusToSettings:
def test_enforcement_disabled_enables_ee_features(
self, base_settings: Settings
) -> None:
"""When LICENSE_ENFORCEMENT_ENABLED=False, EE features are enabled.
If we're running the EE apply function, EE code was loaded via
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES, so features should be on.
"""
"""When LICENSE_ENFORCEMENT_ENABLED=False, EE features are enabled."""
from ee.onyx.server.settings.api import apply_license_status_to_settings
assert base_settings.ee_features_enabled is False
result = apply_license_status_to_settings(base_settings)
assert result.application_status == ApplicationStatus.ACTIVE
assert result.ee_features_enabled is True
assert _pick(result) == {
"application_status": ApplicationStatus.ACTIVE,
"ee_features_enabled": True,
"seat_count": None,
"used_seats": None,
}
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", True)
@patch("ee.onyx.server.settings.api.MULTI_TENANT", True)
@@ -46,13 +59,60 @@ class TestApplyLicenseStatusToSettings:
from ee.onyx.server.settings.api import apply_license_status_to_settings
result = apply_license_status_to_settings(base_settings)
assert result.ee_features_enabled is True
assert _pick(result) == {
"application_status": ApplicationStatus.ACTIVE,
"ee_features_enabled": True,
"seat_count": None,
"used_seats": None,
}
@pytest.mark.parametrize(
"license_status,expected_app_status,expected_ee_enabled",
"license_status,used_seats,seats,expected",
[
(ApplicationStatus.GATED_ACCESS, ApplicationStatus.GATED_ACCESS, False),
(ApplicationStatus.ACTIVE, ApplicationStatus.ACTIVE, True),
(
ApplicationStatus.GATED_ACCESS,
3,
10,
{
"application_status": ApplicationStatus.GATED_ACCESS,
"ee_features_enabled": False,
"seat_count": None,
"used_seats": None,
},
),
(
ApplicationStatus.ACTIVE,
3,
10,
{
"application_status": ApplicationStatus.ACTIVE,
"ee_features_enabled": True,
"seat_count": None,
"used_seats": None,
},
),
(
ApplicationStatus.ACTIVE,
10,
10,
{
"application_status": ApplicationStatus.ACTIVE,
"ee_features_enabled": True,
"seat_count": None,
"used_seats": None,
},
),
(
ApplicationStatus.GRACE_PERIOD,
3,
10,
{
"application_status": ApplicationStatus.ACTIVE,
"ee_features_enabled": True,
"seat_count": None,
"used_seats": None,
},
),
],
)
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", True)
@@ -63,25 +123,80 @@ class TestApplyLicenseStatusToSettings:
self,
mock_get_metadata: MagicMock,
mock_get_tenant: MagicMock,
license_status: ApplicationStatus | None,
expected_app_status: ApplicationStatus,
expected_ee_enabled: bool,
license_status: ApplicationStatus,
used_seats: int,
seats: int,
expected: dict,
base_settings: Settings,
) -> None:
"""Self-hosted: license status controls both application_status and ee_features_enabled."""
from ee.onyx.server.settings.api import apply_license_status_to_settings
mock_get_tenant.return_value = "test_tenant"
if license_status is None:
mock_get_metadata.return_value = None
else:
mock_metadata = MagicMock()
mock_metadata.status = license_status
mock_get_metadata.return_value = mock_metadata
mock_metadata = MagicMock()
mock_metadata.status = license_status
mock_metadata.used_seats = used_seats
mock_metadata.seats = seats
mock_get_metadata.return_value = mock_metadata
result = apply_license_status_to_settings(base_settings)
assert result.application_status == expected_app_status
assert result.ee_features_enabled is expected_ee_enabled
assert _pick(result) == expected
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", True)
@patch("ee.onyx.server.settings.api.MULTI_TENANT", False)
@patch("ee.onyx.server.settings.api.get_current_tenant_id")
@patch("ee.onyx.server.settings.api.get_cached_license_metadata")
def test_seat_limit_exceeded_sets_status_and_counts(
self,
mock_get_metadata: MagicMock,
mock_get_tenant: MagicMock,
base_settings: Settings,
) -> None:
"""Seat limit exceeded sets SEAT_LIMIT_EXCEEDED with counts, keeps EE enabled."""
from ee.onyx.server.settings.api import apply_license_status_to_settings
mock_get_tenant.return_value = "test_tenant"
mock_metadata = MagicMock()
mock_metadata.status = ApplicationStatus.ACTIVE
mock_metadata.used_seats = 15
mock_metadata.seats = 10
mock_get_metadata.return_value = mock_metadata
result = apply_license_status_to_settings(base_settings)
assert _pick(result) == {
"application_status": ApplicationStatus.SEAT_LIMIT_EXCEEDED,
"ee_features_enabled": True,
"seat_count": 10,
"used_seats": 15,
}
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", True)
@patch("ee.onyx.server.settings.api.MULTI_TENANT", False)
@patch("ee.onyx.server.settings.api.get_current_tenant_id")
@patch("ee.onyx.server.settings.api.get_cached_license_metadata")
def test_expired_license_takes_precedence_over_seat_limit(
self,
mock_get_metadata: MagicMock,
mock_get_tenant: MagicMock,
base_settings: Settings,
) -> None:
"""Expired license (GATED_ACCESS) takes precedence over seat limit exceeded."""
from ee.onyx.server.settings.api import apply_license_status_to_settings
mock_get_tenant.return_value = "test_tenant"
mock_metadata = MagicMock()
mock_metadata.status = ApplicationStatus.GATED_ACCESS
mock_metadata.used_seats = 15
mock_metadata.seats = 10
mock_get_metadata.return_value = mock_metadata
result = apply_license_status_to_settings(base_settings)
assert _pick(result) == {
"application_status": ApplicationStatus.GATED_ACCESS,
"ee_features_enabled": False,
"seat_count": None,
"used_seats": None,
}
@patch("ee.onyx.server.settings.api.ENTERPRISE_EDITION_ENABLED", True)
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", True)
@@ -105,8 +220,12 @@ class TestApplyLicenseStatusToSettings:
mock_get_metadata.return_value = None
result = apply_license_status_to_settings(base_settings)
assert result.application_status == ApplicationStatus.GATED_ACCESS
assert result.ee_features_enabled is False
assert _pick(result) == {
"application_status": ApplicationStatus.GATED_ACCESS,
"ee_features_enabled": False,
"seat_count": None,
"used_seats": None,
}
@patch("ee.onyx.server.settings.api.ENTERPRISE_EDITION_ENABLED", False)
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", True)
@@ -130,8 +249,12 @@ class TestApplyLicenseStatusToSettings:
mock_get_metadata.return_value = None
result = apply_license_status_to_settings(base_settings)
assert result.application_status == ApplicationStatus.ACTIVE
assert result.ee_features_enabled is False
assert _pick(result) == {
"application_status": ApplicationStatus.ACTIVE,
"ee_features_enabled": False,
"seat_count": None,
"used_seats": None,
}
@patch("ee.onyx.server.settings.api.LICENSE_ENFORCEMENT_ENABLED", True)
@patch("ee.onyx.server.settings.api.MULTI_TENANT", False)
@@ -150,8 +273,12 @@ class TestApplyLicenseStatusToSettings:
mock_get_metadata.side_effect = RedisError("Connection failed")
result = apply_license_status_to_settings(base_settings)
assert result.application_status == ApplicationStatus.ACTIVE
assert result.ee_features_enabled is False
assert _pick(result) == {
"application_status": ApplicationStatus.ACTIVE,
"ee_features_enabled": False,
"seat_count": None,
"used_seats": None,
}
class TestSettingsDefaultEEDisabled:

View File

@@ -0,0 +1,291 @@
"""Tests for the _impl functions' redis_locking parameter.
Verifies that:
- redis_locking=True acquires/releases Redis locks and clears queued keys
- redis_locking=False skips all Redis operations entirely
- Both paths execute the same business logic (DB lookup, status check)
"""
from unittest.mock import MagicMock
from unittest.mock import patch
from uuid import uuid4
from onyx.background.celery.tasks.user_file_processing.tasks import (
_delete_user_file_impl,
)
from onyx.background.celery.tasks.user_file_processing.tasks import (
_process_user_file_impl,
)
from onyx.background.celery.tasks.user_file_processing.tasks import (
_project_sync_user_file_impl,
)
TASKS_MODULE = "onyx.background.celery.tasks.user_file_processing.tasks"
def _mock_session_returning_none() -> MagicMock:
"""Return a mock session whose .get() returns None (file not found)."""
session = MagicMock()
session.get.return_value = None
session.execute.return_value.scalar_one_or_none.return_value = None
return session
# ------------------------------------------------------------------
# _process_user_file_impl
# ------------------------------------------------------------------
class TestProcessUserFileImpl:
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
@patch(f"{TASKS_MODULE}.get_redis_client")
def test_redis_locking_true_acquires_and_releases_lock(
self,
mock_get_redis: MagicMock,
mock_get_session: MagicMock,
) -> None:
redis_client = MagicMock()
lock = MagicMock()
lock.acquire.return_value = True
lock.owned.return_value = True
redis_client.lock.return_value = lock
mock_get_redis.return_value = redis_client
session = _mock_session_returning_none()
mock_get_session.return_value.__enter__.return_value = session
user_file_id = str(uuid4())
_process_user_file_impl(
user_file_id=user_file_id,
tenant_id="test-tenant",
redis_locking=True,
)
mock_get_redis.assert_called_once_with(tenant_id="test-tenant")
redis_client.delete.assert_called_once()
lock.acquire.assert_called_once_with(blocking=False)
lock.release.assert_called_once()
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
@patch(f"{TASKS_MODULE}.get_redis_client")
def test_redis_locking_true_skips_when_lock_held(
self,
mock_get_redis: MagicMock,
mock_get_session: MagicMock,
) -> None:
redis_client = MagicMock()
lock = MagicMock()
lock.acquire.return_value = False
redis_client.lock.return_value = lock
mock_get_redis.return_value = redis_client
_process_user_file_impl(
user_file_id=str(uuid4()),
tenant_id="test-tenant",
redis_locking=True,
)
lock.acquire.assert_called_once()
mock_get_session.assert_not_called()
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
@patch(f"{TASKS_MODULE}.get_redis_client")
def test_redis_locking_false_skips_redis_entirely(
self,
mock_get_redis: MagicMock,
mock_get_session: MagicMock,
) -> None:
session = _mock_session_returning_none()
mock_get_session.return_value.__enter__.return_value = session
_process_user_file_impl(
user_file_id=str(uuid4()),
tenant_id="test-tenant",
redis_locking=False,
)
mock_get_redis.assert_not_called()
mock_get_session.assert_called_once()
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
@patch(f"{TASKS_MODULE}.get_redis_client")
def test_both_paths_call_db_get(
self,
mock_get_redis: MagicMock,
mock_get_session: MagicMock,
) -> None:
"""Both redis_locking=True and False should call db_session.get(UserFile, ...)."""
redis_client = MagicMock()
lock = MagicMock()
lock.acquire.return_value = True
lock.owned.return_value = True
redis_client.lock.return_value = lock
mock_get_redis.return_value = redis_client
session = _mock_session_returning_none()
mock_get_session.return_value.__enter__.return_value = session
uid = str(uuid4())
_process_user_file_impl(user_file_id=uid, tenant_id="t", redis_locking=True)
call_count_true = session.get.call_count
session.reset_mock()
mock_get_session.reset_mock()
mock_get_session.return_value.__enter__.return_value = session
_process_user_file_impl(user_file_id=uid, tenant_id="t", redis_locking=False)
call_count_false = session.get.call_count
assert call_count_true == call_count_false == 1
# ------------------------------------------------------------------
# _delete_user_file_impl
# ------------------------------------------------------------------
class TestDeleteUserFileImpl:
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
@patch(f"{TASKS_MODULE}.get_redis_client")
def test_redis_locking_true_acquires_and_releases_lock(
self,
mock_get_redis: MagicMock,
mock_get_session: MagicMock,
) -> None:
redis_client = MagicMock()
lock = MagicMock()
lock.acquire.return_value = True
lock.owned.return_value = True
redis_client.lock.return_value = lock
mock_get_redis.return_value = redis_client
session = _mock_session_returning_none()
mock_get_session.return_value.__enter__.return_value = session
_delete_user_file_impl(
user_file_id=str(uuid4()),
tenant_id="test-tenant",
redis_locking=True,
)
mock_get_redis.assert_called_once()
lock.acquire.assert_called_once_with(blocking=False)
lock.release.assert_called_once()
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
@patch(f"{TASKS_MODULE}.get_redis_client")
def test_redis_locking_true_skips_when_lock_held(
self,
mock_get_redis: MagicMock,
mock_get_session: MagicMock,
) -> None:
redis_client = MagicMock()
lock = MagicMock()
lock.acquire.return_value = False
redis_client.lock.return_value = lock
mock_get_redis.return_value = redis_client
_delete_user_file_impl(
user_file_id=str(uuid4()),
tenant_id="test-tenant",
redis_locking=True,
)
lock.acquire.assert_called_once()
mock_get_session.assert_not_called()
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
@patch(f"{TASKS_MODULE}.get_redis_client")
def test_redis_locking_false_skips_redis_entirely(
self,
mock_get_redis: MagicMock,
mock_get_session: MagicMock,
) -> None:
session = _mock_session_returning_none()
mock_get_session.return_value.__enter__.return_value = session
_delete_user_file_impl(
user_file_id=str(uuid4()),
tenant_id="test-tenant",
redis_locking=False,
)
mock_get_redis.assert_not_called()
mock_get_session.assert_called_once()
# ------------------------------------------------------------------
# _project_sync_user_file_impl
# ------------------------------------------------------------------
class TestProjectSyncUserFileImpl:
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
@patch(f"{TASKS_MODULE}.get_redis_client")
def test_redis_locking_true_acquires_and_releases_lock(
self,
mock_get_redis: MagicMock,
mock_get_session: MagicMock,
) -> None:
redis_client = MagicMock()
lock = MagicMock()
lock.acquire.return_value = True
lock.owned.return_value = True
redis_client.lock.return_value = lock
mock_get_redis.return_value = redis_client
session = _mock_session_returning_none()
mock_get_session.return_value.__enter__.return_value = session
_project_sync_user_file_impl(
user_file_id=str(uuid4()),
tenant_id="test-tenant",
redis_locking=True,
)
mock_get_redis.assert_called_once()
redis_client.delete.assert_called_once()
lock.acquire.assert_called_once_with(blocking=False)
lock.release.assert_called_once()
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
@patch(f"{TASKS_MODULE}.get_redis_client")
def test_redis_locking_true_skips_when_lock_held(
self,
mock_get_redis: MagicMock,
mock_get_session: MagicMock,
) -> None:
redis_client = MagicMock()
lock = MagicMock()
lock.acquire.return_value = False
redis_client.lock.return_value = lock
mock_get_redis.return_value = redis_client
_project_sync_user_file_impl(
user_file_id=str(uuid4()),
tenant_id="test-tenant",
redis_locking=True,
)
lock.acquire.assert_called_once()
mock_get_session.assert_not_called()
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
@patch(f"{TASKS_MODULE}.get_redis_client")
def test_redis_locking_false_skips_redis_entirely(
self,
mock_get_redis: MagicMock,
mock_get_session: MagicMock,
) -> None:
session = _mock_session_returning_none()
mock_get_session.return_value.__enter__.return_value = session
_project_sync_user_file_impl(
user_file_id=str(uuid4()),
tenant_id="test-tenant",
redis_locking=False,
)
mock_get_redis.assert_not_called()
mock_get_session.assert_called_once()

View File

@@ -0,0 +1,421 @@
"""Tests for no-vector-DB user file processing paths.
Verifies that when DISABLE_VECTOR_DB is True:
- _process_user_file_impl calls _process_user_file_without_vector_db (not indexing)
- _process_user_file_without_vector_db extracts text, counts tokens, stores plaintext,
sets status=COMPLETED and chunk_count=0
- _delete_user_file_impl skips vector DB chunk deletion
- _project_sync_user_file_impl skips vector DB metadata update
"""
from unittest.mock import MagicMock
from unittest.mock import patch
from uuid import uuid4
from onyx.background.celery.tasks.user_file_processing.tasks import (
_delete_user_file_impl,
)
from onyx.background.celery.tasks.user_file_processing.tasks import (
_process_user_file_impl,
)
from onyx.background.celery.tasks.user_file_processing.tasks import (
_process_user_file_without_vector_db,
)
from onyx.background.celery.tasks.user_file_processing.tasks import (
_project_sync_user_file_impl,
)
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.db.enums import UserFileStatus
TASKS_MODULE = "onyx.background.celery.tasks.user_file_processing.tasks"
LLM_FACTORY_MODULE = "onyx.llm.factory"
def _make_documents(texts: list[str]) -> list[Document]:
"""Build a list of Document objects with the given section texts."""
return [
Document(
id=str(uuid4()),
source=DocumentSource.USER_FILE,
sections=[TextSection(text=t)],
semantic_identifier=f"test-doc-{i}",
metadata={},
)
for i, t in enumerate(texts)
]
def _make_user_file(
*,
status: UserFileStatus = UserFileStatus.PROCESSING,
file_id: str = "test-file-id",
name: str = "test.txt",
) -> MagicMock:
"""Return a MagicMock mimicking a UserFile ORM instance."""
uf = MagicMock()
uf.id = uuid4()
uf.file_id = file_id
uf.name = name
uf.status = status
uf.token_count = None
uf.chunk_count = None
uf.last_project_sync_at = None
uf.projects = []
uf.assistants = []
uf.needs_project_sync = True
uf.needs_persona_sync = True
return uf
# ------------------------------------------------------------------
# _process_user_file_without_vector_db — direct tests
# ------------------------------------------------------------------
class TestProcessUserFileWithoutVectorDb:
@patch(f"{TASKS_MODULE}.store_user_file_plaintext")
@patch(f"{LLM_FACTORY_MODULE}.get_llm_tokenizer_encode_func")
@patch(f"{LLM_FACTORY_MODULE}.get_default_llm")
def test_extracts_and_combines_text(
self,
mock_get_llm: MagicMock, # noqa: ARG002
mock_get_encode: MagicMock,
mock_store_plaintext: MagicMock,
) -> None:
mock_encode = MagicMock(return_value=[1, 2, 3, 4, 5])
mock_get_encode.return_value = mock_encode
uf = _make_user_file()
docs = _make_documents(["hello world", "foo bar"])
db_session = MagicMock()
_process_user_file_without_vector_db(uf, docs, db_session)
stored_text = mock_store_plaintext.call_args.kwargs["plaintext_content"]
assert "hello world" in stored_text
assert "foo bar" in stored_text
@patch(f"{TASKS_MODULE}.store_user_file_plaintext")
@patch(f"{LLM_FACTORY_MODULE}.get_llm_tokenizer_encode_func")
@patch(f"{LLM_FACTORY_MODULE}.get_default_llm")
def test_computes_token_count(
self,
mock_get_llm: MagicMock, # noqa: ARG002
mock_get_encode: MagicMock,
mock_store_plaintext: MagicMock, # noqa: ARG002
) -> None:
mock_encode = MagicMock(return_value=list(range(42)))
mock_get_encode.return_value = mock_encode
uf = _make_user_file()
docs = _make_documents(["some text content"])
db_session = MagicMock()
_process_user_file_without_vector_db(uf, docs, db_session)
assert uf.token_count == 42
@patch(f"{TASKS_MODULE}.store_user_file_plaintext")
@patch(f"{LLM_FACTORY_MODULE}.get_llm_tokenizer_encode_func")
@patch(f"{LLM_FACTORY_MODULE}.get_default_llm")
def test_token_count_falls_back_to_none_on_error(
self,
mock_get_llm: MagicMock,
mock_get_encode: MagicMock, # noqa: ARG002
mock_store_plaintext: MagicMock, # noqa: ARG002
) -> None:
mock_get_llm.side_effect = RuntimeError("No LLM configured")
uf = _make_user_file()
docs = _make_documents(["text"])
db_session = MagicMock()
_process_user_file_without_vector_db(uf, docs, db_session)
assert uf.token_count is None
@patch(f"{TASKS_MODULE}.store_user_file_plaintext")
@patch(f"{LLM_FACTORY_MODULE}.get_llm_tokenizer_encode_func")
@patch(f"{LLM_FACTORY_MODULE}.get_default_llm")
def test_stores_plaintext(
self,
mock_get_llm: MagicMock, # noqa: ARG002
mock_get_encode: MagicMock,
mock_store_plaintext: MagicMock,
) -> None:
mock_get_encode.return_value = MagicMock(return_value=[1])
uf = _make_user_file()
docs = _make_documents(["content to store"])
db_session = MagicMock()
_process_user_file_without_vector_db(uf, docs, db_session)
mock_store_plaintext.assert_called_once_with(
user_file_id=uf.id,
plaintext_content="content to store",
)
@patch(f"{TASKS_MODULE}.store_user_file_plaintext")
@patch(f"{LLM_FACTORY_MODULE}.get_llm_tokenizer_encode_func")
@patch(f"{LLM_FACTORY_MODULE}.get_default_llm")
def test_sets_completed_status_and_zero_chunk_count(
self,
mock_get_llm: MagicMock, # noqa: ARG002
mock_get_encode: MagicMock,
mock_store_plaintext: MagicMock, # noqa: ARG002
) -> None:
mock_get_encode.return_value = MagicMock(return_value=[1])
uf = _make_user_file()
docs = _make_documents(["text"])
db_session = MagicMock()
_process_user_file_without_vector_db(uf, docs, db_session)
assert uf.status == UserFileStatus.COMPLETED
assert uf.chunk_count == 0
assert uf.last_project_sync_at is not None
db_session.add.assert_called_once_with(uf)
db_session.commit.assert_called_once()
@patch(f"{TASKS_MODULE}.store_user_file_plaintext")
@patch(f"{LLM_FACTORY_MODULE}.get_llm_tokenizer_encode_func")
@patch(f"{LLM_FACTORY_MODULE}.get_default_llm")
def test_preserves_deleting_status(
self,
mock_get_llm: MagicMock, # noqa: ARG002
mock_get_encode: MagicMock,
mock_store_plaintext: MagicMock, # noqa: ARG002
) -> None:
mock_get_encode.return_value = MagicMock(return_value=[1])
uf = _make_user_file(status=UserFileStatus.DELETING)
docs = _make_documents(["text"])
db_session = MagicMock()
_process_user_file_without_vector_db(uf, docs, db_session)
assert uf.status == UserFileStatus.DELETING
assert uf.chunk_count == 0
# ------------------------------------------------------------------
# _process_user_file_impl — branching on DISABLE_VECTOR_DB
# ------------------------------------------------------------------
class TestProcessImplBranching:
@patch(f"{TASKS_MODULE}._process_user_file_without_vector_db")
@patch(f"{TASKS_MODULE}._process_user_file_with_indexing")
@patch(f"{TASKS_MODULE}.DISABLE_VECTOR_DB", True)
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
def test_calls_without_vector_db_when_disabled(
self,
mock_get_session: MagicMock,
mock_with_indexing: MagicMock,
mock_without_vdb: MagicMock,
) -> None:
uf = _make_user_file()
session = MagicMock()
session.get.return_value = uf
mock_get_session.return_value.__enter__.return_value = session
connector_mock = MagicMock()
connector_mock.load_from_state.return_value = [_make_documents(["hello"])]
with patch(f"{TASKS_MODULE}.LocalFileConnector", return_value=connector_mock):
_process_user_file_impl(
user_file_id=str(uf.id),
tenant_id="test-tenant",
redis_locking=False,
)
mock_without_vdb.assert_called_once()
mock_with_indexing.assert_not_called()
@patch(f"{TASKS_MODULE}._process_user_file_without_vector_db")
@patch(f"{TASKS_MODULE}._process_user_file_with_indexing")
@patch(f"{TASKS_MODULE}.DISABLE_VECTOR_DB", False)
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
def test_calls_with_indexing_when_vector_db_enabled(
self,
mock_get_session: MagicMock,
mock_with_indexing: MagicMock,
mock_without_vdb: MagicMock,
) -> None:
uf = _make_user_file()
session = MagicMock()
session.get.return_value = uf
mock_get_session.return_value.__enter__.return_value = session
connector_mock = MagicMock()
connector_mock.load_from_state.return_value = [_make_documents(["hello"])]
with patch(f"{TASKS_MODULE}.LocalFileConnector", return_value=connector_mock):
_process_user_file_impl(
user_file_id=str(uf.id),
tenant_id="test-tenant",
redis_locking=False,
)
mock_with_indexing.assert_called_once()
mock_without_vdb.assert_not_called()
@patch(f"{TASKS_MODULE}.run_indexing_pipeline")
@patch(f"{TASKS_MODULE}.store_user_file_plaintext")
@patch(f"{TASKS_MODULE}.DISABLE_VECTOR_DB", True)
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
def test_indexing_pipeline_not_called_when_disabled(
self,
mock_get_session: MagicMock,
mock_store_plaintext: MagicMock, # noqa: ARG002
mock_run_pipeline: MagicMock,
) -> None:
"""End-to-end: verify run_indexing_pipeline is never invoked."""
uf = _make_user_file()
session = MagicMock()
session.get.return_value = uf
mock_get_session.return_value.__enter__.return_value = session
connector_mock = MagicMock()
connector_mock.load_from_state.return_value = [_make_documents(["content"])]
with (
patch(f"{TASKS_MODULE}.LocalFileConnector", return_value=connector_mock),
patch(f"{LLM_FACTORY_MODULE}.get_default_llm"),
patch(
f"{LLM_FACTORY_MODULE}.get_llm_tokenizer_encode_func",
return_value=MagicMock(return_value=[1, 2, 3]),
),
):
_process_user_file_impl(
user_file_id=str(uf.id),
tenant_id="test-tenant",
redis_locking=False,
)
mock_run_pipeline.assert_not_called()
# ------------------------------------------------------------------
# _delete_user_file_impl — vector DB skip
# ------------------------------------------------------------------
class TestDeleteImplNoVectorDb:
@patch(f"{TASKS_MODULE}.DISABLE_VECTOR_DB", True)
@patch(f"{TASKS_MODULE}.get_default_file_store")
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
def test_skips_vector_db_deletion(
self,
mock_get_session: MagicMock,
mock_get_file_store: MagicMock,
) -> None:
uf = _make_user_file(status=UserFileStatus.DELETING)
session = MagicMock()
session.get.return_value = uf
mock_get_session.return_value.__enter__.return_value = session
mock_get_file_store.return_value = MagicMock()
with (
patch(f"{TASKS_MODULE}.get_all_document_indices") as mock_get_indices,
patch(f"{TASKS_MODULE}.get_active_search_settings") as mock_get_ss,
patch(f"{TASKS_MODULE}.httpx_init_vespa_pool") as mock_vespa_pool,
):
_delete_user_file_impl(
user_file_id=str(uf.id),
tenant_id="test-tenant",
redis_locking=False,
)
mock_get_indices.assert_not_called()
mock_get_ss.assert_not_called()
mock_vespa_pool.assert_not_called()
session.delete.assert_called_once_with(uf)
session.commit.assert_called_once()
@patch(f"{TASKS_MODULE}.DISABLE_VECTOR_DB", True)
@patch(f"{TASKS_MODULE}.get_default_file_store")
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
def test_still_deletes_file_store_and_db_record(
self,
mock_get_session: MagicMock,
mock_get_file_store: MagicMock,
) -> None:
uf = _make_user_file(status=UserFileStatus.DELETING)
session = MagicMock()
session.get.return_value = uf
mock_get_session.return_value.__enter__.return_value = session
file_store = MagicMock()
mock_get_file_store.return_value = file_store
_delete_user_file_impl(
user_file_id=str(uf.id),
tenant_id="test-tenant",
redis_locking=False,
)
assert file_store.delete_file.call_count == 2
session.delete.assert_called_once_with(uf)
session.commit.assert_called_once()
# ------------------------------------------------------------------
# _project_sync_user_file_impl — vector DB skip
# ------------------------------------------------------------------
class TestProjectSyncImplNoVectorDb:
@patch(f"{TASKS_MODULE}.DISABLE_VECTOR_DB", True)
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
def test_skips_vector_db_update(
self,
mock_get_session: MagicMock,
) -> None:
uf = _make_user_file(status=UserFileStatus.COMPLETED)
session = MagicMock()
session.execute.return_value.scalar_one_or_none.return_value = uf
mock_get_session.return_value.__enter__.return_value = session
with (
patch(f"{TASKS_MODULE}.get_all_document_indices") as mock_get_indices,
patch(f"{TASKS_MODULE}.get_active_search_settings") as mock_get_ss,
patch(f"{TASKS_MODULE}.httpx_init_vespa_pool") as mock_vespa_pool,
):
_project_sync_user_file_impl(
user_file_id=str(uf.id),
tenant_id="test-tenant",
redis_locking=False,
)
mock_get_indices.assert_not_called()
mock_get_ss.assert_not_called()
mock_vespa_pool.assert_not_called()
@patch(f"{TASKS_MODULE}.DISABLE_VECTOR_DB", True)
@patch(f"{TASKS_MODULE}.get_session_with_current_tenant")
def test_still_clears_sync_flags(
self,
mock_get_session: MagicMock,
) -> None:
uf = _make_user_file(status=UserFileStatus.COMPLETED)
session = MagicMock()
session.execute.return_value.scalar_one_or_none.return_value = uf
mock_get_session.return_value.__enter__.return_value = session
_project_sync_user_file_impl(
user_file_id=str(uf.id),
tenant_id="test-tenant",
redis_locking=False,
)
assert uf.needs_project_sync is False
assert uf.needs_persona_sync is False
assert uf.last_project_sync_at is not None
session.add.assert_called_once_with(uf)
session.commit.assert_called_once()

View File

@@ -44,7 +44,6 @@ def _build_provider_view(
id=1,
name="test-provider",
provider=provider,
default_model_name="test-model",
model_configurations=[
ModelConfigurationView(
name="test-model",
@@ -62,7 +61,6 @@ def _build_provider_view(
groups=[],
personas=[],
deployment_name=None,
default_vision_model=None,
)

View File

@@ -0,0 +1,52 @@
"""Tests for startup validation in no-vector-DB mode.
Verifies that DISABLE_VECTOR_DB raises RuntimeError when combined with
incompatible settings (MULTI_TENANT, ENABLE_CRAFT).
"""
from unittest.mock import patch
import pytest
class TestValidateNoVectorDbSettings:
@patch("onyx.main.DISABLE_VECTOR_DB", False)
def test_no_error_when_vector_db_enabled(self) -> None:
from onyx.main import validate_no_vector_db_settings
validate_no_vector_db_settings()
@patch("onyx.main.DISABLE_VECTOR_DB", True)
@patch("onyx.main.MULTI_TENANT", False)
@patch("onyx.server.features.build.configs.ENABLE_CRAFT", False)
def test_no_error_when_no_conflicts(self) -> None:
from onyx.main import validate_no_vector_db_settings
validate_no_vector_db_settings()
@patch("onyx.main.DISABLE_VECTOR_DB", True)
@patch("onyx.main.MULTI_TENANT", True)
def test_raises_on_multi_tenant(self) -> None:
from onyx.main import validate_no_vector_db_settings
with pytest.raises(RuntimeError, match="MULTI_TENANT"):
validate_no_vector_db_settings()
@patch("onyx.main.DISABLE_VECTOR_DB", True)
@patch("onyx.main.MULTI_TENANT", False)
@patch("onyx.server.features.build.configs.ENABLE_CRAFT", True)
def test_raises_on_enable_craft(self) -> None:
from onyx.main import validate_no_vector_db_settings
with pytest.raises(RuntimeError, match="ENABLE_CRAFT"):
validate_no_vector_db_settings()
@patch("onyx.main.DISABLE_VECTOR_DB", True)
@patch("onyx.main.MULTI_TENANT", True)
@patch("onyx.server.features.build.configs.ENABLE_CRAFT", True)
def test_multi_tenant_checked_before_craft(self) -> None:
"""MULTI_TENANT is checked first, so it should be the error raised."""
from onyx.main import validate_no_vector_db_settings
with pytest.raises(RuntimeError, match="MULTI_TENANT"):
validate_no_vector_db_settings()

View File

@@ -0,0 +1,196 @@
"""Tests for tool construction when DISABLE_VECTOR_DB is True.
Verifies that:
- SearchTool.is_available() returns False when vector DB is disabled
- OpenURLTool.is_available() returns False when vector DB is disabled
- The force-add SearchTool block is suppressed when DISABLE_VECTOR_DB
- FileReaderTool.is_available() returns True when vector DB is disabled
"""
from unittest.mock import MagicMock
from unittest.mock import patch
from onyx.tools.tool_implementations.file_reader.file_reader_tool import FileReaderTool
APP_CONFIGS_MODULE = "onyx.configs.app_configs"
FILE_READER_MODULE = "onyx.tools.tool_implementations.file_reader.file_reader_tool"
# ------------------------------------------------------------------
# SearchTool.is_available()
# ------------------------------------------------------------------
class TestSearchToolAvailability:
@patch(f"{APP_CONFIGS_MODULE}.DISABLE_VECTOR_DB", True)
def test_unavailable_when_vector_db_disabled(self) -> None:
from onyx.tools.tool_implementations.search.search_tool import SearchTool
assert SearchTool.is_available(MagicMock()) is False
@patch("onyx.db.connector.check_user_files_exist", return_value=True)
@patch(
"onyx.tools.tool_implementations.search.search_tool.check_federated_connectors_exist",
return_value=False,
)
@patch(
"onyx.tools.tool_implementations.search.search_tool.check_connectors_exist",
return_value=False,
)
@patch(f"{APP_CONFIGS_MODULE}.DISABLE_VECTOR_DB", False)
def test_available_when_vector_db_enabled_and_files_exist(
self,
mock_connectors: MagicMock, # noqa: ARG002
mock_federated: MagicMock, # noqa: ARG002
mock_user_files: MagicMock, # noqa: ARG002
) -> None:
from onyx.tools.tool_implementations.search.search_tool import SearchTool
assert SearchTool.is_available(MagicMock()) is True
# ------------------------------------------------------------------
# OpenURLTool.is_available()
# ------------------------------------------------------------------
class TestOpenURLToolAvailability:
@patch(f"{APP_CONFIGS_MODULE}.DISABLE_VECTOR_DB", True)
def test_unavailable_when_vector_db_disabled(self) -> None:
from onyx.tools.tool_implementations.open_url.open_url_tool import OpenURLTool
assert OpenURLTool.is_available(MagicMock()) is False
@patch(f"{APP_CONFIGS_MODULE}.DISABLE_VECTOR_DB", False)
def test_available_when_vector_db_enabled(self) -> None:
from onyx.tools.tool_implementations.open_url.open_url_tool import OpenURLTool
assert OpenURLTool.is_available(MagicMock()) is True
# ------------------------------------------------------------------
# FileReaderTool.is_available()
# ------------------------------------------------------------------
class TestFileReaderToolAvailability:
@patch(f"{FILE_READER_MODULE}.DISABLE_VECTOR_DB", True)
def test_available_when_vector_db_disabled(self) -> None:
assert FileReaderTool.is_available(MagicMock()) is True
@patch(f"{FILE_READER_MODULE}.DISABLE_VECTOR_DB", False)
def test_unavailable_when_vector_db_enabled(self) -> None:
assert FileReaderTool.is_available(MagicMock()) is False
# ------------------------------------------------------------------
# Force-add SearchTool suppression
# ------------------------------------------------------------------
class TestForceAddSearchToolGuard:
def test_force_add_block_checks_disable_vector_db(self) -> None:
"""The force-add SearchTool block in construct_tools should include
`not DISABLE_VECTOR_DB` so that forced search is also suppressed
without a vector DB."""
import inspect
from onyx.tools.tool_constructor import construct_tools
source = inspect.getsource(construct_tools)
assert "DISABLE_VECTOR_DB" in source, (
"construct_tools should reference DISABLE_VECTOR_DB "
"to suppress force-adding SearchTool"
)
# ------------------------------------------------------------------
# Persona API — _validate_vector_db_knowledge
# ------------------------------------------------------------------
class TestValidateVectorDbKnowledge:
@patch(
"onyx.server.features.persona.api.DISABLE_VECTOR_DB",
True,
)
def test_rejects_document_set_ids(self) -> None:
from fastapi import HTTPException
from onyx.server.features.persona.api import _validate_vector_db_knowledge
request = MagicMock()
request.document_set_ids = [1]
request.hierarchy_node_ids = []
request.document_ids = []
with __import__("pytest").raises(HTTPException) as exc_info:
_validate_vector_db_knowledge(request)
assert exc_info.value.status_code == 400
assert "document sets" in exc_info.value.detail
@patch(
"onyx.server.features.persona.api.DISABLE_VECTOR_DB",
True,
)
def test_rejects_hierarchy_node_ids(self) -> None:
from fastapi import HTTPException
from onyx.server.features.persona.api import _validate_vector_db_knowledge
request = MagicMock()
request.document_set_ids = []
request.hierarchy_node_ids = [1]
request.document_ids = []
with __import__("pytest").raises(HTTPException) as exc_info:
_validate_vector_db_knowledge(request)
assert exc_info.value.status_code == 400
assert "hierarchy nodes" in exc_info.value.detail
@patch(
"onyx.server.features.persona.api.DISABLE_VECTOR_DB",
True,
)
def test_rejects_document_ids(self) -> None:
from fastapi import HTTPException
from onyx.server.features.persona.api import _validate_vector_db_knowledge
request = MagicMock()
request.document_set_ids = []
request.hierarchy_node_ids = []
request.document_ids = ["doc-abc"]
with __import__("pytest").raises(HTTPException) as exc_info:
_validate_vector_db_knowledge(request)
assert exc_info.value.status_code == 400
assert "documents" in exc_info.value.detail
@patch(
"onyx.server.features.persona.api.DISABLE_VECTOR_DB",
True,
)
def test_allows_user_files_only(self) -> None:
from onyx.server.features.persona.api import _validate_vector_db_knowledge
request = MagicMock()
request.document_set_ids = []
request.hierarchy_node_ids = []
request.document_ids = []
_validate_vector_db_knowledge(request)
@patch(
"onyx.server.features.persona.api.DISABLE_VECTOR_DB",
False,
)
def test_allows_everything_when_vector_db_enabled(self) -> None:
from onyx.server.features.persona.api import _validate_vector_db_knowledge
request = MagicMock()
request.document_set_ids = [1, 2]
request.hierarchy_node_ids = [3]
request.document_ids = ["doc-x"]
_validate_vector_db_knowledge(request)

View File

@@ -0,0 +1,237 @@
"""Tests for the FileReaderTool.
Verifies:
- Tool definition schema is well-formed
- File ID validation (allowlist, UUID format)
- Character range extraction and clamping
- Error handling for missing parameters and non-text files
- is_available() reflects DISABLE_VECTOR_DB
"""
from unittest.mock import MagicMock
from unittest.mock import patch
from uuid import uuid4
import pytest
from onyx.file_store.models import ChatFileType
from onyx.file_store.models import InMemoryChatFile
from onyx.server.query_and_chat.placement import Placement
from onyx.tools.models import ToolCallException
from onyx.tools.tool_implementations.file_reader.file_reader_tool import FILE_ID_FIELD
from onyx.tools.tool_implementations.file_reader.file_reader_tool import FileReaderTool
from onyx.tools.tool_implementations.file_reader.file_reader_tool import MAX_NUM_CHARS
from onyx.tools.tool_implementations.file_reader.file_reader_tool import NUM_CHARS_FIELD
from onyx.tools.tool_implementations.file_reader.file_reader_tool import (
START_CHAR_FIELD,
)
TOOL_MODULE = "onyx.tools.tool_implementations.file_reader.file_reader_tool"
_PLACEMENT = Placement(turn_index=0)
def _make_tool(
user_file_ids: list | None = None,
chat_file_ids: list | None = None,
) -> FileReaderTool:
emitter = MagicMock()
return FileReaderTool(
tool_id=99,
emitter=emitter,
user_file_ids=user_file_ids or [],
chat_file_ids=chat_file_ids or [],
)
def _text_file(content: str, filename: str = "test.txt") -> InMemoryChatFile:
return InMemoryChatFile(
file_id="some-file-id",
content=content.encode("utf-8"),
file_type=ChatFileType.PLAIN_TEXT,
filename=filename,
)
# ------------------------------------------------------------------
# Tool metadata
# ------------------------------------------------------------------
class TestToolMetadata:
def test_tool_name(self) -> None:
tool = _make_tool()
assert tool.name == "read_file"
def test_tool_definition_schema(self) -> None:
tool = _make_tool()
defn = tool.tool_definition()
assert defn["type"] == "function"
func = defn["function"]
assert func["name"] == "read_file"
props = func["parameters"]["properties"]
assert FILE_ID_FIELD in props
assert START_CHAR_FIELD in props
assert NUM_CHARS_FIELD in props
assert func["parameters"]["required"] == [FILE_ID_FIELD]
# ------------------------------------------------------------------
# File ID validation
# ------------------------------------------------------------------
class TestFileIdValidation:
def test_rejects_invalid_uuid(self) -> None:
tool = _make_tool()
with pytest.raises(ToolCallException, match="Invalid file_id"):
tool._validate_file_id("not-a-uuid")
def test_rejects_file_not_in_allowlist(self) -> None:
tool = _make_tool(user_file_ids=[uuid4()])
other_id = uuid4()
with pytest.raises(ToolCallException, match="not in available files"):
tool._validate_file_id(str(other_id))
def test_accepts_user_file_id(self) -> None:
uid = uuid4()
tool = _make_tool(user_file_ids=[uid])
assert tool._validate_file_id(str(uid)) == uid
def test_accepts_chat_file_id(self) -> None:
cid = uuid4()
tool = _make_tool(chat_file_ids=[cid])
assert tool._validate_file_id(str(cid)) == cid
# ------------------------------------------------------------------
# run() — character range extraction
# ------------------------------------------------------------------
class TestRun:
@patch(f"{TOOL_MODULE}.get_session_with_current_tenant")
@patch(f"{TOOL_MODULE}.load_user_file")
def test_returns_full_content_by_default(
self,
mock_load_user_file: MagicMock,
mock_get_session: MagicMock,
) -> None:
uid = uuid4()
content = "Hello, world!"
mock_load_user_file.return_value = _text_file(content)
mock_get_session.return_value.__enter__.return_value = MagicMock()
tool = _make_tool(user_file_ids=[uid])
resp = tool.run(
placement=_PLACEMENT,
override_kwargs=MagicMock(),
**{FILE_ID_FIELD: str(uid)},
)
assert content in resp.llm_facing_response
@patch(f"{TOOL_MODULE}.get_session_with_current_tenant")
@patch(f"{TOOL_MODULE}.load_user_file")
def test_respects_start_char_and_num_chars(
self,
mock_load_user_file: MagicMock,
mock_get_session: MagicMock,
) -> None:
uid = uuid4()
content = "abcdefghijklmnop"
mock_load_user_file.return_value = _text_file(content)
mock_get_session.return_value.__enter__.return_value = MagicMock()
tool = _make_tool(user_file_ids=[uid])
resp = tool.run(
placement=_PLACEMENT,
override_kwargs=MagicMock(),
**{FILE_ID_FIELD: str(uid), START_CHAR_FIELD: 4, NUM_CHARS_FIELD: 6},
)
assert "efghij" in resp.llm_facing_response
@patch(f"{TOOL_MODULE}.get_session_with_current_tenant")
@patch(f"{TOOL_MODULE}.load_user_file")
def test_clamps_num_chars_to_max(
self,
mock_load_user_file: MagicMock,
mock_get_session: MagicMock,
) -> None:
uid = uuid4()
content = "x" * (MAX_NUM_CHARS + 500)
mock_load_user_file.return_value = _text_file(content)
mock_get_session.return_value.__enter__.return_value = MagicMock()
tool = _make_tool(user_file_ids=[uid])
resp = tool.run(
placement=_PLACEMENT,
override_kwargs=MagicMock(),
**{FILE_ID_FIELD: str(uid), NUM_CHARS_FIELD: MAX_NUM_CHARS + 9999},
)
assert f"Characters 0-{MAX_NUM_CHARS}" in resp.llm_facing_response
@patch(f"{TOOL_MODULE}.get_session_with_current_tenant")
@patch(f"{TOOL_MODULE}.load_user_file")
def test_includes_continuation_hint(
self,
mock_load_user_file: MagicMock,
mock_get_session: MagicMock,
) -> None:
uid = uuid4()
content = "x" * 100
mock_load_user_file.return_value = _text_file(content)
mock_get_session.return_value.__enter__.return_value = MagicMock()
tool = _make_tool(user_file_ids=[uid])
resp = tool.run(
placement=_PLACEMENT,
override_kwargs=MagicMock(),
**{FILE_ID_FIELD: str(uid), NUM_CHARS_FIELD: 10},
)
assert "use start_char=10 to continue reading" in resp.llm_facing_response
def test_raises_on_missing_file_id(self) -> None:
tool = _make_tool()
with pytest.raises(ToolCallException, match="Missing required"):
tool.run(
placement=_PLACEMENT,
override_kwargs=MagicMock(),
)
@patch(f"{TOOL_MODULE}.get_session_with_current_tenant")
@patch(f"{TOOL_MODULE}.load_user_file")
def test_raises_on_non_text_file(
self,
mock_load_user_file: MagicMock,
mock_get_session: MagicMock,
) -> None:
uid = uuid4()
mock_load_user_file.return_value = InMemoryChatFile(
file_id="img",
content=b"\x89PNG",
file_type=ChatFileType.IMAGE,
filename="photo.png",
)
mock_get_session.return_value.__enter__.return_value = MagicMock()
tool = _make_tool(user_file_ids=[uid])
with pytest.raises(ToolCallException, match="not a text file"):
tool.run(
placement=_PLACEMENT,
override_kwargs=MagicMock(),
**{FILE_ID_FIELD: str(uid)},
)
# ------------------------------------------------------------------
# is_available()
# ------------------------------------------------------------------
class TestIsAvailable:
@patch(f"{TOOL_MODULE}.DISABLE_VECTOR_DB", True)
def test_available_when_vector_db_disabled(self) -> None:
assert FileReaderTool.is_available(MagicMock()) is True
@patch(f"{TOOL_MODULE}.DISABLE_VECTOR_DB", False)
def test_unavailable_when_vector_db_enabled(self) -> None:
assert FileReaderTool.is_available(MagicMock()) is False

View File

@@ -40,8 +40,6 @@ const TRAY_MENU_OPEN_APP_ID: &str = "tray_open_app";
const TRAY_MENU_OPEN_CHAT_ID: &str = "tray_open_chat";
const TRAY_MENU_SHOW_IN_BAR_ID: &str = "tray_show_in_menu_bar";
const TRAY_MENU_QUIT_ID: &str = "tray_quit";
const MENU_SHOW_MENU_BAR_ID: &str = "show_menu_bar";
const MENU_HIDE_DECORATIONS_ID: &str = "hide_window_decorations";
const CHAT_LINK_INTERCEPT_SCRIPT: &str = r##"
(() => {
if (window.__ONYX_CHAT_LINK_INTERCEPT_INSTALLED__) {
@@ -173,75 +171,25 @@ const CHAT_LINK_INTERCEPT_SCRIPT: &str = r##"
})();
"##;
#[cfg(not(target_os = "macos"))]
const MENU_KEY_HANDLER_SCRIPT: &str = r#"
(() => {
if (window.__ONYX_MENU_KEY_HANDLER__) return;
window.__ONYX_MENU_KEY_HANDLER__ = true;
let altHeld = false;
function invoke(cmd) {
const fn_ =
window.__TAURI__?.core?.invoke || window.__TAURI_INTERNALS__?.invoke;
if (typeof fn_ === 'function') fn_(cmd);
}
document.addEventListener('keydown', (e) => {
if (e.key === 'Alt') {
if (!altHeld) {
altHeld = true;
invoke('show_menu_bar_temporarily');
}
return;
}
if (e.altKey && e.key === 'F1') {
e.preventDefault();
e.stopPropagation();
altHeld = false;
invoke('toggle_menu_bar');
return;
}
}, true);
document.addEventListener('keyup', (e) => {
if (e.key === 'Alt' && altHeld) {
altHeld = false;
invoke('hide_menu_bar_temporary');
}
}, true);
})();
"#;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AppConfig {
/// The Onyx server URL (default: https://cloud.onyx.app)
pub server_url: String,
/// Optional: Custom window title
#[serde(default = "default_window_title")]
pub window_title: String,
#[serde(default = "default_show_menu_bar")]
pub show_menu_bar: bool,
#[serde(default)]
pub hide_window_decorations: bool,
}
fn default_window_title() -> String {
"Onyx".to_string()
}
fn default_show_menu_bar() -> bool {
true
}
impl Default for AppConfig {
fn default() -> Self {
Self {
server_url: DEFAULT_SERVER_URL.to_string(),
window_title: default_window_title(),
show_menu_bar: true,
hide_window_decorations: false,
}
}
}
@@ -299,7 +247,6 @@ struct ConfigState {
config: RwLock<AppConfig>,
config_initialized: RwLock<bool>,
app_base_url: RwLock<Option<Url>>,
menu_temporarily_visible: RwLock<bool>,
}
fn focus_main_window(app: &AppHandle) {
@@ -354,7 +301,6 @@ fn trigger_new_window(app: &AppHandle) {
inject_titlebar(window.clone());
}
apply_settings_to_window(&handle, &window);
let _ = window.set_focus();
}
});
@@ -631,15 +577,18 @@ async fn new_window(app: AppHandle, state: tauri::State<'_, ConfigState>) -> Res
#[cfg(target_os = "linux")]
let builder = builder.background_color(tauri::window::Color(0x1a, 0x1a, 0x2e, 0xff));
let window = builder.build().map_err(|e| e.to_string())?;
#[cfg(target_os = "macos")]
{
let window = builder.build().map_err(|e| e.to_string())?;
// Apply vibrancy effect and inject titlebar
let _ = apply_vibrancy(&window, NSVisualEffectMaterial::Sidebar, None, None);
inject_titlebar(window.clone());
}
apply_settings_to_window(&app, &window);
#[cfg(not(target_os = "macos"))]
{
let _window = builder.build().map_err(|e| e.to_string())?;
}
Ok(())
}
@@ -675,155 +624,6 @@ async fn start_drag_window(window: tauri::Window) -> Result<(), String> {
window.start_dragging().map_err(|e| e.to_string())
}
// ============================================================================
// Window Settings
// ============================================================================
fn find_check_menu_item(
app: &AppHandle,
id: &str,
) -> Option<CheckMenuItem<tauri::Wry>> {
let menu = app.menu()?;
for item in menu.items().ok()? {
if let Some(submenu) = item.as_submenu() {
for sub_item in submenu.items().ok()? {
if let Some(check) = sub_item.as_check_menuitem() {
if check.id().as_ref() == id {
return Some(check.clone());
}
}
}
}
}
None
}
fn apply_settings_to_window(app: &AppHandle, window: &tauri::WebviewWindow) {
if cfg!(target_os = "macos") {
return;
}
let state = app.state::<ConfigState>();
let config = state.config.read().unwrap();
if !config.show_menu_bar {
let _ = window.hide_menu();
}
if config.hide_window_decorations {
let _ = window.set_decorations(false);
}
}
fn handle_menu_bar_toggle(app: &AppHandle) {
if cfg!(target_os = "macos") {
return;
}
let state = app.state::<ConfigState>();
let show = {
let mut config = state.config.write().unwrap();
config.show_menu_bar = !config.show_menu_bar;
let _ = save_config(&config);
config.show_menu_bar
};
*state.menu_temporarily_visible.write().unwrap() = false;
for (_, window) in app.webview_windows() {
if show {
let _ = window.show_menu();
} else {
let _ = window.hide_menu();
}
}
}
fn handle_decorations_toggle(app: &AppHandle) {
if cfg!(target_os = "macos") {
return;
}
let state = app.state::<ConfigState>();
let hide = {
let mut config = state.config.write().unwrap();
config.hide_window_decorations = !config.hide_window_decorations;
let _ = save_config(&config);
config.hide_window_decorations
};
for (_, window) in app.webview_windows() {
let _ = window.set_decorations(!hide);
}
}
#[tauri::command]
fn toggle_menu_bar(app: AppHandle) {
if cfg!(target_os = "macos") {
return;
}
let state = app.state::<ConfigState>();
let new_value = {
let mut config = state.config.write().unwrap();
config.show_menu_bar = !config.show_menu_bar;
let _ = save_config(&config);
config.show_menu_bar
};
*state.menu_temporarily_visible.write().unwrap() = false;
if let Some(check) = find_check_menu_item(&app, MENU_SHOW_MENU_BAR_ID) {
let _ = check.set_checked(new_value);
}
for (_, window) in app.webview_windows() {
if new_value {
let _ = window.show_menu();
} else {
let _ = window.hide_menu();
}
}
}
#[tauri::command]
fn show_menu_bar_temporarily(app: AppHandle) {
if cfg!(target_os = "macos") {
return;
}
let state = app.state::<ConfigState>();
if state.config.read().unwrap().show_menu_bar {
return;
}
let mut temp = state.menu_temporarily_visible.write().unwrap();
if *temp {
return;
}
*temp = true;
drop(temp);
for (_, window) in app.webview_windows() {
let _ = window.show_menu();
}
}
#[tauri::command]
fn hide_menu_bar_temporary(app: AppHandle) {
if cfg!(target_os = "macos") {
return;
}
let state = app.state::<ConfigState>();
let mut temp = state.menu_temporarily_visible.write().unwrap();
if !*temp {
return;
}
*temp = false;
drop(temp);
if state.config.read().unwrap().show_menu_bar {
return;
}
for (_, window) in app.webview_windows() {
let _ = window.hide_menu();
}
}
// ============================================================================
// Menu Setup
// ============================================================================
@@ -867,59 +667,6 @@ fn setup_app_menu(app: &AppHandle) -> tauri::Result<()> {
menu.prepend(&file_menu)?;
}
#[cfg(not(target_os = "macos"))]
{
let config = app.state::<ConfigState>();
let config_guard = config.config.read().unwrap();
let show_menu_bar_item = CheckMenuItem::with_id(
app,
MENU_SHOW_MENU_BAR_ID,
"Show Menu Bar",
true,
config_guard.show_menu_bar,
None::<&str>,
)?;
let hide_decorations_item = CheckMenuItem::with_id(
app,
MENU_HIDE_DECORATIONS_ID,
"Hide Window Decorations",
true,
config_guard.hide_window_decorations,
None::<&str>,
)?;
drop(config_guard);
if let Some(window_menu) = menu
.items()?
.into_iter()
.filter_map(|item| item.as_submenu().cloned())
.find(|submenu| submenu.text().ok().as_deref() == Some("Window"))
{
window_menu.append(&show_menu_bar_item)?;
window_menu.append(&hide_decorations_item)?;
} else {
let window_menu = SubmenuBuilder::new(app, "Window")
.item(&show_menu_bar_item)
.item(&hide_decorations_item)
.build()?;
let items = menu.items()?;
let help_idx = items
.iter()
.position(|item| {
item.as_submenu()
.and_then(|s| s.text().ok())
.as_deref()
== Some("Help")
})
.unwrap_or(items.len());
menu.insert(&window_menu, help_idx)?;
}
}
if let Some(help_menu) = menu
.get(HELP_SUBMENU_ID)
.and_then(|item| item.as_submenu().cloned())
@@ -1054,7 +801,6 @@ fn main() {
config: RwLock::new(config),
config_initialized: RwLock::new(config_initialized),
app_base_url: RwLock::new(None),
menu_temporarily_visible: RwLock::new(false),
})
.invoke_handler(tauri::generate_handler![
get_server_url,
@@ -1070,18 +816,13 @@ fn main() {
go_forward,
new_window,
reset_config,
start_drag_window,
toggle_menu_bar,
show_menu_bar_temporarily,
hide_menu_bar_temporary
start_drag_window
])
.on_menu_event(|app, event| match event.id().as_ref() {
"open_docs" => open_docs(),
"new_chat" => trigger_new_chat(app),
"new_window" => trigger_new_window(app),
"open_settings" => open_settings(app),
"show_menu_bar" => handle_menu_bar_toggle(app),
"hide_window_decorations" => handle_decorations_toggle(app),
_ => {}
})
.setup(move |app| {
@@ -1114,8 +855,6 @@ fn main() {
#[cfg(target_os = "macos")]
inject_titlebar(window.clone());
apply_settings_to_window(&app_handle, &window);
let _ = window.set_focus();
}
@@ -1124,27 +863,7 @@ fn main() {
.on_page_load(|webview: &Webview, _payload: &PageLoadPayload| {
inject_chat_link_intercept(webview);
#[cfg(not(target_os = "macos"))]
{
let _ = webview.eval(MENU_KEY_HANDLER_SCRIPT);
let app = webview.app_handle();
let state = app.state::<ConfigState>();
let config = state.config.read().unwrap();
let temp_visible = *state.menu_temporarily_visible.read().unwrap();
let label = webview.label().to_string();
if !config.show_menu_bar && !temp_visible {
if let Some(win) = app.get_webview_window(&label) {
let _ = win.hide_menu();
}
}
if config.hide_window_decorations {
if let Some(win) = app.get_webview_window(&label) {
let _ = win.set_decorations(false);
}
}
}
// Re-inject titlebar after every navigation/page load (macOS only)
#[cfg(target_os = "macos")]
let _ = webview.eval(TITLEBAR_SCRIPT);
})

View File

@@ -1,6 +1,6 @@
/* Hoverable — item transitions */
.hoverable-item {
transition: opacity 200ms ease-in-out;
transition: opacity 150ms ease-in-out;
}
.hoverable-item[data-hoverable-variant="opacity-on-hover"] {

View File

@@ -11,7 +11,7 @@ import { DraggableTable } from "@/components/table/DraggableTable";
import {
deletePersona,
personaComparator,
togglePersonaDefault,
togglePersonaFeatured,
togglePersonaVisibility,
} from "./lib";
import { FiEdit2 } from "react-icons/fi";
@@ -27,8 +27,8 @@ function PersonaTypeDisplay({ persona }: { persona: Persona }) {
return <Text as="p">Built-In</Text>;
}
if (persona.is_default_persona) {
return <Text as="p">Default</Text>;
if (persona.featured) {
return <Text as="p">Featured</Text>;
}
if (persona.is_public) {
@@ -152,9 +152,9 @@ export function PersonasTable({
const handleToggleDefault = async () => {
if (personaToToggleDefault) {
const response = await togglePersonaDefault(
const response = await togglePersonaFeatured(
personaToToggleDefault.id,
personaToToggleDefault.is_default_persona
personaToToggleDefault.featured
);
if (response.ok) {
refreshPersonas();
@@ -180,7 +180,7 @@ export function PersonasTable({
{defaultModalOpen &&
personaToToggleDefault &&
(() => {
const isDefault = personaToToggleDefault.is_default_persona;
const isDefault = personaToToggleDefault.featured;
const title = isDefault
? "Remove Featured Agent"
@@ -252,7 +252,7 @@ export function PersonasTable({
</p>,
<PersonaTypeDisplay key={persona.id} persona={persona} />,
<div
key="is_default_persona"
key="featured"
onClick={() => {
openDefaultModal(persona);
}}
@@ -261,13 +261,13 @@ export function PersonasTable({
`}
>
<div className="my-auto flex-none w-22">
{!persona.is_default_persona ? (
{!persona.featured ? (
<div className="text-error">Not Featured</div>
) : (
"Featured"
)}
</div>
<Checkbox checked={persona.is_default_persona} />
<Checkbox checked={persona.featured} />
</div>,
<div
key="is_visible"

View File

@@ -53,7 +53,7 @@ export interface MinimalPersonaSnapshot {
is_public: boolean;
is_visible: boolean;
display_priority: number | null;
is_default_persona: boolean;
featured: boolean;
builtin_persona: boolean;
labels?: PersonaLabel[];
@@ -64,7 +64,6 @@ export interface Persona extends MinimalPersonaSnapshot {
user_file_ids: string[];
users: MinimalUserSnapshot[];
groups: number[];
num_chunks?: number;
// Hierarchy nodes (folders, spaces, channels) attached for scoped search
hierarchy_nodes?: HierarchyNodeSnapshot[];
// Individual documents attached for scoped search
@@ -79,8 +78,6 @@ export interface Persona extends MinimalPersonaSnapshot {
export interface FullPersona extends Persona {
search_start_date: string | null;
llm_relevance_filter?: boolean;
llm_filter_extraction?: boolean;
}
export interface PersonaLabel {

View File

@@ -11,11 +11,7 @@ interface PersonaUpsertRequest {
task_prompt: string;
datetime_aware: boolean;
document_set_ids: number[];
num_chunks: number | null;
is_public: boolean;
recency_bias: string;
llm_filter_extraction: boolean;
llm_relevance_filter: boolean | null;
llm_model_provider_override: string | null;
llm_model_version_override: string | null;
starter_messages: StarterMessage[] | null;
@@ -26,7 +22,7 @@ interface PersonaUpsertRequest {
uploaded_image_id: string | null;
icon_name: string | null;
search_start_date: Date | null;
is_default_persona: boolean;
featured: boolean;
display_priority: number | null;
label_ids: number[] | null;
user_file_ids: string[] | null;
@@ -45,9 +41,7 @@ export interface PersonaUpsertParameters {
task_prompt: string;
datetime_aware: boolean;
document_set_ids: number[];
num_chunks: number | null;
is_public: boolean;
llm_relevance_filter: boolean | null;
llm_model_provider_override: string | null;
llm_model_version_override: string | null;
starter_messages: StarterMessage[] | null;
@@ -58,7 +52,7 @@ export interface PersonaUpsertParameters {
search_start_date: Date | null;
uploaded_image_id: string | null;
icon_name: string | null;
is_default_persona: boolean;
featured: boolean;
label_ids: number[] | null;
user_file_ids: string[];
// Hierarchy nodes (folders, spaces, channels) for scoped search
@@ -73,7 +67,6 @@ function buildPersonaUpsertRequest({
system_prompt,
task_prompt,
document_set_ids,
num_chunks,
is_public,
groups,
datetime_aware,
@@ -86,8 +79,7 @@ function buildPersonaUpsertRequest({
document_ids,
icon_name,
uploaded_image_id,
is_default_persona,
llm_relevance_filter,
featured,
llm_model_provider_override,
llm_model_version_override,
starter_messages,
@@ -100,7 +92,6 @@ function buildPersonaUpsertRequest({
system_prompt,
task_prompt,
document_set_ids,
num_chunks,
is_public,
uploaded_image_id,
icon_name,
@@ -110,10 +101,7 @@ function buildPersonaUpsertRequest({
remove_image,
search_start_date,
datetime_aware,
is_default_persona: is_default_persona ?? false,
recency_bias: "base_decay",
llm_filter_extraction: false,
llm_relevance_filter: llm_relevance_filter ?? null,
featured: featured ?? false,
llm_model_provider_override: llm_model_provider_override ?? null,
llm_model_version_override: llm_model_version_override ?? null,
starter_messages: starter_messages ?? null,
@@ -226,17 +214,17 @@ export function personaComparator(
return closerToZeroNegativesFirstComparator(a.id, b.id);
}
export async function togglePersonaDefault(
export async function togglePersonaFeatured(
personaId: number,
isDefault: boolean
featured: boolean
) {
const response = await fetch(`/api/admin/persona/${personaId}/default`, {
const response = await fetch(`/api/admin/persona/${personaId}/featured`, {
method: "PATCH",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify({
is_default_persona: !isDefault,
featured: !featured,
}),
credentials: "include",
});

View File

@@ -2,7 +2,7 @@
import { ModalCreationInterface } from "@/refresh-components/contexts/ModalContext";
import { ImageProvider } from "@/app/admin/configuration/image-generation/constants";
import { LLMProviderView } from "@/app/admin/configuration/llm/interfaces";
import { LLMProviderView } from "@/interfaces/llm";
import { ImageGenerationConfigView } from "@/lib/configuration/imageConfigurationService";
import { getImageGenForm } from "./forms";

View File

@@ -7,7 +7,7 @@ import { Select } from "@/refresh-components/cards";
import { useCreateModal } from "@/refresh-components/contexts/ModalContext";
import { toast } from "@/hooks/useToast";
import { errorHandlingFetcher } from "@/lib/fetcher";
import { LLMProviderView } from "@/app/admin/configuration/llm/interfaces";
import { LLMProviderResponse, LLMProviderView } from "@/interfaces/llm";
import {
IMAGE_PROVIDER_GROUPS,
ImageProvider,
@@ -23,13 +23,14 @@ import Message from "@/refresh-components/messages/Message";
export default function ImageGenerationContent() {
const {
data: llmProviders = [],
data: llmProviderResponse,
error: llmError,
mutate: refetchProviders,
} = useSWR<LLMProviderView[]>(
} = useSWR<LLMProviderResponse<LLMProviderView>>(
"/api/admin/llm/provider?include_image_gen=true",
errorHandlingFetcher
);
const llmProviders = llmProviderResponse?.providers ?? [];
const {
data: configs = [],

View File

@@ -1,6 +1,6 @@
import { FormikProps } from "formik";
import { ImageProvider } from "../constants";
import { LLMProviderView } from "@/app/admin/configuration/llm/interfaces";
import { LLMProviderView } from "@/interfaces/llm";
import {
ImageGenerationConfigView,
ImageGenerationCredentials,

View File

@@ -1,84 +0,0 @@
"use client";
import { errorHandlingFetcher } from "@/lib/fetcher";
import useSWR from "swr";
import { Callout } from "@/components/ui/callout";
import Text from "@/refresh-components/texts/Text";
import Title from "@/components/ui/title";
import { ThreeDotsLoader } from "@/components/Loading";
import { LLMProviderView } from "./interfaces";
import { LLM_PROVIDERS_ADMIN_URL } from "./constants";
import { OpenAIForm } from "./forms/OpenAIForm";
import { AnthropicForm } from "./forms/AnthropicForm";
import { OllamaForm } from "./forms/OllamaForm";
import { AzureForm } from "./forms/AzureForm";
import { BedrockForm } from "./forms/BedrockForm";
import { VertexAIForm } from "./forms/VertexAIForm";
import { OpenRouterForm } from "./forms/OpenRouterForm";
import { getFormForExistingProvider } from "./forms/getForm";
import { CustomForm } from "./forms/CustomForm";
export function LLMConfiguration() {
const { data: existingLlmProviders } = useSWR<LLMProviderView[]>(
LLM_PROVIDERS_ADMIN_URL,
errorHandlingFetcher
);
if (!existingLlmProviders) {
return <ThreeDotsLoader />;
}
const isFirstProvider = existingLlmProviders.length === 0;
return (
<>
<Title className="mb-2">Enabled LLM Providers</Title>
{existingLlmProviders.length > 0 ? (
<>
<Text as="p" className="mb-4">
If multiple LLM providers are enabled, the default provider will be
used for all &quot;Default&quot; Assistants. For user-created
Assistants, you can select the LLM provider/model that best fits the
use case!
</Text>
<div className="flex flex-col gap-y-4">
{[...existingLlmProviders]
.sort((a, b) => {
if (a.is_default_provider && !b.is_default_provider) return -1;
if (!a.is_default_provider && b.is_default_provider) return 1;
return 0;
})
.map((llmProvider) => (
<div key={llmProvider.id}>
{getFormForExistingProvider(llmProvider)}
</div>
))}
</div>
</>
) : (
<Callout type="warning" title="No LLM providers configured yet">
Please set one up below in order to start using Onyx!
</Callout>
)}
<Title className="mb-2 mt-6">Add LLM Provider</Title>
<Text as="p" className="mb-4">
Add a new LLM provider by either selecting from one of the default
providers or by specifying your own custom LLM provider.
</Text>
<div className="flex flex-col gap-y-4">
<OpenAIForm shouldMarkAsDefault={isFirstProvider} />
<AnthropicForm shouldMarkAsDefault={isFirstProvider} />
<OllamaForm shouldMarkAsDefault={isFirstProvider} />
<AzureForm shouldMarkAsDefault={isFirstProvider} />
<BedrockForm shouldMarkAsDefault={isFirstProvider} />
<VertexAIForm shouldMarkAsDefault={isFirstProvider} />
<OpenRouterForm shouldMarkAsDefault={isFirstProvider} />
<CustomForm shouldMarkAsDefault={isFirstProvider} />
</div>
</>
);
}

View File

@@ -1,13 +1,14 @@
"use client";
import { ArrayHelpers, FieldArray, FormikProps, useField } from "formik";
import { ModelConfiguration } from "./interfaces";
import { ModelConfiguration } from "@/interfaces/llm";
import { ManualErrorMessage, TextFormField } from "@/components/Field";
import { useEffect, useState } from "react";
import CreateButton from "@/refresh-components/buttons/CreateButton";
import { Button } from "@opal/components";
import { SvgX } from "@opal/icons";
import Text from "@/refresh-components/texts/Text";
function ModelConfigurationRow({
name,
index,

View File

@@ -1,44 +0,0 @@
import { LLMProviderName, LLMProviderView } from "../interfaces";
import { AnthropicForm } from "./AnthropicForm";
import { OpenAIForm } from "./OpenAIForm";
import { OllamaForm } from "./OllamaForm";
import { AzureForm } from "./AzureForm";
import { VertexAIForm } from "./VertexAIForm";
import { OpenRouterForm } from "./OpenRouterForm";
import { CustomForm } from "./CustomForm";
import { BedrockForm } from "./BedrockForm";
export function detectIfRealOpenAIProvider(provider: LLMProviderView) {
return (
provider.provider === LLMProviderName.OPENAI &&
provider.api_key &&
!provider.api_base &&
Object.keys(provider.custom_config || {}).length === 0
);
}
export const getFormForExistingProvider = (provider: LLMProviderView) => {
switch (provider.provider) {
case LLMProviderName.OPENAI:
// "openai" as a provider name can be used for litellm proxy / any OpenAI-compatible provider
if (detectIfRealOpenAIProvider(provider)) {
return <OpenAIForm existingLlmProvider={provider} />;
} else {
return <CustomForm existingLlmProvider={provider} />;
}
case LLMProviderName.ANTHROPIC:
return <AnthropicForm existingLlmProvider={provider} />;
case LLMProviderName.OLLAMA_CHAT:
return <OllamaForm existingLlmProvider={provider} />;
case LLMProviderName.AZURE:
return <AzureForm existingLlmProvider={provider} />;
case LLMProviderName.VERTEX_AI:
return <VertexAIForm existingLlmProvider={provider} />;
case LLMProviderName.BEDROCK:
return <BedrockForm existingLlmProvider={provider} />;
case LLMProviderName.OPENROUTER:
return <OpenRouterForm existingLlmProvider={provider} />;
default:
return <CustomForm existingLlmProvider={provider} />;
}
};

View File

@@ -1,14 +1,7 @@
"use client";
import { AdminPageTitle } from "@/components/admin/Title";
import { LLMConfiguration } from "./LLMConfiguration";
import { SvgCpu } from "@opal/icons";
export default function Page() {
return (
<>
<AdminPageTitle title="LLM Setup" icon={SvgCpu} />
import LLMConfigurationPage from "@/refresh-pages/admin/LLMConfigurationPage";
<LLMConfiguration />
</>
);
export default function Page() {
return <LLMConfigurationPage />;
}

View File

@@ -24,7 +24,7 @@ import {
BedrockFetchParams,
OllamaFetchParams,
OpenRouterFetchParams,
} from "./interfaces";
} from "@/interfaces/llm";
import { SvgAws, SvgOpenrouter } from "@opal/icons";
// Aggregator providers that host models from multiple vendors
@@ -106,8 +106,9 @@ export const getProviderIcon = (
return CPUIcon;
};
export const isAnthropic = (provider: string, modelName: string) =>
provider === "anthropic" || modelName.toLowerCase().includes("claude");
export const isAnthropic = (provider: string, modelName?: string) =>
provider === LLMProviderName.ANTHROPIC ||
!!modelName?.toLowerCase().includes("claude");
/**
* Fetches Bedrock models directly without any form state dependencies.
@@ -153,6 +154,7 @@ export const fetchBedrockModels = async (
is_visible: false,
max_input_tokens: modelData.max_input_tokens,
supports_image_input: modelData.supports_image_input,
supports_reasoning: false,
}));
return { models };
@@ -205,6 +207,7 @@ export const fetchOllamaModels = async (
is_visible: true,
max_input_tokens: modelData.max_input_tokens,
supports_image_input: modelData.supports_image_input,
supports_reasoning: false,
}));
return { models };
@@ -262,6 +265,7 @@ export const fetchOpenRouterModels = async (
is_visible: true,
max_input_tokens: modelData.max_input_tokens,
supports_image_input: modelData.supports_image_input,
supports_reasoning: false,
}));
return { models };

View File

@@ -25,7 +25,7 @@ import { ModelOption } from "@/components/embedding/ModelSelector";
import {
EMBEDDING_MODELS_ADMIN_URL,
EMBEDDING_PROVIDERS_ADMIN_URL,
} from "@/app/admin/configuration/llm/constants";
} from "@/lib/llmConfig/constants";
import { AdvancedSearchConfiguration } from "@/app/admin/embeddings/interfaces";
import Button from "@/refresh-components/buttons/Button";

View File

@@ -14,7 +14,7 @@ import {
import {
EMBEDDING_PROVIDERS_ADMIN_URL,
LLM_PROVIDERS_ADMIN_URL,
} from "@/app/admin/configuration/llm/constants";
} from "@/lib/llmConfig/constants";
import { mutate } from "swr";
import { testEmbedding } from "@/app/admin/embeddings/pages/utils";
import { SvgSettings } from "@opal/icons";

View File

@@ -11,7 +11,7 @@ import {
EmbeddingProvider,
getFormattedProviderName,
} from "@/components/embedding/interfaces";
import { EMBEDDING_PROVIDERS_ADMIN_URL } from "@/app/admin/configuration/llm/constants";
import { EMBEDDING_PROVIDERS_ADMIN_URL } from "@/lib/llmConfig/constants";
import Modal from "@/refresh-components/Modal";
import { SvgSettings } from "@opal/icons";
export interface ProviderCreationModalProps {

View File

@@ -15,7 +15,7 @@ import {
import NumberInput from "../../connectors/[connector]/pages/ConnectorInput/NumberInput";
import { StringOrNumberOption } from "@/components/Dropdown";
import useSWR from "swr";
import { LLM_CONTEXTUAL_COST_ADMIN_URL } from "../../configuration/llm/constants";
import { LLM_CONTEXTUAL_COST_ADMIN_URL } from "@/lib/llmConfig/constants";
import { errorHandlingFetcher } from "@/lib/fetcher";
import Button from "@/refresh-components/buttons/Button";
import { NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";

View File

@@ -74,10 +74,10 @@ export default function ProjectChatSessionList() {
<div className="flex gap-3 min-w-0 w-full">
<div className="flex h-full w-fit pt-1 pl-1">
{(() => {
const personaIdToDefault =
currentProjectDetails?.persona_id_to_is_default || {};
const isDefault = personaIdToDefault[chat.persona_id];
if (isDefault === false) {
const personaIdToFeatured =
currentProjectDetails?.persona_id_to_featured || {};
const isFeatured = personaIdToFeatured[chat.persona_id];
if (isFeatured === false) {
const assistant = assistants.find(
(a) => a.id === chat.persona_id
);

View File

@@ -18,7 +18,7 @@ import SourceTag from "@/refresh-components/buttons/source-tag/SourceTag";
import { citationsToSourceInfoArray } from "@/refresh-components/buttons/source-tag/sourceTagUtils";
import CopyIconButton from "@/refresh-components/buttons/CopyIconButton";
import LLMPopover from "@/refresh-components/popovers/LLMPopover";
import { parseLlmDescriptor } from "@/lib/llm/utils";
import { parseLlmDescriptor } from "@/lib/llmConfig/utils";
import { LlmManager } from "@/lib/hooks";
import { Message } from "@/app/app/interfaces";
import { SvgThumbsDown, SvgThumbsUp } from "@opal/icons";

View File

@@ -59,7 +59,7 @@ export enum UserFileStatus {
export type ProjectDetails = {
project: Project;
files?: ProjectFile[];
persona_id_to_is_default?: Record<number, boolean>;
persona_id_to_featured?: Record<number, boolean>;
};
export async function fetchProjects(): Promise<Project[]> {

View File

@@ -20,7 +20,7 @@ export function constructMiniFiedPersona(name: string, id: number): Persona {
owner: null,
starter_messages: null,
builtin_persona: false,
is_default_persona: false,
featured: false,
users: [],
groups: [],
user_file_ids: [],

View File

@@ -11,7 +11,7 @@ import Text from "@/refresh-components/texts/Text";
import Popover, { PopoverMenu } from "@/refresh-components/Popover";
import Switch from "@/refresh-components/inputs/Switch";
import LineItem from "@/refresh-components/buttons/LineItem";
import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
import { LLMProviderDescriptor } from "@/interfaces/llm";
import {
BuildLlmSelection,
BUILD_MODE_PROVIDERS,

View File

@@ -1,5 +1,5 @@
import { useMemo, useState, useCallback } from "react";
import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
import { LLMProviderDescriptor } from "@/interfaces/llm";
import {
BuildLlmSelection,
getBuildLlmSelection,

View File

@@ -7,7 +7,7 @@ import { usePreProvisionPolling } from "@/app/craft/hooks/usePreProvisionPolling
import { CRAFT_SEARCH_PARAM_NAMES } from "@/app/craft/services/searchParams";
import { CRAFT_PATH } from "@/app/craft/v1/constants";
import { getBuildUserPersona } from "@/app/craft/onboarding/constants";
import { useLLMProviders } from "@/lib/hooks/useLLMProviders";
import { useLLMProviders } from "@/hooks/useLLMProviders";
import { checkPreProvisionedSession } from "@/app/craft/services/apiServices";
interface UseBuildSessionControllerProps {

View File

@@ -18,8 +18,8 @@ import {
getBuildLlmSelection,
getDefaultLlmSelection,
} from "@/app/craft/onboarding/constants";
import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
import { LLM_PROVIDERS_ADMIN_URL } from "@/app/admin/configuration/llm/constants";
import { LLMProviderDescriptor } from "@/interfaces/llm";
import { LLM_PROVIDERS_ADMIN_URL } from "@/lib/llmConfig/constants";
import {
buildInitialValues,
testApiKeyHelper,

View File

@@ -5,10 +5,7 @@ import { cn } from "@/lib/utils";
import { Disabled } from "@/refresh-components/Disabled";
import Text from "@/refresh-components/texts/Text";
import SimpleTooltip from "@/refresh-components/SimpleTooltip";
import {
LLMProviderName,
LLMProviderDescriptor,
} from "@/app/admin/configuration/llm/interfaces";
import { LLMProviderName, LLMProviderDescriptor } from "@/interfaces/llm";
// Provider configurations
export type ProviderKey = "anthropic" | "openai" | "openrouter";

View File

@@ -19,13 +19,12 @@ const LLM_SELECTION_PRIORITY = [
interface MinimalLlmProvider {
name: string;
provider: string;
default_model_name: string;
is_default_provider: boolean | null;
model_configurations: { name: string; is_visible: boolean }[];
}
/**
* Get the best default LLM selection based on available providers.
* Priority: Anthropic > OpenAI > OpenRouter > system default > first available
* Priority: Anthropic > OpenAI > OpenRouter > first available
*/
export function getDefaultLlmSelection(
llmProviders: MinimalLlmProvider[] | undefined
@@ -44,23 +43,16 @@ export function getDefaultLlmSelection(
}
}
// Fallback: use the default provider's default model
const defaultProvider = llmProviders.find((p) => p.is_default_provider);
if (defaultProvider) {
return {
providerName: defaultProvider.name,
provider: defaultProvider.provider,
modelName: defaultProvider.default_model_name,
};
}
// Final fallback: first available provider
// Fallback: first available provider, use its first visible model
const firstProvider = llmProviders[0];
if (firstProvider) {
const firstModel = firstProvider.model_configurations.find(
(m) => m.is_visible
);
return {
providerName: firstProvider.name,
provider: firstProvider.provider,
modelName: firstProvider.default_model_name,
modelName: firstModel?.name ?? "",
};
}

View File

@@ -2,8 +2,8 @@
import { useCallback, useState, useMemo, useEffect } from "react";
import { useUser } from "@/providers/UserProvider";
import { useLLMProviders } from "@/lib/hooks/useLLMProviders";
import { LLMProviderName } from "@/app/admin/configuration/llm/interfaces";
import { useLLMProviders } from "@/hooks/useLLMProviders";
import { LLMProviderName } from "@/interfaces/llm";
import {
OnboardingModalMode,
OnboardingModalController,
@@ -18,9 +18,7 @@ import { useBuildSessionStore } from "@/app/craft/hooks/useBuildSessionStore";
// Check if all 3 build mode providers are configured (anthropic, openai, openrouter)
function checkAllProvidersConfigured(
llmProviders:
| import("@/app/admin/configuration/llm/interfaces").LLMProviderDescriptor[]
| undefined
llmProviders: import("@/interfaces/llm").LLMProviderDescriptor[] | undefined
): boolean {
if (!llmProviders || llmProviders.length === 0) {
return false;
@@ -35,9 +33,7 @@ function checkAllProvidersConfigured(
// Check if at least one provider is configured
function checkHasAnyProvider(
llmProviders:
| import("@/app/admin/configuration/llm/interfaces").LLMProviderDescriptor[]
| undefined
llmProviders: import("@/interfaces/llm").LLMProviderDescriptor[] | undefined
): boolean {
return !!(llmProviders && llmProviders.length > 0);
}

View File

@@ -1,4 +1,8 @@
import { WorkArea, Level } from "./constants";
import type {
LLMProviderDescriptor,
LLMProviderResponse,
} from "@/interfaces/llm";
export interface BuildUserInfo {
firstName: string;
@@ -33,9 +37,7 @@ export interface OnboardingModalController {
close: () => void;
// Data needed for modal
llmProviders:
| import("@/app/admin/configuration/llm/interfaces").LLMProviderDescriptor[]
| undefined;
llmProviders: LLMProviderDescriptor[] | undefined;
initialValues: {
firstName: string;
lastName: string;
@@ -54,7 +56,6 @@ export interface OnboardingModalController {
completeUserInfo: (info: BuildUserInfo) => Promise<void>;
completeLlmSetup: () => Promise<void>;
refetchLlmProviders: () => Promise<
| import("@/app/admin/configuration/llm/interfaces").LLMProviderDescriptor[]
| undefined
LLMProviderResponse<LLMProviderDescriptor> | undefined
>;
}

View File

@@ -46,7 +46,7 @@ import Switch from "@/refresh-components/inputs/Switch";
import SimpleTooltip from "@/refresh-components/SimpleTooltip";
import NotAllowedModal from "@/app/craft/onboarding/components/NotAllowedModal";
import { useOnboarding } from "@/app/craft/onboarding/BuildOnboardingProvider";
import { useLLMProviders } from "@/lib/hooks/useLLMProviders";
import { useLLMProviders } from "@/hooks/useLLMProviders";
import { useUser } from "@/providers/UserProvider";
import { getProviderIcon } from "@/app/admin/configuration/llm/utils";
import {

View File

@@ -150,7 +150,8 @@ export default async function RootLayout({
// middleware returns 402 for all non-allowlisted API calls, preventing data
// leakage. The user sees a brief loading state before being redirected.
const content =
productGating === ApplicationStatus.GATED_ACCESS ? (
productGating === ApplicationStatus.GATED_ACCESS ||
productGating === ApplicationStatus.SEAT_LIMIT_EXCEEDED ? (
<GatedContentWrapper>{children}</GatedContentWrapper>
) : (
children

View File

@@ -27,6 +27,7 @@ const SETTINGS_LAYOUT_PREFIXES = [
"/admin/document-index-migration",
"/admin/discord-bot",
"/admin/theme",
"/admin/configuration/llm",
];
export function ClientLayout({

View File

@@ -2,7 +2,7 @@
import {
WellKnownLLMProviderDescriptor,
LLMProviderDescriptor,
} from "@/app/admin/configuration/llm/interfaces";
} from "@/interfaces/llm";
import React, {
createContext,
useContext,
@@ -11,9 +11,9 @@ import React, {
useCallback,
} from "react";
import { useUser } from "@/providers/UserProvider";
import { useLLMProviders } from "@/lib/hooks/useLLMProviders";
import { useLLMProviders } from "@/hooks/useLLMProviders";
import { useLLMProviderOptions } from "@/lib/hooks/useLLMProviderOptions";
import { testDefaultProvider as testDefaultProviderSvc } from "@/lib/llm/svc";
import { testDefaultProvider as testDefaultProviderSvc } from "@/lib/llmConfig/svc";
interface ProviderContextType {
shouldShowConfigurationNeeded: boolean;

View File

@@ -9,10 +9,12 @@ import { logout } from "@/lib/user";
import { loadStripe } from "@stripe/stripe-js";
import { NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
import { useLicense } from "@/hooks/useLicense";
import { useSettingsContext } from "@/providers/SettingsProvider";
import { ApplicationStatus } from "@/interfaces/settings";
import Text from "@/refresh-components/texts/Text";
import { SvgLock } from "@opal/icons";
const linkClassName = "text-action-link-05 hover:text-action-link-06";
const linkClassName = "text-action-link-05 hover:text-action-link-06 underline";
const fetchStripePublishableKey = async (): Promise<string> => {
const response = await fetch("/api/tenants/stripe-publishable-key");
@@ -40,15 +42,30 @@ export default function AccessRestricted() {
const [isLoading, setIsLoading] = useState(false);
const [error, setError] = useState<string | null>(null);
const { data: license } = useLicense();
const settings = useSettingsContext();
const isSeatLimitExceeded =
settings.settings.application_status ===
ApplicationStatus.SEAT_LIMIT_EXCEEDED;
const hadPreviousLicense = license?.has_license === true;
const showRenewalMessage = NEXT_PUBLIC_CLOUD_ENABLED || hadPreviousLicense;
const initialModalMessage = showRenewalMessage
? NEXT_PUBLIC_CLOUD_ENABLED
? "Your access to Onyx has been temporarily suspended due to a lapse in your subscription."
: "Your access to Onyx has been temporarily suspended due to a lapse in your license."
: "An Enterprise license is required to use Onyx. Your data is protected and will be available once a license is activated.";
function getSeatLimitMessage() {
const { used_seats, seat_count } = settings.settings;
const counts =
used_seats != null && seat_count != null
? ` (${used_seats} users / ${seat_count} seats)`
: "";
return `Your organization has exceeded its licensed seat count${counts}. Access is restricted until the number of users is reduced or your license is upgraded.`;
}
const initialModalMessage = isSeatLimitExceeded
? getSeatLimitMessage()
: showRenewalMessage
? NEXT_PUBLIC_CLOUD_ENABLED
? "Your access to Onyx has been temporarily suspended due to a lapse in your subscription."
: "Your access to Onyx has been temporarily suspended due to a lapse in your license."
: "An Enterprise license is required to use Onyx. Your data is protected and will be available once a license is activated.";
const handleResubscribe = async () => {
setIsLoading(true);
@@ -80,7 +97,32 @@ export default function AccessRestricted() {
<Text text03>{initialModalMessage}</Text>
{NEXT_PUBLIC_CLOUD_ENABLED ? (
{isSeatLimitExceeded ? (
<>
<Text text03>
If you are an administrator, you can manage users on the{" "}
<Link className={linkClassName} href="/admin/users">
User Management
</Link>{" "}
page or upgrade your license on the{" "}
<Link className={linkClassName} href="/admin/billing">
Admin Billing
</Link>{" "}
page.
</Text>
<div className="flex flex-row gap-2">
<Button
onClick={async () => {
await logout();
window.location.reload();
}}
>
Log out
</Button>
</div>
</>
) : NEXT_PUBLIC_CLOUD_ENABLED ? (
<>
<Text text03>
To reinstate your access and continue benefiting from Onyx&apos;s
@@ -127,7 +169,7 @@ export default function AccessRestricted() {
sign up through Stripe or reach out to{" "}
<a className={linkClassName} href="mailto:support@onyx.app">
support@onyx.app
</a>
</a>{" "}
for billing assistance.
</Text>

View File

@@ -1,8 +1,8 @@
"use client";
import { useMemo } from "react";
import { parseLlmDescriptor, structureValue } from "@/lib/llm/utils";
import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
import { parseLlmDescriptor, structureValue } from "@/lib/llmConfig/utils";
import { DefaultModel, LLMProviderDescriptor } from "@/interfaces/llm";
import { getProviderIcon } from "@/app/admin/configuration/llm/utils";
import InputSelect from "@/refresh-components/inputs/InputSelect";
import { createIcon } from "@/components/icons/icons";
@@ -23,6 +23,7 @@ export interface LLMSelectorProps {
name?: string;
userSettings?: boolean;
llmProviders: LLMProviderDescriptor[];
defaultText?: DefaultModel | null;
currentLlm: string | null;
onSelect: (value: string | null) => void;
requiresImageGeneration?: boolean;
@@ -33,6 +34,7 @@ export default function LLMSelector({
name,
userSettings,
llmProviders,
defaultText,
currentLlm,
onSelect,
requiresImageGeneration,
@@ -139,11 +141,11 @@ export default function LLMSelector({
});
}, [llmOptions]);
const defaultProvider = llmProviders.find(
(llmProvider) => llmProvider.is_default_provider
);
const defaultProvider = defaultText
? llmProviders.find((p) => p.id === defaultText.provider_id)
: undefined;
const defaultModelName = defaultProvider?.default_model_name;
const defaultModelName = defaultText?.model_name;
const defaultModelConfig = defaultProvider?.model_configurations.find(
(m) => m.name === defaultModelName
);

View File

@@ -109,13 +109,11 @@ export function usePinnedAgents() {
const serverPinnedAgents = useMemo(() => {
if (agents.length === 0) return [];
// If pinned_assistants is null/undefined (never set), show default personas
// If pinned_assistants is null/undefined (never set), show featured personas
// If it's an empty array (user explicitly unpinned all), show nothing
const pinnedIds = user?.preferences.pinned_assistants;
if (pinnedIds === null || pinnedIds === undefined) {
return agents.filter(
(agent) => agent.is_default_persona && agent.id !== 0
);
return agents.filter((agent) => agent.featured && agent.id !== 0);
}
return pinnedIds

View File

@@ -42,7 +42,7 @@ import {
getFinalLLM,
modelSupportsImageInput,
structureValue,
} from "@/lib/llm/utils";
} from "@/lib/llmConfig/utils";
import {
CurrentMessageFIFO,
updateCurrentMessageFIFO,

View File

@@ -0,0 +1,150 @@
"use client";
import useSWR from "swr";
import { errorHandlingFetcher } from "@/lib/fetcher";
import {
LLMProviderDescriptor,
LLMProviderResponse,
LLMProviderView,
WellKnownLLMProviderDescriptor,
} from "@/interfaces/llm";
import { LLM_PROVIDERS_ADMIN_URL } from "@/lib/llmConfig/constants";
/**
* Fetches configured LLM providers accessible to the current user.
*
* Hits the **non-admin** endpoints which return `LLMProviderDescriptor`
* (no `id` or sensitive fields like `api_key`). Use this hook in
* user-facing UI (chat, popovers, onboarding) where you need the list
* of providers and their visible models but don't need admin-level details.
*
* The backend wraps the provider list in an `LLMProviderResponse` envelope
* that also carries the global default text and vision models. This hook
* unwraps `.providers` for convenience while still exposing the defaults.
*
* **Endpoints:**
* - No `personaId` → `GET /api/llm/provider`
* Returns all public providers plus restricted providers the user can
* access via group membership.
* - With `personaId` → `GET /api/llm/persona/{personaId}/providers`
* Returns providers scoped to a specific persona, respecting RBAC
* restrictions. Use this when displaying model options for a particular
* assistant.
*
* @param personaId - Optional persona ID for RBAC-scoped providers.
*
* @returns
* - `llmProviders` — The array of provider descriptors, or `undefined`
* while loading.
* - `defaultText` — The global (or persona-overridden) default text model.
* - `defaultVision` — The global (or persona-overridden) default vision model.
* - `isLoading` — `true` until the first successful response or error.
* - `error` — The SWR error object, if any.
* - `refetch` — SWR `mutate` function to trigger a revalidation.
*/
export function useLLMProviders(personaId?: number) {
const url =
personaId !== undefined
? `/api/llm/persona/${personaId}/providers`
: "/api/llm/provider";
const { data, error, mutate } = useSWR<
LLMProviderResponse<LLMProviderDescriptor>
>(url, errorHandlingFetcher, {
revalidateOnFocus: false,
dedupingInterval: 60000,
});
return {
llmProviders: data?.providers,
defaultText: data?.default_text ?? null,
defaultVision: data?.default_vision ?? null,
isLoading: !error && !data,
error,
refetch: mutate,
};
}
/**
* Fetches configured LLM providers via the **admin** endpoint.
*
* Hits `GET /api/admin/llm/provider` which returns `LLMProviderView` —
* the full provider object including `id`, `api_key` (masked),
* group/persona assignments, and all other admin-visible fields.
*
* Use this hook on admin pages (e.g. the LLM Configuration page) where
* you need provider IDs for mutations (setting defaults, editing, deleting)
* or need to display admin-only metadata. **Do not use in user-facing UI**
* — use `useLLMProviders` instead.
*
* @returns
* - `llmProviders` — The array of full provider views, or `undefined`
* while loading.
* - `defaultText` — The global default text model.
* - `defaultVision` — The global default vision model.
* - `isLoading` — `true` until the first successful response or error.
* - `error` — The SWR error object, if any.
* - `refetch` — SWR `mutate` function to trigger a revalidation.
*/
export function useAdminLLMProviders() {
const { data, error, mutate } = useSWR<LLMProviderResponse<LLMProviderView>>(
LLM_PROVIDERS_ADMIN_URL,
errorHandlingFetcher,
{
revalidateOnFocus: false,
dedupingInterval: 60000,
}
);
return {
llmProviders: data?.providers,
defaultText: data?.default_text ?? null,
defaultVision: data?.default_vision ?? null,
isLoading: !error && !data,
error,
refetch: mutate,
};
}
/**
* Fetches the catalog of well-known (built-in) LLM providers.
*
* Hits `GET /api/admin/llm/built-in/options` which returns the static
* list of provider descriptors that Onyx ships with out of the box
* (OpenAI, Anthropic, Vertex AI, Bedrock, Azure, Ollama, OpenRouter,
* etc.). Each descriptor includes the provider's known models and the
* recommended default model.
*
* Used primarily on the LLM Configuration page and onboarding flows
* to show which providers are available to set up, and to pre-populate
* model lists before the user has entered credentials.
*
* @returns
* - `wellKnownLLMProviders` — The array of built-in provider descriptors,
* or `null` while loading.
* - `isLoading` — `true` until the first successful response or error.
* - `error` — The SWR error object, if any.
* - `mutate` — SWR `mutate` function to trigger a revalidation.
*/
export function useWellKnownLLMProviders() {
const {
data: wellKnownLLMProviders,
error,
isLoading,
mutate,
} = useSWR<WellKnownLLMProviderDescriptor[]>(
"/api/admin/llm/built-in/options",
errorHandlingFetcher,
{
revalidateOnFocus: false,
dedupingInterval: 60000,
}
);
return {
wellKnownLLMProviders: wellKnownLLMProviders ?? null,
isLoading,
error,
mutate,
};
}

View File

@@ -13,8 +13,8 @@ export interface ModelConfiguration {
name: string;
is_visible: boolean;
max_input_tokens: number | null;
supports_image_input: boolean | null;
supports_reasoning?: boolean;
supports_image_input: boolean;
supports_reasoning: boolean;
display_name?: string;
provider_display_name?: string;
vendor?: string;
@@ -30,7 +30,6 @@ export interface SimpleKnownModel {
export interface WellKnownLLMProviderDescriptor {
name: string;
known_models: ModelConfiguration[];
recommended_default_model: SimpleKnownModel | null;
}
@@ -40,44 +39,31 @@ export interface LLMModelDescriptor {
maxTokens: number;
}
export interface LLMProvider {
export interface LLMProviderView {
id: number;
name: string;
provider: string;
api_key: string | null;
api_base: string | null;
api_version: string | null;
custom_config: { [key: string]: string } | null;
default_model_name: string;
is_public: boolean;
is_auto_mode: boolean;
groups: number[];
personas: number[];
deployment_name: string | null;
default_vision_model: string | null;
is_default_vision_provider: boolean | null;
model_configurations: ModelConfiguration[];
}
export interface LLMProviderView extends LLMProvider {
id: number;
is_default_provider: boolean | null;
}
export interface VisionProvider extends LLMProviderView {
vision_models: string[];
}
export interface LLMProviderDescriptor {
id: number;
name: string;
provider: string;
provider_display_name?: string;
default_model_name: string;
is_default_provider: boolean | null;
is_default_vision_provider?: boolean | null;
default_vision_model?: string | null;
is_public?: boolean;
groups?: number[];
personas?: number[];
provider_display_name: string;
model_configurations: ModelConfiguration[];
}
@@ -102,9 +88,22 @@ export interface BedrockModelResponse {
supports_image_input: boolean;
}
export interface DefaultModel {
provider_id: number;
model_name: string;
}
export interface LLMProviderResponse<T> {
providers: T[];
default_text: DefaultModel | null;
default_vision: DefaultModel | null;
}
export interface LLMProviderFormProps {
existingLlmProvider?: LLMProviderView;
shouldMarkAsDefault?: boolean;
open?: boolean;
onOpenChange?: (open: boolean) => void;
}
// Param types for model fetching functions - use snake_case to match API structure

View File

@@ -2,6 +2,7 @@ export enum ApplicationStatus {
PAYMENT_REMINDER = "payment_reminder",
GATED_ACCESS = "gated_access",
ACTIVE = "active",
SEAT_LIMIT_EXCEEDED = "seat_limit_exceeded",
}
export enum QueryHistoryType {
@@ -49,6 +50,10 @@ export interface Settings {
// True when user has a valid license, False for community edition
ee_features_enabled?: boolean;
// Seat usage - populated when seat limit is exceeded
seat_count?: number | null;
used_seats?: number | null;
// OpenSearch migration
opensearch_indexing_enabled?: boolean;

View File

@@ -211,10 +211,10 @@ export async function updateAgentFeaturedStatus(
isFeatured: boolean
): Promise<string | null> {
try {
const response = await fetch(`/api/admin/persona/${agentId}/default`, {
const response = await fetch(`/api/admin/persona/${agentId}/featured`, {
method: "PATCH",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({ is_default_persona: isFeatured }),
body: JSON.stringify({ featured: isFeatured }),
});
if (response.ok) {

View File

@@ -18,7 +18,8 @@ export type ApplicationStatus =
| "active"
| "payment_reminder"
| "gated_access"
| "expired";
| "expired"
| "seat_limit_exceeded";
/**
* Billing status from Stripe subscription.

Some files were not shown because too many files have changed in this diff Show More