Compare commits

..

23 Commits

Author SHA1 Message Date
Jamison Lahman
b59f8cf453 feat(cli): onyx install-skill (#9889) 2026-04-03 12:41:39 -07:00
Bo-Onyx
456ecc7b9a feat(hook): UI improve disconnect error popover (#9877) 2026-04-03 19:15:19 +00:00
Jamison Lahman
fdc2bc9ee2 fix(fe): closed sidebar button tooltip text color (#9876) 2026-04-03 18:57:48 +00:00
Jamison Lahman
1c3f371549 fix(fe): projects buttons transition in like other sidebar items (#9875) 2026-04-03 18:50:14 +00:00
Evan Lohn
a120add37b feat: filestore delete missing error (#9878) 2026-04-03 18:19:41 +00:00
Raunak Bhagat
757e4e979b feat: cluster disabled admin sidebar tabs at the bottom (#9867) 2026-04-03 18:01:03 +00:00
Wenxi
cbcdfee56e fix(mcp server): propagate detailed error messages to mcp client instead of generic message and migrate to OnyxError (#9880) 2026-04-03 16:29:22 +00:00
Jamison Lahman
b06700314b fix(fe): fix sticky sidebar headers overlapping scrollbars (#9884) 2026-04-03 16:16:10 +00:00
roshan
01f573cdcb feat(cli): make onyx-cli agent-friendly (#9874)
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-03 16:08:57 +00:00
Jamison Lahman
d4a96d70f3 fix(desktop): prefer native scrollbar styling (#9879) 2026-04-03 00:33:18 +00:00
Evan Lohn
5b000c2173 chore: remove unused db rows (#9869) 2026-04-02 22:17:10 +00:00
acaprau
d62af28e40 fix(opensearch): Doc IDs whose length would exceed OpenSearch's ID length are hashed (#9847) 2026-04-02 21:35:17 +00:00
acaprau
593678a14f fix(opensearch): Re-order migration task logic to not hold DB sessions too long (#9872) 2026-04-02 21:26:08 +00:00
roshan
e6f7c2b45c feat(install): add GitHub star prompt at end of install script (#9861) 2026-04-02 19:12:10 +00:00
Raunak Bhagat
f77128d929 refactor: move SidebarTab to Opal with disabled prop and variant/state API (v2) (#9866) 2026-04-02 19:07:52 +00:00
Jamison Lahman
1d4ca769e7 chore(playwright): stabalize icon loading, users table timestamp (#9864) 2026-04-02 18:58:28 +00:00
Raunak Bhagat
e002f6c195 Revert "refactor: move SidebarTab to opal" (#9865) 2026-04-02 11:38:03 -07:00
Raunak Bhagat
10d696262f refactor: move SidebarTab to opal (#9863) 2026-04-02 18:22:19 +00:00
Jamison Lahman
608e151443 fix(offline): fallback to system sans-serif font (#9860)
Co-authored-by: Claude Opus 4.6 <noreply@anthropic.com>
2026-04-02 17:26:57 +00:00
Raunak Bhagat
41d1a33093 refactor: simplify opal/Disabled by removing its context (#9852) 2026-04-02 17:12:01 +00:00
Bo-Onyx
f396ebbdbb feat(hook): Show connection lost status (#9848) 2026-04-02 16:58:28 +00:00
Raunak Bhagat
67c8df002e refactor: update Button to define its own internal disabled styling (#9851) 2026-04-02 16:42:35 +00:00
SubashMohan
722f7de335 feat(groups): seed default Admin and Basic user groups (#9795)
Co-authored-by: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-02 12:25:16 +00:00
224 changed files with 6994 additions and 3595 deletions

View File

@@ -1,186 +0,0 @@
---
name: onyx-cli
description: Query the Onyx knowledge base using the onyx-cli command. Use when the user wants to search company documents, ask questions about internal knowledge, query connected data sources, or look up information stored in Onyx.
---
# Onyx CLI — Agent Tool
Onyx is an enterprise search and Gen-AI platform that connects to company documents, apps, and people. The `onyx-cli` CLI provides non-interactive commands to query the Onyx knowledge base and list available agents.
## Prerequisites
### 1. Check if installed
```bash
which onyx-cli
```
### 2. Install (if needed)
**Primary — pip:**
```bash
pip install onyx-cli
```
**From source (Go):**
```bash
cd cli && go build -o onyx-cli . && sudo mv onyx-cli /usr/local/bin/
```
### 3. Check if configured
```bash
onyx-cli validate-config
```
This checks the config file exists, API key is present, and tests the server connection via `/api/me`. Exit code 0 on success, non-zero with a descriptive error on failure.
If unconfigured, you have two options:
**Option A — Interactive setup (requires user input):**
```bash
onyx-cli configure
```
This prompts for the Onyx server URL and API key, tests the connection, and saves config.
**Option B — Environment variables (non-interactive, preferred for agents):**
```bash
export ONYX_SERVER_URL="https://your-onyx-server.com" # default: https://cloud.onyx.app
export ONYX_API_KEY="your-api-key"
```
Environment variables override the config file. If these are set, no config file is needed.
| Variable | Required | Description |
|----------|----------|-------------|
| `ONYX_SERVER_URL` | No | Onyx server base URL (default: `https://cloud.onyx.app`) |
| `ONYX_API_KEY` | Yes | API key for authentication |
| `ONYX_PERSONA_ID` | No | Default agent/persona ID |
If neither the config file nor environment variables are set, tell the user that `onyx-cli` needs to be configured and ask them to either:
- Run `onyx-cli configure` interactively, or
- Set `ONYX_SERVER_URL` and `ONYX_API_KEY` environment variables
## Commands
### Validate configuration
```bash
onyx-cli validate-config
```
Checks config file exists, API key is present, and tests the server connection. Use this before `ask` or `agents` to confirm the CLI is properly set up.
### List available agents
```bash
onyx-cli agents
```
Prints a table of agent IDs, names, and descriptions. Use `--json` for structured output:
```bash
onyx-cli agents --json
```
Use agent IDs with `ask --agent-id` to query a specific agent.
### Basic query (plain text output)
```bash
onyx-cli ask "What is our company's PTO policy?"
```
Streams the answer as plain text to stdout. Exit code 0 on success, non-zero on error.
### JSON output (structured events)
```bash
onyx-cli ask --json "What authentication methods do we support?"
```
Outputs JSON-encoded parsed stream events (one object per line). Key event objects include message deltas, stop, errors, search-start, and citation payloads.
Each line is a JSON object with this envelope:
```json
{"type": "<event_type>", "event": { ... }}
```
| Event Type | Description |
|------------|-------------|
| `message_delta` | Content token — concatenate all `content` fields for the full answer |
| `stop` | Stream complete |
| `error` | Error with `error` message field |
| `search_tool_start` | Onyx started searching documents |
| `citation_info` | Source citation — see shape below |
`citation_info` event shape:
```json
{
"type": "citation_info",
"event": {
"citation_number": 1,
"document_id": "abc123def456",
"placement": {"turn_index": 0, "tab_index": 0, "sub_turn_index": null}
}
}
```
`placement` is metadata about where in the conversation the citation appeared and can be ignored for most use cases.
### Specify an agent
```bash
onyx-cli ask --agent-id 5 "Summarize our Q4 roadmap"
```
Uses a specific Onyx agent/persona instead of the default.
### All flags
| Flag | Type | Description |
|------|------|-------------|
| `--agent-id` | int | Agent ID to use (overrides default) |
| `--json` | bool | Output raw NDJSON events instead of plain text |
## Statelessness
Each `onyx-cli ask` call creates an independent chat session. There is no built-in way to chain context across multiple `ask` invocations — every call starts fresh. If you need multi-turn conversation with memory, use the interactive TUI (`onyx-cli` or `onyx-cli chat`) instead.
## When to Use
Use `onyx-cli ask` when:
- The user asks about company-specific information (policies, docs, processes)
- You need to search internal knowledge bases or connected data sources
- The user references Onyx, asks you to "search Onyx", or wants to query their documents
- You need context from company wikis, Confluence, Google Drive, Slack, or other connected sources
Do NOT use when:
- The question is about general programming knowledge (use your own knowledge)
- The user is asking about code in the current repository (use grep/read tools)
- The user hasn't mentioned Onyx and the question doesn't require internal company data
## Examples
```bash
# Simple question
onyx-cli ask "What are the steps to deploy to production?"
# Get structured output for parsing
onyx-cli ask --json "List all active API integrations"
# Use a specialized agent
onyx-cli ask --agent-id 3 "What were the action items from last week's standup?"
# Pipe the answer into another command
onyx-cli ask "What is the database schema for users?" | head -20
```

View File

@@ -0,0 +1 @@
../../../cli/internal/embedded/SKILL.md

View File

@@ -0,0 +1,108 @@
"""backfill_account_type
Revision ID: 03d085c5c38d
Revises: 977e834c1427
Create Date: 2026-03-25 16:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "03d085c5c38d"
down_revision = "977e834c1427"
branch_labels = None
depends_on = None
_STANDARD = "STANDARD"
_BOT = "BOT"
_EXT_PERM_USER = "EXT_PERM_USER"
_SERVICE_ACCOUNT = "SERVICE_ACCOUNT"
_ANONYMOUS = "ANONYMOUS"
# Well-known anonymous user UUID
ANONYMOUS_USER_ID = "00000000-0000-0000-0000-000000000002"
# Email pattern for API key virtual users
API_KEY_EMAIL_PATTERN = r"API\_KEY\_\_%"
# Reflect the table structure for use in DML
user_table = sa.table(
"user",
sa.column("id", sa.Uuid),
sa.column("email", sa.String),
sa.column("role", sa.String),
sa.column("account_type", sa.String),
)
def upgrade() -> None:
# ------------------------------------------------------------------
# Step 1: Backfill account_type from role.
# Order matters — most-specific matches first so the final catch-all
# only touches rows that haven't been classified yet.
# ------------------------------------------------------------------
# 1a. API key virtual users → SERVICE_ACCOUNT
op.execute(
sa.update(user_table)
.where(
user_table.c.email.ilike(API_KEY_EMAIL_PATTERN),
user_table.c.account_type.is_(None),
)
.values(account_type=_SERVICE_ACCOUNT)
)
# 1b. Anonymous user → ANONYMOUS
op.execute(
sa.update(user_table)
.where(
user_table.c.id == ANONYMOUS_USER_ID,
user_table.c.account_type.is_(None),
)
.values(account_type=_ANONYMOUS)
)
# 1c. SLACK_USER role → BOT
op.execute(
sa.update(user_table)
.where(
user_table.c.role == "SLACK_USER",
user_table.c.account_type.is_(None),
)
.values(account_type=_BOT)
)
# 1d. EXT_PERM_USER role → EXT_PERM_USER
op.execute(
sa.update(user_table)
.where(
user_table.c.role == "EXT_PERM_USER",
user_table.c.account_type.is_(None),
)
.values(account_type=_EXT_PERM_USER)
)
# 1e. Everything else → STANDARD
op.execute(
sa.update(user_table)
.where(user_table.c.account_type.is_(None))
.values(account_type=_STANDARD)
)
# ------------------------------------------------------------------
# Step 2: Set account_type to NOT NULL now that every row is filled.
# ------------------------------------------------------------------
op.alter_column(
"user",
"account_type",
nullable=False,
server_default="STANDARD",
)
def downgrade() -> None:
op.alter_column("user", "account_type", nullable=True, server_default=None)
op.execute(sa.update(user_table).values(account_type=None))

View File

@@ -0,0 +1,104 @@
"""add_effective_permissions
Adds a JSONB column `effective_permissions` to the user table to store
directly granted permissions (e.g. ["admin"] or ["basic"]). Implied
permissions are expanded at read time, not stored.
Backfill: joins user__user_group → permission_grant to collect each
user's granted permissions into a JSON array. Users without group
memberships keep the default [].
Revision ID: 503883791c39
Revises: b4b7e1028dfd
Create Date: 2026-03-30 14:49:22.261748
"""
from collections.abc import Sequence
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "503883791c39"
down_revision = "b4b7e1028dfd"
branch_labels: str | None = None
depends_on: str | Sequence[str] | None = None
user_table = sa.table(
"user",
sa.column("id", sa.Uuid),
sa.column("effective_permissions", postgresql.JSONB),
)
user_user_group = sa.table(
"user__user_group",
sa.column("user_id", sa.Uuid),
sa.column("user_group_id", sa.Integer),
)
permission_grant = sa.table(
"permission_grant",
sa.column("group_id", sa.Integer),
sa.column("permission", sa.String),
sa.column("is_deleted", sa.Boolean),
)
def upgrade() -> None:
op.add_column(
"user",
sa.Column(
"effective_permissions",
postgresql.JSONB(),
nullable=False,
server_default=sa.text("'[]'::jsonb"),
),
)
conn = op.get_bind()
# Deduplicated permissions per user
deduped = (
sa.select(
user_user_group.c.user_id,
permission_grant.c.permission,
)
.select_from(
user_user_group.join(
permission_grant,
sa.and_(
permission_grant.c.group_id == user_user_group.c.user_group_id,
permission_grant.c.is_deleted == sa.false(),
),
)
)
.distinct()
.subquery("deduped")
)
# Aggregate into JSONB array per user (order is not guaranteed;
# consumers read this as a set so ordering does not matter)
perms_per_user = (
sa.select(
deduped.c.user_id,
sa.func.jsonb_agg(
deduped.c.permission,
type_=postgresql.JSONB,
).label("perms"),
)
.group_by(deduped.c.user_id)
.subquery("sub")
)
conn.execute(
user_table.update()
.where(user_table.c.id == perms_per_user.c.user_id)
.values(effective_permissions=perms_per_user.c.perms)
)
def downgrade() -> None:
op.drop_column("user", "effective_permissions")

View File

@@ -0,0 +1,139 @@
"""seed_default_groups
Revision ID: 977e834c1427
Revises: 8188861f4e92
Create Date: 2026-03-25 14:59:41.313091
"""
from typing import Any
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import insert as pg_insert
# revision identifiers, used by Alembic.
revision = "977e834c1427"
down_revision = "8188861f4e92"
branch_labels = None
depends_on = None
# (group_name, permission_value)
DEFAULT_GROUPS = [
("Admin", "admin"),
("Basic", "basic"),
]
CUSTOM_SUFFIX = "(Custom)"
MAX_RENAME_ATTEMPTS = 100
# Reflect table structures for use in DML
user_group_table = sa.table(
"user_group",
sa.column("id", sa.Integer),
sa.column("name", sa.String),
sa.column("is_up_to_date", sa.Boolean),
sa.column("is_up_for_deletion", sa.Boolean),
sa.column("is_default", sa.Boolean),
)
permission_grant_table = sa.table(
"permission_grant",
sa.column("group_id", sa.Integer),
sa.column("permission", sa.String),
sa.column("grant_source", sa.String),
)
user__user_group_table = sa.table(
"user__user_group",
sa.column("user_group_id", sa.Integer),
sa.column("user_id", sa.Uuid),
)
def _find_available_name(conn: sa.engine.Connection, base: str) -> str:
"""Return a name like 'Admin (Custom)' or 'Admin (Custom 2)' that is not taken."""
candidate = f"{base} {CUSTOM_SUFFIX}"
attempt = 1
while attempt <= MAX_RENAME_ATTEMPTS:
exists: Any = conn.execute(
sa.select(sa.literal(1))
.select_from(user_group_table)
.where(user_group_table.c.name == candidate)
.limit(1)
).fetchone()
if exists is None:
return candidate
attempt += 1
candidate = f"{base} (Custom {attempt})"
raise RuntimeError(
f"Could not find an available name for group '{base}' "
f"after {MAX_RENAME_ATTEMPTS} attempts"
)
def upgrade() -> None:
conn = op.get_bind()
for group_name, permission_value in DEFAULT_GROUPS:
# Step 1: Rename ALL existing groups that clash with the canonical name.
conflicting = conn.execute(
sa.select(user_group_table.c.id, user_group_table.c.name).where(
user_group_table.c.name == group_name
)
).fetchall()
for row_id, row_name in conflicting:
new_name = _find_available_name(conn, row_name)
op.execute(
sa.update(user_group_table)
.where(user_group_table.c.id == row_id)
.values(name=new_name, is_up_to_date=False)
)
# Step 2: Create a fresh default group.
result = conn.execute(
user_group_table.insert()
.values(
name=group_name,
is_up_to_date=True,
is_up_for_deletion=False,
is_default=True,
)
.returning(user_group_table.c.id)
).fetchone()
assert result is not None
group_id = result[0]
# Step 3: Upsert permission grant.
op.execute(
pg_insert(permission_grant_table)
.values(
group_id=group_id,
permission=permission_value,
grant_source="SYSTEM",
)
.on_conflict_do_nothing(index_elements=["group_id", "permission"])
)
def downgrade() -> None:
# Remove the default groups created by this migration.
# First remove user-group memberships that reference default groups
# to avoid FK violations, then delete the groups themselves.
default_group_ids = sa.select(user_group_table.c.id).where(
user_group_table.c.is_default == True # noqa: E712
)
conn = op.get_bind()
conn.execute(
sa.delete(user__user_group_table).where(
user__user_group_table.c.user_group_id.in_(default_group_ids)
)
)
conn.execute(
sa.delete(user_group_table).where(
user_group_table.c.is_default == True # noqa: E712
)
)

View File

@@ -0,0 +1,84 @@
"""grant_basic_to_existing_groups
Grants the "basic" permission to all existing groups that don't already
have it. Every group should have at least "basic" so that its members
get basic access when effective_permissions is backfilled.
Revision ID: b4b7e1028dfd
Revises: b7bcc991d722
Create Date: 2026-03-30 16:15:17.093498
"""
from collections.abc import Sequence
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "b4b7e1028dfd"
down_revision = "b7bcc991d722"
branch_labels: str | None = None
depends_on: str | Sequence[str] | None = None
user_group = sa.table(
"user_group",
sa.column("id", sa.Integer),
sa.column("is_default", sa.Boolean),
)
permission_grant = sa.table(
"permission_grant",
sa.column("group_id", sa.Integer),
sa.column("permission", sa.String),
sa.column("grant_source", sa.String),
sa.column("is_deleted", sa.Boolean),
)
def upgrade() -> None:
conn = op.get_bind()
already_has_basic = (
sa.select(sa.literal(1))
.select_from(permission_grant)
.where(
permission_grant.c.group_id == user_group.c.id,
permission_grant.c.permission == "basic",
)
.exists()
)
groups_needing_basic = sa.select(
user_group.c.id,
sa.literal("basic").label("permission"),
sa.literal("SYSTEM").label("grant_source"),
sa.literal(False).label("is_deleted"),
).where(
user_group.c.is_default == sa.false(),
~already_has_basic,
)
conn.execute(
permission_grant.insert().from_select(
["group_id", "permission", "grant_source", "is_deleted"],
groups_needing_basic,
)
)
def downgrade() -> None:
conn = op.get_bind()
non_default_group_ids = sa.select(user_group.c.id).where(
user_group.c.is_default == sa.false()
)
conn.execute(
permission_grant.delete().where(
permission_grant.c.permission == "basic",
permission_grant.c.grant_source == "SYSTEM",
permission_grant.c.group_id.in_(non_default_group_ids),
)
)

View File

@@ -0,0 +1,125 @@
"""assign_users_to_default_groups
Revision ID: b7bcc991d722
Revises: 03d085c5c38d
Create Date: 2026-03-25 16:30:39.529301
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects.postgresql import insert as pg_insert
# revision identifiers, used by Alembic.
revision = "b7bcc991d722"
down_revision = "03d085c5c38d"
branch_labels = None
depends_on = None
# The no-auth placeholder user must NOT be assigned to default groups.
# A database trigger (migrate_no_auth_data_to_user) will try to DELETE this
# user when the first real user registers; group membership rows would cause
# an FK violation on that DELETE.
NO_AUTH_PLACEHOLDER_USER_UUID = "00000000-0000-0000-0000-000000000001"
# Reflect table structures for use in DML
user_group_table = sa.table(
"user_group",
sa.column("id", sa.Integer),
sa.column("name", sa.String),
sa.column("is_default", sa.Boolean),
)
user_table = sa.table(
"user",
sa.column("id", sa.Uuid),
sa.column("role", sa.String),
sa.column("account_type", sa.String),
sa.column("is_active", sa.Boolean),
)
user__user_group_table = sa.table(
"user__user_group",
sa.column("user_group_id", sa.Integer),
sa.column("user_id", sa.Uuid),
)
def upgrade() -> None:
conn = op.get_bind()
# Look up default group IDs
admin_row = conn.execute(
sa.select(user_group_table.c.id).where(
user_group_table.c.name == "Admin",
user_group_table.c.is_default == True, # noqa: E712
)
).fetchone()
basic_row = conn.execute(
sa.select(user_group_table.c.id).where(
user_group_table.c.name == "Basic",
user_group_table.c.is_default == True, # noqa: E712
)
).fetchone()
if admin_row is None:
raise RuntimeError(
"Default 'Admin' group not found. "
"Ensure migration 977e834c1427 (seed_default_groups) ran successfully."
)
if basic_row is None:
raise RuntimeError(
"Default 'Basic' group not found. "
"Ensure migration 977e834c1427 (seed_default_groups) ran successfully."
)
# Users with role=admin → Admin group
# Include inactive users so reactivation doesn't require reconciliation.
# Exclude non-human account types (mirrors assign_user_to_default_groups logic).
admin_users = sa.select(
sa.literal(admin_row[0]).label("user_group_id"),
user_table.c.id.label("user_id"),
).where(
user_table.c.role == "ADMIN",
user_table.c.account_type.notin_(["BOT", "EXT_PERM_USER", "ANONYMOUS"]),
user_table.c.id != NO_AUTH_PLACEHOLDER_USER_UUID,
)
op.execute(
pg_insert(user__user_group_table)
.from_select(["user_group_id", "user_id"], admin_users)
.on_conflict_do_nothing(index_elements=["user_group_id", "user_id"])
)
# STANDARD users (non-admin) and SERVICE_ACCOUNT users (role=basic) → Basic group
# Include inactive users so reactivation doesn't require reconciliation.
basic_users = sa.select(
sa.literal(basic_row[0]).label("user_group_id"),
user_table.c.id.label("user_id"),
).where(
user_table.c.account_type.notin_(["BOT", "EXT_PERM_USER", "ANONYMOUS"]),
user_table.c.id != NO_AUTH_PLACEHOLDER_USER_UUID,
sa.or_(
sa.and_(
user_table.c.account_type == "STANDARD",
user_table.c.role != "ADMIN",
),
sa.and_(
user_table.c.account_type == "SERVICE_ACCOUNT",
user_table.c.role == "BASIC",
),
),
)
op.execute(
pg_insert(user__user_group_table)
.from_select(["user_group_id", "user_id"], basic_users)
.on_conflict_do_nothing(index_elements=["user_group_id", "user_id"])
)
def downgrade() -> None:
# Group memberships are left in place — removing them risks
# deleting memberships that existed before this migration.
pass

View File

@@ -1,20 +1,14 @@
from datetime import datetime
from datetime import timezone
from uuid import UUID
from celery import shared_task
from celery import Task
from ee.onyx.background.celery_utils import should_perform_chat_ttl_check
from ee.onyx.background.task_name_builders import name_chat_ttl_task
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.chat import delete_chat_session
from onyx.db.chat import get_chat_sessions_older_than
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import TaskStatus
from onyx.db.tasks import mark_task_as_finished_with_id
from onyx.db.tasks import register_task
from onyx.server.settings.store import load_settings
from onyx.utils.logger import setup_logger
@@ -29,59 +23,42 @@ logger = setup_logger()
trail=False,
)
def perform_ttl_management_task(
self: Task, retention_limit_days: int, *, tenant_id: str
self: Task, retention_limit_days: int, *, tenant_id: str # noqa: ARG001
) -> None:
task_id = self.request.id
if not task_id:
raise RuntimeError("No task id defined for this task; cannot identify it")
start_time = datetime.now(tz=timezone.utc)
user_id: UUID | None = None
session_id: UUID | None = None
try:
with get_session_with_current_tenant() as db_session:
# we generally want to move off this, but keeping for now
register_task(
db_session=db_session,
task_name=name_chat_ttl_task(retention_limit_days, tenant_id),
task_id=task_id,
status=TaskStatus.STARTED,
start_time=start_time,
)
old_chat_sessions = get_chat_sessions_older_than(
retention_limit_days, db_session
)
for user_id, session_id in old_chat_sessions:
# one session per delete so that we don't blow up if a deletion fails.
with get_session_with_current_tenant() as db_session:
delete_chat_session(
user_id,
session_id,
db_session,
include_deleted=True,
hard_delete=True,
try:
with get_session_with_current_tenant() as db_session:
delete_chat_session(
user_id,
session_id,
db_session,
include_deleted=True,
hard_delete=True,
)
except Exception:
logger.exception(
"Failed to delete chat session "
f"user_id={user_id} session_id={session_id}, "
"continuing with remaining sessions"
)
with get_session_with_current_tenant() as db_session:
mark_task_as_finished_with_id(
db_session=db_session,
task_id=task_id,
success=True,
)
except Exception:
logger.exception(
f"delete_chat_session exceptioned. user_id={user_id} session_id={session_id}"
)
with get_session_with_current_tenant() as db_session:
mark_task_as_finished_with_id(
db_session=db_session,
task_id=task_id,
success=False,
)
raise

View File

@@ -36,13 +36,16 @@ from ee.onyx.server.scim.filtering import ScimFilter
from ee.onyx.server.scim.filtering import ScimFilterOperator
from ee.onyx.server.scim.models import ScimMappingFields
from onyx.db.dal import DAL
from onyx.db.enums import AccountType
from onyx.db.enums import GrantSource
from onyx.db.enums import Permission
from onyx.db.models import PermissionGrant
from onyx.db.models import ScimGroupMapping
from onyx.db.models import ScimToken
from onyx.db.models import ScimUserMapping
from onyx.db.models import User
from onyx.db.models import User__UserGroup
from onyx.db.models import UserGroup
from onyx.db.models import UserRole
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -280,7 +283,9 @@ class ScimDAL(DAL):
query = (
select(User)
.join(ScimUserMapping, ScimUserMapping.user_id == User.id)
.where(User.role.notin_([UserRole.SLACK_USER, UserRole.EXT_PERM_USER]))
.where(
User.account_type.notin_([AccountType.BOT, AccountType.EXT_PERM_USER])
)
)
if scim_filter:
@@ -521,6 +526,22 @@ class ScimDAL(DAL):
self._session.add(group)
self._session.flush()
def add_permission_grant_to_group(
self,
group_id: int,
permission: Permission,
grant_source: GrantSource,
) -> None:
"""Grant a permission to a group and flush."""
self._session.add(
PermissionGrant(
group_id=group_id,
permission=permission,
grant_source=grant_source,
)
)
self._session.flush()
def update_group(
self,
group: UserGroup,

View File

@@ -19,6 +19,8 @@ from onyx.configs.app_configs import DISABLE_VECTOR_DB
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.enums import AccessType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.enums import GrantSource
from onyx.db.enums import Permission
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Credential
from onyx.db.models import Credential__UserGroup
@@ -28,6 +30,7 @@ from onyx.db.models import DocumentSet
from onyx.db.models import DocumentSet__UserGroup
from onyx.db.models import FederatedConnector__DocumentSet
from onyx.db.models import LLMProvider__UserGroup
from onyx.db.models import PermissionGrant
from onyx.db.models import Persona
from onyx.db.models import Persona__UserGroup
from onyx.db.models import TokenRateLimit__UserGroup
@@ -36,6 +39,7 @@ from onyx.db.models import User__UserGroup
from onyx.db.models import UserGroup
from onyx.db.models import UserGroup__ConnectorCredentialPair
from onyx.db.models import UserRole
from onyx.db.permissions import recompute_user_permissions__no_commit
from onyx.db.users import fetch_user_by_id
from onyx.utils.logger import setup_logger
@@ -255,6 +259,7 @@ def fetch_user_groups(
db_session: Session,
only_up_to_date: bool = True,
eager_load_for_snapshot: bool = False,
include_default: bool = True,
) -> Sequence[UserGroup]:
"""
Fetches user groups from the database.
@@ -269,6 +274,7 @@ def fetch_user_groups(
to include only up to date user groups. Defaults to `True`.
eager_load_for_snapshot: If True, adds eager loading for all relationships
needed by UserGroup.from_model snapshot creation.
include_default: If False, excludes system default groups (is_default=True).
Returns:
Sequence[UserGroup]: A sequence of `UserGroup` objects matching the query criteria.
@@ -276,6 +282,8 @@ def fetch_user_groups(
stmt = select(UserGroup)
if only_up_to_date:
stmt = stmt.where(UserGroup.is_up_to_date == True) # noqa: E712
if not include_default:
stmt = stmt.where(UserGroup.is_default == False) # noqa: E712
if eager_load_for_snapshot:
stmt = _add_user_group_snapshot_eager_loads(stmt)
return db_session.scalars(stmt).unique().all()
@@ -286,6 +294,7 @@ def fetch_user_groups_for_user(
user_id: UUID,
only_curator_groups: bool = False,
eager_load_for_snapshot: bool = False,
include_default: bool = True,
) -> Sequence[UserGroup]:
stmt = (
select(UserGroup)
@@ -295,6 +304,8 @@ def fetch_user_groups_for_user(
)
if only_curator_groups:
stmt = stmt.where(User__UserGroup.is_curator == True) # noqa: E712
if not include_default:
stmt = stmt.where(UserGroup.is_default == False) # noqa: E712
if eager_load_for_snapshot:
stmt = _add_user_group_snapshot_eager_loads(stmt)
return db_session.scalars(stmt).unique().all()
@@ -478,6 +489,16 @@ def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserG
db_session.add(db_user_group)
db_session.flush() # give the group an ID
# Every group gets the "basic" permission by default
db_session.add(
PermissionGrant(
group_id=db_user_group.id,
permission=Permission.BASIC_ACCESS,
grant_source=GrantSource.SYSTEM,
)
)
db_session.flush()
_add_user__user_group_relationships__no_commit(
db_session=db_session,
user_group_id=db_user_group.id,
@@ -489,6 +510,8 @@ def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserG
cc_pair_ids=user_group.cc_pair_ids,
)
recompute_user_permissions__no_commit(user_group.user_ids, db_session)
db_session.commit()
return db_user_group
@@ -796,6 +819,10 @@ def update_user_group(
# update "time_updated" to now
db_user_group.time_last_modified_by_user = func.now()
recompute_user_permissions__no_commit(
list(set(added_user_ids) | set(removed_user_ids)), db_session
)
db_session.commit()
return db_user_group
@@ -835,6 +862,19 @@ def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) ->
_check_user_group_is_modifiable(db_user_group)
# Collect affected user IDs before cleanup deletes the relationships
affected_user_ids: list[UUID] = [
uid
for uid in db_session.execute(
select(User__UserGroup.user_id).where(
User__UserGroup.user_group_id == user_group_id
)
)
.scalars()
.all()
if uid is not None
]
_mark_user_group__cc_pair_relationships_outdated__no_commit(
db_session=db_session, user_group_id=user_group_id
)
@@ -863,6 +903,10 @@ def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) ->
db_session=db_session, user_group_id=user_group_id
)
# Recompute permissions for affected users now that their
# membership in this group has been removed
recompute_user_permissions__no_commit(affected_user_ids, db_session)
db_user_group.is_up_to_date = False
db_user_group.is_up_for_deletion = True
db_session.commit()

View File

@@ -52,16 +52,25 @@ from ee.onyx.server.scim.schema_definitions import SERVICE_PROVIDER_CONFIG
from ee.onyx.server.scim.schema_definitions import USER_RESOURCE_TYPE
from ee.onyx.server.scim.schema_definitions import USER_SCHEMA_DEF
from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import AccountType
from onyx.db.enums import GrantSource
from onyx.db.enums import Permission
from onyx.db.models import ScimToken
from onyx.db.models import ScimUserMapping
from onyx.db.models import User
from onyx.db.models import UserGroup
from onyx.db.models import UserRole
from onyx.db.permissions import recompute_permissions_for_group__no_commit
from onyx.db.permissions import recompute_user_permissions__no_commit
from onyx.db.users import assign_user_to_default_groups__no_commit
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
logger = setup_logger()
# Group names reserved for system default groups (seeded by migration).
_RESERVED_GROUP_NAMES = frozenset({"Admin", "Basic"})
class ScimJSONResponse(JSONResponse):
"""JSONResponse with Content-Type: application/scim+json (RFC 7644 §3.1)."""
@@ -486,6 +495,7 @@ def create_user(
email=email,
hashed_password=_pw_helper.hash(_pw_helper.generate()),
role=UserRole.BASIC,
account_type=AccountType.STANDARD,
is_active=user_resource.active,
is_verified=True,
personal_name=personal_name,
@@ -506,13 +516,25 @@ def create_user(
scim_username=scim_username,
fields=fields,
)
dal.commit()
except IntegrityError:
dal.rollback()
return _scim_error_response(
409, f"User with email {email} already has a SCIM mapping"
)
# Assign user to default group BEFORE commit so everything is atomic.
# If this fails, the entire user creation rolls back and IdP can retry.
try:
assign_user_to_default_groups__no_commit(db_session, user)
except Exception:
dal.rollback()
logger.exception(f"Failed to assign SCIM user {email} to default groups")
return _scim_error_response(
500, f"Failed to assign user {email} to default group"
)
dal.commit()
return _scim_resource_response(
provider.build_user_resource(
user,
@@ -542,7 +564,8 @@ def replace_user(
user = result
# Handle activation (need seat check) / deactivation
if user_resource.active and not user.is_active:
is_reactivation = user_resource.active and not user.is_active
if is_reactivation:
seat_error = _check_seat_availability(dal)
if seat_error:
return _scim_error_response(403, seat_error)
@@ -556,6 +579,12 @@ def replace_user(
personal_name=personal_name,
)
# Reconcile default-group membership on reactivation
if is_reactivation:
assign_user_to_default_groups__no_commit(
db_session, user, is_admin=(user.role == UserRole.ADMIN)
)
new_external_id = user_resource.externalId
scim_username = user_resource.userName.strip()
fields = _fields_from_resource(user_resource)
@@ -621,6 +650,7 @@ def patch_user(
return _scim_error_response(e.status, e.detail)
# Apply changes back to the DB model
is_reactivation = patched.active and not user.is_active
if patched.active != user.is_active:
if patched.active:
seat_error = _check_seat_availability(dal)
@@ -649,6 +679,12 @@ def patch_user(
personal_name=personal_name,
)
# Reconcile default-group membership on reactivation
if is_reactivation:
assign_user_to_default_groups__no_commit(
db_session, user, is_admin=(user.role == UserRole.ADMIN)
)
# Build updated fields by merging PATCH enterprise data with current values
cf = current_fields or ScimMappingFields()
fields = ScimMappingFields(
@@ -857,6 +893,11 @@ def create_group(
dal = ScimDAL(db_session)
dal.update_token_last_used(_token.id)
if group_resource.displayName in _RESERVED_GROUP_NAMES:
return _scim_error_response(
409, f"'{group_resource.displayName}' is a reserved group name."
)
if dal.get_group_by_name(group_resource.displayName):
return _scim_error_response(
409, f"Group with name '{group_resource.displayName}' already exists"
@@ -879,8 +920,18 @@ def create_group(
409, f"Group with name '{group_resource.displayName}' already exists"
)
# Every group gets the "basic" permission by default.
dal.add_permission_grant_to_group(
group_id=db_group.id,
permission=Permission.BASIC_ACCESS,
grant_source=GrantSource.SYSTEM,
)
dal.upsert_group_members(db_group.id, member_uuids)
# Recompute permissions for initial members.
recompute_user_permissions__no_commit(member_uuids, db_session)
external_id = group_resource.externalId
if external_id:
dal.create_group_mapping(external_id=external_id, user_group_id=db_group.id)
@@ -911,14 +962,36 @@ def replace_group(
return result
group = result
if group.name in _RESERVED_GROUP_NAMES and group_resource.displayName != group.name:
return _scim_error_response(
409, f"'{group.name}' is a reserved group name and cannot be renamed."
)
if (
group_resource.displayName in _RESERVED_GROUP_NAMES
and group_resource.displayName != group.name
):
return _scim_error_response(
409, f"'{group_resource.displayName}' is a reserved group name."
)
member_uuids, err = _validate_and_parse_members(group_resource.members, dal)
if err:
return _scim_error_response(400, err)
# Capture old member IDs before replacing so we can recompute their
# permissions after they are removed from the group.
old_member_ids = {uid for uid, _ in dal.get_group_members(group.id)}
dal.update_group(group, name=group_resource.displayName)
dal.replace_group_members(group.id, member_uuids)
dal.sync_group_external_id(group.id, group_resource.externalId)
# Recompute permissions for current members (batch) and removed members.
recompute_permissions_for_group__no_commit(group.id, db_session)
removed_ids = list(old_member_ids - set(member_uuids))
recompute_user_permissions__no_commit(removed_ids, db_session)
dal.commit()
members = dal.get_group_members(group.id)
@@ -961,8 +1034,19 @@ def patch_group(
return _scim_error_response(e.status, e.detail)
new_name = patched.displayName if patched.displayName != group.name else None
if group.name in _RESERVED_GROUP_NAMES and new_name:
return _scim_error_response(
409, f"'{group.name}' is a reserved group name and cannot be renamed."
)
if new_name and new_name in _RESERVED_GROUP_NAMES:
return _scim_error_response(409, f"'{new_name}' is a reserved group name.")
dal.update_group(group, name=new_name)
affected_uuids: list[UUID] = []
if added_ids:
add_uuids = [UUID(mid) for mid in added_ids if _is_valid_uuid(mid)]
if add_uuids:
@@ -973,10 +1057,15 @@ def patch_group(
f"Member(s) not found: {', '.join(str(u) for u in missing)}",
)
dal.upsert_group_members(group.id, add_uuids)
affected_uuids.extend(add_uuids)
if removed_ids:
remove_uuids = [UUID(mid) for mid in removed_ids if _is_valid_uuid(mid)]
dal.remove_group_members(group.id, remove_uuids)
affected_uuids.extend(remove_uuids)
# Recompute permissions for all users whose group membership changed.
recompute_user_permissions__no_commit(affected_uuids, db_session)
dal.sync_group_external_id(group.id, patched.externalId)
dal.commit()
@@ -1002,11 +1091,21 @@ def delete_group(
return result
group = result
if group.name in _RESERVED_GROUP_NAMES:
return _scim_error_response(409, f"'{group.name}' is a reserved group name.")
# Capture member IDs before deletion so we can recompute their permissions.
affected_user_ids = [uid for uid, _ in dal.get_group_members(group.id)]
mapping = dal.get_group_mapping_by_group_id(group.id)
if mapping:
dal.delete_group_mapping(mapping.id)
dal.delete_group_with_members(group)
# Recompute permissions for users who lost this group membership.
recompute_user_permissions__no_commit(affected_user_ids, db_session)
dal.commit()
return Response(status_code=204)

View File

@@ -43,12 +43,16 @@ router = APIRouter(prefix="/manage", tags=PUBLIC_API_TAGS)
@router.get("/admin/user-group")
def list_user_groups(
include_default: bool = False,
user: User = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> list[UserGroup]:
if user.role == UserRole.ADMIN:
user_groups = fetch_user_groups(
db_session, only_up_to_date=False, eager_load_for_snapshot=True
db_session,
only_up_to_date=False,
eager_load_for_snapshot=True,
include_default=include_default,
)
else:
user_groups = fetch_user_groups_for_user(
@@ -56,27 +60,50 @@ def list_user_groups(
user_id=user.id,
only_curator_groups=user.role == UserRole.CURATOR,
eager_load_for_snapshot=True,
include_default=include_default,
)
return [UserGroup.from_model(user_group) for user_group in user_groups]
@router.get("/user-groups/minimal")
def list_minimal_user_groups(
include_default: bool = False,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> list[MinimalUserGroupSnapshot]:
if user.role == UserRole.ADMIN:
user_groups = fetch_user_groups(db_session, only_up_to_date=False)
user_groups = fetch_user_groups(
db_session,
only_up_to_date=False,
include_default=include_default,
)
else:
user_groups = fetch_user_groups_for_user(
db_session=db_session,
user_id=user.id,
include_default=include_default,
)
return [
MinimalUserGroupSnapshot.from_model(user_group) for user_group in user_groups
]
@router.get("/admin/user-group/{user_group_id}/permissions")
def get_user_group_permissions(
user_group_id: int,
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[str]:
group = fetch_user_group(db_session, user_group_id)
if group is None:
raise OnyxError(OnyxErrorCode.NOT_FOUND, "User group not found")
return [
grant.permission.value
for grant in group.permission_grants
if not grant.is_deleted
]
@router.post("/admin/user-group")
def create_user_group(
user_group: UserGroupCreate,
@@ -100,6 +127,9 @@ def rename_user_group_endpoint(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> UserGroup:
group = fetch_user_group(db_session, rename_request.id)
if group and group.is_default:
raise OnyxError(OnyxErrorCode.CONFLICT, "Cannot rename a default system group.")
try:
return UserGroup.from_model(
rename_user_group(
@@ -185,6 +215,9 @@ def delete_user_group(
_: User = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
group = fetch_user_group(db_session, user_group_id)
if group and group.is_default:
raise OnyxError(OnyxErrorCode.CONFLICT, "Cannot delete a default system group.")
try:
prepare_user_group_for_deletion(db_session, user_group_id)
except ValueError as e:

View File

@@ -22,6 +22,7 @@ class UserGroup(BaseModel):
personas: list[PersonaSnapshot]
is_up_to_date: bool
is_up_for_deletion: bool
is_default: bool
@classmethod
def from_model(cls, user_group_model: UserGroupModel) -> "UserGroup":
@@ -74,18 +75,21 @@ class UserGroup(BaseModel):
],
is_up_to_date=user_group_model.is_up_to_date,
is_up_for_deletion=user_group_model.is_up_for_deletion,
is_default=user_group_model.is_default,
)
class MinimalUserGroupSnapshot(BaseModel):
id: int
name: str
is_default: bool
@classmethod
def from_model(cls, user_group_model: UserGroupModel) -> "MinimalUserGroupSnapshot":
return cls(
id=user_group_model.id,
name=user_group_model.name,
is_default=user_group_model.is_default,
)

View File

@@ -0,0 +1,110 @@
"""
Permission resolution for group-based authorization.
Granted permissions are stored as a JSONB column on the User table and
loaded for free with every auth query. Implied permissions are expanded
at read time — only directly granted permissions are persisted.
"""
from collections.abc import Callable
from collections.abc import Coroutine
from typing import Any
from fastapi import Depends
from onyx.auth.users import current_user
from onyx.db.enums import Permission
from onyx.db.models import User
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.utils.logger import setup_logger
logger = setup_logger()
ALL_PERMISSIONS: frozenset[str] = frozenset(p.value for p in Permission)
# Implication map: granted permission -> set of permissions it implies.
IMPLIED_PERMISSIONS: dict[str, set[str]] = {
Permission.ADD_AGENTS.value: {Permission.READ_AGENTS.value},
Permission.MANAGE_AGENTS.value: {
Permission.ADD_AGENTS.value,
Permission.READ_AGENTS.value,
},
Permission.MANAGE_DOCUMENT_SETS.value: {
Permission.READ_DOCUMENT_SETS.value,
Permission.READ_CONNECTORS.value,
},
Permission.ADD_CONNECTORS.value: {Permission.READ_CONNECTORS.value},
Permission.MANAGE_CONNECTORS.value: {
Permission.ADD_CONNECTORS.value,
Permission.READ_CONNECTORS.value,
},
Permission.MANAGE_USER_GROUPS.value: {
Permission.READ_CONNECTORS.value,
Permission.READ_DOCUMENT_SETS.value,
Permission.READ_AGENTS.value,
Permission.READ_USERS.value,
},
}
def resolve_effective_permissions(granted: set[str]) -> set[str]:
"""Expand granted permissions with their implied permissions.
If "admin" is present, returns all 19 permissions.
"""
if Permission.FULL_ADMIN_PANEL_ACCESS.value in granted:
return set(ALL_PERMISSIONS)
effective = set(granted)
changed = True
while changed:
changed = False
for perm in list(effective):
implied = IMPLIED_PERMISSIONS.get(perm)
if implied and not implied.issubset(effective):
effective |= implied
changed = True
return effective
def get_effective_permissions(user: User) -> set[Permission]:
"""Read granted permissions from the column and expand implied permissions."""
granted: set[Permission] = set()
for p in user.effective_permissions:
try:
granted.add(Permission(p))
except ValueError:
logger.warning(f"Skipping unknown permission '{p}' for user {user.id}")
if Permission.FULL_ADMIN_PANEL_ACCESS in granted:
return set(Permission)
expanded = resolve_effective_permissions({p.value for p in granted})
return {Permission(p) for p in expanded}
def require_permission(
required: Permission,
) -> Callable[..., Coroutine[Any, Any, User]]:
"""FastAPI dependency factory for permission-based access control.
Usage:
@router.get("/endpoint")
def endpoint(user: User = Depends(require_permission(Permission.MANAGE_CONNECTORS))):
...
"""
async def dependency(user: User = Depends(current_user)) -> User:
effective = get_effective_permissions(user)
if Permission.FULL_ADMIN_PANEL_ACCESS in effective:
return user
if required not in effective:
raise OnyxError(
OnyxErrorCode.INSUFFICIENT_PERMISSIONS,
"You do not have the required permissions for this action.",
)
return user
return dependency

View File

@@ -5,6 +5,8 @@ from typing import Any
from fastapi_users import schemas
from typing_extensions import override
from onyx.db.enums import AccountType
class UserRole(str, Enum):
"""
@@ -41,6 +43,7 @@ class UserRead(schemas.BaseUser[uuid.UUID]):
class UserCreate(schemas.BaseUserCreate):
role: UserRole = UserRole.BASIC
account_type: AccountType = AccountType.STANDARD
tenant_id: str | None = None
# Captcha token for cloud signup protection (optional, only used when captcha is enabled)
# Excluded from create_update_dict so it never reaches the DB layer
@@ -50,19 +53,19 @@ class UserCreate(schemas.BaseUserCreate):
def create_update_dict(self) -> dict[str, Any]:
d = super().create_update_dict()
d.pop("captcha_token", None)
# Force STANDARD for self-registration; only trusted paths
# (SCIM, API key creation) supply a different account_type directly.
d["account_type"] = AccountType.STANDARD
return d
@override
def create_update_dict_superuser(self) -> dict[str, Any]:
d = super().create_update_dict_superuser()
d.pop("captcha_token", None)
d.setdefault("account_type", self.account_type)
return d
class UserUpdateWithRole(schemas.BaseUserUpdate):
role: UserRole
class UserUpdate(schemas.BaseUserUpdate):
"""
Role updates are not allowed through the user update endpoint for security reasons

View File

@@ -80,7 +80,6 @@ from onyx.auth.pat import get_hashed_pat_from_request
from onyx.auth.schemas import AuthBackend
from onyx.auth.schemas import UserCreate
from onyx.auth.schemas import UserRole
from onyx.auth.schemas import UserUpdateWithRole
from onyx.configs.app_configs import AUTH_BACKEND
from onyx.configs.app_configs import AUTH_COOKIE_EXPIRE_TIME_SECONDS
from onyx.configs.app_configs import AUTH_TYPE
@@ -120,11 +119,13 @@ from onyx.db.engine.async_sql_engine import get_async_session
from onyx.db.engine.async_sql_engine import get_async_session_context_manager
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.enums import AccountType
from onyx.db.models import AccessToken
from onyx.db.models import OAuthAccount
from onyx.db.models import Persona
from onyx.db.models import User
from onyx.db.pat import fetch_user_for_pat
from onyx.db.users import assign_user_to_default_groups__no_commit
from onyx.db.users import get_user_by_email
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import log_onyx_error
@@ -500,18 +501,21 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
user = user_by_session
if (
user.role.is_web_login()
user.account_type.is_web_login()
or not isinstance(user_create, UserCreate)
or not user_create.role.is_web_login()
or not user_create.account_type.is_web_login()
):
raise exceptions.UserAlreadyExists()
user_update = UserUpdateWithRole(
password=user_create.password,
is_verified=user_create.is_verified,
role=user_create.role,
)
user = await self.update(user_update, user)
# Cache id before expire — accessing attrs on an expired
# object triggers a sync lazy-load which raises MissingGreenlet
# in this async context.
user_id = user.id
self._upgrade_user_to_standard__sync(user_id, user_create)
# Expire so the async session re-fetches the row updated by
# the sync session above.
self.user_db.session.expire(user)
user = await self.user_db.get(user_id) # type: ignore[assignment]
except exceptions.UserAlreadyExists:
user = await self.get_by_email(user_create.email)
@@ -525,18 +529,21 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
# Handle case where user has used product outside of web and is now creating an account through web
if (
user.role.is_web_login()
user.account_type.is_web_login()
or not isinstance(user_create, UserCreate)
or not user_create.role.is_web_login()
or not user_create.account_type.is_web_login()
):
raise exceptions.UserAlreadyExists()
user_update = UserUpdateWithRole(
password=user_create.password,
is_verified=user_create.is_verified,
role=user_create.role,
)
user = await self.update(user_update, user)
# Cache id before expire — accessing attrs on an expired
# object triggers a sync lazy-load which raises MissingGreenlet
# in this async context.
user_id = user.id
self._upgrade_user_to_standard__sync(user_id, user_create)
# Expire so the async session re-fetches the row updated by
# the sync session above.
self.user_db.session.expire(user)
user = await self.user_db.get(user_id) # type: ignore[assignment]
if user_created:
await self._assign_default_pinned_assistants(user, db_session)
remove_user_from_invited_users(user_create.email)
@@ -573,6 +580,38 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
)
user.pinned_assistants = default_persona_ids
def _upgrade_user_to_standard__sync(
self,
user_id: uuid.UUID,
user_create: UserCreate,
) -> None:
"""Upgrade a non-web user to STANDARD and assign default groups atomically.
All writes happen in a single sync transaction so neither the field
update nor the group assignment is visible without the other.
"""
with get_session_with_current_tenant() as sync_db:
sync_user = sync_db.query(User).filter(User.id == user_id).first() # type: ignore[arg-type]
if sync_user:
sync_user.hashed_password = self.password_helper.hash(
user_create.password
)
sync_user.is_verified = user_create.is_verified or False
sync_user.role = user_create.role
sync_user.account_type = AccountType.STANDARD
assign_user_to_default_groups__no_commit(
sync_db,
sync_user,
is_admin=(user_create.role == UserRole.ADMIN),
)
sync_db.commit()
else:
logger.warning(
"User %s not found in sync session during upgrade to standard; "
"skipping upgrade",
user_id,
)
async def validate_password(self, password: str, _: schemas.UC | models.UP) -> None:
# Validate password according to configurable security policy (defined via environment variables)
if len(password) < PASSWORD_MIN_LENGTH:
@@ -694,6 +733,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
"email": account_email,
"hashed_password": self.password_helper.hash(password),
"is_verified": is_verified_by_default,
"account_type": AccountType.STANDARD,
}
user = await self.user_db.create(user_dict)
@@ -726,7 +766,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
)
# Handle case where user has used product outside of web and is now creating an account through web
if not user.role.is_web_login():
if not user.account_type.is_web_login():
# We must use the existing user in the session if it matches
# the user we just got by email/oauth. Note that this only applies
# to multi-tenant, due to the overwriting of the user_db
@@ -743,14 +783,25 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
with get_session_with_current_tenant() as sync_db:
enforce_seat_limit(sync_db)
await self.user_db.update(
user,
{
"is_verified": is_verified_by_default,
"role": UserRole.BASIC,
**({"is_active": True} if not user.is_active else {}),
},
)
# Upgrade the user and assign default groups in a single
# transaction so neither change is visible without the other.
was_inactive = not user.is_active
with get_session_with_current_tenant() as sync_db:
sync_user = sync_db.query(User).filter(User.id == user.id).first() # type: ignore[arg-type]
if sync_user:
sync_user.is_verified = is_verified_by_default
sync_user.role = UserRole.BASIC
sync_user.account_type = AccountType.STANDARD
if was_inactive:
sync_user.is_active = True
assign_user_to_default_groups__no_commit(sync_db, sync_user)
sync_db.commit()
# Refresh the async user object so downstream code
# (e.g. oidc_expiry check) sees the updated fields.
self.user_db.session.expire(user)
user = await self.user_db.get(user.id)
assert user is not None
# this is needed if an organization goes from `TRACK_EXTERNAL_IDP_EXPIRY=true` to `false`
# otherwise, the oidc expiry will always be old, and the user will never be able to login
@@ -836,6 +887,16 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
event=MilestoneRecordType.TENANT_CREATED,
)
# Assign user to the appropriate default group (Admin or Basic).
# Must happen inside the try block while tenant context is active,
# otherwise get_session_with_current_tenant() targets the wrong schema.
is_admin = user_count == 1 or user.email in get_default_admin_user_emails()
with get_session_with_current_tenant() as db_session:
assign_user_to_default_groups__no_commit(
db_session, user, is_admin=is_admin
)
db_session.commit()
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
@@ -975,7 +1036,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
self.password_helper.hash(credentials.password)
return None
if not user.role.is_web_login():
if not user.account_type.is_web_login():
raise BasicAuthenticationError(
detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD",
)
@@ -1471,7 +1532,7 @@ async def _get_or_create_user_from_jwt(
if not user.is_active:
logger.warning("Inactive user %s attempted JWT login; skipping", email)
return None
if not user.role.is_web_login():
if not user.account_type.is_web_login():
raise exceptions.UserNotExists()
except exceptions.UserNotExists:
logger.info("Provisioning user %s from JWT login", email)
@@ -1492,7 +1553,7 @@ async def _get_or_create_user_from_jwt(
email,
)
return None
if not user.role.is_web_login():
if not user.account_type.is_web_login():
logger.warning(
"Non-web-login user %s attempted JWT login during provisioning race; skipping",
email,
@@ -1554,6 +1615,7 @@ def get_anonymous_user() -> User:
is_verified=True,
is_superuser=False,
role=UserRole.LIMITED,
account_type=AccountType.ANONYMOUS,
use_memories=False,
enable_memory_tool=False,
)

View File

@@ -36,6 +36,7 @@ from onyx.configs.constants import OnyxRedisLocks
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.opensearch_migration import build_sanitized_to_original_doc_id_mapping
from onyx.db.opensearch_migration import get_vespa_visit_state
from onyx.db.opensearch_migration import is_migration_completed
from onyx.db.opensearch_migration import (
mark_migration_completed_time_if_not_set_with_commit,
)
@@ -106,14 +107,19 @@ def migrate_chunks_from_vespa_to_opensearch_task(
acquired; effectively a no-op. True if the task completed
successfully. False if the task errored.
"""
# 1. Check if we should run the task.
# 1.a. If OpenSearch indexing is disabled, we don't run the task.
if not ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
task_logger.warning(
"OpenSearch migration is not enabled, skipping chunk migration task."
)
return None
task_logger.info("Starting chunk-level migration from Vespa to OpenSearch.")
task_start_time = time.monotonic()
# 1.b. Only one instance per tenant of this task may run concurrently at
# once. If we fail to acquire a lock, we assume it is because another task
# has one and we exit.
r = get_redis_client()
lock: RedisLock = r.lock(
name=OnyxRedisLocks.OPENSEARCH_MIGRATION_BEAT_LOCK,
@@ -136,10 +142,11 @@ def migrate_chunks_from_vespa_to_opensearch_task(
f"Token: {lock.local.token}"
)
# 2. Prepare to migrate.
total_chunks_migrated_this_task = 0
total_chunks_errored_this_task = 0
try:
# Double check that tenant info is correct.
# 2.a. Double-check that tenant info is correct.
if tenant_id != get_current_tenant_id():
err_str = (
f"Tenant ID mismatch in the OpenSearch migration task: "
@@ -148,16 +155,62 @@ def migrate_chunks_from_vespa_to_opensearch_task(
task_logger.error(err_str)
return False
with (
get_session_with_current_tenant() as db_session,
get_vespa_http_client(
timeout=VESPA_MIGRATION_REQUEST_TIMEOUT_S
) as vespa_client,
):
# Do as much as we can with a DB session in one spot to not hold a
# session during a migration batch.
with get_session_with_current_tenant() as db_session:
# 2.b. Immediately check to see if this tenant is done, to save
# having to do any other work. This function does not require a
# migration record to necessarily exist.
if is_migration_completed(db_session):
return True
# 2.c. Try to insert the OpenSearchTenantMigrationRecord table if it
# does not exist.
try_insert_opensearch_tenant_migration_record_with_commit(db_session)
# 2.d. Get search settings.
search_settings = get_current_search_settings(db_session)
tenant_state = TenantState(tenant_id=tenant_id, multitenant=MULTI_TENANT)
indexing_setting = IndexingSetting.from_db_model(search_settings)
# 2.e. Build sanitized to original doc ID mapping to check for
# conflicts in the event we sanitize a doc ID to an
# already-existing doc ID.
# We reconstruct this mapping for every task invocation because
# a document may have been added in the time between two tasks.
sanitized_doc_start_time = time.monotonic()
sanitized_to_original_doc_id_mapping = (
build_sanitized_to_original_doc_id_mapping(db_session)
)
task_logger.debug(
f"Built sanitized_to_original_doc_id_mapping with {len(sanitized_to_original_doc_id_mapping)} entries "
f"in {time.monotonic() - sanitized_doc_start_time:.3f} seconds."
)
# 2.f. Get the current migration state.
continuation_token_map, total_chunks_migrated = get_vespa_visit_state(
db_session
)
# 2.f.1. Double-check that the migration state does not imply
# completion. Really we should never have to enter this block as we
# would expect is_migration_completed to return True, but in the
# strange event that the migration is complete but the migration
# completed time was never stamped, we do so here.
if is_continuation_token_done_for_all_slices(continuation_token_map):
task_logger.info(
f"OpenSearch migration COMPLETED for tenant {tenant_id}. Total chunks migrated: {total_chunks_migrated}."
)
mark_migration_completed_time_if_not_set_with_commit(db_session)
return True
task_logger.debug(
f"Read the tenant migration record. Total chunks migrated: {total_chunks_migrated}. "
f"Continuation token map: {continuation_token_map}"
)
with get_vespa_http_client(
timeout=VESPA_MIGRATION_REQUEST_TIMEOUT_S
) as vespa_client:
# 2.g. Create the OpenSearch and Vespa document indexes.
tenant_state = TenantState(tenant_id=tenant_id, multitenant=MULTI_TENANT)
opensearch_document_index = OpenSearchDocumentIndex(
tenant_state=tenant_state,
index_name=search_settings.index_name,
@@ -171,22 +224,14 @@ def migrate_chunks_from_vespa_to_opensearch_task(
httpx_client=vespa_client,
)
sanitized_doc_start_time = time.monotonic()
# We reconstruct this mapping for every task invocation because a
# document may have been added in the time between two tasks.
sanitized_to_original_doc_id_mapping = (
build_sanitized_to_original_doc_id_mapping(db_session)
)
task_logger.debug(
f"Built sanitized_to_original_doc_id_mapping with {len(sanitized_to_original_doc_id_mapping)} entries "
f"in {time.monotonic() - sanitized_doc_start_time:.3f} seconds."
)
# 2.h. Get the approximate chunk count in Vespa as of this time to
# update the migration record.
approx_chunk_count_in_vespa: int | None = None
get_chunk_count_start_time = time.monotonic()
try:
approx_chunk_count_in_vespa = vespa_document_index.get_chunk_count()
except Exception:
# This failure should not be blocking.
task_logger.exception(
"Error getting approximate chunk count in Vespa. Moving on..."
)
@@ -195,25 +240,12 @@ def migrate_chunks_from_vespa_to_opensearch_task(
f"approximate chunk count in Vespa. Got {approx_chunk_count_in_vespa}."
)
# 3. Do the actual migration in batches until we run out of time.
while (
time.monotonic() - task_start_time < MIGRATION_TASK_SOFT_TIME_LIMIT_S
and lock.owned()
):
(
continuation_token_map,
total_chunks_migrated,
) = get_vespa_visit_state(db_session)
if is_continuation_token_done_for_all_slices(continuation_token_map):
task_logger.info(
f"OpenSearch migration COMPLETED for tenant {tenant_id}. Total chunks migrated: {total_chunks_migrated}."
)
mark_migration_completed_time_if_not_set_with_commit(db_session)
break
task_logger.debug(
f"Read the tenant migration record. Total chunks migrated: {total_chunks_migrated}. "
f"Continuation token map: {continuation_token_map}"
)
# 3.a. Get the next batch of raw chunks from Vespa.
get_vespa_chunks_start_time = time.monotonic()
raw_vespa_chunks, next_continuation_token_map = (
vespa_document_index.get_all_raw_document_chunks_paginated(
@@ -226,6 +258,7 @@ def migrate_chunks_from_vespa_to_opensearch_task(
f"seconds. Next continuation token map: {next_continuation_token_map}"
)
# 3.b. Transform the raw chunks to OpenSearch chunks in memory.
opensearch_document_chunks, errored_chunks = (
transform_vespa_chunks_to_opensearch_chunks(
raw_vespa_chunks,
@@ -240,6 +273,7 @@ def migrate_chunks_from_vespa_to_opensearch_task(
"errored."
)
# 3.c. Index the OpenSearch chunks into OpenSearch.
index_opensearch_chunks_start_time = time.monotonic()
opensearch_document_index.index_raw_chunks(
chunks=opensearch_document_chunks
@@ -251,12 +285,38 @@ def migrate_chunks_from_vespa_to_opensearch_task(
total_chunks_migrated_this_task += len(opensearch_document_chunks)
total_chunks_errored_this_task += len(errored_chunks)
update_vespa_visit_progress_with_commit(
db_session,
continuation_token_map=next_continuation_token_map,
chunks_processed=len(opensearch_document_chunks),
chunks_errored=len(errored_chunks),
approx_chunk_count_in_vespa=approx_chunk_count_in_vespa,
# Do as much as we can with a DB session in one spot to not hold a
# session during a migration batch.
with get_session_with_current_tenant() as db_session:
# 3.d. Update the migration state.
update_vespa_visit_progress_with_commit(
db_session,
continuation_token_map=next_continuation_token_map,
chunks_processed=len(opensearch_document_chunks),
chunks_errored=len(errored_chunks),
approx_chunk_count_in_vespa=approx_chunk_count_in_vespa,
)
# 3.e. Get the current migration state. Even thought we
# technically have it in-memory since we just wrote it, we
# want to reference the DB as the source of truth at all
# times.
continuation_token_map, total_chunks_migrated = (
get_vespa_visit_state(db_session)
)
# 3.e.1. Check if the migration is done.
if is_continuation_token_done_for_all_slices(
continuation_token_map
):
task_logger.info(
f"OpenSearch migration COMPLETED for tenant {tenant_id}. Total chunks migrated: {total_chunks_migrated}."
)
mark_migration_completed_time_if_not_set_with_commit(db_session)
return True
task_logger.debug(
f"Read the tenant migration record. Total chunks migrated: {total_chunks_migrated}. "
f"Continuation token map: {continuation_token_map}"
)
except Exception:
traceback.print_exc()

View File

@@ -1,6 +1,7 @@
import uuid
from fastapi_users.password import PasswordHelper
from sqlalchemy import delete
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import joinedload
@@ -10,14 +11,23 @@ from onyx.auth.api_key import ApiKeyDescriptor
from onyx.auth.api_key import build_displayable_api_key
from onyx.auth.api_key import generate_api_key
from onyx.auth.api_key import hash_api_key
from onyx.auth.schemas import UserRole
from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
from onyx.configs.constants import DANSWER_API_KEY_PREFIX
from onyx.configs.constants import UNNAMED_KEY_PLACEHOLDER
from onyx.db.enums import AccountType
from onyx.db.models import ApiKey
from onyx.db.models import User
from onyx.db.models import User__UserGroup
from onyx.db.models import UserGroup
from onyx.db.permissions import recompute_user_permissions__no_commit
from onyx.db.users import assign_user_to_default_groups__no_commit
from onyx.server.api_key.models import APIKeyArgs
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
def get_api_key_email_pattern() -> str:
return DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
@@ -85,6 +95,7 @@ def insert_api_key(
is_superuser=False,
is_verified=True,
role=api_key_args.role,
account_type=AccountType.SERVICE_ACCOUNT,
)
db_session.add(api_key_user_row)
@@ -97,7 +108,18 @@ def insert_api_key(
)
db_session.add(api_key_row)
# Assign the API key virtual user to the appropriate default group
# before commit so everything is atomic.
# LIMITED role service accounts should have no group membership.
if api_key_args.role != UserRole.LIMITED:
assign_user_to_default_groups__no_commit(
db_session,
api_key_user_row,
is_admin=(api_key_args.role == UserRole.ADMIN),
)
db_session.commit()
return ApiKeyDescriptor(
api_key_id=api_key_row.id,
api_key_role=api_key_user_row.role,
@@ -124,7 +146,33 @@ def update_api_key(
email_name = api_key_args.name or UNNAMED_KEY_PLACEHOLDER
api_key_user.email = get_api_key_fake_email(email_name, str(api_key_user.id))
old_role = api_key_user.role
api_key_user.role = api_key_args.role
# Reconcile default-group membership when the role changes.
if old_role != api_key_args.role:
# Remove from all default groups first.
delete_stmt = delete(User__UserGroup).where(
User__UserGroup.user_id == api_key_user.id,
User__UserGroup.user_group_id.in_(
select(UserGroup.id).where(UserGroup.is_default.is_(True))
),
)
db_session.execute(delete_stmt)
# Re-assign to the correct default group (skip for LIMITED).
if api_key_args.role != UserRole.LIMITED:
assign_user_to_default_groups__no_commit(
db_session,
api_key_user,
is_admin=(api_key_args.role == UserRole.ADMIN),
)
else:
# No group assigned for LIMITED, but we still need to recompute
# since we just removed the old default-group membership above.
recompute_user_permissions__no_commit(api_key_user.id, db_session)
db_session.commit()
return ApiKeyDescriptor(

View File

@@ -199,7 +199,7 @@ def delete_messages_and_files_from_chat_session(
for _, files in messages_with_files:
file_store = get_default_file_store()
for file_info in files or []:
file_store.delete_file(file_id=file_info.get("id"))
file_store.delete_file(file_id=file_info.get("id"), error_on_missing=False)
# Delete ChatMessage records - CASCADE constraints will automatically handle:
# - ChatMessage__StandardAnswer relationship records

View File

@@ -13,19 +13,26 @@ class AccountType(str, PyEnum):
BOT, EXT_PERM_USER, ANONYMOUS → fixed behavior
"""
STANDARD = "standard"
BOT = "bot"
EXT_PERM_USER = "ext_perm_user"
SERVICE_ACCOUNT = "service_account"
ANONYMOUS = "anonymous"
STANDARD = "STANDARD"
BOT = "BOT"
EXT_PERM_USER = "EXT_PERM_USER"
SERVICE_ACCOUNT = "SERVICE_ACCOUNT"
ANONYMOUS = "ANONYMOUS"
def is_web_login(self) -> bool:
"""Whether this account type supports interactive web login."""
return self not in (
AccountType.BOT,
AccountType.EXT_PERM_USER,
)
class GrantSource(str, PyEnum):
"""How a permission grant was created."""
USER = "user"
SCIM = "scim"
SYSTEM = "system"
USER = "USER"
SCIM = "SCIM"
SYSTEM = "SYSTEM"
class IndexingStatus(str, PyEnum):

View File

@@ -305,8 +305,11 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
role: Mapped[UserRole] = mapped_column(
Enum(UserRole, native_enum=False, default=UserRole.BASIC)
)
account_type: Mapped[AccountType | None] = mapped_column(
Enum(AccountType, native_enum=False), nullable=True
account_type: Mapped[AccountType] = mapped_column(
Enum(AccountType, native_enum=False),
nullable=False,
default=AccountType.STANDARD,
server_default="STANDARD",
)
"""
@@ -353,6 +356,13 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
postgresql.JSONB(), nullable=True, default=None
)
effective_permissions: Mapped[list[str]] = mapped_column(
postgresql.JSONB(),
nullable=False,
default=list,
server_default=text("'[]'::jsonb"),
)
oidc_expiry: Mapped[datetime.datetime] = mapped_column(
TIMESTAMPAware(timezone=True), nullable=True
)
@@ -4016,7 +4026,12 @@ class PermissionGrant(Base):
ForeignKey("user_group.id", ondelete="CASCADE"), nullable=False
)
permission: Mapped[Permission] = mapped_column(
Enum(Permission, native_enum=False), nullable=False
Enum(
Permission,
native_enum=False,
values_callable=lambda x: [e.value for e in x],
),
nullable=False,
)
grant_source: Mapped[GrantSource] = mapped_column(
Enum(GrantSource, native_enum=False), nullable=False

View File

@@ -324,6 +324,15 @@ def mark_migration_completed_time_if_not_set_with_commit(
db_session.commit()
def is_migration_completed(db_session: Session) -> bool:
"""Returns True if the migration is completed.
Can be run even if the migration record does not exist.
"""
record = db_session.query(OpenSearchTenantMigrationRecord).first()
return record is not None and record.migration_completed_at is not None
def build_sanitized_to_original_doc_id_mapping(
db_session: Session,
) -> dict[str, str]:

View File

@@ -0,0 +1,95 @@
"""
DB operations for recomputing user effective_permissions.
These live in onyx/db/ (not onyx/auth/) because they are pure DB operations
that query PermissionGrant rows and update the User.effective_permissions
JSONB column. Keeping them here avoids circular imports when called from
other onyx/db/ modules such as users.py.
"""
from collections import defaultdict
from uuid import UUID
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.orm import Session
from onyx.db.models import PermissionGrant
from onyx.db.models import User
from onyx.db.models import User__UserGroup
def recompute_user_permissions__no_commit(
user_ids: UUID | str | list[UUID] | list[str], db_session: Session
) -> None:
"""Recompute granted permissions for one or more users.
Accepts a single UUID or a list. Uses a single query regardless of
how many users are passed, avoiding N+1 issues.
Stores only directly granted permissions — implication expansion
happens at read time via get_effective_permissions().
Does NOT commit — caller must commit the session.
"""
if isinstance(user_ids, (UUID, str)):
uid_list = [user_ids]
else:
uid_list = list(user_ids)
if not uid_list:
return
# Single query to fetch ALL permissions for these users across ALL their
# groups (a user may belong to multiple groups with different grants).
rows = db_session.execute(
select(User__UserGroup.user_id, PermissionGrant.permission)
.join(
PermissionGrant,
PermissionGrant.group_id == User__UserGroup.user_group_id,
)
.where(
User__UserGroup.user_id.in_(uid_list),
PermissionGrant.is_deleted.is_(False),
)
).all()
# Group permissions by user; users with no grants get an empty set.
perms_by_user: dict[UUID | str, set[str]] = defaultdict(set)
for uid in uid_list:
perms_by_user[uid] # ensure every user has an entry
for uid, perm in rows:
perms_by_user[uid].add(perm.value)
for uid, perms in perms_by_user.items():
db_session.execute(
update(User)
.where(User.id == uid) # type: ignore[arg-type]
.values(effective_permissions=sorted(perms))
)
def recompute_permissions_for_group__no_commit(
group_id: int, db_session: Session
) -> None:
"""Recompute granted permissions for all users in a group.
Does NOT commit — caller must commit the session.
"""
user_ids: list[UUID] = [
uid
for uid in db_session.execute(
select(User__UserGroup.user_id).where(
User__UserGroup.user_group_id == group_id,
User__UserGroup.user_id.isnot(None),
)
)
.scalars()
.all()
if uid is not None
]
if not user_ids:
return
recompute_user_permissions__no_commit(user_ids, db_session)

View File

@@ -5,11 +5,11 @@ from urllib.parse import urlencode
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.auth.schemas import UserRole
from onyx.configs.app_configs import INSTANCE_TYPE
from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
from onyx.configs.constants import NotificationType
from onyx.configs.constants import ONYX_UTM_SOURCE
from onyx.db.enums import AccountType
from onyx.db.models import User
from onyx.db.notification import batch_create_notifications
from onyx.server.features.release_notes.constants import DOCS_CHANGELOG_BASE_URL
@@ -49,7 +49,7 @@ def create_release_notifications_for_versions(
db_session.scalars(
select(User.id).where( # type: ignore
User.is_active == True, # noqa: E712
User.role.notin_([UserRole.SLACK_USER, UserRole.EXT_PERM_USER]),
User.account_type.notin_([AccountType.BOT, AccountType.EXT_PERM_USER]),
User.email.endswith(DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN).is_(False), # type: ignore[attr-defined]
)
).all()

View File

@@ -9,12 +9,17 @@ from sqlalchemy import update
from sqlalchemy.orm import Session
from onyx.auth.schemas import UserRole
from onyx.db.enums import AccountType
from onyx.db.enums import DefaultAppMode
from onyx.db.enums import ThemePreference
from onyx.db.models import AccessToken
from onyx.db.models import Assistant__UserSpecificConfig
from onyx.db.models import Memory
from onyx.db.models import User
from onyx.db.models import User__UserGroup
from onyx.db.models import UserGroup
from onyx.db.permissions import recompute_user_permissions__no_commit
from onyx.db.users import assign_user_to_default_groups__no_commit
from onyx.server.manage.models import MemoryItem
from onyx.server.manage.models import UserSpecificAssistantPreference
from onyx.utils.logger import setup_logger
@@ -23,13 +28,53 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
_ROLE_TO_ACCOUNT_TYPE: dict[UserRole, AccountType] = {
UserRole.SLACK_USER: AccountType.BOT,
UserRole.EXT_PERM_USER: AccountType.EXT_PERM_USER,
}
def update_user_role(
user: User,
new_role: UserRole,
db_session: Session,
) -> None:
"""Update a user's role in the database."""
"""Update a user's role in the database.
Dual-writes account_type to keep it in sync with role and
reconciles default-group membership (Admin / Basic)."""
old_role = user.role
user.role = new_role
# Note: setting account_type to BOT or EXT_PERM_USER causes
# assign_user_to_default_groups__no_commit to early-return, which is
# intentional — these account types should not be in default groups.
if new_role in _ROLE_TO_ACCOUNT_TYPE:
user.account_type = _ROLE_TO_ACCOUNT_TYPE[new_role]
elif user.account_type in (AccountType.BOT, AccountType.EXT_PERM_USER):
# Upgrading from a non-web-login account type to a web role
user.account_type = AccountType.STANDARD
# Reconcile default-group membership when the role changes.
if old_role != new_role:
# Remove from all default groups first.
db_session.execute(
delete(User__UserGroup).where(
User__UserGroup.user_id == user.id,
User__UserGroup.user_group_id.in_(
select(UserGroup.id).where(UserGroup.is_default.is_(True))
),
)
)
# Re-assign to the correct default group (skip for LIMITED).
if new_role != UserRole.LIMITED:
assign_user_to_default_groups__no_commit(
db_session,
user,
is_admin=(new_role == UserRole.ADMIN),
)
recompute_user_permissions__no_commit(user.id, db_session)
db_session.commit()
@@ -47,8 +92,16 @@ def activate_user(
user: User,
db_session: Session,
) -> None:
"""Activate a user by setting is_active to True."""
"""Activate a user by setting is_active to True.
Also reconciles default-group membership — the user may have been
created while inactive or deactivated before the backfill migration.
"""
user.is_active = True
if user.role != UserRole.LIMITED:
assign_user_to_default_groups__no_commit(
db_session, user, is_admin=(user.role == UserRole.ADMIN)
)
db_session.add(user)
db_session.commit()

View File

@@ -17,8 +17,9 @@ from sqlalchemy.sql.expression import or_
from onyx.auth.invited_users import remove_user_from_invited_users
from onyx.auth.schemas import UserRole
from onyx.configs.constants import ANONYMOUS_USER_EMAIL
from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
from onyx.configs.constants import NO_AUTH_PLACEHOLDER_USER_EMAIL
from onyx.db.api_key import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
from onyx.db.enums import AccountType
from onyx.db.models import DocumentSet
from onyx.db.models import DocumentSet__User
from onyx.db.models import Persona
@@ -27,11 +28,17 @@ from onyx.db.models import SamlAccount
from onyx.db.models import User
from onyx.db.models import User__UserGroup
from onyx.db.models import UserGroup
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
logger = setup_logger()
def validate_user_role_update(
requested_role: UserRole, current_role: UserRole, explicit_override: bool = False
requested_role: UserRole,
current_role: UserRole,
current_account_type: AccountType,
explicit_override: bool = False,
) -> None:
"""
Validate that a user role update is valid.
@@ -41,19 +48,18 @@ def validate_user_role_update(
- requested role is a slack user
- requested role is an external permissioned user
- requested role is a limited user
- current role is a slack user
- current role is an external permissioned user
- current account type is BOT (slack user)
- current account type is EXT_PERM_USER
- current role is a limited user
"""
if current_role == UserRole.SLACK_USER:
if current_account_type == AccountType.BOT:
raise HTTPException(
status_code=400,
detail="To change a Slack User's role, they must first login to Onyx via the web app.",
)
if current_role == UserRole.EXT_PERM_USER:
# This shouldn't happen, but just in case
if current_account_type == AccountType.EXT_PERM_USER:
raise HTTPException(
status_code=400,
detail="To change an External Permissioned User's role, they must first login to Onyx via the web app.",
@@ -298,6 +304,7 @@ def _generate_slack_user(email: str) -> User:
email=email,
hashed_password=hashed_pass,
role=UserRole.SLACK_USER,
account_type=AccountType.BOT,
)
@@ -306,8 +313,9 @@ def add_slack_user_if_not_exists(db_session: Session, email: str) -> User:
user = get_user_by_email(email, db_session)
if user is not None:
# If the user is an external permissioned user, we update it to a slack user
if user.role == UserRole.EXT_PERM_USER:
if user.account_type == AccountType.EXT_PERM_USER:
user.role = UserRole.SLACK_USER
user.account_type = AccountType.BOT
db_session.commit()
return user
@@ -344,6 +352,7 @@ def _generate_ext_permissioned_user(email: str) -> User:
email=email,
hashed_password=hashed_pass,
role=UserRole.EXT_PERM_USER,
account_type=AccountType.EXT_PERM_USER,
)
@@ -375,6 +384,81 @@ def batch_add_ext_perm_user_if_not_exists(
return all_users
def assign_user_to_default_groups__no_commit(
db_session: Session,
user: User,
is_admin: bool = False,
) -> None:
"""Assign a newly created user to the appropriate default group.
Does NOT commit — callers must commit the session themselves so that
group assignment can be part of the same transaction as user creation.
Args:
is_admin: If True, assign to Admin default group; otherwise Basic.
Callers determine this from their own context (e.g. user_count,
admin email list, explicit choice). Defaults to False (Basic).
"""
if user.account_type in (
AccountType.BOT,
AccountType.EXT_PERM_USER,
AccountType.ANONYMOUS,
):
return
target_group_name = "Admin" if is_admin else "Basic"
default_group = (
db_session.query(UserGroup)
.filter(
UserGroup.name == target_group_name,
UserGroup.is_default.is_(True),
)
.first()
)
if default_group is None:
raise RuntimeError(
f"Default group '{target_group_name}' not found. "
f"Cannot assign user {user.email} to a group. "
f"Ensure the seed_default_groups migration has run."
)
# Check if the user is already in the group
existing = (
db_session.query(User__UserGroup)
.filter(
User__UserGroup.user_id == user.id,
User__UserGroup.user_group_id == default_group.id,
)
.first()
)
if existing is not None:
return
savepoint = db_session.begin_nested()
try:
db_session.add(
User__UserGroup(
user_id=user.id,
user_group_id=default_group.id,
)
)
db_session.flush()
except IntegrityError:
# Race condition: another transaction inserted this membership
# between our SELECT and INSERT. The savepoint isolates the failure
# so the outer transaction (user creation) stays intact.
savepoint.rollback()
return
from onyx.db.permissions import recompute_user_permissions__no_commit
recompute_user_permissions__no_commit(user.id, db_session)
logger.info(f"Assigned user {user.email} to default group '{default_group.name}'")
def delete_user_from_db(
user_to_delete: User,
db_session: Session,
@@ -421,13 +505,14 @@ def delete_user_from_db(
def batch_get_user_groups(
db_session: Session,
user_ids: list[UUID],
include_default: bool = False,
) -> dict[UUID, list[tuple[int, str]]]:
"""Fetch group memberships for a batch of users in a single query.
Returns a mapping of user_id -> list of (group_id, group_name) tuples."""
if not user_ids:
return {}
rows = db_session.execute(
stmt = (
select(
User__UserGroup.user_id,
UserGroup.id,
@@ -435,7 +520,11 @@ def batch_get_user_groups(
)
.join(UserGroup, UserGroup.id == User__UserGroup.user_group_id)
.where(User__UserGroup.user_id.in_(user_ids))
).all()
)
if not include_default:
stmt = stmt.where(UserGroup.is_default == False) # noqa: E712
rows = db_session.execute(stmt).all()
result: dict[UUID, list[tuple[int, str]]] = {uid: [] for uid in user_ids}
for user_id, group_id, group_name in rows:

View File

@@ -1,3 +1,4 @@
import hashlib
from datetime import datetime
from datetime import timezone
from typing import Any
@@ -20,9 +21,13 @@ from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE
from onyx.document_index.opensearch.constants import EF_CONSTRUCTION
from onyx.document_index.opensearch.constants import EF_SEARCH
from onyx.document_index.opensearch.constants import M
from onyx.document_index.opensearch.string_filtering import DocumentIDTooLongError
from onyx.document_index.opensearch.string_filtering import (
filter_and_validate_document_id,
)
from onyx.document_index.opensearch.string_filtering import (
MAX_DOCUMENT_ID_ENCODED_LENGTH,
)
from onyx.utils.tenant import get_tenant_id_short_string
from shared_configs.configs import MULTI_TENANT
from shared_configs.contextvars import get_current_tenant_id
@@ -75,17 +80,50 @@ def get_opensearch_doc_chunk_id(
This will be the string used to identify the chunk in OpenSearch. Any direct
chunk queries should use this function.
If the document ID is too long, a hash of the ID is used instead.
"""
sanitized_document_id = filter_and_validate_document_id(document_id)
opensearch_doc_chunk_id = (
f"{sanitized_document_id}__{max_chunk_size}__{chunk_index}"
opensearch_doc_chunk_id_suffix: str = f"__{max_chunk_size}__{chunk_index}"
encoded_suffix_length: int = len(opensearch_doc_chunk_id_suffix.encode("utf-8"))
max_encoded_permissible_doc_id_length: int = (
MAX_DOCUMENT_ID_ENCODED_LENGTH - encoded_suffix_length
)
opensearch_doc_chunk_id_tenant_prefix: str = ""
if tenant_state.multitenant:
short_tenant_id: str = get_tenant_id_short_string(tenant_state.tenant_id)
# Use tenant ID because in multitenant mode each tenant has its own
# Documents table, so there is a very small chance that doc IDs are not
# actually unique across all tenants.
short_tenant_id = get_tenant_id_short_string(tenant_state.tenant_id)
opensearch_doc_chunk_id = f"{short_tenant_id}__{opensearch_doc_chunk_id}"
opensearch_doc_chunk_id_tenant_prefix = f"{short_tenant_id}__"
encoded_prefix_length: int = len(
opensearch_doc_chunk_id_tenant_prefix.encode("utf-8")
)
max_encoded_permissible_doc_id_length -= encoded_prefix_length
try:
sanitized_document_id: str = filter_and_validate_document_id(
document_id, max_encoded_length=max_encoded_permissible_doc_id_length
)
except DocumentIDTooLongError:
# If the document ID is too long, use a hash instead.
# We use blake2b because it is faster and equally secure as SHA256, and
# accepts digest_size which controls the number of bytes returned in the
# hash.
# digest_size is the size of the returned hash in bytes. Since we're
# decoding the hash bytes as a hex string, the digest_size should be
# half the max target size of the hash string.
# Subtract 1 because filter_and_validate_document_id compares on >= on
# max_encoded_length.
# 64 is the max digest_size blake2b returns.
digest_size: int = min((max_encoded_permissible_doc_id_length - 1) // 2, 64)
sanitized_document_id = hashlib.blake2b(
document_id.encode("utf-8"), digest_size=digest_size
).hexdigest()
opensearch_doc_chunk_id: str = (
f"{opensearch_doc_chunk_id_tenant_prefix}{sanitized_document_id}{opensearch_doc_chunk_id_suffix}"
)
# Do one more validation to ensure we haven't exceeded the max length.
opensearch_doc_chunk_id = filter_and_validate_document_id(opensearch_doc_chunk_id)
return opensearch_doc_chunk_id

View File

@@ -1,7 +1,15 @@
import re
MAX_DOCUMENT_ID_ENCODED_LENGTH: int = 512
def filter_and_validate_document_id(document_id: str) -> str:
class DocumentIDTooLongError(ValueError):
"""Raised when a document ID is too long for OpenSearch after filtering."""
def filter_and_validate_document_id(
document_id: str, max_encoded_length: int = MAX_DOCUMENT_ID_ENCODED_LENGTH
) -> str:
"""
Filters and validates a document ID such that it can be used as an ID in
OpenSearch.
@@ -19,9 +27,13 @@ def filter_and_validate_document_id(document_id: str) -> str:
Args:
document_id: The document ID to filter and validate.
max_encoded_length: The maximum length of the document ID after
filtering in bytes. Compared with >= for extra resilience, so
encoded values of this length will fail.
Raises:
ValueError: If the document ID is empty or too long after filtering.
DocumentIDTooLongError: If the document ID is too long after filtering.
ValueError: If the document ID is empty after filtering.
Returns:
str: The filtered document ID.
@@ -29,6 +41,8 @@ def filter_and_validate_document_id(document_id: str) -> str:
filtered_document_id = re.sub(r"[^A-Za-z0-9_.\-~]", "", document_id)
if not filtered_document_id:
raise ValueError(f"Document ID {document_id} is empty after filtering.")
if len(filtered_document_id.encode("utf-8")) >= 512:
raise ValueError(f"Document ID {document_id} is too long after filtering.")
if len(filtered_document_id.encode("utf-8")) >= max_encoded_length:
raise DocumentIDTooLongError(
f"Document ID {document_id} is too long after filtering."
)
return filtered_document_id

View File

@@ -136,12 +136,14 @@ class FileStore(ABC):
"""
@abstractmethod
def delete_file(self, file_id: str) -> None:
def delete_file(self, file_id: str, error_on_missing: bool = True) -> None:
"""
Delete a file by its ID.
Parameters:
- file_name: Name of file to delete
- file_id: ID of file to delete
- error_on_missing: If False, silently return when the file record
does not exist instead of raising.
"""
@abstractmethod
@@ -452,12 +454,23 @@ class S3BackedFileStore(FileStore):
logger.warning(f"Error getting file size for {file_id}: {e}")
return None
def delete_file(self, file_id: str, db_session: Session | None = None) -> None:
def delete_file(
self,
file_id: str,
error_on_missing: bool = True,
db_session: Session | None = None,
) -> None:
with get_session_with_current_tenant_if_none(db_session) as db_session:
try:
file_record = get_filerecord_by_file_id(
file_record = get_filerecord_by_file_id_optional(
file_id=file_id, db_session=db_session
)
if file_record is None:
if error_on_missing:
raise RuntimeError(
f"File by id {file_id} does not exist or was deleted"
)
return
if not file_record.bucket_name:
logger.error(
f"File record {file_id} with key {file_record.object_key} "

View File

@@ -222,12 +222,23 @@ class PostgresBackedFileStore(FileStore):
logger.warning(f"Error getting file size for {file_id}: {e}")
return None
def delete_file(self, file_id: str, db_session: Session | None = None) -> None:
def delete_file(
self,
file_id: str,
error_on_missing: bool = True,
db_session: Session | None = None,
) -> None:
with get_session_with_current_tenant_if_none(db_session) as session:
try:
file_content = get_file_content_by_file_id(
file_content = get_file_content_by_file_id_optional(
file_id=file_id, db_session=session
)
if file_content is None:
if error_on_missing:
raise RuntimeError(
f"File content for file_id {file_id} does not exist or was deleted"
)
return
raw_conn = _get_raw_connection(session)
try:

View File

@@ -3,6 +3,8 @@
from datetime import datetime
from typing import Any
import httpx
from onyx.configs.constants import DocumentSource
from onyx.mcp_server.api import mcp_server
from onyx.mcp_server.utils import get_http_client
@@ -15,6 +17,21 @@ from onyx.utils.variable_functionality import global_version
logger = setup_logger()
def _extract_error_detail(response: httpx.Response) -> str:
"""Extract a human-readable error message from a failed backend response.
The backend returns OnyxError responses as
``{"error_code": "...", "detail": "..."}``.
"""
try:
body = response.json()
if detail := body.get("detail"):
return str(detail)
except Exception:
pass
return f"Request failed with status {response.status_code}"
@mcp_server.tool()
async def search_indexed_documents(
query: str,
@@ -158,7 +175,14 @@ async def search_indexed_documents(
json=search_request,
headers=auth_headers,
)
response.raise_for_status()
if not response.is_success:
error_detail = _extract_error_detail(response)
return {
"documents": [],
"total_results": 0,
"query": query,
"error": error_detail,
}
result = response.json()
# Check for error in response
@@ -234,7 +258,13 @@ async def search_web(
json=request_payload,
headers={"Authorization": f"Bearer {access_token.token}"},
)
response.raise_for_status()
if not response.is_success:
error_detail = _extract_error_detail(response)
return {
"error": error_detail,
"results": [],
"query": query,
}
response_payload = response.json()
results = response_payload.get("results", [])
return {
@@ -280,7 +310,12 @@ async def open_urls(
json={"urls": urls},
headers={"Authorization": f"Bearer {access_token.token}"},
)
response.raise_for_status()
if not response.is_success:
error_detail = _extract_error_detail(response)
return {
"error": error_detail,
"results": [],
}
response_payload = response.json()
results = response_payload.get("results", [])
return {

View File

@@ -3,10 +3,10 @@ import datetime
from slack_sdk import WebClient
from slack_sdk.errors import SlackApiError
from onyx.auth.schemas import UserRole
from onyx.configs.onyxbot_configs import ONYX_BOT_FEEDBACK_REMINDER
from onyx.configs.onyxbot_configs import ONYX_BOT_REACT_EMOJI
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import AccountType
from onyx.db.models import SlackChannelConfig
from onyx.db.user_preferences import activate_user
from onyx.db.users import add_slack_user_if_not_exists
@@ -247,7 +247,7 @@ def handle_message(
elif (
not existing_user.is_active
and existing_user.role == UserRole.SLACK_USER
and existing_user.account_type == AccountType.BOT
):
check_seat_fn = fetch_ee_implementation_or_noop(
"onyx.db.license",

View File

@@ -1,6 +1,5 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from onyx.auth.users import current_user
@@ -9,6 +8,8 @@ from onyx.db.engine.sql_engine import get_session
from onyx.db.models import User
from onyx.db.web_search import fetch_active_web_content_provider
from onyx.db.web_search import fetch_active_web_search_provider
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
from onyx.server.features.web_search.models import OpenUrlsToolRequest
from onyx.server.features.web_search.models import OpenUrlsToolResponse
from onyx.server.features.web_search.models import WebSearchToolRequest
@@ -61,9 +62,10 @@ def _get_active_search_provider(
) -> tuple[WebSearchProviderView, WebSearchProvider]:
provider_model = fetch_active_web_search_provider(db_session)
if provider_model is None:
raise HTTPException(
status_code=400,
detail="No web search provider configured.",
raise OnyxError(
OnyxErrorCode.INVALID_INPUT,
"No web search provider configured. Please configure one in "
"Admin > Web Search settings.",
)
provider_view = WebSearchProviderView(
@@ -76,9 +78,10 @@ def _get_active_search_provider(
)
if provider_model.api_key is None:
raise HTTPException(
status_code=400,
detail="Web search provider requires an API key.",
raise OnyxError(
OnyxErrorCode.INVALID_INPUT,
"Web search provider requires an API key. Please configure one in "
"Admin > Web Search settings.",
)
try:
@@ -88,7 +91,7 @@ def _get_active_search_provider(
config=provider_model.config or {},
)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(exc)) from exc
return provider_view, provider
@@ -110,9 +113,9 @@ def _get_active_content_provider(
if provider_model.api_key is None:
# TODO - this is not a great error, in fact, this key should not be nullable.
raise HTTPException(
status_code=400,
detail="Web content provider requires an API key.",
raise OnyxError(
OnyxErrorCode.INVALID_INPUT,
"Web content provider requires an API key.",
)
try:
@@ -125,12 +128,12 @@ def _get_active_content_provider(
config=config,
)
except ValueError as exc:
raise HTTPException(status_code=400, detail=str(exc)) from exc
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(exc)) from exc
if provider is None:
raise HTTPException(
status_code=400,
detail="Unable to initialize the configured web content provider.",
raise OnyxError(
OnyxErrorCode.INVALID_INPUT,
"Unable to initialize the configured web content provider.",
)
provider_view = WebContentProviderView(
@@ -154,12 +157,13 @@ def _run_web_search(
for query in request.queries:
try:
search_results = provider.search(query)
except HTTPException:
except OnyxError:
raise
except Exception as exc:
logger.exception("Web search provider failed for query '%s'", query)
raise HTTPException(
status_code=502, detail="Web search provider failed to execute query."
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY,
"Web search provider failed to execute query.",
) from exc
filtered_results = filter_web_search_results_with_no_title_or_snippet(
@@ -192,12 +196,13 @@ def _open_urls(
docs = filter_web_contents_with_no_title_or_content(
list(provider.contents(urls))
)
except HTTPException:
except OnyxError:
raise
except Exception as exc:
logger.exception("Web content provider failed to fetch URLs")
raise HTTPException(
status_code=502, detail="Web content provider failed to fetch URLs."
raise OnyxError(
OnyxErrorCode.BAD_GATEWAY,
"Web content provider failed to fetch URLs.",
) from exc
results: list[LlmOpenUrlResult] = []

View File

@@ -27,6 +27,7 @@ from onyx.auth.email_utils import send_user_email_invite
from onyx.auth.invited_users import get_invited_users
from onyx.auth.invited_users import remove_user_from_invited_users
from onyx.auth.invited_users import write_invited_users
from onyx.auth.permissions import get_effective_permissions
from onyx.auth.schemas import UserRole
from onyx.auth.users import anonymous_user_enabled
from onyx.auth.users import current_admin_user
@@ -50,6 +51,7 @@ from onyx.configs.constants import PUBLIC_API_TAGS
from onyx.db.api_key import is_api_key_email_address
from onyx.db.auth import get_live_users_count
from onyx.db.engine.sql_engine import get_session
from onyx.db.enums import AccountType
from onyx.db.enums import UserFileStatus
from onyx.db.models import User
from onyx.db.models import UserFile
@@ -142,6 +144,7 @@ def set_user_role(
validate_user_role_update(
requested_role=requested_role,
current_role=current_role,
current_account_type=user_to_update.account_type,
explicit_override=user_role_update_request.explicit_override,
)
@@ -327,8 +330,8 @@ def list_all_users(
if (include_api_keys or not is_api_key_email_address(user.email))
]
slack_users = [user for user in users if user.role == UserRole.SLACK_USER]
accepted_users = [user for user in users if user.role != UserRole.SLACK_USER]
slack_users = [user for user in users if user.account_type == AccountType.BOT]
accepted_users = [user for user in users if user.account_type != AccountType.BOT]
accepted_emails = {user.email for user in accepted_users}
slack_users_emails = {user.email for user in slack_users}
@@ -671,7 +674,7 @@ def list_all_users_basic_info(
return [
MinimalUserSnapshot(id=user.id, email=user.email)
for user in users
if user.role != UserRole.SLACK_USER
if user.account_type != AccountType.BOT
and (include_api_keys or not is_api_key_email_address(user.email))
]
@@ -774,6 +777,13 @@ def _get_token_created_at(
return get_current_token_creation_postgres(user, db_session)
@router.get("/me/permissions", tags=PUBLIC_API_TAGS)
def get_current_user_permissions(
user: User = Depends(current_user),
) -> list[str]:
return sorted(p.value for p in get_effective_permissions(user))
@router.get("/me", tags=PUBLIC_API_TAGS)
def verify_user_logged_in(
request: Request,

View File

@@ -7,6 +7,7 @@ from uuid import UUID
from pydantic import BaseModel
from onyx.auth.schemas import UserRole
from onyx.db.enums import AccountType
from onyx.db.models import User
@@ -41,6 +42,7 @@ class FullUserSnapshot(BaseModel):
id: UUID
email: str
role: UserRole
account_type: AccountType
is_active: bool
password_configured: bool
personal_name: str | None
@@ -60,6 +62,7 @@ class FullUserSnapshot(BaseModel):
id=user.id,
email=user.email,
role=user.role,
account_type=user.account_type,
is_active=user.is_active,
password_configured=user.password_configured,
personal_name=user.personal_name,

View File

@@ -70,7 +70,7 @@ async def upsert_saml_user(email: str) -> User:
try:
user = await user_manager.get_by_email(email)
# If user has a non-authenticated role, treat as non-existent
if not user.role.is_web_login():
if not user.account_type.is_web_login():
raise exceptions.UserNotExists()
return user
except exceptions.UserNotExists:

View File

@@ -7,6 +7,7 @@ from sqlalchemy.orm import Session
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import SqlEngine
from onyx.db.enums import AccountType
from onyx.db.models import User
from onyx.db.models import UserRole
from onyx.file_store.file_store import get_default_file_store
@@ -52,7 +53,12 @@ def tenant_context() -> Generator[None, None, None]:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def create_test_user(db_session: Session, email_prefix: str) -> User:
def create_test_user(
db_session: Session,
email_prefix: str,
role: UserRole = UserRole.BASIC,
account_type: AccountType = AccountType.STANDARD,
) -> User:
"""Helper to create a test user with a unique email"""
# Use UUID to ensure unique email addresses
unique_email = f"{email_prefix}_{uuid4().hex[:8]}@example.com"
@@ -68,7 +74,8 @@ def create_test_user(db_session: Session, email_prefix: str) -> User:
is_active=True,
is_superuser=False,
is_verified=True,
role=UserRole.EXT_PERM_USER,
role=role,
account_type=account_type,
)
db_session.add(user)
db_session.commit()

View File

@@ -13,16 +13,29 @@ from onyx.access.utils import build_ext_group_name_for_onyx
from onyx.configs.constants import DocumentSource
from onyx.connectors.models import InputType
from onyx.db.enums import AccessType
from onyx.db.enums import AccountType
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.models import Connector
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import Credential
from onyx.db.models import PublicExternalUserGroup
from onyx.db.models import User
from onyx.db.models import User__ExternalUserGroupId
from onyx.db.models import UserRole
from tests.external_dependency_unit.conftest import create_test_user
from tests.external_dependency_unit.constants import TEST_TENANT_ID
def _create_ext_perm_user(db_session: Session, name: str) -> User:
"""Create an external-permission user for group sync tests."""
return create_test_user(
db_session,
name,
role=UserRole.EXT_PERM_USER,
account_type=AccountType.EXT_PERM_USER,
)
def _create_test_connector_credential_pair(
db_session: Session, source: DocumentSource = DocumentSource.GOOGLE_DRIVE
) -> ConnectorCredentialPair:
@@ -100,9 +113,9 @@ class TestPerformExternalGroupSync:
def test_initial_group_sync(self, db_session: Session) -> None:
"""Test syncing external groups for the first time (initial sync)"""
# Create test data
user1 = create_test_user(db_session, "user1")
user2 = create_test_user(db_session, "user2")
user3 = create_test_user(db_session, "user3")
user1 = _create_ext_perm_user(db_session, "user1")
user2 = _create_ext_perm_user(db_session, "user2")
user3 = _create_ext_perm_user(db_session, "user3")
cc_pair = _create_test_connector_credential_pair(db_session)
# Mock external groups data as a generator that yields the expected groups
@@ -175,9 +188,9 @@ class TestPerformExternalGroupSync:
def test_update_existing_groups(self, db_session: Session) -> None:
"""Test updating existing groups (adding/removing users)"""
# Create test data
user1 = create_test_user(db_session, "user1")
user2 = create_test_user(db_session, "user2")
user3 = create_test_user(db_session, "user3")
user1 = _create_ext_perm_user(db_session, "user1")
user2 = _create_ext_perm_user(db_session, "user2")
user3 = _create_ext_perm_user(db_session, "user3")
cc_pair = _create_test_connector_credential_pair(db_session)
# Initial sync with original groups
@@ -272,8 +285,8 @@ class TestPerformExternalGroupSync:
def test_remove_groups(self, db_session: Session) -> None:
"""Test removing groups (groups that no longer exist in external system)"""
# Create test data
user1 = create_test_user(db_session, "user1")
user2 = create_test_user(db_session, "user2")
user1 = _create_ext_perm_user(db_session, "user1")
user2 = _create_ext_perm_user(db_session, "user2")
cc_pair = _create_test_connector_credential_pair(db_session)
# Initial sync with multiple groups
@@ -357,7 +370,7 @@ class TestPerformExternalGroupSync:
def test_empty_group_sync(self, db_session: Session) -> None:
"""Test syncing when no groups are returned (all groups removed)"""
# Create test data
user1 = create_test_user(db_session, "user1")
user1 = _create_ext_perm_user(db_session, "user1")
cc_pair = _create_test_connector_credential_pair(db_session)
# Initial sync with groups
@@ -413,7 +426,7 @@ class TestPerformExternalGroupSync:
# Create many test users
users = []
for i in range(150): # More than the batch size of 100
users.append(create_test_user(db_session, f"user{i}"))
users.append(_create_ext_perm_user(db_session, f"user{i}"))
cc_pair = _create_test_connector_credential_pair(db_session)
@@ -452,8 +465,8 @@ class TestPerformExternalGroupSync:
def test_mixed_regular_and_public_groups(self, db_session: Session) -> None:
"""Test syncing a mix of regular and public groups"""
# Create test data
user1 = create_test_user(db_session, "user1")
user2 = create_test_user(db_session, "user2")
user1 = _create_ext_perm_user(db_session, "user1")
user2 = _create_ext_perm_user(db_session, "user2")
cc_pair = _create_test_connector_credential_pair(db_session)
def mixed_group_sync_func(

View File

@@ -9,6 +9,7 @@ from sqlalchemy.orm import Session
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.engine.sql_engine import SqlEngine
from onyx.db.enums import AccountType
from onyx.db.enums import BuildSessionStatus
from onyx.db.models import BuildSession
from onyx.db.models import User
@@ -52,6 +53,7 @@ def test_user(db_session: Session, tenant_context: None) -> User: # noqa: ARG00
is_superuser=False,
is_verified=True,
role=UserRole.EXT_PERM_USER,
account_type=AccountType.EXT_PERM_USER,
)
db_session.add(user)
db_session.commit()

View File

@@ -0,0 +1,51 @@
"""
Tests that account_type is correctly set when creating users through
the internal DB functions: add_slack_user_if_not_exists and
batch_add_ext_perm_user_if_not_exists.
These functions are called by background workers (Slack bot, permission sync)
and are not exposed via API endpoints, so they must be tested directly.
"""
from sqlalchemy.orm import Session
from onyx.db.enums import AccountType
from onyx.db.models import UserRole
from onyx.db.users import add_slack_user_if_not_exists
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
def test_slack_user_creation_sets_account_type_bot(db_session: Session) -> None:
"""add_slack_user_if_not_exists sets account_type=BOT and role=SLACK_USER."""
user = add_slack_user_if_not_exists(db_session, "slack_acct_type@test.com")
assert user.role == UserRole.SLACK_USER
assert user.account_type == AccountType.BOT
def test_ext_perm_user_creation_sets_account_type(db_session: Session) -> None:
"""batch_add_ext_perm_user_if_not_exists sets account_type=EXT_PERM_USER."""
users = batch_add_ext_perm_user_if_not_exists(
db_session, ["extperm_acct_type@test.com"]
)
assert len(users) == 1
user = users[0]
assert user.role == UserRole.EXT_PERM_USER
assert user.account_type == AccountType.EXT_PERM_USER
def test_ext_perm_to_slack_upgrade_updates_role_and_account_type(
db_session: Session,
) -> None:
"""When an EXT_PERM_USER is upgraded to slack, both role and account_type update."""
email = "ext_to_slack_acct_type@test.com"
# Create as ext_perm user first
batch_add_ext_perm_user_if_not_exists(db_session, [email])
# Now "upgrade" via slack path
user = add_slack_user_if_not_exists(db_session, email)
assert user.role == UserRole.SLACK_USER
assert user.account_type == AccountType.BOT

View File

@@ -8,6 +8,7 @@ import pytest
from fastapi_users.password import PasswordHelper
from sqlalchemy.orm import Session
from onyx.db.enums import AccountType
from onyx.db.llm import fetch_existing_llm_provider
from onyx.db.llm import remove_llm_provider
from onyx.db.llm import update_default_provider
@@ -46,6 +47,7 @@ def _create_admin(db_session: Session) -> User:
is_superuser=True,
is_verified=True,
role=UserRole.ADMIN,
account_type=AccountType.STANDARD,
)
db_session.add(user)
db_session.commit()

View File

@@ -126,6 +126,15 @@ class UserManager:
return test_user
@staticmethod
def get_permissions(user: DATestUser) -> list[str]:
response = requests.get(
url=f"{API_SERVER_URL}/me/permissions",
headers=user.headers,
)
response.raise_for_status()
return response.json()
@staticmethod
def is_role(
user_to_verify: DATestUser,

View File

@@ -104,13 +104,30 @@ class UserGroupManager:
)
response.raise_for_status()
@staticmethod
def get_permissions(
user_group: DATestUserGroup,
user_performing_action: DATestUser,
) -> list[str]:
response = requests.get(
f"{API_SERVER_URL}/manage/admin/user-group/{user_group.id}/permissions",
headers=user_performing_action.headers,
)
response.raise_for_status()
return response.json()
@staticmethod
def get_all(
user_performing_action: DATestUser,
include_default: bool = False,
) -> list[UserGroup]:
params: dict[str, str] = {}
if include_default:
params["include_default"] = "true"
response = requests.get(
f"{API_SERVER_URL}/manage/admin/user-group",
headers=user_performing_action.headers,
params=params,
)
response.raise_for_status()
return [UserGroup(**ug) for ug in response.json()]

View File

@@ -1,9 +1,13 @@
from uuid import UUID
import requests
from onyx.auth.schemas import UserRole
from onyx.db.enums import AccountType
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.managers.api_key import APIKeyManager
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user_group import UserGroupManager
from tests.integration.common_utils.test_models import DATestAPIKey
from tests.integration.common_utils.test_models import DATestUser
@@ -33,3 +37,120 @@ def test_limited(reset: None) -> None: # noqa: ARG001
headers=api_key.headers,
)
assert response.status_code == 403
def _get_service_account_account_type(
admin_user: DATestUser,
api_key_user_id: UUID,
) -> AccountType:
"""Fetch the account_type of a service account user via the user listing API."""
response = requests.get(
f"{API_SERVER_URL}/manage/users",
headers=admin_user.headers,
params={"include_api_keys": "true"},
)
response.raise_for_status()
data = response.json()
user_id_str = str(api_key_user_id)
for user in data["accepted"]:
if user["id"] == user_id_str:
return AccountType(user["account_type"])
raise AssertionError(
f"Service account user {user_id_str} not found in user listing"
)
def _get_default_group_user_ids(
admin_user: DATestUser,
) -> tuple[set[str], set[str]]:
"""Return (admin_group_user_ids, basic_group_user_ids) from default groups."""
all_groups = UserGroupManager.get_all(
user_performing_action=admin_user,
include_default=True,
)
admin_group = next(
(g for g in all_groups if g.name == "Admin" and g.is_default), None
)
basic_group = next(
(g for g in all_groups if g.name == "Basic" and g.is_default), None
)
assert admin_group is not None, "Admin default group not found"
assert basic_group is not None, "Basic default group not found"
admin_ids = {str(u.id) for u in admin_group.users}
basic_ids = {str(u.id) for u in basic_group.users}
return admin_ids, basic_ids
def test_api_key_limited_service_account(reset: None) -> None: # noqa: ARG001
"""LIMITED role API key: account_type is SERVICE_ACCOUNT, no group membership."""
admin_user: DATestUser = UserManager.create(name="admin_user")
api_key: DATestAPIKey = APIKeyManager.create(
api_key_role=UserRole.LIMITED,
user_performing_action=admin_user,
)
# Verify account_type
account_type = _get_service_account_account_type(admin_user, api_key.user_id)
assert (
account_type == AccountType.SERVICE_ACCOUNT
), f"Expected account_type={AccountType.SERVICE_ACCOUNT}, got {account_type}"
# Verify no group membership
admin_ids, basic_ids = _get_default_group_user_ids(admin_user)
user_id_str = str(api_key.user_id)
assert (
user_id_str not in admin_ids
), "LIMITED API key should NOT be in Admin default group"
assert (
user_id_str not in basic_ids
), "LIMITED API key should NOT be in Basic default group"
def test_api_key_basic_service_account(reset: None) -> None: # noqa: ARG001
"""BASIC role API key: account_type is SERVICE_ACCOUNT, in Basic group only."""
admin_user: DATestUser = UserManager.create(name="admin_user")
api_key: DATestAPIKey = APIKeyManager.create(
api_key_role=UserRole.BASIC,
user_performing_action=admin_user,
)
# Verify account_type
account_type = _get_service_account_account_type(admin_user, api_key.user_id)
assert (
account_type == AccountType.SERVICE_ACCOUNT
), f"Expected account_type={AccountType.SERVICE_ACCOUNT}, got {account_type}"
# Verify Basic group membership
admin_ids, basic_ids = _get_default_group_user_ids(admin_user)
user_id_str = str(api_key.user_id)
assert user_id_str in basic_ids, "BASIC API key should be in Basic default group"
assert (
user_id_str not in admin_ids
), "BASIC API key should NOT be in Admin default group"
def test_api_key_admin_service_account(reset: None) -> None: # noqa: ARG001
"""ADMIN role API key: account_type is SERVICE_ACCOUNT, in Admin group only."""
admin_user: DATestUser = UserManager.create(name="admin_user")
api_key: DATestAPIKey = APIKeyManager.create(
api_key_role=UserRole.ADMIN,
user_performing_action=admin_user,
)
# Verify account_type
account_type = _get_service_account_account_type(admin_user, api_key.user_id)
assert (
account_type == AccountType.SERVICE_ACCOUNT
), f"Expected account_type={AccountType.SERVICE_ACCOUNT}, got {account_type}"
# Verify Admin group membership
admin_ids, basic_ids = _get_default_group_user_ids(admin_user)
user_id_str = str(api_key.user_id)
assert user_id_str in admin_ids, "ADMIN API key should be in Admin default group"
assert (
user_id_str not in basic_ids
), "ADMIN API key should NOT be in Basic default group"

View File

@@ -4,11 +4,32 @@ import pytest
import requests
from onyx.auth.schemas import UserRole
from onyx.db.enums import AccountType
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user_group import UserGroupManager
from tests.integration.common_utils.test_models import DATestUser
def _simulate_saml_login(email: str, admin_user: DATestUser) -> dict:
"""Simulate a SAML login by calling the test upsert endpoint."""
response = requests.post(
f"{API_SERVER_URL}/manage/users/test-upsert-user",
json={"email": email},
headers=admin_user.headers,
)
response.raise_for_status()
return response.json()
def _get_basic_group_member_emails(admin_user: DATestUser) -> set[str]:
"""Get the set of emails of all members in the Basic default group."""
all_groups = UserGroupManager.get_all(admin_user, include_default=True)
basic_default = [g for g in all_groups if g.is_default and g.name == "Basic"]
assert basic_default, "Basic default group not found"
return {u.email for u in basic_default[0].users}
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="SAML tests are enterprise only",
@@ -49,15 +70,9 @@ def test_saml_user_conversion(reset: None) -> None: # noqa: ARG001
assert UserManager.is_role(test_user, UserRole.EXT_PERM_USER)
# Simulate SAML login by calling the test endpoint
response = requests.post(
f"{API_SERVER_URL}/manage/users/test-upsert-user",
json={"email": test_user_email},
headers=admin_user.headers, # Use admin headers for authorization
)
response.raise_for_status()
user_data = _simulate_saml_login(test_user_email, admin_user)
# Verify the response indicates the role changed to BASIC
user_data = response.json()
assert user_data["role"] == UserRole.BASIC.value
# Verify user role was changed in the database
@@ -82,16 +97,237 @@ def test_saml_user_conversion(reset: None) -> None: # noqa: ARG001
assert UserManager.is_role(slack_user, UserRole.SLACK_USER)
# Simulate SAML login again
response = requests.post(
f"{API_SERVER_URL}/manage/users/test-upsert-user",
json={"email": slack_user_email},
headers=admin_user.headers,
)
response.raise_for_status()
user_data = _simulate_saml_login(slack_user_email, admin_user)
# Verify the response indicates the role changed to BASIC
user_data = response.json()
assert user_data["role"] == UserRole.BASIC.value
# Verify the user's role was changed in the database
assert UserManager.is_role(slack_user, UserRole.BASIC)
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="SAML tests are enterprise only",
)
def test_saml_user_conversion_sets_account_type_and_group(
reset: None, # noqa: ARG001
) -> None:
"""
Test that SAML login sets account_type to STANDARD when converting a
non-web user (EXT_PERM_USER) and that the user receives the correct role
(BASIC) after conversion.
This validates the permissions-migration-phase2 changes which ensure that:
1. account_type is updated to 'standard' on SAML conversion
2. The converted user is assigned to the Basic default group
"""
# Create an admin user (first user is automatically admin)
admin_user: DATestUser = UserManager.create(email="admin@example.com")
# Create a user and set them as EXT_PERM_USER
test_email = "ext_convert@example.com"
test_user = UserManager.create(email=test_email)
UserManager.set_role(
user_to_set=test_user,
target_role=UserRole.EXT_PERM_USER,
user_performing_action=admin_user,
explicit_override=True,
)
assert UserManager.is_role(test_user, UserRole.EXT_PERM_USER)
# Simulate SAML login
user_data = _simulate_saml_login(test_email, admin_user)
# Verify account_type is set to standard after conversion
assert (
user_data["account_type"] == AccountType.STANDARD.value
), f"Expected account_type='{AccountType.STANDARD.value}', got '{user_data['account_type']}'"
# Verify role is BASIC after conversion
assert user_data["role"] == UserRole.BASIC.value
# Verify the user was assigned to the Basic default group
assert test_email in _get_basic_group_member_emails(
admin_user
), f"Converted user '{test_email}' not found in Basic default group"
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="SAML tests are enterprise only",
)
def test_saml_normal_signin_assigns_group(
reset: None, # noqa: ARG001
) -> None:
"""
Test that a brand-new user signing in via SAML for the first time
is created with the correct role, account_type, and group membership.
This validates that normal SAML sign-in (not an upgrade from
SLACK_USER/EXT_PERM_USER) correctly:
1. Creates the user with role=BASIC and account_type=STANDARD
2. Assigns the user to the Basic default group
"""
# First user becomes admin
admin_user: DATestUser = UserManager.create(email="admin@example.com")
# New user signs in via SAML (no prior account)
new_email = "new_saml_user@example.com"
user_data = _simulate_saml_login(new_email, admin_user)
# Verify role and account_type
assert user_data["role"] == UserRole.BASIC.value
assert user_data["account_type"] == AccountType.STANDARD.value
# Verify user is in the Basic default group
assert new_email in _get_basic_group_member_emails(
admin_user
), f"New SAML user '{new_email}' not found in Basic default group"
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="SAML tests are enterprise only",
)
def test_saml_user_conversion_restores_group_membership(
reset: None, # noqa: ARG001
) -> None:
"""
Test that SAML login restores Basic group membership when converting
a non-authenticated user (EXT_PERM_USER or SLACK_USER) to BASIC.
Group membership implies 'basic' permission (verified by
test_new_group_gets_basic_permission).
"""
admin_user: DATestUser = UserManager.create(email="admin@example.com")
# --- EXT_PERM_USER path ---
ext_email = "ext_perm_perms@example.com"
ext_user = UserManager.create(email=ext_email)
assert ext_email in _get_basic_group_member_emails(admin_user)
UserManager.set_role(
user_to_set=ext_user,
target_role=UserRole.EXT_PERM_USER,
user_performing_action=admin_user,
explicit_override=True,
)
assert ext_email not in _get_basic_group_member_emails(admin_user)
user_data = _simulate_saml_login(ext_email, admin_user)
assert user_data["role"] == UserRole.BASIC.value
assert ext_email in _get_basic_group_member_emails(
admin_user
), "EXT_PERM_USER should be back in Basic group after SAML conversion"
# --- SLACK_USER path ---
slack_email = "slack_perms@example.com"
slack_user = UserManager.create(email=slack_email)
UserManager.set_role(
user_to_set=slack_user,
target_role=UserRole.SLACK_USER,
user_performing_action=admin_user,
explicit_override=True,
)
assert slack_email not in _get_basic_group_member_emails(admin_user)
user_data = _simulate_saml_login(slack_email, admin_user)
assert user_data["role"] == UserRole.BASIC.value
assert slack_email in _get_basic_group_member_emails(
admin_user
), "SLACK_USER should be back in Basic group after SAML conversion"
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="SAML tests are enterprise only",
)
def test_saml_round_trip_group_lifecycle(
reset: None, # noqa: ARG001
) -> None:
"""
Test the full round-trip: BASIC -> EXT_PERM -> SAML(BASIC) -> EXT_PERM -> SAML(BASIC).
Verifies group membership is correctly removed and restored at each transition.
"""
admin_user: DATestUser = UserManager.create(email="admin@example.com")
test_email = "roundtrip@example.com"
test_user = UserManager.create(email=test_email)
# Step 1: BASIC user is in Basic group
assert test_email in _get_basic_group_member_emails(admin_user)
# Step 2: Downgrade to EXT_PERM_USER — loses Basic group
UserManager.set_role(
user_to_set=test_user,
target_role=UserRole.EXT_PERM_USER,
user_performing_action=admin_user,
explicit_override=True,
)
assert test_email not in _get_basic_group_member_emails(admin_user)
# Step 3: SAML login — converts back to BASIC, regains Basic group
_simulate_saml_login(test_email, admin_user)
assert test_email in _get_basic_group_member_emails(
admin_user
), "Should be in Basic group after first SAML conversion"
# Step 4: Downgrade again
UserManager.set_role(
user_to_set=test_user,
target_role=UserRole.EXT_PERM_USER,
user_performing_action=admin_user,
explicit_override=True,
)
assert test_email not in _get_basic_group_member_emails(admin_user)
# Step 5: SAML login again — should still restore correctly
_simulate_saml_login(test_email, admin_user)
assert test_email in _get_basic_group_member_emails(
admin_user
), "Should be in Basic group after second SAML conversion"
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="SAML tests are enterprise only",
)
def test_saml_slack_user_conversion_sets_account_type_and_group(
reset: None, # noqa: ARG001
) -> None:
"""
Test that SAML login sets account_type to STANDARD and assigns Basic group
when converting a SLACK_USER (BOT account_type).
Mirrors test_saml_user_conversion_sets_account_type_and_group but for
SLACK_USER instead of EXT_PERM_USER, and additionally verifies permissions.
"""
admin_user: DATestUser = UserManager.create(email="admin@example.com")
test_email = "slack_convert@example.com"
test_user = UserManager.create(email=test_email)
UserManager.set_role(
user_to_set=test_user,
target_role=UserRole.SLACK_USER,
user_performing_action=admin_user,
explicit_override=True,
)
assert UserManager.is_role(test_user, UserRole.SLACK_USER)
# SAML login
user_data = _simulate_saml_login(test_email, admin_user)
# Verify account_type and role
assert (
user_data["account_type"] == AccountType.STANDARD.value
), f"Expected STANDARD, got {user_data['account_type']}"
assert user_data["role"] == UserRole.BASIC.value
# Verify Basic group membership (implies 'basic' permission)
assert test_email in _get_basic_group_member_emails(
admin_user
), f"Converted SLACK_USER '{test_email}' not found in Basic default group"

View File

@@ -0,0 +1,82 @@
"""Integration tests for permission propagation across auth-triggered group changes.
These tests verify that effective permissions (via /me/permissions) actually
propagate when users are added/removed from default groups through role changes.
Custom permission grant tests will be added once the permission grant API is built.
"""
import os
import pytest
from onyx.auth.schemas import UserRole
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user_group import UserGroupManager
from tests.integration.common_utils.test_models import DATestUser
def _get_basic_group_member_emails(admin_user: DATestUser) -> set[str]:
all_groups = UserGroupManager.get_all(admin_user, include_default=True)
basic_group = next(
(g for g in all_groups if g.is_default and g.name == "Basic"), None
)
assert basic_group is not None, "Basic default group not found"
return {u.email for u in basic_group.users}
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="Permission propagation tests require enterprise features",
)
def test_basic_permission_granted_on_registration(
reset: None, # noqa: ARG001
) -> None:
"""New users should get 'basic' permission through default group assignment."""
admin_user: DATestUser = UserManager.create(email="admin@example.com")
basic_user: DATestUser = UserManager.create(email="basic@example.com")
# Admin should have permissions from Admin group
admin_perms = UserManager.get_permissions(admin_user)
assert "basic" in admin_perms
# Basic user should have 'basic' from Basic default group
basic_perms = UserManager.get_permissions(basic_user)
assert "basic" in basic_perms
# Verify group membership matches
assert basic_user.email in _get_basic_group_member_emails(admin_user)
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="Permission propagation tests require enterprise features",
)
def test_role_downgrade_removes_basic_group_and_permission(
reset: None, # noqa: ARG001
) -> None:
"""Downgrading to EXT_PERM_USER or SLACK_USER should remove from Basic group."""
admin_user: DATestUser = UserManager.create(email="admin@example.com")
# --- EXT_PERM_USER ---
ext_user: DATestUser = UserManager.create(email="ext@example.com")
assert ext_user.email in _get_basic_group_member_emails(admin_user)
UserManager.set_role(
user_to_set=ext_user,
target_role=UserRole.EXT_PERM_USER,
user_performing_action=admin_user,
explicit_override=True,
)
assert ext_user.email not in _get_basic_group_member_emails(admin_user)
# --- SLACK_USER ---
slack_user: DATestUser = UserManager.create(email="slack@example.com")
assert slack_user.email in _get_basic_group_member_emails(admin_user)
UserManager.set_role(
user_to_set=slack_user,
target_role=UserRole.SLACK_USER,
user_performing_action=admin_user,
explicit_override=True,
)
assert slack_user.email not in _get_basic_group_member_emails(admin_user)

View File

@@ -21,8 +21,15 @@ import pytest
import requests
from onyx.auth.schemas import UserRole
from tests.integration.common_utils.constants import ADMIN_USER_NAME
from tests.integration.common_utils.constants import API_SERVER_URL
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.managers.scim_client import ScimClient
from tests.integration.common_utils.managers.scim_token import ScimTokenManager
from tests.integration.common_utils.managers.user import build_email
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestUser
SCIM_GROUP_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Group"
@@ -44,13 +51,6 @@ def scim_token(idp_style: str) -> str:
per IdP-style run and reuse. Uses UserManager directly to avoid
fixture-scope conflicts with the function-scoped admin_user fixture.
"""
from tests.integration.common_utils.constants import ADMIN_USER_NAME
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.managers.user import build_email
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestUser
try:
admin = UserManager.create(name=ADMIN_USER_NAME)
except Exception:
@@ -550,3 +550,145 @@ def test_patch_add_duplicate_member_is_idempotent(
)
assert resp.status_code == 200
assert len(resp.json()["members"]) == 1 # still just one member
def test_create_group_reserved_name_admin(scim_token: str) -> None:
"""POST /Groups with reserved name 'Admin' returns 409."""
resp = _create_scim_group(scim_token, "Admin", external_id="ext-reserved-admin")
assert resp.status_code == 409
assert "reserved" in resp.json()["detail"].lower()
def test_create_group_reserved_name_basic(scim_token: str) -> None:
"""POST /Groups with reserved name 'Basic' returns 409."""
resp = _create_scim_group(scim_token, "Basic", external_id="ext-reserved-basic")
assert resp.status_code == 409
assert "reserved" in resp.json()["detail"].lower()
def test_replace_group_cannot_rename_to_reserved(
scim_token: str, idp_style: str
) -> None:
"""PUT /Groups/{id} renaming a group to 'Admin' returns 409."""
created = _create_scim_group(
scim_token,
f"Rename To Reserved {idp_style}",
external_id=f"ext-rtr-{idp_style}",
).json()
resp = ScimClient.put(
f"/Groups/{created['id']}",
scim_token,
json=_make_group_resource(
display_name="Admin", external_id=f"ext-rtr-{idp_style}"
),
)
assert resp.status_code == 409
assert "reserved" in resp.json()["detail"].lower()
def test_patch_rename_to_reserved_name(scim_token: str, idp_style: str) -> None:
"""PATCH /Groups/{id} renaming a group to 'Basic' returns 409."""
created = _create_scim_group(
scim_token,
f"Patch Rename Reserved {idp_style}",
external_id=f"ext-prr-{idp_style}",
).json()
resp = ScimClient.patch(
f"/Groups/{created['id']}",
scim_token,
json=_make_patch_request(
[{"op": "replace", "path": "displayName", "value": "Basic"}],
idp_style,
),
)
assert resp.status_code == 409
assert "reserved" in resp.json()["detail"].lower()
def test_delete_reserved_group_rejected(scim_token: str) -> None:
"""DELETE /Groups/{id} on a reserved group ('Admin') returns 409."""
# Look up the reserved 'Admin' group via SCIM filter
resp = ScimClient.get('/Groups?filter=displayName eq "Admin"', scim_token)
assert resp.status_code == 200
resources = resp.json()["Resources"]
assert len(resources) >= 1, "Expected reserved 'Admin' group to exist"
admin_group_id = resources[0]["id"]
resp = ScimClient.delete(f"/Groups/{admin_group_id}", scim_token)
assert resp.status_code == 409
assert "reserved" in resp.json()["detail"].lower()
def test_scim_created_group_has_basic_permission(
scim_token: str, idp_style: str
) -> None:
"""POST /Groups assigns the 'basic' permission to the group itself."""
# Create a SCIM group (no members needed — we check the group's permissions)
resp = _create_scim_group(
scim_token,
f"Basic Perm Group {idp_style}",
external_id=f"ext-basic-perm-{idp_style}",
)
assert resp.status_code == 201
group_id = resp.json()["id"]
# Log in as the admin user (created by the scim_token fixture).
admin = DATestUser(
id="",
email=build_email(ADMIN_USER_NAME),
password=DEFAULT_PASSWORD,
headers=GENERAL_HEADERS,
role=UserRole.ADMIN,
is_active=True,
)
admin = UserManager.login_as_user(admin)
# Verify the group itself was granted the basic permission
perms_resp = requests.get(
f"{API_SERVER_URL}/manage/admin/user-group/{group_id}/permissions",
headers=admin.headers,
)
perms_resp.raise_for_status()
perms = perms_resp.json()
assert "basic" in perms, f"SCIM group should have 'basic' permission, got: {perms}"
def test_replace_group_cannot_rename_from_reserved(scim_token: str) -> None:
"""PUT /Groups/{id} renaming a reserved group ('Admin') to a non-reserved name returns 409."""
resp = ScimClient.get('/Groups?filter=displayName eq "Admin"', scim_token)
assert resp.status_code == 200
resources = resp.json()["Resources"]
assert len(resources) >= 1, "Expected reserved 'Admin' group to exist"
admin_group_id = resources[0]["id"]
resp = ScimClient.put(
f"/Groups/{admin_group_id}",
scim_token,
json=_make_group_resource(
display_name="RenamedAdmin", external_id="ext-rename-from-reserved"
),
)
assert resp.status_code == 409
assert "reserved" in resp.json()["detail"].lower()
def test_patch_rename_from_reserved_name(scim_token: str, idp_style: str) -> None:
"""PATCH /Groups/{id} renaming a reserved group ('Admin') returns 409."""
resp = ScimClient.get('/Groups?filter=displayName eq "Admin"', scim_token)
assert resp.status_code == 200
resources = resp.json()["Resources"]
assert len(resources) >= 1, "Expected reserved 'Admin' group to exist"
admin_group_id = resources[0]["id"]
resp = ScimClient.patch(
f"/Groups/{admin_group_id}",
scim_token,
json=_make_patch_request(
[{"op": "replace", "path": "displayName", "value": "RenamedAdmin"}],
idp_style,
),
)
assert resp.status_code == 409
assert "reserved" in resp.json()["detail"].lower()

View File

@@ -35,9 +35,16 @@ from onyx.auth.schemas import UserRole
from onyx.configs.app_configs import REDIS_DB_NUMBER
from onyx.configs.app_configs import REDIS_HOST
from onyx.configs.app_configs import REDIS_PORT
from onyx.db.enums import AccountType
from onyx.server.settings.models import ApplicationStatus
from tests.integration.common_utils.constants import ADMIN_USER_NAME
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.managers.scim_client import ScimClient
from tests.integration.common_utils.managers.scim_token import ScimTokenManager
from tests.integration.common_utils.managers.user import build_email
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.test_models import DATestUser
SCIM_USER_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:User"
@@ -211,6 +218,49 @@ def test_create_user(scim_token: str, idp_style: str) -> None:
_assert_entra_emails(body, email)
def test_create_user_default_group_and_account_type(
scim_token: str, idp_style: str
) -> None:
"""SCIM-provisioned users get Basic default group and STANDARD account_type."""
email = f"scim_defaults_{idp_style}@example.com"
ext_id = f"ext-defaults-{idp_style}"
resp = _create_scim_user(scim_token, email, ext_id, idp_style)
assert resp.status_code == 201
user_id = resp.json()["id"]
# --- Verify group assignment via SCIM GET ---
get_resp = ScimClient.get(f"/Users/{user_id}", scim_token)
assert get_resp.status_code == 200
groups = get_resp.json().get("groups", [])
group_names = {g["display"] for g in groups}
assert "Basic" in group_names, f"Expected 'Basic' in groups, got {group_names}"
assert "Admin" not in group_names, "SCIM user should not be in Admin group"
# --- Verify account_type via admin API ---
admin = UserManager.login_as_user(
DATestUser(
id="",
email=build_email(ADMIN_USER_NAME),
password=DEFAULT_PASSWORD,
headers=GENERAL_HEADERS,
role=UserRole.ADMIN,
is_active=True,
)
)
page = UserManager.get_user_page(
user_performing_action=admin,
search_query=email,
)
assert page.total_items >= 1
scim_user_snapshot = next((u for u in page.items if u.email == email), None)
assert (
scim_user_snapshot is not None
), f"SCIM user {email} not found in user listing"
assert (
scim_user_snapshot.account_type == AccountType.STANDARD
), f"Expected STANDARD, got {scim_user_snapshot.account_type}"
def test_get_user(scim_token: str, idp_style: str) -> None:
"""GET /Users/{id} returns the user resource with all stored fields."""
email = f"scim_get_{idp_style}@example.com"

View File

@@ -0,0 +1,118 @@
import os
import pytest
from onyx.db.engine.sql_engine import get_session_with_current_tenant
from onyx.db.enums import Permission
from onyx.db.models import PermissionGrant
from onyx.db.models import UserGroup as UserGroupModel
from onyx.db.permissions import recompute_permissions_for_group__no_commit
from onyx.db.permissions import recompute_user_permissions__no_commit
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user_group import UserGroupManager
from tests.integration.common_utils.test_models import DATestUser
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="User group tests are enterprise only",
)
def test_user_gets_permissions_when_added_to_group(
reset: None, # noqa: ARG001
) -> None:
admin_user: DATestUser = UserManager.create(name="admin_for_perm_test")
basic_user: DATestUser = UserManager.create(name="basic_user_for_perm_test")
# basic_user starts with only "basic" from the default group
initial_permissions = UserManager.get_permissions(basic_user)
assert "basic" in initial_permissions
assert "add:agents" not in initial_permissions
# Create a new group and add basic_user
group = UserGroupManager.create(
name="perm-test-group",
user_ids=[admin_user.id, basic_user.id],
user_performing_action=admin_user,
)
# Grant a non-basic permission to the group and recompute
with get_session_with_current_tenant() as db_session:
db_group = db_session.get(UserGroupModel, group.id)
assert db_group is not None
db_session.add(
PermissionGrant(
group_id=db_group.id,
permission=Permission.ADD_AGENTS,
grant_source="SYSTEM",
)
)
db_session.flush()
recompute_user_permissions__no_commit(basic_user.id, db_session)
db_session.commit()
# Verify the user gained the new permission (expanded includes read:agents)
updated_permissions = UserManager.get_permissions(basic_user)
assert (
"add:agents" in updated_permissions
), f"User should have 'add:agents' after group grant, got: {updated_permissions}"
assert (
"read:agents" in updated_permissions
), f"User should have implied 'read:agents', got: {updated_permissions}"
assert "basic" in updated_permissions
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="User group tests are enterprise only",
)
def test_group_permission_change_propagates_to_all_members(
reset: None, # noqa: ARG001
) -> None:
admin_user: DATestUser = UserManager.create(name="admin_propagate")
user_a: DATestUser = UserManager.create(name="user_a_propagate")
user_b: DATestUser = UserManager.create(name="user_b_propagate")
group = UserGroupManager.create(
name="propagate-test-group",
user_ids=[admin_user.id, user_a.id, user_b.id],
user_performing_action=admin_user,
)
# Neither user should have add:agents yet
for u in (user_a, user_b):
assert "add:agents" not in UserManager.get_permissions(u)
# Grant add:agents to the group, then batch-recompute
with get_session_with_current_tenant() as db_session:
grant = PermissionGrant(
group_id=group.id,
permission=Permission.ADD_AGENTS,
grant_source="SYSTEM",
)
db_session.add(grant)
db_session.flush()
recompute_permissions_for_group__no_commit(group.id, db_session)
db_session.commit()
# Both users should now have the permission (plus implied read:agents)
for u in (user_a, user_b):
perms = UserManager.get_permissions(u)
assert "add:agents" in perms, f"{u.id} missing add:agents: {perms}"
assert "read:agents" in perms, f"{u.id} missing implied read:agents: {perms}"
# Soft-delete the grant and recompute — permission should be removed
with get_session_with_current_tenant() as db_session:
db_grant = (
db_session.query(PermissionGrant)
.filter_by(group_id=group.id, permission=Permission.ADD_AGENTS)
.first()
)
assert db_grant is not None
db_grant.is_deleted = True
db_session.flush()
recompute_permissions_for_group__no_commit(group.id, db_session)
db_session.commit()
for u in (user_a, user_b):
perms = UserManager.get_permissions(u)
assert "add:agents" not in perms, f"{u.id} still has add:agents: {perms}"

View File

@@ -0,0 +1,30 @@
import os
import pytest
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user_group import UserGroupManager
from tests.integration.common_utils.test_models import DATestUser
@pytest.mark.skipif(
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
reason="User group tests are enterprise only",
)
def test_new_group_gets_basic_permission(reset: None) -> None: # noqa: ARG001
admin_user: DATestUser = UserManager.create(name="admin_for_basic_perm")
user_group = UserGroupManager.create(
name="basic-perm-test-group",
user_ids=[admin_user.id],
user_performing_action=admin_user,
)
permissions = UserGroupManager.get_permissions(
user_group=user_group,
user_performing_action=admin_user,
)
assert (
"basic" in permissions
), f"New group should have 'basic' permission, got: {permissions}"

View File

@@ -0,0 +1,78 @@
"""Integration tests for default group assignment on user registration.
Verifies that:
- The first registered user is assigned to the Admin default group
- Subsequent registered users are assigned to the Basic default group
- account_type is set to STANDARD for email/password registrations
"""
from onyx.auth.schemas import UserRole
from onyx.db.enums import AccountType
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user_group import UserGroupManager
from tests.integration.common_utils.test_models import DATestUser
def test_default_group_assignment_on_registration(reset: None) -> None: # noqa: ARG001
# Register first user — should become admin
admin_user: DATestUser = UserManager.create(name="first_user")
assert admin_user.role == UserRole.ADMIN
# Register second user — should become basic
basic_user: DATestUser = UserManager.create(name="second_user")
assert basic_user.role == UserRole.BASIC
# Fetch all groups including default ones
all_groups = UserGroupManager.get_all(
user_performing_action=admin_user,
include_default=True,
)
# Find the default Admin and Basic groups
admin_group = next(
(g for g in all_groups if g.name == "Admin" and g.is_default), None
)
basic_group = next(
(g for g in all_groups if g.name == "Basic" and g.is_default), None
)
assert admin_group is not None, "Admin default group not found"
assert basic_group is not None, "Basic default group not found"
# Verify admin user is in Admin group and NOT in Basic group
admin_group_user_ids = {str(u.id) for u in admin_group.users}
basic_group_user_ids = {str(u.id) for u in basic_group.users}
assert (
admin_user.id in admin_group_user_ids
), "First user should be in Admin default group"
assert (
admin_user.id not in basic_group_user_ids
), "First user should NOT be in Basic default group"
# Verify basic user is in Basic group and NOT in Admin group
assert (
basic_user.id in basic_group_user_ids
), "Second user should be in Basic default group"
assert (
basic_user.id not in admin_group_user_ids
), "Second user should NOT be in Admin default group"
# Verify account_type is STANDARD for both users via user listing API
paginated_result = UserManager.get_user_page(
user_performing_action=admin_user,
page_num=0,
page_size=10,
)
users_by_id = {str(u.id): u for u in paginated_result.items}
admin_snapshot = users_by_id.get(admin_user.id)
basic_snapshot = users_by_id.get(basic_user.id)
assert admin_snapshot is not None, "Admin user not found in user listing"
assert basic_snapshot is not None, "Basic user not found in user listing"
assert (
admin_snapshot.account_type == AccountType.STANDARD
), f"Admin user account_type should be STANDARD, got {admin_snapshot.account_type}"
assert (
basic_snapshot.account_type == AccountType.STANDARD
), f"Basic user account_type should be STANDARD, got {basic_snapshot.account_type}"

View File

@@ -0,0 +1,135 @@
"""Integration tests for password signup upgrade paths.
Verifies that when a BOT or EXT_PERM_USER user signs up via email/password:
- Their account_type is upgraded to STANDARD
- They are assigned to the Basic default group
- They gain the correct effective permissions
"""
import pytest
from onyx.auth.schemas import UserRole
from onyx.db.enums import AccountType
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user_group import UserGroupManager
from tests.integration.common_utils.test_models import DATestUser
def _get_default_group_member_emails(
admin_user: DATestUser,
group_name: str,
) -> set[str]:
"""Get the set of emails of all members in a named default group."""
all_groups = UserGroupManager.get_all(admin_user, include_default=True)
matched = [g for g in all_groups if g.is_default and g.name == group_name]
assert matched, f"Default group '{group_name}' not found"
return {u.email for u in matched[0].users}
@pytest.mark.parametrize(
"target_role",
[UserRole.EXT_PERM_USER, UserRole.SLACK_USER],
ids=["ext_perm_user", "slack_user"],
)
def test_password_signup_upgrade(
reset: None, # noqa: ARG001
target_role: UserRole,
) -> None:
"""When a non-web user signs up via email/password, they should be
upgraded to STANDARD account_type and assigned to the Basic default group."""
admin_user: DATestUser = UserManager.create(email="admin@example.com")
test_email = f"{target_role.value}_upgrade@example.com"
test_user = UserManager.create(email=test_email)
test_user = UserManager.set_role(
user_to_set=test_user,
target_role=target_role,
user_performing_action=admin_user,
explicit_override=True,
)
# Verify user was removed from Basic group after downgrade
basic_emails = _get_default_group_member_emails(admin_user, "Basic")
assert (
test_email not in basic_emails
), f"{target_role.value} should not be in Basic default group"
# Re-register with the same email — triggers the password signup upgrade
upgraded_user = UserManager.create(email=test_email)
assert upgraded_user.role == UserRole.BASIC
paginated = UserManager.get_user_page(
user_performing_action=admin_user,
page_num=0,
page_size=10,
)
user_snapshot = next(
(u for u in paginated.items if str(u.id) == upgraded_user.id), None
)
assert user_snapshot is not None
assert (
user_snapshot.account_type == AccountType.STANDARD
), f"Expected STANDARD, got {user_snapshot.account_type}"
# Verify user is now in the Basic default group
basic_emails = _get_default_group_member_emails(admin_user, "Basic")
assert (
test_email in basic_emails
), f"Upgraded user '{test_email}' not found in Basic default group"
def test_password_signup_upgrade_propagates_permissions(
reset: None, # noqa: ARG001
) -> None:
"""When an EXT_PERM_USER or SLACK_USER signs up via password, they should
gain the 'basic' permission through the Basic default group assignment."""
admin_user: DATestUser = UserManager.create(email="admin@example.com")
# --- EXT_PERM_USER path ---
ext_email = "ext_perms_check@example.com"
ext_user = UserManager.create(email=ext_email)
initial_perms = UserManager.get_permissions(ext_user)
assert "basic" in initial_perms
ext_user = UserManager.set_role(
user_to_set=ext_user,
target_role=UserRole.EXT_PERM_USER,
user_performing_action=admin_user,
explicit_override=True,
)
basic_emails = _get_default_group_member_emails(admin_user, "Basic")
assert ext_email not in basic_emails
upgraded = UserManager.create(email=ext_email)
assert upgraded.role == UserRole.BASIC
perms = UserManager.get_permissions(upgraded)
assert (
"basic" in perms
), f"Upgraded EXT_PERM_USER should have 'basic' permission, got: {perms}"
# --- SLACK_USER path ---
slack_email = "slack_perms_check@example.com"
slack_user = UserManager.create(email=slack_email)
slack_user = UserManager.set_role(
user_to_set=slack_user,
target_role=UserRole.SLACK_USER,
user_performing_action=admin_user,
explicit_override=True,
)
basic_emails = _get_default_group_member_emails(admin_user, "Basic")
assert slack_email not in basic_emails
upgraded = UserManager.create(email=slack_email)
assert upgraded.role == UserRole.BASIC
perms = UserManager.get_permissions(upgraded)
assert (
"basic" in perms
), f"Upgraded SLACK_USER should have 'basic' permission, got: {perms}"

View File

@@ -0,0 +1,54 @@
"""Integration tests for default group reconciliation on user reactivation.
Verifies that:
- A deactivated user retains default group membership after reactivation
- Reactivation via the admin API reconciles missing group membership
"""
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.managers.user_group import UserGroupManager
from tests.integration.common_utils.test_models import DATestUser
def _get_default_group_member_emails(
admin_user: DATestUser,
group_name: str,
) -> set[str]:
"""Get the set of emails of all members in a named default group."""
all_groups = UserGroupManager.get_all(admin_user, include_default=True)
matched = [g for g in all_groups if g.is_default and g.name == group_name]
assert matched, f"Default group '{group_name}' not found"
return {u.email for u in matched[0].users}
def test_reactivated_user_retains_default_group(
reset: None, # noqa: ARG001
) -> None:
"""Deactivating and reactivating a user should preserve their
default group membership."""
admin_user: DATestUser = UserManager.create(name="admin_user")
basic_user: DATestUser = UserManager.create(name="basic_user")
# Verify user is in Basic group initially
basic_emails = _get_default_group_member_emails(admin_user, "Basic")
assert basic_user.email in basic_emails
# Deactivate the user
UserManager.set_status(
user_to_set=basic_user,
target_status=False,
user_performing_action=admin_user,
)
# Reactivate the user
UserManager.set_status(
user_to_set=basic_user,
target_status=True,
user_performing_action=admin_user,
)
# Verify user is still in Basic group after reactivation
basic_emails = _get_default_group_member_emails(admin_user, "Basic")
assert (
basic_user.email in basic_emails
), "Reactivated user should still be in Basic default group"

View File

@@ -0,0 +1,176 @@
"""
Unit tests for onyx.auth.permissions — pure logic and FastAPI dependency.
"""
from unittest.mock import MagicMock
import pytest
from onyx.auth.permissions import ALL_PERMISSIONS
from onyx.auth.permissions import get_effective_permissions
from onyx.auth.permissions import require_permission
from onyx.auth.permissions import resolve_effective_permissions
from onyx.db.enums import Permission
from onyx.error_handling.error_codes import OnyxErrorCode
from onyx.error_handling.exceptions import OnyxError
# ---------------------------------------------------------------------------
# resolve_effective_permissions
# ---------------------------------------------------------------------------
class TestResolveEffectivePermissions:
def test_empty_set(self) -> None:
assert resolve_effective_permissions(set()) == set()
def test_basic_no_implications(self) -> None:
result = resolve_effective_permissions({"basic"})
assert result == {"basic"}
def test_single_implication(self) -> None:
result = resolve_effective_permissions({"add:agents"})
assert result == {"add:agents", "read:agents"}
def test_manage_agents_implies_add_and_read(self) -> None:
"""manage:agents directly maps to {add:agents, read:agents}."""
result = resolve_effective_permissions({"manage:agents"})
assert result == {"manage:agents", "add:agents", "read:agents"}
def test_manage_connectors_chain(self) -> None:
result = resolve_effective_permissions({"manage:connectors"})
assert result == {"manage:connectors", "add:connectors", "read:connectors"}
def test_manage_document_sets(self) -> None:
result = resolve_effective_permissions({"manage:document_sets"})
assert result == {
"manage:document_sets",
"read:document_sets",
"read:connectors",
}
def test_manage_user_groups_implies_all_reads(self) -> None:
result = resolve_effective_permissions({"manage:user_groups"})
assert result == {
"manage:user_groups",
"read:connectors",
"read:document_sets",
"read:agents",
"read:users",
}
def test_admin_override(self) -> None:
result = resolve_effective_permissions({"admin"})
assert result == set(ALL_PERMISSIONS)
def test_admin_with_others(self) -> None:
result = resolve_effective_permissions({"admin", "basic"})
assert result == set(ALL_PERMISSIONS)
def test_multi_group_union(self) -> None:
result = resolve_effective_permissions(
{"add:agents", "manage:connectors", "basic"}
)
assert result == {
"basic",
"add:agents",
"read:agents",
"manage:connectors",
"add:connectors",
"read:connectors",
}
def test_toggle_permission_no_implications(self) -> None:
result = resolve_effective_permissions({"read:agent_analytics"})
assert result == {"read:agent_analytics"}
def test_all_permissions_for_admin(self) -> None:
result = resolve_effective_permissions({"admin"})
assert len(result) == len(ALL_PERMISSIONS)
# ---------------------------------------------------------------------------
# get_effective_permissions (expands implied at read time)
# ---------------------------------------------------------------------------
class TestGetEffectivePermissions:
def test_expands_implied_permissions(self) -> None:
"""Column stores only granted; get_effective_permissions expands implied."""
user = MagicMock()
user.effective_permissions = ["add:agents"]
result = get_effective_permissions(user)
assert result == {Permission.ADD_AGENTS, Permission.READ_AGENTS}
def test_admin_expands_to_all(self) -> None:
user = MagicMock()
user.effective_permissions = ["admin"]
result = get_effective_permissions(user)
assert result == set(Permission)
def test_basic_stays_basic(self) -> None:
user = MagicMock()
user.effective_permissions = ["basic"]
result = get_effective_permissions(user)
assert result == {Permission.BASIC_ACCESS}
def test_empty_column(self) -> None:
user = MagicMock()
user.effective_permissions = []
result = get_effective_permissions(user)
assert result == set()
# ---------------------------------------------------------------------------
# require_permission (FastAPI dependency)
# ---------------------------------------------------------------------------
class TestRequirePermission:
@pytest.mark.asyncio
async def test_admin_bypass(self) -> None:
"""Admin stored in column should pass any permission check."""
user = MagicMock()
user.effective_permissions = ["admin"]
dep = require_permission(Permission.MANAGE_CONNECTORS)
result = await dep(user=user)
assert result is user
@pytest.mark.asyncio
async def test_has_required_permission(self) -> None:
user = MagicMock()
user.effective_permissions = ["manage:connectors"]
dep = require_permission(Permission.MANAGE_CONNECTORS)
result = await dep(user=user)
assert result is user
@pytest.mark.asyncio
async def test_implied_permission_passes(self) -> None:
"""manage:connectors implies read:connectors at read time."""
user = MagicMock()
user.effective_permissions = ["manage:connectors"]
dep = require_permission(Permission.READ_CONNECTORS)
result = await dep(user=user)
assert result is user
@pytest.mark.asyncio
async def test_missing_permission_raises(self) -> None:
user = MagicMock()
user.effective_permissions = ["basic"]
dep = require_permission(Permission.MANAGE_CONNECTORS)
with pytest.raises(OnyxError) as exc_info:
await dep(user=user)
assert exc_info.value.error_code == OnyxErrorCode.INSUFFICIENT_PERMISSIONS
@pytest.mark.asyncio
async def test_empty_permissions_fails(self) -> None:
user = MagicMock()
user.effective_permissions = []
dep = require_permission(Permission.BASIC_ACCESS)
with pytest.raises(OnyxError):
await dep(user=user)

View File

@@ -0,0 +1,29 @@
"""
Unit tests for UserCreate schema dict methods.
Verifies that account_type is always included in create_update_dict
and create_update_dict_superuser.
"""
from onyx.auth.schemas import UserCreate
from onyx.db.enums import AccountType
def test_create_update_dict_includes_default_account_type() -> None:
uc = UserCreate(email="a@b.com", password="secret123")
d = uc.create_update_dict()
assert d["account_type"] == AccountType.STANDARD
def test_create_update_dict_includes_explicit_account_type() -> None:
uc = UserCreate(
email="a@b.com", password="secret123", account_type=AccountType.SERVICE_ACCOUNT
)
d = uc.create_update_dict()
assert d["account_type"] == AccountType.STANDARD
def test_create_update_dict_superuser_includes_account_type() -> None:
uc = UserCreate(email="a@b.com", password="secret123")
d = uc.create_update_dict_superuser()
assert d["account_type"] == AccountType.STANDARD

View File

@@ -0,0 +1,176 @@
"""
Unit tests for assign_user_to_default_groups__no_commit in onyx.db.users.
Covers:
1. Standard/service-account users get assigned to the correct default group
2. BOT, EXT_PERM_USER, ANONYMOUS account types are skipped
3. Missing default group raises RuntimeError
4. Already-in-group is a no-op
5. IntegrityError race condition is handled gracefully
6. The function never commits the session
"""
from unittest.mock import MagicMock
from uuid import uuid4
import pytest
from sqlalchemy.exc import IntegrityError
from onyx.db.enums import AccountType
from onyx.db.models import User__UserGroup
from onyx.db.models import UserGroup
from onyx.db.users import assign_user_to_default_groups__no_commit
def _mock_user(
account_type: AccountType = AccountType.STANDARD,
email: str = "test@example.com",
) -> MagicMock:
user = MagicMock()
user.id = uuid4()
user.email = email
user.account_type = account_type
return user
def _mock_group(name: str = "Basic", group_id: int = 1) -> MagicMock:
group = MagicMock()
group.id = group_id
group.name = name
group.is_default = True
return group
def _make_query_chain(first_return: object = None) -> MagicMock:
"""Returns a mock that supports .filter(...).filter(...).first() chaining."""
chain = MagicMock()
chain.filter.return_value = chain
chain.first.return_value = first_return
return chain
def _setup_db_session(
group_result: object = None,
membership_result: object = None,
) -> MagicMock:
"""Create a db_session mock that routes query(UserGroup) and query(User__UserGroup)."""
db_session = MagicMock()
group_chain = _make_query_chain(group_result)
membership_chain = _make_query_chain(membership_result)
def query_side_effect(model: type) -> MagicMock:
if model is UserGroup:
return group_chain
if model is User__UserGroup:
return membership_chain
return MagicMock()
db_session.query.side_effect = query_side_effect
return db_session
def test_standard_user_assigned_to_basic_group() -> None:
group = _mock_group("Basic")
db_session = _setup_db_session(group_result=group, membership_result=None)
savepoint = MagicMock()
db_session.begin_nested.return_value = savepoint
user = _mock_user(AccountType.STANDARD)
assign_user_to_default_groups__no_commit(db_session, user, is_admin=False)
db_session.add.assert_called_once()
added = db_session.add.call_args[0][0]
assert isinstance(added, User__UserGroup)
assert added.user_id == user.id
assert added.user_group_id == group.id
db_session.flush.assert_called_once()
def test_admin_user_assigned_to_admin_group() -> None:
group = _mock_group("Admin", group_id=2)
db_session = _setup_db_session(group_result=group, membership_result=None)
savepoint = MagicMock()
db_session.begin_nested.return_value = savepoint
user = _mock_user(AccountType.STANDARD)
assign_user_to_default_groups__no_commit(db_session, user, is_admin=True)
db_session.add.assert_called_once()
added = db_session.add.call_args[0][0]
assert isinstance(added, User__UserGroup)
assert added.user_group_id == group.id
@pytest.mark.parametrize(
"account_type",
[AccountType.BOT, AccountType.EXT_PERM_USER, AccountType.ANONYMOUS],
)
def test_excluded_account_types_skipped(account_type: AccountType) -> None:
db_session = MagicMock()
user = _mock_user(account_type)
assign_user_to_default_groups__no_commit(db_session, user)
db_session.query.assert_not_called()
db_session.add.assert_not_called()
def test_service_account_not_skipped() -> None:
group = _mock_group("Basic")
db_session = _setup_db_session(group_result=group, membership_result=None)
savepoint = MagicMock()
db_session.begin_nested.return_value = savepoint
user = _mock_user(AccountType.SERVICE_ACCOUNT)
assign_user_to_default_groups__no_commit(db_session, user, is_admin=False)
db_session.add.assert_called_once()
def test_missing_default_group_raises_error() -> None:
db_session = _setup_db_session(group_result=None)
user = _mock_user()
with pytest.raises(RuntimeError, match="Default group .* not found"):
assign_user_to_default_groups__no_commit(db_session, user)
def test_already_in_group_is_noop() -> None:
group = _mock_group("Basic")
existing_membership = MagicMock()
db_session = _setup_db_session(
group_result=group, membership_result=existing_membership
)
user = _mock_user()
assign_user_to_default_groups__no_commit(db_session, user)
db_session.add.assert_not_called()
db_session.begin_nested.assert_not_called()
def test_integrity_error_race_condition_handled() -> None:
group = _mock_group("Basic")
db_session = _setup_db_session(group_result=group, membership_result=None)
savepoint = MagicMock()
db_session.begin_nested.return_value = savepoint
db_session.flush.side_effect = IntegrityError(None, None, Exception("duplicate"))
user = _mock_user()
# Should not raise
assign_user_to_default_groups__no_commit(db_session, user)
savepoint.rollback.assert_called_once()
def test_no_commit_called_on_successful_assignment() -> None:
group = _mock_group("Basic")
db_session = _setup_db_session(group_result=group, membership_result=None)
savepoint = MagicMock()
db_session.begin_nested.return_value = savepoint
user = _mock_user()
assign_user_to_default_groups__no_commit(db_session, user)
db_session.commit.assert_not_called()

View File

@@ -0,0 +1,203 @@
import pytest
from onyx.document_index.interfaces_new import TenantState
from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE
from onyx.document_index.opensearch.schema import get_opensearch_doc_chunk_id
from onyx.document_index.opensearch.string_filtering import (
MAX_DOCUMENT_ID_ENCODED_LENGTH,
)
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE
SINGLE_TENANT_STATE = TenantState(
tenant_id=POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE, multitenant=False
)
MULTI_TENANT_STATE = TenantState(
tenant_id="tenant_abcdef12-3456-7890-abcd-ef1234567890", multitenant=True
)
EXPECTED_SHORT_TENANT = "abcdef12"
class TestGetOpensearchDocChunkIdSingleTenant:
def test_basic(self) -> None:
result = get_opensearch_doc_chunk_id(
SINGLE_TENANT_STATE, "my-doc-id", chunk_index=0
)
assert result == f"my-doc-id__{DEFAULT_MAX_CHUNK_SIZE}__0"
def test_custom_chunk_size(self) -> None:
result = get_opensearch_doc_chunk_id(
SINGLE_TENANT_STATE, "doc1", chunk_index=3, max_chunk_size=1024
)
assert result == "doc1__1024__3"
def test_special_chars_are_stripped(self) -> None:
"""Tests characters not matching [A-Za-z0-9_.-~] are removed."""
result = get_opensearch_doc_chunk_id(
SINGLE_TENANT_STATE, "doc/with?special#chars&more%stuff", chunk_index=0
)
assert "/" not in result
assert "?" not in result
assert "#" not in result
assert result == f"docwithspecialcharsmorestuff__{DEFAULT_MAX_CHUNK_SIZE}__0"
def test_short_doc_id_not_hashed(self) -> None:
"""
Tests that a short doc ID should appear directly in the result, not as a
hash.
"""
doc_id = "short-id"
result = get_opensearch_doc_chunk_id(SINGLE_TENANT_STATE, doc_id, chunk_index=0)
assert "short-id" in result
def test_long_doc_id_is_hashed(self) -> None:
"""
Tests that a doc ID exceeding the max length should be replaced with a
blake2b hash.
"""
# Create a doc ID that will exceed max length after the suffix is
# appended.
doc_id = "a" * MAX_DOCUMENT_ID_ENCODED_LENGTH
result = get_opensearch_doc_chunk_id(SINGLE_TENANT_STATE, doc_id, chunk_index=0)
# The original doc ID should NOT appear in the result.
assert doc_id not in result
# The suffix should still be present.
assert f"__{DEFAULT_MAX_CHUNK_SIZE}__0" in result
def test_long_doc_id_hash_is_deterministic(self) -> None:
doc_id = "x" * MAX_DOCUMENT_ID_ENCODED_LENGTH
result1 = get_opensearch_doc_chunk_id(
SINGLE_TENANT_STATE, doc_id, chunk_index=5
)
result2 = get_opensearch_doc_chunk_id(
SINGLE_TENANT_STATE, doc_id, chunk_index=5
)
assert result1 == result2
def test_long_doc_id_different_inputs_produce_different_hashes(self) -> None:
doc_id_a = "a" * MAX_DOCUMENT_ID_ENCODED_LENGTH
doc_id_b = "b" * MAX_DOCUMENT_ID_ENCODED_LENGTH
result_a = get_opensearch_doc_chunk_id(
SINGLE_TENANT_STATE, doc_id_a, chunk_index=0
)
result_b = get_opensearch_doc_chunk_id(
SINGLE_TENANT_STATE, doc_id_b, chunk_index=0
)
assert result_a != result_b
def test_result_never_exceeds_max_length(self) -> None:
"""
Tests that the final result should always be under
MAX_DOCUMENT_ID_ENCODED_LENGTH bytes.
"""
doc_id = "z" * (MAX_DOCUMENT_ID_ENCODED_LENGTH * 2)
result = get_opensearch_doc_chunk_id(
SINGLE_TENANT_STATE, doc_id, chunk_index=999, max_chunk_size=99999
)
assert len(result.encode("utf-8")) < MAX_DOCUMENT_ID_ENCODED_LENGTH
def test_no_tenant_prefix_in_single_tenant(self) -> None:
result = get_opensearch_doc_chunk_id(
SINGLE_TENANT_STATE, "mydoc", chunk_index=0
)
assert not result.startswith(SINGLE_TENANT_STATE.tenant_id)
class TestGetOpensearchDocChunkIdMultiTenant:
def test_includes_tenant_prefix(self) -> None:
result = get_opensearch_doc_chunk_id(MULTI_TENANT_STATE, "mydoc", chunk_index=0)
assert result.startswith(f"{EXPECTED_SHORT_TENANT}__")
def test_format(self) -> None:
result = get_opensearch_doc_chunk_id(
MULTI_TENANT_STATE, "mydoc", chunk_index=2, max_chunk_size=256
)
assert result == f"{EXPECTED_SHORT_TENANT}__mydoc__256__2"
def test_long_doc_id_is_hashed_multitenant(self) -> None:
doc_id = "d" * MAX_DOCUMENT_ID_ENCODED_LENGTH
result = get_opensearch_doc_chunk_id(MULTI_TENANT_STATE, doc_id, chunk_index=0)
# Should still have tenant prefix.
assert result.startswith(f"{EXPECTED_SHORT_TENANT}__")
# The original doc ID should NOT appear in the result.
assert doc_id not in result
# The suffix should still be present.
assert f"__{DEFAULT_MAX_CHUNK_SIZE}__0" in result
def test_result_never_exceeds_max_length_multitenant(self) -> None:
doc_id = "q" * (MAX_DOCUMENT_ID_ENCODED_LENGTH * 2)
result = get_opensearch_doc_chunk_id(
MULTI_TENANT_STATE, doc_id, chunk_index=999, max_chunk_size=99999
)
assert len(result.encode("utf-8")) < MAX_DOCUMENT_ID_ENCODED_LENGTH
def test_different_tenants_produce_different_ids(self) -> None:
tenant_a = TenantState(
tenant_id="tenant_aaaaaaaa-0000-0000-0000-000000000000", multitenant=True
)
tenant_b = TenantState(
tenant_id="tenant_bbbbbbbb-0000-0000-0000-000000000000", multitenant=True
)
result_a = get_opensearch_doc_chunk_id(tenant_a, "same-doc", chunk_index=0)
result_b = get_opensearch_doc_chunk_id(tenant_b, "same-doc", chunk_index=0)
assert result_a != result_b
class TestGetOpensearchDocChunkIdEdgeCases:
def test_chunk_index_zero(self) -> None:
result = get_opensearch_doc_chunk_id(SINGLE_TENANT_STATE, "doc", chunk_index=0)
assert result.endswith("__0")
def test_large_chunk_index(self) -> None:
result = get_opensearch_doc_chunk_id(
SINGLE_TENANT_STATE, "doc", chunk_index=99999
)
assert result.endswith("__99999")
def test_doc_id_with_only_special_chars_raises(self) -> None:
"""
Tests that a doc ID that becomes empty after filtering should raise
ValueError.
"""
with pytest.raises(ValueError, match="empty after filtering"):
get_opensearch_doc_chunk_id(SINGLE_TENANT_STATE, "###???///", chunk_index=0)
def test_doc_id_at_boundary_length(self) -> None:
"""
Tests that a doc ID right at the boundary should not be hashed.
"""
suffix = f"__{DEFAULT_MAX_CHUNK_SIZE}__0"
suffix_len = len(suffix.encode("utf-8"))
# Max doc ID length that won't trigger hashing (must be <
# max_encoded_length).
max_doc_len = MAX_DOCUMENT_ID_ENCODED_LENGTH - suffix_len - 1
doc_id = "a" * max_doc_len
result = get_opensearch_doc_chunk_id(SINGLE_TENANT_STATE, doc_id, chunk_index=0)
assert doc_id in result
def test_doc_id_at_boundary_length_multitenant(self) -> None:
"""
Tests that a doc ID right at the boundary should not be hashed in
multitenant mode.
"""
suffix = f"__{DEFAULT_MAX_CHUNK_SIZE}__0"
suffix_len = len(suffix.encode("utf-8"))
prefix = f"{EXPECTED_SHORT_TENANT}__"
prefix_len = len(prefix.encode("utf-8"))
# Max doc ID length that won't trigger hashing (must be <
# max_encoded_length).
max_doc_len = MAX_DOCUMENT_ID_ENCODED_LENGTH - suffix_len - prefix_len - 1
doc_id = "a" * max_doc_len
result = get_opensearch_doc_chunk_id(MULTI_TENANT_STATE, doc_id, chunk_index=0)
assert doc_id in result
def test_doc_id_one_over_boundary_is_hashed(self) -> None:
"""
Tests that a doc ID one byte over the boundary should be hashed.
"""
suffix = f"__{DEFAULT_MAX_CHUNK_SIZE}__0"
suffix_len = len(suffix.encode("utf-8"))
# This length will trigger the >= check in filter_and_validate_document_id
doc_id = "a" * (MAX_DOCUMENT_ID_ENCODED_LENGTH - suffix_len)
result = get_opensearch_doc_chunk_id(SINGLE_TENANT_STATE, doc_id, chunk_index=0)
assert doc_id not in result

View File

@@ -0,0 +1,91 @@
"""Tests for FileStore.delete_file error_on_missing behavior."""
from unittest.mock import MagicMock
from unittest.mock import patch
import pytest
_S3_MODULE = "onyx.file_store.file_store"
_PG_MODULE = "onyx.file_store.postgres_file_store"
def _mock_db_session() -> MagicMock:
session = MagicMock()
session.__enter__ = MagicMock(return_value=session)
session.__exit__ = MagicMock(return_value=False)
return session
# ── S3BackedFileStore ────────────────────────────────────────────────
@patch(f"{_S3_MODULE}.get_session_with_current_tenant_if_none")
@patch(f"{_S3_MODULE}.get_filerecord_by_file_id_optional", return_value=None)
def test_s3_delete_missing_file_raises_by_default(
_mock_get_record: MagicMock,
mock_ctx: MagicMock,
) -> None:
from onyx.file_store.file_store import S3BackedFileStore
mock_ctx.return_value = _mock_db_session()
store = S3BackedFileStore(bucket_name="b")
with pytest.raises(RuntimeError, match="does not exist"):
store.delete_file("nonexistent")
@patch(f"{_S3_MODULE}.get_session_with_current_tenant_if_none")
@patch(f"{_S3_MODULE}.get_filerecord_by_file_id_optional", return_value=None)
@patch(f"{_S3_MODULE}.delete_filerecord_by_file_id")
def test_s3_delete_missing_file_silent_when_error_on_missing_false(
mock_delete_record: MagicMock,
_mock_get_record: MagicMock,
mock_ctx: MagicMock,
) -> None:
from onyx.file_store.file_store import S3BackedFileStore
mock_ctx.return_value = _mock_db_session()
store = S3BackedFileStore(bucket_name="b")
store.delete_file("nonexistent", error_on_missing=False)
mock_delete_record.assert_not_called()
# ── PostgresBackedFileStore ──────────────────────────────────────────
@patch(f"{_PG_MODULE}.get_session_with_current_tenant_if_none")
@patch(f"{_PG_MODULE}.get_file_content_by_file_id_optional", return_value=None)
def test_pg_delete_missing_file_raises_by_default(
_mock_get_content: MagicMock,
mock_ctx: MagicMock,
) -> None:
from onyx.file_store.postgres_file_store import PostgresBackedFileStore
mock_ctx.return_value = _mock_db_session()
store = PostgresBackedFileStore()
with pytest.raises(RuntimeError, match="does not exist"):
store.delete_file("nonexistent")
@patch(f"{_PG_MODULE}.get_session_with_current_tenant_if_none")
@patch(f"{_PG_MODULE}.get_file_content_by_file_id_optional", return_value=None)
@patch(f"{_PG_MODULE}.delete_file_content_by_file_id")
@patch(f"{_PG_MODULE}.delete_filerecord_by_file_id")
def test_pg_delete_missing_file_silent_when_error_on_missing_false(
mock_delete_record: MagicMock,
mock_delete_content: MagicMock,
_mock_get_content: MagicMock,
mock_ctx: MagicMock,
) -> None:
from onyx.file_store.postgres_file_store import PostgresBackedFileStore
mock_ctx.return_value = _mock_db_session()
store = PostgresBackedFileStore()
store.delete_file("nonexistent", error_on_missing=False)
mock_delete_record.assert_not_called()
mock_delete_content.assert_not_called()

View File

@@ -113,6 +113,7 @@ def make_db_group(**kwargs: Any) -> MagicMock:
group.name = kwargs.get("name", "Engineering")
group.is_up_for_deletion = kwargs.get("is_up_for_deletion", False)
group.is_up_to_date = kwargs.get("is_up_to_date", True)
group.is_default = kwargs.get("is_default", False)
return group

View File

@@ -3,6 +3,7 @@ from unittest.mock import MagicMock
from uuid import uuid4
from onyx.auth.schemas import UserRole
from onyx.db.enums import AccountType
from onyx.server.models import FullUserSnapshot
from onyx.server.models import UserGroupInfo
@@ -25,6 +26,7 @@ def _mock_user(
user.updated_at = updated_at or datetime.datetime(
2025, 6, 15, tzinfo=datetime.timezone.utc
)
user.account_type = AccountType.STANDARD
return user

View File

@@ -98,6 +98,7 @@ Useful hardening flags:
| `serve` | Serve the interactive chat TUI over SSH |
| `configure` | Configure server URL and API key |
| `validate-config` | Validate configuration and test connection |
| `install-skill` | Install the agent skill file into a project |
## Slash Commands (in TUI)

View File

@@ -7,6 +7,7 @@ import (
"github.com/onyx-dot-app/onyx/cli/internal/api"
"github.com/onyx-dot-app/onyx/cli/internal/config"
"github.com/onyx-dot-app/onyx/cli/internal/exitcodes"
"github.com/spf13/cobra"
)
@@ -16,16 +17,23 @@ func newAgentsCmd() *cobra.Command {
cmd := &cobra.Command{
Use: "agents",
Short: "List available agents",
Long: `List all visible agents configured on the Onyx server.
By default, output is a human-readable table with ID, name, and description.
Use --json for machine-readable output.`,
Example: ` onyx-cli agents
onyx-cli agents --json
onyx-cli agents --json | jq '.[].name'`,
RunE: func(cmd *cobra.Command, args []string) error {
cfg := config.Load()
if !cfg.IsConfigured() {
return fmt.Errorf("onyx CLI is not configured — run 'onyx-cli configure' first")
return exitcodes.New(exitcodes.NotConfigured, "onyx CLI is not configured\n Run: onyx-cli configure")
}
client := api.NewClient(cfg)
agents, err := client.ListAgents(cmd.Context())
if err != nil {
return fmt.Errorf("failed to list agents: %w", err)
return fmt.Errorf("failed to list agents: %w\n Check your connection with: onyx-cli validate-config", err)
}
if agentsJSON {

View File

@@ -4,33 +4,65 @@ import (
"context"
"encoding/json"
"fmt"
"io"
"os"
"os/signal"
"strings"
"syscall"
"github.com/onyx-dot-app/onyx/cli/internal/api"
"github.com/onyx-dot-app/onyx/cli/internal/config"
"github.com/onyx-dot-app/onyx/cli/internal/exitcodes"
"github.com/onyx-dot-app/onyx/cli/internal/models"
"github.com/onyx-dot-app/onyx/cli/internal/overflow"
"github.com/spf13/cobra"
"golang.org/x/term"
)
const defaultMaxOutputBytes = 4096
func newAskCmd() *cobra.Command {
var (
askAgentID int
askJSON bool
askQuiet bool
askPrompt string
maxOutput int
)
cmd := &cobra.Command{
Use: "ask [question]",
Short: "Ask a one-shot question (non-interactive)",
Args: cobra.ExactArgs(1),
Long: `Send a one-shot question to an Onyx agent and print the response.
The question can be provided as a positional argument, via --prompt, or piped
through stdin. When stdin contains piped data, it is sent as context along
with the question from --prompt (or used as the question itself).
When stdout is not a TTY (e.g., called by a script or AI agent), output is
automatically truncated to --max-output bytes and the full response is saved
to a temp file. Set --max-output 0 to disable truncation.`,
Args: cobra.MaximumNArgs(1),
Example: ` onyx-cli ask "What connectors are available?"
onyx-cli ask --agent-id 3 "Summarize our Q4 revenue"
onyx-cli ask --json "List all users" | jq '.event.content'
cat error.log | onyx-cli ask --prompt "Find the root cause"
echo "what is onyx?" | onyx-cli ask`,
RunE: func(cmd *cobra.Command, args []string) error {
cfg := config.Load()
if !cfg.IsConfigured() {
return fmt.Errorf("onyx CLI is not configured — run 'onyx-cli configure' first")
return exitcodes.New(exitcodes.NotConfigured, "onyx CLI is not configured\n Run: onyx-cli configure")
}
if askJSON && askQuiet {
return exitcodes.New(exitcodes.BadRequest, "--json and --quiet cannot be used together")
}
question, err := resolveQuestion(args, askPrompt)
if err != nil {
return err
}
question := args[0]
agentID := cfg.DefaultAgentID
if cmd.Flags().Changed("agent-id") {
agentID = askAgentID
@@ -50,9 +82,23 @@ func newAskCmd() *cobra.Command {
nil,
)
// Determine truncation threshold.
isTTY := term.IsTerminal(int(os.Stdout.Fd()))
truncateAt := 0 // 0 means no truncation
if cmd.Flags().Changed("max-output") {
truncateAt = maxOutput
} else if !isTTY {
truncateAt = defaultMaxOutputBytes
}
var sessionID string
var lastErr error
gotStop := false
// Overflow writer: tees to stdout and optionally to a temp file.
// In quiet mode, buffer everything and print once at the end.
ow := &overflow.Writer{Limit: truncateAt, Quiet: askQuiet}
for event := range ch {
if e, ok := event.(models.SessionCreatedEvent); ok {
sessionID = e.ChatSessionID
@@ -82,22 +128,50 @@ func newAskCmd() *cobra.Command {
switch e := event.(type) {
case models.MessageDeltaEvent:
fmt.Print(e.Content)
ow.Write(e.Content)
case models.SearchStartEvent:
if isTTY && !askQuiet {
if e.IsInternetSearch {
fmt.Fprintf(os.Stderr, "\033[2mSearching the web...\033[0m\n")
} else {
fmt.Fprintf(os.Stderr, "\033[2mSearching documents...\033[0m\n")
}
}
case models.SearchQueriesEvent:
if isTTY && !askQuiet {
for _, q := range e.Queries {
fmt.Fprintf(os.Stderr, "\033[2m → %s\033[0m\n", q)
}
}
case models.SearchDocumentsEvent:
if isTTY && !askQuiet && len(e.Documents) > 0 {
fmt.Fprintf(os.Stderr, "\033[2mFound %d documents\033[0m\n", len(e.Documents))
}
case models.ReasoningStartEvent:
if isTTY && !askQuiet {
fmt.Fprintf(os.Stderr, "\033[2mThinking...\033[0m\n")
}
case models.ToolStartEvent:
if isTTY && !askQuiet && e.ToolName != "" {
fmt.Fprintf(os.Stderr, "\033[2mUsing %s...\033[0m\n", e.ToolName)
}
case models.ErrorEvent:
ow.Finish()
return fmt.Errorf("%s", e.Error)
case models.StopEvent:
fmt.Println()
ow.Finish()
return nil
}
}
if !askJSON {
ow.Finish()
}
if ctx.Err() != nil {
if sessionID != "" {
client.StopChatSession(context.Background(), sessionID)
}
if !askJSON {
fmt.Println()
}
return nil
}
@@ -105,20 +179,56 @@ func newAskCmd() *cobra.Command {
return lastErr
}
if !gotStop {
if !askJSON {
fmt.Println()
}
return fmt.Errorf("stream ended unexpectedly")
}
if !askJSON {
fmt.Println()
}
return nil
},
}
cmd.Flags().IntVar(&askAgentID, "agent-id", 0, "Agent ID to use")
cmd.Flags().BoolVar(&askJSON, "json", false, "Output raw JSON events")
// Suppress cobra's default error/usage on RunE errors
cmd.Flags().BoolVarP(&askQuiet, "quiet", "q", false, "Buffer output and print once at end (no streaming)")
cmd.Flags().StringVar(&askPrompt, "prompt", "", "Question text (use with piped stdin context)")
cmd.Flags().IntVar(&maxOutput, "max-output", defaultMaxOutputBytes,
"Max bytes to print before truncating (0 to disable, auto-enabled for non-TTY)")
return cmd
}
// resolveQuestion builds the final question string from args, --prompt, and stdin.
func resolveQuestion(args []string, prompt string) (string, error) {
hasArg := len(args) > 0
hasPrompt := prompt != ""
hasStdin := !term.IsTerminal(int(os.Stdin.Fd()))
if hasArg && hasPrompt {
return "", exitcodes.New(exitcodes.BadRequest, "specify the question as an argument or --prompt, not both")
}
var stdinContent string
if hasStdin {
const maxStdinBytes = 10 * 1024 * 1024 // 10MB
data, err := io.ReadAll(io.LimitReader(os.Stdin, maxStdinBytes))
if err != nil {
return "", fmt.Errorf("failed to read stdin: %w", err)
}
stdinContent = strings.TrimSpace(string(data))
}
switch {
case hasArg && stdinContent != "":
// arg is the question, stdin is context
return args[0] + "\n\n" + stdinContent, nil
case hasArg:
return args[0], nil
case hasPrompt && stdinContent != "":
// --prompt is the question, stdin is context
return prompt + "\n\n" + stdinContent, nil
case hasPrompt:
return prompt, nil
case stdinContent != "":
return stdinContent, nil
default:
return "", exitcodes.New(exitcodes.BadRequest, "no question provided\n Usage: onyx-cli ask \"your question\"\n Or: echo \"context\" | onyx-cli ask --prompt \"your question\"")
}
}

View File

@@ -4,6 +4,7 @@ import (
tea "github.com/charmbracelet/bubbletea"
"github.com/onyx-dot-app/onyx/cli/internal/config"
"github.com/onyx-dot-app/onyx/cli/internal/onboarding"
"github.com/onyx-dot-app/onyx/cli/internal/starprompt"
"github.com/onyx-dot-app/onyx/cli/internal/tui"
"github.com/spf13/cobra"
)
@@ -12,6 +13,11 @@ func newChatCmd() *cobra.Command {
return &cobra.Command{
Use: "chat",
Short: "Launch the interactive chat TUI (default)",
Long: `Launch the interactive terminal UI for chatting with your Onyx agent.
This is the default command when no subcommand is specified. On first run,
an interactive setup wizard will guide you through configuration.`,
Example: ` onyx-cli chat
onyx-cli`,
RunE: func(cmd *cobra.Command, args []string) error {
cfg := config.Load()
@@ -24,6 +30,8 @@ func newChatCmd() *cobra.Command {
cfg = *result
}
starprompt.MaybePrompt()
m := tui.NewModel(cfg)
p := tea.NewProgram(m, tea.WithAltScreen(), tea.WithMouseCellMotion())
_, err := p.Run()

View File

@@ -1,19 +1,126 @@
package cmd
import (
"context"
"errors"
"fmt"
"io"
"os"
"strings"
"time"
"github.com/onyx-dot-app/onyx/cli/internal/api"
"github.com/onyx-dot-app/onyx/cli/internal/config"
"github.com/onyx-dot-app/onyx/cli/internal/exitcodes"
"github.com/onyx-dot-app/onyx/cli/internal/onboarding"
"github.com/spf13/cobra"
"golang.org/x/term"
)
func newConfigureCmd() *cobra.Command {
return &cobra.Command{
var (
serverURL string
apiKey string
apiKeyStdin bool
dryRun bool
)
cmd := &cobra.Command{
Use: "configure",
Short: "Configure server URL and API key",
Long: `Set up the Onyx CLI with your server URL and API key.
When --server-url and --api-key are both provided, the configuration is saved
non-interactively (useful for scripts and AI agents). Otherwise, an interactive
setup wizard is launched.
If --api-key is omitted but stdin has piped data, the API key is read from
stdin automatically. You can also use --api-key-stdin to make this explicit.
This avoids leaking the key in shell history.
Use --dry-run to test the connection without saving the configuration.`,
Example: ` onyx-cli configure
onyx-cli configure --server-url https://my-onyx.com --api-key sk-...
echo "$ONYX_API_KEY" | onyx-cli configure --server-url https://my-onyx.com
echo "$ONYX_API_KEY" | onyx-cli configure --server-url https://my-onyx.com --api-key-stdin
onyx-cli configure --server-url https://my-onyx.com --api-key sk-... --dry-run`,
RunE: func(cmd *cobra.Command, args []string) error {
// Read API key from stdin if piped (implicit) or --api-key-stdin (explicit)
if apiKeyStdin && apiKey != "" {
return exitcodes.New(exitcodes.BadRequest, "--api-key and --api-key-stdin cannot be used together")
}
if (apiKey == "" && !term.IsTerminal(int(os.Stdin.Fd()))) || apiKeyStdin {
data, err := io.ReadAll(os.Stdin)
if err != nil {
return fmt.Errorf("failed to read API key from stdin: %w", err)
}
apiKey = strings.TrimSpace(string(data))
}
if serverURL != "" && apiKey != "" {
return configureNonInteractive(serverURL, apiKey, dryRun)
}
if dryRun {
return exitcodes.New(exitcodes.BadRequest, "--dry-run requires --server-url and --api-key")
}
if serverURL != "" || apiKey != "" {
return exitcodes.New(exitcodes.BadRequest, "both --server-url and --api-key are required for non-interactive setup\n Run 'onyx-cli configure' without flags for interactive setup")
}
cfg := config.Load()
onboarding.Run(&cfg)
return nil
},
}
cmd.Flags().StringVar(&serverURL, "server-url", "", "Onyx server URL (e.g., https://cloud.onyx.app)")
cmd.Flags().StringVar(&apiKey, "api-key", "", "API key for authentication (or pipe via stdin)")
cmd.Flags().BoolVar(&apiKeyStdin, "api-key-stdin", false, "Read API key from stdin (explicit; also happens automatically when stdin is piped)")
cmd.Flags().BoolVar(&dryRun, "dry-run", false, "Test connection without saving config (requires --server-url and --api-key)")
return cmd
}
func configureNonInteractive(serverURL, apiKey string, dryRun bool) error {
cfg := config.OnyxCliConfig{
ServerURL: serverURL,
APIKey: apiKey,
DefaultAgentID: 0,
}
// Preserve existing default agent ID from disk (not env overrides)
if existing := config.LoadFromDisk(); existing.DefaultAgentID != 0 {
cfg.DefaultAgentID = existing.DefaultAgentID
}
// Test connection
client := api.NewClient(cfg)
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
defer cancel()
if err := client.TestConnection(ctx); err != nil {
var authErr *api.AuthError
if errors.As(err, &authErr) {
return exitcodes.Newf(exitcodes.AuthFailure, "authentication failed: %v\n Check your API key", err)
}
return exitcodes.Newf(exitcodes.Unreachable, "connection failed: %v\n Check your server URL", err)
}
if dryRun {
fmt.Printf("Server: %s\n", serverURL)
fmt.Println("Status: connected and authenticated")
fmt.Println("Dry run: config was NOT saved")
return nil
}
if err := config.Save(cfg); err != nil {
return fmt.Errorf("could not save config: %w", err)
}
fmt.Printf("Config: %s\n", config.ConfigFilePath())
fmt.Printf("Server: %s\n", serverURL)
fmt.Println("Status: connected and authenticated")
return nil
}

176
cli/cmd/install_skill.go Normal file
View File

@@ -0,0 +1,176 @@
package cmd
import (
"fmt"
"os"
"path/filepath"
"github.com/onyx-dot-app/onyx/cli/internal/embedded"
"github.com/onyx-dot-app/onyx/cli/internal/fsutil"
"github.com/spf13/cobra"
)
// agentSkillDirs maps agent names to their skill directory paths (relative to
// the project or home root). "Universal" agents like Cursor and Codex read
// from .agents/skills directly, so they don't need their own entry here.
var agentSkillDirs = map[string]string{
"claude-code": filepath.Join(".claude", "skills"),
}
const (
canonicalDir = ".agents/skills"
skillName = "onyx-cli"
)
func newInstallSkillCmd() *cobra.Command {
var (
global bool
copyMode bool
agents []string
)
cmd := &cobra.Command{
Use: "install-skill",
Short: "Install the Onyx CLI agent skill file",
Long: `Install the bundled SKILL.md so that AI coding agents can discover and use
the Onyx CLI as a tool.
Files are written to the canonical .agents/skills/onyx-cli/ directory. For
agents that use their own skill directory (e.g. Claude Code uses .claude/skills/),
a symlink is created pointing back to the canonical copy.
By default the skill is installed at the project level (current directory).
Use --global to install under your home directory instead.
Use --copy to write independent copies instead of symlinks.
Use --agent to target specific agents (can be repeated).`,
Example: ` onyx-cli install-skill
onyx-cli install-skill --global
onyx-cli install-skill --agent claude-code
onyx-cli install-skill --copy`,
RunE: func(cmd *cobra.Command, args []string) error {
base, err := installBase(global)
if err != nil {
return err
}
// Write the canonical copy.
canonicalSkillDir := filepath.Join(base, canonicalDir, skillName)
dest := filepath.Join(canonicalSkillDir, "SKILL.md")
content := []byte(embedded.SkillMD)
status, err := fsutil.CompareFile(dest, content)
if err != nil {
return err
}
switch status {
case fsutil.StatusUpToDate:
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Up to date %s\n", dest)
case fsutil.StatusDiffers:
_, _ = fmt.Fprintf(cmd.ErrOrStderr(), "Warning: overwriting modified %s\n", dest)
if err := os.WriteFile(dest, content, 0o644); err != nil {
return fmt.Errorf("could not write skill file: %w", err)
}
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Installed %s\n", dest)
default: // statusMissing
if err := os.MkdirAll(canonicalSkillDir, 0o755); err != nil {
return fmt.Errorf("could not create directory: %w", err)
}
if err := os.WriteFile(dest, content, 0o644); err != nil {
return fmt.Errorf("could not write skill file: %w", err)
}
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Installed %s\n", dest)
}
// Determine which agents to link.
targets := agentSkillDirs
if len(agents) > 0 {
targets = make(map[string]string)
for _, a := range agents {
dir, ok := agentSkillDirs[a]
if !ok {
_, _ = fmt.Fprintf(cmd.ErrOrStderr(), "Unknown agent %q (skipped) — known agents:", a)
for name := range agentSkillDirs {
_, _ = fmt.Fprintf(cmd.ErrOrStderr(), " %s", name)
}
_, _ = fmt.Fprintln(cmd.ErrOrStderr())
continue
}
targets[a] = dir
}
}
// Create symlinks (or copies) from agent-specific dirs to canonical.
for name, skillsDir := range targets {
agentSkillDir := filepath.Join(base, skillsDir, skillName)
if copyMode {
copyDest := filepath.Join(agentSkillDir, "SKILL.md")
if err := fsutil.EnsureDirForCopy(agentSkillDir); err != nil {
return fmt.Errorf("could not prepare %s directory: %w", name, err)
}
if err := os.MkdirAll(agentSkillDir, 0o755); err != nil {
return fmt.Errorf("could not create %s directory: %w", name, err)
}
if err := os.WriteFile(copyDest, []byte(embedded.SkillMD), 0o644); err != nil {
return fmt.Errorf("could not write %s skill file: %w", name, err)
}
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Copied %s\n", copyDest)
continue
}
// Compute relative symlink target. Symlinks resolve relative to
// the parent directory of the link, not the link itself.
rel, err := filepath.Rel(filepath.Dir(agentSkillDir), canonicalSkillDir)
if err != nil {
return fmt.Errorf("could not compute relative path for %s: %w", name, err)
}
if err := os.MkdirAll(filepath.Dir(agentSkillDir), 0o755); err != nil {
return fmt.Errorf("could not create %s directory: %w", name, err)
}
// Remove existing symlink/dir before creating.
_ = os.Remove(agentSkillDir)
if err := os.Symlink(rel, agentSkillDir); err != nil {
// Fall back to copy if symlink fails (e.g. Windows without dev mode).
copyDest := filepath.Join(agentSkillDir, "SKILL.md")
if mkErr := os.MkdirAll(agentSkillDir, 0o755); mkErr != nil {
return fmt.Errorf("could not create %s directory: %w", name, mkErr)
}
if wErr := os.WriteFile(copyDest, []byte(embedded.SkillMD), 0o644); wErr != nil {
return fmt.Errorf("could not write %s skill file: %w", name, wErr)
}
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Copied %s (symlink failed)\n", copyDest)
continue
}
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Linked %s -> %s\n", agentSkillDir, rel)
}
return nil
},
}
cmd.Flags().BoolVarP(&global, "global", "g", false, "Install to home directory instead of project")
cmd.Flags().BoolVar(&copyMode, "copy", false, "Copy files instead of symlinking")
cmd.Flags().StringSliceVarP(&agents, "agent", "a", nil, "Target specific agents (e.g. claude-code)")
return cmd
}
func installBase(global bool) (string, error) {
if global {
home, err := os.UserHomeDir()
if err != nil {
return "", fmt.Errorf("could not determine home directory: %w", err)
}
return home, nil
}
cwd, err := os.Getwd()
if err != nil {
return "", fmt.Errorf("could not determine working directory: %w", err)
}
return cwd, nil
}

View File

@@ -97,6 +97,7 @@ func Execute() error {
rootCmd.AddCommand(newConfigureCmd())
rootCmd.AddCommand(newValidateConfigCmd())
rootCmd.AddCommand(newServeCmd())
rootCmd.AddCommand(newInstallSkillCmd())
// Default command is chat, but intercept --version first
rootCmd.RunE = func(cmd *cobra.Command, args []string) error {

View File

@@ -23,6 +23,7 @@ import (
"github.com/charmbracelet/wish/ratelimiter"
"github.com/onyx-dot-app/onyx/cli/internal/api"
"github.com/onyx-dot-app/onyx/cli/internal/config"
"github.com/onyx-dot-app/onyx/cli/internal/exitcodes"
"github.com/onyx-dot-app/onyx/cli/internal/tui"
"github.com/spf13/cobra"
"golang.org/x/time/rate"
@@ -295,15 +296,15 @@ provided via the ONYX_API_KEY environment variable to skip the prompt:
The server URL is taken from the server operator's config. The server
auto-generates an Ed25519 host key on first run if the key file does not
already exist. The host key path can also be set via the ONYX_SSH_HOST_KEY
environment variable (the --host-key flag takes precedence).
Example:
onyx-cli serve --port 2222
ssh localhost -p 2222`,
environment variable (the --host-key flag takes precedence).`,
Example: ` onyx-cli serve --port 2222
ssh localhost -p 2222
onyx-cli serve --host 0.0.0.0 --port 2222
onyx-cli serve --idle-timeout 30m --max-session-timeout 2h`,
RunE: func(cmd *cobra.Command, args []string) error {
serverCfg := config.Load()
if serverCfg.ServerURL == "" {
return fmt.Errorf("server URL is not configured; run 'onyx-cli configure' first")
return exitcodes.New(exitcodes.NotConfigured, "server URL is not configured\n Run: onyx-cli configure")
}
if !cmd.Flags().Changed("host-key") {
if v := os.Getenv(config.EnvSSHHostKey); v != "" {

View File

@@ -2,11 +2,13 @@ package cmd
import (
"context"
"errors"
"fmt"
"time"
"github.com/onyx-dot-app/onyx/cli/internal/api"
"github.com/onyx-dot-app/onyx/cli/internal/config"
"github.com/onyx-dot-app/onyx/cli/internal/exitcodes"
"github.com/onyx-dot-app/onyx/cli/internal/version"
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
@@ -16,17 +18,21 @@ func newValidateConfigCmd() *cobra.Command {
return &cobra.Command{
Use: "validate-config",
Short: "Validate configuration and test server connection",
Long: `Check that the CLI is configured, the server is reachable, and the API key
is valid. Also reports the server version and warns if it is below the
minimum required.`,
Example: ` onyx-cli validate-config`,
RunE: func(cmd *cobra.Command, args []string) error {
// Check config file
if !config.ConfigExists() {
return fmt.Errorf("config file not found at %s\n Run 'onyx-cli configure' to set up", config.ConfigFilePath())
return exitcodes.Newf(exitcodes.NotConfigured, "config file not found at %s\n Run: onyx-cli configure", config.ConfigFilePath())
}
cfg := config.Load()
// Check API key
if !cfg.IsConfigured() {
return fmt.Errorf("API key is missing\n Run 'onyx-cli configure' to set up")
return exitcodes.New(exitcodes.NotConfigured, "API key is missing\n Run: onyx-cli configure")
}
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Config: %s\n", config.ConfigFilePath())
@@ -35,7 +41,11 @@ func newValidateConfigCmd() *cobra.Command {
// Test connection
client := api.NewClient(cfg)
if err := client.TestConnection(cmd.Context()); err != nil {
return fmt.Errorf("connection failed: %w", err)
var authErr *api.AuthError
if errors.As(err, &authErr) {
return exitcodes.Newf(exitcodes.AuthFailure, "authentication failed: %v\n Reconfigure with: onyx-cli configure", err)
}
return exitcodes.Newf(exitcodes.Unreachable, "connection failed: %v\n Reconfigure with: onyx-cli configure", err)
}
_, _ = fmt.Fprintln(cmd.OutOrStdout(), "Status: connected and authenticated")

View File

@@ -149,12 +149,12 @@ func (c *Client) TestConnection(ctx context.Context) error {
if resp2.StatusCode == 401 || resp2.StatusCode == 403 {
if isHTML || strings.Contains(respServer, "awselb") {
return fmt.Errorf("HTTP %d from a reverse proxy (not the Onyx backend).\n Check your deployment's ingress / proxy configuration", resp2.StatusCode)
return &AuthError{Message: fmt.Sprintf("HTTP %d from a reverse proxy (not the Onyx backend).\n Check your deployment's ingress / proxy configuration", resp2.StatusCode)}
}
if resp2.StatusCode == 401 {
return fmt.Errorf("invalid API key or token.\n %s", body)
return &AuthError{Message: fmt.Sprintf("invalid API key or token.\n %s", body)}
}
return fmt.Errorf("access denied — check that the API key is valid.\n %s", body)
return &AuthError{Message: fmt.Sprintf("access denied — check that the API key is valid.\n %s", body)}
}
detail := fmt.Sprintf("HTTP %d", resp2.StatusCode)

View File

@@ -11,3 +11,12 @@ type OnyxAPIError struct {
func (e *OnyxAPIError) Error() string {
return fmt.Sprintf("HTTP %d: %s", e.StatusCode, e.Detail)
}
// AuthError is returned when authentication or authorization fails.
type AuthError struct {
Message string
}
func (e *AuthError) Error() string {
return e.Message
}

View File

@@ -59,8 +59,10 @@ func ConfigExists() bool {
return err == nil
}
// Load reads config from file and applies environment variable overrides.
func Load() OnyxCliConfig {
// LoadFromDisk reads config from the file only, without applying environment
// variable overrides. Use this when you need the persisted config values
// (e.g., to preserve them during a save operation).
func LoadFromDisk() OnyxCliConfig {
cfg := DefaultConfig()
data, err := os.ReadFile(ConfigFilePath())
@@ -70,6 +72,13 @@ func Load() OnyxCliConfig {
}
}
return cfg
}
// Load reads config from file and applies environment variable overrides.
func Load() OnyxCliConfig {
cfg := LoadFromDisk()
// Environment overrides
if v := os.Getenv(EnvServerURL); v != "" {
cfg.ServerURL = v

View File

@@ -0,0 +1,187 @@
---
name: onyx-cli
description: Query the Onyx knowledge base using the onyx-cli command. Use when the user wants to search company documents, ask questions about internal knowledge, query connected data sources, or look up information stored in Onyx.
---
# Onyx CLI — Agent Tool
Onyx is an enterprise search and Gen-AI platform that connects to company documents, apps, and people. The `onyx-cli` CLI provides non-interactive commands to query the Onyx knowledge base and list available agents.
## Prerequisites
### 1. Check if installed
```bash
which onyx-cli
```
### 2. Install (if needed)
**Primary — pip:**
```bash
pip install onyx-cli
```
**From source (Go):**
```bash
go build -o onyx-cli github.com/onyx-dot-app/onyx/cli && sudo mv onyx-cli /usr/local/bin/
```
### 3. Check if configured
```bash
onyx-cli validate-config
```
This checks the config file exists, API key is present, and tests the server connection via `/api/me`. Exit code 0 on success, non-zero with a descriptive error on failure.
If unconfigured, you have two options:
**Option A — Interactive setup (requires user input):**
```bash
onyx-cli configure
```
This prompts for the Onyx server URL and API key, tests the connection, and saves config.
**Option B — Environment variables (non-interactive, preferred for agents):**
```bash
export ONYX_SERVER_URL="https://your-onyx-server.com" # default: https://cloud.onyx.app
export ONYX_API_KEY="your-api-key"
```
Environment variables override the config file. If these are set, no config file is needed.
| Variable | Required | Description |
| ----------------- | -------- | -------------------------------------------------------- |
| `ONYX_SERVER_URL` | No | Onyx server base URL (default: `https://cloud.onyx.app`) |
| `ONYX_API_KEY` | Yes | API key for authentication |
| `ONYX_PERSONA_ID` | No | Default agent/persona ID |
If neither the config file nor environment variables are set, tell the user that `onyx-cli` needs to be configured and ask them to either:
- Run `onyx-cli configure` interactively, or
- Set `ONYX_SERVER_URL` and `ONYX_API_KEY` environment variables
## Commands
### Validate configuration
```bash
onyx-cli validate-config
```
Checks config file exists, API key is present, and tests the server connection. Use this before `ask` or `agents` to confirm the CLI is properly set up.
### List available agents
```bash
onyx-cli agents
```
Prints a table of agent IDs, names, and descriptions. Use `--json` for structured output:
```bash
onyx-cli agents --json
```
Use agent IDs with `ask --agent-id` to query a specific agent.
### Basic query (plain text output)
```bash
onyx-cli ask "What is our company's PTO policy?"
```
Streams the answer as plain text to stdout. Exit code 0 on success, non-zero on error.
### JSON output (structured events)
```bash
onyx-cli ask --json "What authentication methods do we support?"
```
Outputs JSON-encoded parsed stream events (one object per line). Key event objects include message deltas, stop, errors, search-start, and citation payloads.
Each line is a JSON object with this envelope:
```json
{"type": "<event_type>", "event": { ... }}
```
| Event Type | Description |
| ------------------- | -------------------------------------------------------------------- |
| `message_delta` | Content token — concatenate all `content` fields for the full answer |
| `stop` | Stream complete |
| `error` | Error with `error` message field |
| `search_tool_start` | Onyx started searching documents |
| `citation_info` | Source citation — see shape below |
`citation_info` event shape:
```json
{
"type": "citation_info",
"event": {
"citation_number": 1,
"document_id": "abc123def456",
"placement": { "turn_index": 0, "tab_index": 0, "sub_turn_index": null }
}
}
```
`placement` is metadata about where in the conversation the citation appeared and can be ignored for most use cases.
### Specify an agent
```bash
onyx-cli ask --agent-id 5 "Summarize our Q4 roadmap"
```
Uses a specific Onyx agent/persona instead of the default.
### All flags
| Flag | Type | Description |
| ------------ | ---- | ---------------------------------------------- |
| `--agent-id` | int | Agent ID to use (overrides default) |
| `--json` | bool | Output raw NDJSON events instead of plain text |
## Statelessness
Each `onyx-cli ask` call creates an independent chat session. There is no built-in way to chain context across multiple `ask` invocations — every call starts fresh. If you need multi-turn conversation with memory, use the interactive TUI (`onyx-cli` or `onyx-cli chat`) instead.
## When to Use
Use `onyx-cli ask` when:
- The user asks about company-specific information (policies, docs, processes)
- You need to search internal knowledge bases or connected data sources
- The user references Onyx, asks you to "search Onyx", or wants to query their documents
- You need context from company wikis, Confluence, Google Drive, Slack, or other connected sources
Do NOT use when:
- The question is about general programming knowledge (use your own knowledge)
- The user is asking about code in the current repository (use grep/read tools)
- The user hasn't mentioned Onyx and the question doesn't require internal company data
## Examples
```bash
# Simple question
onyx-cli ask "What are the steps to deploy to production?"
# Get structured output for parsing
onyx-cli ask --json "List all active API integrations"
# Use a specialized agent
onyx-cli ask --agent-id 3 "What were the action items from last week's standup?"
# Pipe the answer into another command
onyx-cli ask "What is the database schema for users?" | head -20
```

View File

@@ -0,0 +1,7 @@
// Package embedded holds files that are compiled into the onyx-cli binary.
package embedded
import _ "embed"
//go:embed SKILL.md
var SkillMD string

View File

@@ -0,0 +1,33 @@
// Package exitcodes defines semantic exit codes for the Onyx CLI.
package exitcodes
import "fmt"
const (
Success = 0
General = 1
BadRequest = 2 // invalid args / command-line errors (convention)
NotConfigured = 3
AuthFailure = 4
Unreachable = 5
)
// ExitError wraps an error with a specific exit code.
type ExitError struct {
Code int
Err error
}
func (e *ExitError) Error() string {
return e.Err.Error()
}
// New creates an ExitError with the given code and message.
func New(code int, msg string) *ExitError {
return &ExitError{Code: code, Err: fmt.Errorf("%s", msg)}
}
// Newf creates an ExitError with a formatted message.
func Newf(code int, format string, args ...any) *ExitError {
return &ExitError{Code: code, Err: fmt.Errorf(format, args...)}
}

View File

@@ -0,0 +1,40 @@
package exitcodes
import (
"errors"
"fmt"
"testing"
)
func TestExitError_Error(t *testing.T) {
e := New(NotConfigured, "not configured")
if e.Error() != "not configured" {
t.Fatalf("expected 'not configured', got %q", e.Error())
}
if e.Code != NotConfigured {
t.Fatalf("expected code %d, got %d", NotConfigured, e.Code)
}
}
func TestExitError_Newf(t *testing.T) {
e := Newf(Unreachable, "cannot reach %s", "server")
if e.Error() != "cannot reach server" {
t.Fatalf("expected 'cannot reach server', got %q", e.Error())
}
if e.Code != Unreachable {
t.Fatalf("expected code %d, got %d", Unreachable, e.Code)
}
}
func TestExitError_ErrorsAs(t *testing.T) {
e := New(BadRequest, "bad input")
wrapped := fmt.Errorf("wrapper: %w", e)
var exitErr *ExitError
if !errors.As(wrapped, &exitErr) {
t.Fatal("errors.As should find ExitError")
}
if exitErr.Code != BadRequest {
t.Fatalf("expected code %d, got %d", BadRequest, exitErr.Code)
}
}

View File

@@ -0,0 +1,50 @@
// Package fsutil provides filesystem helper functions.
package fsutil
import (
"bytes"
"errors"
"fmt"
"os"
)
// FileStatus describes how an on-disk file compares to expected content.
type FileStatus int
const (
StatusMissing FileStatus = iota
StatusUpToDate // file exists with identical content
StatusDiffers // file exists with different content
)
// CompareFile checks whether the file at path matches the expected content.
func CompareFile(path string, expected []byte) (FileStatus, error) {
existing, err := os.ReadFile(path)
if err != nil {
if errors.Is(err, os.ErrNotExist) {
return StatusMissing, nil
}
return 0, fmt.Errorf("could not read %s: %w", path, err)
}
if bytes.Equal(existing, expected) {
return StatusUpToDate, nil
}
return StatusDiffers, nil
}
// EnsureDirForCopy makes sure path is a real directory, not a symlink or
// regular file. If a symlink or file exists at path it is removed so the
// caller can create a directory with independent content.
func EnsureDirForCopy(path string) error {
info, err := os.Lstat(path)
if err == nil {
if info.Mode()&os.ModeSymlink != 0 || !info.IsDir() {
if err := os.Remove(path); err != nil {
return err
}
}
} else if !errors.Is(err, os.ErrNotExist) {
return err
}
return nil
}

View File

@@ -0,0 +1,116 @@
package fsutil
import (
"os"
"path/filepath"
"testing"
)
// TestCompareFile verifies that CompareFile correctly distinguishes between a
// missing file, a file with matching content, and a file with different content.
func TestCompareFile(t *testing.T) {
tmpDir := t.TempDir()
path := filepath.Join(tmpDir, "skill.md")
expected := []byte("expected content")
status, err := CompareFile(path, expected)
if err != nil {
t.Fatalf("CompareFile on missing file failed: %v", err)
}
if status != StatusMissing {
t.Fatalf("expected StatusMissing, got %v", status)
}
if err := os.WriteFile(path, expected, 0o644); err != nil {
t.Fatalf("write expected file failed: %v", err)
}
status, err = CompareFile(path, expected)
if err != nil {
t.Fatalf("CompareFile on matching file failed: %v", err)
}
if status != StatusUpToDate {
t.Fatalf("expected StatusUpToDate, got %v", status)
}
if err := os.WriteFile(path, []byte("different content"), 0o644); err != nil {
t.Fatalf("write different file failed: %v", err)
}
status, err = CompareFile(path, expected)
if err != nil {
t.Fatalf("CompareFile on different file failed: %v", err)
}
if status != StatusDiffers {
t.Fatalf("expected StatusDiffers, got %v", status)
}
}
// TestEnsureDirForCopy verifies that EnsureDirForCopy clears symlinks and
// regular files so --copy can write a real directory, while leaving existing
// directories and missing paths untouched.
func TestEnsureDirForCopy(t *testing.T) {
t.Run("removes symlink", func(t *testing.T) {
tmpDir := t.TempDir()
targetDir := filepath.Join(tmpDir, "target")
linkPath := filepath.Join(tmpDir, "link")
if err := os.MkdirAll(targetDir, 0o755); err != nil {
t.Fatalf("mkdir target failed: %v", err)
}
if err := os.Symlink(targetDir, linkPath); err != nil {
t.Fatalf("create symlink failed: %v", err)
}
if err := EnsureDirForCopy(linkPath); err != nil {
t.Fatalf("EnsureDirForCopy failed: %v", err)
}
if _, err := os.Lstat(linkPath); !os.IsNotExist(err) {
t.Fatalf("expected symlink path to be removed, got err=%v", err)
}
})
t.Run("removes regular file", func(t *testing.T) {
tmpDir := t.TempDir()
filePath := filepath.Join(tmpDir, "onyx-cli")
if err := os.WriteFile(filePath, []byte("x"), 0o644); err != nil {
t.Fatalf("write file failed: %v", err)
}
if err := EnsureDirForCopy(filePath); err != nil {
t.Fatalf("EnsureDirForCopy failed: %v", err)
}
if _, err := os.Lstat(filePath); !os.IsNotExist(err) {
t.Fatalf("expected file path to be removed, got err=%v", err)
}
})
t.Run("keeps existing directory", func(t *testing.T) {
tmpDir := t.TempDir()
dirPath := filepath.Join(tmpDir, "onyx-cli")
if err := os.MkdirAll(dirPath, 0o755); err != nil {
t.Fatalf("mkdir failed: %v", err)
}
if err := EnsureDirForCopy(dirPath); err != nil {
t.Fatalf("EnsureDirForCopy failed: %v", err)
}
info, err := os.Lstat(dirPath)
if err != nil {
t.Fatalf("lstat directory failed: %v", err)
}
if !info.IsDir() {
t.Fatalf("expected directory to remain, got mode %v", info.Mode())
}
})
t.Run("missing path is no-op", func(t *testing.T) {
tmpDir := t.TempDir()
missingPath := filepath.Join(tmpDir, "does-not-exist")
if err := EnsureDirForCopy(missingPath); err != nil {
t.Fatalf("EnsureDirForCopy failed: %v", err)
}
})
}

View File

@@ -0,0 +1,121 @@
// Package overflow provides a streaming writer that auto-truncates output
// for non-TTY callers (e.g., AI agents, scripts). Full content is saved to
// a temp file on disk; only the first N bytes are printed to stdout.
package overflow
import (
"fmt"
"os"
"strings"
log "github.com/sirupsen/logrus"
)
// Writer handles streaming output with optional truncation.
// When Limit > 0, it streams to a temp file on disk (not memory) and stops
// writing to stdout after Limit bytes. When Limit == 0, it writes directly
// to stdout. In Quiet mode, it buffers in memory and prints once at the end.
type Writer struct {
Limit int
Quiet bool
written int
totalBytes int
truncated bool
buf strings.Builder // used only in quiet mode
tmpFile *os.File // used only in truncation mode (Limit > 0)
}
// Write sends a chunk of content through the writer.
func (w *Writer) Write(s string) {
w.totalBytes += len(s)
// Quiet mode: buffer in memory, print nothing
if w.Quiet {
w.buf.WriteString(s)
return
}
if w.Limit <= 0 {
fmt.Print(s)
return
}
// Truncation mode: stream all content to temp file on disk
if w.tmpFile == nil {
f, err := os.CreateTemp("", "onyx-ask-*.txt")
if err != nil {
// Fall back to no-truncation if we can't create the file
fmt.Fprintf(os.Stderr, "warning: could not create temp file: %v\n", err)
w.Limit = 0
fmt.Print(s)
return
}
w.tmpFile = f
}
if _, err := w.tmpFile.WriteString(s); err != nil {
// Disk write failed — abandon truncation, stream directly to stdout
fmt.Fprintf(os.Stderr, "warning: temp file write failed: %v\n", err)
w.closeTmpFile(true)
w.Limit = 0
w.truncated = false
fmt.Print(s)
return
}
if w.truncated {
return
}
remaining := w.Limit - w.written
if len(s) <= remaining {
fmt.Print(s)
w.written += len(s)
} else {
if remaining > 0 {
fmt.Print(s[:remaining])
w.written += remaining
}
w.truncated = true
}
}
// Finish flushes remaining output. Call once after all Write calls are done.
func (w *Writer) Finish() {
// Quiet mode: print buffered content at once
if w.Quiet {
fmt.Println(w.buf.String())
return
}
if !w.truncated {
w.closeTmpFile(true) // clean up unused temp file
fmt.Println()
return
}
// Close the temp file so it's readable
tmpPath := w.tmpFile.Name()
w.closeTmpFile(false) // close but keep the file
fmt.Printf("\n\n--- response truncated (%d bytes total) ---\n", w.totalBytes)
fmt.Printf("Full response: %s\n", tmpPath)
fmt.Printf("Explore:\n")
fmt.Printf(" cat %s | grep \"<pattern>\"\n", tmpPath)
fmt.Printf(" cat %s | tail -50\n", tmpPath)
}
// closeTmpFile closes and optionally removes the temp file.
func (w *Writer) closeTmpFile(remove bool) {
if w.tmpFile == nil {
return
}
if err := w.tmpFile.Close(); err != nil {
log.Debugf("warning: failed to close temp file: %v", err)
}
if remove {
if err := os.Remove(w.tmpFile.Name()); err != nil {
log.Debugf("warning: failed to remove temp file: %v", err)
}
}
w.tmpFile = nil
}

View File

@@ -0,0 +1,95 @@
package overflow
import (
"os"
"testing"
)
func TestWriter_NoLimit(t *testing.T) {
w := &Writer{Limit: 0}
w.Write("hello world")
if w.truncated {
t.Fatal("should not be truncated with limit 0")
}
if w.totalBytes != 11 {
t.Fatalf("expected 11 total bytes, got %d", w.totalBytes)
}
}
func TestWriter_UnderLimit(t *testing.T) {
w := &Writer{Limit: 100}
w.Write("hello")
w.Write(" world")
if w.truncated {
t.Fatal("should not be truncated when under limit")
}
if w.written != 11 {
t.Fatalf("expected 11 written bytes, got %d", w.written)
}
}
func TestWriter_OverLimit(t *testing.T) {
w := &Writer{Limit: 5}
w.Write("hello world") // 11 bytes, limit 5
if !w.truncated {
t.Fatal("should be truncated")
}
if w.written != 5 {
t.Fatalf("expected 5 written bytes, got %d", w.written)
}
if w.totalBytes != 11 {
t.Fatalf("expected 11 total bytes, got %d", w.totalBytes)
}
if w.tmpFile == nil {
t.Fatal("temp file should have been created")
}
_ = w.tmpFile.Close()
data, _ := os.ReadFile(w.tmpFile.Name())
_ = os.Remove(w.tmpFile.Name())
if string(data) != "hello world" {
t.Fatalf("temp file should contain full content, got %q", string(data))
}
}
func TestWriter_MultipleChunks(t *testing.T) {
w := &Writer{Limit: 10}
w.Write("hello") // 5 bytes
w.Write(" ") // 6 bytes
w.Write("world") // 11 bytes, crosses limit
w.Write("!") // 12 bytes, already truncated
if !w.truncated {
t.Fatal("should be truncated")
}
if w.written != 10 {
t.Fatalf("expected 10 written bytes, got %d", w.written)
}
if w.totalBytes != 12 {
t.Fatalf("expected 12 total bytes, got %d", w.totalBytes)
}
if w.tmpFile == nil {
t.Fatal("temp file should have been created")
}
_ = w.tmpFile.Close()
data, _ := os.ReadFile(w.tmpFile.Name())
_ = os.Remove(w.tmpFile.Name())
if string(data) != "hello world!" {
t.Fatalf("temp file should contain full content, got %q", string(data))
}
}
func TestWriter_QuietMode(t *testing.T) {
w := &Writer{Limit: 0, Quiet: true}
w.Write("hello")
w.Write(" world")
if w.written != 0 {
t.Fatalf("quiet mode should not write to stdout, got %d written", w.written)
}
if w.totalBytes != 11 {
t.Fatalf("expected 11 total bytes, got %d", w.totalBytes)
}
if w.buf.String() != "hello world" {
t.Fatalf("buffer should contain full content, got %q", w.buf.String())
}
}

View File

@@ -0,0 +1,83 @@
// Package starprompt implements a one-time GitHub star prompt shown before the TUI.
// Skipped when stdin/stdout is not a TTY, when gh CLI is not installed,
// or when the user has already been prompted. State is stored in the
// config directory so it shows at most once per user.
package starprompt
import (
"bufio"
"fmt"
"os"
"os/exec"
"path/filepath"
"strings"
"time"
"github.com/onyx-dot-app/onyx/cli/internal/config"
"golang.org/x/term"
)
const repo = "onyx-dot-app/onyx"
func statePath() string {
return filepath.Join(config.ConfigDir(), ".star-prompted")
}
func hasBeenPrompted() bool {
_, err := os.Stat(statePath())
return err == nil
}
func markPrompted() {
_ = os.MkdirAll(config.ConfigDir(), 0o755)
f, err := os.Create(statePath())
if err == nil {
_ = f.Close()
}
}
func isGHInstalled() bool {
_, err := exec.LookPath("gh")
return err == nil
}
// MaybePrompt shows a one-time star prompt if conditions are met.
// It is safe to call unconditionally — it no-ops when not appropriate.
func MaybePrompt() {
if !term.IsTerminal(int(os.Stdin.Fd())) || !term.IsTerminal(int(os.Stdout.Fd())) {
return
}
if hasBeenPrompted() {
return
}
if !isGHInstalled() {
return
}
// Mark before asking so Ctrl+C won't cause a re-prompt.
markPrompted()
fmt.Print("Enjoying Onyx? Star the repo on GitHub? [Y/n] ")
reader := bufio.NewReader(os.Stdin)
answer, _ := reader.ReadString('\n')
answer = strings.TrimSpace(strings.ToLower(answer))
if answer == "n" || answer == "no" {
return
}
cmd := exec.Command("gh", "api", "-X", "PUT", "/user/starred/"+repo)
cmd.Env = append(os.Environ(), "GH_PAGER=")
if devnull, err := os.Open(os.DevNull); err == nil {
defer func() { _ = devnull.Close() }()
cmd.Stdin = devnull
cmd.Stdout = devnull
cmd.Stderr = devnull
}
if err := cmd.Run(); err != nil {
fmt.Println("Star us at: https://github.com/" + repo)
} else {
fmt.Println("Thanks for the star!")
time.Sleep(500 * time.Millisecond)
}
}

View File

@@ -1,10 +1,12 @@
package main
import (
"errors"
"fmt"
"os"
"github.com/onyx-dot-app/onyx/cli/cmd"
"github.com/onyx-dot-app/onyx/cli/internal/exitcodes"
)
var (
@@ -18,6 +20,10 @@ func main() {
if err := cmd.Execute(); err != nil {
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
var exitErr *exitcodes.ExitError
if errors.As(err, &exitErr) {
os.Exit(exitErr.Code)
}
os.Exit(1)
}
}

View File

@@ -1302,4 +1302,18 @@ echo ""
print_info "Refer to the README in the ${INSTALL_ROOT} directory for more information."
echo ""
print_info "For help or issues, contact: founders@onyx.app"
echo ""
echo ""
# --- GitHub star prompt (inspired by oh-my-codex) ---
# Only prompt in interactive mode and only if gh CLI is available.
# Uses the GitHub API directly (PUT /user/starred) like oh-my-codex.
if is_interactive && command -v gh &>/dev/null; then
prompt_yn_or_default "Enjoying Onyx? Star the repo on GitHub? [Y/n] " "Y"
if [[ ! "$REPLY" =~ ^[Nn] ]]; then
if GH_PAGER= gh api -X PUT /user/starred/onyx-dot-app/onyx < /dev/null >/dev/null 2>&1; then
print_success "Thanks for the star!"
else
print_info "Star us at: https://github.com/onyx-dot-app/onyx"
fi
fi
fi

View File

@@ -193,9 +193,9 @@ hover, active, and disabled states.
### Disabled (`core/disabled/`)
Propagates disabled state via React context. `Interactive.Stateless` and `Interactive.Stateful`
consume this automatically, so wrapping a subtree in `<Disabled disabled={true}>` disables all
interactive descendants.
A pure CSS wrapper that applies disabled visuals (`opacity-50`, `cursor-not-allowed`,
`pointer-events: none`) to a single child element via Radix `Slot`. Has no React context —
Interactive primitives and buttons manage their own disabled state via a `disabled` prop.
### Hoverable (`core/animations/`)

View File

@@ -1,9 +1,5 @@
import "@opal/components/tooltip.css";
import {
Disabled,
Interactive,
type InteractiveStatelessProps,
} from "@opal/core";
import { Interactive, type InteractiveStatelessProps } from "@opal/core";
import type {
ContainerSizeVariants,
ExtremaSizeVariants,
@@ -49,7 +45,7 @@ type ButtonProps = InteractiveStatelessProps &
/** Which side the tooltip appears on. */
tooltipSide?: TooltipSide;
/** Wraps the button in a Disabled context. `false` overrides parent contexts. */
/** Applies disabled styling and suppresses clicks. */
disabled?: boolean;
};
@@ -94,7 +90,11 @@ function Button({
) : null;
const button = (
<Interactive.Stateless type={type} {...interactiveProps}>
<Interactive.Stateless
type={type}
disabled={disabled}
{...interactiveProps}
>
<Interactive.Container
type={type}
border={interactiveProps.prominence === "secondary"}
@@ -118,28 +118,24 @@ function Button({
</Interactive.Stateless>
);
const result = tooltip ? (
<TooltipPrimitive.Root>
<TooltipPrimitive.Trigger asChild>{button}</TooltipPrimitive.Trigger>
<TooltipPrimitive.Portal>
<TooltipPrimitive.Content
className="opal-tooltip"
side={tooltipSide}
sideOffset={4}
>
{tooltip}
</TooltipPrimitive.Content>
</TooltipPrimitive.Portal>
</TooltipPrimitive.Root>
) : (
button
);
if (disabled != null) {
return <Disabled disabled={disabled}>{result}</Disabled>;
if (tooltip) {
return (
<TooltipPrimitive.Root>
<TooltipPrimitive.Trigger asChild>{button}</TooltipPrimitive.Trigger>
<TooltipPrimitive.Portal>
<TooltipPrimitive.Content
className="opal-tooltip"
side={tooltipSide}
sideOffset={4}
>
{tooltip}
</TooltipPrimitive.Content>
</TooltipPrimitive.Portal>
</TooltipPrimitive.Root>
);
}
return result;
return button;
}
export { Button, type ButtonProps };

View File

@@ -1,6 +1,5 @@
import {
Interactive,
useDisabled,
type InteractiveStatefulProps,
type InteractiveStatefulInteraction,
} from "@opal/core";
@@ -74,6 +73,9 @@ type OpenButtonProps = Omit<InteractiveStatefulProps, "variant"> & {
/** Override the default rounding derived from `size`. */
roundingVariant?: InteractiveContainerRoundingVariant;
/** Applies disabled styling and suppresses clicks. */
disabled?: boolean;
};
// ---------------------------------------------------------------------------
@@ -92,10 +94,9 @@ function OpenButton({
roundingVariant: roundingVariantOverride,
interaction,
variant = "select-heavy",
disabled,
...statefulProps
}: OpenButtonProps) {
const { isDisabled } = useDisabled();
// Derive open state: explicit prop → Radix data-state (injected via Slot chain)
const dataState = (statefulProps as Record<string, unknown>)["data-state"] as
| string
@@ -119,6 +120,7 @@ function OpenButton({
<Interactive.Stateful
variant={variant}
interaction={resolvedInteraction}
disabled={disabled}
{...statefulProps}
>
<Interactive.Container
@@ -168,7 +170,7 @@ function OpenButton({
);
const resolvedTooltip =
tooltip ?? (foldable && isDisabled && children ? children : undefined);
tooltip ?? (foldable && disabled && children ? children : undefined);
if (!resolvedTooltip) return button;

View File

@@ -1,11 +1,7 @@
"use client";
import "@opal/components/buttons/select-button/styles.css";
import {
Interactive,
useDisabled,
type InteractiveStatefulProps,
} from "@opal/core";
import { Interactive, type InteractiveStatefulProps } from "@opal/core";
import type {
ContainerSizeVariants,
ExtremaSizeVariants,
@@ -64,6 +60,9 @@ type SelectButtonProps = InteractiveStatefulProps &
/** Which side the tooltip appears on. */
tooltipSide?: TooltipSide;
/** Applies disabled styling and suppresses clicks. */
disabled?: boolean;
};
// ---------------------------------------------------------------------------
@@ -80,9 +79,9 @@ function SelectButton({
width,
tooltip,
tooltipSide = "top",
disabled,
...statefulProps
}: SelectButtonProps) {
const { isDisabled } = useDisabled();
const isLarge = size === "lg";
const labelEl = children ? (
@@ -96,7 +95,7 @@ function SelectButton({
) : null;
const button = (
<Interactive.Stateful {...statefulProps}>
<Interactive.Stateful disabled={disabled} {...statefulProps}>
<Interactive.Container
type={type}
heightVariant={size}
@@ -128,7 +127,7 @@ function SelectButton({
);
const resolvedTooltip =
tooltip ?? (foldable && isDisabled && children ? children : undefined);
tooltip ?? (foldable && disabled && children ? children : undefined);
if (!resolvedTooltip) return button;

View File

@@ -0,0 +1,64 @@
# SidebarTab
**Import:** `import { SidebarTab, type SidebarTabProps } from "@opal/components";`
A sidebar navigation tab built on `Interactive.Stateful` > `Interactive.Container`. Designed for admin and app sidebars.
## Architecture
```
div.relative
└─ Interactive.Stateful <- variant (sidebar-heavy | sidebar-light), state, disabled
└─ Interactive.Container <- rounding, height, width
├─ Link? (absolute overlay for client-side navigation)
├─ rightChildren? (absolute, above Link for inline actions)
└─ ContentAction (icon + title + truncation spacer)
```
- **`sidebar-heavy`** (default) — muted when unselected (text-03/text-02), bold when selected (text-04/text-03)
- **`sidebar-light`** — uniformly muted across all states (text-02/text-02)
- **Disabled** — both variants use text-02 foreground, transparent background, no hover/active states
- **Navigation** uses an absolutely positioned `<Link>` overlay rather than `href` on the Interactive element, so `rightChildren` can sit above it with `pointer-events-auto`.
## Props
| Prop | Type | Default | Description |
|------|------|---------|-------------|
| `variant` | `"sidebar-heavy" \| "sidebar-light"` | `"sidebar-heavy"` | Sidebar color variant |
| `selected` | `boolean` | `false` | Active/selected state |
| `icon` | `IconFunctionComponent` | — | Left icon |
| `children` | `ReactNode` | — | Label text or custom content |
| `disabled` | `boolean` | `false` | Disables the tab |
| `folded` | `boolean` | `false` | Collapses label, shows tooltip on hover |
| `nested` | `boolean` | `false` | Renders spacer instead of icon for indented items |
| `href` | `string` | — | Client-side navigation URL |
| `onClick` | `MouseEventHandler` | — | Click handler |
| `type` | `ButtonType` | — | HTML button type |
| `rightChildren` | `ReactNode` | — | Actions rendered on the right side |
## Usage
```tsx
import { SidebarTab } from "@opal/components";
import { SvgSettings, SvgLock } from "@opal/icons";
// Active tab
<SidebarTab icon={SvgSettings} href="/admin/settings" selected>
Settings
</SidebarTab>
// Muted variant
<SidebarTab icon={SvgSettings} variant="sidebar-light">
Exit Admin Panel
</SidebarTab>
// Disabled enterprise-only tab
<SidebarTab icon={SvgLock} disabled>
Groups
</SidebarTab>
// Folded sidebar (icon only, tooltip on hover)
<SidebarTab icon={SvgSettings} href="/admin/settings" folded>
Settings
</SidebarTab>
```

View File

@@ -0,0 +1,94 @@
import type { Meta, StoryObj } from "@storybook/react";
import { SidebarTab } from "@opal/components/buttons/sidebar-tab/components";
import {
SvgSettings,
SvgUsers,
SvgLock,
SvgArrowUpCircle,
SvgTrash,
} from "@opal/icons";
import { Button } from "@opal/components";
import * as TooltipPrimitive from "@radix-ui/react-tooltip";
const meta: Meta<typeof SidebarTab> = {
title: "opal/components/SidebarTab",
component: SidebarTab,
tags: ["autodocs"],
decorators: [
(Story) => (
<TooltipPrimitive.Provider>
<div style={{ width: 260, background: "var(--background-neutral-01)" }}>
<Story />
</div>
</TooltipPrimitive.Provider>
),
],
};
export default meta;
type Story = StoryObj<typeof SidebarTab>;
export const Default: Story = {
args: {
icon: SvgSettings,
children: "Settings",
},
};
export const Selected: Story = {
args: {
icon: SvgSettings,
children: "Settings",
selected: true,
},
};
export const Light: Story = {
args: {
icon: SvgSettings,
children: "Settings",
variant: "sidebar-light",
},
};
export const Disabled: Story = {
args: {
icon: SvgLock,
children: "Enterprise Only",
disabled: true,
},
};
export const WithRightChildren: Story = {
args: {
icon: SvgUsers,
children: "Users",
rightChildren: (
<Button
icon={SvgTrash}
size="xs"
prominence="tertiary"
variant="danger"
/>
),
},
};
export const SidebarExample: Story = {
render: () => (
<div className="flex flex-col">
<SidebarTab icon={SvgSettings} selected>
LLM Models
</SidebarTab>
<SidebarTab icon={SvgSettings}>Web Search</SidebarTab>
<SidebarTab icon={SvgUsers}>Users</SidebarTab>
<SidebarTab icon={SvgLock} disabled>
Groups
</SidebarTab>
<SidebarTab icon={SvgLock} disabled>
SCIM
</SidebarTab>
<SidebarTab icon={SvgArrowUpCircle}>Upgrade Plan</SidebarTab>
</div>
),
};

View File

@@ -3,32 +3,66 @@
import React from "react";
import type { ButtonType, IconFunctionComponent } from "@opal/types";
import type { Route } from "next";
import { Interactive } from "@opal/core";
import { Interactive, type InteractiveStatefulVariant } from "@opal/core";
import { ContentAction } from "@opal/layouts";
import { Text } from "@opal/components";
import Link from "next/link";
import SimpleTooltip from "@/refresh-components/SimpleTooltip";
import * as TooltipPrimitive from "@radix-ui/react-tooltip";
import "@opal/components/tooltip.css";
export interface SidebarTabProps {
// Button states:
// ---------------------------------------------------------------------------
// Types
// ---------------------------------------------------------------------------
interface SidebarTabProps {
/** Collapses the label, showing only the icon. */
folded?: boolean;
/** Marks this tab as the currently active/selected item. */
selected?: boolean;
lowlight?: boolean;
/**
* Sidebar color variant.
* @default "sidebar-heavy"
*/
variant?: Extract<
InteractiveStatefulVariant,
"sidebar-light" | "sidebar-heavy"
>;
/** Renders an empty spacer in place of the icon for nested items. */
nested?: boolean;
// Button properties:
/** Disables the tab — applies muted colors and suppresses clicks. */
disabled?: boolean;
onClick?: React.MouseEventHandler<HTMLElement>;
href?: string;
type?: ButtonType;
icon?: IconFunctionComponent;
children?: React.ReactNode;
/** Content rendered on the right side (e.g. action buttons). */
rightChildren?: React.ReactNode;
}
export default function SidebarTab({
// ---------------------------------------------------------------------------
// SidebarTab
// ---------------------------------------------------------------------------
/**
* Sidebar navigation tab built on `Interactive.Stateful` > `Interactive.Container`.
*
* Uses `sidebar-heavy` (default) or `sidebar-light` (via `variant`) variants
* for color styling. Supports an overlay `Link` for client-side navigation,
* `rightChildren` for inline actions, and folded mode with an auto-tooltip.
*/
function SidebarTab({
folded,
selected,
lowlight,
variant = "sidebar-heavy",
nested,
disabled,
onClick,
href,
@@ -45,11 +79,8 @@ export default function SidebarTab({
)) as IconFunctionComponent)
: null);
// NOTE (@raunakab)
//
// The `rightChildren` node NEEDS to be absolutely positioned since it needs to live on top of the absolutely positioned `Link`.
// However, having the `rightChildren` be absolutely positioned means that it cannot appropriately truncate the title.
// Therefore, we add a dummy node solely for the truncation effects that we obtain.
// The `rightChildren` node is absolutely positioned to sit on top of the
// overlay Link. A zero-width spacer reserves truncation space for the title.
const truncationSpacer = rightChildren && (
<div className="w-0 group-hover/SidebarTab:w-6" />
);
@@ -57,8 +88,9 @@ export default function SidebarTab({
const content = (
<div className="relative">
<Interactive.Stateful
variant={lowlight ? "sidebar-light" : "sidebar-heavy"}
variant={variant}
state={selected ? "selected" : "empty"}
disabled={disabled}
onClick={onClick}
type="button"
group="group/SidebarTab"
@@ -69,7 +101,7 @@ export default function SidebarTab({
widthVariant="full"
type={type}
>
{href && (
{href && !disabled && (
<Link
href={href as Route}
scroll={false}
@@ -95,19 +127,14 @@ export default function SidebarTab({
rightChildren={truncationSpacer}
/>
) : (
<div className="flex flex-row items-center gap-2 flex-1">
<div className="flex flex-row items-center gap-2 w-full">
{Icon && (
<div className="flex items-center justify-center p-0.5">
<Icon className="h-[1rem] w-[1rem] text-text-03" />
</div>
)}
{children}
{
// NOTE (@raunakab)
//
// Adding the `truncationSpacer` here for the same reason as above.
truncationSpacer
}
{truncationSpacer}
</div>
)}
</Interactive.Container>
@@ -116,7 +143,23 @@ export default function SidebarTab({
);
if (typeof children !== "string") return content;
if (folded)
return <SimpleTooltip tooltip={children}>{content}</SimpleTooltip>;
if (folded) {
return (
<TooltipPrimitive.Root>
<TooltipPrimitive.Trigger asChild>{content}</TooltipPrimitive.Trigger>
<TooltipPrimitive.Portal>
<TooltipPrimitive.Content
className="opal-tooltip"
side="right"
sideOffset={4}
>
{children}
</TooltipPrimitive.Content>
</TooltipPrimitive.Portal>
</TooltipPrimitive.Root>
);
}
return content;
}
export { SidebarTab, type SidebarTabProps };

View File

@@ -33,6 +33,12 @@ export {
type LineItemButtonProps,
} from "@opal/components/buttons/line-item-button/components";
/* SidebarTab */
export {
SidebarTab,
type SidebarTabProps,
} from "@opal/components/buttons/sidebar-tab/components";
/* Text */
export {
Text,

View File

@@ -586,7 +586,10 @@ export function Table<TData>(props: DataTableProps<TData>) {
// Data / Display cell
return (
<TableCell key={cell.id}>
<TableCell
key={cell.id}
data-column-id={cell.column.id}
>
{flexRender(
cell.column.columnDef.cell,
cell.getContext()

View File

@@ -1,33 +1,7 @@
"use client";
import "@opal/core/disabled/styles.css";
import React, { createContext, useContext } from "react";
import React from "react";
import { Slot } from "@radix-ui/react-slot";
// ---------------------------------------------------------------------------
// Context
// ---------------------------------------------------------------------------
interface DisabledContextValue {
isDisabled: boolean;
allowClick: boolean;
}
const DisabledContext = createContext<DisabledContextValue>({
isDisabled: false,
allowClick: false,
});
/**
* Returns the current disabled state from the nearest `<Disabled>` ancestor.
*
* Used internally by `Interactive.Stateless` and `Interactive.Stateful` to
* derive `data-disabled` and `aria-disabled` attributes automatically.
*/
function useDisabled(): DisabledContextValue {
return useContext(DisabledContext);
}
// ---------------------------------------------------------------------------
// Types
// ---------------------------------------------------------------------------
@@ -56,8 +30,8 @@ interface DisabledProps extends React.HTMLAttributes<HTMLElement> {
// ---------------------------------------------------------------------------
/**
* Wrapper component that propagates disabled state via context and applies
* baseline disabled CSS (opacity, cursor, pointer-events) to its child.
* Wrapper component that applies baseline disabled CSS (opacity, cursor,
* pointer-events) to its child element.
*
* Uses Radix `Slot` — merges props onto the single child element without
* adding any DOM node. Works correctly inside Radix `asChild` chains.
@@ -65,7 +39,7 @@ interface DisabledProps extends React.HTMLAttributes<HTMLElement> {
* @example
* ```tsx
* <Disabled disabled={!canSubmit}>
* <Button onClick={handleSubmit}>Save</Button>
* <div>...</div>
* </Disabled>
* ```
*/
@@ -77,20 +51,16 @@ function Disabled({
...rest
}: DisabledProps) {
return (
<DisabledContext.Provider
value={{ isDisabled: !!disabled, allowClick: !!allowClick }}
<Slot
ref={ref}
{...rest}
aria-disabled={disabled || undefined}
data-opal-disabled={disabled || undefined}
data-allow-click={disabled && allowClick ? "" : undefined}
>
<Slot
ref={ref}
{...rest}
aria-disabled={disabled || undefined}
data-opal-disabled={disabled || undefined}
data-allow-click={disabled && allowClick ? "" : undefined}
>
{children}
</Slot>
</DisabledContext.Provider>
{children}
</Slot>
);
}
export { Disabled, useDisabled, type DisabledProps, type DisabledContextValue };
export { Disabled, type DisabledProps };

View File

@@ -1,10 +1,5 @@
/* Disabled */
export {
Disabled,
useDisabled,
type DisabledProps,
type DisabledContextValue,
} from "@opal/core/disabled/components";
export { Disabled, type DisabledProps } from "@opal/core/disabled/components";
/* Animations (formerly Hoverable) */
export {

View File

@@ -10,7 +10,6 @@ import {
widthVariants,
type ExtremaSizeVariants,
} from "@opal/shared";
import { useDisabled } from "@opal/core/disabled/components";
// ---------------------------------------------------------------------------
// Types
@@ -102,7 +101,6 @@ function InteractiveContainer({
widthVariant = "fit",
...props
}: InteractiveContainerProps) {
const { allowClick } = useDisabled();
const {
className: slotClassName,
style: slotStyle,
@@ -148,8 +146,7 @@ function InteractiveContainer({
if (type) {
const ariaDisabled = (rest as Record<string, unknown>)["aria-disabled"];
const nativeDisabled =
(type === "submit" || !allowClick) &&
(ariaDisabled === true || ariaDisabled === "true" || undefined);
ariaDisabled === true || ariaDisabled === "true" || undefined;
return (
<button
ref={ref as React.Ref<HTMLButtonElement>}

View File

@@ -1,7 +1,6 @@
import React from "react";
import { Slot } from "@radix-ui/react-slot";
import { cn } from "@opal/utils";
import { useDisabled } from "@opal/core/disabled/components";
import { guardPortalClick } from "@opal/core/interactive/utils";
// ---------------------------------------------------------------------------
@@ -29,6 +28,11 @@ interface InteractiveSimpleProps
* Link target (e.g. `"_blank"`). Only used when `href` is provided.
*/
target?: string;
/**
* Applies disabled cursor and suppresses clicks.
*/
disabled?: boolean;
}
// ---------------------------------------------------------------------------
@@ -38,8 +42,8 @@ interface InteractiveSimpleProps
/**
* Minimal interactive surface primitive.
*
* Provides cursor styling, click handling, disabled integration, and
* optional link/group support — but **no color or background styling**.
* Provides cursor styling, click handling, and optional link/group
* support — but **no color or background styling**.
*
* Use this for elements that need interactivity (click, cursor, disabled)
* without participating in the Interactive color system.
@@ -59,9 +63,10 @@ function InteractiveSimple({
group,
href,
target,
disabled,
...props
}: InteractiveSimpleProps) {
const { isDisabled, allowClick } = useDisabled();
const isDisabled = !!disabled;
const classes = cn(
"cursor-pointer select-none",
@@ -88,7 +93,7 @@ function InteractiveSimple({
{...linkAttrs}
{...slotProps}
onClick={
isDisabled && !allowClick
isDisabled
? href
? (e: React.MouseEvent) => e.preventDefault()
: undefined

View File

@@ -3,7 +3,6 @@ import "@opal/core/interactive/stateful/styles.css";
import React from "react";
import { Slot } from "@radix-ui/react-slot";
import { cn } from "@opal/utils";
import { useDisabled } from "@opal/core/disabled/components";
import { guardPortalClick } from "@opal/core/interactive/utils";
import type { ButtonType, WithoutStyles } from "@opal/types";
@@ -87,6 +86,11 @@ interface InteractiveStatefulProps
* Link target (e.g. `"_blank"`). Only used when `href` is provided.
*/
target?: string;
/**
* Applies variant-specific disabled colors and suppresses clicks.
*/
disabled?: boolean;
}
// ---------------------------------------------------------------------------
@@ -100,8 +104,7 @@ interface InteractiveStatefulProps
* (empty/filled/selected). Applies variant/state color styling via CSS
* data-attributes and merges onto a single child element via Radix `Slot`.
*
* Disabled state is consumed from the nearest `<Disabled>` ancestor via
* context — there is no `disabled` prop on this component.
* Disabled state is controlled via the `disabled` prop.
*/
function InteractiveStateful({
ref,
@@ -112,9 +115,10 @@ function InteractiveStateful({
type,
href,
target,
disabled,
...props
}: InteractiveStatefulProps) {
const { isDisabled, allowClick } = useDisabled();
const isDisabled = !!disabled;
// onClick/href are always passed directly — Stateful is the outermost Slot,
// so Radix Slot-injected handlers don't bypass this guard.
@@ -150,7 +154,7 @@ function InteractiveStateful({
{...linkAttrs}
{...slotProps}
onClick={
isDisabled && !allowClick
isDisabled
? href
? (e: React.MouseEvent) => e.preventDefault()
: undefined

View File

@@ -550,6 +550,14 @@
) {
@apply bg-background-tint-03;
}
/* ---------------------------------------------------------------------------
Sidebar-Heavy — Disabled (all states)
--------------------------------------------------------------------------- */
.interactive[data-interactive-variant="sidebar-heavy"][data-disabled] {
@apply bg-transparent opacity-50;
--interactive-foreground: var(--text-03);
--interactive-foreground-icon: var(--text-03);
}
/* ===========================================================================
Sidebar-Light
@@ -607,3 +615,11 @@
) {
@apply bg-background-tint-03;
}
/* ---------------------------------------------------------------------------
Sidebar-Light — Disabled (all states)
--------------------------------------------------------------------------- */
.interactive[data-interactive-variant="sidebar-light"][data-disabled] {
@apply bg-transparent opacity-50;
--interactive-foreground: var(--text-03);
--interactive-foreground-icon: var(--text-03);
}

View File

@@ -3,7 +3,6 @@ import "@opal/core/interactive/stateless/styles.css";
import React from "react";
import { Slot } from "@radix-ui/react-slot";
import { cn } from "@opal/utils";
import { useDisabled } from "@opal/core/disabled/components";
import { guardPortalClick } from "@opal/core/interactive/utils";
import type { ButtonType, WithoutStyles } from "@opal/types";
@@ -70,6 +69,11 @@ interface InteractiveStatelessProps
* Link target (e.g. `"_blank"`). Only used when `href` is provided.
*/
target?: string;
/**
* Applies variant-specific disabled colors and suppresses clicks.
*/
disabled?: boolean;
}
// ---------------------------------------------------------------------------
@@ -84,8 +88,7 @@ interface InteractiveStatelessProps
* color styling via CSS data-attributes and merges onto a single child
* element via Radix `Slot`.
*
* Disabled state is consumed from the nearest `<Disabled>` ancestor via
* context — there is no `disabled` prop on this component.
* Disabled state is controlled via the `disabled` prop.
*/
function InteractiveStateless({
ref,
@@ -96,9 +99,10 @@ function InteractiveStateless({
type,
href,
target,
disabled,
...props
}: InteractiveStatelessProps) {
const { isDisabled, allowClick } = useDisabled();
const isDisabled = !!disabled;
// onClick/href are always passed directly — Stateless is the outermost Slot,
// so Radix Slot-injected handlers don't bypass this guard.
@@ -134,7 +138,7 @@ function InteractiveStateless({
{...linkAttrs}
{...slotProps}
onClick={
isDisabled && !allowClick
isDisabled
? href
? (e: React.MouseEvent) => e.preventDefault()
: undefined

View File

@@ -8,7 +8,6 @@ import * as InputLayouts from "@/layouts/input-layouts";
import Card from "@/refresh-components/cards/Card";
import Button from "@/refresh-components/buttons/Button";
import { Button as OpalButton } from "@opal/components";
import { Disabled } from "@opal/core";
import Text from "@/refresh-components/texts/Text";
import Message from "@/refresh-components/messages/Message";
import InfoBlock from "@/refresh-components/messages/InfoBlock";
@@ -246,15 +245,14 @@ function SubscriptionCard({
to make changes.
</Text>
) : disabled ? (
<Disabled disabled={isReconnecting}>
<OpalButton
prominence="secondary"
onClick={handleReconnect}
rightIcon={SvgArrowRight}
>
{isReconnecting ? "Connecting..." : "Connect to Stripe"}
</OpalButton>
</Disabled>
<OpalButton
disabled={isReconnecting}
prominence="secondary"
onClick={handleReconnect}
rightIcon={SvgArrowRight}
>
{isReconnecting ? "Connecting..." : "Connect to Stripe"}
</OpalButton>
) : (
<OpalButton onClick={handleManagePlan} rightIcon={SvgExternalLink}>
Manage Plan
@@ -377,11 +375,13 @@ function SeatsCard({
sizePreset="main-content"
variant="section"
/>
<Disabled disabled={isSubmitting}>
<OpalButton prominence="secondary" onClick={handleCancel}>
Cancel
</OpalButton>
</Disabled>
<OpalButton
disabled={isSubmitting}
prominence="secondary"
onClick={handleCancel}
>
Cancel
</OpalButton>
</Section>
<div className="billing-content-area">
@@ -463,15 +463,14 @@ function SeatsCard({
No changes to your billing.
</Text>
)}
<Disabled
<OpalButton
disabled={
isSubmitting || newSeatCount === totalSeats || isBelowMinimum
}
onClick={handleConfirm}
>
<OpalButton onClick={handleConfirm}>
{isSubmitting ? "Saving..." : "Confirm Change"}
</OpalButton>
</Disabled>
{isSubmitting ? "Saving..." : "Confirm Change"}
</OpalButton>
</Section>
</Card>
);
@@ -509,15 +508,14 @@ function SeatsCard({
View Users
</OpalButton>
{!hideUpdateSeats && (
<Disabled disabled={isLoadingUsers || disabled || !billing}>
<OpalButton
prominence="secondary"
onClick={handleStartEdit}
icon={SvgPlus}
>
Update Seats
</OpalButton>
</Disabled>
<OpalButton
disabled={isLoadingUsers || disabled || !billing}
prominence="secondary"
onClick={handleStartEdit}
icon={SvgPlus}
>
Update Seats
</OpalButton>
)}
</Section>
</Section>

Some files were not shown because too many files have changed in this diff Show More