mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-25 17:42:41 +00:00
Compare commits
13 Commits
bo/hook_ui
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
55b24d72b4 | ||
|
|
3321a84c7d | ||
|
|
54bf32a5f8 | ||
|
|
4bb6b76be6 | ||
|
|
db94562474 | ||
|
|
582d4642c1 | ||
|
|
3caaecdb0e | ||
|
|
039b69806b | ||
|
|
63971d4958 | ||
|
|
ffd897f380 | ||
|
|
4745069232 | ||
|
|
386782f188 | ||
|
|
ff009c4129 |
12
.vscode/launch.json
vendored
12
.vscode/launch.json
vendored
@@ -117,7 +117,8 @@
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "API Server Console"
|
||||
"consoleTitle": "API Server Console",
|
||||
"justMyCode": false
|
||||
},
|
||||
{
|
||||
"name": "Slack Bot",
|
||||
@@ -268,7 +269,8 @@
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery heavy Console"
|
||||
"consoleTitle": "Celery heavy Console",
|
||||
"justMyCode": false
|
||||
},
|
||||
{
|
||||
"name": "Celery kg_processing",
|
||||
@@ -355,7 +357,8 @@
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery user_file_processing Console"
|
||||
"consoleTitle": "Celery user_file_processing Console",
|
||||
"justMyCode": false
|
||||
},
|
||||
{
|
||||
"name": "Celery docfetching",
|
||||
@@ -413,7 +416,8 @@
|
||||
"presentation": {
|
||||
"group": "2"
|
||||
},
|
||||
"consoleTitle": "Celery docprocessing Console"
|
||||
"consoleTitle": "Celery docprocessing Console",
|
||||
"justMyCode": false
|
||||
},
|
||||
{
|
||||
"name": "Celery beat",
|
||||
|
||||
@@ -0,0 +1,109 @@
|
||||
"""group_permissions_phase1
|
||||
|
||||
Revision ID: 25a5501dc766
|
||||
Revises: b728689f45b1
|
||||
Create Date: 2026-03-23 11:41:25.557442
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import fastapi_users_db_sqlalchemy
|
||||
import sqlalchemy as sa
|
||||
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.enums import GrantSource
|
||||
from onyx.db.enums import Permission
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "25a5501dc766"
|
||||
down_revision = "b728689f45b1"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# 1. Add account_type column to user table (nullable for now).
|
||||
# TODO(subash): backfill account_type for existing rows and add NOT NULL.
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"account_type",
|
||||
sa.Enum(AccountType, native_enum=False),
|
||||
nullable=True,
|
||||
),
|
||||
)
|
||||
|
||||
# 2. Add is_default column to user_group table
|
||||
op.add_column(
|
||||
"user_group",
|
||||
sa.Column(
|
||||
"is_default",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
),
|
||||
)
|
||||
|
||||
# 3. Create permission_grant table
|
||||
op.create_table(
|
||||
"permission_grant",
|
||||
sa.Column("id", sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column("group_id", sa.Integer(), nullable=False),
|
||||
sa.Column(
|
||||
"permission",
|
||||
sa.Enum(Permission, native_enum=False),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"grant_source",
|
||||
sa.Enum(GrantSource, native_enum=False),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"granted_by",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"granted_at",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.func.now(),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"is_deleted",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.false(),
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["group_id"],
|
||||
["user_group.id"],
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["granted_by"],
|
||||
["user.id"],
|
||||
ondelete="SET NULL",
|
||||
),
|
||||
sa.UniqueConstraint(
|
||||
"group_id", "permission", name="uq_permission_grant_group_permission"
|
||||
),
|
||||
)
|
||||
|
||||
# 4. Index on user__user_group(user_id) — existing composite PK
|
||||
# has user_group_id as leading column; user-filtered queries need this
|
||||
op.create_index(
|
||||
"ix_user__user_group_user_id",
|
||||
"user__user_group",
|
||||
["user_id"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_user__user_group_user_id", table_name="user__user_group")
|
||||
op.drop_table("permission_grant")
|
||||
op.drop_column("user_group", "is_default")
|
||||
op.drop_column("user", "account_type")
|
||||
@@ -115,8 +115,14 @@ def fetch_user_group_token_rate_limits_for_user(
|
||||
ordered: bool = True,
|
||||
get_editable: bool = True,
|
||||
) -> Sequence[TokenRateLimit]:
|
||||
stmt = select(TokenRateLimit)
|
||||
stmt = stmt.where(User__UserGroup.user_group_id == group_id)
|
||||
stmt = (
|
||||
select(TokenRateLimit)
|
||||
.join(
|
||||
TokenRateLimit__UserGroup,
|
||||
TokenRateLimit.id == TokenRateLimit__UserGroup.rate_limit_id,
|
||||
)
|
||||
.where(TokenRateLimit__UserGroup.user_group_id == group_id)
|
||||
)
|
||||
stmt = _add_user_filters(stmt, user, get_editable)
|
||||
|
||||
if enabled_only:
|
||||
|
||||
@@ -56,7 +56,7 @@ def _run_single_search(
|
||||
chunk_search_request=chunk_search_request,
|
||||
document_index=document_index,
|
||||
user=user,
|
||||
persona=None, # No persona for direct search
|
||||
persona_search_info=None,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
@@ -13,6 +13,14 @@ from celery.signals import worker_shutdown
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_DOCFETCHING_APP_NAME
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_postrun
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_prerun
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_rejected
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_retry
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_revoked
|
||||
from onyx.server.metrics.indexing_task_metrics import on_indexing_task_postrun
|
||||
from onyx.server.metrics.indexing_task_metrics import on_indexing_task_prerun
|
||||
from onyx.server.metrics.metrics_server import start_metrics_server
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
@@ -34,6 +42,8 @@ def on_task_prerun(
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||
on_celery_task_prerun(task_id, task)
|
||||
on_indexing_task_prerun(task_id, task, kwargs)
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
@@ -48,6 +58,36 @@ def on_task_postrun(
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||
on_celery_task_postrun(task_id, task, state)
|
||||
on_indexing_task_postrun(task_id, task, kwargs, state)
|
||||
|
||||
|
||||
@signals.task_retry.connect
|
||||
def on_task_retry(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
|
||||
# task_retry signal doesn't pass task_id in kwargs; get it from
|
||||
# the sender (the task instance) via sender.request.id.
|
||||
task_id = getattr(getattr(sender, "request", None), "id", None)
|
||||
on_celery_task_retry(task_id, sender)
|
||||
|
||||
|
||||
@signals.task_revoked.connect
|
||||
def on_task_revoked(sender: Any | None = None, **kwargs: Any) -> None:
|
||||
task_name = getattr(sender, "name", None) or str(sender)
|
||||
on_celery_task_revoked(kwargs.get("task_id"), task_name)
|
||||
|
||||
|
||||
@signals.task_rejected.connect
|
||||
def on_task_rejected(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
|
||||
# task_rejected sends the Consumer as sender, not the task instance.
|
||||
# The task name must be extracted from the Celery message headers.
|
||||
message = kwargs.get("message")
|
||||
task_name: str | None = None
|
||||
if message is not None:
|
||||
headers = getattr(message, "headers", None) or {}
|
||||
task_name = headers.get("task")
|
||||
if task_name is None:
|
||||
task_name = "unknown"
|
||||
on_celery_task_rejected(None, task_name)
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
@@ -76,6 +116,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
start_metrics_server("docfetching")
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -14,6 +14,14 @@ from celery.signals import worker_shutdown
|
||||
import onyx.background.celery.apps.app_base as app_base
|
||||
from onyx.configs.constants import POSTGRES_CELERY_WORKER_DOCPROCESSING_APP_NAME
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_postrun
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_prerun
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_rejected
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_retry
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_revoked
|
||||
from onyx.server.metrics.indexing_task_metrics import on_indexing_task_postrun
|
||||
from onyx.server.metrics.indexing_task_metrics import on_indexing_task_prerun
|
||||
from onyx.server.metrics.metrics_server import start_metrics_server
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
@@ -35,6 +43,8 @@ def on_task_prerun(
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_prerun(sender, task_id, task, args, kwargs, **kwds)
|
||||
on_celery_task_prerun(task_id, task)
|
||||
on_indexing_task_prerun(task_id, task, kwargs)
|
||||
|
||||
|
||||
@signals.task_postrun.connect
|
||||
@@ -49,6 +59,36 @@ def on_task_postrun(
|
||||
**kwds: Any,
|
||||
) -> None:
|
||||
app_base.on_task_postrun(sender, task_id, task, args, kwargs, retval, state, **kwds)
|
||||
on_celery_task_postrun(task_id, task, state)
|
||||
on_indexing_task_postrun(task_id, task, kwargs, state)
|
||||
|
||||
|
||||
@signals.task_retry.connect
|
||||
def on_task_retry(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
|
||||
# task_retry signal doesn't pass task_id in kwargs; get it from
|
||||
# the sender (the task instance) via sender.request.id.
|
||||
task_id = getattr(getattr(sender, "request", None), "id", None)
|
||||
on_celery_task_retry(task_id, sender)
|
||||
|
||||
|
||||
@signals.task_revoked.connect
|
||||
def on_task_revoked(sender: Any | None = None, **kwargs: Any) -> None:
|
||||
task_name = getattr(sender, "name", None) or str(sender)
|
||||
on_celery_task_revoked(kwargs.get("task_id"), task_name)
|
||||
|
||||
|
||||
@signals.task_rejected.connect
|
||||
def on_task_rejected(sender: Any | None = None, **kwargs: Any) -> None: # noqa: ARG001
|
||||
# task_rejected sends the Consumer as sender, not the task instance.
|
||||
# The task name must be extracted from the Celery message headers.
|
||||
message = kwargs.get("message")
|
||||
task_name: str | None = None
|
||||
if message is not None:
|
||||
headers = getattr(message, "headers", None) or {}
|
||||
task_name = headers.get("task")
|
||||
if task_name is None:
|
||||
task_name = "unknown"
|
||||
on_celery_task_rejected(None, task_name)
|
||||
|
||||
|
||||
@celeryd_init.connect
|
||||
@@ -82,6 +122,7 @@ def on_worker_init(sender: Worker, **kwargs: Any) -> None:
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
start_metrics_server("docprocessing")
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
@@ -90,6 +131,12 @@ def on_worker_shutdown(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_worker_shutdown(sender, **kwargs)
|
||||
|
||||
|
||||
# Note: worker_process_init only fires in prefork pool mode. Docprocessing uses
|
||||
# worker_pool="threads" (see configs/docprocessing.py), so this handler is
|
||||
# effectively a no-op in normal operation. It remains as a safety net in case
|
||||
# the pool type is ever changed to prefork. Prometheus metrics are safe in
|
||||
# thread-pool mode since all threads share the same process memory and can
|
||||
# update the same Counter/Gauge/Histogram objects directly.
|
||||
@worker_process_init.connect
|
||||
def init_worker(**kwargs: Any) -> None: # noqa: ARG001
|
||||
SqlEngine.reset_engine()
|
||||
|
||||
@@ -54,8 +54,14 @@ def on_celeryd_init(sender: Any = None, conf: Any = None, **kwargs: Any) -> None
|
||||
app_base.on_celeryd_init(sender, conf, **kwargs)
|
||||
|
||||
|
||||
# Set by on_worker_init so on_worker_ready knows whether to start the server.
|
||||
_prometheus_collectors_ok: bool = False
|
||||
|
||||
|
||||
@worker_init.connect
|
||||
def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
global _prometheus_collectors_ok
|
||||
|
||||
logger.info("worker_init signal received.")
|
||||
logger.info(f"Multiprocessing start method: {multiprocessing.get_start_method()}")
|
||||
|
||||
@@ -65,6 +71,8 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.wait_for_redis(sender, **kwargs)
|
||||
app_base.wait_for_db(sender, **kwargs)
|
||||
|
||||
_prometheus_collectors_ok = _setup_prometheus_collectors(sender)
|
||||
|
||||
# Less startup checks in multi-tenant case
|
||||
if MULTI_TENANT:
|
||||
return
|
||||
@@ -72,8 +80,37 @@ def on_worker_init(sender: Any, **kwargs: Any) -> None:
|
||||
app_base.on_secondary_worker_init(sender, **kwargs)
|
||||
|
||||
|
||||
def _setup_prometheus_collectors(sender: Any) -> bool:
|
||||
"""Register Prometheus collectors that need Redis/DB access.
|
||||
|
||||
Passes the Celery app so the queue depth collector can obtain a fresh
|
||||
broker Redis client on each scrape (rather than holding a stale reference).
|
||||
|
||||
Returns True if registration succeeded, False otherwise.
|
||||
"""
|
||||
try:
|
||||
from onyx.server.metrics.indexing_pipeline_setup import (
|
||||
setup_indexing_pipeline_metrics,
|
||||
)
|
||||
|
||||
setup_indexing_pipeline_metrics(sender.app)
|
||||
logger.info("Prometheus indexing pipeline collectors registered")
|
||||
return True
|
||||
except Exception:
|
||||
logger.exception("Failed to register Prometheus indexing pipeline collectors")
|
||||
return False
|
||||
|
||||
|
||||
@worker_ready.connect
|
||||
def on_worker_ready(sender: Any, **kwargs: Any) -> None:
|
||||
if _prometheus_collectors_ok:
|
||||
from onyx.server.metrics.metrics_server import start_metrics_server
|
||||
|
||||
start_metrics_server("monitoring")
|
||||
else:
|
||||
logger.warning(
|
||||
"Skipping Prometheus metrics server — collector registration failed"
|
||||
)
|
||||
app_base.on_worker_ready(sender, **kwargs)
|
||||
|
||||
|
||||
|
||||
@@ -123,7 +123,7 @@ class OnyxConfluence:
|
||||
|
||||
self.shared_base_kwargs: dict[str, str | int | bool] = {
|
||||
"api_version": "cloud" if is_cloud else "latest",
|
||||
"backoff_and_retry": True,
|
||||
"backoff_and_retry": False,
|
||||
"cloud": is_cloud,
|
||||
}
|
||||
if timeout:
|
||||
@@ -456,7 +456,7 @@ class OnyxConfluence:
|
||||
return attr(*args, **kwargs)
|
||||
|
||||
except HTTPError as e:
|
||||
delay_until = _handle_http_error(e, attempt)
|
||||
delay_until = _handle_http_error(e, attempt, MAX_RETRIES)
|
||||
logger.warning(
|
||||
f"HTTPError in confluence call. Retrying in {delay_until} seconds..."
|
||||
)
|
||||
|
||||
@@ -363,7 +363,7 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
|
||||
# and applying our own retries in a more specific set of circumstances
|
||||
return confluence_call(*args, **kwargs)
|
||||
except requests.HTTPError as e:
|
||||
delay_until = _handle_http_error(e, attempt)
|
||||
delay_until = _handle_http_error(e, attempt, MAX_RETRIES)
|
||||
logger.warning(
|
||||
f"HTTPError in confluence call. Retrying in {delay_until} seconds..."
|
||||
)
|
||||
@@ -384,7 +384,7 @@ def handle_confluence_rate_limit(confluence_call: F) -> F:
|
||||
return cast(F, wrapped_call)
|
||||
|
||||
|
||||
def _handle_http_error(e: requests.HTTPError, attempt: int) -> int:
|
||||
def _handle_http_error(e: requests.HTTPError, attempt: int, max_retries: int) -> int:
|
||||
MIN_DELAY = 2
|
||||
MAX_DELAY = 60
|
||||
STARTING_DELAY = 5
|
||||
@@ -408,6 +408,17 @@ def _handle_http_error(e: requests.HTTPError, attempt: int) -> int:
|
||||
|
||||
raise e
|
||||
|
||||
if e.response.status_code >= 500:
|
||||
if attempt >= max_retries - 1:
|
||||
raise e
|
||||
|
||||
delay = min(STARTING_DELAY * (BACKOFF**attempt), MAX_DELAY)
|
||||
logger.warning(
|
||||
f"Server error {e.response.status_code}. "
|
||||
f"Retrying in {delay} seconds (attempt {attempt + 1})..."
|
||||
)
|
||||
return math.ceil(time.monotonic() + delay)
|
||||
|
||||
if (
|
||||
e.response.status_code != 429
|
||||
and RATE_LIMIT_MESSAGE_LOWERCASE not in e.response.text.lower()
|
||||
|
||||
@@ -401,3 +401,16 @@ class SavedSearchDocWithContent(SavedSearchDoc):
|
||||
section in addition to the match_highlights."""
|
||||
|
||||
content: str
|
||||
|
||||
|
||||
class PersonaSearchInfo(BaseModel):
|
||||
"""Snapshot of persona data needed by the search pipeline.
|
||||
|
||||
Extracted from the ORM Persona before the DB session is released so that
|
||||
SearchTool and search_pipeline never lazy-load relationships post-commit.
|
||||
"""
|
||||
|
||||
document_set_names: list[str]
|
||||
search_start_date: datetime | None
|
||||
attached_document_ids: list[str]
|
||||
hierarchy_node_ids: list[int]
|
||||
|
||||
@@ -9,12 +9,12 @@ from onyx.context.search.models import ChunkSearchRequest
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import PersonaSearchInfo
|
||||
from onyx.context.search.preprocessing.access_filters import (
|
||||
build_access_filters_for_user,
|
||||
)
|
||||
from onyx.context.search.retrieval.search_runner import search_chunks
|
||||
from onyx.context.search.utils import inference_section_from_chunks
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.federated_connectors.federated_retrieval import FederatedRetrievalInfo
|
||||
@@ -247,8 +247,8 @@ def search_pipeline(
|
||||
document_index: DocumentIndex,
|
||||
# Used for ACLs and federated search, anonymous users only see public docs
|
||||
user: User,
|
||||
# Used for default filters and settings
|
||||
persona: Persona | None,
|
||||
# Pre-extracted persona search configuration (None when no persona)
|
||||
persona_search_info: PersonaSearchInfo | None,
|
||||
db_session: Session | None = None,
|
||||
auto_detect_filters: bool = False,
|
||||
llm: LLM | None = None,
|
||||
@@ -263,24 +263,18 @@ def search_pipeline(
|
||||
prefetched_federated_retrieval_infos: list[FederatedRetrievalInfo] | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
persona_document_sets: list[str] | None = (
|
||||
[persona_document_set.name for persona_document_set in persona.document_sets]
|
||||
if persona
|
||||
else None
|
||||
persona_search_info.document_set_names if persona_search_info else None
|
||||
)
|
||||
persona_time_cutoff: datetime | None = (
|
||||
persona.search_start_date if persona else None
|
||||
persona_search_info.search_start_date if persona_search_info else None
|
||||
)
|
||||
|
||||
# Extract assistant knowledge filters from persona
|
||||
attached_document_ids: list[str] | None = (
|
||||
[doc.id for doc in persona.attached_documents]
|
||||
if persona and persona.attached_documents
|
||||
persona_search_info.attached_document_ids or None
|
||||
if persona_search_info
|
||||
else None
|
||||
)
|
||||
hierarchy_node_ids: list[int] | None = (
|
||||
[node.id for node in persona.hierarchy_nodes]
|
||||
if persona and persona.hierarchy_nodes
|
||||
else None
|
||||
persona_search_info.hierarchy_node_ids or None if persona_search_info else None
|
||||
)
|
||||
|
||||
filters = _build_index_filters(
|
||||
|
||||
@@ -64,6 +64,9 @@ def get_chat_session_by_id(
|
||||
joinedload(ChatSession.persona).options(
|
||||
selectinload(Persona.tools),
|
||||
selectinload(Persona.user_files),
|
||||
selectinload(Persona.document_sets),
|
||||
selectinload(Persona.attached_documents),
|
||||
selectinload(Persona.hierarchy_nodes),
|
||||
),
|
||||
joinedload(ChatSession.project),
|
||||
)
|
||||
|
||||
@@ -750,3 +750,31 @@ def resync_cc_pair(
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
# ── Metrics query helpers ──────────────────────────────────────────────
|
||||
|
||||
|
||||
def get_connector_health_for_metrics(
|
||||
db_session: Session,
|
||||
) -> list: # Returns list of Row tuples
|
||||
"""Return connector health data for Prometheus metrics.
|
||||
|
||||
Each row is (cc_pair_id, status, in_repeated_error_state,
|
||||
last_successful_index_time, name, source).
|
||||
"""
|
||||
return (
|
||||
db_session.query(
|
||||
ConnectorCredentialPair.id,
|
||||
ConnectorCredentialPair.status,
|
||||
ConnectorCredentialPair.in_repeated_error_state,
|
||||
ConnectorCredentialPair.last_successful_index_time,
|
||||
ConnectorCredentialPair.name,
|
||||
Connector.source,
|
||||
)
|
||||
.join(
|
||||
Connector,
|
||||
ConnectorCredentialPair.connector_id == Connector.id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
@@ -1,4 +1,31 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum as PyEnum
|
||||
from typing import ClassVar
|
||||
|
||||
|
||||
class AccountType(str, PyEnum):
|
||||
"""
|
||||
What kind of account this is — determines whether the user
|
||||
enters the group-based permission system.
|
||||
|
||||
STANDARD + SERVICE_ACCOUNT → participate in group system
|
||||
BOT, EXT_PERM_USER, ANONYMOUS → fixed behavior
|
||||
"""
|
||||
|
||||
STANDARD = "standard"
|
||||
BOT = "bot"
|
||||
EXT_PERM_USER = "ext_perm_user"
|
||||
SERVICE_ACCOUNT = "service_account"
|
||||
ANONYMOUS = "anonymous"
|
||||
|
||||
|
||||
class GrantSource(str, PyEnum):
|
||||
"""How a permission grant was created."""
|
||||
|
||||
USER = "user"
|
||||
SCIM = "scim"
|
||||
SYSTEM = "system"
|
||||
|
||||
|
||||
class IndexingStatus(str, PyEnum):
|
||||
@@ -314,3 +341,54 @@ class HookPoint(str, PyEnum):
|
||||
class HookFailStrategy(str, PyEnum):
|
||||
HARD = "hard" # exception propagates, pipeline aborts
|
||||
SOFT = "soft" # log error, return original input, pipeline continues
|
||||
|
||||
|
||||
class Permission(str, PyEnum):
|
||||
"""
|
||||
Permission tokens for group-based authorization.
|
||||
19 tokens total. full_admin_panel_access is an override —
|
||||
if present, any permission check passes.
|
||||
"""
|
||||
|
||||
# Basic (auto-granted to every new group)
|
||||
BASIC_ACCESS = "basic"
|
||||
|
||||
# Read tokens — implied only, never granted directly
|
||||
READ_CONNECTORS = "read:connectors"
|
||||
READ_DOCUMENT_SETS = "read:document_sets"
|
||||
READ_AGENTS = "read:agents"
|
||||
READ_USERS = "read:users"
|
||||
|
||||
# Add / Manage pairs
|
||||
ADD_AGENTS = "add:agents"
|
||||
MANAGE_AGENTS = "manage:agents"
|
||||
MANAGE_DOCUMENT_SETS = "manage:document_sets"
|
||||
ADD_CONNECTORS = "add:connectors"
|
||||
MANAGE_CONNECTORS = "manage:connectors"
|
||||
MANAGE_LLMS = "manage:llms"
|
||||
|
||||
# Toggle tokens
|
||||
READ_AGENT_ANALYTICS = "read:agent_analytics"
|
||||
MANAGE_ACTIONS = "manage:actions"
|
||||
READ_QUERY_HISTORY = "read:query_history"
|
||||
MANAGE_USER_GROUPS = "manage:user_groups"
|
||||
CREATE_USER_API_KEYS = "create:user_api_keys"
|
||||
CREATE_SERVICE_ACCOUNT_API_KEYS = "create:service_account_api_keys"
|
||||
CREATE_SLACK_DISCORD_BOTS = "create:slack_discord_bots"
|
||||
|
||||
# Override — any permission check passes
|
||||
FULL_ADMIN_PANEL_ACCESS = "admin"
|
||||
|
||||
# Permissions that are implied by other grants and must never be stored
|
||||
# directly in the permission_grant table.
|
||||
IMPLIED: ClassVar[frozenset[Permission]]
|
||||
|
||||
|
||||
Permission.IMPLIED = frozenset(
|
||||
{
|
||||
Permission.READ_CONNECTORS,
|
||||
Permission.READ_DOCUMENT_SETS,
|
||||
Permission.READ_AGENTS,
|
||||
Permission.READ_USERS,
|
||||
}
|
||||
)
|
||||
|
||||
@@ -2,6 +2,8 @@ from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import NamedTuple
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TypeVarTuple
|
||||
|
||||
from sqlalchemy import and_
|
||||
@@ -28,6 +30,9 @@ from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from onyx.configs.constants import DocumentSource
|
||||
|
||||
# from sqlalchemy.sql.selectable import Select
|
||||
|
||||
# Comment out unused imports that cause mypy errors
|
||||
@@ -972,3 +977,106 @@ def get_index_attempt_errors_for_cc_pair(
|
||||
stmt = stmt.offset(page * page_size).limit(page_size)
|
||||
|
||||
return list(db_session.scalars(stmt).all())
|
||||
|
||||
|
||||
# ── Metrics query helpers ──────────────────────────────────────────────
|
||||
|
||||
|
||||
class ActiveIndexAttemptMetric(NamedTuple):
|
||||
"""Row returned by get_active_index_attempts_for_metrics."""
|
||||
|
||||
status: IndexingStatus
|
||||
source: "DocumentSource"
|
||||
cc_pair_id: int
|
||||
cc_pair_name: str | None
|
||||
attempt_count: int
|
||||
|
||||
|
||||
def get_active_index_attempts_for_metrics(
|
||||
db_session: Session,
|
||||
) -> list[ActiveIndexAttemptMetric]:
|
||||
"""Return non-terminal index attempts grouped by status, source, and connector.
|
||||
|
||||
Each row is (status, source, cc_pair_id, cc_pair_name, attempt_count).
|
||||
"""
|
||||
from onyx.db.models import Connector
|
||||
|
||||
terminal_statuses = [s for s in IndexingStatus if s.is_terminal()]
|
||||
rows = (
|
||||
db_session.query(
|
||||
IndexAttempt.status,
|
||||
Connector.source,
|
||||
ConnectorCredentialPair.id,
|
||||
ConnectorCredentialPair.name,
|
||||
func.count(),
|
||||
)
|
||||
.join(
|
||||
ConnectorCredentialPair,
|
||||
IndexAttempt.connector_credential_pair_id == ConnectorCredentialPair.id,
|
||||
)
|
||||
.join(
|
||||
Connector,
|
||||
ConnectorCredentialPair.connector_id == Connector.id,
|
||||
)
|
||||
.filter(IndexAttempt.status.notin_(terminal_statuses))
|
||||
.group_by(
|
||||
IndexAttempt.status,
|
||||
Connector.source,
|
||||
ConnectorCredentialPair.id,
|
||||
ConnectorCredentialPair.name,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
return [ActiveIndexAttemptMetric(*row) for row in rows]
|
||||
|
||||
|
||||
def get_failed_attempt_counts_by_cc_pair(
|
||||
db_session: Session,
|
||||
since: datetime | None = None,
|
||||
) -> dict[int, int]:
|
||||
"""Return {cc_pair_id: failed_attempt_count} for all connectors.
|
||||
|
||||
When ``since`` is provided, only attempts created after that timestamp
|
||||
are counted. Defaults to the last 90 days to avoid unbounded historical
|
||||
aggregation.
|
||||
"""
|
||||
if since is None:
|
||||
since = datetime.now(timezone.utc) - timedelta(days=90)
|
||||
|
||||
rows = (
|
||||
db_session.query(
|
||||
IndexAttempt.connector_credential_pair_id,
|
||||
func.count(),
|
||||
)
|
||||
.filter(IndexAttempt.status == IndexingStatus.FAILED)
|
||||
.filter(IndexAttempt.time_created >= since)
|
||||
.group_by(IndexAttempt.connector_credential_pair_id)
|
||||
.all()
|
||||
)
|
||||
return {cc_id: count for cc_id, count in rows}
|
||||
|
||||
|
||||
def get_docs_indexed_by_cc_pair(
|
||||
db_session: Session,
|
||||
since: datetime | None = None,
|
||||
) -> dict[int, int]:
|
||||
"""Return {cc_pair_id: total_new_docs_indexed} across successful attempts.
|
||||
|
||||
Only counts attempts with status SUCCESS to avoid inflating counts with
|
||||
partial results from failed attempts. When ``since`` is provided, only
|
||||
attempts created after that timestamp are included.
|
||||
"""
|
||||
if since is None:
|
||||
since = datetime.now(timezone.utc) - timedelta(days=90)
|
||||
|
||||
query = (
|
||||
db_session.query(
|
||||
IndexAttempt.connector_credential_pair_id,
|
||||
func.sum(func.coalesce(IndexAttempt.new_docs_indexed, 0)),
|
||||
)
|
||||
.filter(IndexAttempt.status == IndexingStatus.SUCCESS)
|
||||
.filter(IndexAttempt.time_created >= since)
|
||||
.group_by(IndexAttempt.connector_credential_pair_id)
|
||||
)
|
||||
rows = query.all()
|
||||
return {cc_id: int(total or 0) for cc_id, total in rows}
|
||||
|
||||
@@ -48,6 +48,7 @@ from sqlalchemy.types import LargeBinary
|
||||
from sqlalchemy.types import TypeDecorator
|
||||
from sqlalchemy import PrimaryKeyConstraint
|
||||
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.constants import (
|
||||
ANONYMOUS_USER_UUID,
|
||||
@@ -78,6 +79,8 @@ from onyx.db.enums import (
|
||||
MCPAuthenticationPerformer,
|
||||
MCPTransport,
|
||||
MCPServerStatus,
|
||||
Permission,
|
||||
GrantSource,
|
||||
LLMModelFlowType,
|
||||
ThemePreference,
|
||||
DefaultAppMode,
|
||||
@@ -302,6 +305,9 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
role: Mapped[UserRole] = mapped_column(
|
||||
Enum(UserRole, native_enum=False, default=UserRole.BASIC)
|
||||
)
|
||||
account_type: Mapped[AccountType | None] = mapped_column(
|
||||
Enum(AccountType, native_enum=False), nullable=True
|
||||
)
|
||||
|
||||
"""
|
||||
Preferences probably should be in a separate table at some point, but for now
|
||||
@@ -3971,6 +3977,8 @@ class SamlAccount(Base):
|
||||
class User__UserGroup(Base):
|
||||
__tablename__ = "user__user_group"
|
||||
|
||||
__table_args__ = (Index("ix_user__user_group_user_id", "user_id"),)
|
||||
|
||||
is_curator: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
|
||||
user_group_id: Mapped[int] = mapped_column(
|
||||
@@ -3981,6 +3989,48 @@ class User__UserGroup(Base):
|
||||
)
|
||||
|
||||
|
||||
class PermissionGrant(Base):
|
||||
__tablename__ = "permission_grant"
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint(
|
||||
"group_id", "permission", name="uq_permission_grant_group_permission"
|
||||
),
|
||||
)
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
group_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("user_group.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
permission: Mapped[Permission] = mapped_column(
|
||||
Enum(Permission, native_enum=False), nullable=False
|
||||
)
|
||||
grant_source: Mapped[GrantSource] = mapped_column(
|
||||
Enum(GrantSource, native_enum=False), nullable=False
|
||||
)
|
||||
granted_by: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="SET NULL"), nullable=True
|
||||
)
|
||||
granted_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||
)
|
||||
is_deleted: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=False, server_default=text("false")
|
||||
)
|
||||
|
||||
group: Mapped["UserGroup"] = relationship(
|
||||
"UserGroup", back_populates="permission_grants"
|
||||
)
|
||||
|
||||
@validates("permission")
|
||||
def _validate_permission(self, _key: str, value: Permission) -> Permission:
|
||||
if value in Permission.IMPLIED:
|
||||
raise ValueError(
|
||||
f"{value!r} is an implied permission and cannot be granted directly"
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
class UserGroup__ConnectorCredentialPair(Base):
|
||||
__tablename__ = "user_group__connector_credential_pair"
|
||||
|
||||
@@ -4075,6 +4125,8 @@ class UserGroup(Base):
|
||||
is_up_for_deletion: Mapped[bool] = mapped_column(
|
||||
Boolean, nullable=False, default=False
|
||||
)
|
||||
# whether this is a default group (e.g. "Basic", "Admins") that cannot be deleted
|
||||
is_default: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
|
||||
# Last time a user updated this user group
|
||||
time_last_modified_by_user: Mapped[datetime.datetime] = mapped_column(
|
||||
@@ -4118,6 +4170,9 @@ class UserGroup(Base):
|
||||
accessible_mcp_servers: Mapped[list["MCPServer"]] = relationship(
|
||||
"MCPServer", secondary="mcp_server__user_group", back_populates="user_groups"
|
||||
)
|
||||
permission_grants: Mapped[list["PermissionGrant"]] = relationship(
|
||||
"PermissionGrant", back_populates="group", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
|
||||
"""Tables related to Token Rate Limiting
|
||||
|
||||
@@ -50,8 +50,18 @@ from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_default_behavior_persona(db_session: Session) -> Persona | None:
|
||||
def get_default_behavior_persona(
|
||||
db_session: Session,
|
||||
eager_load_for_tools: bool = False,
|
||||
) -> Persona | None:
|
||||
stmt = select(Persona).where(Persona.id == DEFAULT_PERSONA_ID)
|
||||
if eager_load_for_tools:
|
||||
stmt = stmt.options(
|
||||
selectinload(Persona.tools),
|
||||
selectinload(Persona.document_sets),
|
||||
selectinload(Persona.attached_documents),
|
||||
selectinload(Persona.hierarchy_nodes),
|
||||
)
|
||||
return db_session.scalars(stmt).first()
|
||||
|
||||
|
||||
|
||||
@@ -91,8 +91,6 @@ class HookResponse(BaseModel):
|
||||
# Nullable to match the DB column — endpoint_url is required on creation but
|
||||
# future hook point types may not use an external endpoint (e.g. built-in handlers).
|
||||
endpoint_url: str | None
|
||||
# Partially-masked API key (e.g. "abcd••••••••wxyz"), or None if no key is set.
|
||||
api_key_masked: str | None
|
||||
fail_strategy: HookFailStrategy
|
||||
timeout_seconds: float # always resolved — None from request is replaced with spec default before DB write
|
||||
is_active: bool
|
||||
|
||||
@@ -62,9 +62,6 @@ def _hook_to_response(hook: Hook, creator_email: str | None = None) -> HookRespo
|
||||
name=hook.name,
|
||||
hook_point=hook.hook_point,
|
||||
endpoint_url=hook.endpoint_url,
|
||||
api_key_masked=(
|
||||
hook.api_key.get_value(apply_mask=True) if hook.api_key else None
|
||||
),
|
||||
fail_strategy=hook.fail_strategy,
|
||||
timeout_seconds=hook.timeout_seconds,
|
||||
is_active=hook.is_active,
|
||||
|
||||
207
backend/onyx/server/metrics/celery_task_metrics.py
Normal file
207
backend/onyx/server/metrics/celery_task_metrics.py
Normal file
@@ -0,0 +1,207 @@
|
||||
"""Generic Celery task lifecycle Prometheus metrics.
|
||||
|
||||
Provides signal handlers that track task started/completed/failed counts,
|
||||
active task gauge, task duration histograms, and retry/reject/revoke counts.
|
||||
These fire for ALL tasks on the worker — no per-connector enrichment
|
||||
(see indexing_task_metrics.py for that).
|
||||
|
||||
Usage in a worker app module:
|
||||
from onyx.server.metrics.celery_task_metrics import (
|
||||
on_celery_task_prerun,
|
||||
on_celery_task_postrun,
|
||||
on_celery_task_retry,
|
||||
on_celery_task_revoked,
|
||||
on_celery_task_rejected,
|
||||
)
|
||||
# Call from the worker's existing signal handlers
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
|
||||
from celery import Task
|
||||
from prometheus_client import Counter
|
||||
from prometheus_client import Gauge
|
||||
from prometheus_client import Histogram
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
TASK_STARTED = Counter(
|
||||
"onyx_celery_task_started_total",
|
||||
"Total Celery tasks started",
|
||||
["task_name", "queue"],
|
||||
)
|
||||
|
||||
TASK_COMPLETED = Counter(
|
||||
"onyx_celery_task_completed_total",
|
||||
"Total Celery tasks completed",
|
||||
["task_name", "queue", "outcome"],
|
||||
)
|
||||
|
||||
TASK_DURATION = Histogram(
|
||||
"onyx_celery_task_duration_seconds",
|
||||
"Celery task execution duration in seconds",
|
||||
["task_name", "queue"],
|
||||
buckets=[1, 5, 15, 30, 60, 120, 300, 600, 1800, 3600],
|
||||
)
|
||||
|
||||
TASKS_ACTIVE = Gauge(
|
||||
"onyx_celery_tasks_active",
|
||||
"Currently executing Celery tasks",
|
||||
["task_name", "queue"],
|
||||
)
|
||||
|
||||
TASK_RETRIED = Counter(
|
||||
"onyx_celery_task_retried_total",
|
||||
"Total Celery tasks retried",
|
||||
["task_name", "queue"],
|
||||
)
|
||||
|
||||
TASK_REVOKED = Counter(
|
||||
"onyx_celery_task_revoked_total",
|
||||
"Total Celery tasks revoked (cancelled)",
|
||||
["task_name"],
|
||||
)
|
||||
|
||||
TASK_REJECTED = Counter(
|
||||
"onyx_celery_task_rejected_total",
|
||||
"Total Celery tasks rejected by worker",
|
||||
["task_name"],
|
||||
)
|
||||
|
||||
# task_id → (monotonic start time, metric labels)
|
||||
_task_start_times: dict[str, tuple[float, dict[str, str]]] = {}
|
||||
|
||||
# Lock protecting _task_start_times — prerun, postrun, and eviction may
|
||||
# run concurrently on thread-pool workers.
|
||||
_task_start_times_lock = threading.Lock()
|
||||
|
||||
# Entries older than this are evicted on each prerun to prevent unbounded
|
||||
# growth when tasks are killed (SIGTERM, OOM) and postrun never fires.
|
||||
_MAX_START_TIME_AGE_SECONDS = 3600 # 1 hour
|
||||
|
||||
|
||||
def _evict_stale_start_times() -> None:
|
||||
"""Remove _task_start_times entries older than _MAX_START_TIME_AGE_SECONDS.
|
||||
|
||||
Must be called while holding _task_start_times_lock.
|
||||
"""
|
||||
now = time.monotonic()
|
||||
stale_ids = [
|
||||
tid
|
||||
for tid, (start, _labels) in _task_start_times.items()
|
||||
if now - start > _MAX_START_TIME_AGE_SECONDS
|
||||
]
|
||||
for tid in stale_ids:
|
||||
entry = _task_start_times.pop(tid, None)
|
||||
if entry is not None:
|
||||
_labels = entry[1]
|
||||
# Decrement active gauge for evicted tasks — these tasks were
|
||||
# started but never completed (killed, OOM, etc.).
|
||||
active_gauge = TASKS_ACTIVE.labels(**_labels)
|
||||
if active_gauge._value.get() > 0:
|
||||
active_gauge.dec()
|
||||
|
||||
|
||||
def _get_task_labels(task: Task) -> dict[str, str]:
|
||||
"""Extract task_name and queue labels from a Celery Task instance."""
|
||||
task_name = task.name or "unknown"
|
||||
queue = "unknown"
|
||||
try:
|
||||
delivery_info = task.request.delivery_info
|
||||
if delivery_info:
|
||||
queue = delivery_info.get("routing_key") or "unknown"
|
||||
except AttributeError:
|
||||
pass
|
||||
return {"task_name": task_name, "queue": queue}
|
||||
|
||||
|
||||
def on_celery_task_prerun(
|
||||
task_id: str | None,
|
||||
task: Task | None,
|
||||
) -> None:
|
||||
"""Record task start. Call from the worker's task_prerun signal handler."""
|
||||
if task is None or task_id is None:
|
||||
return
|
||||
|
||||
try:
|
||||
labels = _get_task_labels(task)
|
||||
TASK_STARTED.labels(**labels).inc()
|
||||
TASKS_ACTIVE.labels(**labels).inc()
|
||||
with _task_start_times_lock:
|
||||
_evict_stale_start_times()
|
||||
_task_start_times[task_id] = (time.monotonic(), labels)
|
||||
except Exception:
|
||||
logger.debug("Failed to record celery task prerun metrics", exc_info=True)
|
||||
|
||||
|
||||
def on_celery_task_postrun(
|
||||
task_id: str | None,
|
||||
task: Task | None,
|
||||
state: str | None,
|
||||
) -> None:
|
||||
"""Record task completion. Call from the worker's task_postrun signal handler."""
|
||||
if task is None or task_id is None:
|
||||
return
|
||||
|
||||
try:
|
||||
labels = _get_task_labels(task)
|
||||
outcome = "success" if state == "SUCCESS" else "failure"
|
||||
TASK_COMPLETED.labels(**labels, outcome=outcome).inc()
|
||||
|
||||
# Guard against going below 0 if postrun fires without a matching
|
||||
# prerun (e.g. after a worker restart or stale entry eviction).
|
||||
active_gauge = TASKS_ACTIVE.labels(**labels)
|
||||
if active_gauge._value.get() > 0:
|
||||
active_gauge.dec()
|
||||
|
||||
with _task_start_times_lock:
|
||||
entry = _task_start_times.pop(task_id, None)
|
||||
if entry is not None:
|
||||
start_time, _stored_labels = entry
|
||||
TASK_DURATION.labels(**labels).observe(time.monotonic() - start_time)
|
||||
except Exception:
|
||||
logger.debug("Failed to record celery task postrun metrics", exc_info=True)
|
||||
|
||||
|
||||
def on_celery_task_retry(
|
||||
_task_id: str | None,
|
||||
task: Task | None,
|
||||
) -> None:
|
||||
"""Record task retry. Call from the worker's task_retry signal handler."""
|
||||
if task is None:
|
||||
return
|
||||
try:
|
||||
labels = _get_task_labels(task)
|
||||
TASK_RETRIED.labels(**labels).inc()
|
||||
except Exception:
|
||||
logger.debug("Failed to record celery task retry metrics", exc_info=True)
|
||||
|
||||
|
||||
def on_celery_task_revoked(
|
||||
_task_id: str | None,
|
||||
task_name: str | None = None,
|
||||
) -> None:
|
||||
"""Record task revocation. The revoked signal doesn't provide a Task
|
||||
instance, only the task name via sender."""
|
||||
if task_name is None:
|
||||
return
|
||||
try:
|
||||
TASK_REVOKED.labels(task_name=task_name).inc()
|
||||
except Exception:
|
||||
logger.debug("Failed to record celery task revoked metrics", exc_info=True)
|
||||
|
||||
|
||||
def on_celery_task_rejected(
|
||||
_task_id: str | None,
|
||||
task_name: str | None = None,
|
||||
) -> None:
|
||||
"""Record task rejection."""
|
||||
if task_name is None:
|
||||
return
|
||||
try:
|
||||
TASK_REJECTED.labels(task_name=task_name).inc()
|
||||
except Exception:
|
||||
logger.debug("Failed to record celery task rejected metrics", exc_info=True)
|
||||
528
backend/onyx/server/metrics/indexing_pipeline.py
Normal file
528
backend/onyx/server/metrics/indexing_pipeline.py
Normal file
@@ -0,0 +1,528 @@
|
||||
"""Prometheus collectors for Celery queue depths and indexing pipeline state.
|
||||
|
||||
These collectors query Redis and Postgres at scrape time (the Collector pattern),
|
||||
so metrics are always fresh when Prometheus scrapes /metrics. They run inside the
|
||||
monitoring celery worker which already has Redis and DB access.
|
||||
|
||||
To avoid hammering Redis/Postgres on every 15s scrape, results are cached with
|
||||
a configurable TTL (default 30s). This means metrics may be up to TTL seconds
|
||||
stale, which is fine for monitoring dashboards.
|
||||
"""
|
||||
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
from prometheus_client.core import GaugeMetricFamily
|
||||
from prometheus_client.registry import Collector
|
||||
from redis import Redis
|
||||
|
||||
from onyx.background.celery.celery_redis import celery_get_queue_length
|
||||
from onyx.background.celery.celery_redis import celery_get_unacked_task_ids
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Default cache TTL in seconds. Scrapes hitting within this window return
|
||||
# the previous result without re-querying Redis/Postgres.
|
||||
_DEFAULT_CACHE_TTL = 30.0
|
||||
|
||||
_QUEUE_LABEL_MAP: dict[str, str] = {
|
||||
OnyxCeleryQueues.PRIMARY: "primary",
|
||||
OnyxCeleryQueues.DOCPROCESSING: "docprocessing",
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING: "docfetching",
|
||||
OnyxCeleryQueues.VESPA_METADATA_SYNC: "vespa_metadata_sync",
|
||||
OnyxCeleryQueues.CONNECTOR_DELETION: "connector_deletion",
|
||||
OnyxCeleryQueues.CONNECTOR_PRUNING: "connector_pruning",
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_PERMISSIONS_SYNC: "permissions_sync",
|
||||
OnyxCeleryQueues.CONNECTOR_EXTERNAL_GROUP_SYNC: "external_group_sync",
|
||||
OnyxCeleryQueues.DOC_PERMISSIONS_UPSERT: "permissions_upsert",
|
||||
OnyxCeleryQueues.CONNECTOR_HIERARCHY_FETCHING: "hierarchy_fetching",
|
||||
OnyxCeleryQueues.LLM_MODEL_UPDATE: "llm_model_update",
|
||||
OnyxCeleryQueues.CHECKPOINT_CLEANUP: "checkpoint_cleanup",
|
||||
OnyxCeleryQueues.INDEX_ATTEMPT_CLEANUP: "index_attempt_cleanup",
|
||||
OnyxCeleryQueues.CSV_GENERATION: "csv_generation",
|
||||
OnyxCeleryQueues.USER_FILE_PROCESSING: "user_file_processing",
|
||||
OnyxCeleryQueues.USER_FILE_PROJECT_SYNC: "user_file_project_sync",
|
||||
OnyxCeleryQueues.USER_FILE_DELETE: "user_file_delete",
|
||||
OnyxCeleryQueues.MONITORING: "monitoring",
|
||||
OnyxCeleryQueues.SANDBOX: "sandbox",
|
||||
OnyxCeleryQueues.OPENSEARCH_MIGRATION: "opensearch_migration",
|
||||
}
|
||||
|
||||
# Queues where prefetched (unacked) task counts are meaningful
|
||||
_UNACKED_QUEUES: list[str] = [
|
||||
OnyxCeleryQueues.CONNECTOR_DOC_FETCHING,
|
||||
OnyxCeleryQueues.DOCPROCESSING,
|
||||
]
|
||||
|
||||
|
||||
class _CachedCollector(Collector):
|
||||
"""Base collector with TTL-based caching.
|
||||
|
||||
Subclasses implement ``_collect_fresh()`` to query the actual data source.
|
||||
The base ``collect()`` returns cached results if the TTL hasn't expired,
|
||||
avoiding repeated queries when Prometheus scrapes frequently.
|
||||
"""
|
||||
|
||||
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
|
||||
self._cache_ttl = cache_ttl
|
||||
self._cached_result: list[GaugeMetricFamily] | None = None
|
||||
self._last_collect_time: float = 0.0
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def collect(self) -> list[GaugeMetricFamily]:
|
||||
with self._lock:
|
||||
now = time.monotonic()
|
||||
if (
|
||||
now - self._last_collect_time < self._cache_ttl
|
||||
and self._cached_result is not None
|
||||
):
|
||||
return self._cached_result
|
||||
|
||||
try:
|
||||
result = self._collect_fresh()
|
||||
self._cached_result = result
|
||||
self._last_collect_time = now
|
||||
return result
|
||||
except Exception:
|
||||
logger.exception(f"Error in {type(self).__name__}.collect()")
|
||||
# Return stale cache on error rather than nothing — avoids
|
||||
# metrics disappearing during transient failures.
|
||||
return self._cached_result if self._cached_result is not None else []
|
||||
|
||||
def _collect_fresh(self) -> list[GaugeMetricFamily]:
|
||||
raise NotImplementedError
|
||||
|
||||
def describe(self) -> list[GaugeMetricFamily]:
|
||||
return []
|
||||
|
||||
|
||||
class QueueDepthCollector(_CachedCollector):
|
||||
"""Reads Celery queue lengths from the broker Redis on each scrape.
|
||||
|
||||
Uses a Redis client factory (callable) rather than a stored client
|
||||
reference so the connection is always fresh from Celery's pool.
|
||||
"""
|
||||
|
||||
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
|
||||
super().__init__(cache_ttl)
|
||||
self._get_redis: Callable[[], Redis] | None = None
|
||||
|
||||
def set_redis_factory(self, factory: Callable[[], Redis]) -> None:
|
||||
"""Set a callable that returns a broker Redis client on demand."""
|
||||
self._get_redis = factory
|
||||
|
||||
def _collect_fresh(self) -> list[GaugeMetricFamily]:
|
||||
if self._get_redis is None:
|
||||
return []
|
||||
|
||||
redis_client = self._get_redis()
|
||||
|
||||
depth = GaugeMetricFamily(
|
||||
"onyx_queue_depth",
|
||||
"Number of tasks waiting in Celery queue",
|
||||
labels=["queue"],
|
||||
)
|
||||
unacked = GaugeMetricFamily(
|
||||
"onyx_queue_unacked",
|
||||
"Number of prefetched (unacked) tasks for queue",
|
||||
labels=["queue"],
|
||||
)
|
||||
queue_age = GaugeMetricFamily(
|
||||
"onyx_queue_oldest_task_age_seconds",
|
||||
"Age of the oldest task in the queue (seconds since enqueue)",
|
||||
labels=["queue"],
|
||||
)
|
||||
|
||||
now = time.time()
|
||||
|
||||
for queue_name, label in _QUEUE_LABEL_MAP.items():
|
||||
length = celery_get_queue_length(queue_name, redis_client)
|
||||
depth.add_metric([label], length)
|
||||
|
||||
# Peek at the oldest message to get its age
|
||||
if length > 0:
|
||||
age = self._get_oldest_message_age(redis_client, queue_name, now)
|
||||
if age is not None:
|
||||
queue_age.add_metric([label], age)
|
||||
|
||||
for queue_name in _UNACKED_QUEUES:
|
||||
label = _QUEUE_LABEL_MAP[queue_name]
|
||||
task_ids = celery_get_unacked_task_ids(queue_name, redis_client)
|
||||
unacked.add_metric([label], len(task_ids))
|
||||
|
||||
return [depth, unacked, queue_age]
|
||||
|
||||
@staticmethod
|
||||
def _get_oldest_message_age(
|
||||
redis_client: Redis, queue_name: str, now: float
|
||||
) -> float | None:
|
||||
"""Peek at the oldest (tail) message in a Redis list queue
|
||||
and extract its timestamp to compute age.
|
||||
|
||||
Note: If the Celery message contains neither ``properties.timestamp``
|
||||
nor ``headers.timestamp``, no age metric is emitted for this queue.
|
||||
This can happen with custom task producers or non-standard Celery
|
||||
protocol versions. The metric will simply be absent rather than
|
||||
inaccurate, which is the safest behavior for alerting.
|
||||
"""
|
||||
try:
|
||||
raw: bytes | str | None = redis_client.lindex(queue_name, -1) # type: ignore[assignment]
|
||||
if raw is None:
|
||||
return None
|
||||
msg = json.loads(raw)
|
||||
# Check for ETA tasks first — they are intentionally delayed,
|
||||
# so reporting their queue age would be misleading.
|
||||
headers = msg.get("headers", {})
|
||||
if headers.get("eta") is not None:
|
||||
return None
|
||||
# Celery v2 protocol: timestamp in properties
|
||||
props = msg.get("properties", {})
|
||||
ts = props.get("timestamp")
|
||||
if ts is not None:
|
||||
return now - float(ts)
|
||||
# Fallback: some Celery configurations place the timestamp in
|
||||
# headers instead of properties.
|
||||
ts = headers.get("timestamp")
|
||||
if ts is not None:
|
||||
return now - float(ts)
|
||||
except Exception:
|
||||
pass
|
||||
return None
|
||||
|
||||
|
||||
class IndexAttemptCollector(_CachedCollector):
|
||||
"""Queries Postgres for index attempt state on each scrape."""
|
||||
|
||||
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
|
||||
super().__init__(cache_ttl)
|
||||
self._configured: bool = False
|
||||
self._terminal_statuses: list = []
|
||||
|
||||
def configure(self) -> None:
|
||||
"""Call once DB engine is initialized."""
|
||||
from onyx.db.enums import IndexingStatus
|
||||
|
||||
self._terminal_statuses = [s for s in IndexingStatus if s.is_terminal()]
|
||||
self._configured = True
|
||||
|
||||
def _collect_fresh(self) -> list[GaugeMetricFamily]:
|
||||
if not self._configured:
|
||||
return []
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.tenant_utils import get_all_tenant_ids
|
||||
from onyx.db.index_attempt import get_active_index_attempts_for_metrics
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
attempts_gauge = GaugeMetricFamily(
|
||||
"onyx_index_attempts_active",
|
||||
"Number of non-terminal index attempts",
|
||||
labels=[
|
||||
"status",
|
||||
"source",
|
||||
"tenant_id",
|
||||
"connector_name",
|
||||
"cc_pair_id",
|
||||
],
|
||||
)
|
||||
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
|
||||
for tid in tenant_ids:
|
||||
# Defensive guard — get_all_tenant_ids() should never yield None,
|
||||
# but we guard here for API stability in case the contract changes.
|
||||
if tid is None:
|
||||
continue
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tid)
|
||||
try:
|
||||
with get_session_with_current_tenant() as session:
|
||||
rows = get_active_index_attempts_for_metrics(session)
|
||||
|
||||
for status, source, cc_id, cc_name, count in rows:
|
||||
name_val = cc_name or f"cc_pair_{cc_id}"
|
||||
attempts_gauge.add_metric(
|
||||
[
|
||||
status.value,
|
||||
source.value,
|
||||
tid,
|
||||
name_val,
|
||||
str(cc_id),
|
||||
],
|
||||
count,
|
||||
)
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
return [attempts_gauge]
|
||||
|
||||
|
||||
class ConnectorHealthCollector(_CachedCollector):
|
||||
"""Queries Postgres for connector health state on each scrape."""
|
||||
|
||||
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
|
||||
super().__init__(cache_ttl)
|
||||
self._configured: bool = False
|
||||
|
||||
def configure(self) -> None:
|
||||
"""Call once DB engine is initialized."""
|
||||
self._configured = True
|
||||
|
||||
def _collect_fresh(self) -> list[GaugeMetricFamily]:
|
||||
if not self._configured:
|
||||
return []
|
||||
|
||||
from onyx.db.connector_credential_pair import (
|
||||
get_connector_health_for_metrics,
|
||||
)
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.tenant_utils import get_all_tenant_ids
|
||||
from onyx.db.index_attempt import get_docs_indexed_by_cc_pair
|
||||
from onyx.db.index_attempt import get_failed_attempt_counts_by_cc_pair
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
staleness_gauge = GaugeMetricFamily(
|
||||
"onyx_connector_last_success_age_seconds",
|
||||
"Seconds since last successful index for this connector",
|
||||
labels=["tenant_id", "source", "cc_pair_id", "connector_name"],
|
||||
)
|
||||
error_state_gauge = GaugeMetricFamily(
|
||||
"onyx_connector_in_error_state",
|
||||
"Whether the connector is in a repeated error state (1=yes, 0=no)",
|
||||
labels=["tenant_id", "source", "cc_pair_id", "connector_name"],
|
||||
)
|
||||
by_status_gauge = GaugeMetricFamily(
|
||||
"onyx_connectors_by_status",
|
||||
"Number of connectors grouped by status",
|
||||
labels=["tenant_id", "status"],
|
||||
)
|
||||
error_total_gauge = GaugeMetricFamily(
|
||||
"onyx_connectors_in_error_total",
|
||||
"Total number of connectors in repeated error state",
|
||||
labels=["tenant_id"],
|
||||
)
|
||||
per_connector_labels = [
|
||||
"tenant_id",
|
||||
"source",
|
||||
"cc_pair_id",
|
||||
"connector_name",
|
||||
]
|
||||
docs_success_gauge = GaugeMetricFamily(
|
||||
"onyx_connector_docs_indexed",
|
||||
"Total new documents indexed (90-day rolling sum) per connector",
|
||||
labels=per_connector_labels,
|
||||
)
|
||||
docs_error_gauge = GaugeMetricFamily(
|
||||
"onyx_connector_error_count",
|
||||
"Total number of failed index attempts per connector",
|
||||
labels=per_connector_labels,
|
||||
)
|
||||
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
tenant_ids = get_all_tenant_ids()
|
||||
|
||||
for tid in tenant_ids:
|
||||
# Defensive guard — get_all_tenant_ids() should never yield None,
|
||||
# but we guard here for API stability in case the contract changes.
|
||||
if tid is None:
|
||||
continue
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tid)
|
||||
try:
|
||||
with get_session_with_current_tenant() as session:
|
||||
pairs = get_connector_health_for_metrics(session)
|
||||
error_counts_by_cc = get_failed_attempt_counts_by_cc_pair(session)
|
||||
docs_by_cc = get_docs_indexed_by_cc_pair(session)
|
||||
|
||||
status_counts: dict[str, int] = {}
|
||||
error_count = 0
|
||||
|
||||
for (
|
||||
cc_id,
|
||||
status,
|
||||
in_error,
|
||||
last_success,
|
||||
cc_name,
|
||||
source,
|
||||
) in pairs:
|
||||
cc_id_str = str(cc_id)
|
||||
source_val = source.value
|
||||
name_val = cc_name or f"cc_pair_{cc_id}"
|
||||
label_vals = [tid, source_val, cc_id_str, name_val]
|
||||
|
||||
if last_success is not None:
|
||||
# Both `now` and `last_success` are timezone-aware
|
||||
# (the DB column uses DateTime(timezone=True)),
|
||||
# so subtraction is safe.
|
||||
age = (now - last_success).total_seconds()
|
||||
staleness_gauge.add_metric(label_vals, age)
|
||||
|
||||
error_state_gauge.add_metric(
|
||||
label_vals,
|
||||
1.0 if in_error else 0.0,
|
||||
)
|
||||
if in_error:
|
||||
error_count += 1
|
||||
|
||||
docs_success_gauge.add_metric(
|
||||
label_vals,
|
||||
docs_by_cc.get(cc_id, 0),
|
||||
)
|
||||
|
||||
docs_error_gauge.add_metric(
|
||||
label_vals,
|
||||
error_counts_by_cc.get(cc_id, 0),
|
||||
)
|
||||
|
||||
status_val = status.value
|
||||
status_counts[status_val] = status_counts.get(status_val, 0) + 1
|
||||
|
||||
for status_val, count in status_counts.items():
|
||||
by_status_gauge.add_metric([tid, status_val], count)
|
||||
|
||||
error_total_gauge.add_metric([tid], error_count)
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
return [
|
||||
staleness_gauge,
|
||||
error_state_gauge,
|
||||
by_status_gauge,
|
||||
error_total_gauge,
|
||||
docs_success_gauge,
|
||||
docs_error_gauge,
|
||||
]
|
||||
|
||||
|
||||
class RedisHealthCollector(_CachedCollector):
|
||||
"""Collects Redis server health metrics (memory, clients, etc.)."""
|
||||
|
||||
def __init__(self, cache_ttl: float = _DEFAULT_CACHE_TTL) -> None:
|
||||
super().__init__(cache_ttl)
|
||||
self._get_redis: Callable[[], Redis] | None = None
|
||||
|
||||
def set_redis_factory(self, factory: Callable[[], Redis]) -> None:
|
||||
"""Set a callable that returns a broker Redis client on demand."""
|
||||
self._get_redis = factory
|
||||
|
||||
def _collect_fresh(self) -> list[GaugeMetricFamily]:
|
||||
if self._get_redis is None:
|
||||
return []
|
||||
|
||||
redis_client = self._get_redis()
|
||||
|
||||
memory_used = GaugeMetricFamily(
|
||||
"onyx_redis_memory_used_bytes",
|
||||
"Redis used memory in bytes",
|
||||
)
|
||||
memory_peak = GaugeMetricFamily(
|
||||
"onyx_redis_memory_peak_bytes",
|
||||
"Redis peak used memory in bytes",
|
||||
)
|
||||
memory_frag = GaugeMetricFamily(
|
||||
"onyx_redis_memory_fragmentation_ratio",
|
||||
"Redis memory fragmentation ratio (>1.5 indicates fragmentation)",
|
||||
)
|
||||
connected_clients = GaugeMetricFamily(
|
||||
"onyx_redis_connected_clients",
|
||||
"Number of connected Redis clients",
|
||||
)
|
||||
|
||||
try:
|
||||
mem_info: dict = redis_client.info("memory") # type: ignore[assignment]
|
||||
memory_used.add_metric([], mem_info.get("used_memory", 0))
|
||||
memory_peak.add_metric([], mem_info.get("used_memory_peak", 0))
|
||||
frag = mem_info.get("mem_fragmentation_ratio")
|
||||
if frag is not None:
|
||||
memory_frag.add_metric([], frag)
|
||||
|
||||
client_info: dict = redis_client.info("clients") # type: ignore[assignment]
|
||||
connected_clients.add_metric([], client_info.get("connected_clients", 0))
|
||||
except Exception:
|
||||
logger.debug("Failed to collect Redis health metrics", exc_info=True)
|
||||
|
||||
return [memory_used, memory_peak, memory_frag, connected_clients]
|
||||
|
||||
|
||||
class WorkerHealthCollector(_CachedCollector):
|
||||
"""Collects Celery worker count and process count via inspect ping.
|
||||
|
||||
Uses a longer cache TTL (60s) since inspect.ping() is a broadcast
|
||||
command that takes a couple seconds to complete.
|
||||
|
||||
Maintains a set of known worker short-names so that when a worker
|
||||
stops responding, we emit ``up=0`` instead of silently dropping the
|
||||
metric (which would make ``absent()``-style alerts impossible).
|
||||
"""
|
||||
|
||||
# Remove a worker from _known_workers after this many consecutive
|
||||
# missed pings (at 60s TTL ≈ 10 minutes of being unreachable).
|
||||
_MAX_CONSECUTIVE_MISSES = 10
|
||||
|
||||
def __init__(self, cache_ttl: float = 60.0) -> None:
|
||||
super().__init__(cache_ttl)
|
||||
self._celery_app: Any | None = None
|
||||
# worker short-name → consecutive miss count.
|
||||
# Workers start at 0 and reset to 0 each time they respond.
|
||||
# Removed after _MAX_CONSECUTIVE_MISSES missed collects.
|
||||
self._known_workers: dict[str, int] = {}
|
||||
|
||||
def set_celery_app(self, app: Any) -> None:
|
||||
"""Set the Celery app instance for inspect commands."""
|
||||
self._celery_app = app
|
||||
|
||||
def _collect_fresh(self) -> list[GaugeMetricFamily]:
|
||||
if self._celery_app is None:
|
||||
return []
|
||||
|
||||
active_workers = GaugeMetricFamily(
|
||||
"onyx_celery_active_worker_count",
|
||||
"Number of active Celery workers responding to ping",
|
||||
)
|
||||
worker_up = GaugeMetricFamily(
|
||||
"onyx_celery_worker_up",
|
||||
"Whether a specific Celery worker is alive (1=up, 0=down)",
|
||||
labels=["worker"],
|
||||
)
|
||||
|
||||
try:
|
||||
inspector = self._celery_app.control.inspect(timeout=3.0)
|
||||
ping_result = inspector.ping()
|
||||
|
||||
responding: set[str] = set()
|
||||
if ping_result:
|
||||
active_workers.add_metric([], len(ping_result))
|
||||
for worker_name in ping_result:
|
||||
# Strip hostname suffix for cleaner labels
|
||||
short_name = worker_name.split("@")[0]
|
||||
responding.add(short_name)
|
||||
else:
|
||||
active_workers.add_metric([], 0)
|
||||
|
||||
# Register newly-seen workers and reset miss count for
|
||||
# workers that responded.
|
||||
for short_name in responding:
|
||||
self._known_workers[short_name] = 0
|
||||
|
||||
# Increment miss count for non-responding workers and evict
|
||||
# those that have been missing too long.
|
||||
stale = []
|
||||
for short_name in list(self._known_workers):
|
||||
if short_name not in responding:
|
||||
self._known_workers[short_name] += 1
|
||||
if self._known_workers[short_name] >= self._MAX_CONSECUTIVE_MISSES:
|
||||
stale.append(short_name)
|
||||
for short_name in stale:
|
||||
del self._known_workers[short_name]
|
||||
|
||||
for short_name in sorted(self._known_workers):
|
||||
worker_up.add_metric([short_name], 1 if short_name in responding else 0)
|
||||
except Exception:
|
||||
logger.debug("Failed to collect worker health metrics", exc_info=True)
|
||||
|
||||
return [active_workers, worker_up]
|
||||
113
backend/onyx/server/metrics/indexing_pipeline_setup.py
Normal file
113
backend/onyx/server/metrics/indexing_pipeline_setup.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""Setup function for indexing pipeline Prometheus collectors.
|
||||
|
||||
Called once by the monitoring celery worker after Redis and DB are ready.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from celery import Celery
|
||||
from prometheus_client.registry import REGISTRY
|
||||
from redis import Redis
|
||||
|
||||
from onyx.server.metrics.indexing_pipeline import ConnectorHealthCollector
|
||||
from onyx.server.metrics.indexing_pipeline import IndexAttemptCollector
|
||||
from onyx.server.metrics.indexing_pipeline import QueueDepthCollector
|
||||
from onyx.server.metrics.indexing_pipeline import RedisHealthCollector
|
||||
from onyx.server.metrics.indexing_pipeline import WorkerHealthCollector
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Module-level singletons — these are lightweight objects (no connections or DB
|
||||
# state) until configure() / set_redis_factory() is called. Keeping them at
|
||||
# module level ensures they survive the lifetime of the worker process and are
|
||||
# only registered with the Prometheus registry once.
|
||||
_queue_collector = QueueDepthCollector()
|
||||
_attempt_collector = IndexAttemptCollector()
|
||||
_connector_collector = ConnectorHealthCollector()
|
||||
_redis_health_collector = RedisHealthCollector()
|
||||
_worker_health_collector = WorkerHealthCollector()
|
||||
|
||||
|
||||
def _make_broker_redis_factory(celery_app: Celery) -> Callable[[], Redis]:
|
||||
"""Create a factory that returns a cached broker Redis client.
|
||||
|
||||
Reuses a single connection across scrapes to avoid leaking connections.
|
||||
Reconnects automatically if the cached connection becomes stale.
|
||||
"""
|
||||
_cached_client: list[Redis | None] = [None]
|
||||
# Keep a reference to the Kombu Connection so we can close it on
|
||||
# reconnect (the raw Redis client outlives the Kombu wrapper).
|
||||
_cached_kombu_conn: list[Any] = [None]
|
||||
|
||||
def _close_client(client: Redis) -> None:
|
||||
"""Best-effort close of a Redis client."""
|
||||
try:
|
||||
client.close()
|
||||
except Exception:
|
||||
logger.debug("Failed to close stale Redis client", exc_info=True)
|
||||
|
||||
def _close_kombu_conn() -> None:
|
||||
"""Best-effort close of the cached Kombu Connection."""
|
||||
conn = _cached_kombu_conn[0]
|
||||
if conn is not None:
|
||||
try:
|
||||
conn.close()
|
||||
except Exception:
|
||||
logger.debug("Failed to close Kombu connection", exc_info=True)
|
||||
_cached_kombu_conn[0] = None
|
||||
|
||||
def _get_broker_redis() -> Redis:
|
||||
client = _cached_client[0]
|
||||
if client is not None:
|
||||
try:
|
||||
client.ping()
|
||||
return client
|
||||
except Exception:
|
||||
logger.debug("Cached Redis client stale, reconnecting")
|
||||
_close_client(client)
|
||||
_cached_client[0] = None
|
||||
_close_kombu_conn()
|
||||
|
||||
# Get a fresh Redis client from the broker connection.
|
||||
# We hold this client long-term (cached above) rather than using a
|
||||
# context manager, because we need it to persist across scrapes.
|
||||
# The caching logic above ensures we only ever hold one connection,
|
||||
# and we close it explicitly on reconnect.
|
||||
conn = celery_app.broker_connection()
|
||||
# kombu's Channel exposes .client at runtime (the underlying Redis
|
||||
# client) but the type stubs don't declare it.
|
||||
new_client: Redis = conn.channel().client # type: ignore[attr-defined]
|
||||
_cached_client[0] = new_client
|
||||
_cached_kombu_conn[0] = conn
|
||||
return new_client
|
||||
|
||||
return _get_broker_redis
|
||||
|
||||
|
||||
def setup_indexing_pipeline_metrics(celery_app: Celery) -> None:
|
||||
"""Register all indexing pipeline collectors with the default registry.
|
||||
|
||||
Args:
|
||||
celery_app: The Celery application instance. Used to obtain a fresh
|
||||
broker Redis client on each scrape for queue depth metrics.
|
||||
"""
|
||||
redis_factory = _make_broker_redis_factory(celery_app)
|
||||
_queue_collector.set_redis_factory(redis_factory)
|
||||
_redis_health_collector.set_redis_factory(redis_factory)
|
||||
_worker_health_collector.set_celery_app(celery_app)
|
||||
_attempt_collector.configure()
|
||||
_connector_collector.configure()
|
||||
|
||||
for collector in (
|
||||
_queue_collector,
|
||||
_attempt_collector,
|
||||
_connector_collector,
|
||||
_redis_health_collector,
|
||||
_worker_health_collector,
|
||||
):
|
||||
try:
|
||||
REGISTRY.register(collector)
|
||||
except ValueError:
|
||||
logger.debug("Collector already registered: %s", type(collector).__name__)
|
||||
253
backend/onyx/server/metrics/indexing_task_metrics.py
Normal file
253
backend/onyx/server/metrics/indexing_task_metrics.py
Normal file
@@ -0,0 +1,253 @@
|
||||
"""Per-connector Prometheus metrics for indexing tasks.
|
||||
|
||||
Enriches the two primary indexing tasks (docfetching_proxy_task and
|
||||
docprocessing_task) with connector-level labels: source, tenant_id,
|
||||
and cc_pair_id.
|
||||
|
||||
Note: connector_name is intentionally excluded from push-based per-task
|
||||
counters because it is a user-defined free-form string that can create
|
||||
unbounded cardinality. The pull-based collectors on the monitoring worker
|
||||
(see indexing_pipeline.py) include connector_name since they have bounded
|
||||
cardinality (one series per connector, not per task execution).
|
||||
|
||||
Uses an in-memory cache for cc_pair_id → (source, name) lookups.
|
||||
Connectors never change source type, and names change rarely, so the
|
||||
cache is safe to hold for the worker's lifetime.
|
||||
|
||||
Usage in a worker app module:
|
||||
from onyx.server.metrics.indexing_task_metrics import (
|
||||
on_indexing_task_prerun,
|
||||
on_indexing_task_postrun,
|
||||
)
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
|
||||
from celery import Task
|
||||
from prometheus_client import Counter
|
||||
from prometheus_client import Histogram
|
||||
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.server.metrics.celery_task_metrics import _MAX_START_TIME_AGE_SECONDS
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ConnectorInfo:
|
||||
"""Cached connector metadata for metric labels."""
|
||||
|
||||
source: str
|
||||
name: str
|
||||
|
||||
|
||||
_UNKNOWN_CONNECTOR = ConnectorInfo(source="unknown", name="unknown")
|
||||
|
||||
# (tenant_id, cc_pair_id) → ConnectorInfo (populated on first encounter).
|
||||
# Keyed by tenant to avoid cross-tenant cache poisoning in multi-tenant
|
||||
# deployments where different tenants can share the same cc_pair_id value.
|
||||
_connector_cache: dict[tuple[str, int], ConnectorInfo] = {}
|
||||
|
||||
# Lock protecting _connector_cache — multiple thread-pool workers may
|
||||
# resolve connectors concurrently.
|
||||
_connector_cache_lock = threading.Lock()
|
||||
|
||||
# Only enrich these task types with per-connector labels
|
||||
_INDEXING_TASK_NAMES: frozenset[str] = frozenset(
|
||||
{
|
||||
OnyxCeleryTask.CONNECTOR_DOC_FETCHING_TASK,
|
||||
OnyxCeleryTask.DOCPROCESSING_TASK,
|
||||
}
|
||||
)
|
||||
|
||||
# connector_name is intentionally excluded — see module docstring.
|
||||
INDEXING_TASK_STARTED = Counter(
|
||||
"onyx_indexing_task_started_total",
|
||||
"Indexing tasks started per connector",
|
||||
["task_name", "source", "tenant_id", "cc_pair_id"],
|
||||
)
|
||||
|
||||
INDEXING_TASK_COMPLETED = Counter(
|
||||
"onyx_indexing_task_completed_total",
|
||||
"Indexing tasks completed per connector",
|
||||
[
|
||||
"task_name",
|
||||
"source",
|
||||
"tenant_id",
|
||||
"cc_pair_id",
|
||||
"outcome",
|
||||
],
|
||||
)
|
||||
|
||||
INDEXING_TASK_DURATION = Histogram(
|
||||
"onyx_indexing_task_duration_seconds",
|
||||
"Indexing task duration by connector type",
|
||||
["task_name", "source", "tenant_id"],
|
||||
buckets=[1, 5, 15, 30, 60, 120, 300, 600, 1800, 3600],
|
||||
)
|
||||
|
||||
# task_id → monotonic start time (for indexing tasks only)
|
||||
_indexing_start_times: dict[str, float] = {}
|
||||
|
||||
# Lock protecting _indexing_start_times — prerun, postrun, and eviction may
|
||||
# run concurrently on thread-pool workers.
|
||||
_indexing_start_times_lock = threading.Lock()
|
||||
|
||||
|
||||
def _evict_stale_start_times() -> None:
|
||||
"""Remove _indexing_start_times entries older than _MAX_START_TIME_AGE_SECONDS.
|
||||
|
||||
Must be called while holding _indexing_start_times_lock.
|
||||
"""
|
||||
now = time.monotonic()
|
||||
stale_ids = [
|
||||
tid
|
||||
for tid, start in _indexing_start_times.items()
|
||||
if now - start > _MAX_START_TIME_AGE_SECONDS
|
||||
]
|
||||
for tid in stale_ids:
|
||||
_indexing_start_times.pop(tid, None)
|
||||
|
||||
|
||||
def _resolve_connector(cc_pair_id: int) -> ConnectorInfo:
|
||||
"""Resolve cc_pair_id to ConnectorInfo, using cache when possible.
|
||||
|
||||
On cache miss, does a single DB query with eager connector load.
|
||||
On any failure, returns _UNKNOWN_CONNECTOR without caching, so that
|
||||
subsequent calls can retry the lookup once the DB is available.
|
||||
|
||||
Note on tenant_id source: we read CURRENT_TENANT_ID_CONTEXTVAR for the
|
||||
cache key. The Celery tenant-aware middleware sets this contextvar before
|
||||
task execution, and it always matches kwargs["tenant_id"] (which is set
|
||||
at task dispatch time). They are guaranteed to agree for a given task
|
||||
execution context.
|
||||
"""
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get("") or ""
|
||||
cache_key = (tenant_id, cc_pair_id)
|
||||
|
||||
with _connector_cache_lock:
|
||||
cached = _connector_cache.get(cache_key)
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
try:
|
||||
from onyx.db.connector_credential_pair import (
|
||||
get_connector_credential_pair_from_id,
|
||||
)
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session,
|
||||
cc_pair_id,
|
||||
eager_load_connector=True,
|
||||
)
|
||||
if cc_pair is None:
|
||||
# DB lookup succeeded but cc_pair doesn't exist — don't cache,
|
||||
# it may appear later (race with connector creation).
|
||||
return _UNKNOWN_CONNECTOR
|
||||
|
||||
info = ConnectorInfo(
|
||||
source=cc_pair.connector.source.value,
|
||||
name=cc_pair.name,
|
||||
)
|
||||
with _connector_cache_lock:
|
||||
_connector_cache[cache_key] = info
|
||||
return info
|
||||
except Exception:
|
||||
logger.debug(
|
||||
f"Failed to resolve connector info for cc_pair_id={cc_pair_id}",
|
||||
exc_info=True,
|
||||
)
|
||||
return _UNKNOWN_CONNECTOR
|
||||
|
||||
|
||||
def on_indexing_task_prerun(
|
||||
task_id: str | None,
|
||||
task: Task | None,
|
||||
kwargs: dict | None,
|
||||
) -> None:
|
||||
"""Record per-connector metrics at task start.
|
||||
|
||||
Only fires for tasks in _INDEXING_TASK_NAMES. Silently returns for
|
||||
all other tasks.
|
||||
"""
|
||||
if task is None or task_id is None or kwargs is None:
|
||||
return
|
||||
|
||||
task_name = task.name or ""
|
||||
if task_name not in _INDEXING_TASK_NAMES:
|
||||
return
|
||||
|
||||
try:
|
||||
cc_pair_id = kwargs.get("cc_pair_id")
|
||||
tenant_id = str(kwargs.get("tenant_id", "unknown"))
|
||||
|
||||
if cc_pair_id is None:
|
||||
return
|
||||
|
||||
info = _resolve_connector(cc_pair_id)
|
||||
|
||||
INDEXING_TASK_STARTED.labels(
|
||||
task_name=task_name,
|
||||
source=info.source,
|
||||
tenant_id=tenant_id,
|
||||
cc_pair_id=str(cc_pair_id),
|
||||
).inc()
|
||||
|
||||
with _indexing_start_times_lock:
|
||||
_evict_stale_start_times()
|
||||
_indexing_start_times[task_id] = time.monotonic()
|
||||
except Exception:
|
||||
logger.debug("Failed to record indexing task prerun metrics", exc_info=True)
|
||||
|
||||
|
||||
def on_indexing_task_postrun(
|
||||
task_id: str | None,
|
||||
task: Task | None,
|
||||
kwargs: dict | None,
|
||||
state: str | None,
|
||||
) -> None:
|
||||
"""Record per-connector completion metrics.
|
||||
|
||||
Only fires for tasks in _INDEXING_TASK_NAMES.
|
||||
"""
|
||||
if task is None or task_id is None or kwargs is None:
|
||||
return
|
||||
|
||||
task_name = task.name or ""
|
||||
if task_name not in _INDEXING_TASK_NAMES:
|
||||
return
|
||||
|
||||
try:
|
||||
cc_pair_id = kwargs.get("cc_pair_id")
|
||||
tenant_id = str(kwargs.get("tenant_id", "unknown"))
|
||||
|
||||
if cc_pair_id is None:
|
||||
return
|
||||
|
||||
info = _resolve_connector(cc_pair_id)
|
||||
outcome = "success" if state == "SUCCESS" else "failure"
|
||||
|
||||
INDEXING_TASK_COMPLETED.labels(
|
||||
task_name=task_name,
|
||||
source=info.source,
|
||||
tenant_id=tenant_id,
|
||||
cc_pair_id=str(cc_pair_id),
|
||||
outcome=outcome,
|
||||
).inc()
|
||||
|
||||
with _indexing_start_times_lock:
|
||||
start = _indexing_start_times.pop(task_id, None)
|
||||
if start is not None:
|
||||
INDEXING_TASK_DURATION.labels(
|
||||
task_name=task_name,
|
||||
source=info.source,
|
||||
tenant_id=tenant_id,
|
||||
).observe(time.monotonic() - start)
|
||||
except Exception:
|
||||
logger.debug("Failed to record indexing task postrun metrics", exc_info=True)
|
||||
89
backend/onyx/server/metrics/metrics_server.py
Normal file
89
backend/onyx/server/metrics/metrics_server.py
Normal file
@@ -0,0 +1,89 @@
|
||||
"""Standalone Prometheus metrics HTTP server for non-API processes.
|
||||
|
||||
The FastAPI API server already exposes /metrics via prometheus-fastapi-instrumentator.
|
||||
Celery workers and other background processes use this module to expose their
|
||||
own /metrics endpoint on a configurable port.
|
||||
|
||||
Usage:
|
||||
from onyx.server.metrics.metrics_server import start_metrics_server
|
||||
start_metrics_server("monitoring") # reads port from env or uses default
|
||||
"""
|
||||
|
||||
import os
|
||||
import threading
|
||||
|
||||
from prometheus_client import start_http_server
|
||||
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Default ports for worker types that serve custom Prometheus metrics.
|
||||
# Only add entries here when a worker actually registers collectors.
|
||||
# In k8s each worker type runs in its own pod, so PROMETHEUS_METRICS_PORT
|
||||
# env var can override.
|
||||
_DEFAULT_PORTS: dict[str, int] = {
|
||||
"monitoring": 9096,
|
||||
"docfetching": 9092,
|
||||
"docprocessing": 9093,
|
||||
}
|
||||
|
||||
_server_started = False
|
||||
_server_lock = threading.Lock()
|
||||
|
||||
|
||||
def start_metrics_server(worker_type: str) -> int | None:
|
||||
"""Start a Prometheus metrics HTTP server in a background thread.
|
||||
|
||||
Returns the port if started, None if disabled or already started.
|
||||
|
||||
Port resolution order:
|
||||
1. PROMETHEUS_METRICS_PORT env var (explicit override)
|
||||
2. Default port for the worker type
|
||||
3. If worker type is unknown and no env var, skip
|
||||
|
||||
Set PROMETHEUS_METRICS_ENABLED=false to disable.
|
||||
"""
|
||||
global _server_started
|
||||
|
||||
with _server_lock:
|
||||
if _server_started:
|
||||
logger.debug(f"Metrics server already started for {worker_type}")
|
||||
return None
|
||||
|
||||
enabled = os.environ.get("PROMETHEUS_METRICS_ENABLED", "true").lower()
|
||||
if enabled in ("false", "0", "no"):
|
||||
logger.info(f"Prometheus metrics server disabled for {worker_type}")
|
||||
return None
|
||||
|
||||
port_str = os.environ.get("PROMETHEUS_METRICS_PORT")
|
||||
if port_str:
|
||||
try:
|
||||
port = int(port_str)
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
f"Invalid PROMETHEUS_METRICS_PORT '{port_str}' for {worker_type}, "
|
||||
"must be a numeric port. Skipping metrics server."
|
||||
)
|
||||
return None
|
||||
elif worker_type in _DEFAULT_PORTS:
|
||||
port = _DEFAULT_PORTS[worker_type]
|
||||
else:
|
||||
logger.info(
|
||||
f"No default metrics port for worker type '{worker_type}' "
|
||||
"and PROMETHEUS_METRICS_PORT not set. Skipping metrics server."
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
start_http_server(port)
|
||||
_server_started = True
|
||||
logger.info(
|
||||
f"Prometheus metrics server started on :{port} for {worker_type}"
|
||||
)
|
||||
return port
|
||||
except OSError as e:
|
||||
logger.warning(
|
||||
f"Failed to start metrics server on :{port} for {worker_type}: {e}"
|
||||
)
|
||||
return None
|
||||
@@ -736,7 +736,7 @@ if __name__ == "__main__":
|
||||
llm.config.model_name, llm.config.model_provider
|
||||
)
|
||||
|
||||
persona = get_default_behavior_persona(db_session)
|
||||
persona = get_default_behavior_persona(db_session, eager_load_for_tools=True)
|
||||
if persona is None:
|
||||
raise ValueError("No default persona found")
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ from onyx.chat.emitter import Emitter
|
||||
from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from onyx.context.search.models import BaseFilters
|
||||
from onyx.context.search.models import PersonaSearchInfo
|
||||
from onyx.db.enums import MCPAuthenticationPerformer
|
||||
from onyx.db.enums import MCPAuthenticationType
|
||||
from onyx.db.mcp import get_all_mcp_tools_for_server
|
||||
@@ -124,7 +125,12 @@ def construct_tools(
|
||||
) -> dict[int, list[Tool]]:
|
||||
"""Constructs tools based on persona configuration and available APIs.
|
||||
|
||||
Will simply skip tools that are not allowed/available."""
|
||||
Will simply skip tools that are not allowed/available.
|
||||
|
||||
Callers must supply a persona with ``tools``, ``document_sets``,
|
||||
``attached_documents``, and ``hierarchy_nodes`` already eager-loaded
|
||||
(e.g. via ``eager_load_persona=True`` or ``eager_load_for_tools=True``)
|
||||
to avoid lazy SQL queries after the session may have been flushed."""
|
||||
tool_dict: dict[int, list[Tool]] = {}
|
||||
|
||||
# Log which tools are attached to the persona for debugging
|
||||
@@ -143,6 +149,28 @@ def construct_tools(
|
||||
# This flow is for search so we do not get all indices.
|
||||
document_index = get_default_document_index(search_settings, None, db_session)
|
||||
|
||||
def _build_search_tool(tool_id: int, config: SearchToolConfig) -> SearchTool:
|
||||
persona_search_info = PersonaSearchInfo(
|
||||
document_set_names=[ds.name for ds in persona.document_sets],
|
||||
search_start_date=persona.search_start_date,
|
||||
attached_document_ids=[doc.id for doc in persona.attached_documents],
|
||||
hierarchy_node_ids=[node.id for node in persona.hierarchy_nodes],
|
||||
)
|
||||
return SearchTool(
|
||||
tool_id=tool_id,
|
||||
emitter=emitter,
|
||||
user=user,
|
||||
persona_search_info=persona_search_info,
|
||||
llm=llm,
|
||||
document_index=document_index,
|
||||
user_selected_filters=config.user_selected_filters,
|
||||
project_id_filter=config.project_id_filter,
|
||||
persona_id_filter=config.persona_id_filter,
|
||||
bypass_acl=config.bypass_acl,
|
||||
slack_context=config.slack_context,
|
||||
enable_slack_search=config.enable_slack_search,
|
||||
)
|
||||
|
||||
added_search_tool = False
|
||||
for db_tool_model in persona.tools:
|
||||
# If allowed_tool_ids is specified, skip tools not in the allowed list
|
||||
@@ -176,22 +204,9 @@ def construct_tools(
|
||||
if not search_tool_config:
|
||||
search_tool_config = SearchToolConfig()
|
||||
|
||||
search_tool = SearchTool(
|
||||
tool_id=db_tool_model.id,
|
||||
emitter=emitter,
|
||||
user=user,
|
||||
persona=persona,
|
||||
llm=llm,
|
||||
document_index=document_index,
|
||||
user_selected_filters=search_tool_config.user_selected_filters,
|
||||
project_id_filter=search_tool_config.project_id_filter,
|
||||
persona_id_filter=search_tool_config.persona_id_filter,
|
||||
bypass_acl=search_tool_config.bypass_acl,
|
||||
slack_context=search_tool_config.slack_context,
|
||||
enable_slack_search=search_tool_config.enable_slack_search,
|
||||
)
|
||||
|
||||
tool_dict[db_tool_model.id] = [search_tool]
|
||||
tool_dict[db_tool_model.id] = [
|
||||
_build_search_tool(db_tool_model.id, search_tool_config)
|
||||
]
|
||||
|
||||
# Handle Image Generation Tool
|
||||
elif tool_cls.__name__ == ImageGenerationTool.__name__:
|
||||
@@ -421,26 +436,12 @@ def construct_tools(
|
||||
# Get the database tool model for SearchTool
|
||||
search_tool_db_model = get_builtin_tool(db_session, SearchTool)
|
||||
|
||||
# Use the passed-in config if available, otherwise create a new one
|
||||
if not search_tool_config:
|
||||
search_tool_config = SearchToolConfig()
|
||||
|
||||
search_tool = SearchTool(
|
||||
tool_id=search_tool_db_model.id,
|
||||
emitter=emitter,
|
||||
user=user,
|
||||
persona=persona,
|
||||
llm=llm,
|
||||
document_index=document_index,
|
||||
user_selected_filters=search_tool_config.user_selected_filters,
|
||||
project_id_filter=search_tool_config.project_id_filter,
|
||||
persona_id_filter=search_tool_config.persona_id_filter,
|
||||
bypass_acl=search_tool_config.bypass_acl,
|
||||
slack_context=search_tool_config.slack_context,
|
||||
enable_slack_search=search_tool_config.enable_slack_search,
|
||||
)
|
||||
|
||||
tool_dict[search_tool_db_model.id] = [search_tool]
|
||||
tool_dict[search_tool_db_model.id] = [
|
||||
_build_search_tool(search_tool_db_model.id, search_tool_config)
|
||||
]
|
||||
|
||||
# Always inject MemoryTool when the user has the memory tool enabled,
|
||||
# bypassing persona tool associations and allowed_tool_ids filtering
|
||||
|
||||
@@ -51,6 +51,7 @@ from onyx.context.search.models import ChunkSearchRequest
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.context.search.models import PersonaSearchInfo
|
||||
from onyx.context.search.models import SearchDocsResponse
|
||||
from onyx.context.search.pipeline import merge_individual_chunks
|
||||
from onyx.context.search.pipeline import search_pipeline
|
||||
@@ -65,7 +66,6 @@ from onyx.db.federated import (
|
||||
get_federated_connector_document_set_mappings_by_document_set_names,
|
||||
)
|
||||
from onyx.db.federated import list_federated_connector_oauth_tokens
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.models import User
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
@@ -238,8 +238,8 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
emitter: Emitter,
|
||||
# Used for ACLs and federated search, anonymous users only see public docs
|
||||
user: User,
|
||||
# Used for filter settings
|
||||
persona: Persona,
|
||||
# Pre-extracted persona search configuration
|
||||
persona_search_info: PersonaSearchInfo,
|
||||
llm: LLM,
|
||||
document_index: DocumentIndex,
|
||||
# Respecting user selections
|
||||
@@ -258,7 +258,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
super().__init__(emitter=emitter)
|
||||
|
||||
self.user = user
|
||||
self.persona = persona
|
||||
self.persona_search_info = persona_search_info
|
||||
self.llm = llm
|
||||
self.document_index = document_index
|
||||
self.user_selected_filters = user_selected_filters
|
||||
@@ -289,7 +289,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
# Case 1: Slack bot context — requires a Slack federated connector
|
||||
# linked via the persona's document sets
|
||||
if self.slack_context:
|
||||
document_set_names = [ds.name for ds in self.persona.document_sets]
|
||||
document_set_names = self.persona_search_info.document_set_names
|
||||
if not document_set_names:
|
||||
logger.debug(
|
||||
"Skipping Slack federated search: no document sets on persona"
|
||||
@@ -463,7 +463,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
persona_id_filter=self.persona_id_filter,
|
||||
document_index=self.document_index,
|
||||
user=self.user,
|
||||
persona=self.persona,
|
||||
persona_search_info=self.persona_search_info,
|
||||
acl_filters=acl_filters,
|
||||
embedding_model=embedding_model,
|
||||
prefetched_federated_retrieval_infos=federated_retrieval_infos,
|
||||
@@ -587,15 +587,12 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
and self.user_selected_filters.source_type
|
||||
else None
|
||||
)
|
||||
persona_document_sets = (
|
||||
[ds.name for ds in self.persona.document_sets] if self.persona else None
|
||||
)
|
||||
federated_retrieval_infos = (
|
||||
get_federated_retrieval_functions(
|
||||
db_session=db_session,
|
||||
user_id=self.user.id if self.user else None,
|
||||
source_types=prefetch_source_types,
|
||||
document_set_names=persona_document_sets,
|
||||
document_set_names=self.persona_search_info.document_set_names,
|
||||
)
|
||||
or []
|
||||
)
|
||||
|
||||
@@ -43,5 +43,8 @@ def test_eager_load_persona_loads_relationships(db_session: Session) -> None:
|
||||
persona_unloaded = tmp.unloaded
|
||||
assert "tools" not in persona_unloaded
|
||||
assert "user_files" not in persona_unloaded
|
||||
assert "document_sets" not in persona_unloaded
|
||||
assert "attached_documents" not in persona_unloaded
|
||||
assert "hierarchy_nodes" not in persona_unloaded
|
||||
finally:
|
||||
db_session.rollback()
|
||||
|
||||
@@ -11,8 +11,8 @@ from sqlalchemy.orm import Session
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.models import ChunkSearchRequest
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import PersonaSearchInfo
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.models import User
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
@@ -139,7 +139,7 @@ def use_mock_search_pipeline(
|
||||
chunk_search_request: ChunkSearchRequest,
|
||||
document_index: DocumentIndex, # noqa: ARG001
|
||||
user: User | None, # noqa: ARG001
|
||||
persona: Persona | None, # noqa: ARG001
|
||||
persona_search_info: PersonaSearchInfo | None, # noqa: ARG001
|
||||
db_session: Session | None = None, # noqa: ARG001
|
||||
auto_detect_filters: bool = False, # noqa: ARG001
|
||||
llm: LLM | None = None, # noqa: ARG001
|
||||
|
||||
@@ -60,4 +60,4 @@ def test_non_rate_limit_error(mock_confluence_call: Mock) -> None:
|
||||
with pytest.raises(HTTPError):
|
||||
handled_call()
|
||||
|
||||
assert mock_confluence_call.call_count == 1
|
||||
assert mock_confluence_call.call_count == 5
|
||||
|
||||
153
backend/tests/unit/server/metrics/test_celery_task_metrics.py
Normal file
153
backend/tests/unit/server/metrics/test_celery_task_metrics.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""Tests for generic Celery task lifecycle Prometheus metrics."""
|
||||
|
||||
from collections.abc import Iterator
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.server.metrics.celery_task_metrics import _task_start_times
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_postrun
|
||||
from onyx.server.metrics.celery_task_metrics import on_celery_task_prerun
|
||||
from onyx.server.metrics.celery_task_metrics import TASK_COMPLETED
|
||||
from onyx.server.metrics.celery_task_metrics import TASK_DURATION
|
||||
from onyx.server.metrics.celery_task_metrics import TASK_STARTED
|
||||
from onyx.server.metrics.celery_task_metrics import TASKS_ACTIVE
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_metrics() -> Iterator[None]:
|
||||
"""Clear metric state between tests."""
|
||||
_task_start_times.clear()
|
||||
yield
|
||||
_task_start_times.clear()
|
||||
|
||||
|
||||
def _make_task(name: str = "test_task", queue: str = "test_queue") -> MagicMock:
|
||||
task = MagicMock()
|
||||
task.name = name
|
||||
task.request = MagicMock()
|
||||
task.request.delivery_info = {"routing_key": queue}
|
||||
return task
|
||||
|
||||
|
||||
class TestCeleryTaskPrerun:
|
||||
def test_increments_started_and_active(self) -> None:
|
||||
task = _make_task()
|
||||
before_started = TASK_STARTED.labels(
|
||||
task_name="test_task", queue="test_queue"
|
||||
)._value.get()
|
||||
before_active = TASKS_ACTIVE.labels(
|
||||
task_name="test_task", queue="test_queue"
|
||||
)._value.get()
|
||||
|
||||
on_celery_task_prerun("task-1", task)
|
||||
|
||||
after_started = TASK_STARTED.labels(
|
||||
task_name="test_task", queue="test_queue"
|
||||
)._value.get()
|
||||
after_active = TASKS_ACTIVE.labels(
|
||||
task_name="test_task", queue="test_queue"
|
||||
)._value.get()
|
||||
|
||||
assert after_started == before_started + 1
|
||||
assert after_active == before_active + 1
|
||||
|
||||
def test_records_start_time(self) -> None:
|
||||
task = _make_task()
|
||||
on_celery_task_prerun("task-1", task)
|
||||
assert "task-1" in _task_start_times
|
||||
|
||||
def test_noop_when_task_is_none(self) -> None:
|
||||
on_celery_task_prerun("task-1", None)
|
||||
assert "task-1" not in _task_start_times
|
||||
|
||||
def test_noop_when_task_id_is_none(self) -> None:
|
||||
task = _make_task()
|
||||
on_celery_task_prerun(None, task)
|
||||
# Should not crash
|
||||
|
||||
def test_handles_missing_delivery_info(self) -> None:
|
||||
task = _make_task()
|
||||
task.request.delivery_info = None
|
||||
on_celery_task_prerun("task-1", task)
|
||||
assert "task-1" in _task_start_times
|
||||
|
||||
|
||||
class TestCeleryTaskPostrun:
|
||||
def test_increments_completed_success(self) -> None:
|
||||
task = _make_task()
|
||||
on_celery_task_prerun("task-1", task)
|
||||
|
||||
before = TASK_COMPLETED.labels(
|
||||
task_name="test_task", queue="test_queue", outcome="success"
|
||||
)._value.get()
|
||||
|
||||
on_celery_task_postrun("task-1", task, "SUCCESS")
|
||||
|
||||
after = TASK_COMPLETED.labels(
|
||||
task_name="test_task", queue="test_queue", outcome="success"
|
||||
)._value.get()
|
||||
assert after == before + 1
|
||||
|
||||
def test_increments_completed_failure(self) -> None:
|
||||
task = _make_task()
|
||||
on_celery_task_prerun("task-1", task)
|
||||
|
||||
before = TASK_COMPLETED.labels(
|
||||
task_name="test_task", queue="test_queue", outcome="failure"
|
||||
)._value.get()
|
||||
|
||||
on_celery_task_postrun("task-1", task, "FAILURE")
|
||||
|
||||
after = TASK_COMPLETED.labels(
|
||||
task_name="test_task", queue="test_queue", outcome="failure"
|
||||
)._value.get()
|
||||
assert after == before + 1
|
||||
|
||||
def test_decrements_active(self) -> None:
|
||||
task = _make_task()
|
||||
on_celery_task_prerun("task-1", task)
|
||||
|
||||
active_before = TASKS_ACTIVE.labels(
|
||||
task_name="test_task", queue="test_queue"
|
||||
)._value.get()
|
||||
|
||||
on_celery_task_postrun("task-1", task, "SUCCESS")
|
||||
|
||||
active_after = TASKS_ACTIVE.labels(
|
||||
task_name="test_task", queue="test_queue"
|
||||
)._value.get()
|
||||
assert active_after == active_before - 1
|
||||
|
||||
def test_observes_duration(self) -> None:
|
||||
task = _make_task()
|
||||
on_celery_task_prerun("task-1", task)
|
||||
|
||||
before_count = TASK_DURATION.labels(
|
||||
task_name="test_task", queue="test_queue"
|
||||
)._sum.get()
|
||||
|
||||
on_celery_task_postrun("task-1", task, "SUCCESS")
|
||||
|
||||
after_count = TASK_DURATION.labels(
|
||||
task_name="test_task", queue="test_queue"
|
||||
)._sum.get()
|
||||
# Duration should have increased (at least slightly)
|
||||
assert after_count > before_count
|
||||
|
||||
def test_cleans_up_start_time(self) -> None:
|
||||
task = _make_task()
|
||||
on_celery_task_prerun("task-1", task)
|
||||
assert "task-1" in _task_start_times
|
||||
|
||||
on_celery_task_postrun("task-1", task, "SUCCESS")
|
||||
assert "task-1" not in _task_start_times
|
||||
|
||||
def test_noop_when_task_is_none(self) -> None:
|
||||
on_celery_task_postrun("task-1", None, "SUCCESS")
|
||||
|
||||
def test_handles_missing_start_time(self) -> None:
|
||||
"""Postrun without prerun should not crash."""
|
||||
task = _make_task()
|
||||
on_celery_task_postrun("task-1", task, "SUCCESS")
|
||||
# Should not raise
|
||||
@@ -0,0 +1,359 @@
|
||||
"""Tests for indexing pipeline Prometheus collectors."""
|
||||
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.server.metrics.indexing_pipeline import ConnectorHealthCollector
|
||||
from onyx.server.metrics.indexing_pipeline import IndexAttemptCollector
|
||||
from onyx.server.metrics.indexing_pipeline import QueueDepthCollector
|
||||
|
||||
|
||||
class TestQueueDepthCollector:
|
||||
def test_returns_empty_when_factory_not_set(self) -> None:
|
||||
collector = QueueDepthCollector()
|
||||
assert collector.collect() == []
|
||||
|
||||
def test_returns_empty_describe(self) -> None:
|
||||
collector = QueueDepthCollector()
|
||||
assert collector.describe() == []
|
||||
|
||||
def test_collects_queue_depths(self) -> None:
|
||||
collector = QueueDepthCollector(cache_ttl=0)
|
||||
mock_redis = MagicMock()
|
||||
collector.set_redis_factory(lambda: mock_redis)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
|
||||
return_value=5,
|
||||
),
|
||||
patch(
|
||||
"onyx.server.metrics.indexing_pipeline.celery_get_unacked_task_ids",
|
||||
return_value={"task-1", "task-2"},
|
||||
),
|
||||
):
|
||||
families = collector.collect()
|
||||
|
||||
assert len(families) == 3
|
||||
depth_family = families[0]
|
||||
unacked_family = families[1]
|
||||
age_family = families[2]
|
||||
|
||||
assert depth_family.name == "onyx_queue_depth"
|
||||
assert len(depth_family.samples) > 0
|
||||
for sample in depth_family.samples:
|
||||
assert sample.value == 5
|
||||
|
||||
assert unacked_family.name == "onyx_queue_unacked"
|
||||
unacked_labels = {s.labels["queue"] for s in unacked_family.samples}
|
||||
assert "docfetching" in unacked_labels
|
||||
assert "docprocessing" in unacked_labels
|
||||
|
||||
assert age_family.name == "onyx_queue_oldest_task_age_seconds"
|
||||
for sample in unacked_family.samples:
|
||||
assert sample.value == 2
|
||||
|
||||
def test_handles_redis_error_gracefully(self) -> None:
|
||||
collector = QueueDepthCollector(cache_ttl=0)
|
||||
mock_redis = MagicMock()
|
||||
collector.set_redis_factory(lambda: mock_redis)
|
||||
|
||||
with patch(
|
||||
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
|
||||
side_effect=Exception("connection lost"),
|
||||
):
|
||||
families = collector.collect()
|
||||
|
||||
# Returns stale cache (empty on first call)
|
||||
assert families == []
|
||||
|
||||
def test_caching_returns_stale_within_ttl(self) -> None:
|
||||
collector = QueueDepthCollector(cache_ttl=60)
|
||||
mock_redis = MagicMock()
|
||||
collector.set_redis_factory(lambda: mock_redis)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
|
||||
return_value=5,
|
||||
),
|
||||
patch(
|
||||
"onyx.server.metrics.indexing_pipeline.celery_get_unacked_task_ids",
|
||||
return_value=set(),
|
||||
),
|
||||
):
|
||||
first = collector.collect()
|
||||
|
||||
# Second call within TTL should return cached result without calling Redis
|
||||
with patch(
|
||||
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
|
||||
side_effect=Exception("should not be called"),
|
||||
):
|
||||
second = collector.collect()
|
||||
|
||||
assert first is second # Same object, from cache
|
||||
|
||||
def test_factory_called_each_scrape(self) -> None:
|
||||
"""Verify the Redis factory is called on each fresh collect, not cached."""
|
||||
collector = QueueDepthCollector(cache_ttl=0)
|
||||
factory = MagicMock(return_value=MagicMock())
|
||||
collector.set_redis_factory(factory)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
|
||||
return_value=0,
|
||||
),
|
||||
patch(
|
||||
"onyx.server.metrics.indexing_pipeline.celery_get_unacked_task_ids",
|
||||
return_value=set(),
|
||||
),
|
||||
):
|
||||
collector.collect()
|
||||
collector.collect()
|
||||
|
||||
assert factory.call_count == 2
|
||||
|
||||
def test_error_returns_stale_cache(self) -> None:
|
||||
collector = QueueDepthCollector(cache_ttl=0)
|
||||
mock_redis = MagicMock()
|
||||
collector.set_redis_factory(lambda: mock_redis)
|
||||
|
||||
# First call succeeds
|
||||
with (
|
||||
patch(
|
||||
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
|
||||
return_value=10,
|
||||
),
|
||||
patch(
|
||||
"onyx.server.metrics.indexing_pipeline.celery_get_unacked_task_ids",
|
||||
return_value=set(),
|
||||
),
|
||||
):
|
||||
good_result = collector.collect()
|
||||
|
||||
assert len(good_result) == 3
|
||||
assert good_result[0].samples[0].value == 10
|
||||
|
||||
# Second call fails — should return stale cache, not empty
|
||||
with patch(
|
||||
"onyx.server.metrics.indexing_pipeline.celery_get_queue_length",
|
||||
side_effect=Exception("Redis down"),
|
||||
):
|
||||
stale_result = collector.collect()
|
||||
|
||||
assert stale_result is good_result
|
||||
|
||||
|
||||
class TestIndexAttemptCollector:
|
||||
def test_returns_empty_when_not_configured(self) -> None:
|
||||
collector = IndexAttemptCollector()
|
||||
assert collector.collect() == []
|
||||
|
||||
def test_returns_empty_describe(self) -> None:
|
||||
collector = IndexAttemptCollector()
|
||||
assert collector.describe() == []
|
||||
|
||||
@patch("onyx.db.engine.tenant_utils.get_all_tenant_ids")
|
||||
@patch("onyx.db.engine.sql_engine.get_session_with_current_tenant")
|
||||
def test_collects_index_attempts(
|
||||
self,
|
||||
mock_get_session: MagicMock,
|
||||
mock_get_tenants: MagicMock,
|
||||
) -> None:
|
||||
collector = IndexAttemptCollector(cache_ttl=0)
|
||||
collector.configure()
|
||||
|
||||
mock_get_tenants.return_value = ["public"]
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_get_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_get_session.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
from onyx.db.enums import IndexingStatus
|
||||
|
||||
mock_row = (
|
||||
IndexingStatus.IN_PROGRESS,
|
||||
MagicMock(value="web"),
|
||||
81,
|
||||
"Table Tennis Blade Guide",
|
||||
2,
|
||||
)
|
||||
mock_session.query.return_value.join.return_value.join.return_value.filter.return_value.group_by.return_value.all.return_value = [
|
||||
mock_row
|
||||
]
|
||||
|
||||
families = collector.collect()
|
||||
assert len(families) == 1
|
||||
assert families[0].name == "onyx_index_attempts_active"
|
||||
assert len(families[0].samples) == 1
|
||||
sample = families[0].samples[0]
|
||||
assert sample.labels == {
|
||||
"status": "in_progress",
|
||||
"source": "web",
|
||||
"tenant_id": "public",
|
||||
"connector_name": "Table Tennis Blade Guide",
|
||||
"cc_pair_id": "81",
|
||||
}
|
||||
assert sample.value == 2
|
||||
|
||||
@patch("onyx.db.engine.tenant_utils.get_all_tenant_ids")
|
||||
def test_handles_db_error_gracefully(
|
||||
self,
|
||||
mock_get_tenants: MagicMock,
|
||||
) -> None:
|
||||
collector = IndexAttemptCollector(cache_ttl=0)
|
||||
collector.configure()
|
||||
|
||||
mock_get_tenants.side_effect = Exception("DB down")
|
||||
families = collector.collect()
|
||||
# No stale cache, so returns empty
|
||||
assert families == []
|
||||
|
||||
@patch("onyx.db.engine.tenant_utils.get_all_tenant_ids")
|
||||
def test_skips_none_tenant_ids(
|
||||
self,
|
||||
mock_get_tenants: MagicMock,
|
||||
) -> None:
|
||||
collector = IndexAttemptCollector(cache_ttl=0)
|
||||
collector.configure()
|
||||
|
||||
mock_get_tenants.return_value = [None]
|
||||
families = collector.collect()
|
||||
assert len(families) == 1 # Returns the gauge family, just with no samples
|
||||
assert len(families[0].samples) == 0
|
||||
|
||||
|
||||
class TestConnectorHealthCollector:
|
||||
def test_returns_empty_when_not_configured(self) -> None:
|
||||
collector = ConnectorHealthCollector()
|
||||
assert collector.collect() == []
|
||||
|
||||
def test_returns_empty_describe(self) -> None:
|
||||
collector = ConnectorHealthCollector()
|
||||
assert collector.describe() == []
|
||||
|
||||
@patch("onyx.db.engine.tenant_utils.get_all_tenant_ids")
|
||||
@patch("onyx.db.engine.sql_engine.get_session_with_current_tenant")
|
||||
def test_collects_connector_health(
|
||||
self,
|
||||
mock_get_session: MagicMock,
|
||||
mock_get_tenants: MagicMock,
|
||||
) -> None:
|
||||
collector = ConnectorHealthCollector(cache_ttl=0)
|
||||
collector.configure()
|
||||
|
||||
mock_get_tenants.return_value = ["public"]
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_get_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_get_session.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
last_success = now - timedelta(hours=2)
|
||||
|
||||
mock_status = MagicMock(value="ACTIVE")
|
||||
mock_source = MagicMock(value="google_drive")
|
||||
# Row: (id, status, in_error, last_success, name, source)
|
||||
mock_row = (
|
||||
42,
|
||||
mock_status,
|
||||
True, # in_repeated_error_state
|
||||
last_success,
|
||||
"My GDrive Connector",
|
||||
mock_source,
|
||||
)
|
||||
mock_session.query.return_value.join.return_value.all.return_value = [mock_row]
|
||||
|
||||
# Mock the index attempt queries (error counts + docs counts)
|
||||
mock_session.query.return_value.filter.return_value.group_by.return_value.all.return_value = (
|
||||
[]
|
||||
)
|
||||
|
||||
families = collector.collect()
|
||||
|
||||
assert len(families) == 6
|
||||
names = {f.name for f in families}
|
||||
assert names == {
|
||||
"onyx_connector_last_success_age_seconds",
|
||||
"onyx_connector_in_error_state",
|
||||
"onyx_connectors_by_status",
|
||||
"onyx_connectors_in_error_total",
|
||||
"onyx_connector_docs_indexed",
|
||||
"onyx_connector_error_count",
|
||||
}
|
||||
|
||||
staleness = next(
|
||||
f for f in families if f.name == "onyx_connector_last_success_age_seconds"
|
||||
)
|
||||
assert len(staleness.samples) == 1
|
||||
assert staleness.samples[0].value == pytest.approx(7200, abs=5)
|
||||
|
||||
error_state = next(
|
||||
f for f in families if f.name == "onyx_connector_in_error_state"
|
||||
)
|
||||
assert error_state.samples[0].value == 1.0
|
||||
|
||||
by_status = next(f for f in families if f.name == "onyx_connectors_by_status")
|
||||
assert by_status.samples[0].labels == {
|
||||
"tenant_id": "public",
|
||||
"status": "ACTIVE",
|
||||
}
|
||||
assert by_status.samples[0].value == 1
|
||||
|
||||
error_total = next(
|
||||
f for f in families if f.name == "onyx_connectors_in_error_total"
|
||||
)
|
||||
assert error_total.samples[0].value == 1
|
||||
|
||||
@patch("onyx.db.engine.tenant_utils.get_all_tenant_ids")
|
||||
@patch("onyx.db.engine.sql_engine.get_session_with_current_tenant")
|
||||
def test_skips_staleness_when_no_last_success(
|
||||
self,
|
||||
mock_get_session: MagicMock,
|
||||
mock_get_tenants: MagicMock,
|
||||
) -> None:
|
||||
collector = ConnectorHealthCollector(cache_ttl=0)
|
||||
collector.configure()
|
||||
|
||||
mock_get_tenants.return_value = ["public"]
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_get_session.return_value.__enter__ = MagicMock(return_value=mock_session)
|
||||
mock_get_session.return_value.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
mock_status = MagicMock(value="INITIAL_INDEXING")
|
||||
mock_source = MagicMock(value="slack")
|
||||
mock_row = (
|
||||
10,
|
||||
mock_status,
|
||||
False,
|
||||
None, # no last_successful_index_time
|
||||
0,
|
||||
mock_source,
|
||||
)
|
||||
mock_session.query.return_value.join.return_value.all.return_value = [mock_row]
|
||||
|
||||
families = collector.collect()
|
||||
|
||||
staleness = next(
|
||||
f for f in families if f.name == "onyx_connector_last_success_age_seconds"
|
||||
)
|
||||
assert len(staleness.samples) == 0
|
||||
|
||||
@patch("onyx.db.engine.tenant_utils.get_all_tenant_ids")
|
||||
def test_handles_db_error_gracefully(
|
||||
self,
|
||||
mock_get_tenants: MagicMock,
|
||||
) -> None:
|
||||
collector = ConnectorHealthCollector(cache_ttl=0)
|
||||
collector.configure()
|
||||
|
||||
mock_get_tenants.side_effect = Exception("DB down")
|
||||
families = collector.collect()
|
||||
assert families == []
|
||||
@@ -0,0 +1,96 @@
|
||||
"""Tests for indexing pipeline setup (Redis factory caching)."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from onyx.server.metrics.indexing_pipeline_setup import _make_broker_redis_factory
|
||||
|
||||
|
||||
def _make_mock_app(client: MagicMock) -> MagicMock:
|
||||
"""Create a mock Celery app whose broker_connection().channel().client
|
||||
returns the given client."""
|
||||
mock_app = MagicMock()
|
||||
mock_conn = MagicMock()
|
||||
mock_conn.channel.return_value.client = client
|
||||
|
||||
mock_app.broker_connection.return_value = mock_conn
|
||||
|
||||
return mock_app
|
||||
|
||||
|
||||
class TestMakeBrokerRedisFactory:
|
||||
def test_caches_redis_client_across_calls(self) -> None:
|
||||
"""Factory should reuse the same client on subsequent calls."""
|
||||
mock_client = MagicMock()
|
||||
mock_client.ping.return_value = True
|
||||
mock_app = _make_mock_app(mock_client)
|
||||
|
||||
factory = _make_broker_redis_factory(mock_app)
|
||||
|
||||
client1 = factory()
|
||||
client2 = factory()
|
||||
|
||||
assert client1 is client2
|
||||
# broker_connection should only be called once
|
||||
assert mock_app.broker_connection.call_count == 1
|
||||
|
||||
def test_reconnects_when_ping_fails(self) -> None:
|
||||
"""Factory should create a new client if ping fails (stale connection)."""
|
||||
mock_client_stale = MagicMock()
|
||||
mock_client_stale.ping.side_effect = ConnectionError("disconnected")
|
||||
|
||||
mock_client_fresh = MagicMock()
|
||||
mock_client_fresh.ping.return_value = True
|
||||
|
||||
mock_app = _make_mock_app(mock_client_stale)
|
||||
|
||||
factory = _make_broker_redis_factory(mock_app)
|
||||
|
||||
# First call — creates and caches
|
||||
client1 = factory()
|
||||
assert client1 is mock_client_stale
|
||||
assert mock_app.broker_connection.call_count == 1
|
||||
|
||||
# Switch to fresh client for next connection
|
||||
mock_conn_fresh = MagicMock()
|
||||
mock_conn_fresh.channel.return_value.client = mock_client_fresh
|
||||
mock_app.broker_connection.return_value = mock_conn_fresh
|
||||
|
||||
# Second call — ping fails on stale, reconnects
|
||||
client2 = factory()
|
||||
assert client2 is mock_client_fresh
|
||||
assert mock_app.broker_connection.call_count == 2
|
||||
|
||||
def test_reconnect_closes_stale_client(self) -> None:
|
||||
"""When ping fails, the old client should be closed before reconnecting."""
|
||||
mock_client_stale = MagicMock()
|
||||
mock_client_stale.ping.side_effect = ConnectionError("disconnected")
|
||||
|
||||
mock_client_fresh = MagicMock()
|
||||
mock_client_fresh.ping.return_value = True
|
||||
|
||||
mock_app = _make_mock_app(mock_client_stale)
|
||||
|
||||
factory = _make_broker_redis_factory(mock_app)
|
||||
|
||||
# First call — creates and caches
|
||||
factory()
|
||||
|
||||
# Switch to fresh client
|
||||
mock_conn_fresh = MagicMock()
|
||||
mock_conn_fresh.channel.return_value.client = mock_client_fresh
|
||||
mock_app.broker_connection.return_value = mock_conn_fresh
|
||||
|
||||
# Second call — ping fails, should close stale client
|
||||
factory()
|
||||
mock_client_stale.close.assert_called_once()
|
||||
|
||||
def test_first_call_creates_connection(self) -> None:
|
||||
"""First call should always create a new connection."""
|
||||
mock_client = MagicMock()
|
||||
mock_app = _make_mock_app(mock_client)
|
||||
|
||||
factory = _make_broker_redis_factory(mock_app)
|
||||
client = factory()
|
||||
|
||||
assert client is mock_client
|
||||
mock_app.broker_connection.assert_called_once()
|
||||
335
backend/tests/unit/server/metrics/test_indexing_task_metrics.py
Normal file
335
backend/tests/unit/server/metrics/test_indexing_task_metrics.py
Normal file
@@ -0,0 +1,335 @@
|
||||
"""Tests for per-connector indexing task Prometheus metrics."""
|
||||
|
||||
from collections.abc import Iterator
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.server.metrics.indexing_task_metrics import _connector_cache
|
||||
from onyx.server.metrics.indexing_task_metrics import _indexing_start_times
|
||||
from onyx.server.metrics.indexing_task_metrics import ConnectorInfo
|
||||
from onyx.server.metrics.indexing_task_metrics import INDEXING_TASK_COMPLETED
|
||||
from onyx.server.metrics.indexing_task_metrics import INDEXING_TASK_DURATION
|
||||
from onyx.server.metrics.indexing_task_metrics import INDEXING_TASK_STARTED
|
||||
from onyx.server.metrics.indexing_task_metrics import on_indexing_task_postrun
|
||||
from onyx.server.metrics.indexing_task_metrics import on_indexing_task_prerun
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_state() -> Iterator[None]:
|
||||
"""Clear caches and state between tests.
|
||||
|
||||
Sets CURRENT_TENANT_ID_CONTEXTVAR to a realistic value so cache keys
|
||||
are never keyed on an empty string.
|
||||
"""
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set("test_tenant")
|
||||
_connector_cache.clear()
|
||||
_indexing_start_times.clear()
|
||||
yield
|
||||
_connector_cache.clear()
|
||||
_indexing_start_times.clear()
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
def _make_task(name: str) -> MagicMock:
|
||||
task = MagicMock()
|
||||
task.name = name
|
||||
return task
|
||||
|
||||
|
||||
def _mock_db_lookup(
|
||||
source: str = "google_drive", name: str = "My Google Drive"
|
||||
) -> tuple:
|
||||
"""Return (session_patch, cc_pair_patch) context managers for DB mocking."""
|
||||
mock_cc_pair = MagicMock()
|
||||
mock_cc_pair.name = name
|
||||
mock_cc_pair.connector.source.value = source
|
||||
|
||||
session_patch = patch("onyx.db.engine.sql_engine.get_session_with_current_tenant")
|
||||
cc_pair_patch = patch(
|
||||
"onyx.db.connector_credential_pair.get_connector_credential_pair_from_id",
|
||||
return_value=mock_cc_pair,
|
||||
)
|
||||
return session_patch, cc_pair_patch
|
||||
|
||||
|
||||
class TestIndexingTaskPrerun:
|
||||
def test_skips_non_indexing_task(self) -> None:
|
||||
task = _make_task("some_other_task")
|
||||
kwargs = {"cc_pair_id": 1, "tenant_id": "public"}
|
||||
on_indexing_task_prerun("task-1", task, kwargs)
|
||||
assert "task-1" not in _indexing_start_times
|
||||
|
||||
def test_emits_started_for_docfetching(self) -> None:
|
||||
# Pre-populate cache to avoid DB lookup (tenant-scoped key)
|
||||
_connector_cache[("test_tenant", 42)] = ConnectorInfo(
|
||||
source="google_drive", name="My Google Drive"
|
||||
)
|
||||
|
||||
task = _make_task("connector_doc_fetching_task")
|
||||
kwargs = {"cc_pair_id": 42, "tenant_id": "tenant-1"}
|
||||
|
||||
before = INDEXING_TASK_STARTED.labels(
|
||||
task_name="connector_doc_fetching_task",
|
||||
source="google_drive",
|
||||
tenant_id="tenant-1",
|
||||
cc_pair_id="42",
|
||||
)._value.get()
|
||||
|
||||
on_indexing_task_prerun("task-1", task, kwargs)
|
||||
|
||||
after = INDEXING_TASK_STARTED.labels(
|
||||
task_name="connector_doc_fetching_task",
|
||||
source="google_drive",
|
||||
tenant_id="tenant-1",
|
||||
cc_pair_id="42",
|
||||
)._value.get()
|
||||
|
||||
assert after == before + 1
|
||||
assert "task-1" in _indexing_start_times
|
||||
|
||||
def test_emits_started_for_docprocessing(self) -> None:
|
||||
_connector_cache[("test_tenant", 10)] = ConnectorInfo(
|
||||
source="slack", name="Slack Connector"
|
||||
)
|
||||
|
||||
task = _make_task("docprocessing_task")
|
||||
kwargs = {"cc_pair_id": 10, "tenant_id": "public"}
|
||||
|
||||
on_indexing_task_prerun("task-2", task, kwargs)
|
||||
assert "task-2" in _indexing_start_times
|
||||
|
||||
def test_cache_hit_avoids_db_call(self) -> None:
|
||||
_connector_cache[("test_tenant", 42)] = ConnectorInfo(
|
||||
source="confluence", name="Engineering Confluence"
|
||||
)
|
||||
|
||||
task = _make_task("connector_doc_fetching_task")
|
||||
kwargs = {"cc_pair_id": 42, "tenant_id": "public"}
|
||||
|
||||
# No DB patches needed — cache should be used
|
||||
on_indexing_task_prerun("task-1", task, kwargs)
|
||||
assert "task-1" in _indexing_start_times
|
||||
|
||||
def test_db_lookup_on_cache_miss(self) -> None:
|
||||
"""On first encounter of a cc_pair_id, does a DB lookup and caches."""
|
||||
mock_cc_pair = MagicMock()
|
||||
mock_cc_pair.name = "Notion Workspace"
|
||||
mock_cc_pair.connector.source.value = "notion"
|
||||
|
||||
mock_session = MagicMock()
|
||||
mock_session.__enter__ = MagicMock(return_value=MagicMock())
|
||||
mock_session.__exit__ = MagicMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.server.metrics.indexing_task_metrics._resolve_connector"
|
||||
) as mock_resolve,
|
||||
):
|
||||
mock_resolve.return_value = ConnectorInfo(
|
||||
source="notion", name="Notion Workspace"
|
||||
)
|
||||
|
||||
task = _make_task("connector_doc_fetching_task")
|
||||
kwargs = {"cc_pair_id": 77, "tenant_id": "public"}
|
||||
|
||||
on_indexing_task_prerun("task-1", task, kwargs)
|
||||
mock_resolve.assert_called_once_with(77)
|
||||
|
||||
def test_missing_cc_pair_returns_unknown(self) -> None:
|
||||
"""When _resolve_connector can't find the cc_pair, uses 'unknown'."""
|
||||
with patch(
|
||||
"onyx.server.metrics.indexing_task_metrics._resolve_connector"
|
||||
) as mock_resolve:
|
||||
mock_resolve.return_value = ConnectorInfo(source="unknown", name="unknown")
|
||||
|
||||
task = _make_task("connector_doc_fetching_task")
|
||||
kwargs = {"cc_pair_id": 999, "tenant_id": "public"}
|
||||
|
||||
on_indexing_task_prerun("task-1", task, kwargs)
|
||||
assert "task-1" in _indexing_start_times
|
||||
|
||||
def test_skips_when_cc_pair_id_missing(self) -> None:
|
||||
task = _make_task("connector_doc_fetching_task")
|
||||
kwargs = {"tenant_id": "public"}
|
||||
on_indexing_task_prerun("task-1", task, kwargs)
|
||||
assert "task-1" not in _indexing_start_times
|
||||
|
||||
def test_db_error_does_not_crash(self) -> None:
|
||||
with patch(
|
||||
"onyx.server.metrics.indexing_task_metrics._resolve_connector",
|
||||
side_effect=Exception("DB down"),
|
||||
):
|
||||
task = _make_task("connector_doc_fetching_task")
|
||||
kwargs = {"cc_pair_id": 1, "tenant_id": "public"}
|
||||
# Should not raise
|
||||
on_indexing_task_prerun("task-1", task, kwargs)
|
||||
|
||||
|
||||
class TestIndexingTaskPostrun:
|
||||
def test_skips_non_indexing_task(self) -> None:
|
||||
task = _make_task("some_other_task")
|
||||
kwargs = {"cc_pair_id": 1, "tenant_id": "public"}
|
||||
on_indexing_task_postrun("task-1", task, kwargs, "SUCCESS")
|
||||
# Should not raise
|
||||
|
||||
def test_emits_completed_and_duration(self) -> None:
|
||||
_connector_cache[("test_tenant", 42)] = ConnectorInfo(
|
||||
source="google_drive", name="Marketing Drive"
|
||||
)
|
||||
|
||||
task = _make_task("docprocessing_task")
|
||||
kwargs = {"cc_pair_id": 42, "tenant_id": "public"}
|
||||
|
||||
# Simulate prerun
|
||||
on_indexing_task_prerun("task-1", task, kwargs)
|
||||
|
||||
before_completed = INDEXING_TASK_COMPLETED.labels(
|
||||
task_name="docprocessing_task",
|
||||
source="google_drive",
|
||||
tenant_id="public",
|
||||
cc_pair_id="42",
|
||||
outcome="success",
|
||||
)._value.get()
|
||||
|
||||
before_duration = INDEXING_TASK_DURATION.labels(
|
||||
task_name="docprocessing_task",
|
||||
source="google_drive",
|
||||
tenant_id="public",
|
||||
)._sum.get()
|
||||
|
||||
on_indexing_task_postrun("task-1", task, kwargs, "SUCCESS")
|
||||
|
||||
after_completed = INDEXING_TASK_COMPLETED.labels(
|
||||
task_name="docprocessing_task",
|
||||
source="google_drive",
|
||||
tenant_id="public",
|
||||
cc_pair_id="42",
|
||||
outcome="success",
|
||||
)._value.get()
|
||||
|
||||
after_duration = INDEXING_TASK_DURATION.labels(
|
||||
task_name="docprocessing_task",
|
||||
source="google_drive",
|
||||
tenant_id="public",
|
||||
)._sum.get()
|
||||
|
||||
assert after_completed == before_completed + 1
|
||||
assert after_duration > before_duration
|
||||
|
||||
def test_failure_outcome(self) -> None:
|
||||
_connector_cache[("test_tenant", 42)] = ConnectorInfo(
|
||||
source="slack", name="Slack"
|
||||
)
|
||||
|
||||
task = _make_task("connector_doc_fetching_task")
|
||||
kwargs = {"cc_pair_id": 42, "tenant_id": "public"}
|
||||
|
||||
on_indexing_task_prerun("task-1", task, kwargs)
|
||||
|
||||
before = INDEXING_TASK_COMPLETED.labels(
|
||||
task_name="connector_doc_fetching_task",
|
||||
source="slack",
|
||||
tenant_id="public",
|
||||
cc_pair_id="42",
|
||||
outcome="failure",
|
||||
)._value.get()
|
||||
|
||||
on_indexing_task_postrun("task-1", task, kwargs, "FAILURE")
|
||||
|
||||
after = INDEXING_TASK_COMPLETED.labels(
|
||||
task_name="connector_doc_fetching_task",
|
||||
source="slack",
|
||||
tenant_id="public",
|
||||
cc_pair_id="42",
|
||||
outcome="failure",
|
||||
)._value.get()
|
||||
|
||||
assert after == before + 1
|
||||
|
||||
def test_handles_postrun_without_prerun(self) -> None:
|
||||
"""Postrun for an indexing task without a matching prerun should not crash."""
|
||||
_connector_cache[("test_tenant", 42)] = ConnectorInfo(
|
||||
source="slack", name="Slack"
|
||||
)
|
||||
|
||||
task = _make_task("docprocessing_task")
|
||||
kwargs = {"cc_pair_id": 42, "tenant_id": "public"}
|
||||
|
||||
# No prerun — should still emit completed counter, just skip duration
|
||||
on_indexing_task_postrun("task-1", task, kwargs, "SUCCESS")
|
||||
|
||||
|
||||
class TestResolveConnector:
|
||||
def test_failed_lookup_not_cached(self) -> None:
|
||||
"""When DB lookup returns None, result should NOT be cached."""
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set("test-tenant")
|
||||
try:
|
||||
with (
|
||||
patch("onyx.db.engine.sql_engine.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.db.connector_credential_pair"
|
||||
".get_connector_credential_pair_from_id",
|
||||
return_value=None,
|
||||
),
|
||||
):
|
||||
from onyx.server.metrics.indexing_task_metrics import _resolve_connector
|
||||
|
||||
result = _resolve_connector(999)
|
||||
assert result.source == "unknown"
|
||||
# Should NOT be cached so subsequent calls can retry
|
||||
assert ("test-tenant", 999) not in _connector_cache
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
def test_exception_not_cached(self) -> None:
|
||||
"""When DB lookup raises, result should NOT be cached."""
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set("test-tenant")
|
||||
try:
|
||||
with (
|
||||
patch(
|
||||
"onyx.db.engine.sql_engine.get_session_with_current_tenant",
|
||||
side_effect=Exception("DB down"),
|
||||
),
|
||||
):
|
||||
from onyx.server.metrics.indexing_task_metrics import _resolve_connector
|
||||
|
||||
result = _resolve_connector(888)
|
||||
assert result.source == "unknown"
|
||||
assert ("test-tenant", 888) not in _connector_cache
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
def test_successful_lookup_is_cached(self) -> None:
|
||||
"""When DB lookup succeeds, result should be cached."""
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set("test-tenant")
|
||||
try:
|
||||
mock_cc_pair = MagicMock()
|
||||
mock_cc_pair.name = "My Drive"
|
||||
mock_cc_pair.connector.source.value = "google_drive"
|
||||
|
||||
with (
|
||||
patch("onyx.db.engine.sql_engine.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.db.connector_credential_pair"
|
||||
".get_connector_credential_pair_from_id",
|
||||
return_value=mock_cc_pair,
|
||||
),
|
||||
):
|
||||
from onyx.server.metrics.indexing_task_metrics import _resolve_connector
|
||||
|
||||
result = _resolve_connector(777)
|
||||
assert result.source == "google_drive"
|
||||
assert result.name == "My Drive"
|
||||
assert ("test-tenant", 777) in _connector_cache
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
69
backend/tests/unit/server/metrics/test_metrics_server.py
Normal file
69
backend/tests/unit/server/metrics/test_metrics_server.py
Normal file
@@ -0,0 +1,69 @@
|
||||
"""Tests for the Prometheus metrics server module."""
|
||||
|
||||
from collections.abc import Iterator
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.server.metrics.metrics_server import _DEFAULT_PORTS
|
||||
from onyx.server.metrics.metrics_server import start_metrics_server
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def reset_server_state() -> Iterator[None]:
|
||||
"""Reset the global _server_started between tests."""
|
||||
import onyx.server.metrics.metrics_server as mod
|
||||
|
||||
mod._server_started = False
|
||||
yield
|
||||
mod._server_started = False
|
||||
|
||||
|
||||
class TestStartMetricsServer:
|
||||
@patch("onyx.server.metrics.metrics_server.start_http_server")
|
||||
def test_uses_default_port_for_known_worker(self, mock_start: MagicMock) -> None:
|
||||
port = start_metrics_server("monitoring")
|
||||
assert port == _DEFAULT_PORTS["monitoring"]
|
||||
mock_start.assert_called_once_with(_DEFAULT_PORTS["monitoring"])
|
||||
|
||||
@patch("onyx.server.metrics.metrics_server.start_http_server")
|
||||
@patch.dict("os.environ", {"PROMETHEUS_METRICS_PORT": "9999"})
|
||||
def test_env_var_overrides_default(self, mock_start: MagicMock) -> None:
|
||||
port = start_metrics_server("monitoring")
|
||||
assert port == 9999
|
||||
mock_start.assert_called_once_with(9999)
|
||||
|
||||
@patch("onyx.server.metrics.metrics_server.start_http_server")
|
||||
@patch.dict("os.environ", {"PROMETHEUS_METRICS_ENABLED": "false"})
|
||||
def test_disabled_via_env_var(self, mock_start: MagicMock) -> None:
|
||||
port = start_metrics_server("monitoring")
|
||||
assert port is None
|
||||
mock_start.assert_not_called()
|
||||
|
||||
@patch("onyx.server.metrics.metrics_server.start_http_server")
|
||||
def test_unknown_worker_type_no_env_var(self, mock_start: MagicMock) -> None:
|
||||
port = start_metrics_server("unknown_worker")
|
||||
assert port is None
|
||||
mock_start.assert_not_called()
|
||||
|
||||
@patch("onyx.server.metrics.metrics_server.start_http_server")
|
||||
def test_idempotent(self, mock_start: MagicMock) -> None:
|
||||
port1 = start_metrics_server("monitoring")
|
||||
port2 = start_metrics_server("monitoring")
|
||||
assert port1 == _DEFAULT_PORTS["monitoring"]
|
||||
assert port2 is None
|
||||
mock_start.assert_called_once()
|
||||
|
||||
@patch("onyx.server.metrics.metrics_server.start_http_server")
|
||||
def test_handles_os_error(self, mock_start: MagicMock) -> None:
|
||||
mock_start.side_effect = OSError("Address already in use")
|
||||
port = start_metrics_server("monitoring")
|
||||
assert port is None
|
||||
|
||||
@patch("onyx.server.metrics.metrics_server.start_http_server")
|
||||
@patch.dict("os.environ", {"PROMETHEUS_METRICS_PORT": "not_a_number"})
|
||||
def test_invalid_port_env_var_returns_none(self, mock_start: MagicMock) -> None:
|
||||
port = start_metrics_server("monitoring")
|
||||
assert port is None
|
||||
mock_start.assert_not_called()
|
||||
@@ -145,6 +145,8 @@ export function Table<TData>(props: DataTableProps<TData>) {
|
||||
pageSize,
|
||||
initialSorting,
|
||||
initialColumnVisibility,
|
||||
initialRowSelection,
|
||||
initialViewSelected,
|
||||
draggable,
|
||||
footer,
|
||||
size = "lg",
|
||||
@@ -221,6 +223,8 @@ export function Table<TData>(props: DataTableProps<TData>) {
|
||||
pageSize: effectivePageSize,
|
||||
initialSorting,
|
||||
initialColumnVisibility,
|
||||
initialRowSelection,
|
||||
initialViewSelected,
|
||||
getRowId,
|
||||
onSelectionChange,
|
||||
searchTerm,
|
||||
|
||||
@@ -103,6 +103,10 @@ interface UseDataTableOptions<TData extends RowData> {
|
||||
initialSorting?: SortingState;
|
||||
/** Initial column visibility state. @default {} */
|
||||
initialColumnVisibility?: VisibilityState;
|
||||
/** Initial row selection state. Keys are row IDs (from `getRowId`), values are `true`. @default {} */
|
||||
initialRowSelection?: RowSelectionState;
|
||||
/** When true AND `initialRowSelection` is non-empty, start in view-selected mode (filtered to selected rows). @default false */
|
||||
initialViewSelected?: boolean;
|
||||
/** Called whenever the set of selected row IDs changes. */
|
||||
onSelectionChange?: (selectedIds: string[]) => void;
|
||||
/** Search term for global text filtering. Rows are filtered to those containing
|
||||
@@ -195,6 +199,8 @@ export default function useDataTable<TData extends RowData>(
|
||||
columnResizeMode = "onChange",
|
||||
initialSorting = [],
|
||||
initialColumnVisibility = {},
|
||||
initialRowSelection = {},
|
||||
initialViewSelected = false,
|
||||
getRowId,
|
||||
onSelectionChange,
|
||||
searchTerm,
|
||||
@@ -206,7 +212,8 @@ export default function useDataTable<TData extends RowData>(
|
||||
|
||||
// ---- internal state -----------------------------------------------------
|
||||
const [sorting, setSorting] = useState<SortingState>(initialSorting);
|
||||
const [rowSelection, setRowSelection] = useState<RowSelectionState>({});
|
||||
const [rowSelection, setRowSelection] =
|
||||
useState<RowSelectionState>(initialRowSelection);
|
||||
const [columnSizing, setColumnSizing] = useState<ColumnSizingState>({});
|
||||
const [columnVisibility, setColumnVisibility] = useState<VisibilityState>(
|
||||
initialColumnVisibility
|
||||
@@ -216,8 +223,12 @@ export default function useDataTable<TData extends RowData>(
|
||||
pageSize: pageSizeOption,
|
||||
});
|
||||
/** Combined global filter: view-mode (selected IDs) + text search. */
|
||||
const initialSelectedIds =
|
||||
initialViewSelected && Object.keys(initialRowSelection).length > 0
|
||||
? new Set(Object.keys(initialRowSelection))
|
||||
: null;
|
||||
const [globalFilter, setGlobalFilter] = useState<GlobalFilterValue>({
|
||||
selectedIds: null,
|
||||
selectedIds: initialSelectedIds,
|
||||
searchTerm: "",
|
||||
});
|
||||
|
||||
@@ -384,6 +395,31 @@ export default function useDataTable<TData extends RowData>(
|
||||
: data.length;
|
||||
const isPaginated = isFinite(pagination.pageSize);
|
||||
|
||||
// ---- keep view-mode filter in sync with selection ----------------------
|
||||
// When in view-selected mode, deselecting a row should remove it from
|
||||
// the visible set so it disappears immediately.
|
||||
useEffect(() => {
|
||||
if (isServerSide) return;
|
||||
if (globalFilter.selectedIds == null) return;
|
||||
|
||||
const currentIds = new Set(Object.keys(rowSelection));
|
||||
// Remove any ID from the filter that is no longer selected
|
||||
let changed = false;
|
||||
const next = new Set<string>();
|
||||
globalFilter.selectedIds.forEach((id) => {
|
||||
if (currentIds.has(id)) {
|
||||
next.add(id);
|
||||
} else {
|
||||
changed = true;
|
||||
}
|
||||
});
|
||||
if (changed) {
|
||||
setGlobalFilter((prev) => ({ ...prev, selectedIds: next }));
|
||||
}
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps -- only react to
|
||||
// selection changes while in view mode
|
||||
}, [rowSelection, isServerSide]);
|
||||
|
||||
// ---- selection change callback ------------------------------------------
|
||||
const isFirstRenderRef = useRef(true);
|
||||
const onSelectionChangeRef = useRef(onSelectionChange);
|
||||
@@ -392,6 +428,10 @@ export default function useDataTable<TData extends RowData>(
|
||||
useEffect(() => {
|
||||
if (isFirstRenderRef.current) {
|
||||
isFirstRenderRef.current = false;
|
||||
// Still fire the callback on first render if there's an initial selection
|
||||
if (selectedRowIds.length > 0) {
|
||||
onSelectionChangeRef.current?.(selectedRowIds);
|
||||
}
|
||||
return;
|
||||
}
|
||||
onSelectionChangeRef.current?.(selectedRowIds);
|
||||
|
||||
@@ -146,6 +146,10 @@ export interface DataTableProps<TData> {
|
||||
initialSorting?: SortingState;
|
||||
/** Initial column visibility state. */
|
||||
initialColumnVisibility?: VisibilityState;
|
||||
/** Initial row selection state. Keys are row IDs (from `getRowId`), values are `true`. */
|
||||
initialRowSelection?: Record<string, boolean>;
|
||||
/** When true AND `initialRowSelection` is non-empty, start in view-selected mode. @default false */
|
||||
initialViewSelected?: boolean;
|
||||
/** Enable drag-and-drop row reordering. */
|
||||
draggable?: DataTableDraggableConfig;
|
||||
/** Footer configuration. */
|
||||
|
||||
@@ -158,25 +158,22 @@ function Main({ ccPairId }: { ccPairId: number }) {
|
||||
mutate(buildCCPairInfoUrl(ccPairId));
|
||||
}, [ccPairId]);
|
||||
|
||||
const shouldConfirmConnectorDeletion = true;
|
||||
const finishConnectorDeletion = useCallback(() => {
|
||||
router.push("/admin/indexing/status");
|
||||
}, [router]);
|
||||
|
||||
const scheduleConnectorDeletion = useCallback(async () => {
|
||||
const scheduleConnectorDeletion = useCallback(() => {
|
||||
if (!ccPair) return;
|
||||
if (isSchedulingConnectorDeletionRef.current) return;
|
||||
isSchedulingConnectorDeletionRef.current = true;
|
||||
|
||||
try {
|
||||
await deleteCCPair(ccPair.connector.id, ccPair.credential.id, () =>
|
||||
mutate(buildCCPairInfoUrl(ccPair.id))
|
||||
deleteCCPair(ccPair.connector.id, ccPair.credential.id).catch((error) => {
|
||||
toast.error(
|
||||
"Failed to schedule deletion of connector - " + error.message
|
||||
);
|
||||
refresh();
|
||||
} catch (error) {
|
||||
console.error("Error deleting connector:", error);
|
||||
} finally {
|
||||
setShowDeleteConnectorConfirmModal(false);
|
||||
isSchedulingConnectorDeletionRef.current = false;
|
||||
}
|
||||
}, [ccPair, refresh]);
|
||||
});
|
||||
finishConnectorDeletion();
|
||||
}, [ccPair, finishConnectorDeletion]);
|
||||
|
||||
const latestIndexAttempt = indexAttempts?.[0];
|
||||
const canManageInlineFileConnectorFiles =
|
||||
@@ -194,10 +191,6 @@ function Main({ ccPairId }: { ccPairId: number }) {
|
||||
(error) => error.index_attempt_id === latestIndexAttempt?.id
|
||||
);
|
||||
|
||||
const finishConnectorDeletion = useCallback(() => {
|
||||
router.push("/admin/indexing/status?message=connector-deleted");
|
||||
}, [router]);
|
||||
|
||||
const handleStatusUpdate = async (
|
||||
newStatus: ConnectorCredentialPairStatus
|
||||
) => {
|
||||
@@ -520,13 +513,8 @@ function Main({ ccPairId }: { ccPairId: number }) {
|
||||
)}
|
||||
{!isDeleting && (
|
||||
<DropdownMenuItemWithTooltip
|
||||
onClick={async () => {
|
||||
if (shouldConfirmConnectorDeletion) {
|
||||
setShowDeleteConnectorConfirmModal(true);
|
||||
return;
|
||||
}
|
||||
|
||||
await scheduleConnectorDeletion();
|
||||
onClick={() => {
|
||||
setShowDeleteConnectorConfirmModal(true);
|
||||
}}
|
||||
disabled={!statusIsNotCurrentlyActive(ccPair.status)}
|
||||
className="flex items-center gap-x-2 cursor-pointer px-3 py-2 text-red-600 hover:text-red-700 dark:text-red-400 dark:hover:text-red-300"
|
||||
|
||||
13
web/src/app/admin/groups/[id]/page.tsx
Normal file
13
web/src/app/admin/groups/[id]/page.tsx
Normal file
@@ -0,0 +1,13 @@
|
||||
"use client";
|
||||
|
||||
import { use } from "react";
|
||||
import EditGroupPage from "@/refresh-pages/admin/GroupsPage/EditGroupPage";
|
||||
|
||||
export default function EditGroupRoute({
|
||||
params,
|
||||
}: {
|
||||
params: Promise<{ id: string }>;
|
||||
}) {
|
||||
const { id } = use(params);
|
||||
return <EditGroupPage groupId={Number(id)} />;
|
||||
}
|
||||
17
web/src/app/admin/groups2/[id]/page.tsx
Normal file
17
web/src/app/admin/groups2/[id]/page.tsx
Normal file
@@ -0,0 +1,17 @@
|
||||
"use client";
|
||||
|
||||
import { use } from "react";
|
||||
import EditGroupPage from "@/refresh-pages/admin/GroupsPage/EditGroupPage";
|
||||
|
||||
export default function EditGroupRoute({
|
||||
params,
|
||||
}: {
|
||||
params: Promise<{ id: string }>;
|
||||
}) {
|
||||
const { id } = use(params);
|
||||
const groupId = Number(id);
|
||||
if (Number.isNaN(groupId)) {
|
||||
return null;
|
||||
}
|
||||
return <EditGroupPage groupId={groupId} />;
|
||||
}
|
||||
1
web/src/app/admin/groups2/create/page.tsx
Normal file
1
web/src/app/admin/groups2/create/page.tsx
Normal file
@@ -0,0 +1 @@
|
||||
export { default } from "@/refresh-pages/admin/GroupsPage/CreateGroupPage";
|
||||
@@ -211,10 +211,6 @@ export default function Status() {
|
||||
message: "Connector created successfully",
|
||||
type: "success",
|
||||
},
|
||||
"connector-deleted": {
|
||||
message: "Connector deleted successfully",
|
||||
type: "success",
|
||||
},
|
||||
});
|
||||
|
||||
return (
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
import { ConnectorStatus } from "@/lib/types";
|
||||
import { ConnectorMultiSelect } from "@/components/ConnectorMultiSelect";
|
||||
|
||||
interface ConnectorEditorProps {
|
||||
selectedCCPairIds: number[];
|
||||
setSetCCPairIds: (ccPairId: number[]) => void;
|
||||
allCCPairs: ConnectorStatus<any, any>[];
|
||||
}
|
||||
|
||||
export const ConnectorEditor = ({
|
||||
selectedCCPairIds,
|
||||
setSetCCPairIds,
|
||||
allCCPairs,
|
||||
}: ConnectorEditorProps) => {
|
||||
// Filter out public docs, since they don't make sense as part of a group
|
||||
const privateCCPairs = allCCPairs.filter(
|
||||
(ccPair) => ccPair.access_type === "private"
|
||||
);
|
||||
|
||||
return (
|
||||
<ConnectorMultiSelect
|
||||
name="connectors"
|
||||
label="Connectors"
|
||||
connectors={privateCCPairs}
|
||||
selectedIds={selectedCCPairIds}
|
||||
onChange={setSetCCPairIds}
|
||||
placeholder="Search for connectors..."
|
||||
showError={true}
|
||||
/>
|
||||
);
|
||||
};
|
||||
@@ -1,87 +0,0 @@
|
||||
import { User } from "@/lib/types";
|
||||
import { FiX } from "react-icons/fi";
|
||||
import InputComboBox from "@/refresh-components/inputs/InputComboBox/InputComboBox";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
|
||||
interface UserEditorProps {
|
||||
selectedUserIds: string[];
|
||||
setSelectedUserIds: (userIds: string[]) => void;
|
||||
allUsers: User[];
|
||||
existingUsers: User[];
|
||||
onSubmit?: (users: User[]) => void;
|
||||
}
|
||||
|
||||
export const UserEditor = ({
|
||||
selectedUserIds,
|
||||
setSelectedUserIds,
|
||||
allUsers,
|
||||
existingUsers,
|
||||
onSubmit,
|
||||
}: UserEditorProps) => {
|
||||
const selectedUsers = allUsers.filter((user) =>
|
||||
selectedUserIds.includes(user.id)
|
||||
);
|
||||
|
||||
return (
|
||||
<>
|
||||
<div className="mb-2 flex flex-wrap gap-x-2">
|
||||
{selectedUsers.length > 0 &&
|
||||
selectedUsers.map((selectedUser) => (
|
||||
<div
|
||||
key={selectedUser.id}
|
||||
onClick={() => {
|
||||
setSelectedUserIds(
|
||||
selectedUserIds.filter((userId) => userId !== selectedUser.id)
|
||||
);
|
||||
}}
|
||||
className={`
|
||||
flex
|
||||
rounded-lg
|
||||
px-2
|
||||
py-1
|
||||
border
|
||||
border-border
|
||||
hover:bg-accent-background
|
||||
cursor-pointer`}
|
||||
>
|
||||
{selectedUser.email} <FiX className="ml-1 my-auto" />
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
|
||||
<div className="flex">
|
||||
<InputComboBox
|
||||
placeholder="Search..."
|
||||
value=""
|
||||
onChange={() => {}}
|
||||
onValueChange={(selectedValue) => {
|
||||
setSelectedUserIds([
|
||||
...Array.from(new Set([...selectedUserIds, selectedValue])),
|
||||
]);
|
||||
}}
|
||||
options={allUsers
|
||||
.filter(
|
||||
(user) =>
|
||||
!selectedUserIds.includes(user.id) &&
|
||||
!existingUsers.map((user) => user.id).includes(user.id)
|
||||
)
|
||||
.map((user) => ({
|
||||
label: user.email,
|
||||
value: user.id,
|
||||
}))}
|
||||
strict
|
||||
leftSearchIcon
|
||||
/>
|
||||
{onSubmit && (
|
||||
// TODO(@raunakab): migrate to opal Button once className/iconClassName is resolved
|
||||
<Button
|
||||
className="ml-3 flex-nowrap w-32"
|
||||
onClick={() => onSubmit(selectedUsers)}
|
||||
>
|
||||
Add Users
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
};
|
||||
@@ -1,153 +0,0 @@
|
||||
import { Form, Formik } from "formik";
|
||||
import * as Yup from "yup";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { ConnectorStatus, User, UserGroup } from "@/lib/types";
|
||||
import { TextFormField } from "@/components/Field";
|
||||
import { createUserGroup } from "./lib";
|
||||
import { UserEditor } from "./UserEditor";
|
||||
import { ConnectorEditor } from "./ConnectorEditor";
|
||||
import Modal from "@/refresh-components/Modal";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { SvgUsers } from "@opal/icons";
|
||||
import { useVectorDbEnabled } from "@/providers/SettingsProvider";
|
||||
export interface UserGroupCreationFormProps {
|
||||
onClose: () => void;
|
||||
users: User[];
|
||||
ccPairs: ConnectorStatus<any, any>[];
|
||||
existingUserGroup?: UserGroup;
|
||||
}
|
||||
|
||||
export default function UserGroupCreationForm({
|
||||
onClose,
|
||||
users,
|
||||
ccPairs,
|
||||
existingUserGroup,
|
||||
}: UserGroupCreationFormProps) {
|
||||
const isUpdate = existingUserGroup !== undefined;
|
||||
const vectorDbEnabled = useVectorDbEnabled();
|
||||
|
||||
const privateCcPairs = ccPairs.filter(
|
||||
(ccPair) => ccPair.access_type === "private"
|
||||
);
|
||||
|
||||
return (
|
||||
<Modal open onOpenChange={onClose}>
|
||||
<Modal.Content>
|
||||
<Modal.Header
|
||||
icon={SvgUsers}
|
||||
title={isUpdate ? "Update a User Group" : "Create a new User Group"}
|
||||
onClose={onClose}
|
||||
/>
|
||||
<Modal.Body>
|
||||
<Separator />
|
||||
|
||||
<Formik
|
||||
initialValues={{
|
||||
name: existingUserGroup ? existingUserGroup.name : "",
|
||||
user_ids: [] as string[],
|
||||
cc_pair_ids: [] as number[],
|
||||
}}
|
||||
validationSchema={Yup.object().shape({
|
||||
name: Yup.string().required("Please enter a name for the group"),
|
||||
user_ids: Yup.array().of(Yup.string().required()),
|
||||
cc_pair_ids: Yup.array().of(Yup.number().required()),
|
||||
})}
|
||||
onSubmit={async (values, formikHelpers) => {
|
||||
formikHelpers.setSubmitting(true);
|
||||
let response;
|
||||
response = await createUserGroup(values);
|
||||
formikHelpers.setSubmitting(false);
|
||||
if (response.ok) {
|
||||
toast.success(
|
||||
isUpdate
|
||||
? "Successfully updated user group!"
|
||||
: "Successfully created user group!"
|
||||
);
|
||||
onClose();
|
||||
} else {
|
||||
const responseJson = await response.json();
|
||||
const errorMsg = responseJson.detail || responseJson.message;
|
||||
toast.error(
|
||||
isUpdate
|
||||
? `Error updating user group - ${errorMsg}`
|
||||
: `Error creating user group - ${errorMsg}`
|
||||
);
|
||||
}
|
||||
}}
|
||||
>
|
||||
{({ isSubmitting, values, setFieldValue }) => (
|
||||
<Form>
|
||||
<TextFormField
|
||||
name="name"
|
||||
label="Name:"
|
||||
placeholder="A name for the User Group"
|
||||
disabled={isUpdate}
|
||||
/>
|
||||
|
||||
<Separator />
|
||||
|
||||
{vectorDbEnabled ? (
|
||||
<>
|
||||
<Text as="p" className="font-medium">
|
||||
Select which private connectors this group has access to:
|
||||
</Text>
|
||||
<Text as="p" text02>
|
||||
All documents indexed by the selected connectors will be
|
||||
visible to users in this group.
|
||||
</Text>
|
||||
|
||||
<ConnectorEditor
|
||||
allCCPairs={privateCcPairs}
|
||||
selectedCCPairIds={values.cc_pair_ids}
|
||||
setSetCCPairIds={(ccPairsIds) =>
|
||||
setFieldValue("cc_pair_ids", ccPairsIds)
|
||||
}
|
||||
/>
|
||||
</>
|
||||
) : (
|
||||
<Text as="p" text03>
|
||||
Connectors are not available in Onyx Lite. Redeploy Onyx
|
||||
with DISABLE_VECTOR_DB=false to index knowledge via
|
||||
connectors.
|
||||
</Text>
|
||||
)}
|
||||
|
||||
<Separator />
|
||||
|
||||
<Text as="p" className="font-medium">
|
||||
Select which Users should be a part of this Group.
|
||||
</Text>
|
||||
<Text as="p" text02>
|
||||
All selected users will be able to search through all
|
||||
documents indexed by the selected connectors.
|
||||
</Text>
|
||||
<div className="mb-3 gap-2">
|
||||
<UserEditor
|
||||
selectedUserIds={values.user_ids}
|
||||
setSelectedUserIds={(userIds) =>
|
||||
setFieldValue("user_ids", userIds)
|
||||
}
|
||||
allUsers={users}
|
||||
existingUsers={[]}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex">
|
||||
{/* TODO(@raunakab): migrate to opal Button once className/iconClassName is resolved */}
|
||||
<Button
|
||||
type="submit"
|
||||
disabled={isSubmitting}
|
||||
className="mx-auto w-64"
|
||||
>
|
||||
{isUpdate ? "Update!" : "Create!"}
|
||||
</Button>
|
||||
</div>
|
||||
</Form>
|
||||
)}
|
||||
</Formik>
|
||||
</Modal.Body>
|
||||
</Modal.Content>
|
||||
</Modal>
|
||||
);
|
||||
}
|
||||
@@ -1,177 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import {
|
||||
Table,
|
||||
TableHead,
|
||||
TableRow,
|
||||
TableBody,
|
||||
TableCell,
|
||||
} from "@/components/ui/table";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { LoadingAnimation } from "@/components/Loading";
|
||||
import { ConnectorTitle } from "@/components/admin/connectors/ConnectorTitle";
|
||||
import { deleteUserGroup } from "./lib";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { FiEdit2, FiUser } from "react-icons/fi";
|
||||
import { User, UserGroup } from "@/lib/types";
|
||||
import Link from "next/link";
|
||||
import { DeleteButton } from "@/components/DeleteButton";
|
||||
import { TableHeader } from "@/components/ui/table";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import { SvgEdit } from "@opal/icons";
|
||||
const MAX_USERS_TO_DISPLAY = 6;
|
||||
|
||||
const SimpleUserDisplay = ({ user }: { user: User }) => {
|
||||
return (
|
||||
<div className="flex my-0.5">
|
||||
<FiUser className="mr-2 my-auto" /> {user.email}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
interface UserGroupsTableProps {
|
||||
userGroups: UserGroup[];
|
||||
refresh: () => void;
|
||||
}
|
||||
|
||||
export const UserGroupsTable = ({
|
||||
userGroups,
|
||||
refresh,
|
||||
}: UserGroupsTableProps) => {
|
||||
const router = useRouter();
|
||||
|
||||
// sort by name for consistent ordering
|
||||
userGroups.sort((a, b) => {
|
||||
if (a.name < b.name) {
|
||||
return -1;
|
||||
} else if (a.name > b.name) {
|
||||
return 1;
|
||||
} else {
|
||||
return 0;
|
||||
}
|
||||
});
|
||||
|
||||
return (
|
||||
<div>
|
||||
<Table className="overflow-visible">
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>Name</TableHead>
|
||||
<TableHead>Connectors</TableHead>
|
||||
<TableHead>Users</TableHead>
|
||||
<TableHead>Status</TableHead>
|
||||
<TableHead>Delete</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{userGroups
|
||||
.filter((userGroup) => !userGroup.is_up_for_deletion)
|
||||
.map((userGroup) => {
|
||||
return (
|
||||
<TableRow key={userGroup.id}>
|
||||
<TableCell>
|
||||
{/* TODO(@raunakab): migrate to opal Button once className/iconClassName is resolved */}
|
||||
<Button
|
||||
internal
|
||||
leftIcon={SvgEdit}
|
||||
href={`/admin/groups/${userGroup.id}`}
|
||||
className="truncate"
|
||||
>
|
||||
{userGroup.name}
|
||||
</Button>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
{userGroup.cc_pairs.length > 0 ? (
|
||||
<div>
|
||||
{userGroup.cc_pairs.map((ccPairDescriptor, ind) => {
|
||||
return (
|
||||
<div
|
||||
className={
|
||||
ind !== userGroup.cc_pairs.length - 1
|
||||
? "mb-3"
|
||||
: ""
|
||||
}
|
||||
key={ccPairDescriptor.id}
|
||||
>
|
||||
<ConnectorTitle
|
||||
connector={ccPairDescriptor.connector}
|
||||
ccPairId={ccPairDescriptor.id}
|
||||
ccPairName={ccPairDescriptor.name}
|
||||
showMetadata={false}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
) : (
|
||||
"-"
|
||||
)}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
{userGroup.users.length > 0 ? (
|
||||
<div>
|
||||
{userGroup.users.length <= MAX_USERS_TO_DISPLAY ? (
|
||||
userGroup.users.map((user) => {
|
||||
return (
|
||||
<SimpleUserDisplay key={user.id} user={user} />
|
||||
);
|
||||
})
|
||||
) : (
|
||||
<div>
|
||||
{userGroup.users
|
||||
.slice(0, MAX_USERS_TO_DISPLAY)
|
||||
.map((user) => {
|
||||
return (
|
||||
<SimpleUserDisplay
|
||||
key={user.id}
|
||||
user={user}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
<div>
|
||||
+ {userGroup.users.length - MAX_USERS_TO_DISPLAY}{" "}
|
||||
more
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
) : (
|
||||
"-"
|
||||
)}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
{userGroup.is_up_to_date ? (
|
||||
<div className="text-success">Up to date!</div>
|
||||
) : (
|
||||
<div className="w-10">
|
||||
<LoadingAnimation text="Syncing" />
|
||||
</div>
|
||||
)}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<DeleteButton
|
||||
onClick={async (event) => {
|
||||
event.stopPropagation();
|
||||
const response = await deleteUserGroup(userGroup.id);
|
||||
if (response.ok) {
|
||||
toast.success(
|
||||
`User Group "${userGroup.name}" deleted`
|
||||
);
|
||||
} else {
|
||||
const errorMsg = (await response.json()).detail;
|
||||
toast.error(
|
||||
`Failed to delete User Group - ${errorMsg}`
|
||||
);
|
||||
}
|
||||
refresh();
|
||||
}}
|
||||
/>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
);
|
||||
})}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
@@ -1,83 +0,0 @@
|
||||
import { Button } from "@opal/components";
|
||||
import Modal from "@/refresh-components/Modal";
|
||||
import { useState } from "react";
|
||||
import { updateUserGroup } from "./lib";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { ConnectorStatus, UserGroup } from "@/lib/types";
|
||||
import { ConnectorMultiSelect } from "@/components/ConnectorMultiSelect";
|
||||
import { SvgPlus } from "@opal/icons";
|
||||
export interface AddConnectorFormProps {
|
||||
ccPairs: ConnectorStatus<any, any>[];
|
||||
userGroup: UserGroup;
|
||||
onClose: () => void;
|
||||
}
|
||||
|
||||
export default function AddConnectorForm({
|
||||
ccPairs,
|
||||
userGroup,
|
||||
onClose,
|
||||
}: AddConnectorFormProps) {
|
||||
const [selectedCCPairIds, setSelectedCCPairIds] = useState<number[]>([]);
|
||||
|
||||
// Filter out ccPairs that are already in the user group and are not private
|
||||
const availableCCPairs = ccPairs
|
||||
.filter(
|
||||
(ccPair) =>
|
||||
!userGroup.cc_pairs
|
||||
.map((userGroupCCPair) => userGroupCCPair.id)
|
||||
.includes(ccPair.cc_pair_id)
|
||||
)
|
||||
.filter((ccPair) => ccPair.access_type === "private");
|
||||
|
||||
return (
|
||||
<Modal open onOpenChange={onClose}>
|
||||
<Modal.Content width="sm" height="sm">
|
||||
<Modal.Header
|
||||
icon={SvgPlus}
|
||||
title="Add New Connector"
|
||||
onClose={onClose}
|
||||
/>
|
||||
<Modal.Body>
|
||||
<ConnectorMultiSelect
|
||||
name="connectors"
|
||||
label="Select Connectors"
|
||||
connectors={availableCCPairs}
|
||||
selectedIds={selectedCCPairIds}
|
||||
onChange={setSelectedCCPairIds}
|
||||
placeholder="Search for connectors to add..."
|
||||
showError={false}
|
||||
/>
|
||||
|
||||
<Button
|
||||
onClick={async () => {
|
||||
const newCCPairIds = [
|
||||
...Array.from(
|
||||
new Set(
|
||||
userGroup.cc_pairs
|
||||
.map((ccPair) => ccPair.id)
|
||||
.concat(selectedCCPairIds)
|
||||
)
|
||||
),
|
||||
];
|
||||
const response = await updateUserGroup(userGroup.id, {
|
||||
user_ids: userGroup.users.map((user) => user.id),
|
||||
cc_pair_ids: newCCPairIds,
|
||||
});
|
||||
if (response.ok) {
|
||||
toast.success("Successfully added connectors to group");
|
||||
onClose();
|
||||
} else {
|
||||
const responseJson = await response.json();
|
||||
const errorMsg = responseJson.detail || responseJson.message;
|
||||
toast.error(`Failed to add connectors to group - ${errorMsg}`);
|
||||
onClose();
|
||||
}
|
||||
}}
|
||||
>
|
||||
Add Connectors
|
||||
</Button>
|
||||
</Modal.Body>
|
||||
</Modal.Content>
|
||||
</Modal>
|
||||
);
|
||||
}
|
||||
@@ -1,64 +0,0 @@
|
||||
import Modal from "@/refresh-components/Modal";
|
||||
import { updateUserGroup } from "./lib";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { User, UserGroup } from "@/lib/types";
|
||||
import { UserEditor } from "../UserEditor";
|
||||
import { useState } from "react";
|
||||
import { SvgUserPlus } from "@opal/icons";
|
||||
export interface AddMemberFormProps {
|
||||
users: User[];
|
||||
userGroup: UserGroup;
|
||||
onClose: () => void;
|
||||
}
|
||||
|
||||
export default function AddMemberForm({
|
||||
users,
|
||||
userGroup,
|
||||
onClose,
|
||||
}: AddMemberFormProps) {
|
||||
const [selectedUserIds, setSelectedUserIds] = useState<string[]>([]);
|
||||
|
||||
return (
|
||||
<Modal open onOpenChange={onClose}>
|
||||
<Modal.Content width="sm" height="sm">
|
||||
<Modal.Header
|
||||
icon={SvgUserPlus}
|
||||
title="Add New User"
|
||||
onClose={onClose}
|
||||
/>
|
||||
<Modal.Body>
|
||||
<UserEditor
|
||||
selectedUserIds={selectedUserIds}
|
||||
setSelectedUserIds={setSelectedUserIds}
|
||||
allUsers={users}
|
||||
existingUsers={userGroup.users}
|
||||
onSubmit={async (selectedUsers) => {
|
||||
const newUserIds = [
|
||||
...Array.from(
|
||||
new Set(
|
||||
userGroup.users
|
||||
.map((user) => user.id)
|
||||
.concat(selectedUsers.map((user) => user.id))
|
||||
)
|
||||
),
|
||||
];
|
||||
const response = await updateUserGroup(userGroup.id, {
|
||||
user_ids: newUserIds,
|
||||
cc_pair_ids: userGroup.cc_pairs.map((ccPair) => ccPair.id),
|
||||
});
|
||||
if (response.ok) {
|
||||
toast.success("Successfully added users to group");
|
||||
onClose();
|
||||
} else {
|
||||
const responseJson = await response.json();
|
||||
const errorMsg = responseJson.detail || responseJson.message;
|
||||
toast.error(`Failed to add users to group - ${errorMsg}`);
|
||||
onClose();
|
||||
}
|
||||
}}
|
||||
/>
|
||||
</Modal.Body>
|
||||
</Modal.Content>
|
||||
</Modal>
|
||||
);
|
||||
}
|
||||
@@ -1,57 +0,0 @@
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import CreateRateLimitModal from "../../../../admin/token-rate-limits/CreateRateLimitModal";
|
||||
import { Scope } from "../../../../admin/token-rate-limits/types";
|
||||
import { insertGroupTokenRateLimit } from "../../../../admin/token-rate-limits/lib";
|
||||
import { mutate } from "swr";
|
||||
|
||||
interface AddMemberFormProps {
|
||||
isOpen: boolean;
|
||||
setIsOpen: (isOpen: boolean) => void;
|
||||
userGroupId: number;
|
||||
}
|
||||
|
||||
const handleCreateGroupTokenRateLimit = async (
|
||||
period_hours: number,
|
||||
token_budget: number,
|
||||
group_id: number = -1
|
||||
) => {
|
||||
const tokenRateLimitArgs = {
|
||||
enabled: true,
|
||||
token_budget: token_budget,
|
||||
period_hours: period_hours,
|
||||
};
|
||||
return await insertGroupTokenRateLimit(tokenRateLimitArgs, group_id);
|
||||
};
|
||||
|
||||
export const AddTokenRateLimitForm: React.FC<AddMemberFormProps> = ({
|
||||
isOpen,
|
||||
setIsOpen,
|
||||
userGroupId,
|
||||
}) => {
|
||||
const handleSubmit = (
|
||||
_: Scope,
|
||||
period_hours: number,
|
||||
token_budget: number,
|
||||
group_id: number = -1
|
||||
) => {
|
||||
handleCreateGroupTokenRateLimit(period_hours, token_budget, group_id)
|
||||
.then(() => {
|
||||
setIsOpen(false);
|
||||
toast.success("Token rate limit created!");
|
||||
mutate(`/api/admin/token-rate-limits/user-group/${userGroupId}`);
|
||||
})
|
||||
.catch((error) => {
|
||||
toast.error(error.message);
|
||||
});
|
||||
};
|
||||
|
||||
return (
|
||||
<CreateRateLimitModal
|
||||
isOpen={isOpen}
|
||||
setIsOpen={setIsOpen}
|
||||
onSubmit={handleSubmit}
|
||||
forSpecificScope={Scope.USER_GROUP}
|
||||
forSpecificUserGroup={userGroupId}
|
||||
/>
|
||||
);
|
||||
};
|
||||
@@ -1,483 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { useState } from "react";
|
||||
import { ConnectorTitle } from "@/components/admin/connectors/ConnectorTitle";
|
||||
import AddMemberForm from "./AddMemberForm";
|
||||
import { updateUserGroup, updateCuratorStatus } from "./lib";
|
||||
import { LoadingAnimation } from "@/components/Loading";
|
||||
import {
|
||||
User,
|
||||
UserGroup,
|
||||
UserRole,
|
||||
USER_ROLE_LABELS,
|
||||
ConnectorStatus,
|
||||
} from "@/lib/types";
|
||||
import AddConnectorForm from "./AddConnectorForm";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import InputSelect from "@/refresh-components/inputs/InputSelect";
|
||||
import Text from "@/components/ui/text";
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
TableCell,
|
||||
TableHead,
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/ui/table";
|
||||
import SimpleTooltip from "@/refresh-components/SimpleTooltip";
|
||||
import { Button } from "@opal/components";
|
||||
import { Disabled } from "@opal/core";
|
||||
import { DeleteButton } from "@/components/DeleteButton";
|
||||
import { Bubble } from "@/components/Bubble";
|
||||
import { BookmarkIcon, RobotIcon } from "@/components/icons/icons";
|
||||
import { AddTokenRateLimitForm } from "./AddTokenRateLimitForm";
|
||||
import { GenericTokenRateLimitTable } from "@/app/admin/token-rate-limits/TokenRateLimitTables";
|
||||
import { useUser } from "@/providers/UserProvider";
|
||||
import GenericConfirmModal from "@/components/modals/GenericConfirmModal";
|
||||
import Spacer from "@/refresh-components/Spacer";
|
||||
|
||||
interface GroupDisplayProps {
|
||||
users: User[];
|
||||
ccPairs: ConnectorStatus<any, any>[];
|
||||
userGroup: UserGroup;
|
||||
refreshUserGroup: () => void;
|
||||
}
|
||||
|
||||
const UserRoleDropdown = ({
|
||||
user,
|
||||
group,
|
||||
onSuccess,
|
||||
onError,
|
||||
isAdmin,
|
||||
}: {
|
||||
user: User;
|
||||
group: UserGroup;
|
||||
onSuccess: () => void;
|
||||
onError: (message: string) => void;
|
||||
isAdmin: boolean;
|
||||
}) => {
|
||||
const [localRole, setLocalRole] = useState(() => {
|
||||
if (user.role === UserRole.CURATOR) {
|
||||
return group.curator_ids.includes(user.id)
|
||||
? UserRole.CURATOR
|
||||
: UserRole.BASIC;
|
||||
}
|
||||
return user.role;
|
||||
});
|
||||
const [isSettingRole, setIsSettingRole] = useState(false);
|
||||
const [showDemoteConfirm, setShowDemoteConfirm] = useState(false);
|
||||
const [pendingRoleChange, setPendingRoleChange] = useState<string | null>(
|
||||
null
|
||||
);
|
||||
const { user: currentUser } = useUser();
|
||||
|
||||
const applyRoleChange = async (value: string) => {
|
||||
if (value === localRole) return;
|
||||
if (value === UserRole.BASIC || value === UserRole.CURATOR) {
|
||||
setIsSettingRole(true);
|
||||
setLocalRole(value);
|
||||
try {
|
||||
const response = await updateCuratorStatus(group.id, {
|
||||
user_id: user.id,
|
||||
is_curator: value === UserRole.CURATOR,
|
||||
});
|
||||
if (response.ok) {
|
||||
onSuccess();
|
||||
user.role = value;
|
||||
} else {
|
||||
const errorData = await response.json();
|
||||
throw new Error(errorData.detail || "Failed to update user role");
|
||||
}
|
||||
} catch (error: any) {
|
||||
onError(error.message);
|
||||
setLocalRole(user.role);
|
||||
} finally {
|
||||
setIsSettingRole(false);
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
const handleChange = (value: string) => {
|
||||
if (value === UserRole.BASIC && user.id === currentUser?.id) {
|
||||
setPendingRoleChange(value);
|
||||
setShowDemoteConfirm(true);
|
||||
} else {
|
||||
applyRoleChange(value);
|
||||
}
|
||||
};
|
||||
|
||||
const isEditable =
|
||||
user.role === UserRole.BASIC || user.role === UserRole.CURATOR;
|
||||
|
||||
return (
|
||||
<>
|
||||
{/* Confirmation modal - only shown when users try to demote themselves */}
|
||||
{showDemoteConfirm && pendingRoleChange && (
|
||||
<GenericConfirmModal
|
||||
title="Remove Yourself as a Curator for this Group?"
|
||||
message="Are you sure you want to change your role to Basic? This will remove your ability to curate this group."
|
||||
confirmText="Yes, set me to Basic"
|
||||
onClose={() => {
|
||||
// Cancel the role change if user dismisses modal
|
||||
setShowDemoteConfirm(false);
|
||||
setPendingRoleChange(null);
|
||||
}}
|
||||
onConfirm={() => {
|
||||
// Apply the role change if user confirms
|
||||
setShowDemoteConfirm(false);
|
||||
applyRoleChange(pendingRoleChange);
|
||||
setPendingRoleChange(null);
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
|
||||
{isEditable ? (
|
||||
<InputSelect
|
||||
value={localRole}
|
||||
onValueChange={handleChange}
|
||||
disabled={isSettingRole}
|
||||
>
|
||||
<InputSelect.Trigger placeholder="Select role" />
|
||||
|
||||
<InputSelect.Content>
|
||||
<InputSelect.Item value={UserRole.BASIC}>Basic</InputSelect.Item>
|
||||
<InputSelect.Item value={UserRole.CURATOR}>
|
||||
Curator
|
||||
</InputSelect.Item>
|
||||
</InputSelect.Content>
|
||||
</InputSelect>
|
||||
) : (
|
||||
<div>{USER_ROLE_LABELS[localRole]}</div>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
};
|
||||
|
||||
export const GroupDisplay = ({
|
||||
users,
|
||||
ccPairs,
|
||||
userGroup,
|
||||
refreshUserGroup,
|
||||
}: GroupDisplayProps) => {
|
||||
const [addMemberFormVisible, setAddMemberFormVisible] = useState(false);
|
||||
const [addConnectorFormVisible, setAddConnectorFormVisible] = useState(false);
|
||||
const [addRateLimitFormVisible, setAddRateLimitFormVisible] = useState(false);
|
||||
|
||||
const { isAdmin } = useUser();
|
||||
|
||||
const onRoleChangeSuccess = () =>
|
||||
toast.success("User role updated successfully!");
|
||||
const onRoleChangeError = (errorMsg: string) =>
|
||||
toast.error(`Unable to update user role - ${errorMsg}`);
|
||||
|
||||
return (
|
||||
<div>
|
||||
<div className="text-sm mb-3 flex">
|
||||
<Text className="mr-1">Status:</Text>{" "}
|
||||
{userGroup.is_up_to_date ? (
|
||||
<div className="text-success font-bold">Up to date</div>
|
||||
) : (
|
||||
<div className="text-accent font-bold">
|
||||
<LoadingAnimation text="Syncing" />
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<Separator />
|
||||
|
||||
<div className="flex w-full">
|
||||
<h2 className="text-xl font-bold">Users</h2>
|
||||
</div>
|
||||
|
||||
<div className="mt-2">
|
||||
{userGroup.users.length > 0 ? (
|
||||
<>
|
||||
<Table className="overflow-visible">
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>Email</TableHead>
|
||||
<TableHead>Role</TableHead>
|
||||
<TableHead className="flex w-full">
|
||||
<div className="ml-auto">Remove User</div>
|
||||
</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{userGroup.users.map((groupMember) => {
|
||||
return (
|
||||
<TableRow key={groupMember.id}>
|
||||
<TableCell className="whitespace-normal break-all">
|
||||
{groupMember.email}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<UserRoleDropdown
|
||||
user={groupMember}
|
||||
group={userGroup}
|
||||
onSuccess={onRoleChangeSuccess}
|
||||
onError={onRoleChangeError}
|
||||
isAdmin={isAdmin}
|
||||
/>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div className="flex w-full">
|
||||
<div className="ml-auto m-2">
|
||||
{(isAdmin ||
|
||||
!userGroup.curator_ids.includes(
|
||||
groupMember.id
|
||||
)) && (
|
||||
<DeleteButton
|
||||
onClick={async () => {
|
||||
const response = await updateUserGroup(
|
||||
userGroup.id,
|
||||
{
|
||||
user_ids: userGroup.users
|
||||
.filter(
|
||||
(userGroupUser) =>
|
||||
userGroupUser.id !== groupMember.id
|
||||
)
|
||||
.map(
|
||||
(userGroupUser) => userGroupUser.id
|
||||
),
|
||||
cc_pair_ids: userGroup.cc_pairs.map(
|
||||
(ccPair) => ccPair.id
|
||||
),
|
||||
}
|
||||
);
|
||||
if (response.ok) {
|
||||
toast.success(
|
||||
"Successfully removed user from group"
|
||||
);
|
||||
} else {
|
||||
const responseJson = await response.json();
|
||||
const errorMsg =
|
||||
responseJson.detail ||
|
||||
responseJson.message;
|
||||
toast.error(
|
||||
`Error removing user from group - ${errorMsg}`
|
||||
);
|
||||
}
|
||||
refreshUserGroup();
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
);
|
||||
})}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</>
|
||||
) : (
|
||||
<div className="text-sm">No users in this group...</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<SimpleTooltip
|
||||
tooltip="Cannot update group while sync is occurring"
|
||||
disabled={userGroup.is_up_to_date}
|
||||
>
|
||||
<Disabled disabled={!userGroup.is_up_to_date}>
|
||||
<Button
|
||||
onClick={() => {
|
||||
if (userGroup.is_up_to_date) {
|
||||
setAddMemberFormVisible(true);
|
||||
}
|
||||
}}
|
||||
>
|
||||
Add Users
|
||||
</Button>
|
||||
</Disabled>
|
||||
</SimpleTooltip>
|
||||
{addMemberFormVisible && (
|
||||
<AddMemberForm
|
||||
users={users}
|
||||
userGroup={userGroup}
|
||||
onClose={() => {
|
||||
setAddMemberFormVisible(false);
|
||||
refreshUserGroup();
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
|
||||
<Separator />
|
||||
|
||||
<h2 className="text-xl font-bold mt-8">Connectors</h2>
|
||||
<div className="mt-2">
|
||||
{userGroup.cc_pairs.length > 0 ? (
|
||||
<>
|
||||
<Table className="overflow-visible">
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>Connector</TableHead>
|
||||
<TableHead className="flex w-full">
|
||||
<div className="ml-auto">Remove Connector</div>
|
||||
</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{userGroup.cc_pairs.map((ccPair) => {
|
||||
return (
|
||||
<TableRow key={ccPair.id}>
|
||||
<TableCell className="whitespace-normal break-all">
|
||||
<ConnectorTitle
|
||||
connector={ccPair.connector}
|
||||
ccPairId={ccPair.id}
|
||||
ccPairName={ccPair.name}
|
||||
/>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<div className="flex w-full">
|
||||
<div className="ml-auto m-2">
|
||||
<DeleteButton
|
||||
onClick={async () => {
|
||||
const response = await updateUserGroup(
|
||||
userGroup.id,
|
||||
{
|
||||
user_ids: userGroup.users.map(
|
||||
(userGroupUser) => userGroupUser.id
|
||||
),
|
||||
cc_pair_ids: userGroup.cc_pairs
|
||||
.filter(
|
||||
(userGroupCCPair) =>
|
||||
userGroupCCPair.id != ccPair.id
|
||||
)
|
||||
.map((ccPair) => ccPair.id),
|
||||
}
|
||||
);
|
||||
if (response.ok) {
|
||||
toast.success(
|
||||
"Successfully removed connector from group"
|
||||
);
|
||||
} else {
|
||||
const responseJson = await response.json();
|
||||
const errorMsg =
|
||||
responseJson.detail || responseJson.message;
|
||||
toast.error(
|
||||
`Error removing connector from group - ${errorMsg}`
|
||||
);
|
||||
}
|
||||
refreshUserGroup();
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
);
|
||||
})}
|
||||
</TableBody>
|
||||
</Table>
|
||||
</>
|
||||
) : (
|
||||
<div className="text-sm">No connectors in this group...</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<SimpleTooltip
|
||||
tooltip="Cannot update group while sync is occurring"
|
||||
disabled={userGroup.is_up_to_date}
|
||||
>
|
||||
<Disabled disabled={!userGroup.is_up_to_date}>
|
||||
<Button
|
||||
onClick={() => {
|
||||
if (userGroup.is_up_to_date) {
|
||||
setAddConnectorFormVisible(true);
|
||||
}
|
||||
}}
|
||||
>
|
||||
Add Connectors
|
||||
</Button>
|
||||
</Disabled>
|
||||
</SimpleTooltip>
|
||||
|
||||
{addConnectorFormVisible && (
|
||||
<AddConnectorForm
|
||||
ccPairs={ccPairs}
|
||||
userGroup={userGroup}
|
||||
onClose={() => {
|
||||
setAddConnectorFormVisible(false);
|
||||
refreshUserGroup();
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
|
||||
<Separator />
|
||||
|
||||
<h2 className="text-xl font-bold mt-8 mb-2">Document Sets</h2>
|
||||
|
||||
<div>
|
||||
{userGroup.document_sets.length > 0 ? (
|
||||
<div className="flex flex-wrap gap-2">
|
||||
{userGroup.document_sets.map((documentSet) => {
|
||||
return (
|
||||
<Bubble isSelected key={documentSet.id}>
|
||||
<div className="flex">
|
||||
<BookmarkIcon />
|
||||
<Text className="ml-1">{documentSet.name}</Text>
|
||||
</div>
|
||||
</Bubble>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
) : (
|
||||
<>
|
||||
<Text>No document sets in this group...</Text>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<Separator />
|
||||
|
||||
<h2 className="text-xl font-bold mt-8 mb-2">Agents</h2>
|
||||
|
||||
<div>
|
||||
{userGroup.document_sets.length > 0 ? (
|
||||
<div className="flex flex-wrap gap-2">
|
||||
{userGroup.personas.map((persona) => {
|
||||
return (
|
||||
<Bubble isSelected key={persona.id}>
|
||||
<div className="flex">
|
||||
<RobotIcon />
|
||||
<Text className="ml-1">{persona.name}</Text>
|
||||
</div>
|
||||
</Bubble>
|
||||
);
|
||||
})}
|
||||
</div>
|
||||
) : (
|
||||
<>
|
||||
<Text>No Agents in this group...</Text>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<Separator />
|
||||
|
||||
<h2 className="text-xl font-bold mt-8 mb-2">Token Rate Limits</h2>
|
||||
|
||||
<AddTokenRateLimitForm
|
||||
isOpen={addRateLimitFormVisible}
|
||||
setIsOpen={setAddRateLimitFormVisible}
|
||||
userGroupId={userGroup.id}
|
||||
/>
|
||||
|
||||
<GenericTokenRateLimitTable
|
||||
fetchUrl={`/api/admin/token-rate-limits/user-group/${userGroup.id}`}
|
||||
hideHeading
|
||||
isAdmin={isAdmin}
|
||||
/>
|
||||
|
||||
{isAdmin && (
|
||||
<>
|
||||
<Spacer rem={0.75} />
|
||||
<Button onClick={() => setAddRateLimitFormVisible(true)}>
|
||||
Create a Token Rate Limit
|
||||
</Button>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
@@ -1,12 +0,0 @@
|
||||
import { useUserGroups } from "@/lib/hooks";
|
||||
|
||||
export const useSpecificUserGroup = (groupId: string) => {
|
||||
const { data, isLoading, error, refreshUserGroups } = useUserGroups();
|
||||
const userGroup = data?.find((group) => group.id.toString() === groupId);
|
||||
return {
|
||||
userGroup,
|
||||
isLoading,
|
||||
error,
|
||||
refreshUserGroup: refreshUserGroups,
|
||||
};
|
||||
};
|
||||
@@ -1,29 +0,0 @@
|
||||
import { UserGroupUpdate, SetCuratorRequest } from "../types";
|
||||
|
||||
export const updateUserGroup = async (
|
||||
groupId: number,
|
||||
userGroup: UserGroupUpdate
|
||||
) => {
|
||||
const url = `/api/manage/admin/user-group/${groupId}`;
|
||||
return await fetch(url, {
|
||||
method: "PATCH",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify(userGroup),
|
||||
});
|
||||
};
|
||||
|
||||
export const updateCuratorStatus = async (
|
||||
groupId: number,
|
||||
curatorRequest: SetCuratorRequest
|
||||
) => {
|
||||
const url = `/api/manage/admin/user-group/${groupId}/set-curator`;
|
||||
return await fetch(url, {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify(curatorRequest),
|
||||
});
|
||||
};
|
||||
@@ -1,87 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { use } from "react";
|
||||
import { GroupDisplay } from "./GroupDisplay";
|
||||
import { useSpecificUserGroup } from "./hook";
|
||||
import { ThreeDotsLoader } from "@/components/Loading";
|
||||
import { useConnectorStatus } from "@/lib/hooks";
|
||||
import useUsers from "@/hooks/useUsers";
|
||||
import { ADMIN_ROUTES } from "@/lib/admin-routes";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import { useVectorDbEnabled } from "@/providers/SettingsProvider";
|
||||
|
||||
const route = ADMIN_ROUTES.GROUPS;
|
||||
|
||||
function Main({ groupId }: { groupId: string }) {
|
||||
const vectorDbEnabled = useVectorDbEnabled();
|
||||
const {
|
||||
userGroup,
|
||||
isLoading: userGroupIsLoading,
|
||||
error: userGroupError,
|
||||
refreshUserGroup,
|
||||
} = useSpecificUserGroup(groupId);
|
||||
const {
|
||||
data: users,
|
||||
isLoading: userIsLoading,
|
||||
error: usersError,
|
||||
} = useUsers({ includeApiKeys: true });
|
||||
const {
|
||||
data: ccPairs,
|
||||
isLoading: isCCPairsLoading,
|
||||
error: ccPairsError,
|
||||
} = useConnectorStatus(30000, vectorDbEnabled);
|
||||
|
||||
if (
|
||||
userGroupIsLoading ||
|
||||
userIsLoading ||
|
||||
(vectorDbEnabled && isCCPairsLoading)
|
||||
) {
|
||||
return (
|
||||
<div className="h-full">
|
||||
<div className="my-auto">
|
||||
<ThreeDotsLoader />
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (!userGroup || userGroupError) {
|
||||
return <div>Error loading user group</div>;
|
||||
}
|
||||
if (!users || usersError) {
|
||||
return <div>Error loading users</div>;
|
||||
}
|
||||
if (vectorDbEnabled && (!ccPairs || ccPairsError)) {
|
||||
return <div>Error loading connectors</div>;
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<SettingsLayouts.Header
|
||||
icon={route.icon}
|
||||
title={userGroup.name || "Unknown"}
|
||||
separator
|
||||
backButton
|
||||
/>
|
||||
|
||||
<SettingsLayouts.Body>
|
||||
<GroupDisplay
|
||||
users={users.accepted}
|
||||
ccPairs={ccPairs ?? []}
|
||||
userGroup={userGroup}
|
||||
refreshUserGroup={refreshUserGroup}
|
||||
/>
|
||||
</SettingsLayouts.Body>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
export default function Page(props: { params: Promise<{ groupId: string }> }) {
|
||||
const params = use(props.params);
|
||||
|
||||
return (
|
||||
<SettingsLayouts.Root>
|
||||
<Main groupId={params.groupId} />
|
||||
</SettingsLayouts.Root>
|
||||
);
|
||||
}
|
||||
13
web/src/app/ee/admin/groups/[id]/page.tsx
Normal file
13
web/src/app/ee/admin/groups/[id]/page.tsx
Normal file
@@ -0,0 +1,13 @@
|
||||
"use client";
|
||||
|
||||
import { use } from "react";
|
||||
import EditGroupPage from "@/refresh-pages/admin/GroupsPage/EditGroupPage";
|
||||
|
||||
export default function EditGroupRoute({
|
||||
params,
|
||||
}: {
|
||||
params: Promise<{ id: string }>;
|
||||
}) {
|
||||
const { id } = use(params);
|
||||
return <EditGroupPage groupId={Number(id)} />;
|
||||
}
|
||||
@@ -1,17 +0,0 @@
|
||||
import { UserGroupCreation } from "./types";
|
||||
|
||||
export const createUserGroup = async (userGroup: UserGroupCreation) => {
|
||||
return fetch("/api/manage/admin/user-group", {
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify(userGroup),
|
||||
});
|
||||
};
|
||||
|
||||
export const deleteUserGroup = async (userGroupId: number) => {
|
||||
return fetch(`/api/manage/admin/user-group/${userGroupId}`, {
|
||||
method: "DELETE",
|
||||
});
|
||||
};
|
||||
@@ -1,89 +1 @@
|
||||
"use client";
|
||||
|
||||
import { UserGroupsTable } from "./UserGroupsTable";
|
||||
import UserGroupCreationForm from "./UserGroupCreationForm";
|
||||
import { useState } from "react";
|
||||
import { ThreeDotsLoader } from "@/components/Loading";
|
||||
import { useConnectorStatus, useUserGroups } from "@/lib/hooks";
|
||||
import useUsers from "@/hooks/useUsers";
|
||||
import { useUser } from "@/providers/UserProvider";
|
||||
import CreateButton from "@/refresh-components/buttons/CreateButton";
|
||||
import { ADMIN_ROUTES } from "@/lib/admin-routes";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import { useVectorDbEnabled } from "@/providers/SettingsProvider";
|
||||
|
||||
const route = ADMIN_ROUTES.GROUPS;
|
||||
|
||||
function Main() {
|
||||
const [showForm, setShowForm] = useState(false);
|
||||
const vectorDbEnabled = useVectorDbEnabled();
|
||||
|
||||
const { data, isLoading, error, refreshUserGroups } = useUserGroups();
|
||||
|
||||
const {
|
||||
data: ccPairs,
|
||||
isLoading: isCCPairsLoading,
|
||||
error: ccPairsError,
|
||||
} = useConnectorStatus(30000, vectorDbEnabled);
|
||||
|
||||
const {
|
||||
data: users,
|
||||
isLoading: userIsLoading,
|
||||
error: usersError,
|
||||
} = useUsers({ includeApiKeys: true });
|
||||
|
||||
const { isAdmin } = useUser();
|
||||
|
||||
if (isLoading || (vectorDbEnabled && isCCPairsLoading) || userIsLoading) {
|
||||
return <ThreeDotsLoader />;
|
||||
}
|
||||
|
||||
if (error || !data) {
|
||||
return <div className="text-red-600">Error loading users</div>;
|
||||
}
|
||||
|
||||
if (vectorDbEnabled && (ccPairsError || !ccPairs)) {
|
||||
return <div className="text-red-600">Error loading connectors</div>;
|
||||
}
|
||||
|
||||
if (usersError || !users) {
|
||||
return <div className="text-red-600">Error loading users</div>;
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
{isAdmin && (
|
||||
<CreateButton onClick={() => setShowForm(true)}>
|
||||
Create New User Group
|
||||
</CreateButton>
|
||||
)}
|
||||
{data.length > 0 && (
|
||||
<div className="mt-2">
|
||||
<UserGroupsTable userGroups={data} refresh={refreshUserGroups} />
|
||||
</div>
|
||||
)}
|
||||
{showForm && (
|
||||
<UserGroupCreationForm
|
||||
onClose={() => {
|
||||
refreshUserGroups();
|
||||
setShowForm(false);
|
||||
}}
|
||||
users={users.accepted}
|
||||
ccPairs={ccPairs ?? []}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
export default function Page() {
|
||||
return (
|
||||
<SettingsLayouts.Root>
|
||||
<SettingsLayouts.Header icon={route.icon} title={route.title} separator />
|
||||
|
||||
<SettingsLayouts.Body>
|
||||
<Main />
|
||||
</SettingsLayouts.Body>
|
||||
</SettingsLayouts.Root>
|
||||
);
|
||||
}
|
||||
export { default } from "@/refresh-pages/admin/GroupsPage";
|
||||
|
||||
@@ -1,15 +0,0 @@
|
||||
export interface UserGroupUpdate {
|
||||
user_ids: string[];
|
||||
cc_pair_ids: number[];
|
||||
}
|
||||
|
||||
export interface SetCuratorRequest {
|
||||
user_id: string;
|
||||
is_curator: boolean;
|
||||
}
|
||||
|
||||
export interface UserGroupCreation {
|
||||
name: string;
|
||||
user_ids: string[];
|
||||
cc_pair_ids: number[];
|
||||
}
|
||||
@@ -1,15 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import useSWR from "swr";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { HookResponse } from "@/refresh-pages/admin/HooksPage/interfaces";
|
||||
|
||||
export function useHooks() {
|
||||
const { data, isLoading, error, mutate } = useSWR<HookResponse[]>(
|
||||
"/api/admin/hooks",
|
||||
errorHandlingFetcher,
|
||||
{ revalidateOnFocus: false }
|
||||
);
|
||||
|
||||
return { hooks: data, isLoading, error, mutate };
|
||||
}
|
||||
@@ -27,20 +27,17 @@ export async function scheduleDeletionJobForConnector(
|
||||
export async function deleteCCPair(
|
||||
connectorId: number,
|
||||
credentialId: number,
|
||||
onCompletion: () => void
|
||||
onCompletion?: () => void
|
||||
) {
|
||||
const deletionScheduleError = await scheduleDeletionJobForConnector(
|
||||
connectorId,
|
||||
credentialId
|
||||
);
|
||||
if (deletionScheduleError) {
|
||||
toast.error(
|
||||
"Failed to schedule deletion of connector - " + deletionScheduleError
|
||||
);
|
||||
} else {
|
||||
toast.success("Scheduled deletion of connector!");
|
||||
throw new Error(deletionScheduleError);
|
||||
}
|
||||
onCompletion();
|
||||
toast.success("Scheduled deletion of connector!");
|
||||
onCompletion?.();
|
||||
}
|
||||
|
||||
export function isCurrentlyDeleting(
|
||||
|
||||
@@ -115,9 +115,10 @@ export async function refreshToken(
|
||||
}
|
||||
|
||||
export function getUserDisplayName(user: User | null): string {
|
||||
// Prioritize custom personal name if set
|
||||
// Prioritize custom personal name, if set.
|
||||
if (!!user?.personalization?.name) return user.personalization.name;
|
||||
// Then, prioritize personal email
|
||||
|
||||
// Then, prioritize personal email.
|
||||
if (!!user?.email) {
|
||||
const atIndex = user.email.indexOf("@");
|
||||
if (atIndex > 0) {
|
||||
@@ -129,6 +130,14 @@ export function getUserDisplayName(user: User | null): string {
|
||||
return "Anonymous";
|
||||
}
|
||||
|
||||
export function getUserEmail(user: User | null): string {
|
||||
// Prioritize personal email.
|
||||
if (!!user?.email) return user.email;
|
||||
|
||||
// If nothing works, then fall back to anonymous email.
|
||||
return "anonymous@email.com";
|
||||
}
|
||||
|
||||
/**
|
||||
* Derive display initials from a user's name or email.
|
||||
*
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import { SvgUser } from "@opal/icons";
|
||||
import { DEFAULT_AVATAR_SIZE_PX } from "@/lib/constants";
|
||||
import { getUserInitials } from "@/lib/user";
|
||||
import { getUserEmail, getUserInitials } from "@/lib/user";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import type { User } from "@/lib/types";
|
||||
|
||||
export interface UserAvatarProps {
|
||||
user: User;
|
||||
user: User | null;
|
||||
size?: number;
|
||||
}
|
||||
|
||||
@@ -13,16 +13,17 @@ export default function UserAvatar({
|
||||
user,
|
||||
size = DEFAULT_AVATAR_SIZE_PX,
|
||||
}: UserAvatarProps) {
|
||||
const initials = getUserInitials(
|
||||
user.personalization?.name ?? null,
|
||||
user.email
|
||||
const userEmail = getUserEmail(user);
|
||||
const userInitials = getUserInitials(
|
||||
user?.personalization?.name ?? null,
|
||||
userEmail
|
||||
);
|
||||
|
||||
if (!initials) {
|
||||
if (!userInitials) {
|
||||
return (
|
||||
<div
|
||||
role="img"
|
||||
aria-label={`${user.email} avatar`}
|
||||
aria-label={`${userEmail} avatar`}
|
||||
className="flex items-center justify-center rounded-full bg-background-tint-01"
|
||||
style={{ width: size, height: size }}
|
||||
>
|
||||
@@ -34,7 +35,7 @@ export default function UserAvatar({
|
||||
return (
|
||||
<div
|
||||
role="img"
|
||||
aria-label={`${user.email} avatar`}
|
||||
aria-label={`${userEmail} avatar`}
|
||||
className="flex items-center justify-center rounded-full bg-background-neutral-inverted-00"
|
||||
style={{ width: size, height: size }}
|
||||
>
|
||||
@@ -45,7 +46,7 @@ export default function UserAvatar({
|
||||
className="select-none"
|
||||
style={{ fontSize: size * 0.4 }}
|
||||
>
|
||||
{initials}
|
||||
{userInitials}
|
||||
</Text>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
"use client";
|
||||
|
||||
import React from "react";
|
||||
import type { ButtonType, IconFunctionComponent, IconProps } from "@opal/types";
|
||||
import type { ButtonType, IconFunctionComponent } from "@opal/types";
|
||||
import type { Route } from "next";
|
||||
import { Interactive } from "@opal/core";
|
||||
import { ContentAction } from "@opal/layouts";
|
||||
@@ -19,7 +19,7 @@ export interface SidebarTabProps {
|
||||
onClick?: React.MouseEventHandler<HTMLElement>;
|
||||
href?: string;
|
||||
type?: ButtonType;
|
||||
icon?: React.FunctionComponent<IconProps>;
|
||||
icon?: IconFunctionComponent;
|
||||
children?: React.ReactNode;
|
||||
rightChildren?: React.ReactNode;
|
||||
}
|
||||
|
||||
@@ -134,6 +134,7 @@ export default function ActionLineItem({
|
||||
icon={SvgSlash}
|
||||
onClick={noProp(onToggle)}
|
||||
internal
|
||||
aria-label={disabled ? "Enable" : "Disable"}
|
||||
className={cn(
|
||||
!disabled && "invisible group-hover/LineItem:visible",
|
||||
// Hide when showing source count (it has its own hover behavior)
|
||||
@@ -180,6 +181,11 @@ export default function ActionLineItem({
|
||||
|
||||
{isSearchToolAndNotInProject && (
|
||||
<Button
|
||||
aria-label={
|
||||
isSearchToolWithNoConnectors
|
||||
? "Add Connectors"
|
||||
: "Configure Connectors"
|
||||
}
|
||||
icon={
|
||||
isSearchToolWithNoConnectors ? SvgSettings : SvgChevronRight
|
||||
}
|
||||
|
||||
@@ -848,18 +848,18 @@ export default function ActionsPopover({
|
||||
|
||||
if (toolId === searchToolId) {
|
||||
if (wasDisabled) {
|
||||
// Enabling - restore previous sources or enable all (no persistence)
|
||||
// Enabling - restore previous sources or enable all (persisted to localStorage)
|
||||
const previous = previouslyEnabledSourcesRef.current;
|
||||
if (previous.length > 0) {
|
||||
setSelectedSources(previous);
|
||||
enableSources(previous);
|
||||
} else {
|
||||
setSelectedSources(configuredSources);
|
||||
baseEnableAllSources();
|
||||
}
|
||||
previouslyEnabledSourcesRef.current = [];
|
||||
} else {
|
||||
// Disabling - store current sources then disable all (no persistence)
|
||||
// Disabling - store current sources then disable all (persisted to localStorage)
|
||||
previouslyEnabledSourcesRef.current = [...selectedSources];
|
||||
setSelectedSources([]);
|
||||
baseDisableAllSources();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
@@ -17,7 +17,12 @@ import { toast } from "@/hooks/useToast";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import useAdminUsers from "@/hooks/useAdminUsers";
|
||||
import type { ApiKeyDescriptor, MemberRow } from "./interfaces";
|
||||
import { createGroup } from "./svc";
|
||||
import {
|
||||
createGroup,
|
||||
updateAgentGroupSharing,
|
||||
updateDocSetGroupSharing,
|
||||
saveTokenLimits,
|
||||
} from "./svc";
|
||||
import { apiKeyToMemberRow, memberTableColumns, PAGE_SIZE } from "./shared";
|
||||
import SharedGroupResources from "@/refresh-pages/admin/GroupsPage/SharedGroupResources";
|
||||
import TokenLimitSection from "./TokenLimitSection";
|
||||
@@ -62,7 +67,14 @@ function CreateGroupPage() {
|
||||
|
||||
setIsSubmitting(true);
|
||||
try {
|
||||
await createGroup(trimmed, selectedUserIds, selectedCcPairIds);
|
||||
const groupId = await createGroup(
|
||||
trimmed,
|
||||
selectedUserIds,
|
||||
selectedCcPairIds
|
||||
);
|
||||
await updateAgentGroupSharing(groupId, [], selectedAgentIds);
|
||||
await updateDocSetGroupSharing(groupId, [], selectedDocSetIds);
|
||||
await saveTokenLimits(groupId, tokenLimits, []);
|
||||
toast.success(`Group "${trimmed}" created`);
|
||||
router.push("/admin/groups");
|
||||
} catch (e) {
|
||||
|
||||
517
web/src/refresh-pages/admin/GroupsPage/EditGroupPage.tsx
Normal file
517
web/src/refresh-pages/admin/GroupsPage/EditGroupPage.tsx
Normal file
@@ -0,0 +1,517 @@
|
||||
"use client";
|
||||
|
||||
import { useCallback, useEffect, useMemo, useRef, useState } from "react";
|
||||
import { useRouter } from "next/navigation";
|
||||
import useSWR, { useSWRConfig } from "swr";
|
||||
import { Table, Button } from "@opal/components";
|
||||
import { IllustrationContent } from "@opal/layouts";
|
||||
import { SvgUsers, SvgTrash, SvgMinusCircle, SvgPlusCircle } from "@opal/icons";
|
||||
import IconButton from "@/refresh-components/buttons/IconButton";
|
||||
import Card from "@/refresh-components/cards/Card";
|
||||
import * as InputLayouts from "@/layouts/input-layouts";
|
||||
import SvgNoResult from "@opal/illustrations/no-result";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
|
||||
import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationModalLayout";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import useAdminUsers from "@/hooks/useAdminUsers";
|
||||
import type { UserGroup } from "@/lib/types";
|
||||
import type {
|
||||
ApiKeyDescriptor,
|
||||
MemberRow,
|
||||
TokenRateLimitDisplay,
|
||||
} from "./interfaces";
|
||||
import {
|
||||
apiKeyToMemberRow,
|
||||
baseColumns,
|
||||
memberTableColumns,
|
||||
tc,
|
||||
PAGE_SIZE,
|
||||
} from "./shared";
|
||||
import {
|
||||
USER_GROUP_URL,
|
||||
renameGroup,
|
||||
updateGroup,
|
||||
deleteGroup,
|
||||
updateAgentGroupSharing,
|
||||
updateDocSetGroupSharing,
|
||||
saveTokenLimits,
|
||||
} from "./svc";
|
||||
import SharedGroupResources from "@/refresh-pages/admin/GroupsPage/SharedGroupResources";
|
||||
import TokenLimitSection from "./TokenLimitSection";
|
||||
import type { TokenLimit } from "./TokenLimitSection";
|
||||
|
||||
const addModeColumns = memberTableColumns;
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Component
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
interface EditGroupPageProps {
|
||||
groupId: number;
|
||||
}
|
||||
|
||||
function EditGroupPage({ groupId }: EditGroupPageProps) {
|
||||
const router = useRouter();
|
||||
const { mutate } = useSWRConfig();
|
||||
|
||||
// Fetch the group data — poll every 5s while syncing so the UI updates
|
||||
// automatically when the backend finishes processing the previous edit.
|
||||
const {
|
||||
data: groups,
|
||||
isLoading: groupLoading,
|
||||
error: groupError,
|
||||
} = useSWR<UserGroup[]>(USER_GROUP_URL, errorHandlingFetcher, {
|
||||
refreshInterval: (latestData) => {
|
||||
const g = latestData?.find((g) => g.id === groupId);
|
||||
return g && !g.is_up_to_date ? 5000 : 0;
|
||||
},
|
||||
});
|
||||
|
||||
const group = useMemo(
|
||||
() => groups?.find((g) => g.id === groupId) ?? null,
|
||||
[groups, groupId]
|
||||
);
|
||||
|
||||
const isSyncing = group != null && !group.is_up_to_date;
|
||||
|
||||
// Fetch token rate limits for this group
|
||||
const { data: tokenRateLimits, isLoading: tokenLimitsLoading } = useSWR<
|
||||
TokenRateLimitDisplay[]
|
||||
>(`/api/admin/token-rate-limits/user-group/${groupId}`, errorHandlingFetcher);
|
||||
|
||||
// Form state
|
||||
const [groupName, setGroupName] = useState("");
|
||||
const [selectedUserIds, setSelectedUserIds] = useState<string[]>([]);
|
||||
const [searchTerm, setSearchTerm] = useState("");
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
const isSubmittingRef = useRef(false);
|
||||
const [selectedCcPairIds, setSelectedCcPairIds] = useState<number[]>([]);
|
||||
const [selectedDocSetIds, setSelectedDocSetIds] = useState<number[]>([]);
|
||||
const [selectedAgentIds, setSelectedAgentIds] = useState<number[]>([]);
|
||||
const [tokenLimits, setTokenLimits] = useState<TokenLimit[]>([
|
||||
{ tokenBudget: null, periodHours: null },
|
||||
]);
|
||||
const [showDeleteModal, setShowDeleteModal] = useState(false);
|
||||
const [isDeleting, setIsDeleting] = useState(false);
|
||||
const [initialized, setInitialized] = useState(false);
|
||||
const [isAddingMembers, setIsAddingMembers] = useState(false);
|
||||
const initialAgentIdsRef = useRef<number[]>([]);
|
||||
const initialDocSetIdsRef = useRef<number[]>([]);
|
||||
|
||||
// Users and API keys
|
||||
const { users, isLoading: usersLoading, error: usersError } = useAdminUsers();
|
||||
|
||||
const {
|
||||
data: apiKeys,
|
||||
isLoading: apiKeysLoading,
|
||||
error: apiKeysError,
|
||||
} = useSWR<ApiKeyDescriptor[]>("/api/admin/api-key", errorHandlingFetcher);
|
||||
|
||||
const isLoading =
|
||||
groupLoading || usersLoading || apiKeysLoading || tokenLimitsLoading;
|
||||
const error = groupError ?? usersError ?? apiKeysError;
|
||||
|
||||
// Pre-populate form when group data loads
|
||||
useEffect(() => {
|
||||
if (group && !initialized) {
|
||||
setGroupName(group.name);
|
||||
setSelectedUserIds(group.users.map((u) => u.id));
|
||||
setSelectedCcPairIds(group.cc_pairs.map((cc) => cc.id));
|
||||
const docSetIds = group.document_sets.map((ds) => ds.id);
|
||||
setSelectedDocSetIds(docSetIds);
|
||||
initialDocSetIdsRef.current = docSetIds;
|
||||
const agentIds = group.personas.map((p) => p.id);
|
||||
setSelectedAgentIds(agentIds);
|
||||
initialAgentIdsRef.current = agentIds;
|
||||
setInitialized(true);
|
||||
}
|
||||
}, [group, initialized]);
|
||||
|
||||
// Pre-populate token limits when fetched
|
||||
useEffect(() => {
|
||||
if (tokenRateLimits && tokenRateLimits.length > 0) {
|
||||
setTokenLimits(
|
||||
tokenRateLimits.map((trl) => ({
|
||||
tokenBudget: trl.token_budget,
|
||||
periodHours: trl.period_hours,
|
||||
}))
|
||||
);
|
||||
}
|
||||
}, [tokenRateLimits]);
|
||||
|
||||
const allRows = useMemo(() => {
|
||||
const activeUsers = users.filter((u) => u.is_active);
|
||||
const serviceAccountRows = (apiKeys ?? []).map(apiKeyToMemberRow);
|
||||
return [...activeUsers, ...serviceAccountRows];
|
||||
}, [users, apiKeys]);
|
||||
|
||||
const memberRows = useMemo(() => {
|
||||
const selected = new Set(selectedUserIds);
|
||||
return allRows.filter((r) => selected.has(r.id ?? r.email));
|
||||
}, [allRows, selectedUserIds]);
|
||||
|
||||
const currentRowSelection = useMemo(() => {
|
||||
const sel: Record<string, boolean> = {};
|
||||
for (const id of selectedUserIds) sel[id] = true;
|
||||
return sel;
|
||||
}, [selectedUserIds]);
|
||||
|
||||
const handleRemoveMember = useCallback((userId: string) => {
|
||||
setSelectedUserIds((prev) => prev.filter((id) => id !== userId));
|
||||
}, []);
|
||||
|
||||
const memberColumns = useMemo(
|
||||
() => [
|
||||
...baseColumns,
|
||||
tc.actions({
|
||||
showSorting: false,
|
||||
showColumnVisibility: false,
|
||||
cell: (row: MemberRow) => (
|
||||
<IconButton
|
||||
icon={SvgMinusCircle}
|
||||
tertiary
|
||||
onClick={(e) => {
|
||||
e.stopPropagation();
|
||||
handleRemoveMember(row.id ?? row.email);
|
||||
}}
|
||||
/>
|
||||
),
|
||||
}),
|
||||
],
|
||||
[handleRemoveMember]
|
||||
);
|
||||
|
||||
// IDs of members not visible in the add-mode table (e.g. inactive users).
|
||||
// We preserve these so they aren't silently removed when the table fires
|
||||
// onSelectionChange with only the visible rows.
|
||||
const hiddenMemberIds = useMemo(() => {
|
||||
const visibleIds = new Set(allRows.map((r) => r.id ?? r.email));
|
||||
return selectedUserIds.filter((id) => !visibleIds.has(id));
|
||||
}, [allRows, selectedUserIds]);
|
||||
|
||||
// Guard onSelectionChange: ignore updates until the form is fully initialized.
|
||||
// Without this, TanStack fires onSelectionChange before all rows are loaded,
|
||||
// which overwrites selectedUserIds with a partial set.
|
||||
const handleSelectionChange = useCallback(
|
||||
(ids: string[]) => {
|
||||
if (!initialized) return;
|
||||
setSelectedUserIds([...ids, ...hiddenMemberIds]);
|
||||
},
|
||||
[initialized, hiddenMemberIds]
|
||||
);
|
||||
|
||||
async function handleSave() {
|
||||
if (isSubmittingRef.current) return;
|
||||
|
||||
const trimmed = groupName.trim();
|
||||
if (!trimmed) {
|
||||
toast.error("Group name is required");
|
||||
return;
|
||||
}
|
||||
|
||||
// Re-fetch group to check sync status before saving
|
||||
const freshGroups = await fetch(USER_GROUP_URL).then((r) => r.json());
|
||||
const freshGroup = freshGroups.find((g: UserGroup) => g.id === groupId);
|
||||
if (freshGroup && !freshGroup.is_up_to_date) {
|
||||
toast.error(
|
||||
"This group is currently syncing. Please wait a moment and try again."
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
isSubmittingRef.current = true;
|
||||
setIsSubmitting(true);
|
||||
try {
|
||||
// Rename if name changed
|
||||
if (group && trimmed !== group.name) {
|
||||
await renameGroup(group.id, trimmed);
|
||||
}
|
||||
|
||||
// Update members and cc_pairs
|
||||
await updateGroup(groupId, selectedUserIds, selectedCcPairIds);
|
||||
|
||||
// Update agent sharing (add/remove this group from changed agents)
|
||||
await updateAgentGroupSharing(
|
||||
groupId,
|
||||
initialAgentIdsRef.current,
|
||||
selectedAgentIds
|
||||
);
|
||||
|
||||
// Update document set sharing (add/remove this group from changed doc sets)
|
||||
await updateDocSetGroupSharing(
|
||||
groupId,
|
||||
initialDocSetIdsRef.current,
|
||||
selectedDocSetIds
|
||||
);
|
||||
|
||||
// Save token rate limits (create/update/delete)
|
||||
await saveTokenLimits(groupId, tokenLimits, tokenRateLimits ?? []);
|
||||
|
||||
// Update refs so subsequent saves diff correctly
|
||||
initialAgentIdsRef.current = selectedAgentIds;
|
||||
initialDocSetIdsRef.current = selectedDocSetIds;
|
||||
|
||||
mutate(USER_GROUP_URL);
|
||||
mutate(`/api/admin/token-rate-limits/user-group/${groupId}`);
|
||||
toast.success(`Group "${trimmed}" updated`);
|
||||
router.push("/admin/groups");
|
||||
} catch (e) {
|
||||
toast.error(e instanceof Error ? e.message : "Failed to update group");
|
||||
} finally {
|
||||
isSubmittingRef.current = false;
|
||||
setIsSubmitting(false);
|
||||
}
|
||||
}
|
||||
|
||||
async function handleDelete() {
|
||||
setIsDeleting(true);
|
||||
try {
|
||||
await deleteGroup(groupId);
|
||||
mutate(USER_GROUP_URL);
|
||||
toast.success(`Group "${group?.name}" deleted`);
|
||||
router.push("/admin/groups");
|
||||
} catch (e) {
|
||||
toast.error(e instanceof Error ? e.message : "Failed to delete group");
|
||||
} finally {
|
||||
setIsDeleting(false);
|
||||
setShowDeleteModal(false);
|
||||
}
|
||||
}
|
||||
|
||||
// 404 state
|
||||
if (!isLoading && !error && !group) {
|
||||
return (
|
||||
<SettingsLayouts.Root width="sm">
|
||||
<SettingsLayouts.Header
|
||||
icon={SvgUsers}
|
||||
title="Group Not Found"
|
||||
separator
|
||||
/>
|
||||
<SettingsLayouts.Body>
|
||||
<IllustrationContent
|
||||
illustration={SvgNoResult}
|
||||
title="Group not found"
|
||||
description="This group doesn't exist or may have been deleted."
|
||||
/>
|
||||
</SettingsLayouts.Body>
|
||||
</SettingsLayouts.Root>
|
||||
);
|
||||
}
|
||||
|
||||
const headerActions = (
|
||||
<Section flexDirection="row" gap={0.5} width="auto" height="auto">
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
onClick={() => router.push("/admin/groups")}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
onClick={handleSave}
|
||||
disabled={!groupName.trim() || isSubmitting || isSyncing}
|
||||
tooltip={
|
||||
isSyncing
|
||||
? "Document embeddings are being updated due to recent changes to this group."
|
||||
: undefined
|
||||
}
|
||||
>
|
||||
{isSubmitting ? "Saving..." : isSyncing ? "Syncing..." : "Save Changes"}
|
||||
</Button>
|
||||
</Section>
|
||||
);
|
||||
|
||||
return (
|
||||
<>
|
||||
<SettingsLayouts.Root width="sm">
|
||||
<SettingsLayouts.Header
|
||||
icon={SvgUsers}
|
||||
title="Edit Group"
|
||||
separator
|
||||
rightChildren={headerActions}
|
||||
/>
|
||||
|
||||
<SettingsLayouts.Body>
|
||||
{isLoading && <SimpleLoader />}
|
||||
|
||||
{error && (
|
||||
<Text as="p" secondaryBody text03>
|
||||
Failed to load group data.
|
||||
</Text>
|
||||
)}
|
||||
|
||||
{!isLoading && !error && group && (
|
||||
<>
|
||||
{/* Group Name */}
|
||||
<Section
|
||||
gap={0.5}
|
||||
height="auto"
|
||||
alignItems="stretch"
|
||||
justifyContent="start"
|
||||
>
|
||||
<Text mainUiBody text04>
|
||||
Group Name
|
||||
</Text>
|
||||
<InputTypeIn
|
||||
placeholder="Name your group"
|
||||
value={groupName}
|
||||
onChange={(e) => setGroupName(e.target.value)}
|
||||
/>
|
||||
</Section>
|
||||
|
||||
<Separator noPadding />
|
||||
|
||||
{/* Members table */}
|
||||
<Section
|
||||
gap={0.75}
|
||||
height="auto"
|
||||
alignItems="stretch"
|
||||
justifyContent="start"
|
||||
>
|
||||
<Section
|
||||
flexDirection="row"
|
||||
gap={0.5}
|
||||
height="auto"
|
||||
alignItems="center"
|
||||
justifyContent="start"
|
||||
>
|
||||
<InputTypeIn
|
||||
value={searchTerm}
|
||||
onChange={(e) => setSearchTerm(e.target.value)}
|
||||
placeholder={
|
||||
isAddingMembers
|
||||
? "Search users and accounts..."
|
||||
: "Search members..."
|
||||
}
|
||||
leftSearchIcon
|
||||
className="flex-1"
|
||||
/>
|
||||
{isAddingMembers ? (
|
||||
<Button
|
||||
prominence="secondary"
|
||||
onClick={() => setIsAddingMembers(false)}
|
||||
>
|
||||
Done
|
||||
</Button>
|
||||
) : (
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
icon={SvgPlusCircle}
|
||||
onClick={() => setIsAddingMembers(true)}
|
||||
>
|
||||
Add
|
||||
</Button>
|
||||
)}
|
||||
</Section>
|
||||
|
||||
{isAddingMembers ? (
|
||||
<Table
|
||||
key="add-members"
|
||||
data={allRows as MemberRow[]}
|
||||
columns={addModeColumns}
|
||||
getRowId={(row) => row.id ?? row.email}
|
||||
pageSize={PAGE_SIZE}
|
||||
searchTerm={searchTerm}
|
||||
selectionBehavior="multi-select"
|
||||
initialRowSelection={currentRowSelection}
|
||||
onSelectionChange={handleSelectionChange}
|
||||
footer={{}}
|
||||
emptyState={
|
||||
<IllustrationContent
|
||||
illustration={SvgNoResult}
|
||||
title="No users found"
|
||||
description="No users match your search."
|
||||
/>
|
||||
}
|
||||
/>
|
||||
) : (
|
||||
<Table
|
||||
data={memberRows}
|
||||
columns={memberColumns}
|
||||
getRowId={(row) => row.id ?? row.email}
|
||||
pageSize={PAGE_SIZE}
|
||||
searchTerm={searchTerm}
|
||||
footer={{}}
|
||||
emptyState={
|
||||
<IllustrationContent
|
||||
illustration={SvgNoResult}
|
||||
title="No members"
|
||||
description="Add members to this group."
|
||||
/>
|
||||
}
|
||||
/>
|
||||
)}
|
||||
</Section>
|
||||
|
||||
<SharedGroupResources
|
||||
selectedCcPairIds={selectedCcPairIds}
|
||||
onCcPairIdsChange={setSelectedCcPairIds}
|
||||
selectedDocSetIds={selectedDocSetIds}
|
||||
onDocSetIdsChange={setSelectedDocSetIds}
|
||||
selectedAgentIds={selectedAgentIds}
|
||||
onAgentIdsChange={setSelectedAgentIds}
|
||||
/>
|
||||
|
||||
<TokenLimitSection
|
||||
limits={tokenLimits}
|
||||
onLimitsChange={setTokenLimits}
|
||||
/>
|
||||
|
||||
{/* Delete This Group */}
|
||||
<Card>
|
||||
<InputLayouts.Horizontal
|
||||
title="Delete This Group"
|
||||
description="Members will lose access to any resources shared with this group."
|
||||
center
|
||||
nonInteractive
|
||||
>
|
||||
<Button
|
||||
variant="danger"
|
||||
prominence="secondary"
|
||||
icon={SvgTrash}
|
||||
onClick={() => setShowDeleteModal(true)}
|
||||
>
|
||||
Delete Group
|
||||
</Button>
|
||||
</InputLayouts.Horizontal>
|
||||
</Card>
|
||||
</>
|
||||
)}
|
||||
</SettingsLayouts.Body>
|
||||
</SettingsLayouts.Root>
|
||||
|
||||
{showDeleteModal && (
|
||||
<ConfirmationModalLayout
|
||||
icon={SvgTrash}
|
||||
title="Delete Group"
|
||||
onClose={() => setShowDeleteModal(false)}
|
||||
submit={
|
||||
<Button
|
||||
variant="danger"
|
||||
onClick={handleDelete}
|
||||
disabled={isDeleting}
|
||||
>
|
||||
{isDeleting ? "Deleting..." : "Delete"}
|
||||
</Button>
|
||||
}
|
||||
>
|
||||
<Text as="p" text03>
|
||||
Members of group{" "}
|
||||
<Text as="span" text05>
|
||||
{group?.name}
|
||||
</Text>{" "}
|
||||
will lose access to any resources shared with this group, unless
|
||||
they have been granted access directly. Deletion cannot be undone.
|
||||
</Text>
|
||||
</ConfirmationModalLayout>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
export default EditGroupPage;
|
||||
@@ -1,5 +1,6 @@
|
||||
"use client";
|
||||
|
||||
import { useRouter } from "next/navigation";
|
||||
import type { UserGroup } from "@/lib/types";
|
||||
import { SvgChevronRight, SvgUserManage, SvgUsers } from "@opal/icons";
|
||||
import { ContentAction } from "@opal/layouts";
|
||||
@@ -21,6 +22,7 @@ interface GroupCardProps {
|
||||
}
|
||||
|
||||
function GroupCard({ group }: GroupCardProps) {
|
||||
const router = useRouter();
|
||||
const { mutate } = useSWRConfig();
|
||||
const builtIn = isBuiltInGroup(group);
|
||||
const isAdmin = group.name === "Admin";
|
||||
@@ -39,7 +41,7 @@ function GroupCard({ group }: GroupCardProps) {
|
||||
}
|
||||
|
||||
return (
|
||||
<Card padding={0.5}>
|
||||
<Card padding={0.5} data-card>
|
||||
<ContentAction
|
||||
icon={isAdmin ? SvgUserManage : SvgUsers}
|
||||
title={group.name}
|
||||
@@ -53,13 +55,17 @@ function GroupCard({ group }: GroupCardProps) {
|
||||
<Section flexDirection="row" alignItems="start" gap={0}>
|
||||
<div className="py-1">
|
||||
<Text mainUiBody text03>
|
||||
{formatMemberCount(group.users.length)}
|
||||
{formatMemberCount(
|
||||
group.users.filter((u) => u.is_active).length
|
||||
)}
|
||||
</Text>
|
||||
</div>
|
||||
<Button
|
||||
icon={SvgChevronRight}
|
||||
prominence="tertiary"
|
||||
tooltip="View group"
|
||||
aria-label="View group"
|
||||
onClick={() => router.push(`/admin/groups/${group.id}`)}
|
||||
/>
|
||||
</Section>
|
||||
}
|
||||
|
||||
@@ -28,28 +28,32 @@ function ResourceContent({
|
||||
}: ResourceContentProps) {
|
||||
return (
|
||||
<div className="flex flex-1 gap-0.5 items-start p-1.5 rounded-08 bg-background-tint-01 min-w-[240px] max-w-[302px]">
|
||||
<div className="flex flex-1 gap-1 p-0.5 items-start min-w-0">
|
||||
<div className="flex flex-1 gap-1 p-0.5 items-center min-w-0">
|
||||
{leftContent ? (
|
||||
<>
|
||||
{leftContent}
|
||||
<div className="flex-1 min-w-0">
|
||||
<Content
|
||||
title={title}
|
||||
description={description}
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
/>
|
||||
</div>
|
||||
</>
|
||||
) : (
|
||||
<div className="flex-1 min-w-0">
|
||||
<Content
|
||||
icon={icon}
|
||||
title={title}
|
||||
description={description}
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
/>
|
||||
</>
|
||||
) : (
|
||||
<Content
|
||||
icon={icon}
|
||||
title={title}
|
||||
description={description}
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
{infoContent}
|
||||
</div>
|
||||
{infoContent}
|
||||
<IconButton small icon={SvgX} onClick={onRemove} className="shrink-0" />
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -23,11 +23,7 @@ function ResourcePopover({
|
||||
|
||||
return (
|
||||
<Popover open={open} onOpenChange={setOpen}>
|
||||
<Popover.Trigger
|
||||
onClick={(e) => {
|
||||
e.preventDefault();
|
||||
}}
|
||||
>
|
||||
<Popover.Anchor>
|
||||
<InputTypeIn
|
||||
placeholder={placeholder}
|
||||
value={searchValue}
|
||||
@@ -37,7 +33,7 @@ function ResourcePopover({
|
||||
}}
|
||||
onFocus={() => setOpen(true)}
|
||||
/>
|
||||
</Popover.Trigger>
|
||||
</Popover.Anchor>
|
||||
<Popover.Content
|
||||
width="trigger"
|
||||
align="start"
|
||||
|
||||
@@ -40,7 +40,11 @@ function SharedBadge() {
|
||||
);
|
||||
}
|
||||
|
||||
function SourceIconStack({ sources }: { sources: { source: ValidSources }[] }) {
|
||||
interface SourceIconStackProps {
|
||||
sources: { source: ValidSources }[];
|
||||
}
|
||||
|
||||
function SourceIconStack({ sources }: SourceIconStackProps) {
|
||||
if (sources.length === 0) return null;
|
||||
|
||||
const unique = Array.from(
|
||||
@@ -51,16 +55,17 @@ function SourceIconStack({ sources }: { sources: { source: ValidSources }[] }) {
|
||||
<Section
|
||||
flexDirection="row"
|
||||
alignItems="center"
|
||||
width="auto"
|
||||
height="auto"
|
||||
gap={0}
|
||||
className="shrink-0 px-0.5"
|
||||
className="shrink-0 p-0.5"
|
||||
>
|
||||
{unique.map((s, i) => {
|
||||
const Icon = getSourceMetadata(s.source).icon;
|
||||
return (
|
||||
<div
|
||||
key={s.source}
|
||||
className="flex items-center justify-center size-4 rounded-04 bg-background-tint-00 border border-border-01 overflow-hidden"
|
||||
className="flex items-center justify-center size-4 rounded-04 bg-background-tint-00 border border-border-01 overflow-hidden [&_img]:!size-4 [&_img]:!m-0 [&_svg]:size-4"
|
||||
style={{ zIndex: unique.length - i, marginLeft: i > 0 ? -6 : 0 }}
|
||||
>
|
||||
<Icon />
|
||||
@@ -316,7 +321,7 @@ function SharedGroupResources({
|
||||
key={`d-${ds.id}`}
|
||||
icon={SvgFiles}
|
||||
title={ds.name}
|
||||
description={`Document Set - ${ds.cc_pair_summaries.length} Sources`}
|
||||
description="Document Set"
|
||||
infoContent={
|
||||
<SourceIconStack sources={ds.cc_pair_summaries} />
|
||||
}
|
||||
|
||||
@@ -31,7 +31,10 @@ function GroupsPage() {
|
||||
{/* This is the sticky header for the groups page. It is used to display
|
||||
* the groups page title and search input when scrolling down.
|
||||
*/}
|
||||
<div className="sticky top-0 z-settings-header bg-background-tint-01">
|
||||
<div
|
||||
className="sticky top-0 z-settings-header bg-background-tint-01"
|
||||
data-testid="groups-page-heading"
|
||||
>
|
||||
<SettingsLayouts.Header icon={SvgUsers} title="Groups" separator />
|
||||
|
||||
<Section flexDirection="row" padding={1}>
|
||||
|
||||
@@ -13,3 +13,10 @@ export interface ApiKeyDescriptor {
|
||||
export interface MemberRow extends UserRow {
|
||||
api_key_display?: string;
|
||||
}
|
||||
|
||||
export interface TokenRateLimitDisplay {
|
||||
token_id: number;
|
||||
enabled: boolean;
|
||||
token_budget: number;
|
||||
period_hours: number;
|
||||
}
|
||||
|
||||
@@ -73,7 +73,7 @@ function renderAccountTypeColumn(_value: unknown, row: MemberRow) {
|
||||
// Columns
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
const tc = createTableColumns<MemberRow>();
|
||||
export const tc = createTableColumns<MemberRow>();
|
||||
|
||||
export const baseColumns = [
|
||||
tc.qualifier(),
|
||||
|
||||
@@ -40,6 +40,27 @@ async function createGroup(
|
||||
return group.id;
|
||||
}
|
||||
|
||||
async function updateGroup(
|
||||
groupId: number,
|
||||
userIds: string[],
|
||||
ccPairIds: number[]
|
||||
): Promise<void> {
|
||||
const res = await fetch(`${USER_GROUP_URL}/${groupId}`, {
|
||||
method: "PATCH",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({
|
||||
user_ids: userIds,
|
||||
cc_pair_ids: ccPairIds,
|
||||
}),
|
||||
});
|
||||
if (!res.ok) {
|
||||
const detail = await res.json().catch(() => null);
|
||||
throw new Error(
|
||||
detail?.detail ?? `Failed to update group: ${res.statusText}`
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
async function deleteGroup(groupId: number): Promise<void> {
|
||||
const res = await fetch(`${USER_GROUP_URL}/${groupId}`, {
|
||||
method: "DELETE",
|
||||
@@ -262,6 +283,7 @@ export {
|
||||
USER_GROUP_URL,
|
||||
renameGroup,
|
||||
createGroup,
|
||||
updateGroup,
|
||||
deleteGroup,
|
||||
updateAgentGroupSharing,
|
||||
updateDocSetGroupSharing,
|
||||
|
||||
@@ -1,334 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { Button } from "@opal/components";
|
||||
import { Disabled } from "@opal/core";
|
||||
import { SvgHookNodes } from "@opal/icons";
|
||||
import Modal, { BasicModalFooter } from "@/refresh-components/Modal";
|
||||
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
|
||||
import InputSelect from "@/refresh-components/inputs/InputSelect";
|
||||
import PasswordInputTypeIn from "@/refresh-components/inputs/PasswordInputTypeIn";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { createHook, updateHook } from "@/refresh-pages/admin/HooksPage/svc";
|
||||
import type {
|
||||
HookFailStrategy,
|
||||
HookPointMeta,
|
||||
HookResponse,
|
||||
} from "@/refresh-pages/admin/HooksPage/interfaces";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
interface HookFormModalProps {
|
||||
open: boolean;
|
||||
onOpenChange: (open: boolean) => void;
|
||||
/** When provided, the modal is in edit mode for this hook. */
|
||||
hook?: HookResponse;
|
||||
/** When provided (create mode), the hook point is pre-selected and locked. */
|
||||
spec?: HookPointMeta;
|
||||
onSuccess: (hook: HookResponse) => void;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
interface FormState {
|
||||
name: string;
|
||||
endpoint_url: string;
|
||||
api_key: string;
|
||||
fail_strategy: HookFailStrategy;
|
||||
timeout_seconds: string;
|
||||
}
|
||||
|
||||
function buildInitialState(
|
||||
hook: HookResponse | undefined,
|
||||
spec: HookPointMeta | undefined
|
||||
): FormState {
|
||||
if (hook) {
|
||||
return {
|
||||
name: hook.name,
|
||||
endpoint_url: hook.endpoint_url ?? "",
|
||||
api_key: hook.api_key_masked ?? "",
|
||||
fail_strategy: hook.fail_strategy,
|
||||
timeout_seconds: String(hook.timeout_seconds),
|
||||
};
|
||||
}
|
||||
return {
|
||||
name: "",
|
||||
endpoint_url: "",
|
||||
api_key: "",
|
||||
fail_strategy: spec?.default_fail_strategy ?? "soft",
|
||||
timeout_seconds: spec ? String(spec.default_timeout_seconds) : "5",
|
||||
};
|
||||
}
|
||||
|
||||
const SOFT_DESCRIPTION =
|
||||
"If the endpoint returns an error, Onyx logs it and continues the pipeline as normal, ignoring the hook result.";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Sub-components
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
interface FieldProps {
|
||||
label: string;
|
||||
required?: boolean;
|
||||
description?: string;
|
||||
children: React.ReactNode;
|
||||
}
|
||||
|
||||
function Field({ label, required, description, children }: FieldProps) {
|
||||
return (
|
||||
<div className="flex flex-col gap-1 w-full">
|
||||
<span className="font-main-ui-action text-text-04 px-[0.125rem]">
|
||||
{label}
|
||||
{required && <span className="text-status-error-05 ml-0.5">*</span>}
|
||||
</span>
|
||||
{children}
|
||||
{description && (
|
||||
<Text secondaryBody text03>
|
||||
{description}
|
||||
</Text>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Component
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export default function HookFormModal({
|
||||
open,
|
||||
onOpenChange,
|
||||
hook,
|
||||
spec,
|
||||
onSuccess,
|
||||
}: HookFormModalProps) {
|
||||
const isEdit = !!hook;
|
||||
const [form, setForm] = useState<FormState>(() =>
|
||||
buildInitialState(hook, spec)
|
||||
);
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
|
||||
function handleOpenChange(next: boolean) {
|
||||
if (!next) {
|
||||
setTimeout(() => {
|
||||
setForm(buildInitialState(hook, spec));
|
||||
setIsSubmitting(false);
|
||||
}, 200);
|
||||
}
|
||||
onOpenChange(next);
|
||||
}
|
||||
|
||||
function set<K extends keyof FormState>(key: K, value: FormState[K]) {
|
||||
setForm((prev) => ({ ...prev, [key]: value }));
|
||||
}
|
||||
|
||||
const timeoutNum = parseFloat(form.timeout_seconds);
|
||||
const isValid =
|
||||
form.name.trim().length > 0 &&
|
||||
form.endpoint_url.trim().length > 0 &&
|
||||
!isNaN(timeoutNum) &&
|
||||
timeoutNum > 0;
|
||||
|
||||
async function handleSubmit() {
|
||||
if (!isValid) return;
|
||||
|
||||
setIsSubmitting(true);
|
||||
try {
|
||||
let result: HookResponse;
|
||||
if (isEdit && hook) {
|
||||
const req: Record<string, unknown> = {};
|
||||
if (form.name !== hook.name) req.name = form.name;
|
||||
if (form.endpoint_url !== (hook.endpoint_url ?? ""))
|
||||
req.endpoint_url = form.endpoint_url;
|
||||
if (form.fail_strategy !== hook.fail_strategy)
|
||||
req.fail_strategy = form.fail_strategy;
|
||||
if (timeoutNum !== hook.timeout_seconds)
|
||||
req.timeout_seconds = timeoutNum;
|
||||
const maskedPlaceholder = hook.api_key_masked ?? "";
|
||||
if (form.api_key !== maskedPlaceholder) {
|
||||
req.api_key = form.api_key || null;
|
||||
}
|
||||
if (Object.keys(req).length === 0) {
|
||||
handleOpenChange(false);
|
||||
return;
|
||||
}
|
||||
result = await updateHook(hook.id, req);
|
||||
} else {
|
||||
const hookPoint = spec!.hook_point;
|
||||
result = await createHook({
|
||||
name: form.name,
|
||||
hook_point: hookPoint,
|
||||
endpoint_url: form.endpoint_url,
|
||||
...(form.api_key ? { api_key: form.api_key } : {}),
|
||||
fail_strategy: form.fail_strategy,
|
||||
timeout_seconds: timeoutNum,
|
||||
});
|
||||
}
|
||||
toast.success(isEdit ? "Hook updated." : "Hook created.");
|
||||
onSuccess(result);
|
||||
handleOpenChange(false);
|
||||
} catch (err) {
|
||||
toast.error(err instanceof Error ? err.message : "Something went wrong.");
|
||||
} finally {
|
||||
setIsSubmitting(false);
|
||||
}
|
||||
}
|
||||
|
||||
const hookPointDisplayName = isEdit
|
||||
? hook!.hook_point
|
||||
: spec?.display_name ?? spec?.hook_point ?? "";
|
||||
const hookPointDescription = isEdit ? undefined : spec?.description;
|
||||
const docsUrl = isEdit ? undefined : spec?.docs_url;
|
||||
|
||||
const failStrategyDescription =
|
||||
form.fail_strategy === "soft"
|
||||
? SOFT_DESCRIPTION
|
||||
: spec?.fail_hard_description ?? undefined;
|
||||
|
||||
return (
|
||||
<Modal open={open} onOpenChange={handleOpenChange}>
|
||||
<Modal.Content width="md" height="fit">
|
||||
<Modal.Header
|
||||
icon={SvgHookNodes}
|
||||
title="Set Up Hook Extension"
|
||||
description="Connect a external API endpoints to extend the hook point."
|
||||
onClose={() => handleOpenChange(false)}
|
||||
/>
|
||||
|
||||
<Modal.Body>
|
||||
{/* Hook point section header */}
|
||||
<div className="flex flex-row items-start justify-between gap-1">
|
||||
<div className="flex flex-col flex-1 min-w-0">
|
||||
<span className="font-main-ui-action text-text-04 px-[0.125rem]">
|
||||
{hookPointDisplayName}
|
||||
</span>
|
||||
{hookPointDescription && (
|
||||
<span className="font-secondary-body text-text-03 px-[0.125rem]">
|
||||
{hookPointDescription}
|
||||
</span>
|
||||
)}
|
||||
</div>
|
||||
<div className="flex flex-col items-end shrink-0 gap-1">
|
||||
<div className="flex items-center gap-1">
|
||||
<SvgHookNodes
|
||||
style={{ width: "1rem", height: "1rem" }}
|
||||
className="text-text-03 shrink-0 p-0.5"
|
||||
/>
|
||||
<span className="font-secondary-body text-text-03">
|
||||
Hook Point
|
||||
</span>
|
||||
</div>
|
||||
{docsUrl && (
|
||||
<a
|
||||
href={docsUrl}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="font-secondary-body text-text-03 underline"
|
||||
>
|
||||
Documentation
|
||||
</a>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<Field label="Display Name" required>
|
||||
<InputTypeIn
|
||||
value={form.name}
|
||||
onChange={(e) => set("name", e.target.value)}
|
||||
placeholder="Name your extension at this hook point"
|
||||
variant={isSubmitting ? "disabled" : undefined}
|
||||
/>
|
||||
</Field>
|
||||
|
||||
<Field label="Fail Strategy" description={failStrategyDescription}>
|
||||
<InputSelect
|
||||
value={form.fail_strategy}
|
||||
onValueChange={(v) => set("fail_strategy", v as HookFailStrategy)}
|
||||
disabled={isSubmitting}
|
||||
>
|
||||
<InputSelect.Trigger placeholder="Select strategy" />
|
||||
<InputSelect.Content>
|
||||
<InputSelect.Item value="soft">
|
||||
Log Error and Continue (Default)
|
||||
</InputSelect.Item>
|
||||
<InputSelect.Item value="hard">
|
||||
Block Pipeline on Failure
|
||||
</InputSelect.Item>
|
||||
</InputSelect.Content>
|
||||
</InputSelect>
|
||||
</Field>
|
||||
|
||||
<Field
|
||||
label="Timeout (seconds)"
|
||||
required
|
||||
description="Maximum time Onyx will wait for the endpoint to respond before applying the fail strategy."
|
||||
>
|
||||
<InputTypeIn
|
||||
type="number"
|
||||
value={form.timeout_seconds}
|
||||
onChange={(e) => set("timeout_seconds", e.target.value)}
|
||||
placeholder="5"
|
||||
variant={isSubmitting ? "disabled" : undefined}
|
||||
/>
|
||||
</Field>
|
||||
|
||||
<Field
|
||||
label="External API Endpoint URL"
|
||||
required
|
||||
description="Only connect to servers you trust. You are responsible for actions taken and data shared with this connection."
|
||||
>
|
||||
<InputTypeIn
|
||||
value={form.endpoint_url}
|
||||
onChange={(e) => set("endpoint_url", e.target.value)}
|
||||
placeholder="https://your-api-endpoint.com"
|
||||
variant={isSubmitting ? "disabled" : undefined}
|
||||
/>
|
||||
</Field>
|
||||
|
||||
<Field
|
||||
label="API Key"
|
||||
description="Onyx will use this key to authenticate with your API endpoint."
|
||||
>
|
||||
<PasswordInputTypeIn
|
||||
value={form.api_key}
|
||||
onChange={(e) => set("api_key", e.target.value)}
|
||||
placeholder={
|
||||
isEdit && hook?.api_key_masked
|
||||
? "Leave blank to keep current key"
|
||||
: undefined
|
||||
}
|
||||
disabled={isSubmitting}
|
||||
/>
|
||||
</Field>
|
||||
</Modal.Body>
|
||||
|
||||
<Modal.Footer>
|
||||
<BasicModalFooter
|
||||
cancel={
|
||||
<Disabled disabled={isSubmitting}>
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
onClick={() => handleOpenChange(false)}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
</Disabled>
|
||||
}
|
||||
submit={
|
||||
<Disabled disabled={isSubmitting || !isValid}>
|
||||
<Button onClick={handleSubmit}>
|
||||
{isEdit ? "Save" : "Connect"}
|
||||
</Button>
|
||||
</Disabled>
|
||||
}
|
||||
/>
|
||||
</Modal.Footer>
|
||||
</Modal.Content>
|
||||
</Modal>
|
||||
);
|
||||
}
|
||||
@@ -3,34 +3,20 @@
|
||||
import { useState, useEffect } from "react";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { useHookSpecs } from "@/hooks/useHookSpecs";
|
||||
import { useHooks } from "@/hooks/useHooks";
|
||||
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
|
||||
import { ContentAction } from "@opal/layouts";
|
||||
import { Button } from "@opal/components";
|
||||
import { Disabled } from "@opal/core";
|
||||
import InputSearch from "@/refresh-components/inputs/InputSearch";
|
||||
import Card from "@/refresh-components/cards/Card";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import {
|
||||
SvgArrowExchange,
|
||||
SvgBubbleText,
|
||||
SvgCheckCircle,
|
||||
SvgExternalLink,
|
||||
SvgFileBroadcast,
|
||||
SvgHookNodes,
|
||||
SvgRefreshCw,
|
||||
SvgSettings,
|
||||
SvgXCircle,
|
||||
} from "@opal/icons";
|
||||
import { IconFunctionComponent } from "@opal/types";
|
||||
import HookFormModal from "@/refresh-pages/admin/HooksPage/HookFormModal";
|
||||
import type {
|
||||
HookPointMeta,
|
||||
HookResponse,
|
||||
} from "@/refresh-pages/admin/HooksPage/interfaces";
|
||||
import {
|
||||
activateHook,
|
||||
deactivateHook,
|
||||
} from "@/refresh-pages/admin/HooksPage/svc";
|
||||
|
||||
const HOOK_POINT_ICONS: Record<string, IconFunctionComponent> = {
|
||||
document_ingestion: SvgFileBroadcast,
|
||||
@@ -41,152 +27,22 @@ function getHookPointIcon(hookPoint: string): IconFunctionComponent {
|
||||
return HOOK_POINT_ICONS[hookPoint] ?? SvgHookNodes;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Sub-component: connected hook card
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
interface ConnectedHookCardProps {
|
||||
hook: HookResponse;
|
||||
spec: HookPointMeta | undefined;
|
||||
onEdit: () => void;
|
||||
onDeleted: () => void;
|
||||
onToggled: (updated: HookResponse) => void;
|
||||
}
|
||||
|
||||
function ConnectedHookCard({
|
||||
hook,
|
||||
spec,
|
||||
onEdit,
|
||||
onDeleted: _onDeleted,
|
||||
onToggled,
|
||||
}: ConnectedHookCardProps) {
|
||||
const [isBusy, setIsBusy] = useState(false);
|
||||
|
||||
async function handleToggle() {
|
||||
setIsBusy(true);
|
||||
try {
|
||||
const updated = hook.is_active
|
||||
? await deactivateHook(hook.id)
|
||||
: await activateHook(hook.id);
|
||||
onToggled(updated);
|
||||
} catch (err) {
|
||||
toast.error(
|
||||
err instanceof Error ? err.message : "Failed to update hook status."
|
||||
);
|
||||
} finally {
|
||||
setIsBusy(false);
|
||||
}
|
||||
}
|
||||
|
||||
const HookIcon = getHookPointIcon(hook.hook_point);
|
||||
|
||||
return (
|
||||
<Card variant="primary" padding={0.5} gap={0}>
|
||||
<div className="flex flex-row items-start w-full">
|
||||
{/* Left: manually replicate ContentMd main-content layout so the docs
|
||||
link sits as a natural third line with zero artificial gap. */}
|
||||
<div className="flex flex-row flex-1 min-w-0 items-start gap-1 p-2">
|
||||
<div className="shrink-0 p-0.5 text-text-04">
|
||||
<HookIcon style={{ width: "1rem", height: "1rem" }} />
|
||||
</div>
|
||||
<div className="flex flex-col items-start min-w-0 flex-1">
|
||||
<span
|
||||
className="font-main-ui-action text-text-04"
|
||||
style={{ height: "1.25rem" }}
|
||||
>
|
||||
{hook.name}
|
||||
</span>
|
||||
{/* matches opal-content-md-description: font-secondary-body px-[0.125rem] */}
|
||||
<div className="font-secondary-body text-text-03">
|
||||
{`Hook Point: ${spec?.display_name ?? hook.hook_point}`}
|
||||
</div>
|
||||
{spec?.docs_url && (
|
||||
<div className="font-secondary-body text-text-03">
|
||||
<a
|
||||
href={spec.docs_url}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="flex items-center gap-1 w-fit"
|
||||
>
|
||||
<span className="underline">Documentation</span>
|
||||
<SvgExternalLink size={12} className="shrink-0" />
|
||||
</a>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Right: status + actions, top-aligned */}
|
||||
<div className="flex flex-col items-end shrink-0">
|
||||
<div className="flex items-center gap-1 p-2">
|
||||
<span className="font-main-ui-action text-text-03">
|
||||
{hook.is_active ? "Connected" : "Inactive"}
|
||||
</span>
|
||||
{hook.is_active ? (
|
||||
<SvgCheckCircle size={16} className="text-status-success-05" />
|
||||
) : (
|
||||
<SvgXCircle size={16} className="text-text-03" />
|
||||
)}
|
||||
</div>
|
||||
<div className="flex items-center gap-0.5 pl-2 pr-0.5">
|
||||
<Disabled disabled={isBusy}>
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
size="md"
|
||||
icon={SvgRefreshCw}
|
||||
onClick={handleToggle}
|
||||
tooltip={hook.is_active ? "Deactivate" : "Activate"}
|
||||
aria-label={
|
||||
hook.is_active ? "Deactivate hook" : "Activate hook"
|
||||
}
|
||||
/>
|
||||
</Disabled>
|
||||
<Disabled disabled={isBusy}>
|
||||
<Button
|
||||
prominence="tertiary"
|
||||
size="md"
|
||||
icon={SvgSettings}
|
||||
onClick={onEdit}
|
||||
aria-label="Configure hook"
|
||||
/>
|
||||
</Disabled>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</Card>
|
||||
);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Main component
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
export default function HooksContent() {
|
||||
const [search, setSearch] = useState("");
|
||||
const [connectSpec, setConnectSpec] = useState<HookPointMeta | null>(null);
|
||||
const [editHook, setEditHook] = useState<HookResponse | null>(null);
|
||||
|
||||
const { specs, isLoading: specsLoading, error: specsError } = useHookSpecs();
|
||||
const {
|
||||
hooks,
|
||||
isLoading: hooksLoading,
|
||||
error: hooksError,
|
||||
mutate,
|
||||
} = useHooks();
|
||||
const { specs, isLoading, error } = useHookSpecs();
|
||||
|
||||
useEffect(() => {
|
||||
if (specsError) toast.error("Failed to load hook specifications.");
|
||||
}, [specsError]);
|
||||
if (error) {
|
||||
toast.error("Failed to load hook specifications.");
|
||||
}
|
||||
}, [error]);
|
||||
|
||||
useEffect(() => {
|
||||
if (hooksError) toast.error("Failed to load hooks.");
|
||||
}, [hooksError]);
|
||||
|
||||
if (specsLoading || hooksLoading) {
|
||||
if (isLoading) {
|
||||
return <SimpleLoader />;
|
||||
}
|
||||
|
||||
if (specsError) {
|
||||
if (error) {
|
||||
return (
|
||||
<Text text03 secondaryBody>
|
||||
Failed to load hook specifications. Please refresh the page.
|
||||
@@ -194,179 +50,68 @@ export default function HooksContent() {
|
||||
);
|
||||
}
|
||||
|
||||
const hooksByPoint: Record<string, HookResponse[]> = {};
|
||||
for (const hook of hooks ?? []) {
|
||||
(hooksByPoint[hook.hook_point] ??= []).push(hook);
|
||||
}
|
||||
|
||||
const searchLower = search.toLowerCase();
|
||||
|
||||
// Connected hooks sorted alphabetically by hook name
|
||||
const connectedHooks = (hooks ?? [])
|
||||
.filter(
|
||||
(hook) =>
|
||||
!searchLower ||
|
||||
hook.name.toLowerCase().includes(searchLower) ||
|
||||
(hooksByPoint[hook.hook_point] &&
|
||||
specs
|
||||
?.find((s) => s.hook_point === hook.hook_point)
|
||||
?.display_name.toLowerCase()
|
||||
.includes(searchLower))
|
||||
)
|
||||
.sort((a, b) => a.name.localeCompare(b.name));
|
||||
|
||||
// Unconnected hook point specs sorted alphabetically
|
||||
const unconnectedSpecs = (specs ?? [])
|
||||
.filter(
|
||||
(spec) =>
|
||||
(hooksByPoint[spec.hook_point]?.length ?? 0) === 0 &&
|
||||
(!searchLower ||
|
||||
spec.display_name.toLowerCase().includes(searchLower) ||
|
||||
spec.description.toLowerCase().includes(searchLower))
|
||||
)
|
||||
.sort((a, b) => a.display_name.localeCompare(b.display_name));
|
||||
|
||||
function handleHookSuccess(updated: HookResponse) {
|
||||
mutate((prev) => {
|
||||
if (!prev) return [updated];
|
||||
const idx = prev.findIndex((h) => h.id === updated.id);
|
||||
if (idx >= 0) {
|
||||
const next = [...prev];
|
||||
next[idx] = updated;
|
||||
return next;
|
||||
}
|
||||
return [...prev, updated];
|
||||
});
|
||||
}
|
||||
|
||||
function handleHookDeleted(id: number) {
|
||||
mutate((prev) => prev?.filter((h) => h.id !== id));
|
||||
}
|
||||
|
||||
const connectSpec_ =
|
||||
connectSpec ??
|
||||
(editHook
|
||||
? specs?.find((s) => s.hook_point === editHook.hook_point)
|
||||
: undefined);
|
||||
const filtered = (specs ?? []).filter(
|
||||
(spec) =>
|
||||
spec.display_name.toLowerCase().includes(search.toLowerCase()) ||
|
||||
spec.description.toLowerCase().includes(search.toLowerCase())
|
||||
);
|
||||
|
||||
return (
|
||||
<>
|
||||
<div className="flex flex-col gap-6">
|
||||
<InputSearch
|
||||
placeholder="Search hooks..."
|
||||
value={search}
|
||||
onChange={(e) => setSearch(e.target.value)}
|
||||
/>
|
||||
<div className="flex flex-col gap-6">
|
||||
<InputSearch
|
||||
placeholder="Search hooks..."
|
||||
value={search}
|
||||
onChange={(e) => setSearch(e.target.value)}
|
||||
/>
|
||||
|
||||
<div className="flex flex-col gap-2">
|
||||
{connectedHooks.length === 0 && unconnectedSpecs.length === 0 ? (
|
||||
<Text text03 secondaryBody>
|
||||
{search
|
||||
? "No hooks match your search."
|
||||
: "No hook points are available."}
|
||||
</Text>
|
||||
) : (
|
||||
<>
|
||||
{connectedHooks.map((hook) => {
|
||||
const spec = specs?.find(
|
||||
(s) => s.hook_point === hook.hook_point
|
||||
);
|
||||
return (
|
||||
<ConnectedHookCard
|
||||
key={hook.id}
|
||||
hook={hook}
|
||||
spec={spec}
|
||||
onEdit={() => setEditHook(hook)}
|
||||
onDeleted={() => handleHookDeleted(hook.id)}
|
||||
onToggled={handleHookSuccess}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
{unconnectedSpecs.map((spec) => {
|
||||
const UnconnectedIcon = getHookPointIcon(spec.hook_point);
|
||||
return (
|
||||
<Card
|
||||
key={spec.hook_point}
|
||||
variant="secondary"
|
||||
padding={0.5}
|
||||
gap={0}
|
||||
<div className="flex flex-col gap-2">
|
||||
{filtered.length === 0 ? (
|
||||
<Text text03 secondaryBody>
|
||||
{search
|
||||
? "No hooks match your search."
|
||||
: "No hook points are available."}
|
||||
</Text>
|
||||
) : (
|
||||
filtered.map((spec) => (
|
||||
<Card
|
||||
key={spec.hook_point}
|
||||
variant="secondary"
|
||||
padding={0.5}
|
||||
gap={0}
|
||||
>
|
||||
<ContentAction
|
||||
icon={getHookPointIcon(spec.hook_point)}
|
||||
title={spec.display_name}
|
||||
description={spec.description}
|
||||
sizePreset="main-content"
|
||||
variant="section"
|
||||
paddingVariant="fit"
|
||||
rightChildren={
|
||||
// TODO(Bo-Onyx): wire up Connect — open modal to create/edit hook
|
||||
<Button prominence="tertiary" rightIcon={SvgArrowExchange}>
|
||||
Connect
|
||||
</Button>
|
||||
}
|
||||
/>
|
||||
{spec.docs_url && (
|
||||
<div className="pl-7 pt-1">
|
||||
<a
|
||||
href={spec.docs_url}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="flex items-center gap-1 w-fit text-text-03"
|
||||
>
|
||||
<div className="flex flex-row items-start w-full">
|
||||
<div className="flex flex-row flex-1 min-w-0 items-start gap-1 p-2">
|
||||
<div className="shrink-0 p-0.5 text-text-04">
|
||||
<UnconnectedIcon
|
||||
style={{ width: "1rem", height: "1rem" }}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex flex-col items-start min-w-0 flex-1">
|
||||
<span
|
||||
className="font-main-ui-action text-text-04"
|
||||
style={{ height: "1.25rem" }}
|
||||
>
|
||||
{spec.display_name}
|
||||
</span>
|
||||
<div className="font-secondary-body text-text-03">
|
||||
{spec.description}
|
||||
</div>
|
||||
{spec.docs_url && (
|
||||
<div className="font-secondary-body text-text-03">
|
||||
<a
|
||||
href={spec.docs_url}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
className="flex items-center gap-1 w-fit"
|
||||
>
|
||||
<span className="underline">Documentation</span>
|
||||
<SvgExternalLink
|
||||
size={12}
|
||||
className="shrink-0"
|
||||
/>
|
||||
</a>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
<div
|
||||
className="flex items-center gap-1 p-2 cursor-pointer"
|
||||
onClick={() => setConnectSpec(spec)}
|
||||
>
|
||||
<span className="font-main-ui-action text-text-03">
|
||||
Connect
|
||||
</span>
|
||||
<SvgArrowExchange
|
||||
size={16}
|
||||
className="text-text-03 shrink-0"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</Card>
|
||||
);
|
||||
})}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
<Text as="span" secondaryBody text03 className="underline">
|
||||
Documentation
|
||||
</Text>
|
||||
<SvgExternalLink size={16} className="text-text-02" />
|
||||
</a>
|
||||
</div>
|
||||
)}
|
||||
</Card>
|
||||
))
|
||||
)}
|
||||
</div>
|
||||
|
||||
{/* Create modal */}
|
||||
<HookFormModal
|
||||
open={!!connectSpec}
|
||||
onOpenChange={(open) => {
|
||||
if (!open) setConnectSpec(null);
|
||||
}}
|
||||
spec={connectSpec ?? undefined}
|
||||
onSuccess={handleHookSuccess}
|
||||
/>
|
||||
|
||||
{/* Edit modal */}
|
||||
<HookFormModal
|
||||
open={!!editHook}
|
||||
onOpenChange={(open) => {
|
||||
if (!open) setEditHook(null);
|
||||
}}
|
||||
hook={editHook ?? undefined}
|
||||
spec={connectSpec_ ?? undefined}
|
||||
onSuccess={handleHookSuccess}
|
||||
/>
|
||||
</>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -18,8 +18,6 @@ export interface HookResponse {
|
||||
name: string;
|
||||
hook_point: HookPoint;
|
||||
endpoint_url: string | null;
|
||||
/** Partially-masked API key (e.g. "abcd••••••••wxyz"), or null if no key is set. */
|
||||
api_key_masked: string | null;
|
||||
fail_strategy: HookFailStrategy;
|
||||
timeout_seconds: number;
|
||||
is_active: boolean;
|
||||
|
||||
@@ -1,81 +0,0 @@
|
||||
import {
|
||||
HookCreateRequest,
|
||||
HookResponse,
|
||||
HookUpdateRequest,
|
||||
} from "@/refresh-pages/admin/HooksPage/interfaces";
|
||||
|
||||
async function parseErrorDetail(
|
||||
res: Response,
|
||||
fallback: string
|
||||
): Promise<string> {
|
||||
try {
|
||||
const body = await res.json();
|
||||
return body?.detail ?? fallback;
|
||||
} catch {
|
||||
return fallback;
|
||||
}
|
||||
}
|
||||
|
||||
export async function listHooks(): Promise<HookResponse[]> {
|
||||
const res = await fetch("/api/admin/hooks");
|
||||
if (!res.ok) {
|
||||
throw new Error(await parseErrorDetail(res, "Failed to load hooks"));
|
||||
}
|
||||
return res.json();
|
||||
}
|
||||
|
||||
export async function createHook(
|
||||
req: HookCreateRequest
|
||||
): Promise<HookResponse> {
|
||||
const res = await fetch("/api/admin/hooks", {
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify(req),
|
||||
});
|
||||
if (!res.ok) {
|
||||
throw new Error(await parseErrorDetail(res, "Failed to create hook"));
|
||||
}
|
||||
return res.json();
|
||||
}
|
||||
|
||||
export async function updateHook(
|
||||
id: number,
|
||||
req: HookUpdateRequest
|
||||
): Promise<HookResponse> {
|
||||
const res = await fetch(`/api/admin/hooks/${id}`, {
|
||||
method: "PATCH",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify(req),
|
||||
});
|
||||
if (!res.ok) {
|
||||
throw new Error(await parseErrorDetail(res, "Failed to update hook"));
|
||||
}
|
||||
return res.json();
|
||||
}
|
||||
|
||||
export async function deleteHook(id: number): Promise<void> {
|
||||
const res = await fetch(`/api/admin/hooks/${id}`, { method: "DELETE" });
|
||||
if (!res.ok) {
|
||||
throw new Error(await parseErrorDetail(res, "Failed to delete hook"));
|
||||
}
|
||||
}
|
||||
|
||||
export async function activateHook(id: number): Promise<HookResponse> {
|
||||
const res = await fetch(`/api/admin/hooks/${id}/activate`, {
|
||||
method: "POST",
|
||||
});
|
||||
if (!res.ok) {
|
||||
throw new Error(await parseErrorDetail(res, "Failed to activate hook"));
|
||||
}
|
||||
return res.json();
|
||||
}
|
||||
|
||||
export async function deactivateHook(id: number): Promise<HookResponse> {
|
||||
const res = await fetch(`/api/admin/hooks/${id}/deactivate`, {
|
||||
method: "POST",
|
||||
});
|
||||
if (!res.ok) {
|
||||
throw new Error(await parseErrorDetail(res, "Failed to deactivate hook"));
|
||||
}
|
||||
return res.json();
|
||||
}
|
||||
@@ -256,6 +256,7 @@ export default function AgentViewerModal({ agent }: AgentViewerModalProps) {
|
||||
title="Featured"
|
||||
sizePreset="main-ui"
|
||||
variant="body"
|
||||
widthVariant="fit"
|
||||
/>
|
||||
)}
|
||||
<Content
|
||||
@@ -264,6 +265,7 @@ export default function AgentViewerModal({ agent }: AgentViewerModalProps) {
|
||||
sizePreset="main-ui"
|
||||
variant="body"
|
||||
prominence="muted"
|
||||
widthVariant="fit"
|
||||
/>
|
||||
{agent.is_public && (
|
||||
<Content
|
||||
@@ -272,6 +274,7 @@ export default function AgentViewerModal({ agent }: AgentViewerModalProps) {
|
||||
sizePreset="main-ui"
|
||||
variant="body"
|
||||
prominence="muted"
|
||||
widthVariant="fit"
|
||||
/>
|
||||
)}
|
||||
</Section>
|
||||
|
||||
@@ -7,6 +7,7 @@ import {
|
||||
SvgOrganization,
|
||||
SvgShare,
|
||||
SvgTag,
|
||||
SvgUser,
|
||||
SvgUsers,
|
||||
SvgX,
|
||||
} from "@opal/icons";
|
||||
@@ -18,7 +19,6 @@ import InputComboBox from "@/refresh-components/inputs/InputComboBox/InputComboB
|
||||
import * as InputLayouts from "@/layouts/input-layouts";
|
||||
import SwitchField from "@/refresh-components/form/SwitchField";
|
||||
import LineItem from "@/refresh-components/buttons/LineItem";
|
||||
import { SvgUser } from "@opal/icons";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import useShareableUsers from "@/hooks/useShareableUsers";
|
||||
|
||||
@@ -7,12 +7,9 @@ import useSWR, { preload } from "swr";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import { checkUserIsNoAuthUser, getUserDisplayName, logout } from "@/lib/user";
|
||||
import { useUser } from "@/providers/UserProvider";
|
||||
import InputAvatar from "@/refresh-components/inputs/InputAvatar";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import LineItem from "@/refresh-components/buttons/LineItem";
|
||||
import Popover, { PopoverMenu } from "@/refresh-components/Popover";
|
||||
import { usePathname, useRouter, useSearchParams } from "next/navigation";
|
||||
import { cn } from "@/lib/utils";
|
||||
import SidebarTab from "@/refresh-components/buttons/SidebarTab";
|
||||
import NotificationsPopover from "@/sections/sidebar/NotificationsPopover";
|
||||
import {
|
||||
@@ -26,6 +23,7 @@ import { Section } from "@/layouts/general-layouts";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import useAppFocus from "@/hooks/useAppFocus";
|
||||
import { useVectorDbEnabled } from "@/providers/SettingsProvider";
|
||||
import UserAvatar from "@/refresh-components/avatars/UserAvatar";
|
||||
|
||||
interface SettingsPopoverProps {
|
||||
onUserSettingsClick: () => void;
|
||||
@@ -150,7 +148,6 @@ export default function UserAvatarPopover({
|
||||
"Settings" | "Notifications" | undefined
|
||||
>(undefined);
|
||||
const { user } = useUser();
|
||||
const router = useRouter();
|
||||
const appFocus = useAppFocus();
|
||||
const vectorDbEnabled = useVectorDbEnabled();
|
||||
|
||||
@@ -186,18 +183,10 @@ export default function UserAvatarPopover({
|
||||
<Popover.Trigger asChild>
|
||||
<div id="onyx-user-dropdown">
|
||||
<SidebarTab
|
||||
icon={({ className }) => (
|
||||
<InputAvatar
|
||||
className={cn(
|
||||
"flex items-center justify-center bg-background-neutral-inverted-00",
|
||||
className,
|
||||
"w-5 h-5"
|
||||
)}
|
||||
>
|
||||
<Text as="p" inverted secondaryBody>
|
||||
{userDisplayName[0]?.toUpperCase()}
|
||||
</Text>
|
||||
</InputAvatar>
|
||||
icon={() => (
|
||||
<div className="w-[16px] flex flex-col justify-center items-center">
|
||||
<UserAvatar user={user} size={18} />
|
||||
</div>
|
||||
)}
|
||||
rightChildren={
|
||||
hasNotifications ? (
|
||||
|
||||
@@ -84,7 +84,7 @@ const ADMIN_PAGES: AdminPageSnapshot[] = [
|
||||
{
|
||||
name: "User Management - Groups",
|
||||
path: "groups",
|
||||
pageTitle: "Manage User Groups",
|
||||
pageTitle: "Groups",
|
||||
},
|
||||
{
|
||||
name: "Appearance & Theming",
|
||||
|
||||
252
web/tests/e2e/admin/groups/GroupsAdminPage.ts
Normal file
252
web/tests/e2e/admin/groups/GroupsAdminPage.ts
Normal file
@@ -0,0 +1,252 @@
|
||||
/**
|
||||
* Page Object Model for the Admin Groups page (/admin/groups).
|
||||
*
|
||||
* Covers the list page, create page, and edit page interactions.
|
||||
*/
|
||||
|
||||
import { type Page, type Locator, expect } from "@playwright/test";
|
||||
|
||||
/** URL pattern that matches the groups data fetch. */
|
||||
const GROUPS_API = /\/api\/manage\/admin\/user-group/;
|
||||
|
||||
export class GroupsAdminPage {
|
||||
readonly page: Page;
|
||||
|
||||
constructor(page: Page) {
|
||||
this.page = page;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Navigation
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
async goto() {
|
||||
await this.page.goto("/admin/groups");
|
||||
await expect(this.newGroupButton).toBeVisible({ timeout: 15000 });
|
||||
}
|
||||
|
||||
async gotoCreate() {
|
||||
await this.page.goto("/admin/groups/create");
|
||||
await expect(this.page.getByText("Create Group")).toBeVisible({
|
||||
timeout: 15000,
|
||||
});
|
||||
}
|
||||
|
||||
async gotoEdit(groupId: number) {
|
||||
await this.page.goto(`/admin/groups/${groupId}`);
|
||||
// Wait for the form to be ready — avoids networkidle hanging due to SWR polling.
|
||||
await expect(this.groupNameInput).toBeVisible({ timeout: 15000 });
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// List page
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/** The Groups page heading container (unique to the list page). */
|
||||
get pageHeading(): Locator {
|
||||
return this.page.getByTestId("groups-page-heading");
|
||||
}
|
||||
|
||||
/** The search input on the list page. */
|
||||
get listSearchInput(): Locator {
|
||||
return this.page.getByPlaceholder("Search groups...");
|
||||
}
|
||||
|
||||
/** The "New Group" button on the list page header. */
|
||||
get newGroupButton(): Locator {
|
||||
return this.page.getByRole("button", { name: "New Group" });
|
||||
}
|
||||
|
||||
/** Returns all group cards on the list page. */
|
||||
get groupCards(): Locator {
|
||||
return this.page.locator("[data-card]");
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns a group card by name.
|
||||
* Cards use ContentAction which renders the title as text — match by content.
|
||||
*/
|
||||
getGroupCard(name: string): Locator {
|
||||
return this.page.locator("[data-card]").filter({ hasText: name });
|
||||
}
|
||||
|
||||
/** Click into a group's edit page from the list. */
|
||||
async openGroup(name: string) {
|
||||
const card = this.getGroupCard(name);
|
||||
await card.getByRole("button", { name: "View group" }).click();
|
||||
await expect(this.groupNameInput).toBeVisible({ timeout: 15000 });
|
||||
}
|
||||
|
||||
/** Search groups on the list page. */
|
||||
async searchGroups(term: string) {
|
||||
await this.listSearchInput.fill(term);
|
||||
}
|
||||
|
||||
/** Click "New Group" to navigate to the create page. */
|
||||
async clickNewGroup() {
|
||||
await this.newGroupButton.click();
|
||||
await expect(this.page.getByText("Create Group")).toBeVisible({
|
||||
timeout: 15000,
|
||||
});
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Create page
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/** The group name input on create/edit pages. */
|
||||
get groupNameInput(): Locator {
|
||||
return this.page.getByPlaceholder("Name your group");
|
||||
}
|
||||
|
||||
/** The member search input on create/edit pages. */
|
||||
get memberSearchInput(): Locator {
|
||||
return this.page.getByPlaceholder("Search users and accounts...");
|
||||
}
|
||||
|
||||
/** The "Create" button on the create page. */
|
||||
get createButton(): Locator {
|
||||
return this.page.getByRole("button", { name: "Create", exact: true });
|
||||
}
|
||||
|
||||
/** The "Cancel" button on create/edit pages. */
|
||||
get cancelButton(): Locator {
|
||||
return this.page.getByRole("button", { name: "Cancel" });
|
||||
}
|
||||
|
||||
/** Fill in the group name on create/edit pages. */
|
||||
async setGroupName(name: string) {
|
||||
await this.groupNameInput.fill(name);
|
||||
}
|
||||
|
||||
/** Search for members in the members table. */
|
||||
async searchMembers(term: string) {
|
||||
await this.memberSearchInput.fill(term);
|
||||
}
|
||||
|
||||
/** Select a member row by checking their checkbox (create page / add mode). */
|
||||
async selectMember(emailOrName: string) {
|
||||
const row = this.page.getByRole("row").filter({ hasText: emailOrName });
|
||||
const checkbox = row.getByRole("checkbox");
|
||||
await checkbox.click();
|
||||
}
|
||||
|
||||
/** Submit the create form. */
|
||||
async submitCreate() {
|
||||
await this.createButton.click();
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Edit page
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/** The "Save Changes" button on the edit page. */
|
||||
get saveButton(): Locator {
|
||||
return this.page.getByRole("button", { name: "Save Changes" });
|
||||
}
|
||||
|
||||
/** The "Add" button to enter add-members mode. */
|
||||
get addMembersButton(): Locator {
|
||||
return this.page.getByRole("button", { name: "Add", exact: true });
|
||||
}
|
||||
|
||||
/** The "Done" button to exit add-members mode. */
|
||||
get doneAddingButton(): Locator {
|
||||
return this.page.getByRole("button", { name: "Done" });
|
||||
}
|
||||
|
||||
/** The "Delete Group" button in the danger zone card. */
|
||||
get deleteGroupButton(): Locator {
|
||||
return this.page.getByRole("button", { name: "Delete Group" });
|
||||
}
|
||||
|
||||
/** Enter add-members mode on the edit page. */
|
||||
async startAddingMembers() {
|
||||
await this.addMembersButton.click();
|
||||
await expect(this.doneAddingButton).toBeVisible();
|
||||
}
|
||||
|
||||
/** Exit add-members mode. */
|
||||
async finishAddingMembers() {
|
||||
await this.doneAddingButton.click();
|
||||
await expect(this.addMembersButton).toBeVisible();
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove a member from the member view via the minus button.
|
||||
* Only works in member view (not add mode).
|
||||
*/
|
||||
async removeMember(emailOrName: string) {
|
||||
const row = this.page.getByRole("row").filter({ hasText: emailOrName });
|
||||
// The remove button is an IconButton with SvgMinusCircle in the actions column
|
||||
await row.getByRole("button").last().click();
|
||||
}
|
||||
|
||||
/** Save the edit form. */
|
||||
async submitEdit() {
|
||||
await this.saveButton.click();
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Delete flow
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/** Click "Delete Group" to open the confirmation modal. */
|
||||
async clickDeleteGroup() {
|
||||
await this.deleteGroupButton.click();
|
||||
}
|
||||
|
||||
/** The delete confirmation modal. */
|
||||
get deleteModal(): Locator {
|
||||
return this.page.getByRole("dialog");
|
||||
}
|
||||
|
||||
/** Confirm deletion in the modal. */
|
||||
async confirmDelete() {
|
||||
await this.deleteModal.getByRole("button", { name: "Delete" }).click();
|
||||
}
|
||||
|
||||
/** Cancel deletion in the modal. */
|
||||
async cancelDelete() {
|
||||
// The modal close button (X icon) or clicking outside
|
||||
await this.deleteModal
|
||||
.getByRole("button")
|
||||
.filter({ hasText: /close|cancel/i })
|
||||
.first()
|
||||
.click();
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Assertions
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
async expectToast(message: string | RegExp) {
|
||||
await expect(this.page.getByText(message)).toBeVisible({ timeout: 10000 });
|
||||
}
|
||||
|
||||
/** Assert a group card exists on the list page. */
|
||||
async expectGroupVisible(name: string) {
|
||||
await expect(this.getGroupCard(name)).toBeVisible({ timeout: 10000 });
|
||||
}
|
||||
|
||||
/** Assert a group card does NOT exist on the list page. */
|
||||
async expectGroupNotVisible(name: string) {
|
||||
await expect(this.getGroupCard(name)).not.toBeVisible({ timeout: 10000 });
|
||||
}
|
||||
|
||||
/** Assert we navigated back to the groups list. */
|
||||
async expectOnListPage() {
|
||||
await expect(this.page).toHaveURL(/\/admin\/groups\/?$/);
|
||||
await expect(this.newGroupButton).toBeVisible();
|
||||
}
|
||||
|
||||
/** Assert we are on the edit page for a specific group. */
|
||||
async expectOnEditPage(groupId: number) {
|
||||
await expect(this.page).toHaveURL(`/admin/groups/${groupId}`);
|
||||
}
|
||||
|
||||
/** Wait for the groups API response after a mutation. */
|
||||
async waitForGroupsRefresh() {
|
||||
await this.page.waitForResponse(GROUPS_API);
|
||||
}
|
||||
}
|
||||
37
web/tests/e2e/admin/groups/fixtures.ts
Normal file
37
web/tests/e2e/admin/groups/fixtures.ts
Normal file
@@ -0,0 +1,37 @@
|
||||
/**
|
||||
* Playwright fixtures for Admin Groups page tests.
|
||||
*
|
||||
* Provides:
|
||||
* - Authenticated admin page
|
||||
* - OnyxApiClient for API-level setup/teardown
|
||||
* - GroupsAdminPage page object
|
||||
*/
|
||||
|
||||
import { test as base, expect, type Page } from "@playwright/test";
|
||||
import { loginAs } from "@tests/e2e/utils/auth";
|
||||
import { OnyxApiClient } from "@tests/e2e/utils/onyxApiClient";
|
||||
import { GroupsAdminPage } from "./GroupsAdminPage";
|
||||
|
||||
export const test = base.extend<{
|
||||
adminPage: Page;
|
||||
api: OnyxApiClient;
|
||||
groupsPage: GroupsAdminPage;
|
||||
}>({
|
||||
adminPage: async ({ page }, use) => {
|
||||
await page.context().clearCookies();
|
||||
await loginAs(page, "admin");
|
||||
await use(page);
|
||||
},
|
||||
|
||||
api: async ({ adminPage }, use) => {
|
||||
const client = new OnyxApiClient(adminPage.request);
|
||||
await use(client);
|
||||
},
|
||||
|
||||
groupsPage: async ({ adminPage }, use) => {
|
||||
const groupsPage = new GroupsAdminPage(adminPage);
|
||||
await use(groupsPage);
|
||||
},
|
||||
});
|
||||
|
||||
export { expect };
|
||||
279
web/tests/e2e/admin/groups/groups.spec.ts
Normal file
279
web/tests/e2e/admin/groups/groups.spec.ts
Normal file
@@ -0,0 +1,279 @@
|
||||
/**
|
||||
* E2E Tests: Admin Groups Page
|
||||
*
|
||||
* Tests the full groups management page — list, create, edit, delete.
|
||||
*
|
||||
* Uses the GroupsAdminPage POM for all interactions. Groups are created via
|
||||
* OnyxApiClient for setup and cleaned up in afterAll/afterEach.
|
||||
*/
|
||||
|
||||
import { test, expect } from "./fixtures";
|
||||
import type { OnyxApiClient } from "@tests/e2e/utils/onyxApiClient";
|
||||
import type { Browser } from "@playwright/test";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
function uniqueGroupName(prefix: string): string {
|
||||
return `e2e-${prefix}-${Date.now()}`;
|
||||
}
|
||||
|
||||
/** Best-effort cleanup — logs failures instead of silently swallowing them. */
|
||||
async function softCleanup(fn: () => Promise<unknown>): Promise<void> {
|
||||
await fn().catch((e) => console.warn("cleanup:", e));
|
||||
}
|
||||
|
||||
/**
|
||||
* Creates an authenticated API context for beforeAll/afterAll hooks.
|
||||
*/
|
||||
async function withApiContext(
|
||||
browser: Browser,
|
||||
fn: (api: OnyxApiClient) => Promise<void>
|
||||
): Promise<void> {
|
||||
const context = await browser.newContext({
|
||||
storageState: "admin_auth.json",
|
||||
});
|
||||
try {
|
||||
const { OnyxApiClient } = await import("@tests/e2e/utils/onyxApiClient");
|
||||
const api = new OnyxApiClient(context.request);
|
||||
await fn(api);
|
||||
} finally {
|
||||
await context.close();
|
||||
}
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// List page
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
test.describe("Groups page — layout", () => {
|
||||
let adminGroupId: number;
|
||||
let basicGroupId: number;
|
||||
|
||||
test.beforeAll(async ({ browser }) => {
|
||||
await withApiContext(browser, async (api) => {
|
||||
adminGroupId = await api.createUserGroup("Admin");
|
||||
basicGroupId = await api.createUserGroup("Basic");
|
||||
await api.waitForGroupSync(adminGroupId);
|
||||
await api.waitForGroupSync(basicGroupId);
|
||||
});
|
||||
});
|
||||
|
||||
test.afterAll(async ({ browser }) => {
|
||||
await withApiContext(browser, async (api) => {
|
||||
await softCleanup(() => api.deleteUserGroup(adminGroupId));
|
||||
await softCleanup(() => api.deleteUserGroup(basicGroupId));
|
||||
});
|
||||
});
|
||||
|
||||
test("renders page title, search, and new group button", async ({
|
||||
groupsPage,
|
||||
}) => {
|
||||
await groupsPage.goto();
|
||||
|
||||
await expect(groupsPage.pageHeading).toBeVisible();
|
||||
await expect(groupsPage.listSearchInput).toBeVisible();
|
||||
await expect(groupsPage.newGroupButton).toBeVisible();
|
||||
});
|
||||
|
||||
test("shows built-in groups (Admin, Basic)", async ({ groupsPage }) => {
|
||||
await groupsPage.goto();
|
||||
|
||||
await groupsPage.expectGroupVisible("Admin");
|
||||
await groupsPage.expectGroupVisible("Basic");
|
||||
});
|
||||
|
||||
test("search filters groups by name", async ({ groupsPage, api }) => {
|
||||
const name = uniqueGroupName("search");
|
||||
const groupId = await api.createUserGroup(name);
|
||||
await api.waitForGroupSync(groupId);
|
||||
|
||||
try {
|
||||
await groupsPage.goto();
|
||||
await groupsPage.expectGroupVisible(name);
|
||||
|
||||
await groupsPage.searchGroups("zzz-nonexistent-zzz");
|
||||
await groupsPage.expectGroupNotVisible(name);
|
||||
|
||||
await groupsPage.searchGroups(name);
|
||||
await groupsPage.expectGroupVisible(name);
|
||||
} finally {
|
||||
await softCleanup(() => api.deleteUserGroup(groupId));
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Create flow
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
test.describe("Groups page — create", () => {
|
||||
test("navigates to create page via New Group button", async ({
|
||||
groupsPage,
|
||||
}) => {
|
||||
await groupsPage.goto();
|
||||
await groupsPage.clickNewGroup();
|
||||
|
||||
await expect(groupsPage.page).toHaveURL(/\/admin\/groups\/create/);
|
||||
await expect(groupsPage.groupNameInput).toBeVisible();
|
||||
});
|
||||
|
||||
test("creates a group and redirects to list", async ({ groupsPage, api }) => {
|
||||
const name = uniqueGroupName("create");
|
||||
let groupId: number | undefined;
|
||||
|
||||
try {
|
||||
await groupsPage.gotoCreate();
|
||||
await groupsPage.setGroupName(name);
|
||||
await groupsPage.submitCreate();
|
||||
|
||||
await groupsPage.expectToast(`Group "${name}" created`);
|
||||
await groupsPage.expectOnListPage();
|
||||
|
||||
// Find the group ID for cleanup via the authenticated page context
|
||||
const res = await groupsPage.page.request.get(
|
||||
"/api/manage/admin/user-group"
|
||||
);
|
||||
const groups = await res.json();
|
||||
const group = groups.find(
|
||||
(g: { name: string; id: number }) => g.name === name
|
||||
);
|
||||
groupId = group?.id;
|
||||
} finally {
|
||||
if (groupId !== undefined) {
|
||||
await softCleanup(() => api.deleteUserGroup(groupId!));
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
test("cancel returns to list without creating", async ({ groupsPage }) => {
|
||||
await groupsPage.gotoCreate();
|
||||
await groupsPage.setGroupName("should-not-be-created");
|
||||
await groupsPage.cancelButton.click();
|
||||
|
||||
await groupsPage.expectOnListPage();
|
||||
});
|
||||
});
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Edit flow
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
test.describe("Groups page — edit @exclusive", () => {
|
||||
let groupId: number;
|
||||
const groupName = uniqueGroupName("edit");
|
||||
|
||||
test.beforeAll(async ({ browser }) => {
|
||||
await withApiContext(browser, async (api) => {
|
||||
groupId = await api.createUserGroup(groupName);
|
||||
await api.waitForGroupSync(groupId);
|
||||
});
|
||||
});
|
||||
|
||||
test.afterAll(async ({ browser }) => {
|
||||
await withApiContext(browser, async (api) => {
|
||||
await softCleanup(() => api.deleteUserGroup(groupId));
|
||||
});
|
||||
});
|
||||
|
||||
test("navigates to edit page from list", async ({ groupsPage }) => {
|
||||
await groupsPage.goto();
|
||||
await groupsPage.openGroup(groupName);
|
||||
|
||||
await groupsPage.expectOnEditPage(groupId);
|
||||
await expect(groupsPage.saveButton).toBeVisible();
|
||||
});
|
||||
|
||||
test("edit page shows group name and save/cancel buttons", async ({
|
||||
groupsPage,
|
||||
}) => {
|
||||
await groupsPage.gotoEdit(groupId);
|
||||
|
||||
await expect(groupsPage.groupNameInput).toHaveValue(groupName);
|
||||
await expect(groupsPage.saveButton).toBeVisible();
|
||||
await expect(groupsPage.cancelButton).toBeVisible();
|
||||
});
|
||||
|
||||
test("can toggle add-members mode", async ({ groupsPage }) => {
|
||||
await groupsPage.gotoEdit(groupId);
|
||||
|
||||
await expect(groupsPage.addMembersButton).toBeVisible();
|
||||
await groupsPage.startAddingMembers();
|
||||
await expect(groupsPage.doneAddingButton).toBeVisible();
|
||||
await groupsPage.finishAddingMembers();
|
||||
await expect(groupsPage.addMembersButton).toBeVisible();
|
||||
});
|
||||
|
||||
test("cancel returns to list without saving", async ({ groupsPage }) => {
|
||||
await groupsPage.gotoEdit(groupId);
|
||||
await groupsPage.cancelButton.click();
|
||||
|
||||
await groupsPage.expectOnListPage();
|
||||
});
|
||||
});
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Delete flow
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
test.describe("Groups page — delete", () => {
|
||||
test("delete group via edit page", async ({ groupsPage, api }) => {
|
||||
const name = uniqueGroupName("delete");
|
||||
const groupId = await api.createUserGroup(name);
|
||||
await api.waitForGroupSync(groupId);
|
||||
|
||||
await groupsPage.gotoEdit(groupId);
|
||||
await groupsPage.clickDeleteGroup();
|
||||
|
||||
// Modal should show the group name
|
||||
await expect(groupsPage.deleteModal).toBeVisible();
|
||||
await expect(groupsPage.deleteModal.getByText(name)).toBeVisible();
|
||||
|
||||
await groupsPage.confirmDelete();
|
||||
await groupsPage.expectToast(`Group "${name}" deleted`);
|
||||
await groupsPage.expectOnListPage();
|
||||
});
|
||||
});
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Sync status (No Vector DB)
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
test.describe("Groups page — sync @lite", () => {
|
||||
test.beforeAll(async ({ browser }) => {
|
||||
const context = await browser.newContext({
|
||||
storageState: "admin_auth.json",
|
||||
});
|
||||
try {
|
||||
const { OnyxApiClient } = await import("@tests/e2e/utils/onyxApiClient");
|
||||
const client = new OnyxApiClient(context.request);
|
||||
const vectorDbEnabled = await client.isVectorDbEnabled();
|
||||
test.skip(
|
||||
vectorDbEnabled,
|
||||
"Skipped: vector DB is enabled in this deployment"
|
||||
);
|
||||
} finally {
|
||||
await context.close();
|
||||
}
|
||||
});
|
||||
|
||||
test("newly created group syncs immediately", async ({ groupsPage, api }) => {
|
||||
const name = uniqueGroupName("sync");
|
||||
let groupId: number | undefined;
|
||||
|
||||
try {
|
||||
// Create via API and verify sync completes
|
||||
groupId = await api.createUserGroup(name);
|
||||
await api.waitForGroupSync(groupId);
|
||||
|
||||
// Navigate to edit page and verify it loads without error
|
||||
await groupsPage.gotoEdit(groupId);
|
||||
await expect(groupsPage.groupNameInput).toHaveValue(name);
|
||||
} finally {
|
||||
if (groupId !== undefined) {
|
||||
await softCleanup(() => api.deleteUserGroup(groupId!));
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
@@ -1,148 +0,0 @@
|
||||
import { test, expect, BrowserContext } from "@playwright/test";
|
||||
import { OnyxApiClient } from "@tests/e2e/utils/onyxApiClient";
|
||||
|
||||
test.use({ storageState: "admin_auth.json" });
|
||||
|
||||
test.describe("User Groups - No Vector DB @lite", () => {
|
||||
test.beforeAll(async ({ browser }) => {
|
||||
const context = await browser.newContext({
|
||||
storageState: "admin_auth.json",
|
||||
});
|
||||
try {
|
||||
const client = new OnyxApiClient(context.request);
|
||||
const vectorDbEnabled = await client.isVectorDbEnabled();
|
||||
test.skip(
|
||||
vectorDbEnabled,
|
||||
"Skipped: vector DB is enabled in this deployment"
|
||||
);
|
||||
} finally {
|
||||
await context.close();
|
||||
}
|
||||
});
|
||||
|
||||
test("should show user group as synced immediately on creation", async ({
|
||||
page,
|
||||
}) => {
|
||||
const groupName = `E2E-NoVectorDB-Group-${Date.now()}`;
|
||||
let groupId: number | undefined;
|
||||
|
||||
try {
|
||||
await page.goto("/admin/groups");
|
||||
await page.waitForLoadState("networkidle");
|
||||
|
||||
await page.getByRole("button", { name: "Create New User Group" }).click();
|
||||
|
||||
const dialog = page.getByRole("dialog");
|
||||
await expect(dialog).toBeVisible();
|
||||
|
||||
await dialog.locator('input[name="name"]').fill(groupName);
|
||||
|
||||
await expect(
|
||||
dialog.getByText("Connectors are not available in Onyx Lite")
|
||||
).toBeVisible();
|
||||
|
||||
await dialog.getByRole("button", { name: "Create!" }).click();
|
||||
|
||||
await expect(dialog).not.toBeVisible({ timeout: 10000 });
|
||||
|
||||
const groupRow = page.getByRole("row").filter({ hasText: groupName });
|
||||
await expect(groupRow).toBeVisible({ timeout: 10000 });
|
||||
await expect(groupRow.getByText("Up to date!")).toBeVisible();
|
||||
|
||||
const groupLink = groupRow.getByRole("link", { name: groupName });
|
||||
const href = await groupLink.getAttribute("href");
|
||||
const match = href?.match(/\/admin\/groups\/(\d+)/);
|
||||
if (match) {
|
||||
groupId = parseInt(match[1] ?? "", 10);
|
||||
}
|
||||
|
||||
await groupLink.click();
|
||||
await page.waitForLoadState("networkidle");
|
||||
|
||||
await expect(page.getByText("Up to date")).toBeVisible({ timeout: 5000 });
|
||||
|
||||
const addUsersButton = page.getByRole("button", {
|
||||
name: "Add Users",
|
||||
});
|
||||
await expect(addUsersButton).toBeEnabled();
|
||||
} finally {
|
||||
if (groupId !== undefined) {
|
||||
const apiClient = new OnyxApiClient(page.request);
|
||||
await apiClient.deleteUserGroup(groupId);
|
||||
}
|
||||
}
|
||||
});
|
||||
});
|
||||
|
||||
test.describe("User Groups - Standard Deployment @exclusive", () => {
|
||||
let cleanupContext: BrowserContext | undefined;
|
||||
let client: OnyxApiClient;
|
||||
let ccPairId: number | undefined;
|
||||
let groupId: number | undefined;
|
||||
|
||||
test.beforeAll(async ({ browser }) => {
|
||||
const context = await browser.newContext({
|
||||
storageState: "admin_auth.json",
|
||||
});
|
||||
try {
|
||||
client = new OnyxApiClient(context.request);
|
||||
const vectorDbEnabled = await client.isVectorDbEnabled();
|
||||
if (!vectorDbEnabled) {
|
||||
await context.close();
|
||||
test.skip(true, "Skipped: vector DB is disabled in this deployment");
|
||||
return;
|
||||
}
|
||||
cleanupContext = context;
|
||||
} catch (e) {
|
||||
await context.close();
|
||||
throw e;
|
||||
}
|
||||
});
|
||||
|
||||
test.afterAll(async () => {
|
||||
try {
|
||||
if (groupId !== undefined) {
|
||||
await client.deleteUserGroup(groupId).catch(() => {});
|
||||
}
|
||||
if (ccPairId !== undefined) {
|
||||
await client.deleteCCPair(ccPairId).catch(() => {});
|
||||
}
|
||||
} finally {
|
||||
await cleanupContext?.close();
|
||||
}
|
||||
});
|
||||
|
||||
test("should sync user group with connector", async ({ page }) => {
|
||||
const apiClient = new OnyxApiClient(page.request);
|
||||
const groupName = `E2E-Sync-Group-${Date.now()}`;
|
||||
|
||||
ccPairId = await apiClient.createFileConnector(
|
||||
`E2E-Group-Connector-${Date.now()}`,
|
||||
"private"
|
||||
);
|
||||
|
||||
groupId = await apiClient.createUserGroup(groupName, [], [ccPairId]);
|
||||
|
||||
await page.goto("/admin/groups");
|
||||
await page.waitForLoadState("networkidle");
|
||||
|
||||
const groupRow = page.getByRole("row").filter({ hasText: groupName });
|
||||
await expect(groupRow).toBeVisible({ timeout: 10000 });
|
||||
|
||||
const upToDate = groupRow.getByText("Up to date!");
|
||||
const deadline = Date.now() + 120_000;
|
||||
while (Date.now() < deadline) {
|
||||
if (await upToDate.isVisible().catch(() => false)) break;
|
||||
await page.waitForTimeout(3000);
|
||||
await page.reload();
|
||||
await page.waitForLoadState("networkidle");
|
||||
}
|
||||
await expect(upToDate).toBeVisible({ timeout: 5000 });
|
||||
|
||||
const groupLink = groupRow.getByRole("link", { name: groupName });
|
||||
await groupLink.click();
|
||||
await page.waitForLoadState("networkidle");
|
||||
|
||||
await expect(page.getByText("Up to date")).toBeVisible({ timeout: 10000 });
|
||||
});
|
||||
});
|
||||
303
web/tests/e2e/chat/actions_popover.spec.ts
Normal file
303
web/tests/e2e/chat/actions_popover.spec.ts
Normal file
@@ -0,0 +1,303 @@
|
||||
import { test, expect } from "@playwright/test";
|
||||
import { loginAs } from "@tests/e2e/utils/auth";
|
||||
import {
|
||||
TOOL_IDS,
|
||||
openActionManagement,
|
||||
openSourceManagement,
|
||||
toggleToolDisabled,
|
||||
getSourceToggle,
|
||||
} from "@tests/e2e/utils/tools";
|
||||
import { OnyxApiClient } from "@tests/e2e/utils/onyxApiClient";
|
||||
|
||||
const LOCAL_STORAGE_KEY = "selectedInternalSearchSources";
|
||||
|
||||
test.describe("ActionsPopover Tool Toggles", () => {
|
||||
test.describe.configure({ mode: "serial" });
|
||||
|
||||
let ccPairId: number | null = null;
|
||||
let webSearchProviderId: number | null = null;
|
||||
let imageGenConfigId: string | null = null;
|
||||
|
||||
test.beforeAll(async ({ browser }) => {
|
||||
const ctx = await browser.newContext({ storageState: "admin_auth.json" });
|
||||
const page = await ctx.newPage();
|
||||
await page.goto("http://localhost:3000/app");
|
||||
await page.waitForLoadState("networkidle");
|
||||
|
||||
const apiClient = new OnyxApiClient(page.request);
|
||||
|
||||
// Create a file connector so internal search tool is available
|
||||
ccPairId = await apiClient.createFileConnector(
|
||||
`actions-popover-test-${Date.now()}`
|
||||
);
|
||||
|
||||
// Create providers for web search and image generation (best-effort)
|
||||
try {
|
||||
webSearchProviderId = await apiClient.createWebSearchProvider(
|
||||
"exa",
|
||||
`actions-popover-web-search-${Date.now()}`
|
||||
);
|
||||
} catch (error) {
|
||||
console.warn(`Failed to create web search provider: ${error}`);
|
||||
}
|
||||
|
||||
try {
|
||||
imageGenConfigId = await apiClient.createImageGenerationConfig(
|
||||
`actions-popover-image-gen-${Date.now()}`
|
||||
);
|
||||
} catch (error) {
|
||||
console.warn(`Failed to create image gen config: ${error}`);
|
||||
}
|
||||
|
||||
// Ensure all tools are enabled on the default agent
|
||||
const toolsResp = await page.request.get("/api/tool");
|
||||
const allTools = await toolsResp.json();
|
||||
const toolIdsByCodeId: Record<string, number> = {};
|
||||
allTools.forEach((t: any) => {
|
||||
if (t.in_code_tool_id) toolIdsByCodeId[t.in_code_tool_id] = t.id;
|
||||
});
|
||||
|
||||
const configResp = await page.request.get(
|
||||
"/api/admin/default-assistant/configuration"
|
||||
);
|
||||
const currentConfig = await configResp.json();
|
||||
|
||||
const desiredToolIds = [
|
||||
toolIdsByCodeId["SearchTool"],
|
||||
toolIdsByCodeId["WebSearchTool"],
|
||||
toolIdsByCodeId["ImageGenerationTool"],
|
||||
].filter(Boolean);
|
||||
|
||||
const uniqueToolIds = Array.from(
|
||||
new Set([...(currentConfig.tool_ids || []), ...desiredToolIds])
|
||||
);
|
||||
|
||||
await page.request.patch("/api/admin/default-assistant", {
|
||||
data: { tool_ids: uniqueToolIds },
|
||||
});
|
||||
|
||||
await ctx.close();
|
||||
});
|
||||
|
||||
test.afterAll(async ({ browser }) => {
|
||||
const ctx = await browser.newContext({ storageState: "admin_auth.json" });
|
||||
const page = await ctx.newPage();
|
||||
await page.goto("http://localhost:3000/app");
|
||||
await page.waitForLoadState("networkidle");
|
||||
|
||||
const apiClient = new OnyxApiClient(page.request);
|
||||
|
||||
if (ccPairId !== null) {
|
||||
try {
|
||||
await apiClient.deleteCCPair(ccPairId);
|
||||
} catch (error) {
|
||||
console.warn(`Cleanup: failed to delete connector: ${error}`);
|
||||
}
|
||||
}
|
||||
if (webSearchProviderId !== null) {
|
||||
try {
|
||||
await apiClient.deleteWebSearchProvider(webSearchProviderId);
|
||||
} catch (error) {
|
||||
console.warn(`Cleanup: failed to delete web search provider: ${error}`);
|
||||
}
|
||||
}
|
||||
if (imageGenConfigId !== null) {
|
||||
try {
|
||||
await apiClient.deleteImageGenerationConfig(imageGenConfigId);
|
||||
} catch (error) {
|
||||
console.warn(`Cleanup: failed to delete image gen config: ${error}`);
|
||||
}
|
||||
}
|
||||
|
||||
await ctx.close();
|
||||
});
|
||||
|
||||
test.beforeEach(async ({ page }) => {
|
||||
await page.context().clearCookies();
|
||||
await loginAs(page, "admin");
|
||||
await page.goto("/app");
|
||||
await page.waitForLoadState("networkidle");
|
||||
// Clear source preferences for a clean slate
|
||||
await page.evaluate(
|
||||
(key) => localStorage.removeItem(key),
|
||||
LOCAL_STORAGE_KEY
|
||||
);
|
||||
});
|
||||
|
||||
test("should show internal search and other tools in popover", async ({
|
||||
page,
|
||||
}) => {
|
||||
await openActionManagement(page);
|
||||
|
||||
// Internal search must be visible (connector was created in beforeAll)
|
||||
await expect(page.locator(TOOL_IDS.searchOption)).toBeVisible({
|
||||
timeout: 10000,
|
||||
});
|
||||
|
||||
// Soft-check other tools (depend on provider setup success)
|
||||
const webVisible = await page
|
||||
.locator(TOOL_IDS.webSearchOption)
|
||||
.isVisible()
|
||||
.catch(() => false);
|
||||
const imgVisible = await page
|
||||
.locator(TOOL_IDS.imageGenerationOption)
|
||||
.isVisible()
|
||||
.catch(() => false);
|
||||
console.log(`[tools] web_search=${webVisible}, image_gen=${imgVisible}`);
|
||||
});
|
||||
|
||||
test("source preferences should persist to localStorage and survive reload", async ({
|
||||
page,
|
||||
}) => {
|
||||
await openActionManagement(page);
|
||||
await expect(page.locator(TOOL_IDS.searchOption)).toBeVisible({
|
||||
timeout: 10000,
|
||||
});
|
||||
await openSourceManagement(page);
|
||||
|
||||
// Find the first source switch
|
||||
const switches = page.locator('[role="switch"]');
|
||||
await expect(switches.first()).toBeVisible({ timeout: 5000 });
|
||||
|
||||
const firstSwitch = switches.first();
|
||||
const ariaLabel = await firstSwitch.getAttribute("aria-label");
|
||||
const sourceName = ariaLabel?.replace("Toggle ", "") || "";
|
||||
expect(sourceName).toBeTruthy();
|
||||
|
||||
// Ensure it's enabled, then disable it
|
||||
if ((await firstSwitch.getAttribute("aria-checked")) === "false") {
|
||||
await firstSwitch.click();
|
||||
await expect(firstSwitch).toHaveAttribute("aria-checked", "true");
|
||||
}
|
||||
await firstSwitch.click();
|
||||
await expect(firstSwitch).toHaveAttribute("aria-checked", "false");
|
||||
|
||||
// Verify localStorage was updated
|
||||
const stored = await page.evaluate(
|
||||
(key) => localStorage.getItem(key),
|
||||
LOCAL_STORAGE_KEY
|
||||
);
|
||||
expect(stored).toBeTruthy();
|
||||
expect(JSON.parse(stored!).sourcePreferences).toBeDefined();
|
||||
|
||||
// Reload and verify persistence
|
||||
await page.reload();
|
||||
await page.waitForLoadState("networkidle");
|
||||
|
||||
await openActionManagement(page);
|
||||
await openSourceManagement(page);
|
||||
|
||||
const sourceToggle = getSourceToggle(page, sourceName);
|
||||
await expect(sourceToggle).toHaveAttribute("aria-checked", "false", {
|
||||
timeout: 10000,
|
||||
});
|
||||
});
|
||||
|
||||
test("disabling search tool clears sources, re-enabling restores them", async ({
|
||||
page,
|
||||
}) => {
|
||||
await openActionManagement(page);
|
||||
await expect(page.locator(TOOL_IDS.searchOption)).toBeVisible({
|
||||
timeout: 10000,
|
||||
});
|
||||
|
||||
// Open source management and count enabled sources
|
||||
await openSourceManagement(page);
|
||||
const switches = page.locator('[role="switch"]');
|
||||
await expect(switches.first()).toBeVisible({ timeout: 5000 });
|
||||
|
||||
const totalSources = await switches.count();
|
||||
let enabledBefore = 0;
|
||||
for (let i = 0; i < totalSources; i++) {
|
||||
if ((await switches.nth(i).getAttribute("aria-checked")) === "true") {
|
||||
enabledBefore++;
|
||||
}
|
||||
}
|
||||
expect(enabledBefore).toBeGreaterThan(0);
|
||||
|
||||
// Go back to primary view
|
||||
await page.locator('button[aria-label="Back"]').click();
|
||||
await expect(page.locator(TOOL_IDS.searchOption)).toBeVisible();
|
||||
|
||||
// Disable the search tool
|
||||
await toggleToolDisabled(page, TOOL_IDS.searchOption);
|
||||
|
||||
// Verify localStorage was written (the fix being tested)
|
||||
const stored = await page.evaluate(
|
||||
(key) => localStorage.getItem(key),
|
||||
LOCAL_STORAGE_KEY
|
||||
);
|
||||
expect(stored).toBeTruthy();
|
||||
|
||||
// Re-enable the search tool
|
||||
await toggleToolDisabled(page, TOOL_IDS.searchOption);
|
||||
|
||||
// Verify sources were restored
|
||||
await openSourceManagement(page);
|
||||
const switchesAfter = page.locator('[role="switch"]');
|
||||
const totalAfter = await switchesAfter.count();
|
||||
let enabledAfter = 0;
|
||||
for (let i = 0; i < totalAfter; i++) {
|
||||
if (
|
||||
(await switchesAfter.nth(i).getAttribute("aria-checked")) === "true"
|
||||
) {
|
||||
enabledAfter++;
|
||||
}
|
||||
}
|
||||
expect(enabledAfter).toBe(enabledBefore);
|
||||
});
|
||||
|
||||
test("tool enabled and disabled states both persist across reload", async ({
|
||||
page,
|
||||
}) => {
|
||||
await openActionManagement(page);
|
||||
const searchOption = page.locator(TOOL_IDS.searchOption);
|
||||
await expect(searchOption).toBeVisible({ timeout: 10000 });
|
||||
|
||||
// The slash button says "Disable" when the tool is enabled
|
||||
await searchOption.hover();
|
||||
const slashButton = searchOption.locator(
|
||||
'button[aria-label="Disable"], button[aria-label="Enable"]'
|
||||
);
|
||||
await expect(slashButton.first()).toHaveAttribute("aria-label", "Disable");
|
||||
|
||||
// Reload — enabled state should persist
|
||||
await page.reload();
|
||||
await page.waitForLoadState("networkidle");
|
||||
await openActionManagement(page);
|
||||
await page.locator(TOOL_IDS.searchOption).hover();
|
||||
await expect(
|
||||
page
|
||||
.locator(TOOL_IDS.searchOption)
|
||||
.locator('button[aria-label="Disable"], button[aria-label="Enable"]')
|
||||
.first()
|
||||
).toHaveAttribute("aria-label", "Disable");
|
||||
|
||||
// Disable the search tool
|
||||
await toggleToolDisabled(page, TOOL_IDS.searchOption);
|
||||
|
||||
// Verify it's now disabled (slash button says "Enable")
|
||||
await page.locator(TOOL_IDS.searchOption).hover();
|
||||
await expect(
|
||||
page
|
||||
.locator(TOOL_IDS.searchOption)
|
||||
.locator('button[aria-label="Disable"], button[aria-label="Enable"]')
|
||||
.first()
|
||||
).toHaveAttribute("aria-label", "Enable");
|
||||
|
||||
// Reload — disabled state should also persist (saved to DB)
|
||||
await page.reload();
|
||||
await page.waitForLoadState("networkidle");
|
||||
await openActionManagement(page);
|
||||
await page.locator(TOOL_IDS.searchOption).hover();
|
||||
await expect(
|
||||
page
|
||||
.locator(TOOL_IDS.searchOption)
|
||||
.locator('button[aria-label="Disable"], button[aria-label="Enable"]')
|
||||
.first()
|
||||
).toHaveAttribute("aria-label", "Enable");
|
||||
|
||||
// Re-enable the tool for cleanup (serial tests follow)
|
||||
await toggleToolDisabled(page, TOOL_IDS.searchOption);
|
||||
});
|
||||
});
|
||||
@@ -37,3 +37,40 @@ export async function isActionTogglePresent(page: Page): Promise<boolean> {
|
||||
const el = await page.$(TOOL_IDS.actionToggle);
|
||||
return !!el;
|
||||
}
|
||||
|
||||
/**
|
||||
* Click the disable/enable (slash) button on a tool line item.
|
||||
* The button is hidden until hover; we hover first, then force-click
|
||||
* using aria-label which matches the button's current state.
|
||||
*/
|
||||
export async function toggleToolDisabled(
|
||||
page: Page,
|
||||
toolSelector: string
|
||||
): Promise<void> {
|
||||
const toolOption = page.locator(toolSelector);
|
||||
await toolOption.hover();
|
||||
const slashButton = toolOption.locator(
|
||||
'button[aria-label="Disable"], button[aria-label="Enable"]'
|
||||
);
|
||||
await slashButton.first().click({ force: true });
|
||||
}
|
||||
|
||||
/**
|
||||
* Open the source management secondary view for the internal search tool.
|
||||
* Assumes the ActionsPopover is already open.
|
||||
*/
|
||||
export async function openSourceManagement(page: Page): Promise<void> {
|
||||
const searchOption = page.locator(TOOL_IDS.searchOption);
|
||||
await searchOption
|
||||
.locator('button[aria-label="Configure Connectors"]')
|
||||
.click();
|
||||
// Wait for the source list Back button (indicates secondary view is open)
|
||||
await page.locator('button[aria-label="Back"]').waitFor({ timeout: 5000 });
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a source toggle Switch in the source management view by display name.
|
||||
*/
|
||||
export function getSourceToggle(page: Page, sourceName: string) {
|
||||
return page.locator(`[aria-label="Toggle ${sourceName}"]`);
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user