Compare commits

..

13 Commits

Author SHA1 Message Date
Jamison Lahman
55b24d72b4 fix(fe): redirect to status page after deleting connector (#9620) 2026-03-25 17:24:41 +00:00
Raunak Bhagat
3321a84c7d fix(sidebar): fix icon alignment for user-avatar-popover (#9615)
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2026-03-25 17:07:50 +00:00
SubashMohan
54bf32a5f8 fix: use persisted source functions when toggling search tool (#9548) 2026-03-25 16:50:25 +00:00
Nikolas Garza
4bb6b76be6 feat(groups): switchover to /admin/groups and rewrite e2e tests (#9545) 2026-03-25 08:11:13 +00:00
SubashMohan
db94562474 feat: Group-based permissions — Phase 1 schema (AccountType, Permission, PermissionGrant) (#9547) 2026-03-25 06:24:43 +00:00
Nikolas Garza
582d4642c1 feat(metrics): add task lifecycle and per-connector Prometheus metrics (#9602) 2026-03-25 06:02:43 +00:00
Nikolas Garza
3caaecdb0e feat(groups): polish edit page table and delete UX (#9544) 2026-03-25 04:57:50 +00:00
Nikolas Garza
039b69806b feat(metrics): add queue depth and connector health Prometheus collectors (#9590) 2026-03-25 03:53:26 +00:00
Evan Lohn
63971d4958 fix: confluence client retries (#9605) 2026-03-25 03:32:29 +00:00
Nikolas Garza
ffd897f380 feat(metrics): add reusable Prometheus metrics server for celery workers (#9589) 2026-03-25 01:47:06 +00:00
Evan Lohn
4745069232 fix: no more lazy queries per search call (#9578) 2026-03-25 01:38:35 +00:00
Nikolas Garza
386782f188 feat(groups): add edit group page (#9543) 2026-03-25 01:22:57 +00:00
Raunak Bhagat
ff009c4129 fix: Fix tag widths (#9618) 2026-03-25 01:18:51 +00:00
90 changed files with 4549 additions and 2464 deletions

12
.vscode/launch.json vendored
View File

@@ -117,7 +117,8 @@
"presentation": {
"group": "2"
},
"consoleTitle": "API Server Console"
"consoleTitle": "API Server Console",
"justMyCode": false
},
{
"name": "Slack Bot",
@@ -268,7 +269,8 @@
"presentation": {
"group": "2"
},
"consoleTitle": "Celery heavy Console"
"consoleTitle": "Celery heavy Console",
"justMyCode": false
},
{
"name": "Celery kg_processing",
@@ -355,7 +357,8 @@
"presentation": {
"group": "2"
},
"consoleTitle": "Celery user_file_processing Console"
"consoleTitle": "Celery user_file_processing Console",
"justMyCode": false
},
{
"name": "Celery docfetching",
@@ -413,7 +416,8 @@
"presentation": {
"group": "2"
},
"consoleTitle": "Celery docprocessing Console"
"consoleTitle": "Celery docprocessing Console",
"justMyCode": false
},
{
"name": "Celery beat",

View File

@@ -0,0 +1,109 @@
"""group_permissions_phase1
Revision ID: 25a5501dc766
Revises: b728689f45b1
Create Date: 2026-03-23 11:41:25.557442
"""
from alembic import op
import fastapi_users_db_sqlalchemy
import sqlalchemy as sa
from onyx.db.enums import AccountType
from onyx.db.enums import GrantSource
from onyx.db.enums import Permission
# revision identifiers, used by Alembic.
revision = "25a5501dc766"
down_revision = "b728689f45b1"
branch_labels = None
depends_on = None
def upgrade() -> None:
# 1. Add account_type column to user table (nullable for now).
# TODO(subash): backfill account_type for existing rows and add NOT NULL.
op.add_column(
"user",
sa.Column(
"account_type",
sa.Enum(AccountType, native_enum=False),
nullable=True,
),
)
# 2. Add is_default column to user_group table
op.add_column(
"user_group",
sa.Column(
"is_default",
sa.Boolean(),
nullable=False,
server_default=sa.false(),
),
)
# 3. Create permission_grant table
op.create_table(
"permission_grant",
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
sa.Column("group_id", sa.Integer(), nullable=False),
sa.Column(
"permission",
sa.Enum(Permission, native_enum=False),
nullable=False,
),
sa.Column(
"grant_source",
sa.Enum(GrantSource, native_enum=False),
nullable=False,
),
sa.Column(
"granted_by",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=True,
),
sa.Column(
"granted_at",
sa.DateTime(timezone=True),
server_default=sa.func.now(),
nullable=False,
),
sa.Column(
"is_deleted",
sa.Boolean(),
nullable=False,
server_default=sa.false(),
),
sa.PrimaryKeyConstraint("id"),
sa.ForeignKeyConstraint(
["group_id"],
["user_group.id"],
ondelete="CASCADE",
),
sa.ForeignKeyConstraint(
["granted_by"],
["user.id"],
ondelete="SET NULL",
),
sa.UniqueConstraint(
"group_id", "permission", name="uq_permission_grant_group_permission"
),
)
# 4. Index on user__user_group(user_id) — existing composite PK
# has user_group_id as leading column; user-filtered queries need this
op.create_index(
"ix_user__user_group_user_id",
"user__user_group",
["user_id"],
)
def downgrade() -> None:
op.drop_index("ix_user__user_group_user_id", table_name="user__user_group")
op.drop_table("permission_grant")
op.drop_column("user_group", "is_default")
op.drop_column("user", "account_type")

View File

@@ -115,8 +115,14 @@ def fetch_user_group_token_rate_limits_for_user(
ordered: bool = True,
get_editable: bool = True,
) -> Sequence[TokenRateLimit]:
stmt = select(TokenRateLimit)
stmt = stmt.where(User__UserGroup.user_group_id == group_id)
stmt = (
select(TokenRateLimit)
.join(
TokenRateLimit__UserGroup,
TokenRateLimit.id == TokenRateLimit__UserGroup.rate_limit_id,
)
.where(TokenRateLimit__UserGroup.user_group_id == group_id)
)
stmt = _add_user_filters(stmt, user, get_editable)
if enabled_only:

View File

@@ -56,7 +56,7 @@ def _run_single_search(
chunk_search_request=chunk_search_request,
document_index=document_index,
user=user,
persona=None, # No persona for direct search
persona_search_info=None,
db_session=db_session,
)

View File

@@ -13,6 +13,14 @@ from celery.signals import worker_shutdown
import onyx.background.celery.apps.app_base as app_base
from onyx.configs.constants import POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME
from onyx.db.engine.sql_engine import SqlEngine
from onyx.server.metrics.celery_task_metrics import on_celery_task_postrun
from onyx.server.metrics.celery_task_metrics import on_celery_task_prerun
from onyx.server.metrics.celery_task_metrics import on_celery_task_rejected
from onyx.server.metrics.celery_task_metrics import on_celery_task_retry
from onyx.server.metrics.celery_task_metrics import on_celery_task_revoked
from onyx.server.metrics.indexing_task_metrics import on_indexing_task_postrun
from onyx.server.metrics.indexing_task_metrics import on_indexing_task_prerun
from onyx.server.metrics.metrics_server import start_metrics_server
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
@@ -34,6 +42,8 @@ def on_task_prerun(
**kwds: Any,
) -> None:
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
on_celery_task_prerun(task_id, task)
on_indexing_task_prerun(task_id, task, kwargs)
@signals.task_postrun.connect
@@ -48,6 +58,36 @@ def on_task_postrun(
**kwds: Any,
) -> None:
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
on_celery_task_postrun(task_id, task, state)
on_indexing_task_postrun(task_id, task, kwargs, state)
@signals.task_retry.connect
def on_task_retry(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
# task_retry signal doesn't pass task_id in kwargs; get it from
# the sender (the task instance) via sender.request.id.
task_id = getattr(getattr(sender, "request", None), "id", None)
on_celery_task_retry(task_id, sender)
@signals.task_revoked.connect
def on_task_revoked(sender: Any | None = None, **kwargs: Any) -> None:
task_name = getattr(sender, "name", None) or str(sender)
on_celery_task_revoked(kwargs.get("task_id"), task_name)
@signals.task_rejected.connect
def on_task_rejected(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
# task_rejected sends the Consumer as sender, not the task instance.
# The task name must be extracted from the Celery message headers.
message = kwargs.get("message")
task_name: str | None = None
if message is not None:
headers = getattr(message, "headers", None) or {}
task_name = headers.get("task")
if task_name is None:
task_name = "unknown"
on_celery_task_rejected(None, task_name)
@celeryd_init.connect
@@ -76,6 +116,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
@worker_ready.connect
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
start_metrics_server("docfetching")
app_base.on_worker_ready(sender, **kwargs)

View File

@@ -14,6 +14,14 @@ from celery.signals import worker_shutdown
import onyx.background.celery.apps.app_base as app_base
from onyx.configs.constants import POSTGRES_CELERY_WORKER_DOCPROCESSING_APP_NAME
from onyx.db.engine.sql_engine import SqlEngine
from onyx.server.metrics.celery_task_metrics import on_celery_task_postrun
from onyx.server.metrics.celery_task_metrics import on_celery_task_prerun
from onyx.server.metrics.celery_task_metrics import on_celery_task_rejected
from onyx.server.metrics.celery_task_metrics import on_celery_task_retry
from onyx.server.metrics.celery_task_metrics import on_celery_task_revoked
from onyx.server.metrics.indexing_task_metrics import on_indexing_task_postrun
from onyx.server.metrics.indexing_task_metrics import on_indexing_task_prerun
from onyx.server.metrics.metrics_server import start_metrics_server
from onyx.utils.logger import setup_logger
from shared_configs.configs import MULTI_TENANT
@@ -35,6 +43,8 @@ def on_task_prerun(
**kwds: Any,
) -> None:
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
on_celery_task_prerun(task_id, task)
on_indexing_task_prerun(task_id, task, kwargs)
@signals.task_postrun.connect
@@ -49,6 +59,36 @@ def on_task_postrun(
**kwds: Any,
) -> None:
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
on_celery_task_postrun(task_id, task, state)
on_indexing_task_postrun(task_id, task, kwargs, state)
@signals.task_retry.connect
def on_task_retry(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
# task_retry signal doesn't pass task_id in kwargs; get it from
# the sender (the task instance) via sender.request.id.
task_id = getattr(getattr(sender, "request", None), "id", None)
on_celery_task_retry(task_id, sender)
@signals.task_revoked.connect
def on_task_revoked(sender: Any | None = None, **kwargs: Any) -> None:
task_name = getattr(sender, "name", None) or str(sender)
on_celery_task_revoked(kwargs.get("task_id"), task_name)
@signals.task_rejected.connect
def on_task_rejected(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
# task_rejected sends the Consumer as sender, not the task instance.
# The task name must be extracted from the Celery message headers.
message = kwargs.get("message")
task_name: str | None = None
if message is not None:
headers = getattr(message, "headers", None) or {}
task_name = headers.get("task")
if task_name is None:
task_name = "unknown"
on_celery_task_rejected(None, task_name)
@celeryd_init.connect
@@ -82,6 +122,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
@worker_ready.connect
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
start_metrics_server("docprocessing")
app_base.on_worker_ready(sender, **kwargs)
@@ -90,6 +131,12 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
app_base.on_worker_shutdown(sender, **kwargs)
# Note: worker_process_init only fires in prefork pool mode. Docprocessing uses
# worker_pool="threads" (see configs/docprocessing.py), so this handler is
# effectively a no-op in normal operation. It remains as a safety net in case
# the pool type is ever changed to prefork. Prometheus metrics are safe in
# thread-pool mode since all threads share the same process memory and can
# update the same Counter/Gauge/Histogram objects directly.
@worker_process_init.connect
def init_worker(**kwargs: Any) -> None: # noqa: ARG001
SqlEngine.reset_engine()

View File

@@ -54,8 +54,14 @@ def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None
app_base.on_celeryd_init(sender, conf, **kwargs)
# Set by on_worker_init so on_worker_ready knows whether to start the server.
_prometheus_collectors_ok: bool = False
@worker_init.connect
def on_worker_init(sender: Any, **kwargs: Any) -> None:
global _prometheus_collectors_ok
logger.info("worker_init signal received.")
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
@@ -65,6 +71,8 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
app_base.wait_for_redis(sender, **kwargs)
app_base.wait_for_db(sender, **kwargs)
_prometheus_collectors_ok = _setup_prometheus_collectors(sender)
# Less startup checks in multi-tenant case
if MULTI_TENANT:
return
@@ -72,8 +80,37 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
app_base.on_secondary_worker_init(sender, **kwargs)
def _setup_prometheus_collectors(sender: Any) -> bool:
"""Register Prometheus collectors that need Redis/DB access.
Passes the Celery app so the queue depth collector can obtain a fresh
broker Redis client on each scrape (rather than holding a stale reference).
Returns True if registration succeeded, False otherwise.
"""
try:
from onyx.server.metrics.indexing_pipeline_setup import (
setup_indexing_pipeline_metrics,
)
setup_indexing_pipeline_metrics(sender.app)
logger.info("Prometheus indexing pipeline collectors registered")
return True
except Exception:
logger.exception("Failed to register Prometheus indexing pipeline collectors")
return False
@worker_ready.connect
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
if _prometheus_collectors_ok:
from onyx.server.metrics.metrics_server import start_metrics_server
start_metrics_server("monitoring")
else:
logger.warning(
"Skipping Prometheus metrics server — collector registration failed"
)
app_base.on_worker_ready(sender, **kwargs)

View File

@@ -123,7 +123,7 @@ class OnyxConfluence:
self.shared_base_kwargs: dict[str, str | int | bool] = {
"api_version": "cloud" if is_cloud else "latest",
"backoff_and_retry": True,
"backoff_and_retry": False,
"cloud": is_cloud,
}
if timeout:
@@ -456,7 +456,7 @@ class OnyxConfluence:
return attr(*args, **kwargs)
except HTTPError as e:
delay_until = _handle_http_error(e, attempt)
delay_until = _handle_http_error(e, attempt, MAX_RETRIES)
logger.warning(
f"HTTPError in confluence call. Retrying in {delay_until} seconds..."
)

View File

@@ -363,7 +363,7 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
# and applying our own retries in a more specific set of circumstances
return confluence_call(*args, **kwargs)
except requests.HTTPError as e:
delay_until = _handle_http_error(e, attempt)
delay_until = _handle_http_error(e, attempt, MAX_RETRIES)
logger.warning(
f"HTTPError in confluence call. Retrying in {delay_until} seconds..."
)
@@ -384,7 +384,7 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
return cast(F, wrapped_call)
def _handle_http_error(e: requests.HTTPError, attempt: int) -> int:
def _handle_http_error(e: requests.HTTPError, attempt: int, max_retries: int) -> int:
MIN_DELAY = 2
MAX_DELAY = 60
STARTING_DELAY = 5
@@ -408,6 +408,17 @@ def _handle_http_error(e: requests.HTTPError, attempt: int) -> int:
raise e
if e.response.status_code >= 500:
if attempt >= max_retries - 1:
raise e
delay = min(STARTING_DELAY * (BACKOFF**attempt), MAX_DELAY)
logger.warning(
f"Server error {e.response.status_code}. "
f"Retrying in {delay} seconds (attempt {attempt + 1})..."
)
return math.ceil(time.monotonic() + delay)
if (
e.response.status_code != 429
and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower()

View File

@@ -401,3 +401,16 @@ class SavedSearchDocWithContent(SavedSearchDoc):
section in addition to the match_highlights."""
content: str
class PersonaSearchInfo(BaseModel):
"""Snapshot of persona data needed by the search pipeline.
Extracted from the ORM Persona before the DB session is released so that
SearchTool and search_pipeline never lazy-load relationships post-commit.
"""
document_set_names: list[str]
search_start_date: datetime | None
attached_document_ids: list[str]
hierarchy_node_ids: list[int]

View File

@@ -9,12 +9,12 @@ from onyx.context.search.models import ChunkSearchRequest
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import PersonaSearchInfo
from onyx.context.search.preprocessing.access_filters import (
build_access_filters_for_user,
)
from onyx.context.search.retrieval.search_runner import search_chunks
from onyx.context.search.utils import inference_section_from_chunks
from onyx.db.models import Persona
from onyx.db.models import User
from onyx.document_index.interfaces import DocumentIndex
from onyx.federated_connectors.federated_retrieval import FederatedRetrievalInfo
@@ -247,8 +247,8 @@ def search_pipeline(
document_index: DocumentIndex,
# Used for ACLs and federated search, anonymous users only see public docs
user: User,
# Used for default filters and settings
persona: Persona | None,
# Pre-extracted persona search configuration (None when no persona)
persona_search_info: PersonaSearchInfo | None,
db_session: Session | None = None,
auto_detect_filters: bool = False,
llm: LLM | None = None,
@@ -263,24 +263,18 @@ def search_pipeline(
prefetched_federated_retrieval_infos: list[FederatedRetrievalInfo] | None = None,
) -> list[InferenceChunk]:
persona_document_sets: list[str] | None = (
[persona_document_set.name for persona_document_set in persona.document_sets]
if persona
else None
persona_search_info.document_set_names if persona_search_info else None
)
persona_time_cutoff: datetime | None = (
persona.search_start_date if persona else None
persona_search_info.search_start_date if persona_search_info else None
)
# Extract assistant knowledge filters from persona
attached_document_ids: list[str] | None = (
[doc.id for doc in persona.attached_documents]
if persona and persona.attached_documents
persona_search_info.attached_document_ids or None
if persona_search_info
else None
)
hierarchy_node_ids: list[int] | None = (
[node.id for node in persona.hierarchy_nodes]
if persona and persona.hierarchy_nodes
else None
persona_search_info.hierarchy_node_ids or None if persona_search_info else None
)
filters = _build_index_filters(

View File

@@ -64,6 +64,9 @@ def get_chat_session_by_id(
joinedload(ChatSession.persona).options(
selectinload(Persona.tools),
selectinload(Persona.user_files),
selectinload(Persona.document_sets),
selectinload(Persona.attached_documents),
selectinload(Persona.hierarchy_nodes),
),
joinedload(ChatSession.project),
)

View File

@@ -750,3 +750,31 @@ def resync_cc_pair(
)
db_session.commit()
# ── Metrics query helpers ──────────────────────────────────────────────
def get_connector_health_for_metrics(
db_session: Session,
) -> list: # Returns list of Row tuples
"""Return connector health data for Prometheus metrics.
Each row is (cc_pair_id, status, in_repeated_error_state,
last_successful_index_time, name, source).
"""
return (
db_session.query(
ConnectorCredentialPair.id,
ConnectorCredentialPair.status,
ConnectorCredentialPair.in_repeated_error_state,
ConnectorCredentialPair.last_successful_index_time,
ConnectorCredentialPair.name,
Connector.source,
)
.join(
Connector,
ConnectorCredentialPair.connector_id == Connector.id,
)
.all()
)

View File

@@ -1,4 +1,31 @@
from __future__ import annotations
from enum import Enum as PyEnum
from typing import ClassVar
class AccountType(str, PyEnum):
"""
What kind of account this is — determines whether the user
enters the group-based permission system.
STANDARD + SERVICE_ACCOUNT → participate in group system
BOT, EXT_PERM_USER, ANONYMOUS → fixed behavior
"""
STANDARD = "standard"
BOT = "bot"
EXT_PERM_USER = "ext_perm_user"
SERVICE_ACCOUNT = "service_account"
ANONYMOUS = "anonymous"
class GrantSource(str, PyEnum):
"""How a permission grant was created."""
USER = "user"
SCIM = "scim"
SYSTEM = "system"
class IndexingStatus(str, PyEnum):
@@ -314,3 +341,54 @@ class HookPoint(str, PyEnum):
class HookFailStrategy(str, PyEnum):
HARD = "hard" # exception propagates, pipeline aborts
SOFT = "soft" # log error, return original input, pipeline continues
class Permission(str, PyEnum):
"""
Permission tokens for group-based authorization.
19 tokens total. full_admin_panel_access is an override —
if present, any permission check passes.
"""
# Basic (auto-granted to every new group)
BASIC_ACCESS = "basic"
# Read tokens — implied only, never granted directly
READ_CONNECTORS = "read:connectors"
READ_DOCUMENT_SETS = "read:document_sets"
READ_AGENTS = "read:agents"
READ_USERS = "read:users"
# Add / Manage pairs
ADD_AGENTS = "add:agents"
MANAGE_AGENTS = "manage:agents"
MANAGE_DOCUMENT_SETS = "manage:document_sets"
ADD_CONNECTORS = "add:connectors"
MANAGE_CONNECTORS = "manage:connectors"
MANAGE_LLMS = "manage:llms"
# Toggle tokens
READ_AGENT_ANALYTICS = "read:agent_analytics"
MANAGE_ACTIONS = "manage:actions"
READ_QUERY_HISTORY = "read:query_history"
MANAGE_USER_GROUPS = "manage:user_groups"
CREATE_USER_API_KEYS = "create:user_api_keys"
CREATE_SERVICE_ACCOUNT_API_KEYS = "create:service_account_api_keys"
CREATE_SLACK_DISCORD_BOTS = "create:slack_discord_bots"
# Override — any permission check passes
FULL_ADMIN_PANEL_ACCESS = "admin"
# Permissions that are implied by other grants and must never be stored
# directly in the permission_grant table.
IMPLIED: ClassVar[frozenset[Permission]]
Permission.IMPLIED = frozenset(
{
Permission.READ_CONNECTORS,
Permission.READ_DOCUMENT_SETS,
Permission.READ_AGENTS,
Permission.READ_USERS,
}
)

View File

@@ -2,6 +2,8 @@ from collections.abc import Sequence
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import NamedTuple
from typing import TYPE_CHECKING
from typing import TypeVarTuple
from sqlalchemy import and_
@@ -28,6 +30,9 @@ from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
if TYPE_CHECKING:
from onyx.configs.constants import DocumentSource
# from sqlalchemy.sql.selectable import Select
# Comment out unused imports that cause mypy errors
@@ -972,3 +977,106 @@ def get_index_attempt_errors_for_cc_pair(
stmt = stmt.offset(page * page_size).limit(page_size)
return list(db_session.scalars(stmt).all())
# ── Metrics query helpers ──────────────────────────────────────────────
class ActiveIndexAttemptMetric(NamedTuple):
"""Row returned by get_active_index_attempts_for_metrics."""
status: IndexingStatus
source: "DocumentSource"
cc_pair_id: int
cc_pair_name: str | None
attempt_count: int
def get_active_index_attempts_for_metrics(
db_session: Session,
) -> list[ActiveIndexAttemptMetric]:
"""Return non-terminal index attempts grouped by status, source, and connector.
Each row is (status, source, cc_pair_id, cc_pair_name, attempt_count).
"""
from onyx.db.models import Connector
terminal_statuses = [s for s in IndexingStatus if s.is_terminal()]
rows = (
db_session.query(
IndexAttempt.status,
Connector.source,
ConnectorCredentialPair.id,
ConnectorCredentialPair.name,
func.count(),
)
.join(
ConnectorCredentialPair,
IndexAttempt.connector_credential_pair_id == ConnectorCredentialPair.id,
)
.join(
Connector,
ConnectorCredentialPair.connector_id == Connector.id,
)
.filter(IndexAttempt.status.notin_(terminal_statuses))
.group_by(
IndexAttempt.status,
Connector.source,
ConnectorCredentialPair.id,
ConnectorCredentialPair.name,
)
.all()
)
return [ActiveIndexAttemptMetric(*row) for row in rows]
def get_failed_attempt_counts_by_cc_pair(
db_session: Session,
since: datetime | None = None,
) -> dict[int, int]:
"""Return {cc_pair_id: failed_attempt_count} for all connectors.
When ``since`` is provided, only attempts created after that timestamp
are counted. Defaults to the last 90 days to avoid unbounded historical
aggregation.
"""
if since is None:
since = datetime.now(timezone.utc) - timedelta(days=90)
rows = (
db_session.query(
IndexAttempt.connector_credential_pair_id,
func.count(),
)
.filter(IndexAttempt.status == IndexingStatus.FAILED)
.filter(IndexAttempt.time_created >= since)
.group_by(IndexAttempt.connector_credential_pair_id)
.all()
)
return {cc_id: count for cc_id, count in rows}
def get_docs_indexed_by_cc_pair(
db_session: Session,
since: datetime | None = None,
) -> dict[int, int]:
"""Return {cc_pair_id: total_new_docs_indexed} across successful attempts.
Only counts attempts with status SUCCESS to avoid inflating counts with
partial results from failed attempts. When ``since`` is provided, only
attempts created after that timestamp are included.
"""
if since is None:
since = datetime.now(timezone.utc) - timedelta(days=90)
query = (
db_session.query(
IndexAttempt.connector_credential_pair_id,
func.sum(func.coalesce(IndexAttempt.new_docs_indexed, 0)),
)
.filter(IndexAttempt.status == IndexingStatus.SUCCESS)
.filter(IndexAttempt.time_created >= since)
.group_by(IndexAttempt.connector_credential_pair_id)
)
rows = query.all()
return {cc_id: int(total or 0) for cc_id, total in rows}

View File

@@ -48,6 +48,7 @@ from sqlalchemy.types import LargeBinary
from sqlalchemy.types import TypeDecorator
from sqlalchemy import PrimaryKeyConstraint
from onyx.db.enums import AccountType
from onyx.auth.schemas import UserRole
from onyx.configs.constants import (
ANONYMOUS_USER_UUID,
@@ -78,6 +79,8 @@ from onyx.db.enums import (
MCPAuthenticationPerformer,
MCPTransport,
MCPServerStatus,
Permission,
GrantSource,
LLMModelFlowType,
ThemePreference,
DefaultAppMode,
@@ -302,6 +305,9 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
role: Mapped[UserRole] = mapped_column(
Enum(UserRole, native_enum=False, default=UserRole.BASIC)
)
account_type: Mapped[AccountType | None] = mapped_column(
Enum(AccountType, native_enum=False), nullable=True
)
"""
Preferences probably should be in a separate table at some point, but for now
@@ -3971,6 +3977,8 @@ class SamlAccount(Base):
class User__UserGroup(Base):
__tablename__ = "user__user_group"
__table_args__ = (Index("ix_user__user_group_user_id", "user_id"),)
is_curator: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
user_group_id: Mapped[int] = mapped_column(
@@ -3981,6 +3989,48 @@ class User__UserGroup(Base):
)
class PermissionGrant(Base):
__tablename__ = "permission_grant"
__table_args__ = (
UniqueConstraint(
"group_id", "permission", name="uq_permission_grant_group_permission"
),
)
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
group_id: Mapped[int] = mapped_column(
ForeignKey("user_group.id", ondelete="CASCADE"), nullable=False
)
permission: Mapped[Permission] = mapped_column(
Enum(Permission, native_enum=False), nullable=False
)
grant_source: Mapped[GrantSource] = mapped_column(
Enum(GrantSource, native_enum=False), nullable=False
)
granted_by: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id", ondelete="SET NULL"), nullable=True
)
granted_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)
is_deleted: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=False, server_default=text("false")
)
group: Mapped["UserGroup"] = relationship(
"UserGroup", back_populates="permission_grants"
)
@validates("permission")
def _validate_permission(self, _key: str, value: Permission) -> Permission:
if value in Permission.IMPLIED:
raise ValueError(
f"{value!r} is an implied permission and cannot be granted directly"
)
return value
class UserGroup__ConnectorCredentialPair(Base):
__tablename__ = "user_group__connector_credential_pair"
@@ -4075,6 +4125,8 @@ class UserGroup(Base):
is_up_for_deletion: Mapped[bool] = mapped_column(
Boolean, nullable=False, default=False
)
# whether this is a default group (e.g. "Basic", "Admins") that cannot be deleted
is_default: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
# Last time a user updated this user group
time_last_modified_by_user: Mapped[datetime.datetime] = mapped_column(
@@ -4118,6 +4170,9 @@ class UserGroup(Base):
accessible_mcp_servers: Mapped[list["MCPServer"]] = relationship(
"MCPServer", secondary="mcp_server__user_group", back_populates="user_groups"
)
permission_grants: Mapped[list["PermissionGrant"]] = relationship(
"PermissionGrant", back_populates="group", cascade="all, delete-orphan"
)
"""Tables related to Token Rate Limiting

View File

@@ -50,8 +50,18 @@ from onyx.utils.variable_functionality import fetch_versioned_implementation
logger = setup_logger()
def get_default_behavior_persona(db_session: Session) -> Persona | None:
def get_default_behavior_persona(
db_session: Session,
eager_load_for_tools: bool = False,
) -> Persona | None:
stmt = select(Persona).where(Persona.id == DEFAULT_PERSONA_ID)
if eager_load_for_tools:
stmt = stmt.options(
selectinload(Persona.tools),
selectinload(Persona.document_sets),
selectinload(Persona.attached_documents),
selectinload(Persona.hierarchy_nodes),
)
return db_session.scalars(stmt).first()

View File

@@ -91,8 +91,6 @@ class HookResponse(BaseModel):
# Nullable to match the DB column — endpoint_url is required on creation but
# future hook point types may not use an external endpoint (e.g. built-in handlers).
endpoint_url: str | None
# Partially-masked API key (e.g. "abcd••••••••wxyz"), or None if no key is set.
api_key_masked: str | None
fail_strategy: HookFailStrategy
timeout_seconds: float # always resolved — None from request is replaced with spec default before DB write
is_active: bool

View File

@@ -62,9 +62,6 @@ def _hook_to_response(hook: Hook, creator_email: str | None = None) -> HookRespo
name=hook.name,
hook_point=hook.hook_point,
endpoint_url=hook.endpoint_url,
api_key_masked=(
hook.api_key.get_value(apply_mask=True) if hook.api_key else None
),
fail_strategy=hook.fail_strategy,
timeout_seconds=hook.timeout_seconds,
is_active=hook.is_active,

View File

@@ -0,0 +1,207 @@
"""Generic Celery task lifecycle Prometheus metrics.
Provides signal handlers that track task started/completed/failed counts,
active task gauge, task duration histograms, and retry/reject/revoke counts.
These fire for ALL tasks on the worker — no per-connector enrichment
(see indexing_task_metrics.py for that).
Usage in a worker app module:
from onyx.server.metrics.celery_task_metrics import (
on_celery_task_prerun,
on_celery_task_postrun,
on_celery_task_retry,
on_celery_task_revoked,
on_celery_task_rejected,
)
# Call from the worker's existing signal handlers
"""
import threading
import time
from celery import Task
from prometheus_client import Counter
from prometheus_client import Gauge
from prometheus_client import Histogram
from onyx.utils.logger import setup_logger
logger = setup_logger()
TASK_STARTED = Counter(
"onyx_celery_task_started_total",
"Total Celery tasks started",
["task_name", "queue"],
)
TASK_COMPLETED = Counter(
"onyx_celery_task_completed_total",
"Total Celery tasks completed",
["task_name", "queue", "outcome"],
)
TASK_DURATION = Histogram(
"onyx_celery_task_duration_seconds",
"Celery task execution duration in seconds",
["task_name", "queue"],
buckets=[1, 5, 15, 30, 60, 120, 300, 600, 1800, 3600],
)
TASKS_ACTIVE = Gauge(
"onyx_celery_tasks_active",
"Currently executing Celery tasks",
["task_name", "queue"],
)
TASK_RETRIED = Counter(
"onyx_celery_task_retried_total",
"Total Celery tasks retried",
["task_name", "queue"],
)
TASK_REVOKED = Counter(
"onyx_celery_task_revoked_total",
"Total Celery tasks revoked (cancelled)",
["task_name"],
)
TASK_REJECTED = Counter(
"onyx_celery_task_rejected_total",
"Total Celery tasks rejected by worker",
["task_name"],
)
# task_id → (monotonic start time, metric labels)
_task_start_times: dict[str, tuple[float, dict[str, str]]] = {}
# Lock protecting _task_start_times — prerun, postrun, and eviction may
# run concurrently on thread-pool workers.
_task_start_times_lock = threading.Lock()
# Entries older than this are evicted on each prerun to prevent unbounded
# growth when tasks are killed (SIGTERM, OOM) and postrun never fires.
_MAX_START_TIME_AGE_SECONDS = 3600 # 1 hour
def _evict_stale_start_times() -> None:
"""Remove _task_start_times entries older than _MAX_START_TIME_AGE_SECONDS.
Must be called while holding _task_start_times_lock.
"""
now = time.monotonic()
stale_ids = [
tid
for tid, (start, _labels) in _task_start_times.items()
if now - start > _MAX_START_TIME_AGE_SECONDS
]
for tid in stale_ids:
entry = _task_start_times.pop(tid, None)
if entry is not None:
_labels = entry[1]
# Decrement active gauge for evicted tasks — these tasks were
# started but never completed (killed, OOM, etc.).
active_gauge = TASKS_ACTIVE.labels(**_labels)
if active_gauge._value.get() > 0:
active_gauge.dec()
def _get_task_labels(task: Task) -> dict[str, str]:
"""Extract task_name and queue labels from a Celery Task instance."""
task_name = task.name or "unknown"
queue = "unknown"
try:
delivery_info = task.request.delivery_info
if delivery_info:
queue = delivery_info.get("routing_key") or "unknown"
except AttributeError:
pass
return {"task_name": task_name, "queue": queue}
def on_celery_task_prerun(
task_id: str | None,
task: Task | None,
) -> None:
"""Record task start. Call from the worker's task_prerun signal handler."""
if task is None or task_id is None:
return
try:
labels = _get_task_labels(task)
TASK_STARTED.labels(**labels).inc()
TASKS_ACTIVE.labels(**labels).inc()
with _task_start_times_lock:
_evict_stale_start_times()
_task_start_times[task_id] = (time.monotonic(), labels)
except Exception:
logger.debug("Failed to record celery task prerun metrics", exc_info=True)
def on_celery_task_postrun(
task_id: str | None,
task: Task | None,
state: str | None,
) -> None:
"""Record task completion. Call from the worker's task_postrun signal handler."""
if task is None or task_id is None:
return
try:
labels = _get_task_labels(task)
outcome = "success" if state == "SUCCESS" else "failure"
TASK_COMPLETED.labels(**labels, outcome=outcome).inc()
# Guard against going below 0 if postrun fires without a matching
# prerun (e.g. after a worker restart or stale entry eviction).
active_gauge = TASKS_ACTIVE.labels(**labels)
if active_gauge._value.get() > 0:
active_gauge.dec()
with _task_start_times_lock:
entry = _task_start_times.pop(task_id, None)
if entry is not None:
start_time, _stored_labels = entry
TASK_DURATION.labels(**labels).observe(time.monotonic() - start_time)
except Exception:
logger.debug("Failed to record celery task postrun metrics", exc_info=True)
def on_celery_task_retry(
_task_id: str | None,
task: Task | None,
) -> None:
"""Record task retry. Call from the worker's task_retry signal handler."""
if task is None:
return
try:
labels = _get_task_labels(task)
TASK_RETRIED.labels(**labels).inc()
except Exception:
logger.debug("Failed to record celery task retry metrics", exc_info=True)
def on_celery_task_revoked(
_task_id: str | None,
task_name: str | None = None,
) -> None:
"""Record task revocation. The revoked signal doesn't provide a Task
instance, only the task name via sender."""
if task_name is None:
return
try:
TASK_REVOKED.labels(task_name=task_name).inc()
except Exception:
logger.debug("Failed to record celery task revoked metrics", exc_info=True)
def on_celery_task_rejected(
_task_id: str | None,
task_name: str | None = None,
) -> None:
"""Record task rejection."""
if task_name is None:
return
try:
TASK_REJECTED.labels(task_name=task_name).inc()
except Exception:
logger.debug("Failed to record celery task rejected metrics", exc_info=True)

View File

@@ -0,0 +1,528 @@
"""Prometheus collectors for Celery queue depths and indexing pipeline state.
These collectors query Redis and Postgres at scrape time (the Collector pattern),
so metrics are always fresh when Prometheus scrapes /metrics. They run inside the
monitoring celery worker which already has Redis and DB access.
To avoid hammering Redis/Postgres on every 15s scrape, results are cached with
a configurable TTL (default 30s). This means metrics may be up to TTL seconds
stale, which is fine for monitoring dashboards.
"""
import json
import threading
import time
from collections.abc import Callable
from datetime import datetime
from datetime import timezone
from typing import Any
from prometheus_client.core import GaugeMetricFamily
from prometheus_client.registry import Collector
from redis import Redis
from onyx.background.celery.celery_redis import celery_get_queue_length
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
from onyx.configs.constants import OnyxCeleryQueues
from onyx.utils.logger import setup_logger
logger = setup_logger()
# Default cache TTL in seconds. Scrapes hitting within this window return
# the previous result without re-querying Redis/Postgres.
_DEFAULT_CACHE_TTL = 30.0
_QUEUE_LABEL_MAP: dict[str, str] = {
OnyxCeleryQueues.PRIMARY: "primary",
OnyxCeleryQueues.DOCPROCESSING: "docprocessing",
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING: "docfetching",
OnyxCeleryQueues.VESPA_METADATA_SYNC: "vespa_metadata_sync",
OnyxCeleryQueues.CONNECTOR_DELETION: "connector_deletion",
OnyxCeleryQueues.CONNECTOR_PRUNING: "connector_pruning",
OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC: "permissions_sync",
OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC: "external_group_sync",
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT: "permissions_upsert",
OnyxCeleryQueues.CONNECTOR_HIERARCHY_FETCHING: "hierarchy_fetching",
OnyxCeleryQueues.LLM_MODEL_UPDATE: "llm_model_update",
OnyxCeleryQueues.CHECKPOINT_CLEANUP: "checkpoint_cleanup",
OnyxCeleryQueues.INDEX_ATTEMPT_CLEANUP: "index_attempt_cleanup",
OnyxCeleryQueues.CSV_GENERATION: "csv_generation",
OnyxCeleryQueues.USER_FILE_PROCESSING: "user_file_processing",
OnyxCeleryQueues.USER_FILE_PROJECT_SYNC: "user_file_project_sync",
OnyxCeleryQueues.USER_FILE_DELETE: "user_file_delete",
OnyxCeleryQueues.MONITORING: "monitoring",
OnyxCeleryQueues.SANDBOX: "sandbox",
OnyxCeleryQueues.OPENSEARCH_MIGRATION: "opensearch_migration",
}
# Queues where prefetched (unacked) task counts are meaningful
_UNACKED_QUEUES: list[str] = [
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING,
OnyxCeleryQueues.DOCPROCESSING,
]
class _CachedCollector(Collector):
"""Base collector with TTL-based caching.
Subclasses implement ``_collect_fresh()`` to query the actual data source.
The base ``collect()`` returns cached results if the TTL hasn't expired,
avoiding repeated queries when Prometheus scrapes frequently.
"""
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
self._cache_ttl = cache_ttl
self._cached_result: list[GaugeMetricFamily] | None = None
self._last_collect_time: float = 0.0
self._lock = threading.Lock()
def collect(self) -> list[GaugeMetricFamily]:
with self._lock:
now = time.monotonic()
if (
now - self._last_collect_time < self._cache_ttl
and self._cached_result is not None
):
return self._cached_result
try:
result = self._collect_fresh()
self._cached_result = result
self._last_collect_time = now
return result
except Exception:
logger.exception(f"Error in {type(self).__name__}.collect()")
# Return stale cache on error rather than nothing — avoids
# metrics disappearing during transient failures.
return self._cached_result if self._cached_result is not None else []
def _collect_fresh(self) -> list[GaugeMetricFamily]:
raise NotImplementedError
def describe(self) -> list[GaugeMetricFamily]:
return []
class QueueDepthCollector(_CachedCollector):
"""Reads Celery queue lengths from the broker Redis on each scrape.
Uses a Redis client factory (callable) rather than a stored client
reference so the connection is always fresh from Celery's pool.
"""
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
super().__init__(cache_ttl)
self._get_redis: Callable[[], Redis] | None = None
def set_redis_factory(self, factory: Callable[[], Redis]) -> None:
"""Set a callable that returns a broker Redis client on demand."""
self._get_redis = factory
def _collect_fresh(self) -> list[GaugeMetricFamily]:
if self._get_redis is None:
return []
redis_client = self._get_redis()
depth = GaugeMetricFamily(
"onyx_queue_depth",
"Number of tasks waiting in Celery queue",
labels=["queue"],
)
unacked = GaugeMetricFamily(
"onyx_queue_unacked",
"Number of prefetched (unacked) tasks for queue",
labels=["queue"],
)
queue_age = GaugeMetricFamily(
"onyx_queue_oldest_task_age_seconds",
"Age of the oldest task in the queue (seconds since enqueue)",
labels=["queue"],
)
now = time.time()
for queue_name, label in _QUEUE_LABEL_MAP.items():
length = celery_get_queue_length(queue_name, redis_client)
depth.add_metric([label], length)
# Peek at the oldest message to get its age
if length > 0:
age = self._get_oldest_message_age(redis_client, queue_name, now)
if age is not None:
queue_age.add_metric([label], age)
for queue_name in _UNACKED_QUEUES:
label = _QUEUE_LABEL_MAP[queue_name]
task_ids = celery_get_unacked_task_ids(queue_name, redis_client)
unacked.add_metric([label], len(task_ids))
return [depth, unacked, queue_age]
@staticmethod
def _get_oldest_message_age(
redis_client: Redis, queue_name: str, now: float
) -> float | None:
"""Peek at the oldest (tail) message in a Redis list queue
and extract its timestamp to compute age.
Note: If the Celery message contains neither ``properties.timestamp``
nor ``headers.timestamp``, no age metric is emitted for this queue.
This can happen with custom task producers or non-standard Celery
protocol versions. The metric will simply be absent rather than
inaccurate, which is the safest behavior for alerting.
"""
try:
raw: bytes | str | None = redis_client.lindex(queue_name, -1) # type: ignore[assignment]
if raw is None:
return None
msg = json.loads(raw)
# Check for ETA tasks first — they are intentionally delayed,
# so reporting their queue age would be misleading.
headers = msg.get("headers", {})
if headers.get("eta") is not None:
return None
# Celery v2 protocol: timestamp in properties
props = msg.get("properties", {})
ts = props.get("timestamp")
if ts is not None:
return now - float(ts)
# Fallback: some Celery configurations place the timestamp in
# headers instead of properties.
ts = headers.get("timestamp")
if ts is not None:
return now - float(ts)
except Exception:
pass
return None
class IndexAttemptCollector(_CachedCollector):
"""Queries Postgres for index attempt state on each scrape."""
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
super().__init__(cache_ttl)
self._configured: bool = False
self._terminal_statuses: list = []
def configure(self) -> None:
"""Call once DB engine is initialized."""
from onyx.db.enums import IndexingStatus
self._terminal_statuses = [s for s in IndexingStatus if s.is_terminal()]
self._configured = True
def _collect_fresh(self) -> list[GaugeMetricFamily]:
if not self._configured:
return []
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.tenant_utils import get_all_tenant_ids
from onyx.db.index_attempt import get_active_index_attempts_for_metrics
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
attempts_gauge = GaugeMetricFamily(
"onyx_index_attempts_active",
"Number of non-terminal index attempts",
labels=[
"status",
"source",
"tenant_id",
"connector_name",
"cc_pair_id",
],
)
tenant_ids = get_all_tenant_ids()
for tid in tenant_ids:
# Defensive guard — get_all_tenant_ids() should never yield None,
# but we guard here for API stability in case the contract changes.
if tid is None:
continue
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tid)
try:
with get_session_with_current_tenant() as session:
rows = get_active_index_attempts_for_metrics(session)
for status, source, cc_id, cc_name, count in rows:
name_val = cc_name or f"cc_pair_{cc_id}"
attempts_gauge.add_metric(
[
status.value,
source.value,
tid,
name_val,
str(cc_id),
],
count,
)
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
return [attempts_gauge]
class ConnectorHealthCollector(_CachedCollector):
"""Queries Postgres for connector health state on each scrape."""
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
super().__init__(cache_ttl)
self._configured: bool = False
def configure(self) -> None:
"""Call once DB engine is initialized."""
self._configured = True
def _collect_fresh(self) -> list[GaugeMetricFamily]:
if not self._configured:
return []
from onyx.db.connector_credential_pair import (
get_connector_health_for_metrics,
)
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.tenant_utils import get_all_tenant_ids
from onyx.db.index_attempt import get_docs_indexed_by_cc_pair
from onyx.db.index_attempt import get_failed_attempt_counts_by_cc_pair
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
staleness_gauge = GaugeMetricFamily(
"onyx_connector_last_success_age_seconds",
"Seconds since last successful index for this connector",
labels=["tenant_id", "source", "cc_pair_id", "connector_name"],
)
error_state_gauge = GaugeMetricFamily(
"onyx_connector_in_error_state",
"Whether the connector is in a repeated error state (1=yes, 0=no)",
labels=["tenant_id", "source", "cc_pair_id", "connector_name"],
)
by_status_gauge = GaugeMetricFamily(
"onyx_connectors_by_status",
"Number of connectors grouped by status",
labels=["tenant_id", "status"],
)
error_total_gauge = GaugeMetricFamily(
"onyx_connectors_in_error_total",
"Total number of connectors in repeated error state",
labels=["tenant_id"],
)
per_connector_labels = [
"tenant_id",
"source",
"cc_pair_id",
"connector_name",
]
docs_success_gauge = GaugeMetricFamily(
"onyx_connector_docs_indexed",
"Total new documents indexed (90-day rolling sum) per connector",
labels=per_connector_labels,
)
docs_error_gauge = GaugeMetricFamily(
"onyx_connector_error_count",
"Total number of failed index attempts per connector",
labels=per_connector_labels,
)
now = datetime.now(tz=timezone.utc)
tenant_ids = get_all_tenant_ids()
for tid in tenant_ids:
# Defensive guard — get_all_tenant_ids() should never yield None,
# but we guard here for API stability in case the contract changes.
if tid is None:
continue
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tid)
try:
with get_session_with_current_tenant() as session:
pairs = get_connector_health_for_metrics(session)
error_counts_by_cc = get_failed_attempt_counts_by_cc_pair(session)
docs_by_cc = get_docs_indexed_by_cc_pair(session)
status_counts: dict[str, int] = {}
error_count = 0
for (
cc_id,
status,
in_error,
last_success,
cc_name,
source,
) in pairs:
cc_id_str = str(cc_id)
source_val = source.value
name_val = cc_name or f"cc_pair_{cc_id}"
label_vals = [tid, source_val, cc_id_str, name_val]
if last_success is not None:
# Both `now` and `last_success` are timezone-aware
# (the DB column uses DateTime(timezone=True)),
# so subtraction is safe.
age = (now - last_success).total_seconds()
staleness_gauge.add_metric(label_vals, age)
error_state_gauge.add_metric(
label_vals,
1.0 if in_error else 0.0,
)
if in_error:
error_count += 1
docs_success_gauge.add_metric(
label_vals,
docs_by_cc.get(cc_id, 0),
)
docs_error_gauge.add_metric(
label_vals,
error_counts_by_cc.get(cc_id, 0),
)
status_val = status.value
status_counts[status_val] = status_counts.get(status_val, 0) + 1
for status_val, count in status_counts.items():
by_status_gauge.add_metric([tid, status_val], count)
error_total_gauge.add_metric([tid], error_count)
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
return [
staleness_gauge,
error_state_gauge,
by_status_gauge,
error_total_gauge,
docs_success_gauge,
docs_error_gauge,
]
class RedisHealthCollector(_CachedCollector):
"""Collects Redis server health metrics (memory, clients, etc.)."""
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
super().__init__(cache_ttl)
self._get_redis: Callable[[], Redis] | None = None
def set_redis_factory(self, factory: Callable[[], Redis]) -> None:
"""Set a callable that returns a broker Redis client on demand."""
self._get_redis = factory
def _collect_fresh(self) -> list[GaugeMetricFamily]:
if self._get_redis is None:
return []
redis_client = self._get_redis()
memory_used = GaugeMetricFamily(
"onyx_redis_memory_used_bytes",
"Redis used memory in bytes",
)
memory_peak = GaugeMetricFamily(
"onyx_redis_memory_peak_bytes",
"Redis peak used memory in bytes",
)
memory_frag = GaugeMetricFamily(
"onyx_redis_memory_fragmentation_ratio",
"Redis memory fragmentation ratio (>1.5 indicates fragmentation)",
)
connected_clients = GaugeMetricFamily(
"onyx_redis_connected_clients",
"Number of connected Redis clients",
)
try:
mem_info: dict = redis_client.info("memory") # type: ignore[assignment]
memory_used.add_metric([], mem_info.get("used_memory", 0))
memory_peak.add_metric([], mem_info.get("used_memory_peak", 0))
frag = mem_info.get("mem_fragmentation_ratio")
if frag is not None:
memory_frag.add_metric([], frag)
client_info: dict = redis_client.info("clients") # type: ignore[assignment]
connected_clients.add_metric([], client_info.get("connected_clients", 0))
except Exception:
logger.debug("Failed to collect Redis health metrics", exc_info=True)
return [memory_used, memory_peak, memory_frag, connected_clients]
class WorkerHealthCollector(_CachedCollector):
"""Collects Celery worker count and process count via inspect ping.
Uses a longer cache TTL (60s) since inspect.ping() is a broadcast
command that takes a couple seconds to complete.
Maintains a set of known worker short-names so that when a worker
stops responding, we emit ``up=0`` instead of silently dropping the
metric (which would make ``absent()``-style alerts impossible).
"""
# Remove a worker from _known_workers after this many consecutive
# missed pings (at 60s TTL ≈ 10 minutes of being unreachable).
_MAX_CONSECUTIVE_MISSES = 10
def __init__(self, cache_ttl: float = 60.0) -> None:
super().__init__(cache_ttl)
self._celery_app: Any | None = None
# worker short-name → consecutive miss count.
# Workers start at 0 and reset to 0 each time they respond.
# Removed after _MAX_CONSECUTIVE_MISSES missed collects.
self._known_workers: dict[str, int] = {}
def set_celery_app(self, app: Any) -> None:
"""Set the Celery app instance for inspect commands."""
self._celery_app = app
def _collect_fresh(self) -> list[GaugeMetricFamily]:
if self._celery_app is None:
return []
active_workers = GaugeMetricFamily(
"onyx_celery_active_worker_count",
"Number of active Celery workers responding to ping",
)
worker_up = GaugeMetricFamily(
"onyx_celery_worker_up",
"Whether a specific Celery worker is alive (1=up, 0=down)",
labels=["worker"],
)
try:
inspector = self._celery_app.control.inspect(timeout=3.0)
ping_result = inspector.ping()
responding: set[str] = set()
if ping_result:
active_workers.add_metric([], len(ping_result))
for worker_name in ping_result:
# Strip hostname suffix for cleaner labels
short_name = worker_name.split("@")[0]
responding.add(short_name)
else:
active_workers.add_metric([], 0)
# Register newly-seen workers and reset miss count for
# workers that responded.
for short_name in responding:
self._known_workers[short_name] = 0
# Increment miss count for non-responding workers and evict
# those that have been missing too long.
stale = []
for short_name in list(self._known_workers):
if short_name not in responding:
self._known_workers[short_name] += 1
if self._known_workers[short_name] >= self._MAX_CONSECUTIVE_MISSES:
stale.append(short_name)
for short_name in stale:
del self._known_workers[short_name]
for short_name in sorted(self._known_workers):
worker_up.add_metric([short_name], 1 if short_name in responding else 0)
except Exception:
logger.debug("Failed to collect worker health metrics", exc_info=True)
return [active_workers, worker_up]

View File

@@ -0,0 +1,113 @@
"""Setup function for indexing pipeline Prometheus collectors.
Called once by the monitoring celery worker after Redis and DB are ready.
"""
from collections.abc import Callable
from typing import Any
from celery import Celery
from prometheus_client.registry import REGISTRY
from redis import Redis
from onyx.server.metrics.indexing_pipeline import ConnectorHealthCollector
from onyx.server.metrics.indexing_pipeline import IndexAttemptCollector
from onyx.server.metrics.indexing_pipeline import QueueDepthCollector
from onyx.server.metrics.indexing_pipeline import RedisHealthCollector
from onyx.server.metrics.indexing_pipeline import WorkerHealthCollector
from onyx.utils.logger import setup_logger
logger = setup_logger()
# Module-level singletons — these are lightweight objects (no connections or DB
# state) until configure() / set_redis_factory() is called. Keeping them at
# module level ensures they survive the lifetime of the worker process and are
# only registered with the Prometheus registry once.
_queue_collector = QueueDepthCollector()
_attempt_collector = IndexAttemptCollector()
_connector_collector = ConnectorHealthCollector()
_redis_health_collector = RedisHealthCollector()
_worker_health_collector = WorkerHealthCollector()
def _make_broker_redis_factory(celery_app: Celery) -> Callable[[], Redis]:
"""Create a factory that returns a cached broker Redis client.
Reuses a single connection across scrapes to avoid leaking connections.
Reconnects automatically if the cached connection becomes stale.
"""
_cached_client: list[Redis | None] = [None]
# Keep a reference to the Kombu Connection so we can close it on
# reconnect (the raw Redis client outlives the Kombu wrapper).
_cached_kombu_conn: list[Any] = [None]
def _close_client(client: Redis) -> None:
"""Best-effort close of a Redis client."""
try:
client.close()
except Exception:
logger.debug("Failed to close stale Redis client", exc_info=True)
def _close_kombu_conn() -> None:
"""Best-effort close of the cached Kombu Connection."""
conn = _cached_kombu_conn[0]
if conn is not None:
try:
conn.close()
except Exception:
logger.debug("Failed to close Kombu connection", exc_info=True)
_cached_kombu_conn[0] = None
def _get_broker_redis() -> Redis:
client = _cached_client[0]
if client is not None:
try:
client.ping()
return client
except Exception:
logger.debug("Cached Redis client stale, reconnecting")
_close_client(client)
_cached_client[0] = None
_close_kombu_conn()
# Get a fresh Redis client from the broker connection.
# We hold this client long-term (cached above) rather than using a
# context manager, because we need it to persist across scrapes.
# The caching logic above ensures we only ever hold one connection,
# and we close it explicitly on reconnect.
conn = celery_app.broker_connection()
# kombu's Channel exposes .client at runtime (the underlying Redis
# client) but the type stubs don't declare it.
new_client: Redis = conn.channel().client # type: ignore[attr-defined]
_cached_client[0] = new_client
_cached_kombu_conn[0] = conn
return new_client
return _get_broker_redis
def setup_indexing_pipeline_metrics(celery_app: Celery) -> None:
"""Register all indexing pipeline collectors with the default registry.
Args:
celery_app: The Celery application instance. Used to obtain a fresh
broker Redis client on each scrape for queue depth metrics.
"""
redis_factory = _make_broker_redis_factory(celery_app)
_queue_collector.set_redis_factory(redis_factory)
_redis_health_collector.set_redis_factory(redis_factory)
_worker_health_collector.set_celery_app(celery_app)
_attempt_collector.configure()
_connector_collector.configure()
for collector in (
_queue_collector,
_attempt_collector,
_connector_collector,
_redis_health_collector,
_worker_health_collector,
):
try:
REGISTRY.register(collector)
except ValueError:
logger.debug("Collector already registered: %s", type(collector).__name__)

View File

@@ -0,0 +1,253 @@
"""Per-connector Prometheus metrics for indexing tasks.
Enriches the two primary indexing tasks (docfetching_proxy_task and
docprocessing_task) with connector-level labels: source, tenant_id,
and cc_pair_id.
Note: connector_name is intentionally excluded from push-based per-task
counters because it is a user-defined free-form string that can create
unbounded cardinality. The pull-based collectors on the monitoring worker
(see indexing_pipeline.py) include connector_name since they have bounded
cardinality (one series per connector, not per task execution).
Uses an in-memory cache for cc_pair_id → (source, name) lookups.
Connectors never change source type, and names change rarely, so the
cache is safe to hold for the worker's lifetime.
Usage in a worker app module:
from onyx.server.metrics.indexing_task_metrics import (
on_indexing_task_prerun,
on_indexing_task_postrun,
)
"""
import threading
import time
from dataclasses import dataclass
from celery import Task
from prometheus_client import Counter
from prometheus_client import Histogram
from onyx.configs.constants import OnyxCeleryTask
from onyx.server.metrics.celery_task_metrics import _MAX_START_TIME_AGE_SECONDS
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
@dataclass(frozen=True)
class ConnectorInfo:
"""Cached connector metadata for metric labels."""
source: str
name: str
_UNKNOWN_CONNECTOR = ConnectorInfo(source="unknown", name="unknown")
# (tenant_id, cc_pair_id) → ConnectorInfo (populated on first encounter).
# Keyed by tenant to avoid cross-tenant cache poisoning in multi-tenant
# deployments where different tenants can share the same cc_pair_id value.
_connector_cache: dict[tuple[str, int], ConnectorInfo] = {}
# Lock protecting _connector_cache — multiple thread-pool workers may
# resolve connectors concurrently.
_connector_cache_lock = threading.Lock()
# Only enrich these task types with per-connector labels
_INDEXING_TASK_NAMES: frozenset[str] = frozenset(
{
OnyxCeleryTask.CONNECTOR_DOC_FETCHING_TASK,
OnyxCeleryTask.DOCPROCESSING_TASK,
}
)
# connector_name is intentionally excluded — see module docstring.
INDEXING_TASK_STARTED = Counter(
"onyx_indexing_task_started_total",
"Indexing tasks started per connector",
["task_name", "source", "tenant_id", "cc_pair_id"],
)
INDEXING_TASK_COMPLETED = Counter(
"onyx_indexing_task_completed_total",
"Indexing tasks completed per connector",
[
"task_name",
"source",
"tenant_id",
"cc_pair_id",
"outcome",
],
)
INDEXING_TASK_DURATION = Histogram(
"onyx_indexing_task_duration_seconds",
"Indexing task duration by connector type",
["task_name", "source", "tenant_id"],
buckets=[1, 5, 15, 30, 60, 120, 300, 600, 1800, 3600],
)
# task_id → monotonic start time (for indexing tasks only)
_indexing_start_times: dict[str, float] = {}
# Lock protecting _indexing_start_times — prerun, postrun, and eviction may
# run concurrently on thread-pool workers.
_indexing_start_times_lock = threading.Lock()
def _evict_stale_start_times() -> None:
"""Remove _indexing_start_times entries older than _MAX_START_TIME_AGE_SECONDS.
Must be called while holding _indexing_start_times_lock.
"""
now = time.monotonic()
stale_ids = [
tid
for tid, start in _indexing_start_times.items()
if now - start > _MAX_START_TIME_AGE_SECONDS
]
for tid in stale_ids:
_indexing_start_times.pop(tid, None)
def _resolve_connector(cc_pair_id: int) -> ConnectorInfo:
"""Resolve cc_pair_id to ConnectorInfo, using cache when possible.
On cache miss, does a single DB query with eager connector load.
On any failure, returns _UNKNOWN_CONNECTOR without caching, so that
subsequent calls can retry the lookup once the DB is available.
Note on tenant_id source: we read CURRENT_TENANT_ID_CONTEXTVAR for the
cache key. The Celery tenant-aware middleware sets this contextvar before
task execution, and it always matches kwargs["tenant_id"] (which is set
at task dispatch time). They are guaranteed to agree for a given task
execution context.
"""
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get("") or ""
cache_key = (tenant_id, cc_pair_id)
with _connector_cache_lock:
cached = _connector_cache.get(cache_key)
if cached is not None:
return cached
try:
from onyx.db.connector_credential_pair import (
get_connector_credential_pair_from_id,
)
from onyx.db.engine.sql_engine import get_session_with_current_tenant
with get_session_with_current_tenant() as db_session:
cc_pair = get_connector_credential_pair_from_id(
db_session,
cc_pair_id,
eager_load_connector=True,
)
if cc_pair is None:
# DB lookup succeeded but cc_pair doesn't exist — don't cache,
# it may appear later (race with connector creation).
return _UNKNOWN_CONNECTOR
info = ConnectorInfo(
source=cc_pair.connector.source.value,
name=cc_pair.name,
)
with _connector_cache_lock:
_connector_cache[cache_key] = info
return info
except Exception:
logger.debug(
f"Failed to resolve connector info for cc_pair_id={cc_pair_id}",
exc_info=True,
)
return _UNKNOWN_CONNECTOR
def on_indexing_task_prerun(
task_id: str | None,
task: Task | None,
kwargs: dict | None,
) -> None:
"""Record per-connector metrics at task start.
Only fires for tasks in _INDEXING_TASK_NAMES. Silently returns for
all other tasks.
"""
if task is None or task_id is None or kwargs is None:
return
task_name = task.name or ""
if task_name not in _INDEXING_TASK_NAMES:
return
try:
cc_pair_id = kwargs.get("cc_pair_id")
tenant_id = str(kwargs.get("tenant_id", "unknown"))
if cc_pair_id is None:
return
info = _resolve_connector(cc_pair_id)
INDEXING_TASK_STARTED.labels(
task_name=task_name,
source=info.source,
tenant_id=tenant_id,
cc_pair_id=str(cc_pair_id),
).inc()
with _indexing_start_times_lock:
_evict_stale_start_times()
_indexing_start_times[task_id] = time.monotonic()
except Exception:
logger.debug("Failed to record indexing task prerun metrics", exc_info=True)
def on_indexing_task_postrun(
task_id: str | None,
task: Task | None,
kwargs: dict | None,
state: str | None,
) -> None:
"""Record per-connector completion metrics.
Only fires for tasks in _INDEXING_TASK_NAMES.
"""
if task is None or task_id is None or kwargs is None:
return
task_name = task.name or ""
if task_name not in _INDEXING_TASK_NAMES:
return
try:
cc_pair_id = kwargs.get("cc_pair_id")
tenant_id = str(kwargs.get("tenant_id", "unknown"))
if cc_pair_id is None:
return
info = _resolve_connector(cc_pair_id)
outcome = "success" if state == "SUCCESS" else "failure"
INDEXING_TASK_COMPLETED.labels(
task_name=task_name,
source=info.source,
tenant_id=tenant_id,
cc_pair_id=str(cc_pair_id),
outcome=outcome,
).inc()
with _indexing_start_times_lock:
start = _indexing_start_times.pop(task_id, None)
if start is not None:
INDEXING_TASK_DURATION.labels(
task_name=task_name,
source=info.source,
tenant_id=tenant_id,
).observe(time.monotonic() - start)
except Exception:
logger.debug("Failed to record indexing task postrun metrics", exc_info=True)

View File

@@ -0,0 +1,89 @@
"""Standalone Prometheus metrics HTTP server for non-API processes.
The FastAPI API server already exposes /metrics via prometheus-fastapi-instrumentator.
Celery workers and other background processes use this module to expose their
own /metrics endpoint on a configurable port.
Usage:
from onyx.server.metrics.metrics_server import start_metrics_server
start_metrics_server("monitoring") # reads port from env or uses default
"""
import os
import threading
from prometheus_client import start_http_server
from onyx.utils.logger import setup_logger
logger = setup_logger()
# Default ports for worker types that serve custom Prometheus metrics.
# Only add entries here when a worker actually registers collectors.
# In k8s each worker type runs in its own pod, so PROMETHEUS_METRICS_PORT
# env var can override.
_DEFAULT_PORTS: dict[str, int] = {
"monitoring": 9096,
"docfetching": 9092,
"docprocessing": 9093,
}
_server_started = False
_server_lock = threading.Lock()
def start_metrics_server(worker_type: str) -> int | None:
"""Start a Prometheus metrics HTTP server in a background thread.
Returns the port if started, None if disabled or already started.
Port resolution order:
1. PROMETHEUS_METRICS_PORT env var (explicit override)
2. Default port for the worker type
3. If worker type is unknown and no env var, skip
Set PROMETHEUS_METRICS_ENABLED=false to disable.
"""
global _server_started
with _server_lock:
if _server_started:
logger.debug(f"Metrics server already started for {worker_type}")
return None
enabled = os.environ.get("PROMETHEUS_METRICS_ENABLED", "true").lower()
if enabled in ("false", "0", "no"):
logger.info(f"Prometheus metrics server disabled for {worker_type}")
return None
port_str = os.environ.get("PROMETHEUS_METRICS_PORT")
if port_str:
try:
port = int(port_str)
except ValueError:
logger.warning(
f"Invalid PROMETHEUS_METRICS_PORT '{port_str}' for {worker_type}, "
"must be a numeric port. Skipping metrics server."
)
return None
elif worker_type in _DEFAULT_PORTS:
port = _DEFAULT_PORTS[worker_type]
else:
logger.info(
f"No default metrics port for worker type '{worker_type}' "
"and PROMETHEUS_METRICS_PORT not set. Skipping metrics server."
)
return None
try:
start_http_server(port)
_server_started = True
logger.info(
f"Prometheus metrics server started on :{port} for {worker_type}"
)
return port
except OSError as e:
logger.warning(
f"Failed to start metrics server on :{port} for {worker_type}: {e}"
)
return None

View File

@@ -736,7 +736,7 @@ if __name__ == "__main__":
llm.config.model_name, llm.config.model_provider
)
persona = get_default_behavior_persona(db_session)
persona = get_default_behavior_persona(db_session, eager_load_for_tools=True)
if persona is None:
raise ValueError("No default persona found")

View File

@@ -9,6 +9,7 @@ from onyx.chat.emitter import Emitter
from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
from onyx.context.search.models import BaseFilters
from onyx.context.search.models import PersonaSearchInfo
from onyx.db.enums import MCPAuthenticationPerformer
from onyx.db.enums import MCPAuthenticationType
from onyx.db.mcp import get_all_mcp_tools_for_server
@@ -124,7 +125,12 @@ def construct_tools(
) -> dict[int, list[Tool]]:
"""Constructs tools based on persona configuration and available APIs.
Will simply skip tools that are not allowed/available."""
Will simply skip tools that are not allowed/available.
Callers must supply a persona with ``tools``, ``document_sets``,
``attached_documents``, and ``hierarchy_nodes`` already eager-loaded
(e.g. via ``eager_load_persona=True`` or ``eager_load_for_tools=True``)
to avoid lazy SQL queries after the session may have been flushed."""
tool_dict: dict[int, list[Tool]] = {}
# Log which tools are attached to the persona for debugging
@@ -143,6 +149,28 @@ def construct_tools(
# This flow is for search so we do not get all indices.
document_index = get_default_document_index(search_settings, None, db_session)
def _build_search_tool(tool_id: int, config: SearchToolConfig) -> SearchTool:
persona_search_info = PersonaSearchInfo(
document_set_names=[ds.name for ds in persona.document_sets],
search_start_date=persona.search_start_date,
attached_document_ids=[doc.id for doc in persona.attached_documents],
hierarchy_node_ids=[node.id for node in persona.hierarchy_nodes],
)
return SearchTool(
tool_id=tool_id,
emitter=emitter,
user=user,
persona_search_info=persona_search_info,
llm=llm,
document_index=document_index,
user_selected_filters=config.user_selected_filters,
project_id_filter=config.project_id_filter,
persona_id_filter=config.persona_id_filter,
bypass_acl=config.bypass_acl,
slack_context=config.slack_context,
enable_slack_search=config.enable_slack_search,
)
added_search_tool = False
for db_tool_model in persona.tools:
# If allowed_tool_ids is specified, skip tools not in the allowed list
@@ -176,22 +204,9 @@ def construct_tools(
if not search_tool_config:
search_tool_config = SearchToolConfig()
search_tool = SearchTool(
tool_id=db_tool_model.id,
emitter=emitter,
user=user,
persona=persona,
llm=llm,
document_index=document_index,
user_selected_filters=search_tool_config.user_selected_filters,
project_id_filter=search_tool_config.project_id_filter,
persona_id_filter=search_tool_config.persona_id_filter,
bypass_acl=search_tool_config.bypass_acl,
slack_context=search_tool_config.slack_context,
enable_slack_search=search_tool_config.enable_slack_search,
)
tool_dict[db_tool_model.id] = [search_tool]
tool_dict[db_tool_model.id] = [
_build_search_tool(db_tool_model.id, search_tool_config)
]
# Handle Image Generation Tool
elif tool_cls.__name__ == ImageGenerationTool.__name__:
@@ -421,26 +436,12 @@ def construct_tools(
# Get the database tool model for SearchTool
search_tool_db_model = get_builtin_tool(db_session, SearchTool)
# Use the passed-in config if available, otherwise create a new one
if not search_tool_config:
search_tool_config = SearchToolConfig()
search_tool = SearchTool(
tool_id=search_tool_db_model.id,
emitter=emitter,
user=user,
persona=persona,
llm=llm,
document_index=document_index,
user_selected_filters=search_tool_config.user_selected_filters,
project_id_filter=search_tool_config.project_id_filter,
persona_id_filter=search_tool_config.persona_id_filter,
bypass_acl=search_tool_config.bypass_acl,
slack_context=search_tool_config.slack_context,
enable_slack_search=search_tool_config.enable_slack_search,
)
tool_dict[search_tool_db_model.id] = [search_tool]
tool_dict[search_tool_db_model.id] = [
_build_search_tool(search_tool_db_model.id, search_tool_config)
]
# Always inject MemoryTool when the user has the memory tool enabled,
# bypassing persona tool associations and allowed_tool_ids filtering

View File

@@ -51,6 +51,7 @@ from onyx.context.search.models import ChunkSearchRequest
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import InferenceSection
from onyx.context.search.models import PersonaSearchInfo
from onyx.context.search.models import SearchDocsResponse
from onyx.context.search.pipeline import merge_individual_chunks
from onyx.context.search.pipeline import search_pipeline
@@ -65,7 +66,6 @@ from onyx.db.federated import (
get_federated_connector_document_set_mappings_by_document_set_names,
)
from onyx.db.federated import list_federated_connector_oauth_tokens
from onyx.db.models import Persona
from onyx.db.models import SearchSettings
from onyx.db.models import User
from onyx.db.search_settings import get_current_search_settings
@@ -238,8 +238,8 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
emitter: Emitter,
# Used for ACLs and federated search, anonymous users only see public docs
user: User,
# Used for filter settings
persona: Persona,
# Pre-extracted persona search configuration
persona_search_info: PersonaSearchInfo,
llm: LLM,
document_index: DocumentIndex,
# Respecting user selections
@@ -258,7 +258,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
super().__init__(emitter=emitter)
self.user = user
self.persona = persona
self.persona_search_info = persona_search_info
self.llm = llm
self.document_index = document_index
self.user_selected_filters = user_selected_filters
@@ -289,7 +289,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
# Case 1: Slack bot context — requires a Slack federated connector
# linked via the persona's document sets
if self.slack_context:
document_set_names = [ds.name for ds in self.persona.document_sets]
document_set_names = self.persona_search_info.document_set_names
if not document_set_names:
logger.debug(
"Skipping Slack federated search: no document sets on persona"
@@ -463,7 +463,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
persona_id_filter=self.persona_id_filter,
document_index=self.document_index,
user=self.user,
persona=self.persona,
persona_search_info=self.persona_search_info,
acl_filters=acl_filters,
embedding_model=embedding_model,
prefetched_federated_retrieval_infos=federated_retrieval_infos,
@@ -587,15 +587,12 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
and self.user_selected_filters.source_type
else None
)
persona_document_sets = (
[ds.name for ds in self.persona.document_sets] if self.persona else None
)
federated_retrieval_infos = (
get_federated_retrieval_functions(
db_session=db_session,
user_id=self.user.id if self.user else None,
source_types=prefetch_source_types,
document_set_names=persona_document_sets,
document_set_names=self.persona_search_info.document_set_names,
)
or []
)

View File

@@ -43,5 +43,8 @@ def test_eager_load_persona_loads_relationships(db_session: Session) -> None:
persona_unloaded = tmp.unloaded
assert "tools" not in persona_unloaded
assert "user_files" not in persona_unloaded
assert "document_sets" not in persona_unloaded
assert "attached_documents" not in persona_unloaded
assert "hierarchy_nodes" not in persona_unloaded
finally:
db_session.rollback()

View File

@@ -11,8 +11,8 @@ from sqlalchemy.orm import Session
from onyx.configs.constants import DocumentSource
from onyx.context.search.models import ChunkSearchRequest
from onyx.context.search.models import InferenceChunk
from onyx.context.search.models import PersonaSearchInfo
from onyx.context.search.models import SearchDoc
from onyx.db.models import Persona
from onyx.db.models import SearchSettings
from onyx.db.models import User
from onyx.document_index.interfaces import DocumentIndex
@@ -139,7 +139,7 @@ def use_mock_search_pipeline(
chunk_search_request: ChunkSearchRequest,
document_index: DocumentIndex, # noqa: ARG001
user: User | None, # noqa: ARG001
persona: Persona | None, # noqa: ARG001
persona_search_info: PersonaSearchInfo | None, # noqa: ARG001
db_session: Session | None = None, # noqa: ARG001
auto_detect_filters: bool = False, # noqa: ARG001
llm: LLM | None = None, # noqa: ARG001

View File

@@ -60,4 +60,4 @@ def test_non_rate_limit_error(mock_confluence_call: Mock) -> None:
with pytest.raises(HTTPError):
handled_call()
assert mock_confluence_call.call_count == 1
assert mock_confluence_call.call_count == 5

View File

@@ -0,0 +1,153 @@
"""Tests for generic Celery task lifecycle Prometheus metrics."""
from collections.abc import Iterator
from unittest.mock import MagicMock
import pytest
from onyx.server.metrics.celery_task_metrics import _task_start_times
from onyx.server.metrics.celery_task_metrics import on_celery_task_postrun
from onyx.server.metrics.celery_task_metrics import on_celery_task_prerun
from onyx.server.metrics.celery_task_metrics import TASK_COMPLETED
from onyx.server.metrics.celery_task_metrics import TASK_DURATION
from onyx.server.metrics.celery_task_metrics import TASK_STARTED
from onyx.server.metrics.celery_task_metrics import TASKS_ACTIVE
@pytest.fixture(autouse=True)
def reset_metrics() -> Iterator[None]:
"""Clear metric state between tests."""
_task_start_times.clear()
yield
_task_start_times.clear()
def _make_task(name: str = "test_task", queue: str = "test_queue") -> MagicMock:
task = MagicMock()
task.name = name
task.request = MagicMock()
task.request.delivery_info = {"routing_key": queue}
return task
class TestCeleryTaskPrerun:
def test_increments_started_and_active(self) -> None:
task = _make_task()
before_started = TASK_STARTED.labels(
task_name="test_task", queue="test_queue"
)._value.get()
before_active = TASKS_ACTIVE.labels(
task_name="test_task", queue="test_queue"
)._value.get()
on_celery_task_prerun("task-1", task)
after_started = TASK_STARTED.labels(
task_name="test_task", queue="test_queue"
)._value.get()
after_active = TASKS_ACTIVE.labels(
task_name="test_task", queue="test_queue"
)._value.get()
assert after_started == before_started + 1
assert after_active == before_active + 1
def test_records_start_time(self) -> None:
task = _make_task()
on_celery_task_prerun("task-1", task)
assert "task-1" in _task_start_times
def test_noop_when_task_is_none(self) -> None:
on_celery_task_prerun("task-1", None)
assert "task-1" not in _task_start_times
def test_noop_when_task_id_is_none(self) -> None:
task = _make_task()
on_celery_task_prerun(None, task)
# Should not crash
def test_handles_missing_delivery_info(self) -> None:
task = _make_task()
task.request.delivery_info = None
on_celery_task_prerun("task-1", task)
assert "task-1" in _task_start_times
class TestCeleryTaskPostrun:
def test_increments_completed_success(self) -> None:
task = _make_task()
on_celery_task_prerun("task-1", task)
before = TASK_COMPLETED.labels(
task_name="test_task", queue="test_queue", outcome="success"
)._value.get()
on_celery_task_postrun("task-1", task, "SUCCESS")
after = TASK_COMPLETED.labels(
task_name="test_task", queue="test_queue", outcome="success"
)._value.get()
assert after == before + 1
def test_increments_completed_failure(self) -> None:
task = _make_task()
on_celery_task_prerun("task-1", task)
before = TASK_COMPLETED.labels(
task_name="test_task", queue="test_queue", outcome="failure"
)._value.get()
on_celery_task_postrun("task-1", task, "FAILURE")
after = TASK_COMPLETED.labels(
task_name="test_task", queue="test_queue", outcome="failure"
)._value.get()
assert after == before + 1
def test_decrements_active(self) -> None:
task = _make_task()
on_celery_task_prerun("task-1", task)
active_before = TASKS_ACTIVE.labels(
task_name="test_task", queue="test_queue"
)._value.get()
on_celery_task_postrun("task-1", task, "SUCCESS")
active_after = TASKS_ACTIVE.labels(
task_name="test_task", queue="test_queue"
)._value.get()
assert active_after == active_before - 1
def test_observes_duration(self) -> None:
task = _make_task()
on_celery_task_prerun("task-1", task)
before_count = TASK_DURATION.labels(
task_name="test_task", queue="test_queue"
)._sum.get()
on_celery_task_postrun("task-1", task, "SUCCESS")
after_count = TASK_DURATION.labels(
task_name="test_task", queue="test_queue"
)._sum.get()
# Duration should have increased (at least slightly)
assert after_count > before_count
def test_cleans_up_start_time(self) -> None:
task = _make_task()
on_celery_task_prerun("task-1", task)
assert "task-1" in _task_start_times
on_celery_task_postrun("task-1", task, "SUCCESS")
assert "task-1" not in _task_start_times
def test_noop_when_task_is_none(self) -> None:
on_celery_task_postrun("task-1", None, "SUCCESS")
def test_handles_missing_start_time(self) -> None:
"""Postrun without prerun should not crash."""
task = _make_task()
on_celery_task_postrun("task-1", task, "SUCCESS")
# Should not raise

View File

@@ -0,0 +1,359 @@
"""Tests for indexing pipeline Prometheus collectors."""
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from onyx.server.metrics.indexing_pipeline import ConnectorHealthCollector
from onyx.server.metrics.indexing_pipeline import IndexAttemptCollector
from onyx.server.metrics.indexing_pipeline import QueueDepthCollector
class TestQueueDepthCollector:
def test_returns_empty_when_factory_not_set(self) -> None:
collector = QueueDepthCollector()
assert collector.collect() == []
def test_returns_empty_describe(self) -> None:
collector = QueueDepthCollector()
assert collector.describe() == []
def test_collects_queue_depths(self) -> None:
collector = QueueDepthCollector(cache_ttl=0)
mock_redis = MagicMock()
collector.set_redis_factory(lambda: mock_redis)
with (
patch(
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
return_value=5,
),
patch(
"onyx.server.metrics.indexing_pipeline.celery_get_unacked_task_ids",
return_value={"task-1", "task-2"},
),
):
families = collector.collect()
assert len(families) == 3
depth_family = families[0]
unacked_family = families[1]
age_family = families[2]
assert depth_family.name == "onyx_queue_depth"
assert len(depth_family.samples) > 0
for sample in depth_family.samples:
assert sample.value == 5
assert unacked_family.name == "onyx_queue_unacked"
unacked_labels = {s.labels["queue"] for s in unacked_family.samples}
assert "docfetching" in unacked_labels
assert "docprocessing" in unacked_labels
assert age_family.name == "onyx_queue_oldest_task_age_seconds"
for sample in unacked_family.samples:
assert sample.value == 2
def test_handles_redis_error_gracefully(self) -> None:
collector = QueueDepthCollector(cache_ttl=0)
mock_redis = MagicMock()
collector.set_redis_factory(lambda: mock_redis)
with patch(
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
side_effect=Exception("connection lost"),
):
families = collector.collect()
# Returns stale cache (empty on first call)
assert families == []
def test_caching_returns_stale_within_ttl(self) -> None:
collector = QueueDepthCollector(cache_ttl=60)
mock_redis = MagicMock()
collector.set_redis_factory(lambda: mock_redis)
with (
patch(
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
return_value=5,
),
patch(
"onyx.server.metrics.indexing_pipeline.celery_get_unacked_task_ids",
return_value=set(),
),
):
first = collector.collect()
# Second call within TTL should return cached result without calling Redis
with patch(
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
side_effect=Exception("should not be called"),
):
second = collector.collect()
assert first is second # Same object, from cache
def test_factory_called_each_scrape(self) -> None:
"""Verify the Redis factory is called on each fresh collect, not cached."""
collector = QueueDepthCollector(cache_ttl=0)
factory = MagicMock(return_value=MagicMock())
collector.set_redis_factory(factory)
with (
patch(
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
return_value=0,
),
patch(
"onyx.server.metrics.indexing_pipeline.celery_get_unacked_task_ids",
return_value=set(),
),
):
collector.collect()
collector.collect()
assert factory.call_count == 2
def test_error_returns_stale_cache(self) -> None:
collector = QueueDepthCollector(cache_ttl=0)
mock_redis = MagicMock()
collector.set_redis_factory(lambda: mock_redis)
# First call succeeds
with (
patch(
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
return_value=10,
),
patch(
"onyx.server.metrics.indexing_pipeline.celery_get_unacked_task_ids",
return_value=set(),
),
):
good_result = collector.collect()
assert len(good_result) == 3
assert good_result[0].samples[0].value == 10
# Second call fails — should return stale cache, not empty
with patch(
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
side_effect=Exception("Redis down"),
):
stale_result = collector.collect()
assert stale_result is good_result
class TestIndexAttemptCollector:
def test_returns_empty_when_not_configured(self) -> None:
collector = IndexAttemptCollector()
assert collector.collect() == []
def test_returns_empty_describe(self) -> None:
collector = IndexAttemptCollector()
assert collector.describe() == []
@patch("onyx.db.engine.tenant_utils.get_all_tenant_ids")
@patch("onyx.db.engine.sql_engine.get_session_with_current_tenant")
def test_collects_index_attempts(
self,
mock_get_session: MagicMock,
mock_get_tenants: MagicMock,
) -> None:
collector = IndexAttemptCollector(cache_ttl=0)
collector.configure()
mock_get_tenants.return_value = ["public"]
mock_session = MagicMock()
mock_get_session.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_get_session.return_value.__exit__ = MagicMock(return_value=False)
from onyx.db.enums import IndexingStatus
mock_row = (
IndexingStatus.IN_PROGRESS,
MagicMock(value="web"),
81,
"Table Tennis Blade Guide",
2,
)
mock_session.query.return_value.join.return_value.join.return_value.filter.return_value.group_by.return_value.all.return_value = [
mock_row
]
families = collector.collect()
assert len(families) == 1
assert families[0].name == "onyx_index_attempts_active"
assert len(families[0].samples) == 1
sample = families[0].samples[0]
assert sample.labels == {
"status": "in_progress",
"source": "web",
"tenant_id": "public",
"connector_name": "Table Tennis Blade Guide",
"cc_pair_id": "81",
}
assert sample.value == 2
@patch("onyx.db.engine.tenant_utils.get_all_tenant_ids")
def test_handles_db_error_gracefully(
self,
mock_get_tenants: MagicMock,
) -> None:
collector = IndexAttemptCollector(cache_ttl=0)
collector.configure()
mock_get_tenants.side_effect = Exception("DB down")
families = collector.collect()
# No stale cache, so returns empty
assert families == []
@patch("onyx.db.engine.tenant_utils.get_all_tenant_ids")
def test_skips_none_tenant_ids(
self,
mock_get_tenants: MagicMock,
) -> None:
collector = IndexAttemptCollector(cache_ttl=0)
collector.configure()
mock_get_tenants.return_value = [None]
families = collector.collect()
assert len(families) == 1 # Returns the gauge family, just with no samples
assert len(families[0].samples) == 0
class TestConnectorHealthCollector:
def test_returns_empty_when_not_configured(self) -> None:
collector = ConnectorHealthCollector()
assert collector.collect() == []
def test_returns_empty_describe(self) -> None:
collector = ConnectorHealthCollector()
assert collector.describe() == []
@patch("onyx.db.engine.tenant_utils.get_all_tenant_ids")
@patch("onyx.db.engine.sql_engine.get_session_with_current_tenant")
def test_collects_connector_health(
self,
mock_get_session: MagicMock,
mock_get_tenants: MagicMock,
) -> None:
collector = ConnectorHealthCollector(cache_ttl=0)
collector.configure()
mock_get_tenants.return_value = ["public"]
mock_session = MagicMock()
mock_get_session.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_get_session.return_value.__exit__ = MagicMock(return_value=False)
now = datetime.now(tz=timezone.utc)
last_success = now - timedelta(hours=2)
mock_status = MagicMock(value="ACTIVE")
mock_source = MagicMock(value="google_drive")
# Row: (id, status, in_error, last_success, name, source)
mock_row = (
42,
mock_status,
True, # in_repeated_error_state
last_success,
"My GDrive Connector",
mock_source,
)
mock_session.query.return_value.join.return_value.all.return_value = [mock_row]
# Mock the index attempt queries (error counts + docs counts)
mock_session.query.return_value.filter.return_value.group_by.return_value.all.return_value = (
[]
)
families = collector.collect()
assert len(families) == 6
names = {f.name for f in families}
assert names == {
"onyx_connector_last_success_age_seconds",
"onyx_connector_in_error_state",
"onyx_connectors_by_status",
"onyx_connectors_in_error_total",
"onyx_connector_docs_indexed",
"onyx_connector_error_count",
}
staleness = next(
f for f in families if f.name == "onyx_connector_last_success_age_seconds"
)
assert len(staleness.samples) == 1
assert staleness.samples[0].value == pytest.approx(7200, abs=5)
error_state = next(
f for f in families if f.name == "onyx_connector_in_error_state"
)
assert error_state.samples[0].value == 1.0
by_status = next(f for f in families if f.name == "onyx_connectors_by_status")
assert by_status.samples[0].labels == {
"tenant_id": "public",
"status": "ACTIVE",
}
assert by_status.samples[0].value == 1
error_total = next(
f for f in families if f.name == "onyx_connectors_in_error_total"
)
assert error_total.samples[0].value == 1
@patch("onyx.db.engine.tenant_utils.get_all_tenant_ids")
@patch("onyx.db.engine.sql_engine.get_session_with_current_tenant")
def test_skips_staleness_when_no_last_success(
self,
mock_get_session: MagicMock,
mock_get_tenants: MagicMock,
) -> None:
collector = ConnectorHealthCollector(cache_ttl=0)
collector.configure()
mock_get_tenants.return_value = ["public"]
mock_session = MagicMock()
mock_get_session.return_value.__enter__ = MagicMock(return_value=mock_session)
mock_get_session.return_value.__exit__ = MagicMock(return_value=False)
mock_status = MagicMock(value="INITIAL_INDEXING")
mock_source = MagicMock(value="slack")
mock_row = (
10,
mock_status,
False,
None, # no last_successful_index_time
0,
mock_source,
)
mock_session.query.return_value.join.return_value.all.return_value = [mock_row]
families = collector.collect()
staleness = next(
f for f in families if f.name == "onyx_connector_last_success_age_seconds"
)
assert len(staleness.samples) == 0
@patch("onyx.db.engine.tenant_utils.get_all_tenant_ids")
def test_handles_db_error_gracefully(
self,
mock_get_tenants: MagicMock,
) -> None:
collector = ConnectorHealthCollector(cache_ttl=0)
collector.configure()
mock_get_tenants.side_effect = Exception("DB down")
families = collector.collect()
assert families == []

View File

@@ -0,0 +1,96 @@
"""Tests for indexing pipeline setup (Redis factory caching)."""
from unittest.mock import MagicMock
from onyx.server.metrics.indexing_pipeline_setup import _make_broker_redis_factory
def _make_mock_app(client: MagicMock) -> MagicMock:
"""Create a mock Celery app whose broker_connection().channel().client
returns the given client."""
mock_app = MagicMock()
mock_conn = MagicMock()
mock_conn.channel.return_value.client = client
mock_app.broker_connection.return_value = mock_conn
return mock_app
class TestMakeBrokerRedisFactory:
def test_caches_redis_client_across_calls(self) -> None:
"""Factory should reuse the same client on subsequent calls."""
mock_client = MagicMock()
mock_client.ping.return_value = True
mock_app = _make_mock_app(mock_client)
factory = _make_broker_redis_factory(mock_app)
client1 = factory()
client2 = factory()
assert client1 is client2
# broker_connection should only be called once
assert mock_app.broker_connection.call_count == 1
def test_reconnects_when_ping_fails(self) -> None:
"""Factory should create a new client if ping fails (stale connection)."""
mock_client_stale = MagicMock()
mock_client_stale.ping.side_effect = ConnectionError("disconnected")
mock_client_fresh = MagicMock()
mock_client_fresh.ping.return_value = True
mock_app = _make_mock_app(mock_client_stale)
factory = _make_broker_redis_factory(mock_app)
# First call — creates and caches
client1 = factory()
assert client1 is mock_client_stale
assert mock_app.broker_connection.call_count == 1
# Switch to fresh client for next connection
mock_conn_fresh = MagicMock()
mock_conn_fresh.channel.return_value.client = mock_client_fresh
mock_app.broker_connection.return_value = mock_conn_fresh
# Second call — ping fails on stale, reconnects
client2 = factory()
assert client2 is mock_client_fresh
assert mock_app.broker_connection.call_count == 2
def test_reconnect_closes_stale_client(self) -> None:
"""When ping fails, the old client should be closed before reconnecting."""
mock_client_stale = MagicMock()
mock_client_stale.ping.side_effect = ConnectionError("disconnected")
mock_client_fresh = MagicMock()
mock_client_fresh.ping.return_value = True
mock_app = _make_mock_app(mock_client_stale)
factory = _make_broker_redis_factory(mock_app)
# First call — creates and caches
factory()
# Switch to fresh client
mock_conn_fresh = MagicMock()
mock_conn_fresh.channel.return_value.client = mock_client_fresh
mock_app.broker_connection.return_value = mock_conn_fresh
# Second call — ping fails, should close stale client
factory()
mock_client_stale.close.assert_called_once()
def test_first_call_creates_connection(self) -> None:
"""First call should always create a new connection."""
mock_client = MagicMock()
mock_app = _make_mock_app(mock_client)
factory = _make_broker_redis_factory(mock_app)
client = factory()
assert client is mock_client
mock_app.broker_connection.assert_called_once()

View File

@@ -0,0 +1,335 @@
"""Tests for per-connector indexing task Prometheus metrics."""
from collections.abc import Iterator
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from onyx.server.metrics.indexing_task_metrics import _connector_cache
from onyx.server.metrics.indexing_task_metrics import _indexing_start_times
from onyx.server.metrics.indexing_task_metrics import ConnectorInfo
from onyx.server.metrics.indexing_task_metrics import INDEXING_TASK_COMPLETED
from onyx.server.metrics.indexing_task_metrics import INDEXING_TASK_DURATION
from onyx.server.metrics.indexing_task_metrics import INDEXING_TASK_STARTED
from onyx.server.metrics.indexing_task_metrics import on_indexing_task_postrun
from onyx.server.metrics.indexing_task_metrics import on_indexing_task_prerun
@pytest.fixture(autouse=True)
def reset_state() -> Iterator[None]:
"""Clear caches and state between tests.
Sets CURRENT_TENANT_ID_CONTEXTVAR to a realistic value so cache keys
are never keyed on an empty string.
"""
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
token = CURRENT_TENANT_ID_CONTEXTVAR.set("test_tenant")
_connector_cache.clear()
_indexing_start_times.clear()
yield
_connector_cache.clear()
_indexing_start_times.clear()
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def _make_task(name: str) -> MagicMock:
task = MagicMock()
task.name = name
return task
def _mock_db_lookup(
source: str = "google_drive", name: str = "My Google Drive"
) -> tuple:
"""Return (session_patch, cc_pair_patch) context managers for DB mocking."""
mock_cc_pair = MagicMock()
mock_cc_pair.name = name
mock_cc_pair.connector.source.value = source
session_patch = patch("onyx.db.engine.sql_engine.get_session_with_current_tenant")
cc_pair_patch = patch(
"onyx.db.connector_credential_pair.get_connector_credential_pair_from_id",
return_value=mock_cc_pair,
)
return session_patch, cc_pair_patch
class TestIndexingTaskPrerun:
def test_skips_non_indexing_task(self) -> None:
task = _make_task("some_other_task")
kwargs = {"cc_pair_id": 1, "tenant_id": "public"}
on_indexing_task_prerun("task-1", task, kwargs)
assert "task-1" not in _indexing_start_times
def test_emits_started_for_docfetching(self) -> None:
# Pre-populate cache to avoid DB lookup (tenant-scoped key)
_connector_cache[("test_tenant", 42)] = ConnectorInfo(
source="google_drive", name="My Google Drive"
)
task = _make_task("connector_doc_fetching_task")
kwargs = {"cc_pair_id": 42, "tenant_id": "tenant-1"}
before = INDEXING_TASK_STARTED.labels(
task_name="connector_doc_fetching_task",
source="google_drive",
tenant_id="tenant-1",
cc_pair_id="42",
)._value.get()
on_indexing_task_prerun("task-1", task, kwargs)
after = INDEXING_TASK_STARTED.labels(
task_name="connector_doc_fetching_task",
source="google_drive",
tenant_id="tenant-1",
cc_pair_id="42",
)._value.get()
assert after == before + 1
assert "task-1" in _indexing_start_times
def test_emits_started_for_docprocessing(self) -> None:
_connector_cache[("test_tenant", 10)] = ConnectorInfo(
source="slack", name="Slack Connector"
)
task = _make_task("docprocessing_task")
kwargs = {"cc_pair_id": 10, "tenant_id": "public"}
on_indexing_task_prerun("task-2", task, kwargs)
assert "task-2" in _indexing_start_times
def test_cache_hit_avoids_db_call(self) -> None:
_connector_cache[("test_tenant", 42)] = ConnectorInfo(
source="confluence", name="Engineering Confluence"
)
task = _make_task("connector_doc_fetching_task")
kwargs = {"cc_pair_id": 42, "tenant_id": "public"}
# No DB patches needed — cache should be used
on_indexing_task_prerun("task-1", task, kwargs)
assert "task-1" in _indexing_start_times
def test_db_lookup_on_cache_miss(self) -> None:
"""On first encounter of a cc_pair_id, does a DB lookup and caches."""
mock_cc_pair = MagicMock()
mock_cc_pair.name = "Notion Workspace"
mock_cc_pair.connector.source.value = "notion"
mock_session = MagicMock()
mock_session.__enter__ = MagicMock(return_value=MagicMock())
mock_session.__exit__ = MagicMock(return_value=False)
with (
patch(
"onyx.server.metrics.indexing_task_metrics._resolve_connector"
) as mock_resolve,
):
mock_resolve.return_value = ConnectorInfo(
source="notion", name="Notion Workspace"
)
task = _make_task("connector_doc_fetching_task")
kwargs = {"cc_pair_id": 77, "tenant_id": "public"}
on_indexing_task_prerun("task-1", task, kwargs)
mock_resolve.assert_called_once_with(77)
def test_missing_cc_pair_returns_unknown(self) -> None:
"""When _resolve_connector can't find the cc_pair, uses 'unknown'."""
with patch(
"onyx.server.metrics.indexing_task_metrics._resolve_connector"
) as mock_resolve:
mock_resolve.return_value = ConnectorInfo(source="unknown", name="unknown")
task = _make_task("connector_doc_fetching_task")
kwargs = {"cc_pair_id": 999, "tenant_id": "public"}
on_indexing_task_prerun("task-1", task, kwargs)
assert "task-1" in _indexing_start_times
def test_skips_when_cc_pair_id_missing(self) -> None:
task = _make_task("connector_doc_fetching_task")
kwargs = {"tenant_id": "public"}
on_indexing_task_prerun("task-1", task, kwargs)
assert "task-1" not in _indexing_start_times
def test_db_error_does_not_crash(self) -> None:
with patch(
"onyx.server.metrics.indexing_task_metrics._resolve_connector",
side_effect=Exception("DB down"),
):
task = _make_task("connector_doc_fetching_task")
kwargs = {"cc_pair_id": 1, "tenant_id": "public"}
# Should not raise
on_indexing_task_prerun("task-1", task, kwargs)
class TestIndexingTaskPostrun:
def test_skips_non_indexing_task(self) -> None:
task = _make_task("some_other_task")
kwargs = {"cc_pair_id": 1, "tenant_id": "public"}
on_indexing_task_postrun("task-1", task, kwargs, "SUCCESS")
# Should not raise
def test_emits_completed_and_duration(self) -> None:
_connector_cache[("test_tenant", 42)] = ConnectorInfo(
source="google_drive", name="Marketing Drive"
)
task = _make_task("docprocessing_task")
kwargs = {"cc_pair_id": 42, "tenant_id": "public"}
# Simulate prerun
on_indexing_task_prerun("task-1", task, kwargs)
before_completed = INDEXING_TASK_COMPLETED.labels(
task_name="docprocessing_task",
source="google_drive",
tenant_id="public",
cc_pair_id="42",
outcome="success",
)._value.get()
before_duration = INDEXING_TASK_DURATION.labels(
task_name="docprocessing_task",
source="google_drive",
tenant_id="public",
)._sum.get()
on_indexing_task_postrun("task-1", task, kwargs, "SUCCESS")
after_completed = INDEXING_TASK_COMPLETED.labels(
task_name="docprocessing_task",
source="google_drive",
tenant_id="public",
cc_pair_id="42",
outcome="success",
)._value.get()
after_duration = INDEXING_TASK_DURATION.labels(
task_name="docprocessing_task",
source="google_drive",
tenant_id="public",
)._sum.get()
assert after_completed == before_completed + 1
assert after_duration > before_duration
def test_failure_outcome(self) -> None:
_connector_cache[("test_tenant", 42)] = ConnectorInfo(
source="slack", name="Slack"
)
task = _make_task("connector_doc_fetching_task")
kwargs = {"cc_pair_id": 42, "tenant_id": "public"}
on_indexing_task_prerun("task-1", task, kwargs)
before = INDEXING_TASK_COMPLETED.labels(
task_name="connector_doc_fetching_task",
source="slack",
tenant_id="public",
cc_pair_id="42",
outcome="failure",
)._value.get()
on_indexing_task_postrun("task-1", task, kwargs, "FAILURE")
after = INDEXING_TASK_COMPLETED.labels(
task_name="connector_doc_fetching_task",
source="slack",
tenant_id="public",
cc_pair_id="42",
outcome="failure",
)._value.get()
assert after == before + 1
def test_handles_postrun_without_prerun(self) -> None:
"""Postrun for an indexing task without a matching prerun should not crash."""
_connector_cache[("test_tenant", 42)] = ConnectorInfo(
source="slack", name="Slack"
)
task = _make_task("docprocessing_task")
kwargs = {"cc_pair_id": 42, "tenant_id": "public"}
# No prerun — should still emit completed counter, just skip duration
on_indexing_task_postrun("task-1", task, kwargs, "SUCCESS")
class TestResolveConnector:
def test_failed_lookup_not_cached(self) -> None:
"""When DB lookup returns None, result should NOT be cached."""
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
token = CURRENT_TENANT_ID_CONTEXTVAR.set("test-tenant")
try:
with (
patch("onyx.db.engine.sql_engine.get_session_with_current_tenant"),
patch(
"onyx.db.connector_credential_pair"
".get_connector_credential_pair_from_id",
return_value=None,
),
):
from onyx.server.metrics.indexing_task_metrics import _resolve_connector
result = _resolve_connector(999)
assert result.source == "unknown"
# Should NOT be cached so subsequent calls can retry
assert ("test-tenant", 999) not in _connector_cache
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def test_exception_not_cached(self) -> None:
"""When DB lookup raises, result should NOT be cached."""
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
token = CURRENT_TENANT_ID_CONTEXTVAR.set("test-tenant")
try:
with (
patch(
"onyx.db.engine.sql_engine.get_session_with_current_tenant",
side_effect=Exception("DB down"),
),
):
from onyx.server.metrics.indexing_task_metrics import _resolve_connector
result = _resolve_connector(888)
assert result.source == "unknown"
assert ("test-tenant", 888) not in _connector_cache
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def test_successful_lookup_is_cached(self) -> None:
"""When DB lookup succeeds, result should be cached."""
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
token = CURRENT_TENANT_ID_CONTEXTVAR.set("test-tenant")
try:
mock_cc_pair = MagicMock()
mock_cc_pair.name = "My Drive"
mock_cc_pair.connector.source.value = "google_drive"
with (
patch("onyx.db.engine.sql_engine.get_session_with_current_tenant"),
patch(
"onyx.db.connector_credential_pair"
".get_connector_credential_pair_from_id",
return_value=mock_cc_pair,
),
):
from onyx.server.metrics.indexing_task_metrics import _resolve_connector
result = _resolve_connector(777)
assert result.source == "google_drive"
assert result.name == "My Drive"
assert ("test-tenant", 777) in _connector_cache
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)

View File

@@ -0,0 +1,69 @@
"""Tests for the Prometheus metrics server module."""
from collections.abc import Iterator
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
from onyx.server.metrics.metrics_server import _DEFAULT_PORTS
from onyx.server.metrics.metrics_server import start_metrics_server
@pytest.fixture(autouse=True)
def reset_server_state() -> Iterator[None]:
"""Reset the global _server_started between tests."""
import onyx.server.metrics.metrics_server as mod
mod._server_started = False
yield
mod._server_started = False
class TestStartMetricsServer:
@patch("onyx.server.metrics.metrics_server.start_http_server")
def test_uses_default_port_for_known_worker(self, mock_start: MagicMock) -> None:
port = start_metrics_server("monitoring")
assert port == _DEFAULT_PORTS["monitoring"]
mock_start.assert_called_once_with(_DEFAULT_PORTS["monitoring"])
@patch("onyx.server.metrics.metrics_server.start_http_server")
@patch.dict("os.environ", {"PROMETHEUS_METRICS_PORT": "9999"})
def test_env_var_overrides_default(self, mock_start: MagicMock) -> None:
port = start_metrics_server("monitoring")
assert port == 9999
mock_start.assert_called_once_with(9999)
@patch("onyx.server.metrics.metrics_server.start_http_server")
@patch.dict("os.environ", {"PROMETHEUS_METRICS_ENABLED": "false"})
def test_disabled_via_env_var(self, mock_start: MagicMock) -> None:
port = start_metrics_server("monitoring")
assert port is None
mock_start.assert_not_called()
@patch("onyx.server.metrics.metrics_server.start_http_server")
def test_unknown_worker_type_no_env_var(self, mock_start: MagicMock) -> None:
port = start_metrics_server("unknown_worker")
assert port is None
mock_start.assert_not_called()
@patch("onyx.server.metrics.metrics_server.start_http_server")
def test_idempotent(self, mock_start: MagicMock) -> None:
port1 = start_metrics_server("monitoring")
port2 = start_metrics_server("monitoring")
assert port1 == _DEFAULT_PORTS["monitoring"]
assert port2 is None
mock_start.assert_called_once()
@patch("onyx.server.metrics.metrics_server.start_http_server")
def test_handles_os_error(self, mock_start: MagicMock) -> None:
mock_start.side_effect = OSError("Address already in use")
port = start_metrics_server("monitoring")
assert port is None
@patch("onyx.server.metrics.metrics_server.start_http_server")
@patch.dict("os.environ", {"PROMETHEUS_METRICS_PORT": "not_a_number"})
def test_invalid_port_env_var_returns_none(self, mock_start: MagicMock) -> None:
port = start_metrics_server("monitoring")
assert port is None
mock_start.assert_not_called()

View File

@@ -145,6 +145,8 @@ export function Table<TData>(props: DataTableProps<TData>) {
pageSize,
initialSorting,
initialColumnVisibility,
initialRowSelection,
initialViewSelected,
draggable,
footer,
size = "lg",
@@ -221,6 +223,8 @@ export function Table<TData>(props: DataTableProps<TData>) {
pageSize: effectivePageSize,
initialSorting,
initialColumnVisibility,
initialRowSelection,
initialViewSelected,
getRowId,
onSelectionChange,
searchTerm,

View File

@@ -103,6 +103,10 @@ interface UseDataTableOptions<TData extends RowData> {
initialSorting?: SortingState;
/** Initial column visibility state. @default {} */
initialColumnVisibility?: VisibilityState;
/** Initial row selection state. Keys are row IDs (from `getRowId`), values are `true`. @default {} */
initialRowSelection?: RowSelectionState;
/** When true AND `initialRowSelection` is non-empty, start in view-selected mode (filtered to selected rows). @default false */
initialViewSelected?: boolean;
/** Called whenever the set of selected row IDs changes. */
onSelectionChange?: (selectedIds: string[]) => void;
/** Search term for global text filtering. Rows are filtered to those containing
@@ -195,6 +199,8 @@ export default function useDataTable<TData extends RowData>(
columnResizeMode = "onChange",
initialSorting = [],
initialColumnVisibility = {},
initialRowSelection = {},
initialViewSelected = false,
getRowId,
onSelectionChange,
searchTerm,
@@ -206,7 +212,8 @@ export default function useDataTable<TData extends RowData>(
// ---- internal state -----------------------------------------------------
const [sorting, setSorting] = useState<SortingState>(initialSorting);
const [rowSelection, setRowSelection] = useState<RowSelectionState>({});
const [rowSelection, setRowSelection] =
useState<RowSelectionState>(initialRowSelection);
const [columnSizing, setColumnSizing] = useState<ColumnSizingState>({});
const [columnVisibility, setColumnVisibility] = useState<VisibilityState>(
initialColumnVisibility
@@ -216,8 +223,12 @@ export default function useDataTable<TData extends RowData>(
pageSize: pageSizeOption,
});
/** Combined global filter: view-mode (selected IDs) + text search. */
const initialSelectedIds =
initialViewSelected && Object.keys(initialRowSelection).length > 0
? new Set(Object.keys(initialRowSelection))
: null;
const [globalFilter, setGlobalFilter] = useState<GlobalFilterValue>({
selectedIds: null,
selectedIds: initialSelectedIds,
searchTerm: "",
});
@@ -384,6 +395,31 @@ export default function useDataTable<TData extends RowData>(
: data.length;
const isPaginated = isFinite(pagination.pageSize);
// ---- keep view-mode filter in sync with selection ----------------------
// When in view-selected mode, deselecting a row should remove it from
// the visible set so it disappears immediately.
useEffect(() => {
if (isServerSide) return;
if (globalFilter.selectedIds == null) return;
const currentIds = new Set(Object.keys(rowSelection));
// Remove any ID from the filter that is no longer selected
let changed = false;
const next = new Set<string>();
globalFilter.selectedIds.forEach((id) => {
if (currentIds.has(id)) {
next.add(id);
} else {
changed = true;
}
});
if (changed) {
setGlobalFilter((prev) => ({ ...prev, selectedIds: next }));
}
// eslint-disable-next-line react-hooks/exhaustive-deps -- only react to
// selection changes while in view mode
}, [rowSelection, isServerSide]);
// ---- selection change callback ------------------------------------------
const isFirstRenderRef = useRef(true);
const onSelectionChangeRef = useRef(onSelectionChange);
@@ -392,6 +428,10 @@ export default function useDataTable<TData extends RowData>(
useEffect(() => {
if (isFirstRenderRef.current) {
isFirstRenderRef.current = false;
// Still fire the callback on first render if there's an initial selection
if (selectedRowIds.length > 0) {
onSelectionChangeRef.current?.(selectedRowIds);
}
return;
}
onSelectionChangeRef.current?.(selectedRowIds);

View File

@@ -146,6 +146,10 @@ export interface DataTableProps<TData> {
initialSorting?: SortingState;
/** Initial column visibility state. */
initialColumnVisibility?: VisibilityState;
/** Initial row selection state. Keys are row IDs (from `getRowId`), values are `true`. */
initialRowSelection?: Record<string, boolean>;
/** When true AND `initialRowSelection` is non-empty, start in view-selected mode. @default false */
initialViewSelected?: boolean;
/** Enable drag-and-drop row reordering. */
draggable?: DataTableDraggableConfig;
/** Footer configuration. */

View File

@@ -158,25 +158,22 @@ function Main({ ccPairId }: { ccPairId: number }) {
mutate(buildCCPairInfoUrl(ccPairId));
}, [ccPairId]);
const shouldConfirmConnectorDeletion = true;
const finishConnectorDeletion = useCallback(() => {
router.push("/admin/indexing/status");
}, [router]);
const scheduleConnectorDeletion = useCallback(async () => {
const scheduleConnectorDeletion = useCallback(() => {
if (!ccPair) return;
if (isSchedulingConnectorDeletionRef.current) return;
isSchedulingConnectorDeletionRef.current = true;
try {
await deleteCCPair(ccPair.connector.id, ccPair.credential.id, () =>
mutate(buildCCPairInfoUrl(ccPair.id))
deleteCCPair(ccPair.connector.id, ccPair.credential.id).catch((error) => {
toast.error(
"Failed to schedule deletion of connector - " + error.message
);
refresh();
} catch (error) {
console.error("Error deleting connector:", error);
} finally {
setShowDeleteConnectorConfirmModal(false);
isSchedulingConnectorDeletionRef.current = false;
}
}, [ccPair, refresh]);
});
finishConnectorDeletion();
}, [ccPair, finishConnectorDeletion]);
const latestIndexAttempt = indexAttempts?.[0];
const canManageInlineFileConnectorFiles =
@@ -194,10 +191,6 @@ function Main({ ccPairId }: { ccPairId: number }) {
(error) => error.index_attempt_id === latestIndexAttempt?.id
);
const finishConnectorDeletion = useCallback(() => {
router.push("/admin/indexing/status?message=connector-deleted");
}, [router]);
const handleStatusUpdate = async (
newStatus: ConnectorCredentialPairStatus
) => {
@@ -520,13 +513,8 @@ function Main({ ccPairId }: { ccPairId: number }) {
)}
{!isDeleting && (
<DropdownMenuItemWithTooltip
onClick={async () => {
if (shouldConfirmConnectorDeletion) {
setShowDeleteConnectorConfirmModal(true);
return;
}
await scheduleConnectorDeletion();
onClick={() => {
setShowDeleteConnectorConfirmModal(true);
}}
disabled={!statusIsNotCurrentlyActive(ccPair.status)}
className="flex items-center gap-x-2 cursor-pointer px-3 py-2 text-red-600 hover:text-red-700 dark:text-red-400 dark:hover:text-red-300"

View File

@@ -0,0 +1,13 @@
"use client";
import { use } from "react";
import EditGroupPage from "@/refresh-pages/admin/GroupsPage/EditGroupPage";
export default function EditGroupRoute({
params,
}: {
params: Promise<{ id: string }>;
}) {
const { id } = use(params);
return <EditGroupPage groupId={Number(id)} />;
}

View File

@@ -0,0 +1,17 @@
"use client";
import { use } from "react";
import EditGroupPage from "@/refresh-pages/admin/GroupsPage/EditGroupPage";
export default function EditGroupRoute({
params,
}: {
params: Promise<{ id: string }>;
}) {
const { id } = use(params);
const groupId = Number(id);
if (Number.isNaN(groupId)) {
return null;
}
return <EditGroupPage groupId={groupId} />;
}

View File

@@ -0,0 +1 @@
export { default } from "@/refresh-pages/admin/GroupsPage/CreateGroupPage";

View File

@@ -211,10 +211,6 @@ export default function Status() {
message: "Connector created successfully",
type: "success",
},
"connector-deleted": {
message: "Connector deleted successfully",
type: "success",
},
});
return (

View File

@@ -1,31 +0,0 @@
import { ConnectorStatus } from "@/lib/types";
import { ConnectorMultiSelect } from "@/components/ConnectorMultiSelect";
interface ConnectorEditorProps {
selectedCCPairIds: number[];
setSetCCPairIds: (ccPairId: number[]) => void;
allCCPairs: ConnectorStatus<any, any>[];
}
export const ConnectorEditor = ({
selectedCCPairIds,
setSetCCPairIds,
allCCPairs,
}: ConnectorEditorProps) => {
// Filter out public docs, since they don't make sense as part of a group
const privateCCPairs = allCCPairs.filter(
(ccPair) => ccPair.access_type === "private"
);
return (
<ConnectorMultiSelect
name="connectors"
label="Connectors"
connectors={privateCCPairs}
selectedIds={selectedCCPairIds}
onChange={setSetCCPairIds}
placeholder="Search for connectors..."
showError={true}
/>
);
};

View File

@@ -1,87 +0,0 @@
import { User } from "@/lib/types";
import { FiX } from "react-icons/fi";
import InputComboBox from "@/refresh-components/inputs/InputComboBox/InputComboBox";
import Button from "@/refresh-components/buttons/Button";
interface UserEditorProps {
selectedUserIds: string[];
setSelectedUserIds: (userIds: string[]) => void;
allUsers: User[];
existingUsers: User[];
onSubmit?: (users: User[]) => void;
}
export const UserEditor = ({
selectedUserIds,
setSelectedUserIds,
allUsers,
existingUsers,
onSubmit,
}: UserEditorProps) => {
const selectedUsers = allUsers.filter((user) =>
selectedUserIds.includes(user.id)
);
return (
<>
<div className="mb-2 flex flex-wrap gap-x-2">
{selectedUsers.length > 0 &&
selectedUsers.map((selectedUser) => (
<div
key={selectedUser.id}
onClick={() => {
setSelectedUserIds(
selectedUserIds.filter((userId) => userId !== selectedUser.id)
);
}}
className={`
flex
rounded-lg
px-2
py-1
border
border-border
hover:bg-accent-background
cursor-pointer`}
>
{selectedUser.email} <FiX className="ml-1 my-auto" />
</div>
))}
</div>
<div className="flex">
<InputComboBox
placeholder="Search..."
value=""
onChange={() => {}}
onValueChange={(selectedValue) => {
setSelectedUserIds([
...Array.from(new Set([...selectedUserIds, selectedValue])),
]);
}}
options={allUsers
.filter(
(user) =>
!selectedUserIds.includes(user.id) &&
!existingUsers.map((user) => user.id).includes(user.id)
)
.map((user) => ({
label: user.email,
value: user.id,
}))}
strict
leftSearchIcon
/>
{onSubmit && (
// TODO(@raunakab): migrate to opal Button once className/iconClassName is resolved
<Button
className="ml-3 flex-nowrap w-32"
onClick={() => onSubmit(selectedUsers)}
>
Add Users
</Button>
)}
</div>
</>
);
};

View File

@@ -1,153 +0,0 @@
import { Form, Formik } from "formik";
import * as Yup from "yup";
import { toast } from "@/hooks/useToast";
import { ConnectorStatus, User, UserGroup } from "@/lib/types";
import { TextFormField } from "@/components/Field";
import { createUserGroup } from "./lib";
import { UserEditor } from "./UserEditor";
import { ConnectorEditor } from "./ConnectorEditor";
import Modal from "@/refresh-components/Modal";
import Button from "@/refresh-components/buttons/Button";
import Separator from "@/refresh-components/Separator";
import Text from "@/refresh-components/texts/Text";
import { SvgUsers } from "@opal/icons";
import { useVectorDbEnabled } from "@/providers/SettingsProvider";
export interface UserGroupCreationFormProps {
onClose: () => void;
users: User[];
ccPairs: ConnectorStatus<any, any>[];
existingUserGroup?: UserGroup;
}
export default function UserGroupCreationForm({
onClose,
users,
ccPairs,
existingUserGroup,
}: UserGroupCreationFormProps) {
const isUpdate = existingUserGroup !== undefined;
const vectorDbEnabled = useVectorDbEnabled();
const privateCcPairs = ccPairs.filter(
(ccPair) => ccPair.access_type === "private"
);
return (
<Modal open onOpenChange={onClose}>
<Modal.Content>
<Modal.Header
icon={SvgUsers}
title={isUpdate ? "Update a User Group" : "Create a new User Group"}
onClose={onClose}
/>
<Modal.Body>
<Separator />
<Formik
initialValues={{
name: existingUserGroup ? existingUserGroup.name : "",
user_ids: [] as string[],
cc_pair_ids: [] as number[],
}}
validationSchema={Yup.object().shape({
name: Yup.string().required("Please enter a name for the group"),
user_ids: Yup.array().of(Yup.string().required()),
cc_pair_ids: Yup.array().of(Yup.number().required()),
})}
onSubmit={async (values, formikHelpers) => {
formikHelpers.setSubmitting(true);
let response;
response = await createUserGroup(values);
formikHelpers.setSubmitting(false);
if (response.ok) {
toast.success(
isUpdate
? "Successfully updated user group!"
: "Successfully created user group!"
);
onClose();
} else {
const responseJson = await response.json();
const errorMsg = responseJson.detail || responseJson.message;
toast.error(
isUpdate
? `Error updating user group - ${errorMsg}`
: `Error creating user group - ${errorMsg}`
);
}
}}
>
{({ isSubmitting, values, setFieldValue }) => (
<Form>
<TextFormField
name="name"
label="Name:"
placeholder="A name for the User Group"
disabled={isUpdate}
/>
<Separator />
{vectorDbEnabled ? (
<>
<Text as="p" className="font-medium">
Select which private connectors this group has access to:
</Text>
<Text as="p" text02>
All documents indexed by the selected connectors will be
visible to users in this group.
</Text>
<ConnectorEditor
allCCPairs={privateCcPairs}
selectedCCPairIds={values.cc_pair_ids}
setSetCCPairIds={(ccPairsIds) =>
setFieldValue("cc_pair_ids", ccPairsIds)
}
/>
</>
) : (
<Text as="p" text03>
Connectors are not available in Onyx Lite. Redeploy Onyx
with DISABLE_VECTOR_DB=false to index knowledge via
connectors.
</Text>
)}
<Separator />
<Text as="p" className="font-medium">
Select which Users should be a part of this Group.
</Text>
<Text as="p" text02>
All selected users will be able to search through all
documents indexed by the selected connectors.
</Text>
<div className="mb-3 gap-2">
<UserEditor
selectedUserIds={values.user_ids}
setSelectedUserIds={(userIds) =>
setFieldValue("user_ids", userIds)
}
allUsers={users}
existingUsers={[]}
/>
</div>
<div className="flex">
{/* TODO(@raunakab): migrate to opal Button once className/iconClassName is resolved */}
<Button
type="submit"
disabled={isSubmitting}
className="mx-auto w-64"
>
{isUpdate ? "Update!" : "Create!"}
</Button>
</div>
</Form>
)}
</Formik>
</Modal.Body>
</Modal.Content>
</Modal>
);
}

View File

@@ -1,177 +0,0 @@
"use client";
import {
Table,
TableHead,
TableRow,
TableBody,
TableCell,
} from "@/components/ui/table";
import { toast } from "@/hooks/useToast";
import { LoadingAnimation } from "@/components/Loading";
import { ConnectorTitle } from "@/components/admin/connectors/ConnectorTitle";
import { deleteUserGroup } from "./lib";
import { useRouter } from "next/navigation";
import { FiEdit2, FiUser } from "react-icons/fi";
import { User, UserGroup } from "@/lib/types";
import Link from "next/link";
import { DeleteButton } from "@/components/DeleteButton";
import { TableHeader } from "@/components/ui/table";
import Button from "@/refresh-components/buttons/Button";
import { SvgEdit } from "@opal/icons";
const MAX_USERS_TO_DISPLAY = 6;
const SimpleUserDisplay = ({ user }: { user: User }) => {
return (
<div className="flex my-0.5">
<FiUser className="mr-2 my-auto" /> {user.email}
</div>
);
};
interface UserGroupsTableProps {
userGroups: UserGroup[];
refresh: () => void;
}
export const UserGroupsTable = ({
userGroups,
refresh,
}: UserGroupsTableProps) => {
const router = useRouter();
// sort by name for consistent ordering
userGroups.sort((a, b) => {
if (a.name < b.name) {
return -1;
} else if (a.name > b.name) {
return 1;
} else {
return 0;
}
});
return (
<div>
<Table className="overflow-visible">
<TableHeader>
<TableRow>
<TableHead>Name</TableHead>
<TableHead>Connectors</TableHead>
<TableHead>Users</TableHead>
<TableHead>Status</TableHead>
<TableHead>Delete</TableHead>
</TableRow>
</TableHeader>
<TableBody>
{userGroups
.filter((userGroup) => !userGroup.is_up_for_deletion)
.map((userGroup) => {
return (
<TableRow key={userGroup.id}>
<TableCell>
{/* TODO(@raunakab): migrate to opal Button once className/iconClassName is resolved */}
<Button
internal
leftIcon={SvgEdit}
href={`/admin/groups/${userGroup.id}`}
className="truncate"
>
{userGroup.name}
</Button>
</TableCell>
<TableCell>
{userGroup.cc_pairs.length > 0 ? (
<div>
{userGroup.cc_pairs.map((ccPairDescriptor, ind) => {
return (
<div
className={
ind !== userGroup.cc_pairs.length - 1
? "mb-3"
: ""
}
key={ccPairDescriptor.id}
>
<ConnectorTitle
connector={ccPairDescriptor.connector}
ccPairId={ccPairDescriptor.id}
ccPairName={ccPairDescriptor.name}
showMetadata={false}
/>
</div>
);
})}
</div>
) : (
"-"
)}
</TableCell>
<TableCell>
{userGroup.users.length > 0 ? (
<div>
{userGroup.users.length <= MAX_USERS_TO_DISPLAY ? (
userGroup.users.map((user) => {
return (
<SimpleUserDisplay key={user.id} user={user} />
);
})
) : (
<div>
{userGroup.users
.slice(0, MAX_USERS_TO_DISPLAY)
.map((user) => {
return (
<SimpleUserDisplay
key={user.id}
user={user}
/>
);
})}
<div>
+ {userGroup.users.length - MAX_USERS_TO_DISPLAY}{" "}
more
</div>
</div>
)}
</div>
) : (
"-"
)}
</TableCell>
<TableCell>
{userGroup.is_up_to_date ? (
<div className="text-success">Up to date!</div>
) : (
<div className="w-10">
<LoadingAnimation text="Syncing" />
</div>
)}
</TableCell>
<TableCell>
<DeleteButton
onClick={async (event) => {
event.stopPropagation();
const response = await deleteUserGroup(userGroup.id);
if (response.ok) {
toast.success(
`User Group "${userGroup.name}" deleted`
);
} else {
const errorMsg = (await response.json()).detail;
toast.error(
`Failed to delete User Group - ${errorMsg}`
);
}
refresh();
}}
/>
</TableCell>
</TableRow>
);
})}
</TableBody>
</Table>
</div>
);
};

View File

@@ -1,83 +0,0 @@
import { Button } from "@opal/components";
import Modal from "@/refresh-components/Modal";
import { useState } from "react";
import { updateUserGroup } from "./lib";
import { toast } from "@/hooks/useToast";
import { ConnectorStatus, UserGroup } from "@/lib/types";
import { ConnectorMultiSelect } from "@/components/ConnectorMultiSelect";
import { SvgPlus } from "@opal/icons";
export interface AddConnectorFormProps {
ccPairs: ConnectorStatus<any, any>[];
userGroup: UserGroup;
onClose: () => void;
}
export default function AddConnectorForm({
ccPairs,
userGroup,
onClose,
}: AddConnectorFormProps) {
const [selectedCCPairIds, setSelectedCCPairIds] = useState<number[]>([]);
// Filter out ccPairs that are already in the user group and are not private
const availableCCPairs = ccPairs
.filter(
(ccPair) =>
!userGroup.cc_pairs
.map((userGroupCCPair) => userGroupCCPair.id)
.includes(ccPair.cc_pair_id)
)
.filter((ccPair) => ccPair.access_type === "private");
return (
<Modal open onOpenChange={onClose}>
<Modal.Content width="sm" height="sm">
<Modal.Header
icon={SvgPlus}
title="Add New Connector"
onClose={onClose}
/>
<Modal.Body>
<ConnectorMultiSelect
name="connectors"
label="Select Connectors"
connectors={availableCCPairs}
selectedIds={selectedCCPairIds}
onChange={setSelectedCCPairIds}
placeholder="Search for connectors to add..."
showError={false}
/>
<Button
onClick={async () => {
const newCCPairIds = [
...Array.from(
new Set(
userGroup.cc_pairs
.map((ccPair) => ccPair.id)
.concat(selectedCCPairIds)
)
),
];
const response = await updateUserGroup(userGroup.id, {
user_ids: userGroup.users.map((user) => user.id),
cc_pair_ids: newCCPairIds,
});
if (response.ok) {
toast.success("Successfully added connectors to group");
onClose();
} else {
const responseJson = await response.json();
const errorMsg = responseJson.detail || responseJson.message;
toast.error(`Failed to add connectors to group - ${errorMsg}`);
onClose();
}
}}
>
Add Connectors
</Button>
</Modal.Body>
</Modal.Content>
</Modal>
);
}

View File

@@ -1,64 +0,0 @@
import Modal from "@/refresh-components/Modal";
import { updateUserGroup } from "./lib";
import { toast } from "@/hooks/useToast";
import { User, UserGroup } from "@/lib/types";
import { UserEditor } from "../UserEditor";
import { useState } from "react";
import { SvgUserPlus } from "@opal/icons";
export interface AddMemberFormProps {
users: User[];
userGroup: UserGroup;
onClose: () => void;
}
export default function AddMemberForm({
users,
userGroup,
onClose,
}: AddMemberFormProps) {
const [selectedUserIds, setSelectedUserIds] = useState<string[]>([]);
return (
<Modal open onOpenChange={onClose}>
<Modal.Content width="sm" height="sm">
<Modal.Header
icon={SvgUserPlus}
title="Add New User"
onClose={onClose}
/>
<Modal.Body>
<UserEditor
selectedUserIds={selectedUserIds}
setSelectedUserIds={setSelectedUserIds}
allUsers={users}
existingUsers={userGroup.users}
onSubmit={async (selectedUsers) => {
const newUserIds = [
...Array.from(
new Set(
userGroup.users
.map((user) => user.id)
.concat(selectedUsers.map((user) => user.id))
)
),
];
const response = await updateUserGroup(userGroup.id, {
user_ids: newUserIds,
cc_pair_ids: userGroup.cc_pairs.map((ccPair) => ccPair.id),
});
if (response.ok) {
toast.success("Successfully added users to group");
onClose();
} else {
const responseJson = await response.json();
const errorMsg = responseJson.detail || responseJson.message;
toast.error(`Failed to add users to group - ${errorMsg}`);
onClose();
}
}}
/>
</Modal.Body>
</Modal.Content>
</Modal>
);
}

View File

@@ -1,57 +0,0 @@
import { toast } from "@/hooks/useToast";
import CreateRateLimitModal from "../../../../admin/token-rate-limits/CreateRateLimitModal";
import { Scope } from "../../../../admin/token-rate-limits/types";
import { insertGroupTokenRateLimit } from "../../../../admin/token-rate-limits/lib";
import { mutate } from "swr";
interface AddMemberFormProps {
isOpen: boolean;
setIsOpen: (isOpen: boolean) => void;
userGroupId: number;
}
const handleCreateGroupTokenRateLimit = async (
period_hours: number,
token_budget: number,
group_id: number = -1
) => {
const tokenRateLimitArgs = {
enabled: true,
token_budget: token_budget,
period_hours: period_hours,
};
return await insertGroupTokenRateLimit(tokenRateLimitArgs, group_id);
};
export const AddTokenRateLimitForm: React.FC<AddMemberFormProps> = ({
isOpen,
setIsOpen,
userGroupId,
}) => {
const handleSubmit = (
_: Scope,
period_hours: number,
token_budget: number,
group_id: number = -1
) => {
handleCreateGroupTokenRateLimit(period_hours, token_budget, group_id)
.then(() => {
setIsOpen(false);
toast.success("Token rate limit created!");
mutate(`/api/admin/token-rate-limits/user-group/${userGroupId}`);
})
.catch((error) => {
toast.error(error.message);
});
};
return (
<CreateRateLimitModal
isOpen={isOpen}
setIsOpen={setIsOpen}
onSubmit={handleSubmit}
forSpecificScope={Scope.USER_GROUP}
forSpecificUserGroup={userGroupId}
/>
);
};

View File

@@ -1,483 +0,0 @@
"use client";
import { toast } from "@/hooks/useToast";
import { useState } from "react";
import { ConnectorTitle } from "@/components/admin/connectors/ConnectorTitle";
import AddMemberForm from "./AddMemberForm";
import { updateUserGroup, updateCuratorStatus } from "./lib";
import { LoadingAnimation } from "@/components/Loading";
import {
User,
UserGroup,
UserRole,
USER_ROLE_LABELS,
ConnectorStatus,
} from "@/lib/types";
import AddConnectorForm from "./AddConnectorForm";
import Separator from "@/refresh-components/Separator";
import InputSelect from "@/refresh-components/inputs/InputSelect";
import Text from "@/components/ui/text";
import {
Table,
TableBody,
TableCell,
TableHead,
TableHeader,
TableRow,
} from "@/components/ui/table";
import SimpleTooltip from "@/refresh-components/SimpleTooltip";
import { Button } from "@opal/components";
import { Disabled } from "@opal/core";
import { DeleteButton } from "@/components/DeleteButton";
import { Bubble } from "@/components/Bubble";
import { BookmarkIcon, RobotIcon } from "@/components/icons/icons";
import { AddTokenRateLimitForm } from "./AddTokenRateLimitForm";
import { GenericTokenRateLimitTable } from "@/app/admin/token-rate-limits/TokenRateLimitTables";
import { useUser } from "@/providers/UserProvider";
import GenericConfirmModal from "@/components/modals/GenericConfirmModal";
import Spacer from "@/refresh-components/Spacer";
interface GroupDisplayProps {
users: User[];
ccPairs: ConnectorStatus<any, any>[];
userGroup: UserGroup;
refreshUserGroup: () => void;
}
const UserRoleDropdown = ({
user,
group,
onSuccess,
onError,
isAdmin,
}: {
user: User;
group: UserGroup;
onSuccess: () => void;
onError: (message: string) => void;
isAdmin: boolean;
}) => {
const [localRole, setLocalRole] = useState(() => {
if (user.role === UserRole.CURATOR) {
return group.curator_ids.includes(user.id)
? UserRole.CURATOR
: UserRole.BASIC;
}
return user.role;
});
const [isSettingRole, setIsSettingRole] = useState(false);
const [showDemoteConfirm, setShowDemoteConfirm] = useState(false);
const [pendingRoleChange, setPendingRoleChange] = useState<string | null>(
null
);
const { user: currentUser } = useUser();
const applyRoleChange = async (value: string) => {
if (value === localRole) return;
if (value === UserRole.BASIC || value === UserRole.CURATOR) {
setIsSettingRole(true);
setLocalRole(value);
try {
const response = await updateCuratorStatus(group.id, {
user_id: user.id,
is_curator: value === UserRole.CURATOR,
});
if (response.ok) {
onSuccess();
user.role = value;
} else {
const errorData = await response.json();
throw new Error(errorData.detail || "Failed to update user role");
}
} catch (error: any) {
onError(error.message);
setLocalRole(user.role);
} finally {
setIsSettingRole(false);
}
}
};
const handleChange = (value: string) => {
if (value === UserRole.BASIC && user.id === currentUser?.id) {
setPendingRoleChange(value);
setShowDemoteConfirm(true);
} else {
applyRoleChange(value);
}
};
const isEditable =
user.role === UserRole.BASIC || user.role === UserRole.CURATOR;
return (
<>
{/* Confirmation modal - only shown when users try to demote themselves */}
{showDemoteConfirm && pendingRoleChange && (
<GenericConfirmModal
title="Remove Yourself as a Curator for this Group?"
message="Are you sure you want to change your role to Basic? This will remove your ability to curate this group."
confirmText="Yes, set me to Basic"
onClose={() => {
// Cancel the role change if user dismisses modal
setShowDemoteConfirm(false);
setPendingRoleChange(null);
}}
onConfirm={() => {
// Apply the role change if user confirms
setShowDemoteConfirm(false);
applyRoleChange(pendingRoleChange);
setPendingRoleChange(null);
}}
/>
)}
{isEditable ? (
<InputSelect
value={localRole}
onValueChange={handleChange}
disabled={isSettingRole}
>
<InputSelect.Trigger placeholder="Select role" />
<InputSelect.Content>
<InputSelect.Item value={UserRole.BASIC}>Basic</InputSelect.Item>
<InputSelect.Item value={UserRole.CURATOR}>
Curator
</InputSelect.Item>
</InputSelect.Content>
</InputSelect>
) : (
<div>{USER_ROLE_LABELS[localRole]}</div>
)}
</>
);
};
export const GroupDisplay = ({
users,
ccPairs,
userGroup,
refreshUserGroup,
}: GroupDisplayProps) => {
const [addMemberFormVisible, setAddMemberFormVisible] = useState(false);
const [addConnectorFormVisible, setAddConnectorFormVisible] = useState(false);
const [addRateLimitFormVisible, setAddRateLimitFormVisible] = useState(false);
const { isAdmin } = useUser();
const onRoleChangeSuccess = () =>
toast.success("User role updated successfully!");
const onRoleChangeError = (errorMsg: string) =>
toast.error(`Unable to update user role - ${errorMsg}`);
return (
<div>
<div className="text-sm mb-3 flex">
<Text className="mr-1">Status:</Text>{" "}
{userGroup.is_up_to_date ? (
<div className="text-success font-bold">Up to date</div>
) : (
<div className="text-accent font-bold">
<LoadingAnimation text="Syncing" />
</div>
)}
</div>
<Separator />
<div className="flex w-full">
<h2 className="text-xl font-bold">Users</h2>
</div>
<div className="mt-2">
{userGroup.users.length > 0 ? (
<>
<Table className="overflow-visible">
<TableHeader>
<TableRow>
<TableHead>Email</TableHead>
<TableHead>Role</TableHead>
<TableHead className="flex w-full">
<div className="ml-auto">Remove User</div>
</TableHead>
</TableRow>
</TableHeader>
<TableBody>
{userGroup.users.map((groupMember) => {
return (
<TableRow key={groupMember.id}>
<TableCell className="whitespace-normal break-all">
{groupMember.email}
</TableCell>
<TableCell>
<UserRoleDropdown
user={groupMember}
group={userGroup}
onSuccess={onRoleChangeSuccess}
onError={onRoleChangeError}
isAdmin={isAdmin}
/>
</TableCell>
<TableCell>
<div className="flex w-full">
<div className="ml-auto m-2">
{(isAdmin ||
!userGroup.curator_ids.includes(
groupMember.id
)) && (
<DeleteButton
onClick={async () => {
const response = await updateUserGroup(
userGroup.id,
{
user_ids: userGroup.users
.filter(
(userGroupUser) =>
userGroupUser.id !== groupMember.id
)
.map(
(userGroupUser) => userGroupUser.id
),
cc_pair_ids: userGroup.cc_pairs.map(
(ccPair) => ccPair.id
),
}
);
if (response.ok) {
toast.success(
"Successfully removed user from group"
);
} else {
const responseJson = await response.json();
const errorMsg =
responseJson.detail ||
responseJson.message;
toast.error(
`Error removing user from group - ${errorMsg}`
);
}
refreshUserGroup();
}}
/>
)}
</div>
</div>
</TableCell>
</TableRow>
);
})}
</TableBody>
</Table>
</>
) : (
<div className="text-sm">No users in this group...</div>
)}
</div>
<SimpleTooltip
tooltip="Cannot update group while sync is occurring"
disabled={userGroup.is_up_to_date}
>
<Disabled disabled={!userGroup.is_up_to_date}>
<Button
onClick={() => {
if (userGroup.is_up_to_date) {
setAddMemberFormVisible(true);
}
}}
>
Add Users
</Button>
</Disabled>
</SimpleTooltip>
{addMemberFormVisible && (
<AddMemberForm
users={users}
userGroup={userGroup}
onClose={() => {
setAddMemberFormVisible(false);
refreshUserGroup();
}}
/>
)}
<Separator />
<h2 className="text-xl font-bold mt-8">Connectors</h2>
<div className="mt-2">
{userGroup.cc_pairs.length > 0 ? (
<>
<Table className="overflow-visible">
<TableHeader>
<TableRow>
<TableHead>Connector</TableHead>
<TableHead className="flex w-full">
<div className="ml-auto">Remove Connector</div>
</TableHead>
</TableRow>
</TableHeader>
<TableBody>
{userGroup.cc_pairs.map((ccPair) => {
return (
<TableRow key={ccPair.id}>
<TableCell className="whitespace-normal break-all">
<ConnectorTitle
connector={ccPair.connector}
ccPairId={ccPair.id}
ccPairName={ccPair.name}
/>
</TableCell>
<TableCell>
<div className="flex w-full">
<div className="ml-auto m-2">
<DeleteButton
onClick={async () => {
const response = await updateUserGroup(
userGroup.id,
{
user_ids: userGroup.users.map(
(userGroupUser) => userGroupUser.id
),
cc_pair_ids: userGroup.cc_pairs
.filter(
(userGroupCCPair) =>
userGroupCCPair.id != ccPair.id
)
.map((ccPair) => ccPair.id),
}
);
if (response.ok) {
toast.success(
"Successfully removed connector from group"
);
} else {
const responseJson = await response.json();
const errorMsg =
responseJson.detail || responseJson.message;
toast.error(
`Error removing connector from group - ${errorMsg}`
);
}
refreshUserGroup();
}}
/>
</div>
</div>
</TableCell>
</TableRow>
);
})}
</TableBody>
</Table>
</>
) : (
<div className="text-sm">No connectors in this group...</div>
)}
</div>
<SimpleTooltip
tooltip="Cannot update group while sync is occurring"
disabled={userGroup.is_up_to_date}
>
<Disabled disabled={!userGroup.is_up_to_date}>
<Button
onClick={() => {
if (userGroup.is_up_to_date) {
setAddConnectorFormVisible(true);
}
}}
>
Add Connectors
</Button>
</Disabled>
</SimpleTooltip>
{addConnectorFormVisible && (
<AddConnectorForm
ccPairs={ccPairs}
userGroup={userGroup}
onClose={() => {
setAddConnectorFormVisible(false);
refreshUserGroup();
}}
/>
)}
<Separator />
<h2 className="text-xl font-bold mt-8 mb-2">Document Sets</h2>
<div>
{userGroup.document_sets.length > 0 ? (
<div className="flex flex-wrap gap-2">
{userGroup.document_sets.map((documentSet) => {
return (
<Bubble isSelected key={documentSet.id}>
<div className="flex">
<BookmarkIcon />
<Text className="ml-1">{documentSet.name}</Text>
</div>
</Bubble>
);
})}
</div>
) : (
<>
<Text>No document sets in this group...</Text>
</>
)}
</div>
<Separator />
<h2 className="text-xl font-bold mt-8 mb-2">Agents</h2>
<div>
{userGroup.document_sets.length > 0 ? (
<div className="flex flex-wrap gap-2">
{userGroup.personas.map((persona) => {
return (
<Bubble isSelected key={persona.id}>
<div className="flex">
<RobotIcon />
<Text className="ml-1">{persona.name}</Text>
</div>
</Bubble>
);
})}
</div>
) : (
<>
<Text>No Agents in this group...</Text>
</>
)}
</div>
<Separator />
<h2 className="text-xl font-bold mt-8 mb-2">Token Rate Limits</h2>
<AddTokenRateLimitForm
isOpen={addRateLimitFormVisible}
setIsOpen={setAddRateLimitFormVisible}
userGroupId={userGroup.id}
/>
<GenericTokenRateLimitTable
fetchUrl={`/api/admin/token-rate-limits/user-group/${userGroup.id}`}
hideHeading
isAdmin={isAdmin}
/>
{isAdmin && (
<>
<Spacer rem={0.75} />
<Button onClick={() => setAddRateLimitFormVisible(true)}>
Create a Token Rate Limit
</Button>
</>
)}
</div>
);
};

View File

@@ -1,12 +0,0 @@
import { useUserGroups } from "@/lib/hooks";
export const useSpecificUserGroup = (groupId: string) => {
const { data, isLoading, error, refreshUserGroups } = useUserGroups();
const userGroup = data?.find((group) => group.id.toString() === groupId);
return {
userGroup,
isLoading,
error,
refreshUserGroup: refreshUserGroups,
};
};

View File

@@ -1,29 +0,0 @@
import { UserGroupUpdate, SetCuratorRequest } from "../types";
export const updateUserGroup = async (
groupId: number,
userGroup: UserGroupUpdate
) => {
const url = `/api/manage/admin/user-group/${groupId}`;
return await fetch(url, {
method: "PATCH",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(userGroup),
});
};
export const updateCuratorStatus = async (
groupId: number,
curatorRequest: SetCuratorRequest
) => {
const url = `/api/manage/admin/user-group/${groupId}/set-curator`;
return await fetch(url, {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(curatorRequest),
});
};

View File

@@ -1,87 +0,0 @@
"use client";
import { use } from "react";
import { GroupDisplay } from "./GroupDisplay";
import { useSpecificUserGroup } from "./hook";
import { ThreeDotsLoader } from "@/components/Loading";
import { useConnectorStatus } from "@/lib/hooks";
import useUsers from "@/hooks/useUsers";
import { ADMIN_ROUTES } from "@/lib/admin-routes";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import { useVectorDbEnabled } from "@/providers/SettingsProvider";
const route = ADMIN_ROUTES.GROUPS;
function Main({ groupId }: { groupId: string }) {
const vectorDbEnabled = useVectorDbEnabled();
const {
userGroup,
isLoading: userGroupIsLoading,
error: userGroupError,
refreshUserGroup,
} = useSpecificUserGroup(groupId);
const {
data: users,
isLoading: userIsLoading,
error: usersError,
} = useUsers({ includeApiKeys: true });
const {
data: ccPairs,
isLoading: isCCPairsLoading,
error: ccPairsError,
} = useConnectorStatus(30000, vectorDbEnabled);
if (
userGroupIsLoading ||
userIsLoading ||
(vectorDbEnabled && isCCPairsLoading)
) {
return (
<div className="h-full">
<div className="my-auto">
<ThreeDotsLoader />
</div>
</div>
);
}
if (!userGroup || userGroupError) {
return <div>Error loading user group</div>;
}
if (!users || usersError) {
return <div>Error loading users</div>;
}
if (vectorDbEnabled && (!ccPairs || ccPairsError)) {
return <div>Error loading connectors</div>;
}
return (
<>
<SettingsLayouts.Header
icon={route.icon}
title={userGroup.name || "Unknown"}
separator
backButton
/>
<SettingsLayouts.Body>
<GroupDisplay
users={users.accepted}
ccPairs={ccPairs ?? []}
userGroup={userGroup}
refreshUserGroup={refreshUserGroup}
/>
</SettingsLayouts.Body>
</>
);
}
export default function Page(props: { params: Promise<{ groupId: string }> }) {
const params = use(props.params);
return (
<SettingsLayouts.Root>
<Main groupId={params.groupId} />
</SettingsLayouts.Root>
);
}

View File

@@ -0,0 +1,13 @@
"use client";
import { use } from "react";
import EditGroupPage from "@/refresh-pages/admin/GroupsPage/EditGroupPage";
export default function EditGroupRoute({
params,
}: {
params: Promise<{ id: string }>;
}) {
const { id } = use(params);
return <EditGroupPage groupId={Number(id)} />;
}

View File

@@ -1,17 +0,0 @@
import { UserGroupCreation } from "./types";
export const createUserGroup = async (userGroup: UserGroupCreation) => {
return fetch("/api/manage/admin/user-group", {
method: "POST",
headers: {
"Content-Type": "application/json",
},
body: JSON.stringify(userGroup),
});
};
export const deleteUserGroup = async (userGroupId: number) => {
return fetch(`/api/manage/admin/user-group/${userGroupId}`, {
method: "DELETE",
});
};

View File

@@ -1,89 +1 @@
"use client";
import { UserGroupsTable } from "./UserGroupsTable";
import UserGroupCreationForm from "./UserGroupCreationForm";
import { useState } from "react";
import { ThreeDotsLoader } from "@/components/Loading";
import { useConnectorStatus, useUserGroups } from "@/lib/hooks";
import useUsers from "@/hooks/useUsers";
import { useUser } from "@/providers/UserProvider";
import CreateButton from "@/refresh-components/buttons/CreateButton";
import { ADMIN_ROUTES } from "@/lib/admin-routes";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import { useVectorDbEnabled } from "@/providers/SettingsProvider";
const route = ADMIN_ROUTES.GROUPS;
function Main() {
const [showForm, setShowForm] = useState(false);
const vectorDbEnabled = useVectorDbEnabled();
const { data, isLoading, error, refreshUserGroups } = useUserGroups();
const {
data: ccPairs,
isLoading: isCCPairsLoading,
error: ccPairsError,
} = useConnectorStatus(30000, vectorDbEnabled);
const {
data: users,
isLoading: userIsLoading,
error: usersError,
} = useUsers({ includeApiKeys: true });
const { isAdmin } = useUser();
if (isLoading || (vectorDbEnabled && isCCPairsLoading) || userIsLoading) {
return <ThreeDotsLoader />;
}
if (error || !data) {
return <div className="text-red-600">Error loading users</div>;
}
if (vectorDbEnabled && (ccPairsError || !ccPairs)) {
return <div className="text-red-600">Error loading connectors</div>;
}
if (usersError || !users) {
return <div className="text-red-600">Error loading users</div>;
}
return (
<>
{isAdmin && (
<CreateButton onClick={() => setShowForm(true)}>
Create New User Group
</CreateButton>
)}
{data.length > 0 && (
<div className="mt-2">
<UserGroupsTable userGroups={data} refresh={refreshUserGroups} />
</div>
)}
{showForm && (
<UserGroupCreationForm
onClose={() => {
refreshUserGroups();
setShowForm(false);
}}
users={users.accepted}
ccPairs={ccPairs ?? []}
/>
)}
</>
);
}
export default function Page() {
return (
<SettingsLayouts.Root>
<SettingsLayouts.Header icon={route.icon} title={route.title} separator />
<SettingsLayouts.Body>
<Main />
</SettingsLayouts.Body>
</SettingsLayouts.Root>
);
}
export { default } from "@/refresh-pages/admin/GroupsPage";

View File

@@ -1,15 +0,0 @@
export interface UserGroupUpdate {
user_ids: string[];
cc_pair_ids: number[];
}
export interface SetCuratorRequest {
user_id: string;
is_curator: boolean;
}
export interface UserGroupCreation {
name: string;
user_ids: string[];
cc_pair_ids: number[];
}

View File

@@ -1,15 +0,0 @@
"use client";
import useSWR from "swr";
import { errorHandlingFetcher } from "@/lib/fetcher";
import { HookResponse } from "@/refresh-pages/admin/HooksPage/interfaces";
export function useHooks() {
const { data, isLoading, error, mutate } = useSWR<HookResponse[]>(
"/api/admin/hooks",
errorHandlingFetcher,
{ revalidateOnFocus: false }
);
return { hooks: data, isLoading, error, mutate };
}

View File

@@ -27,20 +27,17 @@ export async function scheduleDeletionJobForConnector(
export async function deleteCCPair(
connectorId: number,
credentialId: number,
onCompletion: () => void
onCompletion?: () => void
) {
const deletionScheduleError = await scheduleDeletionJobForConnector(
connectorId,
credentialId
);
if (deletionScheduleError) {
toast.error(
"Failed to schedule deletion of connector - " + deletionScheduleError
);
} else {
toast.success("Scheduled deletion of connector!");
throw new Error(deletionScheduleError);
}
onCompletion();
toast.success("Scheduled deletion of connector!");
onCompletion?.();
}
export function isCurrentlyDeleting(

View File

@@ -115,9 +115,10 @@ export async function refreshToken(
}
export function getUserDisplayName(user: User | null): string {
// Prioritize custom personal name if set
// Prioritize custom personal name, if set.
if (!!user?.personalization?.name) return user.personalization.name;
// Then, prioritize personal email
// Then, prioritize personal email.
if (!!user?.email) {
const atIndex = user.email.indexOf("@");
if (atIndex > 0) {
@@ -129,6 +130,14 @@ export function getUserDisplayName(user: User | null): string {
return "Anonymous";
}
export function getUserEmail(user: User | null): string {
// Prioritize personal email.
if (!!user?.email) return user.email;
// If nothing works, then fall back to anonymous email.
return "anonymous@email.com";
}
/**
* Derive display initials from a user's name or email.
*

View File

@@ -1,11 +1,11 @@
import { SvgUser } from "@opal/icons";
import { DEFAULT_AVATAR_SIZE_PX } from "@/lib/constants";
import { getUserInitials } from "@/lib/user";
import { getUserEmail, getUserInitials } from "@/lib/user";
import Text from "@/refresh-components/texts/Text";
import type { User } from "@/lib/types";
export interface UserAvatarProps {
user: User;
user: User | null;
size?: number;
}
@@ -13,16 +13,17 @@ export default function UserAvatar({
user,
size = DEFAULT_AVATAR_SIZE_PX,
}: UserAvatarProps) {
const initials = getUserInitials(
user.personalization?.name ?? null,
user.email
const userEmail = getUserEmail(user);
const userInitials = getUserInitials(
user?.personalization?.name ?? null,
userEmail
);
if (!initials) {
if (!userInitials) {
return (
<div
role="img"
aria-label={`${user.email} avatar`}
aria-label={`${userEmail} avatar`}
className="flex items-center justify-center rounded-full bg-background-tint-01"
style={{ width: size, height: size }}
>
@@ -34,7 +35,7 @@ export default function UserAvatar({
return (
<div
role="img"
aria-label={`${user.email} avatar`}
aria-label={`${userEmail} avatar`}
className="flex items-center justify-center rounded-full bg-background-neutral-inverted-00"
style={{ width: size, height: size }}
>
@@ -45,7 +46,7 @@ export default function UserAvatar({
className="select-none"
style={{ fontSize: size * 0.4 }}
>
{initials}
{userInitials}
</Text>
</div>
);

View File

@@ -1,7 +1,7 @@
"use client";
import React from "react";
import type { ButtonType, IconFunctionComponent, IconProps } from "@opal/types";
import type { ButtonType, IconFunctionComponent } from "@opal/types";
import type { Route } from "next";
import { Interactive } from "@opal/core";
import { ContentAction } from "@opal/layouts";
@@ -19,7 +19,7 @@ export interface SidebarTabProps {
onClick?: React.MouseEventHandler<HTMLElement>;
href?: string;
type?: ButtonType;
icon?: React.FunctionComponent<IconProps>;
icon?: IconFunctionComponent;
children?: React.ReactNode;
rightChildren?: React.ReactNode;
}

View File

@@ -134,6 +134,7 @@ export default function ActionLineItem({
icon={SvgSlash}
onClick={noProp(onToggle)}
internal
aria-label={disabled ? "Enable" : "Disable"}
className={cn(
!disabled && "invisible group-hover/LineItem:visible",
// Hide when showing source count (it has its own hover behavior)
@@ -180,6 +181,11 @@ export default function ActionLineItem({
{isSearchToolAndNotInProject && (
<Button
aria-label={
isSearchToolWithNoConnectors
? "Add Connectors"
: "Configure Connectors"
}
icon={
isSearchToolWithNoConnectors ? SvgSettings : SvgChevronRight
}

View File

@@ -848,18 +848,18 @@ export default function ActionsPopover({
if (toolId === searchToolId) {
if (wasDisabled) {
// Enabling - restore previous sources or enable all (no persistence)
// Enabling - restore previous sources or enable all (persisted to localStorage)
const previous = previouslyEnabledSourcesRef.current;
if (previous.length > 0) {
setSelectedSources(previous);
enableSources(previous);
} else {
setSelectedSources(configuredSources);
baseEnableAllSources();
}
previouslyEnabledSourcesRef.current = [];
} else {
// Disabling - store current sources then disable all (no persistence)
// Disabling - store current sources then disable all (persisted to localStorage)
previouslyEnabledSourcesRef.current = [...selectedSources];
setSelectedSources([]);
baseDisableAllSources();
}
}
};

View File

@@ -17,7 +17,12 @@ import { toast } from "@/hooks/useToast";
import { errorHandlingFetcher } from "@/lib/fetcher";
import useAdminUsers from "@/hooks/useAdminUsers";
import type { ApiKeyDescriptor, MemberRow } from "./interfaces";
import { createGroup } from "./svc";
import {
createGroup,
updateAgentGroupSharing,
updateDocSetGroupSharing,
saveTokenLimits,
} from "./svc";
import { apiKeyToMemberRow, memberTableColumns, PAGE_SIZE } from "./shared";
import SharedGroupResources from "@/refresh-pages/admin/GroupsPage/SharedGroupResources";
import TokenLimitSection from "./TokenLimitSection";
@@ -62,7 +67,14 @@ function CreateGroupPage() {
setIsSubmitting(true);
try {
await createGroup(trimmed, selectedUserIds, selectedCcPairIds);
const groupId = await createGroup(
trimmed,
selectedUserIds,
selectedCcPairIds
);
await updateAgentGroupSharing(groupId, [], selectedAgentIds);
await updateDocSetGroupSharing(groupId, [], selectedDocSetIds);
await saveTokenLimits(groupId, tokenLimits, []);
toast.success(`Group "${trimmed}" created`);
router.push("/admin/groups");
} catch (e) {

View File

@@ -0,0 +1,517 @@
"use client";
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
import { useRouter } from "next/navigation";
import useSWR, { useSWRConfig } from "swr";
import { Table, Button } from "@opal/components";
import { IllustrationContent } from "@opal/layouts";
import { SvgUsers, SvgTrash, SvgMinusCircle, SvgPlusCircle } from "@opal/icons";
import IconButton from "@/refresh-components/buttons/IconButton";
import Card from "@/refresh-components/cards/Card";
import * as InputLayouts from "@/layouts/input-layouts";
import SvgNoResult from "@opal/illustrations/no-result";
import * as SettingsLayouts from "@/layouts/settings-layouts";
import { Section } from "@/layouts/general-layouts";
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
import Text from "@/refresh-components/texts/Text";
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationModalLayout";
import Separator from "@/refresh-components/Separator";
import { toast } from "@/hooks/useToast";
import { errorHandlingFetcher } from "@/lib/fetcher";
import useAdminUsers from "@/hooks/useAdminUsers";
import type { UserGroup } from "@/lib/types";
import type {
ApiKeyDescriptor,
MemberRow,
TokenRateLimitDisplay,
} from "./interfaces";
import {
apiKeyToMemberRow,
baseColumns,
memberTableColumns,
tc,
PAGE_SIZE,
} from "./shared";
import {
USER_GROUP_URL,
renameGroup,
updateGroup,
deleteGroup,
updateAgentGroupSharing,
updateDocSetGroupSharing,
saveTokenLimits,
} from "./svc";
import SharedGroupResources from "@/refresh-pages/admin/GroupsPage/SharedGroupResources";
import TokenLimitSection from "./TokenLimitSection";
import type { TokenLimit } from "./TokenLimitSection";
const addModeColumns = memberTableColumns;
// ---------------------------------------------------------------------------
// Component
// ---------------------------------------------------------------------------
interface EditGroupPageProps {
groupId: number;
}
function EditGroupPage({ groupId }: EditGroupPageProps) {
const router = useRouter();
const { mutate } = useSWRConfig();
// Fetch the group data — poll every 5s while syncing so the UI updates
// automatically when the backend finishes processing the previous edit.
const {
data: groups,
isLoading: groupLoading,
error: groupError,
} = useSWR<UserGroup[]>(USER_GROUP_URL, errorHandlingFetcher, {
refreshInterval: (latestData) => {
const g = latestData?.find((g) => g.id === groupId);
return g && !g.is_up_to_date ? 5000 : 0;
},
});
const group = useMemo(
() => groups?.find((g) => g.id === groupId) ?? null,
[groups, groupId]
);
const isSyncing = group != null && !group.is_up_to_date;
// Fetch token rate limits for this group
const { data: tokenRateLimits, isLoading: tokenLimitsLoading } = useSWR<
TokenRateLimitDisplay[]
>(`/api/admin/token-rate-limits/user-group/${groupId}`, errorHandlingFetcher);
// Form state
const [groupName, setGroupName] = useState("");
const [selectedUserIds, setSelectedUserIds] = useState<string[]>([]);
const [searchTerm, setSearchTerm] = useState("");
const [isSubmitting, setIsSubmitting] = useState(false);
const isSubmittingRef = useRef(false);
const [selectedCcPairIds, setSelectedCcPairIds] = useState<number[]>([]);
const [selectedDocSetIds, setSelectedDocSetIds] = useState<number[]>([]);
const [selectedAgentIds, setSelectedAgentIds] = useState<number[]>([]);
const [tokenLimits, setTokenLimits] = useState<TokenLimit[]>([
{ tokenBudget: null, periodHours: null },
]);
const [showDeleteModal, setShowDeleteModal] = useState(false);
const [isDeleting, setIsDeleting] = useState(false);
const [initialized, setInitialized] = useState(false);
const [isAddingMembers, setIsAddingMembers] = useState(false);
const initialAgentIdsRef = useRef<number[]>([]);
const initialDocSetIdsRef = useRef<number[]>([]);
// Users and API keys
const { users, isLoading: usersLoading, error: usersError } = useAdminUsers();
const {
data: apiKeys,
isLoading: apiKeysLoading,
error: apiKeysError,
} = useSWR<ApiKeyDescriptor[]>("/api/admin/api-key", errorHandlingFetcher);
const isLoading =
groupLoading || usersLoading || apiKeysLoading || tokenLimitsLoading;
const error = groupError ?? usersError ?? apiKeysError;
// Pre-populate form when group data loads
useEffect(() => {
if (group && !initialized) {
setGroupName(group.name);
setSelectedUserIds(group.users.map((u) => u.id));
setSelectedCcPairIds(group.cc_pairs.map((cc) => cc.id));
const docSetIds = group.document_sets.map((ds) => ds.id);
setSelectedDocSetIds(docSetIds);
initialDocSetIdsRef.current = docSetIds;
const agentIds = group.personas.map((p) => p.id);
setSelectedAgentIds(agentIds);
initialAgentIdsRef.current = agentIds;
setInitialized(true);
}
}, [group, initialized]);
// Pre-populate token limits when fetched
useEffect(() => {
if (tokenRateLimits && tokenRateLimits.length > 0) {
setTokenLimits(
tokenRateLimits.map((trl) => ({
tokenBudget: trl.token_budget,
periodHours: trl.period_hours,
}))
);
}
}, [tokenRateLimits]);
const allRows = useMemo(() => {
const activeUsers = users.filter((u) => u.is_active);
const serviceAccountRows = (apiKeys ?? []).map(apiKeyToMemberRow);
return [...activeUsers, ...serviceAccountRows];
}, [users, apiKeys]);
const memberRows = useMemo(() => {
const selected = new Set(selectedUserIds);
return allRows.filter((r) => selected.has(r.id ?? r.email));
}, [allRows, selectedUserIds]);
const currentRowSelection = useMemo(() => {
const sel: Record<string, boolean> = {};
for (const id of selectedUserIds) sel[id] = true;
return sel;
}, [selectedUserIds]);
const handleRemoveMember = useCallback((userId: string) => {
setSelectedUserIds((prev) => prev.filter((id) => id !== userId));
}, []);
const memberColumns = useMemo(
() => [
...baseColumns,
tc.actions({
showSorting: false,
showColumnVisibility: false,
cell: (row: MemberRow) => (
<IconButton
icon={SvgMinusCircle}
tertiary
onClick={(e) => {
e.stopPropagation();
handleRemoveMember(row.id ?? row.email);
}}
/>
),
}),
],
[handleRemoveMember]
);
// IDs of members not visible in the add-mode table (e.g. inactive users).
// We preserve these so they aren't silently removed when the table fires
// onSelectionChange with only the visible rows.
const hiddenMemberIds = useMemo(() => {
const visibleIds = new Set(allRows.map((r) => r.id ?? r.email));
return selectedUserIds.filter((id) => !visibleIds.has(id));
}, [allRows, selectedUserIds]);
// Guard onSelectionChange: ignore updates until the form is fully initialized.
// Without this, TanStack fires onSelectionChange before all rows are loaded,
// which overwrites selectedUserIds with a partial set.
const handleSelectionChange = useCallback(
(ids: string[]) => {
if (!initialized) return;
setSelectedUserIds([...ids, ...hiddenMemberIds]);
},
[initialized, hiddenMemberIds]
);
async function handleSave() {
if (isSubmittingRef.current) return;
const trimmed = groupName.trim();
if (!trimmed) {
toast.error("Group name is required");
return;
}
// Re-fetch group to check sync status before saving
const freshGroups = await fetch(USER_GROUP_URL).then((r) => r.json());
const freshGroup = freshGroups.find((g: UserGroup) => g.id === groupId);
if (freshGroup && !freshGroup.is_up_to_date) {
toast.error(
"This group is currently syncing. Please wait a moment and try again."
);
return;
}
isSubmittingRef.current = true;
setIsSubmitting(true);
try {
// Rename if name changed
if (group && trimmed !== group.name) {
await renameGroup(group.id, trimmed);
}
// Update members and cc_pairs
await updateGroup(groupId, selectedUserIds, selectedCcPairIds);
// Update agent sharing (add/remove this group from changed agents)
await updateAgentGroupSharing(
groupId,
initialAgentIdsRef.current,
selectedAgentIds
);
// Update document set sharing (add/remove this group from changed doc sets)
await updateDocSetGroupSharing(
groupId,
initialDocSetIdsRef.current,
selectedDocSetIds
);
// Save token rate limits (create/update/delete)
await saveTokenLimits(groupId, tokenLimits, tokenRateLimits ?? []);
// Update refs so subsequent saves diff correctly
initialAgentIdsRef.current = selectedAgentIds;
initialDocSetIdsRef.current = selectedDocSetIds;
mutate(USER_GROUP_URL);
mutate(`/api/admin/token-rate-limits/user-group/${groupId}`);
toast.success(`Group "${trimmed}" updated`);
router.push("/admin/groups");
} catch (e) {
toast.error(e instanceof Error ? e.message : "Failed to update group");
} finally {
isSubmittingRef.current = false;
setIsSubmitting(false);
}
}
async function handleDelete() {
setIsDeleting(true);
try {
await deleteGroup(groupId);
mutate(USER_GROUP_URL);
toast.success(`Group "${group?.name}" deleted`);
router.push("/admin/groups");
} catch (e) {
toast.error(e instanceof Error ? e.message : "Failed to delete group");
} finally {
setIsDeleting(false);
setShowDeleteModal(false);
}
}
// 404 state
if (!isLoading && !error && !group) {
return (
<SettingsLayouts.Root width="sm">
<SettingsLayouts.Header
icon={SvgUsers}
title="Group Not Found"
separator
/>
<SettingsLayouts.Body>
<IllustrationContent
illustration={SvgNoResult}
title="Group not found"
description="This group doesn't exist or may have been deleted."
/>
</SettingsLayouts.Body>
</SettingsLayouts.Root>
);
}
const headerActions = (
<Section flexDirection="row" gap={0.5} width="auto" height="auto">
<Button
prominence="tertiary"
onClick={() => router.push("/admin/groups")}
>
Cancel
</Button>
<Button
onClick={handleSave}
disabled={!groupName.trim() || isSubmitting || isSyncing}
tooltip={
isSyncing
? "Document embeddings are being updated due to recent changes to this group."
: undefined
}
>
{isSubmitting ? "Saving..." : isSyncing ? "Syncing..." : "Save Changes"}
</Button>
</Section>
);
return (
<>
<SettingsLayouts.Root width="sm">
<SettingsLayouts.Header
icon={SvgUsers}
title="Edit Group"
separator
rightChildren={headerActions}
/>
<SettingsLayouts.Body>
{isLoading && <SimpleLoader />}
{error && (
<Text as="p" secondaryBody text03>
Failed to load group data.
</Text>
)}
{!isLoading && !error && group && (
<>
{/* Group Name */}
<Section
gap={0.5}
height="auto"
alignItems="stretch"
justifyContent="start"
>
<Text mainUiBody text04>
Group Name
</Text>
<InputTypeIn
placeholder="Name your group"
value={groupName}
onChange={(e) => setGroupName(e.target.value)}
/>
</Section>
<Separator noPadding />
{/* Members table */}
<Section
gap={0.75}
height="auto"
alignItems="stretch"
justifyContent="start"
>
<Section
flexDirection="row"
gap={0.5}
height="auto"
alignItems="center"
justifyContent="start"
>
<InputTypeIn
value={searchTerm}
onChange={(e) => setSearchTerm(e.target.value)}
placeholder={
isAddingMembers
? "Search users and accounts..."
: "Search members..."
}
leftSearchIcon
className="flex-1"
/>
{isAddingMembers ? (
<Button
prominence="secondary"
onClick={() => setIsAddingMembers(false)}
>
Done
</Button>
) : (
<Button
prominence="tertiary"
icon={SvgPlusCircle}
onClick={() => setIsAddingMembers(true)}
>
Add
</Button>
)}
</Section>
{isAddingMembers ? (
<Table
key="add-members"
data={allRows as MemberRow[]}
columns={addModeColumns}
getRowId={(row) => row.id ?? row.email}
pageSize={PAGE_SIZE}
searchTerm={searchTerm}
selectionBehavior="multi-select"
initialRowSelection={currentRowSelection}
onSelectionChange={handleSelectionChange}
footer={{}}
emptyState={
<IllustrationContent
illustration={SvgNoResult}
title="No users found"
description="No users match your search."
/>
}
/>
) : (
<Table
data={memberRows}
columns={memberColumns}
getRowId={(row) => row.id ?? row.email}
pageSize={PAGE_SIZE}
searchTerm={searchTerm}
footer={{}}
emptyState={
<IllustrationContent
illustration={SvgNoResult}
title="No members"
description="Add members to this group."
/>
}
/>
)}
</Section>
<SharedGroupResources
selectedCcPairIds={selectedCcPairIds}
onCcPairIdsChange={setSelectedCcPairIds}
selectedDocSetIds={selectedDocSetIds}
onDocSetIdsChange={setSelectedDocSetIds}
selectedAgentIds={selectedAgentIds}
onAgentIdsChange={setSelectedAgentIds}
/>
<TokenLimitSection
limits={tokenLimits}
onLimitsChange={setTokenLimits}
/>
{/* Delete This Group */}
<Card>
<InputLayouts.Horizontal
title="Delete This Group"
description="Members will lose access to any resources shared with this group."
center
nonInteractive
>
<Button
variant="danger"
prominence="secondary"
icon={SvgTrash}
onClick={() => setShowDeleteModal(true)}
>
Delete Group
</Button>
</InputLayouts.Horizontal>
</Card>
</>
)}
</SettingsLayouts.Body>
</SettingsLayouts.Root>
{showDeleteModal && (
<ConfirmationModalLayout
icon={SvgTrash}
title="Delete Group"
onClose={() => setShowDeleteModal(false)}
submit={
<Button
variant="danger"
onClick={handleDelete}
disabled={isDeleting}
>
{isDeleting ? "Deleting..." : "Delete"}
</Button>
}
>
<Text as="p" text03>
Members of group{" "}
<Text as="span" text05>
{group?.name}
</Text>{" "}
will lose access to any resources shared with this group, unless
they have been granted access directly. Deletion cannot be undone.
</Text>
</ConfirmationModalLayout>
)}
</>
);
}
export default EditGroupPage;

View File

@@ -1,5 +1,6 @@
"use client";
import { useRouter } from "next/navigation";
import type { UserGroup } from "@/lib/types";
import { SvgChevronRight, SvgUserManage, SvgUsers } from "@opal/icons";
import { ContentAction } from "@opal/layouts";
@@ -21,6 +22,7 @@ interface GroupCardProps {
}
function GroupCard({ group }: GroupCardProps) {
const router = useRouter();
const { mutate } = useSWRConfig();
const builtIn = isBuiltInGroup(group);
const isAdmin = group.name === "Admin";
@@ -39,7 +41,7 @@ function GroupCard({ group }: GroupCardProps) {
}
return (
<Card padding={0.5}>
<Card padding={0.5} data-card>
<ContentAction
icon={isAdmin ? SvgUserManage : SvgUsers}
title={group.name}
@@ -53,13 +55,17 @@ function GroupCard({ group }: GroupCardProps) {
<Section flexDirection="row" alignItems="start" gap={0}>
<div className="py-1">
<Text mainUiBody text03>
{formatMemberCount(group.users.length)}
{formatMemberCount(
group.users.filter((u) => u.is_active).length
)}
</Text>
</div>
<Button
icon={SvgChevronRight}
prominence="tertiary"
tooltip="View group"
aria-label="View group"
onClick={() => router.push(`/admin/groups/${group.id}`)}
/>
</Section>
}

View File

@@ -28,28 +28,32 @@ function ResourceContent({
}: ResourceContentProps) {
return (
<div className="flex flex-1 gap-0.5 items-start p-1.5 rounded-08 bg-background-tint-01 min-w-[240px] max-w-[302px]">
<div className="flex flex-1 gap-1 p-0.5 items-start min-w-0">
<div className="flex flex-1 gap-1 p-0.5 items-center min-w-0">
{leftContent ? (
<>
{leftContent}
<div className="flex-1 min-w-0">
<Content
title={title}
description={description}
sizePreset="main-ui"
variant="section"
/>
</div>
</>
) : (
<div className="flex-1 min-w-0">
<Content
icon={icon}
title={title}
description={description}
sizePreset="main-ui"
variant="section"
/>
</>
) : (
<Content
icon={icon}
title={title}
description={description}
sizePreset="main-ui"
variant="section"
/>
</div>
)}
{infoContent}
</div>
{infoContent}
<IconButton small icon={SvgX} onClick={onRemove} className="shrink-0" />
</div>
);

View File

@@ -23,11 +23,7 @@ function ResourcePopover({
return (
<Popover open={open} onOpenChange={setOpen}>
<Popover.Trigger
onClick={(e) => {
e.preventDefault();
}}
>
<Popover.Anchor>
<InputTypeIn
placeholder={placeholder}
value={searchValue}
@@ -37,7 +33,7 @@ function ResourcePopover({
}}
onFocus={() => setOpen(true)}
/>
</Popover.Trigger>
</Popover.Anchor>
<Popover.Content
width="trigger"
align="start"

View File

@@ -40,7 +40,11 @@ function SharedBadge() {
);
}
function SourceIconStack({ sources }: { sources: { source: ValidSources }[] }) {
interface SourceIconStackProps {
sources: { source: ValidSources }[];
}
function SourceIconStack({ sources }: SourceIconStackProps) {
if (sources.length === 0) return null;
const unique = Array.from(
@@ -51,16 +55,17 @@ function SourceIconStack({ sources }: { sources: { source: ValidSources }[] }) {
<Section
flexDirection="row"
alignItems="center"
width="auto"
height="auto"
gap={0}
className="shrink-0 px-0.5"
className="shrink-0 p-0.5"
>
{unique.map((s, i) => {
const Icon = getSourceMetadata(s.source).icon;
return (
<div
key={s.source}
className="flex items-center justify-center size-4 rounded-04 bg-background-tint-00 border border-border-01 overflow-hidden"
className="flex items-center justify-center size-4 rounded-04 bg-background-tint-00 border border-border-01 overflow-hidden [&_img]:!size-4 [&_img]:!m-0 [&_svg]:size-4"
style={{ zIndex: unique.length - i, marginLeft: i > 0 ? -6 : 0 }}
>
<Icon />
@@ -316,7 +321,7 @@ function SharedGroupResources({
key={`d-${ds.id}`}
icon={SvgFiles}
title={ds.name}
description={`Document Set - ${ds.cc_pair_summaries.length} Sources`}
description="Document Set"
infoContent={
<SourceIconStack sources={ds.cc_pair_summaries} />
}

View File

@@ -31,7 +31,10 @@ function GroupsPage() {
{/* This is the sticky header for the groups page. It is used to display
* the groups page title and search input when scrolling down.
*/}
<div className="sticky top-0 z-settings-header bg-background-tint-01">
<div
className="sticky top-0 z-settings-header bg-background-tint-01"
data-testid="groups-page-heading"
>
<SettingsLayouts.Header icon={SvgUsers} title="Groups" separator />
<Section flexDirection="row" padding={1}>

View File

@@ -13,3 +13,10 @@ export interface ApiKeyDescriptor {
export interface MemberRow extends UserRow {
api_key_display?: string;
}
export interface TokenRateLimitDisplay {
token_id: number;
enabled: boolean;
token_budget: number;
period_hours: number;
}

View File

@@ -73,7 +73,7 @@ function renderAccountTypeColumn(_value: unknown, row: MemberRow) {
// Columns
// ---------------------------------------------------------------------------
const tc = createTableColumns<MemberRow>();
export const tc = createTableColumns<MemberRow>();
export const baseColumns = [
tc.qualifier(),

View File

@@ -40,6 +40,27 @@ async function createGroup(
return group.id;
}
async function updateGroup(
groupId: number,
userIds: string[],
ccPairIds: number[]
): Promise<void> {
const res = await fetch(`${USER_GROUP_URL}/${groupId}`, {
method: "PATCH",
headers: { "Content-Type": "application/json" },
body: JSON.stringify({
user_ids: userIds,
cc_pair_ids: ccPairIds,
}),
});
if (!res.ok) {
const detail = await res.json().catch(() => null);
throw new Error(
detail?.detail ?? `Failed to update group: ${res.statusText}`
);
}
}
async function deleteGroup(groupId: number): Promise<void> {
const res = await fetch(`${USER_GROUP_URL}/${groupId}`, {
method: "DELETE",
@@ -262,6 +283,7 @@ export {
USER_GROUP_URL,
renameGroup,
createGroup,
updateGroup,
deleteGroup,
updateAgentGroupSharing,
updateDocSetGroupSharing,

View File

@@ -1,334 +0,0 @@
"use client";
import { useState } from "react";
import { Button } from "@opal/components";
import { Disabled } from "@opal/core";
import { SvgHookNodes } from "@opal/icons";
import Modal, { BasicModalFooter } from "@/refresh-components/Modal";
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
import InputSelect from "@/refresh-components/inputs/InputSelect";
import PasswordInputTypeIn from "@/refresh-components/inputs/PasswordInputTypeIn";
import Text from "@/refresh-components/texts/Text";
import { toast } from "@/hooks/useToast";
import { createHook, updateHook } from "@/refresh-pages/admin/HooksPage/svc";
import type {
HookFailStrategy,
HookPointMeta,
HookResponse,
} from "@/refresh-pages/admin/HooksPage/interfaces";
// ---------------------------------------------------------------------------
// Types
// ---------------------------------------------------------------------------
interface HookFormModalProps {
open: boolean;
onOpenChange: (open: boolean) => void;
/** When provided, the modal is in edit mode for this hook. */
hook?: HookResponse;
/** When provided (create mode), the hook point is pre-selected and locked. */
spec?: HookPointMeta;
onSuccess: (hook: HookResponse) => void;
}
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
interface FormState {
name: string;
endpoint_url: string;
api_key: string;
fail_strategy: HookFailStrategy;
timeout_seconds: string;
}
function buildInitialState(
hook: HookResponse | undefined,
spec: HookPointMeta | undefined
): FormState {
if (hook) {
return {
name: hook.name,
endpoint_url: hook.endpoint_url ?? "",
api_key: hook.api_key_masked ?? "",
fail_strategy: hook.fail_strategy,
timeout_seconds: String(hook.timeout_seconds),
};
}
return {
name: "",
endpoint_url: "",
api_key: "",
fail_strategy: spec?.default_fail_strategy ?? "soft",
timeout_seconds: spec ? String(spec.default_timeout_seconds) : "5",
};
}
const SOFT_DESCRIPTION =
"If the endpoint returns an error, Onyx logs it and continues the pipeline as normal, ignoring the hook result.";
// ---------------------------------------------------------------------------
// Sub-components
// ---------------------------------------------------------------------------
interface FieldProps {
label: string;
required?: boolean;
description?: string;
children: React.ReactNode;
}
function Field({ label, required, description, children }: FieldProps) {
return (
<div className="flex flex-col gap-1 w-full">
<span className="font-main-ui-action text-text-04 px-[0.125rem]">
{label}
{required && <span className="text-status-error-05 ml-0.5">*</span>}
</span>
{children}
{description && (
<Text secondaryBody text03>
{description}
</Text>
)}
</div>
);
}
// ---------------------------------------------------------------------------
// Component
// ---------------------------------------------------------------------------
export default function HookFormModal({
open,
onOpenChange,
hook,
spec,
onSuccess,
}: HookFormModalProps) {
const isEdit = !!hook;
const [form, setForm] = useState<FormState>(() =>
buildInitialState(hook, spec)
);
const [isSubmitting, setIsSubmitting] = useState(false);
function handleOpenChange(next: boolean) {
if (!next) {
setTimeout(() => {
setForm(buildInitialState(hook, spec));
setIsSubmitting(false);
}, 200);
}
onOpenChange(next);
}
function set<K extends keyof FormState>(key: K, value: FormState[K]) {
setForm((prev) => ({ ...prev, [key]: value }));
}
const timeoutNum = parseFloat(form.timeout_seconds);
const isValid =
form.name.trim().length > 0 &&
form.endpoint_url.trim().length > 0 &&
!isNaN(timeoutNum) &&
timeoutNum > 0;
async function handleSubmit() {
if (!isValid) return;
setIsSubmitting(true);
try {
let result: HookResponse;
if (isEdit && hook) {
const req: Record<string, unknown> = {};
if (form.name !== hook.name) req.name = form.name;
if (form.endpoint_url !== (hook.endpoint_url ?? ""))
req.endpoint_url = form.endpoint_url;
if (form.fail_strategy !== hook.fail_strategy)
req.fail_strategy = form.fail_strategy;
if (timeoutNum !== hook.timeout_seconds)
req.timeout_seconds = timeoutNum;
const maskedPlaceholder = hook.api_key_masked ?? "";
if (form.api_key !== maskedPlaceholder) {
req.api_key = form.api_key || null;
}
if (Object.keys(req).length === 0) {
handleOpenChange(false);
return;
}
result = await updateHook(hook.id, req);
} else {
const hookPoint = spec!.hook_point;
result = await createHook({
name: form.name,
hook_point: hookPoint,
endpoint_url: form.endpoint_url,
...(form.api_key ? { api_key: form.api_key } : {}),
fail_strategy: form.fail_strategy,
timeout_seconds: timeoutNum,
});
}
toast.success(isEdit ? "Hook updated." : "Hook created.");
onSuccess(result);
handleOpenChange(false);
} catch (err) {
toast.error(err instanceof Error ? err.message : "Something went wrong.");
} finally {
setIsSubmitting(false);
}
}
const hookPointDisplayName = isEdit
? hook!.hook_point
: spec?.display_name ?? spec?.hook_point ?? "";
const hookPointDescription = isEdit ? undefined : spec?.description;
const docsUrl = isEdit ? undefined : spec?.docs_url;
const failStrategyDescription =
form.fail_strategy === "soft"
? SOFT_DESCRIPTION
: spec?.fail_hard_description ?? undefined;
return (
<Modal open={open} onOpenChange={handleOpenChange}>
<Modal.Content width="md" height="fit">
<Modal.Header
icon={SvgHookNodes}
title="Set Up Hook Extension"
description="Connect a external API endpoints to extend the hook point."
onClose={() => handleOpenChange(false)}
/>
<Modal.Body>
{/* Hook point section header */}
<div className="flex flex-row items-start justify-between gap-1">
<div className="flex flex-col flex-1 min-w-0">
<span className="font-main-ui-action text-text-04 px-[0.125rem]">
{hookPointDisplayName}
</span>
{hookPointDescription && (
<span className="font-secondary-body text-text-03 px-[0.125rem]">
{hookPointDescription}
</span>
)}
</div>
<div className="flex flex-col items-end shrink-0 gap-1">
<div className="flex items-center gap-1">
<SvgHookNodes
style={{ width: "1rem", height: "1rem" }}
className="text-text-03 shrink-0 p-0.5"
/>
<span className="font-secondary-body text-text-03">
Hook Point
</span>
</div>
{docsUrl && (
<a
href={docsUrl}
target="_blank"
rel="noopener noreferrer"
className="font-secondary-body text-text-03 underline"
>
Documentation
</a>
)}
</div>
</div>
<Field label="Display Name" required>
<InputTypeIn
value={form.name}
onChange={(e) => set("name", e.target.value)}
placeholder="Name your extension at this hook point"
variant={isSubmitting ? "disabled" : undefined}
/>
</Field>
<Field label="Fail Strategy" description={failStrategyDescription}>
<InputSelect
value={form.fail_strategy}
onValueChange={(v) => set("fail_strategy", v as HookFailStrategy)}
disabled={isSubmitting}
>
<InputSelect.Trigger placeholder="Select strategy" />
<InputSelect.Content>
<InputSelect.Item value="soft">
Log Error and Continue (Default)
</InputSelect.Item>
<InputSelect.Item value="hard">
Block Pipeline on Failure
</InputSelect.Item>
</InputSelect.Content>
</InputSelect>
</Field>
<Field
label="Timeout (seconds)"
required
description="Maximum time Onyx will wait for the endpoint to respond before applying the fail strategy."
>
<InputTypeIn
type="number"
value={form.timeout_seconds}
onChange={(e) => set("timeout_seconds", e.target.value)}
placeholder="5"
variant={isSubmitting ? "disabled" : undefined}
/>
</Field>
<Field
label="External API Endpoint URL"
required
description="Only connect to servers you trust. You are responsible for actions taken and data shared with this connection."
>
<InputTypeIn
value={form.endpoint_url}
onChange={(e) => set("endpoint_url", e.target.value)}
placeholder="https://your-api-endpoint.com"
variant={isSubmitting ? "disabled" : undefined}
/>
</Field>
<Field
label="API Key"
description="Onyx will use this key to authenticate with your API endpoint."
>
<PasswordInputTypeIn
value={form.api_key}
onChange={(e) => set("api_key", e.target.value)}
placeholder={
isEdit && hook?.api_key_masked
? "Leave blank to keep current key"
: undefined
}
disabled={isSubmitting}
/>
</Field>
</Modal.Body>
<Modal.Footer>
<BasicModalFooter
cancel={
<Disabled disabled={isSubmitting}>
<Button
prominence="tertiary"
onClick={() => handleOpenChange(false)}
>
Cancel
</Button>
</Disabled>
}
submit={
<Disabled disabled={isSubmitting || !isValid}>
<Button onClick={handleSubmit}>
{isEdit ? "Save" : "Connect"}
</Button>
</Disabled>
}
/>
</Modal.Footer>
</Modal.Content>
</Modal>
);
}

View File

@@ -3,34 +3,20 @@
import { useState, useEffect } from "react";
import { toast } from "@/hooks/useToast";
import { useHookSpecs } from "@/hooks/useHookSpecs";
import { useHooks } from "@/hooks/useHooks";
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
import { ContentAction } from "@opal/layouts";
import { Button } from "@opal/components";
import { Disabled } from "@opal/core";
import InputSearch from "@/refresh-components/inputs/InputSearch";
import Card from "@/refresh-components/cards/Card";
import Text from "@/refresh-components/texts/Text";
import {
SvgArrowExchange,
SvgBubbleText,
SvgCheckCircle,
SvgExternalLink,
SvgFileBroadcast,
SvgHookNodes,
SvgRefreshCw,
SvgSettings,
SvgXCircle,
} from "@opal/icons";
import { IconFunctionComponent } from "@opal/types";
import HookFormModal from "@/refresh-pages/admin/HooksPage/HookFormModal";
import type {
HookPointMeta,
HookResponse,
} from "@/refresh-pages/admin/HooksPage/interfaces";
import {
activateHook,
deactivateHook,
} from "@/refresh-pages/admin/HooksPage/svc";
const HOOK_POINT_ICONS: Record<string, IconFunctionComponent> = {
document_ingestion: SvgFileBroadcast,
@@ -41,152 +27,22 @@ function getHookPointIcon(hookPoint: string): IconFunctionComponent {
return HOOK_POINT_ICONS[hookPoint] ?? SvgHookNodes;
}
// ---------------------------------------------------------------------------
// Sub-component: connected hook card
// ---------------------------------------------------------------------------
interface ConnectedHookCardProps {
hook: HookResponse;
spec: HookPointMeta | undefined;
onEdit: () => void;
onDeleted: () => void;
onToggled: (updated: HookResponse) => void;
}
function ConnectedHookCard({
hook,
spec,
onEdit,
onDeleted: _onDeleted,
onToggled,
}: ConnectedHookCardProps) {
const [isBusy, setIsBusy] = useState(false);
async function handleToggle() {
setIsBusy(true);
try {
const updated = hook.is_active
? await deactivateHook(hook.id)
: await activateHook(hook.id);
onToggled(updated);
} catch (err) {
toast.error(
err instanceof Error ? err.message : "Failed to update hook status."
);
} finally {
setIsBusy(false);
}
}
const HookIcon = getHookPointIcon(hook.hook_point);
return (
<Card variant="primary" padding={0.5} gap={0}>
<div className="flex flex-row items-start w-full">
{/* Left: manually replicate ContentMd main-content layout so the docs
link sits as a natural third line with zero artificial gap. */}
<div className="flex flex-row flex-1 min-w-0 items-start gap-1 p-2">
<div className="shrink-0 p-0.5 text-text-04">
<HookIcon style={{ width: "1rem", height: "1rem" }} />
</div>
<div className="flex flex-col items-start min-w-0 flex-1">
<span
className="font-main-ui-action text-text-04"
style={{ height: "1.25rem" }}
>
{hook.name}
</span>
{/* matches opal-content-md-description: font-secondary-body px-[0.125rem] */}
<div className="font-secondary-body text-text-03">
{`Hook Point: ${spec?.display_name ?? hook.hook_point}`}
</div>
{spec?.docs_url && (
<div className="font-secondary-body text-text-03">
<a
href={spec.docs_url}
target="_blank"
rel="noopener noreferrer"
className="flex items-center gap-1 w-fit"
>
<span className="underline">Documentation</span>
<SvgExternalLink size={12} className="shrink-0" />
</a>
</div>
)}
</div>
</div>
{/* Right: status + actions, top-aligned */}
<div className="flex flex-col items-end shrink-0">
<div className="flex items-center gap-1 p-2">
<span className="font-main-ui-action text-text-03">
{hook.is_active ? "Connected" : "Inactive"}
</span>
{hook.is_active ? (
<SvgCheckCircle size={16} className="text-status-success-05" />
) : (
<SvgXCircle size={16} className="text-text-03" />
)}
</div>
<div className="flex items-center gap-0.5 pl-2 pr-0.5">
<Disabled disabled={isBusy}>
<Button
prominence="tertiary"
size="md"
icon={SvgRefreshCw}
onClick={handleToggle}
tooltip={hook.is_active ? "Deactivate" : "Activate"}
aria-label={
hook.is_active ? "Deactivate hook" : "Activate hook"
}
/>
</Disabled>
<Disabled disabled={isBusy}>
<Button
prominence="tertiary"
size="md"
icon={SvgSettings}
onClick={onEdit}
aria-label="Configure hook"
/>
</Disabled>
</div>
</div>
</div>
</Card>
);
}
// ---------------------------------------------------------------------------
// Main component
// ---------------------------------------------------------------------------
export default function HooksContent() {
const [search, setSearch] = useState("");
const [connectSpec, setConnectSpec] = useState<HookPointMeta | null>(null);
const [editHook, setEditHook] = useState<HookResponse | null>(null);
const { specs, isLoading: specsLoading, error: specsError } = useHookSpecs();
const {
hooks,
isLoading: hooksLoading,
error: hooksError,
mutate,
} = useHooks();
const { specs, isLoading, error } = useHookSpecs();
useEffect(() => {
if (specsError) toast.error("Failed to load hook specifications.");
}, [specsError]);
if (error) {
toast.error("Failed to load hook specifications.");
}
}, [error]);
useEffect(() => {
if (hooksError) toast.error("Failed to load hooks.");
}, [hooksError]);
if (specsLoading || hooksLoading) {
if (isLoading) {
return <SimpleLoader />;
}
if (specsError) {
if (error) {
return (
<Text text03 secondaryBody>
Failed to load hook specifications. Please refresh the page.
@@ -194,179 +50,68 @@ export default function HooksContent() {
);
}
const hooksByPoint: Record<string, HookResponse[]> = {};
for (const hook of hooks ?? []) {
(hooksByPoint[hook.hook_point] ??= []).push(hook);
}
const searchLower = search.toLowerCase();
// Connected hooks sorted alphabetically by hook name
const connectedHooks = (hooks ?? [])
.filter(
(hook) =>
!searchLower ||
hook.name.toLowerCase().includes(searchLower) ||
(hooksByPoint[hook.hook_point] &&
specs
?.find((s) => s.hook_point === hook.hook_point)
?.display_name.toLowerCase()
.includes(searchLower))
)
.sort((a, b) => a.name.localeCompare(b.name));
// Unconnected hook point specs sorted alphabetically
const unconnectedSpecs = (specs ?? [])
.filter(
(spec) =>
(hooksByPoint[spec.hook_point]?.length ?? 0) === 0 &&
(!searchLower ||
spec.display_name.toLowerCase().includes(searchLower) ||
spec.description.toLowerCase().includes(searchLower))
)
.sort((a, b) => a.display_name.localeCompare(b.display_name));
function handleHookSuccess(updated: HookResponse) {
mutate((prev) => {
if (!prev) return [updated];
const idx = prev.findIndex((h) => h.id === updated.id);
if (idx >= 0) {
const next = [...prev];
next[idx] = updated;
return next;
}
return [...prev, updated];
});
}
function handleHookDeleted(id: number) {
mutate((prev) => prev?.filter((h) => h.id !== id));
}
const connectSpec_ =
connectSpec ??
(editHook
? specs?.find((s) => s.hook_point === editHook.hook_point)
: undefined);
const filtered = (specs ?? []).filter(
(spec) =>
spec.display_name.toLowerCase().includes(search.toLowerCase()) ||
spec.description.toLowerCase().includes(search.toLowerCase())
);
return (
<>
<div className="flex flex-col gap-6">
<InputSearch
placeholder="Search hooks..."
value={search}
onChange={(e) => setSearch(e.target.value)}
/>
<div className="flex flex-col gap-6">
<InputSearch
placeholder="Search hooks..."
value={search}
onChange={(e) => setSearch(e.target.value)}
/>
<div className="flex flex-col gap-2">
{connectedHooks.length === 0 && unconnectedSpecs.length === 0 ? (
<Text text03 secondaryBody>
{search
? "No hooks match your search."
: "No hook points are available."}
</Text>
) : (
<>
{connectedHooks.map((hook) => {
const spec = specs?.find(
(s) => s.hook_point === hook.hook_point
);
return (
<ConnectedHookCard
key={hook.id}
hook={hook}
spec={spec}
onEdit={() => setEditHook(hook)}
onDeleted={() => handleHookDeleted(hook.id)}
onToggled={handleHookSuccess}
/>
);
})}
{unconnectedSpecs.map((spec) => {
const UnconnectedIcon = getHookPointIcon(spec.hook_point);
return (
<Card
key={spec.hook_point}
variant="secondary"
padding={0.5}
gap={0}
<div className="flex flex-col gap-2">
{filtered.length === 0 ? (
<Text text03 secondaryBody>
{search
? "No hooks match your search."
: "No hook points are available."}
</Text>
) : (
filtered.map((spec) => (
<Card
key={spec.hook_point}
variant="secondary"
padding={0.5}
gap={0}
>
<ContentAction
icon={getHookPointIcon(spec.hook_point)}
title={spec.display_name}
description={spec.description}
sizePreset="main-content"
variant="section"
paddingVariant="fit"
rightChildren={
// TODO(Bo-Onyx): wire up Connect — open modal to create/edit hook
<Button prominence="tertiary" rightIcon={SvgArrowExchange}>
Connect
</Button>
}
/>
{spec.docs_url && (
<div className="pl-7 pt-1">
<a
href={spec.docs_url}
target="_blank"
rel="noopener noreferrer"
className="flex items-center gap-1 w-fit text-text-03"
>
<div className="flex flex-row items-start w-full">
<div className="flex flex-row flex-1 min-w-0 items-start gap-1 p-2">
<div className="shrink-0 p-0.5 text-text-04">
<UnconnectedIcon
style={{ width: "1rem", height: "1rem" }}
/>
</div>
<div className="flex flex-col items-start min-w-0 flex-1">
<span
className="font-main-ui-action text-text-04"
style={{ height: "1.25rem" }}
>
{spec.display_name}
</span>
<div className="font-secondary-body text-text-03">
{spec.description}
</div>
{spec.docs_url && (
<div className="font-secondary-body text-text-03">
<a
href={spec.docs_url}
target="_blank"
rel="noopener noreferrer"
className="flex items-center gap-1 w-fit"
>
<span className="underline">Documentation</span>
<SvgExternalLink
size={12}
className="shrink-0"
/>
</a>
</div>
)}
</div>
</div>
<div
className="flex items-center gap-1 p-2 cursor-pointer"
onClick={() => setConnectSpec(spec)}
>
<span className="font-main-ui-action text-text-03">
Connect
</span>
<SvgArrowExchange
size={16}
className="text-text-03 shrink-0"
/>
</div>
</div>
</Card>
);
})}
</>
)}
</div>
<Text as="span" secondaryBody text03 className="underline">
Documentation
</Text>
<SvgExternalLink size={16} className="text-text-02" />
</a>
</div>
)}
</Card>
))
)}
</div>
{/* Create modal */}
<HookFormModal
open={!!connectSpec}
onOpenChange={(open) => {
if (!open) setConnectSpec(null);
}}
spec={connectSpec ?? undefined}
onSuccess={handleHookSuccess}
/>
{/* Edit modal */}
<HookFormModal
open={!!editHook}
onOpenChange={(open) => {
if (!open) setEditHook(null);
}}
hook={editHook ?? undefined}
spec={connectSpec_ ?? undefined}
onSuccess={handleHookSuccess}
/>
</>
</div>
);
}

View File

@@ -18,8 +18,6 @@ export interface HookResponse {
name: string;
hook_point: HookPoint;
endpoint_url: string | null;
/** Partially-masked API key (e.g. "abcd••••••••wxyz"), or null if no key is set. */
api_key_masked: string | null;
fail_strategy: HookFailStrategy;
timeout_seconds: number;
is_active: boolean;

View File

@@ -1,81 +0,0 @@
import {
HookCreateRequest,
HookResponse,
HookUpdateRequest,
} from "@/refresh-pages/admin/HooksPage/interfaces";
async function parseErrorDetail(
res: Response,
fallback: string
): Promise<string> {
try {
const body = await res.json();
return body?.detail ?? fallback;
} catch {
return fallback;
}
}
export async function listHooks(): Promise<HookResponse[]> {
const res = await fetch("/api/admin/hooks");
if (!res.ok) {
throw new Error(await parseErrorDetail(res, "Failed to load hooks"));
}
return res.json();
}
export async function createHook(
req: HookCreateRequest
): Promise<HookResponse> {
const res = await fetch("/api/admin/hooks", {
method: "POST",
headers: { "Content-Type": "application/json" },
body: JSON.stringify(req),
});
if (!res.ok) {
throw new Error(await parseErrorDetail(res, "Failed to create hook"));
}
return res.json();
}
export async function updateHook(
id: number,
req: HookUpdateRequest
): Promise<HookResponse> {
const res = await fetch(`/api/admin/hooks/${id}`, {
method: "PATCH",
headers: { "Content-Type": "application/json" },
body: JSON.stringify(req),
});
if (!res.ok) {
throw new Error(await parseErrorDetail(res, "Failed to update hook"));
}
return res.json();
}
export async function deleteHook(id: number): Promise<void> {
const res = await fetch(`/api/admin/hooks/${id}`, { method: "DELETE" });
if (!res.ok) {
throw new Error(await parseErrorDetail(res, "Failed to delete hook"));
}
}
export async function activateHook(id: number): Promise<HookResponse> {
const res = await fetch(`/api/admin/hooks/${id}/activate`, {
method: "POST",
});
if (!res.ok) {
throw new Error(await parseErrorDetail(res, "Failed to activate hook"));
}
return res.json();
}
export async function deactivateHook(id: number): Promise<HookResponse> {
const res = await fetch(`/api/admin/hooks/${id}/deactivate`, {
method: "POST",
});
if (!res.ok) {
throw new Error(await parseErrorDetail(res, "Failed to deactivate hook"));
}
return res.json();
}

View File

@@ -256,6 +256,7 @@ export default function AgentViewerModal({ agent }: AgentViewerModalProps) {
title="Featured"
sizePreset="main-ui"
variant="body"
widthVariant="fit"
/>
)}
<Content
@@ -264,6 +265,7 @@ export default function AgentViewerModal({ agent }: AgentViewerModalProps) {
sizePreset="main-ui"
variant="body"
prominence="muted"
widthVariant="fit"
/>
{agent.is_public && (
<Content
@@ -272,6 +274,7 @@ export default function AgentViewerModal({ agent }: AgentViewerModalProps) {
sizePreset="main-ui"
variant="body"
prominence="muted"
widthVariant="fit"
/>
)}
</Section>

View File

@@ -7,6 +7,7 @@ import {
SvgOrganization,
SvgShare,
SvgTag,
SvgUser,
SvgUsers,
SvgX,
} from "@opal/icons";
@@ -18,7 +19,6 @@ import InputComboBox from "@/refresh-components/inputs/InputComboBox/InputComboB
import * as InputLayouts from "@/layouts/input-layouts";
import SwitchField from "@/refresh-components/form/SwitchField";
import LineItem from "@/refresh-components/buttons/LineItem";
import { SvgUser } from "@opal/icons";
import { Section } from "@/layouts/general-layouts";
import Text from "@/refresh-components/texts/Text";
import useShareableUsers from "@/hooks/useShareableUsers";

View File

@@ -7,12 +7,9 @@ import useSWR, { preload } from "swr";
import { errorHandlingFetcher } from "@/lib/fetcher";
import { checkUserIsNoAuthUser, getUserDisplayName, logout } from "@/lib/user";
import { useUser } from "@/providers/UserProvider";
import InputAvatar from "@/refresh-components/inputs/InputAvatar";
import Text from "@/refresh-components/texts/Text";
import LineItem from "@/refresh-components/buttons/LineItem";
import Popover, { PopoverMenu } from "@/refresh-components/Popover";
import { usePathname, useRouter, useSearchParams } from "next/navigation";
import { cn } from "@/lib/utils";
import SidebarTab from "@/refresh-components/buttons/SidebarTab";
import NotificationsPopover from "@/sections/sidebar/NotificationsPopover";
import {
@@ -26,6 +23,7 @@ import { Section } from "@/layouts/general-layouts";
import { toast } from "@/hooks/useToast";
import useAppFocus from "@/hooks/useAppFocus";
import { useVectorDbEnabled } from "@/providers/SettingsProvider";
import UserAvatar from "@/refresh-components/avatars/UserAvatar";
interface SettingsPopoverProps {
onUserSettingsClick: () => void;
@@ -150,7 +148,6 @@ export default function UserAvatarPopover({
"Settings" | "Notifications" | undefined
>(undefined);
const { user } = useUser();
const router = useRouter();
const appFocus = useAppFocus();
const vectorDbEnabled = useVectorDbEnabled();
@@ -186,18 +183,10 @@ export default function UserAvatarPopover({
<Popover.Trigger asChild>
<div id="onyx-user-dropdown">
<SidebarTab
icon={({ className }) => (
<InputAvatar
className={cn(
"flex items-center justify-center bg-background-neutral-inverted-00",
className,
"w-5 h-5"
)}
>
<Text as="p" inverted secondaryBody>
{userDisplayName[0]?.toUpperCase()}
</Text>
</InputAvatar>
icon={() => (
<div className="w-[16px] flex flex-col justify-center items-center">
<UserAvatar user={user} size={18} />
</div>
)}
rightChildren={
hasNotifications ? (

View File

@@ -84,7 +84,7 @@ const ADMIN_PAGES: AdminPageSnapshot[] = [
{
name: "User Management - Groups",
path: "groups",
pageTitle: "Manage User Groups",
pageTitle: "Groups",
},
{
name: "Appearance & Theming",

View File

@@ -0,0 +1,252 @@
/**
* Page Object Model for the Admin Groups page (/admin/groups).
*
* Covers the list page, create page, and edit page interactions.
*/
import { type Page, type Locator, expect } from "@playwright/test";
/** URL pattern that matches the groups data fetch. */
const GROUPS_API = /\/api\/manage\/admin\/user-group/;
export class GroupsAdminPage {
readonly page: Page;
constructor(page: Page) {
this.page = page;
}
// ---------------------------------------------------------------------------
// Navigation
// ---------------------------------------------------------------------------
async goto() {
await this.page.goto("/admin/groups");
await expect(this.newGroupButton).toBeVisible({ timeout: 15000 });
}
async gotoCreate() {
await this.page.goto("/admin/groups/create");
await expect(this.page.getByText("Create Group")).toBeVisible({
timeout: 15000,
});
}
async gotoEdit(groupId: number) {
await this.page.goto(`/admin/groups/${groupId}`);
// Wait for the form to be ready — avoids networkidle hanging due to SWR polling.
await expect(this.groupNameInput).toBeVisible({ timeout: 15000 });
}
// ---------------------------------------------------------------------------
// List page
// ---------------------------------------------------------------------------
/** The Groups page heading container (unique to the list page). */
get pageHeading(): Locator {
return this.page.getByTestId("groups-page-heading");
}
/** The search input on the list page. */
get listSearchInput(): Locator {
return this.page.getByPlaceholder("Search groups...");
}
/** The "New Group" button on the list page header. */
get newGroupButton(): Locator {
return this.page.getByRole("button", { name: "New Group" });
}
/** Returns all group cards on the list page. */
get groupCards(): Locator {
return this.page.locator("[data-card]");
}
/**
* Returns a group card by name.
* Cards use ContentAction which renders the title as text — match by content.
*/
getGroupCard(name: string): Locator {
return this.page.locator("[data-card]").filter({ hasText: name });
}
/** Click into a group's edit page from the list. */
async openGroup(name: string) {
const card = this.getGroupCard(name);
await card.getByRole("button", { name: "View group" }).click();
await expect(this.groupNameInput).toBeVisible({ timeout: 15000 });
}
/** Search groups on the list page. */
async searchGroups(term: string) {
await this.listSearchInput.fill(term);
}
/** Click "New Group" to navigate to the create page. */
async clickNewGroup() {
await this.newGroupButton.click();
await expect(this.page.getByText("Create Group")).toBeVisible({
timeout: 15000,
});
}
// ---------------------------------------------------------------------------
// Create page
// ---------------------------------------------------------------------------
/** The group name input on create/edit pages. */
get groupNameInput(): Locator {
return this.page.getByPlaceholder("Name your group");
}
/** The member search input on create/edit pages. */
get memberSearchInput(): Locator {
return this.page.getByPlaceholder("Search users and accounts...");
}
/** The "Create" button on the create page. */
get createButton(): Locator {
return this.page.getByRole("button", { name: "Create", exact: true });
}
/** The "Cancel" button on create/edit pages. */
get cancelButton(): Locator {
return this.page.getByRole("button", { name: "Cancel" });
}
/** Fill in the group name on create/edit pages. */
async setGroupName(name: string) {
await this.groupNameInput.fill(name);
}
/** Search for members in the members table. */
async searchMembers(term: string) {
await this.memberSearchInput.fill(term);
}
/** Select a member row by checking their checkbox (create page / add mode). */
async selectMember(emailOrName: string) {
const row = this.page.getByRole("row").filter({ hasText: emailOrName });
const checkbox = row.getByRole("checkbox");
await checkbox.click();
}
/** Submit the create form. */
async submitCreate() {
await this.createButton.click();
}
// ---------------------------------------------------------------------------
// Edit page
// ---------------------------------------------------------------------------
/** The "Save Changes" button on the edit page. */
get saveButton(): Locator {
return this.page.getByRole("button", { name: "Save Changes" });
}
/** The "Add" button to enter add-members mode. */
get addMembersButton(): Locator {
return this.page.getByRole("button", { name: "Add", exact: true });
}
/** The "Done" button to exit add-members mode. */
get doneAddingButton(): Locator {
return this.page.getByRole("button", { name: "Done" });
}
/** The "Delete Group" button in the danger zone card. */
get deleteGroupButton(): Locator {
return this.page.getByRole("button", { name: "Delete Group" });
}
/** Enter add-members mode on the edit page. */
async startAddingMembers() {
await this.addMembersButton.click();
await expect(this.doneAddingButton).toBeVisible();
}
/** Exit add-members mode. */
async finishAddingMembers() {
await this.doneAddingButton.click();
await expect(this.addMembersButton).toBeVisible();
}
/**
* Remove a member from the member view via the minus button.
* Only works in member view (not add mode).
*/
async removeMember(emailOrName: string) {
const row = this.page.getByRole("row").filter({ hasText: emailOrName });
// The remove button is an IconButton with SvgMinusCircle in the actions column
await row.getByRole("button").last().click();
}
/** Save the edit form. */
async submitEdit() {
await this.saveButton.click();
}
// ---------------------------------------------------------------------------
// Delete flow
// ---------------------------------------------------------------------------
/** Click "Delete Group" to open the confirmation modal. */
async clickDeleteGroup() {
await this.deleteGroupButton.click();
}
/** The delete confirmation modal. */
get deleteModal(): Locator {
return this.page.getByRole("dialog");
}
/** Confirm deletion in the modal. */
async confirmDelete() {
await this.deleteModal.getByRole("button", { name: "Delete" }).click();
}
/** Cancel deletion in the modal. */
async cancelDelete() {
// The modal close button (X icon) or clicking outside
await this.deleteModal
.getByRole("button")
.filter({ hasText: /close|cancel/i })
.first()
.click();
}
// ---------------------------------------------------------------------------
// Assertions
// ---------------------------------------------------------------------------
async expectToast(message: string | RegExp) {
await expect(this.page.getByText(message)).toBeVisible({ timeout: 10000 });
}
/** Assert a group card exists on the list page. */
async expectGroupVisible(name: string) {
await expect(this.getGroupCard(name)).toBeVisible({ timeout: 10000 });
}
/** Assert a group card does NOT exist on the list page. */
async expectGroupNotVisible(name: string) {
await expect(this.getGroupCard(name)).not.toBeVisible({ timeout: 10000 });
}
/** Assert we navigated back to the groups list. */
async expectOnListPage() {
await expect(this.page).toHaveURL(/\/admin\/groups\/?$/);
await expect(this.newGroupButton).toBeVisible();
}
/** Assert we are on the edit page for a specific group. */
async expectOnEditPage(groupId: number) {
await expect(this.page).toHaveURL(`/admin/groups/${groupId}`);
}
/** Wait for the groups API response after a mutation. */
async waitForGroupsRefresh() {
await this.page.waitForResponse(GROUPS_API);
}
}

View File

@@ -0,0 +1,37 @@
/**
* Playwright fixtures for Admin Groups page tests.
*
* Provides:
* - Authenticated admin page
* - OnyxApiClient for API-level setup/teardown
* - GroupsAdminPage page object
*/
import { test as base, expect, type Page } from "@playwright/test";
import { loginAs } from "@tests/e2e/utils/auth";
import { OnyxApiClient } from "@tests/e2e/utils/onyxApiClient";
import { GroupsAdminPage } from "./GroupsAdminPage";
export const test = base.extend<{
adminPage: Page;
api: OnyxApiClient;
groupsPage: GroupsAdminPage;
}>({
adminPage: async ({ page }, use) => {
await page.context().clearCookies();
await loginAs(page, "admin");
await use(page);
},
api: async ({ adminPage }, use) => {
const client = new OnyxApiClient(adminPage.request);
await use(client);
},
groupsPage: async ({ adminPage }, use) => {
const groupsPage = new GroupsAdminPage(adminPage);
await use(groupsPage);
},
});
export { expect };

View File

@@ -0,0 +1,279 @@
/**
* E2E Tests: Admin Groups Page
*
* Tests the full groups management page — list, create, edit, delete.
*
* Uses the GroupsAdminPage POM for all interactions. Groups are created via
* OnyxApiClient for setup and cleaned up in afterAll/afterEach.
*/
import { test, expect } from "./fixtures";
import type { OnyxApiClient } from "@tests/e2e/utils/onyxApiClient";
import type { Browser } from "@playwright/test";
// ---------------------------------------------------------------------------
// Helpers
// ---------------------------------------------------------------------------
function uniqueGroupName(prefix: string): string {
return `e2e-${prefix}-${Date.now()}`;
}
/** Best-effort cleanup — logs failures instead of silently swallowing them. */
async function softCleanup(fn: () => Promise<unknown>): Promise<void> {
await fn().catch((e) => console.warn("cleanup:", e));
}
/**
* Creates an authenticated API context for beforeAll/afterAll hooks.
*/
async function withApiContext(
browser: Browser,
fn: (api: OnyxApiClient) => Promise<void>
): Promise<void> {
const context = await browser.newContext({
storageState: "admin_auth.json",
});
try {
const { OnyxApiClient } = await import("@tests/e2e/utils/onyxApiClient");
const api = new OnyxApiClient(context.request);
await fn(api);
} finally {
await context.close();
}
}
// ---------------------------------------------------------------------------
// List page
// ---------------------------------------------------------------------------
test.describe("Groups page — layout", () => {
let adminGroupId: number;
let basicGroupId: number;
test.beforeAll(async ({ browser }) => {
await withApiContext(browser, async (api) => {
adminGroupId = await api.createUserGroup("Admin");
basicGroupId = await api.createUserGroup("Basic");
await api.waitForGroupSync(adminGroupId);
await api.waitForGroupSync(basicGroupId);
});
});
test.afterAll(async ({ browser }) => {
await withApiContext(browser, async (api) => {
await softCleanup(() => api.deleteUserGroup(adminGroupId));
await softCleanup(() => api.deleteUserGroup(basicGroupId));
});
});
test("renders page title, search, and new group button", async ({
groupsPage,
}) => {
await groupsPage.goto();
await expect(groupsPage.pageHeading).toBeVisible();
await expect(groupsPage.listSearchInput).toBeVisible();
await expect(groupsPage.newGroupButton).toBeVisible();
});
test("shows built-in groups (Admin, Basic)", async ({ groupsPage }) => {
await groupsPage.goto();
await groupsPage.expectGroupVisible("Admin");
await groupsPage.expectGroupVisible("Basic");
});
test("search filters groups by name", async ({ groupsPage, api }) => {
const name = uniqueGroupName("search");
const groupId = await api.createUserGroup(name);
await api.waitForGroupSync(groupId);
try {
await groupsPage.goto();
await groupsPage.expectGroupVisible(name);
await groupsPage.searchGroups("zzz-nonexistent-zzz");
await groupsPage.expectGroupNotVisible(name);
await groupsPage.searchGroups(name);
await groupsPage.expectGroupVisible(name);
} finally {
await softCleanup(() => api.deleteUserGroup(groupId));
}
});
});
// ---------------------------------------------------------------------------
// Create flow
// ---------------------------------------------------------------------------
test.describe("Groups page — create", () => {
test("navigates to create page via New Group button", async ({
groupsPage,
}) => {
await groupsPage.goto();
await groupsPage.clickNewGroup();
await expect(groupsPage.page).toHaveURL(/\/admin\/groups\/create/);
await expect(groupsPage.groupNameInput).toBeVisible();
});
test("creates a group and redirects to list", async ({ groupsPage, api }) => {
const name = uniqueGroupName("create");
let groupId: number | undefined;
try {
await groupsPage.gotoCreate();
await groupsPage.setGroupName(name);
await groupsPage.submitCreate();
await groupsPage.expectToast(`Group "${name}" created`);
await groupsPage.expectOnListPage();
// Find the group ID for cleanup via the authenticated page context
const res = await groupsPage.page.request.get(
"/api/manage/admin/user-group"
);
const groups = await res.json();
const group = groups.find(
(g: { name: string; id: number }) => g.name === name
);
groupId = group?.id;
} finally {
if (groupId !== undefined) {
await softCleanup(() => api.deleteUserGroup(groupId!));
}
}
});
test("cancel returns to list without creating", async ({ groupsPage }) => {
await groupsPage.gotoCreate();
await groupsPage.setGroupName("should-not-be-created");
await groupsPage.cancelButton.click();
await groupsPage.expectOnListPage();
});
});
// ---------------------------------------------------------------------------
// Edit flow
// ---------------------------------------------------------------------------
test.describe("Groups page — edit @exclusive", () => {
let groupId: number;
const groupName = uniqueGroupName("edit");
test.beforeAll(async ({ browser }) => {
await withApiContext(browser, async (api) => {
groupId = await api.createUserGroup(groupName);
await api.waitForGroupSync(groupId);
});
});
test.afterAll(async ({ browser }) => {
await withApiContext(browser, async (api) => {
await softCleanup(() => api.deleteUserGroup(groupId));
});
});
test("navigates to edit page from list", async ({ groupsPage }) => {
await groupsPage.goto();
await groupsPage.openGroup(groupName);
await groupsPage.expectOnEditPage(groupId);
await expect(groupsPage.saveButton).toBeVisible();
});
test("edit page shows group name and save/cancel buttons", async ({
groupsPage,
}) => {
await groupsPage.gotoEdit(groupId);
await expect(groupsPage.groupNameInput).toHaveValue(groupName);
await expect(groupsPage.saveButton).toBeVisible();
await expect(groupsPage.cancelButton).toBeVisible();
});
test("can toggle add-members mode", async ({ groupsPage }) => {
await groupsPage.gotoEdit(groupId);
await expect(groupsPage.addMembersButton).toBeVisible();
await groupsPage.startAddingMembers();
await expect(groupsPage.doneAddingButton).toBeVisible();
await groupsPage.finishAddingMembers();
await expect(groupsPage.addMembersButton).toBeVisible();
});
test("cancel returns to list without saving", async ({ groupsPage }) => {
await groupsPage.gotoEdit(groupId);
await groupsPage.cancelButton.click();
await groupsPage.expectOnListPage();
});
});
// ---------------------------------------------------------------------------
// Delete flow
// ---------------------------------------------------------------------------
test.describe("Groups page — delete", () => {
test("delete group via edit page", async ({ groupsPage, api }) => {
const name = uniqueGroupName("delete");
const groupId = await api.createUserGroup(name);
await api.waitForGroupSync(groupId);
await groupsPage.gotoEdit(groupId);
await groupsPage.clickDeleteGroup();
// Modal should show the group name
await expect(groupsPage.deleteModal).toBeVisible();
await expect(groupsPage.deleteModal.getByText(name)).toBeVisible();
await groupsPage.confirmDelete();
await groupsPage.expectToast(`Group "${name}" deleted`);
await groupsPage.expectOnListPage();
});
});
// ---------------------------------------------------------------------------
// Sync status (No Vector DB)
// ---------------------------------------------------------------------------
test.describe("Groups page — sync @lite", () => {
test.beforeAll(async ({ browser }) => {
const context = await browser.newContext({
storageState: "admin_auth.json",
});
try {
const { OnyxApiClient } = await import("@tests/e2e/utils/onyxApiClient");
const client = new OnyxApiClient(context.request);
const vectorDbEnabled = await client.isVectorDbEnabled();
test.skip(
vectorDbEnabled,
"Skipped: vector DB is enabled in this deployment"
);
} finally {
await context.close();
}
});
test("newly created group syncs immediately", async ({ groupsPage, api }) => {
const name = uniqueGroupName("sync");
let groupId: number | undefined;
try {
// Create via API and verify sync completes
groupId = await api.createUserGroup(name);
await api.waitForGroupSync(groupId);
// Navigate to edit page and verify it loads without error
await groupsPage.gotoEdit(groupId);
await expect(groupsPage.groupNameInput).toHaveValue(name);
} finally {
if (groupId !== undefined) {
await softCleanup(() => api.deleteUserGroup(groupId!));
}
}
});
});

View File

@@ -1,148 +0,0 @@
import { test, expect, BrowserContext } from "@playwright/test";
import { OnyxApiClient } from "@tests/e2e/utils/onyxApiClient";
test.use({ storageState: "admin_auth.json" });
test.describe("User Groups - No Vector DB @lite", () => {
test.beforeAll(async ({ browser }) => {
const context = await browser.newContext({
storageState: "admin_auth.json",
});
try {
const client = new OnyxApiClient(context.request);
const vectorDbEnabled = await client.isVectorDbEnabled();
test.skip(
vectorDbEnabled,
"Skipped: vector DB is enabled in this deployment"
);
} finally {
await context.close();
}
});
test("should show user group as synced immediately on creation", async ({
page,
}) => {
const groupName = `E2E-NoVectorDB-Group-${Date.now()}`;
let groupId: number | undefined;
try {
await page.goto("/admin/groups");
await page.waitForLoadState("networkidle");
await page.getByRole("button", { name: "Create New User Group" }).click();
const dialog = page.getByRole("dialog");
await expect(dialog).toBeVisible();
await dialog.locator('input[name="name"]').fill(groupName);
await expect(
dialog.getByText("Connectors are not available in Onyx Lite")
).toBeVisible();
await dialog.getByRole("button", { name: "Create!" }).click();
await expect(dialog).not.toBeVisible({ timeout: 10000 });
const groupRow = page.getByRole("row").filter({ hasText: groupName });
await expect(groupRow).toBeVisible({ timeout: 10000 });
await expect(groupRow.getByText("Up to date!")).toBeVisible();
const groupLink = groupRow.getByRole("link", { name: groupName });
const href = await groupLink.getAttribute("href");
const match = href?.match(/\/admin\/groups\/(\d+)/);
if (match) {
groupId = parseInt(match[1] ?? "", 10);
}
await groupLink.click();
await page.waitForLoadState("networkidle");
await expect(page.getByText("Up to date")).toBeVisible({ timeout: 5000 });
const addUsersButton = page.getByRole("button", {
name: "Add Users",
});
await expect(addUsersButton).toBeEnabled();
} finally {
if (groupId !== undefined) {
const apiClient = new OnyxApiClient(page.request);
await apiClient.deleteUserGroup(groupId);
}
}
});
});
test.describe("User Groups - Standard Deployment @exclusive", () => {
let cleanupContext: BrowserContext | undefined;
let client: OnyxApiClient;
let ccPairId: number | undefined;
let groupId: number | undefined;
test.beforeAll(async ({ browser }) => {
const context = await browser.newContext({
storageState: "admin_auth.json",
});
try {
client = new OnyxApiClient(context.request);
const vectorDbEnabled = await client.isVectorDbEnabled();
if (!vectorDbEnabled) {
await context.close();
test.skip(true, "Skipped: vector DB is disabled in this deployment");
return;
}
cleanupContext = context;
} catch (e) {
await context.close();
throw e;
}
});
test.afterAll(async () => {
try {
if (groupId !== undefined) {
await client.deleteUserGroup(groupId).catch(() => {});
}
if (ccPairId !== undefined) {
await client.deleteCCPair(ccPairId).catch(() => {});
}
} finally {
await cleanupContext?.close();
}
});
test("should sync user group with connector", async ({ page }) => {
const apiClient = new OnyxApiClient(page.request);
const groupName = `E2E-Sync-Group-${Date.now()}`;
ccPairId = await apiClient.createFileConnector(
`E2E-Group-Connector-${Date.now()}`,
"private"
);
groupId = await apiClient.createUserGroup(groupName, [], [ccPairId]);
await page.goto("/admin/groups");
await page.waitForLoadState("networkidle");
const groupRow = page.getByRole("row").filter({ hasText: groupName });
await expect(groupRow).toBeVisible({ timeout: 10000 });
const upToDate = groupRow.getByText("Up to date!");
const deadline = Date.now() + 120_000;
while (Date.now() < deadline) {
if (await upToDate.isVisible().catch(() => false)) break;
await page.waitForTimeout(3000);
await page.reload();
await page.waitForLoadState("networkidle");
}
await expect(upToDate).toBeVisible({ timeout: 5000 });
const groupLink = groupRow.getByRole("link", { name: groupName });
await groupLink.click();
await page.waitForLoadState("networkidle");
await expect(page.getByText("Up to date")).toBeVisible({ timeout: 10000 });
});
});

View File

@@ -0,0 +1,303 @@
import { test, expect } from "@playwright/test";
import { loginAs } from "@tests/e2e/utils/auth";
import {
TOOL_IDS,
openActionManagement,
openSourceManagement,
toggleToolDisabled,
getSourceToggle,
} from "@tests/e2e/utils/tools";
import { OnyxApiClient } from "@tests/e2e/utils/onyxApiClient";
const LOCAL_STORAGE_KEY = "selectedInternalSearchSources";
test.describe("ActionsPopover Tool Toggles", () => {
test.describe.configure({ mode: "serial" });
let ccPairId: number | null = null;
let webSearchProviderId: number | null = null;
let imageGenConfigId: string | null = null;
test.beforeAll(async ({ browser }) => {
const ctx = await browser.newContext({ storageState: "admin_auth.json" });
const page = await ctx.newPage();
await page.goto("http://localhost:3000/app");
await page.waitForLoadState("networkidle");
const apiClient = new OnyxApiClient(page.request);
// Create a file connector so internal search tool is available
ccPairId = await apiClient.createFileConnector(
`actions-popover-test-${Date.now()}`
);
// Create providers for web search and image generation (best-effort)
try {
webSearchProviderId = await apiClient.createWebSearchProvider(
"exa",
`actions-popover-web-search-${Date.now()}`
);
} catch (error) {
console.warn(`Failed to create web search provider: ${error}`);
}
try {
imageGenConfigId = await apiClient.createImageGenerationConfig(
`actions-popover-image-gen-${Date.now()}`
);
} catch (error) {
console.warn(`Failed to create image gen config: ${error}`);
}
// Ensure all tools are enabled on the default agent
const toolsResp = await page.request.get("/api/tool");
const allTools = await toolsResp.json();
const toolIdsByCodeId: Record<string, number> = {};
allTools.forEach((t: any) => {
if (t.in_code_tool_id) toolIdsByCodeId[t.in_code_tool_id] = t.id;
});
const configResp = await page.request.get(
"/api/admin/default-assistant/configuration"
);
const currentConfig = await configResp.json();
const desiredToolIds = [
toolIdsByCodeId["SearchTool"],
toolIdsByCodeId["WebSearchTool"],
toolIdsByCodeId["ImageGenerationTool"],
].filter(Boolean);
const uniqueToolIds = Array.from(
new Set([...(currentConfig.tool_ids || []), ...desiredToolIds])
);
await page.request.patch("/api/admin/default-assistant", {
data: { tool_ids: uniqueToolIds },
});
await ctx.close();
});
test.afterAll(async ({ browser }) => {
const ctx = await browser.newContext({ storageState: "admin_auth.json" });
const page = await ctx.newPage();
await page.goto("http://localhost:3000/app");
await page.waitForLoadState("networkidle");
const apiClient = new OnyxApiClient(page.request);
if (ccPairId !== null) {
try {
await apiClient.deleteCCPair(ccPairId);
} catch (error) {
console.warn(`Cleanup: failed to delete connector: ${error}`);
}
}
if (webSearchProviderId !== null) {
try {
await apiClient.deleteWebSearchProvider(webSearchProviderId);
} catch (error) {
console.warn(`Cleanup: failed to delete web search provider: ${error}`);
}
}
if (imageGenConfigId !== null) {
try {
await apiClient.deleteImageGenerationConfig(imageGenConfigId);
} catch (error) {
console.warn(`Cleanup: failed to delete image gen config: ${error}`);
}
}
await ctx.close();
});
test.beforeEach(async ({ page }) => {
await page.context().clearCookies();
await loginAs(page, "admin");
await page.goto("/app");
await page.waitForLoadState("networkidle");
// Clear source preferences for a clean slate
await page.evaluate(
(key) => localStorage.removeItem(key),
LOCAL_STORAGE_KEY
);
});
test("should show internal search and other tools in popover", async ({
page,
}) => {
await openActionManagement(page);
// Internal search must be visible (connector was created in beforeAll)
await expect(page.locator(TOOL_IDS.searchOption)).toBeVisible({
timeout: 10000,
});
// Soft-check other tools (depend on provider setup success)
const webVisible = await page
.locator(TOOL_IDS.webSearchOption)
.isVisible()
.catch(() => false);
const imgVisible = await page
.locator(TOOL_IDS.imageGenerationOption)
.isVisible()
.catch(() => false);
console.log(`[tools] web_search=${webVisible}, image_gen=${imgVisible}`);
});
test("source preferences should persist to localStorage and survive reload", async ({
page,
}) => {
await openActionManagement(page);
await expect(page.locator(TOOL_IDS.searchOption)).toBeVisible({
timeout: 10000,
});
await openSourceManagement(page);
// Find the first source switch
const switches = page.locator('[role="switch"]');
await expect(switches.first()).toBeVisible({ timeout: 5000 });
const firstSwitch = switches.first();
const ariaLabel = await firstSwitch.getAttribute("aria-label");
const sourceName = ariaLabel?.replace("Toggle ", "") || "";
expect(sourceName).toBeTruthy();
// Ensure it's enabled, then disable it
if ((await firstSwitch.getAttribute("aria-checked")) === "false") {
await firstSwitch.click();
await expect(firstSwitch).toHaveAttribute("aria-checked", "true");
}
await firstSwitch.click();
await expect(firstSwitch).toHaveAttribute("aria-checked", "false");
// Verify localStorage was updated
const stored = await page.evaluate(
(key) => localStorage.getItem(key),
LOCAL_STORAGE_KEY
);
expect(stored).toBeTruthy();
expect(JSON.parse(stored!).sourcePreferences).toBeDefined();
// Reload and verify persistence
await page.reload();
await page.waitForLoadState("networkidle");
await openActionManagement(page);
await openSourceManagement(page);
const sourceToggle = getSourceToggle(page, sourceName);
await expect(sourceToggle).toHaveAttribute("aria-checked", "false", {
timeout: 10000,
});
});
test("disabling search tool clears sources, re-enabling restores them", async ({
page,
}) => {
await openActionManagement(page);
await expect(page.locator(TOOL_IDS.searchOption)).toBeVisible({
timeout: 10000,
});
// Open source management and count enabled sources
await openSourceManagement(page);
const switches = page.locator('[role="switch"]');
await expect(switches.first()).toBeVisible({ timeout: 5000 });
const totalSources = await switches.count();
let enabledBefore = 0;
for (let i = 0; i < totalSources; i++) {
if ((await switches.nth(i).getAttribute("aria-checked")) === "true") {
enabledBefore++;
}
}
expect(enabledBefore).toBeGreaterThan(0);
// Go back to primary view
await page.locator('button[aria-label="Back"]').click();
await expect(page.locator(TOOL_IDS.searchOption)).toBeVisible();
// Disable the search tool
await toggleToolDisabled(page, TOOL_IDS.searchOption);
// Verify localStorage was written (the fix being tested)
const stored = await page.evaluate(
(key) => localStorage.getItem(key),
LOCAL_STORAGE_KEY
);
expect(stored).toBeTruthy();
// Re-enable the search tool
await toggleToolDisabled(page, TOOL_IDS.searchOption);
// Verify sources were restored
await openSourceManagement(page);
const switchesAfter = page.locator('[role="switch"]');
const totalAfter = await switchesAfter.count();
let enabledAfter = 0;
for (let i = 0; i < totalAfter; i++) {
if (
(await switchesAfter.nth(i).getAttribute("aria-checked")) === "true"
) {
enabledAfter++;
}
}
expect(enabledAfter).toBe(enabledBefore);
});
test("tool enabled and disabled states both persist across reload", async ({
page,
}) => {
await openActionManagement(page);
const searchOption = page.locator(TOOL_IDS.searchOption);
await expect(searchOption).toBeVisible({ timeout: 10000 });
// The slash button says "Disable" when the tool is enabled
await searchOption.hover();
const slashButton = searchOption.locator(
'button[aria-label="Disable"], button[aria-label="Enable"]'
);
await expect(slashButton.first()).toHaveAttribute("aria-label", "Disable");
// Reload — enabled state should persist
await page.reload();
await page.waitForLoadState("networkidle");
await openActionManagement(page);
await page.locator(TOOL_IDS.searchOption).hover();
await expect(
page
.locator(TOOL_IDS.searchOption)
.locator('button[aria-label="Disable"], button[aria-label="Enable"]')
.first()
).toHaveAttribute("aria-label", "Disable");
// Disable the search tool
await toggleToolDisabled(page, TOOL_IDS.searchOption);
// Verify it's now disabled (slash button says "Enable")
await page.locator(TOOL_IDS.searchOption).hover();
await expect(
page
.locator(TOOL_IDS.searchOption)
.locator('button[aria-label="Disable"], button[aria-label="Enable"]')
.first()
).toHaveAttribute("aria-label", "Enable");
// Reload — disabled state should also persist (saved to DB)
await page.reload();
await page.waitForLoadState("networkidle");
await openActionManagement(page);
await page.locator(TOOL_IDS.searchOption).hover();
await expect(
page
.locator(TOOL_IDS.searchOption)
.locator('button[aria-label="Disable"], button[aria-label="Enable"]')
.first()
).toHaveAttribute("aria-label", "Enable");
// Re-enable the tool for cleanup (serial tests follow)
await toggleToolDisabled(page, TOOL_IDS.searchOption);
});
});

View File

@@ -37,3 +37,40 @@ export async function isActionTogglePresent(page: Page): Promise<boolean> {
const el = await page.$(TOOL_IDS.actionToggle);
return !!el;
}
/**
* Click the disable/enable (slash) button on a tool line item.
* The button is hidden until hover; we hover first, then force-click
* using aria-label which matches the button's current state.
*/
export async function toggleToolDisabled(
page: Page,
toolSelector: string
): Promise<void> {
const toolOption = page.locator(toolSelector);
await toolOption.hover();
const slashButton = toolOption.locator(
'button[aria-label="Disable"], button[aria-label="Enable"]'
);
await slashButton.first().click({ force: true });
}
/**
* Open the source management secondary view for the internal search tool.
* Assumes the ActionsPopover is already open.
*/
export async function openSourceManagement(page: Page): Promise<void> {
const searchOption = page.locator(TOOL_IDS.searchOption);
await searchOption
.locator('button[aria-label="Configure Connectors"]')
.click();
// Wait for the source list Back button (indicates secondary view is open)
await page.locator('button[aria-label="Back"]').waitFor({ timeout: 5000 });
}
/**
* Get a source toggle Switch in the source management view by display name.
*/
export function getSourceToggle(page: Page, sourceName: string) {
return page.locator(`[aria-label="Toggle ${sourceName}"]`);
}