mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-04 06:22:44 +00:00
Compare commits
23 Commits
multi-mode
...
jamison/wo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b59f8cf453 | ||
|
|
456ecc7b9a | ||
|
|
fdc2bc9ee2 | ||
|
|
1c3f371549 | ||
|
|
a120add37b | ||
|
|
757e4e979b | ||
|
|
cbcdfee56e | ||
|
|
b06700314b | ||
|
|
01f573cdcb | ||
|
|
d4a96d70f3 | ||
|
|
5b000c2173 | ||
|
|
d62af28e40 | ||
|
|
593678a14f | ||
|
|
e6f7c2b45c | ||
|
|
f77128d929 | ||
|
|
1d4ca769e7 | ||
|
|
e002f6c195 | ||
|
|
10d696262f | ||
|
|
608e151443 | ||
|
|
41d1a33093 | ||
|
|
f396ebbdbb | ||
|
|
67c8df002e | ||
|
|
722f7de335 |
@@ -1,186 +0,0 @@
|
||||
---
|
||||
name: onyx-cli
|
||||
description: Query the Onyx knowledge base using the onyx-cli command. Use when the user wants to search company documents, ask questions about internal knowledge, query connected data sources, or look up information stored in Onyx.
|
||||
---
|
||||
|
||||
# Onyx CLI — Agent Tool
|
||||
|
||||
Onyx is an enterprise search and Gen-AI platform that connects to company documents, apps, and people. The `onyx-cli` CLI provides non-interactive commands to query the Onyx knowledge base and list available agents.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
### 1. Check if installed
|
||||
|
||||
```bash
|
||||
which onyx-cli
|
||||
```
|
||||
|
||||
### 2. Install (if needed)
|
||||
|
||||
**Primary — pip:**
|
||||
|
||||
```bash
|
||||
pip install onyx-cli
|
||||
```
|
||||
|
||||
**From source (Go):**
|
||||
|
||||
```bash
|
||||
cd cli && go build -o onyx-cli . && sudo mv onyx-cli /usr/local/bin/
|
||||
```
|
||||
|
||||
### 3. Check if configured
|
||||
|
||||
```bash
|
||||
onyx-cli validate-config
|
||||
```
|
||||
|
||||
This checks the config file exists, API key is present, and tests the server connection via `/api/me`. Exit code 0 on success, non-zero with a descriptive error on failure.
|
||||
|
||||
If unconfigured, you have two options:
|
||||
|
||||
**Option A — Interactive setup (requires user input):**
|
||||
|
||||
```bash
|
||||
onyx-cli configure
|
||||
```
|
||||
|
||||
This prompts for the Onyx server URL and API key, tests the connection, and saves config.
|
||||
|
||||
**Option B — Environment variables (non-interactive, preferred for agents):**
|
||||
|
||||
```bash
|
||||
export ONYX_SERVER_URL="https://your-onyx-server.com" # default: https://cloud.onyx.app
|
||||
export ONYX_API_KEY="your-api-key"
|
||||
```
|
||||
|
||||
Environment variables override the config file. If these are set, no config file is needed.
|
||||
|
||||
| Variable | Required | Description |
|
||||
|----------|----------|-------------|
|
||||
| `ONYX_SERVER_URL` | No | Onyx server base URL (default: `https://cloud.onyx.app`) |
|
||||
| `ONYX_API_KEY` | Yes | API key for authentication |
|
||||
| `ONYX_PERSONA_ID` | No | Default agent/persona ID |
|
||||
|
||||
If neither the config file nor environment variables are set, tell the user that `onyx-cli` needs to be configured and ask them to either:
|
||||
- Run `onyx-cli configure` interactively, or
|
||||
- Set `ONYX_SERVER_URL` and `ONYX_API_KEY` environment variables
|
||||
|
||||
## Commands
|
||||
|
||||
### Validate configuration
|
||||
|
||||
```bash
|
||||
onyx-cli validate-config
|
||||
```
|
||||
|
||||
Checks config file exists, API key is present, and tests the server connection. Use this before `ask` or `agents` to confirm the CLI is properly set up.
|
||||
|
||||
### List available agents
|
||||
|
||||
```bash
|
||||
onyx-cli agents
|
||||
```
|
||||
|
||||
Prints a table of agent IDs, names, and descriptions. Use `--json` for structured output:
|
||||
|
||||
```bash
|
||||
onyx-cli agents --json
|
||||
```
|
||||
|
||||
Use agent IDs with `ask --agent-id` to query a specific agent.
|
||||
|
||||
### Basic query (plain text output)
|
||||
|
||||
```bash
|
||||
onyx-cli ask "What is our company's PTO policy?"
|
||||
```
|
||||
|
||||
Streams the answer as plain text to stdout. Exit code 0 on success, non-zero on error.
|
||||
|
||||
### JSON output (structured events)
|
||||
|
||||
```bash
|
||||
onyx-cli ask --json "What authentication methods do we support?"
|
||||
```
|
||||
|
||||
Outputs JSON-encoded parsed stream events (one object per line). Key event objects include message deltas, stop, errors, search-start, and citation payloads.
|
||||
|
||||
Each line is a JSON object with this envelope:
|
||||
|
||||
```json
|
||||
{"type": "<event_type>", "event": { ... }}
|
||||
```
|
||||
|
||||
| Event Type | Description |
|
||||
|------------|-------------|
|
||||
| `message_delta` | Content token — concatenate all `content` fields for the full answer |
|
||||
| `stop` | Stream complete |
|
||||
| `error` | Error with `error` message field |
|
||||
| `search_tool_start` | Onyx started searching documents |
|
||||
| `citation_info` | Source citation — see shape below |
|
||||
|
||||
`citation_info` event shape:
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "citation_info",
|
||||
"event": {
|
||||
"citation_number": 1,
|
||||
"document_id": "abc123def456",
|
||||
"placement": {"turn_index": 0, "tab_index": 0, "sub_turn_index": null}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
`placement` is metadata about where in the conversation the citation appeared and can be ignored for most use cases.
|
||||
|
||||
### Specify an agent
|
||||
|
||||
```bash
|
||||
onyx-cli ask --agent-id 5 "Summarize our Q4 roadmap"
|
||||
```
|
||||
|
||||
Uses a specific Onyx agent/persona instead of the default.
|
||||
|
||||
### All flags
|
||||
|
||||
| Flag | Type | Description |
|
||||
|------|------|-------------|
|
||||
| `--agent-id` | int | Agent ID to use (overrides default) |
|
||||
| `--json` | bool | Output raw NDJSON events instead of plain text |
|
||||
|
||||
## Statelessness
|
||||
|
||||
Each `onyx-cli ask` call creates an independent chat session. There is no built-in way to chain context across multiple `ask` invocations — every call starts fresh. If you need multi-turn conversation with memory, use the interactive TUI (`onyx-cli` or `onyx-cli chat`) instead.
|
||||
|
||||
## When to Use
|
||||
|
||||
Use `onyx-cli ask` when:
|
||||
|
||||
- The user asks about company-specific information (policies, docs, processes)
|
||||
- You need to search internal knowledge bases or connected data sources
|
||||
- The user references Onyx, asks you to "search Onyx", or wants to query their documents
|
||||
- You need context from company wikis, Confluence, Google Drive, Slack, or other connected sources
|
||||
|
||||
Do NOT use when:
|
||||
|
||||
- The question is about general programming knowledge (use your own knowledge)
|
||||
- The user is asking about code in the current repository (use grep/read tools)
|
||||
- The user hasn't mentioned Onyx and the question doesn't require internal company data
|
||||
|
||||
## Examples
|
||||
|
||||
```bash
|
||||
# Simple question
|
||||
onyx-cli ask "What are the steps to deploy to production?"
|
||||
|
||||
# Get structured output for parsing
|
||||
onyx-cli ask --json "List all active API integrations"
|
||||
|
||||
# Use a specialized agent
|
||||
onyx-cli ask --agent-id 3 "What were the action items from last week's standup?"
|
||||
|
||||
# Pipe the answer into another command
|
||||
onyx-cli ask "What is the database schema for users?" | head -20
|
||||
```
|
||||
1
.cursor/skills/onyx-cli/SKILL.md
Symbolic link
1
.cursor/skills/onyx-cli/SKILL.md
Symbolic link
@@ -0,0 +1 @@
|
||||
../../../cli/internal/embedded/SKILL.md
|
||||
108
backend/alembic/versions/03d085c5c38d_backfill_account_type.py
Normal file
108
backend/alembic/versions/03d085c5c38d_backfill_account_type.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""backfill_account_type
|
||||
|
||||
Revision ID: 03d085c5c38d
|
||||
Revises: 977e834c1427
|
||||
Create Date: 2026-03-25 16:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "03d085c5c38d"
|
||||
down_revision = "977e834c1427"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
_STANDARD = "STANDARD"
|
||||
_BOT = "BOT"
|
||||
_EXT_PERM_USER = "EXT_PERM_USER"
|
||||
_SERVICE_ACCOUNT = "SERVICE_ACCOUNT"
|
||||
_ANONYMOUS = "ANONYMOUS"
|
||||
|
||||
# Well-known anonymous user UUID
|
||||
ANONYMOUS_USER_ID = "00000000-0000-0000-0000-000000000002"
|
||||
|
||||
# Email pattern for API key virtual users
|
||||
API_KEY_EMAIL_PATTERN = r"API\_KEY\_\_%"
|
||||
|
||||
# Reflect the table structure for use in DML
|
||||
user_table = sa.table(
|
||||
"user",
|
||||
sa.column("id", sa.Uuid),
|
||||
sa.column("email", sa.String),
|
||||
sa.column("role", sa.String),
|
||||
sa.column("account_type", sa.String),
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ------------------------------------------------------------------
|
||||
# Step 1: Backfill account_type from role.
|
||||
# Order matters — most-specific matches first so the final catch-all
|
||||
# only touches rows that haven't been classified yet.
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
# 1a. API key virtual users → SERVICE_ACCOUNT
|
||||
op.execute(
|
||||
sa.update(user_table)
|
||||
.where(
|
||||
user_table.c.email.ilike(API_KEY_EMAIL_PATTERN),
|
||||
user_table.c.account_type.is_(None),
|
||||
)
|
||||
.values(account_type=_SERVICE_ACCOUNT)
|
||||
)
|
||||
|
||||
# 1b. Anonymous user → ANONYMOUS
|
||||
op.execute(
|
||||
sa.update(user_table)
|
||||
.where(
|
||||
user_table.c.id == ANONYMOUS_USER_ID,
|
||||
user_table.c.account_type.is_(None),
|
||||
)
|
||||
.values(account_type=_ANONYMOUS)
|
||||
)
|
||||
|
||||
# 1c. SLACK_USER role → BOT
|
||||
op.execute(
|
||||
sa.update(user_table)
|
||||
.where(
|
||||
user_table.c.role == "SLACK_USER",
|
||||
user_table.c.account_type.is_(None),
|
||||
)
|
||||
.values(account_type=_BOT)
|
||||
)
|
||||
|
||||
# 1d. EXT_PERM_USER role → EXT_PERM_USER
|
||||
op.execute(
|
||||
sa.update(user_table)
|
||||
.where(
|
||||
user_table.c.role == "EXT_PERM_USER",
|
||||
user_table.c.account_type.is_(None),
|
||||
)
|
||||
.values(account_type=_EXT_PERM_USER)
|
||||
)
|
||||
|
||||
# 1e. Everything else → STANDARD
|
||||
op.execute(
|
||||
sa.update(user_table)
|
||||
.where(user_table.c.account_type.is_(None))
|
||||
.values(account_type=_STANDARD)
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 2: Set account_type to NOT NULL now that every row is filled.
|
||||
# ------------------------------------------------------------------
|
||||
op.alter_column(
|
||||
"user",
|
||||
"account_type",
|
||||
nullable=False,
|
||||
server_default="STANDARD",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.alter_column("user", "account_type", nullable=True, server_default=None)
|
||||
op.execute(sa.update(user_table).values(account_type=None))
|
||||
@@ -0,0 +1,104 @@
|
||||
"""add_effective_permissions
|
||||
|
||||
Adds a JSONB column `effective_permissions` to the user table to store
|
||||
directly granted permissions (e.g. ["admin"] or ["basic"]). Implied
|
||||
permissions are expanded at read time, not stored.
|
||||
|
||||
Backfill: joins user__user_group → permission_grant to collect each
|
||||
user's granted permissions into a JSON array. Users without group
|
||||
memberships keep the default [].
|
||||
|
||||
Revision ID: 503883791c39
|
||||
Revises: b4b7e1028dfd
|
||||
Create Date: 2026-03-30 14:49:22.261748
|
||||
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "503883791c39"
|
||||
down_revision = "b4b7e1028dfd"
|
||||
branch_labels: str | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
user_table = sa.table(
|
||||
"user",
|
||||
sa.column("id", sa.Uuid),
|
||||
sa.column("effective_permissions", postgresql.JSONB),
|
||||
)
|
||||
|
||||
user_user_group = sa.table(
|
||||
"user__user_group",
|
||||
sa.column("user_id", sa.Uuid),
|
||||
sa.column("user_group_id", sa.Integer),
|
||||
)
|
||||
|
||||
permission_grant = sa.table(
|
||||
"permission_grant",
|
||||
sa.column("group_id", sa.Integer),
|
||||
sa.column("permission", sa.String),
|
||||
sa.column("is_deleted", sa.Boolean),
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"effective_permissions",
|
||||
postgresql.JSONB(),
|
||||
nullable=False,
|
||||
server_default=sa.text("'[]'::jsonb"),
|
||||
),
|
||||
)
|
||||
|
||||
conn = op.get_bind()
|
||||
|
||||
# Deduplicated permissions per user
|
||||
deduped = (
|
||||
sa.select(
|
||||
user_user_group.c.user_id,
|
||||
permission_grant.c.permission,
|
||||
)
|
||||
.select_from(
|
||||
user_user_group.join(
|
||||
permission_grant,
|
||||
sa.and_(
|
||||
permission_grant.c.group_id == user_user_group.c.user_group_id,
|
||||
permission_grant.c.is_deleted == sa.false(),
|
||||
),
|
||||
)
|
||||
)
|
||||
.distinct()
|
||||
.subquery("deduped")
|
||||
)
|
||||
|
||||
# Aggregate into JSONB array per user (order is not guaranteed;
|
||||
# consumers read this as a set so ordering does not matter)
|
||||
perms_per_user = (
|
||||
sa.select(
|
||||
deduped.c.user_id,
|
||||
sa.func.jsonb_agg(
|
||||
deduped.c.permission,
|
||||
type_=postgresql.JSONB,
|
||||
).label("perms"),
|
||||
)
|
||||
.group_by(deduped.c.user_id)
|
||||
.subquery("sub")
|
||||
)
|
||||
|
||||
conn.execute(
|
||||
user_table.update()
|
||||
.where(user_table.c.id == perms_per_user.c.user_id)
|
||||
.values(effective_permissions=perms_per_user.c.perms)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "effective_permissions")
|
||||
139
backend/alembic/versions/977e834c1427_seed_default_groups.py
Normal file
139
backend/alembic/versions/977e834c1427_seed_default_groups.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""seed_default_groups
|
||||
|
||||
Revision ID: 977e834c1427
|
||||
Revises: 8188861f4e92
|
||||
Create Date: 2026-03-25 14:59:41.313091
|
||||
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "977e834c1427"
|
||||
down_revision = "8188861f4e92"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
# (group_name, permission_value)
|
||||
DEFAULT_GROUPS = [
|
||||
("Admin", "admin"),
|
||||
("Basic", "basic"),
|
||||
]
|
||||
|
||||
CUSTOM_SUFFIX = "(Custom)"
|
||||
|
||||
MAX_RENAME_ATTEMPTS = 100
|
||||
|
||||
# Reflect table structures for use in DML
|
||||
user_group_table = sa.table(
|
||||
"user_group",
|
||||
sa.column("id", sa.Integer),
|
||||
sa.column("name", sa.String),
|
||||
sa.column("is_up_to_date", sa.Boolean),
|
||||
sa.column("is_up_for_deletion", sa.Boolean),
|
||||
sa.column("is_default", sa.Boolean),
|
||||
)
|
||||
|
||||
permission_grant_table = sa.table(
|
||||
"permission_grant",
|
||||
sa.column("group_id", sa.Integer),
|
||||
sa.column("permission", sa.String),
|
||||
sa.column("grant_source", sa.String),
|
||||
)
|
||||
|
||||
user__user_group_table = sa.table(
|
||||
"user__user_group",
|
||||
sa.column("user_group_id", sa.Integer),
|
||||
sa.column("user_id", sa.Uuid),
|
||||
)
|
||||
|
||||
|
||||
def _find_available_name(conn: sa.engine.Connection, base: str) -> str:
|
||||
"""Return a name like 'Admin (Custom)' or 'Admin (Custom 2)' that is not taken."""
|
||||
candidate = f"{base} {CUSTOM_SUFFIX}"
|
||||
attempt = 1
|
||||
while attempt <= MAX_RENAME_ATTEMPTS:
|
||||
exists: Any = conn.execute(
|
||||
sa.select(sa.literal(1))
|
||||
.select_from(user_group_table)
|
||||
.where(user_group_table.c.name == candidate)
|
||||
.limit(1)
|
||||
).fetchone()
|
||||
if exists is None:
|
||||
return candidate
|
||||
attempt += 1
|
||||
candidate = f"{base} (Custom {attempt})"
|
||||
raise RuntimeError(
|
||||
f"Could not find an available name for group '{base}' "
|
||||
f"after {MAX_RENAME_ATTEMPTS} attempts"
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
for group_name, permission_value in DEFAULT_GROUPS:
|
||||
# Step 1: Rename ALL existing groups that clash with the canonical name.
|
||||
conflicting = conn.execute(
|
||||
sa.select(user_group_table.c.id, user_group_table.c.name).where(
|
||||
user_group_table.c.name == group_name
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
for row_id, row_name in conflicting:
|
||||
new_name = _find_available_name(conn, row_name)
|
||||
op.execute(
|
||||
sa.update(user_group_table)
|
||||
.where(user_group_table.c.id == row_id)
|
||||
.values(name=new_name, is_up_to_date=False)
|
||||
)
|
||||
|
||||
# Step 2: Create a fresh default group.
|
||||
result = conn.execute(
|
||||
user_group_table.insert()
|
||||
.values(
|
||||
name=group_name,
|
||||
is_up_to_date=True,
|
||||
is_up_for_deletion=False,
|
||||
is_default=True,
|
||||
)
|
||||
.returning(user_group_table.c.id)
|
||||
).fetchone()
|
||||
assert result is not None
|
||||
group_id = result[0]
|
||||
|
||||
# Step 3: Upsert permission grant.
|
||||
op.execute(
|
||||
pg_insert(permission_grant_table)
|
||||
.values(
|
||||
group_id=group_id,
|
||||
permission=permission_value,
|
||||
grant_source="SYSTEM",
|
||||
)
|
||||
.on_conflict_do_nothing(index_elements=["group_id", "permission"])
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove the default groups created by this migration.
|
||||
# First remove user-group memberships that reference default groups
|
||||
# to avoid FK violations, then delete the groups themselves.
|
||||
default_group_ids = sa.select(user_group_table.c.id).where(
|
||||
user_group_table.c.is_default == True # noqa: E712
|
||||
)
|
||||
conn = op.get_bind()
|
||||
conn.execute(
|
||||
sa.delete(user__user_group_table).where(
|
||||
user__user_group_table.c.user_group_id.in_(default_group_ids)
|
||||
)
|
||||
)
|
||||
conn.execute(
|
||||
sa.delete(user_group_table).where(
|
||||
user_group_table.c.is_default == True # noqa: E712
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,84 @@
|
||||
"""grant_basic_to_existing_groups
|
||||
|
||||
Grants the "basic" permission to all existing groups that don't already
|
||||
have it. Every group should have at least "basic" so that its members
|
||||
get basic access when effective_permissions is backfilled.
|
||||
|
||||
Revision ID: b4b7e1028dfd
|
||||
Revises: b7bcc991d722
|
||||
Create Date: 2026-03-30 16:15:17.093498
|
||||
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b4b7e1028dfd"
|
||||
down_revision = "b7bcc991d722"
|
||||
branch_labels: str | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
user_group = sa.table(
|
||||
"user_group",
|
||||
sa.column("id", sa.Integer),
|
||||
sa.column("is_default", sa.Boolean),
|
||||
)
|
||||
|
||||
permission_grant = sa.table(
|
||||
"permission_grant",
|
||||
sa.column("group_id", sa.Integer),
|
||||
sa.column("permission", sa.String),
|
||||
sa.column("grant_source", sa.String),
|
||||
sa.column("is_deleted", sa.Boolean),
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
already_has_basic = (
|
||||
sa.select(sa.literal(1))
|
||||
.select_from(permission_grant)
|
||||
.where(
|
||||
permission_grant.c.group_id == user_group.c.id,
|
||||
permission_grant.c.permission == "basic",
|
||||
)
|
||||
.exists()
|
||||
)
|
||||
|
||||
groups_needing_basic = sa.select(
|
||||
user_group.c.id,
|
||||
sa.literal("basic").label("permission"),
|
||||
sa.literal("SYSTEM").label("grant_source"),
|
||||
sa.literal(False).label("is_deleted"),
|
||||
).where(
|
||||
user_group.c.is_default == sa.false(),
|
||||
~already_has_basic,
|
||||
)
|
||||
|
||||
conn.execute(
|
||||
permission_grant.insert().from_select(
|
||||
["group_id", "permission", "grant_source", "is_deleted"],
|
||||
groups_needing_basic,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
non_default_group_ids = sa.select(user_group.c.id).where(
|
||||
user_group.c.is_default == sa.false()
|
||||
)
|
||||
|
||||
conn.execute(
|
||||
permission_grant.delete().where(
|
||||
permission_grant.c.permission == "basic",
|
||||
permission_grant.c.grant_source == "SYSTEM",
|
||||
permission_grant.c.group_id.in_(non_default_group_ids),
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,125 @@
|
||||
"""assign_users_to_default_groups
|
||||
|
||||
Revision ID: b7bcc991d722
|
||||
Revises: 03d085c5c38d
|
||||
Create Date: 2026-03-25 16:30:39.529301
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b7bcc991d722"
|
||||
down_revision = "03d085c5c38d"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
# The no-auth placeholder user must NOT be assigned to default groups.
|
||||
# A database trigger (migrate_no_auth_data_to_user) will try to DELETE this
|
||||
# user when the first real user registers; group membership rows would cause
|
||||
# an FK violation on that DELETE.
|
||||
NO_AUTH_PLACEHOLDER_USER_UUID = "00000000-0000-0000-0000-000000000001"
|
||||
|
||||
# Reflect table structures for use in DML
|
||||
user_group_table = sa.table(
|
||||
"user_group",
|
||||
sa.column("id", sa.Integer),
|
||||
sa.column("name", sa.String),
|
||||
sa.column("is_default", sa.Boolean),
|
||||
)
|
||||
|
||||
user_table = sa.table(
|
||||
"user",
|
||||
sa.column("id", sa.Uuid),
|
||||
sa.column("role", sa.String),
|
||||
sa.column("account_type", sa.String),
|
||||
sa.column("is_active", sa.Boolean),
|
||||
)
|
||||
|
||||
user__user_group_table = sa.table(
|
||||
"user__user_group",
|
||||
sa.column("user_group_id", sa.Integer),
|
||||
sa.column("user_id", sa.Uuid),
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Look up default group IDs
|
||||
admin_row = conn.execute(
|
||||
sa.select(user_group_table.c.id).where(
|
||||
user_group_table.c.name == "Admin",
|
||||
user_group_table.c.is_default == True, # noqa: E712
|
||||
)
|
||||
).fetchone()
|
||||
|
||||
basic_row = conn.execute(
|
||||
sa.select(user_group_table.c.id).where(
|
||||
user_group_table.c.name == "Basic",
|
||||
user_group_table.c.is_default == True, # noqa: E712
|
||||
)
|
||||
).fetchone()
|
||||
|
||||
if admin_row is None:
|
||||
raise RuntimeError(
|
||||
"Default 'Admin' group not found. "
|
||||
"Ensure migration 977e834c1427 (seed_default_groups) ran successfully."
|
||||
)
|
||||
|
||||
if basic_row is None:
|
||||
raise RuntimeError(
|
||||
"Default 'Basic' group not found. "
|
||||
"Ensure migration 977e834c1427 (seed_default_groups) ran successfully."
|
||||
)
|
||||
|
||||
# Users with role=admin → Admin group
|
||||
# Include inactive users so reactivation doesn't require reconciliation.
|
||||
# Exclude non-human account types (mirrors assign_user_to_default_groups logic).
|
||||
admin_users = sa.select(
|
||||
sa.literal(admin_row[0]).label("user_group_id"),
|
||||
user_table.c.id.label("user_id"),
|
||||
).where(
|
||||
user_table.c.role == "ADMIN",
|
||||
user_table.c.account_type.notin_(["BOT", "EXT_PERM_USER", "ANONYMOUS"]),
|
||||
user_table.c.id != NO_AUTH_PLACEHOLDER_USER_UUID,
|
||||
)
|
||||
op.execute(
|
||||
pg_insert(user__user_group_table)
|
||||
.from_select(["user_group_id", "user_id"], admin_users)
|
||||
.on_conflict_do_nothing(index_elements=["user_group_id", "user_id"])
|
||||
)
|
||||
|
||||
# STANDARD users (non-admin) and SERVICE_ACCOUNT users (role=basic) → Basic group
|
||||
# Include inactive users so reactivation doesn't require reconciliation.
|
||||
basic_users = sa.select(
|
||||
sa.literal(basic_row[0]).label("user_group_id"),
|
||||
user_table.c.id.label("user_id"),
|
||||
).where(
|
||||
user_table.c.account_type.notin_(["BOT", "EXT_PERM_USER", "ANONYMOUS"]),
|
||||
user_table.c.id != NO_AUTH_PLACEHOLDER_USER_UUID,
|
||||
sa.or_(
|
||||
sa.and_(
|
||||
user_table.c.account_type == "STANDARD",
|
||||
user_table.c.role != "ADMIN",
|
||||
),
|
||||
sa.and_(
|
||||
user_table.c.account_type == "SERVICE_ACCOUNT",
|
||||
user_table.c.role == "BASIC",
|
||||
),
|
||||
),
|
||||
)
|
||||
op.execute(
|
||||
pg_insert(user__user_group_table)
|
||||
.from_select(["user_group_id", "user_id"], basic_users)
|
||||
.on_conflict_do_nothing(index_elements=["user_group_id", "user_id"])
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Group memberships are left in place — removing them risks
|
||||
# deleting memberships that existed before this migration.
|
||||
pass
|
||||
@@ -1,20 +1,14 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from uuid import UUID
|
||||
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
|
||||
from ee.onyx.background.celery_utils import should_perform_chat_ttl_check
|
||||
from ee.onyx.background.task_name_builders import name_chat_ttl_task
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.chat import delete_chat_session
|
||||
from onyx.db.chat import get_chat_sessions_older_than
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import TaskStatus
|
||||
from onyx.db.tasks import mark_task_as_finished_with_id
|
||||
from onyx.db.tasks import register_task
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -29,59 +23,42 @@ logger = setup_logger()
|
||||
trail=False,
|
||||
)
|
||||
def perform_ttl_management_task(
|
||||
self: Task, retention_limit_days: int, *, tenant_id: str
|
||||
self: Task, retention_limit_days: int, *, tenant_id: str # noqa: ARG001
|
||||
) -> None:
|
||||
task_id = self.request.id
|
||||
if not task_id:
|
||||
raise RuntimeError("No task id defined for this task; cannot identify it")
|
||||
|
||||
start_time = datetime.now(tz=timezone.utc)
|
||||
|
||||
user_id: UUID | None = None
|
||||
session_id: UUID | None = None
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# we generally want to move off this, but keeping for now
|
||||
register_task(
|
||||
db_session=db_session,
|
||||
task_name=name_chat_ttl_task(retention_limit_days, tenant_id),
|
||||
task_id=task_id,
|
||||
status=TaskStatus.STARTED,
|
||||
start_time=start_time,
|
||||
)
|
||||
|
||||
old_chat_sessions = get_chat_sessions_older_than(
|
||||
retention_limit_days, db_session
|
||||
)
|
||||
|
||||
for user_id, session_id in old_chat_sessions:
|
||||
# one session per delete so that we don't blow up if a deletion fails.
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
delete_chat_session(
|
||||
user_id,
|
||||
session_id,
|
||||
db_session,
|
||||
include_deleted=True,
|
||||
hard_delete=True,
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
delete_chat_session(
|
||||
user_id,
|
||||
session_id,
|
||||
db_session,
|
||||
include_deleted=True,
|
||||
hard_delete=True,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to delete chat session "
|
||||
f"user_id={user_id} session_id={session_id}, "
|
||||
"continuing with remaining sessions"
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
mark_task_as_finished_with_id(
|
||||
db_session=db_session,
|
||||
task_id=task_id,
|
||||
success=True,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"delete_chat_session exceptioned. user_id={user_id} session_id={session_id}"
|
||||
)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
mark_task_as_finished_with_id(
|
||||
db_session=db_session,
|
||||
task_id=task_id,
|
||||
success=False,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
|
||||
@@ -36,13 +36,16 @@ from ee.onyx.server.scim.filtering import ScimFilter
|
||||
from ee.onyx.server.scim.filtering import ScimFilterOperator
|
||||
from ee.onyx.server.scim.models import ScimMappingFields
|
||||
from onyx.db.dal import DAL
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.enums import GrantSource
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import PermissionGrant
|
||||
from onyx.db.models import ScimGroupMapping
|
||||
from onyx.db.models import ScimToken
|
||||
from onyx.db.models import ScimUserMapping
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -280,7 +283,9 @@ class ScimDAL(DAL):
|
||||
query = (
|
||||
select(User)
|
||||
.join(ScimUserMapping, ScimUserMapping.user_id == User.id)
|
||||
.where(User.role.notin_([UserRole.SLACK_USER, UserRole.EXT_PERM_USER]))
|
||||
.where(
|
||||
User.account_type.notin_([AccountType.BOT, AccountType.EXT_PERM_USER])
|
||||
)
|
||||
)
|
||||
|
||||
if scim_filter:
|
||||
@@ -521,6 +526,22 @@ class ScimDAL(DAL):
|
||||
self._session.add(group)
|
||||
self._session.flush()
|
||||
|
||||
def add_permission_grant_to_group(
|
||||
self,
|
||||
group_id: int,
|
||||
permission: Permission,
|
||||
grant_source: GrantSource,
|
||||
) -> None:
|
||||
"""Grant a permission to a group and flush."""
|
||||
self._session.add(
|
||||
PermissionGrant(
|
||||
group_id=group_id,
|
||||
permission=permission,
|
||||
grant_source=grant_source,
|
||||
)
|
||||
)
|
||||
self._session.flush()
|
||||
|
||||
def update_group(
|
||||
self,
|
||||
group: UserGroup,
|
||||
|
||||
@@ -19,6 +19,8 @@ from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import GrantSource
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Credential
|
||||
from onyx.db.models import Credential__UserGroup
|
||||
@@ -28,6 +30,7 @@ from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import DocumentSet__UserGroup
|
||||
from onyx.db.models import FederatedConnector__DocumentSet
|
||||
from onyx.db.models import LLMProvider__UserGroup
|
||||
from onyx.db.models import PermissionGrant
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__UserGroup
|
||||
from onyx.db.models import TokenRateLimit__UserGroup
|
||||
@@ -36,6 +39,7 @@ from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.models import UserGroup__ConnectorCredentialPair
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.db.permissions import recompute_user_permissions__no_commit
|
||||
from onyx.db.users import fetch_user_by_id
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -255,6 +259,7 @@ def fetch_user_groups(
|
||||
db_session: Session,
|
||||
only_up_to_date: bool = True,
|
||||
eager_load_for_snapshot: bool = False,
|
||||
include_default: bool = True,
|
||||
) -> Sequence[UserGroup]:
|
||||
"""
|
||||
Fetches user groups from the database.
|
||||
@@ -269,6 +274,7 @@ def fetch_user_groups(
|
||||
to include only up to date user groups. Defaults to `True`.
|
||||
eager_load_for_snapshot: If True, adds eager loading for all relationships
|
||||
needed by UserGroup.from_model snapshot creation.
|
||||
include_default: If False, excludes system default groups (is_default=True).
|
||||
|
||||
Returns:
|
||||
Sequence[UserGroup]: A sequence of `UserGroup` objects matching the query criteria.
|
||||
@@ -276,6 +282,8 @@ def fetch_user_groups(
|
||||
stmt = select(UserGroup)
|
||||
if only_up_to_date:
|
||||
stmt = stmt.where(UserGroup.is_up_to_date == True) # noqa: E712
|
||||
if not include_default:
|
||||
stmt = stmt.where(UserGroup.is_default == False) # noqa: E712
|
||||
if eager_load_for_snapshot:
|
||||
stmt = _add_user_group_snapshot_eager_loads(stmt)
|
||||
return db_session.scalars(stmt).unique().all()
|
||||
@@ -286,6 +294,7 @@ def fetch_user_groups_for_user(
|
||||
user_id: UUID,
|
||||
only_curator_groups: bool = False,
|
||||
eager_load_for_snapshot: bool = False,
|
||||
include_default: bool = True,
|
||||
) -> Sequence[UserGroup]:
|
||||
stmt = (
|
||||
select(UserGroup)
|
||||
@@ -295,6 +304,8 @@ def fetch_user_groups_for_user(
|
||||
)
|
||||
if only_curator_groups:
|
||||
stmt = stmt.where(User__UserGroup.is_curator == True) # noqa: E712
|
||||
if not include_default:
|
||||
stmt = stmt.where(UserGroup.is_default == False) # noqa: E712
|
||||
if eager_load_for_snapshot:
|
||||
stmt = _add_user_group_snapshot_eager_loads(stmt)
|
||||
return db_session.scalars(stmt).unique().all()
|
||||
@@ -478,6 +489,16 @@ def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserG
|
||||
db_session.add(db_user_group)
|
||||
db_session.flush() # give the group an ID
|
||||
|
||||
# Every group gets the "basic" permission by default
|
||||
db_session.add(
|
||||
PermissionGrant(
|
||||
group_id=db_user_group.id,
|
||||
permission=Permission.BASIC_ACCESS,
|
||||
grant_source=GrantSource.SYSTEM,
|
||||
)
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
_add_user__user_group_relationships__no_commit(
|
||||
db_session=db_session,
|
||||
user_group_id=db_user_group.id,
|
||||
@@ -489,6 +510,8 @@ def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserG
|
||||
cc_pair_ids=user_group.cc_pair_ids,
|
||||
)
|
||||
|
||||
recompute_user_permissions__no_commit(user_group.user_ids, db_session)
|
||||
|
||||
db_session.commit()
|
||||
return db_user_group
|
||||
|
||||
@@ -796,6 +819,10 @@ def update_user_group(
|
||||
# update "time_updated" to now
|
||||
db_user_group.time_last_modified_by_user = func.now()
|
||||
|
||||
recompute_user_permissions__no_commit(
|
||||
list(set(added_user_ids) | set(removed_user_ids)), db_session
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
return db_user_group
|
||||
|
||||
@@ -835,6 +862,19 @@ def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) ->
|
||||
|
||||
_check_user_group_is_modifiable(db_user_group)
|
||||
|
||||
# Collect affected user IDs before cleanup deletes the relationships
|
||||
affected_user_ids: list[UUID] = [
|
||||
uid
|
||||
for uid in db_session.execute(
|
||||
select(User__UserGroup.user_id).where(
|
||||
User__UserGroup.user_group_id == user_group_id
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
if uid is not None
|
||||
]
|
||||
|
||||
_mark_user_group__cc_pair_relationships_outdated__no_commit(
|
||||
db_session=db_session, user_group_id=user_group_id
|
||||
)
|
||||
@@ -863,6 +903,10 @@ def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) ->
|
||||
db_session=db_session, user_group_id=user_group_id
|
||||
)
|
||||
|
||||
# Recompute permissions for affected users now that their
|
||||
# membership in this group has been removed
|
||||
recompute_user_permissions__no_commit(affected_user_ids, db_session)
|
||||
|
||||
db_user_group.is_up_to_date = False
|
||||
db_user_group.is_up_for_deletion = True
|
||||
db_session.commit()
|
||||
|
||||
@@ -52,16 +52,25 @@ from ee.onyx.server.scim.schema_definitions import SERVICE_PROVIDER_CONFIG
|
||||
from ee.onyx.server.scim.schema_definitions import USER_RESOURCE_TYPE
|
||||
from ee.onyx.server.scim.schema_definitions import USER_SCHEMA_DEF
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.enums import GrantSource
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import ScimToken
|
||||
from onyx.db.models import ScimUserMapping
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.db.permissions import recompute_permissions_for_group__no_commit
|
||||
from onyx.db.permissions import recompute_user_permissions__no_commit
|
||||
from onyx.db.users import assign_user_to_default_groups__no_commit
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Group names reserved for system default groups (seeded by migration).
|
||||
_RESERVED_GROUP_NAMES = frozenset({"Admin", "Basic"})
|
||||
|
||||
|
||||
class ScimJSONResponse(JSONResponse):
|
||||
"""JSONResponse with Content-Type: application/scim+json (RFC 7644 §3.1)."""
|
||||
@@ -486,6 +495,7 @@ def create_user(
|
||||
email=email,
|
||||
hashed_password=_pw_helper.hash(_pw_helper.generate()),
|
||||
role=UserRole.BASIC,
|
||||
account_type=AccountType.STANDARD,
|
||||
is_active=user_resource.active,
|
||||
is_verified=True,
|
||||
personal_name=personal_name,
|
||||
@@ -506,13 +516,25 @@ def create_user(
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
dal.commit()
|
||||
except IntegrityError:
|
||||
dal.rollback()
|
||||
return _scim_error_response(
|
||||
409, f"User with email {email} already has a SCIM mapping"
|
||||
)
|
||||
|
||||
# Assign user to default group BEFORE commit so everything is atomic.
|
||||
# If this fails, the entire user creation rolls back and IdP can retry.
|
||||
try:
|
||||
assign_user_to_default_groups__no_commit(db_session, user)
|
||||
except Exception:
|
||||
dal.rollback()
|
||||
logger.exception(f"Failed to assign SCIM user {email} to default groups")
|
||||
return _scim_error_response(
|
||||
500, f"Failed to assign user {email} to default group"
|
||||
)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return _scim_resource_response(
|
||||
provider.build_user_resource(
|
||||
user,
|
||||
@@ -542,7 +564,8 @@ def replace_user(
|
||||
user = result
|
||||
|
||||
# Handle activation (need seat check) / deactivation
|
||||
if user_resource.active and not user.is_active:
|
||||
is_reactivation = user_resource.active and not user.is_active
|
||||
if is_reactivation:
|
||||
seat_error = _check_seat_availability(dal)
|
||||
if seat_error:
|
||||
return _scim_error_response(403, seat_error)
|
||||
@@ -556,6 +579,12 @@ def replace_user(
|
||||
personal_name=personal_name,
|
||||
)
|
||||
|
||||
# Reconcile default-group membership on reactivation
|
||||
if is_reactivation:
|
||||
assign_user_to_default_groups__no_commit(
|
||||
db_session, user, is_admin=(user.role == UserRole.ADMIN)
|
||||
)
|
||||
|
||||
new_external_id = user_resource.externalId
|
||||
scim_username = user_resource.userName.strip()
|
||||
fields = _fields_from_resource(user_resource)
|
||||
@@ -621,6 +650,7 @@ def patch_user(
|
||||
return _scim_error_response(e.status, e.detail)
|
||||
|
||||
# Apply changes back to the DB model
|
||||
is_reactivation = patched.active and not user.is_active
|
||||
if patched.active != user.is_active:
|
||||
if patched.active:
|
||||
seat_error = _check_seat_availability(dal)
|
||||
@@ -649,6 +679,12 @@ def patch_user(
|
||||
personal_name=personal_name,
|
||||
)
|
||||
|
||||
# Reconcile default-group membership on reactivation
|
||||
if is_reactivation:
|
||||
assign_user_to_default_groups__no_commit(
|
||||
db_session, user, is_admin=(user.role == UserRole.ADMIN)
|
||||
)
|
||||
|
||||
# Build updated fields by merging PATCH enterprise data with current values
|
||||
cf = current_fields or ScimMappingFields()
|
||||
fields = ScimMappingFields(
|
||||
@@ -857,6 +893,11 @@ def create_group(
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
if group_resource.displayName in _RESERVED_GROUP_NAMES:
|
||||
return _scim_error_response(
|
||||
409, f"'{group_resource.displayName}' is a reserved group name."
|
||||
)
|
||||
|
||||
if dal.get_group_by_name(group_resource.displayName):
|
||||
return _scim_error_response(
|
||||
409, f"Group with name '{group_resource.displayName}' already exists"
|
||||
@@ -879,8 +920,18 @@ def create_group(
|
||||
409, f"Group with name '{group_resource.displayName}' already exists"
|
||||
)
|
||||
|
||||
# Every group gets the "basic" permission by default.
|
||||
dal.add_permission_grant_to_group(
|
||||
group_id=db_group.id,
|
||||
permission=Permission.BASIC_ACCESS,
|
||||
grant_source=GrantSource.SYSTEM,
|
||||
)
|
||||
|
||||
dal.upsert_group_members(db_group.id, member_uuids)
|
||||
|
||||
# Recompute permissions for initial members.
|
||||
recompute_user_permissions__no_commit(member_uuids, db_session)
|
||||
|
||||
external_id = group_resource.externalId
|
||||
if external_id:
|
||||
dal.create_group_mapping(external_id=external_id, user_group_id=db_group.id)
|
||||
@@ -911,14 +962,36 @@ def replace_group(
|
||||
return result
|
||||
group = result
|
||||
|
||||
if group.name in _RESERVED_GROUP_NAMES and group_resource.displayName != group.name:
|
||||
return _scim_error_response(
|
||||
409, f"'{group.name}' is a reserved group name and cannot be renamed."
|
||||
)
|
||||
|
||||
if (
|
||||
group_resource.displayName in _RESERVED_GROUP_NAMES
|
||||
and group_resource.displayName != group.name
|
||||
):
|
||||
return _scim_error_response(
|
||||
409, f"'{group_resource.displayName}' is a reserved group name."
|
||||
)
|
||||
|
||||
member_uuids, err = _validate_and_parse_members(group_resource.members, dal)
|
||||
if err:
|
||||
return _scim_error_response(400, err)
|
||||
|
||||
# Capture old member IDs before replacing so we can recompute their
|
||||
# permissions after they are removed from the group.
|
||||
old_member_ids = {uid for uid, _ in dal.get_group_members(group.id)}
|
||||
|
||||
dal.update_group(group, name=group_resource.displayName)
|
||||
dal.replace_group_members(group.id, member_uuids)
|
||||
dal.sync_group_external_id(group.id, group_resource.externalId)
|
||||
|
||||
# Recompute permissions for current members (batch) and removed members.
|
||||
recompute_permissions_for_group__no_commit(group.id, db_session)
|
||||
removed_ids = list(old_member_ids - set(member_uuids))
|
||||
recompute_user_permissions__no_commit(removed_ids, db_session)
|
||||
|
||||
dal.commit()
|
||||
|
||||
members = dal.get_group_members(group.id)
|
||||
@@ -961,8 +1034,19 @@ def patch_group(
|
||||
return _scim_error_response(e.status, e.detail)
|
||||
|
||||
new_name = patched.displayName if patched.displayName != group.name else None
|
||||
|
||||
if group.name in _RESERVED_GROUP_NAMES and new_name:
|
||||
return _scim_error_response(
|
||||
409, f"'{group.name}' is a reserved group name and cannot be renamed."
|
||||
)
|
||||
|
||||
if new_name and new_name in _RESERVED_GROUP_NAMES:
|
||||
return _scim_error_response(409, f"'{new_name}' is a reserved group name.")
|
||||
|
||||
dal.update_group(group, name=new_name)
|
||||
|
||||
affected_uuids: list[UUID] = []
|
||||
|
||||
if added_ids:
|
||||
add_uuids = [UUID(mid) for mid in added_ids if _is_valid_uuid(mid)]
|
||||
if add_uuids:
|
||||
@@ -973,10 +1057,15 @@ def patch_group(
|
||||
f"Member(s) not found: {', '.join(str(u) for u in missing)}",
|
||||
)
|
||||
dal.upsert_group_members(group.id, add_uuids)
|
||||
affected_uuids.extend(add_uuids)
|
||||
|
||||
if removed_ids:
|
||||
remove_uuids = [UUID(mid) for mid in removed_ids if _is_valid_uuid(mid)]
|
||||
dal.remove_group_members(group.id, remove_uuids)
|
||||
affected_uuids.extend(remove_uuids)
|
||||
|
||||
# Recompute permissions for all users whose group membership changed.
|
||||
recompute_user_permissions__no_commit(affected_uuids, db_session)
|
||||
|
||||
dal.sync_group_external_id(group.id, patched.externalId)
|
||||
dal.commit()
|
||||
@@ -1002,11 +1091,21 @@ def delete_group(
|
||||
return result
|
||||
group = result
|
||||
|
||||
if group.name in _RESERVED_GROUP_NAMES:
|
||||
return _scim_error_response(409, f"'{group.name}' is a reserved group name.")
|
||||
|
||||
# Capture member IDs before deletion so we can recompute their permissions.
|
||||
affected_user_ids = [uid for uid, _ in dal.get_group_members(group.id)]
|
||||
|
||||
mapping = dal.get_group_mapping_by_group_id(group.id)
|
||||
if mapping:
|
||||
dal.delete_group_mapping(mapping.id)
|
||||
|
||||
dal.delete_group_with_members(group)
|
||||
|
||||
# Recompute permissions for users who lost this group membership.
|
||||
recompute_user_permissions__no_commit(affected_user_ids, db_session)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return Response(status_code=204)
|
||||
|
||||
@@ -43,12 +43,16 @@ router = APIRouter(prefix="/manage", tags=PUBLIC_API_TAGS)
|
||||
|
||||
@router.get("/admin/user-group")
|
||||
def list_user_groups(
|
||||
include_default: bool = False,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[UserGroup]:
|
||||
if user.role == UserRole.ADMIN:
|
||||
user_groups = fetch_user_groups(
|
||||
db_session, only_up_to_date=False, eager_load_for_snapshot=True
|
||||
db_session,
|
||||
only_up_to_date=False,
|
||||
eager_load_for_snapshot=True,
|
||||
include_default=include_default,
|
||||
)
|
||||
else:
|
||||
user_groups = fetch_user_groups_for_user(
|
||||
@@ -56,27 +60,50 @@ def list_user_groups(
|
||||
user_id=user.id,
|
||||
only_curator_groups=user.role == UserRole.CURATOR,
|
||||
eager_load_for_snapshot=True,
|
||||
include_default=include_default,
|
||||
)
|
||||
return [UserGroup.from_model(user_group) for user_group in user_groups]
|
||||
|
||||
|
||||
@router.get("/user-groups/minimal")
|
||||
def list_minimal_user_groups(
|
||||
include_default: bool = False,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[MinimalUserGroupSnapshot]:
|
||||
if user.role == UserRole.ADMIN:
|
||||
user_groups = fetch_user_groups(db_session, only_up_to_date=False)
|
||||
user_groups = fetch_user_groups(
|
||||
db_session,
|
||||
only_up_to_date=False,
|
||||
include_default=include_default,
|
||||
)
|
||||
else:
|
||||
user_groups = fetch_user_groups_for_user(
|
||||
db_session=db_session,
|
||||
user_id=user.id,
|
||||
include_default=include_default,
|
||||
)
|
||||
return [
|
||||
MinimalUserGroupSnapshot.from_model(user_group) for user_group in user_groups
|
||||
]
|
||||
|
||||
|
||||
@router.get("/admin/user-group/{user_group_id}/permissions")
|
||||
def get_user_group_permissions(
|
||||
user_group_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[str]:
|
||||
group = fetch_user_group(db_session, user_group_id)
|
||||
if group is None:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "User group not found")
|
||||
return [
|
||||
grant.permission.value
|
||||
for grant in group.permission_grants
|
||||
if not grant.is_deleted
|
||||
]
|
||||
|
||||
|
||||
@router.post("/admin/user-group")
|
||||
def create_user_group(
|
||||
user_group: UserGroupCreate,
|
||||
@@ -100,6 +127,9 @@ def rename_user_group_endpoint(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UserGroup:
|
||||
group = fetch_user_group(db_session, rename_request.id)
|
||||
if group and group.is_default:
|
||||
raise OnyxError(OnyxErrorCode.CONFLICT, "Cannot rename a default system group.")
|
||||
try:
|
||||
return UserGroup.from_model(
|
||||
rename_user_group(
|
||||
@@ -185,6 +215,9 @@ def delete_user_group(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
group = fetch_user_group(db_session, user_group_id)
|
||||
if group and group.is_default:
|
||||
raise OnyxError(OnyxErrorCode.CONFLICT, "Cannot delete a default system group.")
|
||||
try:
|
||||
prepare_user_group_for_deletion(db_session, user_group_id)
|
||||
except ValueError as e:
|
||||
|
||||
@@ -22,6 +22,7 @@ class UserGroup(BaseModel):
|
||||
personas: list[PersonaSnapshot]
|
||||
is_up_to_date: bool
|
||||
is_up_for_deletion: bool
|
||||
is_default: bool
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, user_group_model: UserGroupModel) -> "UserGroup":
|
||||
@@ -74,18 +75,21 @@ class UserGroup(BaseModel):
|
||||
],
|
||||
is_up_to_date=user_group_model.is_up_to_date,
|
||||
is_up_for_deletion=user_group_model.is_up_for_deletion,
|
||||
is_default=user_group_model.is_default,
|
||||
)
|
||||
|
||||
|
||||
class MinimalUserGroupSnapshot(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
is_default: bool
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, user_group_model: UserGroupModel) -> "MinimalUserGroupSnapshot":
|
||||
return cls(
|
||||
id=user_group_model.id,
|
||||
name=user_group_model.name,
|
||||
is_default=user_group_model.is_default,
|
||||
)
|
||||
|
||||
|
||||
|
||||
110
backend/onyx/auth/permissions.py
Normal file
110
backend/onyx/auth/permissions.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""
|
||||
Permission resolution for group-based authorization.
|
||||
|
||||
Granted permissions are stored as a JSONB column on the User table and
|
||||
loaded for free with every auth query. Implied permissions are expanded
|
||||
at read time — only directly granted permissions are persisted.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Coroutine
|
||||
from typing import Any
|
||||
|
||||
from fastapi import Depends
|
||||
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import User
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
ALL_PERMISSIONS: frozenset[str] = frozenset(p.value for p in Permission)
|
||||
|
||||
# Implication map: granted permission -> set of permissions it implies.
|
||||
IMPLIED_PERMISSIONS: dict[str, set[str]] = {
|
||||
Permission.ADD_AGENTS.value: {Permission.READ_AGENTS.value},
|
||||
Permission.MANAGE_AGENTS.value: {
|
||||
Permission.ADD_AGENTS.value,
|
||||
Permission.READ_AGENTS.value,
|
||||
},
|
||||
Permission.MANAGE_DOCUMENT_SETS.value: {
|
||||
Permission.READ_DOCUMENT_SETS.value,
|
||||
Permission.READ_CONNECTORS.value,
|
||||
},
|
||||
Permission.ADD_CONNECTORS.value: {Permission.READ_CONNECTORS.value},
|
||||
Permission.MANAGE_CONNECTORS.value: {
|
||||
Permission.ADD_CONNECTORS.value,
|
||||
Permission.READ_CONNECTORS.value,
|
||||
},
|
||||
Permission.MANAGE_USER_GROUPS.value: {
|
||||
Permission.READ_CONNECTORS.value,
|
||||
Permission.READ_DOCUMENT_SETS.value,
|
||||
Permission.READ_AGENTS.value,
|
||||
Permission.READ_USERS.value,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def resolve_effective_permissions(granted: set[str]) -> set[str]:
|
||||
"""Expand granted permissions with their implied permissions.
|
||||
|
||||
If "admin" is present, returns all 19 permissions.
|
||||
"""
|
||||
if Permission.FULL_ADMIN_PANEL_ACCESS.value in granted:
|
||||
return set(ALL_PERMISSIONS)
|
||||
|
||||
effective = set(granted)
|
||||
changed = True
|
||||
while changed:
|
||||
changed = False
|
||||
for perm in list(effective):
|
||||
implied = IMPLIED_PERMISSIONS.get(perm)
|
||||
if implied and not implied.issubset(effective):
|
||||
effective |= implied
|
||||
changed = True
|
||||
return effective
|
||||
|
||||
|
||||
def get_effective_permissions(user: User) -> set[Permission]:
|
||||
"""Read granted permissions from the column and expand implied permissions."""
|
||||
granted: set[Permission] = set()
|
||||
for p in user.effective_permissions:
|
||||
try:
|
||||
granted.add(Permission(p))
|
||||
except ValueError:
|
||||
logger.warning(f"Skipping unknown permission '{p}' for user {user.id}")
|
||||
if Permission.FULL_ADMIN_PANEL_ACCESS in granted:
|
||||
return set(Permission)
|
||||
expanded = resolve_effective_permissions({p.value for p in granted})
|
||||
return {Permission(p) for p in expanded}
|
||||
|
||||
|
||||
def require_permission(
|
||||
required: Permission,
|
||||
) -> Callable[..., Coroutine[Any, Any, User]]:
|
||||
"""FastAPI dependency factory for permission-based access control.
|
||||
|
||||
Usage:
|
||||
@router.get("/endpoint")
|
||||
def endpoint(user: User = Depends(require_permission(Permission.MANAGE_CONNECTORS))):
|
||||
...
|
||||
"""
|
||||
|
||||
async def dependency(user: User = Depends(current_user)) -> User:
|
||||
effective = get_effective_permissions(user)
|
||||
|
||||
if Permission.FULL_ADMIN_PANEL_ACCESS in effective:
|
||||
return user
|
||||
|
||||
if required not in effective:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
"You do not have the required permissions for this action.",
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
return dependency
|
||||
@@ -5,6 +5,8 @@ from typing import Any
|
||||
from fastapi_users import schemas
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.db.enums import AccountType
|
||||
|
||||
|
||||
class UserRole(str, Enum):
|
||||
"""
|
||||
@@ -41,6 +43,7 @@ class UserRead(schemas.BaseUser[uuid.UUID]):
|
||||
|
||||
class UserCreate(schemas.BaseUserCreate):
|
||||
role: UserRole = UserRole.BASIC
|
||||
account_type: AccountType = AccountType.STANDARD
|
||||
tenant_id: str | None = None
|
||||
# Captcha token for cloud signup protection (optional, only used when captcha is enabled)
|
||||
# Excluded from create_update_dict so it never reaches the DB layer
|
||||
@@ -50,19 +53,19 @@ class UserCreate(schemas.BaseUserCreate):
|
||||
def create_update_dict(self) -> dict[str, Any]:
|
||||
d = super().create_update_dict()
|
||||
d.pop("captcha_token", None)
|
||||
# Force STANDARD for self-registration; only trusted paths
|
||||
# (SCIM, API key creation) supply a different account_type directly.
|
||||
d["account_type"] = AccountType.STANDARD
|
||||
return d
|
||||
|
||||
@override
|
||||
def create_update_dict_superuser(self) -> dict[str, Any]:
|
||||
d = super().create_update_dict_superuser()
|
||||
d.pop("captcha_token", None)
|
||||
d.setdefault("account_type", self.account_type)
|
||||
return d
|
||||
|
||||
|
||||
class UserUpdateWithRole(schemas.BaseUserUpdate):
|
||||
role: UserRole
|
||||
|
||||
|
||||
class UserUpdate(schemas.BaseUserUpdate):
|
||||
"""
|
||||
Role updates are not allowed through the user update endpoint for security reasons
|
||||
|
||||
@@ -80,7 +80,6 @@ from onyx.auth.pat import get_hashed_pat_from_request
|
||||
from onyx.auth.schemas import AuthBackend
|
||||
from onyx.auth.schemas import UserCreate
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.auth.schemas import UserUpdateWithRole
|
||||
from onyx.configs.app_configs import AUTH_BACKEND
|
||||
from onyx.configs.app_configs import AUTH_COOKIE_EXPIRE_TIME_SECONDS
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
@@ -120,11 +119,13 @@ from onyx.db.engine.async_sql_engine import get_async_session
|
||||
from onyx.db.engine.async_sql_engine import get_async_session_context_manager
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import AccessToken
|
||||
from onyx.db.models import OAuthAccount
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.db.pat import fetch_user_for_pat
|
||||
from onyx.db.users import assign_user_to_default_groups__no_commit
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import log_onyx_error
|
||||
@@ -500,18 +501,21 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
user = user_by_session
|
||||
|
||||
if (
|
||||
user.role.is_web_login()
|
||||
user.account_type.is_web_login()
|
||||
or not isinstance(user_create, UserCreate)
|
||||
or not user_create.role.is_web_login()
|
||||
or not user_create.account_type.is_web_login()
|
||||
):
|
||||
raise exceptions.UserAlreadyExists()
|
||||
|
||||
user_update = UserUpdateWithRole(
|
||||
password=user_create.password,
|
||||
is_verified=user_create.is_verified,
|
||||
role=user_create.role,
|
||||
)
|
||||
user = await self.update(user_update, user)
|
||||
# Cache id before expire — accessing attrs on an expired
|
||||
# object triggers a sync lazy-load which raises MissingGreenlet
|
||||
# in this async context.
|
||||
user_id = user.id
|
||||
self._upgrade_user_to_standard__sync(user_id, user_create)
|
||||
# Expire so the async session re-fetches the row updated by
|
||||
# the sync session above.
|
||||
self.user_db.session.expire(user)
|
||||
user = await self.user_db.get(user_id) # type: ignore[assignment]
|
||||
except exceptions.UserAlreadyExists:
|
||||
user = await self.get_by_email(user_create.email)
|
||||
|
||||
@@ -525,18 +529,21 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
|
||||
# Handle case where user has used product outside of web and is now creating an account through web
|
||||
if (
|
||||
user.role.is_web_login()
|
||||
user.account_type.is_web_login()
|
||||
or not isinstance(user_create, UserCreate)
|
||||
or not user_create.role.is_web_login()
|
||||
or not user_create.account_type.is_web_login()
|
||||
):
|
||||
raise exceptions.UserAlreadyExists()
|
||||
|
||||
user_update = UserUpdateWithRole(
|
||||
password=user_create.password,
|
||||
is_verified=user_create.is_verified,
|
||||
role=user_create.role,
|
||||
)
|
||||
user = await self.update(user_update, user)
|
||||
# Cache id before expire — accessing attrs on an expired
|
||||
# object triggers a sync lazy-load which raises MissingGreenlet
|
||||
# in this async context.
|
||||
user_id = user.id
|
||||
self._upgrade_user_to_standard__sync(user_id, user_create)
|
||||
# Expire so the async session re-fetches the row updated by
|
||||
# the sync session above.
|
||||
self.user_db.session.expire(user)
|
||||
user = await self.user_db.get(user_id) # type: ignore[assignment]
|
||||
if user_created:
|
||||
await self._assign_default_pinned_assistants(user, db_session)
|
||||
remove_user_from_invited_users(user_create.email)
|
||||
@@ -573,6 +580,38 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
)
|
||||
user.pinned_assistants = default_persona_ids
|
||||
|
||||
def _upgrade_user_to_standard__sync(
|
||||
self,
|
||||
user_id: uuid.UUID,
|
||||
user_create: UserCreate,
|
||||
) -> None:
|
||||
"""Upgrade a non-web user to STANDARD and assign default groups atomically.
|
||||
|
||||
All writes happen in a single sync transaction so neither the field
|
||||
update nor the group assignment is visible without the other.
|
||||
"""
|
||||
with get_session_with_current_tenant() as sync_db:
|
||||
sync_user = sync_db.query(User).filter(User.id == user_id).first() # type: ignore[arg-type]
|
||||
if sync_user:
|
||||
sync_user.hashed_password = self.password_helper.hash(
|
||||
user_create.password
|
||||
)
|
||||
sync_user.is_verified = user_create.is_verified or False
|
||||
sync_user.role = user_create.role
|
||||
sync_user.account_type = AccountType.STANDARD
|
||||
assign_user_to_default_groups__no_commit(
|
||||
sync_db,
|
||||
sync_user,
|
||||
is_admin=(user_create.role == UserRole.ADMIN),
|
||||
)
|
||||
sync_db.commit()
|
||||
else:
|
||||
logger.warning(
|
||||
"User %s not found in sync session during upgrade to standard; "
|
||||
"skipping upgrade",
|
||||
user_id,
|
||||
)
|
||||
|
||||
async def validate_password(self, password: str, _: schemas.UC | models.UP) -> None:
|
||||
# Validate password according to configurable security policy (defined via environment variables)
|
||||
if len(password) < PASSWORD_MIN_LENGTH:
|
||||
@@ -694,6 +733,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
"email": account_email,
|
||||
"hashed_password": self.password_helper.hash(password),
|
||||
"is_verified": is_verified_by_default,
|
||||
"account_type": AccountType.STANDARD,
|
||||
}
|
||||
|
||||
user = await self.user_db.create(user_dict)
|
||||
@@ -726,7 +766,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
)
|
||||
|
||||
# Handle case where user has used product outside of web and is now creating an account through web
|
||||
if not user.role.is_web_login():
|
||||
if not user.account_type.is_web_login():
|
||||
# We must use the existing user in the session if it matches
|
||||
# the user we just got by email/oauth. Note that this only applies
|
||||
# to multi-tenant, due to the overwriting of the user_db
|
||||
@@ -743,14 +783,25 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
with get_session_with_current_tenant() as sync_db:
|
||||
enforce_seat_limit(sync_db)
|
||||
|
||||
await self.user_db.update(
|
||||
user,
|
||||
{
|
||||
"is_verified": is_verified_by_default,
|
||||
"role": UserRole.BASIC,
|
||||
**({"is_active": True} if not user.is_active else {}),
|
||||
},
|
||||
)
|
||||
# Upgrade the user and assign default groups in a single
|
||||
# transaction so neither change is visible without the other.
|
||||
was_inactive = not user.is_active
|
||||
with get_session_with_current_tenant() as sync_db:
|
||||
sync_user = sync_db.query(User).filter(User.id == user.id).first() # type: ignore[arg-type]
|
||||
if sync_user:
|
||||
sync_user.is_verified = is_verified_by_default
|
||||
sync_user.role = UserRole.BASIC
|
||||
sync_user.account_type = AccountType.STANDARD
|
||||
if was_inactive:
|
||||
sync_user.is_active = True
|
||||
assign_user_to_default_groups__no_commit(sync_db, sync_user)
|
||||
sync_db.commit()
|
||||
|
||||
# Refresh the async user object so downstream code
|
||||
# (e.g. oidc_expiry check) sees the updated fields.
|
||||
self.user_db.session.expire(user)
|
||||
user = await self.user_db.get(user.id)
|
||||
assert user is not None
|
||||
|
||||
# this is needed if an organization goes from `TRACK_EXTERNAL_IDP_EXPIRY=true` to `false`
|
||||
# otherwise, the oidc expiry will always be old, and the user will never be able to login
|
||||
@@ -836,6 +887,16 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
event=MilestoneRecordType.TENANT_CREATED,
|
||||
)
|
||||
|
||||
# Assign user to the appropriate default group (Admin or Basic).
|
||||
# Must happen inside the try block while tenant context is active,
|
||||
# otherwise get_session_with_current_tenant() targets the wrong schema.
|
||||
is_admin = user_count == 1 or user.email in get_default_admin_user_emails()
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
assign_user_to_default_groups__no_commit(
|
||||
db_session, user, is_admin=is_admin
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
@@ -975,7 +1036,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
self.password_helper.hash(credentials.password)
|
||||
return None
|
||||
|
||||
if not user.role.is_web_login():
|
||||
if not user.account_type.is_web_login():
|
||||
raise BasicAuthenticationError(
|
||||
detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD",
|
||||
)
|
||||
@@ -1471,7 +1532,7 @@ async def _get_or_create_user_from_jwt(
|
||||
if not user.is_active:
|
||||
logger.warning("Inactive user %s attempted JWT login; skipping", email)
|
||||
return None
|
||||
if not user.role.is_web_login():
|
||||
if not user.account_type.is_web_login():
|
||||
raise exceptions.UserNotExists()
|
||||
except exceptions.UserNotExists:
|
||||
logger.info("Provisioning user %s from JWT login", email)
|
||||
@@ -1492,7 +1553,7 @@ async def _get_or_create_user_from_jwt(
|
||||
email,
|
||||
)
|
||||
return None
|
||||
if not user.role.is_web_login():
|
||||
if not user.account_type.is_web_login():
|
||||
logger.warning(
|
||||
"Non-web-login user %s attempted JWT login during provisioning race; skipping",
|
||||
email,
|
||||
@@ -1554,6 +1615,7 @@ def get_anonymous_user() -> User:
|
||||
is_verified=True,
|
||||
is_superuser=False,
|
||||
role=UserRole.LIMITED,
|
||||
account_type=AccountType.ANONYMOUS,
|
||||
use_memories=False,
|
||||
enable_memory_tool=False,
|
||||
)
|
||||
|
||||
@@ -36,6 +36,7 @@ from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.opensearch_migration import build_sanitized_to_original_doc_id_mapping
|
||||
from onyx.db.opensearch_migration import get_vespa_visit_state
|
||||
from onyx.db.opensearch_migration import is_migration_completed
|
||||
from onyx.db.opensearch_migration import (
|
||||
mark_migration_completed_time_if_not_set_with_commit,
|
||||
)
|
||||
@@ -106,14 +107,19 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
acquired; effectively a no-op. True if the task completed
|
||||
successfully. False if the task errored.
|
||||
"""
|
||||
# 1. Check if we should run the task.
|
||||
# 1.a. If OpenSearch indexing is disabled, we don't run the task.
|
||||
if not ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
|
||||
task_logger.warning(
|
||||
"OpenSearch migration is not enabled, skipping chunk migration task."
|
||||
)
|
||||
return None
|
||||
|
||||
task_logger.info("Starting chunk-level migration from Vespa to OpenSearch.")
|
||||
task_start_time = time.monotonic()
|
||||
|
||||
# 1.b. Only one instance per tenant of this task may run concurrently at
|
||||
# once. If we fail to acquire a lock, we assume it is because another task
|
||||
# has one and we exit.
|
||||
r = get_redis_client()
|
||||
lock: RedisLock = r.lock(
|
||||
name=OnyxRedisLocks.OPENSEARCH_MIGRATION_BEAT_LOCK,
|
||||
@@ -136,10 +142,11 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
f"Token: {lock.local.token}"
|
||||
)
|
||||
|
||||
# 2. Prepare to migrate.
|
||||
total_chunks_migrated_this_task = 0
|
||||
total_chunks_errored_this_task = 0
|
||||
try:
|
||||
# Double check that tenant info is correct.
|
||||
# 2.a. Double-check that tenant info is correct.
|
||||
if tenant_id != get_current_tenant_id():
|
||||
err_str = (
|
||||
f"Tenant ID mismatch in the OpenSearch migration task: "
|
||||
@@ -148,16 +155,62 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
task_logger.error(err_str)
|
||||
return False
|
||||
|
||||
with (
|
||||
get_session_with_current_tenant() as db_session,
|
||||
get_vespa_http_client(
|
||||
timeout=VESPA_MIGRATION_REQUEST_TIMEOUT_S
|
||||
) as vespa_client,
|
||||
):
|
||||
# Do as much as we can with a DB session in one spot to not hold a
|
||||
# session during a migration batch.
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# 2.b. Immediately check to see if this tenant is done, to save
|
||||
# having to do any other work. This function does not require a
|
||||
# migration record to necessarily exist.
|
||||
if is_migration_completed(db_session):
|
||||
return True
|
||||
|
||||
# 2.c. Try to insert the OpenSearchTenantMigrationRecord table if it
|
||||
# does not exist.
|
||||
try_insert_opensearch_tenant_migration_record_with_commit(db_session)
|
||||
|
||||
# 2.d. Get search settings.
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
tenant_state = TenantState(tenant_id=tenant_id, multitenant=MULTI_TENANT)
|
||||
indexing_setting = IndexingSetting.from_db_model(search_settings)
|
||||
|
||||
# 2.e. Build sanitized to original doc ID mapping to check for
|
||||
# conflicts in the event we sanitize a doc ID to an
|
||||
# already-existing doc ID.
|
||||
# We reconstruct this mapping for every task invocation because
|
||||
# a document may have been added in the time between two tasks.
|
||||
sanitized_doc_start_time = time.monotonic()
|
||||
sanitized_to_original_doc_id_mapping = (
|
||||
build_sanitized_to_original_doc_id_mapping(db_session)
|
||||
)
|
||||
task_logger.debug(
|
||||
f"Built sanitized_to_original_doc_id_mapping with {len(sanitized_to_original_doc_id_mapping)} entries "
|
||||
f"in {time.monotonic() - sanitized_doc_start_time:.3f} seconds."
|
||||
)
|
||||
|
||||
# 2.f. Get the current migration state.
|
||||
continuation_token_map, total_chunks_migrated = get_vespa_visit_state(
|
||||
db_session
|
||||
)
|
||||
# 2.f.1. Double-check that the migration state does not imply
|
||||
# completion. Really we should never have to enter this block as we
|
||||
# would expect is_migration_completed to return True, but in the
|
||||
# strange event that the migration is complete but the migration
|
||||
# completed time was never stamped, we do so here.
|
||||
if is_continuation_token_done_for_all_slices(continuation_token_map):
|
||||
task_logger.info(
|
||||
f"OpenSearch migration COMPLETED for tenant {tenant_id}. Total chunks migrated: {total_chunks_migrated}."
|
||||
)
|
||||
mark_migration_completed_time_if_not_set_with_commit(db_session)
|
||||
return True
|
||||
task_logger.debug(
|
||||
f"Read the tenant migration record. Total chunks migrated: {total_chunks_migrated}. "
|
||||
f"Continuation token map: {continuation_token_map}"
|
||||
)
|
||||
|
||||
with get_vespa_http_client(
|
||||
timeout=VESPA_MIGRATION_REQUEST_TIMEOUT_S
|
||||
) as vespa_client:
|
||||
# 2.g. Create the OpenSearch and Vespa document indexes.
|
||||
tenant_state = TenantState(tenant_id=tenant_id, multitenant=MULTI_TENANT)
|
||||
opensearch_document_index = OpenSearchDocumentIndex(
|
||||
tenant_state=tenant_state,
|
||||
index_name=search_settings.index_name,
|
||||
@@ -171,22 +224,14 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
httpx_client=vespa_client,
|
||||
)
|
||||
|
||||
sanitized_doc_start_time = time.monotonic()
|
||||
# We reconstruct this mapping for every task invocation because a
|
||||
# document may have been added in the time between two tasks.
|
||||
sanitized_to_original_doc_id_mapping = (
|
||||
build_sanitized_to_original_doc_id_mapping(db_session)
|
||||
)
|
||||
task_logger.debug(
|
||||
f"Built sanitized_to_original_doc_id_mapping with {len(sanitized_to_original_doc_id_mapping)} entries "
|
||||
f"in {time.monotonic() - sanitized_doc_start_time:.3f} seconds."
|
||||
)
|
||||
|
||||
# 2.h. Get the approximate chunk count in Vespa as of this time to
|
||||
# update the migration record.
|
||||
approx_chunk_count_in_vespa: int | None = None
|
||||
get_chunk_count_start_time = time.monotonic()
|
||||
try:
|
||||
approx_chunk_count_in_vespa = vespa_document_index.get_chunk_count()
|
||||
except Exception:
|
||||
# This failure should not be blocking.
|
||||
task_logger.exception(
|
||||
"Error getting approximate chunk count in Vespa. Moving on..."
|
||||
)
|
||||
@@ -195,25 +240,12 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
f"approximate chunk count in Vespa. Got {approx_chunk_count_in_vespa}."
|
||||
)
|
||||
|
||||
# 3. Do the actual migration in batches until we run out of time.
|
||||
while (
|
||||
time.monotonic() - task_start_time < MIGRATION_TASK_SOFT_TIME_LIMIT_S
|
||||
and lock.owned()
|
||||
):
|
||||
(
|
||||
continuation_token_map,
|
||||
total_chunks_migrated,
|
||||
) = get_vespa_visit_state(db_session)
|
||||
if is_continuation_token_done_for_all_slices(continuation_token_map):
|
||||
task_logger.info(
|
||||
f"OpenSearch migration COMPLETED for tenant {tenant_id}. Total chunks migrated: {total_chunks_migrated}."
|
||||
)
|
||||
mark_migration_completed_time_if_not_set_with_commit(db_session)
|
||||
break
|
||||
task_logger.debug(
|
||||
f"Read the tenant migration record. Total chunks migrated: {total_chunks_migrated}. "
|
||||
f"Continuation token map: {continuation_token_map}"
|
||||
)
|
||||
|
||||
# 3.a. Get the next batch of raw chunks from Vespa.
|
||||
get_vespa_chunks_start_time = time.monotonic()
|
||||
raw_vespa_chunks, next_continuation_token_map = (
|
||||
vespa_document_index.get_all_raw_document_chunks_paginated(
|
||||
@@ -226,6 +258,7 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
f"seconds. Next continuation token map: {next_continuation_token_map}"
|
||||
)
|
||||
|
||||
# 3.b. Transform the raw chunks to OpenSearch chunks in memory.
|
||||
opensearch_document_chunks, errored_chunks = (
|
||||
transform_vespa_chunks_to_opensearch_chunks(
|
||||
raw_vespa_chunks,
|
||||
@@ -240,6 +273,7 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
"errored."
|
||||
)
|
||||
|
||||
# 3.c. Index the OpenSearch chunks into OpenSearch.
|
||||
index_opensearch_chunks_start_time = time.monotonic()
|
||||
opensearch_document_index.index_raw_chunks(
|
||||
chunks=opensearch_document_chunks
|
||||
@@ -251,12 +285,38 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
|
||||
total_chunks_migrated_this_task += len(opensearch_document_chunks)
|
||||
total_chunks_errored_this_task += len(errored_chunks)
|
||||
update_vespa_visit_progress_with_commit(
|
||||
db_session,
|
||||
continuation_token_map=next_continuation_token_map,
|
||||
chunks_processed=len(opensearch_document_chunks),
|
||||
chunks_errored=len(errored_chunks),
|
||||
approx_chunk_count_in_vespa=approx_chunk_count_in_vespa,
|
||||
|
||||
# Do as much as we can with a DB session in one spot to not hold a
|
||||
# session during a migration batch.
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# 3.d. Update the migration state.
|
||||
update_vespa_visit_progress_with_commit(
|
||||
db_session,
|
||||
continuation_token_map=next_continuation_token_map,
|
||||
chunks_processed=len(opensearch_document_chunks),
|
||||
chunks_errored=len(errored_chunks),
|
||||
approx_chunk_count_in_vespa=approx_chunk_count_in_vespa,
|
||||
)
|
||||
|
||||
# 3.e. Get the current migration state. Even thought we
|
||||
# technically have it in-memory since we just wrote it, we
|
||||
# want to reference the DB as the source of truth at all
|
||||
# times.
|
||||
continuation_token_map, total_chunks_migrated = (
|
||||
get_vespa_visit_state(db_session)
|
||||
)
|
||||
# 3.e.1. Check if the migration is done.
|
||||
if is_continuation_token_done_for_all_slices(
|
||||
continuation_token_map
|
||||
):
|
||||
task_logger.info(
|
||||
f"OpenSearch migration COMPLETED for tenant {tenant_id}. Total chunks migrated: {total_chunks_migrated}."
|
||||
)
|
||||
mark_migration_completed_time_if_not_set_with_commit(db_session)
|
||||
return True
|
||||
task_logger.debug(
|
||||
f"Read the tenant migration record. Total chunks migrated: {total_chunks_migrated}. "
|
||||
f"Continuation token map: {continuation_token_map}"
|
||||
)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import uuid
|
||||
|
||||
from fastapi_users.password import PasswordHelper
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
@@ -10,14 +11,23 @@ from onyx.auth.api_key import ApiKeyDescriptor
|
||||
from onyx.auth.api_key import build_displayable_api_key
|
||||
from onyx.auth.api_key import generate_api_key
|
||||
from onyx.auth.api_key import hash_api_key
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from onyx.configs.constants import DANSWER_API_KEY_PREFIX
|
||||
from onyx.configs.constants import UNNAMED_KEY_PLACEHOLDER
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import ApiKey
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.permissions import recompute_user_permissions__no_commit
|
||||
from onyx.db.users import assign_user_to_default_groups__no_commit
|
||||
from onyx.server.api_key.models import APIKeyArgs
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_api_key_email_pattern() -> str:
|
||||
return DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
@@ -85,6 +95,7 @@ def insert_api_key(
|
||||
is_superuser=False,
|
||||
is_verified=True,
|
||||
role=api_key_args.role,
|
||||
account_type=AccountType.SERVICE_ACCOUNT,
|
||||
)
|
||||
db_session.add(api_key_user_row)
|
||||
|
||||
@@ -97,7 +108,18 @@ def insert_api_key(
|
||||
)
|
||||
db_session.add(api_key_row)
|
||||
|
||||
# Assign the API key virtual user to the appropriate default group
|
||||
# before commit so everything is atomic.
|
||||
# LIMITED role service accounts should have no group membership.
|
||||
if api_key_args.role != UserRole.LIMITED:
|
||||
assign_user_to_default_groups__no_commit(
|
||||
db_session,
|
||||
api_key_user_row,
|
||||
is_admin=(api_key_args.role == UserRole.ADMIN),
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return ApiKeyDescriptor(
|
||||
api_key_id=api_key_row.id,
|
||||
api_key_role=api_key_user_row.role,
|
||||
@@ -124,7 +146,33 @@ def update_api_key(
|
||||
|
||||
email_name = api_key_args.name or UNNAMED_KEY_PLACEHOLDER
|
||||
api_key_user.email = get_api_key_fake_email(email_name, str(api_key_user.id))
|
||||
|
||||
old_role = api_key_user.role
|
||||
api_key_user.role = api_key_args.role
|
||||
|
||||
# Reconcile default-group membership when the role changes.
|
||||
if old_role != api_key_args.role:
|
||||
# Remove from all default groups first.
|
||||
delete_stmt = delete(User__UserGroup).where(
|
||||
User__UserGroup.user_id == api_key_user.id,
|
||||
User__UserGroup.user_group_id.in_(
|
||||
select(UserGroup.id).where(UserGroup.is_default.is_(True))
|
||||
),
|
||||
)
|
||||
db_session.execute(delete_stmt)
|
||||
|
||||
# Re-assign to the correct default group (skip for LIMITED).
|
||||
if api_key_args.role != UserRole.LIMITED:
|
||||
assign_user_to_default_groups__no_commit(
|
||||
db_session,
|
||||
api_key_user,
|
||||
is_admin=(api_key_args.role == UserRole.ADMIN),
|
||||
)
|
||||
else:
|
||||
# No group assigned for LIMITED, but we still need to recompute
|
||||
# since we just removed the old default-group membership above.
|
||||
recompute_user_permissions__no_commit(api_key_user.id, db_session)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return ApiKeyDescriptor(
|
||||
|
||||
@@ -199,7 +199,7 @@ def delete_messages_and_files_from_chat_session(
|
||||
for _, files in messages_with_files:
|
||||
file_store = get_default_file_store()
|
||||
for file_info in files or []:
|
||||
file_store.delete_file(file_id=file_info.get("id"))
|
||||
file_store.delete_file(file_id=file_info.get("id"), error_on_missing=False)
|
||||
|
||||
# Delete ChatMessage records - CASCADE constraints will automatically handle:
|
||||
# - ChatMessage__StandardAnswer relationship records
|
||||
|
||||
@@ -13,19 +13,26 @@ class AccountType(str, PyEnum):
|
||||
BOT, EXT_PERM_USER, ANONYMOUS → fixed behavior
|
||||
"""
|
||||
|
||||
STANDARD = "standard"
|
||||
BOT = "bot"
|
||||
EXT_PERM_USER = "ext_perm_user"
|
||||
SERVICE_ACCOUNT = "service_account"
|
||||
ANONYMOUS = "anonymous"
|
||||
STANDARD = "STANDARD"
|
||||
BOT = "BOT"
|
||||
EXT_PERM_USER = "EXT_PERM_USER"
|
||||
SERVICE_ACCOUNT = "SERVICE_ACCOUNT"
|
||||
ANONYMOUS = "ANONYMOUS"
|
||||
|
||||
def is_web_login(self) -> bool:
|
||||
"""Whether this account type supports interactive web login."""
|
||||
return self not in (
|
||||
AccountType.BOT,
|
||||
AccountType.EXT_PERM_USER,
|
||||
)
|
||||
|
||||
|
||||
class GrantSource(str, PyEnum):
|
||||
"""How a permission grant was created."""
|
||||
|
||||
USER = "user"
|
||||
SCIM = "scim"
|
||||
SYSTEM = "system"
|
||||
USER = "USER"
|
||||
SCIM = "SCIM"
|
||||
SYSTEM = "SYSTEM"
|
||||
|
||||
|
||||
class IndexingStatus(str, PyEnum):
|
||||
|
||||
@@ -305,8 +305,11 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
role: Mapped[UserRole] = mapped_column(
|
||||
Enum(UserRole, native_enum=False, default=UserRole.BASIC)
|
||||
)
|
||||
account_type: Mapped[AccountType | None] = mapped_column(
|
||||
Enum(AccountType, native_enum=False), nullable=True
|
||||
account_type: Mapped[AccountType] = mapped_column(
|
||||
Enum(AccountType, native_enum=False),
|
||||
nullable=False,
|
||||
default=AccountType.STANDARD,
|
||||
server_default="STANDARD",
|
||||
)
|
||||
|
||||
"""
|
||||
@@ -353,6 +356,13 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
postgresql.JSONB(), nullable=True, default=None
|
||||
)
|
||||
|
||||
effective_permissions: Mapped[list[str]] = mapped_column(
|
||||
postgresql.JSONB(),
|
||||
nullable=False,
|
||||
default=list,
|
||||
server_default=text("'[]'::jsonb"),
|
||||
)
|
||||
|
||||
oidc_expiry: Mapped[datetime.datetime] = mapped_column(
|
||||
TIMESTAMPAware(timezone=True), nullable=True
|
||||
)
|
||||
@@ -4016,7 +4026,12 @@ class PermissionGrant(Base):
|
||||
ForeignKey("user_group.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
permission: Mapped[Permission] = mapped_column(
|
||||
Enum(Permission, native_enum=False), nullable=False
|
||||
Enum(
|
||||
Permission,
|
||||
native_enum=False,
|
||||
values_callable=lambda x: [e.value for e in x],
|
||||
),
|
||||
nullable=False,
|
||||
)
|
||||
grant_source: Mapped[GrantSource] = mapped_column(
|
||||
Enum(GrantSource, native_enum=False), nullable=False
|
||||
|
||||
@@ -324,6 +324,15 @@ def mark_migration_completed_time_if_not_set_with_commit(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def is_migration_completed(db_session: Session) -> bool:
|
||||
"""Returns True if the migration is completed.
|
||||
|
||||
Can be run even if the migration record does not exist.
|
||||
"""
|
||||
record = db_session.query(OpenSearchTenantMigrationRecord).first()
|
||||
return record is not None and record.migration_completed_at is not None
|
||||
|
||||
|
||||
def build_sanitized_to_original_doc_id_mapping(
|
||||
db_session: Session,
|
||||
) -> dict[str, str]:
|
||||
|
||||
95
backend/onyx/db/permissions.py
Normal file
95
backend/onyx/db/permissions.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""
|
||||
DB operations for recomputing user effective_permissions.
|
||||
|
||||
These live in onyx/db/ (not onyx/auth/) because they are pure DB operations
|
||||
that query PermissionGrant rows and update the User.effective_permissions
|
||||
JSONB column. Keeping them here avoids circular imports when called from
|
||||
other onyx/db/ modules such as users.py.
|
||||
"""
|
||||
|
||||
from collections import defaultdict
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import PermissionGrant
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
|
||||
|
||||
def recompute_user_permissions__no_commit(
|
||||
user_ids: UUID | str | list[UUID] | list[str], db_session: Session
|
||||
) -> None:
|
||||
"""Recompute granted permissions for one or more users.
|
||||
|
||||
Accepts a single UUID or a list. Uses a single query regardless of
|
||||
how many users are passed, avoiding N+1 issues.
|
||||
|
||||
Stores only directly granted permissions — implication expansion
|
||||
happens at read time via get_effective_permissions().
|
||||
|
||||
Does NOT commit — caller must commit the session.
|
||||
"""
|
||||
if isinstance(user_ids, (UUID, str)):
|
||||
uid_list = [user_ids]
|
||||
else:
|
||||
uid_list = list(user_ids)
|
||||
|
||||
if not uid_list:
|
||||
return
|
||||
|
||||
# Single query to fetch ALL permissions for these users across ALL their
|
||||
# groups (a user may belong to multiple groups with different grants).
|
||||
rows = db_session.execute(
|
||||
select(User__UserGroup.user_id, PermissionGrant.permission)
|
||||
.join(
|
||||
PermissionGrant,
|
||||
PermissionGrant.group_id == User__UserGroup.user_group_id,
|
||||
)
|
||||
.where(
|
||||
User__UserGroup.user_id.in_(uid_list),
|
||||
PermissionGrant.is_deleted.is_(False),
|
||||
)
|
||||
).all()
|
||||
|
||||
# Group permissions by user; users with no grants get an empty set.
|
||||
perms_by_user: dict[UUID | str, set[str]] = defaultdict(set)
|
||||
for uid in uid_list:
|
||||
perms_by_user[uid] # ensure every user has an entry
|
||||
for uid, perm in rows:
|
||||
perms_by_user[uid].add(perm.value)
|
||||
|
||||
for uid, perms in perms_by_user.items():
|
||||
db_session.execute(
|
||||
update(User)
|
||||
.where(User.id == uid) # type: ignore[arg-type]
|
||||
.values(effective_permissions=sorted(perms))
|
||||
)
|
||||
|
||||
|
||||
def recompute_permissions_for_group__no_commit(
|
||||
group_id: int, db_session: Session
|
||||
) -> None:
|
||||
"""Recompute granted permissions for all users in a group.
|
||||
|
||||
Does NOT commit — caller must commit the session.
|
||||
"""
|
||||
user_ids: list[UUID] = [
|
||||
uid
|
||||
for uid in db_session.execute(
|
||||
select(User__UserGroup.user_id).where(
|
||||
User__UserGroup.user_group_id == group_id,
|
||||
User__UserGroup.user_id.isnot(None),
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
if uid is not None
|
||||
]
|
||||
|
||||
if not user_ids:
|
||||
return
|
||||
|
||||
recompute_user_permissions__no_commit(user_ids, db_session)
|
||||
@@ -5,11 +5,11 @@ from urllib.parse import urlencode
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.app_configs import INSTANCE_TYPE
|
||||
from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.configs.constants import ONYX_UTM_SOURCE
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import User
|
||||
from onyx.db.notification import batch_create_notifications
|
||||
from onyx.server.features.release_notes.constants import DOCS_CHANGELOG_BASE_URL
|
||||
@@ -49,7 +49,7 @@ def create_release_notifications_for_versions(
|
||||
db_session.scalars(
|
||||
select(User.id).where( # type: ignore
|
||||
User.is_active == True, # noqa: E712
|
||||
User.role.notin_([UserRole.SLACK_USER, UserRole.EXT_PERM_USER]),
|
||||
User.account_type.notin_([AccountType.BOT, AccountType.EXT_PERM_USER]),
|
||||
User.email.endswith(DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN).is_(False), # type: ignore[attr-defined]
|
||||
)
|
||||
).all()
|
||||
|
||||
@@ -9,12 +9,17 @@ from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.enums import DefaultAppMode
|
||||
from onyx.db.enums import ThemePreference
|
||||
from onyx.db.models import AccessToken
|
||||
from onyx.db.models import Assistant__UserSpecificConfig
|
||||
from onyx.db.models import Memory
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.permissions import recompute_user_permissions__no_commit
|
||||
from onyx.db.users import assign_user_to_default_groups__no_commit
|
||||
from onyx.server.manage.models import MemoryItem
|
||||
from onyx.server.manage.models import UserSpecificAssistantPreference
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -23,13 +28,53 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
_ROLE_TO_ACCOUNT_TYPE: dict[UserRole, AccountType] = {
|
||||
UserRole.SLACK_USER: AccountType.BOT,
|
||||
UserRole.EXT_PERM_USER: AccountType.EXT_PERM_USER,
|
||||
}
|
||||
|
||||
|
||||
def update_user_role(
|
||||
user: User,
|
||||
new_role: UserRole,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Update a user's role in the database."""
|
||||
"""Update a user's role in the database.
|
||||
Dual-writes account_type to keep it in sync with role and
|
||||
reconciles default-group membership (Admin / Basic)."""
|
||||
old_role = user.role
|
||||
user.role = new_role
|
||||
# Note: setting account_type to BOT or EXT_PERM_USER causes
|
||||
# assign_user_to_default_groups__no_commit to early-return, which is
|
||||
# intentional — these account types should not be in default groups.
|
||||
if new_role in _ROLE_TO_ACCOUNT_TYPE:
|
||||
user.account_type = _ROLE_TO_ACCOUNT_TYPE[new_role]
|
||||
elif user.account_type in (AccountType.BOT, AccountType.EXT_PERM_USER):
|
||||
# Upgrading from a non-web-login account type to a web role
|
||||
user.account_type = AccountType.STANDARD
|
||||
|
||||
# Reconcile default-group membership when the role changes.
|
||||
if old_role != new_role:
|
||||
# Remove from all default groups first.
|
||||
db_session.execute(
|
||||
delete(User__UserGroup).where(
|
||||
User__UserGroup.user_id == user.id,
|
||||
User__UserGroup.user_group_id.in_(
|
||||
select(UserGroup.id).where(UserGroup.is_default.is_(True))
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Re-assign to the correct default group (skip for LIMITED).
|
||||
if new_role != UserRole.LIMITED:
|
||||
assign_user_to_default_groups__no_commit(
|
||||
db_session,
|
||||
user,
|
||||
is_admin=(new_role == UserRole.ADMIN),
|
||||
)
|
||||
|
||||
recompute_user_permissions__no_commit(user.id, db_session)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -47,8 +92,16 @@ def activate_user(
|
||||
user: User,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Activate a user by setting is_active to True."""
|
||||
"""Activate a user by setting is_active to True.
|
||||
|
||||
Also reconciles default-group membership — the user may have been
|
||||
created while inactive or deactivated before the backfill migration.
|
||||
"""
|
||||
user.is_active = True
|
||||
if user.role != UserRole.LIMITED:
|
||||
assign_user_to_default_groups__no_commit(
|
||||
db_session, user, is_admin=(user.role == UserRole.ADMIN)
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -17,8 +17,9 @@ from sqlalchemy.sql.expression import or_
|
||||
from onyx.auth.invited_users import remove_user_from_invited_users
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.constants import ANONYMOUS_USER_EMAIL
|
||||
from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from onyx.configs.constants import NO_AUTH_PLACEHOLDER_USER_EMAIL
|
||||
from onyx.db.api_key import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import DocumentSet__User
|
||||
from onyx.db.models import Persona
|
||||
@@ -27,11 +28,17 @@ from onyx.db.models import SamlAccount
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def validate_user_role_update(
|
||||
requested_role: UserRole, current_role: UserRole, explicit_override: bool = False
|
||||
requested_role: UserRole,
|
||||
current_role: UserRole,
|
||||
current_account_type: AccountType,
|
||||
explicit_override: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Validate that a user role update is valid.
|
||||
@@ -41,19 +48,18 @@ def validate_user_role_update(
|
||||
- requested role is a slack user
|
||||
- requested role is an external permissioned user
|
||||
- requested role is a limited user
|
||||
- current role is a slack user
|
||||
- current role is an external permissioned user
|
||||
- current account type is BOT (slack user)
|
||||
- current account type is EXT_PERM_USER
|
||||
- current role is a limited user
|
||||
"""
|
||||
|
||||
if current_role == UserRole.SLACK_USER:
|
||||
if current_account_type == AccountType.BOT:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="To change a Slack User's role, they must first login to Onyx via the web app.",
|
||||
)
|
||||
|
||||
if current_role == UserRole.EXT_PERM_USER:
|
||||
# This shouldn't happen, but just in case
|
||||
if current_account_type == AccountType.EXT_PERM_USER:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="To change an External Permissioned User's role, they must first login to Onyx via the web app.",
|
||||
@@ -298,6 +304,7 @@ def _generate_slack_user(email: str) -> User:
|
||||
email=email,
|
||||
hashed_password=hashed_pass,
|
||||
role=UserRole.SLACK_USER,
|
||||
account_type=AccountType.BOT,
|
||||
)
|
||||
|
||||
|
||||
@@ -306,8 +313,9 @@ def add_slack_user_if_not_exists(db_session: Session, email: str) -> User:
|
||||
user = get_user_by_email(email, db_session)
|
||||
if user is not None:
|
||||
# If the user is an external permissioned user, we update it to a slack user
|
||||
if user.role == UserRole.EXT_PERM_USER:
|
||||
if user.account_type == AccountType.EXT_PERM_USER:
|
||||
user.role = UserRole.SLACK_USER
|
||||
user.account_type = AccountType.BOT
|
||||
db_session.commit()
|
||||
return user
|
||||
|
||||
@@ -344,6 +352,7 @@ def _generate_ext_permissioned_user(email: str) -> User:
|
||||
email=email,
|
||||
hashed_password=hashed_pass,
|
||||
role=UserRole.EXT_PERM_USER,
|
||||
account_type=AccountType.EXT_PERM_USER,
|
||||
)
|
||||
|
||||
|
||||
@@ -375,6 +384,81 @@ def batch_add_ext_perm_user_if_not_exists(
|
||||
return all_users
|
||||
|
||||
|
||||
def assign_user_to_default_groups__no_commit(
|
||||
db_session: Session,
|
||||
user: User,
|
||||
is_admin: bool = False,
|
||||
) -> None:
|
||||
"""Assign a newly created user to the appropriate default group.
|
||||
|
||||
Does NOT commit — callers must commit the session themselves so that
|
||||
group assignment can be part of the same transaction as user creation.
|
||||
|
||||
Args:
|
||||
is_admin: If True, assign to Admin default group; otherwise Basic.
|
||||
Callers determine this from their own context (e.g. user_count,
|
||||
admin email list, explicit choice). Defaults to False (Basic).
|
||||
"""
|
||||
if user.account_type in (
|
||||
AccountType.BOT,
|
||||
AccountType.EXT_PERM_USER,
|
||||
AccountType.ANONYMOUS,
|
||||
):
|
||||
return
|
||||
|
||||
target_group_name = "Admin" if is_admin else "Basic"
|
||||
|
||||
default_group = (
|
||||
db_session.query(UserGroup)
|
||||
.filter(
|
||||
UserGroup.name == target_group_name,
|
||||
UserGroup.is_default.is_(True),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if default_group is None:
|
||||
raise RuntimeError(
|
||||
f"Default group '{target_group_name}' not found. "
|
||||
f"Cannot assign user {user.email} to a group. "
|
||||
f"Ensure the seed_default_groups migration has run."
|
||||
)
|
||||
|
||||
# Check if the user is already in the group
|
||||
existing = (
|
||||
db_session.query(User__UserGroup)
|
||||
.filter(
|
||||
User__UserGroup.user_id == user.id,
|
||||
User__UserGroup.user_group_id == default_group.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if existing is not None:
|
||||
return
|
||||
|
||||
savepoint = db_session.begin_nested()
|
||||
try:
|
||||
db_session.add(
|
||||
User__UserGroup(
|
||||
user_id=user.id,
|
||||
user_group_id=default_group.id,
|
||||
)
|
||||
)
|
||||
db_session.flush()
|
||||
except IntegrityError:
|
||||
# Race condition: another transaction inserted this membership
|
||||
# between our SELECT and INSERT. The savepoint isolates the failure
|
||||
# so the outer transaction (user creation) stays intact.
|
||||
savepoint.rollback()
|
||||
return
|
||||
|
||||
from onyx.db.permissions import recompute_user_permissions__no_commit
|
||||
|
||||
recompute_user_permissions__no_commit(user.id, db_session)
|
||||
|
||||
logger.info(f"Assigned user {user.email} to default group '{default_group.name}'")
|
||||
|
||||
|
||||
def delete_user_from_db(
|
||||
user_to_delete: User,
|
||||
db_session: Session,
|
||||
@@ -421,13 +505,14 @@ def delete_user_from_db(
|
||||
def batch_get_user_groups(
|
||||
db_session: Session,
|
||||
user_ids: list[UUID],
|
||||
include_default: bool = False,
|
||||
) -> dict[UUID, list[tuple[int, str]]]:
|
||||
"""Fetch group memberships for a batch of users in a single query.
|
||||
Returns a mapping of user_id -> list of (group_id, group_name) tuples."""
|
||||
if not user_ids:
|
||||
return {}
|
||||
|
||||
rows = db_session.execute(
|
||||
stmt = (
|
||||
select(
|
||||
User__UserGroup.user_id,
|
||||
UserGroup.id,
|
||||
@@ -435,7 +520,11 @@ def batch_get_user_groups(
|
||||
)
|
||||
.join(UserGroup, UserGroup.id == User__UserGroup.user_group_id)
|
||||
.where(User__UserGroup.user_id.in_(user_ids))
|
||||
).all()
|
||||
)
|
||||
if not include_default:
|
||||
stmt = stmt.where(UserGroup.is_default == False) # noqa: E712
|
||||
|
||||
rows = db_session.execute(stmt).all()
|
||||
|
||||
result: dict[UUID, list[tuple[int, str]]] = {uid: [] for uid in user_ids}
|
||||
for user_id, group_id, group_name in rows:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import hashlib
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
@@ -20,9 +21,13 @@ from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE
|
||||
from onyx.document_index.opensearch.constants import EF_CONSTRUCTION
|
||||
from onyx.document_index.opensearch.constants import EF_SEARCH
|
||||
from onyx.document_index.opensearch.constants import M
|
||||
from onyx.document_index.opensearch.string_filtering import DocumentIDTooLongError
|
||||
from onyx.document_index.opensearch.string_filtering import (
|
||||
filter_and_validate_document_id,
|
||||
)
|
||||
from onyx.document_index.opensearch.string_filtering import (
|
||||
MAX_DOCUMENT_ID_ENCODED_LENGTH,
|
||||
)
|
||||
from onyx.utils.tenant import get_tenant_id_short_string
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
@@ -75,17 +80,50 @@ def get_opensearch_doc_chunk_id(
|
||||
|
||||
This will be the string used to identify the chunk in OpenSearch. Any direct
|
||||
chunk queries should use this function.
|
||||
|
||||
If the document ID is too long, a hash of the ID is used instead.
|
||||
"""
|
||||
sanitized_document_id = filter_and_validate_document_id(document_id)
|
||||
opensearch_doc_chunk_id = (
|
||||
f"{sanitized_document_id}__{max_chunk_size}__{chunk_index}"
|
||||
opensearch_doc_chunk_id_suffix: str = f"__{max_chunk_size}__{chunk_index}"
|
||||
encoded_suffix_length: int = len(opensearch_doc_chunk_id_suffix.encode("utf-8"))
|
||||
max_encoded_permissible_doc_id_length: int = (
|
||||
MAX_DOCUMENT_ID_ENCODED_LENGTH - encoded_suffix_length
|
||||
)
|
||||
opensearch_doc_chunk_id_tenant_prefix: str = ""
|
||||
if tenant_state.multitenant:
|
||||
short_tenant_id: str = get_tenant_id_short_string(tenant_state.tenant_id)
|
||||
# Use tenant ID because in multitenant mode each tenant has its own
|
||||
# Documents table, so there is a very small chance that doc IDs are not
|
||||
# actually unique across all tenants.
|
||||
short_tenant_id = get_tenant_id_short_string(tenant_state.tenant_id)
|
||||
opensearch_doc_chunk_id = f"{short_tenant_id}__{opensearch_doc_chunk_id}"
|
||||
opensearch_doc_chunk_id_tenant_prefix = f"{short_tenant_id}__"
|
||||
encoded_prefix_length: int = len(
|
||||
opensearch_doc_chunk_id_tenant_prefix.encode("utf-8")
|
||||
)
|
||||
max_encoded_permissible_doc_id_length -= encoded_prefix_length
|
||||
|
||||
try:
|
||||
sanitized_document_id: str = filter_and_validate_document_id(
|
||||
document_id, max_encoded_length=max_encoded_permissible_doc_id_length
|
||||
)
|
||||
except DocumentIDTooLongError:
|
||||
# If the document ID is too long, use a hash instead.
|
||||
# We use blake2b because it is faster and equally secure as SHA256, and
|
||||
# accepts digest_size which controls the number of bytes returned in the
|
||||
# hash.
|
||||
# digest_size is the size of the returned hash in bytes. Since we're
|
||||
# decoding the hash bytes as a hex string, the digest_size should be
|
||||
# half the max target size of the hash string.
|
||||
# Subtract 1 because filter_and_validate_document_id compares on >= on
|
||||
# max_encoded_length.
|
||||
# 64 is the max digest_size blake2b returns.
|
||||
digest_size: int = min((max_encoded_permissible_doc_id_length - 1) // 2, 64)
|
||||
sanitized_document_id = hashlib.blake2b(
|
||||
document_id.encode("utf-8"), digest_size=digest_size
|
||||
).hexdigest()
|
||||
|
||||
opensearch_doc_chunk_id: str = (
|
||||
f"{opensearch_doc_chunk_id_tenant_prefix}{sanitized_document_id}{opensearch_doc_chunk_id_suffix}"
|
||||
)
|
||||
|
||||
# Do one more validation to ensure we haven't exceeded the max length.
|
||||
opensearch_doc_chunk_id = filter_and_validate_document_id(opensearch_doc_chunk_id)
|
||||
return opensearch_doc_chunk_id
|
||||
|
||||
@@ -1,7 +1,15 @@
|
||||
import re
|
||||
|
||||
MAX_DOCUMENT_ID_ENCODED_LENGTH: int = 512
|
||||
|
||||
def filter_and_validate_document_id(document_id: str) -> str:
|
||||
|
||||
class DocumentIDTooLongError(ValueError):
|
||||
"""Raised when a document ID is too long for OpenSearch after filtering."""
|
||||
|
||||
|
||||
def filter_and_validate_document_id(
|
||||
document_id: str, max_encoded_length: int = MAX_DOCUMENT_ID_ENCODED_LENGTH
|
||||
) -> str:
|
||||
"""
|
||||
Filters and validates a document ID such that it can be used as an ID in
|
||||
OpenSearch.
|
||||
@@ -19,9 +27,13 @@ def filter_and_validate_document_id(document_id: str) -> str:
|
||||
|
||||
Args:
|
||||
document_id: The document ID to filter and validate.
|
||||
max_encoded_length: The maximum length of the document ID after
|
||||
filtering in bytes. Compared with >= for extra resilience, so
|
||||
encoded values of this length will fail.
|
||||
|
||||
Raises:
|
||||
ValueError: If the document ID is empty or too long after filtering.
|
||||
DocumentIDTooLongError: If the document ID is too long after filtering.
|
||||
ValueError: If the document ID is empty after filtering.
|
||||
|
||||
Returns:
|
||||
str: The filtered document ID.
|
||||
@@ -29,6 +41,8 @@ def filter_and_validate_document_id(document_id: str) -> str:
|
||||
filtered_document_id = re.sub(r"[^A-Za-z0-9_.\-~]", "", document_id)
|
||||
if not filtered_document_id:
|
||||
raise ValueError(f"Document ID {document_id} is empty after filtering.")
|
||||
if len(filtered_document_id.encode("utf-8")) >= 512:
|
||||
raise ValueError(f"Document ID {document_id} is too long after filtering.")
|
||||
if len(filtered_document_id.encode("utf-8")) >= max_encoded_length:
|
||||
raise DocumentIDTooLongError(
|
||||
f"Document ID {document_id} is too long after filtering."
|
||||
)
|
||||
return filtered_document_id
|
||||
|
||||
@@ -136,12 +136,14 @@ class FileStore(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def delete_file(self, file_id: str) -> None:
|
||||
def delete_file(self, file_id: str, error_on_missing: bool = True) -> None:
|
||||
"""
|
||||
Delete a file by its ID.
|
||||
|
||||
Parameters:
|
||||
- file_name: Name of file to delete
|
||||
- file_id: ID of file to delete
|
||||
- error_on_missing: If False, silently return when the file record
|
||||
does not exist instead of raising.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
@@ -452,12 +454,23 @@ class S3BackedFileStore(FileStore):
|
||||
logger.warning(f"Error getting file size for {file_id}: {e}")
|
||||
return None
|
||||
|
||||
def delete_file(self, file_id: str, db_session: Session | None = None) -> None:
|
||||
def delete_file(
|
||||
self,
|
||||
file_id: str,
|
||||
error_on_missing: bool = True,
|
||||
db_session: Session | None = None,
|
||||
) -> None:
|
||||
with get_session_with_current_tenant_if_none(db_session) as db_session:
|
||||
try:
|
||||
file_record = get_filerecord_by_file_id(
|
||||
file_record = get_filerecord_by_file_id_optional(
|
||||
file_id=file_id, db_session=db_session
|
||||
)
|
||||
if file_record is None:
|
||||
if error_on_missing:
|
||||
raise RuntimeError(
|
||||
f"File by id {file_id} does not exist or was deleted"
|
||||
)
|
||||
return
|
||||
if not file_record.bucket_name:
|
||||
logger.error(
|
||||
f"File record {file_id} with key {file_record.object_key} "
|
||||
|
||||
@@ -222,12 +222,23 @@ class PostgresBackedFileStore(FileStore):
|
||||
logger.warning(f"Error getting file size for {file_id}: {e}")
|
||||
return None
|
||||
|
||||
def delete_file(self, file_id: str, db_session: Session | None = None) -> None:
|
||||
def delete_file(
|
||||
self,
|
||||
file_id: str,
|
||||
error_on_missing: bool = True,
|
||||
db_session: Session | None = None,
|
||||
) -> None:
|
||||
with get_session_with_current_tenant_if_none(db_session) as session:
|
||||
try:
|
||||
file_content = get_file_content_by_file_id(
|
||||
file_content = get_file_content_by_file_id_optional(
|
||||
file_id=file_id, db_session=session
|
||||
)
|
||||
if file_content is None:
|
||||
if error_on_missing:
|
||||
raise RuntimeError(
|
||||
f"File content for file_id {file_id} does not exist or was deleted"
|
||||
)
|
||||
return
|
||||
raw_conn = _get_raw_connection(session)
|
||||
|
||||
try:
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.mcp_server.api import mcp_server
|
||||
from onyx.mcp_server.utils import get_http_client
|
||||
@@ -15,6 +17,21 @@ from onyx.utils.variable_functionality import global_version
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _extract_error_detail(response: httpx.Response) -> str:
|
||||
"""Extract a human-readable error message from a failed backend response.
|
||||
|
||||
The backend returns OnyxError responses as
|
||||
``{"error_code": "...", "detail": "..."}``.
|
||||
"""
|
||||
try:
|
||||
body = response.json()
|
||||
if detail := body.get("detail"):
|
||||
return str(detail)
|
||||
except Exception:
|
||||
pass
|
||||
return f"Request failed with status {response.status_code}"
|
||||
|
||||
|
||||
@mcp_server.tool()
|
||||
async def search_indexed_documents(
|
||||
query: str,
|
||||
@@ -158,7 +175,14 @@ async def search_indexed_documents(
|
||||
json=search_request,
|
||||
headers=auth_headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
if not response.is_success:
|
||||
error_detail = _extract_error_detail(response)
|
||||
return {
|
||||
"documents": [],
|
||||
"total_results": 0,
|
||||
"query": query,
|
||||
"error": error_detail,
|
||||
}
|
||||
result = response.json()
|
||||
|
||||
# Check for error in response
|
||||
@@ -234,7 +258,13 @@ async def search_web(
|
||||
json=request_payload,
|
||||
headers={"Authorization": f"Bearer {access_token.token}"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
if not response.is_success:
|
||||
error_detail = _extract_error_detail(response)
|
||||
return {
|
||||
"error": error_detail,
|
||||
"results": [],
|
||||
"query": query,
|
||||
}
|
||||
response_payload = response.json()
|
||||
results = response_payload.get("results", [])
|
||||
return {
|
||||
@@ -280,7 +310,12 @@ async def open_urls(
|
||||
json={"urls": urls},
|
||||
headers={"Authorization": f"Bearer {access_token.token}"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
if not response.is_success:
|
||||
error_detail = _extract_error_detail(response)
|
||||
return {
|
||||
"error": error_detail,
|
||||
"results": [],
|
||||
}
|
||||
response_payload = response.json()
|
||||
results = response_payload.get("results", [])
|
||||
return {
|
||||
|
||||
@@ -3,10 +3,10 @@ import datetime
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.onyxbot_configs import ONYX_BOT_FEEDBACK_REMINDER
|
||||
from onyx.configs.onyxbot_configs import ONYX_BOT_REACT_EMOJI
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import SlackChannelConfig
|
||||
from onyx.db.user_preferences import activate_user
|
||||
from onyx.db.users import add_slack_user_if_not_exists
|
||||
@@ -247,7 +247,7 @@ def handle_message(
|
||||
|
||||
elif (
|
||||
not existing_user.is_active
|
||||
and existing_user.role == UserRole.SLACK_USER
|
||||
and existing_user.account_type == AccountType.BOT
|
||||
):
|
||||
check_seat_fn = fetch_ee_implementation_or_noop(
|
||||
"onyx.db.license",
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_user
|
||||
@@ -9,6 +8,8 @@ from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.db.web_search import fetch_active_web_content_provider
|
||||
from onyx.db.web_search import fetch_active_web_search_provider
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.server.features.web_search.models import OpenUrlsToolRequest
|
||||
from onyx.server.features.web_search.models import OpenUrlsToolResponse
|
||||
from onyx.server.features.web_search.models import WebSearchToolRequest
|
||||
@@ -61,9 +62,10 @@ def _get_active_search_provider(
|
||||
) -> tuple[WebSearchProviderView, WebSearchProvider]:
|
||||
provider_model = fetch_active_web_search_provider(db_session)
|
||||
if provider_model is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No web search provider configured.",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
"No web search provider configured. Please configure one in "
|
||||
"Admin > Web Search settings.",
|
||||
)
|
||||
|
||||
provider_view = WebSearchProviderView(
|
||||
@@ -76,9 +78,10 @@ def _get_active_search_provider(
|
||||
)
|
||||
|
||||
if provider_model.api_key is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Web search provider requires an API key.",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
"Web search provider requires an API key. Please configure one in "
|
||||
"Admin > Web Search settings.",
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -88,7 +91,7 @@ def _get_active_search_provider(
|
||||
config=provider_model.config or {},
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(exc)) from exc
|
||||
|
||||
return provider_view, provider
|
||||
|
||||
@@ -110,9 +113,9 @@ def _get_active_content_provider(
|
||||
|
||||
if provider_model.api_key is None:
|
||||
# TODO - this is not a great error, in fact, this key should not be nullable.
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Web content provider requires an API key.",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
"Web content provider requires an API key.",
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -125,12 +128,12 @@ def _get_active_content_provider(
|
||||
config=config,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(exc)) from exc
|
||||
|
||||
if provider is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Unable to initialize the configured web content provider.",
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
"Unable to initialize the configured web content provider.",
|
||||
)
|
||||
|
||||
provider_view = WebContentProviderView(
|
||||
@@ -154,12 +157,13 @@ def _run_web_search(
|
||||
for query in request.queries:
|
||||
try:
|
||||
search_results = provider.search(query)
|
||||
except HTTPException:
|
||||
except OnyxError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.exception("Web search provider failed for query '%s'", query)
|
||||
raise HTTPException(
|
||||
status_code=502, detail="Web search provider failed to execute query."
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
"Web search provider failed to execute query.",
|
||||
) from exc
|
||||
|
||||
filtered_results = filter_web_search_results_with_no_title_or_snippet(
|
||||
@@ -192,12 +196,13 @@ def _open_urls(
|
||||
docs = filter_web_contents_with_no_title_or_content(
|
||||
list(provider.contents(urls))
|
||||
)
|
||||
except HTTPException:
|
||||
except OnyxError:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.exception("Web content provider failed to fetch URLs")
|
||||
raise HTTPException(
|
||||
status_code=502, detail="Web content provider failed to fetch URLs."
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
"Web content provider failed to fetch URLs.",
|
||||
) from exc
|
||||
|
||||
results: list[LlmOpenUrlResult] = []
|
||||
|
||||
@@ -27,6 +27,7 @@ from onyx.auth.email_utils import send_user_email_invite
|
||||
from onyx.auth.invited_users import get_invited_users
|
||||
from onyx.auth.invited_users import remove_user_from_invited_users
|
||||
from onyx.auth.invited_users import write_invited_users
|
||||
from onyx.auth.permissions import get_effective_permissions
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.auth.users import anonymous_user_enabled
|
||||
from onyx.auth.users import current_admin_user
|
||||
@@ -50,6 +51,7 @@ from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
from onyx.db.api_key import is_api_key_email_address
|
||||
from onyx.db.auth import get_live_users_count
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
@@ -142,6 +144,7 @@ def set_user_role(
|
||||
validate_user_role_update(
|
||||
requested_role=requested_role,
|
||||
current_role=current_role,
|
||||
current_account_type=user_to_update.account_type,
|
||||
explicit_override=user_role_update_request.explicit_override,
|
||||
)
|
||||
|
||||
@@ -327,8 +330,8 @@ def list_all_users(
|
||||
if (include_api_keys or not is_api_key_email_address(user.email))
|
||||
]
|
||||
|
||||
slack_users = [user for user in users if user.role == UserRole.SLACK_USER]
|
||||
accepted_users = [user for user in users if user.role != UserRole.SLACK_USER]
|
||||
slack_users = [user for user in users if user.account_type == AccountType.BOT]
|
||||
accepted_users = [user for user in users if user.account_type != AccountType.BOT]
|
||||
|
||||
accepted_emails = {user.email for user in accepted_users}
|
||||
slack_users_emails = {user.email for user in slack_users}
|
||||
@@ -671,7 +674,7 @@ def list_all_users_basic_info(
|
||||
return [
|
||||
MinimalUserSnapshot(id=user.id, email=user.email)
|
||||
for user in users
|
||||
if user.role != UserRole.SLACK_USER
|
||||
if user.account_type != AccountType.BOT
|
||||
and (include_api_keys or not is_api_key_email_address(user.email))
|
||||
]
|
||||
|
||||
@@ -774,6 +777,13 @@ def _get_token_created_at(
|
||||
return get_current_token_creation_postgres(user, db_session)
|
||||
|
||||
|
||||
@router.get("/me/permissions", tags=PUBLIC_API_TAGS)
|
||||
def get_current_user_permissions(
|
||||
user: User = Depends(current_user),
|
||||
) -> list[str]:
|
||||
return sorted(p.value for p in get_effective_permissions(user))
|
||||
|
||||
|
||||
@router.get("/me", tags=PUBLIC_API_TAGS)
|
||||
def verify_user_logged_in(
|
||||
request: Request,
|
||||
|
||||
@@ -7,6 +7,7 @@ from uuid import UUID
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import User
|
||||
|
||||
|
||||
@@ -41,6 +42,7 @@ class FullUserSnapshot(BaseModel):
|
||||
id: UUID
|
||||
email: str
|
||||
role: UserRole
|
||||
account_type: AccountType
|
||||
is_active: bool
|
||||
password_configured: bool
|
||||
personal_name: str | None
|
||||
@@ -60,6 +62,7 @@ class FullUserSnapshot(BaseModel):
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
role=user.role,
|
||||
account_type=user.account_type,
|
||||
is_active=user.is_active,
|
||||
password_configured=user.password_configured,
|
||||
personal_name=user.personal_name,
|
||||
|
||||
@@ -70,7 +70,7 @@ async def upsert_saml_user(email: str) -> User:
|
||||
try:
|
||||
user = await user_manager.get_by_email(email)
|
||||
# If user has a non-authenticated role, treat as non-existent
|
||||
if not user.role.is_web_login():
|
||||
if not user.account_type.is_web_login():
|
||||
raise exceptions.UserNotExists()
|
||||
return user
|
||||
except exceptions.UserNotExists:
|
||||
|
||||
@@ -7,6 +7,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
@@ -52,7 +53,12 @@ def tenant_context() -> Generator[None, None, None]:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
def create_test_user(db_session: Session, email_prefix: str) -> User:
|
||||
def create_test_user(
|
||||
db_session: Session,
|
||||
email_prefix: str,
|
||||
role: UserRole = UserRole.BASIC,
|
||||
account_type: AccountType = AccountType.STANDARD,
|
||||
) -> User:
|
||||
"""Helper to create a test user with a unique email"""
|
||||
# Use UUID to ensure unique email addresses
|
||||
unique_email = f"{email_prefix}_{uuid4().hex[:8]}@example.com"
|
||||
@@ -68,7 +74,8 @@ def create_test_user(db_session: Session, email_prefix: str) -> User:
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
is_verified=True,
|
||||
role=UserRole.EXT_PERM_USER,
|
||||
role=role,
|
||||
account_type=account_type,
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
@@ -13,16 +13,29 @@ from onyx.access.utils import build_ext_group_name_for_onyx
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import InputType
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.models import Connector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Credential
|
||||
from onyx.db.models import PublicExternalUserGroup
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__ExternalUserGroupId
|
||||
from onyx.db.models import UserRole
|
||||
from tests.external_dependency_unit.conftest import create_test_user
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
|
||||
def _create_ext_perm_user(db_session: Session, name: str) -> User:
|
||||
"""Create an external-permission user for group sync tests."""
|
||||
return create_test_user(
|
||||
db_session,
|
||||
name,
|
||||
role=UserRole.EXT_PERM_USER,
|
||||
account_type=AccountType.EXT_PERM_USER,
|
||||
)
|
||||
|
||||
|
||||
def _create_test_connector_credential_pair(
|
||||
db_session: Session, source: DocumentSource = DocumentSource.GOOGLE_DRIVE
|
||||
) -> ConnectorCredentialPair:
|
||||
@@ -100,9 +113,9 @@ class TestPerformExternalGroupSync:
|
||||
def test_initial_group_sync(self, db_session: Session) -> None:
|
||||
"""Test syncing external groups for the first time (initial sync)"""
|
||||
# Create test data
|
||||
user1 = create_test_user(db_session, "user1")
|
||||
user2 = create_test_user(db_session, "user2")
|
||||
user3 = create_test_user(db_session, "user3")
|
||||
user1 = _create_ext_perm_user(db_session, "user1")
|
||||
user2 = _create_ext_perm_user(db_session, "user2")
|
||||
user3 = _create_ext_perm_user(db_session, "user3")
|
||||
cc_pair = _create_test_connector_credential_pair(db_session)
|
||||
|
||||
# Mock external groups data as a generator that yields the expected groups
|
||||
@@ -175,9 +188,9 @@ class TestPerformExternalGroupSync:
|
||||
def test_update_existing_groups(self, db_session: Session) -> None:
|
||||
"""Test updating existing groups (adding/removing users)"""
|
||||
# Create test data
|
||||
user1 = create_test_user(db_session, "user1")
|
||||
user2 = create_test_user(db_session, "user2")
|
||||
user3 = create_test_user(db_session, "user3")
|
||||
user1 = _create_ext_perm_user(db_session, "user1")
|
||||
user2 = _create_ext_perm_user(db_session, "user2")
|
||||
user3 = _create_ext_perm_user(db_session, "user3")
|
||||
cc_pair = _create_test_connector_credential_pair(db_session)
|
||||
|
||||
# Initial sync with original groups
|
||||
@@ -272,8 +285,8 @@ class TestPerformExternalGroupSync:
|
||||
def test_remove_groups(self, db_session: Session) -> None:
|
||||
"""Test removing groups (groups that no longer exist in external system)"""
|
||||
# Create test data
|
||||
user1 = create_test_user(db_session, "user1")
|
||||
user2 = create_test_user(db_session, "user2")
|
||||
user1 = _create_ext_perm_user(db_session, "user1")
|
||||
user2 = _create_ext_perm_user(db_session, "user2")
|
||||
cc_pair = _create_test_connector_credential_pair(db_session)
|
||||
|
||||
# Initial sync with multiple groups
|
||||
@@ -357,7 +370,7 @@ class TestPerformExternalGroupSync:
|
||||
def test_empty_group_sync(self, db_session: Session) -> None:
|
||||
"""Test syncing when no groups are returned (all groups removed)"""
|
||||
# Create test data
|
||||
user1 = create_test_user(db_session, "user1")
|
||||
user1 = _create_ext_perm_user(db_session, "user1")
|
||||
cc_pair = _create_test_connector_credential_pair(db_session)
|
||||
|
||||
# Initial sync with groups
|
||||
@@ -413,7 +426,7 @@ class TestPerformExternalGroupSync:
|
||||
# Create many test users
|
||||
users = []
|
||||
for i in range(150): # More than the batch size of 100
|
||||
users.append(create_test_user(db_session, f"user{i}"))
|
||||
users.append(_create_ext_perm_user(db_session, f"user{i}"))
|
||||
|
||||
cc_pair = _create_test_connector_credential_pair(db_session)
|
||||
|
||||
@@ -452,8 +465,8 @@ class TestPerformExternalGroupSync:
|
||||
def test_mixed_regular_and_public_groups(self, db_session: Session) -> None:
|
||||
"""Test syncing a mix of regular and public groups"""
|
||||
# Create test data
|
||||
user1 = create_test_user(db_session, "user1")
|
||||
user2 = create_test_user(db_session, "user2")
|
||||
user1 = _create_ext_perm_user(db_session, "user1")
|
||||
user2 = _create_ext_perm_user(db_session, "user2")
|
||||
cc_pair = _create_test_connector_credential_pair(db_session)
|
||||
|
||||
def mixed_group_sync_func(
|
||||
|
||||
@@ -9,6 +9,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.enums import BuildSessionStatus
|
||||
from onyx.db.models import BuildSession
|
||||
from onyx.db.models import User
|
||||
@@ -52,6 +53,7 @@ def test_user(db_session: Session, tenant_context: None) -> User: # noqa: ARG00
|
||||
is_superuser=False,
|
||||
is_verified=True,
|
||||
role=UserRole.EXT_PERM_USER,
|
||||
account_type=AccountType.EXT_PERM_USER,
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
@@ -0,0 +1,51 @@
|
||||
"""
|
||||
Tests that account_type is correctly set when creating users through
|
||||
the internal DB functions: add_slack_user_if_not_exists and
|
||||
batch_add_ext_perm_user_if_not_exists.
|
||||
|
||||
These functions are called by background workers (Slack bot, permission sync)
|
||||
and are not exposed via API endpoints, so they must be tested directly.
|
||||
"""
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.db.users import add_slack_user_if_not_exists
|
||||
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
|
||||
|
||||
|
||||
def test_slack_user_creation_sets_account_type_bot(db_session: Session) -> None:
|
||||
"""add_slack_user_if_not_exists sets account_type=BOT and role=SLACK_USER."""
|
||||
user = add_slack_user_if_not_exists(db_session, "slack_acct_type@test.com")
|
||||
|
||||
assert user.role == UserRole.SLACK_USER
|
||||
assert user.account_type == AccountType.BOT
|
||||
|
||||
|
||||
def test_ext_perm_user_creation_sets_account_type(db_session: Session) -> None:
|
||||
"""batch_add_ext_perm_user_if_not_exists sets account_type=EXT_PERM_USER."""
|
||||
users = batch_add_ext_perm_user_if_not_exists(
|
||||
db_session, ["extperm_acct_type@test.com"]
|
||||
)
|
||||
|
||||
assert len(users) == 1
|
||||
user = users[0]
|
||||
assert user.role == UserRole.EXT_PERM_USER
|
||||
assert user.account_type == AccountType.EXT_PERM_USER
|
||||
|
||||
|
||||
def test_ext_perm_to_slack_upgrade_updates_role_and_account_type(
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""When an EXT_PERM_USER is upgraded to slack, both role and account_type update."""
|
||||
email = "ext_to_slack_acct_type@test.com"
|
||||
|
||||
# Create as ext_perm user first
|
||||
batch_add_ext_perm_user_if_not_exists(db_session, [email])
|
||||
|
||||
# Now "upgrade" via slack path
|
||||
user = add_slack_user_if_not_exists(db_session, email)
|
||||
|
||||
assert user.role == UserRole.SLACK_USER
|
||||
assert user.account_type == AccountType.BOT
|
||||
@@ -8,6 +8,7 @@ import pytest
|
||||
from fastapi_users.password import PasswordHelper
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import remove_llm_provider
|
||||
from onyx.db.llm import update_default_provider
|
||||
@@ -46,6 +47,7 @@ def _create_admin(db_session: Session) -> User:
|
||||
is_superuser=True,
|
||||
is_verified=True,
|
||||
role=UserRole.ADMIN,
|
||||
account_type=AccountType.STANDARD,
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
@@ -126,6 +126,15 @@ class UserManager:
|
||||
|
||||
return test_user
|
||||
|
||||
@staticmethod
|
||||
def get_permissions(user: DATestUser) -> list[str]:
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/me/permissions",
|
||||
headers=user.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
@staticmethod
|
||||
def is_role(
|
||||
user_to_verify: DATestUser,
|
||||
|
||||
@@ -104,13 +104,30 @@ class UserGroupManager:
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def get_permissions(
|
||||
user_group: DATestUserGroup,
|
||||
user_performing_action: DATestUser,
|
||||
) -> list[str]:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/manage/admin/user-group/{user_group.id}/permissions",
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
@staticmethod
|
||||
def get_all(
|
||||
user_performing_action: DATestUser,
|
||||
include_default: bool = False,
|
||||
) -> list[UserGroup]:
|
||||
params: dict[str, str] = {}
|
||||
if include_default:
|
||||
params["include_default"] = "true"
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/manage/admin/user-group",
|
||||
headers=user_performing_action.headers,
|
||||
params=params,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [UserGroup(**ug) for ug in response.json()]
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
from uuid import UUID
|
||||
|
||||
import requests
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.enums import AccountType
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.api_key import APIKeyManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
from tests.integration.common_utils.test_models import DATestAPIKey
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
@@ -33,3 +37,120 @@ def test_limited(reset: None) -> None: # noqa: ARG001
|
||||
headers=api_key.headers,
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
def _get_service_account_account_type(
|
||||
admin_user: DATestUser,
|
||||
api_key_user_id: UUID,
|
||||
) -> AccountType:
|
||||
"""Fetch the account_type of a service account user via the user listing API."""
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/manage/users",
|
||||
headers=admin_user.headers,
|
||||
params={"include_api_keys": "true"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
user_id_str = str(api_key_user_id)
|
||||
for user in data["accepted"]:
|
||||
if user["id"] == user_id_str:
|
||||
return AccountType(user["account_type"])
|
||||
raise AssertionError(
|
||||
f"Service account user {user_id_str} not found in user listing"
|
||||
)
|
||||
|
||||
|
||||
def _get_default_group_user_ids(
|
||||
admin_user: DATestUser,
|
||||
) -> tuple[set[str], set[str]]:
|
||||
"""Return (admin_group_user_ids, basic_group_user_ids) from default groups."""
|
||||
all_groups = UserGroupManager.get_all(
|
||||
user_performing_action=admin_user,
|
||||
include_default=True,
|
||||
)
|
||||
admin_group = next(
|
||||
(g for g in all_groups if g.name == "Admin" and g.is_default), None
|
||||
)
|
||||
basic_group = next(
|
||||
(g for g in all_groups if g.name == "Basic" and g.is_default), None
|
||||
)
|
||||
assert admin_group is not None, "Admin default group not found"
|
||||
assert basic_group is not None, "Basic default group not found"
|
||||
|
||||
admin_ids = {str(u.id) for u in admin_group.users}
|
||||
basic_ids = {str(u.id) for u in basic_group.users}
|
||||
return admin_ids, basic_ids
|
||||
|
||||
|
||||
def test_api_key_limited_service_account(reset: None) -> None: # noqa: ARG001
|
||||
"""LIMITED role API key: account_type is SERVICE_ACCOUNT, no group membership."""
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
api_key: DATestAPIKey = APIKeyManager.create(
|
||||
api_key_role=UserRole.LIMITED,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Verify account_type
|
||||
account_type = _get_service_account_account_type(admin_user, api_key.user_id)
|
||||
assert (
|
||||
account_type == AccountType.SERVICE_ACCOUNT
|
||||
), f"Expected account_type={AccountType.SERVICE_ACCOUNT}, got {account_type}"
|
||||
|
||||
# Verify no group membership
|
||||
admin_ids, basic_ids = _get_default_group_user_ids(admin_user)
|
||||
user_id_str = str(api_key.user_id)
|
||||
assert (
|
||||
user_id_str not in admin_ids
|
||||
), "LIMITED API key should NOT be in Admin default group"
|
||||
assert (
|
||||
user_id_str not in basic_ids
|
||||
), "LIMITED API key should NOT be in Basic default group"
|
||||
|
||||
|
||||
def test_api_key_basic_service_account(reset: None) -> None: # noqa: ARG001
|
||||
"""BASIC role API key: account_type is SERVICE_ACCOUNT, in Basic group only."""
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
api_key: DATestAPIKey = APIKeyManager.create(
|
||||
api_key_role=UserRole.BASIC,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Verify account_type
|
||||
account_type = _get_service_account_account_type(admin_user, api_key.user_id)
|
||||
assert (
|
||||
account_type == AccountType.SERVICE_ACCOUNT
|
||||
), f"Expected account_type={AccountType.SERVICE_ACCOUNT}, got {account_type}"
|
||||
|
||||
# Verify Basic group membership
|
||||
admin_ids, basic_ids = _get_default_group_user_ids(admin_user)
|
||||
user_id_str = str(api_key.user_id)
|
||||
assert user_id_str in basic_ids, "BASIC API key should be in Basic default group"
|
||||
assert (
|
||||
user_id_str not in admin_ids
|
||||
), "BASIC API key should NOT be in Admin default group"
|
||||
|
||||
|
||||
def test_api_key_admin_service_account(reset: None) -> None: # noqa: ARG001
|
||||
"""ADMIN role API key: account_type is SERVICE_ACCOUNT, in Admin group only."""
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
api_key: DATestAPIKey = APIKeyManager.create(
|
||||
api_key_role=UserRole.ADMIN,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Verify account_type
|
||||
account_type = _get_service_account_account_type(admin_user, api_key.user_id)
|
||||
assert (
|
||||
account_type == AccountType.SERVICE_ACCOUNT
|
||||
), f"Expected account_type={AccountType.SERVICE_ACCOUNT}, got {account_type}"
|
||||
|
||||
# Verify Admin group membership
|
||||
admin_ids, basic_ids = _get_default_group_user_ids(admin_user)
|
||||
user_id_str = str(api_key.user_id)
|
||||
assert user_id_str in admin_ids, "ADMIN API key should be in Admin default group"
|
||||
assert (
|
||||
user_id_str not in basic_ids
|
||||
), "ADMIN API key should NOT be in Basic default group"
|
||||
|
||||
@@ -4,11 +4,32 @@ import pytest
|
||||
import requests
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.enums import AccountType
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def _simulate_saml_login(email: str, admin_user: DATestUser) -> dict:
|
||||
"""Simulate a SAML login by calling the test upsert endpoint."""
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/manage/users/test-upsert-user",
|
||||
json={"email": email},
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
def _get_basic_group_member_emails(admin_user: DATestUser) -> set[str]:
|
||||
"""Get the set of emails of all members in the Basic default group."""
|
||||
all_groups = UserGroupManager.get_all(admin_user, include_default=True)
|
||||
basic_default = [g for g in all_groups if g.is_default and g.name == "Basic"]
|
||||
assert basic_default, "Basic default group not found"
|
||||
return {u.email for u in basic_default[0].users}
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="SAML tests are enterprise only",
|
||||
@@ -49,15 +70,9 @@ def test_saml_user_conversion(reset: None) -> None: # noqa: ARG001
|
||||
assert UserManager.is_role(test_user, UserRole.EXT_PERM_USER)
|
||||
|
||||
# Simulate SAML login by calling the test endpoint
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/manage/users/test-upsert-user",
|
||||
json={"email": test_user_email},
|
||||
headers=admin_user.headers, # Use admin headers for authorization
|
||||
)
|
||||
response.raise_for_status()
|
||||
user_data = _simulate_saml_login(test_user_email, admin_user)
|
||||
|
||||
# Verify the response indicates the role changed to BASIC
|
||||
user_data = response.json()
|
||||
assert user_data["role"] == UserRole.BASIC.value
|
||||
|
||||
# Verify user role was changed in the database
|
||||
@@ -82,16 +97,237 @@ def test_saml_user_conversion(reset: None) -> None: # noqa: ARG001
|
||||
assert UserManager.is_role(slack_user, UserRole.SLACK_USER)
|
||||
|
||||
# Simulate SAML login again
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/manage/users/test-upsert-user",
|
||||
json={"email": slack_user_email},
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
user_data = _simulate_saml_login(slack_user_email, admin_user)
|
||||
|
||||
# Verify the response indicates the role changed to BASIC
|
||||
user_data = response.json()
|
||||
assert user_data["role"] == UserRole.BASIC.value
|
||||
|
||||
# Verify the user's role was changed in the database
|
||||
assert UserManager.is_role(slack_user, UserRole.BASIC)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="SAML tests are enterprise only",
|
||||
)
|
||||
def test_saml_user_conversion_sets_account_type_and_group(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""
|
||||
Test that SAML login sets account_type to STANDARD when converting a
|
||||
non-web user (EXT_PERM_USER) and that the user receives the correct role
|
||||
(BASIC) after conversion.
|
||||
|
||||
This validates the permissions-migration-phase2 changes which ensure that:
|
||||
1. account_type is updated to 'standard' on SAML conversion
|
||||
2. The converted user is assigned to the Basic default group
|
||||
"""
|
||||
# Create an admin user (first user is automatically admin)
|
||||
admin_user: DATestUser = UserManager.create(email="admin@example.com")
|
||||
|
||||
# Create a user and set them as EXT_PERM_USER
|
||||
test_email = "ext_convert@example.com"
|
||||
test_user = UserManager.create(email=test_email)
|
||||
UserManager.set_role(
|
||||
user_to_set=test_user,
|
||||
target_role=UserRole.EXT_PERM_USER,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
assert UserManager.is_role(test_user, UserRole.EXT_PERM_USER)
|
||||
|
||||
# Simulate SAML login
|
||||
user_data = _simulate_saml_login(test_email, admin_user)
|
||||
|
||||
# Verify account_type is set to standard after conversion
|
||||
assert (
|
||||
user_data["account_type"] == AccountType.STANDARD.value
|
||||
), f"Expected account_type='{AccountType.STANDARD.value}', got '{user_data['account_type']}'"
|
||||
|
||||
# Verify role is BASIC after conversion
|
||||
assert user_data["role"] == UserRole.BASIC.value
|
||||
|
||||
# Verify the user was assigned to the Basic default group
|
||||
assert test_email in _get_basic_group_member_emails(
|
||||
admin_user
|
||||
), f"Converted user '{test_email}' not found in Basic default group"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="SAML tests are enterprise only",
|
||||
)
|
||||
def test_saml_normal_signin_assigns_group(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""
|
||||
Test that a brand-new user signing in via SAML for the first time
|
||||
is created with the correct role, account_type, and group membership.
|
||||
|
||||
This validates that normal SAML sign-in (not an upgrade from
|
||||
SLACK_USER/EXT_PERM_USER) correctly:
|
||||
1. Creates the user with role=BASIC and account_type=STANDARD
|
||||
2. Assigns the user to the Basic default group
|
||||
"""
|
||||
# First user becomes admin
|
||||
admin_user: DATestUser = UserManager.create(email="admin@example.com")
|
||||
|
||||
# New user signs in via SAML (no prior account)
|
||||
new_email = "new_saml_user@example.com"
|
||||
user_data = _simulate_saml_login(new_email, admin_user)
|
||||
|
||||
# Verify role and account_type
|
||||
assert user_data["role"] == UserRole.BASIC.value
|
||||
assert user_data["account_type"] == AccountType.STANDARD.value
|
||||
|
||||
# Verify user is in the Basic default group
|
||||
assert new_email in _get_basic_group_member_emails(
|
||||
admin_user
|
||||
), f"New SAML user '{new_email}' not found in Basic default group"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="SAML tests are enterprise only",
|
||||
)
|
||||
def test_saml_user_conversion_restores_group_membership(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""
|
||||
Test that SAML login restores Basic group membership when converting
|
||||
a non-authenticated user (EXT_PERM_USER or SLACK_USER) to BASIC.
|
||||
|
||||
Group membership implies 'basic' permission (verified by
|
||||
test_new_group_gets_basic_permission).
|
||||
"""
|
||||
admin_user: DATestUser = UserManager.create(email="admin@example.com")
|
||||
|
||||
# --- EXT_PERM_USER path ---
|
||||
ext_email = "ext_perm_perms@example.com"
|
||||
ext_user = UserManager.create(email=ext_email)
|
||||
assert ext_email in _get_basic_group_member_emails(admin_user)
|
||||
|
||||
UserManager.set_role(
|
||||
user_to_set=ext_user,
|
||||
target_role=UserRole.EXT_PERM_USER,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
assert ext_email not in _get_basic_group_member_emails(admin_user)
|
||||
|
||||
user_data = _simulate_saml_login(ext_email, admin_user)
|
||||
assert user_data["role"] == UserRole.BASIC.value
|
||||
assert ext_email in _get_basic_group_member_emails(
|
||||
admin_user
|
||||
), "EXT_PERM_USER should be back in Basic group after SAML conversion"
|
||||
|
||||
# --- SLACK_USER path ---
|
||||
slack_email = "slack_perms@example.com"
|
||||
slack_user = UserManager.create(email=slack_email)
|
||||
|
||||
UserManager.set_role(
|
||||
user_to_set=slack_user,
|
||||
target_role=UserRole.SLACK_USER,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
assert slack_email not in _get_basic_group_member_emails(admin_user)
|
||||
|
||||
user_data = _simulate_saml_login(slack_email, admin_user)
|
||||
assert user_data["role"] == UserRole.BASIC.value
|
||||
assert slack_email in _get_basic_group_member_emails(
|
||||
admin_user
|
||||
), "SLACK_USER should be back in Basic group after SAML conversion"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="SAML tests are enterprise only",
|
||||
)
|
||||
def test_saml_round_trip_group_lifecycle(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""
|
||||
Test the full round-trip: BASIC -> EXT_PERM -> SAML(BASIC) -> EXT_PERM -> SAML(BASIC).
|
||||
|
||||
Verifies group membership is correctly removed and restored at each transition.
|
||||
"""
|
||||
admin_user: DATestUser = UserManager.create(email="admin@example.com")
|
||||
|
||||
test_email = "roundtrip@example.com"
|
||||
test_user = UserManager.create(email=test_email)
|
||||
|
||||
# Step 1: BASIC user is in Basic group
|
||||
assert test_email in _get_basic_group_member_emails(admin_user)
|
||||
|
||||
# Step 2: Downgrade to EXT_PERM_USER — loses Basic group
|
||||
UserManager.set_role(
|
||||
user_to_set=test_user,
|
||||
target_role=UserRole.EXT_PERM_USER,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
assert test_email not in _get_basic_group_member_emails(admin_user)
|
||||
|
||||
# Step 3: SAML login — converts back to BASIC, regains Basic group
|
||||
_simulate_saml_login(test_email, admin_user)
|
||||
assert test_email in _get_basic_group_member_emails(
|
||||
admin_user
|
||||
), "Should be in Basic group after first SAML conversion"
|
||||
|
||||
# Step 4: Downgrade again
|
||||
UserManager.set_role(
|
||||
user_to_set=test_user,
|
||||
target_role=UserRole.EXT_PERM_USER,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
assert test_email not in _get_basic_group_member_emails(admin_user)
|
||||
|
||||
# Step 5: SAML login again — should still restore correctly
|
||||
_simulate_saml_login(test_email, admin_user)
|
||||
assert test_email in _get_basic_group_member_emails(
|
||||
admin_user
|
||||
), "Should be in Basic group after second SAML conversion"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="SAML tests are enterprise only",
|
||||
)
|
||||
def test_saml_slack_user_conversion_sets_account_type_and_group(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""
|
||||
Test that SAML login sets account_type to STANDARD and assigns Basic group
|
||||
when converting a SLACK_USER (BOT account_type).
|
||||
|
||||
Mirrors test_saml_user_conversion_sets_account_type_and_group but for
|
||||
SLACK_USER instead of EXT_PERM_USER, and additionally verifies permissions.
|
||||
"""
|
||||
admin_user: DATestUser = UserManager.create(email="admin@example.com")
|
||||
|
||||
test_email = "slack_convert@example.com"
|
||||
test_user = UserManager.create(email=test_email)
|
||||
|
||||
UserManager.set_role(
|
||||
user_to_set=test_user,
|
||||
target_role=UserRole.SLACK_USER,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
assert UserManager.is_role(test_user, UserRole.SLACK_USER)
|
||||
|
||||
# SAML login
|
||||
user_data = _simulate_saml_login(test_email, admin_user)
|
||||
|
||||
# Verify account_type and role
|
||||
assert (
|
||||
user_data["account_type"] == AccountType.STANDARD.value
|
||||
), f"Expected STANDARD, got {user_data['account_type']}"
|
||||
assert user_data["role"] == UserRole.BASIC.value
|
||||
|
||||
# Verify Basic group membership (implies 'basic' permission)
|
||||
assert test_email in _get_basic_group_member_emails(
|
||||
admin_user
|
||||
), f"Converted SLACK_USER '{test_email}' not found in Basic default group"
|
||||
|
||||
@@ -0,0 +1,82 @@
|
||||
"""Integration tests for permission propagation across auth-triggered group changes.
|
||||
|
||||
These tests verify that effective permissions (via /me/permissions) actually
|
||||
propagate when users are added/removed from default groups through role changes.
|
||||
Custom permission grant tests will be added once the permission grant API is built.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def _get_basic_group_member_emails(admin_user: DATestUser) -> set[str]:
|
||||
all_groups = UserGroupManager.get_all(admin_user, include_default=True)
|
||||
basic_group = next(
|
||||
(g for g in all_groups if g.is_default and g.name == "Basic"), None
|
||||
)
|
||||
assert basic_group is not None, "Basic default group not found"
|
||||
return {u.email for u in basic_group.users}
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="Permission propagation tests require enterprise features",
|
||||
)
|
||||
def test_basic_permission_granted_on_registration(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""New users should get 'basic' permission through default group assignment."""
|
||||
admin_user: DATestUser = UserManager.create(email="admin@example.com")
|
||||
basic_user: DATestUser = UserManager.create(email="basic@example.com")
|
||||
|
||||
# Admin should have permissions from Admin group
|
||||
admin_perms = UserManager.get_permissions(admin_user)
|
||||
assert "basic" in admin_perms
|
||||
|
||||
# Basic user should have 'basic' from Basic default group
|
||||
basic_perms = UserManager.get_permissions(basic_user)
|
||||
assert "basic" in basic_perms
|
||||
|
||||
# Verify group membership matches
|
||||
assert basic_user.email in _get_basic_group_member_emails(admin_user)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="Permission propagation tests require enterprise features",
|
||||
)
|
||||
def test_role_downgrade_removes_basic_group_and_permission(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Downgrading to EXT_PERM_USER or SLACK_USER should remove from Basic group."""
|
||||
admin_user: DATestUser = UserManager.create(email="admin@example.com")
|
||||
|
||||
# --- EXT_PERM_USER ---
|
||||
ext_user: DATestUser = UserManager.create(email="ext@example.com")
|
||||
assert ext_user.email in _get_basic_group_member_emails(admin_user)
|
||||
|
||||
UserManager.set_role(
|
||||
user_to_set=ext_user,
|
||||
target_role=UserRole.EXT_PERM_USER,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
assert ext_user.email not in _get_basic_group_member_emails(admin_user)
|
||||
|
||||
# --- SLACK_USER ---
|
||||
slack_user: DATestUser = UserManager.create(email="slack@example.com")
|
||||
assert slack_user.email in _get_basic_group_member_emails(admin_user)
|
||||
|
||||
UserManager.set_role(
|
||||
user_to_set=slack_user,
|
||||
target_role=UserRole.SLACK_USER,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
assert slack_user.email not in _get_basic_group_member_emails(admin_user)
|
||||
@@ -21,8 +21,15 @@ import pytest
|
||||
import requests
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from tests.integration.common_utils.constants import ADMIN_USER_NAME
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.managers.scim_client import ScimClient
|
||||
from tests.integration.common_utils.managers.scim_token import ScimTokenManager
|
||||
from tests.integration.common_utils.managers.user import build_email
|
||||
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
SCIM_GROUP_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Group"
|
||||
@@ -44,13 +51,6 @@ def scim_token(idp_style: str) -> str:
|
||||
per IdP-style run and reuse. Uses UserManager directly to avoid
|
||||
fixture-scope conflicts with the function-scoped admin_user fixture.
|
||||
"""
|
||||
from tests.integration.common_utils.constants import ADMIN_USER_NAME
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.managers.user import build_email
|
||||
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
try:
|
||||
admin = UserManager.create(name=ADMIN_USER_NAME)
|
||||
except Exception:
|
||||
@@ -550,3 +550,145 @@ def test_patch_add_duplicate_member_is_idempotent(
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()["members"]) == 1 # still just one member
|
||||
|
||||
|
||||
def test_create_group_reserved_name_admin(scim_token: str) -> None:
|
||||
"""POST /Groups with reserved name 'Admin' returns 409."""
|
||||
resp = _create_scim_group(scim_token, "Admin", external_id="ext-reserved-admin")
|
||||
assert resp.status_code == 409
|
||||
assert "reserved" in resp.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_create_group_reserved_name_basic(scim_token: str) -> None:
|
||||
"""POST /Groups with reserved name 'Basic' returns 409."""
|
||||
resp = _create_scim_group(scim_token, "Basic", external_id="ext-reserved-basic")
|
||||
assert resp.status_code == 409
|
||||
assert "reserved" in resp.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_replace_group_cannot_rename_to_reserved(
|
||||
scim_token: str, idp_style: str
|
||||
) -> None:
|
||||
"""PUT /Groups/{id} renaming a group to 'Admin' returns 409."""
|
||||
created = _create_scim_group(
|
||||
scim_token,
|
||||
f"Rename To Reserved {idp_style}",
|
||||
external_id=f"ext-rtr-{idp_style}",
|
||||
).json()
|
||||
|
||||
resp = ScimClient.put(
|
||||
f"/Groups/{created['id']}",
|
||||
scim_token,
|
||||
json=_make_group_resource(
|
||||
display_name="Admin", external_id=f"ext-rtr-{idp_style}"
|
||||
),
|
||||
)
|
||||
assert resp.status_code == 409
|
||||
assert "reserved" in resp.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_patch_rename_to_reserved_name(scim_token: str, idp_style: str) -> None:
|
||||
"""PATCH /Groups/{id} renaming a group to 'Basic' returns 409."""
|
||||
created = _create_scim_group(
|
||||
scim_token,
|
||||
f"Patch Rename Reserved {idp_style}",
|
||||
external_id=f"ext-prr-{idp_style}",
|
||||
).json()
|
||||
|
||||
resp = ScimClient.patch(
|
||||
f"/Groups/{created['id']}",
|
||||
scim_token,
|
||||
json=_make_patch_request(
|
||||
[{"op": "replace", "path": "displayName", "value": "Basic"}],
|
||||
idp_style,
|
||||
),
|
||||
)
|
||||
assert resp.status_code == 409
|
||||
assert "reserved" in resp.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_delete_reserved_group_rejected(scim_token: str) -> None:
|
||||
"""DELETE /Groups/{id} on a reserved group ('Admin') returns 409."""
|
||||
# Look up the reserved 'Admin' group via SCIM filter
|
||||
resp = ScimClient.get('/Groups?filter=displayName eq "Admin"', scim_token)
|
||||
assert resp.status_code == 200
|
||||
resources = resp.json()["Resources"]
|
||||
assert len(resources) >= 1, "Expected reserved 'Admin' group to exist"
|
||||
admin_group_id = resources[0]["id"]
|
||||
|
||||
resp = ScimClient.delete(f"/Groups/{admin_group_id}", scim_token)
|
||||
assert resp.status_code == 409
|
||||
assert "reserved" in resp.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_scim_created_group_has_basic_permission(
|
||||
scim_token: str, idp_style: str
|
||||
) -> None:
|
||||
"""POST /Groups assigns the 'basic' permission to the group itself."""
|
||||
# Create a SCIM group (no members needed — we check the group's permissions)
|
||||
resp = _create_scim_group(
|
||||
scim_token,
|
||||
f"Basic Perm Group {idp_style}",
|
||||
external_id=f"ext-basic-perm-{idp_style}",
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
group_id = resp.json()["id"]
|
||||
|
||||
# Log in as the admin user (created by the scim_token fixture).
|
||||
admin = DATestUser(
|
||||
id="",
|
||||
email=build_email(ADMIN_USER_NAME),
|
||||
password=DEFAULT_PASSWORD,
|
||||
headers=GENERAL_HEADERS,
|
||||
role=UserRole.ADMIN,
|
||||
is_active=True,
|
||||
)
|
||||
admin = UserManager.login_as_user(admin)
|
||||
|
||||
# Verify the group itself was granted the basic permission
|
||||
perms_resp = requests.get(
|
||||
f"{API_SERVER_URL}/manage/admin/user-group/{group_id}/permissions",
|
||||
headers=admin.headers,
|
||||
)
|
||||
perms_resp.raise_for_status()
|
||||
perms = perms_resp.json()
|
||||
assert "basic" in perms, f"SCIM group should have 'basic' permission, got: {perms}"
|
||||
|
||||
|
||||
def test_replace_group_cannot_rename_from_reserved(scim_token: str) -> None:
|
||||
"""PUT /Groups/{id} renaming a reserved group ('Admin') to a non-reserved name returns 409."""
|
||||
resp = ScimClient.get('/Groups?filter=displayName eq "Admin"', scim_token)
|
||||
assert resp.status_code == 200
|
||||
resources = resp.json()["Resources"]
|
||||
assert len(resources) >= 1, "Expected reserved 'Admin' group to exist"
|
||||
admin_group_id = resources[0]["id"]
|
||||
|
||||
resp = ScimClient.put(
|
||||
f"/Groups/{admin_group_id}",
|
||||
scim_token,
|
||||
json=_make_group_resource(
|
||||
display_name="RenamedAdmin", external_id="ext-rename-from-reserved"
|
||||
),
|
||||
)
|
||||
assert resp.status_code == 409
|
||||
assert "reserved" in resp.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_patch_rename_from_reserved_name(scim_token: str, idp_style: str) -> None:
|
||||
"""PATCH /Groups/{id} renaming a reserved group ('Admin') returns 409."""
|
||||
resp = ScimClient.get('/Groups?filter=displayName eq "Admin"', scim_token)
|
||||
assert resp.status_code == 200
|
||||
resources = resp.json()["Resources"]
|
||||
assert len(resources) >= 1, "Expected reserved 'Admin' group to exist"
|
||||
admin_group_id = resources[0]["id"]
|
||||
|
||||
resp = ScimClient.patch(
|
||||
f"/Groups/{admin_group_id}",
|
||||
scim_token,
|
||||
json=_make_patch_request(
|
||||
[{"op": "replace", "path": "displayName", "value": "RenamedAdmin"}],
|
||||
idp_style,
|
||||
),
|
||||
)
|
||||
assert resp.status_code == 409
|
||||
assert "reserved" in resp.json()["detail"].lower()
|
||||
|
||||
@@ -35,9 +35,16 @@ from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.app_configs import REDIS_DB_NUMBER
|
||||
from onyx.configs.app_configs import REDIS_HOST
|
||||
from onyx.configs.app_configs import REDIS_PORT
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
from tests.integration.common_utils.constants import ADMIN_USER_NAME
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.managers.scim_client import ScimClient
|
||||
from tests.integration.common_utils.managers.scim_token import ScimTokenManager
|
||||
from tests.integration.common_utils.managers.user import build_email
|
||||
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
SCIM_USER_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:User"
|
||||
@@ -211,6 +218,49 @@ def test_create_user(scim_token: str, idp_style: str) -> None:
|
||||
_assert_entra_emails(body, email)
|
||||
|
||||
|
||||
def test_create_user_default_group_and_account_type(
|
||||
scim_token: str, idp_style: str
|
||||
) -> None:
|
||||
"""SCIM-provisioned users get Basic default group and STANDARD account_type."""
|
||||
email = f"scim_defaults_{idp_style}@example.com"
|
||||
ext_id = f"ext-defaults-{idp_style}"
|
||||
resp = _create_scim_user(scim_token, email, ext_id, idp_style)
|
||||
assert resp.status_code == 201
|
||||
user_id = resp.json()["id"]
|
||||
|
||||
# --- Verify group assignment via SCIM GET ---
|
||||
get_resp = ScimClient.get(f"/Users/{user_id}", scim_token)
|
||||
assert get_resp.status_code == 200
|
||||
groups = get_resp.json().get("groups", [])
|
||||
group_names = {g["display"] for g in groups}
|
||||
assert "Basic" in group_names, f"Expected 'Basic' in groups, got {group_names}"
|
||||
assert "Admin" not in group_names, "SCIM user should not be in Admin group"
|
||||
|
||||
# --- Verify account_type via admin API ---
|
||||
admin = UserManager.login_as_user(
|
||||
DATestUser(
|
||||
id="",
|
||||
email=build_email(ADMIN_USER_NAME),
|
||||
password=DEFAULT_PASSWORD,
|
||||
headers=GENERAL_HEADERS,
|
||||
role=UserRole.ADMIN,
|
||||
is_active=True,
|
||||
)
|
||||
)
|
||||
page = UserManager.get_user_page(
|
||||
user_performing_action=admin,
|
||||
search_query=email,
|
||||
)
|
||||
assert page.total_items >= 1
|
||||
scim_user_snapshot = next((u for u in page.items if u.email == email), None)
|
||||
assert (
|
||||
scim_user_snapshot is not None
|
||||
), f"SCIM user {email} not found in user listing"
|
||||
assert (
|
||||
scim_user_snapshot.account_type == AccountType.STANDARD
|
||||
), f"Expected STANDARD, got {scim_user_snapshot.account_type}"
|
||||
|
||||
|
||||
def test_get_user(scim_token: str, idp_style: str) -> None:
|
||||
"""GET /Users/{id} returns the user resource with all stored fields."""
|
||||
email = f"scim_get_{idp_style}@example.com"
|
||||
|
||||
@@ -0,0 +1,118 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import PermissionGrant
|
||||
from onyx.db.models import UserGroup as UserGroupModel
|
||||
from onyx.db.permissions import recompute_permissions_for_group__no_commit
|
||||
from onyx.db.permissions import recompute_user_permissions__no_commit
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="User group tests are enterprise only",
|
||||
)
|
||||
def test_user_gets_permissions_when_added_to_group(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
admin_user: DATestUser = UserManager.create(name="admin_for_perm_test")
|
||||
basic_user: DATestUser = UserManager.create(name="basic_user_for_perm_test")
|
||||
|
||||
# basic_user starts with only "basic" from the default group
|
||||
initial_permissions = UserManager.get_permissions(basic_user)
|
||||
assert "basic" in initial_permissions
|
||||
assert "add:agents" not in initial_permissions
|
||||
|
||||
# Create a new group and add basic_user
|
||||
group = UserGroupManager.create(
|
||||
name="perm-test-group",
|
||||
user_ids=[admin_user.id, basic_user.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Grant a non-basic permission to the group and recompute
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
db_group = db_session.get(UserGroupModel, group.id)
|
||||
assert db_group is not None
|
||||
db_session.add(
|
||||
PermissionGrant(
|
||||
group_id=db_group.id,
|
||||
permission=Permission.ADD_AGENTS,
|
||||
grant_source="SYSTEM",
|
||||
)
|
||||
)
|
||||
db_session.flush()
|
||||
recompute_user_permissions__no_commit(basic_user.id, db_session)
|
||||
db_session.commit()
|
||||
|
||||
# Verify the user gained the new permission (expanded includes read:agents)
|
||||
updated_permissions = UserManager.get_permissions(basic_user)
|
||||
assert (
|
||||
"add:agents" in updated_permissions
|
||||
), f"User should have 'add:agents' after group grant, got: {updated_permissions}"
|
||||
assert (
|
||||
"read:agents" in updated_permissions
|
||||
), f"User should have implied 'read:agents', got: {updated_permissions}"
|
||||
assert "basic" in updated_permissions
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="User group tests are enterprise only",
|
||||
)
|
||||
def test_group_permission_change_propagates_to_all_members(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
admin_user: DATestUser = UserManager.create(name="admin_propagate")
|
||||
user_a: DATestUser = UserManager.create(name="user_a_propagate")
|
||||
user_b: DATestUser = UserManager.create(name="user_b_propagate")
|
||||
|
||||
group = UserGroupManager.create(
|
||||
name="propagate-test-group",
|
||||
user_ids=[admin_user.id, user_a.id, user_b.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Neither user should have add:agents yet
|
||||
for u in (user_a, user_b):
|
||||
assert "add:agents" not in UserManager.get_permissions(u)
|
||||
|
||||
# Grant add:agents to the group, then batch-recompute
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
grant = PermissionGrant(
|
||||
group_id=group.id,
|
||||
permission=Permission.ADD_AGENTS,
|
||||
grant_source="SYSTEM",
|
||||
)
|
||||
db_session.add(grant)
|
||||
db_session.flush()
|
||||
recompute_permissions_for_group__no_commit(group.id, db_session)
|
||||
db_session.commit()
|
||||
|
||||
# Both users should now have the permission (plus implied read:agents)
|
||||
for u in (user_a, user_b):
|
||||
perms = UserManager.get_permissions(u)
|
||||
assert "add:agents" in perms, f"{u.id} missing add:agents: {perms}"
|
||||
assert "read:agents" in perms, f"{u.id} missing implied read:agents: {perms}"
|
||||
|
||||
# Soft-delete the grant and recompute — permission should be removed
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
db_grant = (
|
||||
db_session.query(PermissionGrant)
|
||||
.filter_by(group_id=group.id, permission=Permission.ADD_AGENTS)
|
||||
.first()
|
||||
)
|
||||
assert db_grant is not None
|
||||
db_grant.is_deleted = True
|
||||
db_session.flush()
|
||||
recompute_permissions_for_group__no_commit(group.id, db_session)
|
||||
db_session.commit()
|
||||
|
||||
for u in (user_a, user_b):
|
||||
perms = UserManager.get_permissions(u)
|
||||
assert "add:agents" not in perms, f"{u.id} still has add:agents: {perms}"
|
||||
@@ -0,0 +1,30 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="User group tests are enterprise only",
|
||||
)
|
||||
def test_new_group_gets_basic_permission(reset: None) -> None: # noqa: ARG001
|
||||
admin_user: DATestUser = UserManager.create(name="admin_for_basic_perm")
|
||||
|
||||
user_group = UserGroupManager.create(
|
||||
name="basic-perm-test-group",
|
||||
user_ids=[admin_user.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
permissions = UserGroupManager.get_permissions(
|
||||
user_group=user_group,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
assert (
|
||||
"basic" in permissions
|
||||
), f"New group should have 'basic' permission, got: {permissions}"
|
||||
@@ -0,0 +1,78 @@
|
||||
"""Integration tests for default group assignment on user registration.
|
||||
|
||||
Verifies that:
|
||||
- The first registered user is assigned to the Admin default group
|
||||
- Subsequent registered users are assigned to the Basic default group
|
||||
- account_type is set to STANDARD for email/password registrations
|
||||
"""
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.enums import AccountType
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def test_default_group_assignment_on_registration(reset: None) -> None: # noqa: ARG001
|
||||
# Register first user — should become admin
|
||||
admin_user: DATestUser = UserManager.create(name="first_user")
|
||||
assert admin_user.role == UserRole.ADMIN
|
||||
|
||||
# Register second user — should become basic
|
||||
basic_user: DATestUser = UserManager.create(name="second_user")
|
||||
assert basic_user.role == UserRole.BASIC
|
||||
|
||||
# Fetch all groups including default ones
|
||||
all_groups = UserGroupManager.get_all(
|
||||
user_performing_action=admin_user,
|
||||
include_default=True,
|
||||
)
|
||||
|
||||
# Find the default Admin and Basic groups
|
||||
admin_group = next(
|
||||
(g for g in all_groups if g.name == "Admin" and g.is_default), None
|
||||
)
|
||||
basic_group = next(
|
||||
(g for g in all_groups if g.name == "Basic" and g.is_default), None
|
||||
)
|
||||
assert admin_group is not None, "Admin default group not found"
|
||||
assert basic_group is not None, "Basic default group not found"
|
||||
|
||||
# Verify admin user is in Admin group and NOT in Basic group
|
||||
admin_group_user_ids = {str(u.id) for u in admin_group.users}
|
||||
basic_group_user_ids = {str(u.id) for u in basic_group.users}
|
||||
|
||||
assert (
|
||||
admin_user.id in admin_group_user_ids
|
||||
), "First user should be in Admin default group"
|
||||
assert (
|
||||
admin_user.id not in basic_group_user_ids
|
||||
), "First user should NOT be in Basic default group"
|
||||
|
||||
# Verify basic user is in Basic group and NOT in Admin group
|
||||
assert (
|
||||
basic_user.id in basic_group_user_ids
|
||||
), "Second user should be in Basic default group"
|
||||
assert (
|
||||
basic_user.id not in admin_group_user_ids
|
||||
), "Second user should NOT be in Admin default group"
|
||||
|
||||
# Verify account_type is STANDARD for both users via user listing API
|
||||
paginated_result = UserManager.get_user_page(
|
||||
user_performing_action=admin_user,
|
||||
page_num=0,
|
||||
page_size=10,
|
||||
)
|
||||
users_by_id = {str(u.id): u for u in paginated_result.items}
|
||||
|
||||
admin_snapshot = users_by_id.get(admin_user.id)
|
||||
basic_snapshot = users_by_id.get(basic_user.id)
|
||||
assert admin_snapshot is not None, "Admin user not found in user listing"
|
||||
assert basic_snapshot is not None, "Basic user not found in user listing"
|
||||
|
||||
assert (
|
||||
admin_snapshot.account_type == AccountType.STANDARD
|
||||
), f"Admin user account_type should be STANDARD, got {admin_snapshot.account_type}"
|
||||
assert (
|
||||
basic_snapshot.account_type == AccountType.STANDARD
|
||||
), f"Basic user account_type should be STANDARD, got {basic_snapshot.account_type}"
|
||||
@@ -0,0 +1,135 @@
|
||||
"""Integration tests for password signup upgrade paths.
|
||||
|
||||
Verifies that when a BOT or EXT_PERM_USER user signs up via email/password:
|
||||
- Their account_type is upgraded to STANDARD
|
||||
- They are assigned to the Basic default group
|
||||
- They gain the correct effective permissions
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.enums import AccountType
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def _get_default_group_member_emails(
|
||||
admin_user: DATestUser,
|
||||
group_name: str,
|
||||
) -> set[str]:
|
||||
"""Get the set of emails of all members in a named default group."""
|
||||
all_groups = UserGroupManager.get_all(admin_user, include_default=True)
|
||||
matched = [g for g in all_groups if g.is_default and g.name == group_name]
|
||||
assert matched, f"Default group '{group_name}' not found"
|
||||
return {u.email for u in matched[0].users}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"target_role",
|
||||
[UserRole.EXT_PERM_USER, UserRole.SLACK_USER],
|
||||
ids=["ext_perm_user", "slack_user"],
|
||||
)
|
||||
def test_password_signup_upgrade(
|
||||
reset: None, # noqa: ARG001
|
||||
target_role: UserRole,
|
||||
) -> None:
|
||||
"""When a non-web user signs up via email/password, they should be
|
||||
upgraded to STANDARD account_type and assigned to the Basic default group."""
|
||||
admin_user: DATestUser = UserManager.create(email="admin@example.com")
|
||||
|
||||
test_email = f"{target_role.value}_upgrade@example.com"
|
||||
test_user = UserManager.create(email=test_email)
|
||||
|
||||
test_user = UserManager.set_role(
|
||||
user_to_set=test_user,
|
||||
target_role=target_role,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
|
||||
# Verify user was removed from Basic group after downgrade
|
||||
basic_emails = _get_default_group_member_emails(admin_user, "Basic")
|
||||
assert (
|
||||
test_email not in basic_emails
|
||||
), f"{target_role.value} should not be in Basic default group"
|
||||
|
||||
# Re-register with the same email — triggers the password signup upgrade
|
||||
upgraded_user = UserManager.create(email=test_email)
|
||||
|
||||
assert upgraded_user.role == UserRole.BASIC
|
||||
|
||||
paginated = UserManager.get_user_page(
|
||||
user_performing_action=admin_user,
|
||||
page_num=0,
|
||||
page_size=10,
|
||||
)
|
||||
user_snapshot = next(
|
||||
(u for u in paginated.items if str(u.id) == upgraded_user.id), None
|
||||
)
|
||||
assert user_snapshot is not None
|
||||
assert (
|
||||
user_snapshot.account_type == AccountType.STANDARD
|
||||
), f"Expected STANDARD, got {user_snapshot.account_type}"
|
||||
|
||||
# Verify user is now in the Basic default group
|
||||
basic_emails = _get_default_group_member_emails(admin_user, "Basic")
|
||||
assert (
|
||||
test_email in basic_emails
|
||||
), f"Upgraded user '{test_email}' not found in Basic default group"
|
||||
|
||||
|
||||
def test_password_signup_upgrade_propagates_permissions(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""When an EXT_PERM_USER or SLACK_USER signs up via password, they should
|
||||
gain the 'basic' permission through the Basic default group assignment."""
|
||||
admin_user: DATestUser = UserManager.create(email="admin@example.com")
|
||||
|
||||
# --- EXT_PERM_USER path ---
|
||||
ext_email = "ext_perms_check@example.com"
|
||||
ext_user = UserManager.create(email=ext_email)
|
||||
|
||||
initial_perms = UserManager.get_permissions(ext_user)
|
||||
assert "basic" in initial_perms
|
||||
|
||||
ext_user = UserManager.set_role(
|
||||
user_to_set=ext_user,
|
||||
target_role=UserRole.EXT_PERM_USER,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
|
||||
basic_emails = _get_default_group_member_emails(admin_user, "Basic")
|
||||
assert ext_email not in basic_emails
|
||||
|
||||
upgraded = UserManager.create(email=ext_email)
|
||||
assert upgraded.role == UserRole.BASIC
|
||||
|
||||
perms = UserManager.get_permissions(upgraded)
|
||||
assert (
|
||||
"basic" in perms
|
||||
), f"Upgraded EXT_PERM_USER should have 'basic' permission, got: {perms}"
|
||||
|
||||
# --- SLACK_USER path ---
|
||||
slack_email = "slack_perms_check@example.com"
|
||||
slack_user = UserManager.create(email=slack_email)
|
||||
|
||||
slack_user = UserManager.set_role(
|
||||
user_to_set=slack_user,
|
||||
target_role=UserRole.SLACK_USER,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
|
||||
basic_emails = _get_default_group_member_emails(admin_user, "Basic")
|
||||
assert slack_email not in basic_emails
|
||||
|
||||
upgraded = UserManager.create(email=slack_email)
|
||||
assert upgraded.role == UserRole.BASIC
|
||||
|
||||
perms = UserManager.get_permissions(upgraded)
|
||||
assert (
|
||||
"basic" in perms
|
||||
), f"Upgraded SLACK_USER should have 'basic' permission, got: {perms}"
|
||||
@@ -0,0 +1,54 @@
|
||||
"""Integration tests for default group reconciliation on user reactivation.
|
||||
|
||||
Verifies that:
|
||||
- A deactivated user retains default group membership after reactivation
|
||||
- Reactivation via the admin API reconciles missing group membership
|
||||
"""
|
||||
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def _get_default_group_member_emails(
|
||||
admin_user: DATestUser,
|
||||
group_name: str,
|
||||
) -> set[str]:
|
||||
"""Get the set of emails of all members in a named default group."""
|
||||
all_groups = UserGroupManager.get_all(admin_user, include_default=True)
|
||||
matched = [g for g in all_groups if g.is_default and g.name == group_name]
|
||||
assert matched, f"Default group '{group_name}' not found"
|
||||
return {u.email for u in matched[0].users}
|
||||
|
||||
|
||||
def test_reactivated_user_retains_default_group(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Deactivating and reactivating a user should preserve their
|
||||
default group membership."""
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
basic_user: DATestUser = UserManager.create(name="basic_user")
|
||||
|
||||
# Verify user is in Basic group initially
|
||||
basic_emails = _get_default_group_member_emails(admin_user, "Basic")
|
||||
assert basic_user.email in basic_emails
|
||||
|
||||
# Deactivate the user
|
||||
UserManager.set_status(
|
||||
user_to_set=basic_user,
|
||||
target_status=False,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Reactivate the user
|
||||
UserManager.set_status(
|
||||
user_to_set=basic_user,
|
||||
target_status=True,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Verify user is still in Basic group after reactivation
|
||||
basic_emails = _get_default_group_member_emails(admin_user, "Basic")
|
||||
assert (
|
||||
basic_user.email in basic_emails
|
||||
), "Reactivated user should still be in Basic default group"
|
||||
176
backend/tests/unit/onyx/auth/test_permissions.py
Normal file
176
backend/tests/unit/onyx/auth/test_permissions.py
Normal file
@@ -0,0 +1,176 @@
|
||||
"""
|
||||
Unit tests for onyx.auth.permissions — pure logic and FastAPI dependency.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.auth.permissions import ALL_PERMISSIONS
|
||||
from onyx.auth.permissions import get_effective_permissions
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.auth.permissions import resolve_effective_permissions
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# resolve_effective_permissions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResolveEffectivePermissions:
|
||||
def test_empty_set(self) -> None:
|
||||
assert resolve_effective_permissions(set()) == set()
|
||||
|
||||
def test_basic_no_implications(self) -> None:
|
||||
result = resolve_effective_permissions({"basic"})
|
||||
assert result == {"basic"}
|
||||
|
||||
def test_single_implication(self) -> None:
|
||||
result = resolve_effective_permissions({"add:agents"})
|
||||
assert result == {"add:agents", "read:agents"}
|
||||
|
||||
def test_manage_agents_implies_add_and_read(self) -> None:
|
||||
"""manage:agents directly maps to {add:agents, read:agents}."""
|
||||
result = resolve_effective_permissions({"manage:agents"})
|
||||
assert result == {"manage:agents", "add:agents", "read:agents"}
|
||||
|
||||
def test_manage_connectors_chain(self) -> None:
|
||||
result = resolve_effective_permissions({"manage:connectors"})
|
||||
assert result == {"manage:connectors", "add:connectors", "read:connectors"}
|
||||
|
||||
def test_manage_document_sets(self) -> None:
|
||||
result = resolve_effective_permissions({"manage:document_sets"})
|
||||
assert result == {
|
||||
"manage:document_sets",
|
||||
"read:document_sets",
|
||||
"read:connectors",
|
||||
}
|
||||
|
||||
def test_manage_user_groups_implies_all_reads(self) -> None:
|
||||
result = resolve_effective_permissions({"manage:user_groups"})
|
||||
assert result == {
|
||||
"manage:user_groups",
|
||||
"read:connectors",
|
||||
"read:document_sets",
|
||||
"read:agents",
|
||||
"read:users",
|
||||
}
|
||||
|
||||
def test_admin_override(self) -> None:
|
||||
result = resolve_effective_permissions({"admin"})
|
||||
assert result == set(ALL_PERMISSIONS)
|
||||
|
||||
def test_admin_with_others(self) -> None:
|
||||
result = resolve_effective_permissions({"admin", "basic"})
|
||||
assert result == set(ALL_PERMISSIONS)
|
||||
|
||||
def test_multi_group_union(self) -> None:
|
||||
result = resolve_effective_permissions(
|
||||
{"add:agents", "manage:connectors", "basic"}
|
||||
)
|
||||
assert result == {
|
||||
"basic",
|
||||
"add:agents",
|
||||
"read:agents",
|
||||
"manage:connectors",
|
||||
"add:connectors",
|
||||
"read:connectors",
|
||||
}
|
||||
|
||||
def test_toggle_permission_no_implications(self) -> None:
|
||||
result = resolve_effective_permissions({"read:agent_analytics"})
|
||||
assert result == {"read:agent_analytics"}
|
||||
|
||||
def test_all_permissions_for_admin(self) -> None:
|
||||
result = resolve_effective_permissions({"admin"})
|
||||
assert len(result) == len(ALL_PERMISSIONS)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_effective_permissions (expands implied at read time)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetEffectivePermissions:
|
||||
def test_expands_implied_permissions(self) -> None:
|
||||
"""Column stores only granted; get_effective_permissions expands implied."""
|
||||
user = MagicMock()
|
||||
user.effective_permissions = ["add:agents"]
|
||||
result = get_effective_permissions(user)
|
||||
assert result == {Permission.ADD_AGENTS, Permission.READ_AGENTS}
|
||||
|
||||
def test_admin_expands_to_all(self) -> None:
|
||||
user = MagicMock()
|
||||
user.effective_permissions = ["admin"]
|
||||
result = get_effective_permissions(user)
|
||||
assert result == set(Permission)
|
||||
|
||||
def test_basic_stays_basic(self) -> None:
|
||||
user = MagicMock()
|
||||
user.effective_permissions = ["basic"]
|
||||
result = get_effective_permissions(user)
|
||||
assert result == {Permission.BASIC_ACCESS}
|
||||
|
||||
def test_empty_column(self) -> None:
|
||||
user = MagicMock()
|
||||
user.effective_permissions = []
|
||||
result = get_effective_permissions(user)
|
||||
assert result == set()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# require_permission (FastAPI dependency)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRequirePermission:
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_bypass(self) -> None:
|
||||
"""Admin stored in column should pass any permission check."""
|
||||
user = MagicMock()
|
||||
user.effective_permissions = ["admin"]
|
||||
|
||||
dep = require_permission(Permission.MANAGE_CONNECTORS)
|
||||
result = await dep(user=user)
|
||||
assert result is user
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_required_permission(self) -> None:
|
||||
user = MagicMock()
|
||||
user.effective_permissions = ["manage:connectors"]
|
||||
|
||||
dep = require_permission(Permission.MANAGE_CONNECTORS)
|
||||
result = await dep(user=user)
|
||||
assert result is user
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_implied_permission_passes(self) -> None:
|
||||
"""manage:connectors implies read:connectors at read time."""
|
||||
user = MagicMock()
|
||||
user.effective_permissions = ["manage:connectors"]
|
||||
|
||||
dep = require_permission(Permission.READ_CONNECTORS)
|
||||
result = await dep(user=user)
|
||||
assert result is user
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_permission_raises(self) -> None:
|
||||
user = MagicMock()
|
||||
user.effective_permissions = ["basic"]
|
||||
|
||||
dep = require_permission(Permission.MANAGE_CONNECTORS)
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
await dep(user=user)
|
||||
assert exc_info.value.error_code == OnyxErrorCode.INSUFFICIENT_PERMISSIONS
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_permissions_fails(self) -> None:
|
||||
user = MagicMock()
|
||||
user.effective_permissions = []
|
||||
|
||||
dep = require_permission(Permission.BASIC_ACCESS)
|
||||
with pytest.raises(OnyxError):
|
||||
await dep(user=user)
|
||||
29
backend/tests/unit/onyx/auth/test_user_create_schema.py
Normal file
29
backend/tests/unit/onyx/auth/test_user_create_schema.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""
|
||||
Unit tests for UserCreate schema dict methods.
|
||||
|
||||
Verifies that account_type is always included in create_update_dict
|
||||
and create_update_dict_superuser.
|
||||
"""
|
||||
|
||||
from onyx.auth.schemas import UserCreate
|
||||
from onyx.db.enums import AccountType
|
||||
|
||||
|
||||
def test_create_update_dict_includes_default_account_type() -> None:
|
||||
uc = UserCreate(email="a@b.com", password="secret123")
|
||||
d = uc.create_update_dict()
|
||||
assert d["account_type"] == AccountType.STANDARD
|
||||
|
||||
|
||||
def test_create_update_dict_includes_explicit_account_type() -> None:
|
||||
uc = UserCreate(
|
||||
email="a@b.com", password="secret123", account_type=AccountType.SERVICE_ACCOUNT
|
||||
)
|
||||
d = uc.create_update_dict()
|
||||
assert d["account_type"] == AccountType.STANDARD
|
||||
|
||||
|
||||
def test_create_update_dict_superuser_includes_account_type() -> None:
|
||||
uc = UserCreate(email="a@b.com", password="secret123")
|
||||
d = uc.create_update_dict_superuser()
|
||||
assert d["account_type"] == AccountType.STANDARD
|
||||
176
backend/tests/unit/onyx/db/test_assign_default_groups.py
Normal file
176
backend/tests/unit/onyx/db/test_assign_default_groups.py
Normal file
@@ -0,0 +1,176 @@
|
||||
"""
|
||||
Unit tests for assign_user_to_default_groups__no_commit in onyx.db.users.
|
||||
|
||||
Covers:
|
||||
1. Standard/service-account users get assigned to the correct default group
|
||||
2. BOT, EXT_PERM_USER, ANONYMOUS account types are skipped
|
||||
3. Missing default group raises RuntimeError
|
||||
4. Already-in-group is a no-op
|
||||
5. IntegrityError race condition is handled gracefully
|
||||
6. The function never commits the session
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.users import assign_user_to_default_groups__no_commit
|
||||
|
||||
|
||||
def _mock_user(
|
||||
account_type: AccountType = AccountType.STANDARD,
|
||||
email: str = "test@example.com",
|
||||
) -> MagicMock:
|
||||
user = MagicMock()
|
||||
user.id = uuid4()
|
||||
user.email = email
|
||||
user.account_type = account_type
|
||||
return user
|
||||
|
||||
|
||||
def _mock_group(name: str = "Basic", group_id: int = 1) -> MagicMock:
|
||||
group = MagicMock()
|
||||
group.id = group_id
|
||||
group.name = name
|
||||
group.is_default = True
|
||||
return group
|
||||
|
||||
|
||||
def _make_query_chain(first_return: object = None) -> MagicMock:
|
||||
"""Returns a mock that supports .filter(...).filter(...).first() chaining."""
|
||||
chain = MagicMock()
|
||||
chain.filter.return_value = chain
|
||||
chain.first.return_value = first_return
|
||||
return chain
|
||||
|
||||
|
||||
def _setup_db_session(
|
||||
group_result: object = None,
|
||||
membership_result: object = None,
|
||||
) -> MagicMock:
|
||||
"""Create a db_session mock that routes query(UserGroup) and query(User__UserGroup)."""
|
||||
db_session = MagicMock()
|
||||
|
||||
group_chain = _make_query_chain(group_result)
|
||||
membership_chain = _make_query_chain(membership_result)
|
||||
|
||||
def query_side_effect(model: type) -> MagicMock:
|
||||
if model is UserGroup:
|
||||
return group_chain
|
||||
if model is User__UserGroup:
|
||||
return membership_chain
|
||||
return MagicMock()
|
||||
|
||||
db_session.query.side_effect = query_side_effect
|
||||
return db_session
|
||||
|
||||
|
||||
def test_standard_user_assigned_to_basic_group() -> None:
|
||||
group = _mock_group("Basic")
|
||||
db_session = _setup_db_session(group_result=group, membership_result=None)
|
||||
savepoint = MagicMock()
|
||||
db_session.begin_nested.return_value = savepoint
|
||||
user = _mock_user(AccountType.STANDARD)
|
||||
|
||||
assign_user_to_default_groups__no_commit(db_session, user, is_admin=False)
|
||||
|
||||
db_session.add.assert_called_once()
|
||||
added = db_session.add.call_args[0][0]
|
||||
assert isinstance(added, User__UserGroup)
|
||||
assert added.user_id == user.id
|
||||
assert added.user_group_id == group.id
|
||||
db_session.flush.assert_called_once()
|
||||
|
||||
|
||||
def test_admin_user_assigned_to_admin_group() -> None:
|
||||
group = _mock_group("Admin", group_id=2)
|
||||
db_session = _setup_db_session(group_result=group, membership_result=None)
|
||||
savepoint = MagicMock()
|
||||
db_session.begin_nested.return_value = savepoint
|
||||
user = _mock_user(AccountType.STANDARD)
|
||||
|
||||
assign_user_to_default_groups__no_commit(db_session, user, is_admin=True)
|
||||
|
||||
db_session.add.assert_called_once()
|
||||
added = db_session.add.call_args[0][0]
|
||||
assert isinstance(added, User__UserGroup)
|
||||
assert added.user_group_id == group.id
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"account_type",
|
||||
[AccountType.BOT, AccountType.EXT_PERM_USER, AccountType.ANONYMOUS],
|
||||
)
|
||||
def test_excluded_account_types_skipped(account_type: AccountType) -> None:
|
||||
db_session = MagicMock()
|
||||
user = _mock_user(account_type)
|
||||
|
||||
assign_user_to_default_groups__no_commit(db_session, user)
|
||||
|
||||
db_session.query.assert_not_called()
|
||||
db_session.add.assert_not_called()
|
||||
|
||||
|
||||
def test_service_account_not_skipped() -> None:
|
||||
group = _mock_group("Basic")
|
||||
db_session = _setup_db_session(group_result=group, membership_result=None)
|
||||
savepoint = MagicMock()
|
||||
db_session.begin_nested.return_value = savepoint
|
||||
user = _mock_user(AccountType.SERVICE_ACCOUNT)
|
||||
|
||||
assign_user_to_default_groups__no_commit(db_session, user, is_admin=False)
|
||||
|
||||
db_session.add.assert_called_once()
|
||||
|
||||
|
||||
def test_missing_default_group_raises_error() -> None:
|
||||
db_session = _setup_db_session(group_result=None)
|
||||
user = _mock_user()
|
||||
|
||||
with pytest.raises(RuntimeError, match="Default group .* not found"):
|
||||
assign_user_to_default_groups__no_commit(db_session, user)
|
||||
|
||||
|
||||
def test_already_in_group_is_noop() -> None:
|
||||
group = _mock_group("Basic")
|
||||
existing_membership = MagicMock()
|
||||
db_session = _setup_db_session(
|
||||
group_result=group, membership_result=existing_membership
|
||||
)
|
||||
user = _mock_user()
|
||||
|
||||
assign_user_to_default_groups__no_commit(db_session, user)
|
||||
|
||||
db_session.add.assert_not_called()
|
||||
db_session.begin_nested.assert_not_called()
|
||||
|
||||
|
||||
def test_integrity_error_race_condition_handled() -> None:
|
||||
group = _mock_group("Basic")
|
||||
db_session = _setup_db_session(group_result=group, membership_result=None)
|
||||
savepoint = MagicMock()
|
||||
db_session.begin_nested.return_value = savepoint
|
||||
db_session.flush.side_effect = IntegrityError(None, None, Exception("duplicate"))
|
||||
user = _mock_user()
|
||||
|
||||
# Should not raise
|
||||
assign_user_to_default_groups__no_commit(db_session, user)
|
||||
|
||||
savepoint.rollback.assert_called_once()
|
||||
|
||||
|
||||
def test_no_commit_called_on_successful_assignment() -> None:
|
||||
group = _mock_group("Basic")
|
||||
db_session = _setup_db_session(group_result=group, membership_result=None)
|
||||
savepoint = MagicMock()
|
||||
db_session.begin_nested.return_value = savepoint
|
||||
user = _mock_user()
|
||||
|
||||
assign_user_to_default_groups__no_commit(db_session, user)
|
||||
|
||||
db_session.commit.assert_not_called()
|
||||
@@ -0,0 +1,203 @@
|
||||
import pytest
|
||||
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE
|
||||
from onyx.document_index.opensearch.schema import get_opensearch_doc_chunk_id
|
||||
from onyx.document_index.opensearch.string_filtering import (
|
||||
MAX_DOCUMENT_ID_ENCODED_LENGTH,
|
||||
)
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE
|
||||
|
||||
|
||||
SINGLE_TENANT_STATE = TenantState(
|
||||
tenant_id=POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE, multitenant=False
|
||||
)
|
||||
MULTI_TENANT_STATE = TenantState(
|
||||
tenant_id="tenant_abcdef12-3456-7890-abcd-ef1234567890", multitenant=True
|
||||
)
|
||||
EXPECTED_SHORT_TENANT = "abcdef12"
|
||||
|
||||
|
||||
class TestGetOpensearchDocChunkIdSingleTenant:
|
||||
def test_basic(self) -> None:
|
||||
result = get_opensearch_doc_chunk_id(
|
||||
SINGLE_TENANT_STATE, "my-doc-id", chunk_index=0
|
||||
)
|
||||
assert result == f"my-doc-id__{DEFAULT_MAX_CHUNK_SIZE}__0"
|
||||
|
||||
def test_custom_chunk_size(self) -> None:
|
||||
result = get_opensearch_doc_chunk_id(
|
||||
SINGLE_TENANT_STATE, "doc1", chunk_index=3, max_chunk_size=1024
|
||||
)
|
||||
assert result == "doc1__1024__3"
|
||||
|
||||
def test_special_chars_are_stripped(self) -> None:
|
||||
"""Tests characters not matching [A-Za-z0-9_.-~] are removed."""
|
||||
result = get_opensearch_doc_chunk_id(
|
||||
SINGLE_TENANT_STATE, "doc/with?special#chars&more%stuff", chunk_index=0
|
||||
)
|
||||
assert "/" not in result
|
||||
assert "?" not in result
|
||||
assert "#" not in result
|
||||
assert result == f"docwithspecialcharsmorestuff__{DEFAULT_MAX_CHUNK_SIZE}__0"
|
||||
|
||||
def test_short_doc_id_not_hashed(self) -> None:
|
||||
"""
|
||||
Tests that a short doc ID should appear directly in the result, not as a
|
||||
hash.
|
||||
"""
|
||||
doc_id = "short-id"
|
||||
result = get_opensearch_doc_chunk_id(SINGLE_TENANT_STATE, doc_id, chunk_index=0)
|
||||
assert "short-id" in result
|
||||
|
||||
def test_long_doc_id_is_hashed(self) -> None:
|
||||
"""
|
||||
Tests that a doc ID exceeding the max length should be replaced with a
|
||||
blake2b hash.
|
||||
"""
|
||||
# Create a doc ID that will exceed max length after the suffix is
|
||||
# appended.
|
||||
doc_id = "a" * MAX_DOCUMENT_ID_ENCODED_LENGTH
|
||||
result = get_opensearch_doc_chunk_id(SINGLE_TENANT_STATE, doc_id, chunk_index=0)
|
||||
# The original doc ID should NOT appear in the result.
|
||||
assert doc_id not in result
|
||||
# The suffix should still be present.
|
||||
assert f"__{DEFAULT_MAX_CHUNK_SIZE}__0" in result
|
||||
|
||||
def test_long_doc_id_hash_is_deterministic(self) -> None:
|
||||
doc_id = "x" * MAX_DOCUMENT_ID_ENCODED_LENGTH
|
||||
result1 = get_opensearch_doc_chunk_id(
|
||||
SINGLE_TENANT_STATE, doc_id, chunk_index=5
|
||||
)
|
||||
result2 = get_opensearch_doc_chunk_id(
|
||||
SINGLE_TENANT_STATE, doc_id, chunk_index=5
|
||||
)
|
||||
assert result1 == result2
|
||||
|
||||
def test_long_doc_id_different_inputs_produce_different_hashes(self) -> None:
|
||||
doc_id_a = "a" * MAX_DOCUMENT_ID_ENCODED_LENGTH
|
||||
doc_id_b = "b" * MAX_DOCUMENT_ID_ENCODED_LENGTH
|
||||
result_a = get_opensearch_doc_chunk_id(
|
||||
SINGLE_TENANT_STATE, doc_id_a, chunk_index=0
|
||||
)
|
||||
result_b = get_opensearch_doc_chunk_id(
|
||||
SINGLE_TENANT_STATE, doc_id_b, chunk_index=0
|
||||
)
|
||||
assert result_a != result_b
|
||||
|
||||
def test_result_never_exceeds_max_length(self) -> None:
|
||||
"""
|
||||
Tests that the final result should always be under
|
||||
MAX_DOCUMENT_ID_ENCODED_LENGTH bytes.
|
||||
"""
|
||||
doc_id = "z" * (MAX_DOCUMENT_ID_ENCODED_LENGTH * 2)
|
||||
result = get_opensearch_doc_chunk_id(
|
||||
SINGLE_TENANT_STATE, doc_id, chunk_index=999, max_chunk_size=99999
|
||||
)
|
||||
assert len(result.encode("utf-8")) < MAX_DOCUMENT_ID_ENCODED_LENGTH
|
||||
|
||||
def test_no_tenant_prefix_in_single_tenant(self) -> None:
|
||||
result = get_opensearch_doc_chunk_id(
|
||||
SINGLE_TENANT_STATE, "mydoc", chunk_index=0
|
||||
)
|
||||
assert not result.startswith(SINGLE_TENANT_STATE.tenant_id)
|
||||
|
||||
|
||||
class TestGetOpensearchDocChunkIdMultiTenant:
|
||||
def test_includes_tenant_prefix(self) -> None:
|
||||
result = get_opensearch_doc_chunk_id(MULTI_TENANT_STATE, "mydoc", chunk_index=0)
|
||||
assert result.startswith(f"{EXPECTED_SHORT_TENANT}__")
|
||||
|
||||
def test_format(self) -> None:
|
||||
result = get_opensearch_doc_chunk_id(
|
||||
MULTI_TENANT_STATE, "mydoc", chunk_index=2, max_chunk_size=256
|
||||
)
|
||||
assert result == f"{EXPECTED_SHORT_TENANT}__mydoc__256__2"
|
||||
|
||||
def test_long_doc_id_is_hashed_multitenant(self) -> None:
|
||||
doc_id = "d" * MAX_DOCUMENT_ID_ENCODED_LENGTH
|
||||
result = get_opensearch_doc_chunk_id(MULTI_TENANT_STATE, doc_id, chunk_index=0)
|
||||
# Should still have tenant prefix.
|
||||
assert result.startswith(f"{EXPECTED_SHORT_TENANT}__")
|
||||
# The original doc ID should NOT appear in the result.
|
||||
assert doc_id not in result
|
||||
# The suffix should still be present.
|
||||
assert f"__{DEFAULT_MAX_CHUNK_SIZE}__0" in result
|
||||
|
||||
def test_result_never_exceeds_max_length_multitenant(self) -> None:
|
||||
doc_id = "q" * (MAX_DOCUMENT_ID_ENCODED_LENGTH * 2)
|
||||
result = get_opensearch_doc_chunk_id(
|
||||
MULTI_TENANT_STATE, doc_id, chunk_index=999, max_chunk_size=99999
|
||||
)
|
||||
assert len(result.encode("utf-8")) < MAX_DOCUMENT_ID_ENCODED_LENGTH
|
||||
|
||||
def test_different_tenants_produce_different_ids(self) -> None:
|
||||
tenant_a = TenantState(
|
||||
tenant_id="tenant_aaaaaaaa-0000-0000-0000-000000000000", multitenant=True
|
||||
)
|
||||
tenant_b = TenantState(
|
||||
tenant_id="tenant_bbbbbbbb-0000-0000-0000-000000000000", multitenant=True
|
||||
)
|
||||
result_a = get_opensearch_doc_chunk_id(tenant_a, "same-doc", chunk_index=0)
|
||||
result_b = get_opensearch_doc_chunk_id(tenant_b, "same-doc", chunk_index=0)
|
||||
assert result_a != result_b
|
||||
|
||||
|
||||
class TestGetOpensearchDocChunkIdEdgeCases:
|
||||
def test_chunk_index_zero(self) -> None:
|
||||
result = get_opensearch_doc_chunk_id(SINGLE_TENANT_STATE, "doc", chunk_index=0)
|
||||
assert result.endswith("__0")
|
||||
|
||||
def test_large_chunk_index(self) -> None:
|
||||
result = get_opensearch_doc_chunk_id(
|
||||
SINGLE_TENANT_STATE, "doc", chunk_index=99999
|
||||
)
|
||||
assert result.endswith("__99999")
|
||||
|
||||
def test_doc_id_with_only_special_chars_raises(self) -> None:
|
||||
"""
|
||||
Tests that a doc ID that becomes empty after filtering should raise
|
||||
ValueError.
|
||||
"""
|
||||
with pytest.raises(ValueError, match="empty after filtering"):
|
||||
get_opensearch_doc_chunk_id(SINGLE_TENANT_STATE, "###???///", chunk_index=0)
|
||||
|
||||
def test_doc_id_at_boundary_length(self) -> None:
|
||||
"""
|
||||
Tests that a doc ID right at the boundary should not be hashed.
|
||||
"""
|
||||
suffix = f"__{DEFAULT_MAX_CHUNK_SIZE}__0"
|
||||
suffix_len = len(suffix.encode("utf-8"))
|
||||
# Max doc ID length that won't trigger hashing (must be <
|
||||
# max_encoded_length).
|
||||
max_doc_len = MAX_DOCUMENT_ID_ENCODED_LENGTH - suffix_len - 1
|
||||
doc_id = "a" * max_doc_len
|
||||
result = get_opensearch_doc_chunk_id(SINGLE_TENANT_STATE, doc_id, chunk_index=0)
|
||||
assert doc_id in result
|
||||
|
||||
def test_doc_id_at_boundary_length_multitenant(self) -> None:
|
||||
"""
|
||||
Tests that a doc ID right at the boundary should not be hashed in
|
||||
multitenant mode.
|
||||
"""
|
||||
suffix = f"__{DEFAULT_MAX_CHUNK_SIZE}__0"
|
||||
suffix_len = len(suffix.encode("utf-8"))
|
||||
prefix = f"{EXPECTED_SHORT_TENANT}__"
|
||||
prefix_len = len(prefix.encode("utf-8"))
|
||||
# Max doc ID length that won't trigger hashing (must be <
|
||||
# max_encoded_length).
|
||||
max_doc_len = MAX_DOCUMENT_ID_ENCODED_LENGTH - suffix_len - prefix_len - 1
|
||||
doc_id = "a" * max_doc_len
|
||||
result = get_opensearch_doc_chunk_id(MULTI_TENANT_STATE, doc_id, chunk_index=0)
|
||||
assert doc_id in result
|
||||
|
||||
def test_doc_id_one_over_boundary_is_hashed(self) -> None:
|
||||
"""
|
||||
Tests that a doc ID one byte over the boundary should be hashed.
|
||||
"""
|
||||
suffix = f"__{DEFAULT_MAX_CHUNK_SIZE}__0"
|
||||
suffix_len = len(suffix.encode("utf-8"))
|
||||
# This length will trigger the >= check in filter_and_validate_document_id
|
||||
doc_id = "a" * (MAX_DOCUMENT_ID_ENCODED_LENGTH - suffix_len)
|
||||
result = get_opensearch_doc_chunk_id(SINGLE_TENANT_STATE, doc_id, chunk_index=0)
|
||||
assert doc_id not in result
|
||||
91
backend/tests/unit/onyx/file_store/test_delete_file.py
Normal file
91
backend/tests/unit/onyx/file_store/test_delete_file.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""Tests for FileStore.delete_file error_on_missing behavior."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
_S3_MODULE = "onyx.file_store.file_store"
|
||||
_PG_MODULE = "onyx.file_store.postgres_file_store"
|
||||
|
||||
|
||||
def _mock_db_session() -> MagicMock:
|
||||
session = MagicMock()
|
||||
session.__enter__ = MagicMock(return_value=session)
|
||||
session.__exit__ = MagicMock(return_value=False)
|
||||
return session
|
||||
|
||||
|
||||
# ── S3BackedFileStore ────────────────────────────────────────────────
|
||||
|
||||
|
||||
@patch(f"{_S3_MODULE}.get_session_with_current_tenant_if_none")
|
||||
@patch(f"{_S3_MODULE}.get_filerecord_by_file_id_optional", return_value=None)
|
||||
def test_s3_delete_missing_file_raises_by_default(
|
||||
_mock_get_record: MagicMock,
|
||||
mock_ctx: MagicMock,
|
||||
) -> None:
|
||||
from onyx.file_store.file_store import S3BackedFileStore
|
||||
|
||||
mock_ctx.return_value = _mock_db_session()
|
||||
store = S3BackedFileStore(bucket_name="b")
|
||||
|
||||
with pytest.raises(RuntimeError, match="does not exist"):
|
||||
store.delete_file("nonexistent")
|
||||
|
||||
|
||||
@patch(f"{_S3_MODULE}.get_session_with_current_tenant_if_none")
|
||||
@patch(f"{_S3_MODULE}.get_filerecord_by_file_id_optional", return_value=None)
|
||||
@patch(f"{_S3_MODULE}.delete_filerecord_by_file_id")
|
||||
def test_s3_delete_missing_file_silent_when_error_on_missing_false(
|
||||
mock_delete_record: MagicMock,
|
||||
_mock_get_record: MagicMock,
|
||||
mock_ctx: MagicMock,
|
||||
) -> None:
|
||||
from onyx.file_store.file_store import S3BackedFileStore
|
||||
|
||||
mock_ctx.return_value = _mock_db_session()
|
||||
store = S3BackedFileStore(bucket_name="b")
|
||||
|
||||
store.delete_file("nonexistent", error_on_missing=False)
|
||||
|
||||
mock_delete_record.assert_not_called()
|
||||
|
||||
|
||||
# ── PostgresBackedFileStore ──────────────────────────────────────────
|
||||
|
||||
|
||||
@patch(f"{_PG_MODULE}.get_session_with_current_tenant_if_none")
|
||||
@patch(f"{_PG_MODULE}.get_file_content_by_file_id_optional", return_value=None)
|
||||
def test_pg_delete_missing_file_raises_by_default(
|
||||
_mock_get_content: MagicMock,
|
||||
mock_ctx: MagicMock,
|
||||
) -> None:
|
||||
from onyx.file_store.postgres_file_store import PostgresBackedFileStore
|
||||
|
||||
mock_ctx.return_value = _mock_db_session()
|
||||
store = PostgresBackedFileStore()
|
||||
|
||||
with pytest.raises(RuntimeError, match="does not exist"):
|
||||
store.delete_file("nonexistent")
|
||||
|
||||
|
||||
@patch(f"{_PG_MODULE}.get_session_with_current_tenant_if_none")
|
||||
@patch(f"{_PG_MODULE}.get_file_content_by_file_id_optional", return_value=None)
|
||||
@patch(f"{_PG_MODULE}.delete_file_content_by_file_id")
|
||||
@patch(f"{_PG_MODULE}.delete_filerecord_by_file_id")
|
||||
def test_pg_delete_missing_file_silent_when_error_on_missing_false(
|
||||
mock_delete_record: MagicMock,
|
||||
mock_delete_content: MagicMock,
|
||||
_mock_get_content: MagicMock,
|
||||
mock_ctx: MagicMock,
|
||||
) -> None:
|
||||
from onyx.file_store.postgres_file_store import PostgresBackedFileStore
|
||||
|
||||
mock_ctx.return_value = _mock_db_session()
|
||||
store = PostgresBackedFileStore()
|
||||
|
||||
store.delete_file("nonexistent", error_on_missing=False)
|
||||
|
||||
mock_delete_record.assert_not_called()
|
||||
mock_delete_content.assert_not_called()
|
||||
@@ -113,6 +113,7 @@ def make_db_group(**kwargs: Any) -> MagicMock:
|
||||
group.name = kwargs.get("name", "Engineering")
|
||||
group.is_up_for_deletion = kwargs.get("is_up_for_deletion", False)
|
||||
group.is_up_to_date = kwargs.get("is_up_to_date", True)
|
||||
group.is_default = kwargs.get("is_default", False)
|
||||
return group
|
||||
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.server.models import FullUserSnapshot
|
||||
from onyx.server.models import UserGroupInfo
|
||||
|
||||
@@ -25,6 +26,7 @@ def _mock_user(
|
||||
user.updated_at = updated_at or datetime.datetime(
|
||||
2025, 6, 15, tzinfo=datetime.timezone.utc
|
||||
)
|
||||
user.account_type = AccountType.STANDARD
|
||||
return user
|
||||
|
||||
|
||||
|
||||
@@ -98,6 +98,7 @@ Useful hardening flags:
|
||||
| `serve` | Serve the interactive chat TUI over SSH |
|
||||
| `configure` | Configure server URL and API key |
|
||||
| `validate-config` | Validate configuration and test connection |
|
||||
| `install-skill` | Install the agent skill file into a project |
|
||||
|
||||
## Slash Commands (in TUI)
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/api"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/config"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/exitcodes"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
@@ -16,16 +17,23 @@ func newAgentsCmd() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "agents",
|
||||
Short: "List available agents",
|
||||
Long: `List all visible agents configured on the Onyx server.
|
||||
|
||||
By default, output is a human-readable table with ID, name, and description.
|
||||
Use --json for machine-readable output.`,
|
||||
Example: ` onyx-cli agents
|
||||
onyx-cli agents --json
|
||||
onyx-cli agents --json | jq '.[].name'`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
cfg := config.Load()
|
||||
if !cfg.IsConfigured() {
|
||||
return fmt.Errorf("onyx CLI is not configured — run 'onyx-cli configure' first")
|
||||
return exitcodes.New(exitcodes.NotConfigured, "onyx CLI is not configured\n Run: onyx-cli configure")
|
||||
}
|
||||
|
||||
client := api.NewClient(cfg)
|
||||
agents, err := client.ListAgents(cmd.Context())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list agents: %w", err)
|
||||
return fmt.Errorf("failed to list agents: %w\n Check your connection with: onyx-cli validate-config", err)
|
||||
}
|
||||
|
||||
if agentsJSON {
|
||||
|
||||
140
cli/cmd/ask.go
140
cli/cmd/ask.go
@@ -4,33 +4,65 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/api"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/config"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/exitcodes"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/models"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/overflow"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
const defaultMaxOutputBytes = 4096
|
||||
|
||||
func newAskCmd() *cobra.Command {
|
||||
var (
|
||||
askAgentID int
|
||||
askJSON bool
|
||||
askQuiet bool
|
||||
askPrompt string
|
||||
maxOutput int
|
||||
)
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "ask [question]",
|
||||
Short: "Ask a one-shot question (non-interactive)",
|
||||
Args: cobra.ExactArgs(1),
|
||||
Long: `Send a one-shot question to an Onyx agent and print the response.
|
||||
|
||||
The question can be provided as a positional argument, via --prompt, or piped
|
||||
through stdin. When stdin contains piped data, it is sent as context along
|
||||
with the question from --prompt (or used as the question itself).
|
||||
|
||||
When stdout is not a TTY (e.g., called by a script or AI agent), output is
|
||||
automatically truncated to --max-output bytes and the full response is saved
|
||||
to a temp file. Set --max-output 0 to disable truncation.`,
|
||||
Args: cobra.MaximumNArgs(1),
|
||||
Example: ` onyx-cli ask "What connectors are available?"
|
||||
onyx-cli ask --agent-id 3 "Summarize our Q4 revenue"
|
||||
onyx-cli ask --json "List all users" | jq '.event.content'
|
||||
cat error.log | onyx-cli ask --prompt "Find the root cause"
|
||||
echo "what is onyx?" | onyx-cli ask`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
cfg := config.Load()
|
||||
if !cfg.IsConfigured() {
|
||||
return fmt.Errorf("onyx CLI is not configured — run 'onyx-cli configure' first")
|
||||
return exitcodes.New(exitcodes.NotConfigured, "onyx CLI is not configured\n Run: onyx-cli configure")
|
||||
}
|
||||
|
||||
if askJSON && askQuiet {
|
||||
return exitcodes.New(exitcodes.BadRequest, "--json and --quiet cannot be used together")
|
||||
}
|
||||
|
||||
question, err := resolveQuestion(args, askPrompt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
question := args[0]
|
||||
agentID := cfg.DefaultAgentID
|
||||
if cmd.Flags().Changed("agent-id") {
|
||||
agentID = askAgentID
|
||||
@@ -50,9 +82,23 @@ func newAskCmd() *cobra.Command {
|
||||
nil,
|
||||
)
|
||||
|
||||
// Determine truncation threshold.
|
||||
isTTY := term.IsTerminal(int(os.Stdout.Fd()))
|
||||
truncateAt := 0 // 0 means no truncation
|
||||
if cmd.Flags().Changed("max-output") {
|
||||
truncateAt = maxOutput
|
||||
} else if !isTTY {
|
||||
truncateAt = defaultMaxOutputBytes
|
||||
}
|
||||
|
||||
var sessionID string
|
||||
var lastErr error
|
||||
gotStop := false
|
||||
|
||||
// Overflow writer: tees to stdout and optionally to a temp file.
|
||||
// In quiet mode, buffer everything and print once at the end.
|
||||
ow := &overflow.Writer{Limit: truncateAt, Quiet: askQuiet}
|
||||
|
||||
for event := range ch {
|
||||
if e, ok := event.(models.SessionCreatedEvent); ok {
|
||||
sessionID = e.ChatSessionID
|
||||
@@ -82,22 +128,50 @@ func newAskCmd() *cobra.Command {
|
||||
|
||||
switch e := event.(type) {
|
||||
case models.MessageDeltaEvent:
|
||||
fmt.Print(e.Content)
|
||||
ow.Write(e.Content)
|
||||
case models.SearchStartEvent:
|
||||
if isTTY && !askQuiet {
|
||||
if e.IsInternetSearch {
|
||||
fmt.Fprintf(os.Stderr, "\033[2mSearching the web...\033[0m\n")
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "\033[2mSearching documents...\033[0m\n")
|
||||
}
|
||||
}
|
||||
case models.SearchQueriesEvent:
|
||||
if isTTY && !askQuiet {
|
||||
for _, q := range e.Queries {
|
||||
fmt.Fprintf(os.Stderr, "\033[2m → %s\033[0m\n", q)
|
||||
}
|
||||
}
|
||||
case models.SearchDocumentsEvent:
|
||||
if isTTY && !askQuiet && len(e.Documents) > 0 {
|
||||
fmt.Fprintf(os.Stderr, "\033[2mFound %d documents\033[0m\n", len(e.Documents))
|
||||
}
|
||||
case models.ReasoningStartEvent:
|
||||
if isTTY && !askQuiet {
|
||||
fmt.Fprintf(os.Stderr, "\033[2mThinking...\033[0m\n")
|
||||
}
|
||||
case models.ToolStartEvent:
|
||||
if isTTY && !askQuiet && e.ToolName != "" {
|
||||
fmt.Fprintf(os.Stderr, "\033[2mUsing %s...\033[0m\n", e.ToolName)
|
||||
}
|
||||
case models.ErrorEvent:
|
||||
ow.Finish()
|
||||
return fmt.Errorf("%s", e.Error)
|
||||
case models.StopEvent:
|
||||
fmt.Println()
|
||||
ow.Finish()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if !askJSON {
|
||||
ow.Finish()
|
||||
}
|
||||
|
||||
if ctx.Err() != nil {
|
||||
if sessionID != "" {
|
||||
client.StopChatSession(context.Background(), sessionID)
|
||||
}
|
||||
if !askJSON {
|
||||
fmt.Println()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -105,20 +179,56 @@ func newAskCmd() *cobra.Command {
|
||||
return lastErr
|
||||
}
|
||||
if !gotStop {
|
||||
if !askJSON {
|
||||
fmt.Println()
|
||||
}
|
||||
return fmt.Errorf("stream ended unexpectedly")
|
||||
}
|
||||
if !askJSON {
|
||||
fmt.Println()
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().IntVar(&askAgentID, "agent-id", 0, "Agent ID to use")
|
||||
cmd.Flags().BoolVar(&askJSON, "json", false, "Output raw JSON events")
|
||||
// Suppress cobra's default error/usage on RunE errors
|
||||
cmd.Flags().BoolVarP(&askQuiet, "quiet", "q", false, "Buffer output and print once at end (no streaming)")
|
||||
cmd.Flags().StringVar(&askPrompt, "prompt", "", "Question text (use with piped stdin context)")
|
||||
cmd.Flags().IntVar(&maxOutput, "max-output", defaultMaxOutputBytes,
|
||||
"Max bytes to print before truncating (0 to disable, auto-enabled for non-TTY)")
|
||||
return cmd
|
||||
}
|
||||
|
||||
// resolveQuestion builds the final question string from args, --prompt, and stdin.
|
||||
func resolveQuestion(args []string, prompt string) (string, error) {
|
||||
hasArg := len(args) > 0
|
||||
hasPrompt := prompt != ""
|
||||
hasStdin := !term.IsTerminal(int(os.Stdin.Fd()))
|
||||
|
||||
if hasArg && hasPrompt {
|
||||
return "", exitcodes.New(exitcodes.BadRequest, "specify the question as an argument or --prompt, not both")
|
||||
}
|
||||
|
||||
var stdinContent string
|
||||
if hasStdin {
|
||||
const maxStdinBytes = 10 * 1024 * 1024 // 10MB
|
||||
data, err := io.ReadAll(io.LimitReader(os.Stdin, maxStdinBytes))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read stdin: %w", err)
|
||||
}
|
||||
stdinContent = strings.TrimSpace(string(data))
|
||||
}
|
||||
|
||||
switch {
|
||||
case hasArg && stdinContent != "":
|
||||
// arg is the question, stdin is context
|
||||
return args[0] + "\n\n" + stdinContent, nil
|
||||
case hasArg:
|
||||
return args[0], nil
|
||||
case hasPrompt && stdinContent != "":
|
||||
// --prompt is the question, stdin is context
|
||||
return prompt + "\n\n" + stdinContent, nil
|
||||
case hasPrompt:
|
||||
return prompt, nil
|
||||
case stdinContent != "":
|
||||
return stdinContent, nil
|
||||
default:
|
||||
return "", exitcodes.New(exitcodes.BadRequest, "no question provided\n Usage: onyx-cli ask \"your question\"\n Or: echo \"context\" | onyx-cli ask --prompt \"your question\"")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/config"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/onboarding"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/starprompt"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/tui"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
@@ -12,6 +13,11 @@ func newChatCmd() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: "chat",
|
||||
Short: "Launch the interactive chat TUI (default)",
|
||||
Long: `Launch the interactive terminal UI for chatting with your Onyx agent.
|
||||
This is the default command when no subcommand is specified. On first run,
|
||||
an interactive setup wizard will guide you through configuration.`,
|
||||
Example: ` onyx-cli chat
|
||||
onyx-cli`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
cfg := config.Load()
|
||||
|
||||
@@ -24,6 +30,8 @@ func newChatCmd() *cobra.Command {
|
||||
cfg = *result
|
||||
}
|
||||
|
||||
starprompt.MaybePrompt()
|
||||
|
||||
m := tui.NewModel(cfg)
|
||||
p := tea.NewProgram(m, tea.WithAltScreen(), tea.WithMouseCellMotion())
|
||||
_, err := p.Run()
|
||||
|
||||
@@ -1,19 +1,126 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/api"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/config"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/exitcodes"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/onboarding"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
func newConfigureCmd() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
var (
|
||||
serverURL string
|
||||
apiKey string
|
||||
apiKeyStdin bool
|
||||
dryRun bool
|
||||
)
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "configure",
|
||||
Short: "Configure server URL and API key",
|
||||
Long: `Set up the Onyx CLI with your server URL and API key.
|
||||
|
||||
When --server-url and --api-key are both provided, the configuration is saved
|
||||
non-interactively (useful for scripts and AI agents). Otherwise, an interactive
|
||||
setup wizard is launched.
|
||||
|
||||
If --api-key is omitted but stdin has piped data, the API key is read from
|
||||
stdin automatically. You can also use --api-key-stdin to make this explicit.
|
||||
This avoids leaking the key in shell history.
|
||||
|
||||
Use --dry-run to test the connection without saving the configuration.`,
|
||||
Example: ` onyx-cli configure
|
||||
onyx-cli configure --server-url https://my-onyx.com --api-key sk-...
|
||||
echo "$ONYX_API_KEY" | onyx-cli configure --server-url https://my-onyx.com
|
||||
echo "$ONYX_API_KEY" | onyx-cli configure --server-url https://my-onyx.com --api-key-stdin
|
||||
onyx-cli configure --server-url https://my-onyx.com --api-key sk-... --dry-run`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
// Read API key from stdin if piped (implicit) or --api-key-stdin (explicit)
|
||||
if apiKeyStdin && apiKey != "" {
|
||||
return exitcodes.New(exitcodes.BadRequest, "--api-key and --api-key-stdin cannot be used together")
|
||||
}
|
||||
if (apiKey == "" && !term.IsTerminal(int(os.Stdin.Fd()))) || apiKeyStdin {
|
||||
data, err := io.ReadAll(os.Stdin)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read API key from stdin: %w", err)
|
||||
}
|
||||
apiKey = strings.TrimSpace(string(data))
|
||||
}
|
||||
|
||||
if serverURL != "" && apiKey != "" {
|
||||
return configureNonInteractive(serverURL, apiKey, dryRun)
|
||||
}
|
||||
|
||||
if dryRun {
|
||||
return exitcodes.New(exitcodes.BadRequest, "--dry-run requires --server-url and --api-key")
|
||||
}
|
||||
|
||||
if serverURL != "" || apiKey != "" {
|
||||
return exitcodes.New(exitcodes.BadRequest, "both --server-url and --api-key are required for non-interactive setup\n Run 'onyx-cli configure' without flags for interactive setup")
|
||||
}
|
||||
|
||||
cfg := config.Load()
|
||||
onboarding.Run(&cfg)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().StringVar(&serverURL, "server-url", "", "Onyx server URL (e.g., https://cloud.onyx.app)")
|
||||
cmd.Flags().StringVar(&apiKey, "api-key", "", "API key for authentication (or pipe via stdin)")
|
||||
cmd.Flags().BoolVar(&apiKeyStdin, "api-key-stdin", false, "Read API key from stdin (explicit; also happens automatically when stdin is piped)")
|
||||
cmd.Flags().BoolVar(&dryRun, "dry-run", false, "Test connection without saving config (requires --server-url and --api-key)")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func configureNonInteractive(serverURL, apiKey string, dryRun bool) error {
|
||||
cfg := config.OnyxCliConfig{
|
||||
ServerURL: serverURL,
|
||||
APIKey: apiKey,
|
||||
DefaultAgentID: 0,
|
||||
}
|
||||
|
||||
// Preserve existing default agent ID from disk (not env overrides)
|
||||
if existing := config.LoadFromDisk(); existing.DefaultAgentID != 0 {
|
||||
cfg.DefaultAgentID = existing.DefaultAgentID
|
||||
}
|
||||
|
||||
// Test connection
|
||||
client := api.NewClient(cfg)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := client.TestConnection(ctx); err != nil {
|
||||
var authErr *api.AuthError
|
||||
if errors.As(err, &authErr) {
|
||||
return exitcodes.Newf(exitcodes.AuthFailure, "authentication failed: %v\n Check your API key", err)
|
||||
}
|
||||
return exitcodes.Newf(exitcodes.Unreachable, "connection failed: %v\n Check your server URL", err)
|
||||
}
|
||||
|
||||
if dryRun {
|
||||
fmt.Printf("Server: %s\n", serverURL)
|
||||
fmt.Println("Status: connected and authenticated")
|
||||
fmt.Println("Dry run: config was NOT saved")
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := config.Save(cfg); err != nil {
|
||||
return fmt.Errorf("could not save config: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Config: %s\n", config.ConfigFilePath())
|
||||
fmt.Printf("Server: %s\n", serverURL)
|
||||
fmt.Println("Status: connected and authenticated")
|
||||
return nil
|
||||
}
|
||||
|
||||
176
cli/cmd/install_skill.go
Normal file
176
cli/cmd/install_skill.go
Normal file
@@ -0,0 +1,176 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/embedded"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/fsutil"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// agentSkillDirs maps agent names to their skill directory paths (relative to
|
||||
// the project or home root). "Universal" agents like Cursor and Codex read
|
||||
// from .agents/skills directly, so they don't need their own entry here.
|
||||
var agentSkillDirs = map[string]string{
|
||||
"claude-code": filepath.Join(".claude", "skills"),
|
||||
}
|
||||
|
||||
const (
|
||||
canonicalDir = ".agents/skills"
|
||||
skillName = "onyx-cli"
|
||||
)
|
||||
|
||||
func newInstallSkillCmd() *cobra.Command {
|
||||
var (
|
||||
global bool
|
||||
copyMode bool
|
||||
agents []string
|
||||
)
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "install-skill",
|
||||
Short: "Install the Onyx CLI agent skill file",
|
||||
Long: `Install the bundled SKILL.md so that AI coding agents can discover and use
|
||||
the Onyx CLI as a tool.
|
||||
|
||||
Files are written to the canonical .agents/skills/onyx-cli/ directory. For
|
||||
agents that use their own skill directory (e.g. Claude Code uses .claude/skills/),
|
||||
a symlink is created pointing back to the canonical copy.
|
||||
|
||||
By default the skill is installed at the project level (current directory).
|
||||
Use --global to install under your home directory instead.
|
||||
|
||||
Use --copy to write independent copies instead of symlinks.
|
||||
Use --agent to target specific agents (can be repeated).`,
|
||||
Example: ` onyx-cli install-skill
|
||||
onyx-cli install-skill --global
|
||||
onyx-cli install-skill --agent claude-code
|
||||
onyx-cli install-skill --copy`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
base, err := installBase(global)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write the canonical copy.
|
||||
canonicalSkillDir := filepath.Join(base, canonicalDir, skillName)
|
||||
dest := filepath.Join(canonicalSkillDir, "SKILL.md")
|
||||
content := []byte(embedded.SkillMD)
|
||||
|
||||
status, err := fsutil.CompareFile(dest, content)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
switch status {
|
||||
case fsutil.StatusUpToDate:
|
||||
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Up to date %s\n", dest)
|
||||
case fsutil.StatusDiffers:
|
||||
_, _ = fmt.Fprintf(cmd.ErrOrStderr(), "Warning: overwriting modified %s\n", dest)
|
||||
if err := os.WriteFile(dest, content, 0o644); err != nil {
|
||||
return fmt.Errorf("could not write skill file: %w", err)
|
||||
}
|
||||
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Installed %s\n", dest)
|
||||
default: // statusMissing
|
||||
if err := os.MkdirAll(canonicalSkillDir, 0o755); err != nil {
|
||||
return fmt.Errorf("could not create directory: %w", err)
|
||||
}
|
||||
if err := os.WriteFile(dest, content, 0o644); err != nil {
|
||||
return fmt.Errorf("could not write skill file: %w", err)
|
||||
}
|
||||
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Installed %s\n", dest)
|
||||
}
|
||||
|
||||
// Determine which agents to link.
|
||||
targets := agentSkillDirs
|
||||
if len(agents) > 0 {
|
||||
targets = make(map[string]string)
|
||||
for _, a := range agents {
|
||||
dir, ok := agentSkillDirs[a]
|
||||
if !ok {
|
||||
_, _ = fmt.Fprintf(cmd.ErrOrStderr(), "Unknown agent %q (skipped) — known agents:", a)
|
||||
for name := range agentSkillDirs {
|
||||
_, _ = fmt.Fprintf(cmd.ErrOrStderr(), " %s", name)
|
||||
}
|
||||
_, _ = fmt.Fprintln(cmd.ErrOrStderr())
|
||||
continue
|
||||
}
|
||||
targets[a] = dir
|
||||
}
|
||||
}
|
||||
|
||||
// Create symlinks (or copies) from agent-specific dirs to canonical.
|
||||
for name, skillsDir := range targets {
|
||||
agentSkillDir := filepath.Join(base, skillsDir, skillName)
|
||||
|
||||
if copyMode {
|
||||
copyDest := filepath.Join(agentSkillDir, "SKILL.md")
|
||||
if err := fsutil.EnsureDirForCopy(agentSkillDir); err != nil {
|
||||
return fmt.Errorf("could not prepare %s directory: %w", name, err)
|
||||
}
|
||||
if err := os.MkdirAll(agentSkillDir, 0o755); err != nil {
|
||||
return fmt.Errorf("could not create %s directory: %w", name, err)
|
||||
}
|
||||
if err := os.WriteFile(copyDest, []byte(embedded.SkillMD), 0o644); err != nil {
|
||||
return fmt.Errorf("could not write %s skill file: %w", name, err)
|
||||
}
|
||||
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Copied %s\n", copyDest)
|
||||
continue
|
||||
}
|
||||
|
||||
// Compute relative symlink target. Symlinks resolve relative to
|
||||
// the parent directory of the link, not the link itself.
|
||||
rel, err := filepath.Rel(filepath.Dir(agentSkillDir), canonicalSkillDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not compute relative path for %s: %w", name, err)
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(agentSkillDir), 0o755); err != nil {
|
||||
return fmt.Errorf("could not create %s directory: %w", name, err)
|
||||
}
|
||||
|
||||
// Remove existing symlink/dir before creating.
|
||||
_ = os.Remove(agentSkillDir)
|
||||
|
||||
if err := os.Symlink(rel, agentSkillDir); err != nil {
|
||||
// Fall back to copy if symlink fails (e.g. Windows without dev mode).
|
||||
copyDest := filepath.Join(agentSkillDir, "SKILL.md")
|
||||
if mkErr := os.MkdirAll(agentSkillDir, 0o755); mkErr != nil {
|
||||
return fmt.Errorf("could not create %s directory: %w", name, mkErr)
|
||||
}
|
||||
if wErr := os.WriteFile(copyDest, []byte(embedded.SkillMD), 0o644); wErr != nil {
|
||||
return fmt.Errorf("could not write %s skill file: %w", name, wErr)
|
||||
}
|
||||
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Copied %s (symlink failed)\n", copyDest)
|
||||
continue
|
||||
}
|
||||
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Linked %s -> %s\n", agentSkillDir, rel)
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().BoolVarP(&global, "global", "g", false, "Install to home directory instead of project")
|
||||
cmd.Flags().BoolVar(©Mode, "copy", false, "Copy files instead of symlinking")
|
||||
cmd.Flags().StringSliceVarP(&agents, "agent", "a", nil, "Target specific agents (e.g. claude-code)")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func installBase(global bool) (string, error) {
|
||||
if global {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("could not determine home directory: %w", err)
|
||||
}
|
||||
return home, nil
|
||||
}
|
||||
cwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("could not determine working directory: %w", err)
|
||||
}
|
||||
return cwd, nil
|
||||
}
|
||||
|
||||
@@ -97,6 +97,7 @@ func Execute() error {
|
||||
rootCmd.AddCommand(newConfigureCmd())
|
||||
rootCmd.AddCommand(newValidateConfigCmd())
|
||||
rootCmd.AddCommand(newServeCmd())
|
||||
rootCmd.AddCommand(newInstallSkillCmd())
|
||||
|
||||
// Default command is chat, but intercept --version first
|
||||
rootCmd.RunE = func(cmd *cobra.Command, args []string) error {
|
||||
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
"github.com/charmbracelet/wish/ratelimiter"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/api"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/config"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/exitcodes"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/tui"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/time/rate"
|
||||
@@ -295,15 +296,15 @@ provided via the ONYX_API_KEY environment variable to skip the prompt:
|
||||
The server URL is taken from the server operator's config. The server
|
||||
auto-generates an Ed25519 host key on first run if the key file does not
|
||||
already exist. The host key path can also be set via the ONYX_SSH_HOST_KEY
|
||||
environment variable (the --host-key flag takes precedence).
|
||||
|
||||
Example:
|
||||
onyx-cli serve --port 2222
|
||||
ssh localhost -p 2222`,
|
||||
environment variable (the --host-key flag takes precedence).`,
|
||||
Example: ` onyx-cli serve --port 2222
|
||||
ssh localhost -p 2222
|
||||
onyx-cli serve --host 0.0.0.0 --port 2222
|
||||
onyx-cli serve --idle-timeout 30m --max-session-timeout 2h`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
serverCfg := config.Load()
|
||||
if serverCfg.ServerURL == "" {
|
||||
return fmt.Errorf("server URL is not configured; run 'onyx-cli configure' first")
|
||||
return exitcodes.New(exitcodes.NotConfigured, "server URL is not configured\n Run: onyx-cli configure")
|
||||
}
|
||||
if !cmd.Flags().Changed("host-key") {
|
||||
if v := os.Getenv(config.EnvSSHHostKey); v != "" {
|
||||
|
||||
@@ -2,11 +2,13 @@ package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/api"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/config"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/exitcodes"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/version"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
@@ -16,17 +18,21 @@ func newValidateConfigCmd() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: "validate-config",
|
||||
Short: "Validate configuration and test server connection",
|
||||
Long: `Check that the CLI is configured, the server is reachable, and the API key
|
||||
is valid. Also reports the server version and warns if it is below the
|
||||
minimum required.`,
|
||||
Example: ` onyx-cli validate-config`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
// Check config file
|
||||
if !config.ConfigExists() {
|
||||
return fmt.Errorf("config file not found at %s\n Run 'onyx-cli configure' to set up", config.ConfigFilePath())
|
||||
return exitcodes.Newf(exitcodes.NotConfigured, "config file not found at %s\n Run: onyx-cli configure", config.ConfigFilePath())
|
||||
}
|
||||
|
||||
cfg := config.Load()
|
||||
|
||||
// Check API key
|
||||
if !cfg.IsConfigured() {
|
||||
return fmt.Errorf("API key is missing\n Run 'onyx-cli configure' to set up")
|
||||
return exitcodes.New(exitcodes.NotConfigured, "API key is missing\n Run: onyx-cli configure")
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Config: %s\n", config.ConfigFilePath())
|
||||
@@ -35,7 +41,11 @@ func newValidateConfigCmd() *cobra.Command {
|
||||
// Test connection
|
||||
client := api.NewClient(cfg)
|
||||
if err := client.TestConnection(cmd.Context()); err != nil {
|
||||
return fmt.Errorf("connection failed: %w", err)
|
||||
var authErr *api.AuthError
|
||||
if errors.As(err, &authErr) {
|
||||
return exitcodes.Newf(exitcodes.AuthFailure, "authentication failed: %v\n Reconfigure with: onyx-cli configure", err)
|
||||
}
|
||||
return exitcodes.Newf(exitcodes.Unreachable, "connection failed: %v\n Reconfigure with: onyx-cli configure", err)
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintln(cmd.OutOrStdout(), "Status: connected and authenticated")
|
||||
|
||||
@@ -149,12 +149,12 @@ func (c *Client) TestConnection(ctx context.Context) error {
|
||||
|
||||
if resp2.StatusCode == 401 || resp2.StatusCode == 403 {
|
||||
if isHTML || strings.Contains(respServer, "awselb") {
|
||||
return fmt.Errorf("HTTP %d from a reverse proxy (not the Onyx backend).\n Check your deployment's ingress / proxy configuration", resp2.StatusCode)
|
||||
return &AuthError{Message: fmt.Sprintf("HTTP %d from a reverse proxy (not the Onyx backend).\n Check your deployment's ingress / proxy configuration", resp2.StatusCode)}
|
||||
}
|
||||
if resp2.StatusCode == 401 {
|
||||
return fmt.Errorf("invalid API key or token.\n %s", body)
|
||||
return &AuthError{Message: fmt.Sprintf("invalid API key or token.\n %s", body)}
|
||||
}
|
||||
return fmt.Errorf("access denied — check that the API key is valid.\n %s", body)
|
||||
return &AuthError{Message: fmt.Sprintf("access denied — check that the API key is valid.\n %s", body)}
|
||||
}
|
||||
|
||||
detail := fmt.Sprintf("HTTP %d", resp2.StatusCode)
|
||||
|
||||
@@ -11,3 +11,12 @@ type OnyxAPIError struct {
|
||||
func (e *OnyxAPIError) Error() string {
|
||||
return fmt.Sprintf("HTTP %d: %s", e.StatusCode, e.Detail)
|
||||
}
|
||||
|
||||
// AuthError is returned when authentication or authorization fails.
|
||||
type AuthError struct {
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e *AuthError) Error() string {
|
||||
return e.Message
|
||||
}
|
||||
|
||||
@@ -59,8 +59,10 @@ func ConfigExists() bool {
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// Load reads config from file and applies environment variable overrides.
|
||||
func Load() OnyxCliConfig {
|
||||
// LoadFromDisk reads config from the file only, without applying environment
|
||||
// variable overrides. Use this when you need the persisted config values
|
||||
// (e.g., to preserve them during a save operation).
|
||||
func LoadFromDisk() OnyxCliConfig {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
data, err := os.ReadFile(ConfigFilePath())
|
||||
@@ -70,6 +72,13 @@ func Load() OnyxCliConfig {
|
||||
}
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
// Load reads config from file and applies environment variable overrides.
|
||||
func Load() OnyxCliConfig {
|
||||
cfg := LoadFromDisk()
|
||||
|
||||
// Environment overrides
|
||||
if v := os.Getenv(EnvServerURL); v != "" {
|
||||
cfg.ServerURL = v
|
||||
|
||||
187
cli/internal/embedded/SKILL.md
Normal file
187
cli/internal/embedded/SKILL.md
Normal file
@@ -0,0 +1,187 @@
|
||||
---
|
||||
name: onyx-cli
|
||||
description: Query the Onyx knowledge base using the onyx-cli command. Use when the user wants to search company documents, ask questions about internal knowledge, query connected data sources, or look up information stored in Onyx.
|
||||
---
|
||||
|
||||
# Onyx CLI — Agent Tool
|
||||
|
||||
Onyx is an enterprise search and Gen-AI platform that connects to company documents, apps, and people. The `onyx-cli` CLI provides non-interactive commands to query the Onyx knowledge base and list available agents.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
### 1. Check if installed
|
||||
|
||||
```bash
|
||||
which onyx-cli
|
||||
```
|
||||
|
||||
### 2. Install (if needed)
|
||||
|
||||
**Primary — pip:**
|
||||
|
||||
```bash
|
||||
pip install onyx-cli
|
||||
```
|
||||
|
||||
**From source (Go):**
|
||||
|
||||
```bash
|
||||
go build -o onyx-cli github.com/onyx-dot-app/onyx/cli && sudo mv onyx-cli /usr/local/bin/
|
||||
```
|
||||
|
||||
### 3. Check if configured
|
||||
|
||||
```bash
|
||||
onyx-cli validate-config
|
||||
```
|
||||
|
||||
This checks the config file exists, API key is present, and tests the server connection via `/api/me`. Exit code 0 on success, non-zero with a descriptive error on failure.
|
||||
|
||||
If unconfigured, you have two options:
|
||||
|
||||
**Option A — Interactive setup (requires user input):**
|
||||
|
||||
```bash
|
||||
onyx-cli configure
|
||||
```
|
||||
|
||||
This prompts for the Onyx server URL and API key, tests the connection, and saves config.
|
||||
|
||||
**Option B — Environment variables (non-interactive, preferred for agents):**
|
||||
|
||||
```bash
|
||||
export ONYX_SERVER_URL="https://your-onyx-server.com" # default: https://cloud.onyx.app
|
||||
export ONYX_API_KEY="your-api-key"
|
||||
```
|
||||
|
||||
Environment variables override the config file. If these are set, no config file is needed.
|
||||
|
||||
| Variable | Required | Description |
|
||||
| ----------------- | -------- | -------------------------------------------------------- |
|
||||
| `ONYX_SERVER_URL` | No | Onyx server base URL (default: `https://cloud.onyx.app`) |
|
||||
| `ONYX_API_KEY` | Yes | API key for authentication |
|
||||
| `ONYX_PERSONA_ID` | No | Default agent/persona ID |
|
||||
|
||||
If neither the config file nor environment variables are set, tell the user that `onyx-cli` needs to be configured and ask them to either:
|
||||
|
||||
- Run `onyx-cli configure` interactively, or
|
||||
- Set `ONYX_SERVER_URL` and `ONYX_API_KEY` environment variables
|
||||
|
||||
## Commands
|
||||
|
||||
### Validate configuration
|
||||
|
||||
```bash
|
||||
onyx-cli validate-config
|
||||
```
|
||||
|
||||
Checks config file exists, API key is present, and tests the server connection. Use this before `ask` or `agents` to confirm the CLI is properly set up.
|
||||
|
||||
### List available agents
|
||||
|
||||
```bash
|
||||
onyx-cli agents
|
||||
```
|
||||
|
||||
Prints a table of agent IDs, names, and descriptions. Use `--json` for structured output:
|
||||
|
||||
```bash
|
||||
onyx-cli agents --json
|
||||
```
|
||||
|
||||
Use agent IDs with `ask --agent-id` to query a specific agent.
|
||||
|
||||
### Basic query (plain text output)
|
||||
|
||||
```bash
|
||||
onyx-cli ask "What is our company's PTO policy?"
|
||||
```
|
||||
|
||||
Streams the answer as plain text to stdout. Exit code 0 on success, non-zero on error.
|
||||
|
||||
### JSON output (structured events)
|
||||
|
||||
```bash
|
||||
onyx-cli ask --json "What authentication methods do we support?"
|
||||
```
|
||||
|
||||
Outputs JSON-encoded parsed stream events (one object per line). Key event objects include message deltas, stop, errors, search-start, and citation payloads.
|
||||
|
||||
Each line is a JSON object with this envelope:
|
||||
|
||||
```json
|
||||
{"type": "<event_type>", "event": { ... }}
|
||||
```
|
||||
|
||||
| Event Type | Description |
|
||||
| ------------------- | -------------------------------------------------------------------- |
|
||||
| `message_delta` | Content token — concatenate all `content` fields for the full answer |
|
||||
| `stop` | Stream complete |
|
||||
| `error` | Error with `error` message field |
|
||||
| `search_tool_start` | Onyx started searching documents |
|
||||
| `citation_info` | Source citation — see shape below |
|
||||
|
||||
`citation_info` event shape:
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "citation_info",
|
||||
"event": {
|
||||
"citation_number": 1,
|
||||
"document_id": "abc123def456",
|
||||
"placement": { "turn_index": 0, "tab_index": 0, "sub_turn_index": null }
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
`placement` is metadata about where in the conversation the citation appeared and can be ignored for most use cases.
|
||||
|
||||
### Specify an agent
|
||||
|
||||
```bash
|
||||
onyx-cli ask --agent-id 5 "Summarize our Q4 roadmap"
|
||||
```
|
||||
|
||||
Uses a specific Onyx agent/persona instead of the default.
|
||||
|
||||
### All flags
|
||||
|
||||
| Flag | Type | Description |
|
||||
| ------------ | ---- | ---------------------------------------------- |
|
||||
| `--agent-id` | int | Agent ID to use (overrides default) |
|
||||
| `--json` | bool | Output raw NDJSON events instead of plain text |
|
||||
|
||||
## Statelessness
|
||||
|
||||
Each `onyx-cli ask` call creates an independent chat session. There is no built-in way to chain context across multiple `ask` invocations — every call starts fresh. If you need multi-turn conversation with memory, use the interactive TUI (`onyx-cli` or `onyx-cli chat`) instead.
|
||||
|
||||
## When to Use
|
||||
|
||||
Use `onyx-cli ask` when:
|
||||
|
||||
- The user asks about company-specific information (policies, docs, processes)
|
||||
- You need to search internal knowledge bases or connected data sources
|
||||
- The user references Onyx, asks you to "search Onyx", or wants to query their documents
|
||||
- You need context from company wikis, Confluence, Google Drive, Slack, or other connected sources
|
||||
|
||||
Do NOT use when:
|
||||
|
||||
- The question is about general programming knowledge (use your own knowledge)
|
||||
- The user is asking about code in the current repository (use grep/read tools)
|
||||
- The user hasn't mentioned Onyx and the question doesn't require internal company data
|
||||
|
||||
## Examples
|
||||
|
||||
```bash
|
||||
# Simple question
|
||||
onyx-cli ask "What are the steps to deploy to production?"
|
||||
|
||||
# Get structured output for parsing
|
||||
onyx-cli ask --json "List all active API integrations"
|
||||
|
||||
# Use a specialized agent
|
||||
onyx-cli ask --agent-id 3 "What were the action items from last week's standup?"
|
||||
|
||||
# Pipe the answer into another command
|
||||
onyx-cli ask "What is the database schema for users?" | head -20
|
||||
```
|
||||
7
cli/internal/embedded/embed.go
Normal file
7
cli/internal/embedded/embed.go
Normal file
@@ -0,0 +1,7 @@
|
||||
// Package embedded holds files that are compiled into the onyx-cli binary.
|
||||
package embedded
|
||||
|
||||
import _ "embed"
|
||||
|
||||
//go:embed SKILL.md
|
||||
var SkillMD string
|
||||
33
cli/internal/exitcodes/codes.go
Normal file
33
cli/internal/exitcodes/codes.go
Normal file
@@ -0,0 +1,33 @@
|
||||
// Package exitcodes defines semantic exit codes for the Onyx CLI.
|
||||
package exitcodes
|
||||
|
||||
import "fmt"
|
||||
|
||||
const (
|
||||
Success = 0
|
||||
General = 1
|
||||
BadRequest = 2 // invalid args / command-line errors (convention)
|
||||
NotConfigured = 3
|
||||
AuthFailure = 4
|
||||
Unreachable = 5
|
||||
)
|
||||
|
||||
// ExitError wraps an error with a specific exit code.
|
||||
type ExitError struct {
|
||||
Code int
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *ExitError) Error() string {
|
||||
return e.Err.Error()
|
||||
}
|
||||
|
||||
// New creates an ExitError with the given code and message.
|
||||
func New(code int, msg string) *ExitError {
|
||||
return &ExitError{Code: code, Err: fmt.Errorf("%s", msg)}
|
||||
}
|
||||
|
||||
// Newf creates an ExitError with a formatted message.
|
||||
func Newf(code int, format string, args ...any) *ExitError {
|
||||
return &ExitError{Code: code, Err: fmt.Errorf(format, args...)}
|
||||
}
|
||||
40
cli/internal/exitcodes/codes_test.go
Normal file
40
cli/internal/exitcodes/codes_test.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package exitcodes
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestExitError_Error(t *testing.T) {
|
||||
e := New(NotConfigured, "not configured")
|
||||
if e.Error() != "not configured" {
|
||||
t.Fatalf("expected 'not configured', got %q", e.Error())
|
||||
}
|
||||
if e.Code != NotConfigured {
|
||||
t.Fatalf("expected code %d, got %d", NotConfigured, e.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExitError_Newf(t *testing.T) {
|
||||
e := Newf(Unreachable, "cannot reach %s", "server")
|
||||
if e.Error() != "cannot reach server" {
|
||||
t.Fatalf("expected 'cannot reach server', got %q", e.Error())
|
||||
}
|
||||
if e.Code != Unreachable {
|
||||
t.Fatalf("expected code %d, got %d", Unreachable, e.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExitError_ErrorsAs(t *testing.T) {
|
||||
e := New(BadRequest, "bad input")
|
||||
wrapped := fmt.Errorf("wrapper: %w", e)
|
||||
|
||||
var exitErr *ExitError
|
||||
if !errors.As(wrapped, &exitErr) {
|
||||
t.Fatal("errors.As should find ExitError")
|
||||
}
|
||||
if exitErr.Code != BadRequest {
|
||||
t.Fatalf("expected code %d, got %d", BadRequest, exitErr.Code)
|
||||
}
|
||||
}
|
||||
50
cli/internal/fsutil/fsutil.go
Normal file
50
cli/internal/fsutil/fsutil.go
Normal file
@@ -0,0 +1,50 @@
|
||||
// Package fsutil provides filesystem helper functions.
|
||||
package fsutil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
)
|
||||
|
||||
// FileStatus describes how an on-disk file compares to expected content.
|
||||
type FileStatus int
|
||||
|
||||
const (
|
||||
StatusMissing FileStatus = iota
|
||||
StatusUpToDate // file exists with identical content
|
||||
StatusDiffers // file exists with different content
|
||||
)
|
||||
|
||||
// CompareFile checks whether the file at path matches the expected content.
|
||||
func CompareFile(path string, expected []byte) (FileStatus, error) {
|
||||
existing, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
return StatusMissing, nil
|
||||
}
|
||||
return 0, fmt.Errorf("could not read %s: %w", path, err)
|
||||
}
|
||||
if bytes.Equal(existing, expected) {
|
||||
return StatusUpToDate, nil
|
||||
}
|
||||
return StatusDiffers, nil
|
||||
}
|
||||
|
||||
// EnsureDirForCopy makes sure path is a real directory, not a symlink or
|
||||
// regular file. If a symlink or file exists at path it is removed so the
|
||||
// caller can create a directory with independent content.
|
||||
func EnsureDirForCopy(path string) error {
|
||||
info, err := os.Lstat(path)
|
||||
if err == nil {
|
||||
if info.Mode()&os.ModeSymlink != 0 || !info.IsDir() {
|
||||
if err := os.Remove(path); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else if !errors.Is(err, os.ErrNotExist) {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
116
cli/internal/fsutil/fsutil_test.go
Normal file
116
cli/internal/fsutil/fsutil_test.go
Normal file
@@ -0,0 +1,116 @@
|
||||
package fsutil
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// TestCompareFile verifies that CompareFile correctly distinguishes between a
|
||||
// missing file, a file with matching content, and a file with different content.
|
||||
func TestCompareFile(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
path := filepath.Join(tmpDir, "skill.md")
|
||||
expected := []byte("expected content")
|
||||
|
||||
status, err := CompareFile(path, expected)
|
||||
if err != nil {
|
||||
t.Fatalf("CompareFile on missing file failed: %v", err)
|
||||
}
|
||||
if status != StatusMissing {
|
||||
t.Fatalf("expected StatusMissing, got %v", status)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(path, expected, 0o644); err != nil {
|
||||
t.Fatalf("write expected file failed: %v", err)
|
||||
}
|
||||
status, err = CompareFile(path, expected)
|
||||
if err != nil {
|
||||
t.Fatalf("CompareFile on matching file failed: %v", err)
|
||||
}
|
||||
if status != StatusUpToDate {
|
||||
t.Fatalf("expected StatusUpToDate, got %v", status)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(path, []byte("different content"), 0o644); err != nil {
|
||||
t.Fatalf("write different file failed: %v", err)
|
||||
}
|
||||
status, err = CompareFile(path, expected)
|
||||
if err != nil {
|
||||
t.Fatalf("CompareFile on different file failed: %v", err)
|
||||
}
|
||||
if status != StatusDiffers {
|
||||
t.Fatalf("expected StatusDiffers, got %v", status)
|
||||
}
|
||||
}
|
||||
|
||||
// TestEnsureDirForCopy verifies that EnsureDirForCopy clears symlinks and
|
||||
// regular files so --copy can write a real directory, while leaving existing
|
||||
// directories and missing paths untouched.
|
||||
func TestEnsureDirForCopy(t *testing.T) {
|
||||
t.Run("removes symlink", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
targetDir := filepath.Join(tmpDir, "target")
|
||||
linkPath := filepath.Join(tmpDir, "link")
|
||||
|
||||
if err := os.MkdirAll(targetDir, 0o755); err != nil {
|
||||
t.Fatalf("mkdir target failed: %v", err)
|
||||
}
|
||||
if err := os.Symlink(targetDir, linkPath); err != nil {
|
||||
t.Fatalf("create symlink failed: %v", err)
|
||||
}
|
||||
|
||||
if err := EnsureDirForCopy(linkPath); err != nil {
|
||||
t.Fatalf("EnsureDirForCopy failed: %v", err)
|
||||
}
|
||||
|
||||
if _, err := os.Lstat(linkPath); !os.IsNotExist(err) {
|
||||
t.Fatalf("expected symlink path to be removed, got err=%v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("removes regular file", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
filePath := filepath.Join(tmpDir, "onyx-cli")
|
||||
if err := os.WriteFile(filePath, []byte("x"), 0o644); err != nil {
|
||||
t.Fatalf("write file failed: %v", err)
|
||||
}
|
||||
|
||||
if err := EnsureDirForCopy(filePath); err != nil {
|
||||
t.Fatalf("EnsureDirForCopy failed: %v", err)
|
||||
}
|
||||
|
||||
if _, err := os.Lstat(filePath); !os.IsNotExist(err) {
|
||||
t.Fatalf("expected file path to be removed, got err=%v", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("keeps existing directory", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
dirPath := filepath.Join(tmpDir, "onyx-cli")
|
||||
if err := os.MkdirAll(dirPath, 0o755); err != nil {
|
||||
t.Fatalf("mkdir failed: %v", err)
|
||||
}
|
||||
|
||||
if err := EnsureDirForCopy(dirPath); err != nil {
|
||||
t.Fatalf("EnsureDirForCopy failed: %v", err)
|
||||
}
|
||||
|
||||
info, err := os.Lstat(dirPath)
|
||||
if err != nil {
|
||||
t.Fatalf("lstat directory failed: %v", err)
|
||||
}
|
||||
if !info.IsDir() {
|
||||
t.Fatalf("expected directory to remain, got mode %v", info.Mode())
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("missing path is no-op", func(t *testing.T) {
|
||||
tmpDir := t.TempDir()
|
||||
missingPath := filepath.Join(tmpDir, "does-not-exist")
|
||||
|
||||
if err := EnsureDirForCopy(missingPath); err != nil {
|
||||
t.Fatalf("EnsureDirForCopy failed: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
121
cli/internal/overflow/writer.go
Normal file
121
cli/internal/overflow/writer.go
Normal file
@@ -0,0 +1,121 @@
|
||||
// Package overflow provides a streaming writer that auto-truncates output
|
||||
// for non-TTY callers (e.g., AI agents, scripts). Full content is saved to
|
||||
// a temp file on disk; only the first N bytes are printed to stdout.
|
||||
package overflow
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Writer handles streaming output with optional truncation.
|
||||
// When Limit > 0, it streams to a temp file on disk (not memory) and stops
|
||||
// writing to stdout after Limit bytes. When Limit == 0, it writes directly
|
||||
// to stdout. In Quiet mode, it buffers in memory and prints once at the end.
|
||||
type Writer struct {
|
||||
Limit int
|
||||
Quiet bool
|
||||
written int
|
||||
totalBytes int
|
||||
truncated bool
|
||||
buf strings.Builder // used only in quiet mode
|
||||
tmpFile *os.File // used only in truncation mode (Limit > 0)
|
||||
}
|
||||
|
||||
// Write sends a chunk of content through the writer.
|
||||
func (w *Writer) Write(s string) {
|
||||
w.totalBytes += len(s)
|
||||
|
||||
// Quiet mode: buffer in memory, print nothing
|
||||
if w.Quiet {
|
||||
w.buf.WriteString(s)
|
||||
return
|
||||
}
|
||||
|
||||
if w.Limit <= 0 {
|
||||
fmt.Print(s)
|
||||
return
|
||||
}
|
||||
|
||||
// Truncation mode: stream all content to temp file on disk
|
||||
if w.tmpFile == nil {
|
||||
f, err := os.CreateTemp("", "onyx-ask-*.txt")
|
||||
if err != nil {
|
||||
// Fall back to no-truncation if we can't create the file
|
||||
fmt.Fprintf(os.Stderr, "warning: could not create temp file: %v\n", err)
|
||||
w.Limit = 0
|
||||
fmt.Print(s)
|
||||
return
|
||||
}
|
||||
w.tmpFile = f
|
||||
}
|
||||
if _, err := w.tmpFile.WriteString(s); err != nil {
|
||||
// Disk write failed — abandon truncation, stream directly to stdout
|
||||
fmt.Fprintf(os.Stderr, "warning: temp file write failed: %v\n", err)
|
||||
w.closeTmpFile(true)
|
||||
w.Limit = 0
|
||||
w.truncated = false
|
||||
fmt.Print(s)
|
||||
return
|
||||
}
|
||||
|
||||
if w.truncated {
|
||||
return
|
||||
}
|
||||
|
||||
remaining := w.Limit - w.written
|
||||
if len(s) <= remaining {
|
||||
fmt.Print(s)
|
||||
w.written += len(s)
|
||||
} else {
|
||||
if remaining > 0 {
|
||||
fmt.Print(s[:remaining])
|
||||
w.written += remaining
|
||||
}
|
||||
w.truncated = true
|
||||
}
|
||||
}
|
||||
|
||||
// Finish flushes remaining output. Call once after all Write calls are done.
|
||||
func (w *Writer) Finish() {
|
||||
// Quiet mode: print buffered content at once
|
||||
if w.Quiet {
|
||||
fmt.Println(w.buf.String())
|
||||
return
|
||||
}
|
||||
|
||||
if !w.truncated {
|
||||
w.closeTmpFile(true) // clean up unused temp file
|
||||
fmt.Println()
|
||||
return
|
||||
}
|
||||
|
||||
// Close the temp file so it's readable
|
||||
tmpPath := w.tmpFile.Name()
|
||||
w.closeTmpFile(false) // close but keep the file
|
||||
|
||||
fmt.Printf("\n\n--- response truncated (%d bytes total) ---\n", w.totalBytes)
|
||||
fmt.Printf("Full response: %s\n", tmpPath)
|
||||
fmt.Printf("Explore:\n")
|
||||
fmt.Printf(" cat %s | grep \"<pattern>\"\n", tmpPath)
|
||||
fmt.Printf(" cat %s | tail -50\n", tmpPath)
|
||||
}
|
||||
|
||||
// closeTmpFile closes and optionally removes the temp file.
|
||||
func (w *Writer) closeTmpFile(remove bool) {
|
||||
if w.tmpFile == nil {
|
||||
return
|
||||
}
|
||||
if err := w.tmpFile.Close(); err != nil {
|
||||
log.Debugf("warning: failed to close temp file: %v", err)
|
||||
}
|
||||
if remove {
|
||||
if err := os.Remove(w.tmpFile.Name()); err != nil {
|
||||
log.Debugf("warning: failed to remove temp file: %v", err)
|
||||
}
|
||||
}
|
||||
w.tmpFile = nil
|
||||
}
|
||||
95
cli/internal/overflow/writer_test.go
Normal file
95
cli/internal/overflow/writer_test.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package overflow
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestWriter_NoLimit(t *testing.T) {
|
||||
w := &Writer{Limit: 0}
|
||||
w.Write("hello world")
|
||||
if w.truncated {
|
||||
t.Fatal("should not be truncated with limit 0")
|
||||
}
|
||||
if w.totalBytes != 11 {
|
||||
t.Fatalf("expected 11 total bytes, got %d", w.totalBytes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriter_UnderLimit(t *testing.T) {
|
||||
w := &Writer{Limit: 100}
|
||||
w.Write("hello")
|
||||
w.Write(" world")
|
||||
if w.truncated {
|
||||
t.Fatal("should not be truncated when under limit")
|
||||
}
|
||||
if w.written != 11 {
|
||||
t.Fatalf("expected 11 written bytes, got %d", w.written)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriter_OverLimit(t *testing.T) {
|
||||
w := &Writer{Limit: 5}
|
||||
w.Write("hello world") // 11 bytes, limit 5
|
||||
if !w.truncated {
|
||||
t.Fatal("should be truncated")
|
||||
}
|
||||
if w.written != 5 {
|
||||
t.Fatalf("expected 5 written bytes, got %d", w.written)
|
||||
}
|
||||
if w.totalBytes != 11 {
|
||||
t.Fatalf("expected 11 total bytes, got %d", w.totalBytes)
|
||||
}
|
||||
if w.tmpFile == nil {
|
||||
t.Fatal("temp file should have been created")
|
||||
}
|
||||
_ = w.tmpFile.Close()
|
||||
data, _ := os.ReadFile(w.tmpFile.Name())
|
||||
_ = os.Remove(w.tmpFile.Name())
|
||||
if string(data) != "hello world" {
|
||||
t.Fatalf("temp file should contain full content, got %q", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriter_MultipleChunks(t *testing.T) {
|
||||
w := &Writer{Limit: 10}
|
||||
w.Write("hello") // 5 bytes
|
||||
w.Write(" ") // 6 bytes
|
||||
w.Write("world") // 11 bytes, crosses limit
|
||||
w.Write("!") // 12 bytes, already truncated
|
||||
|
||||
if !w.truncated {
|
||||
t.Fatal("should be truncated")
|
||||
}
|
||||
if w.written != 10 {
|
||||
t.Fatalf("expected 10 written bytes, got %d", w.written)
|
||||
}
|
||||
if w.totalBytes != 12 {
|
||||
t.Fatalf("expected 12 total bytes, got %d", w.totalBytes)
|
||||
}
|
||||
if w.tmpFile == nil {
|
||||
t.Fatal("temp file should have been created")
|
||||
}
|
||||
_ = w.tmpFile.Close()
|
||||
data, _ := os.ReadFile(w.tmpFile.Name())
|
||||
_ = os.Remove(w.tmpFile.Name())
|
||||
if string(data) != "hello world!" {
|
||||
t.Fatalf("temp file should contain full content, got %q", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriter_QuietMode(t *testing.T) {
|
||||
w := &Writer{Limit: 0, Quiet: true}
|
||||
w.Write("hello")
|
||||
w.Write(" world")
|
||||
|
||||
if w.written != 0 {
|
||||
t.Fatalf("quiet mode should not write to stdout, got %d written", w.written)
|
||||
}
|
||||
if w.totalBytes != 11 {
|
||||
t.Fatalf("expected 11 total bytes, got %d", w.totalBytes)
|
||||
}
|
||||
if w.buf.String() != "hello world" {
|
||||
t.Fatalf("buffer should contain full content, got %q", w.buf.String())
|
||||
}
|
||||
}
|
||||
83
cli/internal/starprompt/starprompt.go
Normal file
83
cli/internal/starprompt/starprompt.go
Normal file
@@ -0,0 +1,83 @@
|
||||
// Package starprompt implements a one-time GitHub star prompt shown before the TUI.
|
||||
// Skipped when stdin/stdout is not a TTY, when gh CLI is not installed,
|
||||
// or when the user has already been prompted. State is stored in the
|
||||
// config directory so it shows at most once per user.
|
||||
package starprompt
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/config"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
const repo = "onyx-dot-app/onyx"
|
||||
|
||||
func statePath() string {
|
||||
return filepath.Join(config.ConfigDir(), ".star-prompted")
|
||||
}
|
||||
|
||||
func hasBeenPrompted() bool {
|
||||
_, err := os.Stat(statePath())
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func markPrompted() {
|
||||
_ = os.MkdirAll(config.ConfigDir(), 0o755)
|
||||
f, err := os.Create(statePath())
|
||||
if err == nil {
|
||||
_ = f.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func isGHInstalled() bool {
|
||||
_, err := exec.LookPath("gh")
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// MaybePrompt shows a one-time star prompt if conditions are met.
|
||||
// It is safe to call unconditionally — it no-ops when not appropriate.
|
||||
func MaybePrompt() {
|
||||
if !term.IsTerminal(int(os.Stdin.Fd())) || !term.IsTerminal(int(os.Stdout.Fd())) {
|
||||
return
|
||||
}
|
||||
if hasBeenPrompted() {
|
||||
return
|
||||
}
|
||||
if !isGHInstalled() {
|
||||
return
|
||||
}
|
||||
|
||||
// Mark before asking so Ctrl+C won't cause a re-prompt.
|
||||
markPrompted()
|
||||
|
||||
fmt.Print("Enjoying Onyx? Star the repo on GitHub? [Y/n] ")
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
answer, _ := reader.ReadString('\n')
|
||||
answer = strings.TrimSpace(strings.ToLower(answer))
|
||||
|
||||
if answer == "n" || answer == "no" {
|
||||
return
|
||||
}
|
||||
|
||||
cmd := exec.Command("gh", "api", "-X", "PUT", "/user/starred/"+repo)
|
||||
cmd.Env = append(os.Environ(), "GH_PAGER=")
|
||||
if devnull, err := os.Open(os.DevNull); err == nil {
|
||||
defer func() { _ = devnull.Close() }()
|
||||
cmd.Stdin = devnull
|
||||
cmd.Stdout = devnull
|
||||
cmd.Stderr = devnull
|
||||
}
|
||||
if err := cmd.Run(); err != nil {
|
||||
fmt.Println("Star us at: https://github.com/" + repo)
|
||||
} else {
|
||||
fmt.Println("Thanks for the star!")
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
@@ -1,10 +1,12 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/onyx-dot-app/onyx/cli/cmd"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/exitcodes"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -18,6 +20,10 @@ func main() {
|
||||
|
||||
if err := cmd.Execute(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
var exitErr *exitcodes.ExitError
|
||||
if errors.As(err, &exitErr) {
|
||||
os.Exit(exitErr.Code)
|
||||
}
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1302,4 +1302,18 @@ echo ""
|
||||
print_info "Refer to the README in the ${INSTALL_ROOT} directory for more information."
|
||||
echo ""
|
||||
print_info "For help or issues, contact: founders@onyx.app"
|
||||
echo ""
|
||||
echo ""
|
||||
|
||||
# --- GitHub star prompt (inspired by oh-my-codex) ---
|
||||
# Only prompt in interactive mode and only if gh CLI is available.
|
||||
# Uses the GitHub API directly (PUT /user/starred) like oh-my-codex.
|
||||
if is_interactive && command -v gh &>/dev/null; then
|
||||
prompt_yn_or_default "Enjoying Onyx? Star the repo on GitHub? [Y/n] " "Y"
|
||||
if [[ ! "$REPLY" =~ ^[Nn] ]]; then
|
||||
if GH_PAGER= gh api -X PUT /user/starred/onyx-dot-app/onyx < /dev/null >/dev/null 2>&1; then
|
||||
print_success "Thanks for the star!"
|
||||
else
|
||||
print_info "Star us at: https://github.com/onyx-dot-app/onyx"
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
@@ -193,9 +193,9 @@ hover, active, and disabled states.
|
||||
|
||||
### Disabled (`core/disabled/`)
|
||||
|
||||
Propagates disabled state via React context. `Interactive.Stateless` and `Interactive.Stateful`
|
||||
consume this automatically, so wrapping a subtree in `<Disabled disabled={true}>` disables all
|
||||
interactive descendants.
|
||||
A pure CSS wrapper that applies disabled visuals (`opacity-50`, `cursor-not-allowed`,
|
||||
`pointer-events: none`) to a single child element via Radix `Slot`. Has no React context —
|
||||
Interactive primitives and buttons manage their own disabled state via a `disabled` prop.
|
||||
|
||||
### Hoverable (`core/animations/`)
|
||||
|
||||
|
||||
@@ -1,9 +1,5 @@
|
||||
import "@opal/components/tooltip.css";
|
||||
import {
|
||||
Disabled,
|
||||
Interactive,
|
||||
type InteractiveStatelessProps,
|
||||
} from "@opal/core";
|
||||
import { Interactive, type InteractiveStatelessProps } from "@opal/core";
|
||||
import type {
|
||||
ContainerSizeVariants,
|
||||
ExtremaSizeVariants,
|
||||
@@ -49,7 +45,7 @@ type ButtonProps = InteractiveStatelessProps &
|
||||
/** Which side the tooltip appears on. */
|
||||
tooltipSide?: TooltipSide;
|
||||
|
||||
/** Wraps the button in a Disabled context. `false` overrides parent contexts. */
|
||||
/** Applies disabled styling and suppresses clicks. */
|
||||
disabled?: boolean;
|
||||
};
|
||||
|
||||
@@ -94,7 +90,11 @@ function Button({
|
||||
) : null;
|
||||
|
||||
const button = (
|
||||
<Interactive.Stateless type={type} {...interactiveProps}>
|
||||
<Interactive.Stateless
|
||||
type={type}
|
||||
disabled={disabled}
|
||||
{...interactiveProps}
|
||||
>
|
||||
<Interactive.Container
|
||||
type={type}
|
||||
border={interactiveProps.prominence === "secondary"}
|
||||
@@ -118,28 +118,24 @@ function Button({
|
||||
</Interactive.Stateless>
|
||||
);
|
||||
|
||||
const result = tooltip ? (
|
||||
<TooltipPrimitive.Root>
|
||||
<TooltipPrimitive.Trigger asChild>{button}</TooltipPrimitive.Trigger>
|
||||
<TooltipPrimitive.Portal>
|
||||
<TooltipPrimitive.Content
|
||||
className="opal-tooltip"
|
||||
side={tooltipSide}
|
||||
sideOffset={4}
|
||||
>
|
||||
{tooltip}
|
||||
</TooltipPrimitive.Content>
|
||||
</TooltipPrimitive.Portal>
|
||||
</TooltipPrimitive.Root>
|
||||
) : (
|
||||
button
|
||||
);
|
||||
|
||||
if (disabled != null) {
|
||||
return <Disabled disabled={disabled}>{result}</Disabled>;
|
||||
if (tooltip) {
|
||||
return (
|
||||
<TooltipPrimitive.Root>
|
||||
<TooltipPrimitive.Trigger asChild>{button}</TooltipPrimitive.Trigger>
|
||||
<TooltipPrimitive.Portal>
|
||||
<TooltipPrimitive.Content
|
||||
className="opal-tooltip"
|
||||
side={tooltipSide}
|
||||
sideOffset={4}
|
||||
>
|
||||
{tooltip}
|
||||
</TooltipPrimitive.Content>
|
||||
</TooltipPrimitive.Portal>
|
||||
</TooltipPrimitive.Root>
|
||||
);
|
||||
}
|
||||
|
||||
return result;
|
||||
return button;
|
||||
}
|
||||
|
||||
export { Button, type ButtonProps };
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import {
|
||||
Interactive,
|
||||
useDisabled,
|
||||
type InteractiveStatefulProps,
|
||||
type InteractiveStatefulInteraction,
|
||||
} from "@opal/core";
|
||||
@@ -74,6 +73,9 @@ type OpenButtonProps = Omit<InteractiveStatefulProps, "variant"> & {
|
||||
|
||||
/** Override the default rounding derived from `size`. */
|
||||
roundingVariant?: InteractiveContainerRoundingVariant;
|
||||
|
||||
/** Applies disabled styling and suppresses clicks. */
|
||||
disabled?: boolean;
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -92,10 +94,9 @@ function OpenButton({
|
||||
roundingVariant: roundingVariantOverride,
|
||||
interaction,
|
||||
variant = "select-heavy",
|
||||
disabled,
|
||||
...statefulProps
|
||||
}: OpenButtonProps) {
|
||||
const { isDisabled } = useDisabled();
|
||||
|
||||
// Derive open state: explicit prop → Radix data-state (injected via Slot chain)
|
||||
const dataState = (statefulProps as Record<string, unknown>)["data-state"] as
|
||||
| string
|
||||
@@ -119,6 +120,7 @@ function OpenButton({
|
||||
<Interactive.Stateful
|
||||
variant={variant}
|
||||
interaction={resolvedInteraction}
|
||||
disabled={disabled}
|
||||
{...statefulProps}
|
||||
>
|
||||
<Interactive.Container
|
||||
@@ -168,7 +170,7 @@ function OpenButton({
|
||||
);
|
||||
|
||||
const resolvedTooltip =
|
||||
tooltip ?? (foldable && isDisabled && children ? children : undefined);
|
||||
tooltip ?? (foldable && disabled && children ? children : undefined);
|
||||
|
||||
if (!resolvedTooltip) return button;
|
||||
|
||||
|
||||
@@ -1,11 +1,7 @@
|
||||
"use client";
|
||||
|
||||
import "@opal/components/buttons/select-button/styles.css";
|
||||
import {
|
||||
Interactive,
|
||||
useDisabled,
|
||||
type InteractiveStatefulProps,
|
||||
} from "@opal/core";
|
||||
import { Interactive, type InteractiveStatefulProps } from "@opal/core";
|
||||
import type {
|
||||
ContainerSizeVariants,
|
||||
ExtremaSizeVariants,
|
||||
@@ -64,6 +60,9 @@ type SelectButtonProps = InteractiveStatefulProps &
|
||||
|
||||
/** Which side the tooltip appears on. */
|
||||
tooltipSide?: TooltipSide;
|
||||
|
||||
/** Applies disabled styling and suppresses clicks. */
|
||||
disabled?: boolean;
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -80,9 +79,9 @@ function SelectButton({
|
||||
width,
|
||||
tooltip,
|
||||
tooltipSide = "top",
|
||||
disabled,
|
||||
...statefulProps
|
||||
}: SelectButtonProps) {
|
||||
const { isDisabled } = useDisabled();
|
||||
const isLarge = size === "lg";
|
||||
|
||||
const labelEl = children ? (
|
||||
@@ -96,7 +95,7 @@ function SelectButton({
|
||||
) : null;
|
||||
|
||||
const button = (
|
||||
<Interactive.Stateful {...statefulProps}>
|
||||
<Interactive.Stateful disabled={disabled} {...statefulProps}>
|
||||
<Interactive.Container
|
||||
type={type}
|
||||
heightVariant={size}
|
||||
@@ -128,7 +127,7 @@ function SelectButton({
|
||||
);
|
||||
|
||||
const resolvedTooltip =
|
||||
tooltip ?? (foldable && isDisabled && children ? children : undefined);
|
||||
tooltip ?? (foldable && disabled && children ? children : undefined);
|
||||
|
||||
if (!resolvedTooltip) return button;
|
||||
|
||||
|
||||
64
web/lib/opal/src/components/buttons/sidebar-tab/README.md
Normal file
64
web/lib/opal/src/components/buttons/sidebar-tab/README.md
Normal file
@@ -0,0 +1,64 @@
|
||||
# SidebarTab
|
||||
|
||||
**Import:** `import { SidebarTab, type SidebarTabProps } from "@opal/components";`
|
||||
|
||||
A sidebar navigation tab built on `Interactive.Stateful` > `Interactive.Container`. Designed for admin and app sidebars.
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
div.relative
|
||||
└─ Interactive.Stateful <- variant (sidebar-heavy | sidebar-light), state, disabled
|
||||
└─ Interactive.Container <- rounding, height, width
|
||||
├─ Link? (absolute overlay for client-side navigation)
|
||||
├─ rightChildren? (absolute, above Link for inline actions)
|
||||
└─ ContentAction (icon + title + truncation spacer)
|
||||
```
|
||||
|
||||
- **`sidebar-heavy`** (default) — muted when unselected (text-03/text-02), bold when selected (text-04/text-03)
|
||||
- **`sidebar-light`** — uniformly muted across all states (text-02/text-02)
|
||||
- **Disabled** — both variants use text-02 foreground, transparent background, no hover/active states
|
||||
- **Navigation** uses an absolutely positioned `<Link>` overlay rather than `href` on the Interactive element, so `rightChildren` can sit above it with `pointer-events-auto`.
|
||||
|
||||
## Props
|
||||
|
||||
| Prop | Type | Default | Description |
|
||||
|------|------|---------|-------------|
|
||||
| `variant` | `"sidebar-heavy" \| "sidebar-light"` | `"sidebar-heavy"` | Sidebar color variant |
|
||||
| `selected` | `boolean` | `false` | Active/selected state |
|
||||
| `icon` | `IconFunctionComponent` | — | Left icon |
|
||||
| `children` | `ReactNode` | — | Label text or custom content |
|
||||
| `disabled` | `boolean` | `false` | Disables the tab |
|
||||
| `folded` | `boolean` | `false` | Collapses label, shows tooltip on hover |
|
||||
| `nested` | `boolean` | `false` | Renders spacer instead of icon for indented items |
|
||||
| `href` | `string` | — | Client-side navigation URL |
|
||||
| `onClick` | `MouseEventHandler` | — | Click handler |
|
||||
| `type` | `ButtonType` | — | HTML button type |
|
||||
| `rightChildren` | `ReactNode` | — | Actions rendered on the right side |
|
||||
|
||||
## Usage
|
||||
|
||||
```tsx
|
||||
import { SidebarTab } from "@opal/components";
|
||||
import { SvgSettings, SvgLock } from "@opal/icons";
|
||||
|
||||
// Active tab
|
||||
<SidebarTab icon={SvgSettings} href="/admin/settings" selected>
|
||||
Settings
|
||||
</SidebarTab>
|
||||
|
||||
// Muted variant
|
||||
<SidebarTab icon={SvgSettings} variant="sidebar-light">
|
||||
Exit Admin Panel
|
||||
</SidebarTab>
|
||||
|
||||
// Disabled enterprise-only tab
|
||||
<SidebarTab icon={SvgLock} disabled>
|
||||
Groups
|
||||
</SidebarTab>
|
||||
|
||||
// Folded sidebar (icon only, tooltip on hover)
|
||||
<SidebarTab icon={SvgSettings} href="/admin/settings" folded>
|
||||
Settings
|
||||
</SidebarTab>
|
||||
```
|
||||
@@ -0,0 +1,94 @@
|
||||
import type { Meta, StoryObj } from "@storybook/react";
|
||||
import { SidebarTab } from "@opal/components/buttons/sidebar-tab/components";
|
||||
import {
|
||||
SvgSettings,
|
||||
SvgUsers,
|
||||
SvgLock,
|
||||
SvgArrowUpCircle,
|
||||
SvgTrash,
|
||||
} from "@opal/icons";
|
||||
import { Button } from "@opal/components";
|
||||
import * as TooltipPrimitive from "@radix-ui/react-tooltip";
|
||||
|
||||
const meta: Meta<typeof SidebarTab> = {
|
||||
title: "opal/components/SidebarTab",
|
||||
component: SidebarTab,
|
||||
tags: ["autodocs"],
|
||||
decorators: [
|
||||
(Story) => (
|
||||
<TooltipPrimitive.Provider>
|
||||
<div style={{ width: 260, background: "var(--background-neutral-01)" }}>
|
||||
<Story />
|
||||
</div>
|
||||
</TooltipPrimitive.Provider>
|
||||
),
|
||||
],
|
||||
};
|
||||
|
||||
export default meta;
|
||||
type Story = StoryObj<typeof SidebarTab>;
|
||||
|
||||
export const Default: Story = {
|
||||
args: {
|
||||
icon: SvgSettings,
|
||||
children: "Settings",
|
||||
},
|
||||
};
|
||||
|
||||
export const Selected: Story = {
|
||||
args: {
|
||||
icon: SvgSettings,
|
||||
children: "Settings",
|
||||
selected: true,
|
||||
},
|
||||
};
|
||||
|
||||
export const Light: Story = {
|
||||
args: {
|
||||
icon: SvgSettings,
|
||||
children: "Settings",
|
||||
variant: "sidebar-light",
|
||||
},
|
||||
};
|
||||
|
||||
export const Disabled: Story = {
|
||||
args: {
|
||||
icon: SvgLock,
|
||||
children: "Enterprise Only",
|
||||
disabled: true,
|
||||
},
|
||||
};
|
||||
|
||||
export const WithRightChildren: Story = {
|
||||
args: {
|
||||
icon: SvgUsers,
|
||||
children: "Users",
|
||||
rightChildren: (
|
||||
<Button
|
||||
icon={SvgTrash}
|
||||
size="xs"
|
||||
prominence="tertiary"
|
||||
variant="danger"
|
||||
/>
|
||||
),
|
||||
},
|
||||
};
|
||||
|
||||
export const SidebarExample: Story = {
|
||||
render: () => (
|
||||
<div className="flex flex-col">
|
||||
<SidebarTab icon={SvgSettings} selected>
|
||||
LLM Models
|
||||
</SidebarTab>
|
||||
<SidebarTab icon={SvgSettings}>Web Search</SidebarTab>
|
||||
<SidebarTab icon={SvgUsers}>Users</SidebarTab>
|
||||
<SidebarTab icon={SvgLock} disabled>
|
||||
Groups
|
||||
</SidebarTab>
|
||||
<SidebarTab icon={SvgLock} disabled>
|
||||
SCIM
|
||||
</SidebarTab>
|
||||
<SidebarTab icon={SvgArrowUpCircle}>Upgrade Plan</SidebarTab>
|
||||
</div>
|
||||
),
|
||||
};
|
||||
@@ -3,32 +3,66 @@
|
||||
import React from "react";
|
||||
import type { ButtonType, IconFunctionComponent } from "@opal/types";
|
||||
import type { Route } from "next";
|
||||
import { Interactive } from "@opal/core";
|
||||
import { Interactive, type InteractiveStatefulVariant } from "@opal/core";
|
||||
import { ContentAction } from "@opal/layouts";
|
||||
import { Text } from "@opal/components";
|
||||
import Link from "next/link";
|
||||
import SimpleTooltip from "@/refresh-components/SimpleTooltip";
|
||||
import * as TooltipPrimitive from "@radix-ui/react-tooltip";
|
||||
import "@opal/components/tooltip.css";
|
||||
|
||||
export interface SidebarTabProps {
|
||||
// Button states:
|
||||
// ---------------------------------------------------------------------------
|
||||
// Types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
interface SidebarTabProps {
|
||||
/** Collapses the label, showing only the icon. */
|
||||
folded?: boolean;
|
||||
|
||||
/** Marks this tab as the currently active/selected item. */
|
||||
selected?: boolean;
|
||||
lowlight?: boolean;
|
||||
|
||||
/**
|
||||
* Sidebar color variant.
|
||||
* @default "sidebar-heavy"
|
||||
*/
|
||||
variant?: Extract<
|
||||
InteractiveStatefulVariant,
|
||||
"sidebar-light" | "sidebar-heavy"
|
||||
>;
|
||||
|
||||
/** Renders an empty spacer in place of the icon for nested items. */
|
||||
nested?: boolean;
|
||||
|
||||
// Button properties:
|
||||
/** Disables the tab — applies muted colors and suppresses clicks. */
|
||||
disabled?: boolean;
|
||||
|
||||
onClick?: React.MouseEventHandler<HTMLElement>;
|
||||
href?: string;
|
||||
type?: ButtonType;
|
||||
icon?: IconFunctionComponent;
|
||||
children?: React.ReactNode;
|
||||
|
||||
/** Content rendered on the right side (e.g. action buttons). */
|
||||
rightChildren?: React.ReactNode;
|
||||
}
|
||||
|
||||
export default function SidebarTab({
|
||||
// ---------------------------------------------------------------------------
|
||||
// SidebarTab
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Sidebar navigation tab built on `Interactive.Stateful` > `Interactive.Container`.
|
||||
*
|
||||
* Uses `sidebar-heavy` (default) or `sidebar-light` (via `variant`) variants
|
||||
* for color styling. Supports an overlay `Link` for client-side navigation,
|
||||
* `rightChildren` for inline actions, and folded mode with an auto-tooltip.
|
||||
*/
|
||||
function SidebarTab({
|
||||
folded,
|
||||
selected,
|
||||
lowlight,
|
||||
variant = "sidebar-heavy",
|
||||
nested,
|
||||
disabled,
|
||||
|
||||
onClick,
|
||||
href,
|
||||
@@ -45,11 +79,8 @@ export default function SidebarTab({
|
||||
)) as IconFunctionComponent)
|
||||
: null);
|
||||
|
||||
// NOTE (@raunakab)
|
||||
//
|
||||
// The `rightChildren` node NEEDS to be absolutely positioned since it needs to live on top of the absolutely positioned `Link`.
|
||||
// However, having the `rightChildren` be absolutely positioned means that it cannot appropriately truncate the title.
|
||||
// Therefore, we add a dummy node solely for the truncation effects that we obtain.
|
||||
// The `rightChildren` node is absolutely positioned to sit on top of the
|
||||
// overlay Link. A zero-width spacer reserves truncation space for the title.
|
||||
const truncationSpacer = rightChildren && (
|
||||
<div className="w-0 group-hover/SidebarTab:w-6" />
|
||||
);
|
||||
@@ -57,8 +88,9 @@ export default function SidebarTab({
|
||||
const content = (
|
||||
<div className="relative">
|
||||
<Interactive.Stateful
|
||||
variant={lowlight ? "sidebar-light" : "sidebar-heavy"}
|
||||
variant={variant}
|
||||
state={selected ? "selected" : "empty"}
|
||||
disabled={disabled}
|
||||
onClick={onClick}
|
||||
type="button"
|
||||
group="group/SidebarTab"
|
||||
@@ -69,7 +101,7 @@ export default function SidebarTab({
|
||||
widthVariant="full"
|
||||
type={type}
|
||||
>
|
||||
{href && (
|
||||
{href && !disabled && (
|
||||
<Link
|
||||
href={href as Route}
|
||||
scroll={false}
|
||||
@@ -95,19 +127,14 @@ export default function SidebarTab({
|
||||
rightChildren={truncationSpacer}
|
||||
/>
|
||||
) : (
|
||||
<div className="flex flex-row items-center gap-2 flex-1">
|
||||
<div className="flex flex-row items-center gap-2 w-full">
|
||||
{Icon && (
|
||||
<div className="flex items-center justify-center p-0.5">
|
||||
<Icon className="h-[1rem] w-[1rem] text-text-03" />
|
||||
</div>
|
||||
)}
|
||||
{children}
|
||||
{
|
||||
// NOTE (@raunakab)
|
||||
//
|
||||
// Adding the `truncationSpacer` here for the same reason as above.
|
||||
truncationSpacer
|
||||
}
|
||||
{truncationSpacer}
|
||||
</div>
|
||||
)}
|
||||
</Interactive.Container>
|
||||
@@ -116,7 +143,23 @@ export default function SidebarTab({
|
||||
);
|
||||
|
||||
if (typeof children !== "string") return content;
|
||||
if (folded)
|
||||
return <SimpleTooltip tooltip={children}>{content}</SimpleTooltip>;
|
||||
if (folded) {
|
||||
return (
|
||||
<TooltipPrimitive.Root>
|
||||
<TooltipPrimitive.Trigger asChild>{content}</TooltipPrimitive.Trigger>
|
||||
<TooltipPrimitive.Portal>
|
||||
<TooltipPrimitive.Content
|
||||
className="opal-tooltip"
|
||||
side="right"
|
||||
sideOffset={4}
|
||||
>
|
||||
{children}
|
||||
</TooltipPrimitive.Content>
|
||||
</TooltipPrimitive.Portal>
|
||||
</TooltipPrimitive.Root>
|
||||
);
|
||||
}
|
||||
return content;
|
||||
}
|
||||
|
||||
export { SidebarTab, type SidebarTabProps };
|
||||
@@ -33,6 +33,12 @@ export {
|
||||
type LineItemButtonProps,
|
||||
} from "@opal/components/buttons/line-item-button/components";
|
||||
|
||||
/* SidebarTab */
|
||||
export {
|
||||
SidebarTab,
|
||||
type SidebarTabProps,
|
||||
} from "@opal/components/buttons/sidebar-tab/components";
|
||||
|
||||
/* Text */
|
||||
export {
|
||||
Text,
|
||||
|
||||
@@ -586,7 +586,10 @@ export function Table<TData>(props: DataTableProps<TData>) {
|
||||
|
||||
// Data / Display cell
|
||||
return (
|
||||
<TableCell key={cell.id}>
|
||||
<TableCell
|
||||
key={cell.id}
|
||||
data-column-id={cell.column.id}
|
||||
>
|
||||
{flexRender(
|
||||
cell.column.columnDef.cell,
|
||||
cell.getContext()
|
||||
|
||||
@@ -1,33 +1,7 @@
|
||||
"use client";
|
||||
|
||||
import "@opal/core/disabled/styles.css";
|
||||
import React, { createContext, useContext } from "react";
|
||||
import React from "react";
|
||||
import { Slot } from "@radix-ui/react-slot";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Context
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
interface DisabledContextValue {
|
||||
isDisabled: boolean;
|
||||
allowClick: boolean;
|
||||
}
|
||||
|
||||
const DisabledContext = createContext<DisabledContextValue>({
|
||||
isDisabled: false,
|
||||
allowClick: false,
|
||||
});
|
||||
|
||||
/**
|
||||
* Returns the current disabled state from the nearest `<Disabled>` ancestor.
|
||||
*
|
||||
* Used internally by `Interactive.Stateless` and `Interactive.Stateful` to
|
||||
* derive `data-disabled` and `aria-disabled` attributes automatically.
|
||||
*/
|
||||
function useDisabled(): DisabledContextValue {
|
||||
return useContext(DisabledContext);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Types
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -56,8 +30,8 @@ interface DisabledProps extends React.HTMLAttributes<HTMLElement> {
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Wrapper component that propagates disabled state via context and applies
|
||||
* baseline disabled CSS (opacity, cursor, pointer-events) to its child.
|
||||
* Wrapper component that applies baseline disabled CSS (opacity, cursor,
|
||||
* pointer-events) to its child element.
|
||||
*
|
||||
* Uses Radix `Slot` — merges props onto the single child element without
|
||||
* adding any DOM node. Works correctly inside Radix `asChild` chains.
|
||||
@@ -65,7 +39,7 @@ interface DisabledProps extends React.HTMLAttributes<HTMLElement> {
|
||||
* @example
|
||||
* ```tsx
|
||||
* <Disabled disabled={!canSubmit}>
|
||||
* <Button onClick={handleSubmit}>Save</Button>
|
||||
* <div>...</div>
|
||||
* </Disabled>
|
||||
* ```
|
||||
*/
|
||||
@@ -77,20 +51,16 @@ function Disabled({
|
||||
...rest
|
||||
}: DisabledProps) {
|
||||
return (
|
||||
<DisabledContext.Provider
|
||||
value={{ isDisabled: !!disabled, allowClick: !!allowClick }}
|
||||
<Slot
|
||||
ref={ref}
|
||||
{...rest}
|
||||
aria-disabled={disabled || undefined}
|
||||
data-opal-disabled={disabled || undefined}
|
||||
data-allow-click={disabled && allowClick ? "" : undefined}
|
||||
>
|
||||
<Slot
|
||||
ref={ref}
|
||||
{...rest}
|
||||
aria-disabled={disabled || undefined}
|
||||
data-opal-disabled={disabled || undefined}
|
||||
data-allow-click={disabled && allowClick ? "" : undefined}
|
||||
>
|
||||
{children}
|
||||
</Slot>
|
||||
</DisabledContext.Provider>
|
||||
{children}
|
||||
</Slot>
|
||||
);
|
||||
}
|
||||
|
||||
export { Disabled, useDisabled, type DisabledProps, type DisabledContextValue };
|
||||
export { Disabled, type DisabledProps };
|
||||
|
||||
@@ -1,10 +1,5 @@
|
||||
/* Disabled */
|
||||
export {
|
||||
Disabled,
|
||||
useDisabled,
|
||||
type DisabledProps,
|
||||
type DisabledContextValue,
|
||||
} from "@opal/core/disabled/components";
|
||||
export { Disabled, type DisabledProps } from "@opal/core/disabled/components";
|
||||
|
||||
/* Animations (formerly Hoverable) */
|
||||
export {
|
||||
|
||||
@@ -10,7 +10,6 @@ import {
|
||||
widthVariants,
|
||||
type ExtremaSizeVariants,
|
||||
} from "@opal/shared";
|
||||
import { useDisabled } from "@opal/core/disabled/components";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Types
|
||||
@@ -102,7 +101,6 @@ function InteractiveContainer({
|
||||
widthVariant = "fit",
|
||||
...props
|
||||
}: InteractiveContainerProps) {
|
||||
const { allowClick } = useDisabled();
|
||||
const {
|
||||
className: slotClassName,
|
||||
style: slotStyle,
|
||||
@@ -148,8 +146,7 @@ function InteractiveContainer({
|
||||
if (type) {
|
||||
const ariaDisabled = (rest as Record<string, unknown>)["aria-disabled"];
|
||||
const nativeDisabled =
|
||||
(type === "submit" || !allowClick) &&
|
||||
(ariaDisabled === true || ariaDisabled === "true" || undefined);
|
||||
ariaDisabled === true || ariaDisabled === "true" || undefined;
|
||||
return (
|
||||
<button
|
||||
ref={ref as React.Ref<HTMLButtonElement>}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import React from "react";
|
||||
import { Slot } from "@radix-ui/react-slot";
|
||||
import { cn } from "@opal/utils";
|
||||
import { useDisabled } from "@opal/core/disabled/components";
|
||||
import { guardPortalClick } from "@opal/core/interactive/utils";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -29,6 +28,11 @@ interface InteractiveSimpleProps
|
||||
* Link target (e.g. `"_blank"`). Only used when `href` is provided.
|
||||
*/
|
||||
target?: string;
|
||||
|
||||
/**
|
||||
* Applies disabled cursor and suppresses clicks.
|
||||
*/
|
||||
disabled?: boolean;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -38,8 +42,8 @@ interface InteractiveSimpleProps
|
||||
/**
|
||||
* Minimal interactive surface primitive.
|
||||
*
|
||||
* Provides cursor styling, click handling, disabled integration, and
|
||||
* optional link/group support — but **no color or background styling**.
|
||||
* Provides cursor styling, click handling, and optional link/group
|
||||
* support — but **no color or background styling**.
|
||||
*
|
||||
* Use this for elements that need interactivity (click, cursor, disabled)
|
||||
* without participating in the Interactive color system.
|
||||
@@ -59,9 +63,10 @@ function InteractiveSimple({
|
||||
group,
|
||||
href,
|
||||
target,
|
||||
disabled,
|
||||
...props
|
||||
}: InteractiveSimpleProps) {
|
||||
const { isDisabled, allowClick } = useDisabled();
|
||||
const isDisabled = !!disabled;
|
||||
|
||||
const classes = cn(
|
||||
"cursor-pointer select-none",
|
||||
@@ -88,7 +93,7 @@ function InteractiveSimple({
|
||||
{...linkAttrs}
|
||||
{...slotProps}
|
||||
onClick={
|
||||
isDisabled && !allowClick
|
||||
isDisabled
|
||||
? href
|
||||
? (e: React.MouseEvent) => e.preventDefault()
|
||||
: undefined
|
||||
|
||||
@@ -3,7 +3,6 @@ import "@opal/core/interactive/stateful/styles.css";
|
||||
import React from "react";
|
||||
import { Slot } from "@radix-ui/react-slot";
|
||||
import { cn } from "@opal/utils";
|
||||
import { useDisabled } from "@opal/core/disabled/components";
|
||||
import { guardPortalClick } from "@opal/core/interactive/utils";
|
||||
import type { ButtonType, WithoutStyles } from "@opal/types";
|
||||
|
||||
@@ -87,6 +86,11 @@ interface InteractiveStatefulProps
|
||||
* Link target (e.g. `"_blank"`). Only used when `href` is provided.
|
||||
*/
|
||||
target?: string;
|
||||
|
||||
/**
|
||||
* Applies variant-specific disabled colors and suppresses clicks.
|
||||
*/
|
||||
disabled?: boolean;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -100,8 +104,7 @@ interface InteractiveStatefulProps
|
||||
* (empty/filled/selected). Applies variant/state color styling via CSS
|
||||
* data-attributes and merges onto a single child element via Radix `Slot`.
|
||||
*
|
||||
* Disabled state is consumed from the nearest `<Disabled>` ancestor via
|
||||
* context — there is no `disabled` prop on this component.
|
||||
* Disabled state is controlled via the `disabled` prop.
|
||||
*/
|
||||
function InteractiveStateful({
|
||||
ref,
|
||||
@@ -112,9 +115,10 @@ function InteractiveStateful({
|
||||
type,
|
||||
href,
|
||||
target,
|
||||
disabled,
|
||||
...props
|
||||
}: InteractiveStatefulProps) {
|
||||
const { isDisabled, allowClick } = useDisabled();
|
||||
const isDisabled = !!disabled;
|
||||
|
||||
// onClick/href are always passed directly — Stateful is the outermost Slot,
|
||||
// so Radix Slot-injected handlers don't bypass this guard.
|
||||
@@ -150,7 +154,7 @@ function InteractiveStateful({
|
||||
{...linkAttrs}
|
||||
{...slotProps}
|
||||
onClick={
|
||||
isDisabled && !allowClick
|
||||
isDisabled
|
||||
? href
|
||||
? (e: React.MouseEvent) => e.preventDefault()
|
||||
: undefined
|
||||
|
||||
@@ -550,6 +550,14 @@
|
||||
) {
|
||||
@apply bg-background-tint-03;
|
||||
}
|
||||
/* ---------------------------------------------------------------------------
|
||||
Sidebar-Heavy — Disabled (all states)
|
||||
--------------------------------------------------------------------------- */
|
||||
.interactive[data-interactive-variant="sidebar-heavy"][data-disabled] {
|
||||
@apply bg-transparent opacity-50;
|
||||
--interactive-foreground: var(--text-03);
|
||||
--interactive-foreground-icon: var(--text-03);
|
||||
}
|
||||
|
||||
/* ===========================================================================
|
||||
Sidebar-Light
|
||||
@@ -607,3 +615,11 @@
|
||||
) {
|
||||
@apply bg-background-tint-03;
|
||||
}
|
||||
/* ---------------------------------------------------------------------------
|
||||
Sidebar-Light — Disabled (all states)
|
||||
--------------------------------------------------------------------------- */
|
||||
.interactive[data-interactive-variant="sidebar-light"][data-disabled] {
|
||||
@apply bg-transparent opacity-50;
|
||||
--interactive-foreground: var(--text-03);
|
||||
--interactive-foreground-icon: var(--text-03);
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ import "@opal/core/interactive/stateless/styles.css";
|
||||
import React from "react";
|
||||
import { Slot } from "@radix-ui/react-slot";
|
||||
import { cn } from "@opal/utils";
|
||||
import { useDisabled } from "@opal/core/disabled/components";
|
||||
import { guardPortalClick } from "@opal/core/interactive/utils";
|
||||
import type { ButtonType, WithoutStyles } from "@opal/types";
|
||||
|
||||
@@ -70,6 +69,11 @@ interface InteractiveStatelessProps
|
||||
* Link target (e.g. `"_blank"`). Only used when `href` is provided.
|
||||
*/
|
||||
target?: string;
|
||||
|
||||
/**
|
||||
* Applies variant-specific disabled colors and suppresses clicks.
|
||||
*/
|
||||
disabled?: boolean;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -84,8 +88,7 @@ interface InteractiveStatelessProps
|
||||
* color styling via CSS data-attributes and merges onto a single child
|
||||
* element via Radix `Slot`.
|
||||
*
|
||||
* Disabled state is consumed from the nearest `<Disabled>` ancestor via
|
||||
* context — there is no `disabled` prop on this component.
|
||||
* Disabled state is controlled via the `disabled` prop.
|
||||
*/
|
||||
function InteractiveStateless({
|
||||
ref,
|
||||
@@ -96,9 +99,10 @@ function InteractiveStateless({
|
||||
type,
|
||||
href,
|
||||
target,
|
||||
disabled,
|
||||
...props
|
||||
}: InteractiveStatelessProps) {
|
||||
const { isDisabled, allowClick } = useDisabled();
|
||||
const isDisabled = !!disabled;
|
||||
|
||||
// onClick/href are always passed directly — Stateless is the outermost Slot,
|
||||
// so Radix Slot-injected handlers don't bypass this guard.
|
||||
@@ -134,7 +138,7 @@ function InteractiveStateless({
|
||||
{...linkAttrs}
|
||||
{...slotProps}
|
||||
onClick={
|
||||
isDisabled && !allowClick
|
||||
isDisabled
|
||||
? href
|
||||
? (e: React.MouseEvent) => e.preventDefault()
|
||||
: undefined
|
||||
|
||||
@@ -8,7 +8,6 @@ import * as InputLayouts from "@/layouts/input-layouts";
|
||||
import Card from "@/refresh-components/cards/Card";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import { Button as OpalButton } from "@opal/components";
|
||||
import { Disabled } from "@opal/core";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import Message from "@/refresh-components/messages/Message";
|
||||
import InfoBlock from "@/refresh-components/messages/InfoBlock";
|
||||
@@ -246,15 +245,14 @@ function SubscriptionCard({
|
||||
to make changes.
|
||||
</Text>
|
||||
) : disabled ? (
|
||||
<Disabled disabled={isReconnecting}>
|
||||
<OpalButton
|
||||
prominence="secondary"
|
||||
onClick={handleReconnect}
|
||||
rightIcon={SvgArrowRight}
|
||||
>
|
||||
{isReconnecting ? "Connecting..." : "Connect to Stripe"}
|
||||
</OpalButton>
|
||||
</Disabled>
|
||||
<OpalButton
|
||||
disabled={isReconnecting}
|
||||
prominence="secondary"
|
||||
onClick={handleReconnect}
|
||||
rightIcon={SvgArrowRight}
|
||||
>
|
||||
{isReconnecting ? "Connecting..." : "Connect to Stripe"}
|
||||
</OpalButton>
|
||||
) : (
|
||||
<OpalButton onClick={handleManagePlan} rightIcon={SvgExternalLink}>
|
||||
Manage Plan
|
||||
@@ -377,11 +375,13 @@ function SeatsCard({
|
||||
sizePreset="main-content"
|
||||
variant="section"
|
||||
/>
|
||||
<Disabled disabled={isSubmitting}>
|
||||
<OpalButton prominence="secondary" onClick={handleCancel}>
|
||||
Cancel
|
||||
</OpalButton>
|
||||
</Disabled>
|
||||
<OpalButton
|
||||
disabled={isSubmitting}
|
||||
prominence="secondary"
|
||||
onClick={handleCancel}
|
||||
>
|
||||
Cancel
|
||||
</OpalButton>
|
||||
</Section>
|
||||
|
||||
<div className="billing-content-area">
|
||||
@@ -463,15 +463,14 @@ function SeatsCard({
|
||||
No changes to your billing.
|
||||
</Text>
|
||||
)}
|
||||
<Disabled
|
||||
<OpalButton
|
||||
disabled={
|
||||
isSubmitting || newSeatCount === totalSeats || isBelowMinimum
|
||||
}
|
||||
onClick={handleConfirm}
|
||||
>
|
||||
<OpalButton onClick={handleConfirm}>
|
||||
{isSubmitting ? "Saving..." : "Confirm Change"}
|
||||
</OpalButton>
|
||||
</Disabled>
|
||||
{isSubmitting ? "Saving..." : "Confirm Change"}
|
||||
</OpalButton>
|
||||
</Section>
|
||||
</Card>
|
||||
);
|
||||
@@ -509,15 +508,14 @@ function SeatsCard({
|
||||
View Users
|
||||
</OpalButton>
|
||||
{!hideUpdateSeats && (
|
||||
<Disabled disabled={isLoadingUsers || disabled || !billing}>
|
||||
<OpalButton
|
||||
prominence="secondary"
|
||||
onClick={handleStartEdit}
|
||||
icon={SvgPlus}
|
||||
>
|
||||
Update Seats
|
||||
</OpalButton>
|
||||
</Disabled>
|
||||
<OpalButton
|
||||
disabled={isLoadingUsers || disabled || !billing}
|
||||
prominence="secondary"
|
||||
onClick={handleStartEdit}
|
||||
icon={SvgPlus}
|
||||
>
|
||||
Update Seats
|
||||
</OpalButton>
|
||||
)}
|
||||
</Section>
|
||||
</Section>
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user