mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-04-03 14:02:42 +00:00
Compare commits
28 Commits
multi-mode
...
cli-agent-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2c28b57992 | ||
|
|
e424d34c7d | ||
|
|
411dc8e86c | ||
|
|
026298aed7 | ||
|
|
d0c5c6dc66 | ||
|
|
85c5507fc0 | ||
|
|
00e7fe2280 | ||
|
|
f5970f8f7f | ||
|
|
0eaab180dd | ||
|
|
9399cc7548 | ||
|
|
9c3a85d1fc | ||
|
|
2faa475c83 | ||
|
|
52d926f002 | ||
|
|
718227a336 | ||
|
|
73cd88a708 | ||
|
|
b08f50fa53 | ||
|
|
ea8366aa69 | ||
|
|
e6f7c2b45c | ||
|
|
f77128d929 | ||
|
|
1d4ca769e7 | ||
|
|
e002f6c195 | ||
|
|
10d696262f | ||
|
|
608e151443 | ||
|
|
41d1a33093 | ||
|
|
f396ebbdbb | ||
|
|
67c8df002e | ||
|
|
722f7de335 | ||
|
|
df14bbe0e2 |
108
backend/alembic/versions/03d085c5c38d_backfill_account_type.py
Normal file
108
backend/alembic/versions/03d085c5c38d_backfill_account_type.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""backfill_account_type
|
||||
|
||||
Revision ID: 03d085c5c38d
|
||||
Revises: 977e834c1427
|
||||
Create Date: 2026-03-25 16:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "03d085c5c38d"
|
||||
down_revision = "977e834c1427"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
_STANDARD = "STANDARD"
|
||||
_BOT = "BOT"
|
||||
_EXT_PERM_USER = "EXT_PERM_USER"
|
||||
_SERVICE_ACCOUNT = "SERVICE_ACCOUNT"
|
||||
_ANONYMOUS = "ANONYMOUS"
|
||||
|
||||
# Well-known anonymous user UUID
|
||||
ANONYMOUS_USER_ID = "00000000-0000-0000-0000-000000000002"
|
||||
|
||||
# Email pattern for API key virtual users
|
||||
API_KEY_EMAIL_PATTERN = r"API\_KEY\_\_%"
|
||||
|
||||
# Reflect the table structure for use in DML
|
||||
user_table = sa.table(
|
||||
"user",
|
||||
sa.column("id", sa.Uuid),
|
||||
sa.column("email", sa.String),
|
||||
sa.column("role", sa.String),
|
||||
sa.column("account_type", sa.String),
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ------------------------------------------------------------------
|
||||
# Step 1: Backfill account_type from role.
|
||||
# Order matters — most-specific matches first so the final catch-all
|
||||
# only touches rows that haven't been classified yet.
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
# 1a. API key virtual users → SERVICE_ACCOUNT
|
||||
op.execute(
|
||||
sa.update(user_table)
|
||||
.where(
|
||||
user_table.c.email.ilike(API_KEY_EMAIL_PATTERN),
|
||||
user_table.c.account_type.is_(None),
|
||||
)
|
||||
.values(account_type=_SERVICE_ACCOUNT)
|
||||
)
|
||||
|
||||
# 1b. Anonymous user → ANONYMOUS
|
||||
op.execute(
|
||||
sa.update(user_table)
|
||||
.where(
|
||||
user_table.c.id == ANONYMOUS_USER_ID,
|
||||
user_table.c.account_type.is_(None),
|
||||
)
|
||||
.values(account_type=_ANONYMOUS)
|
||||
)
|
||||
|
||||
# 1c. SLACK_USER role → BOT
|
||||
op.execute(
|
||||
sa.update(user_table)
|
||||
.where(
|
||||
user_table.c.role == "SLACK_USER",
|
||||
user_table.c.account_type.is_(None),
|
||||
)
|
||||
.values(account_type=_BOT)
|
||||
)
|
||||
|
||||
# 1d. EXT_PERM_USER role → EXT_PERM_USER
|
||||
op.execute(
|
||||
sa.update(user_table)
|
||||
.where(
|
||||
user_table.c.role == "EXT_PERM_USER",
|
||||
user_table.c.account_type.is_(None),
|
||||
)
|
||||
.values(account_type=_EXT_PERM_USER)
|
||||
)
|
||||
|
||||
# 1e. Everything else → STANDARD
|
||||
op.execute(
|
||||
sa.update(user_table)
|
||||
.where(user_table.c.account_type.is_(None))
|
||||
.values(account_type=_STANDARD)
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 2: Set account_type to NOT NULL now that every row is filled.
|
||||
# ------------------------------------------------------------------
|
||||
op.alter_column(
|
||||
"user",
|
||||
"account_type",
|
||||
nullable=False,
|
||||
server_default="STANDARD",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.alter_column("user", "account_type", nullable=True, server_default=None)
|
||||
op.execute(sa.update(user_table).values(account_type=None))
|
||||
@@ -0,0 +1,104 @@
|
||||
"""add_effective_permissions
|
||||
|
||||
Adds a JSONB column `effective_permissions` to the user table to store
|
||||
directly granted permissions (e.g. ["admin"] or ["basic"]). Implied
|
||||
permissions are expanded at read time, not stored.
|
||||
|
||||
Backfill: joins user__user_group → permission_grant to collect each
|
||||
user's granted permissions into a JSON array. Users without group
|
||||
memberships keep the default [].
|
||||
|
||||
Revision ID: 503883791c39
|
||||
Revises: b4b7e1028dfd
|
||||
Create Date: 2026-03-30 14:49:22.261748
|
||||
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "503883791c39"
|
||||
down_revision = "b4b7e1028dfd"
|
||||
branch_labels: str | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
user_table = sa.table(
|
||||
"user",
|
||||
sa.column("id", sa.Uuid),
|
||||
sa.column("effective_permissions", postgresql.JSONB),
|
||||
)
|
||||
|
||||
user_user_group = sa.table(
|
||||
"user__user_group",
|
||||
sa.column("user_id", sa.Uuid),
|
||||
sa.column("user_group_id", sa.Integer),
|
||||
)
|
||||
|
||||
permission_grant = sa.table(
|
||||
"permission_grant",
|
||||
sa.column("group_id", sa.Integer),
|
||||
sa.column("permission", sa.String),
|
||||
sa.column("is_deleted", sa.Boolean),
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column(
|
||||
"effective_permissions",
|
||||
postgresql.JSONB(),
|
||||
nullable=False,
|
||||
server_default=sa.text("'[]'::jsonb"),
|
||||
),
|
||||
)
|
||||
|
||||
conn = op.get_bind()
|
||||
|
||||
# Deduplicated permissions per user
|
||||
deduped = (
|
||||
sa.select(
|
||||
user_user_group.c.user_id,
|
||||
permission_grant.c.permission,
|
||||
)
|
||||
.select_from(
|
||||
user_user_group.join(
|
||||
permission_grant,
|
||||
sa.and_(
|
||||
permission_grant.c.group_id == user_user_group.c.user_group_id,
|
||||
permission_grant.c.is_deleted == sa.false(),
|
||||
),
|
||||
)
|
||||
)
|
||||
.distinct()
|
||||
.subquery("deduped")
|
||||
)
|
||||
|
||||
# Aggregate into JSONB array per user (order is not guaranteed;
|
||||
# consumers read this as a set so ordering does not matter)
|
||||
perms_per_user = (
|
||||
sa.select(
|
||||
deduped.c.user_id,
|
||||
sa.func.jsonb_agg(
|
||||
deduped.c.permission,
|
||||
type_=postgresql.JSONB,
|
||||
).label("perms"),
|
||||
)
|
||||
.group_by(deduped.c.user_id)
|
||||
.subquery("sub")
|
||||
)
|
||||
|
||||
conn.execute(
|
||||
user_table.update()
|
||||
.where(user_table.c.id == perms_per_user.c.user_id)
|
||||
.values(effective_permissions=perms_per_user.c.perms)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "effective_permissions")
|
||||
139
backend/alembic/versions/977e834c1427_seed_default_groups.py
Normal file
139
backend/alembic/versions/977e834c1427_seed_default_groups.py
Normal file
@@ -0,0 +1,139 @@
|
||||
"""seed_default_groups
|
||||
|
||||
Revision ID: 977e834c1427
|
||||
Revises: 8188861f4e92
|
||||
Create Date: 2026-03-25 14:59:41.313091
|
||||
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "977e834c1427"
|
||||
down_revision = "8188861f4e92"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
# (group_name, permission_value)
|
||||
DEFAULT_GROUPS = [
|
||||
("Admin", "admin"),
|
||||
("Basic", "basic"),
|
||||
]
|
||||
|
||||
CUSTOM_SUFFIX = "(Custom)"
|
||||
|
||||
MAX_RENAME_ATTEMPTS = 100
|
||||
|
||||
# Reflect table structures for use in DML
|
||||
user_group_table = sa.table(
|
||||
"user_group",
|
||||
sa.column("id", sa.Integer),
|
||||
sa.column("name", sa.String),
|
||||
sa.column("is_up_to_date", sa.Boolean),
|
||||
sa.column("is_up_for_deletion", sa.Boolean),
|
||||
sa.column("is_default", sa.Boolean),
|
||||
)
|
||||
|
||||
permission_grant_table = sa.table(
|
||||
"permission_grant",
|
||||
sa.column("group_id", sa.Integer),
|
||||
sa.column("permission", sa.String),
|
||||
sa.column("grant_source", sa.String),
|
||||
)
|
||||
|
||||
user__user_group_table = sa.table(
|
||||
"user__user_group",
|
||||
sa.column("user_group_id", sa.Integer),
|
||||
sa.column("user_id", sa.Uuid),
|
||||
)
|
||||
|
||||
|
||||
def _find_available_name(conn: sa.engine.Connection, base: str) -> str:
|
||||
"""Return a name like 'Admin (Custom)' or 'Admin (Custom 2)' that is not taken."""
|
||||
candidate = f"{base} {CUSTOM_SUFFIX}"
|
||||
attempt = 1
|
||||
while attempt <= MAX_RENAME_ATTEMPTS:
|
||||
exists: Any = conn.execute(
|
||||
sa.select(sa.literal(1))
|
||||
.select_from(user_group_table)
|
||||
.where(user_group_table.c.name == candidate)
|
||||
.limit(1)
|
||||
).fetchone()
|
||||
if exists is None:
|
||||
return candidate
|
||||
attempt += 1
|
||||
candidate = f"{base} (Custom {attempt})"
|
||||
raise RuntimeError(
|
||||
f"Could not find an available name for group '{base}' "
|
||||
f"after {MAX_RENAME_ATTEMPTS} attempts"
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
for group_name, permission_value in DEFAULT_GROUPS:
|
||||
# Step 1: Rename ALL existing groups that clash with the canonical name.
|
||||
conflicting = conn.execute(
|
||||
sa.select(user_group_table.c.id, user_group_table.c.name).where(
|
||||
user_group_table.c.name == group_name
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
for row_id, row_name in conflicting:
|
||||
new_name = _find_available_name(conn, row_name)
|
||||
op.execute(
|
||||
sa.update(user_group_table)
|
||||
.where(user_group_table.c.id == row_id)
|
||||
.values(name=new_name, is_up_to_date=False)
|
||||
)
|
||||
|
||||
# Step 2: Create a fresh default group.
|
||||
result = conn.execute(
|
||||
user_group_table.insert()
|
||||
.values(
|
||||
name=group_name,
|
||||
is_up_to_date=True,
|
||||
is_up_for_deletion=False,
|
||||
is_default=True,
|
||||
)
|
||||
.returning(user_group_table.c.id)
|
||||
).fetchone()
|
||||
assert result is not None
|
||||
group_id = result[0]
|
||||
|
||||
# Step 3: Upsert permission grant.
|
||||
op.execute(
|
||||
pg_insert(permission_grant_table)
|
||||
.values(
|
||||
group_id=group_id,
|
||||
permission=permission_value,
|
||||
grant_source="SYSTEM",
|
||||
)
|
||||
.on_conflict_do_nothing(index_elements=["group_id", "permission"])
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Remove the default groups created by this migration.
|
||||
# First remove user-group memberships that reference default groups
|
||||
# to avoid FK violations, then delete the groups themselves.
|
||||
default_group_ids = sa.select(user_group_table.c.id).where(
|
||||
user_group_table.c.is_default == True # noqa: E712
|
||||
)
|
||||
conn = op.get_bind()
|
||||
conn.execute(
|
||||
sa.delete(user__user_group_table).where(
|
||||
user__user_group_table.c.user_group_id.in_(default_group_ids)
|
||||
)
|
||||
)
|
||||
conn.execute(
|
||||
sa.delete(user_group_table).where(
|
||||
user_group_table.c.is_default == True # noqa: E712
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,84 @@
|
||||
"""grant_basic_to_existing_groups
|
||||
|
||||
Grants the "basic" permission to all existing groups that don't already
|
||||
have it. Every group should have at least "basic" so that its members
|
||||
get basic access when effective_permissions is backfilled.
|
||||
|
||||
Revision ID: b4b7e1028dfd
|
||||
Revises: b7bcc991d722
|
||||
Create Date: 2026-03-30 16:15:17.093498
|
||||
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b4b7e1028dfd"
|
||||
down_revision = "b7bcc991d722"
|
||||
branch_labels: str | None = None
|
||||
depends_on: str | Sequence[str] | None = None
|
||||
|
||||
user_group = sa.table(
|
||||
"user_group",
|
||||
sa.column("id", sa.Integer),
|
||||
sa.column("is_default", sa.Boolean),
|
||||
)
|
||||
|
||||
permission_grant = sa.table(
|
||||
"permission_grant",
|
||||
sa.column("group_id", sa.Integer),
|
||||
sa.column("permission", sa.String),
|
||||
sa.column("grant_source", sa.String),
|
||||
sa.column("is_deleted", sa.Boolean),
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
already_has_basic = (
|
||||
sa.select(sa.literal(1))
|
||||
.select_from(permission_grant)
|
||||
.where(
|
||||
permission_grant.c.group_id == user_group.c.id,
|
||||
permission_grant.c.permission == "basic",
|
||||
)
|
||||
.exists()
|
||||
)
|
||||
|
||||
groups_needing_basic = sa.select(
|
||||
user_group.c.id,
|
||||
sa.literal("basic").label("permission"),
|
||||
sa.literal("SYSTEM").label("grant_source"),
|
||||
sa.literal(False).label("is_deleted"),
|
||||
).where(
|
||||
user_group.c.is_default == sa.false(),
|
||||
~already_has_basic,
|
||||
)
|
||||
|
||||
conn.execute(
|
||||
permission_grant.insert().from_select(
|
||||
["group_id", "permission", "grant_source", "is_deleted"],
|
||||
groups_needing_basic,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
non_default_group_ids = sa.select(user_group.c.id).where(
|
||||
user_group.c.is_default == sa.false()
|
||||
)
|
||||
|
||||
conn.execute(
|
||||
permission_grant.delete().where(
|
||||
permission_grant.c.permission == "basic",
|
||||
permission_grant.c.grant_source == "SYSTEM",
|
||||
permission_grant.c.group_id.in_(non_default_group_ids),
|
||||
)
|
||||
)
|
||||
@@ -0,0 +1,125 @@
|
||||
"""assign_users_to_default_groups
|
||||
|
||||
Revision ID: b7bcc991d722
|
||||
Revises: 03d085c5c38d
|
||||
Create Date: 2026-03-25 16:30:39.529301
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import insert as pg_insert
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "b7bcc991d722"
|
||||
down_revision = "03d085c5c38d"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
# The no-auth placeholder user must NOT be assigned to default groups.
|
||||
# A database trigger (migrate_no_auth_data_to_user) will try to DELETE this
|
||||
# user when the first real user registers; group membership rows would cause
|
||||
# an FK violation on that DELETE.
|
||||
NO_AUTH_PLACEHOLDER_USER_UUID = "00000000-0000-0000-0000-000000000001"
|
||||
|
||||
# Reflect table structures for use in DML
|
||||
user_group_table = sa.table(
|
||||
"user_group",
|
||||
sa.column("id", sa.Integer),
|
||||
sa.column("name", sa.String),
|
||||
sa.column("is_default", sa.Boolean),
|
||||
)
|
||||
|
||||
user_table = sa.table(
|
||||
"user",
|
||||
sa.column("id", sa.Uuid),
|
||||
sa.column("role", sa.String),
|
||||
sa.column("account_type", sa.String),
|
||||
sa.column("is_active", sa.Boolean),
|
||||
)
|
||||
|
||||
user__user_group_table = sa.table(
|
||||
"user__user_group",
|
||||
sa.column("user_group_id", sa.Integer),
|
||||
sa.column("user_id", sa.Uuid),
|
||||
)
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
conn = op.get_bind()
|
||||
|
||||
# Look up default group IDs
|
||||
admin_row = conn.execute(
|
||||
sa.select(user_group_table.c.id).where(
|
||||
user_group_table.c.name == "Admin",
|
||||
user_group_table.c.is_default == True, # noqa: E712
|
||||
)
|
||||
).fetchone()
|
||||
|
||||
basic_row = conn.execute(
|
||||
sa.select(user_group_table.c.id).where(
|
||||
user_group_table.c.name == "Basic",
|
||||
user_group_table.c.is_default == True, # noqa: E712
|
||||
)
|
||||
).fetchone()
|
||||
|
||||
if admin_row is None:
|
||||
raise RuntimeError(
|
||||
"Default 'Admin' group not found. "
|
||||
"Ensure migration 977e834c1427 (seed_default_groups) ran successfully."
|
||||
)
|
||||
|
||||
if basic_row is None:
|
||||
raise RuntimeError(
|
||||
"Default 'Basic' group not found. "
|
||||
"Ensure migration 977e834c1427 (seed_default_groups) ran successfully."
|
||||
)
|
||||
|
||||
# Users with role=admin → Admin group
|
||||
# Include inactive users so reactivation doesn't require reconciliation.
|
||||
# Exclude non-human account types (mirrors assign_user_to_default_groups logic).
|
||||
admin_users = sa.select(
|
||||
sa.literal(admin_row[0]).label("user_group_id"),
|
||||
user_table.c.id.label("user_id"),
|
||||
).where(
|
||||
user_table.c.role == "ADMIN",
|
||||
user_table.c.account_type.notin_(["BOT", "EXT_PERM_USER", "ANONYMOUS"]),
|
||||
user_table.c.id != NO_AUTH_PLACEHOLDER_USER_UUID,
|
||||
)
|
||||
op.execute(
|
||||
pg_insert(user__user_group_table)
|
||||
.from_select(["user_group_id", "user_id"], admin_users)
|
||||
.on_conflict_do_nothing(index_elements=["user_group_id", "user_id"])
|
||||
)
|
||||
|
||||
# STANDARD users (non-admin) and SERVICE_ACCOUNT users (role=basic) → Basic group
|
||||
# Include inactive users so reactivation doesn't require reconciliation.
|
||||
basic_users = sa.select(
|
||||
sa.literal(basic_row[0]).label("user_group_id"),
|
||||
user_table.c.id.label("user_id"),
|
||||
).where(
|
||||
user_table.c.account_type.notin_(["BOT", "EXT_PERM_USER", "ANONYMOUS"]),
|
||||
user_table.c.id != NO_AUTH_PLACEHOLDER_USER_UUID,
|
||||
sa.or_(
|
||||
sa.and_(
|
||||
user_table.c.account_type == "STANDARD",
|
||||
user_table.c.role != "ADMIN",
|
||||
),
|
||||
sa.and_(
|
||||
user_table.c.account_type == "SERVICE_ACCOUNT",
|
||||
user_table.c.role == "BASIC",
|
||||
),
|
||||
),
|
||||
)
|
||||
op.execute(
|
||||
pg_insert(user__user_group_table)
|
||||
.from_select(["user_group_id", "user_id"], basic_users)
|
||||
.on_conflict_do_nothing(index_elements=["user_group_id", "user_id"])
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Group memberships are left in place — removing them risks
|
||||
# deleting memberships that existed before this migration.
|
||||
pass
|
||||
@@ -36,13 +36,16 @@ from ee.onyx.server.scim.filtering import ScimFilter
|
||||
from ee.onyx.server.scim.filtering import ScimFilterOperator
|
||||
from ee.onyx.server.scim.models import ScimMappingFields
|
||||
from onyx.db.dal import DAL
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.enums import GrantSource
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import PermissionGrant
|
||||
from onyx.db.models import ScimGroupMapping
|
||||
from onyx.db.models import ScimToken
|
||||
from onyx.db.models import ScimUserMapping
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -280,7 +283,9 @@ class ScimDAL(DAL):
|
||||
query = (
|
||||
select(User)
|
||||
.join(ScimUserMapping, ScimUserMapping.user_id == User.id)
|
||||
.where(User.role.notin_([UserRole.SLACK_USER, UserRole.EXT_PERM_USER]))
|
||||
.where(
|
||||
User.account_type.notin_([AccountType.BOT, AccountType.EXT_PERM_USER])
|
||||
)
|
||||
)
|
||||
|
||||
if scim_filter:
|
||||
@@ -521,6 +526,22 @@ class ScimDAL(DAL):
|
||||
self._session.add(group)
|
||||
self._session.flush()
|
||||
|
||||
def add_permission_grant_to_group(
|
||||
self,
|
||||
group_id: int,
|
||||
permission: Permission,
|
||||
grant_source: GrantSource,
|
||||
) -> None:
|
||||
"""Grant a permission to a group and flush."""
|
||||
self._session.add(
|
||||
PermissionGrant(
|
||||
group_id=group_id,
|
||||
permission=permission,
|
||||
grant_source=grant_source,
|
||||
)
|
||||
)
|
||||
self._session.flush()
|
||||
|
||||
def update_group(
|
||||
self,
|
||||
group: UserGroup,
|
||||
|
||||
@@ -19,6 +19,8 @@ from onyx.configs.app_configs import DISABLE_VECTOR_DB
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.enums import GrantSource
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Credential
|
||||
from onyx.db.models import Credential__UserGroup
|
||||
@@ -28,6 +30,7 @@ from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import DocumentSet__UserGroup
|
||||
from onyx.db.models import FederatedConnector__DocumentSet
|
||||
from onyx.db.models import LLMProvider__UserGroup
|
||||
from onyx.db.models import PermissionGrant
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__UserGroup
|
||||
from onyx.db.models import TokenRateLimit__UserGroup
|
||||
@@ -36,6 +39,7 @@ from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.models import UserGroup__ConnectorCredentialPair
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.db.permissions import recompute_user_permissions__no_commit
|
||||
from onyx.db.users import fetch_user_by_id
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -255,6 +259,7 @@ def fetch_user_groups(
|
||||
db_session: Session,
|
||||
only_up_to_date: bool = True,
|
||||
eager_load_for_snapshot: bool = False,
|
||||
include_default: bool = True,
|
||||
) -> Sequence[UserGroup]:
|
||||
"""
|
||||
Fetches user groups from the database.
|
||||
@@ -269,6 +274,7 @@ def fetch_user_groups(
|
||||
to include only up to date user groups. Defaults to `True`.
|
||||
eager_load_for_snapshot: If True, adds eager loading for all relationships
|
||||
needed by UserGroup.from_model snapshot creation.
|
||||
include_default: If False, excludes system default groups (is_default=True).
|
||||
|
||||
Returns:
|
||||
Sequence[UserGroup]: A sequence of `UserGroup` objects matching the query criteria.
|
||||
@@ -276,6 +282,8 @@ def fetch_user_groups(
|
||||
stmt = select(UserGroup)
|
||||
if only_up_to_date:
|
||||
stmt = stmt.where(UserGroup.is_up_to_date == True) # noqa: E712
|
||||
if not include_default:
|
||||
stmt = stmt.where(UserGroup.is_default == False) # noqa: E712
|
||||
if eager_load_for_snapshot:
|
||||
stmt = _add_user_group_snapshot_eager_loads(stmt)
|
||||
return db_session.scalars(stmt).unique().all()
|
||||
@@ -286,6 +294,7 @@ def fetch_user_groups_for_user(
|
||||
user_id: UUID,
|
||||
only_curator_groups: bool = False,
|
||||
eager_load_for_snapshot: bool = False,
|
||||
include_default: bool = True,
|
||||
) -> Sequence[UserGroup]:
|
||||
stmt = (
|
||||
select(UserGroup)
|
||||
@@ -295,6 +304,8 @@ def fetch_user_groups_for_user(
|
||||
)
|
||||
if only_curator_groups:
|
||||
stmt = stmt.where(User__UserGroup.is_curator == True) # noqa: E712
|
||||
if not include_default:
|
||||
stmt = stmt.where(UserGroup.is_default == False) # noqa: E712
|
||||
if eager_load_for_snapshot:
|
||||
stmt = _add_user_group_snapshot_eager_loads(stmt)
|
||||
return db_session.scalars(stmt).unique().all()
|
||||
@@ -478,6 +489,16 @@ def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserG
|
||||
db_session.add(db_user_group)
|
||||
db_session.flush() # give the group an ID
|
||||
|
||||
# Every group gets the "basic" permission by default
|
||||
db_session.add(
|
||||
PermissionGrant(
|
||||
group_id=db_user_group.id,
|
||||
permission=Permission.BASIC_ACCESS,
|
||||
grant_source=GrantSource.SYSTEM,
|
||||
)
|
||||
)
|
||||
db_session.flush()
|
||||
|
||||
_add_user__user_group_relationships__no_commit(
|
||||
db_session=db_session,
|
||||
user_group_id=db_user_group.id,
|
||||
@@ -489,6 +510,8 @@ def insert_user_group(db_session: Session, user_group: UserGroupCreate) -> UserG
|
||||
cc_pair_ids=user_group.cc_pair_ids,
|
||||
)
|
||||
|
||||
recompute_user_permissions__no_commit(user_group.user_ids, db_session)
|
||||
|
||||
db_session.commit()
|
||||
return db_user_group
|
||||
|
||||
@@ -796,6 +819,10 @@ def update_user_group(
|
||||
# update "time_updated" to now
|
||||
db_user_group.time_last_modified_by_user = func.now()
|
||||
|
||||
recompute_user_permissions__no_commit(
|
||||
list(set(added_user_ids) | set(removed_user_ids)), db_session
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
return db_user_group
|
||||
|
||||
@@ -835,6 +862,19 @@ def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) ->
|
||||
|
||||
_check_user_group_is_modifiable(db_user_group)
|
||||
|
||||
# Collect affected user IDs before cleanup deletes the relationships
|
||||
affected_user_ids: list[UUID] = [
|
||||
uid
|
||||
for uid in db_session.execute(
|
||||
select(User__UserGroup.user_id).where(
|
||||
User__UserGroup.user_group_id == user_group_id
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
if uid is not None
|
||||
]
|
||||
|
||||
_mark_user_group__cc_pair_relationships_outdated__no_commit(
|
||||
db_session=db_session, user_group_id=user_group_id
|
||||
)
|
||||
@@ -863,6 +903,10 @@ def prepare_user_group_for_deletion(db_session: Session, user_group_id: int) ->
|
||||
db_session=db_session, user_group_id=user_group_id
|
||||
)
|
||||
|
||||
# Recompute permissions for affected users now that their
|
||||
# membership in this group has been removed
|
||||
recompute_user_permissions__no_commit(affected_user_ids, db_session)
|
||||
|
||||
db_user_group.is_up_to_date = False
|
||||
db_user_group.is_up_for_deletion = True
|
||||
db_session.commit()
|
||||
|
||||
@@ -52,16 +52,25 @@ from ee.onyx.server.scim.schema_definitions import SERVICE_PROVIDER_CONFIG
|
||||
from ee.onyx.server.scim.schema_definitions import USER_RESOURCE_TYPE
|
||||
from ee.onyx.server.scim.schema_definitions import USER_SCHEMA_DEF
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.enums import GrantSource
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import ScimToken
|
||||
from onyx.db.models import ScimUserMapping
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.db.permissions import recompute_permissions_for_group__no_commit
|
||||
from onyx.db.permissions import recompute_user_permissions__no_commit
|
||||
from onyx.db.users import assign_user_to_default_groups__no_commit
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Group names reserved for system default groups (seeded by migration).
|
||||
_RESERVED_GROUP_NAMES = frozenset({"Admin", "Basic"})
|
||||
|
||||
|
||||
class ScimJSONResponse(JSONResponse):
|
||||
"""JSONResponse with Content-Type: application/scim+json (RFC 7644 §3.1)."""
|
||||
@@ -486,6 +495,7 @@ def create_user(
|
||||
email=email,
|
||||
hashed_password=_pw_helper.hash(_pw_helper.generate()),
|
||||
role=UserRole.BASIC,
|
||||
account_type=AccountType.STANDARD,
|
||||
is_active=user_resource.active,
|
||||
is_verified=True,
|
||||
personal_name=personal_name,
|
||||
@@ -506,13 +516,25 @@ def create_user(
|
||||
scim_username=scim_username,
|
||||
fields=fields,
|
||||
)
|
||||
dal.commit()
|
||||
except IntegrityError:
|
||||
dal.rollback()
|
||||
return _scim_error_response(
|
||||
409, f"User with email {email} already has a SCIM mapping"
|
||||
)
|
||||
|
||||
# Assign user to default group BEFORE commit so everything is atomic.
|
||||
# If this fails, the entire user creation rolls back and IdP can retry.
|
||||
try:
|
||||
assign_user_to_default_groups__no_commit(db_session, user)
|
||||
except Exception:
|
||||
dal.rollback()
|
||||
logger.exception(f"Failed to assign SCIM user {email} to default groups")
|
||||
return _scim_error_response(
|
||||
500, f"Failed to assign user {email} to default group"
|
||||
)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return _scim_resource_response(
|
||||
provider.build_user_resource(
|
||||
user,
|
||||
@@ -542,7 +564,8 @@ def replace_user(
|
||||
user = result
|
||||
|
||||
# Handle activation (need seat check) / deactivation
|
||||
if user_resource.active and not user.is_active:
|
||||
is_reactivation = user_resource.active and not user.is_active
|
||||
if is_reactivation:
|
||||
seat_error = _check_seat_availability(dal)
|
||||
if seat_error:
|
||||
return _scim_error_response(403, seat_error)
|
||||
@@ -556,6 +579,12 @@ def replace_user(
|
||||
personal_name=personal_name,
|
||||
)
|
||||
|
||||
# Reconcile default-group membership on reactivation
|
||||
if is_reactivation:
|
||||
assign_user_to_default_groups__no_commit(
|
||||
db_session, user, is_admin=(user.role == UserRole.ADMIN)
|
||||
)
|
||||
|
||||
new_external_id = user_resource.externalId
|
||||
scim_username = user_resource.userName.strip()
|
||||
fields = _fields_from_resource(user_resource)
|
||||
@@ -621,6 +650,7 @@ def patch_user(
|
||||
return _scim_error_response(e.status, e.detail)
|
||||
|
||||
# Apply changes back to the DB model
|
||||
is_reactivation = patched.active and not user.is_active
|
||||
if patched.active != user.is_active:
|
||||
if patched.active:
|
||||
seat_error = _check_seat_availability(dal)
|
||||
@@ -649,6 +679,12 @@ def patch_user(
|
||||
personal_name=personal_name,
|
||||
)
|
||||
|
||||
# Reconcile default-group membership on reactivation
|
||||
if is_reactivation:
|
||||
assign_user_to_default_groups__no_commit(
|
||||
db_session, user, is_admin=(user.role == UserRole.ADMIN)
|
||||
)
|
||||
|
||||
# Build updated fields by merging PATCH enterprise data with current values
|
||||
cf = current_fields or ScimMappingFields()
|
||||
fields = ScimMappingFields(
|
||||
@@ -857,6 +893,11 @@ def create_group(
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
if group_resource.displayName in _RESERVED_GROUP_NAMES:
|
||||
return _scim_error_response(
|
||||
409, f"'{group_resource.displayName}' is a reserved group name."
|
||||
)
|
||||
|
||||
if dal.get_group_by_name(group_resource.displayName):
|
||||
return _scim_error_response(
|
||||
409, f"Group with name '{group_resource.displayName}' already exists"
|
||||
@@ -879,8 +920,18 @@ def create_group(
|
||||
409, f"Group with name '{group_resource.displayName}' already exists"
|
||||
)
|
||||
|
||||
# Every group gets the "basic" permission by default.
|
||||
dal.add_permission_grant_to_group(
|
||||
group_id=db_group.id,
|
||||
permission=Permission.BASIC_ACCESS,
|
||||
grant_source=GrantSource.SYSTEM,
|
||||
)
|
||||
|
||||
dal.upsert_group_members(db_group.id, member_uuids)
|
||||
|
||||
# Recompute permissions for initial members.
|
||||
recompute_user_permissions__no_commit(member_uuids, db_session)
|
||||
|
||||
external_id = group_resource.externalId
|
||||
if external_id:
|
||||
dal.create_group_mapping(external_id=external_id, user_group_id=db_group.id)
|
||||
@@ -911,14 +962,36 @@ def replace_group(
|
||||
return result
|
||||
group = result
|
||||
|
||||
if group.name in _RESERVED_GROUP_NAMES and group_resource.displayName != group.name:
|
||||
return _scim_error_response(
|
||||
409, f"'{group.name}' is a reserved group name and cannot be renamed."
|
||||
)
|
||||
|
||||
if (
|
||||
group_resource.displayName in _RESERVED_GROUP_NAMES
|
||||
and group_resource.displayName != group.name
|
||||
):
|
||||
return _scim_error_response(
|
||||
409, f"'{group_resource.displayName}' is a reserved group name."
|
||||
)
|
||||
|
||||
member_uuids, err = _validate_and_parse_members(group_resource.members, dal)
|
||||
if err:
|
||||
return _scim_error_response(400, err)
|
||||
|
||||
# Capture old member IDs before replacing so we can recompute their
|
||||
# permissions after they are removed from the group.
|
||||
old_member_ids = {uid for uid, _ in dal.get_group_members(group.id)}
|
||||
|
||||
dal.update_group(group, name=group_resource.displayName)
|
||||
dal.replace_group_members(group.id, member_uuids)
|
||||
dal.sync_group_external_id(group.id, group_resource.externalId)
|
||||
|
||||
# Recompute permissions for current members (batch) and removed members.
|
||||
recompute_permissions_for_group__no_commit(group.id, db_session)
|
||||
removed_ids = list(old_member_ids - set(member_uuids))
|
||||
recompute_user_permissions__no_commit(removed_ids, db_session)
|
||||
|
||||
dal.commit()
|
||||
|
||||
members = dal.get_group_members(group.id)
|
||||
@@ -961,8 +1034,19 @@ def patch_group(
|
||||
return _scim_error_response(e.status, e.detail)
|
||||
|
||||
new_name = patched.displayName if patched.displayName != group.name else None
|
||||
|
||||
if group.name in _RESERVED_GROUP_NAMES and new_name:
|
||||
return _scim_error_response(
|
||||
409, f"'{group.name}' is a reserved group name and cannot be renamed."
|
||||
)
|
||||
|
||||
if new_name and new_name in _RESERVED_GROUP_NAMES:
|
||||
return _scim_error_response(409, f"'{new_name}' is a reserved group name.")
|
||||
|
||||
dal.update_group(group, name=new_name)
|
||||
|
||||
affected_uuids: list[UUID] = []
|
||||
|
||||
if added_ids:
|
||||
add_uuids = [UUID(mid) for mid in added_ids if _is_valid_uuid(mid)]
|
||||
if add_uuids:
|
||||
@@ -973,10 +1057,15 @@ def patch_group(
|
||||
f"Member(s) not found: {', '.join(str(u) for u in missing)}",
|
||||
)
|
||||
dal.upsert_group_members(group.id, add_uuids)
|
||||
affected_uuids.extend(add_uuids)
|
||||
|
||||
if removed_ids:
|
||||
remove_uuids = [UUID(mid) for mid in removed_ids if _is_valid_uuid(mid)]
|
||||
dal.remove_group_members(group.id, remove_uuids)
|
||||
affected_uuids.extend(remove_uuids)
|
||||
|
||||
# Recompute permissions for all users whose group membership changed.
|
||||
recompute_user_permissions__no_commit(affected_uuids, db_session)
|
||||
|
||||
dal.sync_group_external_id(group.id, patched.externalId)
|
||||
dal.commit()
|
||||
@@ -1002,11 +1091,21 @@ def delete_group(
|
||||
return result
|
||||
group = result
|
||||
|
||||
if group.name in _RESERVED_GROUP_NAMES:
|
||||
return _scim_error_response(409, f"'{group.name}' is a reserved group name.")
|
||||
|
||||
# Capture member IDs before deletion so we can recompute their permissions.
|
||||
affected_user_ids = [uid for uid, _ in dal.get_group_members(group.id)]
|
||||
|
||||
mapping = dal.get_group_mapping_by_group_id(group.id)
|
||||
if mapping:
|
||||
dal.delete_group_mapping(mapping.id)
|
||||
|
||||
dal.delete_group_with_members(group)
|
||||
|
||||
# Recompute permissions for users who lost this group membership.
|
||||
recompute_user_permissions__no_commit(affected_user_ids, db_session)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return Response(status_code=204)
|
||||
|
||||
@@ -43,12 +43,16 @@ router = APIRouter(prefix="/manage", tags=PUBLIC_API_TAGS)
|
||||
|
||||
@router.get("/admin/user-group")
|
||||
def list_user_groups(
|
||||
include_default: bool = False,
|
||||
user: User = Depends(current_curator_or_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[UserGroup]:
|
||||
if user.role == UserRole.ADMIN:
|
||||
user_groups = fetch_user_groups(
|
||||
db_session, only_up_to_date=False, eager_load_for_snapshot=True
|
||||
db_session,
|
||||
only_up_to_date=False,
|
||||
eager_load_for_snapshot=True,
|
||||
include_default=include_default,
|
||||
)
|
||||
else:
|
||||
user_groups = fetch_user_groups_for_user(
|
||||
@@ -56,27 +60,50 @@ def list_user_groups(
|
||||
user_id=user.id,
|
||||
only_curator_groups=user.role == UserRole.CURATOR,
|
||||
eager_load_for_snapshot=True,
|
||||
include_default=include_default,
|
||||
)
|
||||
return [UserGroup.from_model(user_group) for user_group in user_groups]
|
||||
|
||||
|
||||
@router.get("/user-groups/minimal")
|
||||
def list_minimal_user_groups(
|
||||
include_default: bool = False,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[MinimalUserGroupSnapshot]:
|
||||
if user.role == UserRole.ADMIN:
|
||||
user_groups = fetch_user_groups(db_session, only_up_to_date=False)
|
||||
user_groups = fetch_user_groups(
|
||||
db_session,
|
||||
only_up_to_date=False,
|
||||
include_default=include_default,
|
||||
)
|
||||
else:
|
||||
user_groups = fetch_user_groups_for_user(
|
||||
db_session=db_session,
|
||||
user_id=user.id,
|
||||
include_default=include_default,
|
||||
)
|
||||
return [
|
||||
MinimalUserGroupSnapshot.from_model(user_group) for user_group in user_groups
|
||||
]
|
||||
|
||||
|
||||
@router.get("/admin/user-group/{user_group_id}/permissions")
|
||||
def get_user_group_permissions(
|
||||
user_group_id: int,
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[str]:
|
||||
group = fetch_user_group(db_session, user_group_id)
|
||||
if group is None:
|
||||
raise OnyxError(OnyxErrorCode.NOT_FOUND, "User group not found")
|
||||
return [
|
||||
grant.permission.value
|
||||
for grant in group.permission_grants
|
||||
if not grant.is_deleted
|
||||
]
|
||||
|
||||
|
||||
@router.post("/admin/user-group")
|
||||
def create_user_group(
|
||||
user_group: UserGroupCreate,
|
||||
@@ -100,6 +127,9 @@ def rename_user_group_endpoint(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UserGroup:
|
||||
group = fetch_user_group(db_session, rename_request.id)
|
||||
if group and group.is_default:
|
||||
raise OnyxError(OnyxErrorCode.CONFLICT, "Cannot rename a default system group.")
|
||||
try:
|
||||
return UserGroup.from_model(
|
||||
rename_user_group(
|
||||
@@ -185,6 +215,9 @@ def delete_user_group(
|
||||
_: User = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
group = fetch_user_group(db_session, user_group_id)
|
||||
if group and group.is_default:
|
||||
raise OnyxError(OnyxErrorCode.CONFLICT, "Cannot delete a default system group.")
|
||||
try:
|
||||
prepare_user_group_for_deletion(db_session, user_group_id)
|
||||
except ValueError as e:
|
||||
|
||||
@@ -22,6 +22,7 @@ class UserGroup(BaseModel):
|
||||
personas: list[PersonaSnapshot]
|
||||
is_up_to_date: bool
|
||||
is_up_for_deletion: bool
|
||||
is_default: bool
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, user_group_model: UserGroupModel) -> "UserGroup":
|
||||
@@ -74,18 +75,21 @@ class UserGroup(BaseModel):
|
||||
],
|
||||
is_up_to_date=user_group_model.is_up_to_date,
|
||||
is_up_for_deletion=user_group_model.is_up_for_deletion,
|
||||
is_default=user_group_model.is_default,
|
||||
)
|
||||
|
||||
|
||||
class MinimalUserGroupSnapshot(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
is_default: bool
|
||||
|
||||
@classmethod
|
||||
def from_model(cls, user_group_model: UserGroupModel) -> "MinimalUserGroupSnapshot":
|
||||
return cls(
|
||||
id=user_group_model.id,
|
||||
name=user_group_model.name,
|
||||
is_default=user_group_model.is_default,
|
||||
)
|
||||
|
||||
|
||||
|
||||
110
backend/onyx/auth/permissions.py
Normal file
110
backend/onyx/auth/permissions.py
Normal file
@@ -0,0 +1,110 @@
|
||||
"""
|
||||
Permission resolution for group-based authorization.
|
||||
|
||||
Granted permissions are stored as a JSONB column on the User table and
|
||||
loaded for free with every auth query. Implied permissions are expanded
|
||||
at read time — only directly granted permissions are persisted.
|
||||
"""
|
||||
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Coroutine
|
||||
from typing import Any
|
||||
|
||||
from fastapi import Depends
|
||||
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import User
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
ALL_PERMISSIONS: frozenset[str] = frozenset(p.value for p in Permission)
|
||||
|
||||
# Implication map: granted permission -> set of permissions it implies.
|
||||
IMPLIED_PERMISSIONS: dict[str, set[str]] = {
|
||||
Permission.ADD_AGENTS.value: {Permission.READ_AGENTS.value},
|
||||
Permission.MANAGE_AGENTS.value: {
|
||||
Permission.ADD_AGENTS.value,
|
||||
Permission.READ_AGENTS.value,
|
||||
},
|
||||
Permission.MANAGE_DOCUMENT_SETS.value: {
|
||||
Permission.READ_DOCUMENT_SETS.value,
|
||||
Permission.READ_CONNECTORS.value,
|
||||
},
|
||||
Permission.ADD_CONNECTORS.value: {Permission.READ_CONNECTORS.value},
|
||||
Permission.MANAGE_CONNECTORS.value: {
|
||||
Permission.ADD_CONNECTORS.value,
|
||||
Permission.READ_CONNECTORS.value,
|
||||
},
|
||||
Permission.MANAGE_USER_GROUPS.value: {
|
||||
Permission.READ_CONNECTORS.value,
|
||||
Permission.READ_DOCUMENT_SETS.value,
|
||||
Permission.READ_AGENTS.value,
|
||||
Permission.READ_USERS.value,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def resolve_effective_permissions(granted: set[str]) -> set[str]:
|
||||
"""Expand granted permissions with their implied permissions.
|
||||
|
||||
If "admin" is present, returns all 19 permissions.
|
||||
"""
|
||||
if Permission.FULL_ADMIN_PANEL_ACCESS.value in granted:
|
||||
return set(ALL_PERMISSIONS)
|
||||
|
||||
effective = set(granted)
|
||||
changed = True
|
||||
while changed:
|
||||
changed = False
|
||||
for perm in list(effective):
|
||||
implied = IMPLIED_PERMISSIONS.get(perm)
|
||||
if implied and not implied.issubset(effective):
|
||||
effective |= implied
|
||||
changed = True
|
||||
return effective
|
||||
|
||||
|
||||
def get_effective_permissions(user: User) -> set[Permission]:
|
||||
"""Read granted permissions from the column and expand implied permissions."""
|
||||
granted: set[Permission] = set()
|
||||
for p in user.effective_permissions:
|
||||
try:
|
||||
granted.add(Permission(p))
|
||||
except ValueError:
|
||||
logger.warning(f"Skipping unknown permission '{p}' for user {user.id}")
|
||||
if Permission.FULL_ADMIN_PANEL_ACCESS in granted:
|
||||
return set(Permission)
|
||||
expanded = resolve_effective_permissions({p.value for p in granted})
|
||||
return {Permission(p) for p in expanded}
|
||||
|
||||
|
||||
def require_permission(
|
||||
required: Permission,
|
||||
) -> Callable[..., Coroutine[Any, Any, User]]:
|
||||
"""FastAPI dependency factory for permission-based access control.
|
||||
|
||||
Usage:
|
||||
@router.get("/endpoint")
|
||||
def endpoint(user: User = Depends(require_permission(Permission.MANAGE_CONNECTORS))):
|
||||
...
|
||||
"""
|
||||
|
||||
async def dependency(user: User = Depends(current_user)) -> User:
|
||||
effective = get_effective_permissions(user)
|
||||
|
||||
if Permission.FULL_ADMIN_PANEL_ACCESS in effective:
|
||||
return user
|
||||
|
||||
if required not in effective:
|
||||
raise OnyxError(
|
||||
OnyxErrorCode.INSUFFICIENT_PERMISSIONS,
|
||||
"You do not have the required permissions for this action.",
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
return dependency
|
||||
@@ -5,6 +5,8 @@ from typing import Any
|
||||
from fastapi_users import schemas
|
||||
from typing_extensions import override
|
||||
|
||||
from onyx.db.enums import AccountType
|
||||
|
||||
|
||||
class UserRole(str, Enum):
|
||||
"""
|
||||
@@ -41,6 +43,7 @@ class UserRead(schemas.BaseUser[uuid.UUID]):
|
||||
|
||||
class UserCreate(schemas.BaseUserCreate):
|
||||
role: UserRole = UserRole.BASIC
|
||||
account_type: AccountType = AccountType.STANDARD
|
||||
tenant_id: str | None = None
|
||||
# Captcha token for cloud signup protection (optional, only used when captcha is enabled)
|
||||
# Excluded from create_update_dict so it never reaches the DB layer
|
||||
@@ -50,19 +53,19 @@ class UserCreate(schemas.BaseUserCreate):
|
||||
def create_update_dict(self) -> dict[str, Any]:
|
||||
d = super().create_update_dict()
|
||||
d.pop("captcha_token", None)
|
||||
# Force STANDARD for self-registration; only trusted paths
|
||||
# (SCIM, API key creation) supply a different account_type directly.
|
||||
d["account_type"] = AccountType.STANDARD
|
||||
return d
|
||||
|
||||
@override
|
||||
def create_update_dict_superuser(self) -> dict[str, Any]:
|
||||
d = super().create_update_dict_superuser()
|
||||
d.pop("captcha_token", None)
|
||||
d.setdefault("account_type", self.account_type)
|
||||
return d
|
||||
|
||||
|
||||
class UserUpdateWithRole(schemas.BaseUserUpdate):
|
||||
role: UserRole
|
||||
|
||||
|
||||
class UserUpdate(schemas.BaseUserUpdate):
|
||||
"""
|
||||
Role updates are not allowed through the user update endpoint for security reasons
|
||||
|
||||
@@ -80,7 +80,6 @@ from onyx.auth.pat import get_hashed_pat_from_request
|
||||
from onyx.auth.schemas import AuthBackend
|
||||
from onyx.auth.schemas import UserCreate
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.auth.schemas import UserUpdateWithRole
|
||||
from onyx.configs.app_configs import AUTH_BACKEND
|
||||
from onyx.configs.app_configs import AUTH_COOKIE_EXPIRE_TIME_SECONDS
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
@@ -120,11 +119,13 @@ from onyx.db.engine.async_sql_engine import get_async_session
|
||||
from onyx.db.engine.async_sql_engine import get_async_session_context_manager
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import AccessToken
|
||||
from onyx.db.models import OAuthAccount
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.db.pat import fetch_user_for_pat
|
||||
from onyx.db.users import assign_user_to_default_groups__no_commit
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import log_onyx_error
|
||||
@@ -500,18 +501,21 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
user = user_by_session
|
||||
|
||||
if (
|
||||
user.role.is_web_login()
|
||||
user.account_type.is_web_login()
|
||||
or not isinstance(user_create, UserCreate)
|
||||
or not user_create.role.is_web_login()
|
||||
or not user_create.account_type.is_web_login()
|
||||
):
|
||||
raise exceptions.UserAlreadyExists()
|
||||
|
||||
user_update = UserUpdateWithRole(
|
||||
password=user_create.password,
|
||||
is_verified=user_create.is_verified,
|
||||
role=user_create.role,
|
||||
)
|
||||
user = await self.update(user_update, user)
|
||||
# Cache id before expire — accessing attrs on an expired
|
||||
# object triggers a sync lazy-load which raises MissingGreenlet
|
||||
# in this async context.
|
||||
user_id = user.id
|
||||
self._upgrade_user_to_standard__sync(user_id, user_create)
|
||||
# Expire so the async session re-fetches the row updated by
|
||||
# the sync session above.
|
||||
self.user_db.session.expire(user)
|
||||
user = await self.user_db.get(user_id) # type: ignore[assignment]
|
||||
except exceptions.UserAlreadyExists:
|
||||
user = await self.get_by_email(user_create.email)
|
||||
|
||||
@@ -525,18 +529,21 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
|
||||
# Handle case where user has used product outside of web and is now creating an account through web
|
||||
if (
|
||||
user.role.is_web_login()
|
||||
user.account_type.is_web_login()
|
||||
or not isinstance(user_create, UserCreate)
|
||||
or not user_create.role.is_web_login()
|
||||
or not user_create.account_type.is_web_login()
|
||||
):
|
||||
raise exceptions.UserAlreadyExists()
|
||||
|
||||
user_update = UserUpdateWithRole(
|
||||
password=user_create.password,
|
||||
is_verified=user_create.is_verified,
|
||||
role=user_create.role,
|
||||
)
|
||||
user = await self.update(user_update, user)
|
||||
# Cache id before expire — accessing attrs on an expired
|
||||
# object triggers a sync lazy-load which raises MissingGreenlet
|
||||
# in this async context.
|
||||
user_id = user.id
|
||||
self._upgrade_user_to_standard__sync(user_id, user_create)
|
||||
# Expire so the async session re-fetches the row updated by
|
||||
# the sync session above.
|
||||
self.user_db.session.expire(user)
|
||||
user = await self.user_db.get(user_id) # type: ignore[assignment]
|
||||
if user_created:
|
||||
await self._assign_default_pinned_assistants(user, db_session)
|
||||
remove_user_from_invited_users(user_create.email)
|
||||
@@ -573,6 +580,38 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
)
|
||||
user.pinned_assistants = default_persona_ids
|
||||
|
||||
def _upgrade_user_to_standard__sync(
|
||||
self,
|
||||
user_id: uuid.UUID,
|
||||
user_create: UserCreate,
|
||||
) -> None:
|
||||
"""Upgrade a non-web user to STANDARD and assign default groups atomically.
|
||||
|
||||
All writes happen in a single sync transaction so neither the field
|
||||
update nor the group assignment is visible without the other.
|
||||
"""
|
||||
with get_session_with_current_tenant() as sync_db:
|
||||
sync_user = sync_db.query(User).filter(User.id == user_id).first() # type: ignore[arg-type]
|
||||
if sync_user:
|
||||
sync_user.hashed_password = self.password_helper.hash(
|
||||
user_create.password
|
||||
)
|
||||
sync_user.is_verified = user_create.is_verified or False
|
||||
sync_user.role = user_create.role
|
||||
sync_user.account_type = AccountType.STANDARD
|
||||
assign_user_to_default_groups__no_commit(
|
||||
sync_db,
|
||||
sync_user,
|
||||
is_admin=(user_create.role == UserRole.ADMIN),
|
||||
)
|
||||
sync_db.commit()
|
||||
else:
|
||||
logger.warning(
|
||||
"User %s not found in sync session during upgrade to standard; "
|
||||
"skipping upgrade",
|
||||
user_id,
|
||||
)
|
||||
|
||||
async def validate_password(self, password: str, _: schemas.UC | models.UP) -> None:
|
||||
# Validate password according to configurable security policy (defined via environment variables)
|
||||
if len(password) < PASSWORD_MIN_LENGTH:
|
||||
@@ -694,6 +733,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
"email": account_email,
|
||||
"hashed_password": self.password_helper.hash(password),
|
||||
"is_verified": is_verified_by_default,
|
||||
"account_type": AccountType.STANDARD,
|
||||
}
|
||||
|
||||
user = await self.user_db.create(user_dict)
|
||||
@@ -726,7 +766,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
)
|
||||
|
||||
# Handle case where user has used product outside of web and is now creating an account through web
|
||||
if not user.role.is_web_login():
|
||||
if not user.account_type.is_web_login():
|
||||
# We must use the existing user in the session if it matches
|
||||
# the user we just got by email/oauth. Note that this only applies
|
||||
# to multi-tenant, due to the overwriting of the user_db
|
||||
@@ -743,14 +783,25 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
with get_session_with_current_tenant() as sync_db:
|
||||
enforce_seat_limit(sync_db)
|
||||
|
||||
await self.user_db.update(
|
||||
user,
|
||||
{
|
||||
"is_verified": is_verified_by_default,
|
||||
"role": UserRole.BASIC,
|
||||
**({"is_active": True} if not user.is_active else {}),
|
||||
},
|
||||
)
|
||||
# Upgrade the user and assign default groups in a single
|
||||
# transaction so neither change is visible without the other.
|
||||
was_inactive = not user.is_active
|
||||
with get_session_with_current_tenant() as sync_db:
|
||||
sync_user = sync_db.query(User).filter(User.id == user.id).first() # type: ignore[arg-type]
|
||||
if sync_user:
|
||||
sync_user.is_verified = is_verified_by_default
|
||||
sync_user.role = UserRole.BASIC
|
||||
sync_user.account_type = AccountType.STANDARD
|
||||
if was_inactive:
|
||||
sync_user.is_active = True
|
||||
assign_user_to_default_groups__no_commit(sync_db, sync_user)
|
||||
sync_db.commit()
|
||||
|
||||
# Refresh the async user object so downstream code
|
||||
# (e.g. oidc_expiry check) sees the updated fields.
|
||||
self.user_db.session.expire(user)
|
||||
user = await self.user_db.get(user.id)
|
||||
assert user is not None
|
||||
|
||||
# this is needed if an organization goes from `TRACK_EXTERNAL_IDP_EXPIRY=true` to `false`
|
||||
# otherwise, the oidc expiry will always be old, and the user will never be able to login
|
||||
@@ -836,6 +887,16 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
event=MilestoneRecordType.TENANT_CREATED,
|
||||
)
|
||||
|
||||
# Assign user to the appropriate default group (Admin or Basic).
|
||||
# Must happen inside the try block while tenant context is active,
|
||||
# otherwise get_session_with_current_tenant() targets the wrong schema.
|
||||
is_admin = user_count == 1 or user.email in get_default_admin_user_emails()
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
assign_user_to_default_groups__no_commit(
|
||||
db_session, user, is_admin=is_admin
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
@@ -975,7 +1036,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
self.password_helper.hash(credentials.password)
|
||||
return None
|
||||
|
||||
if not user.role.is_web_login():
|
||||
if not user.account_type.is_web_login():
|
||||
raise BasicAuthenticationError(
|
||||
detail="NO_WEB_LOGIN_AND_HAS_NO_PASSWORD",
|
||||
)
|
||||
@@ -1471,7 +1532,7 @@ async def _get_or_create_user_from_jwt(
|
||||
if not user.is_active:
|
||||
logger.warning("Inactive user %s attempted JWT login; skipping", email)
|
||||
return None
|
||||
if not user.role.is_web_login():
|
||||
if not user.account_type.is_web_login():
|
||||
raise exceptions.UserNotExists()
|
||||
except exceptions.UserNotExists:
|
||||
logger.info("Provisioning user %s from JWT login", email)
|
||||
@@ -1492,7 +1553,7 @@ async def _get_or_create_user_from_jwt(
|
||||
email,
|
||||
)
|
||||
return None
|
||||
if not user.role.is_web_login():
|
||||
if not user.account_type.is_web_login():
|
||||
logger.warning(
|
||||
"Non-web-login user %s attempted JWT login during provisioning race; skipping",
|
||||
email,
|
||||
@@ -1554,6 +1615,7 @@ def get_anonymous_user() -> User:
|
||||
is_verified=True,
|
||||
is_superuser=False,
|
||||
role=UserRole.LIMITED,
|
||||
account_type=AccountType.ANONYMOUS,
|
||||
use_memories=False,
|
||||
enable_memory_tool=False,
|
||||
)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import uuid
|
||||
|
||||
from fastapi_users.password import PasswordHelper
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
@@ -10,14 +11,23 @@ from onyx.auth.api_key import ApiKeyDescriptor
|
||||
from onyx.auth.api_key import build_displayable_api_key
|
||||
from onyx.auth.api_key import generate_api_key
|
||||
from onyx.auth.api_key import hash_api_key
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from onyx.configs.constants import DANSWER_API_KEY_PREFIX
|
||||
from onyx.configs.constants import UNNAMED_KEY_PLACEHOLDER
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import ApiKey
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.permissions import recompute_user_permissions__no_commit
|
||||
from onyx.db.users import assign_user_to_default_groups__no_commit
|
||||
from onyx.server.api_key.models import APIKeyArgs
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_api_key_email_pattern() -> str:
|
||||
return DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
@@ -85,6 +95,7 @@ def insert_api_key(
|
||||
is_superuser=False,
|
||||
is_verified=True,
|
||||
role=api_key_args.role,
|
||||
account_type=AccountType.SERVICE_ACCOUNT,
|
||||
)
|
||||
db_session.add(api_key_user_row)
|
||||
|
||||
@@ -97,7 +108,18 @@ def insert_api_key(
|
||||
)
|
||||
db_session.add(api_key_row)
|
||||
|
||||
# Assign the API key virtual user to the appropriate default group
|
||||
# before commit so everything is atomic.
|
||||
# LIMITED role service accounts should have no group membership.
|
||||
if api_key_args.role != UserRole.LIMITED:
|
||||
assign_user_to_default_groups__no_commit(
|
||||
db_session,
|
||||
api_key_user_row,
|
||||
is_admin=(api_key_args.role == UserRole.ADMIN),
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return ApiKeyDescriptor(
|
||||
api_key_id=api_key_row.id,
|
||||
api_key_role=api_key_user_row.role,
|
||||
@@ -124,7 +146,33 @@ def update_api_key(
|
||||
|
||||
email_name = api_key_args.name or UNNAMED_KEY_PLACEHOLDER
|
||||
api_key_user.email = get_api_key_fake_email(email_name, str(api_key_user.id))
|
||||
|
||||
old_role = api_key_user.role
|
||||
api_key_user.role = api_key_args.role
|
||||
|
||||
# Reconcile default-group membership when the role changes.
|
||||
if old_role != api_key_args.role:
|
||||
# Remove from all default groups first.
|
||||
delete_stmt = delete(User__UserGroup).where(
|
||||
User__UserGroup.user_id == api_key_user.id,
|
||||
User__UserGroup.user_group_id.in_(
|
||||
select(UserGroup.id).where(UserGroup.is_default.is_(True))
|
||||
),
|
||||
)
|
||||
db_session.execute(delete_stmt)
|
||||
|
||||
# Re-assign to the correct default group (skip for LIMITED).
|
||||
if api_key_args.role != UserRole.LIMITED:
|
||||
assign_user_to_default_groups__no_commit(
|
||||
db_session,
|
||||
api_key_user,
|
||||
is_admin=(api_key_args.role == UserRole.ADMIN),
|
||||
)
|
||||
else:
|
||||
# No group assigned for LIMITED, but we still need to recompute
|
||||
# since we just removed the old default-group membership above.
|
||||
recompute_user_permissions__no_commit(api_key_user.id, db_session)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
return ApiKeyDescriptor(
|
||||
|
||||
@@ -13,19 +13,26 @@ class AccountType(str, PyEnum):
|
||||
BOT, EXT_PERM_USER, ANONYMOUS → fixed behavior
|
||||
"""
|
||||
|
||||
STANDARD = "standard"
|
||||
BOT = "bot"
|
||||
EXT_PERM_USER = "ext_perm_user"
|
||||
SERVICE_ACCOUNT = "service_account"
|
||||
ANONYMOUS = "anonymous"
|
||||
STANDARD = "STANDARD"
|
||||
BOT = "BOT"
|
||||
EXT_PERM_USER = "EXT_PERM_USER"
|
||||
SERVICE_ACCOUNT = "SERVICE_ACCOUNT"
|
||||
ANONYMOUS = "ANONYMOUS"
|
||||
|
||||
def is_web_login(self) -> bool:
|
||||
"""Whether this account type supports interactive web login."""
|
||||
return self not in (
|
||||
AccountType.BOT,
|
||||
AccountType.EXT_PERM_USER,
|
||||
)
|
||||
|
||||
|
||||
class GrantSource(str, PyEnum):
|
||||
"""How a permission grant was created."""
|
||||
|
||||
USER = "user"
|
||||
SCIM = "scim"
|
||||
SYSTEM = "system"
|
||||
USER = "USER"
|
||||
SCIM = "SCIM"
|
||||
SYSTEM = "SYSTEM"
|
||||
|
||||
|
||||
class IndexingStatus(str, PyEnum):
|
||||
|
||||
@@ -305,8 +305,11 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
role: Mapped[UserRole] = mapped_column(
|
||||
Enum(UserRole, native_enum=False, default=UserRole.BASIC)
|
||||
)
|
||||
account_type: Mapped[AccountType | None] = mapped_column(
|
||||
Enum(AccountType, native_enum=False), nullable=True
|
||||
account_type: Mapped[AccountType] = mapped_column(
|
||||
Enum(AccountType, native_enum=False),
|
||||
nullable=False,
|
||||
default=AccountType.STANDARD,
|
||||
server_default="STANDARD",
|
||||
)
|
||||
|
||||
"""
|
||||
@@ -353,6 +356,13 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
postgresql.JSONB(), nullable=True, default=None
|
||||
)
|
||||
|
||||
effective_permissions: Mapped[list[str]] = mapped_column(
|
||||
postgresql.JSONB(),
|
||||
nullable=False,
|
||||
default=list,
|
||||
server_default=text("'[]'::jsonb"),
|
||||
)
|
||||
|
||||
oidc_expiry: Mapped[datetime.datetime] = mapped_column(
|
||||
TIMESTAMPAware(timezone=True), nullable=True
|
||||
)
|
||||
@@ -4016,7 +4026,12 @@ class PermissionGrant(Base):
|
||||
ForeignKey("user_group.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
permission: Mapped[Permission] = mapped_column(
|
||||
Enum(Permission, native_enum=False), nullable=False
|
||||
Enum(
|
||||
Permission,
|
||||
native_enum=False,
|
||||
values_callable=lambda x: [e.value for e in x],
|
||||
),
|
||||
nullable=False,
|
||||
)
|
||||
grant_source: Mapped[GrantSource] = mapped_column(
|
||||
Enum(GrantSource, native_enum=False), nullable=False
|
||||
|
||||
95
backend/onyx/db/permissions.py
Normal file
95
backend/onyx/db/permissions.py
Normal file
@@ -0,0 +1,95 @@
|
||||
"""
|
||||
DB operations for recomputing user effective_permissions.
|
||||
|
||||
These live in onyx/db/ (not onyx/auth/) because they are pure DB operations
|
||||
that query PermissionGrant rows and update the User.effective_permissions
|
||||
JSONB column. Keeping them here avoids circular imports when called from
|
||||
other onyx/db/ modules such as users.py.
|
||||
"""
|
||||
|
||||
from collections import defaultdict
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import PermissionGrant
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
|
||||
|
||||
def recompute_user_permissions__no_commit(
|
||||
user_ids: UUID | str | list[UUID] | list[str], db_session: Session
|
||||
) -> None:
|
||||
"""Recompute granted permissions for one or more users.
|
||||
|
||||
Accepts a single UUID or a list. Uses a single query regardless of
|
||||
how many users are passed, avoiding N+1 issues.
|
||||
|
||||
Stores only directly granted permissions — implication expansion
|
||||
happens at read time via get_effective_permissions().
|
||||
|
||||
Does NOT commit — caller must commit the session.
|
||||
"""
|
||||
if isinstance(user_ids, (UUID, str)):
|
||||
uid_list = [user_ids]
|
||||
else:
|
||||
uid_list = list(user_ids)
|
||||
|
||||
if not uid_list:
|
||||
return
|
||||
|
||||
# Single query to fetch ALL permissions for these users across ALL their
|
||||
# groups (a user may belong to multiple groups with different grants).
|
||||
rows = db_session.execute(
|
||||
select(User__UserGroup.user_id, PermissionGrant.permission)
|
||||
.join(
|
||||
PermissionGrant,
|
||||
PermissionGrant.group_id == User__UserGroup.user_group_id,
|
||||
)
|
||||
.where(
|
||||
User__UserGroup.user_id.in_(uid_list),
|
||||
PermissionGrant.is_deleted.is_(False),
|
||||
)
|
||||
).all()
|
||||
|
||||
# Group permissions by user; users with no grants get an empty set.
|
||||
perms_by_user: dict[UUID | str, set[str]] = defaultdict(set)
|
||||
for uid in uid_list:
|
||||
perms_by_user[uid] # ensure every user has an entry
|
||||
for uid, perm in rows:
|
||||
perms_by_user[uid].add(perm.value)
|
||||
|
||||
for uid, perms in perms_by_user.items():
|
||||
db_session.execute(
|
||||
update(User)
|
||||
.where(User.id == uid) # type: ignore[arg-type]
|
||||
.values(effective_permissions=sorted(perms))
|
||||
)
|
||||
|
||||
|
||||
def recompute_permissions_for_group__no_commit(
|
||||
group_id: int, db_session: Session
|
||||
) -> None:
|
||||
"""Recompute granted permissions for all users in a group.
|
||||
|
||||
Does NOT commit — caller must commit the session.
|
||||
"""
|
||||
user_ids: list[UUID] = [
|
||||
uid
|
||||
for uid in db_session.execute(
|
||||
select(User__UserGroup.user_id).where(
|
||||
User__UserGroup.user_group_id == group_id,
|
||||
User__UserGroup.user_id.isnot(None),
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
.all()
|
||||
if uid is not None
|
||||
]
|
||||
|
||||
if not user_ids:
|
||||
return
|
||||
|
||||
recompute_user_permissions__no_commit(user_ids, db_session)
|
||||
@@ -5,11 +5,11 @@ from urllib.parse import urlencode
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.app_configs import INSTANCE_TYPE
|
||||
from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from onyx.configs.constants import NotificationType
|
||||
from onyx.configs.constants import ONYX_UTM_SOURCE
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import User
|
||||
from onyx.db.notification import batch_create_notifications
|
||||
from onyx.server.features.release_notes.constants import DOCS_CHANGELOG_BASE_URL
|
||||
@@ -49,7 +49,7 @@ def create_release_notifications_for_versions(
|
||||
db_session.scalars(
|
||||
select(User.id).where( # type: ignore
|
||||
User.is_active == True, # noqa: E712
|
||||
User.role.notin_([UserRole.SLACK_USER, UserRole.EXT_PERM_USER]),
|
||||
User.account_type.notin_([AccountType.BOT, AccountType.EXT_PERM_USER]),
|
||||
User.email.endswith(DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN).is_(False), # type: ignore[attr-defined]
|
||||
)
|
||||
).all()
|
||||
|
||||
@@ -9,12 +9,17 @@ from sqlalchemy import update
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.enums import DefaultAppMode
|
||||
from onyx.db.enums import ThemePreference
|
||||
from onyx.db.models import AccessToken
|
||||
from onyx.db.models import Assistant__UserSpecificConfig
|
||||
from onyx.db.models import Memory
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.permissions import recompute_user_permissions__no_commit
|
||||
from onyx.db.users import assign_user_to_default_groups__no_commit
|
||||
from onyx.server.manage.models import MemoryItem
|
||||
from onyx.server.manage.models import UserSpecificAssistantPreference
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -23,13 +28,53 @@ from onyx.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
_ROLE_TO_ACCOUNT_TYPE: dict[UserRole, AccountType] = {
|
||||
UserRole.SLACK_USER: AccountType.BOT,
|
||||
UserRole.EXT_PERM_USER: AccountType.EXT_PERM_USER,
|
||||
}
|
||||
|
||||
|
||||
def update_user_role(
|
||||
user: User,
|
||||
new_role: UserRole,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Update a user's role in the database."""
|
||||
"""Update a user's role in the database.
|
||||
Dual-writes account_type to keep it in sync with role and
|
||||
reconciles default-group membership (Admin / Basic)."""
|
||||
old_role = user.role
|
||||
user.role = new_role
|
||||
# Note: setting account_type to BOT or EXT_PERM_USER causes
|
||||
# assign_user_to_default_groups__no_commit to early-return, which is
|
||||
# intentional — these account types should not be in default groups.
|
||||
if new_role in _ROLE_TO_ACCOUNT_TYPE:
|
||||
user.account_type = _ROLE_TO_ACCOUNT_TYPE[new_role]
|
||||
elif user.account_type in (AccountType.BOT, AccountType.EXT_PERM_USER):
|
||||
# Upgrading from a non-web-login account type to a web role
|
||||
user.account_type = AccountType.STANDARD
|
||||
|
||||
# Reconcile default-group membership when the role changes.
|
||||
if old_role != new_role:
|
||||
# Remove from all default groups first.
|
||||
db_session.execute(
|
||||
delete(User__UserGroup).where(
|
||||
User__UserGroup.user_id == user.id,
|
||||
User__UserGroup.user_group_id.in_(
|
||||
select(UserGroup.id).where(UserGroup.is_default.is_(True))
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
# Re-assign to the correct default group (skip for LIMITED).
|
||||
if new_role != UserRole.LIMITED:
|
||||
assign_user_to_default_groups__no_commit(
|
||||
db_session,
|
||||
user,
|
||||
is_admin=(new_role == UserRole.ADMIN),
|
||||
)
|
||||
|
||||
recompute_user_permissions__no_commit(user.id, db_session)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -47,8 +92,16 @@ def activate_user(
|
||||
user: User,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Activate a user by setting is_active to True."""
|
||||
"""Activate a user by setting is_active to True.
|
||||
|
||||
Also reconciles default-group membership — the user may have been
|
||||
created while inactive or deactivated before the backfill migration.
|
||||
"""
|
||||
user.is_active = True
|
||||
if user.role != UserRole.LIMITED:
|
||||
assign_user_to_default_groups__no_commit(
|
||||
db_session, user, is_admin=(user.role == UserRole.ADMIN)
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -17,8 +17,9 @@ from sqlalchemy.sql.expression import or_
|
||||
from onyx.auth.invited_users import remove_user_from_invited_users
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.constants import ANONYMOUS_USER_EMAIL
|
||||
from onyx.configs.constants import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from onyx.configs.constants import NO_AUTH_PLACEHOLDER_USER_EMAIL
|
||||
from onyx.db.api_key import DANSWER_API_KEY_DUMMY_EMAIL_DOMAIN
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import DocumentSet__User
|
||||
from onyx.db.models import Persona
|
||||
@@ -27,11 +28,17 @@ from onyx.db.models import SamlAccount
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def validate_user_role_update(
|
||||
requested_role: UserRole, current_role: UserRole, explicit_override: bool = False
|
||||
requested_role: UserRole,
|
||||
current_role: UserRole,
|
||||
current_account_type: AccountType,
|
||||
explicit_override: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Validate that a user role update is valid.
|
||||
@@ -41,19 +48,18 @@ def validate_user_role_update(
|
||||
- requested role is a slack user
|
||||
- requested role is an external permissioned user
|
||||
- requested role is a limited user
|
||||
- current role is a slack user
|
||||
- current role is an external permissioned user
|
||||
- current account type is BOT (slack user)
|
||||
- current account type is EXT_PERM_USER
|
||||
- current role is a limited user
|
||||
"""
|
||||
|
||||
if current_role == UserRole.SLACK_USER:
|
||||
if current_account_type == AccountType.BOT:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="To change a Slack User's role, they must first login to Onyx via the web app.",
|
||||
)
|
||||
|
||||
if current_role == UserRole.EXT_PERM_USER:
|
||||
# This shouldn't happen, but just in case
|
||||
if current_account_type == AccountType.EXT_PERM_USER:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="To change an External Permissioned User's role, they must first login to Onyx via the web app.",
|
||||
@@ -298,6 +304,7 @@ def _generate_slack_user(email: str) -> User:
|
||||
email=email,
|
||||
hashed_password=hashed_pass,
|
||||
role=UserRole.SLACK_USER,
|
||||
account_type=AccountType.BOT,
|
||||
)
|
||||
|
||||
|
||||
@@ -306,8 +313,9 @@ def add_slack_user_if_not_exists(db_session: Session, email: str) -> User:
|
||||
user = get_user_by_email(email, db_session)
|
||||
if user is not None:
|
||||
# If the user is an external permissioned user, we update it to a slack user
|
||||
if user.role == UserRole.EXT_PERM_USER:
|
||||
if user.account_type == AccountType.EXT_PERM_USER:
|
||||
user.role = UserRole.SLACK_USER
|
||||
user.account_type = AccountType.BOT
|
||||
db_session.commit()
|
||||
return user
|
||||
|
||||
@@ -344,6 +352,7 @@ def _generate_ext_permissioned_user(email: str) -> User:
|
||||
email=email,
|
||||
hashed_password=hashed_pass,
|
||||
role=UserRole.EXT_PERM_USER,
|
||||
account_type=AccountType.EXT_PERM_USER,
|
||||
)
|
||||
|
||||
|
||||
@@ -375,6 +384,81 @@ def batch_add_ext_perm_user_if_not_exists(
|
||||
return all_users
|
||||
|
||||
|
||||
def assign_user_to_default_groups__no_commit(
|
||||
db_session: Session,
|
||||
user: User,
|
||||
is_admin: bool = False,
|
||||
) -> None:
|
||||
"""Assign a newly created user to the appropriate default group.
|
||||
|
||||
Does NOT commit — callers must commit the session themselves so that
|
||||
group assignment can be part of the same transaction as user creation.
|
||||
|
||||
Args:
|
||||
is_admin: If True, assign to Admin default group; otherwise Basic.
|
||||
Callers determine this from their own context (e.g. user_count,
|
||||
admin email list, explicit choice). Defaults to False (Basic).
|
||||
"""
|
||||
if user.account_type in (
|
||||
AccountType.BOT,
|
||||
AccountType.EXT_PERM_USER,
|
||||
AccountType.ANONYMOUS,
|
||||
):
|
||||
return
|
||||
|
||||
target_group_name = "Admin" if is_admin else "Basic"
|
||||
|
||||
default_group = (
|
||||
db_session.query(UserGroup)
|
||||
.filter(
|
||||
UserGroup.name == target_group_name,
|
||||
UserGroup.is_default.is_(True),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if default_group is None:
|
||||
raise RuntimeError(
|
||||
f"Default group '{target_group_name}' not found. "
|
||||
f"Cannot assign user {user.email} to a group. "
|
||||
f"Ensure the seed_default_groups migration has run."
|
||||
)
|
||||
|
||||
# Check if the user is already in the group
|
||||
existing = (
|
||||
db_session.query(User__UserGroup)
|
||||
.filter(
|
||||
User__UserGroup.user_id == user.id,
|
||||
User__UserGroup.user_group_id == default_group.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if existing is not None:
|
||||
return
|
||||
|
||||
savepoint = db_session.begin_nested()
|
||||
try:
|
||||
db_session.add(
|
||||
User__UserGroup(
|
||||
user_id=user.id,
|
||||
user_group_id=default_group.id,
|
||||
)
|
||||
)
|
||||
db_session.flush()
|
||||
except IntegrityError:
|
||||
# Race condition: another transaction inserted this membership
|
||||
# between our SELECT and INSERT. The savepoint isolates the failure
|
||||
# so the outer transaction (user creation) stays intact.
|
||||
savepoint.rollback()
|
||||
return
|
||||
|
||||
from onyx.db.permissions import recompute_user_permissions__no_commit
|
||||
|
||||
recompute_user_permissions__no_commit(user.id, db_session)
|
||||
|
||||
logger.info(f"Assigned user {user.email} to default group '{default_group.name}'")
|
||||
|
||||
|
||||
def delete_user_from_db(
|
||||
user_to_delete: User,
|
||||
db_session: Session,
|
||||
@@ -421,13 +505,14 @@ def delete_user_from_db(
|
||||
def batch_get_user_groups(
|
||||
db_session: Session,
|
||||
user_ids: list[UUID],
|
||||
include_default: bool = False,
|
||||
) -> dict[UUID, list[tuple[int, str]]]:
|
||||
"""Fetch group memberships for a batch of users in a single query.
|
||||
Returns a mapping of user_id -> list of (group_id, group_name) tuples."""
|
||||
if not user_ids:
|
||||
return {}
|
||||
|
||||
rows = db_session.execute(
|
||||
stmt = (
|
||||
select(
|
||||
User__UserGroup.user_id,
|
||||
UserGroup.id,
|
||||
@@ -435,7 +520,11 @@ def batch_get_user_groups(
|
||||
)
|
||||
.join(UserGroup, UserGroup.id == User__UserGroup.user_group_id)
|
||||
.where(User__UserGroup.user_id.in_(user_ids))
|
||||
).all()
|
||||
)
|
||||
if not include_default:
|
||||
stmt = stmt.where(UserGroup.is_default == False) # noqa: E712
|
||||
|
||||
rows = db_session.execute(stmt).all()
|
||||
|
||||
result: dict[UUID, list[tuple[int, str]]] = {uid: [] for uid in user_ids}
|
||||
for user_id, group_id, group_name in rows:
|
||||
|
||||
@@ -3,10 +3,10 @@ import datetime
|
||||
from slack_sdk import WebClient
|
||||
from slack_sdk.errors import SlackApiError
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.onyxbot_configs import ONYX_BOT_FEEDBACK_REMINDER
|
||||
from onyx.configs.onyxbot_configs import ONYX_BOT_REACT_EMOJI
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import SlackChannelConfig
|
||||
from onyx.db.user_preferences import activate_user
|
||||
from onyx.db.users import add_slack_user_if_not_exists
|
||||
@@ -247,7 +247,7 @@ def handle_message(
|
||||
|
||||
elif (
|
||||
not existing_user.is_active
|
||||
and existing_user.role == UserRole.SLACK_USER
|
||||
and existing_user.account_type == AccountType.BOT
|
||||
):
|
||||
check_seat_fn = fetch_ee_implementation_or_noop(
|
||||
"onyx.db.license",
|
||||
|
||||
@@ -27,6 +27,7 @@ from onyx.auth.email_utils import send_user_email_invite
|
||||
from onyx.auth.invited_users import get_invited_users
|
||||
from onyx.auth.invited_users import remove_user_from_invited_users
|
||||
from onyx.auth.invited_users import write_invited_users
|
||||
from onyx.auth.permissions import get_effective_permissions
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.auth.users import anonymous_user_enabled
|
||||
from onyx.auth.users import current_admin_user
|
||||
@@ -50,6 +51,7 @@ from onyx.configs.constants import PUBLIC_API_TAGS
|
||||
from onyx.db.api_key import is_api_key_email_address
|
||||
from onyx.db.auth import get_live_users_count
|
||||
from onyx.db.engine.sql_engine import get_session
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.enums import UserFileStatus
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserFile
|
||||
@@ -142,6 +144,7 @@ def set_user_role(
|
||||
validate_user_role_update(
|
||||
requested_role=requested_role,
|
||||
current_role=current_role,
|
||||
current_account_type=user_to_update.account_type,
|
||||
explicit_override=user_role_update_request.explicit_override,
|
||||
)
|
||||
|
||||
@@ -327,8 +330,8 @@ def list_all_users(
|
||||
if (include_api_keys or not is_api_key_email_address(user.email))
|
||||
]
|
||||
|
||||
slack_users = [user for user in users if user.role == UserRole.SLACK_USER]
|
||||
accepted_users = [user for user in users if user.role != UserRole.SLACK_USER]
|
||||
slack_users = [user for user in users if user.account_type == AccountType.BOT]
|
||||
accepted_users = [user for user in users if user.account_type != AccountType.BOT]
|
||||
|
||||
accepted_emails = {user.email for user in accepted_users}
|
||||
slack_users_emails = {user.email for user in slack_users}
|
||||
@@ -671,7 +674,7 @@ def list_all_users_basic_info(
|
||||
return [
|
||||
MinimalUserSnapshot(id=user.id, email=user.email)
|
||||
for user in users
|
||||
if user.role != UserRole.SLACK_USER
|
||||
if user.account_type != AccountType.BOT
|
||||
and (include_api_keys or not is_api_key_email_address(user.email))
|
||||
]
|
||||
|
||||
@@ -774,6 +777,13 @@ def _get_token_created_at(
|
||||
return get_current_token_creation_postgres(user, db_session)
|
||||
|
||||
|
||||
@router.get("/me/permissions", tags=PUBLIC_API_TAGS)
|
||||
def get_current_user_permissions(
|
||||
user: User = Depends(current_user),
|
||||
) -> list[str]:
|
||||
return sorted(p.value for p in get_effective_permissions(user))
|
||||
|
||||
|
||||
@router.get("/me", tags=PUBLIC_API_TAGS)
|
||||
def verify_user_logged_in(
|
||||
request: Request,
|
||||
|
||||
@@ -7,6 +7,7 @@ from uuid import UUID
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import User
|
||||
|
||||
|
||||
@@ -41,6 +42,7 @@ class FullUserSnapshot(BaseModel):
|
||||
id: UUID
|
||||
email: str
|
||||
role: UserRole
|
||||
account_type: AccountType
|
||||
is_active: bool
|
||||
password_configured: bool
|
||||
personal_name: str | None
|
||||
@@ -60,6 +62,7 @@ class FullUserSnapshot(BaseModel):
|
||||
id=user.id,
|
||||
email=user.email,
|
||||
role=user.role,
|
||||
account_type=user.account_type,
|
||||
is_active=user.is_active,
|
||||
password_configured=user.password_configured,
|
||||
personal_name=user.personal_name,
|
||||
|
||||
@@ -70,7 +70,7 @@ async def upsert_saml_user(email: str) -> User:
|
||||
try:
|
||||
user = await user_manager.get_by_email(email)
|
||||
# If user has a non-authenticated role, treat as non-existent
|
||||
if not user.role.is_web_login():
|
||||
if not user.account_type.is_web_login():
|
||||
raise exceptions.UserNotExists()
|
||||
return user
|
||||
except exceptions.UserNotExists:
|
||||
|
||||
@@ -7,6 +7,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
@@ -52,7 +53,12 @@ def tenant_context() -> Generator[None, None, None]:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
def create_test_user(db_session: Session, email_prefix: str) -> User:
|
||||
def create_test_user(
|
||||
db_session: Session,
|
||||
email_prefix: str,
|
||||
role: UserRole = UserRole.BASIC,
|
||||
account_type: AccountType = AccountType.STANDARD,
|
||||
) -> User:
|
||||
"""Helper to create a test user with a unique email"""
|
||||
# Use UUID to ensure unique email addresses
|
||||
unique_email = f"{email_prefix}_{uuid4().hex[:8]}@example.com"
|
||||
@@ -68,7 +74,8 @@ def create_test_user(db_session: Session, email_prefix: str) -> User:
|
||||
is_active=True,
|
||||
is_superuser=False,
|
||||
is_verified=True,
|
||||
role=UserRole.EXT_PERM_USER,
|
||||
role=role,
|
||||
account_type=account_type,
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
@@ -13,16 +13,29 @@ from onyx.access.utils import build_ext_group_name_for_onyx
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.models import InputType
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.models import Connector
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Credential
|
||||
from onyx.db.models import PublicExternalUserGroup
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__ExternalUserGroupId
|
||||
from onyx.db.models import UserRole
|
||||
from tests.external_dependency_unit.conftest import create_test_user
|
||||
from tests.external_dependency_unit.constants import TEST_TENANT_ID
|
||||
|
||||
|
||||
def _create_ext_perm_user(db_session: Session, name: str) -> User:
|
||||
"""Create an external-permission user for group sync tests."""
|
||||
return create_test_user(
|
||||
db_session,
|
||||
name,
|
||||
role=UserRole.EXT_PERM_USER,
|
||||
account_type=AccountType.EXT_PERM_USER,
|
||||
)
|
||||
|
||||
|
||||
def _create_test_connector_credential_pair(
|
||||
db_session: Session, source: DocumentSource = DocumentSource.GOOGLE_DRIVE
|
||||
) -> ConnectorCredentialPair:
|
||||
@@ -100,9 +113,9 @@ class TestPerformExternalGroupSync:
|
||||
def test_initial_group_sync(self, db_session: Session) -> None:
|
||||
"""Test syncing external groups for the first time (initial sync)"""
|
||||
# Create test data
|
||||
user1 = create_test_user(db_session, "user1")
|
||||
user2 = create_test_user(db_session, "user2")
|
||||
user3 = create_test_user(db_session, "user3")
|
||||
user1 = _create_ext_perm_user(db_session, "user1")
|
||||
user2 = _create_ext_perm_user(db_session, "user2")
|
||||
user3 = _create_ext_perm_user(db_session, "user3")
|
||||
cc_pair = _create_test_connector_credential_pair(db_session)
|
||||
|
||||
# Mock external groups data as a generator that yields the expected groups
|
||||
@@ -175,9 +188,9 @@ class TestPerformExternalGroupSync:
|
||||
def test_update_existing_groups(self, db_session: Session) -> None:
|
||||
"""Test updating existing groups (adding/removing users)"""
|
||||
# Create test data
|
||||
user1 = create_test_user(db_session, "user1")
|
||||
user2 = create_test_user(db_session, "user2")
|
||||
user3 = create_test_user(db_session, "user3")
|
||||
user1 = _create_ext_perm_user(db_session, "user1")
|
||||
user2 = _create_ext_perm_user(db_session, "user2")
|
||||
user3 = _create_ext_perm_user(db_session, "user3")
|
||||
cc_pair = _create_test_connector_credential_pair(db_session)
|
||||
|
||||
# Initial sync with original groups
|
||||
@@ -272,8 +285,8 @@ class TestPerformExternalGroupSync:
|
||||
def test_remove_groups(self, db_session: Session) -> None:
|
||||
"""Test removing groups (groups that no longer exist in external system)"""
|
||||
# Create test data
|
||||
user1 = create_test_user(db_session, "user1")
|
||||
user2 = create_test_user(db_session, "user2")
|
||||
user1 = _create_ext_perm_user(db_session, "user1")
|
||||
user2 = _create_ext_perm_user(db_session, "user2")
|
||||
cc_pair = _create_test_connector_credential_pair(db_session)
|
||||
|
||||
# Initial sync with multiple groups
|
||||
@@ -357,7 +370,7 @@ class TestPerformExternalGroupSync:
|
||||
def test_empty_group_sync(self, db_session: Session) -> None:
|
||||
"""Test syncing when no groups are returned (all groups removed)"""
|
||||
# Create test data
|
||||
user1 = create_test_user(db_session, "user1")
|
||||
user1 = _create_ext_perm_user(db_session, "user1")
|
||||
cc_pair = _create_test_connector_credential_pair(db_session)
|
||||
|
||||
# Initial sync with groups
|
||||
@@ -413,7 +426,7 @@ class TestPerformExternalGroupSync:
|
||||
# Create many test users
|
||||
users = []
|
||||
for i in range(150): # More than the batch size of 100
|
||||
users.append(create_test_user(db_session, f"user{i}"))
|
||||
users.append(_create_ext_perm_user(db_session, f"user{i}"))
|
||||
|
||||
cc_pair = _create_test_connector_credential_pair(db_session)
|
||||
|
||||
@@ -452,8 +465,8 @@ class TestPerformExternalGroupSync:
|
||||
def test_mixed_regular_and_public_groups(self, db_session: Session) -> None:
|
||||
"""Test syncing a mix of regular and public groups"""
|
||||
# Create test data
|
||||
user1 = create_test_user(db_session, "user1")
|
||||
user2 = create_test_user(db_session, "user2")
|
||||
user1 = _create_ext_perm_user(db_session, "user1")
|
||||
user2 = _create_ext_perm_user(db_session, "user2")
|
||||
cc_pair = _create_test_connector_credential_pair(db_session)
|
||||
|
||||
def mixed_group_sync_func(
|
||||
|
||||
@@ -9,6 +9,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.enums import BuildSessionStatus
|
||||
from onyx.db.models import BuildSession
|
||||
from onyx.db.models import User
|
||||
@@ -52,6 +53,7 @@ def test_user(db_session: Session, tenant_context: None) -> User: # noqa: ARG00
|
||||
is_superuser=False,
|
||||
is_verified=True,
|
||||
role=UserRole.EXT_PERM_USER,
|
||||
account_type=AccountType.EXT_PERM_USER,
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
@@ -0,0 +1,51 @@
|
||||
"""
|
||||
Tests that account_type is correctly set when creating users through
|
||||
the internal DB functions: add_slack_user_if_not_exists and
|
||||
batch_add_ext_perm_user_if_not_exists.
|
||||
|
||||
These functions are called by background workers (Slack bot, permission sync)
|
||||
and are not exposed via API endpoints, so they must be tested directly.
|
||||
"""
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.db.users import add_slack_user_if_not_exists
|
||||
from onyx.db.users import batch_add_ext_perm_user_if_not_exists
|
||||
|
||||
|
||||
def test_slack_user_creation_sets_account_type_bot(db_session: Session) -> None:
|
||||
"""add_slack_user_if_not_exists sets account_type=BOT and role=SLACK_USER."""
|
||||
user = add_slack_user_if_not_exists(db_session, "slack_acct_type@test.com")
|
||||
|
||||
assert user.role == UserRole.SLACK_USER
|
||||
assert user.account_type == AccountType.BOT
|
||||
|
||||
|
||||
def test_ext_perm_user_creation_sets_account_type(db_session: Session) -> None:
|
||||
"""batch_add_ext_perm_user_if_not_exists sets account_type=EXT_PERM_USER."""
|
||||
users = batch_add_ext_perm_user_if_not_exists(
|
||||
db_session, ["extperm_acct_type@test.com"]
|
||||
)
|
||||
|
||||
assert len(users) == 1
|
||||
user = users[0]
|
||||
assert user.role == UserRole.EXT_PERM_USER
|
||||
assert user.account_type == AccountType.EXT_PERM_USER
|
||||
|
||||
|
||||
def test_ext_perm_to_slack_upgrade_updates_role_and_account_type(
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""When an EXT_PERM_USER is upgraded to slack, both role and account_type update."""
|
||||
email = "ext_to_slack_acct_type@test.com"
|
||||
|
||||
# Create as ext_perm user first
|
||||
batch_add_ext_perm_user_if_not_exists(db_session, [email])
|
||||
|
||||
# Now "upgrade" via slack path
|
||||
user = add_slack_user_if_not_exists(db_session, email)
|
||||
|
||||
assert user.role == UserRole.SLACK_USER
|
||||
assert user.account_type == AccountType.BOT
|
||||
@@ -8,6 +8,7 @@ import pytest
|
||||
from fastapi_users.password import PasswordHelper
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.llm import fetch_existing_llm_provider
|
||||
from onyx.db.llm import remove_llm_provider
|
||||
from onyx.db.llm import update_default_provider
|
||||
@@ -46,6 +47,7 @@ def _create_admin(db_session: Session) -> User:
|
||||
is_superuser=True,
|
||||
is_verified=True,
|
||||
role=UserRole.ADMIN,
|
||||
account_type=AccountType.STANDARD,
|
||||
)
|
||||
db_session.add(user)
|
||||
db_session.commit()
|
||||
|
||||
@@ -126,6 +126,15 @@ class UserManager:
|
||||
|
||||
return test_user
|
||||
|
||||
@staticmethod
|
||||
def get_permissions(user: DATestUser) -> list[str]:
|
||||
response = requests.get(
|
||||
url=f"{API_SERVER_URL}/me/permissions",
|
||||
headers=user.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
@staticmethod
|
||||
def is_role(
|
||||
user_to_verify: DATestUser,
|
||||
|
||||
@@ -104,13 +104,30 @@ class UserGroupManager:
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@staticmethod
|
||||
def get_permissions(
|
||||
user_group: DATestUserGroup,
|
||||
user_performing_action: DATestUser,
|
||||
) -> list[str]:
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/manage/admin/user-group/{user_group.id}/permissions",
|
||||
headers=user_performing_action.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
@staticmethod
|
||||
def get_all(
|
||||
user_performing_action: DATestUser,
|
||||
include_default: bool = False,
|
||||
) -> list[UserGroup]:
|
||||
params: dict[str, str] = {}
|
||||
if include_default:
|
||||
params["include_default"] = "true"
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/manage/admin/user-group",
|
||||
headers=user_performing_action.headers,
|
||||
params=params,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return [UserGroup(**ug) for ug in response.json()]
|
||||
|
||||
@@ -1,9 +1,13 @@
|
||||
from uuid import UUID
|
||||
|
||||
import requests
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.enums import AccountType
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.api_key import APIKeyManager
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
from tests.integration.common_utils.test_models import DATestAPIKey
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
@@ -33,3 +37,120 @@ def test_limited(reset: None) -> None: # noqa: ARG001
|
||||
headers=api_key.headers,
|
||||
)
|
||||
assert response.status_code == 403
|
||||
|
||||
|
||||
def _get_service_account_account_type(
|
||||
admin_user: DATestUser,
|
||||
api_key_user_id: UUID,
|
||||
) -> AccountType:
|
||||
"""Fetch the account_type of a service account user via the user listing API."""
|
||||
response = requests.get(
|
||||
f"{API_SERVER_URL}/manage/users",
|
||||
headers=admin_user.headers,
|
||||
params={"include_api_keys": "true"},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
user_id_str = str(api_key_user_id)
|
||||
for user in data["accepted"]:
|
||||
if user["id"] == user_id_str:
|
||||
return AccountType(user["account_type"])
|
||||
raise AssertionError(
|
||||
f"Service account user {user_id_str} not found in user listing"
|
||||
)
|
||||
|
||||
|
||||
def _get_default_group_user_ids(
|
||||
admin_user: DATestUser,
|
||||
) -> tuple[set[str], set[str]]:
|
||||
"""Return (admin_group_user_ids, basic_group_user_ids) from default groups."""
|
||||
all_groups = UserGroupManager.get_all(
|
||||
user_performing_action=admin_user,
|
||||
include_default=True,
|
||||
)
|
||||
admin_group = next(
|
||||
(g for g in all_groups if g.name == "Admin" and g.is_default), None
|
||||
)
|
||||
basic_group = next(
|
||||
(g for g in all_groups if g.name == "Basic" and g.is_default), None
|
||||
)
|
||||
assert admin_group is not None, "Admin default group not found"
|
||||
assert basic_group is not None, "Basic default group not found"
|
||||
|
||||
admin_ids = {str(u.id) for u in admin_group.users}
|
||||
basic_ids = {str(u.id) for u in basic_group.users}
|
||||
return admin_ids, basic_ids
|
||||
|
||||
|
||||
def test_api_key_limited_service_account(reset: None) -> None: # noqa: ARG001
|
||||
"""LIMITED role API key: account_type is SERVICE_ACCOUNT, no group membership."""
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
api_key: DATestAPIKey = APIKeyManager.create(
|
||||
api_key_role=UserRole.LIMITED,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Verify account_type
|
||||
account_type = _get_service_account_account_type(admin_user, api_key.user_id)
|
||||
assert (
|
||||
account_type == AccountType.SERVICE_ACCOUNT
|
||||
), f"Expected account_type={AccountType.SERVICE_ACCOUNT}, got {account_type}"
|
||||
|
||||
# Verify no group membership
|
||||
admin_ids, basic_ids = _get_default_group_user_ids(admin_user)
|
||||
user_id_str = str(api_key.user_id)
|
||||
assert (
|
||||
user_id_str not in admin_ids
|
||||
), "LIMITED API key should NOT be in Admin default group"
|
||||
assert (
|
||||
user_id_str not in basic_ids
|
||||
), "LIMITED API key should NOT be in Basic default group"
|
||||
|
||||
|
||||
def test_api_key_basic_service_account(reset: None) -> None: # noqa: ARG001
|
||||
"""BASIC role API key: account_type is SERVICE_ACCOUNT, in Basic group only."""
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
api_key: DATestAPIKey = APIKeyManager.create(
|
||||
api_key_role=UserRole.BASIC,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Verify account_type
|
||||
account_type = _get_service_account_account_type(admin_user, api_key.user_id)
|
||||
assert (
|
||||
account_type == AccountType.SERVICE_ACCOUNT
|
||||
), f"Expected account_type={AccountType.SERVICE_ACCOUNT}, got {account_type}"
|
||||
|
||||
# Verify Basic group membership
|
||||
admin_ids, basic_ids = _get_default_group_user_ids(admin_user)
|
||||
user_id_str = str(api_key.user_id)
|
||||
assert user_id_str in basic_ids, "BASIC API key should be in Basic default group"
|
||||
assert (
|
||||
user_id_str not in admin_ids
|
||||
), "BASIC API key should NOT be in Admin default group"
|
||||
|
||||
|
||||
def test_api_key_admin_service_account(reset: None) -> None: # noqa: ARG001
|
||||
"""ADMIN role API key: account_type is SERVICE_ACCOUNT, in Admin group only."""
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
|
||||
api_key: DATestAPIKey = APIKeyManager.create(
|
||||
api_key_role=UserRole.ADMIN,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Verify account_type
|
||||
account_type = _get_service_account_account_type(admin_user, api_key.user_id)
|
||||
assert (
|
||||
account_type == AccountType.SERVICE_ACCOUNT
|
||||
), f"Expected account_type={AccountType.SERVICE_ACCOUNT}, got {account_type}"
|
||||
|
||||
# Verify Admin group membership
|
||||
admin_ids, basic_ids = _get_default_group_user_ids(admin_user)
|
||||
user_id_str = str(api_key.user_id)
|
||||
assert user_id_str in admin_ids, "ADMIN API key should be in Admin default group"
|
||||
assert (
|
||||
user_id_str not in basic_ids
|
||||
), "ADMIN API key should NOT be in Basic default group"
|
||||
|
||||
@@ -4,11 +4,32 @@ import pytest
|
||||
import requests
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.enums import AccountType
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def _simulate_saml_login(email: str, admin_user: DATestUser) -> dict:
|
||||
"""Simulate a SAML login by calling the test upsert endpoint."""
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/manage/users/test-upsert-user",
|
||||
json={"email": email},
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
|
||||
|
||||
def _get_basic_group_member_emails(admin_user: DATestUser) -> set[str]:
|
||||
"""Get the set of emails of all members in the Basic default group."""
|
||||
all_groups = UserGroupManager.get_all(admin_user, include_default=True)
|
||||
basic_default = [g for g in all_groups if g.is_default and g.name == "Basic"]
|
||||
assert basic_default, "Basic default group not found"
|
||||
return {u.email for u in basic_default[0].users}
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="SAML tests are enterprise only",
|
||||
@@ -49,15 +70,9 @@ def test_saml_user_conversion(reset: None) -> None: # noqa: ARG001
|
||||
assert UserManager.is_role(test_user, UserRole.EXT_PERM_USER)
|
||||
|
||||
# Simulate SAML login by calling the test endpoint
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/manage/users/test-upsert-user",
|
||||
json={"email": test_user_email},
|
||||
headers=admin_user.headers, # Use admin headers for authorization
|
||||
)
|
||||
response.raise_for_status()
|
||||
user_data = _simulate_saml_login(test_user_email, admin_user)
|
||||
|
||||
# Verify the response indicates the role changed to BASIC
|
||||
user_data = response.json()
|
||||
assert user_data["role"] == UserRole.BASIC.value
|
||||
|
||||
# Verify user role was changed in the database
|
||||
@@ -82,16 +97,237 @@ def test_saml_user_conversion(reset: None) -> None: # noqa: ARG001
|
||||
assert UserManager.is_role(slack_user, UserRole.SLACK_USER)
|
||||
|
||||
# Simulate SAML login again
|
||||
response = requests.post(
|
||||
f"{API_SERVER_URL}/manage/users/test-upsert-user",
|
||||
json={"email": slack_user_email},
|
||||
headers=admin_user.headers,
|
||||
)
|
||||
response.raise_for_status()
|
||||
user_data = _simulate_saml_login(slack_user_email, admin_user)
|
||||
|
||||
# Verify the response indicates the role changed to BASIC
|
||||
user_data = response.json()
|
||||
assert user_data["role"] == UserRole.BASIC.value
|
||||
|
||||
# Verify the user's role was changed in the database
|
||||
assert UserManager.is_role(slack_user, UserRole.BASIC)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="SAML tests are enterprise only",
|
||||
)
|
||||
def test_saml_user_conversion_sets_account_type_and_group(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""
|
||||
Test that SAML login sets account_type to STANDARD when converting a
|
||||
non-web user (EXT_PERM_USER) and that the user receives the correct role
|
||||
(BASIC) after conversion.
|
||||
|
||||
This validates the permissions-migration-phase2 changes which ensure that:
|
||||
1. account_type is updated to 'standard' on SAML conversion
|
||||
2. The converted user is assigned to the Basic default group
|
||||
"""
|
||||
# Create an admin user (first user is automatically admin)
|
||||
admin_user: DATestUser = UserManager.create(email="admin@example.com")
|
||||
|
||||
# Create a user and set them as EXT_PERM_USER
|
||||
test_email = "ext_convert@example.com"
|
||||
test_user = UserManager.create(email=test_email)
|
||||
UserManager.set_role(
|
||||
user_to_set=test_user,
|
||||
target_role=UserRole.EXT_PERM_USER,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
assert UserManager.is_role(test_user, UserRole.EXT_PERM_USER)
|
||||
|
||||
# Simulate SAML login
|
||||
user_data = _simulate_saml_login(test_email, admin_user)
|
||||
|
||||
# Verify account_type is set to standard after conversion
|
||||
assert (
|
||||
user_data["account_type"] == AccountType.STANDARD.value
|
||||
), f"Expected account_type='{AccountType.STANDARD.value}', got '{user_data['account_type']}'"
|
||||
|
||||
# Verify role is BASIC after conversion
|
||||
assert user_data["role"] == UserRole.BASIC.value
|
||||
|
||||
# Verify the user was assigned to the Basic default group
|
||||
assert test_email in _get_basic_group_member_emails(
|
||||
admin_user
|
||||
), f"Converted user '{test_email}' not found in Basic default group"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="SAML tests are enterprise only",
|
||||
)
|
||||
def test_saml_normal_signin_assigns_group(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""
|
||||
Test that a brand-new user signing in via SAML for the first time
|
||||
is created with the correct role, account_type, and group membership.
|
||||
|
||||
This validates that normal SAML sign-in (not an upgrade from
|
||||
SLACK_USER/EXT_PERM_USER) correctly:
|
||||
1. Creates the user with role=BASIC and account_type=STANDARD
|
||||
2. Assigns the user to the Basic default group
|
||||
"""
|
||||
# First user becomes admin
|
||||
admin_user: DATestUser = UserManager.create(email="admin@example.com")
|
||||
|
||||
# New user signs in via SAML (no prior account)
|
||||
new_email = "new_saml_user@example.com"
|
||||
user_data = _simulate_saml_login(new_email, admin_user)
|
||||
|
||||
# Verify role and account_type
|
||||
assert user_data["role"] == UserRole.BASIC.value
|
||||
assert user_data["account_type"] == AccountType.STANDARD.value
|
||||
|
||||
# Verify user is in the Basic default group
|
||||
assert new_email in _get_basic_group_member_emails(
|
||||
admin_user
|
||||
), f"New SAML user '{new_email}' not found in Basic default group"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="SAML tests are enterprise only",
|
||||
)
|
||||
def test_saml_user_conversion_restores_group_membership(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""
|
||||
Test that SAML login restores Basic group membership when converting
|
||||
a non-authenticated user (EXT_PERM_USER or SLACK_USER) to BASIC.
|
||||
|
||||
Group membership implies 'basic' permission (verified by
|
||||
test_new_group_gets_basic_permission).
|
||||
"""
|
||||
admin_user: DATestUser = UserManager.create(email="admin@example.com")
|
||||
|
||||
# --- EXT_PERM_USER path ---
|
||||
ext_email = "ext_perm_perms@example.com"
|
||||
ext_user = UserManager.create(email=ext_email)
|
||||
assert ext_email in _get_basic_group_member_emails(admin_user)
|
||||
|
||||
UserManager.set_role(
|
||||
user_to_set=ext_user,
|
||||
target_role=UserRole.EXT_PERM_USER,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
assert ext_email not in _get_basic_group_member_emails(admin_user)
|
||||
|
||||
user_data = _simulate_saml_login(ext_email, admin_user)
|
||||
assert user_data["role"] == UserRole.BASIC.value
|
||||
assert ext_email in _get_basic_group_member_emails(
|
||||
admin_user
|
||||
), "EXT_PERM_USER should be back in Basic group after SAML conversion"
|
||||
|
||||
# --- SLACK_USER path ---
|
||||
slack_email = "slack_perms@example.com"
|
||||
slack_user = UserManager.create(email=slack_email)
|
||||
|
||||
UserManager.set_role(
|
||||
user_to_set=slack_user,
|
||||
target_role=UserRole.SLACK_USER,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
assert slack_email not in _get_basic_group_member_emails(admin_user)
|
||||
|
||||
user_data = _simulate_saml_login(slack_email, admin_user)
|
||||
assert user_data["role"] == UserRole.BASIC.value
|
||||
assert slack_email in _get_basic_group_member_emails(
|
||||
admin_user
|
||||
), "SLACK_USER should be back in Basic group after SAML conversion"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="SAML tests are enterprise only",
|
||||
)
|
||||
def test_saml_round_trip_group_lifecycle(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""
|
||||
Test the full round-trip: BASIC -> EXT_PERM -> SAML(BASIC) -> EXT_PERM -> SAML(BASIC).
|
||||
|
||||
Verifies group membership is correctly removed and restored at each transition.
|
||||
"""
|
||||
admin_user: DATestUser = UserManager.create(email="admin@example.com")
|
||||
|
||||
test_email = "roundtrip@example.com"
|
||||
test_user = UserManager.create(email=test_email)
|
||||
|
||||
# Step 1: BASIC user is in Basic group
|
||||
assert test_email in _get_basic_group_member_emails(admin_user)
|
||||
|
||||
# Step 2: Downgrade to EXT_PERM_USER — loses Basic group
|
||||
UserManager.set_role(
|
||||
user_to_set=test_user,
|
||||
target_role=UserRole.EXT_PERM_USER,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
assert test_email not in _get_basic_group_member_emails(admin_user)
|
||||
|
||||
# Step 3: SAML login — converts back to BASIC, regains Basic group
|
||||
_simulate_saml_login(test_email, admin_user)
|
||||
assert test_email in _get_basic_group_member_emails(
|
||||
admin_user
|
||||
), "Should be in Basic group after first SAML conversion"
|
||||
|
||||
# Step 4: Downgrade again
|
||||
UserManager.set_role(
|
||||
user_to_set=test_user,
|
||||
target_role=UserRole.EXT_PERM_USER,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
assert test_email not in _get_basic_group_member_emails(admin_user)
|
||||
|
||||
# Step 5: SAML login again — should still restore correctly
|
||||
_simulate_saml_login(test_email, admin_user)
|
||||
assert test_email in _get_basic_group_member_emails(
|
||||
admin_user
|
||||
), "Should be in Basic group after second SAML conversion"
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="SAML tests are enterprise only",
|
||||
)
|
||||
def test_saml_slack_user_conversion_sets_account_type_and_group(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""
|
||||
Test that SAML login sets account_type to STANDARD and assigns Basic group
|
||||
when converting a SLACK_USER (BOT account_type).
|
||||
|
||||
Mirrors test_saml_user_conversion_sets_account_type_and_group but for
|
||||
SLACK_USER instead of EXT_PERM_USER, and additionally verifies permissions.
|
||||
"""
|
||||
admin_user: DATestUser = UserManager.create(email="admin@example.com")
|
||||
|
||||
test_email = "slack_convert@example.com"
|
||||
test_user = UserManager.create(email=test_email)
|
||||
|
||||
UserManager.set_role(
|
||||
user_to_set=test_user,
|
||||
target_role=UserRole.SLACK_USER,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
assert UserManager.is_role(test_user, UserRole.SLACK_USER)
|
||||
|
||||
# SAML login
|
||||
user_data = _simulate_saml_login(test_email, admin_user)
|
||||
|
||||
# Verify account_type and role
|
||||
assert (
|
||||
user_data["account_type"] == AccountType.STANDARD.value
|
||||
), f"Expected STANDARD, got {user_data['account_type']}"
|
||||
assert user_data["role"] == UserRole.BASIC.value
|
||||
|
||||
# Verify Basic group membership (implies 'basic' permission)
|
||||
assert test_email in _get_basic_group_member_emails(
|
||||
admin_user
|
||||
), f"Converted SLACK_USER '{test_email}' not found in Basic default group"
|
||||
|
||||
@@ -0,0 +1,82 @@
|
||||
"""Integration tests for permission propagation across auth-triggered group changes.
|
||||
|
||||
These tests verify that effective permissions (via /me/permissions) actually
|
||||
propagate when users are added/removed from default groups through role changes.
|
||||
Custom permission grant tests will be added once the permission grant API is built.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def _get_basic_group_member_emails(admin_user: DATestUser) -> set[str]:
|
||||
all_groups = UserGroupManager.get_all(admin_user, include_default=True)
|
||||
basic_group = next(
|
||||
(g for g in all_groups if g.is_default and g.name == "Basic"), None
|
||||
)
|
||||
assert basic_group is not None, "Basic default group not found"
|
||||
return {u.email for u in basic_group.users}
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="Permission propagation tests require enterprise features",
|
||||
)
|
||||
def test_basic_permission_granted_on_registration(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""New users should get 'basic' permission through default group assignment."""
|
||||
admin_user: DATestUser = UserManager.create(email="admin@example.com")
|
||||
basic_user: DATestUser = UserManager.create(email="basic@example.com")
|
||||
|
||||
# Admin should have permissions from Admin group
|
||||
admin_perms = UserManager.get_permissions(admin_user)
|
||||
assert "basic" in admin_perms
|
||||
|
||||
# Basic user should have 'basic' from Basic default group
|
||||
basic_perms = UserManager.get_permissions(basic_user)
|
||||
assert "basic" in basic_perms
|
||||
|
||||
# Verify group membership matches
|
||||
assert basic_user.email in _get_basic_group_member_emails(admin_user)
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="Permission propagation tests require enterprise features",
|
||||
)
|
||||
def test_role_downgrade_removes_basic_group_and_permission(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Downgrading to EXT_PERM_USER or SLACK_USER should remove from Basic group."""
|
||||
admin_user: DATestUser = UserManager.create(email="admin@example.com")
|
||||
|
||||
# --- EXT_PERM_USER ---
|
||||
ext_user: DATestUser = UserManager.create(email="ext@example.com")
|
||||
assert ext_user.email in _get_basic_group_member_emails(admin_user)
|
||||
|
||||
UserManager.set_role(
|
||||
user_to_set=ext_user,
|
||||
target_role=UserRole.EXT_PERM_USER,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
assert ext_user.email not in _get_basic_group_member_emails(admin_user)
|
||||
|
||||
# --- SLACK_USER ---
|
||||
slack_user: DATestUser = UserManager.create(email="slack@example.com")
|
||||
assert slack_user.email in _get_basic_group_member_emails(admin_user)
|
||||
|
||||
UserManager.set_role(
|
||||
user_to_set=slack_user,
|
||||
target_role=UserRole.SLACK_USER,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
assert slack_user.email not in _get_basic_group_member_emails(admin_user)
|
||||
@@ -21,8 +21,15 @@ import pytest
|
||||
import requests
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from tests.integration.common_utils.constants import ADMIN_USER_NAME
|
||||
from tests.integration.common_utils.constants import API_SERVER_URL
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.managers.scim_client import ScimClient
|
||||
from tests.integration.common_utils.managers.scim_token import ScimTokenManager
|
||||
from tests.integration.common_utils.managers.user import build_email
|
||||
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
SCIM_GROUP_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Group"
|
||||
@@ -44,13 +51,6 @@ def scim_token(idp_style: str) -> str:
|
||||
per IdP-style run and reuse. Uses UserManager directly to avoid
|
||||
fixture-scope conflicts with the function-scoped admin_user fixture.
|
||||
"""
|
||||
from tests.integration.common_utils.constants import ADMIN_USER_NAME
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.managers.user import build_email
|
||||
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
try:
|
||||
admin = UserManager.create(name=ADMIN_USER_NAME)
|
||||
except Exception:
|
||||
@@ -550,3 +550,145 @@ def test_patch_add_duplicate_member_is_idempotent(
|
||||
)
|
||||
assert resp.status_code == 200
|
||||
assert len(resp.json()["members"]) == 1 # still just one member
|
||||
|
||||
|
||||
def test_create_group_reserved_name_admin(scim_token: str) -> None:
|
||||
"""POST /Groups with reserved name 'Admin' returns 409."""
|
||||
resp = _create_scim_group(scim_token, "Admin", external_id="ext-reserved-admin")
|
||||
assert resp.status_code == 409
|
||||
assert "reserved" in resp.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_create_group_reserved_name_basic(scim_token: str) -> None:
|
||||
"""POST /Groups with reserved name 'Basic' returns 409."""
|
||||
resp = _create_scim_group(scim_token, "Basic", external_id="ext-reserved-basic")
|
||||
assert resp.status_code == 409
|
||||
assert "reserved" in resp.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_replace_group_cannot_rename_to_reserved(
|
||||
scim_token: str, idp_style: str
|
||||
) -> None:
|
||||
"""PUT /Groups/{id} renaming a group to 'Admin' returns 409."""
|
||||
created = _create_scim_group(
|
||||
scim_token,
|
||||
f"Rename To Reserved {idp_style}",
|
||||
external_id=f"ext-rtr-{idp_style}",
|
||||
).json()
|
||||
|
||||
resp = ScimClient.put(
|
||||
f"/Groups/{created['id']}",
|
||||
scim_token,
|
||||
json=_make_group_resource(
|
||||
display_name="Admin", external_id=f"ext-rtr-{idp_style}"
|
||||
),
|
||||
)
|
||||
assert resp.status_code == 409
|
||||
assert "reserved" in resp.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_patch_rename_to_reserved_name(scim_token: str, idp_style: str) -> None:
|
||||
"""PATCH /Groups/{id} renaming a group to 'Basic' returns 409."""
|
||||
created = _create_scim_group(
|
||||
scim_token,
|
||||
f"Patch Rename Reserved {idp_style}",
|
||||
external_id=f"ext-prr-{idp_style}",
|
||||
).json()
|
||||
|
||||
resp = ScimClient.patch(
|
||||
f"/Groups/{created['id']}",
|
||||
scim_token,
|
||||
json=_make_patch_request(
|
||||
[{"op": "replace", "path": "displayName", "value": "Basic"}],
|
||||
idp_style,
|
||||
),
|
||||
)
|
||||
assert resp.status_code == 409
|
||||
assert "reserved" in resp.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_delete_reserved_group_rejected(scim_token: str) -> None:
|
||||
"""DELETE /Groups/{id} on a reserved group ('Admin') returns 409."""
|
||||
# Look up the reserved 'Admin' group via SCIM filter
|
||||
resp = ScimClient.get('/Groups?filter=displayName eq "Admin"', scim_token)
|
||||
assert resp.status_code == 200
|
||||
resources = resp.json()["Resources"]
|
||||
assert len(resources) >= 1, "Expected reserved 'Admin' group to exist"
|
||||
admin_group_id = resources[0]["id"]
|
||||
|
||||
resp = ScimClient.delete(f"/Groups/{admin_group_id}", scim_token)
|
||||
assert resp.status_code == 409
|
||||
assert "reserved" in resp.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_scim_created_group_has_basic_permission(
|
||||
scim_token: str, idp_style: str
|
||||
) -> None:
|
||||
"""POST /Groups assigns the 'basic' permission to the group itself."""
|
||||
# Create a SCIM group (no members needed — we check the group's permissions)
|
||||
resp = _create_scim_group(
|
||||
scim_token,
|
||||
f"Basic Perm Group {idp_style}",
|
||||
external_id=f"ext-basic-perm-{idp_style}",
|
||||
)
|
||||
assert resp.status_code == 201
|
||||
group_id = resp.json()["id"]
|
||||
|
||||
# Log in as the admin user (created by the scim_token fixture).
|
||||
admin = DATestUser(
|
||||
id="",
|
||||
email=build_email(ADMIN_USER_NAME),
|
||||
password=DEFAULT_PASSWORD,
|
||||
headers=GENERAL_HEADERS,
|
||||
role=UserRole.ADMIN,
|
||||
is_active=True,
|
||||
)
|
||||
admin = UserManager.login_as_user(admin)
|
||||
|
||||
# Verify the group itself was granted the basic permission
|
||||
perms_resp = requests.get(
|
||||
f"{API_SERVER_URL}/manage/admin/user-group/{group_id}/permissions",
|
||||
headers=admin.headers,
|
||||
)
|
||||
perms_resp.raise_for_status()
|
||||
perms = perms_resp.json()
|
||||
assert "basic" in perms, f"SCIM group should have 'basic' permission, got: {perms}"
|
||||
|
||||
|
||||
def test_replace_group_cannot_rename_from_reserved(scim_token: str) -> None:
|
||||
"""PUT /Groups/{id} renaming a reserved group ('Admin') to a non-reserved name returns 409."""
|
||||
resp = ScimClient.get('/Groups?filter=displayName eq "Admin"', scim_token)
|
||||
assert resp.status_code == 200
|
||||
resources = resp.json()["Resources"]
|
||||
assert len(resources) >= 1, "Expected reserved 'Admin' group to exist"
|
||||
admin_group_id = resources[0]["id"]
|
||||
|
||||
resp = ScimClient.put(
|
||||
f"/Groups/{admin_group_id}",
|
||||
scim_token,
|
||||
json=_make_group_resource(
|
||||
display_name="RenamedAdmin", external_id="ext-rename-from-reserved"
|
||||
),
|
||||
)
|
||||
assert resp.status_code == 409
|
||||
assert "reserved" in resp.json()["detail"].lower()
|
||||
|
||||
|
||||
def test_patch_rename_from_reserved_name(scim_token: str, idp_style: str) -> None:
|
||||
"""PATCH /Groups/{id} renaming a reserved group ('Admin') returns 409."""
|
||||
resp = ScimClient.get('/Groups?filter=displayName eq "Admin"', scim_token)
|
||||
assert resp.status_code == 200
|
||||
resources = resp.json()["Resources"]
|
||||
assert len(resources) >= 1, "Expected reserved 'Admin' group to exist"
|
||||
admin_group_id = resources[0]["id"]
|
||||
|
||||
resp = ScimClient.patch(
|
||||
f"/Groups/{admin_group_id}",
|
||||
scim_token,
|
||||
json=_make_patch_request(
|
||||
[{"op": "replace", "path": "displayName", "value": "RenamedAdmin"}],
|
||||
idp_style,
|
||||
),
|
||||
)
|
||||
assert resp.status_code == 409
|
||||
assert "reserved" in resp.json()["detail"].lower()
|
||||
|
||||
@@ -35,9 +35,16 @@ from onyx.auth.schemas import UserRole
|
||||
from onyx.configs.app_configs import REDIS_DB_NUMBER
|
||||
from onyx.configs.app_configs import REDIS_HOST
|
||||
from onyx.configs.app_configs import REDIS_PORT
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.server.settings.models import ApplicationStatus
|
||||
from tests.integration.common_utils.constants import ADMIN_USER_NAME
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.managers.scim_client import ScimClient
|
||||
from tests.integration.common_utils.managers.scim_token import ScimTokenManager
|
||||
from tests.integration.common_utils.managers.user import build_email
|
||||
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
SCIM_USER_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:User"
|
||||
@@ -211,6 +218,49 @@ def test_create_user(scim_token: str, idp_style: str) -> None:
|
||||
_assert_entra_emails(body, email)
|
||||
|
||||
|
||||
def test_create_user_default_group_and_account_type(
|
||||
scim_token: str, idp_style: str
|
||||
) -> None:
|
||||
"""SCIM-provisioned users get Basic default group and STANDARD account_type."""
|
||||
email = f"scim_defaults_{idp_style}@example.com"
|
||||
ext_id = f"ext-defaults-{idp_style}"
|
||||
resp = _create_scim_user(scim_token, email, ext_id, idp_style)
|
||||
assert resp.status_code == 201
|
||||
user_id = resp.json()["id"]
|
||||
|
||||
# --- Verify group assignment via SCIM GET ---
|
||||
get_resp = ScimClient.get(f"/Users/{user_id}", scim_token)
|
||||
assert get_resp.status_code == 200
|
||||
groups = get_resp.json().get("groups", [])
|
||||
group_names = {g["display"] for g in groups}
|
||||
assert "Basic" in group_names, f"Expected 'Basic' in groups, got {group_names}"
|
||||
assert "Admin" not in group_names, "SCIM user should not be in Admin group"
|
||||
|
||||
# --- Verify account_type via admin API ---
|
||||
admin = UserManager.login_as_user(
|
||||
DATestUser(
|
||||
id="",
|
||||
email=build_email(ADMIN_USER_NAME),
|
||||
password=DEFAULT_PASSWORD,
|
||||
headers=GENERAL_HEADERS,
|
||||
role=UserRole.ADMIN,
|
||||
is_active=True,
|
||||
)
|
||||
)
|
||||
page = UserManager.get_user_page(
|
||||
user_performing_action=admin,
|
||||
search_query=email,
|
||||
)
|
||||
assert page.total_items >= 1
|
||||
scim_user_snapshot = next((u for u in page.items if u.email == email), None)
|
||||
assert (
|
||||
scim_user_snapshot is not None
|
||||
), f"SCIM user {email} not found in user listing"
|
||||
assert (
|
||||
scim_user_snapshot.account_type == AccountType.STANDARD
|
||||
), f"Expected STANDARD, got {scim_user_snapshot.account_type}"
|
||||
|
||||
|
||||
def test_get_user(scim_token: str, idp_style: str) -> None:
|
||||
"""GET /Users/{id} returns the user resource with all stored fields."""
|
||||
email = f"scim_get_{idp_style}@example.com"
|
||||
|
||||
@@ -0,0 +1,118 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_current_tenant
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.db.models import PermissionGrant
|
||||
from onyx.db.models import UserGroup as UserGroupModel
|
||||
from onyx.db.permissions import recompute_permissions_for_group__no_commit
|
||||
from onyx.db.permissions import recompute_user_permissions__no_commit
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="User group tests are enterprise only",
|
||||
)
|
||||
def test_user_gets_permissions_when_added_to_group(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
admin_user: DATestUser = UserManager.create(name="admin_for_perm_test")
|
||||
basic_user: DATestUser = UserManager.create(name="basic_user_for_perm_test")
|
||||
|
||||
# basic_user starts with only "basic" from the default group
|
||||
initial_permissions = UserManager.get_permissions(basic_user)
|
||||
assert "basic" in initial_permissions
|
||||
assert "add:agents" not in initial_permissions
|
||||
|
||||
# Create a new group and add basic_user
|
||||
group = UserGroupManager.create(
|
||||
name="perm-test-group",
|
||||
user_ids=[admin_user.id, basic_user.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Grant a non-basic permission to the group and recompute
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
db_group = db_session.get(UserGroupModel, group.id)
|
||||
assert db_group is not None
|
||||
db_session.add(
|
||||
PermissionGrant(
|
||||
group_id=db_group.id,
|
||||
permission=Permission.ADD_AGENTS,
|
||||
grant_source="SYSTEM",
|
||||
)
|
||||
)
|
||||
db_session.flush()
|
||||
recompute_user_permissions__no_commit(basic_user.id, db_session)
|
||||
db_session.commit()
|
||||
|
||||
# Verify the user gained the new permission (expanded includes read:agents)
|
||||
updated_permissions = UserManager.get_permissions(basic_user)
|
||||
assert (
|
||||
"add:agents" in updated_permissions
|
||||
), f"User should have 'add:agents' after group grant, got: {updated_permissions}"
|
||||
assert (
|
||||
"read:agents" in updated_permissions
|
||||
), f"User should have implied 'read:agents', got: {updated_permissions}"
|
||||
assert "basic" in updated_permissions
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="User group tests are enterprise only",
|
||||
)
|
||||
def test_group_permission_change_propagates_to_all_members(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
admin_user: DATestUser = UserManager.create(name="admin_propagate")
|
||||
user_a: DATestUser = UserManager.create(name="user_a_propagate")
|
||||
user_b: DATestUser = UserManager.create(name="user_b_propagate")
|
||||
|
||||
group = UserGroupManager.create(
|
||||
name="propagate-test-group",
|
||||
user_ids=[admin_user.id, user_a.id, user_b.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Neither user should have add:agents yet
|
||||
for u in (user_a, user_b):
|
||||
assert "add:agents" not in UserManager.get_permissions(u)
|
||||
|
||||
# Grant add:agents to the group, then batch-recompute
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
grant = PermissionGrant(
|
||||
group_id=group.id,
|
||||
permission=Permission.ADD_AGENTS,
|
||||
grant_source="SYSTEM",
|
||||
)
|
||||
db_session.add(grant)
|
||||
db_session.flush()
|
||||
recompute_permissions_for_group__no_commit(group.id, db_session)
|
||||
db_session.commit()
|
||||
|
||||
# Both users should now have the permission (plus implied read:agents)
|
||||
for u in (user_a, user_b):
|
||||
perms = UserManager.get_permissions(u)
|
||||
assert "add:agents" in perms, f"{u.id} missing add:agents: {perms}"
|
||||
assert "read:agents" in perms, f"{u.id} missing implied read:agents: {perms}"
|
||||
|
||||
# Soft-delete the grant and recompute — permission should be removed
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
db_grant = (
|
||||
db_session.query(PermissionGrant)
|
||||
.filter_by(group_id=group.id, permission=Permission.ADD_AGENTS)
|
||||
.first()
|
||||
)
|
||||
assert db_grant is not None
|
||||
db_grant.is_deleted = True
|
||||
db_session.flush()
|
||||
recompute_permissions_for_group__no_commit(group.id, db_session)
|
||||
db_session.commit()
|
||||
|
||||
for u in (user_a, user_b):
|
||||
perms = UserManager.get_permissions(u)
|
||||
assert "add:agents" not in perms, f"{u.id} still has add:agents: {perms}"
|
||||
@@ -0,0 +1,30 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.environ.get("ENABLE_PAID_ENTERPRISE_EDITION_FEATURES", "").lower() != "true",
|
||||
reason="User group tests are enterprise only",
|
||||
)
|
||||
def test_new_group_gets_basic_permission(reset: None) -> None: # noqa: ARG001
|
||||
admin_user: DATestUser = UserManager.create(name="admin_for_basic_perm")
|
||||
|
||||
user_group = UserGroupManager.create(
|
||||
name="basic-perm-test-group",
|
||||
user_ids=[admin_user.id],
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
permissions = UserGroupManager.get_permissions(
|
||||
user_group=user_group,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
assert (
|
||||
"basic" in permissions
|
||||
), f"New group should have 'basic' permission, got: {permissions}"
|
||||
@@ -0,0 +1,78 @@
|
||||
"""Integration tests for default group assignment on user registration.
|
||||
|
||||
Verifies that:
|
||||
- The first registered user is assigned to the Admin default group
|
||||
- Subsequent registered users are assigned to the Basic default group
|
||||
- account_type is set to STANDARD for email/password registrations
|
||||
"""
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.enums import AccountType
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def test_default_group_assignment_on_registration(reset: None) -> None: # noqa: ARG001
|
||||
# Register first user — should become admin
|
||||
admin_user: DATestUser = UserManager.create(name="first_user")
|
||||
assert admin_user.role == UserRole.ADMIN
|
||||
|
||||
# Register second user — should become basic
|
||||
basic_user: DATestUser = UserManager.create(name="second_user")
|
||||
assert basic_user.role == UserRole.BASIC
|
||||
|
||||
# Fetch all groups including default ones
|
||||
all_groups = UserGroupManager.get_all(
|
||||
user_performing_action=admin_user,
|
||||
include_default=True,
|
||||
)
|
||||
|
||||
# Find the default Admin and Basic groups
|
||||
admin_group = next(
|
||||
(g for g in all_groups if g.name == "Admin" and g.is_default), None
|
||||
)
|
||||
basic_group = next(
|
||||
(g for g in all_groups if g.name == "Basic" and g.is_default), None
|
||||
)
|
||||
assert admin_group is not None, "Admin default group not found"
|
||||
assert basic_group is not None, "Basic default group not found"
|
||||
|
||||
# Verify admin user is in Admin group and NOT in Basic group
|
||||
admin_group_user_ids = {str(u.id) for u in admin_group.users}
|
||||
basic_group_user_ids = {str(u.id) for u in basic_group.users}
|
||||
|
||||
assert (
|
||||
admin_user.id in admin_group_user_ids
|
||||
), "First user should be in Admin default group"
|
||||
assert (
|
||||
admin_user.id not in basic_group_user_ids
|
||||
), "First user should NOT be in Basic default group"
|
||||
|
||||
# Verify basic user is in Basic group and NOT in Admin group
|
||||
assert (
|
||||
basic_user.id in basic_group_user_ids
|
||||
), "Second user should be in Basic default group"
|
||||
assert (
|
||||
basic_user.id not in admin_group_user_ids
|
||||
), "Second user should NOT be in Admin default group"
|
||||
|
||||
# Verify account_type is STANDARD for both users via user listing API
|
||||
paginated_result = UserManager.get_user_page(
|
||||
user_performing_action=admin_user,
|
||||
page_num=0,
|
||||
page_size=10,
|
||||
)
|
||||
users_by_id = {str(u.id): u for u in paginated_result.items}
|
||||
|
||||
admin_snapshot = users_by_id.get(admin_user.id)
|
||||
basic_snapshot = users_by_id.get(basic_user.id)
|
||||
assert admin_snapshot is not None, "Admin user not found in user listing"
|
||||
assert basic_snapshot is not None, "Basic user not found in user listing"
|
||||
|
||||
assert (
|
||||
admin_snapshot.account_type == AccountType.STANDARD
|
||||
), f"Admin user account_type should be STANDARD, got {admin_snapshot.account_type}"
|
||||
assert (
|
||||
basic_snapshot.account_type == AccountType.STANDARD
|
||||
), f"Basic user account_type should be STANDARD, got {basic_snapshot.account_type}"
|
||||
@@ -0,0 +1,135 @@
|
||||
"""Integration tests for password signup upgrade paths.
|
||||
|
||||
Verifies that when a BOT or EXT_PERM_USER user signs up via email/password:
|
||||
- Their account_type is upgraded to STANDARD
|
||||
- They are assigned to the Basic default group
|
||||
- They gain the correct effective permissions
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.enums import AccountType
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def _get_default_group_member_emails(
|
||||
admin_user: DATestUser,
|
||||
group_name: str,
|
||||
) -> set[str]:
|
||||
"""Get the set of emails of all members in a named default group."""
|
||||
all_groups = UserGroupManager.get_all(admin_user, include_default=True)
|
||||
matched = [g for g in all_groups if g.is_default and g.name == group_name]
|
||||
assert matched, f"Default group '{group_name}' not found"
|
||||
return {u.email for u in matched[0].users}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"target_role",
|
||||
[UserRole.EXT_PERM_USER, UserRole.SLACK_USER],
|
||||
ids=["ext_perm_user", "slack_user"],
|
||||
)
|
||||
def test_password_signup_upgrade(
|
||||
reset: None, # noqa: ARG001
|
||||
target_role: UserRole,
|
||||
) -> None:
|
||||
"""When a non-web user signs up via email/password, they should be
|
||||
upgraded to STANDARD account_type and assigned to the Basic default group."""
|
||||
admin_user: DATestUser = UserManager.create(email="admin@example.com")
|
||||
|
||||
test_email = f"{target_role.value}_upgrade@example.com"
|
||||
test_user = UserManager.create(email=test_email)
|
||||
|
||||
test_user = UserManager.set_role(
|
||||
user_to_set=test_user,
|
||||
target_role=target_role,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
|
||||
# Verify user was removed from Basic group after downgrade
|
||||
basic_emails = _get_default_group_member_emails(admin_user, "Basic")
|
||||
assert (
|
||||
test_email not in basic_emails
|
||||
), f"{target_role.value} should not be in Basic default group"
|
||||
|
||||
# Re-register with the same email — triggers the password signup upgrade
|
||||
upgraded_user = UserManager.create(email=test_email)
|
||||
|
||||
assert upgraded_user.role == UserRole.BASIC
|
||||
|
||||
paginated = UserManager.get_user_page(
|
||||
user_performing_action=admin_user,
|
||||
page_num=0,
|
||||
page_size=10,
|
||||
)
|
||||
user_snapshot = next(
|
||||
(u for u in paginated.items if str(u.id) == upgraded_user.id), None
|
||||
)
|
||||
assert user_snapshot is not None
|
||||
assert (
|
||||
user_snapshot.account_type == AccountType.STANDARD
|
||||
), f"Expected STANDARD, got {user_snapshot.account_type}"
|
||||
|
||||
# Verify user is now in the Basic default group
|
||||
basic_emails = _get_default_group_member_emails(admin_user, "Basic")
|
||||
assert (
|
||||
test_email in basic_emails
|
||||
), f"Upgraded user '{test_email}' not found in Basic default group"
|
||||
|
||||
|
||||
def test_password_signup_upgrade_propagates_permissions(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""When an EXT_PERM_USER or SLACK_USER signs up via password, they should
|
||||
gain the 'basic' permission through the Basic default group assignment."""
|
||||
admin_user: DATestUser = UserManager.create(email="admin@example.com")
|
||||
|
||||
# --- EXT_PERM_USER path ---
|
||||
ext_email = "ext_perms_check@example.com"
|
||||
ext_user = UserManager.create(email=ext_email)
|
||||
|
||||
initial_perms = UserManager.get_permissions(ext_user)
|
||||
assert "basic" in initial_perms
|
||||
|
||||
ext_user = UserManager.set_role(
|
||||
user_to_set=ext_user,
|
||||
target_role=UserRole.EXT_PERM_USER,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
|
||||
basic_emails = _get_default_group_member_emails(admin_user, "Basic")
|
||||
assert ext_email not in basic_emails
|
||||
|
||||
upgraded = UserManager.create(email=ext_email)
|
||||
assert upgraded.role == UserRole.BASIC
|
||||
|
||||
perms = UserManager.get_permissions(upgraded)
|
||||
assert (
|
||||
"basic" in perms
|
||||
), f"Upgraded EXT_PERM_USER should have 'basic' permission, got: {perms}"
|
||||
|
||||
# --- SLACK_USER path ---
|
||||
slack_email = "slack_perms_check@example.com"
|
||||
slack_user = UserManager.create(email=slack_email)
|
||||
|
||||
slack_user = UserManager.set_role(
|
||||
user_to_set=slack_user,
|
||||
target_role=UserRole.SLACK_USER,
|
||||
user_performing_action=admin_user,
|
||||
explicit_override=True,
|
||||
)
|
||||
|
||||
basic_emails = _get_default_group_member_emails(admin_user, "Basic")
|
||||
assert slack_email not in basic_emails
|
||||
|
||||
upgraded = UserManager.create(email=slack_email)
|
||||
assert upgraded.role == UserRole.BASIC
|
||||
|
||||
perms = UserManager.get_permissions(upgraded)
|
||||
assert (
|
||||
"basic" in perms
|
||||
), f"Upgraded SLACK_USER should have 'basic' permission, got: {perms}"
|
||||
@@ -0,0 +1,54 @@
|
||||
"""Integration tests for default group reconciliation on user reactivation.
|
||||
|
||||
Verifies that:
|
||||
- A deactivated user retains default group membership after reactivation
|
||||
- Reactivation via the admin API reconciles missing group membership
|
||||
"""
|
||||
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.managers.user_group import UserGroupManager
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
def _get_default_group_member_emails(
|
||||
admin_user: DATestUser,
|
||||
group_name: str,
|
||||
) -> set[str]:
|
||||
"""Get the set of emails of all members in a named default group."""
|
||||
all_groups = UserGroupManager.get_all(admin_user, include_default=True)
|
||||
matched = [g for g in all_groups if g.is_default and g.name == group_name]
|
||||
assert matched, f"Default group '{group_name}' not found"
|
||||
return {u.email for u in matched[0].users}
|
||||
|
||||
|
||||
def test_reactivated_user_retains_default_group(
|
||||
reset: None, # noqa: ARG001
|
||||
) -> None:
|
||||
"""Deactivating and reactivating a user should preserve their
|
||||
default group membership."""
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
basic_user: DATestUser = UserManager.create(name="basic_user")
|
||||
|
||||
# Verify user is in Basic group initially
|
||||
basic_emails = _get_default_group_member_emails(admin_user, "Basic")
|
||||
assert basic_user.email in basic_emails
|
||||
|
||||
# Deactivate the user
|
||||
UserManager.set_status(
|
||||
user_to_set=basic_user,
|
||||
target_status=False,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Reactivate the user
|
||||
UserManager.set_status(
|
||||
user_to_set=basic_user,
|
||||
target_status=True,
|
||||
user_performing_action=admin_user,
|
||||
)
|
||||
|
||||
# Verify user is still in Basic group after reactivation
|
||||
basic_emails = _get_default_group_member_emails(admin_user, "Basic")
|
||||
assert (
|
||||
basic_user.email in basic_emails
|
||||
), "Reactivated user should still be in Basic default group"
|
||||
176
backend/tests/unit/onyx/auth/test_permissions.py
Normal file
176
backend/tests/unit/onyx/auth/test_permissions.py
Normal file
@@ -0,0 +1,176 @@
|
||||
"""
|
||||
Unit tests for onyx.auth.permissions — pure logic and FastAPI dependency.
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.auth.permissions import ALL_PERMISSIONS
|
||||
from onyx.auth.permissions import get_effective_permissions
|
||||
from onyx.auth.permissions import require_permission
|
||||
from onyx.auth.permissions import resolve_effective_permissions
|
||||
from onyx.db.enums import Permission
|
||||
from onyx.error_handling.error_codes import OnyxErrorCode
|
||||
from onyx.error_handling.exceptions import OnyxError
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# resolve_effective_permissions
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResolveEffectivePermissions:
|
||||
def test_empty_set(self) -> None:
|
||||
assert resolve_effective_permissions(set()) == set()
|
||||
|
||||
def test_basic_no_implications(self) -> None:
|
||||
result = resolve_effective_permissions({"basic"})
|
||||
assert result == {"basic"}
|
||||
|
||||
def test_single_implication(self) -> None:
|
||||
result = resolve_effective_permissions({"add:agents"})
|
||||
assert result == {"add:agents", "read:agents"}
|
||||
|
||||
def test_manage_agents_implies_add_and_read(self) -> None:
|
||||
"""manage:agents directly maps to {add:agents, read:agents}."""
|
||||
result = resolve_effective_permissions({"manage:agents"})
|
||||
assert result == {"manage:agents", "add:agents", "read:agents"}
|
||||
|
||||
def test_manage_connectors_chain(self) -> None:
|
||||
result = resolve_effective_permissions({"manage:connectors"})
|
||||
assert result == {"manage:connectors", "add:connectors", "read:connectors"}
|
||||
|
||||
def test_manage_document_sets(self) -> None:
|
||||
result = resolve_effective_permissions({"manage:document_sets"})
|
||||
assert result == {
|
||||
"manage:document_sets",
|
||||
"read:document_sets",
|
||||
"read:connectors",
|
||||
}
|
||||
|
||||
def test_manage_user_groups_implies_all_reads(self) -> None:
|
||||
result = resolve_effective_permissions({"manage:user_groups"})
|
||||
assert result == {
|
||||
"manage:user_groups",
|
||||
"read:connectors",
|
||||
"read:document_sets",
|
||||
"read:agents",
|
||||
"read:users",
|
||||
}
|
||||
|
||||
def test_admin_override(self) -> None:
|
||||
result = resolve_effective_permissions({"admin"})
|
||||
assert result == set(ALL_PERMISSIONS)
|
||||
|
||||
def test_admin_with_others(self) -> None:
|
||||
result = resolve_effective_permissions({"admin", "basic"})
|
||||
assert result == set(ALL_PERMISSIONS)
|
||||
|
||||
def test_multi_group_union(self) -> None:
|
||||
result = resolve_effective_permissions(
|
||||
{"add:agents", "manage:connectors", "basic"}
|
||||
)
|
||||
assert result == {
|
||||
"basic",
|
||||
"add:agents",
|
||||
"read:agents",
|
||||
"manage:connectors",
|
||||
"add:connectors",
|
||||
"read:connectors",
|
||||
}
|
||||
|
||||
def test_toggle_permission_no_implications(self) -> None:
|
||||
result = resolve_effective_permissions({"read:agent_analytics"})
|
||||
assert result == {"read:agent_analytics"}
|
||||
|
||||
def test_all_permissions_for_admin(self) -> None:
|
||||
result = resolve_effective_permissions({"admin"})
|
||||
assert len(result) == len(ALL_PERMISSIONS)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_effective_permissions (expands implied at read time)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestGetEffectivePermissions:
|
||||
def test_expands_implied_permissions(self) -> None:
|
||||
"""Column stores only granted; get_effective_permissions expands implied."""
|
||||
user = MagicMock()
|
||||
user.effective_permissions = ["add:agents"]
|
||||
result = get_effective_permissions(user)
|
||||
assert result == {Permission.ADD_AGENTS, Permission.READ_AGENTS}
|
||||
|
||||
def test_admin_expands_to_all(self) -> None:
|
||||
user = MagicMock()
|
||||
user.effective_permissions = ["admin"]
|
||||
result = get_effective_permissions(user)
|
||||
assert result == set(Permission)
|
||||
|
||||
def test_basic_stays_basic(self) -> None:
|
||||
user = MagicMock()
|
||||
user.effective_permissions = ["basic"]
|
||||
result = get_effective_permissions(user)
|
||||
assert result == {Permission.BASIC_ACCESS}
|
||||
|
||||
def test_empty_column(self) -> None:
|
||||
user = MagicMock()
|
||||
user.effective_permissions = []
|
||||
result = get_effective_permissions(user)
|
||||
assert result == set()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# require_permission (FastAPI dependency)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRequirePermission:
|
||||
@pytest.mark.asyncio
|
||||
async def test_admin_bypass(self) -> None:
|
||||
"""Admin stored in column should pass any permission check."""
|
||||
user = MagicMock()
|
||||
user.effective_permissions = ["admin"]
|
||||
|
||||
dep = require_permission(Permission.MANAGE_CONNECTORS)
|
||||
result = await dep(user=user)
|
||||
assert result is user
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_has_required_permission(self) -> None:
|
||||
user = MagicMock()
|
||||
user.effective_permissions = ["manage:connectors"]
|
||||
|
||||
dep = require_permission(Permission.MANAGE_CONNECTORS)
|
||||
result = await dep(user=user)
|
||||
assert result is user
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_implied_permission_passes(self) -> None:
|
||||
"""manage:connectors implies read:connectors at read time."""
|
||||
user = MagicMock()
|
||||
user.effective_permissions = ["manage:connectors"]
|
||||
|
||||
dep = require_permission(Permission.READ_CONNECTORS)
|
||||
result = await dep(user=user)
|
||||
assert result is user
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_permission_raises(self) -> None:
|
||||
user = MagicMock()
|
||||
user.effective_permissions = ["basic"]
|
||||
|
||||
dep = require_permission(Permission.MANAGE_CONNECTORS)
|
||||
with pytest.raises(OnyxError) as exc_info:
|
||||
await dep(user=user)
|
||||
assert exc_info.value.error_code == OnyxErrorCode.INSUFFICIENT_PERMISSIONS
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_permissions_fails(self) -> None:
|
||||
user = MagicMock()
|
||||
user.effective_permissions = []
|
||||
|
||||
dep = require_permission(Permission.BASIC_ACCESS)
|
||||
with pytest.raises(OnyxError):
|
||||
await dep(user=user)
|
||||
29
backend/tests/unit/onyx/auth/test_user_create_schema.py
Normal file
29
backend/tests/unit/onyx/auth/test_user_create_schema.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""
|
||||
Unit tests for UserCreate schema dict methods.
|
||||
|
||||
Verifies that account_type is always included in create_update_dict
|
||||
and create_update_dict_superuser.
|
||||
"""
|
||||
|
||||
from onyx.auth.schemas import UserCreate
|
||||
from onyx.db.enums import AccountType
|
||||
|
||||
|
||||
def test_create_update_dict_includes_default_account_type() -> None:
|
||||
uc = UserCreate(email="a@b.com", password="secret123")
|
||||
d = uc.create_update_dict()
|
||||
assert d["account_type"] == AccountType.STANDARD
|
||||
|
||||
|
||||
def test_create_update_dict_includes_explicit_account_type() -> None:
|
||||
uc = UserCreate(
|
||||
email="a@b.com", password="secret123", account_type=AccountType.SERVICE_ACCOUNT
|
||||
)
|
||||
d = uc.create_update_dict()
|
||||
assert d["account_type"] == AccountType.STANDARD
|
||||
|
||||
|
||||
def test_create_update_dict_superuser_includes_account_type() -> None:
|
||||
uc = UserCreate(email="a@b.com", password="secret123")
|
||||
d = uc.create_update_dict_superuser()
|
||||
assert d["account_type"] == AccountType.STANDARD
|
||||
176
backend/tests/unit/onyx/db/test_assign_default_groups.py
Normal file
176
backend/tests/unit/onyx/db/test_assign_default_groups.py
Normal file
@@ -0,0 +1,176 @@
|
||||
"""
|
||||
Unit tests for assign_user_to_default_groups__no_commit in onyx.db.users.
|
||||
|
||||
Covers:
|
||||
1. Standard/service-account users get assigned to the correct default group
|
||||
2. BOT, EXT_PERM_USER, ANONYMOUS account types are skipped
|
||||
3. Missing default group raises RuntimeError
|
||||
4. Already-in-group is a no-op
|
||||
5. IntegrityError race condition is handled gracefully
|
||||
6. The function never commits the session
|
||||
"""
|
||||
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.users import assign_user_to_default_groups__no_commit
|
||||
|
||||
|
||||
def _mock_user(
|
||||
account_type: AccountType = AccountType.STANDARD,
|
||||
email: str = "test@example.com",
|
||||
) -> MagicMock:
|
||||
user = MagicMock()
|
||||
user.id = uuid4()
|
||||
user.email = email
|
||||
user.account_type = account_type
|
||||
return user
|
||||
|
||||
|
||||
def _mock_group(name: str = "Basic", group_id: int = 1) -> MagicMock:
|
||||
group = MagicMock()
|
||||
group.id = group_id
|
||||
group.name = name
|
||||
group.is_default = True
|
||||
return group
|
||||
|
||||
|
||||
def _make_query_chain(first_return: object = None) -> MagicMock:
|
||||
"""Returns a mock that supports .filter(...).filter(...).first() chaining."""
|
||||
chain = MagicMock()
|
||||
chain.filter.return_value = chain
|
||||
chain.first.return_value = first_return
|
||||
return chain
|
||||
|
||||
|
||||
def _setup_db_session(
|
||||
group_result: object = None,
|
||||
membership_result: object = None,
|
||||
) -> MagicMock:
|
||||
"""Create a db_session mock that routes query(UserGroup) and query(User__UserGroup)."""
|
||||
db_session = MagicMock()
|
||||
|
||||
group_chain = _make_query_chain(group_result)
|
||||
membership_chain = _make_query_chain(membership_result)
|
||||
|
||||
def query_side_effect(model: type) -> MagicMock:
|
||||
if model is UserGroup:
|
||||
return group_chain
|
||||
if model is User__UserGroup:
|
||||
return membership_chain
|
||||
return MagicMock()
|
||||
|
||||
db_session.query.side_effect = query_side_effect
|
||||
return db_session
|
||||
|
||||
|
||||
def test_standard_user_assigned_to_basic_group() -> None:
|
||||
group = _mock_group("Basic")
|
||||
db_session = _setup_db_session(group_result=group, membership_result=None)
|
||||
savepoint = MagicMock()
|
||||
db_session.begin_nested.return_value = savepoint
|
||||
user = _mock_user(AccountType.STANDARD)
|
||||
|
||||
assign_user_to_default_groups__no_commit(db_session, user, is_admin=False)
|
||||
|
||||
db_session.add.assert_called_once()
|
||||
added = db_session.add.call_args[0][0]
|
||||
assert isinstance(added, User__UserGroup)
|
||||
assert added.user_id == user.id
|
||||
assert added.user_group_id == group.id
|
||||
db_session.flush.assert_called_once()
|
||||
|
||||
|
||||
def test_admin_user_assigned_to_admin_group() -> None:
|
||||
group = _mock_group("Admin", group_id=2)
|
||||
db_session = _setup_db_session(group_result=group, membership_result=None)
|
||||
savepoint = MagicMock()
|
||||
db_session.begin_nested.return_value = savepoint
|
||||
user = _mock_user(AccountType.STANDARD)
|
||||
|
||||
assign_user_to_default_groups__no_commit(db_session, user, is_admin=True)
|
||||
|
||||
db_session.add.assert_called_once()
|
||||
added = db_session.add.call_args[0][0]
|
||||
assert isinstance(added, User__UserGroup)
|
||||
assert added.user_group_id == group.id
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"account_type",
|
||||
[AccountType.BOT, AccountType.EXT_PERM_USER, AccountType.ANONYMOUS],
|
||||
)
|
||||
def test_excluded_account_types_skipped(account_type: AccountType) -> None:
|
||||
db_session = MagicMock()
|
||||
user = _mock_user(account_type)
|
||||
|
||||
assign_user_to_default_groups__no_commit(db_session, user)
|
||||
|
||||
db_session.query.assert_not_called()
|
||||
db_session.add.assert_not_called()
|
||||
|
||||
|
||||
def test_service_account_not_skipped() -> None:
|
||||
group = _mock_group("Basic")
|
||||
db_session = _setup_db_session(group_result=group, membership_result=None)
|
||||
savepoint = MagicMock()
|
||||
db_session.begin_nested.return_value = savepoint
|
||||
user = _mock_user(AccountType.SERVICE_ACCOUNT)
|
||||
|
||||
assign_user_to_default_groups__no_commit(db_session, user, is_admin=False)
|
||||
|
||||
db_session.add.assert_called_once()
|
||||
|
||||
|
||||
def test_missing_default_group_raises_error() -> None:
|
||||
db_session = _setup_db_session(group_result=None)
|
||||
user = _mock_user()
|
||||
|
||||
with pytest.raises(RuntimeError, match="Default group .* not found"):
|
||||
assign_user_to_default_groups__no_commit(db_session, user)
|
||||
|
||||
|
||||
def test_already_in_group_is_noop() -> None:
|
||||
group = _mock_group("Basic")
|
||||
existing_membership = MagicMock()
|
||||
db_session = _setup_db_session(
|
||||
group_result=group, membership_result=existing_membership
|
||||
)
|
||||
user = _mock_user()
|
||||
|
||||
assign_user_to_default_groups__no_commit(db_session, user)
|
||||
|
||||
db_session.add.assert_not_called()
|
||||
db_session.begin_nested.assert_not_called()
|
||||
|
||||
|
||||
def test_integrity_error_race_condition_handled() -> None:
|
||||
group = _mock_group("Basic")
|
||||
db_session = _setup_db_session(group_result=group, membership_result=None)
|
||||
savepoint = MagicMock()
|
||||
db_session.begin_nested.return_value = savepoint
|
||||
db_session.flush.side_effect = IntegrityError(None, None, Exception("duplicate"))
|
||||
user = _mock_user()
|
||||
|
||||
# Should not raise
|
||||
assign_user_to_default_groups__no_commit(db_session, user)
|
||||
|
||||
savepoint.rollback.assert_called_once()
|
||||
|
||||
|
||||
def test_no_commit_called_on_successful_assignment() -> None:
|
||||
group = _mock_group("Basic")
|
||||
db_session = _setup_db_session(group_result=group, membership_result=None)
|
||||
savepoint = MagicMock()
|
||||
db_session.begin_nested.return_value = savepoint
|
||||
user = _mock_user()
|
||||
|
||||
assign_user_to_default_groups__no_commit(db_session, user)
|
||||
|
||||
db_session.commit.assert_not_called()
|
||||
@@ -113,6 +113,7 @@ def make_db_group(**kwargs: Any) -> MagicMock:
|
||||
group.name = kwargs.get("name", "Engineering")
|
||||
group.is_up_for_deletion = kwargs.get("is_up_for_deletion", False)
|
||||
group.is_up_to_date = kwargs.get("is_up_to_date", True)
|
||||
group.is_default = kwargs.get("is_default", False)
|
||||
return group
|
||||
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ from unittest.mock import MagicMock
|
||||
from uuid import uuid4
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
from onyx.db.enums import AccountType
|
||||
from onyx.server.models import FullUserSnapshot
|
||||
from onyx.server.models import UserGroupInfo
|
||||
|
||||
@@ -25,6 +26,7 @@ def _mock_user(
|
||||
user.updated_at = updated_at or datetime.datetime(
|
||||
2025, 6, 15, tzinfo=datetime.timezone.utc
|
||||
)
|
||||
user.account_type = AccountType.STANDARD
|
||||
return user
|
||||
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/api"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/config"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/exitcodes"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
@@ -16,16 +17,23 @@ func newAgentsCmd() *cobra.Command {
|
||||
cmd := &cobra.Command{
|
||||
Use: "agents",
|
||||
Short: "List available agents",
|
||||
Long: `List all visible agents configured on the Onyx server.
|
||||
|
||||
By default, output is a human-readable table with ID, name, and description.
|
||||
Use --json for machine-readable output.`,
|
||||
Example: ` onyx-cli agents
|
||||
onyx-cli agents --json
|
||||
onyx-cli agents --json | jq '.[].name'`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
cfg := config.Load()
|
||||
if !cfg.IsConfigured() {
|
||||
return fmt.Errorf("onyx CLI is not configured — run 'onyx-cli configure' first")
|
||||
return exitcodes.New(exitcodes.NotConfigured, "onyx CLI is not configured\n Run: onyx-cli configure")
|
||||
}
|
||||
|
||||
client := api.NewClient(cfg)
|
||||
agents, err := client.ListAgents(cmd.Context())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to list agents: %w", err)
|
||||
return fmt.Errorf("failed to list agents: %w\n Check your connection with: onyx-cli validate-config", err)
|
||||
}
|
||||
|
||||
if agentsJSON {
|
||||
|
||||
140
cli/cmd/ask.go
140
cli/cmd/ask.go
@@ -4,33 +4,65 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/signal"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/api"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/config"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/exitcodes"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/models"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/overflow"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
const defaultMaxOutputBytes = 4096
|
||||
|
||||
func newAskCmd() *cobra.Command {
|
||||
var (
|
||||
askAgentID int
|
||||
askJSON bool
|
||||
askQuiet bool
|
||||
askPrompt string
|
||||
maxOutput int
|
||||
)
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "ask [question]",
|
||||
Short: "Ask a one-shot question (non-interactive)",
|
||||
Args: cobra.ExactArgs(1),
|
||||
Long: `Send a one-shot question to an Onyx agent and print the response.
|
||||
|
||||
The question can be provided as a positional argument, via --prompt, or piped
|
||||
through stdin. When stdin contains piped data, it is sent as context along
|
||||
with the question from --prompt (or used as the question itself).
|
||||
|
||||
When stdout is not a TTY (e.g., called by a script or AI agent), output is
|
||||
automatically truncated to --max-output bytes and the full response is saved
|
||||
to a temp file. Set --max-output 0 to disable truncation.`,
|
||||
Args: cobra.MaximumNArgs(1),
|
||||
Example: ` onyx-cli ask "What connectors are available?"
|
||||
onyx-cli ask --agent-id 3 "Summarize our Q4 revenue"
|
||||
onyx-cli ask --json "List all users" | jq '.event.content'
|
||||
cat error.log | onyx-cli ask --prompt "Find the root cause"
|
||||
echo "what is onyx?" | onyx-cli ask`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
cfg := config.Load()
|
||||
if !cfg.IsConfigured() {
|
||||
return fmt.Errorf("onyx CLI is not configured — run 'onyx-cli configure' first")
|
||||
return exitcodes.New(exitcodes.NotConfigured, "onyx CLI is not configured\n Run: onyx-cli configure")
|
||||
}
|
||||
|
||||
if askJSON && askQuiet {
|
||||
return exitcodes.New(exitcodes.BadRequest, "--json and --quiet cannot be used together")
|
||||
}
|
||||
|
||||
question, err := resolveQuestion(args, askPrompt)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
question := args[0]
|
||||
agentID := cfg.DefaultAgentID
|
||||
if cmd.Flags().Changed("agent-id") {
|
||||
agentID = askAgentID
|
||||
@@ -50,9 +82,23 @@ func newAskCmd() *cobra.Command {
|
||||
nil,
|
||||
)
|
||||
|
||||
// Determine truncation threshold.
|
||||
isTTY := term.IsTerminal(int(os.Stdout.Fd()))
|
||||
truncateAt := 0 // 0 means no truncation
|
||||
if cmd.Flags().Changed("max-output") {
|
||||
truncateAt = maxOutput
|
||||
} else if !isTTY {
|
||||
truncateAt = defaultMaxOutputBytes
|
||||
}
|
||||
|
||||
var sessionID string
|
||||
var lastErr error
|
||||
gotStop := false
|
||||
|
||||
// Overflow writer: tees to stdout and optionally to a temp file.
|
||||
// In quiet mode, buffer everything and print once at the end.
|
||||
ow := &overflow.Writer{Limit: truncateAt, Quiet: askQuiet}
|
||||
|
||||
for event := range ch {
|
||||
if e, ok := event.(models.SessionCreatedEvent); ok {
|
||||
sessionID = e.ChatSessionID
|
||||
@@ -82,22 +128,50 @@ func newAskCmd() *cobra.Command {
|
||||
|
||||
switch e := event.(type) {
|
||||
case models.MessageDeltaEvent:
|
||||
fmt.Print(e.Content)
|
||||
ow.Write(e.Content)
|
||||
case models.SearchStartEvent:
|
||||
if isTTY && !askQuiet {
|
||||
if e.IsInternetSearch {
|
||||
fmt.Fprintf(os.Stderr, "\033[2mSearching the web...\033[0m\n")
|
||||
} else {
|
||||
fmt.Fprintf(os.Stderr, "\033[2mSearching documents...\033[0m\n")
|
||||
}
|
||||
}
|
||||
case models.SearchQueriesEvent:
|
||||
if isTTY && !askQuiet {
|
||||
for _, q := range e.Queries {
|
||||
fmt.Fprintf(os.Stderr, "\033[2m → %s\033[0m\n", q)
|
||||
}
|
||||
}
|
||||
case models.SearchDocumentsEvent:
|
||||
if isTTY && !askQuiet && len(e.Documents) > 0 {
|
||||
fmt.Fprintf(os.Stderr, "\033[2mFound %d documents\033[0m\n", len(e.Documents))
|
||||
}
|
||||
case models.ReasoningStartEvent:
|
||||
if isTTY && !askQuiet {
|
||||
fmt.Fprintf(os.Stderr, "\033[2mThinking...\033[0m\n")
|
||||
}
|
||||
case models.ToolStartEvent:
|
||||
if isTTY && !askQuiet && e.ToolName != "" {
|
||||
fmt.Fprintf(os.Stderr, "\033[2mUsing %s...\033[0m\n", e.ToolName)
|
||||
}
|
||||
case models.ErrorEvent:
|
||||
ow.Finish()
|
||||
return fmt.Errorf("%s", e.Error)
|
||||
case models.StopEvent:
|
||||
fmt.Println()
|
||||
ow.Finish()
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
if !askJSON {
|
||||
ow.Finish()
|
||||
}
|
||||
|
||||
if ctx.Err() != nil {
|
||||
if sessionID != "" {
|
||||
client.StopChatSession(context.Background(), sessionID)
|
||||
}
|
||||
if !askJSON {
|
||||
fmt.Println()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -105,20 +179,56 @@ func newAskCmd() *cobra.Command {
|
||||
return lastErr
|
||||
}
|
||||
if !gotStop {
|
||||
if !askJSON {
|
||||
fmt.Println()
|
||||
}
|
||||
return fmt.Errorf("stream ended unexpectedly")
|
||||
}
|
||||
if !askJSON {
|
||||
fmt.Println()
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().IntVar(&askAgentID, "agent-id", 0, "Agent ID to use")
|
||||
cmd.Flags().BoolVar(&askJSON, "json", false, "Output raw JSON events")
|
||||
// Suppress cobra's default error/usage on RunE errors
|
||||
cmd.Flags().BoolVarP(&askQuiet, "quiet", "q", false, "Buffer output and print once at end (no streaming)")
|
||||
cmd.Flags().StringVar(&askPrompt, "prompt", "", "Question text (use with piped stdin context)")
|
||||
cmd.Flags().IntVar(&maxOutput, "max-output", defaultMaxOutputBytes,
|
||||
"Max bytes to print before truncating (0 to disable, auto-enabled for non-TTY)")
|
||||
return cmd
|
||||
}
|
||||
|
||||
// resolveQuestion builds the final question string from args, --prompt, and stdin.
|
||||
func resolveQuestion(args []string, prompt string) (string, error) {
|
||||
hasArg := len(args) > 0
|
||||
hasPrompt := prompt != ""
|
||||
hasStdin := !term.IsTerminal(int(os.Stdin.Fd()))
|
||||
|
||||
if hasArg && hasPrompt {
|
||||
return "", exitcodes.New(exitcodes.BadRequest, "specify the question as an argument or --prompt, not both")
|
||||
}
|
||||
|
||||
var stdinContent string
|
||||
if hasStdin {
|
||||
const maxStdinBytes = 10 * 1024 * 1024 // 10MB
|
||||
data, err := io.ReadAll(io.LimitReader(os.Stdin, maxStdinBytes))
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to read stdin: %w", err)
|
||||
}
|
||||
stdinContent = strings.TrimSpace(string(data))
|
||||
}
|
||||
|
||||
switch {
|
||||
case hasArg && stdinContent != "":
|
||||
// arg is the question, stdin is context
|
||||
return args[0] + "\n\n" + stdinContent, nil
|
||||
case hasArg:
|
||||
return args[0], nil
|
||||
case hasPrompt && stdinContent != "":
|
||||
// --prompt is the question, stdin is context
|
||||
return prompt + "\n\n" + stdinContent, nil
|
||||
case hasPrompt:
|
||||
return prompt, nil
|
||||
case stdinContent != "":
|
||||
return stdinContent, nil
|
||||
default:
|
||||
return "", exitcodes.New(exitcodes.BadRequest, "no question provided\n Usage: onyx-cli ask \"your question\"\n Or: echo \"context\" | onyx-cli ask --prompt \"your question\"")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import (
|
||||
tea "github.com/charmbracelet/bubbletea"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/config"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/onboarding"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/starprompt"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/tui"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
@@ -12,6 +13,11 @@ func newChatCmd() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: "chat",
|
||||
Short: "Launch the interactive chat TUI (default)",
|
||||
Long: `Launch the interactive terminal UI for chatting with your Onyx agent.
|
||||
This is the default command when no subcommand is specified. On first run,
|
||||
an interactive setup wizard will guide you through configuration.`,
|
||||
Example: ` onyx-cli chat
|
||||
onyx-cli`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
cfg := config.Load()
|
||||
|
||||
@@ -24,6 +30,8 @@ func newChatCmd() *cobra.Command {
|
||||
cfg = *result
|
||||
}
|
||||
|
||||
starprompt.MaybePrompt()
|
||||
|
||||
m := tui.NewModel(cfg)
|
||||
p := tea.NewProgram(m, tea.WithAltScreen(), tea.WithMouseCellMotion())
|
||||
_, err := p.Run()
|
||||
|
||||
@@ -1,19 +1,126 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/api"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/config"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/exitcodes"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/onboarding"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
func newConfigureCmd() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
var (
|
||||
serverURL string
|
||||
apiKey string
|
||||
apiKeyStdin bool
|
||||
dryRun bool
|
||||
)
|
||||
|
||||
cmd := &cobra.Command{
|
||||
Use: "configure",
|
||||
Short: "Configure server URL and API key",
|
||||
Long: `Set up the Onyx CLI with your server URL and API key.
|
||||
|
||||
When --server-url and --api-key are both provided, the configuration is saved
|
||||
non-interactively (useful for scripts and AI agents). Otherwise, an interactive
|
||||
setup wizard is launched.
|
||||
|
||||
If --api-key is omitted but stdin has piped data, the API key is read from
|
||||
stdin automatically. You can also use --api-key-stdin to make this explicit.
|
||||
This avoids leaking the key in shell history.
|
||||
|
||||
Use --dry-run to test the connection without saving the configuration.`,
|
||||
Example: ` onyx-cli configure
|
||||
onyx-cli configure --server-url https://my-onyx.com --api-key sk-...
|
||||
echo "$ONYX_API_KEY" | onyx-cli configure --server-url https://my-onyx.com
|
||||
echo "$ONYX_API_KEY" | onyx-cli configure --server-url https://my-onyx.com --api-key-stdin
|
||||
onyx-cli configure --server-url https://my-onyx.com --api-key sk-... --dry-run`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
// Read API key from stdin if piped (implicit) or --api-key-stdin (explicit)
|
||||
if apiKeyStdin && apiKey != "" {
|
||||
return exitcodes.New(exitcodes.BadRequest, "--api-key and --api-key-stdin cannot be used together")
|
||||
}
|
||||
if (apiKey == "" && !term.IsTerminal(int(os.Stdin.Fd()))) || apiKeyStdin {
|
||||
data, err := io.ReadAll(os.Stdin)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to read API key from stdin: %w", err)
|
||||
}
|
||||
apiKey = strings.TrimSpace(string(data))
|
||||
}
|
||||
|
||||
if serverURL != "" && apiKey != "" {
|
||||
return configureNonInteractive(serverURL, apiKey, dryRun)
|
||||
}
|
||||
|
||||
if dryRun {
|
||||
return exitcodes.New(exitcodes.BadRequest, "--dry-run requires --server-url and --api-key")
|
||||
}
|
||||
|
||||
if serverURL != "" || apiKey != "" {
|
||||
return exitcodes.New(exitcodes.BadRequest, "both --server-url and --api-key are required for non-interactive setup\n Run 'onyx-cli configure' without flags for interactive setup")
|
||||
}
|
||||
|
||||
cfg := config.Load()
|
||||
onboarding.Run(&cfg)
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
cmd.Flags().StringVar(&serverURL, "server-url", "", "Onyx server URL (e.g., https://cloud.onyx.app)")
|
||||
cmd.Flags().StringVar(&apiKey, "api-key", "", "API key for authentication (or pipe via stdin)")
|
||||
cmd.Flags().BoolVar(&apiKeyStdin, "api-key-stdin", false, "Read API key from stdin (explicit; also happens automatically when stdin is piped)")
|
||||
cmd.Flags().BoolVar(&dryRun, "dry-run", false, "Test connection without saving config (requires --server-url and --api-key)")
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
func configureNonInteractive(serverURL, apiKey string, dryRun bool) error {
|
||||
cfg := config.OnyxCliConfig{
|
||||
ServerURL: serverURL,
|
||||
APIKey: apiKey,
|
||||
DefaultAgentID: 0,
|
||||
}
|
||||
|
||||
// Preserve existing default agent ID from disk (not env overrides)
|
||||
if existing := config.LoadFromDisk(); existing.DefaultAgentID != 0 {
|
||||
cfg.DefaultAgentID = existing.DefaultAgentID
|
||||
}
|
||||
|
||||
// Test connection
|
||||
client := api.NewClient(cfg)
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := client.TestConnection(ctx); err != nil {
|
||||
var authErr *api.AuthError
|
||||
if errors.As(err, &authErr) {
|
||||
return exitcodes.Newf(exitcodes.AuthFailure, "authentication failed: %v\n Check your API key", err)
|
||||
}
|
||||
return exitcodes.Newf(exitcodes.Unreachable, "connection failed: %v\n Check your server URL", err)
|
||||
}
|
||||
|
||||
if dryRun {
|
||||
fmt.Printf("Server: %s\n", serverURL)
|
||||
fmt.Println("Status: connected and authenticated")
|
||||
fmt.Println("Dry run: config was NOT saved")
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := config.Save(cfg); err != nil {
|
||||
return fmt.Errorf("could not save config: %w", err)
|
||||
}
|
||||
|
||||
fmt.Printf("Config: %s\n", config.ConfigFilePath())
|
||||
fmt.Printf("Server: %s\n", serverURL)
|
||||
fmt.Println("Status: connected and authenticated")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -23,6 +23,7 @@ import (
|
||||
"github.com/charmbracelet/wish/ratelimiter"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/api"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/config"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/exitcodes"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/tui"
|
||||
"github.com/spf13/cobra"
|
||||
"golang.org/x/time/rate"
|
||||
@@ -295,15 +296,15 @@ provided via the ONYX_API_KEY environment variable to skip the prompt:
|
||||
The server URL is taken from the server operator's config. The server
|
||||
auto-generates an Ed25519 host key on first run if the key file does not
|
||||
already exist. The host key path can also be set via the ONYX_SSH_HOST_KEY
|
||||
environment variable (the --host-key flag takes precedence).
|
||||
|
||||
Example:
|
||||
onyx-cli serve --port 2222
|
||||
ssh localhost -p 2222`,
|
||||
environment variable (the --host-key flag takes precedence).`,
|
||||
Example: ` onyx-cli serve --port 2222
|
||||
ssh localhost -p 2222
|
||||
onyx-cli serve --host 0.0.0.0 --port 2222
|
||||
onyx-cli serve --idle-timeout 30m --max-session-timeout 2h`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
serverCfg := config.Load()
|
||||
if serverCfg.ServerURL == "" {
|
||||
return fmt.Errorf("server URL is not configured; run 'onyx-cli configure' first")
|
||||
return exitcodes.New(exitcodes.NotConfigured, "server URL is not configured\n Run: onyx-cli configure")
|
||||
}
|
||||
if !cmd.Flags().Changed("host-key") {
|
||||
if v := os.Getenv(config.EnvSSHHostKey); v != "" {
|
||||
|
||||
@@ -2,11 +2,13 @@ package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/api"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/config"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/exitcodes"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/version"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
@@ -16,17 +18,21 @@ func newValidateConfigCmd() *cobra.Command {
|
||||
return &cobra.Command{
|
||||
Use: "validate-config",
|
||||
Short: "Validate configuration and test server connection",
|
||||
Long: `Check that the CLI is configured, the server is reachable, and the API key
|
||||
is valid. Also reports the server version and warns if it is below the
|
||||
minimum required.`,
|
||||
Example: ` onyx-cli validate-config`,
|
||||
RunE: func(cmd *cobra.Command, args []string) error {
|
||||
// Check config file
|
||||
if !config.ConfigExists() {
|
||||
return fmt.Errorf("config file not found at %s\n Run 'onyx-cli configure' to set up", config.ConfigFilePath())
|
||||
return exitcodes.Newf(exitcodes.NotConfigured, "config file not found at %s\n Run: onyx-cli configure", config.ConfigFilePath())
|
||||
}
|
||||
|
||||
cfg := config.Load()
|
||||
|
||||
// Check API key
|
||||
if !cfg.IsConfigured() {
|
||||
return fmt.Errorf("API key is missing\n Run 'onyx-cli configure' to set up")
|
||||
return exitcodes.New(exitcodes.NotConfigured, "API key is missing\n Run: onyx-cli configure")
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintf(cmd.OutOrStdout(), "Config: %s\n", config.ConfigFilePath())
|
||||
@@ -35,7 +41,11 @@ func newValidateConfigCmd() *cobra.Command {
|
||||
// Test connection
|
||||
client := api.NewClient(cfg)
|
||||
if err := client.TestConnection(cmd.Context()); err != nil {
|
||||
return fmt.Errorf("connection failed: %w", err)
|
||||
var authErr *api.AuthError
|
||||
if errors.As(err, &authErr) {
|
||||
return exitcodes.Newf(exitcodes.AuthFailure, "authentication failed: %v\n Reconfigure with: onyx-cli configure", err)
|
||||
}
|
||||
return exitcodes.Newf(exitcodes.Unreachable, "connection failed: %v\n Reconfigure with: onyx-cli configure", err)
|
||||
}
|
||||
|
||||
_, _ = fmt.Fprintln(cmd.OutOrStdout(), "Status: connected and authenticated")
|
||||
|
||||
@@ -149,12 +149,12 @@ func (c *Client) TestConnection(ctx context.Context) error {
|
||||
|
||||
if resp2.StatusCode == 401 || resp2.StatusCode == 403 {
|
||||
if isHTML || strings.Contains(respServer, "awselb") {
|
||||
return fmt.Errorf("HTTP %d from a reverse proxy (not the Onyx backend).\n Check your deployment's ingress / proxy configuration", resp2.StatusCode)
|
||||
return &AuthError{Message: fmt.Sprintf("HTTP %d from a reverse proxy (not the Onyx backend).\n Check your deployment's ingress / proxy configuration", resp2.StatusCode)}
|
||||
}
|
||||
if resp2.StatusCode == 401 {
|
||||
return fmt.Errorf("invalid API key or token.\n %s", body)
|
||||
return &AuthError{Message: fmt.Sprintf("invalid API key or token.\n %s", body)}
|
||||
}
|
||||
return fmt.Errorf("access denied — check that the API key is valid.\n %s", body)
|
||||
return &AuthError{Message: fmt.Sprintf("access denied — check that the API key is valid.\n %s", body)}
|
||||
}
|
||||
|
||||
detail := fmt.Sprintf("HTTP %d", resp2.StatusCode)
|
||||
|
||||
@@ -11,3 +11,12 @@ type OnyxAPIError struct {
|
||||
func (e *OnyxAPIError) Error() string {
|
||||
return fmt.Sprintf("HTTP %d: %s", e.StatusCode, e.Detail)
|
||||
}
|
||||
|
||||
// AuthError is returned when authentication or authorization fails.
|
||||
type AuthError struct {
|
||||
Message string
|
||||
}
|
||||
|
||||
func (e *AuthError) Error() string {
|
||||
return e.Message
|
||||
}
|
||||
|
||||
@@ -59,8 +59,10 @@ func ConfigExists() bool {
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// Load reads config from file and applies environment variable overrides.
|
||||
func Load() OnyxCliConfig {
|
||||
// LoadFromDisk reads config from the file only, without applying environment
|
||||
// variable overrides. Use this when you need the persisted config values
|
||||
// (e.g., to preserve them during a save operation).
|
||||
func LoadFromDisk() OnyxCliConfig {
|
||||
cfg := DefaultConfig()
|
||||
|
||||
data, err := os.ReadFile(ConfigFilePath())
|
||||
@@ -70,6 +72,13 @@ func Load() OnyxCliConfig {
|
||||
}
|
||||
}
|
||||
|
||||
return cfg
|
||||
}
|
||||
|
||||
// Load reads config from file and applies environment variable overrides.
|
||||
func Load() OnyxCliConfig {
|
||||
cfg := LoadFromDisk()
|
||||
|
||||
// Environment overrides
|
||||
if v := os.Getenv(EnvServerURL); v != "" {
|
||||
cfg.ServerURL = v
|
||||
|
||||
33
cli/internal/exitcodes/codes.go
Normal file
33
cli/internal/exitcodes/codes.go
Normal file
@@ -0,0 +1,33 @@
|
||||
// Package exitcodes defines semantic exit codes for the Onyx CLI.
|
||||
package exitcodes
|
||||
|
||||
import "fmt"
|
||||
|
||||
const (
|
||||
Success = 0
|
||||
General = 1
|
||||
BadRequest = 2 // invalid args / command-line errors (convention)
|
||||
NotConfigured = 3
|
||||
AuthFailure = 4
|
||||
Unreachable = 5
|
||||
)
|
||||
|
||||
// ExitError wraps an error with a specific exit code.
|
||||
type ExitError struct {
|
||||
Code int
|
||||
Err error
|
||||
}
|
||||
|
||||
func (e *ExitError) Error() string {
|
||||
return e.Err.Error()
|
||||
}
|
||||
|
||||
// New creates an ExitError with the given code and message.
|
||||
func New(code int, msg string) *ExitError {
|
||||
return &ExitError{Code: code, Err: fmt.Errorf("%s", msg)}
|
||||
}
|
||||
|
||||
// Newf creates an ExitError with a formatted message.
|
||||
func Newf(code int, format string, args ...any) *ExitError {
|
||||
return &ExitError{Code: code, Err: fmt.Errorf(format, args...)}
|
||||
}
|
||||
40
cli/internal/exitcodes/codes_test.go
Normal file
40
cli/internal/exitcodes/codes_test.go
Normal file
@@ -0,0 +1,40 @@
|
||||
package exitcodes
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestExitError_Error(t *testing.T) {
|
||||
e := New(NotConfigured, "not configured")
|
||||
if e.Error() != "not configured" {
|
||||
t.Fatalf("expected 'not configured', got %q", e.Error())
|
||||
}
|
||||
if e.Code != NotConfigured {
|
||||
t.Fatalf("expected code %d, got %d", NotConfigured, e.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExitError_Newf(t *testing.T) {
|
||||
e := Newf(Unreachable, "cannot reach %s", "server")
|
||||
if e.Error() != "cannot reach server" {
|
||||
t.Fatalf("expected 'cannot reach server', got %q", e.Error())
|
||||
}
|
||||
if e.Code != Unreachable {
|
||||
t.Fatalf("expected code %d, got %d", Unreachable, e.Code)
|
||||
}
|
||||
}
|
||||
|
||||
func TestExitError_ErrorsAs(t *testing.T) {
|
||||
e := New(BadRequest, "bad input")
|
||||
wrapped := fmt.Errorf("wrapper: %w", e)
|
||||
|
||||
var exitErr *ExitError
|
||||
if !errors.As(wrapped, &exitErr) {
|
||||
t.Fatal("errors.As should find ExitError")
|
||||
}
|
||||
if exitErr.Code != BadRequest {
|
||||
t.Fatalf("expected code %d, got %d", BadRequest, exitErr.Code)
|
||||
}
|
||||
}
|
||||
121
cli/internal/overflow/writer.go
Normal file
121
cli/internal/overflow/writer.go
Normal file
@@ -0,0 +1,121 @@
|
||||
// Package overflow provides a streaming writer that auto-truncates output
|
||||
// for non-TTY callers (e.g., AI agents, scripts). Full content is saved to
|
||||
// a temp file on disk; only the first N bytes are printed to stdout.
|
||||
package overflow
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Writer handles streaming output with optional truncation.
|
||||
// When Limit > 0, it streams to a temp file on disk (not memory) and stops
|
||||
// writing to stdout after Limit bytes. When Limit == 0, it writes directly
|
||||
// to stdout. In Quiet mode, it buffers in memory and prints once at the end.
|
||||
type Writer struct {
|
||||
Limit int
|
||||
Quiet bool
|
||||
written int
|
||||
totalBytes int
|
||||
truncated bool
|
||||
buf strings.Builder // used only in quiet mode
|
||||
tmpFile *os.File // used only in truncation mode (Limit > 0)
|
||||
}
|
||||
|
||||
// Write sends a chunk of content through the writer.
|
||||
func (w *Writer) Write(s string) {
|
||||
w.totalBytes += len(s)
|
||||
|
||||
// Quiet mode: buffer in memory, print nothing
|
||||
if w.Quiet {
|
||||
w.buf.WriteString(s)
|
||||
return
|
||||
}
|
||||
|
||||
if w.Limit <= 0 {
|
||||
fmt.Print(s)
|
||||
return
|
||||
}
|
||||
|
||||
// Truncation mode: stream all content to temp file on disk
|
||||
if w.tmpFile == nil {
|
||||
f, err := os.CreateTemp("", "onyx-ask-*.txt")
|
||||
if err != nil {
|
||||
// Fall back to no-truncation if we can't create the file
|
||||
fmt.Fprintf(os.Stderr, "warning: could not create temp file: %v\n", err)
|
||||
w.Limit = 0
|
||||
fmt.Print(s)
|
||||
return
|
||||
}
|
||||
w.tmpFile = f
|
||||
}
|
||||
if _, err := w.tmpFile.WriteString(s); err != nil {
|
||||
// Disk write failed — abandon truncation, stream directly to stdout
|
||||
fmt.Fprintf(os.Stderr, "warning: temp file write failed: %v\n", err)
|
||||
w.closeTmpFile(true)
|
||||
w.Limit = 0
|
||||
w.truncated = false
|
||||
fmt.Print(s)
|
||||
return
|
||||
}
|
||||
|
||||
if w.truncated {
|
||||
return
|
||||
}
|
||||
|
||||
remaining := w.Limit - w.written
|
||||
if len(s) <= remaining {
|
||||
fmt.Print(s)
|
||||
w.written += len(s)
|
||||
} else {
|
||||
if remaining > 0 {
|
||||
fmt.Print(s[:remaining])
|
||||
w.written += remaining
|
||||
}
|
||||
w.truncated = true
|
||||
}
|
||||
}
|
||||
|
||||
// Finish flushes remaining output. Call once after all Write calls are done.
|
||||
func (w *Writer) Finish() {
|
||||
// Quiet mode: print buffered content at once
|
||||
if w.Quiet {
|
||||
fmt.Println(w.buf.String())
|
||||
return
|
||||
}
|
||||
|
||||
if !w.truncated {
|
||||
w.closeTmpFile(true) // clean up unused temp file
|
||||
fmt.Println()
|
||||
return
|
||||
}
|
||||
|
||||
// Close the temp file so it's readable
|
||||
tmpPath := w.tmpFile.Name()
|
||||
w.closeTmpFile(false) // close but keep the file
|
||||
|
||||
fmt.Printf("\n\n--- response truncated (%d bytes total) ---\n", w.totalBytes)
|
||||
fmt.Printf("Full response: %s\n", tmpPath)
|
||||
fmt.Printf("Explore:\n")
|
||||
fmt.Printf(" cat %s | grep \"<pattern>\"\n", tmpPath)
|
||||
fmt.Printf(" cat %s | tail -50\n", tmpPath)
|
||||
}
|
||||
|
||||
// closeTmpFile closes and optionally removes the temp file.
|
||||
func (w *Writer) closeTmpFile(remove bool) {
|
||||
if w.tmpFile == nil {
|
||||
return
|
||||
}
|
||||
if err := w.tmpFile.Close(); err != nil {
|
||||
log.Debugf("warning: failed to close temp file: %v", err)
|
||||
}
|
||||
if remove {
|
||||
if err := os.Remove(w.tmpFile.Name()); err != nil {
|
||||
log.Debugf("warning: failed to remove temp file: %v", err)
|
||||
}
|
||||
}
|
||||
w.tmpFile = nil
|
||||
}
|
||||
95
cli/internal/overflow/writer_test.go
Normal file
95
cli/internal/overflow/writer_test.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package overflow
|
||||
|
||||
import (
|
||||
"os"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestWriter_NoLimit(t *testing.T) {
|
||||
w := &Writer{Limit: 0}
|
||||
w.Write("hello world")
|
||||
if w.truncated {
|
||||
t.Fatal("should not be truncated with limit 0")
|
||||
}
|
||||
if w.totalBytes != 11 {
|
||||
t.Fatalf("expected 11 total bytes, got %d", w.totalBytes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriter_UnderLimit(t *testing.T) {
|
||||
w := &Writer{Limit: 100}
|
||||
w.Write("hello")
|
||||
w.Write(" world")
|
||||
if w.truncated {
|
||||
t.Fatal("should not be truncated when under limit")
|
||||
}
|
||||
if w.written != 11 {
|
||||
t.Fatalf("expected 11 written bytes, got %d", w.written)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriter_OverLimit(t *testing.T) {
|
||||
w := &Writer{Limit: 5}
|
||||
w.Write("hello world") // 11 bytes, limit 5
|
||||
if !w.truncated {
|
||||
t.Fatal("should be truncated")
|
||||
}
|
||||
if w.written != 5 {
|
||||
t.Fatalf("expected 5 written bytes, got %d", w.written)
|
||||
}
|
||||
if w.totalBytes != 11 {
|
||||
t.Fatalf("expected 11 total bytes, got %d", w.totalBytes)
|
||||
}
|
||||
if w.tmpFile == nil {
|
||||
t.Fatal("temp file should have been created")
|
||||
}
|
||||
_ = w.tmpFile.Close()
|
||||
data, _ := os.ReadFile(w.tmpFile.Name())
|
||||
_ = os.Remove(w.tmpFile.Name())
|
||||
if string(data) != "hello world" {
|
||||
t.Fatalf("temp file should contain full content, got %q", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriter_MultipleChunks(t *testing.T) {
|
||||
w := &Writer{Limit: 10}
|
||||
w.Write("hello") // 5 bytes
|
||||
w.Write(" ") // 6 bytes
|
||||
w.Write("world") // 11 bytes, crosses limit
|
||||
w.Write("!") // 12 bytes, already truncated
|
||||
|
||||
if !w.truncated {
|
||||
t.Fatal("should be truncated")
|
||||
}
|
||||
if w.written != 10 {
|
||||
t.Fatalf("expected 10 written bytes, got %d", w.written)
|
||||
}
|
||||
if w.totalBytes != 12 {
|
||||
t.Fatalf("expected 12 total bytes, got %d", w.totalBytes)
|
||||
}
|
||||
if w.tmpFile == nil {
|
||||
t.Fatal("temp file should have been created")
|
||||
}
|
||||
_ = w.tmpFile.Close()
|
||||
data, _ := os.ReadFile(w.tmpFile.Name())
|
||||
_ = os.Remove(w.tmpFile.Name())
|
||||
if string(data) != "hello world!" {
|
||||
t.Fatalf("temp file should contain full content, got %q", string(data))
|
||||
}
|
||||
}
|
||||
|
||||
func TestWriter_QuietMode(t *testing.T) {
|
||||
w := &Writer{Limit: 0, Quiet: true}
|
||||
w.Write("hello")
|
||||
w.Write(" world")
|
||||
|
||||
if w.written != 0 {
|
||||
t.Fatalf("quiet mode should not write to stdout, got %d written", w.written)
|
||||
}
|
||||
if w.totalBytes != 11 {
|
||||
t.Fatalf("expected 11 total bytes, got %d", w.totalBytes)
|
||||
}
|
||||
if w.buf.String() != "hello world" {
|
||||
t.Fatalf("buffer should contain full content, got %q", w.buf.String())
|
||||
}
|
||||
}
|
||||
83
cli/internal/starprompt/starprompt.go
Normal file
83
cli/internal/starprompt/starprompt.go
Normal file
@@ -0,0 +1,83 @@
|
||||
// Package starprompt implements a one-time GitHub star prompt shown before the TUI.
|
||||
// Skipped when stdin/stdout is not a TTY, when gh CLI is not installed,
|
||||
// or when the user has already been prompted. State is stored in the
|
||||
// config directory so it shows at most once per user.
|
||||
package starprompt
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/config"
|
||||
"golang.org/x/term"
|
||||
)
|
||||
|
||||
const repo = "onyx-dot-app/onyx"
|
||||
|
||||
func statePath() string {
|
||||
return filepath.Join(config.ConfigDir(), ".star-prompted")
|
||||
}
|
||||
|
||||
func hasBeenPrompted() bool {
|
||||
_, err := os.Stat(statePath())
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func markPrompted() {
|
||||
_ = os.MkdirAll(config.ConfigDir(), 0o755)
|
||||
f, err := os.Create(statePath())
|
||||
if err == nil {
|
||||
_ = f.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func isGHInstalled() bool {
|
||||
_, err := exec.LookPath("gh")
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// MaybePrompt shows a one-time star prompt if conditions are met.
|
||||
// It is safe to call unconditionally — it no-ops when not appropriate.
|
||||
func MaybePrompt() {
|
||||
if !term.IsTerminal(int(os.Stdin.Fd())) || !term.IsTerminal(int(os.Stdout.Fd())) {
|
||||
return
|
||||
}
|
||||
if hasBeenPrompted() {
|
||||
return
|
||||
}
|
||||
if !isGHInstalled() {
|
||||
return
|
||||
}
|
||||
|
||||
// Mark before asking so Ctrl+C won't cause a re-prompt.
|
||||
markPrompted()
|
||||
|
||||
fmt.Print("Enjoying Onyx? Star the repo on GitHub? [Y/n] ")
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
answer, _ := reader.ReadString('\n')
|
||||
answer = strings.TrimSpace(strings.ToLower(answer))
|
||||
|
||||
if answer == "n" || answer == "no" {
|
||||
return
|
||||
}
|
||||
|
||||
cmd := exec.Command("gh", "api", "-X", "PUT", "/user/starred/"+repo)
|
||||
cmd.Env = append(os.Environ(), "GH_PAGER=")
|
||||
if devnull, err := os.Open(os.DevNull); err == nil {
|
||||
defer func() { _ = devnull.Close() }()
|
||||
cmd.Stdin = devnull
|
||||
cmd.Stdout = devnull
|
||||
cmd.Stderr = devnull
|
||||
}
|
||||
if err := cmd.Run(); err != nil {
|
||||
fmt.Println("Star us at: https://github.com/" + repo)
|
||||
} else {
|
||||
fmt.Println("Thanks for the star!")
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
}
|
||||
}
|
||||
@@ -1,10 +1,12 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/onyx-dot-app/onyx/cli/cmd"
|
||||
"github.com/onyx-dot-app/onyx/cli/internal/exitcodes"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -18,6 +20,10 @@ func main() {
|
||||
|
||||
if err := cmd.Execute(); err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error: %v\n", err)
|
||||
var exitErr *exitcodes.ExitError
|
||||
if errors.As(err, &exitErr) {
|
||||
os.Exit(exitErr.Code)
|
||||
}
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1302,4 +1302,18 @@ echo ""
|
||||
print_info "Refer to the README in the ${INSTALL_ROOT} directory for more information."
|
||||
echo ""
|
||||
print_info "For help or issues, contact: founders@onyx.app"
|
||||
echo ""
|
||||
echo ""
|
||||
|
||||
# --- GitHub star prompt (inspired by oh-my-codex) ---
|
||||
# Only prompt in interactive mode and only if gh CLI is available.
|
||||
# Uses the GitHub API directly (PUT /user/starred) like oh-my-codex.
|
||||
if is_interactive && command -v gh &>/dev/null; then
|
||||
prompt_yn_or_default "Enjoying Onyx? Star the repo on GitHub? [Y/n] " "Y"
|
||||
if [[ ! "$REPLY" =~ ^[Nn] ]]; then
|
||||
if GH_PAGER= gh api -X PUT /user/starred/onyx-dot-app/onyx < /dev/null >/dev/null 2>&1; then
|
||||
print_success "Thanks for the star!"
|
||||
else
|
||||
print_info "Star us at: https://github.com/onyx-dot-app/onyx"
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
@@ -193,9 +193,9 @@ hover, active, and disabled states.
|
||||
|
||||
### Disabled (`core/disabled/`)
|
||||
|
||||
Propagates disabled state via React context. `Interactive.Stateless` and `Interactive.Stateful`
|
||||
consume this automatically, so wrapping a subtree in `<Disabled disabled={true}>` disables all
|
||||
interactive descendants.
|
||||
A pure CSS wrapper that applies disabled visuals (`opacity-50`, `cursor-not-allowed`,
|
||||
`pointer-events: none`) to a single child element via Radix `Slot`. Has no React context —
|
||||
Interactive primitives and buttons manage their own disabled state via a `disabled` prop.
|
||||
|
||||
### Hoverable (`core/animations/`)
|
||||
|
||||
|
||||
@@ -1,9 +1,5 @@
|
||||
import "@opal/components/tooltip.css";
|
||||
import {
|
||||
Disabled,
|
||||
Interactive,
|
||||
type InteractiveStatelessProps,
|
||||
} from "@opal/core";
|
||||
import { Interactive, type InteractiveStatelessProps } from "@opal/core";
|
||||
import type {
|
||||
ContainerSizeVariants,
|
||||
ExtremaSizeVariants,
|
||||
@@ -49,7 +45,7 @@ type ButtonProps = InteractiveStatelessProps &
|
||||
/** Which side the tooltip appears on. */
|
||||
tooltipSide?: TooltipSide;
|
||||
|
||||
/** Wraps the button in a Disabled context. `false` overrides parent contexts. */
|
||||
/** Applies disabled styling and suppresses clicks. */
|
||||
disabled?: boolean;
|
||||
};
|
||||
|
||||
@@ -94,7 +90,11 @@ function Button({
|
||||
) : null;
|
||||
|
||||
const button = (
|
||||
<Interactive.Stateless type={type} {...interactiveProps}>
|
||||
<Interactive.Stateless
|
||||
type={type}
|
||||
disabled={disabled}
|
||||
{...interactiveProps}
|
||||
>
|
||||
<Interactive.Container
|
||||
type={type}
|
||||
border={interactiveProps.prominence === "secondary"}
|
||||
@@ -118,28 +118,24 @@ function Button({
|
||||
</Interactive.Stateless>
|
||||
);
|
||||
|
||||
const result = tooltip ? (
|
||||
<TooltipPrimitive.Root>
|
||||
<TooltipPrimitive.Trigger asChild>{button}</TooltipPrimitive.Trigger>
|
||||
<TooltipPrimitive.Portal>
|
||||
<TooltipPrimitive.Content
|
||||
className="opal-tooltip"
|
||||
side={tooltipSide}
|
||||
sideOffset={4}
|
||||
>
|
||||
{tooltip}
|
||||
</TooltipPrimitive.Content>
|
||||
</TooltipPrimitive.Portal>
|
||||
</TooltipPrimitive.Root>
|
||||
) : (
|
||||
button
|
||||
);
|
||||
|
||||
if (disabled != null) {
|
||||
return <Disabled disabled={disabled}>{result}</Disabled>;
|
||||
if (tooltip) {
|
||||
return (
|
||||
<TooltipPrimitive.Root>
|
||||
<TooltipPrimitive.Trigger asChild>{button}</TooltipPrimitive.Trigger>
|
||||
<TooltipPrimitive.Portal>
|
||||
<TooltipPrimitive.Content
|
||||
className="opal-tooltip"
|
||||
side={tooltipSide}
|
||||
sideOffset={4}
|
||||
>
|
||||
{tooltip}
|
||||
</TooltipPrimitive.Content>
|
||||
</TooltipPrimitive.Portal>
|
||||
</TooltipPrimitive.Root>
|
||||
);
|
||||
}
|
||||
|
||||
return result;
|
||||
return button;
|
||||
}
|
||||
|
||||
export { Button, type ButtonProps };
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import {
|
||||
Interactive,
|
||||
useDisabled,
|
||||
type InteractiveStatefulProps,
|
||||
type InteractiveStatefulInteraction,
|
||||
} from "@opal/core";
|
||||
@@ -74,6 +73,9 @@ type OpenButtonProps = Omit<InteractiveStatefulProps, "variant"> & {
|
||||
|
||||
/** Override the default rounding derived from `size`. */
|
||||
roundingVariant?: InteractiveContainerRoundingVariant;
|
||||
|
||||
/** Applies disabled styling and suppresses clicks. */
|
||||
disabled?: boolean;
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -92,10 +94,9 @@ function OpenButton({
|
||||
roundingVariant: roundingVariantOverride,
|
||||
interaction,
|
||||
variant = "select-heavy",
|
||||
disabled,
|
||||
...statefulProps
|
||||
}: OpenButtonProps) {
|
||||
const { isDisabled } = useDisabled();
|
||||
|
||||
// Derive open state: explicit prop → Radix data-state (injected via Slot chain)
|
||||
const dataState = (statefulProps as Record<string, unknown>)["data-state"] as
|
||||
| string
|
||||
@@ -119,6 +120,7 @@ function OpenButton({
|
||||
<Interactive.Stateful
|
||||
variant={variant}
|
||||
interaction={resolvedInteraction}
|
||||
disabled={disabled}
|
||||
{...statefulProps}
|
||||
>
|
||||
<Interactive.Container
|
||||
@@ -168,7 +170,7 @@ function OpenButton({
|
||||
);
|
||||
|
||||
const resolvedTooltip =
|
||||
tooltip ?? (foldable && isDisabled && children ? children : undefined);
|
||||
tooltip ?? (foldable && disabled && children ? children : undefined);
|
||||
|
||||
if (!resolvedTooltip) return button;
|
||||
|
||||
|
||||
@@ -1,11 +1,7 @@
|
||||
"use client";
|
||||
|
||||
import "@opal/components/buttons/select-button/styles.css";
|
||||
import {
|
||||
Interactive,
|
||||
useDisabled,
|
||||
type InteractiveStatefulProps,
|
||||
} from "@opal/core";
|
||||
import { Interactive, type InteractiveStatefulProps } from "@opal/core";
|
||||
import type {
|
||||
ContainerSizeVariants,
|
||||
ExtremaSizeVariants,
|
||||
@@ -64,6 +60,9 @@ type SelectButtonProps = InteractiveStatefulProps &
|
||||
|
||||
/** Which side the tooltip appears on. */
|
||||
tooltipSide?: TooltipSide;
|
||||
|
||||
/** Applies disabled styling and suppresses clicks. */
|
||||
disabled?: boolean;
|
||||
};
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -80,9 +79,9 @@ function SelectButton({
|
||||
width,
|
||||
tooltip,
|
||||
tooltipSide = "top",
|
||||
disabled,
|
||||
...statefulProps
|
||||
}: SelectButtonProps) {
|
||||
const { isDisabled } = useDisabled();
|
||||
const isLarge = size === "lg";
|
||||
|
||||
const labelEl = children ? (
|
||||
@@ -96,7 +95,7 @@ function SelectButton({
|
||||
) : null;
|
||||
|
||||
const button = (
|
||||
<Interactive.Stateful {...statefulProps}>
|
||||
<Interactive.Stateful disabled={disabled} {...statefulProps}>
|
||||
<Interactive.Container
|
||||
type={type}
|
||||
heightVariant={size}
|
||||
@@ -128,7 +127,7 @@ function SelectButton({
|
||||
);
|
||||
|
||||
const resolvedTooltip =
|
||||
tooltip ?? (foldable && isDisabled && children ? children : undefined);
|
||||
tooltip ?? (foldable && disabled && children ? children : undefined);
|
||||
|
||||
if (!resolvedTooltip) return button;
|
||||
|
||||
|
||||
64
web/lib/opal/src/components/buttons/sidebar-tab/README.md
Normal file
64
web/lib/opal/src/components/buttons/sidebar-tab/README.md
Normal file
@@ -0,0 +1,64 @@
|
||||
# SidebarTab
|
||||
|
||||
**Import:** `import { SidebarTab, type SidebarTabProps } from "@opal/components";`
|
||||
|
||||
A sidebar navigation tab built on `Interactive.Stateful` > `Interactive.Container`. Designed for admin and app sidebars.
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
div.relative
|
||||
└─ Interactive.Stateful <- variant (sidebar-heavy | sidebar-light), state, disabled
|
||||
└─ Interactive.Container <- rounding, height, width
|
||||
├─ Link? (absolute overlay for client-side navigation)
|
||||
├─ rightChildren? (absolute, above Link for inline actions)
|
||||
└─ ContentAction (icon + title + truncation spacer)
|
||||
```
|
||||
|
||||
- **`sidebar-heavy`** (default) — muted when unselected (text-03/text-02), bold when selected (text-04/text-03)
|
||||
- **`sidebar-light`** — uniformly muted across all states (text-02/text-02)
|
||||
- **Disabled** — both variants use text-02 foreground, transparent background, no hover/active states
|
||||
- **Navigation** uses an absolutely positioned `<Link>` overlay rather than `href` on the Interactive element, so `rightChildren` can sit above it with `pointer-events-auto`.
|
||||
|
||||
## Props
|
||||
|
||||
| Prop | Type | Default | Description |
|
||||
|------|------|---------|-------------|
|
||||
| `variant` | `"sidebar-heavy" \| "sidebar-light"` | `"sidebar-heavy"` | Sidebar color variant |
|
||||
| `selected` | `boolean` | `false` | Active/selected state |
|
||||
| `icon` | `IconFunctionComponent` | — | Left icon |
|
||||
| `children` | `ReactNode` | — | Label text or custom content |
|
||||
| `disabled` | `boolean` | `false` | Disables the tab |
|
||||
| `folded` | `boolean` | `false` | Collapses label, shows tooltip on hover |
|
||||
| `nested` | `boolean` | `false` | Renders spacer instead of icon for indented items |
|
||||
| `href` | `string` | — | Client-side navigation URL |
|
||||
| `onClick` | `MouseEventHandler` | — | Click handler |
|
||||
| `type` | `ButtonType` | — | HTML button type |
|
||||
| `rightChildren` | `ReactNode` | — | Actions rendered on the right side |
|
||||
|
||||
## Usage
|
||||
|
||||
```tsx
|
||||
import { SidebarTab } from "@opal/components";
|
||||
import { SvgSettings, SvgLock } from "@opal/icons";
|
||||
|
||||
// Active tab
|
||||
<SidebarTab icon={SvgSettings} href="/admin/settings" selected>
|
||||
Settings
|
||||
</SidebarTab>
|
||||
|
||||
// Muted variant
|
||||
<SidebarTab icon={SvgSettings} variant="sidebar-light">
|
||||
Exit Admin Panel
|
||||
</SidebarTab>
|
||||
|
||||
// Disabled enterprise-only tab
|
||||
<SidebarTab icon={SvgLock} disabled>
|
||||
Groups
|
||||
</SidebarTab>
|
||||
|
||||
// Folded sidebar (icon only, tooltip on hover)
|
||||
<SidebarTab icon={SvgSettings} href="/admin/settings" folded>
|
||||
Settings
|
||||
</SidebarTab>
|
||||
```
|
||||
@@ -0,0 +1,94 @@
|
||||
import type { Meta, StoryObj } from "@storybook/react";
|
||||
import { SidebarTab } from "@opal/components/buttons/sidebar-tab/components";
|
||||
import {
|
||||
SvgSettings,
|
||||
SvgUsers,
|
||||
SvgLock,
|
||||
SvgArrowUpCircle,
|
||||
SvgTrash,
|
||||
} from "@opal/icons";
|
||||
import { Button } from "@opal/components";
|
||||
import * as TooltipPrimitive from "@radix-ui/react-tooltip";
|
||||
|
||||
const meta: Meta<typeof SidebarTab> = {
|
||||
title: "opal/components/SidebarTab",
|
||||
component: SidebarTab,
|
||||
tags: ["autodocs"],
|
||||
decorators: [
|
||||
(Story) => (
|
||||
<TooltipPrimitive.Provider>
|
||||
<div style={{ width: 260, background: "var(--background-neutral-01)" }}>
|
||||
<Story />
|
||||
</div>
|
||||
</TooltipPrimitive.Provider>
|
||||
),
|
||||
],
|
||||
};
|
||||
|
||||
export default meta;
|
||||
type Story = StoryObj<typeof SidebarTab>;
|
||||
|
||||
export const Default: Story = {
|
||||
args: {
|
||||
icon: SvgSettings,
|
||||
children: "Settings",
|
||||
},
|
||||
};
|
||||
|
||||
export const Selected: Story = {
|
||||
args: {
|
||||
icon: SvgSettings,
|
||||
children: "Settings",
|
||||
selected: true,
|
||||
},
|
||||
};
|
||||
|
||||
export const Light: Story = {
|
||||
args: {
|
||||
icon: SvgSettings,
|
||||
children: "Settings",
|
||||
variant: "sidebar-light",
|
||||
},
|
||||
};
|
||||
|
||||
export const Disabled: Story = {
|
||||
args: {
|
||||
icon: SvgLock,
|
||||
children: "Enterprise Only",
|
||||
disabled: true,
|
||||
},
|
||||
};
|
||||
|
||||
export const WithRightChildren: Story = {
|
||||
args: {
|
||||
icon: SvgUsers,
|
||||
children: "Users",
|
||||
rightChildren: (
|
||||
<Button
|
||||
icon={SvgTrash}
|
||||
size="xs"
|
||||
prominence="tertiary"
|
||||
variant="danger"
|
||||
/>
|
||||
),
|
||||
},
|
||||
};
|
||||
|
||||
export const SidebarExample: Story = {
|
||||
render: () => (
|
||||
<div className="flex flex-col">
|
||||
<SidebarTab icon={SvgSettings} selected>
|
||||
LLM Models
|
||||
</SidebarTab>
|
||||
<SidebarTab icon={SvgSettings}>Web Search</SidebarTab>
|
||||
<SidebarTab icon={SvgUsers}>Users</SidebarTab>
|
||||
<SidebarTab icon={SvgLock} disabled>
|
||||
Groups
|
||||
</SidebarTab>
|
||||
<SidebarTab icon={SvgLock} disabled>
|
||||
SCIM
|
||||
</SidebarTab>
|
||||
<SidebarTab icon={SvgArrowUpCircle}>Upgrade Plan</SidebarTab>
|
||||
</div>
|
||||
),
|
||||
};
|
||||
@@ -3,32 +3,66 @@
|
||||
import React from "react";
|
||||
import type { ButtonType, IconFunctionComponent } from "@opal/types";
|
||||
import type { Route } from "next";
|
||||
import { Interactive } from "@opal/core";
|
||||
import { Interactive, type InteractiveStatefulVariant } from "@opal/core";
|
||||
import { ContentAction } from "@opal/layouts";
|
||||
import { Text } from "@opal/components";
|
||||
import Link from "next/link";
|
||||
import SimpleTooltip from "@/refresh-components/SimpleTooltip";
|
||||
import * as TooltipPrimitive from "@radix-ui/react-tooltip";
|
||||
import "@opal/components/tooltip.css";
|
||||
|
||||
export interface SidebarTabProps {
|
||||
// Button states:
|
||||
// ---------------------------------------------------------------------------
|
||||
// Types
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
interface SidebarTabProps {
|
||||
/** Collapses the label, showing only the icon. */
|
||||
folded?: boolean;
|
||||
|
||||
/** Marks this tab as the currently active/selected item. */
|
||||
selected?: boolean;
|
||||
lowlight?: boolean;
|
||||
|
||||
/**
|
||||
* Sidebar color variant.
|
||||
* @default "sidebar-heavy"
|
||||
*/
|
||||
variant?: Extract<
|
||||
InteractiveStatefulVariant,
|
||||
"sidebar-light" | "sidebar-heavy"
|
||||
>;
|
||||
|
||||
/** Renders an empty spacer in place of the icon for nested items. */
|
||||
nested?: boolean;
|
||||
|
||||
// Button properties:
|
||||
/** Disables the tab — applies muted colors and suppresses clicks. */
|
||||
disabled?: boolean;
|
||||
|
||||
onClick?: React.MouseEventHandler<HTMLElement>;
|
||||
href?: string;
|
||||
type?: ButtonType;
|
||||
icon?: IconFunctionComponent;
|
||||
children?: React.ReactNode;
|
||||
|
||||
/** Content rendered on the right side (e.g. action buttons). */
|
||||
rightChildren?: React.ReactNode;
|
||||
}
|
||||
|
||||
export default function SidebarTab({
|
||||
// ---------------------------------------------------------------------------
|
||||
// SidebarTab
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Sidebar navigation tab built on `Interactive.Stateful` > `Interactive.Container`.
|
||||
*
|
||||
* Uses `sidebar-heavy` (default) or `sidebar-light` (via `variant`) variants
|
||||
* for color styling. Supports an overlay `Link` for client-side navigation,
|
||||
* `rightChildren` for inline actions, and folded mode with an auto-tooltip.
|
||||
*/
|
||||
function SidebarTab({
|
||||
folded,
|
||||
selected,
|
||||
lowlight,
|
||||
variant = "sidebar-heavy",
|
||||
nested,
|
||||
disabled,
|
||||
|
||||
onClick,
|
||||
href,
|
||||
@@ -45,11 +79,8 @@ export default function SidebarTab({
|
||||
)) as IconFunctionComponent)
|
||||
: null);
|
||||
|
||||
// NOTE (@raunakab)
|
||||
//
|
||||
// The `rightChildren` node NEEDS to be absolutely positioned since it needs to live on top of the absolutely positioned `Link`.
|
||||
// However, having the `rightChildren` be absolutely positioned means that it cannot appropriately truncate the title.
|
||||
// Therefore, we add a dummy node solely for the truncation effects that we obtain.
|
||||
// The `rightChildren` node is absolutely positioned to sit on top of the
|
||||
// overlay Link. A zero-width spacer reserves truncation space for the title.
|
||||
const truncationSpacer = rightChildren && (
|
||||
<div className="w-0 group-hover/SidebarTab:w-6" />
|
||||
);
|
||||
@@ -57,8 +88,9 @@ export default function SidebarTab({
|
||||
const content = (
|
||||
<div className="relative">
|
||||
<Interactive.Stateful
|
||||
variant={lowlight ? "sidebar-light" : "sidebar-heavy"}
|
||||
variant={variant}
|
||||
state={selected ? "selected" : "empty"}
|
||||
disabled={disabled}
|
||||
onClick={onClick}
|
||||
type="button"
|
||||
group="group/SidebarTab"
|
||||
@@ -69,7 +101,7 @@ export default function SidebarTab({
|
||||
widthVariant="full"
|
||||
type={type}
|
||||
>
|
||||
{href && (
|
||||
{href && !disabled && (
|
||||
<Link
|
||||
href={href as Route}
|
||||
scroll={false}
|
||||
@@ -102,12 +134,7 @@ export default function SidebarTab({
|
||||
</div>
|
||||
)}
|
||||
{children}
|
||||
{
|
||||
// NOTE (@raunakab)
|
||||
//
|
||||
// Adding the `truncationSpacer` here for the same reason as above.
|
||||
truncationSpacer
|
||||
}
|
||||
{truncationSpacer}
|
||||
</div>
|
||||
)}
|
||||
</Interactive.Container>
|
||||
@@ -116,7 +143,23 @@ export default function SidebarTab({
|
||||
);
|
||||
|
||||
if (typeof children !== "string") return content;
|
||||
if (folded)
|
||||
return <SimpleTooltip tooltip={children}>{content}</SimpleTooltip>;
|
||||
if (folded) {
|
||||
return (
|
||||
<TooltipPrimitive.Root>
|
||||
<TooltipPrimitive.Trigger asChild>{content}</TooltipPrimitive.Trigger>
|
||||
<TooltipPrimitive.Portal>
|
||||
<TooltipPrimitive.Content
|
||||
className="opal-tooltip"
|
||||
side="right"
|
||||
sideOffset={4}
|
||||
>
|
||||
<Text>{children}</Text>
|
||||
</TooltipPrimitive.Content>
|
||||
</TooltipPrimitive.Portal>
|
||||
</TooltipPrimitive.Root>
|
||||
);
|
||||
}
|
||||
return content;
|
||||
}
|
||||
|
||||
export { SidebarTab, type SidebarTabProps };
|
||||
@@ -33,6 +33,12 @@ export {
|
||||
type LineItemButtonProps,
|
||||
} from "@opal/components/buttons/line-item-button/components";
|
||||
|
||||
/* SidebarTab */
|
||||
export {
|
||||
SidebarTab,
|
||||
type SidebarTabProps,
|
||||
} from "@opal/components/buttons/sidebar-tab/components";
|
||||
|
||||
/* Text */
|
||||
export {
|
||||
Text,
|
||||
|
||||
@@ -586,7 +586,10 @@ export function Table<TData>(props: DataTableProps<TData>) {
|
||||
|
||||
// Data / Display cell
|
||||
return (
|
||||
<TableCell key={cell.id}>
|
||||
<TableCell
|
||||
key={cell.id}
|
||||
data-column-id={cell.column.id}
|
||||
>
|
||||
{flexRender(
|
||||
cell.column.columnDef.cell,
|
||||
cell.getContext()
|
||||
|
||||
@@ -1,33 +1,7 @@
|
||||
"use client";
|
||||
|
||||
import "@opal/core/disabled/styles.css";
|
||||
import React, { createContext, useContext } from "react";
|
||||
import React from "react";
|
||||
import { Slot } from "@radix-ui/react-slot";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Context
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
interface DisabledContextValue {
|
||||
isDisabled: boolean;
|
||||
allowClick: boolean;
|
||||
}
|
||||
|
||||
const DisabledContext = createContext<DisabledContextValue>({
|
||||
isDisabled: false,
|
||||
allowClick: false,
|
||||
});
|
||||
|
||||
/**
|
||||
* Returns the current disabled state from the nearest `<Disabled>` ancestor.
|
||||
*
|
||||
* Used internally by `Interactive.Stateless` and `Interactive.Stateful` to
|
||||
* derive `data-disabled` and `aria-disabled` attributes automatically.
|
||||
*/
|
||||
function useDisabled(): DisabledContextValue {
|
||||
return useContext(DisabledContext);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Types
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -56,8 +30,8 @@ interface DisabledProps extends React.HTMLAttributes<HTMLElement> {
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Wrapper component that propagates disabled state via context and applies
|
||||
* baseline disabled CSS (opacity, cursor, pointer-events) to its child.
|
||||
* Wrapper component that applies baseline disabled CSS (opacity, cursor,
|
||||
* pointer-events) to its child element.
|
||||
*
|
||||
* Uses Radix `Slot` — merges props onto the single child element without
|
||||
* adding any DOM node. Works correctly inside Radix `asChild` chains.
|
||||
@@ -65,7 +39,7 @@ interface DisabledProps extends React.HTMLAttributes<HTMLElement> {
|
||||
* @example
|
||||
* ```tsx
|
||||
* <Disabled disabled={!canSubmit}>
|
||||
* <Button onClick={handleSubmit}>Save</Button>
|
||||
* <div>...</div>
|
||||
* </Disabled>
|
||||
* ```
|
||||
*/
|
||||
@@ -77,20 +51,16 @@ function Disabled({
|
||||
...rest
|
||||
}: DisabledProps) {
|
||||
return (
|
||||
<DisabledContext.Provider
|
||||
value={{ isDisabled: !!disabled, allowClick: !!allowClick }}
|
||||
<Slot
|
||||
ref={ref}
|
||||
{...rest}
|
||||
aria-disabled={disabled || undefined}
|
||||
data-opal-disabled={disabled || undefined}
|
||||
data-allow-click={disabled && allowClick ? "" : undefined}
|
||||
>
|
||||
<Slot
|
||||
ref={ref}
|
||||
{...rest}
|
||||
aria-disabled={disabled || undefined}
|
||||
data-opal-disabled={disabled || undefined}
|
||||
data-allow-click={disabled && allowClick ? "" : undefined}
|
||||
>
|
||||
{children}
|
||||
</Slot>
|
||||
</DisabledContext.Provider>
|
||||
{children}
|
||||
</Slot>
|
||||
);
|
||||
}
|
||||
|
||||
export { Disabled, useDisabled, type DisabledProps, type DisabledContextValue };
|
||||
export { Disabled, type DisabledProps };
|
||||
|
||||
@@ -1,10 +1,5 @@
|
||||
/* Disabled */
|
||||
export {
|
||||
Disabled,
|
||||
useDisabled,
|
||||
type DisabledProps,
|
||||
type DisabledContextValue,
|
||||
} from "@opal/core/disabled/components";
|
||||
export { Disabled, type DisabledProps } from "@opal/core/disabled/components";
|
||||
|
||||
/* Animations (formerly Hoverable) */
|
||||
export {
|
||||
|
||||
@@ -10,7 +10,6 @@ import {
|
||||
widthVariants,
|
||||
type ExtremaSizeVariants,
|
||||
} from "@opal/shared";
|
||||
import { useDisabled } from "@opal/core/disabled/components";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Types
|
||||
@@ -102,7 +101,6 @@ function InteractiveContainer({
|
||||
widthVariant = "fit",
|
||||
...props
|
||||
}: InteractiveContainerProps) {
|
||||
const { allowClick } = useDisabled();
|
||||
const {
|
||||
className: slotClassName,
|
||||
style: slotStyle,
|
||||
@@ -148,8 +146,7 @@ function InteractiveContainer({
|
||||
if (type) {
|
||||
const ariaDisabled = (rest as Record<string, unknown>)["aria-disabled"];
|
||||
const nativeDisabled =
|
||||
(type === "submit" || !allowClick) &&
|
||||
(ariaDisabled === true || ariaDisabled === "true" || undefined);
|
||||
ariaDisabled === true || ariaDisabled === "true" || undefined;
|
||||
return (
|
||||
<button
|
||||
ref={ref as React.Ref<HTMLButtonElement>}
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import React from "react";
|
||||
import { Slot } from "@radix-ui/react-slot";
|
||||
import { cn } from "@opal/utils";
|
||||
import { useDisabled } from "@opal/core/disabled/components";
|
||||
import { guardPortalClick } from "@opal/core/interactive/utils";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -29,6 +28,11 @@ interface InteractiveSimpleProps
|
||||
* Link target (e.g. `"_blank"`). Only used when `href` is provided.
|
||||
*/
|
||||
target?: string;
|
||||
|
||||
/**
|
||||
* Applies disabled cursor and suppresses clicks.
|
||||
*/
|
||||
disabled?: boolean;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -38,8 +42,8 @@ interface InteractiveSimpleProps
|
||||
/**
|
||||
* Minimal interactive surface primitive.
|
||||
*
|
||||
* Provides cursor styling, click handling, disabled integration, and
|
||||
* optional link/group support — but **no color or background styling**.
|
||||
* Provides cursor styling, click handling, and optional link/group
|
||||
* support — but **no color or background styling**.
|
||||
*
|
||||
* Use this for elements that need interactivity (click, cursor, disabled)
|
||||
* without participating in the Interactive color system.
|
||||
@@ -59,9 +63,10 @@ function InteractiveSimple({
|
||||
group,
|
||||
href,
|
||||
target,
|
||||
disabled,
|
||||
...props
|
||||
}: InteractiveSimpleProps) {
|
||||
const { isDisabled, allowClick } = useDisabled();
|
||||
const isDisabled = !!disabled;
|
||||
|
||||
const classes = cn(
|
||||
"cursor-pointer select-none",
|
||||
@@ -88,7 +93,7 @@ function InteractiveSimple({
|
||||
{...linkAttrs}
|
||||
{...slotProps}
|
||||
onClick={
|
||||
isDisabled && !allowClick
|
||||
isDisabled
|
||||
? href
|
||||
? (e: React.MouseEvent) => e.preventDefault()
|
||||
: undefined
|
||||
|
||||
@@ -3,7 +3,6 @@ import "@opal/core/interactive/stateful/styles.css";
|
||||
import React from "react";
|
||||
import { Slot } from "@radix-ui/react-slot";
|
||||
import { cn } from "@opal/utils";
|
||||
import { useDisabled } from "@opal/core/disabled/components";
|
||||
import { guardPortalClick } from "@opal/core/interactive/utils";
|
||||
import type { ButtonType, WithoutStyles } from "@opal/types";
|
||||
|
||||
@@ -87,6 +86,11 @@ interface InteractiveStatefulProps
|
||||
* Link target (e.g. `"_blank"`). Only used when `href` is provided.
|
||||
*/
|
||||
target?: string;
|
||||
|
||||
/**
|
||||
* Applies variant-specific disabled colors and suppresses clicks.
|
||||
*/
|
||||
disabled?: boolean;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -100,8 +104,7 @@ interface InteractiveStatefulProps
|
||||
* (empty/filled/selected). Applies variant/state color styling via CSS
|
||||
* data-attributes and merges onto a single child element via Radix `Slot`.
|
||||
*
|
||||
* Disabled state is consumed from the nearest `<Disabled>` ancestor via
|
||||
* context — there is no `disabled` prop on this component.
|
||||
* Disabled state is controlled via the `disabled` prop.
|
||||
*/
|
||||
function InteractiveStateful({
|
||||
ref,
|
||||
@@ -112,9 +115,10 @@ function InteractiveStateful({
|
||||
type,
|
||||
href,
|
||||
target,
|
||||
disabled,
|
||||
...props
|
||||
}: InteractiveStatefulProps) {
|
||||
const { isDisabled, allowClick } = useDisabled();
|
||||
const isDisabled = !!disabled;
|
||||
|
||||
// onClick/href are always passed directly — Stateful is the outermost Slot,
|
||||
// so Radix Slot-injected handlers don't bypass this guard.
|
||||
@@ -150,7 +154,7 @@ function InteractiveStateful({
|
||||
{...linkAttrs}
|
||||
{...slotProps}
|
||||
onClick={
|
||||
isDisabled && !allowClick
|
||||
isDisabled
|
||||
? href
|
||||
? (e: React.MouseEvent) => e.preventDefault()
|
||||
: undefined
|
||||
|
||||
@@ -550,6 +550,14 @@
|
||||
) {
|
||||
@apply bg-background-tint-03;
|
||||
}
|
||||
/* ---------------------------------------------------------------------------
|
||||
Sidebar-Heavy — Disabled (all states)
|
||||
--------------------------------------------------------------------------- */
|
||||
.interactive[data-interactive-variant="sidebar-heavy"][data-disabled] {
|
||||
@apply bg-transparent opacity-50;
|
||||
--interactive-foreground: var(--text-03);
|
||||
--interactive-foreground-icon: var(--text-03);
|
||||
}
|
||||
|
||||
/* ===========================================================================
|
||||
Sidebar-Light
|
||||
@@ -607,3 +615,11 @@
|
||||
) {
|
||||
@apply bg-background-tint-03;
|
||||
}
|
||||
/* ---------------------------------------------------------------------------
|
||||
Sidebar-Light — Disabled (all states)
|
||||
--------------------------------------------------------------------------- */
|
||||
.interactive[data-interactive-variant="sidebar-light"][data-disabled] {
|
||||
@apply bg-transparent opacity-50;
|
||||
--interactive-foreground: var(--text-03);
|
||||
--interactive-foreground-icon: var(--text-03);
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ import "@opal/core/interactive/stateless/styles.css";
|
||||
import React from "react";
|
||||
import { Slot } from "@radix-ui/react-slot";
|
||||
import { cn } from "@opal/utils";
|
||||
import { useDisabled } from "@opal/core/disabled/components";
|
||||
import { guardPortalClick } from "@opal/core/interactive/utils";
|
||||
import type { ButtonType, WithoutStyles } from "@opal/types";
|
||||
|
||||
@@ -70,6 +69,11 @@ interface InteractiveStatelessProps
|
||||
* Link target (e.g. `"_blank"`). Only used when `href` is provided.
|
||||
*/
|
||||
target?: string;
|
||||
|
||||
/**
|
||||
* Applies variant-specific disabled colors and suppresses clicks.
|
||||
*/
|
||||
disabled?: boolean;
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -84,8 +88,7 @@ interface InteractiveStatelessProps
|
||||
* color styling via CSS data-attributes and merges onto a single child
|
||||
* element via Radix `Slot`.
|
||||
*
|
||||
* Disabled state is consumed from the nearest `<Disabled>` ancestor via
|
||||
* context — there is no `disabled` prop on this component.
|
||||
* Disabled state is controlled via the `disabled` prop.
|
||||
*/
|
||||
function InteractiveStateless({
|
||||
ref,
|
||||
@@ -96,9 +99,10 @@ function InteractiveStateless({
|
||||
type,
|
||||
href,
|
||||
target,
|
||||
disabled,
|
||||
...props
|
||||
}: InteractiveStatelessProps) {
|
||||
const { isDisabled, allowClick } = useDisabled();
|
||||
const isDisabled = !!disabled;
|
||||
|
||||
// onClick/href are always passed directly — Stateless is the outermost Slot,
|
||||
// so Radix Slot-injected handlers don't bypass this guard.
|
||||
@@ -134,7 +138,7 @@ function InteractiveStateless({
|
||||
{...linkAttrs}
|
||||
{...slotProps}
|
||||
onClick={
|
||||
isDisabled && !allowClick
|
||||
isDisabled
|
||||
? href
|
||||
? (e: React.MouseEvent) => e.preventDefault()
|
||||
: undefined
|
||||
|
||||
@@ -8,7 +8,6 @@ import * as InputLayouts from "@/layouts/input-layouts";
|
||||
import Card from "@/refresh-components/cards/Card";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import { Button as OpalButton } from "@opal/components";
|
||||
import { Disabled } from "@opal/core";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import Message from "@/refresh-components/messages/Message";
|
||||
import InfoBlock from "@/refresh-components/messages/InfoBlock";
|
||||
@@ -246,15 +245,14 @@ function SubscriptionCard({
|
||||
to make changes.
|
||||
</Text>
|
||||
) : disabled ? (
|
||||
<Disabled disabled={isReconnecting}>
|
||||
<OpalButton
|
||||
prominence="secondary"
|
||||
onClick={handleReconnect}
|
||||
rightIcon={SvgArrowRight}
|
||||
>
|
||||
{isReconnecting ? "Connecting..." : "Connect to Stripe"}
|
||||
</OpalButton>
|
||||
</Disabled>
|
||||
<OpalButton
|
||||
disabled={isReconnecting}
|
||||
prominence="secondary"
|
||||
onClick={handleReconnect}
|
||||
rightIcon={SvgArrowRight}
|
||||
>
|
||||
{isReconnecting ? "Connecting..." : "Connect to Stripe"}
|
||||
</OpalButton>
|
||||
) : (
|
||||
<OpalButton onClick={handleManagePlan} rightIcon={SvgExternalLink}>
|
||||
Manage Plan
|
||||
@@ -377,11 +375,13 @@ function SeatsCard({
|
||||
sizePreset="main-content"
|
||||
variant="section"
|
||||
/>
|
||||
<Disabled disabled={isSubmitting}>
|
||||
<OpalButton prominence="secondary" onClick={handleCancel}>
|
||||
Cancel
|
||||
</OpalButton>
|
||||
</Disabled>
|
||||
<OpalButton
|
||||
disabled={isSubmitting}
|
||||
prominence="secondary"
|
||||
onClick={handleCancel}
|
||||
>
|
||||
Cancel
|
||||
</OpalButton>
|
||||
</Section>
|
||||
|
||||
<div className="billing-content-area">
|
||||
@@ -463,15 +463,14 @@ function SeatsCard({
|
||||
No changes to your billing.
|
||||
</Text>
|
||||
)}
|
||||
<Disabled
|
||||
<OpalButton
|
||||
disabled={
|
||||
isSubmitting || newSeatCount === totalSeats || isBelowMinimum
|
||||
}
|
||||
onClick={handleConfirm}
|
||||
>
|
||||
<OpalButton onClick={handleConfirm}>
|
||||
{isSubmitting ? "Saving..." : "Confirm Change"}
|
||||
</OpalButton>
|
||||
</Disabled>
|
||||
{isSubmitting ? "Saving..." : "Confirm Change"}
|
||||
</OpalButton>
|
||||
</Section>
|
||||
</Card>
|
||||
);
|
||||
@@ -509,15 +508,14 @@ function SeatsCard({
|
||||
View Users
|
||||
</OpalButton>
|
||||
{!hideUpdateSeats && (
|
||||
<Disabled disabled={isLoadingUsers || disabled || !billing}>
|
||||
<OpalButton
|
||||
prominence="secondary"
|
||||
onClick={handleStartEdit}
|
||||
icon={SvgPlus}
|
||||
>
|
||||
Update Seats
|
||||
</OpalButton>
|
||||
</Disabled>
|
||||
<OpalButton
|
||||
disabled={isLoadingUsers || disabled || !billing}
|
||||
prominence="secondary"
|
||||
onClick={handleStartEdit}
|
||||
icon={SvgPlus}
|
||||
>
|
||||
Update Seats
|
||||
</OpalButton>
|
||||
)}
|
||||
</Section>
|
||||
</Section>
|
||||
|
||||
@@ -4,7 +4,6 @@ import { useState, useMemo, useEffect } from "react";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import * as InputLayouts from "@/layouts/input-layouts";
|
||||
import { Button } from "@opal/components";
|
||||
import { Disabled } from "@opal/core";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import Card from "@/refresh-components/cards/Card";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
@@ -263,11 +262,9 @@ export default function CheckoutView({ onAdjustPlan }: CheckoutViewProps) {
|
||||
// Empty div to maintain space-between alignment
|
||||
<div></div>
|
||||
)}
|
||||
<Disabled disabled={isSubmitting}>
|
||||
<Button onClick={handleSubmit}>
|
||||
{isSubmitting ? "Loading..." : "Continue to Payment"}
|
||||
</Button>
|
||||
</Disabled>
|
||||
<Button disabled={isSubmitting} onClick={handleSubmit}>
|
||||
{isSubmitting ? "Loading..." : "Continue to Payment"}
|
||||
</Button>
|
||||
</Section>
|
||||
</Card>
|
||||
);
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
import { useState } from "react";
|
||||
import Card from "@/refresh-components/cards/Card";
|
||||
import { Button } from "@opal/components";
|
||||
import { Disabled } from "@opal/core";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import InputFile from "@/refresh-components/inputs/InputFile";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
@@ -147,11 +146,13 @@ export default function LicenseActivationCard({
|
||||
<Text headingH3>
|
||||
{hasLicense ? "Update License Key" : "Activate License Key"}
|
||||
</Text>
|
||||
<Disabled disabled={isActivating}>
|
||||
<Button prominence="secondary" onClick={handleClose}>
|
||||
Cancel
|
||||
</Button>
|
||||
</Disabled>
|
||||
<Button
|
||||
disabled={isActivating}
|
||||
prominence="secondary"
|
||||
onClick={handleClose}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
</Section>
|
||||
<Text secondaryBody text03>
|
||||
Manually add and activate a license for this Onyx instance.
|
||||
@@ -221,15 +222,16 @@ export default function LicenseActivationCard({
|
||||
|
||||
{/* Footer */}
|
||||
<Section flexDirection="row" justifyContent="end" padding={1}>
|
||||
<Disabled disabled={isActivating || !licenseKey.trim() || success}>
|
||||
<Button onClick={handleActivate}>
|
||||
{isActivating
|
||||
? "Activating..."
|
||||
: hasLicense
|
||||
? "Update License"
|
||||
: "Activate License"}
|
||||
</Button>
|
||||
</Disabled>
|
||||
<Button
|
||||
disabled={isActivating || !licenseKey.trim() || success}
|
||||
onClick={handleActivate}
|
||||
>
|
||||
{isActivating
|
||||
? "Activating..."
|
||||
: hasLicense
|
||||
? "Update License"
|
||||
: "Activate License"}
|
||||
</Button>
|
||||
</Section>
|
||||
</Card>
|
||||
);
|
||||
|
||||
@@ -64,7 +64,6 @@ const BUSINESS_FEATURES: PlanFeature[] = [
|
||||
{ icon: SvgKey, text: "Service Account API Keys" },
|
||||
{ icon: SvgHardDrive, text: "Self-hosting (Optional)" },
|
||||
{ icon: SvgPaintBrush, text: "Custom Theming" },
|
||||
{ icon: SvgShareWebhook, text: "Hook Extensions" },
|
||||
];
|
||||
|
||||
const ENTERPRISE_FEATURES: PlanFeature[] = [
|
||||
@@ -72,6 +71,7 @@ const ENTERPRISE_FEATURES: PlanFeature[] = [
|
||||
{ icon: SvgDashboard, text: "Full White-labeling" },
|
||||
{ icon: SvgUserManage, text: "Custom Roles and Permissions" },
|
||||
{ icon: SvgSliders, text: "Configurable Usage Limits" },
|
||||
{ icon: SvgShareWebhook, text: "Hook Extensions" },
|
||||
{ icon: SvgServer, text: "Custom Deployments" },
|
||||
{ icon: SvgGlobe, text: "Region-Specific Data Processing" },
|
||||
{ icon: SvgHeadsetMic, text: "Enterprise SLAs and Priority Support" },
|
||||
|
||||
@@ -5,7 +5,6 @@ import { Form, Formik } from "formik";
|
||||
import * as Yup from "yup";
|
||||
import { createSlackBot, updateSlackBot } from "./new/lib";
|
||||
import { Button } from "@opal/components";
|
||||
import { Disabled } from "@opal/core";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import { useEffect } from "react";
|
||||
import { DOCS_ADMINS_PATH } from "@/lib/constants";
|
||||
@@ -127,16 +126,17 @@ export const SlackTokensForm = ({
|
||||
subtext="Optional: User OAuth token for enhanced private channel access"
|
||||
/>
|
||||
<div className="flex justify-end w-full mt-4">
|
||||
<Disabled
|
||||
<Button
|
||||
disabled={
|
||||
isSubmitting ||
|
||||
!values.bot_token ||
|
||||
!values.app_token ||
|
||||
!values.name
|
||||
}
|
||||
type="submit"
|
||||
>
|
||||
<Button type="submit">{isUpdate ? "Update" : "Create"}</Button>
|
||||
</Disabled>
|
||||
{isUpdate ? "Update" : "Create"}
|
||||
</Button>
|
||||
</div>
|
||||
</Form>
|
||||
)}
|
||||
|
||||
@@ -6,7 +6,6 @@ import { ManualErrorMessage, TextFormField } from "@/components/Field";
|
||||
import { useEffect, useState } from "react";
|
||||
import CreateButton from "@/refresh-components/buttons/CreateButton";
|
||||
import { Button } from "@opal/components";
|
||||
import { Disabled } from "@opal/core";
|
||||
import { SvgX } from "@opal/icons";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
|
||||
@@ -56,20 +55,17 @@ function ModelConfigurationRow({
|
||||
/>
|
||||
</div>
|
||||
<div className="flex flex-col justify-center">
|
||||
<Disabled
|
||||
<Button
|
||||
disabled={formikProps.values.model_configurations.length <= 1}
|
||||
>
|
||||
<Button
|
||||
onClick={() => {
|
||||
if (formikProps.values.model_configurations.length > 1) {
|
||||
setError(null);
|
||||
arrayHelpers.remove(index);
|
||||
}
|
||||
}}
|
||||
icon={SvgX}
|
||||
prominence="secondary"
|
||||
/>
|
||||
</Disabled>
|
||||
onClick={() => {
|
||||
if (formikProps.values.model_configurations.length > 1) {
|
||||
setError(null);
|
||||
arrayHelpers.remove(index);
|
||||
}
|
||||
}}
|
||||
icon={SvgX}
|
||||
prominence="secondary"
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
import { useState, useRef } from "react";
|
||||
import { Button } from "@opal/components";
|
||||
import { Disabled } from "@opal/core";
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
@@ -184,25 +183,24 @@ export default function InlineFileManagement({
|
||||
</Button>
|
||||
) : (
|
||||
<>
|
||||
<Disabled disabled={isSaving}>
|
||||
<Button
|
||||
prominence="secondary"
|
||||
onClick={handleCancel}
|
||||
icon={SvgX}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
</Disabled>
|
||||
<Disabled
|
||||
<Button
|
||||
disabled={isSaving}
|
||||
prominence="secondary"
|
||||
onClick={handleCancel}
|
||||
icon={SvgX}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button
|
||||
disabled={
|
||||
isSaving ||
|
||||
(selectedFilesToRemove.size === 0 && filesToAdd.length === 0)
|
||||
}
|
||||
onClick={handleSaveClick}
|
||||
icon={SvgCheck}
|
||||
>
|
||||
<Button onClick={handleSaveClick} icon={SvgCheck}>
|
||||
{isSaving ? "Saving..." : "Save Changes"}
|
||||
</Button>
|
||||
</Disabled>
|
||||
{isSaving ? "Saving..." : "Save Changes"}
|
||||
</Button>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
@@ -334,15 +332,14 @@ export default function InlineFileManagement({
|
||||
className="hidden"
|
||||
id={`file-upload-${connectorId}`}
|
||||
/>
|
||||
<Disabled disabled={isSaving}>
|
||||
<Button
|
||||
prominence="secondary"
|
||||
onClick={() => fileInputRef.current?.click()}
|
||||
icon={SvgPlusCircle}
|
||||
>
|
||||
Add Files
|
||||
</Button>
|
||||
</Disabled>
|
||||
<Button
|
||||
disabled={isSaving}
|
||||
prominence="secondary"
|
||||
onClick={() => fileInputRef.current?.click()}
|
||||
icon={SvgPlusCircle}
|
||||
>
|
||||
Add Files
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
@@ -398,19 +395,16 @@ export default function InlineFileManagement({
|
||||
</Modal.Body>
|
||||
|
||||
<Modal.Footer>
|
||||
<Disabled disabled={isSaving}>
|
||||
<Button
|
||||
prominence="secondary"
|
||||
onClick={() => setShowSaveConfirm(false)}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
</Disabled>
|
||||
<Disabled disabled={isSaving}>
|
||||
<Button onClick={handleConfirmSave}>
|
||||
{isSaving ? "Saving..." : "Confirm & Save"}
|
||||
</Button>
|
||||
</Disabled>
|
||||
<Button
|
||||
disabled={isSaving}
|
||||
prominence="secondary"
|
||||
onClick={() => setShowSaveConfirm(false)}
|
||||
>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button disabled={isSaving} onClick={handleConfirmSave}>
|
||||
{isSaving ? "Saving..." : "Confirm & Save"}
|
||||
</Button>
|
||||
</Modal.Footer>
|
||||
</Modal.Content>
|
||||
</Modal>
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
"use client";
|
||||
|
||||
import { Button } from "@opal/components";
|
||||
import { Disabled } from "@opal/core";
|
||||
import { useState } from "react";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { triggerIndexing } from "@/app/admin/connector/[ccPairId]/lib";
|
||||
@@ -116,9 +115,9 @@ export default function ReIndexModal({ hide, onRunIndex }: ReIndexModalProps) {
|
||||
This will pull in and index all documents that have changed and/or
|
||||
have been added since the last successful indexing run.
|
||||
</Text>
|
||||
<Disabled disabled={isProcessing}>
|
||||
<Button onClick={() => handleRunIndex(false)}>Run Update</Button>
|
||||
</Disabled>
|
||||
<Button disabled={isProcessing} onClick={() => handleRunIndex(false)}>
|
||||
Run Update
|
||||
</Button>
|
||||
|
||||
<Separator />
|
||||
|
||||
@@ -131,11 +130,9 @@ export default function ReIndexModal({ hide, onRunIndex }: ReIndexModalProps) {
|
||||
in the source, this may take a long time.
|
||||
</Text>
|
||||
|
||||
<Disabled disabled={isProcessing}>
|
||||
<Button onClick={() => handleRunIndex(true)}>
|
||||
Run Complete Re-Indexing
|
||||
</Button>
|
||||
</Disabled>
|
||||
<Button disabled={isProcessing} onClick={() => handleRunIndex(true)}>
|
||||
Run Complete Re-Indexing
|
||||
</Button>
|
||||
</Modal.Body>
|
||||
</Modal.Content>
|
||||
</Modal>
|
||||
|
||||
@@ -56,7 +56,6 @@ import {
|
||||
import { CreateStdOAuthCredential } from "@/components/credentials/actions/CreateStdOAuthCredential";
|
||||
import { Spinner } from "@/components/Spinner";
|
||||
import { Button } from "@opal/components";
|
||||
import { Disabled } from "@opal/core";
|
||||
import { deleteConnector } from "@/lib/connector";
|
||||
import ConnectorDocsLink from "@/components/admin/connectors/ConnectorDocsLink";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
@@ -580,19 +579,18 @@ export default function AddConnector({
|
||||
{/* Button to sign in via OAuth */}
|
||||
{oauthSupportedSources.includes(connector) &&
|
||||
(NEXT_PUBLIC_CLOUD_ENABLED || NEXT_PUBLIC_TEST_ENV) && (
|
||||
<Disabled disabled={isAuthorizing}>
|
||||
<Button
|
||||
variant="action"
|
||||
onClick={handleAuthorize}
|
||||
hidden={!isAuthorizeVisible}
|
||||
>
|
||||
{isAuthorizing
|
||||
? "Authorizing..."
|
||||
: `Authorize with ${getSourceDisplayName(
|
||||
connector
|
||||
)}`}
|
||||
</Button>
|
||||
</Disabled>
|
||||
<Button
|
||||
disabled={isAuthorizing}
|
||||
variant="action"
|
||||
onClick={handleAuthorize}
|
||||
hidden={!isAuthorizeVisible}
|
||||
>
|
||||
{isAuthorizing
|
||||
? "Authorizing..."
|
||||
: `Authorize with ${getSourceDisplayName(
|
||||
connector
|
||||
)}`}
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import { useFormContext } from "@/components/context/FormContext";
|
||||
import { Button } from "@opal/components";
|
||||
import { Disabled } from "@opal/core";
|
||||
import { SvgArrowLeft, SvgArrowRight, SvgPlusCircle } from "@opal/icons";
|
||||
|
||||
const NavigationRow = ({
|
||||
@@ -34,35 +33,35 @@ const NavigationRow = ({
|
||||
</div>
|
||||
<div className="flex justify-center">
|
||||
{(formStep > 0 || noCredentials) && (
|
||||
<Disabled disabled={!isValid}>
|
||||
<Button rightIcon={SvgPlusCircle} onClick={onSubmit}>
|
||||
Create Connector
|
||||
</Button>
|
||||
</Disabled>
|
||||
<Button
|
||||
disabled={!isValid}
|
||||
rightIcon={SvgPlusCircle}
|
||||
onClick={onSubmit}
|
||||
>
|
||||
Create Connector
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
<div className="flex justify-end">
|
||||
{formStep === 0 && (
|
||||
<Disabled disabled={!activatedCredential}>
|
||||
<Button
|
||||
variant="action"
|
||||
rightIcon={SvgArrowRight}
|
||||
onClick={() => nextFormStep()}
|
||||
>
|
||||
Continue
|
||||
</Button>
|
||||
</Disabled>
|
||||
<Button
|
||||
disabled={!activatedCredential}
|
||||
variant="action"
|
||||
rightIcon={SvgArrowRight}
|
||||
onClick={() => nextFormStep()}
|
||||
>
|
||||
Continue
|
||||
</Button>
|
||||
)}
|
||||
{!noAdvanced && formStep === 1 && (
|
||||
<Disabled disabled={!isValid}>
|
||||
<Button
|
||||
prominence="secondary"
|
||||
rightIcon={SvgArrowRight}
|
||||
onClick={() => nextFormStep()}
|
||||
>
|
||||
Advanced
|
||||
</Button>
|
||||
</Disabled>
|
||||
<Button
|
||||
disabled={!isValid}
|
||||
prominence="secondary"
|
||||
rightIcon={SvgArrowRight}
|
||||
onClick={() => nextFormStep()}
|
||||
>
|
||||
Advanced
|
||||
</Button>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -4,7 +4,6 @@ import { useEffect, useState } from "react";
|
||||
import { usePathname, useRouter, useSearchParams } from "next/navigation";
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import { Button } from "@opal/components";
|
||||
import { Disabled } from "@opal/core";
|
||||
import { getSourceMetadata, isValidSource } from "@/lib/sources";
|
||||
import { ConfluenceAccessibleResource, ValidSources } from "@/lib/types";
|
||||
import CardSection from "@/components/admin/CardSection";
|
||||
@@ -260,11 +259,9 @@ export default function OAuthFinalizePage() {
|
||||
)}
|
||||
<br />
|
||||
{!redirectUrl && (
|
||||
<Disabled disabled={!isValid || isSubmitting}>
|
||||
<Button type="submit">
|
||||
{isSubmitting ? "Submitting..." : "Submit"}
|
||||
</Button>
|
||||
</Disabled>
|
||||
<Button disabled={!isValid || isSubmitting} type="submit">
|
||||
{isSubmitting ? "Submitting..." : "Submit"}
|
||||
</Button>
|
||||
)}
|
||||
</Form>
|
||||
)}
|
||||
|
||||
@@ -11,7 +11,6 @@ import { TextFormField, SectionHeader } from "@/components/Field";
|
||||
import { Form, Formik } from "formik";
|
||||
import { User } from "@/lib/types";
|
||||
import { Button } from "@opal/components";
|
||||
import { Disabled } from "@opal/core";
|
||||
import {
|
||||
Credential,
|
||||
GoogleDriveCredentialJson,
|
||||
@@ -563,11 +562,9 @@ export const DriveAuthSection = ({
|
||||
subtext="Enter the email of an admin/owner of the Google Organization that owns the Google Drive(s) you want to index."
|
||||
/>
|
||||
<div className="flex">
|
||||
<Disabled disabled={isSubmitting}>
|
||||
<Button type="submit">
|
||||
{isSubmitting ? "Creating..." : "Create Credential"}
|
||||
</Button>
|
||||
</Disabled>
|
||||
<Button disabled={isSubmitting} type="submit">
|
||||
{isSubmitting ? "Creating..." : "Create Credential"}
|
||||
</Button>
|
||||
</div>
|
||||
</Form>
|
||||
)}
|
||||
@@ -587,35 +584,34 @@ export const DriveAuthSection = ({
|
||||
Google Drive account.
|
||||
</p>
|
||||
</div>
|
||||
<Disabled disabled={isAuthenticating}>
|
||||
<Button
|
||||
onClick={async () => {
|
||||
setIsAuthenticating(true);
|
||||
try {
|
||||
const [authUrl, errorMsg] = await setupGoogleDriveOAuth({
|
||||
isAdmin: true,
|
||||
name: "OAuth (uploaded)",
|
||||
});
|
||||
<Button
|
||||
disabled={isAuthenticating}
|
||||
onClick={async () => {
|
||||
setIsAuthenticating(true);
|
||||
try {
|
||||
const [authUrl, errorMsg] = await setupGoogleDriveOAuth({
|
||||
isAdmin: true,
|
||||
name: "OAuth (uploaded)",
|
||||
});
|
||||
|
||||
if (authUrl) {
|
||||
router.push(authUrl as Route);
|
||||
} else {
|
||||
toast.error(errorMsg);
|
||||
setIsAuthenticating(false);
|
||||
}
|
||||
} catch (error) {
|
||||
toast.error(
|
||||
`Failed to authenticate with Google Drive - ${error}`
|
||||
);
|
||||
if (authUrl) {
|
||||
router.push(authUrl as Route);
|
||||
} else {
|
||||
toast.error(errorMsg);
|
||||
setIsAuthenticating(false);
|
||||
}
|
||||
}}
|
||||
>
|
||||
{isAuthenticating
|
||||
? "Authenticating..."
|
||||
: "Authenticate with Google Drive"}
|
||||
</Button>
|
||||
</Disabled>
|
||||
} catch (error) {
|
||||
toast.error(
|
||||
`Failed to authenticate with Google Drive - ${error}`
|
||||
);
|
||||
setIsAuthenticating(false);
|
||||
}
|
||||
}}
|
||||
>
|
||||
{isAuthenticating
|
||||
? "Authenticating..."
|
||||
: "Authenticate with Google Drive"}
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import { Button } from "@opal/components";
|
||||
import { Disabled } from "@opal/core";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import React, { useState, useEffect } from "react";
|
||||
import { useSWRConfig } from "swr";
|
||||
@@ -570,11 +569,9 @@ export const GmailAuthSection = ({
|
||||
subtext="Enter the email of an admin/owner of the Google Organization that owns the Gmail account(s) you want to index."
|
||||
/>
|
||||
<div className="flex">
|
||||
<Disabled disabled={isSubmitting}>
|
||||
<Button type="submit">
|
||||
{isSubmitting ? "Creating..." : "Create Credential"}
|
||||
</Button>
|
||||
</Disabled>
|
||||
<Button disabled={isSubmitting} type="submit">
|
||||
{isSubmitting ? "Creating..." : "Create Credential"}
|
||||
</Button>
|
||||
</div>
|
||||
</Form>
|
||||
)}
|
||||
@@ -593,36 +590,35 @@ export const GmailAuthSection = ({
|
||||
read access to the emails you have access to in your Gmail account.
|
||||
</p>
|
||||
</div>
|
||||
<Disabled disabled={isAuthenticating}>
|
||||
<Button
|
||||
onClick={async () => {
|
||||
setIsAuthenticating(true);
|
||||
try {
|
||||
if (buildMode) {
|
||||
Cookies.set(CRAFT_OAUTH_COOKIE_NAME, "true", {
|
||||
path: "/",
|
||||
});
|
||||
}
|
||||
const [authUrl, errorMsg] = await setupGmailOAuth({
|
||||
isAdmin: true,
|
||||
<Button
|
||||
disabled={isAuthenticating}
|
||||
onClick={async () => {
|
||||
setIsAuthenticating(true);
|
||||
try {
|
||||
if (buildMode) {
|
||||
Cookies.set(CRAFT_OAUTH_COOKIE_NAME, "true", {
|
||||
path: "/",
|
||||
});
|
||||
}
|
||||
const [authUrl, errorMsg] = await setupGmailOAuth({
|
||||
isAdmin: true,
|
||||
});
|
||||
|
||||
if (authUrl) {
|
||||
onOAuthRedirect?.();
|
||||
router.push(authUrl as Route);
|
||||
} else {
|
||||
toast.error(errorMsg);
|
||||
setIsAuthenticating(false);
|
||||
}
|
||||
} catch (error) {
|
||||
toast.error(`Failed to authenticate with Gmail - ${error}`);
|
||||
if (authUrl) {
|
||||
onOAuthRedirect?.();
|
||||
router.push(authUrl as Route);
|
||||
} else {
|
||||
toast.error(errorMsg);
|
||||
setIsAuthenticating(false);
|
||||
}
|
||||
}}
|
||||
>
|
||||
{isAuthenticating ? "Authenticating..." : "Authenticate with Gmail"}
|
||||
</Button>
|
||||
</Disabled>
|
||||
} catch (error) {
|
||||
toast.error(`Failed to authenticate with Gmail - ${error}`);
|
||||
setIsAuthenticating(false);
|
||||
}
|
||||
}}
|
||||
>
|
||||
{isAuthenticating ? "Authenticating..." : "Authenticate with Gmail"}
|
||||
</Button>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@ import { Section } from "@/layouts/general-layouts";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import Card from "@/refresh-components/cards/Card";
|
||||
import { Button } from "@opal/components";
|
||||
import { Disabled } from "@opal/core";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
import PasswordInputTypeIn from "@/refresh-components/inputs/PasswordInputTypeIn";
|
||||
import { ThreeDotsLoader } from "@/components/Loading";
|
||||
@@ -126,14 +125,13 @@ export function BotConfigCard() {
|
||||
}
|
||||
disabled={!hasServerConfigs}
|
||||
>
|
||||
<Disabled disabled={isSubmitting || hasServerConfigs}>
|
||||
<Button
|
||||
variant="danger"
|
||||
onClick={() => setShowDeleteConfirm(true)}
|
||||
>
|
||||
Delete Discord Token
|
||||
</Button>
|
||||
</Disabled>
|
||||
<Button
|
||||
disabled={isSubmitting || hasServerConfigs}
|
||||
variant="danger"
|
||||
onClick={() => setShowDeleteConfirm(true)}
|
||||
>
|
||||
Delete Discord Token
|
||||
</Button>
|
||||
</SimpleTooltip>
|
||||
)}
|
||||
</Section>
|
||||
@@ -167,11 +165,12 @@ export function BotConfigCard() {
|
||||
disabled={isSubmitting}
|
||||
className="flex-1"
|
||||
/>
|
||||
<Disabled disabled={isSubmitting || !botToken.trim()}>
|
||||
<Button onClick={handleSaveToken}>
|
||||
{isSubmitting ? "Saving..." : "Save Token"}
|
||||
</Button>
|
||||
</Disabled>
|
||||
<Button
|
||||
disabled={isSubmitting || !botToken.trim()}
|
||||
onClick={handleSaveToken}
|
||||
>
|
||||
{isSubmitting ? "Saving..." : "Save Token"}
|
||||
</Button>
|
||||
</Section>
|
||||
</Section>
|
||||
)}
|
||||
|
||||
@@ -13,7 +13,6 @@ import {
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
import { DeleteButton } from "@/components/DeleteButton";
|
||||
import { Button } from "@opal/components";
|
||||
import { Disabled } from "@opal/core";
|
||||
import Switch from "@/refresh-components/inputs/Switch";
|
||||
import { SvgEdit, SvgServer } from "@opal/icons";
|
||||
import EmptyMessage from "@/refresh-components/EmptyMessage";
|
||||
@@ -116,17 +115,14 @@ export function DiscordGuildsTable({ guilds, onRefresh }: Props) {
|
||||
{guilds.map((guild) => (
|
||||
<TableRow key={guild.id}>
|
||||
<TableCell>
|
||||
<Disabled disabled={!guild.guild_id}>
|
||||
<Button
|
||||
prominence="internal"
|
||||
onClick={() =>
|
||||
router.push(`/admin/discord-bot/${guild.id}`)
|
||||
}
|
||||
icon={SvgEdit}
|
||||
>
|
||||
{guild.guild_name || `Server #${guild.id}`}
|
||||
</Button>
|
||||
</Disabled>
|
||||
<Button
|
||||
disabled={!guild.guild_id}
|
||||
prominence="internal"
|
||||
onClick={() => router.push(`/admin/discord-bot/${guild.id}`)}
|
||||
icon={SvgEdit}
|
||||
>
|
||||
{guild.guild_name || `Server #${guild.id}`}
|
||||
</Button>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
{guild.guild_id ? (
|
||||
|
||||
@@ -13,7 +13,6 @@ import Card from "@/refresh-components/cards/Card";
|
||||
import { Callout } from "@/components/ui/callout";
|
||||
import Message from "@/refresh-components/messages/Message";
|
||||
import { Button } from "@opal/components";
|
||||
import { Disabled } from "@opal/core";
|
||||
import { SvgServer } from "@opal/icons";
|
||||
import InputSelect from "@/refresh-components/inputs/InputSelect";
|
||||
import {
|
||||
@@ -105,16 +104,20 @@ function GuildDetailContent({
|
||||
width="fit"
|
||||
gap={0.5}
|
||||
>
|
||||
<Disabled disabled={disabled}>
|
||||
<Button prominence="secondary" onClick={handleEnableAll}>
|
||||
Enable All
|
||||
</Button>
|
||||
</Disabled>
|
||||
<Disabled disabled={disabled}>
|
||||
<Button prominence="secondary" onClick={handleDisableAll}>
|
||||
Disable All
|
||||
</Button>
|
||||
</Disabled>
|
||||
<Button
|
||||
disabled={disabled}
|
||||
prominence="secondary"
|
||||
onClick={handleEnableAll}
|
||||
>
|
||||
Enable All
|
||||
</Button>
|
||||
<Button
|
||||
disabled={disabled}
|
||||
prominence="secondary"
|
||||
onClick={handleDisableAll}
|
||||
>
|
||||
Disable All
|
||||
</Button>
|
||||
</Section>
|
||||
) : undefined
|
||||
}
|
||||
@@ -335,9 +338,9 @@ export default function Page({ params }: Props) {
|
||||
description={registeredText}
|
||||
backButton
|
||||
rightChildren={
|
||||
<Disabled disabled={isUpdateDisabled}>
|
||||
<Button onClick={handleSaveChanges}>Update Configuration</Button>
|
||||
</Disabled>
|
||||
<Button disabled={isUpdateDisabled} onClick={handleSaveChanges}>
|
||||
Update Configuration
|
||||
</Button>
|
||||
}
|
||||
/>
|
||||
<SettingsLayouts.Body>
|
||||
|
||||
@@ -2,7 +2,6 @@ import React, { useRef, useState } from "react";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { Callout } from "@/components/ui/callout";
|
||||
import { Button } from "@opal/components";
|
||||
import { Disabled } from "@opal/core";
|
||||
import { Formik, Form } from "formik";
|
||||
import * as Yup from "yup";
|
||||
import { Label, TextFormField } from "@/components/Field";
|
||||
@@ -297,19 +296,18 @@ export default function ProviderCreationModal({
|
||||
</Callout>
|
||||
)}
|
||||
|
||||
<Disabled disabled={isSubmitting}>
|
||||
<Button
|
||||
type="submit"
|
||||
width="full"
|
||||
icon={isSubmitting ? SimpleLoader : undefined}
|
||||
>
|
||||
{isSubmitting
|
||||
? "Submitting"
|
||||
: existingProvider
|
||||
? "Update"
|
||||
: "Create"}
|
||||
</Button>
|
||||
</Disabled>
|
||||
<Button
|
||||
disabled={isSubmitting}
|
||||
type="submit"
|
||||
width="full"
|
||||
icon={isSubmitting ? SimpleLoader : undefined}
|
||||
>
|
||||
{isSubmitting
|
||||
? "Submitting"
|
||||
: existingProvider
|
||||
? "Update"
|
||||
: "Create"}
|
||||
</Button>
|
||||
</Form>
|
||||
)}
|
||||
</Formik>
|
||||
|
||||
@@ -7,7 +7,6 @@ import { useCallback, useEffect, useMemo, useState, useRef } from "react";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import { Button as OpalButton } from "@opal/components";
|
||||
import { Disabled } from "@opal/core";
|
||||
import { WarningCircle, Warning, CaretDownIcon } from "@phosphor-icons/react";
|
||||
import {
|
||||
CloudEmbeddingModel,
|
||||
@@ -378,16 +377,15 @@ export default function EmbeddingForm() {
|
||||
</div>
|
||||
) : (
|
||||
<div className="flex mx-auto gap-x-1 ml-auto items-center">
|
||||
<Disabled disabled={!isOverallFormValid}>
|
||||
<OpalButton
|
||||
onClick={() => {
|
||||
updateSearch();
|
||||
navigateToEmbeddingPage("search settings");
|
||||
}}
|
||||
>
|
||||
Update Search
|
||||
</OpalButton>
|
||||
</Disabled>
|
||||
<OpalButton
|
||||
disabled={!isOverallFormValid}
|
||||
onClick={() => {
|
||||
updateSearch();
|
||||
navigateToEmbeddingPage("search settings");
|
||||
}}
|
||||
>
|
||||
Update Search
|
||||
</OpalButton>
|
||||
{!isOverallFormValid &&
|
||||
Object.keys(combinedFormErrors).length > 0 && (
|
||||
<div className="relative group">
|
||||
|
||||
@@ -10,7 +10,6 @@ import {
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import Modal from "@/refresh-components/Modal";
|
||||
import { Button } from "@opal/components";
|
||||
import { Disabled } from "@opal/core";
|
||||
import SwitchField from "@/refresh-components/form/SwitchField";
|
||||
import { Form, Formik, FormikState, useFormikContext } from "formik";
|
||||
import { useState } from "react";
|
||||
@@ -201,9 +200,9 @@ function KGConfiguration({
|
||||
disabled={!props.values.enabled}
|
||||
/>
|
||||
</div>
|
||||
<Disabled disabled={!props.dirty}>
|
||||
<Button type="submit">Submit</Button>
|
||||
</Disabled>
|
||||
<Button disabled={!props.dirty} type="submit">
|
||||
Submit
|
||||
</Button>
|
||||
</div>
|
||||
</Form>
|
||||
)}
|
||||
|
||||
@@ -2,7 +2,6 @@ import { SvgDownload, SvgKey, SvgRefreshCw } from "@opal/icons";
|
||||
import { Interactive, Hoverable } from "@opal/core";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import { Button } from "@opal/components";
|
||||
import { Disabled } from "@opal/core";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import CopyIconButton from "@/refresh-components/buttons/CopyIconButton";
|
||||
import InputTextArea from "@/refresh-components/inputs/InputTextArea";
|
||||
@@ -55,11 +54,13 @@ export default function ScimModal({
|
||||
title="Regenerate SCIM Token"
|
||||
onClose={onClose}
|
||||
submit={
|
||||
<Disabled disabled={isSubmitting}>
|
||||
<Button variant="danger" onClick={onRegenerate}>
|
||||
Regenerate Token
|
||||
</Button>
|
||||
</Disabled>
|
||||
<Button
|
||||
disabled={isSubmitting}
|
||||
variant="danger"
|
||||
onClick={onRegenerate}
|
||||
>
|
||||
Regenerate Token
|
||||
</Button>
|
||||
}
|
||||
>
|
||||
<Section alignItems="start" gap={0.5}>
|
||||
|
||||
@@ -3,7 +3,6 @@ import { ContentAction } from "@opal/layouts";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import Card from "@/refresh-components/cards/Card";
|
||||
import { Button } from "@opal/components";
|
||||
import { Disabled } from "@opal/core";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import { timeAgo } from "@/lib/time";
|
||||
@@ -54,11 +53,13 @@ export default function ScimSyncCard({
|
||||
Regenerate Token
|
||||
</Button>
|
||||
) : (
|
||||
<Disabled disabled={isSubmitting}>
|
||||
<Button rightIcon={SvgKey} onClick={onGenerate}>
|
||||
Generate SCIM Token
|
||||
</Button>
|
||||
</Disabled>
|
||||
<Button
|
||||
disabled={isSubmitting}
|
||||
rightIcon={SvgKey}
|
||||
onClick={onGenerate}
|
||||
>
|
||||
Generate SCIM Token
|
||||
</Button>
|
||||
)
|
||||
}
|
||||
/>
|
||||
|
||||
@@ -2,7 +2,6 @@
|
||||
|
||||
import * as Yup from "yup";
|
||||
import { Button } from "@opal/components";
|
||||
import { Disabled } from "@opal/core";
|
||||
import { useEffect, useState } from "react";
|
||||
import Modal from "@/refresh-components/Modal";
|
||||
import { Form, Formik } from "formik";
|
||||
@@ -148,9 +147,9 @@ export default function CreateRateLimitModal({
|
||||
type="number"
|
||||
placeholder=""
|
||||
/>
|
||||
<Disabled disabled={isSubmitting}>
|
||||
<Button type="submit">Create</Button>
|
||||
</Disabled>
|
||||
<Button disabled={isSubmitting} type="submit">
|
||||
Create
|
||||
</Button>
|
||||
</Form>
|
||||
)}
|
||||
</Formik>
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user