mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-05 23:12:43 +00:00
Compare commits
2 Commits
main
...
mock-conne
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
67456412c8 | ||
|
|
c83a107390 |
@@ -1 +0,0 @@
|
||||
../../../cli/internal/embedded/SKILL.md
|
||||
186
.cursor/skills/onyx-cli/SKILL.md
Normal file
186
.cursor/skills/onyx-cli/SKILL.md
Normal file
@@ -0,0 +1,186 @@
|
||||
---
|
||||
name: onyx-cli
|
||||
description: Query the Onyx knowledge base using the onyx-cli command. Use when the user wants to search company documents, ask questions about internal knowledge, query connected data sources, or look up information stored in Onyx.
|
||||
---
|
||||
|
||||
# Onyx CLI — Agent Tool
|
||||
|
||||
Onyx is an enterprise search and Gen-AI platform that connects to company documents, apps, and people. The `onyx-cli` CLI provides non-interactive commands to query the Onyx knowledge base and list available agents.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
### 1. Check if installed
|
||||
|
||||
```bash
|
||||
which onyx-cli
|
||||
```
|
||||
|
||||
### 2. Install (if needed)
|
||||
|
||||
**Primary — pip:**
|
||||
|
||||
```bash
|
||||
pip install onyx-cli
|
||||
```
|
||||
|
||||
**From source (Go):**
|
||||
|
||||
```bash
|
||||
cd cli && go build -o onyx-cli . && sudo mv onyx-cli /usr/local/bin/
|
||||
```
|
||||
|
||||
### 3. Check if configured
|
||||
|
||||
```bash
|
||||
onyx-cli validate-config
|
||||
```
|
||||
|
||||
This checks the config file exists, API key is present, and tests the server connection via `/api/me`. Exit code 0 on success, non-zero with a descriptive error on failure.
|
||||
|
||||
If unconfigured, you have two options:
|
||||
|
||||
**Option A — Interactive setup (requires user input):**
|
||||
|
||||
```bash
|
||||
onyx-cli configure
|
||||
```
|
||||
|
||||
This prompts for the Onyx server URL and API key, tests the connection, and saves config.
|
||||
|
||||
**Option B — Environment variables (non-interactive, preferred for agents):**
|
||||
|
||||
```bash
|
||||
export ONYX_SERVER_URL="https://your-onyx-server.com" # default: https://cloud.onyx.app
|
||||
export ONYX_API_KEY="your-api-key"
|
||||
```
|
||||
|
||||
Environment variables override the config file. If these are set, no config file is needed.
|
||||
|
||||
| Variable | Required | Description |
|
||||
|----------|----------|-------------|
|
||||
| `ONYX_SERVER_URL` | No | Onyx server base URL (default: `https://cloud.onyx.app`) |
|
||||
| `ONYX_API_KEY` | Yes | API key for authentication |
|
||||
| `ONYX_PERSONA_ID` | No | Default agent/persona ID |
|
||||
|
||||
If neither the config file nor environment variables are set, tell the user that `onyx-cli` needs to be configured and ask them to either:
|
||||
- Run `onyx-cli configure` interactively, or
|
||||
- Set `ONYX_SERVER_URL` and `ONYX_API_KEY` environment variables
|
||||
|
||||
## Commands
|
||||
|
||||
### Validate configuration
|
||||
|
||||
```bash
|
||||
onyx-cli validate-config
|
||||
```
|
||||
|
||||
Checks config file exists, API key is present, and tests the server connection. Use this before `ask` or `agents` to confirm the CLI is properly set up.
|
||||
|
||||
### List available agents
|
||||
|
||||
```bash
|
||||
onyx-cli agents
|
||||
```
|
||||
|
||||
Prints a table of agent IDs, names, and descriptions. Use `--json` for structured output:
|
||||
|
||||
```bash
|
||||
onyx-cli agents --json
|
||||
```
|
||||
|
||||
Use agent IDs with `ask --agent-id` to query a specific agent.
|
||||
|
||||
### Basic query (plain text output)
|
||||
|
||||
```bash
|
||||
onyx-cli ask "What is our company's PTO policy?"
|
||||
```
|
||||
|
||||
Streams the answer as plain text to stdout. Exit code 0 on success, non-zero on error.
|
||||
|
||||
### JSON output (structured events)
|
||||
|
||||
```bash
|
||||
onyx-cli ask --json "What authentication methods do we support?"
|
||||
```
|
||||
|
||||
Outputs JSON-encoded parsed stream events (one object per line). Key event objects include message deltas, stop, errors, search-start, and citation payloads.
|
||||
|
||||
Each line is a JSON object with this envelope:
|
||||
|
||||
```json
|
||||
{"type": "<event_type>", "event": { ... }}
|
||||
```
|
||||
|
||||
| Event Type | Description |
|
||||
|------------|-------------|
|
||||
| `message_delta` | Content token — concatenate all `content` fields for the full answer |
|
||||
| `stop` | Stream complete |
|
||||
| `error` | Error with `error` message field |
|
||||
| `search_tool_start` | Onyx started searching documents |
|
||||
| `citation_info` | Source citation — see shape below |
|
||||
|
||||
`citation_info` event shape:
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "citation_info",
|
||||
"event": {
|
||||
"citation_number": 1,
|
||||
"document_id": "abc123def456",
|
||||
"placement": {"turn_index": 0, "tab_index": 0, "sub_turn_index": null}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
`placement` is metadata about where in the conversation the citation appeared and can be ignored for most use cases.
|
||||
|
||||
### Specify an agent
|
||||
|
||||
```bash
|
||||
onyx-cli ask --agent-id 5 "Summarize our Q4 roadmap"
|
||||
```
|
||||
|
||||
Uses a specific Onyx agent/persona instead of the default.
|
||||
|
||||
### All flags
|
||||
|
||||
| Flag | Type | Description |
|
||||
|------|------|-------------|
|
||||
| `--agent-id` | int | Agent ID to use (overrides default) |
|
||||
| `--json` | bool | Output raw NDJSON events instead of plain text |
|
||||
|
||||
## Statelessness
|
||||
|
||||
Each `onyx-cli ask` call creates an independent chat session. There is no built-in way to chain context across multiple `ask` invocations — every call starts fresh. If you need multi-turn conversation with memory, use the interactive TUI (`onyx-cli` or `onyx-cli chat`) instead.
|
||||
|
||||
## When to Use
|
||||
|
||||
Use `onyx-cli ask` when:
|
||||
|
||||
- The user asks about company-specific information (policies, docs, processes)
|
||||
- You need to search internal knowledge bases or connected data sources
|
||||
- The user references Onyx, asks you to "search Onyx", or wants to query their documents
|
||||
- You need context from company wikis, Confluence, Google Drive, Slack, or other connected sources
|
||||
|
||||
Do NOT use when:
|
||||
|
||||
- The question is about general programming knowledge (use your own knowledge)
|
||||
- The user is asking about code in the current repository (use grep/read tools)
|
||||
- The user hasn't mentioned Onyx and the question doesn't require internal company data
|
||||
|
||||
## Examples
|
||||
|
||||
```bash
|
||||
# Simple question
|
||||
onyx-cli ask "What are the steps to deploy to production?"
|
||||
|
||||
# Get structured output for parsing
|
||||
onyx-cli ask --json "List all active API integrations"
|
||||
|
||||
# Use a specialized agent
|
||||
onyx-cli ask --agent-id 3 "What were the action items from last week's standup?"
|
||||
|
||||
# Pipe the answer into another command
|
||||
onyx-cli ask "What is the database schema for users?" | head -20
|
||||
```
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any
|
||||
from typing import Any, Literal
|
||||
from onyx.db.engine.iam_auth import get_iam_auth_token
|
||||
from onyx.configs.app_configs import USE_IAM_AUTH
|
||||
from onyx.configs.app_configs import POSTGRES_HOST
|
||||
@@ -19,6 +19,7 @@ from logging.config import fileConfig
|
||||
|
||||
from alembic import context
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.sql.schema import SchemaItem
|
||||
from onyx.configs.constants import SSL_CERT_FILE
|
||||
from shared_configs.configs import (
|
||||
MULTI_TENANT,
|
||||
@@ -44,6 +45,8 @@ if config.config_file_name is not None and config.attributes.get(
|
||||
|
||||
target_metadata = [Base.metadata, ResultModelBase.metadata]
|
||||
|
||||
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
ssl_context: ssl.SSLContext | None = None
|
||||
@@ -53,6 +56,25 @@ if USE_IAM_AUTH:
|
||||
ssl_context = ssl.create_default_context(cafile=SSL_CERT_FILE)
|
||||
|
||||
|
||||
def include_object(
|
||||
object: SchemaItem, # noqa: ARG001
|
||||
name: str | None,
|
||||
type_: Literal[
|
||||
"schema",
|
||||
"table",
|
||||
"column",
|
||||
"index",
|
||||
"unique_constraint",
|
||||
"foreign_key_constraint",
|
||||
],
|
||||
reflected: bool, # noqa: ARG001
|
||||
compare_to: SchemaItem | None, # noqa: ARG001
|
||||
) -> bool:
|
||||
if type_ == "table" and name in EXCLUDE_TABLES:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def filter_tenants_by_range(
|
||||
tenant_ids: list[str], start_range: int | None = None, end_range: int | None = None
|
||||
) -> list[str]:
|
||||
@@ -209,6 +231,7 @@ def do_run_migrations(
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
include_object=include_object,
|
||||
version_table_schema=schema_name,
|
||||
include_schemas=True,
|
||||
compare_type=True,
|
||||
@@ -382,6 +405,7 @@ def run_migrations_offline() -> None:
|
||||
url=url,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
literal_binds=True,
|
||||
include_object=include_object,
|
||||
version_table_schema=schema,
|
||||
include_schemas=True,
|
||||
script_location=config.get_main_option("script_location"),
|
||||
@@ -423,6 +447,7 @@ def run_migrations_offline() -> None:
|
||||
url=url,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
literal_binds=True,
|
||||
include_object=include_object,
|
||||
version_table_schema=schema,
|
||||
include_schemas=True,
|
||||
script_location=config.get_main_option("script_location"),
|
||||
@@ -465,6 +490,7 @@ def run_migrations_online() -> None:
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata, # type: ignore
|
||||
include_object=include_object,
|
||||
version_table_schema=schema_name,
|
||||
include_schemas=True,
|
||||
compare_type=True,
|
||||
|
||||
@@ -1,108 +0,0 @@
|
||||
"""backfill_account_type
|
||||
|
||||
Revision ID: 03d085c5c38d
|
||||
Revises: 977e834c1427
|
||||
Create Date: 2026-03-25 16:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "03d085c5c38d"
|
||||
down_revision = "977e834c1427"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
_STANDARD = "STANDARD"
|
||||
_BOT = "BOT"
|
||||
_EXT_PERM_USER = "EXT_PERM_USER"
|
||||
_SERVICE_ACCOUNT = "SERVICE_ACCOUNT"
|
||||
_ANONYMOUS = "ANONYMOUS"
|
||||
|
||||
# Well-known anonymous user UUID
|
||||
ANONYMOUS_USER_ID = "00000000-0000-0000-0000-000000000002"
|
||||
|
||||
# Email pattern for API key virtual users
|
||||
API_KEY_EMAIL_PATTERN = r"API\_KEY\_\_%"
|
||||
|
||||
# Reflect the table structure for use in DML
|
||||
user_table = sa.table(
|
||||
"user",
|
||||
sa.column("id", sa.Uuid),
|
||||
sa.column("email", sa.String),
|
||||
sa.column("role", sa.String),
|
||||
sa.column("account_type", sa.String),
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ------------------------------------------------------------------
|
||||
# Step 1: Backfill account_type from role.
|
||||
# Order matters — most-specific matches first so the final catch-all
|
||||
# only touches rows that haven't been classified yet.
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
# 1a. API key virtual users → SERVICE_ACCOUNT
|
||||
op.execute(
|
||||
sa.update(user_table)
|
||||
.where(
|
||||
user_table.c.email.ilike(API_KEY_EMAIL_PATTERN),
|
||||
user_table.c.account_type.is_(None),
|
||||
)
|
||||
.values(account_type=_SERVICE_ACCOUNT)
|
||||
)
|
||||
|
||||
# 1b. Anonymous user → ANONYMOUS
|
||||
op.execute(
|
||||
sa.update(user_table)
|
||||
.where(
|
||||
user_table.c.id == ANONYMOUS_USER_ID,
|
||||
user_table.c.account_type.is_(None),
|
||||
)
|
||||
.values(account_type=_ANONYMOUS)
|
||||
)
|
||||
|
||||
# 1c. SLACK_USER role → BOT
|
||||
op.execute(
|
||||
sa.update(user_table)
|
||||
.where(
|
||||
user_table.c.role == "SLACK_USER",
|
||||
user_table.c.account_type.is_(None),
|
||||
)
|
||||
.values(account_type=_BOT)
|
||||
)
|
||||
|
||||
# 1d. EXT_PERM_USER role → EXT_PERM_USER
|
||||
op.execute(
|
||||
sa.update(user_table)
|
||||
.where(
|
||||
user_table.c.role == "EXT_PERM_USER",
|
||||
user_table.c.account_type.is_(None),
|
||||
)
|
||||
.values(account_type=_EXT_PERM_USER)
|
||||
)
|
||||
|
||||
# 1e. Everything else → STANDARD
|
||||
op.execute(
|
||||
sa.update(user_table)
|
||||
.where(user_table.c.account_type.is_(None))
|
||||
.values(account_type=_STANDARD)
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 2: Set account_type to NOT NULL now that every row is filled.
|
||||
# ------------------------------------------------------------------
|
||||
op.alter_column(
|
||||
"user",
|
||||
"account_type",
|
||||
nullable=False,
|
||||
server_default="STANDARD",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.alter_column("user", "account_type", nullable=True, server_default=None)
|
||||
op.execute(sa.update(user_table).values(account_type=None))
|
||||
@@ -1,104 +0,0 @@
|
||||
"""add_effective_permissions
|
||||
|
||||
Adds a JSONB column `effective_permissions` to the user table to store
|
||||
directly granted permissions (e.g. ["admin"] or ["basic"]). Implied
|
||||
permissions are expanded at read time, not stored.
|
||||
|
||||
Backfill: joins user__user_group → permission_grant to collect each
|
||||
user's granted permissions into a JSON array. Users without group
|
||||
memberships keep the default [].
|
||||
|
||||
Revision ID: 503883791c39
|
||||
Revises: b4b7e1028dfd
|
||||
Create Date: 2026-03-30 14:49:22.261748
|
||||
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "503883791c39"
|
||||
down_revision = "b4b7e1028dfd"
|
||||
branch_labels: str | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
user_table = sa.table(
|
||||
"user",
|
||||
sa.column("id", sa.Uuid),
|
||||
sa.column("effective_permissions", postgresql.JSONB),
|
||||
)
|
||||
|
||||
user_user_group = sa.table(
|
||||
"user__user_group",
|
||||
sa.column("user_id", sa.Uuid),
|
||||
sa.column("user_group_id", sa.Integer),
|
||||
)
|
||||
|
||||
permission_grant = sa.table(
|
||||
"permission_grant",
|
||||
sa.column("group_id", sa.Integer),
|
||||
sa.column("permission", sa.String),
|
||||
sa.column("is_deleted", sa.Boolean),
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"effective_permissions",
|
||||
postgresql.JSONB(),
|
||||
nullable=False,
|
||||
server_default=sa.text("'[]'::jsonb"),
|
||||
),
|
||||
)
|
||||
|
||||
conn = op.get_bind()
|
||||
|
||||
# Deduplicated permissions per user
|
||||
deduped = (
|
||||
sa.select(
|
||||
user_user_group.c.user_id,
|
||||
permission_grant.c.permission,
|
||||
)
|
||||
.select_from(
|
||||
user_user_group.join(
|
||||
permission_grant,
|
||||
sa.and_(
|
||||
permission_grant.c.group_id == user_user_group.c.user_group_id,
|
||||
permission_grant.c.is_deleted == sa.false(),
|
||||
),
|
||||
)
|
||||
)
|
||||
.distinct()
|
||||
.subquery("deduped")
|
||||
)
|
||||
|
||||
# Aggregate into JSONB array per user (order is not guaranteed;
|
||||
# consumers read this as a set so ordering does not matter)
|
||||
perms_per_user = (
|
||||
sa.select(
|
||||
deduped.c.user_id,
|
||||
sa.func.jsonb_agg(
|
||||
deduped.c.permission,
|
||||
type_=postgresql.JSONB,
|
||||
).label("perms"),
|
||||
)
|
||||
.group_by(deduped.c.user_id)
|
||||
.subquery("sub")
|
||||
)
|
||||
|
||||
conn.execute(
|
||||
user_table.update()
|
||||
.where(user_table.c.id == perms_per_user.c.user_id)
|
||||
.values(effective_permissions=perms_per_user.c.perms)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "effective_permissions")
|
||||
@@ -1,139 +0,0 @@
|
||||
"""seed_default_groups
|
||||
|
||||
Revision ID: 977e834c1427
|
||||
Revises: 8188861f4e92
|
||||
Create Date: 2026-03-25 14:59:41.313091
|
||||
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "977e834c1427"
|
||||
down_revision = "8188861f4e92"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
# (group_name, permission_value)
|
||||
DEFAULT_GROUPS = [
|
||||
("Admin", "admin"),
|
||||
("Basic", "basic"),
|
||||
]
|
||||
|
||||
CUSTOM_SUFFIX = "(Custom)"
|
||||
|
||||
MAX_RENAME_ATTEMPTS = 100
|
||||
|
||||
# Reflect table structures for use in DML
|
||||
user_group_table = sa.table(
|
||||
"user_group",
|
||||
sa.column("id", sa.Integer),
|
||||
sa.column("name", sa.String),
|
||||
sa.column("is_up_to_date", sa.Boolean),
|
||||
sa.column("is_up_for_deletion", sa.Boolean),
|
||||
sa.column("is_default", sa.Boolean),
|
||||
)
|
||||
|
||||
permission_grant_table = sa.table(
|
||||
"permission_grant",
|
||||
sa.column("group_id", sa.Integer),
|
||||
sa.column("permission", sa.String),
|
||||
sa.column("grant_source", sa.String),
|
||||
)
|
||||
|
||||
user__user_group_table = sa.table(
|
||||
"user__user_group",
|
||||
sa.column("user_group_id", sa.Integer),
|
||||
sa.column("user_id", sa.Uuid),
|
||||
)
|
||||
|
||||
|
||||
def _find_available_name(conn: sa.engine.Connection, base: str) -> str:
|
||||
"""Return a name like 'Admin (Custom)' or 'Admin (Custom 2)' that is not taken."""
|
||||
candidate = f"{base} {CUSTOM_SUFFIX}"
|
||||
attempt = 1
|
||||
while attempt <= MAX_RENAME_ATTEMPTS:
|
||||
exists: Any = conn.execute(
|
||||
sa.select(sa.literal(1))
|
||||
.select_from(user_group_table)
|
||||
.where(user_group_table.c.name == candidate)
|
||||
.limit(1)
|
||||
).fetchone()
|
||||
if exists is None:
|
||||
return candidate
|
||||
attempt += 1
|
||||
candidate = f"{base} (Custom {attempt})"
|
||||
raise RuntimeError(
|
||||
f"Could not find an available name for group '{base}' "
|
||||
f"after {MAX_RENAME_ATTEMPTS} attempts"
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
for group_name, permission_value in DEFAULT_GROUPS:
|
||||
# Step 1: Rename ALL existing groups that clash with the canonical name.
|
||||
conflicting = conn.execute(
|
||||
sa.select(user_group_table.c.id, user_group_table.c.name).where(
|
||||
user_group_table.c.name == group_name
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
for row_id, row_name in conflicting:
|
||||
new_name = _find_available_name(conn, row_name)
|
||||
op.execute(
|
||||
sa.update(user_group_table)
|
||||
.where(user_group_table.c.id == row_id)
|
||||
.values(name=new_name, is_up_to_date=False)
|
||||
)
|
||||
|
||||
# Step 2: Create a fresh default group.
|
||||
result = conn.execute(
|
||||
user_group_table.insert()
|
||||
.values(
|
||||
name=group_name,
|
||||
is_up_to_date=True,
|
||||
is_up_for_deletion=False,
|
||||
is_default=True,
|
||||
)
|
||||
.returning(user_group_table.c.id)
|
||||
).fetchone()
|
||||
assert result is not None
|
||||
group_id = result[0]
|
||||
|
||||
# Step 3: Upsert permission grant.
|
||||
op.execute(
|
||||
pg_insert(permission_grant_table)
|
||||
.values(
|
||||
group_id=group_id,
|
||||
permission=permission_value,
|
||||
grant_source="SYSTEM",
|
||||
)
|
||||
.on_conflict_do_nothing(index_elements=["group_id", "permission"])
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove the default groups created by this migration.
|
||||
# First remove user-group memberships that reference default groups
|
||||
# to avoid FK violations, then delete the groups themselves.
|
||||
default_group_ids = sa.select(user_group_table.c.id).where(
|
||||
user_group_table.c.is_default == True # noqa: E712
|
||||
)
|
||||
conn = op.get_bind()
|
||||
conn.execute(
|
||||
sa.delete(user__user_group_table).where(
|
||||
user__user_group_table.c.user_group_id.in_(default_group_ids)
|
||||
)
|
||||
)
|
||||
conn.execute(
|
||||
sa.delete(user_group_table).where(
|
||||
user_group_table.c.is_default == True # noqa: E712
|
||||
)
|
||||
)
|
||||
@@ -1,84 +0,0 @@
|
||||
"""grant_basic_to_existing_groups
|
||||
|
||||
Grants the "basic" permission to all existing groups that don't already
|
||||
have it. Every group should have at least "basic" so that its members
|
||||
get basic access when effective_permissions is backfilled.
|
||||
|
||||
Revision ID: b4b7e1028dfd
|
||||
Revises: b7bcc991d722
|
||||
Create Date: 2026-03-30 16:15:17.093498
|
||||
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b4b7e1028dfd"
|
||||
down_revision = "b7bcc991d722"
|
||||
branch_labels: str | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
user_group = sa.table(
|
||||
"user_group",
|
||||
sa.column("id", sa.Integer),
|
||||
sa.column("is_default", sa.Boolean),
|
||||
)
|
||||
|
||||
permission_grant = sa.table(
|
||||
"permission_grant",
|
||||
sa.column("group_id", sa.Integer),
|
||||
sa.column("permission", sa.String),
|
||||
sa.column("grant_source", sa.String),
|
||||
sa.column("is_deleted", sa.Boolean),
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
already_has_basic = (
|
||||
sa.select(sa.literal(1))
|
||||
.select_from(permission_grant)
|
||||
.where(
|
||||
permission_grant.c.group_id == user_group.c.id,
|
||||
permission_grant.c.permission == "basic",
|
||||
)
|
||||
.exists()
|
||||
)
|
||||
|
||||
groups_needing_basic = sa.select(
|
||||
user_group.c.id,
|
||||
sa.literal("basic").label("permission"),
|
||||
sa.literal("SYSTEM").label("grant_source"),
|
||||
sa.literal(False).label("is_deleted"),
|
||||
).where(
|
||||
user_group.c.is_default == sa.false(),
|
||||
~already_has_basic,
|
||||
)
|
||||
|
||||
conn.execute(
|
||||
permission_grant.insert().from_select(
|
||||
["group_id", "permission", "grant_source", "is_deleted"],
|
||||
groups_needing_basic,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
non_default_group_ids = sa.select(user_group.c.id).where(
|
||||
user_group.c.is_default == sa.false()
|
||||
)
|
||||
|
||||
conn.execute(
|
||||
permission_grant.delete().where(
|
||||
permission_grant.c.permission == "basic",
|
||||
permission_grant.c.grant_source == "SYSTEM",
|
||||
permission_grant.c.group_id.in_(non_default_group_ids),
|
||||
)
|
||||
)
|
||||
@@ -1,125 +0,0 @@
|
||||
"""assign_users_to_default_groups
|
||||
|
||||
Revision ID: b7bcc991d722
|
||||
Revises: 03d085c5c38d
|
||||
Create Date: 2026-03-25 16:30:39.529301
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b7bcc991d722"
|
||||
down_revision = "03d085c5c38d"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
# The no-auth placeholder user must NOT be assigned to default groups.
|
||||
# A database trigger (migrate_no_auth_data_to_user) will try to DELETE this
|
||||
# user when the first real user registers; group membership rows would cause
|
||||
# an FK violation on that DELETE.
|
||||
NO_AUTH_PLACEHOLDER_USER_UUID = "00000000-0000-0000-0000-000000000001"
|
||||
|
||||
# Reflect table structures for use in DML
|
||||
user_group_table = sa.table(
|
||||
"user_group",
|
||||
sa.column("id", sa.Integer),
|
||||
sa.column("name", sa.String),
|
||||
sa.column("is_default", sa.Boolean),
|
||||
)
|
||||
|
||||
user_table = sa.table(
|
||||
"user",
|
||||
sa.column("id", sa.Uuid),
|
||||
sa.column("role", sa.String),
|
||||
sa.column("account_type", sa.String),
|
||||
sa.column("is_active", sa.Boolean),
|
||||
)
|
||||
|
||||
user__user_group_table = sa.table(
|
||||
"user__user_group",
|
||||
sa.column("user_group_id", sa.Integer),
|
||||
sa.column("user_id", sa.Uuid),
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Look up default group IDs
|
||||
admin_row = conn.execute(
|
||||
sa.select(user_group_table.c.id).where(
|
||||
user_group_table.c.name == "Admin",
|
||||
user_group_table.c.is_default == True, # noqa: E712
|
||||
)
|
||||
).fetchone()
|
||||
|
||||
basic_row = conn.execute(
|
||||
sa.select(user_group_table.c.id).where(
|
||||
user_group_table.c.name == "Basic",
|
||||
user_group_table.c.is_default == True, # noqa: E712
|
||||
)
|
||||
).fetchone()
|
||||
|
||||
if admin_row is None:
|
||||
raise RuntimeError(
|
||||
"Default 'Admin' group not found. "
|
||||
"Ensure migration 977e834c1427 (seed_default_groups) ran successfully."
|
||||
)
|
||||
|
||||
if basic_row is None:
|
||||
raise RuntimeError(
|
||||
"Default 'Basic' group not found. "
|
||||
"Ensure migration 977e834c1427 (seed_default_groups) ran successfully."
|
||||
)
|
||||
|
||||
# Users with role=admin → Admin group
|
||||
# Include inactive users so reactivation doesn't require reconciliation.
|
||||
# Exclude non-human account types (mirrors assign_user_to_default_groups logic).
|
||||
admin_users = sa.select(
|
||||
sa.literal(admin_row[0]).label("user_group_id"),
|
||||
user_table.c.id.label("user_id"),
|
||||
).where(
|
||||
user_table.c.role == "ADMIN",
|
||||
user_table.c.account_type.notin_(["BOT", "EXT_PERM_USER", "ANONYMOUS"]),
|
||||
user_table.c.id != NO_AUTH_PLACEHOLDER_USER_UUID,
|
||||
)
|
||||
op.execute(
|
||||
pg_insert(user__user_group_table)
|
||||
.from_select(["user_group_id", "user_id"], admin_users)
|
||||
.on_conflict_do_nothing(index_elements=["user_group_id", "user_id"])
|
||||
)
|
||||
|
||||
# STANDARD users (non-admin) and SERVICE_ACCOUNT users (role=basic) → Basic group
|
||||
# Include inactive users so reactivation doesn't require reconciliation.
|
||||
basic_users = sa.select(
|
||||
sa.literal(basic_row[0]).label("user_group_id"),
|
||||
user_table.c.id.label("user_id"),
|
||||
).where(
|
||||
user_table.c.account_type.notin_(["BOT", "EXT_PERM_USER", "ANONYMOUS"]),
|
||||
user_table.c.id != NO_AUTH_PLACEHOLDER_USER_UUID,
|
||||
sa.or_(
|
||||
sa.and_(
|
||||
user_table.c.account_type == "STANDARD",
|
||||
user_table.c.role != "ADMIN",
|
||||
),
|
||||
sa.and_(
|
||||
user_table.c.account_type == "SERVICE_ACCOUNT",
|
||||
user_table.c.role == "BASIC",
|
||||
),
|
||||
),
|
||||
)
|
||||
op.execute(
|
||||
pg_insert(user__user_group_table)
|
||||
.from_select(["user_group_id", "user_id"], basic_users)
|
||||
.on_conflict_do_nothing(index_elements=["user_group_id", "user_id"])
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Group memberships are left in place — removing them risks
|
||||
# deleting memberships that existed before this migration.
|
||||
pass
|
||||
@@ -1,9 +1,11 @@
|
||||
import asyncio
|
||||
from logging.config import fileConfig
|
||||
from typing import Literal
|
||||
|
||||
from sqlalchemy import pool
|
||||
from sqlalchemy.engine import Connection
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy.schema import SchemaItem
|
||||
|
||||
from alembic import context
|
||||
from onyx.db.engine.sql_engine import build_connection_string
|
||||
@@ -33,6 +35,27 @@ target_metadata = [PublicBase.metadata]
|
||||
# my_important_option = config.get_main_option("my_important_option")
|
||||
# ... etc.
|
||||
|
||||
EXCLUDE_TABLES = {"kombu_queue", "kombu_message"}
|
||||
|
||||
|
||||
def include_object(
|
||||
object: SchemaItem, # noqa: ARG001
|
||||
name: str | None,
|
||||
type_: Literal[
|
||||
"schema",
|
||||
"table",
|
||||
"column",
|
||||
"index",
|
||||
"unique_constraint",
|
||||
"foreign_key_constraint",
|
||||
],
|
||||
reflected: bool, # noqa: ARG001
|
||||
compare_to: SchemaItem | None, # noqa: ARG001
|
||||
) -> bool:
|
||||
if type_ == "table" and name in EXCLUDE_TABLES:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""Run migrations in 'offline' mode.
|
||||
@@ -62,6 +85,7 @@ def do_run_migrations(connection: Connection) -> None:
|
||||
context.configure(
|
||||
connection=connection,
|
||||
target_metadata=target_metadata, # type: ignore[arg-type]
|
||||
include_object=include_object,
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
|
||||
@@ -27,13 +27,13 @@ from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import TENANT_ID_PREFIX
|
||||
|
||||
# Maximum tenants to provision in a single task run.
|
||||
# Each tenant takes ~80s (alembic migrations), so 15 tenants ≈ 20 minutes.
|
||||
_MAX_TENANTS_PER_RUN = 15
|
||||
# Each tenant takes ~80s (alembic migrations), so 5 tenants ≈ 7 minutes.
|
||||
_MAX_TENANTS_PER_RUN = 5
|
||||
|
||||
# Time limits sized for worst-case: provisioning up to _MAX_TENANTS_PER_RUN new tenants
|
||||
# (~90s each) plus migrating up to TARGET_AVAILABLE_TENANTS pool tenants (~90s each).
|
||||
_TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 40 # 40 minutes
|
||||
_TENANT_PROVISIONING_TIME_LIMIT = 60 * 45 # 45 minutes
|
||||
_TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 20 # 20 minutes
|
||||
_TENANT_PROVISIONING_TIME_LIMIT = 60 * 25 # 25 minutes
|
||||
|
||||
|
||||
@shared_task(
|
||||
|
||||
@@ -1,14 +1,20 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from uuid import UUID
|
||||
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
|
||||
from ee.onyx.background.celery_utils import should_perform_chat_ttl_check
|
||||
from ee.onyx.background.task_name_builders import name_chat_ttl_task
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.chat import delete_chat_session
|
||||
from onyx.db.chat import get_chat_sessions_older_than
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import TaskStatus
|
||||
from onyx.db.tasks import mark_task_as_finished_with_id
|
||||
from onyx.db.tasks import register_task
|
||||
from onyx.server.settings.store import load_settings
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -23,42 +29,59 @@ logger = setup_logger()
|
||||
trail=False,
|
||||
)
|
||||
def perform_ttl_management_task(
|
||||
self: Task, retention_limit_days: int, *, tenant_id: str # noqa: ARG001
|
||||
self: Task, retention_limit_days: int, *, tenant_id: str
|
||||
) -> None:
|
||||
task_id = self.request.id
|
||||
if not task_id:
|
||||
raise RuntimeError("No task id defined for this task; cannot identify it")
|
||||
|
||||
start_time = datetime.now(tz=timezone.utc)
|
||||
|
||||
user_id: UUID | None = None
|
||||
session_id: UUID | None = None
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# we generally want to move off this, but keeping for now
|
||||
register_task(
|
||||
db_session=db_session,
|
||||
task_name=name_chat_ttl_task(retention_limit_days, tenant_id),
|
||||
task_id=task_id,
|
||||
status=TaskStatus.STARTED,
|
||||
start_time=start_time,
|
||||
)
|
||||
|
||||
old_chat_sessions = get_chat_sessions_older_than(
|
||||
retention_limit_days, db_session
|
||||
)
|
||||
|
||||
for user_id, session_id in old_chat_sessions:
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
delete_chat_session(
|
||||
user_id,
|
||||
session_id,
|
||||
db_session,
|
||||
include_deleted=True,
|
||||
hard_delete=True,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to delete chat session "
|
||||
f"user_id={user_id} session_id={session_id}, "
|
||||
"continuing with remaining sessions"
|
||||
# one session per delete so that we don't blow up if a deletion fails.
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
delete_chat_session(
|
||||
user_id,
|
||||
session_id,
|
||||
db_session,
|
||||
include_deleted=True,
|
||||
hard_delete=True,
|
||||
)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
mark_task_as_finished_with_id(
|
||||
db_session=db_session,
|
||||
task_id=task_id,
|
||||
success=True,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"delete_chat_session exceptioned. user_id={user_id} session_id={session_id}"
|
||||
)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
mark_task_as_finished_with_id(
|
||||
db_session=db_session,
|
||||
task_id=task_id,
|
||||
success=False,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
|
||||
@@ -36,16 +36,13 @@ from ee.onyx.server.scim.filtering import ScimFilter
|
||||
from ee.onyx.server.scim.filtering import ScimFilterOperator
|
||||
from ee.onyx.server.scim.models import ScimMappingFields
|
||||
from onyx.db.dal import DAL
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.enums import GrantSource
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import PermissionGrant
|
||||
from onyx.db.models import ScimGroupMapping
|
||||
from onyx.db.models import ScimToken
|
||||
from onyx.db.models import ScimUserMapping
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -283,9 +280,7 @@ class ScimDAL(DAL):
|
||||
query = (
|
||||
select(User)
|
||||
.join(ScimUserMapping, ScimUserMapping.user_id == User.id)
|
||||
.where(
|
||||
User.account_type.notin_([AccountType.BOT, AccountType.EXT_PERM_USER])
|
||||
)
|
||||
.where(User.role.notin_([UserRole.SLACK_USER, UserRole.EXT_PERM_USER]))
|
||||
)
|
||||
|
||||
if scim_filter:
|
||||
@@ -526,22 +521,6 @@ class ScimDAL(DAL):
|
||||
self._session.add(group)
|
||||
self._session.flush()
|
||||
|
||||
def add_permission_grant_to_group(
|
||||
self,
|
||||
group_id: int,
|
||||
permission: Permission,
|
||||
grant_source: GrantSource,
|
||||
) -> None:
|
||||
"""Grant a permission to a group and flush."""
|
||||
self._session.add(
|
||||
PermissionGrant(
|
||||
group_id=group_id,
|
||||
permission=permission,
|
||||
grant_source=grant_source,
|
||||
)
|
||||
)
|
||||
self._session.flush()
|
||||
|
||||
def update_group(
|
||||
self,
|
||||
group: UserGroup,
|
||||
|
||||
@@ -19,8 +19,6 @@ from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import GrantSource
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Credential
|
||||
from onyx.db.models import Credential__UserGroup
|
||||
@@ -30,7 +28,6 @@ from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import DocumentSet__UserGroup
|
||||
from onyx.db.models import FederatedConnector__DocumentSet
|
||||
from onyx.db.models import LLMProvider__UserGroup
|
||||
from onyx.db.models import PermissionGrant
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__UserGroup
|
||||
from onyx.db.models import TokenRateLimit__UserGroup
|
||||
@@ -39,7 +36,6 @@ from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.models import UserGroup__ConnectorCredentialPair
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.db.permissions import recompute_user_permissions__no_commit
|
||||
from onyx.db.users import fetch_user_by_id
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -259,7 +255,6 @@ def fetch_user_groups(
|
||||
db_session: Session,
|
||||
only_up_to_date: bool = True,
|
||||
eager_load_for_snapshot: bool = False,
|
||||
include_default: bool = True,
|
||||
) -> Sequence[UserGroup]:
|
||||
"""
|
||||
Fetches user groups from the database.
|
||||
@@ -274,7 +269,6 @@ def fetch_user_groups(
|
||||
to include only up to date user groups. Defaults to `True`.
|
||||
eager_load_for_snapshot: If True, adds eager loading for all relationships
|
||||
needed by UserGroup.from_model snapshot creation.
|
||||
include_default: If False, excludes system default groups (is_default=True).
|
||||
|
||||
Returns:
|
||||
Sequence[UserGroup]: A sequence of `UserGroup` objects matching the query criteria.
|
||||
@@ -282,8 +276,6 @@ def fetch_user_groups(
|
||||
stmt = select(UserGroup)
|
||||
if only_up_to_date:
|
||||
stmt = stmt.where(UserGroup.is_up_to_date == True) # noqa: E712
|
||||
if not include_default:
|
||||
stmt = stmt.where(UserGroup.is_default == False) # noqa: E712
|
||||
if eager_load_for_snapshot:
|
||||
stmt = _add_user_group_snapshot_eager_loads(stmt)
|
||||
return db_session.scalars(stmt).unique().all()
|
||||
@@ -294,7 +286,6 @@ def fetch_user_groups_for_user(
|
||||
user_id: UUID,
|
||||
only_curator_groups: bool = False,
|
||||
eager_load_for_snapshot: bool = False,
|
||||
include_default: bool = True,
|
||||
) -> Sequence[UserGroup]:
|
||||
stmt = (
|
||||
select(UserGroup)
|
||||
@@ -304,8 +295,6 @@ def fetch_user_groups_for_user(
|
||||
)
|
||||
if only_curator_groups:
|
||||
stmt = stmt.where(User__UserGroup.is_curator == True) # noqa: E712
|
||||
if not include_default:
|
||||
stmt = stmt.where(UserGroup.is_default == False) # noqa: E712
|
||||
if eager_load_for_snapshot:
|
||||
stmt = _add_user_group_snapshot_eager_loads(stmt)
|
||||
return db_session.scalars(stmt).unique().all()
|
||||
@@ -489,16 +478,6 @@ def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserG
|
||||
db_session.add(db_user_group)
|
||||
db_session.flush() # give the group an ID
|
||||
|
||||
# Every group gets the "basic" permission by default
|
||||
db_session.add(
|
||||
PermissionGrant(
|
||||
group_id=db_user_group.id,
|
||||
permission=Permission.BASIC_ACCESS,
|
||||
grant_source=GrantSource.SYSTEM,
|
||||
)
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
_add_user__user_group_relationships__no_commit(
|
||||
db_session=db_session,
|
||||
user_group_id=db_user_group.id,
|
||||
@@ -510,8 +489,6 @@ def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserG
|
||||
cc_pair_ids=user_group.cc_pair_ids,
|
||||
)
|
||||
|
||||
recompute_user_permissions__no_commit(user_group.user_ids, db_session)
|
||||
|
||||
db_session.commit()
|
||||
return db_user_group
|
||||
|
||||
@@ -819,10 +796,6 @@ def update_user_group(
|
||||
# update "time_updated" to now
|
||||
db_user_group.time_last_modified_by_user = func.now()
|
||||
|
||||
recompute_user_permissions__no_commit(
|
||||
list(set(added_user_ids) | set(removed_user_ids)), db_session
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
return db_user_group
|
||||
|
||||
@@ -862,19 +835,6 @@ def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) ->
|
||||
|
||||
_check_user_group_is_modifiable(db_user_group)
|
||||
|
||||
# Collect affected user IDs before cleanup deletes the relationships
|
||||
affected_user_ids: list[UUID] = [
|
||||
uid
|
||||
for uid in db_session.execute(
|
||||
select(User__UserGroup.user_id).where(
|
||||
User__UserGroup.user_group_id == user_group_id
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
if uid is not None
|
||||
]
|
||||
|
||||
_mark_user_group__cc_pair_relationships_outdated__no_commit(
|
||||
db_session=db_session, user_group_id=user_group_id
|
||||
)
|
||||
@@ -903,10 +863,6 @@ def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) ->
|
||||
db_session=db_session, user_group_id=user_group_id
|
||||
)
|
||||
|
||||
# Recompute permissions for affected users now that their
|
||||
# membership in this group has been removed
|
||||
recompute_user_permissions__no_commit(affected_user_ids, db_session)
|
||||
|
||||
db_user_group.is_up_to_date = False
|
||||
db_user_group.is_up_for_deletion = True
|
||||
db_session.commit()
|
||||
|
||||
@@ -52,25 +52,16 @@ from ee.onyx.server.scim.schema_definitions import SERVICE_PROVIDER_CONFIG
|
||||
from ee.onyx.server.scim.schema_definitions import USER_RESOURCE_TYPE
|
||||
from ee.onyx.server.scim.schema_definitions import USER_SCHEMA_DEF
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.enums import GrantSource
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import ScimToken
|
||||
from onyx.db.models import ScimUserMapping
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.db.permissions import recompute_permissions_for_group__no_commit
|
||||
from onyx.db.permissions import recompute_user_permissions__no_commit
|
||||
from onyx.db.users import assign_user_to_default_groups__no_commit
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Group names reserved for system default groups (seeded by migration).
|
||||
_RESERVED_GROUP_NAMES = frozenset({"Admin", "Basic"})
|
||||
|
||||
|
||||
class ScimJSONResponse(JSONResponse):
|
||||
"""JSONResponse with Content-Type: application/scim+json (RFC 7644 §3.1)."""
|
||||
@@ -495,7 +486,6 @@ def create_user(
|
||||
email=email,
|
||||
hashed_password=_pw_helper.hash(_pw_helper.generate()),
|
||||
role=UserRole.BASIC,
|
||||
account_type=AccountType.STANDARD,
|
||||
is_active=user_resource.active,
|
||||
is_verified=True,
|
||||
personal_name=personal_name,
|
||||
@@ -516,25 +506,13 @@ def create_user(
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
dal.commit()
|
||||
except IntegrityError:
|
||||
dal.rollback()
|
||||
return _scim_error_response(
|
||||
409, f"User with email {email} already has a SCIM mapping"
|
||||
)
|
||||
|
||||
# Assign user to default group BEFORE commit so everything is atomic.
|
||||
# If this fails, the entire user creation rolls back and IdP can retry.
|
||||
try:
|
||||
assign_user_to_default_groups__no_commit(db_session, user)
|
||||
except Exception:
|
||||
dal.rollback()
|
||||
logger.exception(f"Failed to assign SCIM user {email} to default groups")
|
||||
return _scim_error_response(
|
||||
500, f"Failed to assign user {email} to default group"
|
||||
)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return _scim_resource_response(
|
||||
provider.build_user_resource(
|
||||
user,
|
||||
@@ -564,8 +542,7 @@ def replace_user(
|
||||
user = result
|
||||
|
||||
# Handle activation (need seat check) / deactivation
|
||||
is_reactivation = user_resource.active and not user.is_active
|
||||
if is_reactivation:
|
||||
if user_resource.active and not user.is_active:
|
||||
seat_error = _check_seat_availability(dal)
|
||||
if seat_error:
|
||||
return _scim_error_response(403, seat_error)
|
||||
@@ -579,12 +556,6 @@ def replace_user(
|
||||
personal_name=personal_name,
|
||||
)
|
||||
|
||||
# Reconcile default-group membership on reactivation
|
||||
if is_reactivation:
|
||||
assign_user_to_default_groups__no_commit(
|
||||
db_session, user, is_admin=(user.role == UserRole.ADMIN)
|
||||
)
|
||||
|
||||
new_external_id = user_resource.externalId
|
||||
scim_username = user_resource.userName.strip()
|
||||
fields = _fields_from_resource(user_resource)
|
||||
@@ -650,7 +621,6 @@ def patch_user(
|
||||
return _scim_error_response(e.status, e.detail)
|
||||
|
||||
# Apply changes back to the DB model
|
||||
is_reactivation = patched.active and not user.is_active
|
||||
if patched.active != user.is_active:
|
||||
if patched.active:
|
||||
seat_error = _check_seat_availability(dal)
|
||||
@@ -679,12 +649,6 @@ def patch_user(
|
||||
personal_name=personal_name,
|
||||
)
|
||||
|
||||
# Reconcile default-group membership on reactivation
|
||||
if is_reactivation:
|
||||
assign_user_to_default_groups__no_commit(
|
||||
db_session, user, is_admin=(user.role == UserRole.ADMIN)
|
||||
)
|
||||
|
||||
# Build updated fields by merging PATCH enterprise data with current values
|
||||
cf = current_fields or ScimMappingFields()
|
||||
fields = ScimMappingFields(
|
||||
@@ -893,11 +857,6 @@ def create_group(
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
if group_resource.displayName in _RESERVED_GROUP_NAMES:
|
||||
return _scim_error_response(
|
||||
409, f"'{group_resource.displayName}' is a reserved group name."
|
||||
)
|
||||
|
||||
if dal.get_group_by_name(group_resource.displayName):
|
||||
return _scim_error_response(
|
||||
409, f"Group with name '{group_resource.displayName}' already exists"
|
||||
@@ -920,18 +879,8 @@ def create_group(
|
||||
409, f"Group with name '{group_resource.displayName}' already exists"
|
||||
)
|
||||
|
||||
# Every group gets the "basic" permission by default.
|
||||
dal.add_permission_grant_to_group(
|
||||
group_id=db_group.id,
|
||||
permission=Permission.BASIC_ACCESS,
|
||||
grant_source=GrantSource.SYSTEM,
|
||||
)
|
||||
|
||||
dal.upsert_group_members(db_group.id, member_uuids)
|
||||
|
||||
# Recompute permissions for initial members.
|
||||
recompute_user_permissions__no_commit(member_uuids, db_session)
|
||||
|
||||
external_id = group_resource.externalId
|
||||
if external_id:
|
||||
dal.create_group_mapping(external_id=external_id, user_group_id=db_group.id)
|
||||
@@ -962,36 +911,14 @@ def replace_group(
|
||||
return result
|
||||
group = result
|
||||
|
||||
if group.name in _RESERVED_GROUP_NAMES and group_resource.displayName != group.name:
|
||||
return _scim_error_response(
|
||||
409, f"'{group.name}' is a reserved group name and cannot be renamed."
|
||||
)
|
||||
|
||||
if (
|
||||
group_resource.displayName in _RESERVED_GROUP_NAMES
|
||||
and group_resource.displayName != group.name
|
||||
):
|
||||
return _scim_error_response(
|
||||
409, f"'{group_resource.displayName}' is a reserved group name."
|
||||
)
|
||||
|
||||
member_uuids, err = _validate_and_parse_members(group_resource.members, dal)
|
||||
if err:
|
||||
return _scim_error_response(400, err)
|
||||
|
||||
# Capture old member IDs before replacing so we can recompute their
|
||||
# permissions after they are removed from the group.
|
||||
old_member_ids = {uid for uid, _ in dal.get_group_members(group.id)}
|
||||
|
||||
dal.update_group(group, name=group_resource.displayName)
|
||||
dal.replace_group_members(group.id, member_uuids)
|
||||
dal.sync_group_external_id(group.id, group_resource.externalId)
|
||||
|
||||
# Recompute permissions for current members (batch) and removed members.
|
||||
recompute_permissions_for_group__no_commit(group.id, db_session)
|
||||
removed_ids = list(old_member_ids - set(member_uuids))
|
||||
recompute_user_permissions__no_commit(removed_ids, db_session)
|
||||
|
||||
dal.commit()
|
||||
|
||||
members = dal.get_group_members(group.id)
|
||||
@@ -1034,19 +961,8 @@ def patch_group(
|
||||
return _scim_error_response(e.status, e.detail)
|
||||
|
||||
new_name = patched.displayName if patched.displayName != group.name else None
|
||||
|
||||
if group.name in _RESERVED_GROUP_NAMES and new_name:
|
||||
return _scim_error_response(
|
||||
409, f"'{group.name}' is a reserved group name and cannot be renamed."
|
||||
)
|
||||
|
||||
if new_name and new_name in _RESERVED_GROUP_NAMES:
|
||||
return _scim_error_response(409, f"'{new_name}' is a reserved group name.")
|
||||
|
||||
dal.update_group(group, name=new_name)
|
||||
|
||||
affected_uuids: list[UUID] = []
|
||||
|
||||
if added_ids:
|
||||
add_uuids = [UUID(mid) for mid in added_ids if _is_valid_uuid(mid)]
|
||||
if add_uuids:
|
||||
@@ -1057,15 +973,10 @@ def patch_group(
|
||||
f"Member(s) not found: {', '.join(str(u) for u in missing)}",
|
||||
)
|
||||
dal.upsert_group_members(group.id, add_uuids)
|
||||
affected_uuids.extend(add_uuids)
|
||||
|
||||
if removed_ids:
|
||||
remove_uuids = [UUID(mid) for mid in removed_ids if _is_valid_uuid(mid)]
|
||||
dal.remove_group_members(group.id, remove_uuids)
|
||||
affected_uuids.extend(remove_uuids)
|
||||
|
||||
# Recompute permissions for all users whose group membership changed.
|
||||
recompute_user_permissions__no_commit(affected_uuids, db_session)
|
||||
|
||||
dal.sync_group_external_id(group.id, patched.externalId)
|
||||
dal.commit()
|
||||
@@ -1091,21 +1002,11 @@ def delete_group(
|
||||
return result
|
||||
group = result
|
||||
|
||||
if group.name in _RESERVED_GROUP_NAMES:
|
||||
return _scim_error_response(409, f"'{group.name}' is a reserved group name.")
|
||||
|
||||
# Capture member IDs before deletion so we can recompute their permissions.
|
||||
affected_user_ids = [uid for uid, _ in dal.get_group_members(group.id)]
|
||||
|
||||
mapping = dal.get_group_mapping_by_group_id(group.id)
|
||||
if mapping:
|
||||
dal.delete_group_mapping(mapping.id)
|
||||
|
||||
dal.delete_group_with_members(group)
|
||||
|
||||
# Recompute permissions for users who lost this group membership.
|
||||
recompute_user_permissions__no_commit(affected_user_ids, db_session)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return Response(status_code=204)
|
||||
|
||||
@@ -43,16 +43,12 @@ router = APIRouter(prefix="/manage", tags=PUBLIC_API_TAGS)
|
||||
|
||||
@router.get("/admin/user-group")
|
||||
def list_user_groups(
|
||||
include_default: bool = False,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[UserGroup]:
|
||||
if user.role == UserRole.ADMIN:
|
||||
user_groups = fetch_user_groups(
|
||||
db_session,
|
||||
only_up_to_date=False,
|
||||
eager_load_for_snapshot=True,
|
||||
include_default=include_default,
|
||||
db_session, only_up_to_date=False, eager_load_for_snapshot=True
|
||||
)
|
||||
else:
|
||||
user_groups = fetch_user_groups_for_user(
|
||||
@@ -60,50 +56,27 @@ def list_user_groups(
|
||||
user_id=user.id,
|
||||
only_curator_groups=user.role == UserRole.CURATOR,
|
||||
eager_load_for_snapshot=True,
|
||||
include_default=include_default,
|
||||
)
|
||||
return [UserGroup.from_model(user_group) for user_group in user_groups]
|
||||
|
||||
|
||||
@router.get("/user-groups/minimal")
|
||||
def list_minimal_user_groups(
|
||||
include_default: bool = False,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[MinimalUserGroupSnapshot]:
|
||||
if user.role == UserRole.ADMIN:
|
||||
user_groups = fetch_user_groups(
|
||||
db_session,
|
||||
only_up_to_date=False,
|
||||
include_default=include_default,
|
||||
)
|
||||
user_groups = fetch_user_groups(db_session, only_up_to_date=False)
|
||||
else:
|
||||
user_groups = fetch_user_groups_for_user(
|
||||
db_session=db_session,
|
||||
user_id=user.id,
|
||||
include_default=include_default,
|
||||
)
|
||||
return [
|
||||
MinimalUserGroupSnapshot.from_model(user_group) for user_group in user_groups
|
||||
]
|
||||
|
||||
|
||||
@router.get("/admin/user-group/{user_group_id}/permissions")
|
||||
def get_user_group_permissions(
|
||||
user_group_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[str]:
|
||||
group = fetch_user_group(db_session, user_group_id)
|
||||
if group is None:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "User group not found")
|
||||
return [
|
||||
grant.permission.value
|
||||
for grant in group.permission_grants
|
||||
if not grant.is_deleted
|
||||
]
|
||||
|
||||
|
||||
@router.post("/admin/user-group")
|
||||
def create_user_group(
|
||||
user_group: UserGroupCreate,
|
||||
@@ -127,9 +100,6 @@ def rename_user_group_endpoint(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UserGroup:
|
||||
group = fetch_user_group(db_session, rename_request.id)
|
||||
if group and group.is_default:
|
||||
raise OnyxError(OnyxErrorCode.CONFLICT, "Cannot rename a default system group.")
|
||||
try:
|
||||
return UserGroup.from_model(
|
||||
rename_user_group(
|
||||
@@ -215,9 +185,6 @@ def delete_user_group(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
group = fetch_user_group(db_session, user_group_id)
|
||||
if group and group.is_default:
|
||||
raise OnyxError(OnyxErrorCode.CONFLICT, "Cannot delete a default system group.")
|
||||
try:
|
||||
prepare_user_group_for_deletion(db_session, user_group_id)
|
||||
except ValueError as e:
|
||||
|
||||
@@ -22,7 +22,6 @@ class UserGroup(BaseModel):
|
||||
personas: list[PersonaSnapshot]
|
||||
is_up_to_date: bool
|
||||
is_up_for_deletion: bool
|
||||
is_default: bool
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, user_group_model: UserGroupModel) -> "UserGroup":
|
||||
@@ -75,21 +74,18 @@ class UserGroup(BaseModel):
|
||||
],
|
||||
is_up_to_date=user_group_model.is_up_to_date,
|
||||
is_up_for_deletion=user_group_model.is_up_for_deletion,
|
||||
is_default=user_group_model.is_default,
|
||||
)
|
||||
|
||||
|
||||
class MinimalUserGroupSnapshot(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
is_default: bool
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, user_group_model: UserGroupModel) -> "MinimalUserGroupSnapshot":
|
||||
return cls(
|
||||
id=user_group_model.id,
|
||||
name=user_group_model.name,
|
||||
is_default=user_group_model.is_default,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,110 +0,0 @@
|
||||
"""
|
||||
Permission resolution for group-based authorization.
|
||||
|
||||
Granted permissions are stored as a JSONB column on the User table and
|
||||
loaded for free with every auth query. Implied permissions are expanded
|
||||
at read time — only directly granted permissions are persisted.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Coroutine
|
||||
from typing import Any
|
||||
|
||||
from fastapi import Depends
|
||||
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import User
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
ALL_PERMISSIONS: frozenset[str] = frozenset(p.value for p in Permission)
|
||||
|
||||
# Implication map: granted permission -> set of permissions it implies.
|
||||
IMPLIED_PERMISSIONS: dict[str, set[str]] = {
|
||||
Permission.ADD_AGENTS.value: {Permission.READ_AGENTS.value},
|
||||
Permission.MANAGE_AGENTS.value: {
|
||||
Permission.ADD_AGENTS.value,
|
||||
Permission.READ_AGENTS.value,
|
||||
},
|
||||
Permission.MANAGE_DOCUMENT_SETS.value: {
|
||||
Permission.READ_DOCUMENT_SETS.value,
|
||||
Permission.READ_CONNECTORS.value,
|
||||
},
|
||||
Permission.ADD_CONNECTORS.value: {Permission.READ_CONNECTORS.value},
|
||||
Permission.MANAGE_CONNECTORS.value: {
|
||||
Permission.ADD_CONNECTORS.value,
|
||||
Permission.READ_CONNECTORS.value,
|
||||
},
|
||||
Permission.MANAGE_USER_GROUPS.value: {
|
||||
Permission.READ_CONNECTORS.value,
|
||||
Permission.READ_DOCUMENT_SETS.value,
|
||||
Permission.READ_AGENTS.value,
|
||||
Permission.READ_USERS.value,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def resolve_effective_permissions(granted: set[str]) -> set[str]:
|
||||
"""Expand granted permissions with their implied permissions.
|
||||
|
||||
If "admin" is present, returns all 19 permissions.
|
||||
"""
|
||||
if Permission.FULL_ADMIN_PANEL_ACCESS.value in granted:
|
||||
return set(ALL_PERMISSIONS)
|
||||
|
||||
effective = set(granted)
|
||||
changed = True
|
||||
while changed:
|
||||
changed = False
|
||||
for perm in list(effective):
|
||||
implied = IMPLIED_PERMISSIONS.get(perm)
|
||||
if implied and not implied.issubset(effective):
|
||||
effective |= implied
|
||||
changed = True
|
||||
return effective
|
||||
|
||||
|
||||
def get_effective_permissions(user: User) -> set[Permission]:
|
||||
"""Read granted permissions from the column and expand implied permissions."""
|
||||
granted: set[Permission] = set()
|
||||
for p in user.effective_permissions:
|
||||
try:
|
||||
granted.add(Permission(p))
|
||||
except ValueError:
|
||||
logger.warning(f"Skipping unknown permission '{p}' for user {user.id}")
|
||||
if Permission.FULL_ADMIN_PANEL_ACCESS in granted:
|
||||
return set(Permission)
|
||||
expanded = resolve_effective_permissions({p.value for p in granted})
|
||||
return {Permission(p) for p in expanded}
|
||||
|
||||
|
||||
def require_permission(
|
||||
required: Permission,
|
||||
) -> Callable[..., Coroutine[Any, Any, User]]:
|
||||
"""FastAPI dependency factory for permission-based access control.
|
||||
|
||||
Usage:
|
||||
@router.get("/endpoint")
|
||||
def endpoint(user: User = Depends(require_permission(Permission.MANAGE_CONNECTORS))):
|
||||
...
|
||||
"""
|
||||
|
||||
async def dependency(user: User = Depends(current_user)) -> User:
|
||||
effective = get_effective_permissions(user)
|
||||
|
||||
if Permission.FULL_ADMIN_PANEL_ACCESS in effective:
|
||||
return user
|
||||
|
||||
if required not in effective:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
"You do not have the required permissions for this action.",
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
return dependency
|
||||
@@ -5,8 +5,6 @@ from typing import Any
|
||||
from fastapi_users import schemas
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.db.enums import AccountType
|
||||
|
||||
|
||||
class UserRole(str, Enum):
|
||||
"""
|
||||
@@ -43,7 +41,6 @@ class UserRead(schemas.BaseUser[uuid.UUID]):
|
||||
|
||||
class UserCreate(schemas.BaseUserCreate):
|
||||
role: UserRole = UserRole.BASIC
|
||||
account_type: AccountType = AccountType.STANDARD
|
||||
tenant_id: str | None = None
|
||||
# Captcha token for cloud signup protection (optional, only used when captcha is enabled)
|
||||
# Excluded from create_update_dict so it never reaches the DB layer
|
||||
@@ -53,19 +50,19 @@ class UserCreate(schemas.BaseUserCreate):
|
||||
def create_update_dict(self) -> dict[str, Any]:
|
||||
d = super().create_update_dict()
|
||||
d.pop("captcha_token", None)
|
||||
# Force STANDARD for self-registration; only trusted paths
|
||||
# (SCIM, API key creation) supply a different account_type directly.
|
||||
d["account_type"] = AccountType.STANDARD
|
||||
return d
|
||||
|
||||
@override
|
||||
def create_update_dict_superuser(self) -> dict[str, Any]:
|
||||
d = super().create_update_dict_superuser()
|
||||
d.pop("captcha_token", None)
|
||||
d.setdefault("account_type", self.account_type)
|
||||
return d
|
||||
|
||||
|
||||
class UserUpdateWithRole(schemas.BaseUserUpdate):
|
||||
role: UserRole
|
||||
|
||||
|
||||
class UserUpdate(schemas.BaseUserUpdate):
|
||||
"""
|
||||
Role updates are not allowed through the user update endpoint for security reasons
|
||||
|
||||
@@ -80,6 +80,7 @@ from onyx.auth.pat import get_hashed_pat_from_request
|
||||
from onyx.auth.schemas import AuthBackend
|
||||
from onyx.auth.schemas import UserCreate
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.auth.schemas import UserUpdateWithRole
|
||||
from onyx.configs.app_configs import AUTH_BACKEND
|
||||
from onyx.configs.app_configs import AUTH_COOKIE_EXPIRE_TIME_SECONDS
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
@@ -119,13 +120,11 @@ from onyx.db.engine.async_sql_engine import get_async_session
|
||||
from onyx.db.engine.async_sql_engine import get_async_session_context_manager
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import AccessToken
|
||||
from onyx.db.models import OAuthAccount
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.db.pat import fetch_user_for_pat
|
||||
from onyx.db.users import assign_user_to_default_groups__no_commit
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import log_onyx_error
|
||||
@@ -501,21 +500,18 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
user = user_by_session
|
||||
|
||||
if (
|
||||
user.account_type.is_web_login()
|
||||
user.role.is_web_login()
|
||||
or not isinstance(user_create, UserCreate)
|
||||
or not user_create.account_type.is_web_login()
|
||||
or not user_create.role.is_web_login()
|
||||
):
|
||||
raise exceptions.UserAlreadyExists()
|
||||
|
||||
# Cache id before expire — accessing attrs on an expired
|
||||
# object triggers a sync lazy-load which raises MissingGreenlet
|
||||
# in this async context.
|
||||
user_id = user.id
|
||||
self._upgrade_user_to_standard__sync(user_id, user_create)
|
||||
# Expire so the async session re-fetches the row updated by
|
||||
# the sync session above.
|
||||
self.user_db.session.expire(user)
|
||||
user = await self.user_db.get(user_id) # type: ignore[assignment]
|
||||
user_update = UserUpdateWithRole(
|
||||
password=user_create.password,
|
||||
is_verified=user_create.is_verified,
|
||||
role=user_create.role,
|
||||
)
|
||||
user = await self.update(user_update, user)
|
||||
except exceptions.UserAlreadyExists:
|
||||
user = await self.get_by_email(user_create.email)
|
||||
|
||||
@@ -529,21 +525,18 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
|
||||
# Handle case where user has used product outside of web and is now creating an account through web
|
||||
if (
|
||||
user.account_type.is_web_login()
|
||||
user.role.is_web_login()
|
||||
or not isinstance(user_create, UserCreate)
|
||||
or not user_create.account_type.is_web_login()
|
||||
or not user_create.role.is_web_login()
|
||||
):
|
||||
raise exceptions.UserAlreadyExists()
|
||||
|
||||
# Cache id before expire — accessing attrs on an expired
|
||||
# object triggers a sync lazy-load which raises MissingGreenlet
|
||||
# in this async context.
|
||||
user_id = user.id
|
||||
self._upgrade_user_to_standard__sync(user_id, user_create)
|
||||
# Expire so the async session re-fetches the row updated by
|
||||
# the sync session above.
|
||||
self.user_db.session.expire(user)
|
||||
user = await self.user_db.get(user_id) # type: ignore[assignment]
|
||||
user_update = UserUpdateWithRole(
|
||||
password=user_create.password,
|
||||
is_verified=user_create.is_verified,
|
||||
role=user_create.role,
|
||||
)
|
||||
user = await self.update(user_update, user)
|
||||
if user_created:
|
||||
await self._assign_default_pinned_assistants(user, db_session)
|
||||
remove_user_from_invited_users(user_create.email)
|
||||
@@ -580,38 +573,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
)
|
||||
user.pinned_assistants = default_persona_ids
|
||||
|
||||
def _upgrade_user_to_standard__sync(
|
||||
self,
|
||||
user_id: uuid.UUID,
|
||||
user_create: UserCreate,
|
||||
) -> None:
|
||||
"""Upgrade a non-web user to STANDARD and assign default groups atomically.
|
||||
|
||||
All writes happen in a single sync transaction so neither the field
|
||||
update nor the group assignment is visible without the other.
|
||||
"""
|
||||
with get_session_with_current_tenant() as sync_db:
|
||||
sync_user = sync_db.query(User).filter(User.id == user_id).first() # type: ignore[arg-type]
|
||||
if sync_user:
|
||||
sync_user.hashed_password = self.password_helper.hash(
|
||||
user_create.password
|
||||
)
|
||||
sync_user.is_verified = user_create.is_verified or False
|
||||
sync_user.role = user_create.role
|
||||
sync_user.account_type = AccountType.STANDARD
|
||||
assign_user_to_default_groups__no_commit(
|
||||
sync_db,
|
||||
sync_user,
|
||||
is_admin=(user_create.role == UserRole.ADMIN),
|
||||
)
|
||||
sync_db.commit()
|
||||
else:
|
||||
logger.warning(
|
||||
"User %s not found in sync session during upgrade to standard; "
|
||||
"skipping upgrade",
|
||||
user_id,
|
||||
)
|
||||
|
||||
async def validate_password(self, password: str, _: schemas.UC | models.UP) -> None:
|
||||
# Validate password according to configurable security policy (defined via environment variables)
|
||||
if len(password) < PASSWORD_MIN_LENGTH:
|
||||
@@ -733,7 +694,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
"email": account_email,
|
||||
"hashed_password": self.password_helper.hash(password),
|
||||
"is_verified": is_verified_by_default,
|
||||
"account_type": AccountType.STANDARD,
|
||||
}
|
||||
|
||||
user = await self.user_db.create(user_dict)
|
||||
@@ -766,7 +726,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
)
|
||||
|
||||
# Handle case where user has used product outside of web and is now creating an account through web
|
||||
if not user.account_type.is_web_login():
|
||||
if not user.role.is_web_login():
|
||||
# We must use the existing user in the session if it matches
|
||||
# the user we just got by email/oauth. Note that this only applies
|
||||
# to multi-tenant, due to the overwriting of the user_db
|
||||
@@ -783,25 +743,14 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
with get_session_with_current_tenant() as sync_db:
|
||||
enforce_seat_limit(sync_db)
|
||||
|
||||
# Upgrade the user and assign default groups in a single
|
||||
# transaction so neither change is visible without the other.
|
||||
was_inactive = not user.is_active
|
||||
with get_session_with_current_tenant() as sync_db:
|
||||
sync_user = sync_db.query(User).filter(User.id == user.id).first() # type: ignore[arg-type]
|
||||
if sync_user:
|
||||
sync_user.is_verified = is_verified_by_default
|
||||
sync_user.role = UserRole.BASIC
|
||||
sync_user.account_type = AccountType.STANDARD
|
||||
if was_inactive:
|
||||
sync_user.is_active = True
|
||||
assign_user_to_default_groups__no_commit(sync_db, sync_user)
|
||||
sync_db.commit()
|
||||
|
||||
# Refresh the async user object so downstream code
|
||||
# (e.g. oidc_expiry check) sees the updated fields.
|
||||
self.user_db.session.expire(user)
|
||||
user = await self.user_db.get(user.id)
|
||||
assert user is not None
|
||||
await self.user_db.update(
|
||||
user,
|
||||
{
|
||||
"is_verified": is_verified_by_default,
|
||||
"role": UserRole.BASIC,
|
||||
**({"is_active": True} if not user.is_active else {}),
|
||||
},
|
||||
)
|
||||
|
||||
# this is needed if an organization goes from `TRACK_EXTERNAL_IDP_EXPIRY=true` to `false`
|
||||
# otherwise, the oidc expiry will always be old, and the user will never be able to login
|
||||
@@ -887,16 +836,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
event=MilestoneRecordType.TENANT_CREATED,
|
||||
)
|
||||
|
||||
# Assign user to the appropriate default group (Admin or Basic).
|
||||
# Must happen inside the try block while tenant context is active,
|
||||
# otherwise get_session_with_current_tenant() targets the wrong schema.
|
||||
is_admin = user_count == 1 or user.email in get_default_admin_user_emails()
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
assign_user_to_default_groups__no_commit(
|
||||
db_session, user, is_admin=is_admin
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
@@ -1036,7 +975,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
self.password_helper.hash(credentials.password)
|
||||
return None
|
||||
|
||||
if not user.account_type.is_web_login():
|
||||
if not user.role.is_web_login():
|
||||
raise BasicAuthenticationError(
|
||||
detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD",
|
||||
)
|
||||
@@ -1532,7 +1471,7 @@ async def _get_or_create_user_from_jwt(
|
||||
if not user.is_active:
|
||||
logger.warning("Inactive user %s attempted JWT login; skipping", email)
|
||||
return None
|
||||
if not user.account_type.is_web_login():
|
||||
if not user.role.is_web_login():
|
||||
raise exceptions.UserNotExists()
|
||||
except exceptions.UserNotExists:
|
||||
logger.info("Provisioning user %s from JWT login", email)
|
||||
@@ -1553,7 +1492,7 @@ async def _get_or_create_user_from_jwt(
|
||||
email,
|
||||
)
|
||||
return None
|
||||
if not user.account_type.is_web_login():
|
||||
if not user.role.is_web_login():
|
||||
logger.warning(
|
||||
"Non-web-login user %s attempted JWT login during provisioning race; skipping",
|
||||
email,
|
||||
@@ -1615,7 +1554,6 @@ def get_anonymous_user() -> User:
|
||||
is_verified=True,
|
||||
is_superuser=False,
|
||||
role=UserRole.LIMITED,
|
||||
account_type=AccountType.ANONYMOUS,
|
||||
use_memories=False,
|
||||
enable_memory_tool=False,
|
||||
)
|
||||
|
||||
@@ -56,6 +56,7 @@ Then it cycles through its tasks as scheduled by Celery Beat:
|
||||
| `check_for_user_file_processing` | 20s | Checks for user uploads → dispatches to `USER_FILE_PROCESSING` queue |
|
||||
| `check_for_checkpoint_cleanup` | 1h | Cleans up old indexing checkpoints |
|
||||
| `check_for_index_attempt_cleanup` | 30m | Cleans up old index attempts |
|
||||
| `kombu_message_cleanup_task` | periodic | Cleans orphaned Kombu messages from DB (Kombu being the messaging framework used by Celery) |
|
||||
| `celery_beat_heartbeat` | 1m | Heartbeat for Beat watchdog |
|
||||
|
||||
Watchdog is a separate Python process managed by supervisord which runs alongside celery workers. It checks the ONYX_CELERY_BEAT_HEARTBEAT_KEY in
|
||||
|
||||
@@ -317,6 +317,7 @@ celery_app.autodiscover_tasks(
|
||||
"onyx.background.celery.tasks.docprocessing",
|
||||
"onyx.background.celery.tasks.evals",
|
||||
"onyx.background.celery.tasks.hierarchyfetching",
|
||||
"onyx.background.celery.tasks.periodic",
|
||||
"onyx.background.celery.tasks.pruning",
|
||||
"onyx.background.celery.tasks.shared",
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
|
||||
@@ -302,7 +302,7 @@ beat_cloud_tasks: list[dict] = [
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-available-tenants",
|
||||
"task": OnyxCeleryTask.CLOUD_CHECK_AVAILABLE_TENANTS,
|
||||
"schedule": timedelta(minutes=2),
|
||||
"schedule": timedelta(minutes=10),
|
||||
"options": {
|
||||
"queue": OnyxCeleryQueues.MONITORING,
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
|
||||
@@ -36,7 +36,6 @@ from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.opensearch_migration import build_sanitized_to_original_doc_id_mapping
|
||||
from onyx.db.opensearch_migration import get_vespa_visit_state
|
||||
from onyx.db.opensearch_migration import is_migration_completed
|
||||
from onyx.db.opensearch_migration import (
|
||||
mark_migration_completed_time_if_not_set_with_commit,
|
||||
)
|
||||
@@ -107,19 +106,14 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
acquired; effectively a no-op. True if the task completed
|
||||
successfully. False if the task errored.
|
||||
"""
|
||||
# 1. Check if we should run the task.
|
||||
# 1.a. If OpenSearch indexing is disabled, we don't run the task.
|
||||
if not ENABLE_OPENSEARCH_INDEXING_FOR_ONYX:
|
||||
task_logger.warning(
|
||||
"OpenSearch migration is not enabled, skipping chunk migration task."
|
||||
)
|
||||
return None
|
||||
|
||||
task_logger.info("Starting chunk-level migration from Vespa to OpenSearch.")
|
||||
task_start_time = time.monotonic()
|
||||
|
||||
# 1.b. Only one instance per tenant of this task may run concurrently at
|
||||
# once. If we fail to acquire a lock, we assume it is because another task
|
||||
# has one and we exit.
|
||||
r = get_redis_client()
|
||||
lock: RedisLock = r.lock(
|
||||
name=OnyxRedisLocks.OPENSEARCH_MIGRATION_BEAT_LOCK,
|
||||
@@ -142,11 +136,10 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
f"Token: {lock.local.token}"
|
||||
)
|
||||
|
||||
# 2. Prepare to migrate.
|
||||
total_chunks_migrated_this_task = 0
|
||||
total_chunks_errored_this_task = 0
|
||||
try:
|
||||
# 2.a. Double-check that tenant info is correct.
|
||||
# Double check that tenant info is correct.
|
||||
if tenant_id != get_current_tenant_id():
|
||||
err_str = (
|
||||
f"Tenant ID mismatch in the OpenSearch migration task: "
|
||||
@@ -155,62 +148,16 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
task_logger.error(err_str)
|
||||
return False
|
||||
|
||||
# Do as much as we can with a DB session in one spot to not hold a
|
||||
# session during a migration batch.
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# 2.b. Immediately check to see if this tenant is done, to save
|
||||
# having to do any other work. This function does not require a
|
||||
# migration record to necessarily exist.
|
||||
if is_migration_completed(db_session):
|
||||
return True
|
||||
|
||||
# 2.c. Try to insert the OpenSearchTenantMigrationRecord table if it
|
||||
# does not exist.
|
||||
with (
|
||||
get_session_with_current_tenant() as db_session,
|
||||
get_vespa_http_client(
|
||||
timeout=VESPA_MIGRATION_REQUEST_TIMEOUT_S
|
||||
) as vespa_client,
|
||||
):
|
||||
try_insert_opensearch_tenant_migration_record_with_commit(db_session)
|
||||
|
||||
# 2.d. Get search settings.
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
indexing_setting = IndexingSetting.from_db_model(search_settings)
|
||||
|
||||
# 2.e. Build sanitized to original doc ID mapping to check for
|
||||
# conflicts in the event we sanitize a doc ID to an
|
||||
# already-existing doc ID.
|
||||
# We reconstruct this mapping for every task invocation because
|
||||
# a document may have been added in the time between two tasks.
|
||||
sanitized_doc_start_time = time.monotonic()
|
||||
sanitized_to_original_doc_id_mapping = (
|
||||
build_sanitized_to_original_doc_id_mapping(db_session)
|
||||
)
|
||||
task_logger.debug(
|
||||
f"Built sanitized_to_original_doc_id_mapping with {len(sanitized_to_original_doc_id_mapping)} entries "
|
||||
f"in {time.monotonic() - sanitized_doc_start_time:.3f} seconds."
|
||||
)
|
||||
|
||||
# 2.f. Get the current migration state.
|
||||
continuation_token_map, total_chunks_migrated = get_vespa_visit_state(
|
||||
db_session
|
||||
)
|
||||
# 2.f.1. Double-check that the migration state does not imply
|
||||
# completion. Really we should never have to enter this block as we
|
||||
# would expect is_migration_completed to return True, but in the
|
||||
# strange event that the migration is complete but the migration
|
||||
# completed time was never stamped, we do so here.
|
||||
if is_continuation_token_done_for_all_slices(continuation_token_map):
|
||||
task_logger.info(
|
||||
f"OpenSearch migration COMPLETED for tenant {tenant_id}. Total chunks migrated: {total_chunks_migrated}."
|
||||
)
|
||||
mark_migration_completed_time_if_not_set_with_commit(db_session)
|
||||
return True
|
||||
task_logger.debug(
|
||||
f"Read the tenant migration record. Total chunks migrated: {total_chunks_migrated}. "
|
||||
f"Continuation token map: {continuation_token_map}"
|
||||
)
|
||||
|
||||
with get_vespa_http_client(
|
||||
timeout=VESPA_MIGRATION_REQUEST_TIMEOUT_S
|
||||
) as vespa_client:
|
||||
# 2.g. Create the OpenSearch and Vespa document indexes.
|
||||
tenant_state = TenantState(tenant_id=tenant_id, multitenant=MULTI_TENANT)
|
||||
indexing_setting = IndexingSetting.from_db_model(search_settings)
|
||||
opensearch_document_index = OpenSearchDocumentIndex(
|
||||
tenant_state=tenant_state,
|
||||
index_name=search_settings.index_name,
|
||||
@@ -224,14 +171,22 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
httpx_client=vespa_client,
|
||||
)
|
||||
|
||||
# 2.h. Get the approximate chunk count in Vespa as of this time to
|
||||
# update the migration record.
|
||||
sanitized_doc_start_time = time.monotonic()
|
||||
# We reconstruct this mapping for every task invocation because a
|
||||
# document may have been added in the time between two tasks.
|
||||
sanitized_to_original_doc_id_mapping = (
|
||||
build_sanitized_to_original_doc_id_mapping(db_session)
|
||||
)
|
||||
task_logger.debug(
|
||||
f"Built sanitized_to_original_doc_id_mapping with {len(sanitized_to_original_doc_id_mapping)} entries "
|
||||
f"in {time.monotonic() - sanitized_doc_start_time:.3f} seconds."
|
||||
)
|
||||
|
||||
approx_chunk_count_in_vespa: int | None = None
|
||||
get_chunk_count_start_time = time.monotonic()
|
||||
try:
|
||||
approx_chunk_count_in_vespa = vespa_document_index.get_chunk_count()
|
||||
except Exception:
|
||||
# This failure should not be blocking.
|
||||
task_logger.exception(
|
||||
"Error getting approximate chunk count in Vespa. Moving on..."
|
||||
)
|
||||
@@ -240,12 +195,25 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
f"approximate chunk count in Vespa. Got {approx_chunk_count_in_vespa}."
|
||||
)
|
||||
|
||||
# 3. Do the actual migration in batches until we run out of time.
|
||||
while (
|
||||
time.monotonic() - task_start_time < MIGRATION_TASK_SOFT_TIME_LIMIT_S
|
||||
and lock.owned()
|
||||
):
|
||||
# 3.a. Get the next batch of raw chunks from Vespa.
|
||||
(
|
||||
continuation_token_map,
|
||||
total_chunks_migrated,
|
||||
) = get_vespa_visit_state(db_session)
|
||||
if is_continuation_token_done_for_all_slices(continuation_token_map):
|
||||
task_logger.info(
|
||||
f"OpenSearch migration COMPLETED for tenant {tenant_id}. Total chunks migrated: {total_chunks_migrated}."
|
||||
)
|
||||
mark_migration_completed_time_if_not_set_with_commit(db_session)
|
||||
break
|
||||
task_logger.debug(
|
||||
f"Read the tenant migration record. Total chunks migrated: {total_chunks_migrated}. "
|
||||
f"Continuation token map: {continuation_token_map}"
|
||||
)
|
||||
|
||||
get_vespa_chunks_start_time = time.monotonic()
|
||||
raw_vespa_chunks, next_continuation_token_map = (
|
||||
vespa_document_index.get_all_raw_document_chunks_paginated(
|
||||
@@ -258,7 +226,6 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
f"seconds. Next continuation token map: {next_continuation_token_map}"
|
||||
)
|
||||
|
||||
# 3.b. Transform the raw chunks to OpenSearch chunks in memory.
|
||||
opensearch_document_chunks, errored_chunks = (
|
||||
transform_vespa_chunks_to_opensearch_chunks(
|
||||
raw_vespa_chunks,
|
||||
@@ -273,7 +240,6 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
"errored."
|
||||
)
|
||||
|
||||
# 3.c. Index the OpenSearch chunks into OpenSearch.
|
||||
index_opensearch_chunks_start_time = time.monotonic()
|
||||
opensearch_document_index.index_raw_chunks(
|
||||
chunks=opensearch_document_chunks
|
||||
@@ -285,38 +251,12 @@ def migrate_chunks_from_vespa_to_opensearch_task(
|
||||
|
||||
total_chunks_migrated_this_task += len(opensearch_document_chunks)
|
||||
total_chunks_errored_this_task += len(errored_chunks)
|
||||
|
||||
# Do as much as we can with a DB session in one spot to not hold a
|
||||
# session during a migration batch.
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# 3.d. Update the migration state.
|
||||
update_vespa_visit_progress_with_commit(
|
||||
db_session,
|
||||
continuation_token_map=next_continuation_token_map,
|
||||
chunks_processed=len(opensearch_document_chunks),
|
||||
chunks_errored=len(errored_chunks),
|
||||
approx_chunk_count_in_vespa=approx_chunk_count_in_vespa,
|
||||
)
|
||||
|
||||
# 3.e. Get the current migration state. Even thought we
|
||||
# technically have it in-memory since we just wrote it, we
|
||||
# want to reference the DB as the source of truth at all
|
||||
# times.
|
||||
continuation_token_map, total_chunks_migrated = (
|
||||
get_vespa_visit_state(db_session)
|
||||
)
|
||||
# 3.e.1. Check if the migration is done.
|
||||
if is_continuation_token_done_for_all_slices(
|
||||
continuation_token_map
|
||||
):
|
||||
task_logger.info(
|
||||
f"OpenSearch migration COMPLETED for tenant {tenant_id}. Total chunks migrated: {total_chunks_migrated}."
|
||||
)
|
||||
mark_migration_completed_time_if_not_set_with_commit(db_session)
|
||||
return True
|
||||
task_logger.debug(
|
||||
f"Read the tenant migration record. Total chunks migrated: {total_chunks_migrated}. "
|
||||
f"Continuation token map: {continuation_token_map}"
|
||||
update_vespa_visit_progress_with_commit(
|
||||
db_session,
|
||||
continuation_token_map=next_continuation_token_map,
|
||||
chunks_processed=len(opensearch_document_chunks),
|
||||
chunks_errored=len(errored_chunks),
|
||||
approx_chunk_count_in_vespa=approx_chunk_count_in_vespa,
|
||||
)
|
||||
except Exception:
|
||||
traceback.print_exc()
|
||||
|
||||
138
backend/onyx/background/celery/tasks/periodic/tasks.py
Normal file
138
backend/onyx/background/celery/tasks/periodic/tasks.py
Normal file
@@ -0,0 +1,138 @@
|
||||
#####
|
||||
# Periodic Tasks
|
||||
#####
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
from celery import shared_task
|
||||
from celery.contrib.abortable import AbortableTask # type: ignore
|
||||
from celery.exceptions import TaskRevokedError
|
||||
from sqlalchemy import inspect
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import PostgresAdvisoryLocks
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.KOMBU_MESSAGE_CLEANUP_TASK,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
bind=True,
|
||||
base=AbortableTask,
|
||||
)
|
||||
def kombu_message_cleanup_task(self: Any, tenant_id: str) -> int: # noqa: ARG001
|
||||
"""Runs periodically to clean up the kombu_message table"""
|
||||
|
||||
# we will select messages older than this amount to clean up
|
||||
KOMBU_MESSAGE_CLEANUP_AGE = 7 # days
|
||||
KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT = 1000
|
||||
|
||||
ctx = {}
|
||||
ctx["last_processed_id"] = 0
|
||||
ctx["deleted"] = 0
|
||||
ctx["cleanup_age"] = KOMBU_MESSAGE_CLEANUP_AGE
|
||||
ctx["page_limit"] = KOMBU_MESSAGE_CLEANUP_PAGE_LIMIT
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# Exit the task if we can't take the advisory lock
|
||||
result = db_session.execute(
|
||||
text("SELECT pg_try_advisory_lock(:id)"),
|
||||
{"id": PostgresAdvisoryLocks.KOMBU_MESSAGE_CLEANUP_LOCK_ID.value},
|
||||
).scalar()
|
||||
if not result:
|
||||
return 0
|
||||
|
||||
while True:
|
||||
if self.is_aborted():
|
||||
raise TaskRevokedError("kombu_message_cleanup_task was aborted.")
|
||||
|
||||
b = kombu_message_cleanup_task_helper(ctx, db_session)
|
||||
if not b:
|
||||
break
|
||||
|
||||
db_session.commit()
|
||||
|
||||
if ctx["deleted"] > 0:
|
||||
task_logger.info(
|
||||
f"Deleted {ctx['deleted']} orphaned messages from kombu_message."
|
||||
)
|
||||
|
||||
return ctx["deleted"]
|
||||
|
||||
|
||||
def kombu_message_cleanup_task_helper(ctx: dict, db_session: Session) -> bool:
|
||||
"""
|
||||
Helper function to clean up old messages from the `kombu_message` table that are no longer relevant.
|
||||
|
||||
This function retrieves messages from the `kombu_message` table that are no longer visible and
|
||||
older than a specified interval. It checks if the corresponding task_id exists in the
|
||||
`celery_taskmeta` table. If the task_id does not exist, the message is deleted.
|
||||
|
||||
Args:
|
||||
ctx (dict): A context dictionary containing configuration parameters such as:
|
||||
- 'cleanup_age' (int): The age in days after which messages are considered old.
|
||||
- 'page_limit' (int): The maximum number of messages to process in one batch.
|
||||
- 'last_processed_id' (int): The ID of the last processed message to handle pagination.
|
||||
- 'deleted' (int): A counter to track the number of deleted messages.
|
||||
db_session (Session): The SQLAlchemy database session for executing queries.
|
||||
|
||||
Returns:
|
||||
bool: Returns True if there are more rows to process, False if not.
|
||||
"""
|
||||
|
||||
inspector = inspect(db_session.bind)
|
||||
if not inspector:
|
||||
return False
|
||||
|
||||
# With the move to redis as celery's broker and backend, kombu tables may not even exist.
|
||||
# We can fail silently.
|
||||
if not inspector.has_table("kombu_message"):
|
||||
return False
|
||||
|
||||
query = text(
|
||||
"""
|
||||
SELECT id, timestamp, payload
|
||||
FROM kombu_message WHERE visible = 'false'
|
||||
AND timestamp < CURRENT_TIMESTAMP - INTERVAL :interval_days
|
||||
AND id > :last_processed_id
|
||||
ORDER BY id
|
||||
LIMIT :page_limit
|
||||
"""
|
||||
)
|
||||
kombu_messages = db_session.execute(
|
||||
query,
|
||||
{
|
||||
"interval_days": f"{ctx['cleanup_age']} days",
|
||||
"page_limit": ctx["page_limit"],
|
||||
"last_processed_id": ctx["last_processed_id"],
|
||||
},
|
||||
).fetchall()
|
||||
|
||||
if len(kombu_messages) == 0:
|
||||
return False
|
||||
|
||||
for msg in kombu_messages:
|
||||
payload = json.loads(msg[2])
|
||||
task_id = payload["headers"]["id"]
|
||||
|
||||
# Check if task_id exists in celery_taskmeta
|
||||
task_exists = db_session.execute(
|
||||
text("SELECT 1 FROM celery_taskmeta WHERE task_id = :task_id"),
|
||||
{"task_id": task_id},
|
||||
).fetchone()
|
||||
|
||||
# If task_id does not exist, delete the message
|
||||
if not task_exists:
|
||||
result = db_session.execute(
|
||||
text("DELETE FROM kombu_message WHERE id = :message_id"),
|
||||
{"message_id": msg[0]},
|
||||
)
|
||||
if result.rowcount > 0: # type: ignore
|
||||
ctx["deleted"] += 1
|
||||
|
||||
ctx["last_processed_id"] = msg[0]
|
||||
|
||||
return True
|
||||
@@ -1,10 +1,5 @@
|
||||
# Overview of Context Management
|
||||
|
||||
This document reviews some design decisions around the main agent-loop powering Onyx's chat flow.
|
||||
It is highly recommended for all engineers contributing to this flow to be familiar with the concepts here.
|
||||
|
||||
> Note: it is assumed the reader is familiar with the Onyx product and features such as Projects, User files, Citations, etc.
|
||||
|
||||
## System Prompt
|
||||
|
||||
The system prompt is a default prompt that comes packaged with the system. Users can edit the default prompt and it will be persisted in the database.
|
||||
@@ -46,9 +41,9 @@ the system can RAG over the project files.
|
||||
|
||||
## How documents are represented
|
||||
|
||||
Documents from search or uploaded Project files are represented as a json so that the LLM can easily understand it. It is represented with a prefix string to
|
||||
make the context clearer to the LLM. Note that for search results (whether web or internal, it will just be the json) and it will be a Tool Call type of
|
||||
message rather than a user message.
|
||||
Documents from search or uploaded Project files are represented as a json so that the LLM can easily understand it. It is represented with a prefix to make the
|
||||
context clearer to the LLM. Note that for search results (whether web or internal, it will just be the json) and it will be a Tool Call type of message
|
||||
rather than a user message.
|
||||
|
||||
```
|
||||
Here are some documents provided for context, they may not all be relevant:
|
||||
@@ -60,12 +55,12 @@ Here are some documents provided for context, they may not all be relevant:
|
||||
}
|
||||
```
|
||||
|
||||
Documents are represented with the `document` key so that the LLM can easily cite them with a single number. The tool returns have to be richer to be able to
|
||||
Documents are represented with document so that the LLM can easily cite them with a single number. The tool returns have to be richer to be able to
|
||||
translate this into links and other UI elements. What the LLM sees is far simpler to reduce noise/hallucinations.
|
||||
|
||||
Note that documents included in a single turn should be collapsed into a single user message.
|
||||
|
||||
Search tools also give URLs to the LLM so that open_url (a separate tool) can be called on them.
|
||||
Search tools give URLs to the LLM though so that open_url (a separate tool) can be called on them.
|
||||
|
||||
## Reminders
|
||||
|
||||
@@ -77,13 +72,10 @@ If a search related tool is called at any point during the turn, the reminder wi
|
||||
|
||||
## Tool Calls
|
||||
|
||||
As tool call responses can get very long (like an internal search can be many thousands of tokens), tool responses are current replaced with a hardcoded
|
||||
As tool call responses can get very long (like an internal search can be many thousands of tokens), tool responses are today replaced with a hardcoded
|
||||
string saying it is no longer available. Tool Call details like the search query and other arguments are kept in the history as this is information
|
||||
rich and generally very few tokens.
|
||||
|
||||
> Note: in the Internal Search flow with query expansion, the Tool Call which was actually run differs from what the LLM provided as arguments.
|
||||
> What the LLM sees in the history (to be most informative for future calls) is the full set of expanded queries.
|
||||
|
||||
**Possible Future Extension**:
|
||||
Instead of dropping the Tool Call response, we might summarize it using an LLM so that it is just 1-2 sentences and captures the main points. That said,
|
||||
this is questionable value add because anything relevant and useful should be already captured in the Agent response.
|
||||
@@ -111,7 +103,7 @@ Flow with Project and File Upload
|
||||
S, CA, P, F, U1, A1 -- user sends another message -> S, F, U1, A1, CA, P, U2, A2
|
||||
- File stays in place, above the user message
|
||||
- Project files move along the chain as new messages are sent
|
||||
- Custom Agent prompt comes before project files which come before user uploaded files in each turn
|
||||
- Custom Agent prompt comes before project files which comes before user uploaded files in each turn
|
||||
|
||||
Reminders during a single Turn
|
||||
S, U1, TC, TR, R -- agent calls another tool -> S, U1, TC, TR, TC, TR, R, A1
|
||||
@@ -132,7 +124,7 @@ and should be very targetted for it to work reliably and also not interfere with
|
||||
|
||||
## Reasons / Experiments
|
||||
|
||||
Custom Agent instructions being placed in the system prompt is poorly followed. It also degrades performance of the system especially when the instructions
|
||||
Custom Agent instructions being placed in the system prompt is poorly followed. It also degrade performance of the system especially when the instructions
|
||||
are orthogonal (or even possibly contradictory) to the system prompt. For weaker models, it causes strange artifacts in tool calls and final responses
|
||||
that completely ruins the user experience. Empirically, this way works better across a range of models especially when the history gets longer.
|
||||
Having the Custom Agent instructions not move means it fades more as the chat gets long which is also not ok from a UX perspective.
|
||||
@@ -159,7 +151,7 @@ In a similar concept, LLM instructions in the system prompt are structured speci
|
||||
fairly surprising actually but if there is a line of instructions effectively saying "If you try to use some tools and find that you need more information or
|
||||
need to call additional tools, you are encouraged to do this", having this in the Tool section of the System prompt makes all the LLMs follow it well but if it's
|
||||
even just a paragraph away like near the beginning of the prompt, it is often ignored. The difference is as drastic as a 30% follow rate to a 90% follow
|
||||
rate by even just moving the same statement a few sentences.
|
||||
rate even just moving the same statement a few sentences.
|
||||
|
||||
## Other related pointers
|
||||
|
||||
@@ -243,9 +235,8 @@ tool calls and returns that to the LLM Loop to execute.
|
||||
concept of a turn. The turn_index for the frontend is which block does this packet belong to. So while a reasoning + tool call
|
||||
comes from the same LLM inference (same backend LLM step), they are 2 turns to the frontend because that's how it's rendered.
|
||||
|
||||
- There are 3 representations of a message, each scoped to a different layer:
|
||||
1. **ChatMessage** — The database model. Should be converted into ChatMessageSimple early and never passed deep into the flow.
|
||||
2. **ChatMessageSimple** — The canonical data model used throughout the codebase. This is the rich, full-featured representation
|
||||
of a message. Any modifications or additions to message structure should be made here.
|
||||
3. **LanguageModelInput** — The LLM-facing representation. Intentionally minimal so the LLM interface layer stays clean and
|
||||
easy to maintain/extend.
|
||||
- There are 3 representations of "message". The first is the database model ChatMessage, this one should be translated away and
|
||||
not used deep into the flow. The second is ChatMessageSimple which is the data model which should be used throughout the code
|
||||
as much as possible. If modifications/additions are needed, it should be to this object. This is the rich representation of a
|
||||
message for the code. Finally there is the LanguageModelInput representation of a message. This one is for the LLM interface
|
||||
layer and is as stripped down as possible so that the LLM interface can be clean and easy to maintain/extend.
|
||||
|
||||
@@ -30,7 +30,7 @@ class Emitter:
|
||||
self._drain_done = drain_done
|
||||
|
||||
def emit(self, packet: Packet) -> None:
|
||||
if self._drain_done is not None and self._drain_done.is_set():
|
||||
if self._drain_done and self._drain_done.is_set():
|
||||
return
|
||||
base = packet.placement or Placement(turn_index=0)
|
||||
tagged = Packet(
|
||||
|
||||
@@ -3,7 +3,7 @@ IMPORTANT: familiarize yourself with the design concepts prior to contributing t
|
||||
An overview can be found in the README.md file in this directory.
|
||||
"""
|
||||
|
||||
import contextvars
|
||||
import functools
|
||||
import io
|
||||
import queue
|
||||
import re
|
||||
@@ -11,9 +11,7 @@ import threading
|
||||
import traceback
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from contextvars import Token
|
||||
from typing import Final
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -66,7 +64,6 @@ from onyx.db.chat import create_new_chat_message
|
||||
from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_or_create_root_message
|
||||
from onyx.db.chat import reserve_message_id
|
||||
from onyx.db.chat import reserve_multi_model_message_ids
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import HookPoint
|
||||
from onyx.db.memory import get_memories
|
||||
@@ -94,7 +91,6 @@ from onyx.llm.factory import get_llm_for_persona
|
||||
from onyx.llm.factory import get_llm_token_counter
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMUserIdentity
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.llm.request_context import reset_llm_mock_response
|
||||
from onyx.llm.request_context import set_llm_mock_response
|
||||
from onyx.llm.utils import litellm_exception_to_error_msg
|
||||
@@ -102,8 +98,6 @@ from onyx.onyxbot.slack.models import SlackContext
|
||||
from onyx.server.query_and_chat.chat_utils import mime_type_to_chat_file_type
|
||||
from onyx.server.query_and_chat.models import AUTO_PLACE_AFTER_LATEST_MESSAGE
|
||||
from onyx.server.query_and_chat.models import MessageResponseIDInfo
|
||||
from onyx.server.query_and_chat.models import ModelResponseSlot
|
||||
from onyx.server.query_and_chat.models import MultiModelMessageResponseIDInfo
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import AgentResponseDelta
|
||||
@@ -122,11 +116,13 @@ from onyx.tools.tool_constructor import FileReaderToolConfig
|
||||
from onyx.tools.tool_constructor import SearchToolConfig
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import mt_cloud_telemetry
|
||||
from onyx.utils.threadpool_concurrency import run_multiple_in_background
|
||||
from onyx.utils.timing import log_function_time
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
ERROR_TYPE_CANCELLED = "cancelled"
|
||||
|
||||
APPROX_CHARS_PER_TOKEN = 4
|
||||
|
||||
|
||||
@@ -486,8 +482,6 @@ def build_chat_turn(
|
||||
new_msg_req: SendMessageRequest,
|
||||
user: User,
|
||||
db_session: Session,
|
||||
# None → single-model (persona default LLM); non-empty list → multi-model (one LLM per override)
|
||||
llm_overrides: list[LLMOverride] | None,
|
||||
*,
|
||||
litellm_additional_headers: dict[str, str] | None = None,
|
||||
custom_tool_additional_headers: dict[str, str] | None = None,
|
||||
@@ -500,23 +494,21 @@ def build_chat_turn(
|
||||
# NOTE: not stored in the database, only passed in to the LLM as context
|
||||
additional_context: str | None = None,
|
||||
) -> Generator[AnswerStreamPart, None, ChatTurnSetup]:
|
||||
"""Shared setup generator for both single-model and multi-model chat turns.
|
||||
"""Setup generator for a single-model chat turn.
|
||||
|
||||
Yields the packet(s) the frontend needs for request tracking, then returns an
|
||||
immutable ``ChatTurnSetup`` containing everything the execution strategy needs.
|
||||
|
||||
Callers use::
|
||||
|
||||
setup = yield from build_chat_turn(new_msg_req, ..., llm_overrides=...)
|
||||
setup = yield from build_chat_turn(new_msg_req, ...)
|
||||
|
||||
to forward yielded packets upstream while receiving the return value locally.
|
||||
|
||||
Args:
|
||||
llm_overrides: ``None`` → single-model (persona default LLM).
|
||||
Non-empty list → multi-model (one LLM per override).
|
||||
"""
|
||||
# TODO(nmgarza5): Consider refactoring so that yields move to handle_stream_message_objects
|
||||
# and build_chat_turn becomes a plain function returning ChatTurnSetup. This would make
|
||||
# the generator pattern (yield from build_chat_turn) unnecessary and easier to reason about.
|
||||
tenant_id = get_current_tenant_id()
|
||||
is_multi = bool(llm_overrides)
|
||||
|
||||
user_id = user.id
|
||||
llm_user_identifier = (
|
||||
@@ -527,25 +519,22 @@ def build_chat_turn(
|
||||
if not new_msg_req.chat_session_id:
|
||||
if not new_msg_req.chat_session_info:
|
||||
raise RuntimeError("Must specify a chat session id or chat session info")
|
||||
chat_session = create_chat_session_from_request(
|
||||
new_session = create_chat_session_from_request(
|
||||
chat_session_request=new_msg_req.chat_session_info,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
yield CreateChatSessionID(chat_session_id=chat_session.id)
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=chat_session.id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
eager_load_persona=True,
|
||||
)
|
||||
session_id = new_session.id
|
||||
yield CreateChatSessionID(chat_session_id=session_id)
|
||||
else:
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=new_msg_req.chat_session_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
eager_load_persona=True,
|
||||
)
|
||||
session_id = new_msg_req.chat_session_id
|
||||
|
||||
chat_session = get_chat_session_by_id(
|
||||
chat_session_id=session_id,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
eager_load_persona=True,
|
||||
)
|
||||
|
||||
persona = chat_session.persona
|
||||
message_text = new_msg_req.message
|
||||
@@ -574,33 +563,21 @@ def build_chat_turn(
|
||||
)
|
||||
|
||||
# Check LLM cost limits before using the LLM (only for Onyx-managed keys),
|
||||
# then build the LLM instance(s).
|
||||
llms: list[LLM] = []
|
||||
model_display_names: list[str] = []
|
||||
selected_overrides: list[LLMOverride | None] = (
|
||||
list(llm_overrides or [])
|
||||
if is_multi
|
||||
else [new_msg_req.llm_override or chat_session.llm_override]
|
||||
# then build the LLM instance.
|
||||
primary_llm = get_llm_for_persona(
|
||||
persona=persona,
|
||||
user=user,
|
||||
llm_override=new_msg_req.llm_override or chat_session.llm_override,
|
||||
additional_headers=litellm_additional_headers,
|
||||
)
|
||||
for override in selected_overrides:
|
||||
llm = get_llm_for_persona(
|
||||
persona=persona,
|
||||
user=user,
|
||||
llm_override=override,
|
||||
additional_headers=litellm_additional_headers,
|
||||
)
|
||||
check_llm_cost_limit_for_provider(
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
llm_provider_api_key=llm.config.api_key,
|
||||
)
|
||||
llms.append(llm)
|
||||
model_display_names.append(_build_model_display_name(override))
|
||||
token_counter = get_llm_token_counter(llms[0])
|
||||
|
||||
# not sure why we do this, but to maintain parity with previous code:
|
||||
if not is_multi:
|
||||
model_display_names = [""]
|
||||
check_llm_cost_limit_for_provider(
|
||||
db_session=db_session,
|
||||
tenant_id=tenant_id,
|
||||
llm_provider_api_key=primary_llm.config.api_key,
|
||||
)
|
||||
llms = [primary_llm]
|
||||
model_display_names = [""]
|
||||
token_counter = get_llm_token_counter(primary_llm)
|
||||
|
||||
# Verify that the user-specified files actually belong to the user
|
||||
verify_user_files(
|
||||
@@ -761,8 +738,7 @@ def build_chat_turn(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# Use the smallest context window across models for safety (harmless for N=1).
|
||||
llm_max_context_window = min(llm.config.max_input_tokens for llm in llms)
|
||||
llm_max_context_window = llms[0].config.max_input_tokens
|
||||
|
||||
extracted_context_files = extract_context_files(
|
||||
user_files=context_user_files,
|
||||
@@ -806,34 +782,18 @@ def build_chat_turn(
|
||||
# Convert loaded files to ChatFile format for tools like PythonTool
|
||||
chat_files_for_tools = _convert_loaded_files_to_chat_files(files)
|
||||
|
||||
# ── Reserve assistant message ID(s) → yield to frontend ──────────────────
|
||||
if is_multi:
|
||||
assert llm_overrides is not None
|
||||
reserved_messages = reserve_multi_model_message_ids(
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message_id=user_message.id,
|
||||
model_display_names=model_display_names,
|
||||
)
|
||||
yield MultiModelMessageResponseIDInfo(
|
||||
user_message_id=user_message.id,
|
||||
responses=[
|
||||
ModelResponseSlot(message_id=m.id, model_name=name)
|
||||
for m, name in zip(reserved_messages, model_display_names)
|
||||
],
|
||||
)
|
||||
else:
|
||||
assistant_response = reserve_message_id(
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=user_message.id,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
)
|
||||
reserved_messages = [assistant_response]
|
||||
yield MessageResponseIDInfo(
|
||||
user_message_id=user_message.id,
|
||||
reserved_assistant_message_id=assistant_response.id,
|
||||
)
|
||||
# ── Reserve assistant message ID → yield to frontend ─────────────────────
|
||||
assistant_response = reserve_message_id(
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session.id,
|
||||
parent_message=user_message.id,
|
||||
message_type=MessageType.ASSISTANT,
|
||||
)
|
||||
reserved_messages = [assistant_response]
|
||||
yield MessageResponseIDInfo(
|
||||
user_message_id=user_message.id,
|
||||
reserved_assistant_message_id=assistant_response.id,
|
||||
)
|
||||
|
||||
# Convert the chat history into a simple format that is free of any DB objects
|
||||
# and is easy to parse for the agent loop.
|
||||
@@ -936,9 +896,6 @@ def build_chat_turn(
|
||||
# Sentinel placed on the merged queue when a model thread finishes.
|
||||
_MODEL_DONE = object()
|
||||
|
||||
# How often the drain loop polls for user-initiated cancellation (stop button).
|
||||
_CANCEL_POLL_INTERVAL_S: Final[float] = 0.05
|
||||
|
||||
|
||||
def _run_models(
|
||||
setup: ChatTurnSetup,
|
||||
@@ -949,7 +906,7 @@ def _run_models(
|
||||
"""Stream packets from one or more LLM loops running in parallel worker threads.
|
||||
|
||||
Each model gets its own worker thread, DB session, and ``Emitter``. Threads write
|
||||
packets to a shared unbounded queue as they are produced; the drain loop yields them
|
||||
packets to a shared bounded queue as they are produced; the drain loop yields them
|
||||
in arrival order so the caller receives a single interleaved stream regardless of
|
||||
how many models are running.
|
||||
|
||||
@@ -977,6 +934,8 @@ def _run_models(
|
||||
|
||||
merged_queue: queue.Queue[tuple[int, Packet | Exception | object]] = queue.Queue()
|
||||
|
||||
# external_state_container is only non-None for single-model turns (n_models == 1),
|
||||
# so only index 0 can receive it. Multi-model turns always create fresh containers.
|
||||
state_containers: list[ChatStateContainer] = [
|
||||
(
|
||||
external_state_container
|
||||
@@ -986,12 +945,9 @@ def _run_models(
|
||||
for i in range(n_models)
|
||||
]
|
||||
model_succeeded: list[bool] = [False] * n_models
|
||||
# Set to True when a model raises an exception (distinct from "still running").
|
||||
# Used in the stop-button path to avoid calling completion for errored models.
|
||||
model_errored: list[bool] = [False] * n_models
|
||||
|
||||
# Set when the drain loop exits early (HTTP disconnect / GeneratorExit).
|
||||
# Signals emitters to skip future puts so workers exit promptly.
|
||||
# Signals emitters to skip future puts and workers to self-complete.
|
||||
drain_done = threading.Event()
|
||||
|
||||
def _run_model(model_idx: int) -> None:
|
||||
@@ -1096,84 +1052,85 @@ def _run_models(
|
||||
model_succeeded[model_idx] = True
|
||||
|
||||
except Exception as e:
|
||||
model_errored[model_idx] = True
|
||||
merged_queue.put((model_idx, e))
|
||||
|
||||
finally:
|
||||
merged_queue.put((model_idx, _MODEL_DONE))
|
||||
|
||||
def _delete_orphaned_message(model_idx: int, context: str) -> None:
|
||||
"""Delete a reserved ChatMessage that was never populated due to a model error."""
|
||||
try:
|
||||
orphaned = db_session.get(
|
||||
ChatMessage, setup.reserved_messages[model_idx].id
|
||||
)
|
||||
if orphaned is not None:
|
||||
db_session.delete(orphaned)
|
||||
db_session.commit()
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"%s orphan cleanup failed for model %d (%s)",
|
||||
context,
|
||||
model_idx,
|
||||
setup.model_display_names[model_idx],
|
||||
)
|
||||
# Self-completion on disconnect: _MODEL_DONE was already posted in the finally
|
||||
# block above, so the drain loop has counted this model. If drain_done is set,
|
||||
# the main thread exited early and will NOT call llm_loop_completion_handle for
|
||||
# this model — open a fresh session and persist the response here instead.
|
||||
if drain_done.is_set() and model_succeeded[model_idx]:
|
||||
try:
|
||||
with get_session_with_current_tenant() as self_complete_db:
|
||||
assistant_message = self_complete_db.get(
|
||||
ChatMessage, setup.reserved_messages[model_idx].id
|
||||
)
|
||||
if assistant_message is not None:
|
||||
llm_loop_completion_handle(
|
||||
state_container=state_containers[model_idx],
|
||||
# Guard on line above already ensures model_succeeded is True.
|
||||
is_connected=lambda: True,
|
||||
db_session=self_complete_db,
|
||||
assistant_message=assistant_message,
|
||||
llm=setup.llms[model_idx],
|
||||
reserved_tokens=setup.reserved_token_count,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"model %d (%s): self-completion after disconnect failed",
|
||||
model_idx,
|
||||
setup.model_display_names[model_idx],
|
||||
)
|
||||
|
||||
# Copy contextvars before submitting futures — ThreadPoolExecutor does NOT
|
||||
# auto-propagate contextvars in Python 3.11; threads would inherit a blank context.
|
||||
worker_context = contextvars.copy_context()
|
||||
executor = ThreadPoolExecutor(
|
||||
max_workers=n_models, thread_name_prefix="multi-model"
|
||||
executor = run_multiple_in_background(
|
||||
[functools.partial(_run_model, i) for i in range(n_models)],
|
||||
thread_name_prefix="multi-model",
|
||||
)
|
||||
completion_persisted: bool = False
|
||||
_completion_done: bool = False
|
||||
try:
|
||||
for i in range(n_models):
|
||||
executor.submit(worker_context.run, _run_model, i)
|
||||
|
||||
# ── Main thread: merge and yield packets ────────────────────────────
|
||||
models_remaining = n_models
|
||||
last_turn_index = 0
|
||||
while models_remaining > 0:
|
||||
try:
|
||||
model_idx, item = merged_queue.get(timeout=_CANCEL_POLL_INTERVAL_S)
|
||||
model_idx, item = merged_queue.get(timeout=0.05)
|
||||
except queue.Empty:
|
||||
# Check for user-initiated cancellation every 50 ms.
|
||||
if not setup.check_is_connected():
|
||||
# Save state for every model before exiting.
|
||||
# - Succeeded models: full answer (is_connected=True).
|
||||
# - Still-in-flight models: partial answer + "stopped by user".
|
||||
# - Errored models: delete the orphaned reserved message; do NOT
|
||||
# save "stopped by user" for a model that actually threw an exception.
|
||||
for i in range(n_models):
|
||||
if model_errored[i]:
|
||||
_delete_orphaned_message(i, "stop-button")
|
||||
continue
|
||||
try:
|
||||
succeeded = model_succeeded[i]
|
||||
llm_loop_completion_handle(
|
||||
state_container=state_containers[i],
|
||||
is_connected=lambda: succeeded,
|
||||
db_session=db_session,
|
||||
assistant_message=setup.reserved_messages[i],
|
||||
llm=setup.llms[i],
|
||||
reserved_tokens=setup.reserved_token_count,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"stop-button completion failed for model %d (%s)",
|
||||
i,
|
||||
setup.model_display_names[i],
|
||||
)
|
||||
yield Packet(
|
||||
placement=Placement(turn_index=0),
|
||||
obj=OverallStop(type="stop", stop_reason="user_cancelled"),
|
||||
)
|
||||
completion_persisted = True
|
||||
return
|
||||
continue
|
||||
if setup.check_is_connected():
|
||||
continue
|
||||
|
||||
# Save state for every model before exiting. Models that already
|
||||
# finished (model_succeeded[i]=True) get their full answer saved;
|
||||
# models still in-flight get partial answer + "stopped by user".
|
||||
for i in range(n_models):
|
||||
try:
|
||||
llm_loop_completion_handle(
|
||||
state_container=state_containers[i],
|
||||
# partial captures model_succeeded[i] by value at loop time, not by reference
|
||||
is_connected=functools.partial(bool, model_succeeded[i]),
|
||||
db_session=db_session,
|
||||
assistant_message=setup.reserved_messages[i],
|
||||
llm=setup.llms[i],
|
||||
reserved_tokens=setup.reserved_token_count,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
f"Failed completion for model {i} on disconnect ({setup.model_display_names[i]})"
|
||||
)
|
||||
yield Packet(
|
||||
placement=Placement(turn_index=last_turn_index + 1),
|
||||
obj=OverallStop(type="stop", stop_reason="user_cancelled"),
|
||||
)
|
||||
_completion_done = True
|
||||
return
|
||||
else:
|
||||
if item is _MODEL_DONE:
|
||||
models_remaining -= 1
|
||||
elif isinstance(item, Exception):
|
||||
continue
|
||||
|
||||
if isinstance(item, Exception):
|
||||
# Yield a tagged error for this model but keep the other models running.
|
||||
# Do NOT decrement models_remaining — _run_model's finally always posts
|
||||
# _MODEL_DONE, which is the sole completion signal.
|
||||
@@ -1200,7 +1157,14 @@ def _run_models(
|
||||
"model_index": model_idx,
|
||||
},
|
||||
)
|
||||
elif isinstance(item, Packet):
|
||||
continue
|
||||
|
||||
if isinstance(item, Packet):
|
||||
# Track the highest turn_index seen so OverallStop can follow it.
|
||||
if item.placement:
|
||||
last_turn_index = max(
|
||||
last_turn_index, item.placement.turn_index
|
||||
)
|
||||
# model_index already embedded by the model's Emitter in _run_model
|
||||
yield item
|
||||
|
||||
@@ -1210,8 +1174,6 @@ def _run_models(
|
||||
# sessions, but the main-thread db_session is unshared and safe to use.
|
||||
for i in range(n_models):
|
||||
if not model_succeeded[i]:
|
||||
# Model errored — delete its orphaned reserved message.
|
||||
_delete_orphaned_message(i, "normal")
|
||||
continue
|
||||
try:
|
||||
llm_loop_completion_handle(
|
||||
@@ -1224,60 +1186,34 @@ def _run_models(
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"normal completion failed for model %d (%s)",
|
||||
i,
|
||||
setup.model_display_names[i],
|
||||
f"Failed completion for model {i} ({setup.model_display_names[i]})"
|
||||
)
|
||||
completion_persisted = True
|
||||
_completion_done = True
|
||||
|
||||
finally:
|
||||
if completion_persisted:
|
||||
if _completion_done:
|
||||
# Normal exit or stop-button exit: completion already persisted.
|
||||
# Threads are done (normal path) or can finish in the background (stop-button).
|
||||
executor.shutdown(wait=False)
|
||||
else:
|
||||
# Early exit (GeneratorExit from raw HTTP disconnect, or unhandled
|
||||
# exception in the drain loop).
|
||||
# 1. Signal emitters to stop — future emit() calls return immediately,
|
||||
# so workers exit their LLM loops promptly.
|
||||
# 1. Signal emitters to stop blocking — future emit() calls return immediately.
|
||||
drain_done.set()
|
||||
# 2. Wait for all workers to finish. Once drain_done is set the Emitter
|
||||
# short-circuits, so workers should exit quickly.
|
||||
executor.shutdown(wait=True)
|
||||
# 3. All workers are done — complete from the main thread only.
|
||||
for i in range(n_models):
|
||||
if model_succeeded[i]:
|
||||
try:
|
||||
llm_loop_completion_handle(
|
||||
state_container=state_containers[i],
|
||||
# Model already finished — persist full response.
|
||||
is_connected=lambda: True,
|
||||
db_session=db_session,
|
||||
assistant_message=setup.reserved_messages[i],
|
||||
llm=setup.llms[i],
|
||||
reserved_tokens=setup.reserved_token_count,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"disconnect completion failed for model %d (%s)",
|
||||
i,
|
||||
setup.model_display_names[i],
|
||||
)
|
||||
elif model_errored[i]:
|
||||
_delete_orphaned_message(i, "disconnect")
|
||||
# 4. Drain buffered packets from memory — no consumer is running.
|
||||
# 2. Drain buffered packets from memory — no consumer is running.
|
||||
while not merged_queue.empty():
|
||||
try:
|
||||
merged_queue.get_nowait()
|
||||
except queue.Empty:
|
||||
break
|
||||
# 3. Don't block the server thread — workers self-complete via drain_done.
|
||||
executor.shutdown(wait=False)
|
||||
|
||||
|
||||
def _stream_chat_turn(
|
||||
def handle_stream_message_objects(
|
||||
new_msg_req: SendMessageRequest,
|
||||
user: User,
|
||||
db_session: Session,
|
||||
llm_overrides: list[LLMOverride] | None = None,
|
||||
litellm_additional_headers: dict[str, str] | None = None,
|
||||
custom_tool_additional_headers: dict[str, str] | None = None,
|
||||
mcp_headers: dict[str, str] | None = None,
|
||||
@@ -1286,23 +1222,17 @@ def _stream_chat_turn(
|
||||
slack_context: SlackContext | None = None,
|
||||
external_state_container: ChatStateContainer | None = None,
|
||||
) -> AnswerStream:
|
||||
"""Private implementation for single-model and multi-model chat turn streaming.
|
||||
"""Single-model streaming entrypoint.
|
||||
|
||||
Builds the turn context via ``build_chat_turn``, then streams packets from
|
||||
``_run_models`` back to the caller. Handles setup errors, LLM errors, and
|
||||
cancellation uniformly, saving whatever partial state has been accumulated
|
||||
before re-raising or yielding a terminal error packet.
|
||||
|
||||
Not called directly — use the public wrappers:
|
||||
- ``handle_stream_message_objects`` for single-model (N=1) requests.
|
||||
- ``handle_multi_model_stream`` for side-by-side multi-model comparison (N>1).
|
||||
|
||||
Args:
|
||||
new_msg_req: The incoming chat request from the user.
|
||||
user: Authenticated user; may be anonymous for public personas.
|
||||
db_session: Database session for this request.
|
||||
llm_overrides: ``None`` → single-model (persona default LLM).
|
||||
Non-empty list → multi-model (one LLM per override, 2–3 items).
|
||||
litellm_additional_headers: Extra headers forwarded to the LLM provider.
|
||||
custom_tool_additional_headers: Extra headers for custom tool HTTP calls.
|
||||
mcp_headers: Extra headers for MCP tool calls.
|
||||
@@ -1331,7 +1261,6 @@ def _stream_chat_turn(
|
||||
new_msg_req=new_msg_req,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
llm_overrides=llm_overrides,
|
||||
litellm_additional_headers=litellm_additional_headers,
|
||||
custom_tool_additional_headers=custom_tool_additional_headers,
|
||||
mcp_headers=mcp_headers,
|
||||
@@ -1440,94 +1369,6 @@ def _stream_chat_turn(
|
||||
logger.exception("Error in setting processing status")
|
||||
|
||||
|
||||
def handle_stream_message_objects(
|
||||
new_msg_req: SendMessageRequest,
|
||||
user: User,
|
||||
db_session: Session,
|
||||
litellm_additional_headers: dict[str, str] | None = None,
|
||||
custom_tool_additional_headers: dict[str, str] | None = None,
|
||||
mcp_headers: dict[str, str] | None = None,
|
||||
bypass_acl: bool = False,
|
||||
additional_context: str | None = None,
|
||||
slack_context: SlackContext | None = None,
|
||||
external_state_container: ChatStateContainer | None = None,
|
||||
) -> AnswerStream:
|
||||
"""Single-model streaming entrypoint. For multi-model comparison, use ``handle_multi_model_stream``."""
|
||||
yield from _stream_chat_turn(
|
||||
new_msg_req=new_msg_req,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
llm_overrides=None,
|
||||
litellm_additional_headers=litellm_additional_headers,
|
||||
custom_tool_additional_headers=custom_tool_additional_headers,
|
||||
mcp_headers=mcp_headers,
|
||||
bypass_acl=bypass_acl,
|
||||
additional_context=additional_context,
|
||||
slack_context=slack_context,
|
||||
external_state_container=external_state_container,
|
||||
)
|
||||
|
||||
|
||||
def _build_model_display_name(override: LLMOverride | None) -> str:
|
||||
"""Build a human-readable display name from an LLM override."""
|
||||
if override is None:
|
||||
return "unknown"
|
||||
return override.display_name or override.model_version or "unknown"
|
||||
|
||||
|
||||
def handle_multi_model_stream(
|
||||
new_msg_req: SendMessageRequest,
|
||||
user: User,
|
||||
db_session: Session,
|
||||
llm_overrides: list[LLMOverride],
|
||||
litellm_additional_headers: dict[str, str] | None = None,
|
||||
custom_tool_additional_headers: dict[str, str] | None = None,
|
||||
mcp_headers: dict[str, str] | None = None,
|
||||
) -> AnswerStream:
|
||||
"""Thin wrapper for side-by-side multi-model comparison (2–3 models).
|
||||
|
||||
Validates the override list and delegates to ``_stream_chat_turn``,
|
||||
which handles both single-model and multi-model execution via the same path.
|
||||
|
||||
Args:
|
||||
new_msg_req: The incoming chat request. ``deep_research`` must be ``False``.
|
||||
user: Authenticated user making the request.
|
||||
db_session: Database session for this request.
|
||||
llm_overrides: Exactly 2 or 3 ``LLMOverride`` objects — one per model to run.
|
||||
litellm_additional_headers: Extra headers forwarded to each LLM provider.
|
||||
custom_tool_additional_headers: Extra headers for custom tool HTTP calls.
|
||||
mcp_headers: Extra headers for MCP tool calls.
|
||||
|
||||
Returns:
|
||||
Generator yielding interleaved ``Packet`` objects from all models, each tagged
|
||||
with ``model_index`` in its placement.
|
||||
"""
|
||||
n_models = len(llm_overrides)
|
||||
if n_models < 2 or n_models > 3:
|
||||
yield StreamingError(
|
||||
error=f"Multi-model requires 2-3 overrides, got {n_models}",
|
||||
error_code="VALIDATION_ERROR",
|
||||
is_retryable=False,
|
||||
)
|
||||
return
|
||||
if new_msg_req.deep_research:
|
||||
yield StreamingError(
|
||||
error="Multi-model is not supported with deep research",
|
||||
error_code="VALIDATION_ERROR",
|
||||
is_retryable=False,
|
||||
)
|
||||
return
|
||||
yield from _stream_chat_turn(
|
||||
new_msg_req=new_msg_req,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
llm_overrides=llm_overrides,
|
||||
litellm_additional_headers=litellm_additional_headers,
|
||||
custom_tool_additional_headers=custom_tool_additional_headers,
|
||||
mcp_headers=mcp_headers,
|
||||
)
|
||||
|
||||
|
||||
def llm_loop_completion_handle(
|
||||
state_container: ChatStateContainer,
|
||||
is_connected: Callable[[], bool],
|
||||
|
||||
@@ -12,11 +12,6 @@ SLACK_USER_TOKEN_PREFIX = "xoxp-"
|
||||
SLACK_BOT_TOKEN_PREFIX = "xoxb-"
|
||||
ONYX_EMAILABLE_LOGO_MAX_DIM = 512
|
||||
|
||||
# The mask_string() function in encryption.py uses "•" (U+2022 BULLET) to mask secrets.
|
||||
MASK_CREDENTIAL_CHAR = "\u2022"
|
||||
# Pattern produced by mask_string for strings >= 14 chars: "abcd...wxyz" (exactly 11 chars)
|
||||
MASK_CREDENTIAL_LONG_RE = re.compile(r"^.{4}\.{3}.{4}$")
|
||||
|
||||
SOURCE_TYPE = "source_type"
|
||||
# stored in the `metadata` of a chunk. Used to signify that this chunk should
|
||||
# not be used for QA. For example, Google Drive file types which can't be parsed
|
||||
@@ -396,6 +391,10 @@ class MilestoneRecordType(str, Enum):
|
||||
REQUESTED_CONNECTOR = "requested_connector"
|
||||
|
||||
|
||||
class PostgresAdvisoryLocks(Enum):
|
||||
KOMBU_MESSAGE_CLEANUP_LOCK_ID = auto()
|
||||
|
||||
|
||||
class OnyxCeleryQueues:
|
||||
# "celery" is the default queue defined by celery and also the queue
|
||||
# we are running in the primary worker to run system tasks
|
||||
@@ -578,6 +577,7 @@ class OnyxCeleryTask:
|
||||
MONITOR_PROCESS_MEMORY = "monitor_process_memory"
|
||||
CELERY_BEAT_HEARTBEAT = "celery_beat_heartbeat"
|
||||
|
||||
KOMBU_MESSAGE_CLEANUP_TASK = "kombu_message_cleanup_task"
|
||||
CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK = (
|
||||
"connector_permission_sync_generator_task"
|
||||
)
|
||||
|
||||
@@ -44,7 +44,7 @@ _NOTION_CALL_TIMEOUT = 30 # 30 seconds
|
||||
_MAX_PAGES = 1000
|
||||
|
||||
|
||||
# TODO: Pages need to have their metadata ingested
|
||||
# TODO: Tables need to be ingested, Pages need to have their metadata ingested
|
||||
|
||||
|
||||
class NotionPage(BaseModel):
|
||||
@@ -452,19 +452,6 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
sub_inner_dict: dict[str, Any] | list[Any] | str = inner_dict
|
||||
while isinstance(sub_inner_dict, dict) and "type" in sub_inner_dict:
|
||||
type_name = sub_inner_dict["type"]
|
||||
|
||||
# Notion user objects (people properties, created_by, etc.) have
|
||||
# "name" at the same level as "type": "person"/"bot". If we drill
|
||||
# into the person/bot sub-dict we lose the name. Capture it here
|
||||
# before descending, but skip "title"-type properties where "name"
|
||||
# is not the display value we want.
|
||||
if (
|
||||
"name" in sub_inner_dict
|
||||
and isinstance(sub_inner_dict["name"], str)
|
||||
and type_name not in ("title",)
|
||||
):
|
||||
return sub_inner_dict["name"]
|
||||
|
||||
sub_inner_dict = sub_inner_dict[type_name]
|
||||
|
||||
# If the innermost layer is None, the value is not set
|
||||
@@ -676,19 +663,6 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
text = rich_text["text"]["content"]
|
||||
cur_result_text_arr.append(text)
|
||||
|
||||
# table_row blocks store content in "cells" (list of lists
|
||||
# of rich text objects) rather than "rich_text"
|
||||
if "cells" in result_obj:
|
||||
row_cells: list[str] = []
|
||||
for cell in result_obj["cells"]:
|
||||
cell_texts = [
|
||||
rt.get("plain_text", "")
|
||||
for rt in cell
|
||||
if isinstance(rt, dict)
|
||||
]
|
||||
row_cells.append(" ".join(cell_texts))
|
||||
cur_result_text_arr.append("\t".join(row_cells))
|
||||
|
||||
if result["has_children"]:
|
||||
if result_type == "child_page":
|
||||
# Child pages will not be included at this top level, it will be a separate document.
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import uuid
|
||||
|
||||
from fastapi_users.password import PasswordHelper
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
@@ -11,23 +10,14 @@ from onyx.auth.api_key import ApiKeyDescriptor
|
||||
from onyx.auth.api_key import build_displayable_api_key
|
||||
from onyx.auth.api_key import generate_api_key
|
||||
from onyx.auth.api_key import hash_api_key
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from onyx.configs.constants import DANSWER_API_KEY_PREFIX
|
||||
from onyx.configs.constants import UNNAMED_KEY_PLACEHOLDER
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import ApiKey
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.permissions import recompute_user_permissions__no_commit
|
||||
from onyx.db.users import assign_user_to_default_groups__no_commit
|
||||
from onyx.server.api_key.models import APIKeyArgs
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_api_key_email_pattern() -> str:
|
||||
return DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
@@ -95,7 +85,6 @@ def insert_api_key(
|
||||
is_superuser=False,
|
||||
is_verified=True,
|
||||
role=api_key_args.role,
|
||||
account_type=AccountType.SERVICE_ACCOUNT,
|
||||
)
|
||||
db_session.add(api_key_user_row)
|
||||
|
||||
@@ -108,18 +97,7 @@ def insert_api_key(
|
||||
)
|
||||
db_session.add(api_key_row)
|
||||
|
||||
# Assign the API key virtual user to the appropriate default group
|
||||
# before commit so everything is atomic.
|
||||
# LIMITED role service accounts should have no group membership.
|
||||
if api_key_args.role != UserRole.LIMITED:
|
||||
assign_user_to_default_groups__no_commit(
|
||||
db_session,
|
||||
api_key_user_row,
|
||||
is_admin=(api_key_args.role == UserRole.ADMIN),
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return ApiKeyDescriptor(
|
||||
api_key_id=api_key_row.id,
|
||||
api_key_role=api_key_user_row.role,
|
||||
@@ -146,33 +124,7 @@ def update_api_key(
|
||||
|
||||
email_name = api_key_args.name or UNNAMED_KEY_PLACEHOLDER
|
||||
api_key_user.email = get_api_key_fake_email(email_name, str(api_key_user.id))
|
||||
|
||||
old_role = api_key_user.role
|
||||
api_key_user.role = api_key_args.role
|
||||
|
||||
# Reconcile default-group membership when the role changes.
|
||||
if old_role != api_key_args.role:
|
||||
# Remove from all default groups first.
|
||||
delete_stmt = delete(User__UserGroup).where(
|
||||
User__UserGroup.user_id == api_key_user.id,
|
||||
User__UserGroup.user_group_id.in_(
|
||||
select(UserGroup.id).where(UserGroup.is_default.is_(True))
|
||||
),
|
||||
)
|
||||
db_session.execute(delete_stmt)
|
||||
|
||||
# Re-assign to the correct default group (skip for LIMITED).
|
||||
if api_key_args.role != UserRole.LIMITED:
|
||||
assign_user_to_default_groups__no_commit(
|
||||
db_session,
|
||||
api_key_user,
|
||||
is_admin=(api_key_args.role == UserRole.ADMIN),
|
||||
)
|
||||
else:
|
||||
# No group assigned for LIMITED, but we still need to recompute
|
||||
# since we just removed the old default-group membership above.
|
||||
recompute_user_permissions__no_commit(api_key_user.id, db_session)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return ApiKeyDescriptor(
|
||||
|
||||
@@ -190,23 +190,16 @@ def delete_messages_and_files_from_chat_session(
|
||||
chat_session_id: UUID, db_session: Session
|
||||
) -> None:
|
||||
# Select messages older than cutoff_time with files
|
||||
messages_with_files = (
|
||||
db_session.execute(
|
||||
select(ChatMessage.id, ChatMessage.files).where(
|
||||
ChatMessage.chat_session_id == chat_session_id,
|
||||
)
|
||||
messages_with_files = db_session.execute(
|
||||
select(ChatMessage.id, ChatMessage.files).where(
|
||||
ChatMessage.chat_session_id == chat_session_id,
|
||||
)
|
||||
.tuples()
|
||||
.all()
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
file_store = get_default_file_store()
|
||||
for _, files in messages_with_files:
|
||||
file_store = get_default_file_store()
|
||||
for file_info in files or []:
|
||||
if file_info.get("user_file_id"):
|
||||
# user files are managed by the user file lifecycle
|
||||
continue
|
||||
file_store.delete_file(file_id=file_info["id"], error_on_missing=False)
|
||||
file_store.delete_file(file_id=file_info.get("id"))
|
||||
|
||||
# Delete ChatMessage records - CASCADE constraints will automatically handle:
|
||||
# - ChatMessage__StandardAnswer relationship records
|
||||
@@ -638,91 +631,6 @@ def reserve_message_id(
|
||||
return empty_message
|
||||
|
||||
|
||||
def reserve_multi_model_message_ids(
|
||||
db_session: Session,
|
||||
chat_session_id: UUID,
|
||||
parent_message_id: int,
|
||||
model_display_names: list[str],
|
||||
) -> list[ChatMessage]:
|
||||
"""Reserve N assistant message placeholders for multi-model parallel streaming.
|
||||
|
||||
All messages share the same parent (the user message). The parent's
|
||||
latest_child_message_id points to the LAST reserved message so that the
|
||||
default history-chain walker picks it up.
|
||||
"""
|
||||
reserved: list[ChatMessage] = []
|
||||
for display_name in model_display_names:
|
||||
msg = ChatMessage(
|
||||
chat_session_id=chat_session_id,
|
||||
parent_message_id=parent_message_id,
|
||||
latest_child_message_id=None,
|
||||
message="Response was terminated prior to completion, try regenerating.",
|
||||
token_count=15, # placeholder; updated on completion by llm_loop_completion_handle
|
||||
message_type=MessageType.ASSISTANT,
|
||||
model_display_name=display_name,
|
||||
)
|
||||
db_session.add(msg)
|
||||
reserved.append(msg)
|
||||
|
||||
# Flush to assign IDs without committing yet
|
||||
db_session.flush()
|
||||
|
||||
# Point parent's latest_child to the last reserved message
|
||||
parent = (
|
||||
db_session.query(ChatMessage)
|
||||
.filter(ChatMessage.id == parent_message_id)
|
||||
.first()
|
||||
)
|
||||
if parent:
|
||||
parent.latest_child_message_id = reserved[-1].id
|
||||
|
||||
db_session.commit()
|
||||
return reserved
|
||||
|
||||
|
||||
def set_preferred_response(
|
||||
db_session: Session,
|
||||
user_message_id: int,
|
||||
preferred_assistant_message_id: int,
|
||||
) -> None:
|
||||
"""Mark one assistant response as the user's preferred choice in a multi-model turn.
|
||||
|
||||
Also advances ``latest_child_message_id`` so the preferred response becomes
|
||||
the active branch for any subsequent messages in the conversation.
|
||||
|
||||
Args:
|
||||
db_session: Active database session.
|
||||
user_message_id: Primary key of the ``USER``-type ``ChatMessage`` whose
|
||||
preferred response is being set.
|
||||
preferred_assistant_message_id: Primary key of the ``ASSISTANT``-type
|
||||
``ChatMessage`` to prefer. Must be a direct child of ``user_message_id``.
|
||||
|
||||
Raises:
|
||||
ValueError: If either message is not found, if ``user_message_id`` does not
|
||||
refer to a USER message, or if the assistant message is not a direct child
|
||||
of the user message.
|
||||
"""
|
||||
user_msg = db_session.get(ChatMessage, user_message_id)
|
||||
if user_msg is None:
|
||||
raise ValueError(f"User message {user_message_id} not found")
|
||||
if user_msg.message_type != MessageType.USER:
|
||||
raise ValueError(f"Message {user_message_id} is not a user message")
|
||||
|
||||
assistant_msg = db_session.get(ChatMessage, preferred_assistant_message_id)
|
||||
if assistant_msg is None:
|
||||
raise ValueError(
|
||||
f"Assistant message {preferred_assistant_message_id} not found"
|
||||
)
|
||||
if assistant_msg.parent_message_id != user_message_id:
|
||||
raise ValueError(
|
||||
f"Assistant message {preferred_assistant_message_id} is not a child of user message {user_message_id}"
|
||||
)
|
||||
|
||||
user_msg.preferred_response_id = preferred_assistant_message_id
|
||||
user_msg.latest_child_message_id = preferred_assistant_message_id
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def create_new_chat_message(
|
||||
chat_session_id: UUID,
|
||||
parent_message: ChatMessage,
|
||||
@@ -945,8 +853,6 @@ def translate_db_message_to_chat_message_detail(
|
||||
error=chat_message.error,
|
||||
current_feedback=current_feedback,
|
||||
processing_duration_seconds=chat_message.processing_duration_seconds,
|
||||
preferred_response_id=chat_message.preferred_response_id,
|
||||
model_display_name=chat_message.model_display_name,
|
||||
)
|
||||
|
||||
return chat_msg_detail
|
||||
|
||||
@@ -13,26 +13,19 @@ class AccountType(str, PyEnum):
|
||||
BOT, EXT_PERM_USER, ANONYMOUS → fixed behavior
|
||||
"""
|
||||
|
||||
STANDARD = "STANDARD"
|
||||
BOT = "BOT"
|
||||
EXT_PERM_USER = "EXT_PERM_USER"
|
||||
SERVICE_ACCOUNT = "SERVICE_ACCOUNT"
|
||||
ANONYMOUS = "ANONYMOUS"
|
||||
|
||||
def is_web_login(self) -> bool:
|
||||
"""Whether this account type supports interactive web login."""
|
||||
return self not in (
|
||||
AccountType.BOT,
|
||||
AccountType.EXT_PERM_USER,
|
||||
)
|
||||
STANDARD = "standard"
|
||||
BOT = "bot"
|
||||
EXT_PERM_USER = "ext_perm_user"
|
||||
SERVICE_ACCOUNT = "service_account"
|
||||
ANONYMOUS = "anonymous"
|
||||
|
||||
|
||||
class GrantSource(str, PyEnum):
|
||||
"""How a permission grant was created."""
|
||||
|
||||
USER = "USER"
|
||||
SCIM = "SCIM"
|
||||
SYSTEM = "SYSTEM"
|
||||
USER = "user"
|
||||
SCIM = "scim"
|
||||
SYSTEM = "system"
|
||||
|
||||
|
||||
class IndexingStatus(str, PyEnum):
|
||||
|
||||
@@ -8,8 +8,6 @@ from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.constants import FederatedConnectorSource
|
||||
from onyx.configs.constants import MASK_CREDENTIAL_CHAR
|
||||
from onyx.configs.constants import MASK_CREDENTIAL_LONG_RE
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import FederatedConnector
|
||||
@@ -47,23 +45,6 @@ def fetch_all_federated_connectors_parallel() -> list[FederatedConnector]:
|
||||
return fetch_all_federated_connectors(db_session)
|
||||
|
||||
|
||||
def _reject_masked_credentials(credentials: dict[str, Any]) -> None:
|
||||
"""Raise if any credential string value contains mask placeholder characters.
|
||||
|
||||
mask_string() has two output formats:
|
||||
- Short strings (< 14 chars): "••••••••••••" (U+2022 BULLET)
|
||||
- Long strings (>= 14 chars): "abcd...wxyz" (first4 + "..." + last4)
|
||||
Both must be rejected.
|
||||
"""
|
||||
for key, val in credentials.items():
|
||||
if isinstance(val, str) and (
|
||||
MASK_CREDENTIAL_CHAR in val or MASK_CREDENTIAL_LONG_RE.match(val)
|
||||
):
|
||||
raise ValueError(
|
||||
f"Credential field '{key}' contains masked placeholder characters. Please provide the actual credential value."
|
||||
)
|
||||
|
||||
|
||||
def validate_federated_connector_credentials(
|
||||
source: FederatedConnectorSource,
|
||||
credentials: dict[str, Any],
|
||||
@@ -85,8 +66,6 @@ def create_federated_connector(
|
||||
config: dict[str, Any] | None = None,
|
||||
) -> FederatedConnector:
|
||||
"""Create a new federated connector with credential and config validation."""
|
||||
_reject_masked_credentials(credentials)
|
||||
|
||||
# Validate credentials before creating
|
||||
if not validate_federated_connector_credentials(source, credentials):
|
||||
raise ValueError(
|
||||
@@ -298,8 +277,6 @@ def update_federated_connector(
|
||||
)
|
||||
|
||||
if credentials is not None:
|
||||
_reject_masked_credentials(credentials)
|
||||
|
||||
# Validate credentials before updating
|
||||
if not validate_federated_connector_credentials(
|
||||
federated_connector.source, credentials
|
||||
|
||||
@@ -305,11 +305,8 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
role: Mapped[UserRole] = mapped_column(
|
||||
Enum(UserRole, native_enum=False, default=UserRole.BASIC)
|
||||
)
|
||||
account_type: Mapped[AccountType] = mapped_column(
|
||||
Enum(AccountType, native_enum=False),
|
||||
nullable=False,
|
||||
default=AccountType.STANDARD,
|
||||
server_default="STANDARD",
|
||||
account_type: Mapped[AccountType | None] = mapped_column(
|
||||
Enum(AccountType, native_enum=False), nullable=True
|
||||
)
|
||||
|
||||
"""
|
||||
@@ -356,13 +353,6 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
postgresql.JSONB(), nullable=True, default=None
|
||||
)
|
||||
|
||||
effective_permissions: Mapped[list[str]] = mapped_column(
|
||||
postgresql.JSONB(),
|
||||
nullable=False,
|
||||
default=list,
|
||||
server_default=text("'[]'::jsonb"),
|
||||
)
|
||||
|
||||
oidc_expiry: Mapped[datetime.datetime] = mapped_column(
|
||||
TIMESTAMPAware(timezone=True), nullable=True
|
||||
)
|
||||
@@ -4026,12 +4016,7 @@ class PermissionGrant(Base):
|
||||
ForeignKey("user_group.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
permission: Mapped[Permission] = mapped_column(
|
||||
Enum(
|
||||
Permission,
|
||||
native_enum=False,
|
||||
values_callable=lambda x: [e.value for e in x],
|
||||
),
|
||||
nullable=False,
|
||||
Enum(Permission, native_enum=False), nullable=False
|
||||
)
|
||||
grant_source: Mapped[GrantSource] = mapped_column(
|
||||
Enum(GrantSource, native_enum=False), nullable=False
|
||||
|
||||
@@ -324,15 +324,6 @@ def mark_migration_completed_time_if_not_set_with_commit(
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def is_migration_completed(db_session: Session) -> bool:
|
||||
"""Returns True if the migration is completed.
|
||||
|
||||
Can be run even if the migration record does not exist.
|
||||
"""
|
||||
record = db_session.query(OpenSearchTenantMigrationRecord).first()
|
||||
return record is not None and record.migration_completed_at is not None
|
||||
|
||||
|
||||
def build_sanitized_to_original_doc_id_mapping(
|
||||
db_session: Session,
|
||||
) -> dict[str, str]:
|
||||
|
||||
@@ -1,95 +0,0 @@
|
||||
"""
|
||||
DB operations for recomputing user effective_permissions.
|
||||
|
||||
These live in onyx/db/ (not onyx/auth/) because they are pure DB operations
|
||||
that query PermissionGrant rows and update the User.effective_permissions
|
||||
JSONB column. Keeping them here avoids circular imports when called from
|
||||
other onyx/db/ modules such as users.py.
|
||||
"""
|
||||
|
||||
from collections import defaultdict
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import PermissionGrant
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
|
||||
|
||||
def recompute_user_permissions__no_commit(
|
||||
user_ids: UUID | str | list[UUID] | list[str], db_session: Session
|
||||
) -> None:
|
||||
"""Recompute granted permissions for one or more users.
|
||||
|
||||
Accepts a single UUID or a list. Uses a single query regardless of
|
||||
how many users are passed, avoiding N+1 issues.
|
||||
|
||||
Stores only directly granted permissions — implication expansion
|
||||
happens at read time via get_effective_permissions().
|
||||
|
||||
Does NOT commit — caller must commit the session.
|
||||
"""
|
||||
if isinstance(user_ids, (UUID, str)):
|
||||
uid_list = [user_ids]
|
||||
else:
|
||||
uid_list = list(user_ids)
|
||||
|
||||
if not uid_list:
|
||||
return
|
||||
|
||||
# Single query to fetch ALL permissions for these users across ALL their
|
||||
# groups (a user may belong to multiple groups with different grants).
|
||||
rows = db_session.execute(
|
||||
select(User__UserGroup.user_id, PermissionGrant.permission)
|
||||
.join(
|
||||
PermissionGrant,
|
||||
PermissionGrant.group_id == User__UserGroup.user_group_id,
|
||||
)
|
||||
.where(
|
||||
User__UserGroup.user_id.in_(uid_list),
|
||||
PermissionGrant.is_deleted.is_(False),
|
||||
)
|
||||
).all()
|
||||
|
||||
# Group permissions by user; users with no grants get an empty set.
|
||||
perms_by_user: dict[UUID | str, set[str]] = defaultdict(set)
|
||||
for uid in uid_list:
|
||||
perms_by_user[uid] # ensure every user has an entry
|
||||
for uid, perm in rows:
|
||||
perms_by_user[uid].add(perm.value)
|
||||
|
||||
for uid, perms in perms_by_user.items():
|
||||
db_session.execute(
|
||||
update(User)
|
||||
.where(User.id == uid) # type: ignore[arg-type]
|
||||
.values(effective_permissions=sorted(perms))
|
||||
)
|
||||
|
||||
|
||||
def recompute_permissions_for_group__no_commit(
|
||||
group_id: int, db_session: Session
|
||||
) -> None:
|
||||
"""Recompute granted permissions for all users in a group.
|
||||
|
||||
Does NOT commit — caller must commit the session.
|
||||
"""
|
||||
user_ids: list[UUID] = [
|
||||
uid
|
||||
for uid in db_session.execute(
|
||||
select(User__UserGroup.user_id).where(
|
||||
User__UserGroup.user_group_id == group_id,
|
||||
User__UserGroup.user_id.isnot(None),
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
if uid is not None
|
||||
]
|
||||
|
||||
if not user_ids:
|
||||
return
|
||||
|
||||
recompute_user_permissions__no_commit(user_ids, db_session)
|
||||
@@ -5,11 +5,11 @@ from urllib.parse import urlencode
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.app_configs import INSTANCE_TYPE
|
||||
from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.configs.constants import ONYX_UTM_SOURCE
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import User
|
||||
from onyx.db.notification import batch_create_notifications
|
||||
from onyx.server.features.release_notes.constants import DOCS_CHANGELOG_BASE_URL
|
||||
@@ -49,7 +49,7 @@ def create_release_notifications_for_versions(
|
||||
db_session.scalars(
|
||||
select(User.id).where( # type: ignore
|
||||
User.is_active == True, # noqa: E712
|
||||
User.account_type.notin_([AccountType.BOT, AccountType.EXT_PERM_USER]),
|
||||
User.role.notin_([UserRole.SLACK_USER, UserRole.EXT_PERM_USER]),
|
||||
User.email.endswith(DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN).is_(False), # type: ignore[attr-defined]
|
||||
)
|
||||
).all()
|
||||
|
||||
@@ -9,17 +9,12 @@ from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.enums import DefaultAppMode
|
||||
from onyx.db.enums import ThemePreference
|
||||
from onyx.db.models import AccessToken
|
||||
from onyx.db.models import Assistant__UserSpecificConfig
|
||||
from onyx.db.models import Memory
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.permissions import recompute_user_permissions__no_commit
|
||||
from onyx.db.users import assign_user_to_default_groups__no_commit
|
||||
from onyx.server.manage.models import MemoryItem
|
||||
from onyx.server.manage.models import UserSpecificAssistantPreference
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -28,53 +23,13 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
_ROLE_TO_ACCOUNT_TYPE: dict[UserRole, AccountType] = {
|
||||
UserRole.SLACK_USER: AccountType.BOT,
|
||||
UserRole.EXT_PERM_USER: AccountType.EXT_PERM_USER,
|
||||
}
|
||||
|
||||
|
||||
def update_user_role(
|
||||
user: User,
|
||||
new_role: UserRole,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Update a user's role in the database.
|
||||
Dual-writes account_type to keep it in sync with role and
|
||||
reconciles default-group membership (Admin / Basic)."""
|
||||
old_role = user.role
|
||||
"""Update a user's role in the database."""
|
||||
user.role = new_role
|
||||
# Note: setting account_type to BOT or EXT_PERM_USER causes
|
||||
# assign_user_to_default_groups__no_commit to early-return, which is
|
||||
# intentional — these account types should not be in default groups.
|
||||
if new_role in _ROLE_TO_ACCOUNT_TYPE:
|
||||
user.account_type = _ROLE_TO_ACCOUNT_TYPE[new_role]
|
||||
elif user.account_type in (AccountType.BOT, AccountType.EXT_PERM_USER):
|
||||
# Upgrading from a non-web-login account type to a web role
|
||||
user.account_type = AccountType.STANDARD
|
||||
|
||||
# Reconcile default-group membership when the role changes.
|
||||
if old_role != new_role:
|
||||
# Remove from all default groups first.
|
||||
db_session.execute(
|
||||
delete(User__UserGroup).where(
|
||||
User__UserGroup.user_id == user.id,
|
||||
User__UserGroup.user_group_id.in_(
|
||||
select(UserGroup.id).where(UserGroup.is_default.is_(True))
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Re-assign to the correct default group (skip for LIMITED).
|
||||
if new_role != UserRole.LIMITED:
|
||||
assign_user_to_default_groups__no_commit(
|
||||
db_session,
|
||||
user,
|
||||
is_admin=(new_role == UserRole.ADMIN),
|
||||
)
|
||||
|
||||
recompute_user_permissions__no_commit(user.id, db_session)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -92,16 +47,8 @@ def activate_user(
|
||||
user: User,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Activate a user by setting is_active to True.
|
||||
|
||||
Also reconciles default-group membership — the user may have been
|
||||
created while inactive or deactivated before the backfill migration.
|
||||
"""
|
||||
"""Activate a user by setting is_active to True."""
|
||||
user.is_active = True
|
||||
if user.role != UserRole.LIMITED:
|
||||
assign_user_to_default_groups__no_commit(
|
||||
db_session, user, is_admin=(user.role == UserRole.ADMIN)
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -17,9 +17,8 @@ from sqlalchemy.sql.expression import or_
|
||||
from onyx.auth.invited_users import remove_user_from_invited_users
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.constants import ANONYMOUS_USER_EMAIL
|
||||
from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from onyx.configs.constants import NO_AUTH_PLACEHOLDER_USER_EMAIL
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.api_key import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import DocumentSet__User
|
||||
from onyx.db.models import Persona
|
||||
@@ -28,17 +27,11 @@ from onyx.db.models import SamlAccount
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def validate_user_role_update(
|
||||
requested_role: UserRole,
|
||||
current_role: UserRole,
|
||||
current_account_type: AccountType,
|
||||
explicit_override: bool = False,
|
||||
requested_role: UserRole, current_role: UserRole, explicit_override: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Validate that a user role update is valid.
|
||||
@@ -48,18 +41,19 @@ def validate_user_role_update(
|
||||
- requested role is a slack user
|
||||
- requested role is an external permissioned user
|
||||
- requested role is a limited user
|
||||
- current account type is BOT (slack user)
|
||||
- current account type is EXT_PERM_USER
|
||||
- current role is a slack user
|
||||
- current role is an external permissioned user
|
||||
- current role is a limited user
|
||||
"""
|
||||
|
||||
if current_account_type == AccountType.BOT:
|
||||
if current_role == UserRole.SLACK_USER:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="To change a Slack User's role, they must first login to Onyx via the web app.",
|
||||
)
|
||||
|
||||
if current_account_type == AccountType.EXT_PERM_USER:
|
||||
if current_role == UserRole.EXT_PERM_USER:
|
||||
# This shouldn't happen, but just in case
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="To change an External Permissioned User's role, they must first login to Onyx via the web app.",
|
||||
@@ -304,7 +298,6 @@ def _generate_slack_user(email: str) -> User:
|
||||
email=email,
|
||||
hashed_password=hashed_pass,
|
||||
role=UserRole.SLACK_USER,
|
||||
account_type=AccountType.BOT,
|
||||
)
|
||||
|
||||
|
||||
@@ -313,9 +306,8 @@ def add_slack_user_if_not_exists(db_session: Session, email: str) -> User:
|
||||
user = get_user_by_email(email, db_session)
|
||||
if user is not None:
|
||||
# If the user is an external permissioned user, we update it to a slack user
|
||||
if user.account_type == AccountType.EXT_PERM_USER:
|
||||
if user.role == UserRole.EXT_PERM_USER:
|
||||
user.role = UserRole.SLACK_USER
|
||||
user.account_type = AccountType.BOT
|
||||
db_session.commit()
|
||||
return user
|
||||
|
||||
@@ -352,7 +344,6 @@ def _generate_ext_permissioned_user(email: str) -> User:
|
||||
email=email,
|
||||
hashed_password=hashed_pass,
|
||||
role=UserRole.EXT_PERM_USER,
|
||||
account_type=AccountType.EXT_PERM_USER,
|
||||
)
|
||||
|
||||
|
||||
@@ -384,81 +375,6 @@ def batch_add_ext_perm_user_if_not_exists(
|
||||
return all_users
|
||||
|
||||
|
||||
def assign_user_to_default_groups__no_commit(
|
||||
db_session: Session,
|
||||
user: User,
|
||||
is_admin: bool = False,
|
||||
) -> None:
|
||||
"""Assign a newly created user to the appropriate default group.
|
||||
|
||||
Does NOT commit — callers must commit the session themselves so that
|
||||
group assignment can be part of the same transaction as user creation.
|
||||
|
||||
Args:
|
||||
is_admin: If True, assign to Admin default group; otherwise Basic.
|
||||
Callers determine this from their own context (e.g. user_count,
|
||||
admin email list, explicit choice). Defaults to False (Basic).
|
||||
"""
|
||||
if user.account_type in (
|
||||
AccountType.BOT,
|
||||
AccountType.EXT_PERM_USER,
|
||||
AccountType.ANONYMOUS,
|
||||
):
|
||||
return
|
||||
|
||||
target_group_name = "Admin" if is_admin else "Basic"
|
||||
|
||||
default_group = (
|
||||
db_session.query(UserGroup)
|
||||
.filter(
|
||||
UserGroup.name == target_group_name,
|
||||
UserGroup.is_default.is_(True),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if default_group is None:
|
||||
raise RuntimeError(
|
||||
f"Default group '{target_group_name}' not found. "
|
||||
f"Cannot assign user {user.email} to a group. "
|
||||
f"Ensure the seed_default_groups migration has run."
|
||||
)
|
||||
|
||||
# Check if the user is already in the group
|
||||
existing = (
|
||||
db_session.query(User__UserGroup)
|
||||
.filter(
|
||||
User__UserGroup.user_id == user.id,
|
||||
User__UserGroup.user_group_id == default_group.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if existing is not None:
|
||||
return
|
||||
|
||||
savepoint = db_session.begin_nested()
|
||||
try:
|
||||
db_session.add(
|
||||
User__UserGroup(
|
||||
user_id=user.id,
|
||||
user_group_id=default_group.id,
|
||||
)
|
||||
)
|
||||
db_session.flush()
|
||||
except IntegrityError:
|
||||
# Race condition: another transaction inserted this membership
|
||||
# between our SELECT and INSERT. The savepoint isolates the failure
|
||||
# so the outer transaction (user creation) stays intact.
|
||||
savepoint.rollback()
|
||||
return
|
||||
|
||||
from onyx.db.permissions import recompute_user_permissions__no_commit
|
||||
|
||||
recompute_user_permissions__no_commit(user.id, db_session)
|
||||
|
||||
logger.info(f"Assigned user {user.email} to default group '{default_group.name}'")
|
||||
|
||||
|
||||
def delete_user_from_db(
|
||||
user_to_delete: User,
|
||||
db_session: Session,
|
||||
@@ -505,14 +421,13 @@ def delete_user_from_db(
|
||||
def batch_get_user_groups(
|
||||
db_session: Session,
|
||||
user_ids: list[UUID],
|
||||
include_default: bool = False,
|
||||
) -> dict[UUID, list[tuple[int, str]]]:
|
||||
"""Fetch group memberships for a batch of users in a single query.
|
||||
Returns a mapping of user_id -> list of (group_id, group_name) tuples."""
|
||||
if not user_ids:
|
||||
return {}
|
||||
|
||||
stmt = (
|
||||
rows = db_session.execute(
|
||||
select(
|
||||
User__UserGroup.user_id,
|
||||
UserGroup.id,
|
||||
@@ -520,11 +435,7 @@ def batch_get_user_groups(
|
||||
)
|
||||
.join(UserGroup, UserGroup.id == User__UserGroup.user_group_id)
|
||||
.where(User__UserGroup.user_id.in_(user_ids))
|
||||
)
|
||||
if not include_default:
|
||||
stmt = stmt.where(UserGroup.is_default == False) # noqa: E712
|
||||
|
||||
rows = db_session.execute(stmt).all()
|
||||
).all()
|
||||
|
||||
result: dict[UUID, list[tuple[int, str]]] = {uid: [] for uid in user_ids}
|
||||
for user_id, group_id, group_name in rows:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import hashlib
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
@@ -21,13 +20,9 @@ from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE
|
||||
from onyx.document_index.opensearch.constants import EF_CONSTRUCTION
|
||||
from onyx.document_index.opensearch.constants import EF_SEARCH
|
||||
from onyx.document_index.opensearch.constants import M
|
||||
from onyx.document_index.opensearch.string_filtering import DocumentIDTooLongError
|
||||
from onyx.document_index.opensearch.string_filtering import (
|
||||
filter_and_validate_document_id,
|
||||
)
|
||||
from onyx.document_index.opensearch.string_filtering import (
|
||||
MAX_DOCUMENT_ID_ENCODED_LENGTH,
|
||||
)
|
||||
from onyx.utils.tenant import get_tenant_id_short_string
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
@@ -80,50 +75,17 @@ def get_opensearch_doc_chunk_id(
|
||||
|
||||
This will be the string used to identify the chunk in OpenSearch. Any direct
|
||||
chunk queries should use this function.
|
||||
|
||||
If the document ID is too long, a hash of the ID is used instead.
|
||||
"""
|
||||
opensearch_doc_chunk_id_suffix: str = f"__{max_chunk_size}__{chunk_index}"
|
||||
encoded_suffix_length: int = len(opensearch_doc_chunk_id_suffix.encode("utf-8"))
|
||||
max_encoded_permissible_doc_id_length: int = (
|
||||
MAX_DOCUMENT_ID_ENCODED_LENGTH - encoded_suffix_length
|
||||
sanitized_document_id = filter_and_validate_document_id(document_id)
|
||||
opensearch_doc_chunk_id = (
|
||||
f"{sanitized_document_id}__{max_chunk_size}__{chunk_index}"
|
||||
)
|
||||
opensearch_doc_chunk_id_tenant_prefix: str = ""
|
||||
if tenant_state.multitenant:
|
||||
short_tenant_id: str = get_tenant_id_short_string(tenant_state.tenant_id)
|
||||
# Use tenant ID because in multitenant mode each tenant has its own
|
||||
# Documents table, so there is a very small chance that doc IDs are not
|
||||
# actually unique across all tenants.
|
||||
opensearch_doc_chunk_id_tenant_prefix = f"{short_tenant_id}__"
|
||||
encoded_prefix_length: int = len(
|
||||
opensearch_doc_chunk_id_tenant_prefix.encode("utf-8")
|
||||
)
|
||||
max_encoded_permissible_doc_id_length -= encoded_prefix_length
|
||||
|
||||
try:
|
||||
sanitized_document_id: str = filter_and_validate_document_id(
|
||||
document_id, max_encoded_length=max_encoded_permissible_doc_id_length
|
||||
)
|
||||
except DocumentIDTooLongError:
|
||||
# If the document ID is too long, use a hash instead.
|
||||
# We use blake2b because it is faster and equally secure as SHA256, and
|
||||
# accepts digest_size which controls the number of bytes returned in the
|
||||
# hash.
|
||||
# digest_size is the size of the returned hash in bytes. Since we're
|
||||
# decoding the hash bytes as a hex string, the digest_size should be
|
||||
# half the max target size of the hash string.
|
||||
# Subtract 1 because filter_and_validate_document_id compares on >= on
|
||||
# max_encoded_length.
|
||||
# 64 is the max digest_size blake2b returns.
|
||||
digest_size: int = min((max_encoded_permissible_doc_id_length - 1) // 2, 64)
|
||||
sanitized_document_id = hashlib.blake2b(
|
||||
document_id.encode("utf-8"), digest_size=digest_size
|
||||
).hexdigest()
|
||||
|
||||
opensearch_doc_chunk_id: str = (
|
||||
f"{opensearch_doc_chunk_id_tenant_prefix}{sanitized_document_id}{opensearch_doc_chunk_id_suffix}"
|
||||
)
|
||||
|
||||
short_tenant_id = get_tenant_id_short_string(tenant_state.tenant_id)
|
||||
opensearch_doc_chunk_id = f"{short_tenant_id}__{opensearch_doc_chunk_id}"
|
||||
# Do one more validation to ensure we haven't exceeded the max length.
|
||||
opensearch_doc_chunk_id = filter_and_validate_document_id(opensearch_doc_chunk_id)
|
||||
return opensearch_doc_chunk_id
|
||||
|
||||
@@ -1,15 +1,7 @@
|
||||
import re
|
||||
|
||||
MAX_DOCUMENT_ID_ENCODED_LENGTH: int = 512
|
||||
|
||||
|
||||
class DocumentIDTooLongError(ValueError):
|
||||
"""Raised when a document ID is too long for OpenSearch after filtering."""
|
||||
|
||||
|
||||
def filter_and_validate_document_id(
|
||||
document_id: str, max_encoded_length: int = MAX_DOCUMENT_ID_ENCODED_LENGTH
|
||||
) -> str:
|
||||
def filter_and_validate_document_id(document_id: str) -> str:
|
||||
"""
|
||||
Filters and validates a document ID such that it can be used as an ID in
|
||||
OpenSearch.
|
||||
@@ -27,13 +19,9 @@ def filter_and_validate_document_id(
|
||||
|
||||
Args:
|
||||
document_id: The document ID to filter and validate.
|
||||
max_encoded_length: The maximum length of the document ID after
|
||||
filtering in bytes. Compared with >= for extra resilience, so
|
||||
encoded values of this length will fail.
|
||||
|
||||
Raises:
|
||||
DocumentIDTooLongError: If the document ID is too long after filtering.
|
||||
ValueError: If the document ID is empty after filtering.
|
||||
ValueError: If the document ID is empty or too long after filtering.
|
||||
|
||||
Returns:
|
||||
str: The filtered document ID.
|
||||
@@ -41,8 +29,6 @@ def filter_and_validate_document_id(
|
||||
filtered_document_id = re.sub(r"[^A-Za-z0-9_.\-~]", "", document_id)
|
||||
if not filtered_document_id:
|
||||
raise ValueError(f"Document ID {document_id} is empty after filtering.")
|
||||
if len(filtered_document_id.encode("utf-8")) >= max_encoded_length:
|
||||
raise DocumentIDTooLongError(
|
||||
f"Document ID {document_id} is too long after filtering."
|
||||
)
|
||||
if len(filtered_document_id.encode("utf-8")) >= 512:
|
||||
raise ValueError(f"Document ID {document_id} is too long after filtering.")
|
||||
return filtered_document_id
|
||||
|
||||
@@ -136,14 +136,12 @@ class FileStore(ABC):
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def delete_file(self, file_id: str, error_on_missing: bool = True) -> None:
|
||||
def delete_file(self, file_id: str) -> None:
|
||||
"""
|
||||
Delete a file by its ID.
|
||||
|
||||
Parameters:
|
||||
- file_id: ID of file to delete
|
||||
- error_on_missing: If False, silently return when the file record
|
||||
does not exist instead of raising.
|
||||
- file_name: Name of file to delete
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
@@ -454,23 +452,12 @@ class S3BackedFileStore(FileStore):
|
||||
logger.warning(f"Error getting file size for {file_id}: {e}")
|
||||
return None
|
||||
|
||||
def delete_file(
|
||||
self,
|
||||
file_id: str,
|
||||
error_on_missing: bool = True,
|
||||
db_session: Session | None = None,
|
||||
) -> None:
|
||||
def delete_file(self, file_id: str, db_session: Session | None = None) -> None:
|
||||
with get_session_with_current_tenant_if_none(db_session) as db_session:
|
||||
try:
|
||||
file_record = get_filerecord_by_file_id_optional(
|
||||
file_record = get_filerecord_by_file_id(
|
||||
file_id=file_id, db_session=db_session
|
||||
)
|
||||
if file_record is None:
|
||||
if error_on_missing:
|
||||
raise RuntimeError(
|
||||
f"File by id {file_id} does not exist or was deleted"
|
||||
)
|
||||
return
|
||||
if not file_record.bucket_name:
|
||||
logger.error(
|
||||
f"File record {file_id} with key {file_record.object_key} "
|
||||
|
||||
@@ -222,23 +222,12 @@ class PostgresBackedFileStore(FileStore):
|
||||
logger.warning(f"Error getting file size for {file_id}: {e}")
|
||||
return None
|
||||
|
||||
def delete_file(
|
||||
self,
|
||||
file_id: str,
|
||||
error_on_missing: bool = True,
|
||||
db_session: Session | None = None,
|
||||
) -> None:
|
||||
def delete_file(self, file_id: str, db_session: Session | None = None) -> None:
|
||||
with get_session_with_current_tenant_if_none(db_session) as session:
|
||||
try:
|
||||
file_content = get_file_content_by_file_id_optional(
|
||||
file_content = get_file_content_by_file_id(
|
||||
file_id=file_id, db_session=session
|
||||
)
|
||||
if file_content is None:
|
||||
if error_on_missing:
|
||||
raise RuntimeError(
|
||||
f"File content for file_id {file_id} does not exist or was deleted"
|
||||
)
|
||||
return
|
||||
raw_conn = _get_raw_connection(session)
|
||||
|
||||
try:
|
||||
|
||||
@@ -3,8 +3,6 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.mcp_server.api import mcp_server
|
||||
from onyx.mcp_server.utils import get_http_client
|
||||
@@ -17,21 +15,6 @@ from onyx.utils.variable_functionality import global_version
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _extract_error_detail(response: httpx.Response) -> str:
|
||||
"""Extract a human-readable error message from a failed backend response.
|
||||
|
||||
The backend returns OnyxError responses as
|
||||
``{"error_code": "...", "detail": "..."}``.
|
||||
"""
|
||||
try:
|
||||
body = response.json()
|
||||
if detail := body.get("detail"):
|
||||
return str(detail)
|
||||
except Exception:
|
||||
pass
|
||||
return f"Request failed with status {response.status_code}"
|
||||
|
||||
|
||||
@mcp_server.tool()
|
||||
async def search_indexed_documents(
|
||||
query: str,
|
||||
@@ -175,14 +158,7 @@ async def search_indexed_documents(
|
||||
json=search_request,
|
||||
headers=auth_headers,
|
||||
)
|
||||
if not response.is_success:
|
||||
error_detail = _extract_error_detail(response)
|
||||
return {
|
||||
"documents": [],
|
||||
"total_results": 0,
|
||||
"query": query,
|
||||
"error": error_detail,
|
||||
}
|
||||
response.raise_for_status()
|
||||
result = response.json()
|
||||
|
||||
# Check for error in response
|
||||
@@ -258,13 +234,7 @@ async def search_web(
|
||||
json=request_payload,
|
||||
headers={"Authorization": f"Bearer {access_token.token}"},
|
||||
)
|
||||
if not response.is_success:
|
||||
error_detail = _extract_error_detail(response)
|
||||
return {
|
||||
"error": error_detail,
|
||||
"results": [],
|
||||
"query": query,
|
||||
}
|
||||
response.raise_for_status()
|
||||
response_payload = response.json()
|
||||
results = response_payload.get("results", [])
|
||||
return {
|
||||
@@ -310,12 +280,7 @@ async def open_urls(
|
||||
json={"urls": urls},
|
||||
headers={"Authorization": f"Bearer {access_token.token}"},
|
||||
)
|
||||
if not response.is_success:
|
||||
error_detail = _extract_error_detail(response)
|
||||
return {
|
||||
"error": error_detail,
|
||||
"results": [],
|
||||
}
|
||||
response.raise_for_status()
|
||||
response_payload = response.json()
|
||||
results = response_payload.get("results", [])
|
||||
return {
|
||||
|
||||
@@ -3,10 +3,10 @@ import datetime
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.onyxbot_configs import ONYX_BOT_FEEDBACK_REMINDER
|
||||
from onyx.configs.onyxbot_configs import ONYX_BOT_REACT_EMOJI
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import SlackChannelConfig
|
||||
from onyx.db.user_preferences import activate_user
|
||||
from onyx.db.users import add_slack_user_if_not_exists
|
||||
@@ -247,7 +247,7 @@ def handle_message(
|
||||
|
||||
elif (
|
||||
not existing_user.is_active
|
||||
and existing_user.account_type == AccountType.BOT
|
||||
and existing_user.role == UserRole.SLACK_USER
|
||||
):
|
||||
check_seat_fn = fetch_ee_implementation_or_noop(
|
||||
"onyx.db.license",
|
||||
|
||||
@@ -59,6 +59,9 @@ from onyx.db.permission_sync_attempt import (
|
||||
from onyx.redis.redis_connector import RedisConnector
|
||||
from onyx.redis.redis_connector_utils import get_deletion_attempt_snapshot
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.documents.mock_connector_data import get_mock_cc_pair_full_info
|
||||
from onyx.server.documents.mock_connector_data import get_mock_index_attempts
|
||||
from onyx.server.documents.mock_connector_data import load_mock_data
|
||||
from onyx.server.documents.models import CCPairFullInfo
|
||||
from onyx.server.documents.models import CCPropertyUpdateRequest
|
||||
from onyx.server.documents.models import CCStatusUpdateRequest
|
||||
@@ -85,6 +88,18 @@ def get_cc_pair_index_attempts(
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> PaginatedReturn[IndexAttemptSnapshot]:
|
||||
mock_data = load_mock_data()
|
||||
if mock_data is not None:
|
||||
mock_attempts = get_mock_index_attempts(mock_data, cc_pair_id)
|
||||
if mock_attempts is not None:
|
||||
all_items = [IndexAttemptSnapshot(**a) for a in mock_attempts]
|
||||
start = page_num * page_size
|
||||
page_items = all_items[start : start + page_size]
|
||||
return PaginatedReturn(
|
||||
items=page_items,
|
||||
total_items=len(all_items),
|
||||
)
|
||||
|
||||
if user:
|
||||
user_has_access = verify_user_has_access_to_cc_pair(
|
||||
cc_pair_id, db_session, user, get_editable=False
|
||||
@@ -157,6 +172,12 @@ def get_cc_pair_full_info(
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> CCPairFullInfo:
|
||||
mock_data = load_mock_data()
|
||||
if mock_data is not None:
|
||||
mock_info = get_mock_cc_pair_full_info(mock_data, cc_pair_id)
|
||||
if mock_info is not None:
|
||||
return CCPairFullInfo(**mock_info)
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
cc_pair = get_connector_credential_pair_from_id_for_user(
|
||||
|
||||
@@ -32,7 +32,6 @@ from onyx.background.celery.tasks.pruning.tasks import (
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
from onyx.configs.app_configs import EMAIL_CONFIGURED
|
||||
from onyx.configs.app_configs import ENABLED_CONNECTOR_TYPES
|
||||
from onyx.configs.app_configs import MOCK_CONNECTOR_FILE_PATH
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
@@ -125,6 +124,8 @@ from onyx.file_store.file_store import FileStore
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.documents.mock_connector_data import get_mock_indexing_statuses
|
||||
from onyx.server.documents.mock_connector_data import load_mock_data
|
||||
from onyx.server.documents.models import AuthStatus
|
||||
from onyx.server.documents.models import AuthUrl
|
||||
from onyx.server.documents.models import ConnectorBase
|
||||
@@ -1115,28 +1116,40 @@ def get_connector_indexing_status(
|
||||
# sqlalchemy-method-connection-for-bind-is-already-in-progress
|
||||
# for why we can't pass in the current db_session to these functions
|
||||
|
||||
if MOCK_CONNECTOR_FILE_PATH:
|
||||
import json
|
||||
mock_data = load_mock_data()
|
||||
if mock_data is not None:
|
||||
mock_statuses = get_mock_indexing_statuses(mock_data)
|
||||
if mock_statuses is not None:
|
||||
# Group statuses by source, mirroring the real code path.
|
||||
source_to_statuses: dict[
|
||||
DocumentSource, list[ConnectorIndexingStatusLite]
|
||||
] = {}
|
||||
for raw in mock_statuses:
|
||||
status = ConnectorIndexingStatusLite(**raw)
|
||||
source_to_statuses.setdefault(status.source, []).append(status)
|
||||
|
||||
with open(MOCK_CONNECTOR_FILE_PATH, "r") as f:
|
||||
raw_data = json.load(f)
|
||||
connector_indexing_statuses = [
|
||||
ConnectorIndexingStatusLite(**status) for status in raw_data
|
||||
]
|
||||
return [
|
||||
ConnectorIndexingStatusLiteResponse(
|
||||
source=DocumentSource.FILE,
|
||||
summary=SourceSummary(
|
||||
total_connectors=100,
|
||||
active_connectors=100,
|
||||
public_connectors=100,
|
||||
total_docs_indexed=100000,
|
||||
),
|
||||
current_page=1,
|
||||
total_pages=1,
|
||||
indexing_statuses=connector_indexing_statuses,
|
||||
)
|
||||
]
|
||||
response_list: list[ConnectorIndexingStatusLiteResponse] = []
|
||||
for source in sorted(source_to_statuses):
|
||||
statuses = source_to_statuses[source]
|
||||
total_docs = sum(s.docs_indexed for s in statuses)
|
||||
public_count = sum(
|
||||
1 for s in statuses if s.access_type == AccessType.PUBLIC
|
||||
)
|
||||
response_list.append(
|
||||
ConnectorIndexingStatusLiteResponse(
|
||||
source=source,
|
||||
summary=SourceSummary(
|
||||
total_connectors=len(statuses),
|
||||
active_connectors=len(statuses),
|
||||
public_connectors=public_count,
|
||||
total_docs_indexed=total_docs,
|
||||
),
|
||||
current_page=1,
|
||||
total_pages=1,
|
||||
indexing_statuses=statuses,
|
||||
)
|
||||
)
|
||||
return response_list
|
||||
|
||||
parallel_functions: list[tuple[CallableProtocol, tuple[Any, ...]]] = [
|
||||
# Get editable connector/credential pairs
|
||||
|
||||
133
backend/onyx/server/documents/mock_connector_data.py
Normal file
133
backend/onyx/server/documents/mock_connector_data.py
Normal file
@@ -0,0 +1,133 @@
|
||||
"""Utilities for loading mock connector data from a JSON file.
|
||||
|
||||
When MOCK_CONNECTOR_FILE_PATH is set, the backend serves connector listing,
|
||||
detail, and index-attempt endpoints from a static JSON file instead of hitting
|
||||
the database. This is useful for frontend development and demos.
|
||||
|
||||
Time-offset support
|
||||
-------------------
|
||||
Any datetime string field in the JSON can be replaced with an *offset string*
|
||||
of the form ``"<offset_seconds>"``, e.g. ``"-3600"`` means "1 hour ago" and
|
||||
``"-86400"`` means "24 hours ago". Positive values point to the future.
|
||||
The offset is resolved to an absolute ISO-8601 datetime at load time, so
|
||||
each request gets a fresh "now".
|
||||
"""
|
||||
|
||||
import json
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
from onyx.configs.app_configs import MOCK_CONNECTOR_FILE_PATH
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# ---- JSON schema top-level keys ------------------------------------------------
|
||||
_KEY_INDEXING_STATUSES = "indexing_statuses"
|
||||
_KEY_CC_PAIR_FULL_INFO = "cc_pair_full_info"
|
||||
_KEY_INDEX_ATTEMPTS = "index_attempts"
|
||||
|
||||
# Fields across the relevant Pydantic models that hold datetimes.
|
||||
_DATETIME_FIELDS: set[str] = {
|
||||
# ConnectorIndexingStatusLite
|
||||
"last_success",
|
||||
# CCPairFullInfo
|
||||
"last_indexed",
|
||||
"last_pruned",
|
||||
"last_full_permission_sync",
|
||||
"last_permission_sync_attempt_finished",
|
||||
# ConnectorSnapshot / CredentialSnapshot
|
||||
"time_created",
|
||||
"time_updated",
|
||||
"indexing_start",
|
||||
# IndexAttemptSnapshot
|
||||
"time_started",
|
||||
"time_updated",
|
||||
"poll_range_start",
|
||||
"poll_range_end",
|
||||
# IndexAttemptErrorPydantic
|
||||
"failed_time_range_start",
|
||||
"failed_time_range_end",
|
||||
"time_created",
|
||||
}
|
||||
|
||||
|
||||
def _resolve_time_offsets(obj: Any) -> Any:
|
||||
"""Walk a JSON-like structure and resolve offset strings to ISO datetimes.
|
||||
|
||||
An offset string is a string that, after stripping whitespace, is parseable
|
||||
as an integer or float. It represents seconds relative to *now*.
|
||||
"""
|
||||
now = datetime.now(tz=timezone.utc)
|
||||
|
||||
if isinstance(obj, dict):
|
||||
return {k: _resolve_value(k, v, now) for k, v in obj.items()}
|
||||
if isinstance(obj, list):
|
||||
return [_resolve_time_offsets(item) for item in obj]
|
||||
return obj
|
||||
|
||||
|
||||
def _resolve_value(key: str, value: Any, now: datetime) -> Any:
|
||||
if isinstance(value, dict):
|
||||
return {k: _resolve_value(k, v, now) for k, v in value.items()}
|
||||
if isinstance(value, list):
|
||||
return [_resolve_time_offsets(item) for item in value]
|
||||
if key in _DATETIME_FIELDS and isinstance(value, str):
|
||||
try:
|
||||
offset_seconds = float(value)
|
||||
return (now + timedelta(seconds=offset_seconds)).isoformat()
|
||||
except ValueError:
|
||||
# Not a numeric string – leave it as-is (already an ISO datetime).
|
||||
pass
|
||||
return value
|
||||
|
||||
|
||||
def _load_raw() -> dict[str, Any] | None:
|
||||
"""Load and return the raw JSON from MOCK_CONNECTOR_FILE_PATH, or None."""
|
||||
if not MOCK_CONNECTOR_FILE_PATH:
|
||||
return None
|
||||
with open(MOCK_CONNECTOR_FILE_PATH) as f:
|
||||
return json.load(f) # type: ignore[no-any-return]
|
||||
|
||||
|
||||
def load_mock_data() -> dict[str, Any] | None:
|
||||
"""Load mock data with time offsets resolved. Returns None when mocking is
|
||||
disabled."""
|
||||
raw = _load_raw()
|
||||
if raw is None:
|
||||
return None
|
||||
|
||||
# Support both the old format (bare list of indexing statuses) and the new
|
||||
# format (dict with explicit keys).
|
||||
if isinstance(raw, list):
|
||||
raw = {_KEY_INDEXING_STATUSES: raw}
|
||||
|
||||
return _resolve_time_offsets(raw) # type: ignore[return-value]
|
||||
|
||||
|
||||
def get_mock_indexing_statuses(
|
||||
data: dict[str, Any],
|
||||
) -> list[dict[str, Any]] | None:
|
||||
return data.get(_KEY_INDEXING_STATUSES)
|
||||
|
||||
|
||||
def get_mock_cc_pair_full_info(
|
||||
data: dict[str, Any],
|
||||
cc_pair_id: int,
|
||||
) -> dict[str, Any] | None:
|
||||
by_id = data.get(_KEY_CC_PAIR_FULL_INFO)
|
||||
if not by_id:
|
||||
return None
|
||||
return by_id.get(str(cc_pair_id))
|
||||
|
||||
|
||||
def get_mock_index_attempts(
|
||||
data: dict[str, Any],
|
||||
cc_pair_id: int,
|
||||
) -> list[dict[str, Any]] | None:
|
||||
by_id = data.get(_KEY_INDEX_ATTEMPTS)
|
||||
if not by_id:
|
||||
return None
|
||||
return by_id.get(str(cc_pair_id))
|
||||
@@ -1,5 +1,6 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_user
|
||||
@@ -8,8 +9,6 @@ from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.models import User
|
||||
from onyx.db.web_search import fetch_active_web_content_provider
|
||||
from onyx.db.web_search import fetch_active_web_search_provider
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.server.features.web_search.models import OpenUrlsToolRequest
|
||||
from onyx.server.features.web_search.models import OpenUrlsToolResponse
|
||||
from onyx.server.features.web_search.models import WebSearchToolRequest
|
||||
@@ -62,10 +61,9 @@ def _get_active_search_provider(
|
||||
) -> tuple[WebSearchProviderView, WebSearchProvider]:
|
||||
provider_model = fetch_active_web_search_provider(db_session)
|
||||
if provider_model is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
"No web search provider configured. Please configure one in "
|
||||
"Admin > Web Search settings.",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No web search provider configured.",
|
||||
)
|
||||
|
||||
provider_view = WebSearchProviderView(
|
||||
@@ -78,10 +76,9 @@ def _get_active_search_provider(
|
||||
)
|
||||
|
||||
if provider_model.api_key is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
"Web search provider requires an API key. Please configure one in "
|
||||
"Admin > Web Search settings.",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Web search provider requires an API key.",
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -91,7 +88,7 @@ def _get_active_search_provider(
|
||||
config=provider_model.config or {},
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(exc)) from exc
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
return provider_view, provider
|
||||
|
||||
@@ -113,9 +110,9 @@ def _get_active_content_provider(
|
||||
|
||||
if provider_model.api_key is None:
|
||||
# TODO - this is not a great error, in fact, this key should not be nullable.
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
"Web content provider requires an API key.",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Web content provider requires an API key.",
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -128,12 +125,12 @@ def _get_active_content_provider(
|
||||
config=config,
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(exc)) from exc
|
||||
raise HTTPException(status_code=400, detail=str(exc)) from exc
|
||||
|
||||
if provider is None:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
"Unable to initialize the configured web content provider.",
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Unable to initialize the configured web content provider.",
|
||||
)
|
||||
|
||||
provider_view = WebContentProviderView(
|
||||
@@ -157,13 +154,12 @@ def _run_web_search(
|
||||
for query in request.queries:
|
||||
try:
|
||||
search_results = provider.search(query)
|
||||
except OnyxError:
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.exception("Web search provider failed for query '%s'", query)
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
"Web search provider failed to execute query.",
|
||||
raise HTTPException(
|
||||
status_code=502, detail="Web search provider failed to execute query."
|
||||
) from exc
|
||||
|
||||
filtered_results = filter_web_search_results_with_no_title_or_snippet(
|
||||
@@ -196,13 +192,12 @@ def _open_urls(
|
||||
docs = filter_web_contents_with_no_title_or_content(
|
||||
list(provider.contents(urls))
|
||||
)
|
||||
except OnyxError:
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
logger.exception("Web content provider failed to fetch URLs")
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.BAD_GATEWAY,
|
||||
"Web content provider failed to fetch URLs.",
|
||||
raise HTTPException(
|
||||
status_code=502, detail="Web content provider failed to fetch URLs."
|
||||
) from exc
|
||||
|
||||
results: list[LlmOpenUrlResult] = []
|
||||
|
||||
@@ -27,7 +27,6 @@ from onyx.auth.email_utils import send_user_email_invite
|
||||
from onyx.auth.invited_users import get_invited_users
|
||||
from onyx.auth.invited_users import remove_user_from_invited_users
|
||||
from onyx.auth.invited_users import write_invited_users
|
||||
from onyx.auth.permissions import get_effective_permissions
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.auth.users import anonymous_user_enabled
|
||||
from onyx.auth.users import current_admin_user
|
||||
@@ -51,7 +50,6 @@ from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
from onyx.db.api_key import is_api_key_email_address
|
||||
from onyx.db.auth import get_live_users_count
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
@@ -144,7 +142,6 @@ def set_user_role(
|
||||
validate_user_role_update(
|
||||
requested_role=requested_role,
|
||||
current_role=current_role,
|
||||
current_account_type=user_to_update.account_type,
|
||||
explicit_override=user_role_update_request.explicit_override,
|
||||
)
|
||||
|
||||
@@ -330,8 +327,8 @@ def list_all_users(
|
||||
if (include_api_keys or not is_api_key_email_address(user.email))
|
||||
]
|
||||
|
||||
slack_users = [user for user in users if user.account_type == AccountType.BOT]
|
||||
accepted_users = [user for user in users if user.account_type != AccountType.BOT]
|
||||
slack_users = [user for user in users if user.role == UserRole.SLACK_USER]
|
||||
accepted_users = [user for user in users if user.role != UserRole.SLACK_USER]
|
||||
|
||||
accepted_emails = {user.email for user in accepted_users}
|
||||
slack_users_emails = {user.email for user in slack_users}
|
||||
@@ -674,7 +671,7 @@ def list_all_users_basic_info(
|
||||
return [
|
||||
MinimalUserSnapshot(id=user.id, email=user.email)
|
||||
for user in users
|
||||
if user.account_type != AccountType.BOT
|
||||
if user.role != UserRole.SLACK_USER
|
||||
and (include_api_keys or not is_api_key_email_address(user.email))
|
||||
]
|
||||
|
||||
@@ -777,13 +774,6 @@ def _get_token_created_at(
|
||||
return get_current_token_creation_postgres(user, db_session)
|
||||
|
||||
|
||||
@router.get("/me/permissions", tags=PUBLIC_API_TAGS)
|
||||
def get_current_user_permissions(
|
||||
user: User = Depends(current_user),
|
||||
) -> list[str]:
|
||||
return sorted(p.value for p in get_effective_permissions(user))
|
||||
|
||||
|
||||
@router.get("/me", tags=PUBLIC_API_TAGS)
|
||||
def verify_user_logged_in(
|
||||
request: Request,
|
||||
|
||||
@@ -7,7 +7,6 @@ from uuid import UUID
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import User
|
||||
|
||||
|
||||
@@ -42,7 +41,6 @@ class FullUserSnapshot(BaseModel):
|
||||
id: UUID
|
||||
email: str
|
||||
role: UserRole
|
||||
account_type: AccountType
|
||||
is_active: bool
|
||||
password_configured: bool
|
||||
personal_name: str | None
|
||||
@@ -62,7 +60,6 @@ class FullUserSnapshot(BaseModel):
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
role=user.role,
|
||||
account_type=user.account_type,
|
||||
is_active=user.is_active,
|
||||
password_configured=user.password_configured,
|
||||
personal_name=user.personal_name,
|
||||
|
||||
@@ -28,7 +28,6 @@ from onyx.chat.chat_utils import extract_headers
|
||||
from onyx.chat.models import ChatFullResponse
|
||||
from onyx.chat.models import CreateChatSessionID
|
||||
from onyx.chat.process_message import gather_stream_full
|
||||
from onyx.chat.process_message import handle_multi_model_stream
|
||||
from onyx.chat.process_message import handle_stream_message_objects
|
||||
from onyx.chat.prompt_utils import get_default_base_system_prompt
|
||||
from onyx.chat.stop_signal_checker import set_fence
|
||||
@@ -47,7 +46,6 @@ from onyx.db.chat import get_chat_messages_by_session
|
||||
from onyx.db.chat import get_chat_session_by_id
|
||||
from onyx.db.chat import get_chat_sessions_by_user
|
||||
from onyx.db.chat import set_as_latest_chat_message
|
||||
from onyx.db.chat import set_preferred_response
|
||||
from onyx.db.chat import translate_db_message_to_chat_message_detail
|
||||
from onyx.db.chat import update_chat_session
|
||||
from onyx.db.chat_search import search_chat_sessions
|
||||
@@ -62,8 +60,6 @@ from onyx.db.persona import get_persona_by_id
|
||||
from onyx.db.usage import increment_usage
|
||||
from onyx.db.usage import UsageType
|
||||
from onyx.db.user_file import get_file_id_by_user_file_id
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.factory import get_default_llm
|
||||
@@ -85,7 +81,6 @@ from onyx.server.query_and_chat.models import ChatSessionUpdateRequest
|
||||
from onyx.server.query_and_chat.models import MessageOrigin
|
||||
from onyx.server.query_and_chat.models import RenameChatSessionResponse
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.server.query_and_chat.models import SetPreferredResponseRequest
|
||||
from onyx.server.query_and_chat.models import UpdateChatSessionTemperatureRequest
|
||||
from onyx.server.query_and_chat.models import UpdateChatSessionThreadRequest
|
||||
from onyx.server.query_and_chat.session_loading import (
|
||||
@@ -575,46 +570,6 @@ def handle_send_chat_message(
|
||||
if get_hashed_api_key_from_request(request) or get_hashed_pat_from_request(request):
|
||||
chat_message_req.origin = MessageOrigin.API
|
||||
|
||||
# Multi-model streaming path: 2-3 LLMs in parallel (streaming only)
|
||||
is_multi_model = (
|
||||
chat_message_req.llm_overrides is not None
|
||||
and len(chat_message_req.llm_overrides) > 1
|
||||
)
|
||||
if is_multi_model and chat_message_req.stream:
|
||||
# Narrowed here; is_multi_model already checked llm_overrides is not None
|
||||
llm_overrides = chat_message_req.llm_overrides or []
|
||||
|
||||
def multi_model_stream_generator() -> Generator[str, None, None]:
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for obj in handle_multi_model_stream(
|
||||
new_msg_req=chat_message_req,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
llm_overrides=llm_overrides,
|
||||
litellm_additional_headers=extract_headers(
|
||||
request.headers, LITELLM_PASS_THROUGH_HEADERS
|
||||
),
|
||||
custom_tool_additional_headers=get_custom_tool_additional_request_headers(
|
||||
request.headers
|
||||
),
|
||||
mcp_headers=chat_message_req.mcp_headers,
|
||||
):
|
||||
yield get_json_line(obj.model_dump())
|
||||
except Exception as e:
|
||||
logger.exception("Error in multi-model streaming")
|
||||
yield json.dumps({"error": str(e)})
|
||||
|
||||
return StreamingResponse(
|
||||
multi_model_stream_generator(), media_type="text/event-stream"
|
||||
)
|
||||
|
||||
if is_multi_model and not chat_message_req.stream:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INVALID_INPUT,
|
||||
"Multi-model mode (llm_overrides with >1 entry) requires stream=True.",
|
||||
)
|
||||
|
||||
# Non-streaming path: consume all packets and return complete response
|
||||
if not chat_message_req.stream:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
@@ -705,30 +660,6 @@ def set_message_as_latest(
|
||||
)
|
||||
|
||||
|
||||
@router.put("/set-preferred-response")
|
||||
def set_preferred_response_endpoint(
|
||||
request_body: SetPreferredResponseRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
"""Set the preferred assistant response for a multi-model turn."""
|
||||
try:
|
||||
# Ownership check: get_chat_message raises ValueError if the message
|
||||
# doesn't belong to this user, preventing cross-user mutation.
|
||||
get_chat_message(
|
||||
chat_message_id=request_body.user_message_id,
|
||||
user_id=user.id if user else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
set_preferred_response(
|
||||
db_session=db_session,
|
||||
user_message_id=request_body.user_message_id,
|
||||
preferred_assistant_message_id=request_body.preferred_response_id,
|
||||
)
|
||||
except ValueError as e:
|
||||
raise OnyxError(OnyxErrorCode.INVALID_INPUT, str(e))
|
||||
|
||||
|
||||
@router.post("/create-chat-message-feedback")
|
||||
def create_chat_feedback(
|
||||
feedback: ChatFeedbackRequest,
|
||||
|
||||
@@ -70,7 +70,7 @@ async def upsert_saml_user(email: str) -> User:
|
||||
try:
|
||||
user = await user_manager.get_by_email(email)
|
||||
# If user has a non-authenticated role, treat as non-existent
|
||||
if not user.account_type.is_web_login():
|
||||
if not user.role.is_web_login():
|
||||
raise exceptions.UserNotExists()
|
||||
return user
|
||||
except exceptions.UserNotExists:
|
||||
|
||||
@@ -7,7 +7,6 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
@@ -53,12 +52,7 @@ def tenant_context() -> Generator[None, None, None]:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
def create_test_user(
|
||||
db_session: Session,
|
||||
email_prefix: str,
|
||||
role: UserRole = UserRole.BASIC,
|
||||
account_type: AccountType = AccountType.STANDARD,
|
||||
) -> User:
|
||||
def create_test_user(db_session: Session, email_prefix: str) -> User:
|
||||
"""Helper to create a test user with a unique email"""
|
||||
# Use UUID to ensure unique email addresses
|
||||
unique_email = f"{email_prefix}_{uuid4().hex[:8]}@example.com"
|
||||
@@ -74,8 +68,7 @@ def create_test_user(
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
is_verified=True,
|
||||
role=role,
|
||||
account_type=account_type,
|
||||
role=UserRole.EXT_PERM_USER,
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
@@ -13,29 +13,16 @@ from onyx.access.utils import build_ext_group_name_for_onyx
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import InputType
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.models import Connector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Credential
|
||||
from onyx.db.models import PublicExternalUserGroup
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__ExternalUserGroupId
|
||||
from onyx.db.models import UserRole
|
||||
from tests.external_dependency_unit.conftest import create_test_user
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
|
||||
def _create_ext_perm_user(db_session: Session, name: str) -> User:
|
||||
"""Create an external-permission user for group sync tests."""
|
||||
return create_test_user(
|
||||
db_session,
|
||||
name,
|
||||
role=UserRole.EXT_PERM_USER,
|
||||
account_type=AccountType.EXT_PERM_USER,
|
||||
)
|
||||
|
||||
|
||||
def _create_test_connector_credential_pair(
|
||||
db_session: Session, source: DocumentSource = DocumentSource.GOOGLE_DRIVE
|
||||
) -> ConnectorCredentialPair:
|
||||
@@ -113,9 +100,9 @@ class TestPerformExternalGroupSync:
|
||||
def test_initial_group_sync(self, db_session: Session) -> None:
|
||||
"""Test syncing external groups for the first time (initial sync)"""
|
||||
# Create test data
|
||||
user1 = _create_ext_perm_user(db_session, "user1")
|
||||
user2 = _create_ext_perm_user(db_session, "user2")
|
||||
user3 = _create_ext_perm_user(db_session, "user3")
|
||||
user1 = create_test_user(db_session, "user1")
|
||||
user2 = create_test_user(db_session, "user2")
|
||||
user3 = create_test_user(db_session, "user3")
|
||||
cc_pair = _create_test_connector_credential_pair(db_session)
|
||||
|
||||
# Mock external groups data as a generator that yields the expected groups
|
||||
@@ -188,9 +175,9 @@ class TestPerformExternalGroupSync:
|
||||
def test_update_existing_groups(self, db_session: Session) -> None:
|
||||
"""Test updating existing groups (adding/removing users)"""
|
||||
# Create test data
|
||||
user1 = _create_ext_perm_user(db_session, "user1")
|
||||
user2 = _create_ext_perm_user(db_session, "user2")
|
||||
user3 = _create_ext_perm_user(db_session, "user3")
|
||||
user1 = create_test_user(db_session, "user1")
|
||||
user2 = create_test_user(db_session, "user2")
|
||||
user3 = create_test_user(db_session, "user3")
|
||||
cc_pair = _create_test_connector_credential_pair(db_session)
|
||||
|
||||
# Initial sync with original groups
|
||||
@@ -285,8 +272,8 @@ class TestPerformExternalGroupSync:
|
||||
def test_remove_groups(self, db_session: Session) -> None:
|
||||
"""Test removing groups (groups that no longer exist in external system)"""
|
||||
# Create test data
|
||||
user1 = _create_ext_perm_user(db_session, "user1")
|
||||
user2 = _create_ext_perm_user(db_session, "user2")
|
||||
user1 = create_test_user(db_session, "user1")
|
||||
user2 = create_test_user(db_session, "user2")
|
||||
cc_pair = _create_test_connector_credential_pair(db_session)
|
||||
|
||||
# Initial sync with multiple groups
|
||||
@@ -370,7 +357,7 @@ class TestPerformExternalGroupSync:
|
||||
def test_empty_group_sync(self, db_session: Session) -> None:
|
||||
"""Test syncing when no groups are returned (all groups removed)"""
|
||||
# Create test data
|
||||
user1 = _create_ext_perm_user(db_session, "user1")
|
||||
user1 = create_test_user(db_session, "user1")
|
||||
cc_pair = _create_test_connector_credential_pair(db_session)
|
||||
|
||||
# Initial sync with groups
|
||||
@@ -426,7 +413,7 @@ class TestPerformExternalGroupSync:
|
||||
# Create many test users
|
||||
users = []
|
||||
for i in range(150): # More than the batch size of 100
|
||||
users.append(_create_ext_perm_user(db_session, f"user{i}"))
|
||||
users.append(create_test_user(db_session, f"user{i}"))
|
||||
|
||||
cc_pair = _create_test_connector_credential_pair(db_session)
|
||||
|
||||
@@ -465,8 +452,8 @@ class TestPerformExternalGroupSync:
|
||||
def test_mixed_regular_and_public_groups(self, db_session: Session) -> None:
|
||||
"""Test syncing a mix of regular and public groups"""
|
||||
# Create test data
|
||||
user1 = _create_ext_perm_user(db_session, "user1")
|
||||
user2 = _create_ext_perm_user(db_session, "user2")
|
||||
user1 = create_test_user(db_session, "user1")
|
||||
user2 = create_test_user(db_session, "user2")
|
||||
cc_pair = _create_test_connector_credential_pair(db_session)
|
||||
|
||||
def mixed_group_sync_func(
|
||||
|
||||
@@ -9,7 +9,6 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.enums import BuildSessionStatus
|
||||
from onyx.db.models import BuildSession
|
||||
from onyx.db.models import User
|
||||
@@ -53,7 +52,6 @@ def test_user(db_session: Session, tenant_context: None) -> User: # noqa: ARG00
|
||||
is_superuser=False,
|
||||
is_verified=True,
|
||||
role=UserRole.EXT_PERM_USER,
|
||||
account_type=AccountType.EXT_PERM_USER,
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
@@ -1,51 +0,0 @@
|
||||
"""
|
||||
Tests that account_type is correctly set when creating users through
|
||||
the internal DB functions: add_slack_user_if_not_exists and
|
||||
batch_add_ext_perm_user_if_not_exists.
|
||||
|
||||
These functions are called by background workers (Slack bot, permission sync)
|
||||
and are not exposed via API endpoints, so they must be tested directly.
|
||||
"""
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.db.users import add_slack_user_if_not_exists
|
||||
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
|
||||
|
||||
|
||||
def test_slack_user_creation_sets_account_type_bot(db_session: Session) -> None:
|
||||
"""add_slack_user_if_not_exists sets account_type=BOT and role=SLACK_USER."""
|
||||
user = add_slack_user_if_not_exists(db_session, "slack_acct_type@test.com")
|
||||
|
||||
assert user.role == UserRole.SLACK_USER
|
||||
assert user.account_type == AccountType.BOT
|
||||
|
||||
|
||||
def test_ext_perm_user_creation_sets_account_type(db_session: Session) -> None:
|
||||
"""batch_add_ext_perm_user_if_not_exists sets account_type=EXT_PERM_USER."""
|
||||
users = batch_add_ext_perm_user_if_not_exists(
|
||||
db_session, ["extperm_acct_type@test.com"]
|
||||
)
|
||||
|
||||
assert len(users) == 1
|
||||
user = users[0]
|
||||
assert user.role == UserRole.EXT_PERM_USER
|
||||
assert user.account_type == AccountType.EXT_PERM_USER
|
||||
|
||||
|
||||
def test_ext_perm_to_slack_upgrade_updates_role_and_account_type(
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""When an EXT_PERM_USER is upgraded to slack, both role and account_type update."""
|
||||
email = "ext_to_slack_acct_type@test.com"
|
||||
|
||||
# Create as ext_perm user first
|
||||
batch_add_ext_perm_user_if_not_exists(db_session, [email])
|
||||
|
||||
# Now "upgrade" via slack path
|
||||
user = add_slack_user_if_not_exists(db_session, email)
|
||||
|
||||
assert user.role == UserRole.SLACK_USER
|
||||
assert user.account_type == AccountType.BOT
|
||||
@@ -8,7 +8,6 @@ import pytest
|
||||
from fastapi_users.password import PasswordHelper
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import remove_llm_provider
|
||||
from onyx.db.llm import update_default_provider
|
||||
@@ -47,7 +46,6 @@ def _create_admin(db_session: Session) -> User:
|
||||
is_superuser=True,
|
||||
is_verified=True,
|
||||
role=UserRole.ADMIN,
|
||||
account_type=AccountType.STANDARD,
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
@@ -126,15 +126,6 @@ class UserManager:
|
||||
|
||||
return test_user
|
||||
|
||||
@staticmethod
|
||||
def get_permissions(user: DATestUser) -> list[str]:
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/me/permissions",
|
||||
headers=user.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
@staticmethod
|
||||
def is_role(
|
||||
user_to_verify: DATestUser,
|
||||
|
||||
@@ -104,30 +104,13 @@ class UserGroupManager:
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def get_permissions(
|
||||
user_group: DATestUserGroup,
|
||||
user_performing_action: DATestUser,
|
||||
) -> list[str]:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/manage/admin/user-group/{user_group.id}/permissions",
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
@staticmethod
|
||||
def get_all(
|
||||
user_performing_action: DATestUser,
|
||||
include_default: bool = False,
|
||||
) -> list[UserGroup]:
|
||||
params: dict[str, str] = {}
|
||||
if include_default:
|
||||
params["include_default"] = "true"
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/manage/admin/user-group",
|
||||
headers=user_performing_action.headers,
|
||||
params=params,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [UserGroup(**ug) for ug in response.json()]
|
||||
|
||||
@@ -9,7 +9,6 @@ This test verifies the full flow: provisioning failure → rollback → schema c
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from sqlalchemy import text
|
||||
@@ -56,28 +55,18 @@ class TestTenantProvisioningRollback:
|
||||
created_tenant_id = tenant_id
|
||||
return create_schema_if_not_exists(tenant_id)
|
||||
|
||||
# Mock setup_tenant to fail after schema creation.
|
||||
# Also mock the Redis lock so the test doesn't compete with a live
|
||||
# monitoring worker that may already hold the provision lock.
|
||||
mock_lock = MagicMock()
|
||||
mock_lock.acquire.return_value = True
|
||||
|
||||
# Mock setup_tenant to fail after schema creation
|
||||
with patch(
|
||||
"ee.onyx.background.celery.tasks.tenant_provisioning.tasks.get_redis_client"
|
||||
) as mock_redis:
|
||||
mock_redis.return_value.lock.return_value = mock_lock
|
||||
"ee.onyx.background.celery.tasks.tenant_provisioning.tasks.setup_tenant"
|
||||
) as mock_setup:
|
||||
mock_setup.side_effect = Exception("Simulated provisioning failure")
|
||||
|
||||
with patch(
|
||||
"ee.onyx.background.celery.tasks.tenant_provisioning.tasks.setup_tenant"
|
||||
) as mock_setup:
|
||||
mock_setup.side_effect = Exception("Simulated provisioning failure")
|
||||
|
||||
with patch(
|
||||
"ee.onyx.background.celery.tasks.tenant_provisioning.tasks.create_schema_if_not_exists",
|
||||
side_effect=track_schema_creation,
|
||||
):
|
||||
# Run pre-provisioning - it should fail and trigger rollback
|
||||
pre_provision_tenant()
|
||||
"ee.onyx.background.celery.tasks.tenant_provisioning.tasks.create_schema_if_not_exists",
|
||||
side_effect=track_schema_creation,
|
||||
):
|
||||
# Run pre-provisioning - it should fail and trigger rollback
|
||||
pre_provision_tenant()
|
||||
|
||||
# Verify that the schema was created and then cleaned up
|
||||
assert created_tenant_id is not None, "Schema should have been created"
|
||||
|
||||
@@ -1,13 +1,9 @@
|
||||
from uuid import UUID
|
||||
|
||||
import requests
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.enums import AccountType
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.api_key import APIKeyManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
from tests.integration.common_utils.test_models import DATestAPIKey
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
@@ -37,120 +33,3 @@ def test_limited(reset: None) -> None: # noqa: ARG001
|
||||
headers=api_key.headers,
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
def _get_service_account_account_type(
|
||||
admin_user: DATestUser,
|
||||
api_key_user_id: UUID,
|
||||
) -> AccountType:
|
||||
"""Fetch the account_type of a service account user via the user listing API."""
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/manage/users",
|
||||
headers=admin_user.headers,
|
||||
params={"include_api_keys": "true"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
user_id_str = str(api_key_user_id)
|
||||
for user in data["accepted"]:
|
||||
if user["id"] == user_id_str:
|
||||
return AccountType(user["account_type"])
|
||||
raise AssertionError(
|
||||
f"Service account user {user_id_str} not found in user listing"
|
||||
)
|
||||
|
||||
|
||||
def _get_default_group_user_ids(
|
||||
admin_user: DATestUser,
|
||||
) -> tuple[set[str], set[str]]:
|
||||
"""Return (admin_group_user_ids, basic_group_user_ids) from default groups."""
|
||||
all_groups = UserGroupManager.get_all(
|
||||
user_performing_action=admin_user,
|
||||
include_default=True,
|
||||
)
|
||||
admin_group = next(
|
||||
(g for g in all_groups if g.name == "Admin" and g.is_default), None
|
||||
)
|
||||
basic_group = next(
|
||||
(g for g in all_groups if g.name == "Basic" and g.is_default), None
|
||||
)
|
||||
assert admin_group is not None, "Admin default group not found"
|
||||
assert basic_group is not None, "Basic default group not found"
|
||||
|
||||
admin_ids = {str(u.id) for u in admin_group.users}
|
||||
basic_ids = {str(u.id) for u in basic_group.users}
|
||||
return admin_ids, basic_ids
|
||||
|
||||
|
||||
def test_api_key_limited_service_account(reset: None) -> None: # noqa: ARG001
|
||||
"""LIMITED role API key: account_type is SERVICE_ACCOUNT, no group membership."""
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
api_key: DATestAPIKey = APIKeyManager.create(
|
||||
api_key_role=UserRole.LIMITED,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Verify account_type
|
||||
account_type = _get_service_account_account_type(admin_user, api_key.user_id)
|
||||
assert (
|
||||
account_type == AccountType.SERVICE_ACCOUNT
|
||||
), f"Expected account_type={AccountType.SERVICE_ACCOUNT}, got {account_type}"
|
||||
|
||||
# Verify no group membership
|
||||
admin_ids, basic_ids = _get_default_group_user_ids(admin_user)
|
||||
user_id_str = str(api_key.user_id)
|
||||
assert (
|
||||
user_id_str not in admin_ids
|
||||
), "LIMITED API key should NOT be in Admin default group"
|
||||
assert (
|
||||
user_id_str not in basic_ids
|
||||
), "LIMITED API key should NOT be in Basic default group"
|
||||
|
||||
|
||||
def test_api_key_basic_service_account(reset: None) -> None: # noqa: ARG001
|
||||
"""BASIC role API key: account_type is SERVICE_ACCOUNT, in Basic group only."""
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
api_key: DATestAPIKey = APIKeyManager.create(
|
||||
api_key_role=UserRole.BASIC,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Verify account_type
|
||||
account_type = _get_service_account_account_type(admin_user, api_key.user_id)
|
||||
assert (
|
||||
account_type == AccountType.SERVICE_ACCOUNT
|
||||
), f"Expected account_type={AccountType.SERVICE_ACCOUNT}, got {account_type}"
|
||||
|
||||
# Verify Basic group membership
|
||||
admin_ids, basic_ids = _get_default_group_user_ids(admin_user)
|
||||
user_id_str = str(api_key.user_id)
|
||||
assert user_id_str in basic_ids, "BASIC API key should be in Basic default group"
|
||||
assert (
|
||||
user_id_str not in admin_ids
|
||||
), "BASIC API key should NOT be in Admin default group"
|
||||
|
||||
|
||||
def test_api_key_admin_service_account(reset: None) -> None: # noqa: ARG001
|
||||
"""ADMIN role API key: account_type is SERVICE_ACCOUNT, in Admin group only."""
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
api_key: DATestAPIKey = APIKeyManager.create(
|
||||
api_key_role=UserRole.ADMIN,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Verify account_type
|
||||
account_type = _get_service_account_account_type(admin_user, api_key.user_id)
|
||||
assert (
|
||||
account_type == AccountType.SERVICE_ACCOUNT
|
||||
), f"Expected account_type={AccountType.SERVICE_ACCOUNT}, got {account_type}"
|
||||
|
||||
# Verify Admin group membership
|
||||
admin_ids, basic_ids = _get_default_group_user_ids(admin_user)
|
||||
user_id_str = str(api_key.user_id)
|
||||
assert user_id_str in admin_ids, "ADMIN API key should be in Admin default group"
|
||||
assert (
|
||||
user_id_str not in basic_ids
|
||||
), "ADMIN API key should NOT be in Basic default group"
|
||||
|
||||
@@ -4,32 +4,11 @@ import pytest
|
||||
import requests
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.enums import AccountType
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def _simulate_saml_login(email: str, admin_user: DATestUser) -> dict:
|
||||
"""Simulate a SAML login by calling the test upsert endpoint."""
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/manage/users/test-upsert-user",
|
||||
json={"email": email},
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
def _get_basic_group_member_emails(admin_user: DATestUser) -> set[str]:
|
||||
"""Get the set of emails of all members in the Basic default group."""
|
||||
all_groups = UserGroupManager.get_all(admin_user, include_default=True)
|
||||
basic_default = [g for g in all_groups if g.is_default and g.name == "Basic"]
|
||||
assert basic_default, "Basic default group not found"
|
||||
return {u.email for u in basic_default[0].users}
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="SAML tests are enterprise only",
|
||||
@@ -70,9 +49,15 @@ def test_saml_user_conversion(reset: None) -> None: # noqa: ARG001
|
||||
assert UserManager.is_role(test_user, UserRole.EXT_PERM_USER)
|
||||
|
||||
# Simulate SAML login by calling the test endpoint
|
||||
user_data = _simulate_saml_login(test_user_email, admin_user)
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/manage/users/test-upsert-user",
|
||||
json={"email": test_user_email},
|
||||
headers=admin_user.headers, # Use admin headers for authorization
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
# Verify the response indicates the role changed to BASIC
|
||||
user_data = response.json()
|
||||
assert user_data["role"] == UserRole.BASIC.value
|
||||
|
||||
# Verify user role was changed in the database
|
||||
@@ -97,237 +82,16 @@ def test_saml_user_conversion(reset: None) -> None: # noqa: ARG001
|
||||
assert UserManager.is_role(slack_user, UserRole.SLACK_USER)
|
||||
|
||||
# Simulate SAML login again
|
||||
user_data = _simulate_saml_login(slack_user_email, admin_user)
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/manage/users/test-upsert-user",
|
||||
json={"email": slack_user_email},
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
# Verify the response indicates the role changed to BASIC
|
||||
user_data = response.json()
|
||||
assert user_data["role"] == UserRole.BASIC.value
|
||||
|
||||
# Verify the user's role was changed in the database
|
||||
assert UserManager.is_role(slack_user, UserRole.BASIC)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="SAML tests are enterprise only",
|
||||
)
|
||||
def test_saml_user_conversion_sets_account_type_and_group(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""
|
||||
Test that SAML login sets account_type to STANDARD when converting a
|
||||
non-web user (EXT_PERM_USER) and that the user receives the correct role
|
||||
(BASIC) after conversion.
|
||||
|
||||
This validates the permissions-migration-phase2 changes which ensure that:
|
||||
1. account_type is updated to 'standard' on SAML conversion
|
||||
2. The converted user is assigned to the Basic default group
|
||||
"""
|
||||
# Create an admin user (first user is automatically admin)
|
||||
admin_user: DATestUser = UserManager.create(email="admin@example.com")
|
||||
|
||||
# Create a user and set them as EXT_PERM_USER
|
||||
test_email = "ext_convert@example.com"
|
||||
test_user = UserManager.create(email=test_email)
|
||||
UserManager.set_role(
|
||||
user_to_set=test_user,
|
||||
target_role=UserRole.EXT_PERM_USER,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
assert UserManager.is_role(test_user, UserRole.EXT_PERM_USER)
|
||||
|
||||
# Simulate SAML login
|
||||
user_data = _simulate_saml_login(test_email, admin_user)
|
||||
|
||||
# Verify account_type is set to standard after conversion
|
||||
assert (
|
||||
user_data["account_type"] == AccountType.STANDARD.value
|
||||
), f"Expected account_type='{AccountType.STANDARD.value}', got '{user_data['account_type']}'"
|
||||
|
||||
# Verify role is BASIC after conversion
|
||||
assert user_data["role"] == UserRole.BASIC.value
|
||||
|
||||
# Verify the user was assigned to the Basic default group
|
||||
assert test_email in _get_basic_group_member_emails(
|
||||
admin_user
|
||||
), f"Converted user '{test_email}' not found in Basic default group"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="SAML tests are enterprise only",
|
||||
)
|
||||
def test_saml_normal_signin_assigns_group(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""
|
||||
Test that a brand-new user signing in via SAML for the first time
|
||||
is created with the correct role, account_type, and group membership.
|
||||
|
||||
This validates that normal SAML sign-in (not an upgrade from
|
||||
SLACK_USER/EXT_PERM_USER) correctly:
|
||||
1. Creates the user with role=BASIC and account_type=STANDARD
|
||||
2. Assigns the user to the Basic default group
|
||||
"""
|
||||
# First user becomes admin
|
||||
admin_user: DATestUser = UserManager.create(email="admin@example.com")
|
||||
|
||||
# New user signs in via SAML (no prior account)
|
||||
new_email = "new_saml_user@example.com"
|
||||
user_data = _simulate_saml_login(new_email, admin_user)
|
||||
|
||||
# Verify role and account_type
|
||||
assert user_data["role"] == UserRole.BASIC.value
|
||||
assert user_data["account_type"] == AccountType.STANDARD.value
|
||||
|
||||
# Verify user is in the Basic default group
|
||||
assert new_email in _get_basic_group_member_emails(
|
||||
admin_user
|
||||
), f"New SAML user '{new_email}' not found in Basic default group"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="SAML tests are enterprise only",
|
||||
)
|
||||
def test_saml_user_conversion_restores_group_membership(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""
|
||||
Test that SAML login restores Basic group membership when converting
|
||||
a non-authenticated user (EXT_PERM_USER or SLACK_USER) to BASIC.
|
||||
|
||||
Group membership implies 'basic' permission (verified by
|
||||
test_new_group_gets_basic_permission).
|
||||
"""
|
||||
admin_user: DATestUser = UserManager.create(email="admin@example.com")
|
||||
|
||||
# --- EXT_PERM_USER path ---
|
||||
ext_email = "ext_perm_perms@example.com"
|
||||
ext_user = UserManager.create(email=ext_email)
|
||||
assert ext_email in _get_basic_group_member_emails(admin_user)
|
||||
|
||||
UserManager.set_role(
|
||||
user_to_set=ext_user,
|
||||
target_role=UserRole.EXT_PERM_USER,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
assert ext_email not in _get_basic_group_member_emails(admin_user)
|
||||
|
||||
user_data = _simulate_saml_login(ext_email, admin_user)
|
||||
assert user_data["role"] == UserRole.BASIC.value
|
||||
assert ext_email in _get_basic_group_member_emails(
|
||||
admin_user
|
||||
), "EXT_PERM_USER should be back in Basic group after SAML conversion"
|
||||
|
||||
# --- SLACK_USER path ---
|
||||
slack_email = "slack_perms@example.com"
|
||||
slack_user = UserManager.create(email=slack_email)
|
||||
|
||||
UserManager.set_role(
|
||||
user_to_set=slack_user,
|
||||
target_role=UserRole.SLACK_USER,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
assert slack_email not in _get_basic_group_member_emails(admin_user)
|
||||
|
||||
user_data = _simulate_saml_login(slack_email, admin_user)
|
||||
assert user_data["role"] == UserRole.BASIC.value
|
||||
assert slack_email in _get_basic_group_member_emails(
|
||||
admin_user
|
||||
), "SLACK_USER should be back in Basic group after SAML conversion"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="SAML tests are enterprise only",
|
||||
)
|
||||
def test_saml_round_trip_group_lifecycle(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""
|
||||
Test the full round-trip: BASIC -> EXT_PERM -> SAML(BASIC) -> EXT_PERM -> SAML(BASIC).
|
||||
|
||||
Verifies group membership is correctly removed and restored at each transition.
|
||||
"""
|
||||
admin_user: DATestUser = UserManager.create(email="admin@example.com")
|
||||
|
||||
test_email = "roundtrip@example.com"
|
||||
test_user = UserManager.create(email=test_email)
|
||||
|
||||
# Step 1: BASIC user is in Basic group
|
||||
assert test_email in _get_basic_group_member_emails(admin_user)
|
||||
|
||||
# Step 2: Downgrade to EXT_PERM_USER — loses Basic group
|
||||
UserManager.set_role(
|
||||
user_to_set=test_user,
|
||||
target_role=UserRole.EXT_PERM_USER,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
assert test_email not in _get_basic_group_member_emails(admin_user)
|
||||
|
||||
# Step 3: SAML login — converts back to BASIC, regains Basic group
|
||||
_simulate_saml_login(test_email, admin_user)
|
||||
assert test_email in _get_basic_group_member_emails(
|
||||
admin_user
|
||||
), "Should be in Basic group after first SAML conversion"
|
||||
|
||||
# Step 4: Downgrade again
|
||||
UserManager.set_role(
|
||||
user_to_set=test_user,
|
||||
target_role=UserRole.EXT_PERM_USER,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
assert test_email not in _get_basic_group_member_emails(admin_user)
|
||||
|
||||
# Step 5: SAML login again — should still restore correctly
|
||||
_simulate_saml_login(test_email, admin_user)
|
||||
assert test_email in _get_basic_group_member_emails(
|
||||
admin_user
|
||||
), "Should be in Basic group after second SAML conversion"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="SAML tests are enterprise only",
|
||||
)
|
||||
def test_saml_slack_user_conversion_sets_account_type_and_group(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""
|
||||
Test that SAML login sets account_type to STANDARD and assigns Basic group
|
||||
when converting a SLACK_USER (BOT account_type).
|
||||
|
||||
Mirrors test_saml_user_conversion_sets_account_type_and_group but for
|
||||
SLACK_USER instead of EXT_PERM_USER, and additionally verifies permissions.
|
||||
"""
|
||||
admin_user: DATestUser = UserManager.create(email="admin@example.com")
|
||||
|
||||
test_email = "slack_convert@example.com"
|
||||
test_user = UserManager.create(email=test_email)
|
||||
|
||||
UserManager.set_role(
|
||||
user_to_set=test_user,
|
||||
target_role=UserRole.SLACK_USER,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
assert UserManager.is_role(test_user, UserRole.SLACK_USER)
|
||||
|
||||
# SAML login
|
||||
user_data = _simulate_saml_login(test_email, admin_user)
|
||||
|
||||
# Verify account_type and role
|
||||
assert (
|
||||
user_data["account_type"] == AccountType.STANDARD.value
|
||||
), f"Expected STANDARD, got {user_data['account_type']}"
|
||||
assert user_data["role"] == UserRole.BASIC.value
|
||||
|
||||
# Verify Basic group membership (implies 'basic' permission)
|
||||
assert test_email in _get_basic_group_member_emails(
|
||||
admin_user
|
||||
), f"Converted SLACK_USER '{test_email}' not found in Basic default group"
|
||||
|
||||
@@ -1,82 +0,0 @@
|
||||
"""Integration tests for permission propagation across auth-triggered group changes.
|
||||
|
||||
These tests verify that effective permissions (via /me/permissions) actually
|
||||
propagate when users are added/removed from default groups through role changes.
|
||||
Custom permission grant tests will be added once the permission grant API is built.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def _get_basic_group_member_emails(admin_user: DATestUser) -> set[str]:
|
||||
all_groups = UserGroupManager.get_all(admin_user, include_default=True)
|
||||
basic_group = next(
|
||||
(g for g in all_groups if g.is_default and g.name == "Basic"), None
|
||||
)
|
||||
assert basic_group is not None, "Basic default group not found"
|
||||
return {u.email for u in basic_group.users}
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="Permission propagation tests require enterprise features",
|
||||
)
|
||||
def test_basic_permission_granted_on_registration(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""New users should get 'basic' permission through default group assignment."""
|
||||
admin_user: DATestUser = UserManager.create(email="admin@example.com")
|
||||
basic_user: DATestUser = UserManager.create(email="basic@example.com")
|
||||
|
||||
# Admin should have permissions from Admin group
|
||||
admin_perms = UserManager.get_permissions(admin_user)
|
||||
assert "basic" in admin_perms
|
||||
|
||||
# Basic user should have 'basic' from Basic default group
|
||||
basic_perms = UserManager.get_permissions(basic_user)
|
||||
assert "basic" in basic_perms
|
||||
|
||||
# Verify group membership matches
|
||||
assert basic_user.email in _get_basic_group_member_emails(admin_user)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="Permission propagation tests require enterprise features",
|
||||
)
|
||||
def test_role_downgrade_removes_basic_group_and_permission(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Downgrading to EXT_PERM_USER or SLACK_USER should remove from Basic group."""
|
||||
admin_user: DATestUser = UserManager.create(email="admin@example.com")
|
||||
|
||||
# --- EXT_PERM_USER ---
|
||||
ext_user: DATestUser = UserManager.create(email="ext@example.com")
|
||||
assert ext_user.email in _get_basic_group_member_emails(admin_user)
|
||||
|
||||
UserManager.set_role(
|
||||
user_to_set=ext_user,
|
||||
target_role=UserRole.EXT_PERM_USER,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
assert ext_user.email not in _get_basic_group_member_emails(admin_user)
|
||||
|
||||
# --- SLACK_USER ---
|
||||
slack_user: DATestUser = UserManager.create(email="slack@example.com")
|
||||
assert slack_user.email in _get_basic_group_member_emails(admin_user)
|
||||
|
||||
UserManager.set_role(
|
||||
user_to_set=slack_user,
|
||||
target_role=UserRole.SLACK_USER,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
assert slack_user.email not in _get_basic_group_member_emails(admin_user)
|
||||
@@ -21,15 +21,8 @@ import pytest
|
||||
import requests
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from tests.integration.common_utils.constants import ADMIN_USER_NAME
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.managers.scim_client import ScimClient
|
||||
from tests.integration.common_utils.managers.scim_token import ScimTokenManager
|
||||
from tests.integration.common_utils.managers.user import build_email
|
||||
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
SCIM_GROUP_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Group"
|
||||
@@ -51,6 +44,13 @@ def scim_token(idp_style: str) -> str:
|
||||
per IdP-style run and reuse. Uses UserManager directly to avoid
|
||||
fixture-scope conflicts with the function-scoped admin_user fixture.
|
||||
"""
|
||||
from tests.integration.common_utils.constants import ADMIN_USER_NAME
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.managers.user import build_email
|
||||
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
try:
|
||||
admin = UserManager.create(name=ADMIN_USER_NAME)
|
||||
except Exception:
|
||||
@@ -550,145 +550,3 @@ def test_patch_add_duplicate_member_is_idempotent(
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()["members"]) == 1 # still just one member
|
||||
|
||||
|
||||
def test_create_group_reserved_name_admin(scim_token: str) -> None:
|
||||
"""POST /Groups with reserved name 'Admin' returns 409."""
|
||||
resp = _create_scim_group(scim_token, "Admin", external_id="ext-reserved-admin")
|
||||
assert resp.status_code == 409
|
||||
assert "reserved" in resp.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_create_group_reserved_name_basic(scim_token: str) -> None:
|
||||
"""POST /Groups with reserved name 'Basic' returns 409."""
|
||||
resp = _create_scim_group(scim_token, "Basic", external_id="ext-reserved-basic")
|
||||
assert resp.status_code == 409
|
||||
assert "reserved" in resp.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_replace_group_cannot_rename_to_reserved(
|
||||
scim_token: str, idp_style: str
|
||||
) -> None:
|
||||
"""PUT /Groups/{id} renaming a group to 'Admin' returns 409."""
|
||||
created = _create_scim_group(
|
||||
scim_token,
|
||||
f"Rename To Reserved {idp_style}",
|
||||
external_id=f"ext-rtr-{idp_style}",
|
||||
).json()
|
||||
|
||||
resp = ScimClient.put(
|
||||
f"/Groups/{created['id']}",
|
||||
scim_token,
|
||||
json=_make_group_resource(
|
||||
display_name="Admin", external_id=f"ext-rtr-{idp_style}"
|
||||
),
|
||||
)
|
||||
assert resp.status_code == 409
|
||||
assert "reserved" in resp.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_patch_rename_to_reserved_name(scim_token: str, idp_style: str) -> None:
|
||||
"""PATCH /Groups/{id} renaming a group to 'Basic' returns 409."""
|
||||
created = _create_scim_group(
|
||||
scim_token,
|
||||
f"Patch Rename Reserved {idp_style}",
|
||||
external_id=f"ext-prr-{idp_style}",
|
||||
).json()
|
||||
|
||||
resp = ScimClient.patch(
|
||||
f"/Groups/{created['id']}",
|
||||
scim_token,
|
||||
json=_make_patch_request(
|
||||
[{"op": "replace", "path": "displayName", "value": "Basic"}],
|
||||
idp_style,
|
||||
),
|
||||
)
|
||||
assert resp.status_code == 409
|
||||
assert "reserved" in resp.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_delete_reserved_group_rejected(scim_token: str) -> None:
|
||||
"""DELETE /Groups/{id} on a reserved group ('Admin') returns 409."""
|
||||
# Look up the reserved 'Admin' group via SCIM filter
|
||||
resp = ScimClient.get('/Groups?filter=displayName eq "Admin"', scim_token)
|
||||
assert resp.status_code == 200
|
||||
resources = resp.json()["Resources"]
|
||||
assert len(resources) >= 1, "Expected reserved 'Admin' group to exist"
|
||||
admin_group_id = resources[0]["id"]
|
||||
|
||||
resp = ScimClient.delete(f"/Groups/{admin_group_id}", scim_token)
|
||||
assert resp.status_code == 409
|
||||
assert "reserved" in resp.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_scim_created_group_has_basic_permission(
|
||||
scim_token: str, idp_style: str
|
||||
) -> None:
|
||||
"""POST /Groups assigns the 'basic' permission to the group itself."""
|
||||
# Create a SCIM group (no members needed — we check the group's permissions)
|
||||
resp = _create_scim_group(
|
||||
scim_token,
|
||||
f"Basic Perm Group {idp_style}",
|
||||
external_id=f"ext-basic-perm-{idp_style}",
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
group_id = resp.json()["id"]
|
||||
|
||||
# Log in as the admin user (created by the scim_token fixture).
|
||||
admin = DATestUser(
|
||||
id="",
|
||||
email=build_email(ADMIN_USER_NAME),
|
||||
password=DEFAULT_PASSWORD,
|
||||
headers=GENERAL_HEADERS,
|
||||
role=UserRole.ADMIN,
|
||||
is_active=True,
|
||||
)
|
||||
admin = UserManager.login_as_user(admin)
|
||||
|
||||
# Verify the group itself was granted the basic permission
|
||||
perms_resp = requests.get(
|
||||
f"{API_SERVER_URL}/manage/admin/user-group/{group_id}/permissions",
|
||||
headers=admin.headers,
|
||||
)
|
||||
perms_resp.raise_for_status()
|
||||
perms = perms_resp.json()
|
||||
assert "basic" in perms, f"SCIM group should have 'basic' permission, got: {perms}"
|
||||
|
||||
|
||||
def test_replace_group_cannot_rename_from_reserved(scim_token: str) -> None:
|
||||
"""PUT /Groups/{id} renaming a reserved group ('Admin') to a non-reserved name returns 409."""
|
||||
resp = ScimClient.get('/Groups?filter=displayName eq "Admin"', scim_token)
|
||||
assert resp.status_code == 200
|
||||
resources = resp.json()["Resources"]
|
||||
assert len(resources) >= 1, "Expected reserved 'Admin' group to exist"
|
||||
admin_group_id = resources[0]["id"]
|
||||
|
||||
resp = ScimClient.put(
|
||||
f"/Groups/{admin_group_id}",
|
||||
scim_token,
|
||||
json=_make_group_resource(
|
||||
display_name="RenamedAdmin", external_id="ext-rename-from-reserved"
|
||||
),
|
||||
)
|
||||
assert resp.status_code == 409
|
||||
assert "reserved" in resp.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_patch_rename_from_reserved_name(scim_token: str, idp_style: str) -> None:
|
||||
"""PATCH /Groups/{id} renaming a reserved group ('Admin') returns 409."""
|
||||
resp = ScimClient.get('/Groups?filter=displayName eq "Admin"', scim_token)
|
||||
assert resp.status_code == 200
|
||||
resources = resp.json()["Resources"]
|
||||
assert len(resources) >= 1, "Expected reserved 'Admin' group to exist"
|
||||
admin_group_id = resources[0]["id"]
|
||||
|
||||
resp = ScimClient.patch(
|
||||
f"/Groups/{admin_group_id}",
|
||||
scim_token,
|
||||
json=_make_patch_request(
|
||||
[{"op": "replace", "path": "displayName", "value": "RenamedAdmin"}],
|
||||
idp_style,
|
||||
),
|
||||
)
|
||||
assert resp.status_code == 409
|
||||
assert "reserved" in resp.json()["detail"].lower()
|
||||
|
||||
@@ -35,16 +35,9 @@ from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.app_configs import REDIS_DB_NUMBER
|
||||
from onyx.configs.app_configs import REDIS_HOST
|
||||
from onyx.configs.app_configs import REDIS_PORT
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
from tests.integration.common_utils.constants import ADMIN_USER_NAME
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.managers.scim_client import ScimClient
|
||||
from tests.integration.common_utils.managers.scim_token import ScimTokenManager
|
||||
from tests.integration.common_utils.managers.user import build_email
|
||||
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
SCIM_USER_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:User"
|
||||
@@ -218,49 +211,6 @@ def test_create_user(scim_token: str, idp_style: str) -> None:
|
||||
_assert_entra_emails(body, email)
|
||||
|
||||
|
||||
def test_create_user_default_group_and_account_type(
|
||||
scim_token: str, idp_style: str
|
||||
) -> None:
|
||||
"""SCIM-provisioned users get Basic default group and STANDARD account_type."""
|
||||
email = f"scim_defaults_{idp_style}@example.com"
|
||||
ext_id = f"ext-defaults-{idp_style}"
|
||||
resp = _create_scim_user(scim_token, email, ext_id, idp_style)
|
||||
assert resp.status_code == 201
|
||||
user_id = resp.json()["id"]
|
||||
|
||||
# --- Verify group assignment via SCIM GET ---
|
||||
get_resp = ScimClient.get(f"/Users/{user_id}", scim_token)
|
||||
assert get_resp.status_code == 200
|
||||
groups = get_resp.json().get("groups", [])
|
||||
group_names = {g["display"] for g in groups}
|
||||
assert "Basic" in group_names, f"Expected 'Basic' in groups, got {group_names}"
|
||||
assert "Admin" not in group_names, "SCIM user should not be in Admin group"
|
||||
|
||||
# --- Verify account_type via admin API ---
|
||||
admin = UserManager.login_as_user(
|
||||
DATestUser(
|
||||
id="",
|
||||
email=build_email(ADMIN_USER_NAME),
|
||||
password=DEFAULT_PASSWORD,
|
||||
headers=GENERAL_HEADERS,
|
||||
role=UserRole.ADMIN,
|
||||
is_active=True,
|
||||
)
|
||||
)
|
||||
page = UserManager.get_user_page(
|
||||
user_performing_action=admin,
|
||||
search_query=email,
|
||||
)
|
||||
assert page.total_items >= 1
|
||||
scim_user_snapshot = next((u for u in page.items if u.email == email), None)
|
||||
assert (
|
||||
scim_user_snapshot is not None
|
||||
), f"SCIM user {email} not found in user listing"
|
||||
assert (
|
||||
scim_user_snapshot.account_type == AccountType.STANDARD
|
||||
), f"Expected STANDARD, got {scim_user_snapshot.account_type}"
|
||||
|
||||
|
||||
def test_get_user(scim_token: str, idp_style: str) -> None:
|
||||
"""GET /Users/{id} returns the user resource with all stored fields."""
|
||||
email = f"scim_get_{idp_style}@example.com"
|
||||
|
||||
@@ -1,118 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import PermissionGrant
|
||||
from onyx.db.models import UserGroup as UserGroupModel
|
||||
from onyx.db.permissions import recompute_permissions_for_group__no_commit
|
||||
from onyx.db.permissions import recompute_user_permissions__no_commit
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="User group tests are enterprise only",
|
||||
)
|
||||
def test_user_gets_permissions_when_added_to_group(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
admin_user: DATestUser = UserManager.create(name="admin_for_perm_test")
|
||||
basic_user: DATestUser = UserManager.create(name="basic_user_for_perm_test")
|
||||
|
||||
# basic_user starts with only "basic" from the default group
|
||||
initial_permissions = UserManager.get_permissions(basic_user)
|
||||
assert "basic" in initial_permissions
|
||||
assert "add:agents" not in initial_permissions
|
||||
|
||||
# Create a new group and add basic_user
|
||||
group = UserGroupManager.create(
|
||||
name="perm-test-group",
|
||||
user_ids=[admin_user.id, basic_user.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Grant a non-basic permission to the group and recompute
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
db_group = db_session.get(UserGroupModel, group.id)
|
||||
assert db_group is not None
|
||||
db_session.add(
|
||||
PermissionGrant(
|
||||
group_id=db_group.id,
|
||||
permission=Permission.ADD_AGENTS,
|
||||
grant_source="SYSTEM",
|
||||
)
|
||||
)
|
||||
db_session.flush()
|
||||
recompute_user_permissions__no_commit(basic_user.id, db_session)
|
||||
db_session.commit()
|
||||
|
||||
# Verify the user gained the new permission (expanded includes read:agents)
|
||||
updated_permissions = UserManager.get_permissions(basic_user)
|
||||
assert (
|
||||
"add:agents" in updated_permissions
|
||||
), f"User should have 'add:agents' after group grant, got: {updated_permissions}"
|
||||
assert (
|
||||
"read:agents" in updated_permissions
|
||||
), f"User should have implied 'read:agents', got: {updated_permissions}"
|
||||
assert "basic" in updated_permissions
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="User group tests are enterprise only",
|
||||
)
|
||||
def test_group_permission_change_propagates_to_all_members(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
admin_user: DATestUser = UserManager.create(name="admin_propagate")
|
||||
user_a: DATestUser = UserManager.create(name="user_a_propagate")
|
||||
user_b: DATestUser = UserManager.create(name="user_b_propagate")
|
||||
|
||||
group = UserGroupManager.create(
|
||||
name="propagate-test-group",
|
||||
user_ids=[admin_user.id, user_a.id, user_b.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Neither user should have add:agents yet
|
||||
for u in (user_a, user_b):
|
||||
assert "add:agents" not in UserManager.get_permissions(u)
|
||||
|
||||
# Grant add:agents to the group, then batch-recompute
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
grant = PermissionGrant(
|
||||
group_id=group.id,
|
||||
permission=Permission.ADD_AGENTS,
|
||||
grant_source="SYSTEM",
|
||||
)
|
||||
db_session.add(grant)
|
||||
db_session.flush()
|
||||
recompute_permissions_for_group__no_commit(group.id, db_session)
|
||||
db_session.commit()
|
||||
|
||||
# Both users should now have the permission (plus implied read:agents)
|
||||
for u in (user_a, user_b):
|
||||
perms = UserManager.get_permissions(u)
|
||||
assert "add:agents" in perms, f"{u.id} missing add:agents: {perms}"
|
||||
assert "read:agents" in perms, f"{u.id} missing implied read:agents: {perms}"
|
||||
|
||||
# Soft-delete the grant and recompute — permission should be removed
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
db_grant = (
|
||||
db_session.query(PermissionGrant)
|
||||
.filter_by(group_id=group.id, permission=Permission.ADD_AGENTS)
|
||||
.first()
|
||||
)
|
||||
assert db_grant is not None
|
||||
db_grant.is_deleted = True
|
||||
db_session.flush()
|
||||
recompute_permissions_for_group__no_commit(group.id, db_session)
|
||||
db_session.commit()
|
||||
|
||||
for u in (user_a, user_b):
|
||||
perms = UserManager.get_permissions(u)
|
||||
assert "add:agents" not in perms, f"{u.id} still has add:agents: {perms}"
|
||||
@@ -1,30 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="User group tests are enterprise only",
|
||||
)
|
||||
def test_new_group_gets_basic_permission(reset: None) -> None: # noqa: ARG001
|
||||
admin_user: DATestUser = UserManager.create(name="admin_for_basic_perm")
|
||||
|
||||
user_group = UserGroupManager.create(
|
||||
name="basic-perm-test-group",
|
||||
user_ids=[admin_user.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
permissions = UserGroupManager.get_permissions(
|
||||
user_group=user_group,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
assert (
|
||||
"basic" in permissions
|
||||
), f"New group should have 'basic' permission, got: {permissions}"
|
||||
@@ -1,78 +0,0 @@
|
||||
"""Integration tests for default group assignment on user registration.
|
||||
|
||||
Verifies that:
|
||||
- The first registered user is assigned to the Admin default group
|
||||
- Subsequent registered users are assigned to the Basic default group
|
||||
- account_type is set to STANDARD for email/password registrations
|
||||
"""
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.enums import AccountType
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def test_default_group_assignment_on_registration(reset: None) -> None: # noqa: ARG001
|
||||
# Register first user — should become admin
|
||||
admin_user: DATestUser = UserManager.create(name="first_user")
|
||||
assert admin_user.role == UserRole.ADMIN
|
||||
|
||||
# Register second user — should become basic
|
||||
basic_user: DATestUser = UserManager.create(name="second_user")
|
||||
assert basic_user.role == UserRole.BASIC
|
||||
|
||||
# Fetch all groups including default ones
|
||||
all_groups = UserGroupManager.get_all(
|
||||
user_performing_action=admin_user,
|
||||
include_default=True,
|
||||
)
|
||||
|
||||
# Find the default Admin and Basic groups
|
||||
admin_group = next(
|
||||
(g for g in all_groups if g.name == "Admin" and g.is_default), None
|
||||
)
|
||||
basic_group = next(
|
||||
(g for g in all_groups if g.name == "Basic" and g.is_default), None
|
||||
)
|
||||
assert admin_group is not None, "Admin default group not found"
|
||||
assert basic_group is not None, "Basic default group not found"
|
||||
|
||||
# Verify admin user is in Admin group and NOT in Basic group
|
||||
admin_group_user_ids = {str(u.id) for u in admin_group.users}
|
||||
basic_group_user_ids = {str(u.id) for u in basic_group.users}
|
||||
|
||||
assert (
|
||||
admin_user.id in admin_group_user_ids
|
||||
), "First user should be in Admin default group"
|
||||
assert (
|
||||
admin_user.id not in basic_group_user_ids
|
||||
), "First user should NOT be in Basic default group"
|
||||
|
||||
# Verify basic user is in Basic group and NOT in Admin group
|
||||
assert (
|
||||
basic_user.id in basic_group_user_ids
|
||||
), "Second user should be in Basic default group"
|
||||
assert (
|
||||
basic_user.id not in admin_group_user_ids
|
||||
), "Second user should NOT be in Admin default group"
|
||||
|
||||
# Verify account_type is STANDARD for both users via user listing API
|
||||
paginated_result = UserManager.get_user_page(
|
||||
user_performing_action=admin_user,
|
||||
page_num=0,
|
||||
page_size=10,
|
||||
)
|
||||
users_by_id = {str(u.id): u for u in paginated_result.items}
|
||||
|
||||
admin_snapshot = users_by_id.get(admin_user.id)
|
||||
basic_snapshot = users_by_id.get(basic_user.id)
|
||||
assert admin_snapshot is not None, "Admin user not found in user listing"
|
||||
assert basic_snapshot is not None, "Basic user not found in user listing"
|
||||
|
||||
assert (
|
||||
admin_snapshot.account_type == AccountType.STANDARD
|
||||
), f"Admin user account_type should be STANDARD, got {admin_snapshot.account_type}"
|
||||
assert (
|
||||
basic_snapshot.account_type == AccountType.STANDARD
|
||||
), f"Basic user account_type should be STANDARD, got {basic_snapshot.account_type}"
|
||||
@@ -1,135 +0,0 @@
|
||||
"""Integration tests for password signup upgrade paths.
|
||||
|
||||
Verifies that when a BOT or EXT_PERM_USER user signs up via email/password:
|
||||
- Their account_type is upgraded to STANDARD
|
||||
- They are assigned to the Basic default group
|
||||
- They gain the correct effective permissions
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.enums import AccountType
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def _get_default_group_member_emails(
|
||||
admin_user: DATestUser,
|
||||
group_name: str,
|
||||
) -> set[str]:
|
||||
"""Get the set of emails of all members in a named default group."""
|
||||
all_groups = UserGroupManager.get_all(admin_user, include_default=True)
|
||||
matched = [g for g in all_groups if g.is_default and g.name == group_name]
|
||||
assert matched, f"Default group '{group_name}' not found"
|
||||
return {u.email for u in matched[0].users}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"target_role",
|
||||
[UserRole.EXT_PERM_USER, UserRole.SLACK_USER],
|
||||
ids=["ext_perm_user", "slack_user"],
|
||||
)
|
||||
def test_password_signup_upgrade(
|
||||
reset: None, # noqa: ARG001
|
||||
target_role: UserRole,
|
||||
) -> None:
|
||||
"""When a non-web user signs up via email/password, they should be
|
||||
upgraded to STANDARD account_type and assigned to the Basic default group."""
|
||||
admin_user: DATestUser = UserManager.create(email="admin@example.com")
|
||||
|
||||
test_email = f"{target_role.value}_upgrade@example.com"
|
||||
test_user = UserManager.create(email=test_email)
|
||||
|
||||
test_user = UserManager.set_role(
|
||||
user_to_set=test_user,
|
||||
target_role=target_role,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
|
||||
# Verify user was removed from Basic group after downgrade
|
||||
basic_emails = _get_default_group_member_emails(admin_user, "Basic")
|
||||
assert (
|
||||
test_email not in basic_emails
|
||||
), f"{target_role.value} should not be in Basic default group"
|
||||
|
||||
# Re-register with the same email — triggers the password signup upgrade
|
||||
upgraded_user = UserManager.create(email=test_email)
|
||||
|
||||
assert upgraded_user.role == UserRole.BASIC
|
||||
|
||||
paginated = UserManager.get_user_page(
|
||||
user_performing_action=admin_user,
|
||||
page_num=0,
|
||||
page_size=10,
|
||||
)
|
||||
user_snapshot = next(
|
||||
(u for u in paginated.items if str(u.id) == upgraded_user.id), None
|
||||
)
|
||||
assert user_snapshot is not None
|
||||
assert (
|
||||
user_snapshot.account_type == AccountType.STANDARD
|
||||
), f"Expected STANDARD, got {user_snapshot.account_type}"
|
||||
|
||||
# Verify user is now in the Basic default group
|
||||
basic_emails = _get_default_group_member_emails(admin_user, "Basic")
|
||||
assert (
|
||||
test_email in basic_emails
|
||||
), f"Upgraded user '{test_email}' not found in Basic default group"
|
||||
|
||||
|
||||
def test_password_signup_upgrade_propagates_permissions(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""When an EXT_PERM_USER or SLACK_USER signs up via password, they should
|
||||
gain the 'basic' permission through the Basic default group assignment."""
|
||||
admin_user: DATestUser = UserManager.create(email="admin@example.com")
|
||||
|
||||
# --- EXT_PERM_USER path ---
|
||||
ext_email = "ext_perms_check@example.com"
|
||||
ext_user = UserManager.create(email=ext_email)
|
||||
|
||||
initial_perms = UserManager.get_permissions(ext_user)
|
||||
assert "basic" in initial_perms
|
||||
|
||||
ext_user = UserManager.set_role(
|
||||
user_to_set=ext_user,
|
||||
target_role=UserRole.EXT_PERM_USER,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
|
||||
basic_emails = _get_default_group_member_emails(admin_user, "Basic")
|
||||
assert ext_email not in basic_emails
|
||||
|
||||
upgraded = UserManager.create(email=ext_email)
|
||||
assert upgraded.role == UserRole.BASIC
|
||||
|
||||
perms = UserManager.get_permissions(upgraded)
|
||||
assert (
|
||||
"basic" in perms
|
||||
), f"Upgraded EXT_PERM_USER should have 'basic' permission, got: {perms}"
|
||||
|
||||
# --- SLACK_USER path ---
|
||||
slack_email = "slack_perms_check@example.com"
|
||||
slack_user = UserManager.create(email=slack_email)
|
||||
|
||||
slack_user = UserManager.set_role(
|
||||
user_to_set=slack_user,
|
||||
target_role=UserRole.SLACK_USER,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
|
||||
basic_emails = _get_default_group_member_emails(admin_user, "Basic")
|
||||
assert slack_email not in basic_emails
|
||||
|
||||
upgraded = UserManager.create(email=slack_email)
|
||||
assert upgraded.role == UserRole.BASIC
|
||||
|
||||
perms = UserManager.get_permissions(upgraded)
|
||||
assert (
|
||||
"basic" in perms
|
||||
), f"Upgraded SLACK_USER should have 'basic' permission, got: {perms}"
|
||||
@@ -1,54 +0,0 @@
|
||||
"""Integration tests for default group reconciliation on user reactivation.
|
||||
|
||||
Verifies that:
|
||||
- A deactivated user retains default group membership after reactivation
|
||||
- Reactivation via the admin API reconciles missing group membership
|
||||
"""
|
||||
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def _get_default_group_member_emails(
|
||||
admin_user: DATestUser,
|
||||
group_name: str,
|
||||
) -> set[str]:
|
||||
"""Get the set of emails of all members in a named default group."""
|
||||
all_groups = UserGroupManager.get_all(admin_user, include_default=True)
|
||||
matched = [g for g in all_groups if g.is_default and g.name == group_name]
|
||||
assert matched, f"Default group '{group_name}' not found"
|
||||
return {u.email for u in matched[0].users}
|
||||
|
||||
|
||||
def test_reactivated_user_retains_default_group(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Deactivating and reactivating a user should preserve their
|
||||
default group membership."""
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
basic_user: DATestUser = UserManager.create(name="basic_user")
|
||||
|
||||
# Verify user is in Basic group initially
|
||||
basic_emails = _get_default_group_member_emails(admin_user, "Basic")
|
||||
assert basic_user.email in basic_emails
|
||||
|
||||
# Deactivate the user
|
||||
UserManager.set_status(
|
||||
user_to_set=basic_user,
|
||||
target_status=False,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Reactivate the user
|
||||
UserManager.set_status(
|
||||
user_to_set=basic_user,
|
||||
target_status=True,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Verify user is still in Basic group after reactivation
|
||||
basic_emails = _get_default_group_member_emails(admin_user, "Basic")
|
||||
assert (
|
||||
basic_user.email in basic_emails
|
||||
), "Reactivated user should still be in Basic default group"
|
||||
@@ -1,58 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from onyx.configs.constants import MASK_CREDENTIAL_CHAR
|
||||
from onyx.db.federated import _reject_masked_credentials
|
||||
|
||||
|
||||
class TestRejectMaskedCredentials:
|
||||
"""Verify that masked credential values are never accepted for DB writes.
|
||||
|
||||
mask_string() has two output formats:
|
||||
- Short strings (< 14 chars): "••••••••••••" (U+2022 BULLET)
|
||||
- Long strings (>= 14 chars): "abcd...wxyz" (first4 + "..." + last4)
|
||||
_reject_masked_credentials must catch both.
|
||||
"""
|
||||
|
||||
def test_rejects_fully_masked_value(self) -> None:
|
||||
masked = MASK_CREDENTIAL_CHAR * 12 # "••••••••••••"
|
||||
with pytest.raises(ValueError, match="masked placeholder"):
|
||||
_reject_masked_credentials({"client_id": masked})
|
||||
|
||||
def test_rejects_long_string_masked_value(self) -> None:
|
||||
"""mask_string returns 'first4...last4' for long strings — the real
|
||||
format used for OAuth credentials like client_id and client_secret."""
|
||||
with pytest.raises(ValueError, match="masked placeholder"):
|
||||
_reject_masked_credentials({"client_id": "1234...7890"})
|
||||
|
||||
def test_rejects_when_any_field_is_masked(self) -> None:
|
||||
"""Even if client_id is real, a masked client_secret must be caught."""
|
||||
with pytest.raises(ValueError, match="client_secret"):
|
||||
_reject_masked_credentials(
|
||||
{
|
||||
"client_id": "1234567890.1234567890",
|
||||
"client_secret": MASK_CREDENTIAL_CHAR * 12,
|
||||
}
|
||||
)
|
||||
|
||||
def test_accepts_real_credentials(self) -> None:
|
||||
# Should not raise
|
||||
_reject_masked_credentials(
|
||||
{
|
||||
"client_id": "1234567890.1234567890",
|
||||
"client_secret": "test_client_secret_value",
|
||||
}
|
||||
)
|
||||
|
||||
def test_accepts_empty_dict(self) -> None:
|
||||
# Should not raise — empty credentials are handled elsewhere
|
||||
_reject_masked_credentials({})
|
||||
|
||||
def test_ignores_non_string_values(self) -> None:
|
||||
# Non-string values (None, bool, int) should pass through
|
||||
_reject_masked_credentials(
|
||||
{
|
||||
"client_id": "real_value",
|
||||
"redirect_uri": None,
|
||||
"some_flag": True,
|
||||
}
|
||||
)
|
||||
@@ -1,176 +0,0 @@
|
||||
"""
|
||||
Unit tests for onyx.auth.permissions — pure logic and FastAPI dependency.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.auth.permissions import ALL_PERMISSIONS
|
||||
from onyx.auth.permissions import get_effective_permissions
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.auth.permissions import resolve_effective_permissions
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# resolve_effective_permissions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResolveEffectivePermissions:
|
||||
def test_empty_set(self) -> None:
|
||||
assert resolve_effective_permissions(set()) == set()
|
||||
|
||||
def test_basic_no_implications(self) -> None:
|
||||
result = resolve_effective_permissions({"basic"})
|
||||
assert result == {"basic"}
|
||||
|
||||
def test_single_implication(self) -> None:
|
||||
result = resolve_effective_permissions({"add:agents"})
|
||||
assert result == {"add:agents", "read:agents"}
|
||||
|
||||
def test_manage_agents_implies_add_and_read(self) -> None:
|
||||
"""manage:agents directly maps to {add:agents, read:agents}."""
|
||||
result = resolve_effective_permissions({"manage:agents"})
|
||||
assert result == {"manage:agents", "add:agents", "read:agents"}
|
||||
|
||||
def test_manage_connectors_chain(self) -> None:
|
||||
result = resolve_effective_permissions({"manage:connectors"})
|
||||
assert result == {"manage:connectors", "add:connectors", "read:connectors"}
|
||||
|
||||
def test_manage_document_sets(self) -> None:
|
||||
result = resolve_effective_permissions({"manage:document_sets"})
|
||||
assert result == {
|
||||
"manage:document_sets",
|
||||
"read:document_sets",
|
||||
"read:connectors",
|
||||
}
|
||||
|
||||
def test_manage_user_groups_implies_all_reads(self) -> None:
|
||||
result = resolve_effective_permissions({"manage:user_groups"})
|
||||
assert result == {
|
||||
"manage:user_groups",
|
||||
"read:connectors",
|
||||
"read:document_sets",
|
||||
"read:agents",
|
||||
"read:users",
|
||||
}
|
||||
|
||||
def test_admin_override(self) -> None:
|
||||
result = resolve_effective_permissions({"admin"})
|
||||
assert result == set(ALL_PERMISSIONS)
|
||||
|
||||
def test_admin_with_others(self) -> None:
|
||||
result = resolve_effective_permissions({"admin", "basic"})
|
||||
assert result == set(ALL_PERMISSIONS)
|
||||
|
||||
def test_multi_group_union(self) -> None:
|
||||
result = resolve_effective_permissions(
|
||||
{"add:agents", "manage:connectors", "basic"}
|
||||
)
|
||||
assert result == {
|
||||
"basic",
|
||||
"add:agents",
|
||||
"read:agents",
|
||||
"manage:connectors",
|
||||
"add:connectors",
|
||||
"read:connectors",
|
||||
}
|
||||
|
||||
def test_toggle_permission_no_implications(self) -> None:
|
||||
result = resolve_effective_permissions({"read:agent_analytics"})
|
||||
assert result == {"read:agent_analytics"}
|
||||
|
||||
def test_all_permissions_for_admin(self) -> None:
|
||||
result = resolve_effective_permissions({"admin"})
|
||||
assert len(result) == len(ALL_PERMISSIONS)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_effective_permissions (expands implied at read time)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetEffectivePermissions:
|
||||
def test_expands_implied_permissions(self) -> None:
|
||||
"""Column stores only granted; get_effective_permissions expands implied."""
|
||||
user = MagicMock()
|
||||
user.effective_permissions = ["add:agents"]
|
||||
result = get_effective_permissions(user)
|
||||
assert result == {Permission.ADD_AGENTS, Permission.READ_AGENTS}
|
||||
|
||||
def test_admin_expands_to_all(self) -> None:
|
||||
user = MagicMock()
|
||||
user.effective_permissions = ["admin"]
|
||||
result = get_effective_permissions(user)
|
||||
assert result == set(Permission)
|
||||
|
||||
def test_basic_stays_basic(self) -> None:
|
||||
user = MagicMock()
|
||||
user.effective_permissions = ["basic"]
|
||||
result = get_effective_permissions(user)
|
||||
assert result == {Permission.BASIC_ACCESS}
|
||||
|
||||
def test_empty_column(self) -> None:
|
||||
user = MagicMock()
|
||||
user.effective_permissions = []
|
||||
result = get_effective_permissions(user)
|
||||
assert result == set()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# require_permission (FastAPI dependency)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRequirePermission:
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_bypass(self) -> None:
|
||||
"""Admin stored in column should pass any permission check."""
|
||||
user = MagicMock()
|
||||
user.effective_permissions = ["admin"]
|
||||
|
||||
dep = require_permission(Permission.MANAGE_CONNECTORS)
|
||||
result = await dep(user=user)
|
||||
assert result is user
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_required_permission(self) -> None:
|
||||
user = MagicMock()
|
||||
user.effective_permissions = ["manage:connectors"]
|
||||
|
||||
dep = require_permission(Permission.MANAGE_CONNECTORS)
|
||||
result = await dep(user=user)
|
||||
assert result is user
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_implied_permission_passes(self) -> None:
|
||||
"""manage:connectors implies read:connectors at read time."""
|
||||
user = MagicMock()
|
||||
user.effective_permissions = ["manage:connectors"]
|
||||
|
||||
dep = require_permission(Permission.READ_CONNECTORS)
|
||||
result = await dep(user=user)
|
||||
assert result is user
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_permission_raises(self) -> None:
|
||||
user = MagicMock()
|
||||
user.effective_permissions = ["basic"]
|
||||
|
||||
dep = require_permission(Permission.MANAGE_CONNECTORS)
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
await dep(user=user)
|
||||
assert exc_info.value.error_code == OnyxErrorCode.INSUFFICIENT_PERMISSIONS
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_permissions_fails(self) -> None:
|
||||
user = MagicMock()
|
||||
user.effective_permissions = []
|
||||
|
||||
dep = require_permission(Permission.BASIC_ACCESS)
|
||||
with pytest.raises(OnyxError):
|
||||
await dep(user=user)
|
||||
@@ -1,29 +0,0 @@
|
||||
"""
|
||||
Unit tests for UserCreate schema dict methods.
|
||||
|
||||
Verifies that account_type is always included in create_update_dict
|
||||
and create_update_dict_superuser.
|
||||
"""
|
||||
|
||||
from onyx.auth.schemas import UserCreate
|
||||
from onyx.db.enums import AccountType
|
||||
|
||||
|
||||
def test_create_update_dict_includes_default_account_type() -> None:
|
||||
uc = UserCreate(email="a@b.com", password="secret123")
|
||||
d = uc.create_update_dict()
|
||||
assert d["account_type"] == AccountType.STANDARD
|
||||
|
||||
|
||||
def test_create_update_dict_includes_explicit_account_type() -> None:
|
||||
uc = UserCreate(
|
||||
email="a@b.com", password="secret123", account_type=AccountType.SERVICE_ACCOUNT
|
||||
)
|
||||
d = uc.create_update_dict()
|
||||
assert d["account_type"] == AccountType.STANDARD
|
||||
|
||||
|
||||
def test_create_update_dict_superuser_includes_account_type() -> None:
|
||||
uc = UserCreate(email="a@b.com", password="secret123")
|
||||
d = uc.create_update_dict_superuser()
|
||||
assert d["account_type"] == AccountType.STANDARD
|
||||
@@ -1,754 +0,0 @@
|
||||
"""Unit tests for multi-model streaming validation and DB helpers.
|
||||
|
||||
These are pure unit tests — no real database or LLM calls required.
|
||||
The validation logic in handle_multi_model_stream fires before any external
|
||||
calls, so we can trigger it with lightweight mocks.
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.db.chat import set_preferred_response
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.server.query_and_chat.models import SendMessageRequest
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import OverallStop
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import ReasoningStart
|
||||
from onyx.utils.variable_functionality import global_version
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _restore_ee_version() -> Generator[None, None, None]:
|
||||
"""Reset EE global state after each test.
|
||||
|
||||
Importing onyx.chat.process_message triggers set_is_ee_based_on_env_variable()
|
||||
(via the celery client import chain). Without this fixture, the EE flag stays
|
||||
True for the rest of the session and breaks unrelated tests that mock Confluence
|
||||
or other connectors and assume EE is disabled.
|
||||
"""
|
||||
original = global_version._is_ee
|
||||
yield
|
||||
global_version._is_ee = original
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_request(**kwargs: Any) -> SendMessageRequest:
|
||||
defaults: dict[str, Any] = {
|
||||
"message": "hello",
|
||||
"chat_session_id": uuid4(),
|
||||
}
|
||||
defaults.update(kwargs)
|
||||
return SendMessageRequest(**defaults)
|
||||
|
||||
|
||||
def _make_override(provider: str = "openai", version: str = "gpt-4") -> LLMOverride:
|
||||
return LLMOverride(model_provider=provider, model_version=version)
|
||||
|
||||
|
||||
def _first_from_stream(req: SendMessageRequest, overrides: list[LLMOverride]) -> Any:
|
||||
"""Return the first item yielded by handle_multi_model_stream."""
|
||||
from onyx.chat.process_message import handle_multi_model_stream
|
||||
|
||||
user = MagicMock()
|
||||
user.is_anonymous = False
|
||||
user.email = "test@example.com"
|
||||
db = MagicMock()
|
||||
|
||||
gen = handle_multi_model_stream(req, user, db, overrides)
|
||||
return next(gen)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# handle_multi_model_stream — validation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRunMultiModelStreamValidation:
|
||||
def test_single_override_yields_error(self) -> None:
|
||||
"""Exactly 1 override is not multi-model — yields StreamingError."""
|
||||
req = _make_request()
|
||||
result = _first_from_stream(req, [_make_override()])
|
||||
assert isinstance(result, StreamingError)
|
||||
assert "2-3" in result.error
|
||||
|
||||
def test_four_overrides_yields_error(self) -> None:
|
||||
"""4 overrides exceeds maximum — yields StreamingError."""
|
||||
req = _make_request()
|
||||
result = _first_from_stream(
|
||||
req,
|
||||
[
|
||||
_make_override("openai", "gpt-4"),
|
||||
_make_override("anthropic", "claude-3"),
|
||||
_make_override("google", "gemini-pro"),
|
||||
_make_override("cohere", "command-r"),
|
||||
],
|
||||
)
|
||||
assert isinstance(result, StreamingError)
|
||||
assert "2-3" in result.error
|
||||
|
||||
def test_zero_overrides_yields_error(self) -> None:
|
||||
"""Empty override list yields StreamingError."""
|
||||
req = _make_request()
|
||||
result = _first_from_stream(req, [])
|
||||
assert isinstance(result, StreamingError)
|
||||
assert "2-3" in result.error
|
||||
|
||||
def test_deep_research_yields_error(self) -> None:
|
||||
"""deep_research=True is incompatible with multi-model — yields StreamingError."""
|
||||
req = _make_request(deep_research=True)
|
||||
result = _first_from_stream(
|
||||
req, [_make_override(), _make_override("anthropic", "claude-3")]
|
||||
)
|
||||
assert isinstance(result, StreamingError)
|
||||
assert "not supported" in result.error
|
||||
|
||||
def test_exactly_two_overrides_is_minimum(self) -> None:
|
||||
"""Boundary: 1 override yields error, 2 overrides passes validation."""
|
||||
req = _make_request()
|
||||
# 1 override must yield a StreamingError
|
||||
result = _first_from_stream(req, [_make_override()])
|
||||
assert isinstance(
|
||||
result, StreamingError
|
||||
), "1 override should yield StreamingError"
|
||||
# 2 overrides must NOT yield a validation StreamingError (may raise later due to
|
||||
# missing session, that's OK — validation itself passed)
|
||||
try:
|
||||
result2 = _first_from_stream(
|
||||
req, [_make_override(), _make_override("anthropic", "claude-3")]
|
||||
)
|
||||
if isinstance(result2, StreamingError) and "2-3" in result2.error:
|
||||
pytest.fail(
|
||||
f"2 overrides should pass validation, got StreamingError: {result2.error}"
|
||||
)
|
||||
except Exception:
|
||||
pass # Any non-validation error means validation passed
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# set_preferred_response — validation (mocked db)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSetPreferredResponseValidation:
|
||||
def test_user_message_not_found(self) -> None:
|
||||
db = MagicMock()
|
||||
db.get.return_value = None
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
set_preferred_response(
|
||||
db, user_message_id=999, preferred_assistant_message_id=1
|
||||
)
|
||||
|
||||
def test_wrong_message_type(self) -> None:
|
||||
"""Cannot set preferred response on a non-USER message."""
|
||||
db = MagicMock()
|
||||
user_msg = MagicMock()
|
||||
user_msg.message_type = MessageType.ASSISTANT # wrong type
|
||||
|
||||
db.get.return_value = user_msg
|
||||
|
||||
with pytest.raises(ValueError, match="not a user message"):
|
||||
set_preferred_response(
|
||||
db, user_message_id=1, preferred_assistant_message_id=2
|
||||
)
|
||||
|
||||
def test_assistant_message_not_found(self) -> None:
|
||||
db = MagicMock()
|
||||
user_msg = MagicMock()
|
||||
user_msg.message_type = MessageType.USER
|
||||
|
||||
# First call returns user_msg, second call (for assistant) returns None
|
||||
db.get.side_effect = [user_msg, None]
|
||||
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
set_preferred_response(
|
||||
db, user_message_id=1, preferred_assistant_message_id=2
|
||||
)
|
||||
|
||||
def test_assistant_not_child_of_user(self) -> None:
|
||||
db = MagicMock()
|
||||
user_msg = MagicMock()
|
||||
user_msg.message_type = MessageType.USER
|
||||
|
||||
assistant_msg = MagicMock()
|
||||
assistant_msg.parent_message_id = 999 # different parent
|
||||
|
||||
db.get.side_effect = [user_msg, assistant_msg]
|
||||
|
||||
with pytest.raises(ValueError, match="not a child"):
|
||||
set_preferred_response(
|
||||
db, user_message_id=1, preferred_assistant_message_id=2
|
||||
)
|
||||
|
||||
def test_valid_call_sets_preferred_response_id(self) -> None:
|
||||
db = MagicMock()
|
||||
user_msg = MagicMock()
|
||||
user_msg.message_type = MessageType.USER
|
||||
|
||||
assistant_msg = MagicMock()
|
||||
assistant_msg.parent_message_id = 1 # correct parent
|
||||
|
||||
db.get.side_effect = [user_msg, assistant_msg]
|
||||
|
||||
set_preferred_response(db, user_message_id=1, preferred_assistant_message_id=2)
|
||||
|
||||
assert user_msg.preferred_response_id == 2
|
||||
assert user_msg.latest_child_message_id == 2
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LLMOverride — display_name field
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLLMOverrideDisplayName:
|
||||
def test_display_name_defaults_none(self) -> None:
|
||||
override = LLMOverride(model_provider="openai", model_version="gpt-4")
|
||||
assert override.display_name is None
|
||||
|
||||
def test_display_name_set(self) -> None:
|
||||
override = LLMOverride(
|
||||
model_provider="openai",
|
||||
model_version="gpt-4",
|
||||
display_name="GPT-4 Turbo",
|
||||
)
|
||||
assert override.display_name == "GPT-4 Turbo"
|
||||
|
||||
def test_display_name_serializes(self) -> None:
|
||||
override = LLMOverride(
|
||||
model_provider="anthropic",
|
||||
model_version="claude-opus-4-6",
|
||||
display_name="Claude Opus",
|
||||
)
|
||||
d = override.model_dump()
|
||||
assert d["display_name"] == "Claude Opus"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# _run_models — drain loop behaviour
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_setup(n_models: int = 1) -> MagicMock:
|
||||
"""Minimal ChatTurnSetup mock whose fields pass Pydantic validation in _run_model."""
|
||||
setup = MagicMock()
|
||||
setup.llms = [MagicMock() for _ in range(n_models)]
|
||||
setup.model_display_names = [f"model-{i}" for i in range(n_models)]
|
||||
setup.check_is_connected = MagicMock(return_value=True)
|
||||
setup.reserved_messages = [MagicMock() for _ in range(n_models)]
|
||||
setup.reserved_token_count = 100
|
||||
# Fields consumed by SearchToolConfig / CustomToolConfig / FileReaderToolConfig
|
||||
# constructors inside _run_model — must be typed correctly for Pydantic.
|
||||
setup.new_msg_req.deep_research = False
|
||||
setup.new_msg_req.internal_search_filters = None
|
||||
setup.new_msg_req.allowed_tool_ids = None
|
||||
setup.new_msg_req.include_citations = True
|
||||
setup.search_params.project_id_filter = None
|
||||
setup.search_params.persona_id_filter = None
|
||||
setup.bypass_acl = False
|
||||
setup.slack_context = None
|
||||
setup.available_files.user_file_ids = []
|
||||
setup.available_files.chat_file_ids = []
|
||||
setup.forced_tool_id = None
|
||||
setup.simple_chat_history = []
|
||||
setup.chat_session.id = uuid4()
|
||||
setup.user_message.id = None
|
||||
setup.custom_tool_additional_headers = None
|
||||
setup.mcp_headers = None
|
||||
return setup
|
||||
|
||||
|
||||
def _run_models_collect(setup: MagicMock) -> list:
|
||||
"""Drive _run_models to completion and return all yielded items."""
|
||||
from onyx.chat.process_message import _run_models
|
||||
|
||||
return list(_run_models(setup, MagicMock(), MagicMock()))
|
||||
|
||||
|
||||
class TestRunModels:
|
||||
"""Tests for the _run_models worker-thread drain loop.
|
||||
|
||||
All external dependencies (LLM, DB, tools) are patched out. Worker threads
|
||||
still run but return immediately since run_llm_loop is mocked.
|
||||
"""
|
||||
|
||||
def test_n1_overall_stop_from_llm_loop_passes_through(self) -> None:
|
||||
"""OverallStop emitted by run_llm_loop is passed through the drain loop unchanged."""
|
||||
|
||||
def emit_stop(**kwargs: Any) -> None:
|
||||
kwargs["emitter"].emit(
|
||||
Packet(
|
||||
placement=Placement(turn_index=0),
|
||||
obj=OverallStop(stop_reason="complete"),
|
||||
)
|
||||
)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_stop),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
packets = _run_models_collect(_make_setup(n_models=1))
|
||||
|
||||
stops = [
|
||||
p
|
||||
for p in packets
|
||||
if isinstance(p, Packet) and isinstance(p.obj, OverallStop)
|
||||
]
|
||||
assert len(stops) == 1
|
||||
stop_obj = stops[0].obj
|
||||
assert isinstance(stop_obj, OverallStop)
|
||||
assert stop_obj.stop_reason == "complete"
|
||||
|
||||
def test_n1_emitted_packet_has_model_index_zero(self) -> None:
|
||||
"""Single-model path: model_index is 0 (Emitter defaults model_idx=0)."""
|
||||
|
||||
def emit_one(**kwargs: Any) -> None:
|
||||
kwargs["emitter"].emit(
|
||||
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
|
||||
)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_one),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
packets = _run_models_collect(_make_setup(n_models=1))
|
||||
|
||||
reasoning = [
|
||||
p
|
||||
for p in packets
|
||||
if isinstance(p, Packet) and isinstance(p.obj, ReasoningStart)
|
||||
]
|
||||
assert len(reasoning) == 1
|
||||
assert reasoning[0].placement.model_index == 0
|
||||
|
||||
def test_n2_each_model_packet_tagged_with_its_index(self) -> None:
|
||||
"""Multi-model path: packets from model 0 get index=0, model 1 gets index=1."""
|
||||
|
||||
def emit_one(**kwargs: Any) -> None:
|
||||
# _model_idx is set by _run_model based on position in setup.llms
|
||||
emitter = kwargs["emitter"]
|
||||
emitter.emit(
|
||||
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
|
||||
)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=emit_one),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
packets = _run_models_collect(_make_setup(n_models=2))
|
||||
|
||||
reasoning = [
|
||||
p
|
||||
for p in packets
|
||||
if isinstance(p, Packet) and isinstance(p.obj, ReasoningStart)
|
||||
]
|
||||
assert len(reasoning) == 2
|
||||
indices = {p.placement.model_index for p in reasoning}
|
||||
assert indices == {0, 1}
|
||||
|
||||
def test_model_error_yields_streaming_error(self) -> None:
|
||||
"""An exception inside a worker thread is surfaced as a StreamingError."""
|
||||
|
||||
def always_fail(**_kwargs: Any) -> None:
|
||||
raise RuntimeError("intentional test failure")
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=always_fail),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
packets = _run_models_collect(_make_setup(n_models=1))
|
||||
|
||||
errors = [p for p in packets if isinstance(p, StreamingError)]
|
||||
assert len(errors) == 1
|
||||
assert errors[0].error_code == "MODEL_ERROR"
|
||||
assert "intentional test failure" in errors[0].error
|
||||
|
||||
def test_one_model_error_does_not_stop_other_models(self) -> None:
|
||||
"""A failing model yields StreamingError; the surviving model's packets still arrive."""
|
||||
setup = _make_setup(n_models=2)
|
||||
|
||||
def fail_model_0_succeed_model_1(**kwargs: Any) -> None:
|
||||
if kwargs["llm"] is setup.llms[0]:
|
||||
raise RuntimeError("model 0 failed")
|
||||
kwargs["emitter"].emit(
|
||||
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
|
||||
)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.chat.process_message.run_llm_loop",
|
||||
side_effect=fail_model_0_succeed_model_1,
|
||||
),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
packets = _run_models_collect(setup)
|
||||
|
||||
errors = [p for p in packets if isinstance(p, StreamingError)]
|
||||
assert len(errors) == 1
|
||||
|
||||
reasoning = [
|
||||
p
|
||||
for p in packets
|
||||
if isinstance(p, Packet) and isinstance(p.obj, ReasoningStart)
|
||||
]
|
||||
assert len(reasoning) == 1
|
||||
assert reasoning[0].placement.model_index == 1
|
||||
|
||||
def test_cancellation_yields_user_cancelled_stop(self) -> None:
|
||||
"""If check_is_connected returns False, drain loop emits user_cancelled."""
|
||||
|
||||
def slow_llm(**_kwargs: Any) -> None:
|
||||
time.sleep(0.3) # Outlasts the 50 ms queue-poll interval
|
||||
|
||||
setup = _make_setup(n_models=1)
|
||||
setup.check_is_connected = MagicMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=slow_llm),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
packets = _run_models_collect(setup)
|
||||
|
||||
stops = [
|
||||
p
|
||||
for p in packets
|
||||
if isinstance(p, Packet) and isinstance(p.obj, OverallStop)
|
||||
]
|
||||
assert any(
|
||||
isinstance(s.obj, OverallStop) and s.obj.stop_reason == "user_cancelled"
|
||||
for s in stops
|
||||
)
|
||||
|
||||
def test_stop_button_calls_completion_for_all_models(self) -> None:
|
||||
"""llm_loop_completion_handle must be called for all models when the stop button fires.
|
||||
|
||||
Regression test for the disconnect-cleanup bug: the old
|
||||
run_chat_loop_with_state_containers always called completion_callback in
|
||||
its finally block (even on disconnect) so the DB message was updated from
|
||||
the TERMINATED placeholder to a partial answer. The new _run_models must
|
||||
replicate this — otherwise the integration test
|
||||
test_send_message_disconnect_and_cleanup fails because the message stays
|
||||
as "Response was terminated prior to completion, try regenerating."
|
||||
"""
|
||||
|
||||
def slow_llm(**_kwargs: Any) -> None:
|
||||
time.sleep(0.3)
|
||||
|
||||
setup = _make_setup(n_models=2)
|
||||
setup.check_is_connected = MagicMock(return_value=False)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=slow_llm),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
_run_models_collect(setup)
|
||||
|
||||
# Must be called once per model, not zero times
|
||||
assert mock_handle.call_count == 2
|
||||
|
||||
def test_completion_handle_called_for_each_successful_model(self) -> None:
|
||||
"""llm_loop_completion_handle must be called once per model that succeeded."""
|
||||
setup = _make_setup(n_models=2)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop"),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
_run_models_collect(setup)
|
||||
|
||||
assert mock_handle.call_count == 2
|
||||
|
||||
def test_completion_handle_not_called_for_failed_model(self) -> None:
|
||||
"""llm_loop_completion_handle must be skipped for a model that raised."""
|
||||
|
||||
def always_fail(**_kwargs: Any) -> None:
|
||||
raise RuntimeError("fail")
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=always_fail),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
_run_models_collect(_make_setup(n_models=1))
|
||||
|
||||
mock_handle.assert_not_called()
|
||||
|
||||
def test_http_disconnect_completion_via_generator_exit(self) -> None:
|
||||
"""GeneratorExit from HTTP disconnect triggers main-thread completion.
|
||||
|
||||
When the HTTP client closes the connection, Starlette throws GeneratorExit
|
||||
into the stream generator. The finally block sets drain_done (signalling
|
||||
emitters to stop blocking), waits for workers via executor.shutdown(wait=True),
|
||||
then calls llm_loop_completion_handle for each successful model from the main
|
||||
thread.
|
||||
|
||||
This is the primary regression for test_send_message_disconnect_and_cleanup:
|
||||
the integration test disconnects mid-stream and expects the DB message to be
|
||||
updated from the TERMINATED placeholder to the real response.
|
||||
"""
|
||||
import threading
|
||||
|
||||
completion_called = threading.Event()
|
||||
|
||||
def emit_then_block_until_drain(**kwargs: Any) -> None:
|
||||
"""Emit one packet (to give the drain loop a yield point), then block
|
||||
until drain_done is set — simulating a mid-stream LLM call that exits
|
||||
promptly once the emitter signals shutdown.
|
||||
"""
|
||||
emitter = kwargs["emitter"]
|
||||
emitter.emit(
|
||||
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
|
||||
)
|
||||
# Block until drain_done is set by gen.close(). The Emitter's _drain_done
|
||||
# is the same Event that _run_models sets, so this unblocks promptly.
|
||||
emitter._drain_done.wait(timeout=5)
|
||||
|
||||
setup = _make_setup(n_models=1)
|
||||
# is_connected() always True — HTTP disconnect does NOT set the Redis stop fence.
|
||||
setup.check_is_connected = MagicMock(return_value=True)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.chat.process_message.run_llm_loop",
|
||||
side_effect=emit_then_block_until_drain,
|
||||
),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle",
|
||||
side_effect=lambda *_, **__: completion_called.set(),
|
||||
) as mock_handle,
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
from onyx.chat.process_message import _run_models
|
||||
|
||||
gen = cast(Generator, _run_models(setup, MagicMock(), MagicMock()))
|
||||
first = next(gen)
|
||||
assert isinstance(first, Packet)
|
||||
# Simulate Starlette closing the stream on HTTP client disconnect.
|
||||
# gen.close() → GeneratorExit → finally → drain_done.set() →
|
||||
# executor.shutdown(wait=True) → main thread completes models.
|
||||
gen.close()
|
||||
|
||||
assert (
|
||||
completion_called.is_set()
|
||||
), "main thread must call completion for the successful model"
|
||||
assert mock_handle.call_count == 1
|
||||
|
||||
def test_b1_race_disconnect_handler_completes_already_finished_model(self) -> None:
|
||||
"""B1 regression: model finishes BEFORE GeneratorExit fires.
|
||||
|
||||
The worker exits _run_model before drain_done is set. When gen.close()
|
||||
fires afterward, the finally block sets drain_done, waits for workers
|
||||
(already done), then the main thread calls llm_loop_completion_handle.
|
||||
|
||||
Contrast with test_http_disconnect_completion_via_generator_exit, which
|
||||
tests the opposite ordering (worker finishes AFTER disconnect).
|
||||
"""
|
||||
import threading
|
||||
import time
|
||||
|
||||
completion_called = threading.Event()
|
||||
|
||||
def emit_and_return_immediately(**kwargs: Any) -> None:
|
||||
# Emit one packet so the drain loop has something to yield, then return
|
||||
# immediately — no blocking. The worker will be done in microseconds.
|
||||
kwargs["emitter"].emit(
|
||||
Packet(placement=Placement(turn_index=0), obj=ReasoningStart())
|
||||
)
|
||||
|
||||
setup = _make_setup(n_models=1)
|
||||
setup.check_is_connected = MagicMock(return_value=True)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.chat.process_message.run_llm_loop",
|
||||
side_effect=emit_and_return_immediately,
|
||||
),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle",
|
||||
side_effect=lambda *_, **__: completion_called.set(),
|
||||
) as mock_handle,
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
from onyx.chat.process_message import _run_models
|
||||
|
||||
gen = cast(Generator, _run_models(setup, MagicMock(), MagicMock()))
|
||||
first = next(gen)
|
||||
assert isinstance(first, Packet)
|
||||
|
||||
# Give the worker thread time to finish completely (emit + return +
|
||||
# finally + self-completion check). It does almost no work, so 100 ms
|
||||
# is far more than enough while still keeping the test fast.
|
||||
time.sleep(0.1)
|
||||
|
||||
# Now close — worker is already done, so else-branch handles completion.
|
||||
gen.close()
|
||||
|
||||
assert completion_called.wait(
|
||||
timeout=5
|
||||
), "disconnect handler must call completion for a model that already finished"
|
||||
assert mock_handle.call_count == 1, "completion must be called exactly once"
|
||||
|
||||
def test_stop_button_does_not_call_completion_for_errored_model(self) -> None:
|
||||
"""B2 regression: stop-button must NOT call completion for an errored model.
|
||||
|
||||
When model 0 raises an exception, its reserved ChatMessage must not be
|
||||
saved with 'stopped by user' — that message is wrong for a model that
|
||||
errored. llm_loop_completion_handle must only be called for non-errored
|
||||
models when the stop button fires.
|
||||
"""
|
||||
|
||||
def fail_model_0(**kwargs: Any) -> None:
|
||||
if kwargs["llm"] is setup.llms[0]:
|
||||
raise RuntimeError("model 0 errored")
|
||||
# Model 1: run forever (stop button fires before it finishes)
|
||||
time.sleep(10)
|
||||
|
||||
setup = _make_setup(n_models=2)
|
||||
# Return False immediately so the stop-button path fires while model 1
|
||||
# is still sleeping (model 0 has already errored by then).
|
||||
setup.check_is_connected = lambda: False
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop", side_effect=fail_model_0),
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch(
|
||||
"onyx.chat.process_message.llm_loop_completion_handle"
|
||||
) as mock_handle,
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
_run_models_collect(setup)
|
||||
|
||||
# Completion must NOT be called for model 0 (it errored).
|
||||
# It MAY be called for model 1 (still in-flight when stop fired).
|
||||
for call in mock_handle.call_args_list:
|
||||
assert (
|
||||
call.kwargs.get("llm") is not setup.llms[0]
|
||||
), "llm_loop_completion_handle must not be called for the errored model"
|
||||
|
||||
def test_external_state_container_used_for_model_zero(self) -> None:
|
||||
"""When provided, external_state_container is used as state_containers[0]."""
|
||||
from onyx.chat.chat_state import ChatStateContainer
|
||||
from onyx.chat.process_message import _run_models
|
||||
|
||||
external = ChatStateContainer()
|
||||
setup = _make_setup(n_models=1)
|
||||
|
||||
with (
|
||||
patch("onyx.chat.process_message.run_llm_loop") as mock_llm,
|
||||
patch("onyx.chat.process_message.run_deep_research_llm_loop"),
|
||||
patch("onyx.chat.process_message.construct_tools", return_value={}),
|
||||
patch("onyx.chat.process_message.get_session_with_current_tenant"),
|
||||
patch("onyx.chat.process_message.llm_loop_completion_handle"),
|
||||
patch(
|
||||
"onyx.chat.process_message.get_llm_token_counter",
|
||||
return_value=lambda _: 0,
|
||||
),
|
||||
):
|
||||
list(
|
||||
_run_models(
|
||||
setup, MagicMock(), MagicMock(), external_state_container=external
|
||||
)
|
||||
)
|
||||
|
||||
# The state_container kwarg passed to run_llm_loop must be the external one
|
||||
call_kwargs = mock_llm.call_args.kwargs
|
||||
assert call_kwargs["state_container"] is external
|
||||
@@ -1,318 +0,0 @@
|
||||
"""Unit tests for Notion connector handling of people properties and table blocks.
|
||||
|
||||
Reproduces two bugs:
|
||||
1. ENG-3970: People-type database properties (user mentions) are not extracted —
|
||||
the user's "name" field is lost when _recurse_properties drills into the
|
||||
"person" sub-dict.
|
||||
2. ENG-3971: Inline table blocks (table/table_row) are not indexed — table_row
|
||||
blocks store content in "cells" rather than "rich_text", so no text is extracted.
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.connectors.notion.connector import NotionConnector
|
||||
|
||||
|
||||
def _make_connector() -> NotionConnector:
|
||||
connector = NotionConnector()
|
||||
connector.load_credentials({"notion_integration_token": "fake-token"})
|
||||
return connector
|
||||
|
||||
|
||||
class TestPeoplePropertyExtraction:
|
||||
"""ENG-3970: Verifies that 'people' type database properties extract user names."""
|
||||
|
||||
def test_single_person_property(self) -> None:
|
||||
"""A database cell with a single @mention should extract the user name."""
|
||||
properties = {
|
||||
"Team Lead": {
|
||||
"id": "abc",
|
||||
"type": "people",
|
||||
"people": [
|
||||
{
|
||||
"object": "user",
|
||||
"id": "user-uuid-1",
|
||||
"name": "Arturo Martinez",
|
||||
"type": "person",
|
||||
"person": {"email": "arturo@example.com"},
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
result = NotionConnector._properties_to_str(properties)
|
||||
assert (
|
||||
"Arturo Martinez" in result
|
||||
), f"Expected 'Arturo Martinez' in extracted text, got: {result!r}"
|
||||
|
||||
def test_multiple_people_property(self) -> None:
|
||||
"""A database cell with multiple @mentions should extract all user names."""
|
||||
properties = {
|
||||
"Members": {
|
||||
"id": "def",
|
||||
"type": "people",
|
||||
"people": [
|
||||
{
|
||||
"object": "user",
|
||||
"id": "user-uuid-1",
|
||||
"name": "Arturo Martinez",
|
||||
"type": "person",
|
||||
"person": {"email": "arturo@example.com"},
|
||||
},
|
||||
{
|
||||
"object": "user",
|
||||
"id": "user-uuid-2",
|
||||
"name": "Jane Smith",
|
||||
"type": "person",
|
||||
"person": {"email": "jane@example.com"},
|
||||
},
|
||||
],
|
||||
}
|
||||
}
|
||||
result = NotionConnector._properties_to_str(properties)
|
||||
assert (
|
||||
"Arturo Martinez" in result
|
||||
), f"Expected 'Arturo Martinez' in extracted text, got: {result!r}"
|
||||
assert (
|
||||
"Jane Smith" in result
|
||||
), f"Expected 'Jane Smith' in extracted text, got: {result!r}"
|
||||
|
||||
def test_bot_user_property(self) -> None:
|
||||
"""Bot users (integrations) have 'type': 'bot' — name should still be extracted."""
|
||||
properties = {
|
||||
"Created By": {
|
||||
"id": "ghi",
|
||||
"type": "people",
|
||||
"people": [
|
||||
{
|
||||
"object": "user",
|
||||
"id": "bot-uuid-1",
|
||||
"name": "Onyx Integration",
|
||||
"type": "bot",
|
||||
"bot": {},
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
result = NotionConnector._properties_to_str(properties)
|
||||
assert (
|
||||
"Onyx Integration" in result
|
||||
), f"Expected 'Onyx Integration' in extracted text, got: {result!r}"
|
||||
|
||||
def test_person_without_person_details(self) -> None:
|
||||
"""Some user objects may have an empty/null person sub-dict."""
|
||||
properties = {
|
||||
"Assignee": {
|
||||
"id": "jkl",
|
||||
"type": "people",
|
||||
"people": [
|
||||
{
|
||||
"object": "user",
|
||||
"id": "user-uuid-3",
|
||||
"name": "Ghost User",
|
||||
"type": "person",
|
||||
"person": {},
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
result = NotionConnector._properties_to_str(properties)
|
||||
assert (
|
||||
"Ghost User" in result
|
||||
), f"Expected 'Ghost User' in extracted text, got: {result!r}"
|
||||
|
||||
def test_people_mixed_with_other_properties(self) -> None:
|
||||
"""People property should work alongside other property types."""
|
||||
properties = {
|
||||
"Name": {
|
||||
"id": "aaa",
|
||||
"type": "title",
|
||||
"title": [
|
||||
{
|
||||
"plain_text": "Project Alpha",
|
||||
"type": "text",
|
||||
"text": {"content": "Project Alpha"},
|
||||
}
|
||||
],
|
||||
},
|
||||
"Lead": {
|
||||
"id": "bbb",
|
||||
"type": "people",
|
||||
"people": [
|
||||
{
|
||||
"object": "user",
|
||||
"id": "user-uuid-1",
|
||||
"name": "Arturo Martinez",
|
||||
"type": "person",
|
||||
"person": {"email": "arturo@example.com"},
|
||||
}
|
||||
],
|
||||
},
|
||||
"Status": {
|
||||
"id": "ccc",
|
||||
"type": "status",
|
||||
"status": {"name": "In Progress", "id": "status-1"},
|
||||
},
|
||||
}
|
||||
result = NotionConnector._properties_to_str(properties)
|
||||
assert "Arturo Martinez" in result
|
||||
assert "In Progress" in result
|
||||
|
||||
|
||||
class TestTableBlockExtraction:
|
||||
"""ENG-3971: Verifies that inline table blocks (table/table_row) are indexed."""
|
||||
|
||||
def _make_blocks_response(self, results: list) -> dict:
|
||||
return {"results": results, "next_cursor": None}
|
||||
|
||||
def test_table_row_cells_are_extracted(self) -> None:
|
||||
"""table_row blocks store content in 'cells', not 'rich_text'.
|
||||
The connector should extract text from cells."""
|
||||
connector = _make_connector()
|
||||
connector.workspace_id = "ws-1"
|
||||
|
||||
table_block = {
|
||||
"id": "table-block-1",
|
||||
"type": "table",
|
||||
"table": {
|
||||
"has_column_header": True,
|
||||
"has_row_header": False,
|
||||
"table_width": 3,
|
||||
},
|
||||
"has_children": True,
|
||||
}
|
||||
|
||||
header_row = {
|
||||
"id": "row-1",
|
||||
"type": "table_row",
|
||||
"table_row": {
|
||||
"cells": [
|
||||
[
|
||||
{
|
||||
"type": "text",
|
||||
"text": {"content": "Name"},
|
||||
"plain_text": "Name",
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"type": "text",
|
||||
"text": {"content": "Role"},
|
||||
"plain_text": "Role",
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"type": "text",
|
||||
"text": {"content": "Team"},
|
||||
"plain_text": "Team",
|
||||
}
|
||||
],
|
||||
]
|
||||
},
|
||||
"has_children": False,
|
||||
}
|
||||
|
||||
data_row = {
|
||||
"id": "row-2",
|
||||
"type": "table_row",
|
||||
"table_row": {
|
||||
"cells": [
|
||||
[
|
||||
{
|
||||
"type": "text",
|
||||
"text": {"content": "Arturo Martinez"},
|
||||
"plain_text": "Arturo Martinez",
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"type": "text",
|
||||
"text": {"content": "Engineer"},
|
||||
"plain_text": "Engineer",
|
||||
}
|
||||
],
|
||||
[
|
||||
{
|
||||
"type": "text",
|
||||
"text": {"content": "Platform"},
|
||||
"plain_text": "Platform",
|
||||
}
|
||||
],
|
||||
]
|
||||
},
|
||||
"has_children": False,
|
||||
}
|
||||
|
||||
with patch.object(
|
||||
connector,
|
||||
"_fetch_child_blocks",
|
||||
side_effect=[
|
||||
self._make_blocks_response([table_block]),
|
||||
self._make_blocks_response([header_row, data_row]),
|
||||
],
|
||||
):
|
||||
output = connector._read_blocks("page-1")
|
||||
|
||||
all_text = " ".join(block.text for block in output.blocks)
|
||||
assert "Arturo Martinez" in all_text, (
|
||||
f"Expected 'Arturo Martinez' in table row text, got blocks: "
|
||||
f"{[(b.id, b.text) for b in output.blocks]}"
|
||||
)
|
||||
assert "Engineer" in all_text, (
|
||||
f"Expected 'Engineer' in table row text, got blocks: "
|
||||
f"{[(b.id, b.text) for b in output.blocks]}"
|
||||
)
|
||||
assert "Platform" in all_text, (
|
||||
f"Expected 'Platform' in table row text, got blocks: "
|
||||
f"{[(b.id, b.text) for b in output.blocks]}"
|
||||
)
|
||||
|
||||
def test_table_with_empty_cells(self) -> None:
|
||||
"""Table rows with some empty cells should still extract non-empty content."""
|
||||
connector = _make_connector()
|
||||
connector.workspace_id = "ws-1"
|
||||
|
||||
table_block = {
|
||||
"id": "table-block-2",
|
||||
"type": "table",
|
||||
"table": {
|
||||
"has_column_header": False,
|
||||
"has_row_header": False,
|
||||
"table_width": 2,
|
||||
},
|
||||
"has_children": True,
|
||||
}
|
||||
|
||||
row_with_empty = {
|
||||
"id": "row-3",
|
||||
"type": "table_row",
|
||||
"table_row": {
|
||||
"cells": [
|
||||
[
|
||||
{
|
||||
"type": "text",
|
||||
"text": {"content": "Has Value"},
|
||||
"plain_text": "Has Value",
|
||||
}
|
||||
],
|
||||
[], # empty cell
|
||||
]
|
||||
},
|
||||
"has_children": False,
|
||||
}
|
||||
|
||||
with patch.object(
|
||||
connector,
|
||||
"_fetch_child_blocks",
|
||||
side_effect=[
|
||||
self._make_blocks_response([table_block]),
|
||||
self._make_blocks_response([row_with_empty]),
|
||||
],
|
||||
):
|
||||
output = connector._read_blocks("page-2")
|
||||
|
||||
all_text = " ".join(block.text for block in output.blocks)
|
||||
assert "Has Value" in all_text, (
|
||||
f"Expected 'Has Value' in table row text, got blocks: "
|
||||
f"{[(b.id, b.text) for b in output.blocks]}"
|
||||
)
|
||||
@@ -1,176 +0,0 @@
|
||||
"""
|
||||
Unit tests for assign_user_to_default_groups__no_commit in onyx.db.users.
|
||||
|
||||
Covers:
|
||||
1. Standard/service-account users get assigned to the correct default group
|
||||
2. BOT, EXT_PERM_USER, ANONYMOUS account types are skipped
|
||||
3. Missing default group raises RuntimeError
|
||||
4. Already-in-group is a no-op
|
||||
5. IntegrityError race condition is handled gracefully
|
||||
6. The function never commits the session
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.users import assign_user_to_default_groups__no_commit
|
||||
|
||||
|
||||
def _mock_user(
|
||||
account_type: AccountType = AccountType.STANDARD,
|
||||
email: str = "test@example.com",
|
||||
) -> MagicMock:
|
||||
user = MagicMock()
|
||||
user.id = uuid4()
|
||||
user.email = email
|
||||
user.account_type = account_type
|
||||
return user
|
||||
|
||||
|
||||
def _mock_group(name: str = "Basic", group_id: int = 1) -> MagicMock:
|
||||
group = MagicMock()
|
||||
group.id = group_id
|
||||
group.name = name
|
||||
group.is_default = True
|
||||
return group
|
||||
|
||||
|
||||
def _make_query_chain(first_return: object = None) -> MagicMock:
|
||||
"""Returns a mock that supports .filter(...).filter(...).first() chaining."""
|
||||
chain = MagicMock()
|
||||
chain.filter.return_value = chain
|
||||
chain.first.return_value = first_return
|
||||
return chain
|
||||
|
||||
|
||||
def _setup_db_session(
|
||||
group_result: object = None,
|
||||
membership_result: object = None,
|
||||
) -> MagicMock:
|
||||
"""Create a db_session mock that routes query(UserGroup) and query(User__UserGroup)."""
|
||||
db_session = MagicMock()
|
||||
|
||||
group_chain = _make_query_chain(group_result)
|
||||
membership_chain = _make_query_chain(membership_result)
|
||||
|
||||
def query_side_effect(model: type) -> MagicMock:
|
||||
if model is UserGroup:
|
||||
return group_chain
|
||||
if model is User__UserGroup:
|
||||
return membership_chain
|
||||
return MagicMock()
|
||||
|
||||
db_session.query.side_effect = query_side_effect
|
||||
return db_session
|
||||
|
||||
|
||||
def test_standard_user_assigned_to_basic_group() -> None:
|
||||
group = _mock_group("Basic")
|
||||
db_session = _setup_db_session(group_result=group, membership_result=None)
|
||||
savepoint = MagicMock()
|
||||
db_session.begin_nested.return_value = savepoint
|
||||
user = _mock_user(AccountType.STANDARD)
|
||||
|
||||
assign_user_to_default_groups__no_commit(db_session, user, is_admin=False)
|
||||
|
||||
db_session.add.assert_called_once()
|
||||
added = db_session.add.call_args[0][0]
|
||||
assert isinstance(added, User__UserGroup)
|
||||
assert added.user_id == user.id
|
||||
assert added.user_group_id == group.id
|
||||
db_session.flush.assert_called_once()
|
||||
|
||||
|
||||
def test_admin_user_assigned_to_admin_group() -> None:
|
||||
group = _mock_group("Admin", group_id=2)
|
||||
db_session = _setup_db_session(group_result=group, membership_result=None)
|
||||
savepoint = MagicMock()
|
||||
db_session.begin_nested.return_value = savepoint
|
||||
user = _mock_user(AccountType.STANDARD)
|
||||
|
||||
assign_user_to_default_groups__no_commit(db_session, user, is_admin=True)
|
||||
|
||||
db_session.add.assert_called_once()
|
||||
added = db_session.add.call_args[0][0]
|
||||
assert isinstance(added, User__UserGroup)
|
||||
assert added.user_group_id == group.id
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"account_type",
|
||||
[AccountType.BOT, AccountType.EXT_PERM_USER, AccountType.ANONYMOUS],
|
||||
)
|
||||
def test_excluded_account_types_skipped(account_type: AccountType) -> None:
|
||||
db_session = MagicMock()
|
||||
user = _mock_user(account_type)
|
||||
|
||||
assign_user_to_default_groups__no_commit(db_session, user)
|
||||
|
||||
db_session.query.assert_not_called()
|
||||
db_session.add.assert_not_called()
|
||||
|
||||
|
||||
def test_service_account_not_skipped() -> None:
|
||||
group = _mock_group("Basic")
|
||||
db_session = _setup_db_session(group_result=group, membership_result=None)
|
||||
savepoint = MagicMock()
|
||||
db_session.begin_nested.return_value = savepoint
|
||||
user = _mock_user(AccountType.SERVICE_ACCOUNT)
|
||||
|
||||
assign_user_to_default_groups__no_commit(db_session, user, is_admin=False)
|
||||
|
||||
db_session.add.assert_called_once()
|
||||
|
||||
|
||||
def test_missing_default_group_raises_error() -> None:
|
||||
db_session = _setup_db_session(group_result=None)
|
||||
user = _mock_user()
|
||||
|
||||
with pytest.raises(RuntimeError, match="Default group .* not found"):
|
||||
assign_user_to_default_groups__no_commit(db_session, user)
|
||||
|
||||
|
||||
def test_already_in_group_is_noop() -> None:
|
||||
group = _mock_group("Basic")
|
||||
existing_membership = MagicMock()
|
||||
db_session = _setup_db_session(
|
||||
group_result=group, membership_result=existing_membership
|
||||
)
|
||||
user = _mock_user()
|
||||
|
||||
assign_user_to_default_groups__no_commit(db_session, user)
|
||||
|
||||
db_session.add.assert_not_called()
|
||||
db_session.begin_nested.assert_not_called()
|
||||
|
||||
|
||||
def test_integrity_error_race_condition_handled() -> None:
|
||||
group = _mock_group("Basic")
|
||||
db_session = _setup_db_session(group_result=group, membership_result=None)
|
||||
savepoint = MagicMock()
|
||||
db_session.begin_nested.return_value = savepoint
|
||||
db_session.flush.side_effect = IntegrityError(None, None, Exception("duplicate"))
|
||||
user = _mock_user()
|
||||
|
||||
# Should not raise
|
||||
assign_user_to_default_groups__no_commit(db_session, user)
|
||||
|
||||
savepoint.rollback.assert_called_once()
|
||||
|
||||
|
||||
def test_no_commit_called_on_successful_assignment() -> None:
|
||||
group = _mock_group("Basic")
|
||||
db_session = _setup_db_session(group_result=group, membership_result=None)
|
||||
savepoint = MagicMock()
|
||||
db_session.begin_nested.return_value = savepoint
|
||||
user = _mock_user()
|
||||
|
||||
assign_user_to_default_groups__no_commit(db_session, user)
|
||||
|
||||
db_session.commit.assert_not_called()
|
||||
@@ -1,100 +0,0 @@
|
||||
"""Regression tests for delete_messages_and_files_from_chat_session.
|
||||
|
||||
Verifies that user-owned files (those with user_file_id) are never deleted
|
||||
during chat session cleanup — only chat-only files should be removed.
|
||||
"""
|
||||
|
||||
from unittest.mock import call
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.db.chat import delete_messages_and_files_from_chat_session
|
||||
|
||||
_MODULE = "onyx.db.chat"
|
||||
|
||||
|
||||
def _make_db_session(
|
||||
rows: list[tuple[int, list[dict[str, str]] | None]],
|
||||
) -> MagicMock:
|
||||
db_session = MagicMock()
|
||||
db_session.execute.return_value.tuples.return_value.all.return_value = rows
|
||||
return db_session
|
||||
|
||||
|
||||
@patch(f"{_MODULE}.delete_orphaned_search_docs")
|
||||
@patch(f"{_MODULE}.get_default_file_store")
|
||||
def test_user_files_are_not_deleted(
|
||||
mock_get_file_store: MagicMock,
|
||||
_mock_orphan_cleanup: MagicMock,
|
||||
) -> None:
|
||||
"""User files (with user_file_id) must be skipped during cleanup."""
|
||||
file_store = MagicMock()
|
||||
mock_get_file_store.return_value = file_store
|
||||
|
||||
db_session = _make_db_session(
|
||||
[
|
||||
(
|
||||
1,
|
||||
[
|
||||
{"id": "chat-file-1", "type": "image"},
|
||||
{"id": "user-file-1", "type": "document", "user_file_id": "uf-1"},
|
||||
{"id": "chat-file-2", "type": "image"},
|
||||
],
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
delete_messages_and_files_from_chat_session(uuid4(), db_session)
|
||||
|
||||
assert file_store.delete_file.call_count == 2
|
||||
file_store.delete_file.assert_has_calls(
|
||||
[
|
||||
call(file_id="chat-file-1", error_on_missing=False),
|
||||
call(file_id="chat-file-2", error_on_missing=False),
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@patch(f"{_MODULE}.delete_orphaned_search_docs")
|
||||
@patch(f"{_MODULE}.get_default_file_store")
|
||||
def test_only_user_files_means_no_deletions(
|
||||
mock_get_file_store: MagicMock,
|
||||
_mock_orphan_cleanup: MagicMock,
|
||||
) -> None:
|
||||
"""When every file in the session is a user file, nothing should be deleted."""
|
||||
file_store = MagicMock()
|
||||
mock_get_file_store.return_value = file_store
|
||||
|
||||
db_session = _make_db_session(
|
||||
[
|
||||
(1, [{"id": "uf-a", "type": "document", "user_file_id": "uf-1"}]),
|
||||
(2, [{"id": "uf-b", "type": "document", "user_file_id": "uf-2"}]),
|
||||
]
|
||||
)
|
||||
|
||||
delete_messages_and_files_from_chat_session(uuid4(), db_session)
|
||||
|
||||
file_store.delete_file.assert_not_called()
|
||||
|
||||
|
||||
@patch(f"{_MODULE}.delete_orphaned_search_docs")
|
||||
@patch(f"{_MODULE}.get_default_file_store")
|
||||
def test_messages_with_no_files(
|
||||
mock_get_file_store: MagicMock,
|
||||
_mock_orphan_cleanup: MagicMock,
|
||||
) -> None:
|
||||
"""Messages with None or empty file lists should not trigger any deletions."""
|
||||
file_store = MagicMock()
|
||||
mock_get_file_store.return_value = file_store
|
||||
|
||||
db_session = _make_db_session(
|
||||
[
|
||||
(1, None),
|
||||
(2, []),
|
||||
]
|
||||
)
|
||||
|
||||
delete_messages_and_files_from_chat_session(uuid4(), db_session)
|
||||
|
||||
file_store.delete_file.assert_not_called()
|
||||
@@ -1,203 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from onyx.document_index.interfaces_new import TenantState
|
||||
from onyx.document_index.opensearch.constants import DEFAULT_MAX_CHUNK_SIZE
|
||||
from onyx.document_index.opensearch.schema import get_opensearch_doc_chunk_id
|
||||
from onyx.document_index.opensearch.string_filtering import (
|
||||
MAX_DOCUMENT_ID_ENCODED_LENGTH,
|
||||
)
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE
|
||||
|
||||
|
||||
SINGLE_TENANT_STATE = TenantState(
|
||||
tenant_id=POSTGRES_DEFAULT_SCHEMA_STANDARD_VALUE, multitenant=False
|
||||
)
|
||||
MULTI_TENANT_STATE = TenantState(
|
||||
tenant_id="tenant_abcdef12-3456-7890-abcd-ef1234567890", multitenant=True
|
||||
)
|
||||
EXPECTED_SHORT_TENANT = "abcdef12"
|
||||
|
||||
|
||||
class TestGetOpensearchDocChunkIdSingleTenant:
|
||||
def test_basic(self) -> None:
|
||||
result = get_opensearch_doc_chunk_id(
|
||||
SINGLE_TENANT_STATE, "my-doc-id", chunk_index=0
|
||||
)
|
||||
assert result == f"my-doc-id__{DEFAULT_MAX_CHUNK_SIZE}__0"
|
||||
|
||||
def test_custom_chunk_size(self) -> None:
|
||||
result = get_opensearch_doc_chunk_id(
|
||||
SINGLE_TENANT_STATE, "doc1", chunk_index=3, max_chunk_size=1024
|
||||
)
|
||||
assert result == "doc1__1024__3"
|
||||
|
||||
def test_special_chars_are_stripped(self) -> None:
|
||||
"""Tests characters not matching [A-Za-z0-9_.-~] are removed."""
|
||||
result = get_opensearch_doc_chunk_id(
|
||||
SINGLE_TENANT_STATE, "doc/with?special#chars&more%stuff", chunk_index=0
|
||||
)
|
||||
assert "/" not in result
|
||||
assert "?" not in result
|
||||
assert "#" not in result
|
||||
assert result == f"docwithspecialcharsmorestuff__{DEFAULT_MAX_CHUNK_SIZE}__0"
|
||||
|
||||
def test_short_doc_id_not_hashed(self) -> None:
|
||||
"""
|
||||
Tests that a short doc ID should appear directly in the result, not as a
|
||||
hash.
|
||||
"""
|
||||
doc_id = "short-id"
|
||||
result = get_opensearch_doc_chunk_id(SINGLE_TENANT_STATE, doc_id, chunk_index=0)
|
||||
assert "short-id" in result
|
||||
|
||||
def test_long_doc_id_is_hashed(self) -> None:
|
||||
"""
|
||||
Tests that a doc ID exceeding the max length should be replaced with a
|
||||
blake2b hash.
|
||||
"""
|
||||
# Create a doc ID that will exceed max length after the suffix is
|
||||
# appended.
|
||||
doc_id = "a" * MAX_DOCUMENT_ID_ENCODED_LENGTH
|
||||
result = get_opensearch_doc_chunk_id(SINGLE_TENANT_STATE, doc_id, chunk_index=0)
|
||||
# The original doc ID should NOT appear in the result.
|
||||
assert doc_id not in result
|
||||
# The suffix should still be present.
|
||||
assert f"__{DEFAULT_MAX_CHUNK_SIZE}__0" in result
|
||||
|
||||
def test_long_doc_id_hash_is_deterministic(self) -> None:
|
||||
doc_id = "x" * MAX_DOCUMENT_ID_ENCODED_LENGTH
|
||||
result1 = get_opensearch_doc_chunk_id(
|
||||
SINGLE_TENANT_STATE, doc_id, chunk_index=5
|
||||
)
|
||||
result2 = get_opensearch_doc_chunk_id(
|
||||
SINGLE_TENANT_STATE, doc_id, chunk_index=5
|
||||
)
|
||||
assert result1 == result2
|
||||
|
||||
def test_long_doc_id_different_inputs_produce_different_hashes(self) -> None:
|
||||
doc_id_a = "a" * MAX_DOCUMENT_ID_ENCODED_LENGTH
|
||||
doc_id_b = "b" * MAX_DOCUMENT_ID_ENCODED_LENGTH
|
||||
result_a = get_opensearch_doc_chunk_id(
|
||||
SINGLE_TENANT_STATE, doc_id_a, chunk_index=0
|
||||
)
|
||||
result_b = get_opensearch_doc_chunk_id(
|
||||
SINGLE_TENANT_STATE, doc_id_b, chunk_index=0
|
||||
)
|
||||
assert result_a != result_b
|
||||
|
||||
def test_result_never_exceeds_max_length(self) -> None:
|
||||
"""
|
||||
Tests that the final result should always be under
|
||||
MAX_DOCUMENT_ID_ENCODED_LENGTH bytes.
|
||||
"""
|
||||
doc_id = "z" * (MAX_DOCUMENT_ID_ENCODED_LENGTH * 2)
|
||||
result = get_opensearch_doc_chunk_id(
|
||||
SINGLE_TENANT_STATE, doc_id, chunk_index=999, max_chunk_size=99999
|
||||
)
|
||||
assert len(result.encode("utf-8")) < MAX_DOCUMENT_ID_ENCODED_LENGTH
|
||||
|
||||
def test_no_tenant_prefix_in_single_tenant(self) -> None:
|
||||
result = get_opensearch_doc_chunk_id(
|
||||
SINGLE_TENANT_STATE, "mydoc", chunk_index=0
|
||||
)
|
||||
assert not result.startswith(SINGLE_TENANT_STATE.tenant_id)
|
||||
|
||||
|
||||
class TestGetOpensearchDocChunkIdMultiTenant:
|
||||
def test_includes_tenant_prefix(self) -> None:
|
||||
result = get_opensearch_doc_chunk_id(MULTI_TENANT_STATE, "mydoc", chunk_index=0)
|
||||
assert result.startswith(f"{EXPECTED_SHORT_TENANT}__")
|
||||
|
||||
def test_format(self) -> None:
|
||||
result = get_opensearch_doc_chunk_id(
|
||||
MULTI_TENANT_STATE, "mydoc", chunk_index=2, max_chunk_size=256
|
||||
)
|
||||
assert result == f"{EXPECTED_SHORT_TENANT}__mydoc__256__2"
|
||||
|
||||
def test_long_doc_id_is_hashed_multitenant(self) -> None:
|
||||
doc_id = "d" * MAX_DOCUMENT_ID_ENCODED_LENGTH
|
||||
result = get_opensearch_doc_chunk_id(MULTI_TENANT_STATE, doc_id, chunk_index=0)
|
||||
# Should still have tenant prefix.
|
||||
assert result.startswith(f"{EXPECTED_SHORT_TENANT}__")
|
||||
# The original doc ID should NOT appear in the result.
|
||||
assert doc_id not in result
|
||||
# The suffix should still be present.
|
||||
assert f"__{DEFAULT_MAX_CHUNK_SIZE}__0" in result
|
||||
|
||||
def test_result_never_exceeds_max_length_multitenant(self) -> None:
|
||||
doc_id = "q" * (MAX_DOCUMENT_ID_ENCODED_LENGTH * 2)
|
||||
result = get_opensearch_doc_chunk_id(
|
||||
MULTI_TENANT_STATE, doc_id, chunk_index=999, max_chunk_size=99999
|
||||
)
|
||||
assert len(result.encode("utf-8")) < MAX_DOCUMENT_ID_ENCODED_LENGTH
|
||||
|
||||
def test_different_tenants_produce_different_ids(self) -> None:
|
||||
tenant_a = TenantState(
|
||||
tenant_id="tenant_aaaaaaaa-0000-0000-0000-000000000000", multitenant=True
|
||||
)
|
||||
tenant_b = TenantState(
|
||||
tenant_id="tenant_bbbbbbbb-0000-0000-0000-000000000000", multitenant=True
|
||||
)
|
||||
result_a = get_opensearch_doc_chunk_id(tenant_a, "same-doc", chunk_index=0)
|
||||
result_b = get_opensearch_doc_chunk_id(tenant_b, "same-doc", chunk_index=0)
|
||||
assert result_a != result_b
|
||||
|
||||
|
||||
class TestGetOpensearchDocChunkIdEdgeCases:
|
||||
def test_chunk_index_zero(self) -> None:
|
||||
result = get_opensearch_doc_chunk_id(SINGLE_TENANT_STATE, "doc", chunk_index=0)
|
||||
assert result.endswith("__0")
|
||||
|
||||
def test_large_chunk_index(self) -> None:
|
||||
result = get_opensearch_doc_chunk_id(
|
||||
SINGLE_TENANT_STATE, "doc", chunk_index=99999
|
||||
)
|
||||
assert result.endswith("__99999")
|
||||
|
||||
def test_doc_id_with_only_special_chars_raises(self) -> None:
|
||||
"""
|
||||
Tests that a doc ID that becomes empty after filtering should raise
|
||||
ValueError.
|
||||
"""
|
||||
with pytest.raises(ValueError, match="empty after filtering"):
|
||||
get_opensearch_doc_chunk_id(SINGLE_TENANT_STATE, "###???///", chunk_index=0)
|
||||
|
||||
def test_doc_id_at_boundary_length(self) -> None:
|
||||
"""
|
||||
Tests that a doc ID right at the boundary should not be hashed.
|
||||
"""
|
||||
suffix = f"__{DEFAULT_MAX_CHUNK_SIZE}__0"
|
||||
suffix_len = len(suffix.encode("utf-8"))
|
||||
# Max doc ID length that won't trigger hashing (must be <
|
||||
# max_encoded_length).
|
||||
max_doc_len = MAX_DOCUMENT_ID_ENCODED_LENGTH - suffix_len - 1
|
||||
doc_id = "a" * max_doc_len
|
||||
result = get_opensearch_doc_chunk_id(SINGLE_TENANT_STATE, doc_id, chunk_index=0)
|
||||
assert doc_id in result
|
||||
|
||||
def test_doc_id_at_boundary_length_multitenant(self) -> None:
|
||||
"""
|
||||
Tests that a doc ID right at the boundary should not be hashed in
|
||||
multitenant mode.
|
||||
"""
|
||||
suffix = f"__{DEFAULT_MAX_CHUNK_SIZE}__0"
|
||||
suffix_len = len(suffix.encode("utf-8"))
|
||||
prefix = f"{EXPECTED_SHORT_TENANT}__"
|
||||
prefix_len = len(prefix.encode("utf-8"))
|
||||
# Max doc ID length that won't trigger hashing (must be <
|
||||
# max_encoded_length).
|
||||
max_doc_len = MAX_DOCUMENT_ID_ENCODED_LENGTH - suffix_len - prefix_len - 1
|
||||
doc_id = "a" * max_doc_len
|
||||
result = get_opensearch_doc_chunk_id(MULTI_TENANT_STATE, doc_id, chunk_index=0)
|
||||
assert doc_id in result
|
||||
|
||||
def test_doc_id_one_over_boundary_is_hashed(self) -> None:
|
||||
"""
|
||||
Tests that a doc ID one byte over the boundary should be hashed.
|
||||
"""
|
||||
suffix = f"__{DEFAULT_MAX_CHUNK_SIZE}__0"
|
||||
suffix_len = len(suffix.encode("utf-8"))
|
||||
# This length will trigger the >= check in filter_and_validate_document_id
|
||||
doc_id = "a" * (MAX_DOCUMENT_ID_ENCODED_LENGTH - suffix_len)
|
||||
result = get_opensearch_doc_chunk_id(SINGLE_TENANT_STATE, doc_id, chunk_index=0)
|
||||
assert doc_id not in result
|
||||
@@ -1,91 +0,0 @@
|
||||
"""Tests for FileStore.delete_file error_on_missing behavior."""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
_S3_MODULE = "onyx.file_store.file_store"
|
||||
_PG_MODULE = "onyx.file_store.postgres_file_store"
|
||||
|
||||
|
||||
def _mock_db_session() -> MagicMock:
|
||||
session = MagicMock()
|
||||
session.__enter__ = MagicMock(return_value=session)
|
||||
session.__exit__ = MagicMock(return_value=False)
|
||||
return session
|
||||
|
||||
|
||||
# ── S3BackedFileStore ────────────────────────────────────────────────
|
||||
|
||||
|
||||
@patch(f"{_S3_MODULE}.get_session_with_current_tenant_if_none")
|
||||
@patch(f"{_S3_MODULE}.get_filerecord_by_file_id_optional", return_value=None)
|
||||
def test_s3_delete_missing_file_raises_by_default(
|
||||
_mock_get_record: MagicMock,
|
||||
mock_ctx: MagicMock,
|
||||
) -> None:
|
||||
from onyx.file_store.file_store import S3BackedFileStore
|
||||
|
||||
mock_ctx.return_value = _mock_db_session()
|
||||
store = S3BackedFileStore(bucket_name="b")
|
||||
|
||||
with pytest.raises(RuntimeError, match="does not exist"):
|
||||
store.delete_file("nonexistent")
|
||||
|
||||
|
||||
@patch(f"{_S3_MODULE}.get_session_with_current_tenant_if_none")
|
||||
@patch(f"{_S3_MODULE}.get_filerecord_by_file_id_optional", return_value=None)
|
||||
@patch(f"{_S3_MODULE}.delete_filerecord_by_file_id")
|
||||
def test_s3_delete_missing_file_silent_when_error_on_missing_false(
|
||||
mock_delete_record: MagicMock,
|
||||
_mock_get_record: MagicMock,
|
||||
mock_ctx: MagicMock,
|
||||
) -> None:
|
||||
from onyx.file_store.file_store import S3BackedFileStore
|
||||
|
||||
mock_ctx.return_value = _mock_db_session()
|
||||
store = S3BackedFileStore(bucket_name="b")
|
||||
|
||||
store.delete_file("nonexistent", error_on_missing=False)
|
||||
|
||||
mock_delete_record.assert_not_called()
|
||||
|
||||
|
||||
# ── PostgresBackedFileStore ──────────────────────────────────────────
|
||||
|
||||
|
||||
@patch(f"{_PG_MODULE}.get_session_with_current_tenant_if_none")
|
||||
@patch(f"{_PG_MODULE}.get_file_content_by_file_id_optional", return_value=None)
|
||||
def test_pg_delete_missing_file_raises_by_default(
|
||||
_mock_get_content: MagicMock,
|
||||
mock_ctx: MagicMock,
|
||||
) -> None:
|
||||
from onyx.file_store.postgres_file_store import PostgresBackedFileStore
|
||||
|
||||
mock_ctx.return_value = _mock_db_session()
|
||||
store = PostgresBackedFileStore()
|
||||
|
||||
with pytest.raises(RuntimeError, match="does not exist"):
|
||||
store.delete_file("nonexistent")
|
||||
|
||||
|
||||
@patch(f"{_PG_MODULE}.get_session_with_current_tenant_if_none")
|
||||
@patch(f"{_PG_MODULE}.get_file_content_by_file_id_optional", return_value=None)
|
||||
@patch(f"{_PG_MODULE}.delete_file_content_by_file_id")
|
||||
@patch(f"{_PG_MODULE}.delete_filerecord_by_file_id")
|
||||
def test_pg_delete_missing_file_silent_when_error_on_missing_false(
|
||||
mock_delete_record: MagicMock,
|
||||
mock_delete_content: MagicMock,
|
||||
_mock_get_content: MagicMock,
|
||||
mock_ctx: MagicMock,
|
||||
) -> None:
|
||||
from onyx.file_store.postgres_file_store import PostgresBackedFileStore
|
||||
|
||||
mock_ctx.return_value = _mock_db_session()
|
||||
store = PostgresBackedFileStore()
|
||||
|
||||
store.delete_file("nonexistent", error_on_missing=False)
|
||||
|
||||
mock_delete_record.assert_not_called()
|
||||
mock_delete_content.assert_not_called()
|
||||
@@ -113,7 +113,6 @@ def make_db_group(**kwargs: Any) -> MagicMock:
|
||||
group.name = kwargs.get("name", "Engineering")
|
||||
group.is_up_for_deletion = kwargs.get("is_up_for_deletion", False)
|
||||
group.is_up_to_date = kwargs.get("is_up_to_date", True)
|
||||
group.is_default = kwargs.get("is_default", False)
|
||||
return group
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.server.models import FullUserSnapshot
|
||||
from onyx.server.models import UserGroupInfo
|
||||
|
||||
@@ -26,7 +25,6 @@ def _mock_user(
|
||||
user.updated_at = updated_at or datetime.datetime(
|
||||
2025, 6, 15, tzinfo=datetime.timezone.utc
|
||||
)
|
||||
user.account_type = AccountType.STANDARD
|
||||
return user
|
||||
|
||||
|
||||
|
||||
@@ -98,7 +98,6 @@ Useful hardening flags:
|
||||
| `serve` | Serve the interactive chat TUI over SSH |
|
||||
| `configure` | Configure server URL and API key |
|
||||
| `validate-config` | Validate configuration and test connection |
|
||||
| `install-skill` | Install the agent skill file into a project |
|
||||
|
||||
## Slash Commands (in TUI)
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/api"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/config"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/exitcodes"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
@@ -17,23 +16,16 @@ func newAgentsCmd() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "agents",
|
||||
Short: "List available agents",
|
||||
Long: `List all visible agents configured on the Onyx server.
|
||||
|
||||
By default, output is a human-readable table with ID, name, and description.
|
||||
Use --json for machine-readable output.`,
|
||||
Example: ` onyx-cli agents
|
||||
onyx-cli agents --json
|
||||
onyx-cli agents --json | jq '.[].name'`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
cfg := config.Load()
|
||||
if !cfg.IsConfigured() {
|
||||
return exitcodes.New(exitcodes.NotConfigured, "onyx CLI is not configured\n Run: onyx-cli configure")
|
||||
return fmt.Errorf("onyx CLI is not configured — run 'onyx-cli configure' first")
|
||||
}
|
||||
|
||||
client := api.NewClient(cfg)
|
||||
agents, err := client.ListAgents(cmd.Context())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list agents: %w\n Check your connection with: onyx-cli validate-config", err)
|
||||
return fmt.Errorf("failed to list agents: %w", err)
|
||||
}
|
||||
|
||||
if agentsJSON {
|
||||
|
||||
140
cli/cmd/ask.go
140
cli/cmd/ask.go
@@ -4,65 +4,33 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/api"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/config"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/exitcodes"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/models"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/overflow"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
const defaultMaxOutputBytes = 4096
|
||||
|
||||
func newAskCmd() *cobra.Command {
|
||||
var (
|
||||
askAgentID int
|
||||
askJSON bool
|
||||
askQuiet bool
|
||||
askPrompt string
|
||||
maxOutput int
|
||||
)
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "ask [question]",
|
||||
Short: "Ask a one-shot question (non-interactive)",
|
||||
Long: `Send a one-shot question to an Onyx agent and print the response.
|
||||
|
||||
The question can be provided as a positional argument, via --prompt, or piped
|
||||
through stdin. When stdin contains piped data, it is sent as context along
|
||||
with the question from --prompt (or used as the question itself).
|
||||
|
||||
When stdout is not a TTY (e.g., called by a script or AI agent), output is
|
||||
automatically truncated to --max-output bytes and the full response is saved
|
||||
to a temp file. Set --max-output 0 to disable truncation.`,
|
||||
Args: cobra.MaximumNArgs(1),
|
||||
Example: ` onyx-cli ask "What connectors are available?"
|
||||
onyx-cli ask --agent-id 3 "Summarize our Q4 revenue"
|
||||
onyx-cli ask --json "List all users" | jq '.event.content'
|
||||
cat error.log | onyx-cli ask --prompt "Find the root cause"
|
||||
echo "what is onyx?" | onyx-cli ask`,
|
||||
Args: cobra.ExactArgs(1),
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
cfg := config.Load()
|
||||
if !cfg.IsConfigured() {
|
||||
return exitcodes.New(exitcodes.NotConfigured, "onyx CLI is not configured\n Run: onyx-cli configure")
|
||||
}
|
||||
|
||||
if askJSON && askQuiet {
|
||||
return exitcodes.New(exitcodes.BadRequest, "--json and --quiet cannot be used together")
|
||||
}
|
||||
|
||||
question, err := resolveQuestion(args, askPrompt)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("onyx CLI is not configured — run 'onyx-cli configure' first")
|
||||
}
|
||||
|
||||
question := args[0]
|
||||
agentID := cfg.DefaultAgentID
|
||||
if cmd.Flags().Changed("agent-id") {
|
||||
agentID = askAgentID
|
||||
@@ -82,23 +50,9 @@ to a temp file. Set --max-output 0 to disable truncation.`,
|
||||
nil,
|
||||
)
|
||||
|
||||
// Determine truncation threshold.
|
||||
isTTY := term.IsTerminal(int(os.Stdout.Fd()))
|
||||
truncateAt := 0 // 0 means no truncation
|
||||
if cmd.Flags().Changed("max-output") {
|
||||
truncateAt = maxOutput
|
||||
} else if !isTTY {
|
||||
truncateAt = defaultMaxOutputBytes
|
||||
}
|
||||
|
||||
var sessionID string
|
||||
var lastErr error
|
||||
gotStop := false
|
||||
|
||||
// Overflow writer: tees to stdout and optionally to a temp file.
|
||||
// In quiet mode, buffer everything and print once at the end.
|
||||
ow := &overflow.Writer{Limit: truncateAt, Quiet: askQuiet}
|
||||
|
||||
for event := range ch {
|
||||
if e, ok := event.(models.SessionCreatedEvent); ok {
|
||||
sessionID = e.ChatSessionID
|
||||
@@ -128,50 +82,22 @@ to a temp file. Set --max-output 0 to disable truncation.`,
|
||||
|
||||
switch e := event.(type) {
|
||||
case models.MessageDeltaEvent:
|
||||
ow.Write(e.Content)
|
||||
case models.SearchStartEvent:
|
||||
if isTTY && !askQuiet {
|
||||
if e.IsInternetSearch {
|
||||
fmt.Fprintf(os.Stderr, "\033[2mSearching the web...\033[0m\n")
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "\033[2mSearching documents...\033[0m\n")
|
||||
}
|
||||
}
|
||||
case models.SearchQueriesEvent:
|
||||
if isTTY && !askQuiet {
|
||||
for _, q := range e.Queries {
|
||||
fmt.Fprintf(os.Stderr, "\033[2m → %s\033[0m\n", q)
|
||||
}
|
||||
}
|
||||
case models.SearchDocumentsEvent:
|
||||
if isTTY && !askQuiet && len(e.Documents) > 0 {
|
||||
fmt.Fprintf(os.Stderr, "\033[2mFound %d documents\033[0m\n", len(e.Documents))
|
||||
}
|
||||
case models.ReasoningStartEvent:
|
||||
if isTTY && !askQuiet {
|
||||
fmt.Fprintf(os.Stderr, "\033[2mThinking...\033[0m\n")
|
||||
}
|
||||
case models.ToolStartEvent:
|
||||
if isTTY && !askQuiet && e.ToolName != "" {
|
||||
fmt.Fprintf(os.Stderr, "\033[2mUsing %s...\033[0m\n", e.ToolName)
|
||||
}
|
||||
fmt.Print(e.Content)
|
||||
case models.ErrorEvent:
|
||||
ow.Finish()
|
||||
return fmt.Errorf("%s", e.Error)
|
||||
case models.StopEvent:
|
||||
ow.Finish()
|
||||
fmt.Println()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if !askJSON {
|
||||
ow.Finish()
|
||||
}
|
||||
|
||||
if ctx.Err() != nil {
|
||||
if sessionID != "" {
|
||||
client.StopChatSession(context.Background(), sessionID)
|
||||
}
|
||||
if !askJSON {
|
||||
fmt.Println()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -179,56 +105,20 @@ to a temp file. Set --max-output 0 to disable truncation.`,
|
||||
return lastErr
|
||||
}
|
||||
if !gotStop {
|
||||
if !askJSON {
|
||||
fmt.Println()
|
||||
}
|
||||
return fmt.Errorf("stream ended unexpectedly")
|
||||
}
|
||||
if !askJSON {
|
||||
fmt.Println()
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().IntVar(&askAgentID, "agent-id", 0, "Agent ID to use")
|
||||
cmd.Flags().BoolVar(&askJSON, "json", false, "Output raw JSON events")
|
||||
cmd.Flags().BoolVarP(&askQuiet, "quiet", "q", false, "Buffer output and print once at end (no streaming)")
|
||||
cmd.Flags().StringVar(&askPrompt, "prompt", "", "Question text (use with piped stdin context)")
|
||||
cmd.Flags().IntVar(&maxOutput, "max-output", defaultMaxOutputBytes,
|
||||
"Max bytes to print before truncating (0 to disable, auto-enabled for non-TTY)")
|
||||
// Suppress cobra's default error/usage on RunE errors
|
||||
return cmd
|
||||
}
|
||||
|
||||
// resolveQuestion builds the final question string from args, --prompt, and stdin.
|
||||
func resolveQuestion(args []string, prompt string) (string, error) {
|
||||
hasArg := len(args) > 0
|
||||
hasPrompt := prompt != ""
|
||||
hasStdin := !term.IsTerminal(int(os.Stdin.Fd()))
|
||||
|
||||
if hasArg && hasPrompt {
|
||||
return "", exitcodes.New(exitcodes.BadRequest, "specify the question as an argument or --prompt, not both")
|
||||
}
|
||||
|
||||
var stdinContent string
|
||||
if hasStdin {
|
||||
const maxStdinBytes = 10 * 1024 * 1024 // 10MB
|
||||
data, err := io.ReadAll(io.LimitReader(os.Stdin, maxStdinBytes))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read stdin: %w", err)
|
||||
}
|
||||
stdinContent = strings.TrimSpace(string(data))
|
||||
}
|
||||
|
||||
switch {
|
||||
case hasArg && stdinContent != "":
|
||||
// arg is the question, stdin is context
|
||||
return args[0] + "\n\n" + stdinContent, nil
|
||||
case hasArg:
|
||||
return args[0], nil
|
||||
case hasPrompt && stdinContent != "":
|
||||
// --prompt is the question, stdin is context
|
||||
return prompt + "\n\n" + stdinContent, nil
|
||||
case hasPrompt:
|
||||
return prompt, nil
|
||||
case stdinContent != "":
|
||||
return stdinContent, nil
|
||||
default:
|
||||
return "", exitcodes.New(exitcodes.BadRequest, "no question provided\n Usage: onyx-cli ask \"your question\"\n Or: echo \"context\" | onyx-cli ask --prompt \"your question\"")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,22 +4,14 @@ import (
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/config"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/onboarding"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/starprompt"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/tui"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func newChatCmd() *cobra.Command {
|
||||
var noStreamMarkdown bool
|
||||
|
||||
cmd := &cobra.Command{
|
||||
return &cobra.Command{
|
||||
Use: "chat",
|
||||
Short: "Launch the interactive chat TUI (default)",
|
||||
Long: `Launch the interactive terminal UI for chatting with your Onyx agent.
|
||||
This is the default command when no subcommand is specified. On first run,
|
||||
an interactive setup wizard will guide you through configuration.`,
|
||||
Example: ` onyx-cli chat
|
||||
onyx-cli`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
cfg := config.Load()
|
||||
|
||||
@@ -32,22 +24,10 @@ an interactive setup wizard will guide you through configuration.`,
|
||||
cfg = *result
|
||||
}
|
||||
|
||||
// CLI flag overrides config/env
|
||||
if cmd.Flags().Changed("no-stream-markdown") {
|
||||
v := !noStreamMarkdown
|
||||
cfg.Features.StreamMarkdown = &v
|
||||
}
|
||||
|
||||
starprompt.MaybePrompt()
|
||||
|
||||
m := tui.NewModel(cfg)
|
||||
p := tea.NewProgram(m, tea.WithAltScreen(), tea.WithMouseCellMotion())
|
||||
_, err := p.Run()
|
||||
return err
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().BoolVar(&noStreamMarkdown, "no-stream-markdown", false, "Disable progressive markdown rendering during streaming")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
@@ -1,126 +1,19 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/api"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/config"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/exitcodes"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/onboarding"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
func newConfigureCmd() *cobra.Command {
|
||||
var (
|
||||
serverURL string
|
||||
apiKey string
|
||||
apiKeyStdin bool
|
||||
dryRun bool
|
||||
)
|
||||
|
||||
cmd := &cobra.Command{
|
||||
return &cobra.Command{
|
||||
Use: "configure",
|
||||
Short: "Configure server URL and API key",
|
||||
Long: `Set up the Onyx CLI with your server URL and API key.
|
||||
|
||||
When --server-url and --api-key are both provided, the configuration is saved
|
||||
non-interactively (useful for scripts and AI agents). Otherwise, an interactive
|
||||
setup wizard is launched.
|
||||
|
||||
If --api-key is omitted but stdin has piped data, the API key is read from
|
||||
stdin automatically. You can also use --api-key-stdin to make this explicit.
|
||||
This avoids leaking the key in shell history.
|
||||
|
||||
Use --dry-run to test the connection without saving the configuration.`,
|
||||
Example: ` onyx-cli configure
|
||||
onyx-cli configure --server-url https://my-onyx.com --api-key sk-...
|
||||
echo "$ONYX_API_KEY" | onyx-cli configure --server-url https://my-onyx.com
|
||||
echo "$ONYX_API_KEY" | onyx-cli configure --server-url https://my-onyx.com --api-key-stdin
|
||||
onyx-cli configure --server-url https://my-onyx.com --api-key sk-... --dry-run`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
// Read API key from stdin if piped (implicit) or --api-key-stdin (explicit)
|
||||
if apiKeyStdin && apiKey != "" {
|
||||
return exitcodes.New(exitcodes.BadRequest, "--api-key and --api-key-stdin cannot be used together")
|
||||
}
|
||||
if (apiKey == "" && !term.IsTerminal(int(os.Stdin.Fd()))) || apiKeyStdin {
|
||||
data, err := io.ReadAll(os.Stdin)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read API key from stdin: %w", err)
|
||||
}
|
||||
apiKey = strings.TrimSpace(string(data))
|
||||
}
|
||||
|
||||
if serverURL != "" && apiKey != "" {
|
||||
return configureNonInteractive(serverURL, apiKey, dryRun)
|
||||
}
|
||||
|
||||
if dryRun {
|
||||
return exitcodes.New(exitcodes.BadRequest, "--dry-run requires --server-url and --api-key")
|
||||
}
|
||||
|
||||
if serverURL != "" || apiKey != "" {
|
||||
return exitcodes.New(exitcodes.BadRequest, "both --server-url and --api-key are required for non-interactive setup\n Run 'onyx-cli configure' without flags for interactive setup")
|
||||
}
|
||||
|
||||
cfg := config.Load()
|
||||
onboarding.Run(&cfg)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().StringVar(&serverURL, "server-url", "", "Onyx server URL (e.g., https://cloud.onyx.app)")
|
||||
cmd.Flags().StringVar(&apiKey, "api-key", "", "API key for authentication (or pipe via stdin)")
|
||||
cmd.Flags().BoolVar(&apiKeyStdin, "api-key-stdin", false, "Read API key from stdin (explicit; also happens automatically when stdin is piped)")
|
||||
cmd.Flags().BoolVar(&dryRun, "dry-run", false, "Test connection without saving config (requires --server-url and --api-key)")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func configureNonInteractive(serverURL, apiKey string, dryRun bool) error {
|
||||
cfg := config.OnyxCliConfig{
|
||||
ServerURL: serverURL,
|
||||
APIKey: apiKey,
|
||||
DefaultAgentID: 0,
|
||||
}
|
||||
|
||||
// Preserve existing default agent ID from disk (not env overrides)
|
||||
if existing := config.LoadFromDisk(); existing.DefaultAgentID != 0 {
|
||||
cfg.DefaultAgentID = existing.DefaultAgentID
|
||||
}
|
||||
|
||||
// Test connection
|
||||
client := api.NewClient(cfg)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := client.TestConnection(ctx); err != nil {
|
||||
var authErr *api.AuthError
|
||||
if errors.As(err, &authErr) {
|
||||
return exitcodes.Newf(exitcodes.AuthFailure, "authentication failed: %v\n Check your API key", err)
|
||||
}
|
||||
return exitcodes.Newf(exitcodes.Unreachable, "connection failed: %v\n Check your server URL", err)
|
||||
}
|
||||
|
||||
if dryRun {
|
||||
fmt.Printf("Server: %s\n", serverURL)
|
||||
fmt.Println("Status: connected and authenticated")
|
||||
fmt.Println("Dry run: config was NOT saved")
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := config.Save(cfg); err != nil {
|
||||
return fmt.Errorf("could not save config: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Config: %s\n", config.ConfigFilePath())
|
||||
fmt.Printf("Server: %s\n", serverURL)
|
||||
fmt.Println("Status: connected and authenticated")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1,20 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/config"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func newExperimentsCmd() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: "experiments",
|
||||
Short: "List experimental features and their status",
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
cfg := config.Load()
|
||||
_, _ = fmt.Fprintln(cmd.OutOrStdout(), config.ExperimentsText(cfg.Features))
|
||||
return nil
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -1,176 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/embedded"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/fsutil"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
// agentSkillDirs maps agent names to their skill directory paths (relative to
|
||||
// the project or home root). "Universal" agents like Cursor and Codex read
|
||||
// from .agents/skills directly, so they don't need their own entry here.
|
||||
var agentSkillDirs = map[string]string{
|
||||
"claude-code": filepath.Join(".claude", "skills"),
|
||||
}
|
||||
|
||||
const (
|
||||
canonicalDir = ".agents/skills"
|
||||
skillName = "onyx-cli"
|
||||
)
|
||||
|
||||
func newInstallSkillCmd() *cobra.Command {
|
||||
var (
|
||||
global bool
|
||||
copyMode bool
|
||||
agents []string
|
||||
)
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "install-skill",
|
||||
Short: "Install the Onyx CLI agent skill file",
|
||||
Long: `Install the bundled SKILL.md so that AI coding agents can discover and use
|
||||
the Onyx CLI as a tool.
|
||||
|
||||
Files are written to the canonical .agents/skills/onyx-cli/ directory. For
|
||||
agents that use their own skill directory (e.g. Claude Code uses .claude/skills/),
|
||||
a symlink is created pointing back to the canonical copy.
|
||||
|
||||
By default the skill is installed at the project level (current directory).
|
||||
Use --global to install under your home directory instead.
|
||||
|
||||
Use --copy to write independent copies instead of symlinks.
|
||||
Use --agent to target specific agents (can be repeated).`,
|
||||
Example: ` onyx-cli install-skill
|
||||
onyx-cli install-skill --global
|
||||
onyx-cli install-skill --agent claude-code
|
||||
onyx-cli install-skill --copy`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
base, err := installBase(global)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Write the canonical copy.
|
||||
canonicalSkillDir := filepath.Join(base, canonicalDir, skillName)
|
||||
dest := filepath.Join(canonicalSkillDir, "SKILL.md")
|
||||
content := []byte(embedded.SkillMD)
|
||||
|
||||
status, err := fsutil.CompareFile(dest, content)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
switch status {
|
||||
case fsutil.StatusUpToDate:
|
||||
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Up to date %s\n", dest)
|
||||
case fsutil.StatusDiffers:
|
||||
_, _ = fmt.Fprintf(cmd.ErrOrStderr(), "Warning: overwriting modified %s\n", dest)
|
||||
if err := os.WriteFile(dest, content, 0o644); err != nil {
|
||||
return fmt.Errorf("could not write skill file: %w", err)
|
||||
}
|
||||
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Installed %s\n", dest)
|
||||
default: // statusMissing
|
||||
if err := os.MkdirAll(canonicalSkillDir, 0o755); err != nil {
|
||||
return fmt.Errorf("could not create directory: %w", err)
|
||||
}
|
||||
if err := os.WriteFile(dest, content, 0o644); err != nil {
|
||||
return fmt.Errorf("could not write skill file: %w", err)
|
||||
}
|
||||
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Installed %s\n", dest)
|
||||
}
|
||||
|
||||
// Determine which agents to link.
|
||||
targets := agentSkillDirs
|
||||
if len(agents) > 0 {
|
||||
targets = make(map[string]string)
|
||||
for _, a := range agents {
|
||||
dir, ok := agentSkillDirs[a]
|
||||
if !ok {
|
||||
_, _ = fmt.Fprintf(cmd.ErrOrStderr(), "Unknown agent %q (skipped) — known agents:", a)
|
||||
for name := range agentSkillDirs {
|
||||
_, _ = fmt.Fprintf(cmd.ErrOrStderr(), " %s", name)
|
||||
}
|
||||
_, _ = fmt.Fprintln(cmd.ErrOrStderr())
|
||||
continue
|
||||
}
|
||||
targets[a] = dir
|
||||
}
|
||||
}
|
||||
|
||||
// Create symlinks (or copies) from agent-specific dirs to canonical.
|
||||
for name, skillsDir := range targets {
|
||||
agentSkillDir := filepath.Join(base, skillsDir, skillName)
|
||||
|
||||
if copyMode {
|
||||
copyDest := filepath.Join(agentSkillDir, "SKILL.md")
|
||||
if err := fsutil.EnsureDirForCopy(agentSkillDir); err != nil {
|
||||
return fmt.Errorf("could not prepare %s directory: %w", name, err)
|
||||
}
|
||||
if err := os.MkdirAll(agentSkillDir, 0o755); err != nil {
|
||||
return fmt.Errorf("could not create %s directory: %w", name, err)
|
||||
}
|
||||
if err := os.WriteFile(copyDest, []byte(embedded.SkillMD), 0o644); err != nil {
|
||||
return fmt.Errorf("could not write %s skill file: %w", name, err)
|
||||
}
|
||||
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Copied %s\n", copyDest)
|
||||
continue
|
||||
}
|
||||
|
||||
// Compute relative symlink target. Symlinks resolve relative to
|
||||
// the parent directory of the link, not the link itself.
|
||||
rel, err := filepath.Rel(filepath.Dir(agentSkillDir), canonicalSkillDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("could not compute relative path for %s: %w", name, err)
|
||||
}
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(agentSkillDir), 0o755); err != nil {
|
||||
return fmt.Errorf("could not create %s directory: %w", name, err)
|
||||
}
|
||||
|
||||
// Remove existing symlink/dir before creating.
|
||||
_ = os.Remove(agentSkillDir)
|
||||
|
||||
if err := os.Symlink(rel, agentSkillDir); err != nil {
|
||||
// Fall back to copy if symlink fails (e.g. Windows without dev mode).
|
||||
copyDest := filepath.Join(agentSkillDir, "SKILL.md")
|
||||
if mkErr := os.MkdirAll(agentSkillDir, 0o755); mkErr != nil {
|
||||
return fmt.Errorf("could not create %s directory: %w", name, mkErr)
|
||||
}
|
||||
if wErr := os.WriteFile(copyDest, []byte(embedded.SkillMD), 0o644); wErr != nil {
|
||||
return fmt.Errorf("could not write %s skill file: %w", name, wErr)
|
||||
}
|
||||
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Copied %s (symlink failed)\n", copyDest)
|
||||
continue
|
||||
}
|
||||
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Linked %s -> %s\n", agentSkillDir, rel)
|
||||
}
|
||||
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().BoolVarP(&global, "global", "g", false, "Install to home directory instead of project")
|
||||
cmd.Flags().BoolVar(©Mode, "copy", false, "Copy files instead of symlinking")
|
||||
cmd.Flags().StringSliceVarP(&agents, "agent", "a", nil, "Target specific agents (e.g. claude-code)")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func installBase(global bool) (string, error) {
|
||||
if global {
|
||||
home, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("could not determine home directory: %w", err)
|
||||
}
|
||||
return home, nil
|
||||
}
|
||||
cwd, err := os.Getwd()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("could not determine working directory: %w", err)
|
||||
}
|
||||
return cwd, nil
|
||||
}
|
||||
|
||||
@@ -97,8 +97,6 @@ func Execute() error {
|
||||
rootCmd.AddCommand(newConfigureCmd())
|
||||
rootCmd.AddCommand(newValidateConfigCmd())
|
||||
rootCmd.AddCommand(newServeCmd())
|
||||
rootCmd.AddCommand(newInstallSkillCmd())
|
||||
rootCmd.AddCommand(newExperimentsCmd())
|
||||
|
||||
// Default command is chat, but intercept --version first
|
||||
rootCmd.RunE = func(cmd *cobra.Command, args []string) error {
|
||||
|
||||
@@ -23,7 +23,6 @@ import (
|
||||
"github.com/charmbracelet/wish/ratelimiter"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/api"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/config"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/exitcodes"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/tui"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/time/rate"
|
||||
@@ -296,15 +295,15 @@ provided via the ONYX_API_KEY environment variable to skip the prompt:
|
||||
The server URL is taken from the server operator's config. The server
|
||||
auto-generates an Ed25519 host key on first run if the key file does not
|
||||
already exist. The host key path can also be set via the ONYX_SSH_HOST_KEY
|
||||
environment variable (the --host-key flag takes precedence).`,
|
||||
Example: ` onyx-cli serve --port 2222
|
||||
ssh localhost -p 2222
|
||||
onyx-cli serve --host 0.0.0.0 --port 2222
|
||||
onyx-cli serve --idle-timeout 30m --max-session-timeout 2h`,
|
||||
environment variable (the --host-key flag takes precedence).
|
||||
|
||||
Example:
|
||||
onyx-cli serve --port 2222
|
||||
ssh localhost -p 2222`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
serverCfg := config.Load()
|
||||
if serverCfg.ServerURL == "" {
|
||||
return exitcodes.New(exitcodes.NotConfigured, "server URL is not configured\n Run: onyx-cli configure")
|
||||
return fmt.Errorf("server URL is not configured; run 'onyx-cli configure' first")
|
||||
}
|
||||
if !cmd.Flags().Changed("host-key") {
|
||||
if v := os.Getenv(config.EnvSSHHostKey); v != "" {
|
||||
|
||||
@@ -2,13 +2,11 @@ package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/api"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/config"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/exitcodes"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/version"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
@@ -18,21 +16,17 @@ func newValidateConfigCmd() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: "validate-config",
|
||||
Short: "Validate configuration and test server connection",
|
||||
Long: `Check that the CLI is configured, the server is reachable, and the API key
|
||||
is valid. Also reports the server version and warns if it is below the
|
||||
minimum required.`,
|
||||
Example: ` onyx-cli validate-config`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
// Check config file
|
||||
if !config.ConfigExists() {
|
||||
return exitcodes.Newf(exitcodes.NotConfigured, "config file not found at %s\n Run: onyx-cli configure", config.ConfigFilePath())
|
||||
return fmt.Errorf("config file not found at %s\n Run 'onyx-cli configure' to set up", config.ConfigFilePath())
|
||||
}
|
||||
|
||||
cfg := config.Load()
|
||||
|
||||
// Check API key
|
||||
if !cfg.IsConfigured() {
|
||||
return exitcodes.New(exitcodes.NotConfigured, "API key is missing\n Run: onyx-cli configure")
|
||||
return fmt.Errorf("API key is missing\n Run 'onyx-cli configure' to set up")
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Config: %s\n", config.ConfigFilePath())
|
||||
@@ -41,11 +35,7 @@ minimum required.`,
|
||||
// Test connection
|
||||
client := api.NewClient(cfg)
|
||||
if err := client.TestConnection(cmd.Context()); err != nil {
|
||||
var authErr *api.AuthError
|
||||
if errors.As(err, &authErr) {
|
||||
return exitcodes.Newf(exitcodes.AuthFailure, "authentication failed: %v\n Reconfigure with: onyx-cli configure", err)
|
||||
}
|
||||
return exitcodes.Newf(exitcodes.Unreachable, "connection failed: %v\n Reconfigure with: onyx-cli configure", err)
|
||||
return fmt.Errorf("connection failed: %w", err)
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintln(cmd.OutOrStdout(), "Status: connected and authenticated")
|
||||
|
||||
@@ -149,12 +149,12 @@ func (c *Client) TestConnection(ctx context.Context) error {
|
||||
|
||||
if resp2.StatusCode == 401 || resp2.StatusCode == 403 {
|
||||
if isHTML || strings.Contains(respServer, "awselb") {
|
||||
return &AuthError{Message: fmt.Sprintf("HTTP %d from a reverse proxy (not the Onyx backend).\n Check your deployment's ingress / proxy configuration", resp2.StatusCode)}
|
||||
return fmt.Errorf("HTTP %d from a reverse proxy (not the Onyx backend).\n Check your deployment's ingress / proxy configuration", resp2.StatusCode)
|
||||
}
|
||||
if resp2.StatusCode == 401 {
|
||||
return &AuthError{Message: fmt.Sprintf("invalid API key or token.\n %s", body)}
|
||||
return fmt.Errorf("invalid API key or token.\n %s", body)
|
||||
}
|
||||
return &AuthError{Message: fmt.Sprintf("access denied — check that the API key is valid.\n %s", body)}
|
||||
return fmt.Errorf("access denied — check that the API key is valid.\n %s", body)
|
||||
}
|
||||
|
||||
detail := fmt.Sprintf("HTTP %d", resp2.StatusCode)
|
||||
|
||||
@@ -11,12 +11,3 @@ type OnyxAPIError struct {
|
||||
func (e *OnyxAPIError) Error() string {
|
||||
return fmt.Sprintf("HTTP %d: %s", e.StatusCode, e.Detail)
|
||||
}
|
||||
|
||||
// AuthError is returned when authentication or authorization fails.
|
||||
type AuthError struct {
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e *AuthError) Error() string {
|
||||
return e.Message
|
||||
}
|
||||
|
||||
@@ -9,47 +9,28 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
EnvServerURL = "ONYX_SERVER_URL"
|
||||
EnvAPIKey = "ONYX_API_KEY"
|
||||
EnvAgentID = "ONYX_PERSONA_ID"
|
||||
EnvSSHHostKey = "ONYX_SSH_HOST_KEY"
|
||||
EnvStreamMarkdown = "ONYX_STREAM_MARKDOWN"
|
||||
EnvServerURL = "ONYX_SERVER_URL"
|
||||
EnvAPIKey = "ONYX_API_KEY"
|
||||
EnvAgentID = "ONYX_PERSONA_ID"
|
||||
EnvSSHHostKey = "ONYX_SSH_HOST_KEY"
|
||||
)
|
||||
|
||||
// Features holds experimental feature flags for the CLI.
|
||||
type Features struct {
|
||||
// StreamMarkdown enables progressive markdown rendering during streaming,
|
||||
// so output is formatted as it arrives rather than after completion.
|
||||
// nil means use the app default (true).
|
||||
StreamMarkdown *bool `json:"stream_markdown,omitempty"`
|
||||
}
|
||||
|
||||
// OnyxCliConfig holds the CLI configuration.
|
||||
type OnyxCliConfig struct {
|
||||
ServerURL string `json:"server_url"`
|
||||
APIKey string `json:"api_key"`
|
||||
DefaultAgentID int `json:"default_persona_id"`
|
||||
Features Features `json:"features,omitempty"`
|
||||
ServerURL string `json:"server_url"`
|
||||
APIKey string `json:"api_key"`
|
||||
DefaultAgentID int `json:"default_persona_id"`
|
||||
}
|
||||
|
||||
// DefaultConfig returns a config with default values.
|
||||
func DefaultConfig() OnyxCliConfig {
|
||||
return OnyxCliConfig{
|
||||
ServerURL: "https://cloud.onyx.app",
|
||||
APIKey: "",
|
||||
ServerURL: "https://cloud.onyx.app",
|
||||
APIKey: "",
|
||||
DefaultAgentID: 0,
|
||||
}
|
||||
}
|
||||
|
||||
// StreamMarkdownEnabled returns whether stream markdown is enabled,
|
||||
// defaulting to true when the user hasn't set an explicit preference.
|
||||
func (f Features) StreamMarkdownEnabled() bool {
|
||||
if f.StreamMarkdown != nil {
|
||||
return *f.StreamMarkdown
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// IsConfigured returns true if the config has an API key.
|
||||
func (c OnyxCliConfig) IsConfigured() bool {
|
||||
return c.APIKey != ""
|
||||
@@ -78,10 +59,8 @@ func ConfigExists() bool {
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// LoadFromDisk reads config from the file only, without applying environment
|
||||
// variable overrides. Use this when you need the persisted config values
|
||||
// (e.g., to preserve them during a save operation).
|
||||
func LoadFromDisk() OnyxCliConfig {
|
||||
// Load reads config from file and applies environment variable overrides.
|
||||
func Load() OnyxCliConfig {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
data, err := os.ReadFile(ConfigFilePath())
|
||||
@@ -91,13 +70,6 @@ func LoadFromDisk() OnyxCliConfig {
|
||||
}
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
// Load reads config from file and applies environment variable overrides.
|
||||
func Load() OnyxCliConfig {
|
||||
cfg := LoadFromDisk()
|
||||
|
||||
// Environment overrides
|
||||
if v := os.Getenv(EnvServerURL); v != "" {
|
||||
cfg.ServerURL = v
|
||||
@@ -110,13 +82,6 @@ func Load() OnyxCliConfig {
|
||||
cfg.DefaultAgentID = id
|
||||
}
|
||||
}
|
||||
if v := os.Getenv(EnvStreamMarkdown); v != "" {
|
||||
if b, err := strconv.ParseBool(v); err == nil {
|
||||
cfg.Features.StreamMarkdown = &b
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "warning: invalid value %q for %s (expected true/false), ignoring\n", v, EnvStreamMarkdown)
|
||||
}
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
@@ -9,7 +9,7 @@ import (
|
||||
|
||||
func clearEnvVars(t *testing.T) {
|
||||
t.Helper()
|
||||
for _, key := range []string{EnvServerURL, EnvAPIKey, EnvAgentID, EnvStreamMarkdown} {
|
||||
for _, key := range []string{EnvServerURL, EnvAPIKey, EnvAgentID} {
|
||||
t.Setenv(key, "")
|
||||
if err := os.Unsetenv(key); err != nil {
|
||||
t.Fatal(err)
|
||||
@@ -199,48 +199,6 @@ func TestSaveAndReload(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestDefaultFeaturesStreamMarkdownNil(t *testing.T) {
|
||||
cfg := DefaultConfig()
|
||||
if cfg.Features.StreamMarkdown != nil {
|
||||
t.Error("expected StreamMarkdown to be nil by default")
|
||||
}
|
||||
if !cfg.Features.StreamMarkdownEnabled() {
|
||||
t.Error("expected StreamMarkdownEnabled() to return true when nil")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnvOverrideStreamMarkdownFalse(t *testing.T) {
|
||||
clearEnvVars(t)
|
||||
dir := t.TempDir()
|
||||
t.Setenv("XDG_CONFIG_HOME", dir)
|
||||
t.Setenv(EnvStreamMarkdown, "false")
|
||||
|
||||
cfg := Load()
|
||||
if cfg.Features.StreamMarkdown == nil || *cfg.Features.StreamMarkdown {
|
||||
t.Error("expected StreamMarkdown=false from env override")
|
||||
}
|
||||
}
|
||||
|
||||
func TestLoadFeaturesFromFile(t *testing.T) {
|
||||
clearEnvVars(t)
|
||||
dir := t.TempDir()
|
||||
t.Setenv("XDG_CONFIG_HOME", dir)
|
||||
|
||||
data, _ := json.Marshal(map[string]interface{}{
|
||||
"server_url": "https://example.com",
|
||||
"api_key": "key",
|
||||
"features": map[string]interface{}{
|
||||
"stream_markdown": true,
|
||||
},
|
||||
})
|
||||
writeConfig(t, dir, data)
|
||||
|
||||
cfg := Load()
|
||||
if cfg.Features.StreamMarkdown == nil || !*cfg.Features.StreamMarkdown {
|
||||
t.Error("expected StreamMarkdown=true from config file")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSaveCreatesParentDirs(t *testing.T) {
|
||||
clearEnvVars(t)
|
||||
dir := t.TempDir()
|
||||
|
||||
@@ -1,46 +0,0 @@
|
||||
package config
|
||||
|
||||
import "fmt"
|
||||
|
||||
// Experiment describes an experimental feature flag.
|
||||
type Experiment struct {
|
||||
Name string
|
||||
Flag string // CLI flag name
|
||||
EnvVar string // environment variable name
|
||||
Config string // JSON path in config file
|
||||
Enabled bool
|
||||
Desc string
|
||||
}
|
||||
|
||||
// Experiments returns the list of available experimental features
|
||||
// with their current status based on the given feature flags.
|
||||
func Experiments(f Features) []Experiment {
|
||||
return []Experiment{
|
||||
{
|
||||
Name: "Stream Markdown",
|
||||
Flag: "--no-stream-markdown",
|
||||
EnvVar: EnvStreamMarkdown,
|
||||
Config: "features.stream_markdown",
|
||||
Enabled: f.StreamMarkdownEnabled(),
|
||||
Desc: "Render markdown progressively as the response streams in (enabled by default)",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ExperimentsText formats the experiments list for display.
|
||||
func ExperimentsText(f Features) string {
|
||||
exps := Experiments(f)
|
||||
text := "Experimental Features\n\n"
|
||||
for _, e := range exps {
|
||||
status := "off"
|
||||
if e.Enabled {
|
||||
status = "on"
|
||||
}
|
||||
text += fmt.Sprintf(" %-20s [%s]\n", e.Name, status)
|
||||
text += fmt.Sprintf(" %s\n", e.Desc)
|
||||
text += fmt.Sprintf(" flag: %s env: %s config: %s\n\n", e.Flag, e.EnvVar, e.Config)
|
||||
}
|
||||
text += "Toggle via CLI flag, environment variable, or config file.\n"
|
||||
text += "Example: onyx-cli chat --no-stream-markdown"
|
||||
return text
|
||||
}
|
||||
@@ -1,187 +0,0 @@
|
||||
---
|
||||
name: onyx-cli
|
||||
description: Query the Onyx knowledge base using the onyx-cli command. Use when the user wants to search company documents, ask questions about internal knowledge, query connected data sources, or look up information stored in Onyx.
|
||||
---
|
||||
|
||||
# Onyx CLI — Agent Tool
|
||||
|
||||
Onyx is an enterprise search and Gen-AI platform that connects to company documents, apps, and people. The `onyx-cli` CLI provides non-interactive commands to query the Onyx knowledge base and list available agents.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
### 1. Check if installed
|
||||
|
||||
```bash
|
||||
which onyx-cli
|
||||
```
|
||||
|
||||
### 2. Install (if needed)
|
||||
|
||||
**Primary — pip:**
|
||||
|
||||
```bash
|
||||
pip install onyx-cli
|
||||
```
|
||||
|
||||
**From source (Go):**
|
||||
|
||||
```bash
|
||||
go build -o onyx-cli github.com/onyx-dot-app/onyx/cli && sudo mv onyx-cli /usr/local/bin/
|
||||
```
|
||||
|
||||
### 3. Check if configured
|
||||
|
||||
```bash
|
||||
onyx-cli validate-config
|
||||
```
|
||||
|
||||
This checks the config file exists, API key is present, and tests the server connection via `/api/me`. Exit code 0 on success, non-zero with a descriptive error on failure.
|
||||
|
||||
If unconfigured, you have two options:
|
||||
|
||||
**Option A — Interactive setup (requires user input):**
|
||||
|
||||
```bash
|
||||
onyx-cli configure
|
||||
```
|
||||
|
||||
This prompts for the Onyx server URL and API key, tests the connection, and saves config.
|
||||
|
||||
**Option B — Environment variables (non-interactive, preferred for agents):**
|
||||
|
||||
```bash
|
||||
export ONYX_SERVER_URL="https://your-onyx-server.com" # default: https://cloud.onyx.app
|
||||
export ONYX_API_KEY="your-api-key"
|
||||
```
|
||||
|
||||
Environment variables override the config file. If these are set, no config file is needed.
|
||||
|
||||
| Variable | Required | Description |
|
||||
| ----------------- | -------- | -------------------------------------------------------- |
|
||||
| `ONYX_SERVER_URL` | No | Onyx server base URL (default: `https://cloud.onyx.app`) |
|
||||
| `ONYX_API_KEY` | Yes | API key for authentication |
|
||||
| `ONYX_PERSONA_ID` | No | Default agent/persona ID |
|
||||
|
||||
If neither the config file nor environment variables are set, tell the user that `onyx-cli` needs to be configured and ask them to either:
|
||||
|
||||
- Run `onyx-cli configure` interactively, or
|
||||
- Set `ONYX_SERVER_URL` and `ONYX_API_KEY` environment variables
|
||||
|
||||
## Commands
|
||||
|
||||
### Validate configuration
|
||||
|
||||
```bash
|
||||
onyx-cli validate-config
|
||||
```
|
||||
|
||||
Checks config file exists, API key is present, and tests the server connection. Use this before `ask` or `agents` to confirm the CLI is properly set up.
|
||||
|
||||
### List available agents
|
||||
|
||||
```bash
|
||||
onyx-cli agents
|
||||
```
|
||||
|
||||
Prints a table of agent IDs, names, and descriptions. Use `--json` for structured output:
|
||||
|
||||
```bash
|
||||
onyx-cli agents --json
|
||||
```
|
||||
|
||||
Use agent IDs with `ask --agent-id` to query a specific agent.
|
||||
|
||||
### Basic query (plain text output)
|
||||
|
||||
```bash
|
||||
onyx-cli ask "What is our company's PTO policy?"
|
||||
```
|
||||
|
||||
Streams the answer as plain text to stdout. Exit code 0 on success, non-zero on error.
|
||||
|
||||
### JSON output (structured events)
|
||||
|
||||
```bash
|
||||
onyx-cli ask --json "What authentication methods do we support?"
|
||||
```
|
||||
|
||||
Outputs JSON-encoded parsed stream events (one object per line). Key event objects include message deltas, stop, errors, search-start, and citation payloads.
|
||||
|
||||
Each line is a JSON object with this envelope:
|
||||
|
||||
```json
|
||||
{"type": "<event_type>", "event": { ... }}
|
||||
```
|
||||
|
||||
| Event Type | Description |
|
||||
| ------------------- | -------------------------------------------------------------------- |
|
||||
| `message_delta` | Content token — concatenate all `content` fields for the full answer |
|
||||
| `stop` | Stream complete |
|
||||
| `error` | Error with `error` message field |
|
||||
| `search_tool_start` | Onyx started searching documents |
|
||||
| `citation_info` | Source citation — see shape below |
|
||||
|
||||
`citation_info` event shape:
|
||||
|
||||
```json
|
||||
{
|
||||
"type": "citation_info",
|
||||
"event": {
|
||||
"citation_number": 1,
|
||||
"document_id": "abc123def456",
|
||||
"placement": { "turn_index": 0, "tab_index": 0, "sub_turn_index": null }
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
`placement` is metadata about where in the conversation the citation appeared and can be ignored for most use cases.
|
||||
|
||||
### Specify an agent
|
||||
|
||||
```bash
|
||||
onyx-cli ask --agent-id 5 "Summarize our Q4 roadmap"
|
||||
```
|
||||
|
||||
Uses a specific Onyx agent/persona instead of the default.
|
||||
|
||||
### All flags
|
||||
|
||||
| Flag | Type | Description |
|
||||
| ------------ | ---- | ---------------------------------------------- |
|
||||
| `--agent-id` | int | Agent ID to use (overrides default) |
|
||||
| `--json` | bool | Output raw NDJSON events instead of plain text |
|
||||
|
||||
## Statelessness
|
||||
|
||||
Each `onyx-cli ask` call creates an independent chat session. There is no built-in way to chain context across multiple `ask` invocations — every call starts fresh. If you need multi-turn conversation with memory, use the interactive TUI (`onyx-cli` or `onyx-cli chat`) instead.
|
||||
|
||||
## When to Use
|
||||
|
||||
Use `onyx-cli ask` when:
|
||||
|
||||
- The user asks about company-specific information (policies, docs, processes)
|
||||
- You need to search internal knowledge bases or connected data sources
|
||||
- The user references Onyx, asks you to "search Onyx", or wants to query their documents
|
||||
- You need context from company wikis, Confluence, Google Drive, Slack, or other connected sources
|
||||
|
||||
Do NOT use when:
|
||||
|
||||
- The question is about general programming knowledge (use your own knowledge)
|
||||
- The user is asking about code in the current repository (use grep/read tools)
|
||||
- The user hasn't mentioned Onyx and the question doesn't require internal company data
|
||||
|
||||
## Examples
|
||||
|
||||
```bash
|
||||
# Simple question
|
||||
onyx-cli ask "What are the steps to deploy to production?"
|
||||
|
||||
# Get structured output for parsing
|
||||
onyx-cli ask --json "List all active API integrations"
|
||||
|
||||
# Use a specialized agent
|
||||
onyx-cli ask --agent-id 3 "What were the action items from last week's standup?"
|
||||
|
||||
# Pipe the answer into another command
|
||||
onyx-cli ask "What is the database schema for users?" | head -20
|
||||
```
|
||||
@@ -1,7 +0,0 @@
|
||||
// Package embedded holds files that are compiled into the onyx-cli binary.
|
||||
package embedded
|
||||
|
||||
import _ "embed"
|
||||
|
||||
//go:embed SKILL.md
|
||||
var SkillMD string
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user