mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-26 20:25:46 +00:00
Compare commits
24 Commits
v3.0.0-bet
...
csv_render
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
80cf389774 | ||
|
|
e775aaacb7 | ||
|
|
e5b08b3d92 | ||
|
|
7c91304ba2 | ||
|
|
68a292b500 | ||
|
|
e553b80030 | ||
|
|
f3949f8e09 | ||
|
|
c7c064e296 | ||
|
|
68b91a8862 | ||
|
|
c23e5a196d | ||
|
|
093223c6c4 | ||
|
|
89517111d4 | ||
|
|
883d4b4ceb | ||
|
|
f3672b6819 | ||
|
|
921f5d9e96 | ||
|
|
15fe47adc5 | ||
|
|
29958f1a52 | ||
|
|
ac7f9838bc | ||
|
|
d0fa4b3319 | ||
|
|
3fb4fb422e | ||
|
|
ba5da22ea1 | ||
|
|
9909049047 | ||
|
|
c516aa3e3c | ||
|
|
5cc6220417 |
2
.github/pull_request_template.md
vendored
2
.github/pull_request_template.md
vendored
@@ -8,5 +8,5 @@
|
||||
|
||||
## Additional Options
|
||||
|
||||
- [ ] [Required] I have considered whether this PR needs to be cherry-picked to the latest beta branch.
|
||||
- [ ] [Optional] Please cherry-pick this PR to the latest release version.
|
||||
- [ ] [Optional] Override Linear Check
|
||||
|
||||
79
.github/workflows/post-merge-beta-cherry-pick.yml
vendored
Normal file
79
.github/workflows/post-merge-beta-cherry-pick.yml
vendored
Normal file
@@ -0,0 +1,79 @@
|
||||
name: Post-Merge Beta Cherry-Pick
|
||||
|
||||
on:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
|
||||
permissions:
|
||||
contents: write
|
||||
pull-requests: write
|
||||
|
||||
jobs:
|
||||
cherry-pick-to-latest-release:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Resolve merged PR and checkbox state
|
||||
id: gate
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
run: |
|
||||
# For the commit that triggered this workflow (HEAD on main), fetch all
|
||||
# associated PRs and keep only the PR that was actually merged into main
|
||||
# with this exact merge commit SHA.
|
||||
pr_numbers="$(gh api "repos/${GITHUB_REPOSITORY}/commits/${GITHUB_SHA}/pulls" | jq -r --arg sha "${GITHUB_SHA}" '.[] | select(.merged_at != null and .base.ref == "main" and .merge_commit_sha == $sha) | .number')"
|
||||
match_count="$(printf '%s\n' "$pr_numbers" | sed '/^[[:space:]]*$/d' | wc -l | tr -d ' ')"
|
||||
pr_number="$(printf '%s\n' "$pr_numbers" | sed '/^[[:space:]]*$/d' | head -n 1)"
|
||||
|
||||
if [ "${match_count}" -gt 1 ]; then
|
||||
echo "::warning::Multiple merged PRs matched commit ${GITHUB_SHA}. Using PR #${pr_number}."
|
||||
fi
|
||||
|
||||
if [ -z "$pr_number" ]; then
|
||||
echo "No merged PR associated with commit ${GITHUB_SHA}; skipping."
|
||||
echo "should_cherrypick=false" >> "$GITHUB_OUTPUT"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
# Read the PR body and check whether the helper checkbox is checked.
|
||||
pr_body="$(gh api "repos/${GITHUB_REPOSITORY}/pulls/${pr_number}" --jq '.body // ""')"
|
||||
echo "pr_number=$pr_number" >> "$GITHUB_OUTPUT"
|
||||
|
||||
if echo "$pr_body" | grep -qiE "\\[x\\][[:space:]]*(\\[[^]]+\\][[:space:]]*)?Please cherry-pick this PR to the latest release version"; then
|
||||
echo "should_cherrypick=true" >> "$GITHUB_OUTPUT"
|
||||
echo "Cherry-pick checkbox checked for PR #${pr_number}."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "should_cherrypick=false" >> "$GITHUB_OUTPUT"
|
||||
echo "Cherry-pick checkbox not checked for PR #${pr_number}. Skipping."
|
||||
|
||||
- name: Checkout repository
|
||||
if: steps.gate.outputs.should_cherrypick == 'true'
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # ratchet:actions/checkout@v6
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: true
|
||||
ref: main
|
||||
|
||||
- name: Install the latest version of uv
|
||||
if: steps.gate.outputs.should_cherrypick == 'true'
|
||||
uses: astral-sh/setup-uv@61cb8a9741eeb8a550a1b8544337180c0fc8476b # ratchet:astral-sh/setup-uv@v7
|
||||
with:
|
||||
enable-cache: false
|
||||
version: "0.9.9"
|
||||
|
||||
- name: Configure git identity
|
||||
if: steps.gate.outputs.should_cherrypick == 'true'
|
||||
run: |
|
||||
git config user.name "github-actions[bot]"
|
||||
git config user.email "github-actions[bot]@users.noreply.github.com"
|
||||
|
||||
- name: Create cherry-pick PR to latest release
|
||||
if: steps.gate.outputs.should_cherrypick == 'true'
|
||||
env:
|
||||
GH_TOKEN: ${{ github.token }}
|
||||
GITHUB_TOKEN: ${{ github.token }}
|
||||
run: |
|
||||
uv run --no-sync --with onyx-devtools ods cherry-pick "${GITHUB_SHA}" --yes --no-verify
|
||||
28
.github/workflows/pr-beta-cherrypick-check.yml
vendored
28
.github/workflows/pr-beta-cherrypick-check.yml
vendored
@@ -1,28 +0,0 @@
|
||||
name: Require beta cherry-pick consideration
|
||||
concurrency:
|
||||
group: Require-Beta-Cherrypick-Consideration-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, edited, reopened, synchronize]
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
beta-cherrypick-check:
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
steps:
|
||||
- name: Check PR body for beta cherry-pick consideration
|
||||
env:
|
||||
PR_BODY: ${{ github.event.pull_request.body }}
|
||||
run: |
|
||||
if echo "$PR_BODY" | grep -qiE "\\[x\\][[:space:]]*\\[Required\\][[:space:]]*I have considered whether this PR needs to be cherry[- ]picked to the latest beta branch"; then
|
||||
echo "Cherry-pick consideration box is checked. Check passed."
|
||||
exit 0
|
||||
fi
|
||||
|
||||
echo "::error::Please check the 'I have considered whether this PR needs to be cherry-picked to the latest beta branch' box in the PR description."
|
||||
exit 1
|
||||
@@ -21,15 +21,14 @@ import sys
|
||||
import threading
|
||||
import time
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import List, NamedTuple
|
||||
from typing import NamedTuple
|
||||
|
||||
from alembic.config import Config
|
||||
from alembic.script import ScriptDirectory
|
||||
from sqlalchemy import text
|
||||
|
||||
from onyx.db.engine.sql_engine import is_valid_schema_name
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.db.engine.tenant_utils import get_all_tenant_ids
|
||||
from onyx.db.engine.tenant_utils import get_schemas_needing_migration
|
||||
from shared_configs.configs import TENANT_ID_PREFIX
|
||||
|
||||
|
||||
@@ -105,56 +104,6 @@ def get_head_revision() -> str | None:
|
||||
return script.get_current_head()
|
||||
|
||||
|
||||
def get_schemas_needing_migration(
|
||||
tenant_schemas: List[str], head_rev: str
|
||||
) -> List[str]:
|
||||
"""Return only schemas whose current alembic version is not at head."""
|
||||
if not tenant_schemas:
|
||||
return []
|
||||
|
||||
engine = SqlEngine.get_engine()
|
||||
|
||||
with engine.connect() as conn:
|
||||
# Find which schemas actually have an alembic_version table
|
||||
rows = conn.execute(
|
||||
text(
|
||||
"SELECT table_schema FROM information_schema.tables "
|
||||
"WHERE table_name = 'alembic_version' "
|
||||
"AND table_schema = ANY(:schemas)"
|
||||
),
|
||||
{"schemas": tenant_schemas},
|
||||
)
|
||||
schemas_with_table = set(row[0] for row in rows)
|
||||
|
||||
# Schemas without the table definitely need migration
|
||||
needs_migration = [s for s in tenant_schemas if s not in schemas_with_table]
|
||||
|
||||
if not schemas_with_table:
|
||||
return needs_migration
|
||||
|
||||
# Validate schema names before interpolating into SQL
|
||||
for schema in schemas_with_table:
|
||||
if not is_valid_schema_name(schema):
|
||||
raise ValueError(f"Invalid schema name: {schema}")
|
||||
|
||||
# Single query to get every schema's current revision at once.
|
||||
# Use integer tags instead of interpolating schema names into
|
||||
# string literals to avoid quoting issues.
|
||||
schema_list = list(schemas_with_table)
|
||||
union_parts = [
|
||||
f'SELECT {i} AS idx, version_num FROM "{schema}".alembic_version'
|
||||
for i, schema in enumerate(schema_list)
|
||||
]
|
||||
rows = conn.execute(text(" UNION ALL ".join(union_parts)))
|
||||
version_by_schema = {schema_list[row[0]]: row[1] for row in rows}
|
||||
|
||||
needs_migration.extend(
|
||||
s for s in schemas_with_table if version_by_schema.get(s) != head_rev
|
||||
)
|
||||
|
||||
return needs_migration
|
||||
|
||||
|
||||
def run_migrations_parallel(
|
||||
schemas: list[str],
|
||||
max_workers: int,
|
||||
|
||||
@@ -0,0 +1,31 @@
|
||||
"""code interpreter server model
|
||||
|
||||
Revision ID: 7cb492013621
|
||||
Revises: 0bb4558f35df
|
||||
Create Date: 2026-02-22 18:54:54.007265
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "7cb492013621"
|
||||
down_revision = "0bb4558f35df"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"code_interpreter_server",
|
||||
sa.Column("id", sa.Integer, primary_key=True),
|
||||
sa.Column(
|
||||
"server_enabled", sa.Boolean, nullable=False, server_default=sa.true()
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table("code_interpreter_server")
|
||||
@@ -127,9 +127,14 @@ class ScimDAL(DAL):
|
||||
self,
|
||||
external_id: str,
|
||||
user_id: UUID,
|
||||
scim_username: str | None = None,
|
||||
) -> ScimUserMapping:
|
||||
"""Create a mapping between a SCIM externalId and an Onyx user."""
|
||||
mapping = ScimUserMapping(external_id=external_id, user_id=user_id)
|
||||
mapping = ScimUserMapping(
|
||||
external_id=external_id,
|
||||
user_id=user_id,
|
||||
scim_username=scim_username,
|
||||
)
|
||||
self._session.add(mapping)
|
||||
self._session.flush()
|
||||
return mapping
|
||||
@@ -248,11 +253,11 @@ class ScimDAL(DAL):
|
||||
scim_filter: ScimFilter | None,
|
||||
start_index: int = 1,
|
||||
count: int = 100,
|
||||
) -> tuple[list[tuple[User, str | None]], int]:
|
||||
) -> tuple[list[tuple[User, ScimUserMapping | None]], int]:
|
||||
"""Query users with optional SCIM filter and pagination.
|
||||
|
||||
Returns:
|
||||
A tuple of (list of (user, external_id) pairs, total_count).
|
||||
A tuple of (list of (user, mapping) pairs, total_count).
|
||||
|
||||
Raises:
|
||||
ValueError: If the filter uses an unsupported attribute.
|
||||
@@ -292,33 +297,104 @@ class ScimDAL(DAL):
|
||||
users = list(
|
||||
self._session.scalars(
|
||||
query.order_by(User.id).offset(offset).limit(count) # type: ignore[arg-type]
|
||||
).all()
|
||||
)
|
||||
.unique()
|
||||
.all()
|
||||
)
|
||||
|
||||
# Batch-fetch external IDs to avoid N+1 queries
|
||||
ext_id_map = self._get_user_external_ids([u.id for u in users])
|
||||
return [(u, ext_id_map.get(u.id)) for u in users], total
|
||||
# Batch-fetch SCIM mappings to avoid N+1 queries
|
||||
mapping_map = self._get_user_mappings_batch([u.id for u in users])
|
||||
return [(u, mapping_map.get(u.id)) for u in users], total
|
||||
|
||||
def sync_user_external_id(self, user_id: UUID, new_external_id: str | None) -> None:
|
||||
def sync_user_external_id(
|
||||
self,
|
||||
user_id: UUID,
|
||||
new_external_id: str | None,
|
||||
scim_username: str | None = None,
|
||||
) -> None:
|
||||
"""Create, update, or delete the external ID mapping for a user."""
|
||||
mapping = self.get_user_mapping_by_user_id(user_id)
|
||||
if new_external_id:
|
||||
if mapping:
|
||||
if mapping.external_id != new_external_id:
|
||||
mapping.external_id = new_external_id
|
||||
if scim_username is not None:
|
||||
mapping.scim_username = scim_username
|
||||
else:
|
||||
self.create_user_mapping(external_id=new_external_id, user_id=user_id)
|
||||
self.create_user_mapping(
|
||||
external_id=new_external_id,
|
||||
user_id=user_id,
|
||||
scim_username=scim_username,
|
||||
)
|
||||
elif mapping:
|
||||
self.delete_user_mapping(mapping.id)
|
||||
|
||||
def _get_user_external_ids(self, user_ids: list[UUID]) -> dict[UUID, str]:
|
||||
"""Batch-fetch external IDs for a list of user IDs."""
|
||||
def _get_user_mappings_batch(
|
||||
self, user_ids: list[UUID]
|
||||
) -> dict[UUID, ScimUserMapping]:
|
||||
"""Batch-fetch SCIM user mappings keyed by user ID."""
|
||||
if not user_ids:
|
||||
return {}
|
||||
mappings = self._session.scalars(
|
||||
select(ScimUserMapping).where(ScimUserMapping.user_id.in_(user_ids))
|
||||
).all()
|
||||
return {m.user_id: m.external_id for m in mappings}
|
||||
return {m.user_id: m for m in mappings}
|
||||
|
||||
def get_user_groups(self, user_id: UUID) -> list[tuple[int, str]]:
|
||||
"""Get groups a user belongs to as ``(group_id, group_name)`` pairs.
|
||||
|
||||
Excludes groups marked for deletion.
|
||||
"""
|
||||
rels = self._session.scalars(
|
||||
select(User__UserGroup).where(User__UserGroup.user_id == user_id)
|
||||
).all()
|
||||
|
||||
group_ids = [r.user_group_id for r in rels]
|
||||
if not group_ids:
|
||||
return []
|
||||
|
||||
groups = self._session.scalars(
|
||||
select(UserGroup).where(
|
||||
UserGroup.id.in_(group_ids),
|
||||
UserGroup.is_up_for_deletion.is_(False),
|
||||
)
|
||||
).all()
|
||||
return [(g.id, g.name) for g in groups]
|
||||
|
||||
def get_users_groups_batch(
|
||||
self, user_ids: list[UUID]
|
||||
) -> dict[UUID, list[tuple[int, str]]]:
|
||||
"""Batch-fetch group memberships for multiple users.
|
||||
|
||||
Returns a mapping of ``user_id → [(group_id, group_name), ...]``.
|
||||
Avoids N+1 queries when building user list responses.
|
||||
"""
|
||||
if not user_ids:
|
||||
return {}
|
||||
|
||||
rels = self._session.scalars(
|
||||
select(User__UserGroup).where(User__UserGroup.user_id.in_(user_ids))
|
||||
).all()
|
||||
|
||||
group_ids = list({r.user_group_id for r in rels})
|
||||
if not group_ids:
|
||||
return {}
|
||||
|
||||
groups = self._session.scalars(
|
||||
select(UserGroup).where(
|
||||
UserGroup.id.in_(group_ids),
|
||||
UserGroup.is_up_for_deletion.is_(False),
|
||||
)
|
||||
).all()
|
||||
groups_by_id = {g.id: g.name for g in groups}
|
||||
|
||||
result: dict[UUID, list[tuple[int, str]]] = {}
|
||||
for r in rels:
|
||||
if r.user_id and r.user_group_id in groups_by_id:
|
||||
result.setdefault(r.user_id, []).append(
|
||||
(r.user_group_id, groups_by_id[r.user_group_id])
|
||||
)
|
||||
return result
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Group mapping operations
|
||||
@@ -483,9 +559,13 @@ class ScimDAL(DAL):
|
||||
if not user_ids:
|
||||
return []
|
||||
|
||||
users = self._session.scalars(
|
||||
select(User).where(User.id.in_(user_ids)) # type: ignore[attr-defined]
|
||||
).all()
|
||||
users = (
|
||||
self._session.scalars(
|
||||
select(User).where(User.id.in_(user_ids)) # type: ignore[attr-defined]
|
||||
)
|
||||
.unique()
|
||||
.all()
|
||||
)
|
||||
users_by_id = {u.id: u for u in users}
|
||||
|
||||
return [
|
||||
@@ -504,9 +584,13 @@ class ScimDAL(DAL):
|
||||
"""
|
||||
if not uuids:
|
||||
return []
|
||||
existing_users = self._session.scalars(
|
||||
select(User).where(User.id.in_(uuids)) # type: ignore[attr-defined]
|
||||
).all()
|
||||
existing_users = (
|
||||
self._session.scalars(
|
||||
select(User).where(User.id.in_(uuids)) # type: ignore[attr-defined]
|
||||
)
|
||||
.unique()
|
||||
.all()
|
||||
)
|
||||
existing_ids = {u.id for u in existing_users}
|
||||
return [uid for uid in uuids if uid not in existing_ids]
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ from sqlalchemy import Select
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.server.user_group.models import SetCuratorRequest
|
||||
@@ -18,11 +19,15 @@ from onyx.db.connector_credential_pair import get_connector_credential_pair_from
|
||||
from onyx.db.enums import AccessType
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Credential
|
||||
from onyx.db.models import Credential__UserGroup
|
||||
from onyx.db.models import Document
|
||||
from onyx.db.models import DocumentByConnectorCredentialPair
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import DocumentSet__UserGroup
|
||||
from onyx.db.models import FederatedConnector__DocumentSet
|
||||
from onyx.db.models import LLMProvider__UserGroup
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__UserGroup
|
||||
from onyx.db.models import TokenRateLimit__UserGroup
|
||||
from onyx.db.models import User
|
||||
@@ -195,8 +200,60 @@ def fetch_user_group(db_session: Session, user_group_id: int) -> UserGroup | Non
|
||||
return db_session.scalar(stmt)
|
||||
|
||||
|
||||
def _add_user_group_snapshot_eager_loads(
|
||||
stmt: Select,
|
||||
) -> Select:
|
||||
"""Add eager loading options needed by UserGroup.from_model snapshot creation."""
|
||||
return stmt.options(
|
||||
selectinload(UserGroup.users),
|
||||
selectinload(UserGroup.user_group_relationships),
|
||||
selectinload(UserGroup.cc_pair_relationships)
|
||||
.selectinload(UserGroup__ConnectorCredentialPair.cc_pair)
|
||||
.options(
|
||||
selectinload(ConnectorCredentialPair.connector),
|
||||
selectinload(ConnectorCredentialPair.credential).selectinload(
|
||||
Credential.user
|
||||
),
|
||||
),
|
||||
selectinload(UserGroup.document_sets).options(
|
||||
selectinload(DocumentSet.connector_credential_pairs).selectinload(
|
||||
ConnectorCredentialPair.connector
|
||||
),
|
||||
selectinload(DocumentSet.users),
|
||||
selectinload(DocumentSet.groups),
|
||||
selectinload(DocumentSet.federated_connectors).selectinload(
|
||||
FederatedConnector__DocumentSet.federated_connector
|
||||
),
|
||||
),
|
||||
selectinload(UserGroup.personas).options(
|
||||
selectinload(Persona.tools),
|
||||
selectinload(Persona.hierarchy_nodes),
|
||||
selectinload(Persona.attached_documents).selectinload(
|
||||
Document.parent_hierarchy_node
|
||||
),
|
||||
selectinload(Persona.labels),
|
||||
selectinload(Persona.document_sets).options(
|
||||
selectinload(DocumentSet.connector_credential_pairs).selectinload(
|
||||
ConnectorCredentialPair.connector
|
||||
),
|
||||
selectinload(DocumentSet.users),
|
||||
selectinload(DocumentSet.groups),
|
||||
selectinload(DocumentSet.federated_connectors).selectinload(
|
||||
FederatedConnector__DocumentSet.federated_connector
|
||||
),
|
||||
),
|
||||
selectinload(Persona.user),
|
||||
selectinload(Persona.user_files),
|
||||
selectinload(Persona.users),
|
||||
selectinload(Persona.groups),
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def fetch_user_groups(
|
||||
db_session: Session, only_up_to_date: bool = True
|
||||
db_session: Session,
|
||||
only_up_to_date: bool = True,
|
||||
eager_load_for_snapshot: bool = False,
|
||||
) -> Sequence[UserGroup]:
|
||||
"""
|
||||
Fetches user groups from the database.
|
||||
@@ -209,6 +266,8 @@ def fetch_user_groups(
|
||||
db_session (Session): The SQLAlchemy session used to query the database.
|
||||
only_up_to_date (bool, optional): Flag to determine whether to filter the results
|
||||
to include only up to date user groups. Defaults to `True`.
|
||||
eager_load_for_snapshot: If True, adds eager loading for all relationships
|
||||
needed by UserGroup.from_model snapshot creation.
|
||||
|
||||
Returns:
|
||||
Sequence[UserGroup]: A sequence of `UserGroup` objects matching the query criteria.
|
||||
@@ -216,11 +275,16 @@ def fetch_user_groups(
|
||||
stmt = select(UserGroup)
|
||||
if only_up_to_date:
|
||||
stmt = stmt.where(UserGroup.is_up_to_date == True) # noqa: E712
|
||||
return db_session.scalars(stmt).all()
|
||||
if eager_load_for_snapshot:
|
||||
stmt = _add_user_group_snapshot_eager_loads(stmt)
|
||||
return db_session.scalars(stmt).unique().all()
|
||||
|
||||
|
||||
def fetch_user_groups_for_user(
|
||||
db_session: Session, user_id: UUID, only_curator_groups: bool = False
|
||||
db_session: Session,
|
||||
user_id: UUID,
|
||||
only_curator_groups: bool = False,
|
||||
eager_load_for_snapshot: bool = False,
|
||||
) -> Sequence[UserGroup]:
|
||||
stmt = (
|
||||
select(UserGroup)
|
||||
@@ -230,7 +294,9 @@ def fetch_user_groups_for_user(
|
||||
)
|
||||
if only_curator_groups:
|
||||
stmt = stmt.where(User__UserGroup.is_curator == True) # noqa: E712
|
||||
return db_session.scalars(stmt).all()
|
||||
if eager_load_for_snapshot:
|
||||
stmt = _add_user_group_snapshot_eager_loads(stmt)
|
||||
return db_session.scalars(stmt).unique().all()
|
||||
|
||||
|
||||
def construct_document_id_select_by_usergroup(
|
||||
|
||||
@@ -34,7 +34,7 @@ class SendSearchQueryRequest(BaseModel):
|
||||
filters: BaseFilters | None = None
|
||||
num_docs_fed_to_llm_selection: int | None = None
|
||||
run_query_expansion: bool = False
|
||||
num_hits: int = 50
|
||||
num_hits: int = 30
|
||||
|
||||
include_content: bool = False
|
||||
stream: bool = False
|
||||
|
||||
@@ -26,12 +26,10 @@ from sqlalchemy.orm import Session
|
||||
from ee.onyx.db.scim import ScimDAL
|
||||
from ee.onyx.server.scim.auth import verify_scim_token
|
||||
from ee.onyx.server.scim.filtering import parse_scim_filter
|
||||
from ee.onyx.server.scim.models import ScimEmail
|
||||
from ee.onyx.server.scim.models import ScimError
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimListResponse
|
||||
from ee.onyx.server.scim.models import ScimMeta
|
||||
from ee.onyx.server.scim.models import ScimName
|
||||
from ee.onyx.server.scim.models import ScimPatchRequest
|
||||
from ee.onyx.server.scim.models import ScimResourceType
|
||||
@@ -41,6 +39,8 @@ from ee.onyx.server.scim.models import ScimUserResource
|
||||
from ee.onyx.server.scim.patch import apply_group_patch
|
||||
from ee.onyx.server.scim.patch import apply_user_patch
|
||||
from ee.onyx.server.scim.patch import ScimPatchError
|
||||
from ee.onyx.server.scim.providers.base import get_default_provider
|
||||
from ee.onyx.server.scim.providers.base import ScimProvider
|
||||
from ee.onyx.server.scim.schema_definitions import GROUP_RESOURCE_TYPE
|
||||
from ee.onyx.server.scim.schema_definitions import GROUP_SCHEMA_DEF
|
||||
from ee.onyx.server.scim.schema_definitions import SERVICE_PROVIDER_CONFIG
|
||||
@@ -53,7 +53,6 @@ from onyx.db.models import UserGroup
|
||||
from onyx.db.models import UserRole
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
|
||||
# NOTE: All URL paths in this router (/ServiceProviderConfig, /ResourceTypes,
|
||||
# /Schemas, /Users, /Groups) are mandated by the SCIM spec (RFC 7643/7644).
|
||||
# IdPs like Okta and Azure AD hardcode these exact paths, so they cannot be
|
||||
@@ -63,6 +62,18 @@ scim_router = APIRouter(prefix="/scim/v2", tags=["SCIM"])
|
||||
_pw_helper = PasswordHelper()
|
||||
|
||||
|
||||
def _get_provider(
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
) -> ScimProvider:
|
||||
"""Resolve the SCIM provider for the current request.
|
||||
|
||||
Currently returns OktaProvider for all requests. When multi-provider
|
||||
support is added (ENG-3652), this will resolve based on token metadata
|
||||
or tenant configuration — no endpoint changes required.
|
||||
"""
|
||||
return get_default_provider()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Service Discovery Endpoints (unauthenticated)
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -100,28 +111,6 @@ def _scim_error_response(status: int, detail: str) -> JSONResponse:
|
||||
)
|
||||
|
||||
|
||||
def _user_to_scim(user: User, external_id: str | None = None) -> ScimUserResource:
|
||||
"""Convert an Onyx User to a SCIM User resource representation."""
|
||||
name = None
|
||||
if user.personal_name:
|
||||
parts = user.personal_name.split(" ", 1)
|
||||
name = ScimName(
|
||||
givenName=parts[0],
|
||||
familyName=parts[1] if len(parts) > 1 else None,
|
||||
formatted=user.personal_name,
|
||||
)
|
||||
|
||||
return ScimUserResource(
|
||||
id=str(user.id),
|
||||
externalId=external_id,
|
||||
userName=user.email,
|
||||
name=name,
|
||||
emails=[ScimEmail(value=user.email, type="work", primary=True)],
|
||||
active=user.is_active,
|
||||
meta=ScimMeta(resourceType="User"),
|
||||
)
|
||||
|
||||
|
||||
def _check_seat_availability(dal: ScimDAL) -> str | None:
|
||||
"""Return an error message if seat limit is reached, else None."""
|
||||
check_fn = fetch_ee_implementation_or_noop(
|
||||
@@ -155,9 +144,10 @@ def _scim_name_to_str(name: ScimName | None) -> str | None:
|
||||
"""
|
||||
if not name:
|
||||
return None
|
||||
return name.formatted or " ".join(
|
||||
part for part in [name.givenName, name.familyName] if part
|
||||
)
|
||||
# Build from givenName/familyName first — IdPs like Okta may send a stale
|
||||
# ``formatted`` value while updating the individual name components.
|
||||
parts = " ".join(part for part in [name.givenName, name.familyName] if part)
|
||||
return parts or name.formatted
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -171,6 +161,7 @@ def list_users(
|
||||
startIndex: int = Query(1, ge=1),
|
||||
count: int = Query(100, ge=0, le=500),
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimListResponse | JSONResponse:
|
||||
"""List users with optional SCIM filter and pagination."""
|
||||
@@ -183,12 +174,19 @@ def list_users(
|
||||
return _scim_error_response(400, str(e))
|
||||
|
||||
try:
|
||||
users_with_ext_ids, total = dal.list_users(scim_filter, startIndex, count)
|
||||
users_with_mappings, total = dal.list_users(scim_filter, startIndex, count)
|
||||
except ValueError as e:
|
||||
return _scim_error_response(400, str(e))
|
||||
|
||||
user_groups_map = dal.get_users_groups_batch([u.id for u, _ in users_with_mappings])
|
||||
resources: list[ScimUserResource | ScimGroupResource] = [
|
||||
_user_to_scim(user, ext_id) for user, ext_id in users_with_ext_ids
|
||||
provider.build_user_resource(
|
||||
user,
|
||||
mapping.external_id if mapping else None,
|
||||
groups=user_groups_map.get(user.id, []),
|
||||
scim_username=mapping.scim_username if mapping else None,
|
||||
)
|
||||
for user, mapping in users_with_mappings
|
||||
]
|
||||
|
||||
return ScimListResponse(
|
||||
@@ -203,6 +201,7 @@ def list_users(
|
||||
def get_user(
|
||||
user_id: str,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimUserResource | JSONResponse:
|
||||
"""Get a single user by ID."""
|
||||
@@ -215,20 +214,26 @@ def get_user(
|
||||
user = result
|
||||
|
||||
mapping = dal.get_user_mapping_by_user_id(user.id)
|
||||
return _user_to_scim(user, mapping.external_id if mapping else None)
|
||||
return provider.build_user_resource(
|
||||
user,
|
||||
mapping.external_id if mapping else None,
|
||||
groups=dal.get_user_groups(user.id),
|
||||
scim_username=mapping.scim_username if mapping else None,
|
||||
)
|
||||
|
||||
|
||||
@scim_router.post("/Users", status_code=201, response_model=None)
|
||||
def create_user(
|
||||
user_resource: ScimUserResource,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimUserResource | JSONResponse:
|
||||
"""Create a new user from a SCIM provisioning request."""
|
||||
dal = ScimDAL(db_session)
|
||||
dal.update_token_last_used(_token.id)
|
||||
|
||||
email = user_resource.userName.strip().lower()
|
||||
email = user_resource.userName.strip()
|
||||
|
||||
# externalId is how the IdP correlates this user on subsequent requests.
|
||||
# Without it, the IdP can't find the user and will try to re-create,
|
||||
@@ -264,11 +269,14 @@ def create_user(
|
||||
|
||||
# Create SCIM mapping (externalId is validated above, always present)
|
||||
external_id = user_resource.externalId
|
||||
dal.create_user_mapping(external_id=external_id, user_id=user.id)
|
||||
scim_username = user_resource.userName.strip()
|
||||
dal.create_user_mapping(
|
||||
external_id=external_id, user_id=user.id, scim_username=scim_username
|
||||
)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return _user_to_scim(user, external_id)
|
||||
return provider.build_user_resource(user, external_id, scim_username=scim_username)
|
||||
|
||||
|
||||
@scim_router.put("/Users/{user_id}", response_model=None)
|
||||
@@ -276,6 +284,7 @@ def replace_user(
|
||||
user_id: str,
|
||||
user_resource: ScimUserResource,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimUserResource | JSONResponse:
|
||||
"""Replace a user entirely (RFC 7644 §3.5.1)."""
|
||||
@@ -293,19 +302,27 @@ def replace_user(
|
||||
if seat_error:
|
||||
return _scim_error_response(403, seat_error)
|
||||
|
||||
personal_name = _scim_name_to_str(user_resource.name)
|
||||
|
||||
dal.update_user(
|
||||
user,
|
||||
email=user_resource.userName.strip().lower(),
|
||||
email=user_resource.userName.strip(),
|
||||
is_active=user_resource.active,
|
||||
personal_name=_scim_name_to_str(user_resource.name),
|
||||
personal_name=personal_name,
|
||||
)
|
||||
|
||||
new_external_id = user_resource.externalId
|
||||
dal.sync_user_external_id(user.id, new_external_id)
|
||||
scim_username = user_resource.userName.strip()
|
||||
dal.sync_user_external_id(user.id, new_external_id, scim_username=scim_username)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return _user_to_scim(user, new_external_id)
|
||||
return provider.build_user_resource(
|
||||
user,
|
||||
new_external_id,
|
||||
groups=dal.get_user_groups(user.id),
|
||||
scim_username=scim_username,
|
||||
)
|
||||
|
||||
|
||||
@scim_router.patch("/Users/{user_id}", response_model=None)
|
||||
@@ -313,6 +330,7 @@ def patch_user(
|
||||
user_id: str,
|
||||
patch_request: ScimPatchRequest,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimUserResource | JSONResponse:
|
||||
"""Partially update a user (RFC 7644 §3.5.2).
|
||||
@@ -330,11 +348,19 @@ def patch_user(
|
||||
|
||||
mapping = dal.get_user_mapping_by_user_id(user.id)
|
||||
external_id = mapping.external_id if mapping else None
|
||||
current_scim_username = mapping.scim_username if mapping else None
|
||||
|
||||
current = _user_to_scim(user, external_id)
|
||||
current = provider.build_user_resource(
|
||||
user,
|
||||
external_id,
|
||||
groups=dal.get_user_groups(user.id),
|
||||
scim_username=current_scim_username,
|
||||
)
|
||||
|
||||
try:
|
||||
patched = apply_user_patch(patch_request.Operations, current)
|
||||
patched = apply_user_patch(
|
||||
patch_request.Operations, current, provider.ignored_patch_paths
|
||||
)
|
||||
except ScimPatchError as e:
|
||||
return _scim_error_response(e.status, e.detail)
|
||||
|
||||
@@ -345,22 +371,40 @@ def patch_user(
|
||||
if seat_error:
|
||||
return _scim_error_response(403, seat_error)
|
||||
|
||||
# Track the scim_username — if userName was patched, update it
|
||||
new_scim_username = patched.userName.strip() if patched.userName else None
|
||||
|
||||
# If displayName was explicitly patched (different from the original), use
|
||||
# it as personal_name directly. Otherwise, derive from name components.
|
||||
personal_name: str | None
|
||||
if patched.displayName and patched.displayName != current.displayName:
|
||||
personal_name = patched.displayName
|
||||
else:
|
||||
personal_name = _scim_name_to_str(patched.name)
|
||||
|
||||
dal.update_user(
|
||||
user,
|
||||
email=(
|
||||
patched.userName.strip().lower()
|
||||
if patched.userName.lower() != user.email
|
||||
patched.userName.strip()
|
||||
if patched.userName.strip().lower() != user.email.lower()
|
||||
else None
|
||||
),
|
||||
is_active=patched.active if patched.active != user.is_active else None,
|
||||
personal_name=_scim_name_to_str(patched.name),
|
||||
personal_name=personal_name,
|
||||
)
|
||||
|
||||
dal.sync_user_external_id(user.id, patched.externalId)
|
||||
dal.sync_user_external_id(
|
||||
user.id, patched.externalId, scim_username=new_scim_username
|
||||
)
|
||||
|
||||
dal.commit()
|
||||
|
||||
return _user_to_scim(user, patched.externalId)
|
||||
return provider.build_user_resource(
|
||||
user,
|
||||
patched.externalId,
|
||||
groups=dal.get_user_groups(user.id),
|
||||
scim_username=new_scim_username,
|
||||
)
|
||||
|
||||
|
||||
@scim_router.delete("/Users/{user_id}", status_code=204, response_model=None)
|
||||
@@ -398,24 +442,6 @@ def delete_user(
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _group_to_scim(
|
||||
group: UserGroup,
|
||||
members: list[tuple[UUID, str | None]],
|
||||
external_id: str | None = None,
|
||||
) -> ScimGroupResource:
|
||||
"""Convert an Onyx UserGroup to a SCIM Group resource."""
|
||||
scim_members = [
|
||||
ScimGroupMember(value=str(uid), display=email) for uid, email in members
|
||||
]
|
||||
return ScimGroupResource(
|
||||
id=str(group.id),
|
||||
externalId=external_id,
|
||||
displayName=group.name,
|
||||
members=scim_members,
|
||||
meta=ScimMeta(resourceType="Group"),
|
||||
)
|
||||
|
||||
|
||||
def _fetch_group_or_404(group_id: str, dal: ScimDAL) -> UserGroup | JSONResponse:
|
||||
"""Parse *group_id* as int, look up the group, or return a 404 error."""
|
||||
try:
|
||||
@@ -474,6 +500,7 @@ def list_groups(
|
||||
startIndex: int = Query(1, ge=1),
|
||||
count: int = Query(100, ge=0, le=500),
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimListResponse | JSONResponse:
|
||||
"""List groups with optional SCIM filter and pagination."""
|
||||
@@ -491,7 +518,7 @@ def list_groups(
|
||||
return _scim_error_response(400, str(e))
|
||||
|
||||
resources: list[ScimUserResource | ScimGroupResource] = [
|
||||
_group_to_scim(group, dal.get_group_members(group.id), ext_id)
|
||||
provider.build_group_resource(group, dal.get_group_members(group.id), ext_id)
|
||||
for group, ext_id in groups_with_ext_ids
|
||||
]
|
||||
|
||||
@@ -507,6 +534,7 @@ def list_groups(
|
||||
def get_group(
|
||||
group_id: str,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimGroupResource | JSONResponse:
|
||||
"""Get a single group by ID."""
|
||||
@@ -521,13 +549,16 @@ def get_group(
|
||||
mapping = dal.get_group_mapping_by_group_id(group.id)
|
||||
members = dal.get_group_members(group.id)
|
||||
|
||||
return _group_to_scim(group, members, mapping.external_id if mapping else None)
|
||||
return provider.build_group_resource(
|
||||
group, members, mapping.external_id if mapping else None
|
||||
)
|
||||
|
||||
|
||||
@scim_router.post("/Groups", status_code=201, response_model=None)
|
||||
def create_group(
|
||||
group_resource: ScimGroupResource,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimGroupResource | JSONResponse:
|
||||
"""Create a new group from a SCIM provisioning request."""
|
||||
@@ -565,7 +596,7 @@ def create_group(
|
||||
dal.commit()
|
||||
|
||||
members = dal.get_group_members(db_group.id)
|
||||
return _group_to_scim(db_group, members, external_id)
|
||||
return provider.build_group_resource(db_group, members, external_id)
|
||||
|
||||
|
||||
@scim_router.put("/Groups/{group_id}", response_model=None)
|
||||
@@ -573,6 +604,7 @@ def replace_group(
|
||||
group_id: str,
|
||||
group_resource: ScimGroupResource,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimGroupResource | JSONResponse:
|
||||
"""Replace a group entirely (RFC 7644 §3.5.1)."""
|
||||
@@ -595,7 +627,7 @@ def replace_group(
|
||||
dal.commit()
|
||||
|
||||
members = dal.get_group_members(group.id)
|
||||
return _group_to_scim(group, members, group_resource.externalId)
|
||||
return provider.build_group_resource(group, members, group_resource.externalId)
|
||||
|
||||
|
||||
@scim_router.patch("/Groups/{group_id}", response_model=None)
|
||||
@@ -603,6 +635,7 @@ def patch_group(
|
||||
group_id: str,
|
||||
patch_request: ScimPatchRequest,
|
||||
_token: ScimToken = Depends(verify_scim_token),
|
||||
provider: ScimProvider = Depends(_get_provider),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ScimGroupResource | JSONResponse:
|
||||
"""Partially update a group (RFC 7644 §3.5.2).
|
||||
@@ -621,11 +654,11 @@ def patch_group(
|
||||
external_id = mapping.external_id if mapping else None
|
||||
|
||||
current_members = dal.get_group_members(group.id)
|
||||
current = _group_to_scim(group, current_members, external_id)
|
||||
current = provider.build_group_resource(group, current_members, external_id)
|
||||
|
||||
try:
|
||||
patched, added_ids, removed_ids = apply_group_patch(
|
||||
patch_request.Operations, current
|
||||
patch_request.Operations, current, provider.ignored_patch_paths
|
||||
)
|
||||
except ScimPatchError as e:
|
||||
return _scim_error_response(e.status, e.detail)
|
||||
@@ -652,7 +685,7 @@ def patch_group(
|
||||
dal.commit()
|
||||
|
||||
members = dal.get_group_members(group.id)
|
||||
return _group_to_scim(group, members, patched.externalId)
|
||||
return provider.build_group_resource(group, members, patched.externalId)
|
||||
|
||||
|
||||
@scim_router.delete("/Groups/{group_id}", status_code=204, response_model=None)
|
||||
|
||||
@@ -63,6 +63,13 @@ class ScimMeta(BaseModel):
|
||||
location: str | None = None
|
||||
|
||||
|
||||
class ScimUserGroupRef(BaseModel):
|
||||
"""Group reference within a User resource (RFC 7643 §4.1.2, read-only)."""
|
||||
|
||||
value: str
|
||||
display: str | None = None
|
||||
|
||||
|
||||
class ScimUserResource(BaseModel):
|
||||
"""SCIM User resource representation (RFC 7643 §4.1).
|
||||
|
||||
@@ -76,8 +83,10 @@ class ScimUserResource(BaseModel):
|
||||
externalId: str | None = None # IdP's identifier for this user
|
||||
userName: str # Typically the user's email address
|
||||
name: ScimName | None = None
|
||||
displayName: str | None = None
|
||||
emails: list[ScimEmail] = Field(default_factory=list)
|
||||
active: bool = True
|
||||
groups: list[ScimUserGroupRef] = Field(default_factory=list)
|
||||
meta: ScimMeta | None = None
|
||||
|
||||
|
||||
@@ -121,12 +130,40 @@ class ScimPatchOperationType(str, Enum):
|
||||
REMOVE = "remove"
|
||||
|
||||
|
||||
class ScimPatchResourceValue(BaseModel):
|
||||
"""Partial resource dict for path-less PATCH replace operations.
|
||||
|
||||
When an IdP sends a PATCH without a ``path``, the ``value`` is a dict
|
||||
of resource attributes to set. IdPs may include read-only fields
|
||||
(``id``, ``schemas``, ``meta``) alongside actual changes — these are
|
||||
stripped by the provider's ``ignored_patch_paths`` before processing.
|
||||
|
||||
``extra="allow"`` lets unknown attributes pass through so the patch
|
||||
handler can decide what to do with them (ignore or reject).
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
active: bool | None = None
|
||||
userName: str | None = None
|
||||
displayName: str | None = None
|
||||
externalId: str | None = None
|
||||
name: ScimName | None = None
|
||||
members: list[ScimGroupMember] | None = None
|
||||
id: str | None = None
|
||||
schemas: list[str] | None = None
|
||||
meta: ScimMeta | None = None
|
||||
|
||||
|
||||
ScimPatchValue = str | bool | list[ScimGroupMember] | ScimPatchResourceValue | None
|
||||
|
||||
|
||||
class ScimPatchOperation(BaseModel):
|
||||
"""Single PATCH operation (RFC 7644 §3.5.2)."""
|
||||
|
||||
op: ScimPatchOperationType
|
||||
path: str | None = None
|
||||
value: str | list[dict[str, str]] | dict[str, str | bool] | bool | None = None
|
||||
value: ScimPatchValue = None
|
||||
|
||||
|
||||
class ScimPatchRequest(BaseModel):
|
||||
|
||||
@@ -16,9 +16,12 @@ from __future__ import annotations
|
||||
|
||||
import re
|
||||
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimPatchOperation
|
||||
from ee.onyx.server.scim.models import ScimPatchOperationType
|
||||
from ee.onyx.server.scim.models import ScimPatchResourceValue
|
||||
from ee.onyx.server.scim.models import ScimPatchValue
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
|
||||
|
||||
@@ -41,9 +44,15 @@ _MEMBER_FILTER_RE = re.compile(
|
||||
def apply_user_patch(
|
||||
operations: list[ScimPatchOperation],
|
||||
current: ScimUserResource,
|
||||
ignored_paths: frozenset[str] = frozenset(),
|
||||
) -> ScimUserResource:
|
||||
"""Apply SCIM PATCH operations to a user resource.
|
||||
|
||||
Args:
|
||||
operations: The PATCH operations to apply.
|
||||
current: The current user resource state.
|
||||
ignored_paths: SCIM attribute paths to silently skip (from provider).
|
||||
|
||||
Returns a new ``ScimUserResource`` with the modifications applied.
|
||||
The original object is not mutated.
|
||||
|
||||
@@ -55,9 +64,9 @@ def apply_user_patch(
|
||||
|
||||
for op in operations:
|
||||
if op.op == ScimPatchOperationType.REPLACE:
|
||||
_apply_user_replace(op, data, name_data)
|
||||
_apply_user_replace(op, data, name_data, ignored_paths)
|
||||
elif op.op == ScimPatchOperationType.ADD:
|
||||
_apply_user_replace(op, data, name_data)
|
||||
_apply_user_replace(op, data, name_data, ignored_paths)
|
||||
else:
|
||||
raise ScimPatchError(
|
||||
f"Unsupported operation '{op.op.value}' on User resource"
|
||||
@@ -71,30 +80,34 @@ def _apply_user_replace(
|
||||
op: ScimPatchOperation,
|
||||
data: dict,
|
||||
name_data: dict,
|
||||
ignored_paths: frozenset[str],
|
||||
) -> None:
|
||||
"""Apply a replace/add operation to user data."""
|
||||
path = (op.path or "").lower()
|
||||
|
||||
if not path:
|
||||
# No path — value is a dict of top-level attributes to set
|
||||
if isinstance(op.value, dict):
|
||||
for key, val in op.value.items():
|
||||
_set_user_field(key.lower(), val, data, name_data)
|
||||
# No path — value is a resource dict of top-level attributes to set
|
||||
if isinstance(op.value, ScimPatchResourceValue):
|
||||
for key, val in op.value.model_dump(exclude_unset=True).items():
|
||||
_set_user_field(key.lower(), val, data, name_data, ignored_paths)
|
||||
else:
|
||||
raise ScimPatchError("Replace without path requires a dict value")
|
||||
return
|
||||
|
||||
_set_user_field(path, op.value, data, name_data)
|
||||
_set_user_field(path, op.value, data, name_data, ignored_paths)
|
||||
|
||||
|
||||
def _set_user_field(
|
||||
path: str,
|
||||
value: str | bool | dict | list | None,
|
||||
value: ScimPatchValue,
|
||||
data: dict,
|
||||
name_data: dict,
|
||||
ignored_paths: frozenset[str],
|
||||
) -> None:
|
||||
"""Set a single field on user data by SCIM path."""
|
||||
if path == "active":
|
||||
if path in ignored_paths:
|
||||
return
|
||||
elif path == "active":
|
||||
data["active"] = value
|
||||
elif path == "username":
|
||||
data["userName"] = value
|
||||
@@ -107,7 +120,7 @@ def _set_user_field(
|
||||
elif path == "name.formatted":
|
||||
name_data["formatted"] = value
|
||||
elif path == "displayname":
|
||||
# Some IdPs send displayName on users; map to formatted name
|
||||
data["displayName"] = value
|
||||
name_data["formatted"] = value
|
||||
else:
|
||||
raise ScimPatchError(f"Unsupported path '{path}' for User PATCH")
|
||||
@@ -116,9 +129,15 @@ def _set_user_field(
|
||||
def apply_group_patch(
|
||||
operations: list[ScimPatchOperation],
|
||||
current: ScimGroupResource,
|
||||
ignored_paths: frozenset[str] = frozenset(),
|
||||
) -> tuple[ScimGroupResource, list[str], list[str]]:
|
||||
"""Apply SCIM PATCH operations to a group resource.
|
||||
|
||||
Args:
|
||||
operations: The PATCH operations to apply.
|
||||
current: The current group resource state.
|
||||
ignored_paths: SCIM attribute paths to silently skip (from provider).
|
||||
|
||||
Returns:
|
||||
A tuple of (modified group, added member IDs, removed member IDs).
|
||||
The caller uses the member ID lists to update the database.
|
||||
@@ -133,7 +152,9 @@ def apply_group_patch(
|
||||
|
||||
for op in operations:
|
||||
if op.op == ScimPatchOperationType.REPLACE:
|
||||
_apply_group_replace(op, data, current_members, added_ids, removed_ids)
|
||||
_apply_group_replace(
|
||||
op, data, current_members, added_ids, removed_ids, ignored_paths
|
||||
)
|
||||
elif op.op == ScimPatchOperationType.ADD:
|
||||
_apply_group_add(op, current_members, added_ids)
|
||||
elif op.op == ScimPatchOperationType.REMOVE:
|
||||
@@ -154,38 +175,48 @@ def _apply_group_replace(
|
||||
current_members: list[dict],
|
||||
added_ids: list[str],
|
||||
removed_ids: list[str],
|
||||
ignored_paths: frozenset[str],
|
||||
) -> None:
|
||||
"""Apply a replace operation to group data."""
|
||||
path = (op.path or "").lower()
|
||||
|
||||
if not path:
|
||||
if isinstance(op.value, dict):
|
||||
for key, val in op.value.items():
|
||||
if isinstance(op.value, ScimPatchResourceValue):
|
||||
dumped = op.value.model_dump(exclude_unset=True)
|
||||
for key, val in dumped.items():
|
||||
if key.lower() == "members":
|
||||
_replace_members(val, current_members, added_ids, removed_ids)
|
||||
else:
|
||||
_set_group_field(key.lower(), val, data)
|
||||
_set_group_field(key.lower(), val, data, ignored_paths)
|
||||
else:
|
||||
raise ScimPatchError("Replace without path requires a dict value")
|
||||
return
|
||||
|
||||
if path == "members":
|
||||
_replace_members(op.value, current_members, added_ids, removed_ids)
|
||||
_replace_members(
|
||||
_members_to_dicts(op.value), current_members, added_ids, removed_ids
|
||||
)
|
||||
return
|
||||
|
||||
_set_group_field(path, op.value, data)
|
||||
_set_group_field(path, op.value, data, ignored_paths)
|
||||
|
||||
|
||||
def _members_to_dicts(
|
||||
value: str | bool | list[ScimGroupMember] | ScimPatchResourceValue | None,
|
||||
) -> list[dict]:
|
||||
"""Convert a member list value to a list of dicts for internal processing."""
|
||||
if not isinstance(value, list):
|
||||
raise ScimPatchError("Replace members requires a list value")
|
||||
return [m.model_dump(exclude_none=True) for m in value]
|
||||
|
||||
|
||||
def _replace_members(
|
||||
value: str | list | dict | bool | None,
|
||||
value: list[dict],
|
||||
current_members: list[dict],
|
||||
added_ids: list[str],
|
||||
removed_ids: list[str],
|
||||
) -> None:
|
||||
"""Replace the entire group member list."""
|
||||
if not isinstance(value, list):
|
||||
raise ScimPatchError("Replace members requires a list value")
|
||||
|
||||
old_ids = {m["value"] for m in current_members}
|
||||
new_ids = {m.get("value", "") for m in value}
|
||||
|
||||
@@ -197,11 +228,14 @@ def _replace_members(
|
||||
|
||||
def _set_group_field(
|
||||
path: str,
|
||||
value: str | bool | dict | list | None,
|
||||
value: ScimPatchValue,
|
||||
data: dict,
|
||||
ignored_paths: frozenset[str],
|
||||
) -> None:
|
||||
"""Set a single field on group data by SCIM path."""
|
||||
if path == "displayname":
|
||||
if path in ignored_paths:
|
||||
return
|
||||
elif path == "displayname":
|
||||
data["displayName"] = value
|
||||
elif path == "externalid":
|
||||
data["externalId"] = value
|
||||
@@ -223,8 +257,10 @@ def _apply_group_add(
|
||||
if not isinstance(op.value, list):
|
||||
raise ScimPatchError("Add members requires a list value")
|
||||
|
||||
member_dicts = [m.model_dump(exclude_none=True) for m in op.value]
|
||||
|
||||
existing_ids = {m["value"] for m in members}
|
||||
for member_data in op.value:
|
||||
for member_data in member_dicts:
|
||||
member_id = member_data.get("value", "")
|
||||
if member_id and member_id not in existing_ids:
|
||||
members.append(member_data)
|
||||
|
||||
0
backend/ee/onyx/server/scim/providers/__init__.py
Normal file
0
backend/ee/onyx/server/scim/providers/__init__.py
Normal file
123
backend/ee/onyx/server/scim/providers/base.py
Normal file
123
backend/ee/onyx/server/scim/providers/base.py
Normal file
@@ -0,0 +1,123 @@
|
||||
"""Base SCIM provider abstraction."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from abc import ABC
|
||||
from abc import abstractmethod
|
||||
from uuid import UUID
|
||||
|
||||
from ee.onyx.server.scim.models import ScimEmail
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimMeta
|
||||
from ee.onyx.server.scim.models import ScimName
|
||||
from ee.onyx.server.scim.models import ScimUserGroupRef
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserGroup
|
||||
|
||||
|
||||
class ScimProvider(ABC):
|
||||
"""Base class for provider-specific SCIM behavior.
|
||||
|
||||
Subclass this to handle IdP-specific quirks. The base class provides
|
||||
RFC 7643-compliant response builders that populate all standard fields.
|
||||
"""
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def name(self) -> str:
|
||||
"""Short identifier for this provider (e.g. ``"okta"``)."""
|
||||
...
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def ignored_patch_paths(self) -> frozenset[str]:
|
||||
"""SCIM attribute paths to silently skip in PATCH value-object dicts.
|
||||
|
||||
IdPs may include read-only or meta fields alongside actual changes
|
||||
(e.g. Okta sends ``{"id": "...", "active": false}``). Paths listed
|
||||
here are silently dropped instead of raising an error.
|
||||
"""
|
||||
...
|
||||
|
||||
def build_user_resource(
|
||||
self,
|
||||
user: User,
|
||||
external_id: str | None = None,
|
||||
groups: list[tuple[int, str]] | None = None,
|
||||
scim_username: str | None = None,
|
||||
) -> ScimUserResource:
|
||||
"""Build a SCIM User response from an Onyx User.
|
||||
|
||||
Args:
|
||||
user: The Onyx user model.
|
||||
external_id: The IdP's external identifier for this user.
|
||||
groups: List of ``(group_id, group_name)`` tuples for the
|
||||
``groups`` read-only attribute. Pass ``None`` or ``[]``
|
||||
for newly-created users.
|
||||
scim_username: The original-case userName from the IdP. Falls
|
||||
back to ``user.email`` (lowercase) when not available.
|
||||
"""
|
||||
group_refs = [
|
||||
ScimUserGroupRef(value=str(gid), display=gname)
|
||||
for gid, gname in (groups or [])
|
||||
]
|
||||
|
||||
# Use original-case userName if stored, otherwise fall back to the
|
||||
# lowercased email from the User model.
|
||||
username = scim_username or user.email
|
||||
|
||||
return ScimUserResource(
|
||||
id=str(user.id),
|
||||
externalId=external_id,
|
||||
userName=username,
|
||||
name=self._build_scim_name(user),
|
||||
displayName=user.personal_name,
|
||||
emails=[ScimEmail(value=username, type="work", primary=True)],
|
||||
active=user.is_active,
|
||||
groups=group_refs,
|
||||
meta=ScimMeta(resourceType="User"),
|
||||
)
|
||||
|
||||
def build_group_resource(
|
||||
self,
|
||||
group: UserGroup,
|
||||
members: list[tuple[UUID, str | None]],
|
||||
external_id: str | None = None,
|
||||
) -> ScimGroupResource:
|
||||
"""Build a SCIM Group response from an Onyx UserGroup."""
|
||||
scim_members = [
|
||||
ScimGroupMember(value=str(uid), display=email) for uid, email in members
|
||||
]
|
||||
return ScimGroupResource(
|
||||
id=str(group.id),
|
||||
externalId=external_id,
|
||||
displayName=group.name,
|
||||
members=scim_members,
|
||||
meta=ScimMeta(resourceType="Group"),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _build_scim_name(user: User) -> ScimName | None:
|
||||
"""Extract SCIM name components from a user's personal name."""
|
||||
if not user.personal_name:
|
||||
return None
|
||||
parts = user.personal_name.split(" ", 1)
|
||||
return ScimName(
|
||||
givenName=parts[0],
|
||||
familyName=parts[1] if len(parts) > 1 else None,
|
||||
formatted=user.personal_name,
|
||||
)
|
||||
|
||||
|
||||
def get_default_provider() -> ScimProvider:
|
||||
"""Return the default SCIM provider.
|
||||
|
||||
Currently returns ``OktaProvider`` since Okta is the primary supported
|
||||
IdP. When provider detection is added (via token metadata or tenant
|
||||
config), this can be replaced with dynamic resolution.
|
||||
"""
|
||||
from ee.onyx.server.scim.providers.okta import OktaProvider
|
||||
|
||||
return OktaProvider()
|
||||
25
backend/ee/onyx/server/scim/providers/okta.py
Normal file
25
backend/ee/onyx/server/scim/providers/okta.py
Normal file
@@ -0,0 +1,25 @@
|
||||
"""Okta SCIM provider."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from ee.onyx.server.scim.providers.base import ScimProvider
|
||||
|
||||
|
||||
class OktaProvider(ScimProvider):
|
||||
"""Okta SCIM provider.
|
||||
|
||||
Okta behavioral notes:
|
||||
- Uses ``PATCH {"active": false}`` for deprovisioning (not DELETE)
|
||||
- Sends path-less PATCH with value dicts containing extra fields
|
||||
(``id``, ``schemas``)
|
||||
- Expects ``displayName`` and ``groups`` in user responses
|
||||
- Only uses ``eq`` operator for ``userName`` filter
|
||||
"""
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return "okta"
|
||||
|
||||
@property
|
||||
def ignored_patch_paths(self) -> frozenset[str]:
|
||||
return frozenset({"id", "schemas", "meta"})
|
||||
@@ -37,12 +37,15 @@ def list_user_groups(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[UserGroup]:
|
||||
if user.role == UserRole.ADMIN:
|
||||
user_groups = fetch_user_groups(db_session, only_up_to_date=False)
|
||||
user_groups = fetch_user_groups(
|
||||
db_session, only_up_to_date=False, eager_load_for_snapshot=True
|
||||
)
|
||||
else:
|
||||
user_groups = fetch_user_groups_for_user(
|
||||
db_session=db_session,
|
||||
user_id=user.id,
|
||||
only_curator_groups=user.role == UserRole.CURATOR,
|
||||
eager_load_for_snapshot=True,
|
||||
)
|
||||
return [UserGroup.from_model(user_group) for user_group in user_groups]
|
||||
|
||||
|
||||
@@ -53,7 +53,8 @@ class UserGroup(BaseModel):
|
||||
id=cc_pair_relationship.cc_pair.id,
|
||||
name=cc_pair_relationship.cc_pair.name,
|
||||
connector=ConnectorSnapshot.from_connector_db_model(
|
||||
cc_pair_relationship.cc_pair.connector
|
||||
cc_pair_relationship.cc_pair.connector,
|
||||
credential_ids=[cc_pair_relationship.cc_pair.credential_id],
|
||||
),
|
||||
credential=CredentialSnapshot.from_credential_db_model(
|
||||
cc_pair_relationship.cc_pair.credential
|
||||
|
||||
@@ -277,13 +277,32 @@ def verify_email_domain(email: str) -> None:
|
||||
detail="Email is not valid",
|
||||
)
|
||||
|
||||
domain = email.split("@")[-1].lower()
|
||||
local_part, domain = email.split("@")
|
||||
domain = domain.lower()
|
||||
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
# Normalize googlemail.com to gmail.com (they deliver to the same inbox)
|
||||
if domain == "googlemail.com":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={"reason": "Please use @gmail.com instead of @googlemail.com."},
|
||||
)
|
||||
|
||||
if "+" in local_part and domain != "onyx.app":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail={
|
||||
"reason": "Email addresses with '+' are not allowed. Please use your base email address."
|
||||
},
|
||||
)
|
||||
|
||||
# Check if email uses a disposable/temporary domain
|
||||
if is_disposable_email(email):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Disposable email addresses are not allowed. Please use a permanent email address.",
|
||||
detail={
|
||||
"reason": "Disposable email addresses are not allowed. Please use a permanent email address."
|
||||
},
|
||||
)
|
||||
|
||||
# Check domain whitelist if configured
|
||||
|
||||
@@ -190,7 +190,7 @@ def _build_user_information_section(
|
||||
if not sections:
|
||||
return ""
|
||||
|
||||
return USER_INFORMATION_HEADER + "".join(sections)
|
||||
return USER_INFORMATION_HEADER + "\n".join(sections)
|
||||
|
||||
|
||||
def build_system_prompt(
|
||||
@@ -228,23 +228,21 @@ def build_system_prompt(
|
||||
system_prompt += REQUIRE_CITATION_GUIDANCE
|
||||
|
||||
if include_all_guidance:
|
||||
system_prompt += (
|
||||
TOOL_SECTION_HEADER
|
||||
+ TOOL_DESCRIPTION_SEARCH_GUIDANCE
|
||||
+ INTERNAL_SEARCH_GUIDANCE
|
||||
+ WEB_SEARCH_GUIDANCE.format(
|
||||
tool_sections = [
|
||||
TOOL_DESCRIPTION_SEARCH_GUIDANCE,
|
||||
INTERNAL_SEARCH_GUIDANCE,
|
||||
WEB_SEARCH_GUIDANCE.format(
|
||||
site_colon_disabled=WEB_SEARCH_SITE_DISABLED_GUIDANCE
|
||||
)
|
||||
+ OPEN_URLS_GUIDANCE
|
||||
+ PYTHON_TOOL_GUIDANCE
|
||||
+ GENERATE_IMAGE_GUIDANCE
|
||||
+ MEMORY_GUIDANCE
|
||||
)
|
||||
),
|
||||
OPEN_URLS_GUIDANCE,
|
||||
PYTHON_TOOL_GUIDANCE,
|
||||
GENERATE_IMAGE_GUIDANCE,
|
||||
MEMORY_GUIDANCE,
|
||||
]
|
||||
system_prompt += TOOL_SECTION_HEADER + "\n".join(tool_sections)
|
||||
return system_prompt
|
||||
|
||||
if tools:
|
||||
system_prompt += TOOL_SECTION_HEADER
|
||||
|
||||
has_web_search = any(isinstance(tool, WebSearchTool) for tool in tools)
|
||||
has_internal_search = any(isinstance(tool, SearchTool) for tool in tools)
|
||||
has_open_urls = any(isinstance(tool, OpenURLTool) for tool in tools)
|
||||
@@ -254,12 +252,14 @@ def build_system_prompt(
|
||||
)
|
||||
has_memory = any(isinstance(tool, MemoryTool) for tool in tools)
|
||||
|
||||
tool_guidance_sections: list[str] = []
|
||||
|
||||
if has_web_search or has_internal_search or include_all_guidance:
|
||||
system_prompt += TOOL_DESCRIPTION_SEARCH_GUIDANCE
|
||||
tool_guidance_sections.append(TOOL_DESCRIPTION_SEARCH_GUIDANCE)
|
||||
|
||||
# These are not included at the Tool level because the ordering may matter.
|
||||
if has_internal_search or include_all_guidance:
|
||||
system_prompt += INTERNAL_SEARCH_GUIDANCE
|
||||
tool_guidance_sections.append(INTERNAL_SEARCH_GUIDANCE)
|
||||
|
||||
if has_web_search or include_all_guidance:
|
||||
site_disabled_guidance = ""
|
||||
@@ -269,20 +269,23 @@ def build_system_prompt(
|
||||
)
|
||||
if web_search_tool and not web_search_tool.supports_site_filter:
|
||||
site_disabled_guidance = WEB_SEARCH_SITE_DISABLED_GUIDANCE
|
||||
system_prompt += WEB_SEARCH_GUIDANCE.format(
|
||||
site_colon_disabled=site_disabled_guidance
|
||||
tool_guidance_sections.append(
|
||||
WEB_SEARCH_GUIDANCE.format(site_colon_disabled=site_disabled_guidance)
|
||||
)
|
||||
|
||||
if has_open_urls or include_all_guidance:
|
||||
system_prompt += OPEN_URLS_GUIDANCE
|
||||
tool_guidance_sections.append(OPEN_URLS_GUIDANCE)
|
||||
|
||||
if has_python or include_all_guidance:
|
||||
system_prompt += PYTHON_TOOL_GUIDANCE
|
||||
tool_guidance_sections.append(PYTHON_TOOL_GUIDANCE)
|
||||
|
||||
if has_generate_image or include_all_guidance:
|
||||
system_prompt += GENERATE_IMAGE_GUIDANCE
|
||||
tool_guidance_sections.append(GENERATE_IMAGE_GUIDANCE)
|
||||
|
||||
if has_memory or include_all_guidance:
|
||||
system_prompt += MEMORY_GUIDANCE
|
||||
tool_guidance_sections.append(MEMORY_GUIDANCE)
|
||||
|
||||
if tool_guidance_sections:
|
||||
system_prompt += TOOL_SECTION_HEADER + "\n".join(tool_guidance_sections)
|
||||
|
||||
return system_prompt
|
||||
|
||||
@@ -244,6 +244,12 @@ class SharepointConnectorCheckpoint(ConnectorCheckpoint):
|
||||
current_drive_name: str | None = None
|
||||
# Drive's web_url from the API - used as raw_node_id for DRIVE hierarchy nodes
|
||||
current_drive_web_url: str | None = None
|
||||
# Resolved drive ID — avoids re-resolving on checkpoint resume
|
||||
current_drive_id: str | None = None
|
||||
# Next delta API page URL for per-page checkpointing within a drive.
|
||||
# When set, Phase 3b fetches one page at a time so progress is persisted
|
||||
# between pages. None means BFS path or no active delta traversal.
|
||||
current_drive_delta_next_link: str | None = None
|
||||
|
||||
process_site_pages: bool = False
|
||||
|
||||
@@ -1403,6 +1409,87 @@ class SharepointConnector(
|
||||
if not page_url:
|
||||
break
|
||||
|
||||
def _build_delta_start_url(
|
||||
self,
|
||||
drive_id: str,
|
||||
start: datetime | None = None,
|
||||
page_size: int = 200,
|
||||
) -> str:
|
||||
"""Build the initial delta API URL with query parameters embedded.
|
||||
|
||||
Embeds ``$top`` (and optionally a timestamp ``token``) directly in the
|
||||
URL so that the returned string is fully self-contained and can be
|
||||
stored in a checkpoint without needing a separate params dict.
|
||||
"""
|
||||
base_url = f"{self.graph_api_base}/drives/{drive_id}/root/delta"
|
||||
params = [f"$top={page_size}"]
|
||||
if start is not None and start > _EPOCH:
|
||||
token = quote(start.isoformat(timespec="seconds"))
|
||||
params.append(f"token={token}")
|
||||
return f"{base_url}?{'&'.join(params)}"
|
||||
|
||||
def _fetch_one_delta_page(
|
||||
self,
|
||||
page_url: str,
|
||||
drive_id: str,
|
||||
start: datetime | None = None,
|
||||
end: datetime | None = None,
|
||||
page_size: int = 200,
|
||||
) -> tuple[list[DriveItemData], str | None]:
|
||||
"""Fetch a single page of delta API results.
|
||||
|
||||
Returns ``(items, next_page_url)``. *next_page_url* is ``None`` when
|
||||
the delta enumeration is complete (deltaLink with no nextLink).
|
||||
|
||||
On 410 Gone (expired token) returns ``([], full_resync_url)`` so
|
||||
the caller can store the resync URL in the checkpoint and retry on
|
||||
the next cycle.
|
||||
"""
|
||||
try:
|
||||
data = self._graph_api_get_json(page_url)
|
||||
except requests.HTTPError as e:
|
||||
if e.response is not None and e.response.status_code == 410:
|
||||
logger.warning(
|
||||
"Delta token expired (410 Gone) for drive '%s'. "
|
||||
"Will restart with full delta enumeration.",
|
||||
drive_id,
|
||||
)
|
||||
full_url = (
|
||||
f"{self.graph_api_base}/drives/{drive_id}/root/delta"
|
||||
f"?$top={page_size}"
|
||||
)
|
||||
return [], full_url
|
||||
raise
|
||||
|
||||
items: list[DriveItemData] = []
|
||||
for item in data.get("value", []):
|
||||
if "folder" in item or "deleted" in item:
|
||||
continue
|
||||
if start is not None or end is not None:
|
||||
raw_ts = item.get("lastModifiedDateTime")
|
||||
if raw_ts:
|
||||
mod_dt = datetime.fromisoformat(raw_ts.replace("Z", "+00:00"))
|
||||
if start is not None and mod_dt < start:
|
||||
continue
|
||||
if end is not None and mod_dt > end:
|
||||
continue
|
||||
items.append(DriveItemData.from_graph_json(item))
|
||||
|
||||
next_url = data.get("@odata.nextLink")
|
||||
if next_url:
|
||||
return items, next_url
|
||||
return items, None
|
||||
|
||||
@staticmethod
|
||||
def _clear_drive_checkpoint_state(
|
||||
checkpoint: "SharepointConnectorCheckpoint",
|
||||
) -> None:
|
||||
"""Reset all drive-level fields in the checkpoint."""
|
||||
checkpoint.current_drive_name = None
|
||||
checkpoint.current_drive_id = None
|
||||
checkpoint.current_drive_web_url = None
|
||||
checkpoint.current_drive_delta_next_link = None
|
||||
|
||||
def _fetch_slim_documents_from_sharepoint(self) -> GenerateSlimDocumentOutput:
|
||||
site_descriptors = self.site_descriptors or self.fetch_sites()
|
||||
|
||||
@@ -1844,14 +1931,13 @@ class SharepointConnector(
|
||||
# Return checkpoint to allow persistence after drive initialization
|
||||
return checkpoint
|
||||
|
||||
# Phase 3: Process documents from current drive
|
||||
# Phase 3a: Initialize the next drive for processing
|
||||
if (
|
||||
checkpoint.current_site_descriptor
|
||||
and checkpoint.cached_drive_names
|
||||
and len(checkpoint.cached_drive_names) > 0
|
||||
and checkpoint.current_drive_name is None
|
||||
):
|
||||
|
||||
checkpoint.current_drive_name = checkpoint.cached_drive_names.popleft()
|
||||
|
||||
start_dt = datetime.fromtimestamp(start, tz=timezone.utc)
|
||||
@@ -1859,7 +1945,8 @@ class SharepointConnector(
|
||||
site_descriptor = checkpoint.current_site_descriptor
|
||||
|
||||
logger.info(
|
||||
f"Processing drive '{checkpoint.current_drive_name}' in site: {site_descriptor.url}"
|
||||
f"Processing drive '{checkpoint.current_drive_name}' "
|
||||
f"in site: {site_descriptor.url}"
|
||||
)
|
||||
logger.debug(f"Time range: {start_dt} to {end_dt}")
|
||||
|
||||
@@ -1868,35 +1955,35 @@ class SharepointConnector(
|
||||
logger.warning("Current drive name is None, skipping")
|
||||
return checkpoint
|
||||
|
||||
driveitems: Iterable[DriveItemData] = iter(())
|
||||
drive_web_url: str | None = None
|
||||
try:
|
||||
logger.info(
|
||||
f"Fetching drive items for drive name: {current_drive_name}"
|
||||
)
|
||||
result = self._resolve_drive(site_descriptor, current_drive_name)
|
||||
if result is not None:
|
||||
drive_id, drive_web_url = result
|
||||
driveitems = self._get_drive_items_for_drive_id(
|
||||
site_descriptor, drive_id, start_dt, end_dt
|
||||
)
|
||||
checkpoint.current_drive_web_url = drive_web_url
|
||||
if result is None:
|
||||
logger.warning(f"Drive '{current_drive_name}' not found, skipping")
|
||||
self._clear_drive_checkpoint_state(checkpoint)
|
||||
return checkpoint
|
||||
|
||||
drive_id, drive_web_url = result
|
||||
checkpoint.current_drive_id = drive_id
|
||||
checkpoint.current_drive_web_url = drive_web_url
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to retrieve items from drive '{current_drive_name}' in site: {site_descriptor.url}: {e}"
|
||||
f"Failed to retrieve items from drive '{current_drive_name}' "
|
||||
f"in site: {site_descriptor.url}: {e}"
|
||||
)
|
||||
yield _create_entity_failure(
|
||||
f"{site_descriptor.url}|{current_drive_name}",
|
||||
f"Failed to access drive '{current_drive_name}' in site '{site_descriptor.url}': {str(e)}",
|
||||
f"Failed to access drive '{current_drive_name}' "
|
||||
f"in site '{site_descriptor.url}': {str(e)}",
|
||||
(start_dt, end_dt),
|
||||
e,
|
||||
)
|
||||
checkpoint.current_drive_name = None
|
||||
checkpoint.current_drive_web_url = None
|
||||
self._clear_drive_checkpoint_state(checkpoint)
|
||||
return checkpoint
|
||||
|
||||
# Normalize drive name (e.g., "Documents" -> "Shared Documents")
|
||||
current_drive_name = SHARED_DOCUMENTS_MAP.get(
|
||||
display_drive_name = SHARED_DOCUMENTS_MAP.get(
|
||||
current_drive_name, current_drive_name
|
||||
)
|
||||
|
||||
@@ -1904,10 +1991,74 @@ class SharepointConnector(
|
||||
yield from self._yield_drive_hierarchy_node(
|
||||
site_descriptor.url,
|
||||
drive_web_url,
|
||||
current_drive_name,
|
||||
display_drive_name,
|
||||
checkpoint,
|
||||
)
|
||||
|
||||
# For non-folder-scoped drives, use delta API with per-page
|
||||
# checkpointing. Build the initial URL and fall through to 3b.
|
||||
if not site_descriptor.folder_path:
|
||||
checkpoint.current_drive_delta_next_link = self._build_delta_start_url(
|
||||
drive_id, start_dt
|
||||
)
|
||||
# else: BFS path — delta_next_link stays None;
|
||||
# Phase 3b will use _iter_drive_items_paged.
|
||||
|
||||
# Phase 3b: Process items from the current drive
|
||||
if (
|
||||
checkpoint.current_site_descriptor
|
||||
and checkpoint.current_drive_name is not None
|
||||
and checkpoint.current_drive_id is not None
|
||||
):
|
||||
site_descriptor = checkpoint.current_site_descriptor
|
||||
start_dt = datetime.fromtimestamp(start, tz=timezone.utc)
|
||||
end_dt = datetime.fromtimestamp(end, tz=timezone.utc)
|
||||
current_drive_name = SHARED_DOCUMENTS_MAP.get(
|
||||
checkpoint.current_drive_name, checkpoint.current_drive_name
|
||||
)
|
||||
drive_web_url = checkpoint.current_drive_web_url
|
||||
|
||||
# --- determine item source ---
|
||||
driveitems: Iterable[DriveItemData]
|
||||
has_more_delta_pages = False
|
||||
|
||||
if checkpoint.current_drive_delta_next_link:
|
||||
# Delta path: fetch one page at a time for checkpointing
|
||||
try:
|
||||
page_items, next_url = self._fetch_one_delta_page(
|
||||
page_url=checkpoint.current_drive_delta_next_link,
|
||||
drive_id=checkpoint.current_drive_id,
|
||||
start=start_dt,
|
||||
end=end_dt,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to fetch delta page for drive "
|
||||
f"'{current_drive_name}': {e}"
|
||||
)
|
||||
yield _create_entity_failure(
|
||||
f"{site_descriptor.url}|{current_drive_name}",
|
||||
f"Failed to fetch delta page for drive "
|
||||
f"'{current_drive_name}': {str(e)}",
|
||||
(start_dt, end_dt),
|
||||
e,
|
||||
)
|
||||
self._clear_drive_checkpoint_state(checkpoint)
|
||||
return checkpoint
|
||||
|
||||
driveitems = page_items
|
||||
has_more_delta_pages = next_url is not None
|
||||
if next_url:
|
||||
checkpoint.current_drive_delta_next_link = next_url
|
||||
else:
|
||||
# BFS path (folder-scoped): process all items at once
|
||||
driveitems = self._iter_drive_items_paged(
|
||||
drive_id=checkpoint.current_drive_id,
|
||||
folder_path=site_descriptor.folder_path,
|
||||
start=start_dt,
|
||||
end=end_dt,
|
||||
)
|
||||
|
||||
item_count = 0
|
||||
for driveitem in driveitems:
|
||||
item_count += 1
|
||||
@@ -1949,8 +2100,6 @@ class SharepointConnector(
|
||||
if include_permissions:
|
||||
ctx = self._create_rest_client_context(site_descriptor.url)
|
||||
|
||||
# Re-acquire token in case it expired during a long traversal
|
||||
# MSAL has a cache that returns the same token while still valid.
|
||||
access_token = self._get_graph_access_token()
|
||||
doc_or_failure = _convert_driveitem_to_document_with_permissions(
|
||||
driveitem,
|
||||
@@ -1986,8 +2135,11 @@ class SharepointConnector(
|
||||
)
|
||||
|
||||
logger.info(f"Processed {item_count} items in drive '{current_drive_name}'")
|
||||
checkpoint.current_drive_name = None
|
||||
checkpoint.current_drive_web_url = None
|
||||
|
||||
if has_more_delta_pages:
|
||||
return checkpoint
|
||||
|
||||
self._clear_drive_checkpoint_state(checkpoint)
|
||||
|
||||
# Phase 4: Progression logic - determine next step
|
||||
# If we have more drives in current site, continue with current site
|
||||
|
||||
@@ -32,6 +32,7 @@ from onyx.context.search.federated.slack_search_utils import should_include_mess
|
||||
from onyx.context.search.models import ChunkIndexRequest
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.db.document import DocumentSource
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.document_index.document_index_utils import (
|
||||
get_multipass_config,
|
||||
@@ -905,13 +906,15 @@ def convert_slack_score(slack_score: float) -> float:
|
||||
def slack_retrieval(
|
||||
query: ChunkIndexRequest,
|
||||
access_token: str,
|
||||
db_session: Session,
|
||||
db_session: Session | None = None,
|
||||
connector: FederatedConnectorDetail | None = None, # noqa: ARG001
|
||||
entities: dict[str, Any] | None = None,
|
||||
limit: int | None = None,
|
||||
slack_event_context: SlackContext | None = None,
|
||||
bot_token: str | None = None, # Add bot token parameter
|
||||
team_id: str | None = None,
|
||||
# Pre-fetched data — when provided, avoids DB query (no session needed)
|
||||
search_settings: SearchSettings | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
"""
|
||||
Main entry point for Slack federated search with entity filtering.
|
||||
@@ -925,7 +928,7 @@ def slack_retrieval(
|
||||
Args:
|
||||
query: Search query object
|
||||
access_token: User OAuth access token
|
||||
db_session: Database session
|
||||
db_session: Database session (optional if search_settings provided)
|
||||
connector: Federated connector detail (unused, kept for backwards compat)
|
||||
entities: Connector-level config (entity filtering configuration)
|
||||
limit: Maximum number of results
|
||||
@@ -1153,7 +1156,10 @@ def slack_retrieval(
|
||||
|
||||
# chunk index docs into doc aware chunks
|
||||
# a single index doc can get split into multiple chunks
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
if search_settings is None:
|
||||
if db_session is None:
|
||||
raise ValueError("Either db_session or search_settings must be provided")
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
embedder = DefaultIndexingEmbedder.from_db_search_settings(
|
||||
search_settings=search_settings
|
||||
)
|
||||
|
||||
@@ -18,8 +18,10 @@ from onyx.context.search.utils import inference_section_from_chunks
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import User
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.federated_connectors.federated_retrieval import FederatedRetrievalInfo
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.natural_language_processing.english_stopwords import strip_stopwords
|
||||
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from onyx.secondary_llm_flows.source_filter import extract_source_filter
|
||||
from onyx.secondary_llm_flows.time_filter import extract_time_filter
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -41,7 +43,7 @@ def _build_index_filters(
|
||||
user_file_ids: list[UUID] | None,
|
||||
persona_document_sets: list[str] | None,
|
||||
persona_time_cutoff: datetime | None,
|
||||
db_session: Session,
|
||||
db_session: Session | None = None,
|
||||
auto_detect_filters: bool = False,
|
||||
query: str | None = None,
|
||||
llm: LLM | None = None,
|
||||
@@ -49,18 +51,19 @@ def _build_index_filters(
|
||||
# Assistant knowledge filters
|
||||
attached_document_ids: list[str] | None = None,
|
||||
hierarchy_node_ids: list[int] | None = None,
|
||||
# Pre-fetched ACL filters (skips DB query when provided)
|
||||
acl_filters: list[str] | None = None,
|
||||
) -> IndexFilters:
|
||||
if auto_detect_filters and (llm is None or query is None):
|
||||
raise RuntimeError("LLM and query are required for auto detect filters")
|
||||
|
||||
base_filters = user_provided_filters or BaseFilters()
|
||||
|
||||
if (
|
||||
user_provided_filters
|
||||
and user_provided_filters.document_set is None
|
||||
and persona_document_sets is not None
|
||||
):
|
||||
base_filters.document_set = persona_document_sets
|
||||
document_set_filter = (
|
||||
base_filters.document_set
|
||||
if base_filters.document_set is not None
|
||||
else persona_document_sets
|
||||
)
|
||||
|
||||
time_filter = base_filters.time_cutoff or persona_time_cutoff
|
||||
source_filter = base_filters.source_type
|
||||
@@ -103,15 +106,20 @@ def _build_index_filters(
|
||||
source_filter = list(source_filter) + [DocumentSource.USER_FILE]
|
||||
logger.debug("Added USER_FILE to source_filter for user knowledge search")
|
||||
|
||||
user_acl_filters = (
|
||||
None if bypass_acl else build_access_filters_for_user(user, db_session)
|
||||
)
|
||||
if bypass_acl:
|
||||
user_acl_filters = None
|
||||
elif acl_filters is not None:
|
||||
user_acl_filters = acl_filters
|
||||
else:
|
||||
if db_session is None:
|
||||
raise ValueError("Either db_session or acl_filters must be provided")
|
||||
user_acl_filters = build_access_filters_for_user(user, db_session)
|
||||
|
||||
final_filters = IndexFilters(
|
||||
user_file_ids=user_file_ids,
|
||||
project_id=project_id,
|
||||
source_type=source_filter,
|
||||
document_set=persona_document_sets,
|
||||
document_set=document_set_filter,
|
||||
time_cutoff=time_filter,
|
||||
tags=base_filters.tags,
|
||||
access_control_list=user_acl_filters,
|
||||
@@ -252,11 +260,15 @@ def search_pipeline(
|
||||
user: User,
|
||||
# Used for default filters and settings
|
||||
persona: Persona | None,
|
||||
db_session: Session,
|
||||
db_session: Session | None = None,
|
||||
auto_detect_filters: bool = False,
|
||||
llm: LLM | None = None,
|
||||
# If a project ID is provided, it will be exclusively scoped to that project
|
||||
project_id: int | None = None,
|
||||
# Pre-fetched data — when provided, avoids DB queries (no session needed)
|
||||
acl_filters: list[str] | None = None,
|
||||
embedding_model: EmbeddingModel | None = None,
|
||||
prefetched_federated_retrieval_infos: list[FederatedRetrievalInfo] | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
user_uploaded_persona_files: list[UUID] | None = (
|
||||
[user_file.id for user_file in persona.user_files] if persona else None
|
||||
@@ -297,6 +309,7 @@ def search_pipeline(
|
||||
bypass_acl=chunk_search_request.bypass_acl,
|
||||
attached_document_ids=attached_document_ids,
|
||||
hierarchy_node_ids=hierarchy_node_ids,
|
||||
acl_filters=acl_filters,
|
||||
)
|
||||
|
||||
query_keywords = strip_stopwords(chunk_search_request.query)
|
||||
@@ -315,6 +328,8 @@ def search_pipeline(
|
||||
user_id=user.id if user else None,
|
||||
document_index=document_index,
|
||||
db_session=db_session,
|
||||
embedding_model=embedding_model,
|
||||
prefetched_federated_retrieval_infos=prefetched_federated_retrieval_infos,
|
||||
)
|
||||
|
||||
# For some specific connectors like Salesforce, a user that has access to an object doesn't mean
|
||||
|
||||
@@ -14,9 +14,11 @@ from onyx.context.search.utils import get_query_embedding
|
||||
from onyx.context.search.utils import inference_section_from_chunks
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.interfaces import VespaChunkRequest
|
||||
from onyx.federated_connectors.federated_retrieval import FederatedRetrievalInfo
|
||||
from onyx.federated_connectors.federated_retrieval import (
|
||||
get_federated_retrieval_functions,
|
||||
)
|
||||
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
|
||||
@@ -50,9 +52,14 @@ def combine_retrieval_results(
|
||||
def _embed_and_search(
|
||||
query_request: ChunkIndexRequest,
|
||||
document_index: DocumentIndex,
|
||||
db_session: Session,
|
||||
db_session: Session | None = None,
|
||||
embedding_model: EmbeddingModel | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
query_embedding = get_query_embedding(query_request.query, db_session)
|
||||
query_embedding = get_query_embedding(
|
||||
query_request.query,
|
||||
db_session=db_session,
|
||||
embedding_model=embedding_model,
|
||||
)
|
||||
|
||||
hybrid_alpha = query_request.hybrid_alpha or HYBRID_ALPHA
|
||||
|
||||
@@ -78,7 +85,9 @@ def search_chunks(
|
||||
query_request: ChunkIndexRequest,
|
||||
user_id: UUID | None,
|
||||
document_index: DocumentIndex,
|
||||
db_session: Session,
|
||||
db_session: Session | None = None,
|
||||
embedding_model: EmbeddingModel | None = None,
|
||||
prefetched_federated_retrieval_infos: list[FederatedRetrievalInfo] | None = None,
|
||||
) -> list[InferenceChunk]:
|
||||
run_queries: list[tuple[Callable, tuple]] = []
|
||||
|
||||
@@ -88,14 +97,22 @@ def search_chunks(
|
||||
else None
|
||||
)
|
||||
|
||||
# Federated retrieval
|
||||
federated_retrieval_infos = get_federated_retrieval_functions(
|
||||
db_session=db_session,
|
||||
user_id=user_id,
|
||||
source_types=list(source_filters) if source_filters else None,
|
||||
document_set_names=query_request.filters.document_set,
|
||||
user_file_ids=query_request.filters.user_file_ids,
|
||||
)
|
||||
# Federated retrieval — use pre-fetched if available, otherwise query DB
|
||||
if prefetched_federated_retrieval_infos is not None:
|
||||
federated_retrieval_infos = prefetched_federated_retrieval_infos
|
||||
else:
|
||||
if db_session is None:
|
||||
raise ValueError(
|
||||
"Either db_session or prefetched_federated_retrieval_infos "
|
||||
"must be provided"
|
||||
)
|
||||
federated_retrieval_infos = get_federated_retrieval_functions(
|
||||
db_session=db_session,
|
||||
user_id=user_id,
|
||||
source_types=list(source_filters) if source_filters else None,
|
||||
document_set_names=query_request.filters.document_set,
|
||||
user_file_ids=query_request.filters.user_file_ids,
|
||||
)
|
||||
|
||||
federated_sources = set(
|
||||
federated_retrieval_info.source.to_non_federated_source()
|
||||
@@ -114,7 +131,10 @@ def search_chunks(
|
||||
|
||||
if normal_search_enabled:
|
||||
run_queries.append(
|
||||
(_embed_and_search, (query_request, document_index, db_session))
|
||||
(
|
||||
_embed_and_search,
|
||||
(query_request, document_index, db_session, embedding_model),
|
||||
)
|
||||
)
|
||||
|
||||
parallel_search_results = run_functions_tuples_in_parallel(run_queries)
|
||||
|
||||
@@ -64,23 +64,34 @@ def inference_section_from_single_chunk(
|
||||
)
|
||||
|
||||
|
||||
def get_query_embeddings(queries: list[str], db_session: Session) -> list[Embedding]:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
def get_query_embeddings(
|
||||
queries: list[str],
|
||||
db_session: Session | None = None,
|
||||
embedding_model: EmbeddingModel | None = None,
|
||||
) -> list[Embedding]:
|
||||
if embedding_model is None:
|
||||
if db_session is None:
|
||||
raise ValueError("Either db_session or embedding_model must be provided")
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
embedding_model = EmbeddingModel.from_db_model(
|
||||
search_settings=search_settings,
|
||||
server_host=MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
model = EmbeddingModel.from_db_model(
|
||||
search_settings=search_settings,
|
||||
# The below are globally set, this flow always uses the indexing one
|
||||
server_host=MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
query_embedding = model.encode(queries, text_type=EmbedTextType.QUERY)
|
||||
query_embedding = embedding_model.encode(queries, text_type=EmbedTextType.QUERY)
|
||||
return query_embedding
|
||||
|
||||
|
||||
@log_function_time(print_only=True, debug_only=True)
|
||||
def get_query_embedding(query: str, db_session: Session) -> Embedding:
|
||||
return get_query_embeddings([query], db_session)[0]
|
||||
def get_query_embedding(
|
||||
query: str,
|
||||
db_session: Session | None = None,
|
||||
embedding_model: EmbeddingModel | None = None,
|
||||
) -> Embedding:
|
||||
return get_query_embeddings(
|
||||
[query], db_session=db_session, embedding_model=embedding_model
|
||||
)[0]
|
||||
|
||||
|
||||
def convert_inference_sections_to_search_docs(
|
||||
|
||||
@@ -4,6 +4,7 @@ from fastapi_users.password import PasswordHelper
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.api_key import ApiKeyDescriptor
|
||||
@@ -54,6 +55,7 @@ async def fetch_user_for_api_key(
|
||||
select(User)
|
||||
.join(ApiKey, ApiKey.user_id == User.id)
|
||||
.where(ApiKey.hashed_api_key == hashed_api_key)
|
||||
.options(selectinload(User.memories))
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -13,6 +13,7 @@ from sqlalchemy import func
|
||||
from sqlalchemy import Select
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.schemas import UserRole
|
||||
@@ -97,6 +98,11 @@ async def get_user_count(only_admin_users: bool = False) -> int:
|
||||
|
||||
# Need to override this because FastAPI Users doesn't give flexibility for backend field creation logic in OAuth flow
|
||||
class SQLAlchemyUserAdminDB(SQLAlchemyUserDatabase[UP, ID]):
|
||||
async def _get_user(self, statement: Select) -> UP | None:
|
||||
statement = statement.options(selectinload(User.memories))
|
||||
results = await self.session.execute(statement)
|
||||
return results.unique().scalar_one_or_none()
|
||||
|
||||
async def create(
|
||||
self,
|
||||
create_dict: Dict[str, Any],
|
||||
|
||||
@@ -116,12 +116,15 @@ def get_connector_credential_pairs_for_user(
|
||||
order_by_desc: bool = False,
|
||||
source: DocumentSource | None = None,
|
||||
processing_mode: ProcessingMode | None = ProcessingMode.REGULAR,
|
||||
defer_connector_config: bool = False,
|
||||
) -> list[ConnectorCredentialPair]:
|
||||
"""Get connector credential pairs for a user.
|
||||
|
||||
Args:
|
||||
processing_mode: Filter by processing mode. Defaults to REGULAR to hide
|
||||
FILE_SYSTEM connectors from standard admin UI. Pass None to get all.
|
||||
defer_connector_config: If True, skips loading Connector.connector_specific_config
|
||||
to avoid fetching large JSONB blobs when they aren't needed.
|
||||
"""
|
||||
if eager_load_user:
|
||||
assert (
|
||||
@@ -130,7 +133,10 @@ def get_connector_credential_pairs_for_user(
|
||||
stmt = select(ConnectorCredentialPair).distinct()
|
||||
|
||||
if eager_load_connector:
|
||||
stmt = stmt.options(selectinload(ConnectorCredentialPair.connector))
|
||||
connector_load = selectinload(ConnectorCredentialPair.connector)
|
||||
if defer_connector_config:
|
||||
connector_load = connector_load.defer(Connector.connector_specific_config)
|
||||
stmt = stmt.options(connector_load)
|
||||
|
||||
if eager_load_credential:
|
||||
load_opts = selectinload(ConnectorCredentialPair.credential)
|
||||
@@ -170,6 +176,7 @@ def get_connector_credential_pairs_for_user_parallel(
|
||||
order_by_desc: bool = False,
|
||||
source: DocumentSource | None = None,
|
||||
processing_mode: ProcessingMode | None = ProcessingMode.REGULAR,
|
||||
defer_connector_config: bool = False,
|
||||
) -> list[ConnectorCredentialPair]:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
return get_connector_credential_pairs_for_user(
|
||||
@@ -183,6 +190,7 @@ def get_connector_credential_pairs_for_user_parallel(
|
||||
order_by_desc=order_by_desc,
|
||||
source=source,
|
||||
processing_mode=processing_mode,
|
||||
defer_connector_config=defer_connector_config,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -554,10 +554,19 @@ def fetch_all_document_sets_for_user(
|
||||
stmt = (
|
||||
select(DocumentSetDBModel)
|
||||
.distinct()
|
||||
.options(selectinload(DocumentSetDBModel.federated_connectors))
|
||||
.options(
|
||||
selectinload(DocumentSetDBModel.connector_credential_pairs).selectinload(
|
||||
ConnectorCredentialPair.connector
|
||||
),
|
||||
selectinload(DocumentSetDBModel.users),
|
||||
selectinload(DocumentSetDBModel.groups),
|
||||
selectinload(DocumentSetDBModel.federated_connectors).selectinload(
|
||||
FederatedConnector__DocumentSet.federated_connector
|
||||
),
|
||||
)
|
||||
)
|
||||
stmt = _add_user_filters(stmt, user, get_editable=get_editable)
|
||||
return db_session.scalars(stmt).all()
|
||||
return db_session.scalars(stmt).unique().all()
|
||||
|
||||
|
||||
def fetch_documents_for_document_set_paginated(
|
||||
|
||||
@@ -1,11 +1,102 @@
|
||||
from sqlalchemy import text
|
||||
|
||||
from onyx.db.engine.sql_engine import get_session_with_shared_schema
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.configs import TENANT_ID_PREFIX
|
||||
|
||||
|
||||
def get_schemas_needing_migration(
|
||||
tenant_schemas: list[str], head_rev: str
|
||||
) -> list[str]:
|
||||
"""Return only schemas whose current alembic version is not at head.
|
||||
|
||||
Uses a server-side PL/pgSQL loop to collect each schema's alembic version
|
||||
into a temp table one at a time. This avoids building a massive UNION ALL
|
||||
query (which locks the DB and times out at 17k+ schemas) and instead
|
||||
acquires locks sequentially, one schema per iteration.
|
||||
"""
|
||||
if not tenant_schemas:
|
||||
return []
|
||||
|
||||
engine = SqlEngine.get_engine()
|
||||
|
||||
with engine.connect() as conn:
|
||||
# Populate a temp input table with exactly the schemas we care about.
|
||||
# The DO block reads from this table so it only iterates the requested
|
||||
# schemas instead of every tenant_% schema in the database.
|
||||
conn.execute(text("DROP TABLE IF EXISTS _alembic_version_snapshot"))
|
||||
conn.execute(text("DROP TABLE IF EXISTS _tenant_schemas_input"))
|
||||
conn.execute(text("CREATE TEMP TABLE _tenant_schemas_input (schema_name text)"))
|
||||
conn.execute(
|
||||
text(
|
||||
"INSERT INTO _tenant_schemas_input (schema_name) "
|
||||
"SELECT unnest(CAST(:schemas AS text[]))"
|
||||
),
|
||||
{"schemas": tenant_schemas},
|
||||
)
|
||||
conn.execute(
|
||||
text(
|
||||
"CREATE TEMP TABLE _alembic_version_snapshot "
|
||||
"(schema_name text, version_num text)"
|
||||
)
|
||||
)
|
||||
|
||||
conn.execute(
|
||||
text(
|
||||
"""
|
||||
DO $$
|
||||
DECLARE
|
||||
s text;
|
||||
schemas text[];
|
||||
BEGIN
|
||||
SELECT array_agg(schema_name) INTO schemas
|
||||
FROM _tenant_schemas_input;
|
||||
|
||||
IF schemas IS NULL THEN
|
||||
RAISE NOTICE 'No tenant schemas found.';
|
||||
RETURN;
|
||||
END IF;
|
||||
|
||||
FOREACH s IN ARRAY schemas LOOP
|
||||
BEGIN
|
||||
EXECUTE format(
|
||||
'INSERT INTO _alembic_version_snapshot
|
||||
SELECT %L, version_num FROM %I.alembic_version',
|
||||
s, s
|
||||
);
|
||||
EXCEPTION
|
||||
-- undefined_table: schema exists but has no alembic_version
|
||||
-- table yet (new tenant, not yet migrated).
|
||||
-- invalid_schema_name: tenant is registered but its
|
||||
-- PostgreSQL schema does not exist yet (e.g. provisioning
|
||||
-- incomplete). Both cases mean no version is available and
|
||||
-- the schema will be included in the migration list.
|
||||
WHEN undefined_table THEN NULL;
|
||||
WHEN invalid_schema_name THEN NULL;
|
||||
END;
|
||||
END LOOP;
|
||||
END;
|
||||
$$
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
rows = conn.execute(
|
||||
text("SELECT schema_name, version_num FROM _alembic_version_snapshot")
|
||||
)
|
||||
version_by_schema = {row[0]: row[1] for row in rows}
|
||||
|
||||
conn.execute(text("DROP TABLE IF EXISTS _alembic_version_snapshot"))
|
||||
conn.execute(text("DROP TABLE IF EXISTS _tenant_schemas_input"))
|
||||
|
||||
# Schemas missing from the snapshot have no alembic_version table yet and
|
||||
# also need migration. version_by_schema.get(s) returns None for those,
|
||||
# and None != head_rev, so they are included automatically.
|
||||
return [s for s in tenant_schemas if version_by_schema.get(s) != head_rev]
|
||||
|
||||
|
||||
def get_all_tenant_ids() -> list[str]:
|
||||
"""Returning [None] means the only tenant is the 'public' or self hosted tenant."""
|
||||
|
||||
|
||||
@@ -287,7 +287,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
|
||||
# relationships
|
||||
credentials: Mapped[list["Credential"]] = relationship(
|
||||
"Credential", back_populates="user", lazy="joined"
|
||||
"Credential", back_populates="user"
|
||||
)
|
||||
chat_sessions: Mapped[list["ChatSession"]] = relationship(
|
||||
"ChatSession", back_populates="user"
|
||||
@@ -321,7 +321,6 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
"Memory",
|
||||
back_populates="user",
|
||||
cascade="all, delete-orphan",
|
||||
lazy="selectin",
|
||||
order_by="desc(Memory.id)",
|
||||
)
|
||||
oauth_user_tokens: Mapped[list["OAuthUserToken"]] = relationship(
|
||||
@@ -4979,3 +4978,12 @@ class ScimGroupMapping(Base):
|
||||
user_group: Mapped[UserGroup] = relationship(
|
||||
"UserGroup", foreign_keys=[user_group_id]
|
||||
)
|
||||
|
||||
|
||||
class CodeInterpreterServer(Base):
|
||||
"""Details about the code interpreter server"""
|
||||
|
||||
__tablename__ = "code_interpreter_server"
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
server_enabled: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
|
||||
@@ -8,6 +8,7 @@ from uuid import UUID
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import selectinload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.pat import build_displayable_pat
|
||||
@@ -31,55 +32,61 @@ async def fetch_user_for_pat(
|
||||
|
||||
NOTE: This is async since it's used during auth (which is necessarily async due to FastAPI Users).
|
||||
NOTE: Expired includes both naturally expired and user-revoked tokens (revocation sets expires_at=NOW()).
|
||||
|
||||
Uses select(User) as primary entity so that joined-eager relationships (e.g. oauth_accounts)
|
||||
are loaded correctly — matching the pattern in fetch_user_for_api_key.
|
||||
"""
|
||||
# Single joined query with all filters pushed to database
|
||||
now = datetime.now(timezone.utc)
|
||||
result = await async_db_session.execute(
|
||||
select(PersonalAccessToken, User)
|
||||
.join(User, PersonalAccessToken.user_id == User.id)
|
||||
|
||||
user = await async_db_session.scalar(
|
||||
select(User)
|
||||
.join(PersonalAccessToken, PersonalAccessToken.user_id == User.id)
|
||||
.where(PersonalAccessToken.hashed_token == hashed_token)
|
||||
.where(User.is_active) # type: ignore
|
||||
.where(
|
||||
(PersonalAccessToken.expires_at.is_(None))
|
||||
| (PersonalAccessToken.expires_at > now)
|
||||
)
|
||||
.limit(1)
|
||||
.options(selectinload(User.memories))
|
||||
)
|
||||
row = result.first()
|
||||
|
||||
if not row:
|
||||
if not user:
|
||||
return None
|
||||
|
||||
pat, user = row
|
||||
|
||||
# Throttle last_used_at updates to reduce DB load (5-minute granularity sufficient for auditing)
|
||||
# For request-level auditing, use application logs or a dedicated audit table
|
||||
should_update = (
|
||||
pat.last_used_at is None or (now - pat.last_used_at).total_seconds() > 300
|
||||
)
|
||||
|
||||
if should_update:
|
||||
# Update in separate session to avoid transaction coupling (fire-and-forget)
|
||||
async def _update_last_used() -> None:
|
||||
try:
|
||||
tenant_id = get_current_tenant_id()
|
||||
async with get_async_session_context_manager(
|
||||
tenant_id
|
||||
) as separate_session:
|
||||
await separate_session.execute(
|
||||
update(PersonalAccessToken)
|
||||
.where(PersonalAccessToken.hashed_token == hashed_token)
|
||||
.values(last_used_at=now)
|
||||
)
|
||||
await separate_session.commit()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to update last_used_at for PAT: {e}")
|
||||
|
||||
asyncio.create_task(_update_last_used())
|
||||
|
||||
_schedule_pat_last_used_update(hashed_token, now)
|
||||
return user
|
||||
|
||||
|
||||
def _schedule_pat_last_used_update(hashed_token: str, now: datetime) -> None:
|
||||
"""Fire-and-forget update of last_used_at, throttled to 5-minute granularity."""
|
||||
|
||||
async def _update() -> None:
|
||||
try:
|
||||
tenant_id = get_current_tenant_id()
|
||||
async with get_async_session_context_manager(tenant_id) as session:
|
||||
pat = await session.scalar(
|
||||
select(PersonalAccessToken).where(
|
||||
PersonalAccessToken.hashed_token == hashed_token
|
||||
)
|
||||
)
|
||||
if not pat:
|
||||
return
|
||||
if (
|
||||
pat.last_used_at is not None
|
||||
and (now - pat.last_used_at).total_seconds() <= 300
|
||||
):
|
||||
return
|
||||
await session.execute(
|
||||
update(PersonalAccessToken)
|
||||
.where(PersonalAccessToken.hashed_token == hashed_token)
|
||||
.values(last_used_at=now)
|
||||
)
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to update last_used_at for PAT: {e}")
|
||||
|
||||
asyncio.create_task(_update())
|
||||
|
||||
|
||||
def create_pat(
|
||||
db_session: Session,
|
||||
user_id: UUID,
|
||||
|
||||
@@ -28,6 +28,7 @@ from onyx.db.document_access import get_accessible_documents_by_ids
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.models import Document
|
||||
from onyx.db.models import DocumentSet
|
||||
from onyx.db.models import FederatedConnector__DocumentSet
|
||||
from onyx.db.models import HierarchyNode
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import Persona__User
|
||||
@@ -420,9 +421,16 @@ def get_minimal_persona_snapshots_for_user(
|
||||
stmt = stmt.options(
|
||||
selectinload(Persona.tools),
|
||||
selectinload(Persona.labels),
|
||||
selectinload(Persona.document_sets)
|
||||
.selectinload(DocumentSet.connector_credential_pairs)
|
||||
.selectinload(ConnectorCredentialPair.connector),
|
||||
selectinload(Persona.document_sets).options(
|
||||
selectinload(DocumentSet.connector_credential_pairs).selectinload(
|
||||
ConnectorCredentialPair.connector
|
||||
),
|
||||
selectinload(DocumentSet.users),
|
||||
selectinload(DocumentSet.groups),
|
||||
selectinload(DocumentSet.federated_connectors).selectinload(
|
||||
FederatedConnector__DocumentSet.federated_connector
|
||||
),
|
||||
),
|
||||
selectinload(Persona.hierarchy_nodes),
|
||||
selectinload(Persona.attached_documents).selectinload(
|
||||
Document.parent_hierarchy_node
|
||||
@@ -453,7 +461,16 @@ def get_persona_snapshots_for_user(
|
||||
Document.parent_hierarchy_node
|
||||
),
|
||||
selectinload(Persona.labels),
|
||||
selectinload(Persona.document_sets),
|
||||
selectinload(Persona.document_sets).options(
|
||||
selectinload(DocumentSet.connector_credential_pairs).selectinload(
|
||||
ConnectorCredentialPair.connector
|
||||
),
|
||||
selectinload(DocumentSet.users),
|
||||
selectinload(DocumentSet.groups),
|
||||
selectinload(DocumentSet.federated_connectors).selectinload(
|
||||
FederatedConnector__DocumentSet.federated_connector
|
||||
),
|
||||
),
|
||||
selectinload(Persona.user),
|
||||
selectinload(Persona.user_files),
|
||||
selectinload(Persona.users),
|
||||
@@ -550,9 +567,16 @@ def get_minimal_persona_snapshots_paginated(
|
||||
Document.parent_hierarchy_node
|
||||
),
|
||||
selectinload(Persona.labels),
|
||||
selectinload(Persona.document_sets)
|
||||
.selectinload(DocumentSet.connector_credential_pairs)
|
||||
.selectinload(ConnectorCredentialPair.connector),
|
||||
selectinload(Persona.document_sets).options(
|
||||
selectinload(DocumentSet.connector_credential_pairs).selectinload(
|
||||
ConnectorCredentialPair.connector
|
||||
),
|
||||
selectinload(DocumentSet.users),
|
||||
selectinload(DocumentSet.groups),
|
||||
selectinload(DocumentSet.federated_connectors).selectinload(
|
||||
FederatedConnector__DocumentSet.federated_connector
|
||||
),
|
||||
),
|
||||
selectinload(Persona.user),
|
||||
)
|
||||
|
||||
@@ -611,7 +635,16 @@ def get_persona_snapshots_paginated(
|
||||
Document.parent_hierarchy_node
|
||||
),
|
||||
selectinload(Persona.labels),
|
||||
selectinload(Persona.document_sets),
|
||||
selectinload(Persona.document_sets).options(
|
||||
selectinload(DocumentSet.connector_credential_pairs).selectinload(
|
||||
ConnectorCredentialPair.connector
|
||||
),
|
||||
selectinload(DocumentSet.users),
|
||||
selectinload(DocumentSet.groups),
|
||||
selectinload(DocumentSet.federated_connectors).selectinload(
|
||||
FederatedConnector__DocumentSet.federated_connector
|
||||
),
|
||||
),
|
||||
selectinload(Persona.user),
|
||||
selectinload(Persona.user_files),
|
||||
selectinload(Persona.users),
|
||||
|
||||
@@ -554,10 +554,9 @@ class VespaDocumentIndex(DocumentIndex):
|
||||
num_to_retrieve: int,
|
||||
) -> list[InferenceChunk]:
|
||||
vespa_where_clauses = build_vespa_filters(filters)
|
||||
# Needs to be at least as much as the rerank-count value set in the
|
||||
# Vespa schema config. Otherwise we would be getting fewer results than
|
||||
# expected for reranking.
|
||||
target_hits = max(10 * num_to_retrieve, RERANK_COUNT)
|
||||
# Avoid over-fetching a very large candidate set for global-phase reranking.
|
||||
# Keep enough headroom for quality while capping cost on larger indices.
|
||||
target_hits = min(max(4 * num_to_retrieve, 100), RERANK_COUNT)
|
||||
|
||||
yql = (
|
||||
YQL_BASE.format(index_name=self._index_name)
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
import json
|
||||
import pathlib
|
||||
import threading
|
||||
import time
|
||||
|
||||
from onyx.llm.constants import LlmProviderNames
|
||||
from onyx.llm.constants import PROVIDER_DISPLAY_NAMES
|
||||
@@ -23,6 +25,11 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_RECOMMENDATIONS_CACHE_TTL_SECONDS = 300
|
||||
_recommendations_cache_lock = threading.Lock()
|
||||
_cached_recommendations: LLMRecommendations | None = None
|
||||
_cached_recommendations_time: float = 0.0
|
||||
|
||||
|
||||
def _get_provider_to_models_map() -> dict[str, list[str]]:
|
||||
"""Lazy-load provider model mappings to avoid importing litellm at module level.
|
||||
@@ -41,19 +48,40 @@ def _get_provider_to_models_map() -> dict[str, list[str]]:
|
||||
}
|
||||
|
||||
|
||||
def get_recommendations() -> LLMRecommendations:
|
||||
"""Get the recommendations from the GitHub config."""
|
||||
recommendations_from_github = fetch_llm_recommendations_from_github()
|
||||
if recommendations_from_github:
|
||||
return recommendations_from_github
|
||||
|
||||
# Fall back to json bundled with code
|
||||
def _load_bundled_recommendations() -> LLMRecommendations:
|
||||
json_path = pathlib.Path(__file__).parent / "recommended-models.json"
|
||||
with open(json_path, "r") as f:
|
||||
json_config = json.load(f)
|
||||
return LLMRecommendations.model_validate(json_config)
|
||||
|
||||
recommendations_from_json = LLMRecommendations.model_validate(json_config)
|
||||
return recommendations_from_json
|
||||
|
||||
def get_recommendations() -> LLMRecommendations:
|
||||
"""Get the recommendations, with an in-memory cache to avoid
|
||||
hitting GitHub on every API request."""
|
||||
global _cached_recommendations, _cached_recommendations_time
|
||||
|
||||
now = time.monotonic()
|
||||
if (
|
||||
_cached_recommendations is not None
|
||||
and (now - _cached_recommendations_time) < _RECOMMENDATIONS_CACHE_TTL_SECONDS
|
||||
):
|
||||
return _cached_recommendations
|
||||
|
||||
with _recommendations_cache_lock:
|
||||
# Double-check after acquiring lock
|
||||
if (
|
||||
_cached_recommendations is not None
|
||||
and (time.monotonic() - _cached_recommendations_time)
|
||||
< _RECOMMENDATIONS_CACHE_TTL_SECONDS
|
||||
):
|
||||
return _cached_recommendations
|
||||
|
||||
recommendations_from_github = fetch_llm_recommendations_from_github()
|
||||
result = recommendations_from_github or _load_bundled_recommendations()
|
||||
|
||||
_cached_recommendations = result
|
||||
_cached_recommendations_time = time.monotonic()
|
||||
return result
|
||||
|
||||
|
||||
def is_obsolete_model(model_name: str, provider: str) -> bool:
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# ruff: noqa: E501, W605 start
|
||||
# If there are any tools, this section is included, the sections below are for the available tools
|
||||
TOOL_SECTION_HEADER = "\n\n# Tools\n"
|
||||
TOOL_SECTION_HEADER = "\n# Tools\n\n"
|
||||
|
||||
|
||||
# This section is included if there are search type tools, currently internal_search and web_search
|
||||
@@ -16,11 +16,10 @@ When searching for information, if the initial results cannot fully answer the u
|
||||
Do not repeat the same or very similar queries if it already has been run in the chat history.
|
||||
|
||||
If it is unclear which tool to use, consider using multiple in parallel to be efficient with time.
|
||||
"""
|
||||
""".lstrip()
|
||||
|
||||
|
||||
INTERNAL_SEARCH_GUIDANCE = """
|
||||
|
||||
## internal_search
|
||||
Use the `internal_search` tool to search connected applications for information. Some examples of when to use `internal_search` include:
|
||||
- Internal information: any time where there may be some information stored in internal applications that could help better answer the query.
|
||||
@@ -28,34 +27,31 @@ Use the `internal_search` tool to search connected applications for information.
|
||||
- Keyword Queries: queries that are heavily keyword based are often internal document search queries.
|
||||
- Ambiguity: questions about something that is not widely known or understood.
|
||||
Never provide more than 3 queries at once to `internal_search`.
|
||||
"""
|
||||
""".lstrip()
|
||||
|
||||
|
||||
WEB_SEARCH_GUIDANCE = """
|
||||
|
||||
## web_search
|
||||
Use the `web_search` tool to access up-to-date information from the web. Some examples of when to use `web_search` include:
|
||||
- Freshness: when the answer might be enhanced by up-to-date information on a topic. Very important for topics that are changing or evolving.
|
||||
- Accuracy: if the cost of outdated/inaccurate information is high.
|
||||
- Niche Information: when detailed info is not widely known or understood (but is likely found on the internet).{site_colon_disabled}
|
||||
"""
|
||||
""".lstrip()
|
||||
|
||||
WEB_SEARCH_SITE_DISABLED_GUIDANCE = """
|
||||
Do not use the "site:" operator in your web search queries.
|
||||
""".rstrip()
|
||||
""".lstrip()
|
||||
|
||||
|
||||
OPEN_URLS_GUIDANCE = """
|
||||
|
||||
## open_url
|
||||
Use the `open_url` tool to read the content of one or more URLs. Use this tool to access the contents of the most promising web pages from your web searches or user specified URLs. \
|
||||
You can open many URLs at once by passing multiple URLs in the array if multiple pages seem promising. Prioritize the most promising pages and reputable sources. \
|
||||
Do not open URLs that are image files like .png, .jpg, etc.
|
||||
You should almost always use open_url after a web_search call. Use this tool when a user asks about a specific provided URL.
|
||||
"""
|
||||
""".lstrip()
|
||||
|
||||
PYTHON_TOOL_GUIDANCE = """
|
||||
|
||||
## python
|
||||
Use the `python` tool to execute Python code in an isolated sandbox. The tool will respond with the output of the execution or time out after 60.0 seconds.
|
||||
Any files uploaded to the chat will be automatically be available in the execution environment's current directory. \
|
||||
@@ -64,23 +60,21 @@ Use this to give the user a way to download the file OR to display generated ima
|
||||
Internet access for this session is disabled. Do not make external web requests or API calls as they will fail.
|
||||
Use `openpyxl` to read and write Excel files. You have access to libraries like numpy, pandas, scipy, matplotlib, and PIL.
|
||||
IMPORTANT: each call to this tool is independent. Variables from previous calls will NOT be available in the current call.
|
||||
"""
|
||||
""".lstrip()
|
||||
|
||||
GENERATE_IMAGE_GUIDANCE = """
|
||||
|
||||
## generate_image
|
||||
NEVER use generate_image unless the user specifically requests an image.
|
||||
For edits/variations of a previously generated image, pass `reference_image_file_ids` with
|
||||
the `file_id` values returned by earlier `generate_image` tool results.
|
||||
"""
|
||||
""".lstrip()
|
||||
|
||||
MEMORY_GUIDANCE = """
|
||||
|
||||
## add_memory
|
||||
Use the `add_memory` tool for facts shared by the user that should be remembered for future conversations. \
|
||||
Only add memories that are specific, likely to remain true, and likely to be useful later. \
|
||||
Focus on enduring preferences, long-term goals, stable constraints, and explicit "remember this" type requests.
|
||||
"""
|
||||
""".lstrip()
|
||||
|
||||
TOOL_CALL_FAILURE_PROMPT = """
|
||||
LLM attempted to call a tool but failed. Most likely the tool name or arguments were misspelled.
|
||||
|
||||
@@ -1,40 +1,36 @@
|
||||
# ruff: noqa: E501, W605 start
|
||||
USER_INFORMATION_HEADER = "\n\n# User Information\n"
|
||||
USER_INFORMATION_HEADER = "\n# User Information\n\n"
|
||||
|
||||
BASIC_INFORMATION_PROMPT = """
|
||||
|
||||
## Basic Information
|
||||
User name: {user_name}
|
||||
User email: {user_email}{user_role}
|
||||
"""
|
||||
""".lstrip()
|
||||
|
||||
# This line only shows up if the user has configured their role.
|
||||
USER_ROLE_PROMPT = """
|
||||
User role: {user_role}
|
||||
"""
|
||||
""".lstrip()
|
||||
|
||||
# Team information should be a paragraph style description of the user's team.
|
||||
TEAM_INFORMATION_PROMPT = """
|
||||
|
||||
## Team Information
|
||||
{team_information}
|
||||
"""
|
||||
""".lstrip()
|
||||
|
||||
# User preferences should be a paragraph style description of the user's preferences.
|
||||
USER_PREFERENCES_PROMPT = """
|
||||
|
||||
## User Preferences
|
||||
{user_preferences}
|
||||
"""
|
||||
""".lstrip()
|
||||
|
||||
# User memories should look something like:
|
||||
# - Memory 1
|
||||
# - Memory 2
|
||||
# - Memory 3
|
||||
USER_MEMORIES_PROMPT = """
|
||||
|
||||
## User Memories
|
||||
{user_memories}
|
||||
"""
|
||||
""".lstrip()
|
||||
|
||||
# ruff: noqa: E501, W605 end
|
||||
|
||||
@@ -988,6 +988,7 @@ def get_connector_status(
|
||||
user=user,
|
||||
eager_load_connector=True,
|
||||
eager_load_credential=True,
|
||||
eager_load_user=True,
|
||||
get_editable=False,
|
||||
)
|
||||
|
||||
@@ -1001,11 +1002,23 @@ def get_connector_status(
|
||||
relationship.user_group_id
|
||||
)
|
||||
|
||||
# Pre-compute credential_ids per connector to avoid N+1 lazy loads
|
||||
connector_to_credential_ids: dict[int, list[int]] = {}
|
||||
for cc_pair in cc_pairs:
|
||||
connector_to_credential_ids.setdefault(cc_pair.connector_id, []).append(
|
||||
cc_pair.credential_id
|
||||
)
|
||||
|
||||
return [
|
||||
ConnectorStatus(
|
||||
cc_pair_id=cc_pair.id,
|
||||
name=cc_pair.name,
|
||||
connector=ConnectorSnapshot.from_connector_db_model(cc_pair.connector),
|
||||
connector=ConnectorSnapshot.from_connector_db_model(
|
||||
cc_pair.connector,
|
||||
credential_ids=connector_to_credential_ids.get(
|
||||
cc_pair.connector_id, []
|
||||
),
|
||||
),
|
||||
credential=CredentialSnapshot.from_credential_db_model(cc_pair.credential),
|
||||
access_type=cc_pair.access_type,
|
||||
groups=group_cc_pair_relationships_dict.get(cc_pair.id, []),
|
||||
@@ -1060,15 +1073,27 @@ def get_connector_indexing_status(
|
||||
parallel_functions: list[tuple[CallableProtocol, tuple[Any, ...]]] = [
|
||||
# Get editable connector/credential pairs
|
||||
(
|
||||
get_connector_credential_pairs_for_user_parallel,
|
||||
(user, True, None, True, True, True, True, request.source),
|
||||
lambda: get_connector_credential_pairs_for_user_parallel(
|
||||
user, True, None, True, True, False, True, request.source
|
||||
),
|
||||
(),
|
||||
),
|
||||
# Get federated connectors
|
||||
(fetch_all_federated_connectors_parallel, ()),
|
||||
# Get most recent index attempts
|
||||
(get_latest_index_attempts_parallel, (request.secondary_index, True, False)),
|
||||
(
|
||||
lambda: get_latest_index_attempts_parallel(
|
||||
request.secondary_index, True, False
|
||||
),
|
||||
(),
|
||||
),
|
||||
# Get most recent finished index attempts
|
||||
(get_latest_index_attempts_parallel, (request.secondary_index, True, True)),
|
||||
(
|
||||
lambda: get_latest_index_attempts_parallel(
|
||||
request.secondary_index, True, True
|
||||
),
|
||||
(),
|
||||
),
|
||||
]
|
||||
|
||||
if user and user.role == UserRole.ADMIN:
|
||||
@@ -1085,8 +1110,10 @@ def get_connector_indexing_status(
|
||||
parallel_functions.append(
|
||||
# Get non-editable connector/credential pairs
|
||||
(
|
||||
get_connector_credential_pairs_for_user_parallel,
|
||||
(user, False, None, True, True, True, True, request.source),
|
||||
lambda: get_connector_credential_pairs_for_user_parallel(
|
||||
user, False, None, True, True, False, True, request.source
|
||||
),
|
||||
(),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1912,6 +1939,7 @@ Tenant ID: {tenant_id}
|
||||
class BasicCCPairInfo(BaseModel):
|
||||
has_successful_run: bool
|
||||
source: DocumentSource
|
||||
status: ConnectorCredentialPairStatus
|
||||
|
||||
|
||||
@router.get("/connector-status", tags=PUBLIC_API_TAGS)
|
||||
@@ -1931,6 +1959,7 @@ def get_basic_connector_indexing_status(
|
||||
BasicCCPairInfo(
|
||||
has_successful_run=cc_pair.last_successful_index_time is not None,
|
||||
source=cc_pair.connector.source,
|
||||
status=cc_pair.status,
|
||||
)
|
||||
for cc_pair in cc_pairs
|
||||
if cc_pair.connector.source != DocumentSource.INGESTION_API
|
||||
|
||||
@@ -365,7 +365,8 @@ class CCPairFullInfo(BaseModel):
|
||||
in_repeated_error_state=cc_pair_model.in_repeated_error_state,
|
||||
num_docs_indexed=num_docs_indexed,
|
||||
connector=ConnectorSnapshot.from_connector_db_model(
|
||||
cc_pair_model.connector
|
||||
cc_pair_model.connector,
|
||||
credential_ids=[cc_pair_model.credential_id],
|
||||
),
|
||||
credential=CredentialSnapshot.from_credential_db_model(
|
||||
cc_pair_model.credential
|
||||
|
||||
@@ -111,7 +111,8 @@ class DocumentSet(BaseModel):
|
||||
id=cc_pair.id,
|
||||
name=cc_pair.name,
|
||||
connector=ConnectorSnapshot.from_connector_db_model(
|
||||
cc_pair.connector
|
||||
cc_pair.connector,
|
||||
credential_ids=[cc_pair.credential_id],
|
||||
),
|
||||
credential=CredentialSnapshot.from_credential_db_model(
|
||||
cc_pair.credential
|
||||
|
||||
@@ -57,6 +57,7 @@ class Settings(BaseModel):
|
||||
anonymous_user_enabled: bool | None = None
|
||||
invite_only_enabled: bool = False
|
||||
deep_research_enabled: bool | None = None
|
||||
search_ui_enabled: bool | None = None
|
||||
|
||||
# Enterprise features flag - set by license enforcement at runtime
|
||||
# When LICENSE_ENFORCEMENT_ENABLED=true, this reflects license status
|
||||
|
||||
@@ -171,10 +171,8 @@ def construct_tools(
|
||||
if not search_tool_config:
|
||||
search_tool_config = SearchToolConfig()
|
||||
|
||||
# TODO concerning passing the db_session here.
|
||||
search_tool = SearchTool(
|
||||
tool_id=db_tool_model.id,
|
||||
db_session=db_session,
|
||||
emitter=emitter,
|
||||
user=user,
|
||||
persona=persona,
|
||||
@@ -422,7 +420,6 @@ def construct_tools(
|
||||
|
||||
search_tool = SearchTool(
|
||||
tool_id=search_tool_db_model.id,
|
||||
db_session=db_session,
|
||||
emitter=emitter,
|
||||
user=user,
|
||||
persona=persona,
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
from onyx.file_processing.html_utils import ParsedHTML
|
||||
from onyx.file_processing.html_utils import web_html_cleanup
|
||||
@@ -21,10 +22,22 @@ from onyx.utils.web_content import title_from_url
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
DEFAULT_TIMEOUT_SECONDS = 15
|
||||
DEFAULT_READ_TIMEOUT_SECONDS = 15
|
||||
DEFAULT_CONNECT_TIMEOUT_SECONDS = 5
|
||||
DEFAULT_USER_AGENT = "OnyxWebCrawler/1.0 (+https://www.onyx.app)"
|
||||
DEFAULT_MAX_PDF_SIZE_BYTES = 50 * 1024 * 1024 # 50 MB
|
||||
DEFAULT_MAX_HTML_SIZE_BYTES = 20 * 1024 * 1024 # 20 MB
|
||||
DEFAULT_MAX_WORKERS = 5
|
||||
|
||||
|
||||
def _failed_result(url: str) -> WebContent:
|
||||
return WebContent(
|
||||
title="",
|
||||
link=url,
|
||||
full_content="",
|
||||
published_date=None,
|
||||
scrape_successful=False,
|
||||
)
|
||||
|
||||
|
||||
class OnyxWebCrawler(WebContentProvider):
|
||||
@@ -37,12 +50,14 @@ class OnyxWebCrawler(WebContentProvider):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
timeout_seconds: int = DEFAULT_TIMEOUT_SECONDS,
|
||||
timeout_seconds: int = DEFAULT_READ_TIMEOUT_SECONDS,
|
||||
connect_timeout_seconds: int = DEFAULT_CONNECT_TIMEOUT_SECONDS,
|
||||
user_agent: str = DEFAULT_USER_AGENT,
|
||||
max_pdf_size_bytes: int | None = None,
|
||||
max_html_size_bytes: int | None = None,
|
||||
) -> None:
|
||||
self._timeout_seconds = timeout_seconds
|
||||
self._read_timeout_seconds = timeout_seconds
|
||||
self._connect_timeout_seconds = connect_timeout_seconds
|
||||
self._max_pdf_size_bytes = max_pdf_size_bytes
|
||||
self._max_html_size_bytes = max_html_size_bytes
|
||||
self._headers = {
|
||||
@@ -51,75 +66,68 @@ class OnyxWebCrawler(WebContentProvider):
|
||||
}
|
||||
|
||||
def contents(self, urls: Sequence[str]) -> list[WebContent]:
|
||||
results: list[WebContent] = []
|
||||
for url in urls:
|
||||
results.append(self._fetch_url(url))
|
||||
return results
|
||||
if not urls:
|
||||
return []
|
||||
|
||||
max_workers = min(DEFAULT_MAX_WORKERS, len(urls))
|
||||
with ThreadPoolExecutor(max_workers=max_workers) as executor:
|
||||
return list(executor.map(self._fetch_url_safe, urls))
|
||||
|
||||
def _fetch_url_safe(self, url: str) -> WebContent:
|
||||
"""Wrapper that catches all exceptions so one bad URL doesn't kill the batch."""
|
||||
try:
|
||||
return self._fetch_url(url)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Onyx crawler unexpected error for %s (%s)",
|
||||
url,
|
||||
exc.__class__.__name__,
|
||||
)
|
||||
return _failed_result(url)
|
||||
|
||||
def _fetch_url(self, url: str) -> WebContent:
|
||||
try:
|
||||
# Use SSRF-safe request to prevent DNS rebinding attacks
|
||||
response = ssrf_safe_get(
|
||||
url, headers=self._headers, timeout=self._timeout_seconds
|
||||
url,
|
||||
headers=self._headers,
|
||||
timeout=(self._connect_timeout_seconds, self._read_timeout_seconds),
|
||||
)
|
||||
except SSRFException as exc:
|
||||
logger.error(
|
||||
"SSRF protection blocked request to %s: %s",
|
||||
"SSRF protection blocked request to %s (%s)",
|
||||
url,
|
||||
str(exc),
|
||||
exc.__class__.__name__,
|
||||
)
|
||||
return WebContent(
|
||||
title="",
|
||||
link=url,
|
||||
full_content="",
|
||||
published_date=None,
|
||||
scrape_successful=False,
|
||||
)
|
||||
except Exception as exc: # pragma: no cover - network failures vary
|
||||
return _failed_result(url)
|
||||
except Exception as exc:
|
||||
logger.warning(
|
||||
"Onyx crawler failed to fetch %s (%s)",
|
||||
url,
|
||||
exc.__class__.__name__,
|
||||
)
|
||||
return WebContent(
|
||||
title="",
|
||||
link=url,
|
||||
full_content="",
|
||||
published_date=None,
|
||||
scrape_successful=False,
|
||||
)
|
||||
return _failed_result(url)
|
||||
|
||||
if response.status_code >= 400:
|
||||
logger.warning("Onyx crawler received %s for %s", response.status_code, url)
|
||||
return WebContent(
|
||||
title="",
|
||||
link=url,
|
||||
full_content="",
|
||||
published_date=None,
|
||||
scrape_successful=False,
|
||||
)
|
||||
return _failed_result(url)
|
||||
|
||||
content_type = response.headers.get("Content-Type", "")
|
||||
content_sniff = response.content[:1024] if response.content else None
|
||||
content = response.content
|
||||
|
||||
content_sniff = content[:1024] if content else None
|
||||
if is_pdf_resource(url, content_type, content_sniff):
|
||||
if (
|
||||
self._max_pdf_size_bytes is not None
|
||||
and len(response.content) > self._max_pdf_size_bytes
|
||||
and len(content) > self._max_pdf_size_bytes
|
||||
):
|
||||
logger.warning(
|
||||
"PDF content too large (%d bytes) for %s, max is %d",
|
||||
len(response.content),
|
||||
len(content),
|
||||
url,
|
||||
self._max_pdf_size_bytes,
|
||||
)
|
||||
return WebContent(
|
||||
title="",
|
||||
link=url,
|
||||
full_content="",
|
||||
published_date=None,
|
||||
scrape_successful=False,
|
||||
)
|
||||
text_content, metadata = extract_pdf_text(response.content)
|
||||
return _failed_result(url)
|
||||
text_content, metadata = extract_pdf_text(content)
|
||||
title = title_from_pdf_metadata(metadata) or title_from_url(url)
|
||||
return WebContent(
|
||||
title=title,
|
||||
@@ -131,25 +139,19 @@ class OnyxWebCrawler(WebContentProvider):
|
||||
|
||||
if (
|
||||
self._max_html_size_bytes is not None
|
||||
and len(response.content) > self._max_html_size_bytes
|
||||
and len(content) > self._max_html_size_bytes
|
||||
):
|
||||
logger.warning(
|
||||
"HTML content too large (%d bytes) for %s, max is %d",
|
||||
len(response.content),
|
||||
len(content),
|
||||
url,
|
||||
self._max_html_size_bytes,
|
||||
)
|
||||
return WebContent(
|
||||
title="",
|
||||
link=url,
|
||||
full_content="",
|
||||
published_date=None,
|
||||
scrape_successful=False,
|
||||
)
|
||||
return _failed_result(url)
|
||||
|
||||
try:
|
||||
decoded_html = decode_html_bytes(
|
||||
response.content,
|
||||
content,
|
||||
content_type=content_type,
|
||||
fallback_encoding=response.apparent_encoding or response.encoding,
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -146,7 +146,7 @@ MAX_REDIRECTS = 10
|
||||
def _make_ssrf_safe_request(
|
||||
url: str,
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: int = 15,
|
||||
timeout: float | tuple[float, float] = 15,
|
||||
**kwargs: Any,
|
||||
) -> requests.Response:
|
||||
"""
|
||||
@@ -204,7 +204,7 @@ def _make_ssrf_safe_request(
|
||||
def ssrf_safe_get(
|
||||
url: str,
|
||||
headers: dict[str, str] | None = None,
|
||||
timeout: int = 15,
|
||||
timeout: float | tuple[float, float] = 15,
|
||||
follow_redirects: bool = True,
|
||||
**kwargs: Any,
|
||||
) -> requests.Response:
|
||||
|
||||
@@ -243,12 +243,12 @@ USAGE_LIMIT_CHUNKS_INDEXED_PAID = int(
|
||||
)
|
||||
|
||||
# Per-week API calls using API keys or Personal Access Tokens
|
||||
USAGE_LIMIT_API_CALLS_TRIAL = int(os.environ.get("USAGE_LIMIT_API_CALLS_TRIAL", "400"))
|
||||
USAGE_LIMIT_API_CALLS_TRIAL = int(os.environ.get("USAGE_LIMIT_API_CALLS_TRIAL", "0"))
|
||||
USAGE_LIMIT_API_CALLS_PAID = int(os.environ.get("USAGE_LIMIT_API_CALLS_PAID", "40000"))
|
||||
|
||||
# Per-week non-streaming API calls (more expensive, so lower limits)
|
||||
USAGE_LIMIT_NON_STREAMING_CALLS_TRIAL = int(
|
||||
os.environ.get("USAGE_LIMIT_NON_STREAMING_CALLS_TRIAL", "80")
|
||||
os.environ.get("USAGE_LIMIT_NON_STREAMING_CALLS_TRIAL", "0")
|
||||
)
|
||||
USAGE_LIMIT_NON_STREAMING_CALLS_PAID = int(
|
||||
os.environ.get("USAGE_LIMIT_NON_STREAMING_CALLS_PAID", "160")
|
||||
|
||||
@@ -2,6 +2,7 @@ from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from contextlib import contextmanager
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -12,9 +13,13 @@ from onyx.context.search.models import ChunkSearchRequest
|
||||
from onyx.context.search.models import InferenceChunk
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.db.models import Persona
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.models import User
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.federated_connectors.federated_retrieval import FederatedRetrievalInfo
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
|
||||
|
||||
def run_functions_tuples_sequential(
|
||||
@@ -135,13 +140,25 @@ def use_mock_search_pipeline(
|
||||
document_index: DocumentIndex, # noqa: ARG001
|
||||
user: User | None, # noqa: ARG001
|
||||
persona: Persona | None, # noqa: ARG001
|
||||
db_session: Session, # noqa: ARG001
|
||||
db_session: Session | None = None, # noqa: ARG001
|
||||
auto_detect_filters: bool = False, # noqa: ARG001
|
||||
llm: LLM | None = None, # noqa: ARG001
|
||||
project_id: int | None = None, # noqa: ARG001
|
||||
# Pre-fetched data (used by SearchTool to avoid DB access in parallel)
|
||||
acl_filters: list[str] | None = None, # noqa: ARG001
|
||||
embedding_model: EmbeddingModel | None = None, # noqa: ARG001
|
||||
prefetched_federated_retrieval_infos: ( # noqa: ARG001
|
||||
list[FederatedRetrievalInfo] | None
|
||||
) = None,
|
||||
) -> list[InferenceChunk]:
|
||||
return controller.get_search_results(chunk_search_request.query)
|
||||
|
||||
# Mock the pre-fetch session and DB queries in SearchTool.run() so
|
||||
# tests don't need a fully initialised DB with search settings.
|
||||
@contextmanager
|
||||
def mock_get_session() -> Generator[MagicMock, None, None]:
|
||||
yield MagicMock(spec=Session)
|
||||
|
||||
with (
|
||||
patch(
|
||||
"onyx.tools.tool_implementations.search.search_tool.search_pipeline",
|
||||
@@ -183,5 +200,31 @@ def use_mock_search_pipeline(
|
||||
"onyx.db.connector.fetch_unique_document_sources",
|
||||
new=mock_fetch_unique_document_sources,
|
||||
),
|
||||
# Mock the pre-fetch phase of SearchTool.run()
|
||||
patch(
|
||||
"onyx.tools.tool_implementations.search.search_tool.get_session_with_current_tenant",
|
||||
new=mock_get_session,
|
||||
),
|
||||
patch(
|
||||
"onyx.tools.tool_implementations.search.search_tool.build_access_filters_for_user",
|
||||
return_value=[],
|
||||
),
|
||||
patch(
|
||||
"onyx.tools.tool_implementations.search.search_tool.get_current_search_settings",
|
||||
return_value=MagicMock(spec=SearchSettings),
|
||||
),
|
||||
patch(
|
||||
"onyx.tools.tool_implementations.search.search_tool.EmbeddingModel.from_db_model",
|
||||
return_value=MagicMock(spec=EmbeddingModel),
|
||||
),
|
||||
patch(
|
||||
"onyx.tools.tool_implementations.search.search_tool.get_federated_retrieval_functions",
|
||||
return_value=[],
|
||||
),
|
||||
patch.object(
|
||||
SearchTool,
|
||||
"_prefetch_slack_data",
|
||||
return_value=(None, None, {}),
|
||||
),
|
||||
):
|
||||
yield controller
|
||||
|
||||
@@ -106,13 +106,13 @@ class TestGuildDataIsolation:
|
||||
|
||||
# Create admin user for tenant 1
|
||||
admin_user1: DATestUser = UserManager.create(
|
||||
email=f"discord_admin1+{unique}@example.com",
|
||||
email=f"discord_admin1_{unique}@example.com",
|
||||
)
|
||||
assert UserManager.is_role(admin_user1, UserRole.ADMIN)
|
||||
|
||||
# Create admin user for tenant 2
|
||||
admin_user2: DATestUser = UserManager.create(
|
||||
email=f"discord_admin2+{unique}@example.com",
|
||||
email=f"discord_admin2_{unique}@example.com",
|
||||
)
|
||||
assert UserManager.is_role(admin_user2, UserRole.ADMIN)
|
||||
|
||||
@@ -170,10 +170,10 @@ class TestGuildDataIsolation:
|
||||
|
||||
# Create admin users for two tenants
|
||||
admin_user1: DATestUser = UserManager.create(
|
||||
email=f"discord_list1+{unique}@example.com",
|
||||
email=f"discord_list1_{unique}@example.com",
|
||||
)
|
||||
admin_user2: DATestUser = UserManager.create(
|
||||
email=f"discord_list2+{unique}@example.com",
|
||||
email=f"discord_list2_{unique}@example.com",
|
||||
)
|
||||
|
||||
# Create 1 guild in tenant 1
|
||||
@@ -350,10 +350,10 @@ class TestGuildAccessIsolation:
|
||||
|
||||
# Create admin users for two tenants
|
||||
admin_user1: DATestUser = UserManager.create(
|
||||
email=f"discord_access1+{unique}@example.com",
|
||||
email=f"discord_access1_{unique}@example.com",
|
||||
)
|
||||
admin_user2: DATestUser = UserManager.create(
|
||||
email=f"discord_access2+{unique}@example.com",
|
||||
email=f"discord_access2_{unique}@example.com",
|
||||
)
|
||||
|
||||
# Create a guild in tenant 1
|
||||
|
||||
@@ -21,7 +21,7 @@ def test_admin_can_invite_users(reset_multitenant: None) -> None: # noqa: ARG00
|
||||
|
||||
# Admin user invites the previously registered and non-registered user
|
||||
UserManager.invite_user(invited_user.email, admin_user)
|
||||
UserManager.invite_user(f"{INVITED_BASIC_USER}+{unique}@example.com", admin_user)
|
||||
UserManager.invite_user(f"{INVITED_BASIC_USER}_{unique}@example.com", admin_user)
|
||||
|
||||
# Verify users are in the invited users list
|
||||
invited_users = UserManager.get_invited_users(admin_user)
|
||||
@@ -40,7 +40,7 @@ def test_non_registered_user_gets_basic_role(
|
||||
assert UserManager.is_role(admin_user, UserRole.ADMIN)
|
||||
|
||||
# Admin user invites a non-registered user
|
||||
invited_email = f"{INVITED_BASIC_USER}+{unique}@example.com"
|
||||
invited_email = f"{INVITED_BASIC_USER}_{unique}@example.com"
|
||||
UserManager.invite_user(invited_email, admin_user)
|
||||
|
||||
# Non-registered user registers
|
||||
@@ -58,7 +58,7 @@ def test_user_can_accept_invitation(reset_multitenant: None) -> None: # noqa: A
|
||||
assert UserManager.is_role(admin_user, UserRole.ADMIN)
|
||||
|
||||
# Create a user to be invited
|
||||
invited_user_email = f"invited_user+{unique}@example.com"
|
||||
invited_user_email = f"invited_user_{unique}@example.com"
|
||||
|
||||
# User registers with the same email as the invitation
|
||||
invited_user: DATestUser = UserManager.create(
|
||||
|
||||
@@ -20,13 +20,13 @@ def setup_test_tenants(reset_multitenant: None) -> dict[str, Any]: # noqa: ARG0
|
||||
unique = uuid4().hex
|
||||
# Creating an admin user for Tenant 1
|
||||
admin_user1: DATestUser = UserManager.create(
|
||||
email=f"admin+{unique}@example.com",
|
||||
email=f"admin_{unique}@example.com",
|
||||
)
|
||||
assert UserManager.is_role(admin_user1, UserRole.ADMIN)
|
||||
|
||||
# Create Tenant 2 and its Admin User
|
||||
admin_user2: DATestUser = UserManager.create(
|
||||
email=f"admin2+{unique}@example.com",
|
||||
email=f"admin2_{unique}@example.com",
|
||||
)
|
||||
assert UserManager.is_role(admin_user2, UserRole.ADMIN)
|
||||
|
||||
|
||||
@@ -0,0 +1,167 @@
|
||||
"""
|
||||
Integration tests for onyx.db.engine.tenant_utils.get_schemas_needing_migration.
|
||||
|
||||
These tests require a live database and exercise the function directly,
|
||||
independent of the alembic migration runner script.
|
||||
|
||||
Usage:
|
||||
pytest tests/integration/multitenant_tests/test_get_schemas_needing_migration.py -v
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import subprocess
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import text
|
||||
from sqlalchemy.engine import Engine
|
||||
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.db.engine.tenant_utils import get_schemas_needing_migration
|
||||
|
||||
_BACKEND_DIR = __file__[: __file__.index("/tests/")]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def engine() -> Engine:
|
||||
return SqlEngine.get_engine()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def current_head_rev() -> str:
|
||||
result = subprocess.run(
|
||||
["alembic", "heads", "--resolve-dependencies"],
|
||||
cwd=_BACKEND_DIR,
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.STDOUT,
|
||||
text=True,
|
||||
)
|
||||
assert (
|
||||
result.returncode == 0
|
||||
), f"alembic heads failed (exit {result.returncode}):\n{result.stdout}"
|
||||
rev = result.stdout.strip().split()[0]
|
||||
assert len(rev) > 0
|
||||
return rev
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tenant_schema_at_head(
|
||||
engine: Engine, current_head_rev: str
|
||||
) -> Generator[str, None, None]:
|
||||
"""Tenant schema with alembic_version already at head — should be excluded."""
|
||||
schema = f"tenant_test_{uuid.uuid4().hex[:12]}"
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text(f'CREATE SCHEMA "{schema}"'))
|
||||
conn.execute(
|
||||
text(
|
||||
f'CREATE TABLE "{schema}".alembic_version '
|
||||
f"(version_num VARCHAR(32) NOT NULL)"
|
||||
)
|
||||
)
|
||||
conn.execute(
|
||||
text(f'INSERT INTO "{schema}".alembic_version (version_num) VALUES (:rev)'),
|
||||
{"rev": current_head_rev},
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
yield schema
|
||||
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
|
||||
conn.commit()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tenant_schema_empty(engine: Engine) -> Generator[str, None, None]:
|
||||
"""Tenant schema with no tables — should be included (needs migration)."""
|
||||
schema = f"tenant_test_{uuid.uuid4().hex[:12]}"
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text(f'CREATE SCHEMA "{schema}"'))
|
||||
conn.commit()
|
||||
|
||||
yield schema
|
||||
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
|
||||
conn.commit()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tenant_schema_stale_rev(engine: Engine) -> Generator[str, None, None]:
|
||||
"""Tenant schema with a non-head revision — should be included (needs migration)."""
|
||||
schema = f"tenant_test_{uuid.uuid4().hex[:12]}"
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text(f'CREATE SCHEMA "{schema}"'))
|
||||
conn.execute(
|
||||
text(
|
||||
f'CREATE TABLE "{schema}".alembic_version '
|
||||
f"(version_num VARCHAR(32) NOT NULL)"
|
||||
)
|
||||
)
|
||||
conn.execute(
|
||||
text(
|
||||
f'INSERT INTO "{schema}".alembic_version (version_num) '
|
||||
f"VALUES ('stalerev000000000000')"
|
||||
)
|
||||
)
|
||||
conn.commit()
|
||||
|
||||
yield schema
|
||||
|
||||
with engine.connect() as conn:
|
||||
conn.execute(text(f'DROP SCHEMA IF EXISTS "{schema}" CASCADE'))
|
||||
conn.commit()
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_classifies_all_cases(
|
||||
current_head_rev: str,
|
||||
tenant_schema_at_head: str,
|
||||
tenant_schema_empty: str,
|
||||
tenant_schema_stale_rev: str,
|
||||
) -> None:
|
||||
"""Correctly classifies all three schema states:
|
||||
- at head → excluded
|
||||
- no table → included (needs migration)
|
||||
- stale rev → included (needs migration)
|
||||
"""
|
||||
all_schemas = [tenant_schema_at_head, tenant_schema_empty, tenant_schema_stale_rev]
|
||||
result = get_schemas_needing_migration(all_schemas, current_head_rev)
|
||||
|
||||
assert tenant_schema_at_head not in result
|
||||
assert tenant_schema_empty in result
|
||||
assert tenant_schema_stale_rev in result
|
||||
|
||||
|
||||
def test_idempotent(
|
||||
current_head_rev: str,
|
||||
tenant_schema_at_head: str,
|
||||
tenant_schema_empty: str,
|
||||
) -> None:
|
||||
"""Calling the function twice returns the same result.
|
||||
|
||||
Verifies that the DROP TABLE IF EXISTS guards correctly clean up temp
|
||||
tables so a second call succeeds even if the first left state behind.
|
||||
"""
|
||||
schemas = [tenant_schema_at_head, tenant_schema_empty]
|
||||
|
||||
first = get_schemas_needing_migration(schemas, current_head_rev)
|
||||
second = get_schemas_needing_migration(schemas, current_head_rev)
|
||||
|
||||
assert first == second
|
||||
|
||||
|
||||
def test_empty_input(current_head_rev: str) -> None:
|
||||
"""An empty input list returns immediately without touching the DB."""
|
||||
assert get_schemas_needing_migration([], current_head_rev) == []
|
||||
@@ -3,6 +3,7 @@ from fastapi import HTTPException
|
||||
|
||||
import onyx.auth.users as users
|
||||
from onyx.auth.users import verify_email_domain
|
||||
from onyx.configs.constants import AuthType
|
||||
|
||||
|
||||
def test_verify_email_domain_allows_case_insensitive_match(
|
||||
@@ -35,3 +36,37 @@ def test_verify_email_domain_invalid_email_format(
|
||||
verify_email_domain("userexample.com") # missing '@'
|
||||
assert exc.value.status_code == 400
|
||||
assert "Email is not valid" in exc.value.detail
|
||||
|
||||
|
||||
def test_verify_email_domain_rejects_plus_addressing(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(users, "VALID_EMAIL_DOMAINS", [], raising=False)
|
||||
monkeypatch.setattr(users, "AUTH_TYPE", AuthType.CLOUD, raising=False)
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
verify_email_domain("user+tag@gmail.com")
|
||||
assert exc.value.status_code == 400
|
||||
assert "'+'" in str(exc.value.detail)
|
||||
|
||||
|
||||
def test_verify_email_domain_allows_plus_for_onyx_app(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(users, "VALID_EMAIL_DOMAINS", [], raising=False)
|
||||
monkeypatch.setattr(users, "AUTH_TYPE", AuthType.CLOUD, raising=False)
|
||||
|
||||
# Should not raise for onyx.app domain
|
||||
verify_email_domain("user+tag@onyx.app")
|
||||
|
||||
|
||||
def test_verify_email_domain_rejects_googlemail(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
) -> None:
|
||||
monkeypatch.setattr(users, "VALID_EMAIL_DOMAINS", [], raising=False)
|
||||
monkeypatch.setattr(users, "AUTH_TYPE", AuthType.CLOUD, raising=False)
|
||||
|
||||
with pytest.raises(HTTPException) as exc:
|
||||
verify_email_domain("user@googlemail.com")
|
||||
assert exc.value.status_code == 400
|
||||
assert "gmail.com" in str(exc.value.detail)
|
||||
|
||||
@@ -0,0 +1,459 @@
|
||||
"""Tests for per-page delta checkpointing in the SharePoint connector (P1-1).
|
||||
|
||||
Validates that:
|
||||
- Delta drives process one page per _load_from_checkpoint call
|
||||
- Checkpoints persist the delta next_link for resumption
|
||||
- Crash + resume skips already-processed pages
|
||||
- BFS (folder-scoped) drives process all items in one call
|
||||
- 410 Gone triggers a full-resync URL in the checkpoint
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import deque
|
||||
from collections.abc import Generator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentSource
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.sharepoint.connector import DriveItemData
|
||||
from onyx.connectors.sharepoint.connector import SharepointConnector
|
||||
from onyx.connectors.sharepoint.connector import SharepointConnectorCheckpoint
|
||||
from onyx.connectors.sharepoint.connector import SiteDescriptor
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
SITE_URL = "https://example.sharepoint.com/sites/sample"
|
||||
DRIVE_WEB_URL = f"{SITE_URL}/Shared Documents"
|
||||
DRIVE_ID = "fake-drive-id"
|
||||
|
||||
# Use a start time in the future so delta URLs include a timestamp token
|
||||
_START_TS = datetime(2025, 6, 1, tzinfo=timezone.utc).timestamp()
|
||||
_END_TS = datetime(2026, 1, 1, tzinfo=timezone.utc).timestamp()
|
||||
|
||||
# For BFS tests we use epoch so no token is generated
|
||||
_EPOCH_START: float = 0.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_item(item_id: str, name: str = "doc.pdf") -> DriveItemData:
|
||||
return DriveItemData(
|
||||
id=item_id,
|
||||
name=name,
|
||||
web_url=f"{SITE_URL}/{name}",
|
||||
parent_reference_path="/drives/d1/root:",
|
||||
drive_id=DRIVE_ID,
|
||||
)
|
||||
|
||||
|
||||
def _make_document(item: DriveItemData) -> Document:
|
||||
return Document(
|
||||
id=item.id,
|
||||
source=DocumentSource.SHAREPOINT,
|
||||
semantic_identifier=item.name,
|
||||
metadata={},
|
||||
sections=[TextSection(link=item.web_url, text="content")],
|
||||
)
|
||||
|
||||
|
||||
def _consume_generator(
|
||||
gen: Generator[Any, None, SharepointConnectorCheckpoint],
|
||||
) -> tuple[list[Any], SharepointConnectorCheckpoint]:
|
||||
"""Exhaust a _load_from_checkpoint generator.
|
||||
|
||||
Returns (yielded_items, returned_checkpoint).
|
||||
"""
|
||||
yielded: list[Any] = []
|
||||
try:
|
||||
while True:
|
||||
yielded.append(next(gen))
|
||||
except StopIteration as e:
|
||||
return yielded, e.value
|
||||
|
||||
|
||||
def _docs_from(yielded: list[Any]) -> list[Document]:
|
||||
return [y for y in yielded if isinstance(y, Document)]
|
||||
|
||||
|
||||
def _failures_from(yielded: list[Any]) -> list[ConnectorFailure]:
|
||||
return [y for y in yielded if isinstance(y, ConnectorFailure)]
|
||||
|
||||
|
||||
def _build_ready_checkpoint(
|
||||
drive_names: list[str] | None = None,
|
||||
folder_path: str | None = None,
|
||||
) -> SharepointConnectorCheckpoint:
|
||||
"""Checkpoint ready for Phase 3 (sites initialised, drives queued)."""
|
||||
cp = SharepointConnectorCheckpoint(has_more=True)
|
||||
cp.cached_site_descriptors = deque()
|
||||
cp.current_site_descriptor = SiteDescriptor(
|
||||
url=SITE_URL,
|
||||
drive_name=None,
|
||||
folder_path=folder_path,
|
||||
)
|
||||
cp.cached_drive_names = deque(drive_names or ["Documents"])
|
||||
cp.process_site_pages = False
|
||||
return cp
|
||||
|
||||
|
||||
def _setup_connector(monkeypatch: pytest.MonkeyPatch) -> SharepointConnector:
|
||||
"""Create a connector with common methods mocked."""
|
||||
connector = SharepointConnector()
|
||||
connector._graph_client = object()
|
||||
connector.include_site_pages = False
|
||||
|
||||
def fake_resolve_drive(
|
||||
self: SharepointConnector, # noqa: ARG001
|
||||
site_descriptor: SiteDescriptor, # noqa: ARG001
|
||||
drive_name: str, # noqa: ARG001
|
||||
) -> tuple[str, str | None]:
|
||||
return (DRIVE_ID, DRIVE_WEB_URL)
|
||||
|
||||
def fake_get_access_token(self: SharepointConnector) -> str: # noqa: ARG001
|
||||
return "fake-access-token"
|
||||
|
||||
monkeypatch.setattr(SharepointConnector, "_resolve_drive", fake_resolve_drive)
|
||||
monkeypatch.setattr(
|
||||
SharepointConnector, "_get_graph_access_token", fake_get_access_token
|
||||
)
|
||||
|
||||
return connector
|
||||
|
||||
|
||||
def _mock_convert(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Replace _convert_driveitem_to_document_with_permissions with a trivial stub."""
|
||||
|
||||
def fake_convert(
|
||||
driveitem: DriveItemData,
|
||||
drive_name: str, # noqa: ARG001
|
||||
ctx: Any = None, # noqa: ARG001
|
||||
graph_client: Any = None, # noqa: ARG001
|
||||
graph_api_base: str = "", # noqa: ARG001
|
||||
include_permissions: bool = False, # noqa: ARG001
|
||||
parent_hierarchy_raw_node_id: str | None = None, # noqa: ARG001
|
||||
access_token: str | None = None, # noqa: ARG001
|
||||
) -> Document:
|
||||
return _make_document(driveitem)
|
||||
|
||||
monkeypatch.setattr(
|
||||
"onyx.connectors.sharepoint.connector"
|
||||
"._convert_driveitem_to_document_with_permissions",
|
||||
fake_convert,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDeltaPerPageCheckpointing:
|
||||
"""Delta (non-folder-scoped) drives should process one API page per
|
||||
_load_from_checkpoint call, persisting the next-link in between."""
|
||||
|
||||
def test_processes_one_page_per_cycle(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
connector = _setup_connector(monkeypatch)
|
||||
_mock_convert(monkeypatch)
|
||||
|
||||
items_p1 = [_make_item("a"), _make_item("b")]
|
||||
items_p2 = [_make_item("c")]
|
||||
items_p3 = [_make_item("d"), _make_item("e")]
|
||||
|
||||
call_count = 0
|
||||
|
||||
def fake_fetch_page(
|
||||
self: SharepointConnector, # noqa: ARG001
|
||||
page_url: str, # noqa: ARG001
|
||||
drive_id: str, # noqa: ARG001
|
||||
start: datetime | None = None, # noqa: ARG001
|
||||
end: datetime | None = None, # noqa: ARG001
|
||||
page_size: int = 200, # noqa: ARG001
|
||||
) -> tuple[list[DriveItemData], str | None]:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
return items_p1, "https://graph.microsoft.com/next2"
|
||||
if call_count == 2:
|
||||
return items_p2, "https://graph.microsoft.com/next3"
|
||||
return items_p3, None
|
||||
|
||||
monkeypatch.setattr(
|
||||
SharepointConnector, "_fetch_one_delta_page", fake_fetch_page
|
||||
)
|
||||
|
||||
checkpoint = _build_ready_checkpoint()
|
||||
|
||||
# Call 1: Phase 3a inits drive, Phase 3b processes page 1
|
||||
gen = connector._load_from_checkpoint(
|
||||
_START_TS, _END_TS, checkpoint, include_permissions=False
|
||||
)
|
||||
yielded, checkpoint = _consume_generator(gen)
|
||||
assert len(_docs_from(yielded)) == 2
|
||||
assert (
|
||||
checkpoint.current_drive_delta_next_link
|
||||
== "https://graph.microsoft.com/next2"
|
||||
)
|
||||
assert checkpoint.current_drive_id == DRIVE_ID
|
||||
assert checkpoint.has_more is True
|
||||
|
||||
# Call 2: Phase 3b processes page 2
|
||||
gen = connector._load_from_checkpoint(
|
||||
_START_TS, _END_TS, checkpoint, include_permissions=False
|
||||
)
|
||||
yielded, checkpoint = _consume_generator(gen)
|
||||
assert len(_docs_from(yielded)) == 1
|
||||
assert (
|
||||
checkpoint.current_drive_delta_next_link
|
||||
== "https://graph.microsoft.com/next3"
|
||||
)
|
||||
|
||||
# Call 3: Phase 3b processes page 3 (last)
|
||||
gen = connector._load_from_checkpoint(
|
||||
_START_TS, _END_TS, checkpoint, include_permissions=False
|
||||
)
|
||||
yielded, checkpoint = _consume_generator(gen)
|
||||
assert len(_docs_from(yielded)) == 2
|
||||
assert checkpoint.current_drive_name is None
|
||||
assert checkpoint.current_drive_id is None
|
||||
assert checkpoint.current_drive_delta_next_link is None
|
||||
|
||||
def test_resume_after_simulated_crash(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""Serialise the checkpoint after page 1, create a fresh connector,
|
||||
and verify page 2 is fetched using the saved next-link."""
|
||||
connector = _setup_connector(monkeypatch)
|
||||
_mock_convert(monkeypatch)
|
||||
|
||||
captured_urls: list[str] = []
|
||||
call_count = 0
|
||||
|
||||
def fake_fetch_page(
|
||||
self: SharepointConnector, # noqa: ARG001
|
||||
page_url: str,
|
||||
drive_id: str, # noqa: ARG001
|
||||
start: datetime | None = None, # noqa: ARG001
|
||||
end: datetime | None = None, # noqa: ARG001
|
||||
page_size: int = 200, # noqa: ARG001
|
||||
) -> tuple[list[DriveItemData], str | None]:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
captured_urls.append(page_url)
|
||||
if call_count == 1:
|
||||
return [_make_item("a")], "https://graph.microsoft.com/next2"
|
||||
return [_make_item("b")], None
|
||||
|
||||
monkeypatch.setattr(
|
||||
SharepointConnector, "_fetch_one_delta_page", fake_fetch_page
|
||||
)
|
||||
|
||||
# Process page 1
|
||||
checkpoint = _build_ready_checkpoint()
|
||||
gen = connector._load_from_checkpoint(
|
||||
_START_TS, _END_TS, checkpoint, include_permissions=False
|
||||
)
|
||||
_, checkpoint = _consume_generator(gen)
|
||||
assert (
|
||||
checkpoint.current_drive_delta_next_link
|
||||
== "https://graph.microsoft.com/next2"
|
||||
)
|
||||
|
||||
# --- Simulate crash: serialise & deserialise checkpoint ---
|
||||
saved_json = checkpoint.model_dump_json()
|
||||
restored = SharepointConnectorCheckpoint.model_validate_json(saved_json)
|
||||
|
||||
# New connector instance (as if process restarted)
|
||||
connector2 = _setup_connector(monkeypatch)
|
||||
_mock_convert(monkeypatch)
|
||||
monkeypatch.setattr(
|
||||
SharepointConnector, "_fetch_one_delta_page", fake_fetch_page
|
||||
)
|
||||
|
||||
# Resume — should pick up from next2
|
||||
gen = connector2._load_from_checkpoint(
|
||||
_START_TS, _END_TS, restored, include_permissions=False
|
||||
)
|
||||
yielded, final_cp = _consume_generator(gen)
|
||||
|
||||
docs = _docs_from(yielded)
|
||||
assert len(docs) == 1
|
||||
assert docs[0].id == "b"
|
||||
assert captured_urls[-1] == "https://graph.microsoft.com/next2"
|
||||
assert final_cp.current_drive_name is None
|
||||
assert final_cp.current_drive_delta_next_link is None
|
||||
|
||||
def test_single_page_drive_completes_in_one_cycle(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
"""A drive with only one delta page should init + process + clear
|
||||
in a single _load_from_checkpoint call."""
|
||||
connector = _setup_connector(monkeypatch)
|
||||
_mock_convert(monkeypatch)
|
||||
|
||||
def fake_fetch_page(
|
||||
self: SharepointConnector, # noqa: ARG001
|
||||
page_url: str, # noqa: ARG001
|
||||
drive_id: str, # noqa: ARG001
|
||||
start: datetime | None = None, # noqa: ARG001
|
||||
end: datetime | None = None, # noqa: ARG001
|
||||
page_size: int = 200, # noqa: ARG001
|
||||
) -> tuple[list[DriveItemData], str | None]:
|
||||
return [_make_item("only")], None
|
||||
|
||||
monkeypatch.setattr(
|
||||
SharepointConnector, "_fetch_one_delta_page", fake_fetch_page
|
||||
)
|
||||
|
||||
checkpoint = _build_ready_checkpoint()
|
||||
gen = connector._load_from_checkpoint(
|
||||
_START_TS, _END_TS, checkpoint, include_permissions=False
|
||||
)
|
||||
yielded, final_cp = _consume_generator(gen)
|
||||
|
||||
assert len(_docs_from(yielded)) == 1
|
||||
assert final_cp.current_drive_name is None
|
||||
assert final_cp.current_drive_id is None
|
||||
assert final_cp.current_drive_delta_next_link is None
|
||||
|
||||
|
||||
class TestBfsPathNoCheckpointing:
|
||||
"""Folder-scoped (BFS) drives should process all items in one call
|
||||
because the BFS queue cannot be cheaply serialised."""
|
||||
|
||||
def test_bfs_processes_all_at_once(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
connector = _setup_connector(monkeypatch)
|
||||
_mock_convert(monkeypatch)
|
||||
|
||||
items = [_make_item("x"), _make_item("y"), _make_item("z")]
|
||||
|
||||
def fake_iter_paged(
|
||||
self: SharepointConnector, # noqa: ARG001
|
||||
drive_id: str, # noqa: ARG001
|
||||
folder_path: str | None = None, # noqa: ARG001
|
||||
start: datetime | None = None, # noqa: ARG001
|
||||
end: datetime | None = None, # noqa: ARG001
|
||||
page_size: int = 200, # noqa: ARG001
|
||||
) -> Generator[DriveItemData, None, None]:
|
||||
yield from items
|
||||
|
||||
monkeypatch.setattr(
|
||||
SharepointConnector, "_iter_drive_items_paged", fake_iter_paged
|
||||
)
|
||||
|
||||
checkpoint = _build_ready_checkpoint(folder_path="Engineering/Docs")
|
||||
gen = connector._load_from_checkpoint(
|
||||
_EPOCH_START, _END_TS, checkpoint, include_permissions=False
|
||||
)
|
||||
yielded, final_cp = _consume_generator(gen)
|
||||
|
||||
assert len(_docs_from(yielded)) == 3
|
||||
assert final_cp.current_drive_name is None
|
||||
assert final_cp.current_drive_id is None
|
||||
assert final_cp.current_drive_delta_next_link is None
|
||||
|
||||
|
||||
class TestDelta410GoneResync:
|
||||
"""On 410 Gone the checkpoint should be updated with a full-resync URL
|
||||
and the next cycle should re-enumerate from scratch."""
|
||||
|
||||
def test_410_stores_full_resync_url(self, monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
connector = _setup_connector(monkeypatch)
|
||||
_mock_convert(monkeypatch)
|
||||
|
||||
call_count = 0
|
||||
|
||||
def fake_fetch_page(
|
||||
self: SharepointConnector, # noqa: ARG001
|
||||
page_url: str, # noqa: ARG001
|
||||
drive_id: str,
|
||||
start: datetime | None = None, # noqa: ARG001
|
||||
end: datetime | None = None, # noqa: ARG001
|
||||
page_size: int = 200,
|
||||
) -> tuple[list[DriveItemData], str | None]:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
# Simulate the 410 handler returning a full-resync URL
|
||||
full_url = (
|
||||
f"https://graph.microsoft.com/v1.0/drives/{drive_id}"
|
||||
f"/root/delta?$top={page_size}"
|
||||
)
|
||||
return [], full_url
|
||||
return [_make_item("recovered")], None
|
||||
|
||||
monkeypatch.setattr(
|
||||
SharepointConnector, "_fetch_one_delta_page", fake_fetch_page
|
||||
)
|
||||
|
||||
checkpoint = _build_ready_checkpoint()
|
||||
|
||||
# Call 1: 3a inits, 3b gets empty page + resync URL
|
||||
gen = connector._load_from_checkpoint(
|
||||
_START_TS, _END_TS, checkpoint, include_permissions=False
|
||||
)
|
||||
yielded, checkpoint = _consume_generator(gen)
|
||||
assert len(_docs_from(yielded)) == 0
|
||||
assert checkpoint.current_drive_delta_next_link is not None
|
||||
assert "token=" not in checkpoint.current_drive_delta_next_link
|
||||
|
||||
# Call 2: processes the full resync
|
||||
gen = connector._load_from_checkpoint(
|
||||
_START_TS, _END_TS, checkpoint, include_permissions=False
|
||||
)
|
||||
yielded, checkpoint = _consume_generator(gen)
|
||||
docs = _docs_from(yielded)
|
||||
assert len(docs) == 1
|
||||
assert docs[0].id == "recovered"
|
||||
assert checkpoint.current_drive_name is None
|
||||
|
||||
|
||||
class TestDeltaPageFetchFailure:
|
||||
"""If _fetch_one_delta_page raises, the drive should be abandoned with a
|
||||
ConnectorFailure and the checkpoint should be cleared for the next drive."""
|
||||
|
||||
def test_page_fetch_error_yields_failure_and_clears_state(
|
||||
self, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
connector = _setup_connector(monkeypatch)
|
||||
_mock_convert(monkeypatch)
|
||||
|
||||
def fake_fetch_page(
|
||||
self: SharepointConnector, # noqa: ARG001
|
||||
page_url: str, # noqa: ARG001
|
||||
drive_id: str, # noqa: ARG001
|
||||
start: datetime | None = None, # noqa: ARG001
|
||||
end: datetime | None = None, # noqa: ARG001
|
||||
page_size: int = 200, # noqa: ARG001
|
||||
) -> tuple[list[DriveItemData], str | None]:
|
||||
raise RuntimeError("network blip")
|
||||
|
||||
monkeypatch.setattr(
|
||||
SharepointConnector, "_fetch_one_delta_page", fake_fetch_page
|
||||
)
|
||||
|
||||
checkpoint = _build_ready_checkpoint()
|
||||
gen = connector._load_from_checkpoint(
|
||||
_START_TS, _END_TS, checkpoint, include_permissions=False
|
||||
)
|
||||
yielded, final_cp = _consume_generator(gen)
|
||||
|
||||
failures = _failures_from(yielded)
|
||||
assert len(failures) == 1
|
||||
assert "network blip" in failures[0].failure_message
|
||||
assert final_cp.current_drive_name is None
|
||||
assert final_cp.current_drive_id is None
|
||||
assert final_cp.current_drive_delta_next_link is None
|
||||
@@ -192,14 +192,15 @@ def test_load_from_checkpoint_maps_drive_name(monkeypatch: pytest.MonkeyPatch) -
|
||||
"https://example.sharepoint.com/sites/sample/Documents",
|
||||
)
|
||||
|
||||
def fake_get_drive_items(
|
||||
def fake_fetch_one_delta_page(
|
||||
self: SharepointConnector, # noqa: ARG001
|
||||
site_descriptor: SiteDescriptor, # noqa: ARG001
|
||||
page_url: str, # noqa: ARG001
|
||||
drive_id: str, # noqa: ARG001
|
||||
start: datetime | None, # noqa: ARG001
|
||||
end: datetime | None, # noqa: ARG001
|
||||
) -> Generator[DriveItemData, None, None]:
|
||||
yield sample_item
|
||||
start: datetime | None = None, # noqa: ARG001
|
||||
end: datetime | None = None, # noqa: ARG001
|
||||
page_size: int = 200, # noqa: ARG001
|
||||
) -> tuple[list[DriveItemData], str | None]:
|
||||
return [sample_item], None
|
||||
|
||||
def fake_convert(
|
||||
driveitem: DriveItemData, # noqa: ARG001
|
||||
@@ -230,8 +231,8 @@ def test_load_from_checkpoint_maps_drive_name(monkeypatch: pytest.MonkeyPatch) -
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
SharepointConnector,
|
||||
"_get_drive_items_for_drive_id",
|
||||
fake_get_drive_items,
|
||||
"_fetch_one_delta_page",
|
||||
fake_fetch_one_delta_page,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"onyx.connectors.sharepoint.connector._convert_driveitem_to_document_with_permissions",
|
||||
|
||||
@@ -97,6 +97,7 @@ class TestScimDALUserMappings:
|
||||
assert model_attrs(added_obj) == {
|
||||
"external_id": "ext-1",
|
||||
"user_id": user_id,
|
||||
"scim_username": None,
|
||||
}
|
||||
|
||||
def test_delete_user_mapping(
|
||||
|
||||
@@ -15,7 +15,10 @@ from sqlalchemy.orm import Session
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimName
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
from ee.onyx.server.scim.providers.base import ScimProvider
|
||||
from ee.onyx.server.scim.providers.okta import OktaProvider
|
||||
from onyx.db.models import ScimToken
|
||||
from onyx.db.models import ScimUserMapping
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import UserGroup
|
||||
from onyx.db.models import UserRole
|
||||
@@ -35,6 +38,12 @@ def mock_token() -> MagicMock:
|
||||
return token
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def provider() -> ScimProvider:
|
||||
"""An OktaProvider instance for endpoint tests."""
|
||||
return OktaProvider()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dal() -> Generator[MagicMock, None, None]:
|
||||
"""Patch ScimDAL construction in api module and yield the mock instance."""
|
||||
@@ -53,6 +62,9 @@ def mock_dal() -> Generator[MagicMock, None, None]:
|
||||
dal.get_group_mapping_by_external_id.return_value = None
|
||||
dal.get_group_members.return_value = []
|
||||
dal.list_groups.return_value = ([], 0)
|
||||
# User-group relationship defaults
|
||||
dal.get_user_groups.return_value = []
|
||||
dal.get_users_groups_batch.return_value = {}
|
||||
yield dal
|
||||
|
||||
|
||||
@@ -96,6 +108,16 @@ def make_db_group(**kwargs: Any) -> MagicMock:
|
||||
return group
|
||||
|
||||
|
||||
def make_user_mapping(**kwargs: Any) -> MagicMock:
|
||||
"""Build a mock ScimUserMapping ORM object with configurable attributes."""
|
||||
mapping = MagicMock(spec=ScimUserMapping)
|
||||
mapping.id = kwargs.get("id", 1)
|
||||
mapping.external_id = kwargs.get("external_id", "ext-default")
|
||||
mapping.user_id = kwargs.get("user_id", uuid4())
|
||||
mapping.scim_username = kwargs.get("scim_username", None)
|
||||
return mapping
|
||||
|
||||
|
||||
def assert_scim_error(result: object, expected_status: int) -> None:
|
||||
"""Assert *result* is a JSONResponse with the given status code."""
|
||||
assert isinstance(result, JSONResponse)
|
||||
|
||||
@@ -21,6 +21,7 @@ from ee.onyx.server.scim.models import ScimPatchOperation
|
||||
from ee.onyx.server.scim.models import ScimPatchOperationType
|
||||
from ee.onyx.server.scim.models import ScimPatchRequest
|
||||
from ee.onyx.server.scim.patch import ScimPatchError
|
||||
from ee.onyx.server.scim.providers.base import ScimProvider
|
||||
from tests.unit.onyx.server.scim.conftest import assert_scim_error
|
||||
from tests.unit.onyx.server.scim.conftest import make_db_group
|
||||
from tests.unit.onyx.server.scim.conftest import make_scim_group
|
||||
@@ -34,6 +35,7 @@ class TestListGroups:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
mock_dal.list_groups.return_value = ([], 0)
|
||||
|
||||
@@ -42,6 +44,7 @@ class TestListGroups:
|
||||
startIndex=1,
|
||||
count=100,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -54,6 +57,7 @@ class TestListGroups:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
mock_dal.list_groups.side_effect = ValueError(
|
||||
"Unsupported filter attribute: userName"
|
||||
@@ -64,6 +68,7 @@ class TestListGroups:
|
||||
startIndex=1,
|
||||
count=100,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -74,6 +79,7 @@ class TestListGroups:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
group = make_db_group(id=5, name="Engineering")
|
||||
uid = uuid4()
|
||||
@@ -85,6 +91,7 @@ class TestListGroups:
|
||||
startIndex=1,
|
||||
count=100,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -106,6 +113,7 @@ class TestGetGroup:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
group = make_db_group(id=5, name="Engineering")
|
||||
mock_dal.get_group.return_value = group
|
||||
@@ -114,6 +122,7 @@ class TestGetGroup:
|
||||
result = get_group(
|
||||
group_id="5",
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -126,10 +135,12 @@ class TestGetGroup:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock, # noqa: ARG002
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
result = get_group(
|
||||
group_id="not-a-number",
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -140,12 +151,14 @@ class TestGetGroup:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
mock_dal.get_group.return_value = None
|
||||
|
||||
result = get_group(
|
||||
group_id="999",
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -162,6 +175,7 @@ class TestCreateGroup:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
mock_dal.get_group_by_name.return_value = None
|
||||
mock_validate.return_value = ([], None)
|
||||
@@ -172,6 +186,7 @@ class TestCreateGroup:
|
||||
result = create_group(
|
||||
group_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -185,6 +200,7 @@ class TestCreateGroup:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
mock_dal.get_group_by_name.return_value = make_db_group()
|
||||
resource = make_scim_group()
|
||||
@@ -192,6 +208,7 @@ class TestCreateGroup:
|
||||
result = create_group(
|
||||
group_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -204,6 +221,7 @@ class TestCreateGroup:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
mock_dal.get_group_by_name.return_value = None
|
||||
mock_validate.return_value = ([], "Invalid member ID: bad-uuid")
|
||||
@@ -213,6 +231,7 @@ class TestCreateGroup:
|
||||
result = create_group(
|
||||
group_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -225,6 +244,7 @@ class TestCreateGroup:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
mock_dal.get_group_by_name.return_value = None
|
||||
uid = uuid4()
|
||||
@@ -235,6 +255,7 @@ class TestCreateGroup:
|
||||
result = create_group(
|
||||
group_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -247,6 +268,7 @@ class TestCreateGroup:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
mock_dal.get_group_by_name.return_value = None
|
||||
mock_validate.return_value = ([], None)
|
||||
@@ -257,6 +279,7 @@ class TestCreateGroup:
|
||||
result = create_group(
|
||||
group_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -274,6 +297,7 @@ class TestReplaceGroup:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
group = make_db_group(id=5, name="Old Name")
|
||||
mock_dal.get_group.return_value = group
|
||||
@@ -286,6 +310,7 @@ class TestReplaceGroup:
|
||||
group_id="5",
|
||||
group_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -299,6 +324,7 @@ class TestReplaceGroup:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
mock_dal.get_group.return_value = None
|
||||
|
||||
@@ -306,6 +332,7 @@ class TestReplaceGroup:
|
||||
group_id="999",
|
||||
group_resource=make_scim_group(),
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -318,6 +345,7 @@ class TestReplaceGroup:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
group = make_db_group(id=5)
|
||||
mock_dal.get_group.return_value = group
|
||||
@@ -329,6 +357,7 @@ class TestReplaceGroup:
|
||||
group_id="5",
|
||||
group_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -341,6 +370,7 @@ class TestReplaceGroup:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
group = make_db_group(id=5)
|
||||
mock_dal.get_group.return_value = group
|
||||
@@ -353,6 +383,7 @@ class TestReplaceGroup:
|
||||
group_id="5",
|
||||
group_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -369,6 +400,7 @@ class TestPatchGroup:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
group = make_db_group(id=5, name="Old Name")
|
||||
mock_dal.get_group.return_value = group
|
||||
@@ -391,6 +423,7 @@ class TestPatchGroup:
|
||||
group_id="5",
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -402,6 +435,7 @@ class TestPatchGroup:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
mock_dal.get_group.return_value = None
|
||||
|
||||
@@ -419,6 +453,7 @@ class TestPatchGroup:
|
||||
group_id="999",
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -431,6 +466,7 @@ class TestPatchGroup:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
group = make_db_group(id=5)
|
||||
mock_dal.get_group.return_value = group
|
||||
@@ -452,6 +488,7 @@ class TestPatchGroup:
|
||||
group_id="5",
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -464,6 +501,7 @@ class TestPatchGroup:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
group = make_db_group(id=5)
|
||||
mock_dal.get_group.return_value = group
|
||||
@@ -483,7 +521,7 @@ class TestPatchGroup:
|
||||
ScimPatchOperation(
|
||||
op=ScimPatchOperationType.ADD,
|
||||
path="members",
|
||||
value=[{"value": uid}],
|
||||
value=[ScimGroupMember(value=uid)],
|
||||
)
|
||||
]
|
||||
)
|
||||
@@ -492,6 +530,7 @@ class TestPatchGroup:
|
||||
group_id="5",
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -506,6 +545,7 @@ class TestPatchGroup:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
group = make_db_group(id=5)
|
||||
mock_dal.get_group.return_value = group
|
||||
@@ -525,7 +565,7 @@ class TestPatchGroup:
|
||||
ScimPatchOperation(
|
||||
op=ScimPatchOperationType.ADD,
|
||||
path="members",
|
||||
value=[{"value": str(uid)}],
|
||||
value=[ScimGroupMember(value=str(uid))],
|
||||
)
|
||||
]
|
||||
)
|
||||
@@ -534,6 +574,7 @@ class TestPatchGroup:
|
||||
group_id="5",
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -546,6 +587,7 @@ class TestPatchGroup:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
group = make_db_group(id=5)
|
||||
mock_dal.get_group.return_value = group
|
||||
@@ -568,6 +610,7 @@ class TestPatchGroup:
|
||||
group_id="5",
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
|
||||
@@ -2,13 +2,19 @@ import pytest
|
||||
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimMeta
|
||||
from ee.onyx.server.scim.models import ScimName
|
||||
from ee.onyx.server.scim.models import ScimPatchOperation
|
||||
from ee.onyx.server.scim.models import ScimPatchOperationType
|
||||
from ee.onyx.server.scim.models import ScimPatchResourceValue
|
||||
from ee.onyx.server.scim.models import ScimPatchValue
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
from ee.onyx.server.scim.patch import apply_group_patch
|
||||
from ee.onyx.server.scim.patch import apply_user_patch
|
||||
from ee.onyx.server.scim.patch import ScimPatchError
|
||||
from ee.onyx.server.scim.providers.okta import OktaProvider
|
||||
|
||||
_OKTA_IGNORED = OktaProvider().ignored_patch_paths
|
||||
|
||||
|
||||
def _make_user(**kwargs: object) -> ScimUserResource:
|
||||
@@ -29,14 +35,14 @@ def _make_group(**kwargs: object) -> ScimGroupResource:
|
||||
|
||||
def _replace_op(
|
||||
path: str | None = None,
|
||||
value: str | bool | dict | list | None = None,
|
||||
value: ScimPatchValue = None,
|
||||
) -> ScimPatchOperation:
|
||||
return ScimPatchOperation(op=ScimPatchOperationType.REPLACE, path=path, value=value)
|
||||
|
||||
|
||||
def _add_op(
|
||||
path: str | None = None,
|
||||
value: str | bool | dict | list | None = None,
|
||||
value: ScimPatchValue = None,
|
||||
) -> ScimPatchOperation:
|
||||
return ScimPatchOperation(op=ScimPatchOperationType.ADD, path=path, value=value)
|
||||
|
||||
@@ -80,7 +86,12 @@ class TestApplyUserPatch:
|
||||
def test_replace_without_path_uses_dict(self) -> None:
|
||||
user = _make_user()
|
||||
result = apply_user_patch(
|
||||
[_replace_op(None, {"active": False, "userName": "new@example.com"})],
|
||||
[
|
||||
_replace_op(
|
||||
None,
|
||||
ScimPatchResourceValue(active=False, userName="new@example.com"),
|
||||
)
|
||||
],
|
||||
user,
|
||||
)
|
||||
assert result.active is False
|
||||
@@ -119,6 +130,86 @@ class TestApplyUserPatch:
|
||||
with pytest.raises(ScimPatchError, match="Unsupported operation"):
|
||||
apply_user_patch([_remove_op("active")], user)
|
||||
|
||||
def test_replace_without_path_ignores_id(self) -> None:
|
||||
"""Okta sends 'id' alongside actual changes — it should be silently ignored."""
|
||||
user = _make_user()
|
||||
result = apply_user_patch(
|
||||
[_replace_op(None, ScimPatchResourceValue(active=False, id="some-uuid"))],
|
||||
user,
|
||||
ignored_paths=_OKTA_IGNORED,
|
||||
)
|
||||
assert result.active is False
|
||||
|
||||
def test_replace_without_path_ignores_schemas(self) -> None:
|
||||
"""The 'schemas' key in a value dict should be silently ignored."""
|
||||
user = _make_user()
|
||||
result = apply_user_patch(
|
||||
[
|
||||
_replace_op(
|
||||
None,
|
||||
ScimPatchResourceValue(
|
||||
active=False,
|
||||
schemas=["urn:ietf:params:scim:schemas:core:2.0:User"],
|
||||
),
|
||||
)
|
||||
],
|
||||
user,
|
||||
ignored_paths=_OKTA_IGNORED,
|
||||
)
|
||||
assert result.active is False
|
||||
|
||||
def test_okta_deactivation_payload(self) -> None:
|
||||
"""Exact Okta deactivation payload: path-less replace with id + active."""
|
||||
user = _make_user()
|
||||
result = apply_user_patch(
|
||||
[
|
||||
_replace_op(
|
||||
None,
|
||||
ScimPatchResourceValue(id="abc-123", active=False),
|
||||
)
|
||||
],
|
||||
user,
|
||||
ignored_paths=_OKTA_IGNORED,
|
||||
)
|
||||
assert result.active is False
|
||||
assert result.userName == "test@example.com"
|
||||
|
||||
def test_replace_displayname(self) -> None:
|
||||
user = _make_user()
|
||||
result = apply_user_patch(
|
||||
[_replace_op("displayName", "New Display Name")], user
|
||||
)
|
||||
assert result.displayName == "New Display Name"
|
||||
assert result.name is not None
|
||||
assert result.name.formatted == "New Display Name"
|
||||
|
||||
def test_replace_without_path_complex_value_dict(self) -> None:
|
||||
"""Okta sends id/schemas/meta alongside actual changes — complex types
|
||||
(lists, nested dicts) must not cause Pydantic validation errors."""
|
||||
user = _make_user()
|
||||
result = apply_user_patch(
|
||||
[
|
||||
_replace_op(
|
||||
None,
|
||||
ScimPatchResourceValue(
|
||||
active=False,
|
||||
id="some-uuid",
|
||||
schemas=["urn:ietf:params:scim:schemas:core:2.0:User"],
|
||||
meta=ScimMeta(resourceType="User"),
|
||||
),
|
||||
)
|
||||
],
|
||||
user,
|
||||
ignored_paths=_OKTA_IGNORED,
|
||||
)
|
||||
assert result.active is False
|
||||
assert result.userName == "test@example.com"
|
||||
|
||||
def test_add_operation_works_like_replace(self) -> None:
|
||||
user = _make_user()
|
||||
result = apply_user_patch([_add_op("externalId", "ext-456")], user)
|
||||
assert result.externalId == "ext-456"
|
||||
|
||||
|
||||
class TestApplyGroupPatch:
|
||||
"""Tests for SCIM group PATCH operations."""
|
||||
@@ -135,7 +226,12 @@ class TestApplyGroupPatch:
|
||||
def test_add_members(self) -> None:
|
||||
group = _make_group()
|
||||
result, added, removed = apply_group_patch(
|
||||
[_add_op("members", [{"value": "user-1"}, {"value": "user-2"}])],
|
||||
[
|
||||
_add_op(
|
||||
"members",
|
||||
[ScimGroupMember(value="user-1"), ScimGroupMember(value="user-2")],
|
||||
)
|
||||
],
|
||||
group,
|
||||
)
|
||||
assert len(result.members) == 2
|
||||
@@ -145,7 +241,7 @@ class TestApplyGroupPatch:
|
||||
def test_add_members_without_path(self) -> None:
|
||||
group = _make_group()
|
||||
result, added, _ = apply_group_patch(
|
||||
[_add_op(None, [{"value": "user-1"}])],
|
||||
[_add_op(None, [ScimGroupMember(value="user-1")])],
|
||||
group,
|
||||
)
|
||||
assert len(result.members) == 1
|
||||
@@ -154,7 +250,12 @@ class TestApplyGroupPatch:
|
||||
def test_add_duplicate_member_skipped(self) -> None:
|
||||
group = _make_group(members=[ScimGroupMember(value="user-1")])
|
||||
result, added, _ = apply_group_patch(
|
||||
[_add_op("members", [{"value": "user-1"}, {"value": "user-2"}])],
|
||||
[
|
||||
_add_op(
|
||||
"members",
|
||||
[ScimGroupMember(value="user-1"), ScimGroupMember(value="user-2")],
|
||||
)
|
||||
],
|
||||
group,
|
||||
)
|
||||
assert len(result.members) == 2
|
||||
@@ -190,7 +291,7 @@ class TestApplyGroupPatch:
|
||||
result, added, removed = apply_group_patch(
|
||||
[
|
||||
_replace_op("displayName", "Renamed"),
|
||||
_add_op("members", [{"value": "user-2"}]),
|
||||
_add_op("members", [ScimGroupMember(value="user-2")]),
|
||||
_remove_op('members[value eq "user-1"]'),
|
||||
],
|
||||
group,
|
||||
@@ -221,7 +322,12 @@ class TestApplyGroupPatch:
|
||||
]
|
||||
)
|
||||
result, added, removed = apply_group_patch(
|
||||
[_replace_op("members", [{"value": "user-2"}, {"value": "user-3"}])],
|
||||
[
|
||||
_replace_op(
|
||||
"members",
|
||||
[ScimGroupMember(value="user-2"), ScimGroupMember(value="user-3")],
|
||||
)
|
||||
],
|
||||
group,
|
||||
)
|
||||
assert len(result.members) == 2
|
||||
@@ -256,3 +362,55 @@ class TestApplyGroupPatch:
|
||||
group = _make_group()
|
||||
apply_group_patch([_replace_op("displayName", "Changed")], group)
|
||||
assert group.displayName == "Engineering"
|
||||
|
||||
def test_replace_without_path_ignores_id(self) -> None:
|
||||
"""Group replace with 'id' in value dict should be silently ignored."""
|
||||
group = _make_group()
|
||||
result, _, _ = apply_group_patch(
|
||||
[
|
||||
_replace_op(
|
||||
None, ScimPatchResourceValue(displayName="Updated", id="some-id")
|
||||
)
|
||||
],
|
||||
group,
|
||||
ignored_paths=_OKTA_IGNORED,
|
||||
)
|
||||
assert result.displayName == "Updated"
|
||||
|
||||
def test_replace_without_path_ignores_schemas(self) -> None:
|
||||
group = _make_group()
|
||||
result, _, _ = apply_group_patch(
|
||||
[
|
||||
_replace_op(
|
||||
None,
|
||||
ScimPatchResourceValue(
|
||||
displayName="Updated",
|
||||
schemas=["urn:ietf:params:scim:schemas:core:2.0:Group"],
|
||||
),
|
||||
)
|
||||
],
|
||||
group,
|
||||
ignored_paths=_OKTA_IGNORED,
|
||||
)
|
||||
assert result.displayName == "Updated"
|
||||
|
||||
def test_replace_without_path_complex_value_dict(self) -> None:
|
||||
"""Group PATCH with complex types in value dict (lists, nested dicts)
|
||||
must not cause Pydantic validation errors."""
|
||||
group = _make_group()
|
||||
result, _, _ = apply_group_patch(
|
||||
[
|
||||
_replace_op(
|
||||
None,
|
||||
ScimPatchResourceValue(
|
||||
displayName="Updated",
|
||||
id="123",
|
||||
schemas=["urn:ietf:params:scim:schemas:core:2.0:Group"],
|
||||
meta=ScimMeta(resourceType="Group"),
|
||||
),
|
||||
)
|
||||
],
|
||||
group,
|
||||
ignored_paths=_OKTA_IGNORED,
|
||||
)
|
||||
assert result.displayName == "Updated"
|
||||
|
||||
167
backend/tests/unit/onyx/server/scim/test_providers.py
Normal file
167
backend/tests/unit/onyx/server/scim/test_providers.py
Normal file
@@ -0,0 +1,167 @@
|
||||
from unittest.mock import MagicMock
|
||||
from uuid import UUID
|
||||
from uuid import uuid4
|
||||
|
||||
from ee.onyx.server.scim.models import ScimEmail
|
||||
from ee.onyx.server.scim.models import ScimGroupMember
|
||||
from ee.onyx.server.scim.models import ScimGroupResource
|
||||
from ee.onyx.server.scim.models import ScimMeta
|
||||
from ee.onyx.server.scim.models import ScimName
|
||||
from ee.onyx.server.scim.models import ScimUserGroupRef
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
from ee.onyx.server.scim.providers.base import get_default_provider
|
||||
from ee.onyx.server.scim.providers.okta import OktaProvider
|
||||
|
||||
|
||||
def _make_mock_user(
|
||||
user_id: UUID | None = None,
|
||||
email: str = "test@example.com",
|
||||
personal_name: str | None = "Test User",
|
||||
is_active: bool = True,
|
||||
) -> MagicMock:
|
||||
user = MagicMock()
|
||||
user.id = user_id or uuid4()
|
||||
user.email = email
|
||||
user.personal_name = personal_name
|
||||
user.is_active = is_active
|
||||
return user
|
||||
|
||||
|
||||
def _make_mock_group(group_id: int = 42, name: str = "Engineering") -> MagicMock:
|
||||
group = MagicMock()
|
||||
group.id = group_id
|
||||
group.name = name
|
||||
return group
|
||||
|
||||
|
||||
class TestOktaProvider:
|
||||
def test_name(self) -> None:
|
||||
assert OktaProvider().name == "okta"
|
||||
|
||||
def test_ignored_patch_paths(self) -> None:
|
||||
assert OktaProvider().ignored_patch_paths == frozenset(
|
||||
{"id", "schemas", "meta"}
|
||||
)
|
||||
|
||||
def test_build_user_resource_basic(self) -> None:
|
||||
provider = OktaProvider()
|
||||
user = _make_mock_user()
|
||||
result = provider.build_user_resource(user, "ext-123")
|
||||
|
||||
assert result == ScimUserResource(
|
||||
id=str(user.id),
|
||||
externalId="ext-123",
|
||||
userName="test@example.com",
|
||||
name=ScimName(givenName="Test", familyName="User", formatted="Test User"),
|
||||
displayName="Test User",
|
||||
emails=[ScimEmail(value="test@example.com", type="work", primary=True)],
|
||||
active=True,
|
||||
groups=[],
|
||||
meta=ScimMeta(resourceType="User"),
|
||||
)
|
||||
|
||||
def test_build_user_resource_with_groups(self) -> None:
|
||||
provider = OktaProvider()
|
||||
user = _make_mock_user()
|
||||
groups = [(1, "Engineering"), (2, "Design")]
|
||||
result = provider.build_user_resource(user, "ext-123", groups=groups)
|
||||
|
||||
assert result.groups == [
|
||||
ScimUserGroupRef(value="1", display="Engineering"),
|
||||
ScimUserGroupRef(value="2", display="Design"),
|
||||
]
|
||||
|
||||
def test_build_user_resource_empty_groups(self) -> None:
|
||||
provider = OktaProvider()
|
||||
user = _make_mock_user()
|
||||
result = provider.build_user_resource(user, "ext-123", groups=[])
|
||||
|
||||
assert result.groups == []
|
||||
|
||||
def test_build_user_resource_no_groups(self) -> None:
|
||||
provider = OktaProvider()
|
||||
user = _make_mock_user()
|
||||
result = provider.build_user_resource(user, "ext-123")
|
||||
|
||||
assert result.groups == []
|
||||
|
||||
def test_build_user_resource_name_parsing(self) -> None:
|
||||
provider = OktaProvider()
|
||||
user = _make_mock_user(personal_name="Jane Doe")
|
||||
result = provider.build_user_resource(user, None)
|
||||
|
||||
assert result.name == ScimName(
|
||||
givenName="Jane", familyName="Doe", formatted="Jane Doe"
|
||||
)
|
||||
|
||||
def test_build_user_resource_single_name(self) -> None:
|
||||
provider = OktaProvider()
|
||||
user = _make_mock_user(personal_name="Madonna")
|
||||
result = provider.build_user_resource(user, None)
|
||||
|
||||
assert result.name == ScimName(
|
||||
givenName="Madonna", familyName=None, formatted="Madonna"
|
||||
)
|
||||
|
||||
def test_build_user_resource_no_name(self) -> None:
|
||||
provider = OktaProvider()
|
||||
user = _make_mock_user(personal_name=None)
|
||||
result = provider.build_user_resource(user, None)
|
||||
|
||||
assert result.name is None
|
||||
assert result.displayName is None
|
||||
|
||||
def test_build_user_resource_scim_username_preserves_case(self) -> None:
|
||||
"""When scim_username is set, userName and emails use original case."""
|
||||
provider = OktaProvider()
|
||||
user = _make_mock_user(email="alice@example.com")
|
||||
result = provider.build_user_resource(
|
||||
user, "ext-1", scim_username="Alice@Example.com"
|
||||
)
|
||||
|
||||
assert result.userName == "Alice@Example.com"
|
||||
assert result.emails[0].value == "Alice@Example.com"
|
||||
|
||||
def test_build_user_resource_scim_username_none_falls_back(self) -> None:
|
||||
"""When scim_username is None, userName falls back to user.email."""
|
||||
provider = OktaProvider()
|
||||
user = _make_mock_user(email="alice@example.com")
|
||||
result = provider.build_user_resource(user, "ext-1", scim_username=None)
|
||||
|
||||
assert result.userName == "alice@example.com"
|
||||
assert result.emails[0].value == "alice@example.com"
|
||||
|
||||
def test_build_group_resource(self) -> None:
|
||||
provider = OktaProvider()
|
||||
group = _make_mock_group()
|
||||
uid1, uid2 = uuid4(), uuid4()
|
||||
members: list[tuple[UUID, str | None]] = [
|
||||
(uid1, "alice@example.com"),
|
||||
(uid2, "bob@example.com"),
|
||||
]
|
||||
|
||||
result = provider.build_group_resource(group, members, "ext-g-1")
|
||||
|
||||
assert result == ScimGroupResource(
|
||||
id="42",
|
||||
externalId="ext-g-1",
|
||||
displayName="Engineering",
|
||||
members=[
|
||||
ScimGroupMember(value=str(uid1), display="alice@example.com"),
|
||||
ScimGroupMember(value=str(uid2), display="bob@example.com"),
|
||||
],
|
||||
meta=ScimMeta(resourceType="Group"),
|
||||
)
|
||||
|
||||
def test_build_group_resource_empty_members(self) -> None:
|
||||
provider = OktaProvider()
|
||||
group = _make_mock_group()
|
||||
result = provider.build_group_resource(group, [])
|
||||
|
||||
assert result.members == []
|
||||
|
||||
|
||||
class TestGetDefaultProvider:
|
||||
def test_returns_okta(self) -> None:
|
||||
provider = get_default_provider()
|
||||
assert isinstance(provider, OktaProvider)
|
||||
@@ -9,6 +9,7 @@ from uuid import uuid4
|
||||
from fastapi import Response
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from ee.onyx.server.scim.api import _scim_name_to_str
|
||||
from ee.onyx.server.scim.api import create_user
|
||||
from ee.onyx.server.scim.api import delete_user
|
||||
from ee.onyx.server.scim.api import get_user
|
||||
@@ -22,9 +23,11 @@ from ee.onyx.server.scim.models import ScimPatchOperationType
|
||||
from ee.onyx.server.scim.models import ScimPatchRequest
|
||||
from ee.onyx.server.scim.models import ScimUserResource
|
||||
from ee.onyx.server.scim.patch import ScimPatchError
|
||||
from ee.onyx.server.scim.providers.base import ScimProvider
|
||||
from tests.unit.onyx.server.scim.conftest import assert_scim_error
|
||||
from tests.unit.onyx.server.scim.conftest import make_db_user
|
||||
from tests.unit.onyx.server.scim.conftest import make_scim_user
|
||||
from tests.unit.onyx.server.scim.conftest import make_user_mapping
|
||||
|
||||
|
||||
class TestListUsers:
|
||||
@@ -35,6 +38,7 @@ class TestListUsers:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
mock_dal.list_users.return_value = ([], 0)
|
||||
|
||||
@@ -43,6 +47,7 @@ class TestListUsers:
|
||||
startIndex=1,
|
||||
count=100,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -55,15 +60,20 @@ class TestListUsers:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
user = make_db_user(email="alice@example.com", personal_name="Alice Smith")
|
||||
mock_dal.list_users.return_value = ([(user, "ext-abc")], 1)
|
||||
mapping = make_user_mapping(
|
||||
external_id="ext-abc", user_id=user.id, scim_username="Alice@example.com"
|
||||
)
|
||||
mock_dal.list_users.return_value = ([(user, mapping)], 1)
|
||||
|
||||
result = list_users(
|
||||
filter=None,
|
||||
startIndex=1,
|
||||
count=100,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -72,7 +82,7 @@ class TestListUsers:
|
||||
assert len(result.Resources) == 1
|
||||
resource = result.Resources[0]
|
||||
assert isinstance(resource, ScimUserResource)
|
||||
assert resource.userName == "alice@example.com"
|
||||
assert resource.userName == "Alice@example.com"
|
||||
assert resource.externalId == "ext-abc"
|
||||
|
||||
def test_unsupported_filter_attribute_returns_400(
|
||||
@@ -80,6 +90,7 @@ class TestListUsers:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
mock_dal.list_users.side_effect = ValueError(
|
||||
"Unsupported filter attribute: emails"
|
||||
@@ -90,6 +101,7 @@ class TestListUsers:
|
||||
startIndex=1,
|
||||
count=100,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -100,12 +112,14 @@ class TestListUsers:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock, # noqa: ARG002
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
result = list_users(
|
||||
filter="not a valid filter",
|
||||
startIndex=1,
|
||||
count=100,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -120,6 +134,7 @@ class TestGetUser:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
user = make_db_user(email="alice@example.com")
|
||||
mock_dal.get_user.return_value = user
|
||||
@@ -127,6 +142,7 @@ class TestGetUser:
|
||||
result = get_user(
|
||||
user_id=str(user.id),
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -139,10 +155,12 @@ class TestGetUser:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock, # noqa: ARG002
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
result = get_user(
|
||||
user_id="not-a-uuid",
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -153,12 +171,14 @@ class TestGetUser:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
mock_dal.get_user.return_value = None
|
||||
|
||||
result = get_user(
|
||||
user_id=str(uuid4()),
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -175,6 +195,7 @@ class TestCreateUser:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
mock_dal.get_user_by_email.return_value = None
|
||||
resource = make_scim_user(userName="new@example.com")
|
||||
@@ -182,6 +203,7 @@ class TestCreateUser:
|
||||
result = create_user(
|
||||
user_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -195,12 +217,14 @@ class TestCreateUser:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock, # noqa: ARG002
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
resource = make_scim_user(externalId=None)
|
||||
|
||||
result = create_user(
|
||||
user_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -213,6 +237,7 @@ class TestCreateUser:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
mock_dal.get_user_by_email.return_value = make_db_user()
|
||||
resource = make_scim_user()
|
||||
@@ -220,6 +245,7 @@ class TestCreateUser:
|
||||
result = create_user(
|
||||
user_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -232,6 +258,7 @@ class TestCreateUser:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
mock_dal.get_user_by_email.return_value = None
|
||||
mock_dal.add_user.side_effect = IntegrityError("dup", {}, Exception())
|
||||
@@ -240,6 +267,7 @@ class TestCreateUser:
|
||||
result = create_user(
|
||||
user_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -253,6 +281,7 @@ class TestCreateUser:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock, # noqa: ARG002
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
mock_seats.return_value = "Seat limit reached"
|
||||
resource = make_scim_user()
|
||||
@@ -260,6 +289,7 @@ class TestCreateUser:
|
||||
result = create_user(
|
||||
user_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -272,6 +302,7 @@ class TestCreateUser:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
mock_dal.get_user_by_email.return_value = None
|
||||
resource = make_scim_user(externalId="ext-123")
|
||||
@@ -279,6 +310,7 @@ class TestCreateUser:
|
||||
result = create_user(
|
||||
user_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -295,6 +327,7 @@ class TestReplaceUser:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
user = make_db_user(email="old@example.com")
|
||||
mock_dal.get_user.return_value = user
|
||||
@@ -307,6 +340,7 @@ class TestReplaceUser:
|
||||
user_id=str(user.id),
|
||||
user_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -319,6 +353,7 @@ class TestReplaceUser:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
mock_dal.get_user.return_value = None
|
||||
|
||||
@@ -326,6 +361,7 @@ class TestReplaceUser:
|
||||
user_id=str(uuid4()),
|
||||
user_resource=make_scim_user(),
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -338,6 +374,7 @@ class TestReplaceUser:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
user = make_db_user(is_active=False)
|
||||
mock_dal.get_user.return_value = user
|
||||
@@ -348,6 +385,7 @@ class TestReplaceUser:
|
||||
user_id=str(user.id),
|
||||
user_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -359,6 +397,7 @@ class TestReplaceUser:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
user = make_db_user()
|
||||
mock_dal.get_user.return_value = user
|
||||
@@ -369,11 +408,14 @@ class TestReplaceUser:
|
||||
user_id=str(user.id),
|
||||
user_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimUserResource)
|
||||
mock_dal.sync_user_external_id.assert_called_once_with(user.id, None)
|
||||
mock_dal.sync_user_external_id.assert_called_once_with(
|
||||
user.id, None, scim_username="test@example.com"
|
||||
)
|
||||
|
||||
|
||||
class TestPatchUser:
|
||||
@@ -384,6 +426,7 @@ class TestPatchUser:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
user = make_db_user(is_active=True)
|
||||
mock_dal.get_user.return_value = user
|
||||
@@ -401,6 +444,7 @@ class TestPatchUser:
|
||||
user_id=str(user.id),
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -412,6 +456,7 @@ class TestPatchUser:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
mock_dal.get_user.return_value = None
|
||||
patch_req = ScimPatchRequest(
|
||||
@@ -428,11 +473,45 @@ class TestPatchUser:
|
||||
user_id=str(uuid4()),
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert_scim_error(result, 404)
|
||||
|
||||
def test_patch_displayname_persists(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
"""PATCH displayName should update personal_name in the DB."""
|
||||
user = make_db_user(personal_name="Old Name")
|
||||
mock_dal.get_user.return_value = user
|
||||
patch_req = ScimPatchRequest(
|
||||
Operations=[
|
||||
ScimPatchOperation(
|
||||
op=ScimPatchOperationType.REPLACE,
|
||||
path="displayName",
|
||||
value="New Display Name",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
result = patch_user(
|
||||
user_id=str(user.id),
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimUserResource)
|
||||
# Verify the update_user call received the new display name
|
||||
call_kwargs = mock_dal.update_user.call_args
|
||||
assert call_kwargs[1]["personal_name"] == "New Display Name"
|
||||
|
||||
@patch("ee.onyx.server.scim.api.apply_user_patch")
|
||||
def test_patch_error_returns_error_response(
|
||||
self,
|
||||
@@ -440,6 +519,7 @@ class TestPatchUser:
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
user = make_db_user()
|
||||
mock_dal.get_user.return_value = user
|
||||
@@ -457,6 +537,7 @@ class TestPatchUser:
|
||||
user_id=str(user.id),
|
||||
patch_request=patch_req,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
@@ -519,3 +600,87 @@ class TestDeleteUser:
|
||||
)
|
||||
|
||||
assert_scim_error(result, 404)
|
||||
|
||||
|
||||
class TestScimNameToStr:
|
||||
"""Tests for _scim_name_to_str helper."""
|
||||
|
||||
def test_prefers_given_family_over_formatted(self) -> None:
|
||||
"""Okta may send stale formatted while updating givenName/familyName."""
|
||||
name = ScimName(givenName="Jane", familyName="Smith", formatted="Old Name")
|
||||
assert _scim_name_to_str(name) == "Jane Smith"
|
||||
|
||||
def test_given_name_only(self) -> None:
|
||||
name = ScimName(givenName="Jane")
|
||||
assert _scim_name_to_str(name) == "Jane"
|
||||
|
||||
def test_family_name_only(self) -> None:
|
||||
name = ScimName(familyName="Smith")
|
||||
assert _scim_name_to_str(name) == "Smith"
|
||||
|
||||
def test_falls_back_to_formatted(self) -> None:
|
||||
name = ScimName(formatted="Display Name")
|
||||
assert _scim_name_to_str(name) == "Display Name"
|
||||
|
||||
def test_none_returns_none(self) -> None:
|
||||
assert _scim_name_to_str(None) is None
|
||||
|
||||
def test_empty_name_returns_none(self) -> None:
|
||||
name = ScimName()
|
||||
assert _scim_name_to_str(name) is None
|
||||
|
||||
|
||||
class TestEmailCasePreservation:
|
||||
"""Tests verifying email case is preserved through SCIM endpoints."""
|
||||
|
||||
@patch("ee.onyx.server.scim.api._check_seat_availability", return_value=None)
|
||||
def test_create_preserves_username_case(
|
||||
self,
|
||||
mock_seats: MagicMock, # noqa: ARG002
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
"""POST /Users with mixed-case userName returns the original case."""
|
||||
mock_dal.get_user_by_email.return_value = None
|
||||
resource = make_scim_user(userName="Alice@Example.COM")
|
||||
|
||||
result = create_user(
|
||||
user_resource=resource,
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimUserResource)
|
||||
assert result.userName == "Alice@Example.COM"
|
||||
assert result.emails[0].value == "Alice@Example.COM"
|
||||
|
||||
def test_get_preserves_username_case(
|
||||
self,
|
||||
mock_db_session: MagicMock,
|
||||
mock_token: MagicMock,
|
||||
mock_dal: MagicMock,
|
||||
provider: ScimProvider,
|
||||
) -> None:
|
||||
"""GET /Users/{id} returns the original-case userName from mapping."""
|
||||
user = make_db_user(email="alice@example.com")
|
||||
mock_dal.get_user.return_value = user
|
||||
mapping = make_user_mapping(
|
||||
external_id="ext-1",
|
||||
user_id=user.id,
|
||||
scim_username="Alice@Example.COM",
|
||||
)
|
||||
mock_dal.get_user_mapping_by_user_id.return_value = mapping
|
||||
|
||||
result = get_user(
|
||||
user_id=str(user.id),
|
||||
_token=mock_token,
|
||||
provider=provider,
|
||||
db_session=mock_db_session,
|
||||
)
|
||||
|
||||
assert isinstance(result, ScimUserResource)
|
||||
assert result.userName == "Alice@Example.COM"
|
||||
assert result.emails[0].value == "Alice@Example.COM"
|
||||
|
||||
@@ -1,9 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import time
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
import onyx.tools.tool_implementations.open_url.onyx_web_crawler as crawler_module
|
||||
from onyx.tools.tool_implementations.open_url.onyx_web_crawler import (
|
||||
DEFAULT_CONNECT_TIMEOUT_SECONDS,
|
||||
)
|
||||
from onyx.tools.tool_implementations.open_url.onyx_web_crawler import (
|
||||
DEFAULT_READ_TIMEOUT_SECONDS,
|
||||
)
|
||||
from onyx.tools.tool_implementations.open_url.onyx_web_crawler import OnyxWebCrawler
|
||||
|
||||
|
||||
@@ -181,3 +191,163 @@ def test_fetch_url_html_within_size_limit(monkeypatch: pytest.MonkeyPatch) -> No
|
||||
|
||||
assert "hello world" in result.full_content
|
||||
assert result.scrape_successful is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers for parallel / failure-isolation / timeout tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_mock_response(
|
||||
*,
|
||||
status_code: int = 200,
|
||||
content: bytes = b"<html><body>Hello</body></html>",
|
||||
content_type: str = "text/html",
|
||||
delay: float = 0.0,
|
||||
) -> MagicMock:
|
||||
"""Create a mock response that behaves like a requests.Response."""
|
||||
resp = MagicMock()
|
||||
resp.status_code = status_code
|
||||
resp.headers = {"Content-Type": content_type}
|
||||
|
||||
if delay:
|
||||
original_content = content
|
||||
|
||||
@property # type: ignore[misc]
|
||||
def _delayed_content(_self: object) -> bytes:
|
||||
time.sleep(delay)
|
||||
return original_content
|
||||
|
||||
type(resp).content = _delayed_content
|
||||
else:
|
||||
resp.content = content
|
||||
|
||||
resp.apparent_encoding = None
|
||||
resp.encoding = None
|
||||
|
||||
return resp
|
||||
|
||||
|
||||
class TestParallelExecution:
|
||||
"""Verify that contents() fetches URLs in parallel."""
|
||||
|
||||
@patch("onyx.tools.tool_implementations.open_url.onyx_web_crawler.ssrf_safe_get")
|
||||
def test_multiple_urls_fetched_concurrently(self, mock_get: MagicMock) -> None:
|
||||
"""With a per-URL delay, parallel execution should be much faster than sequential."""
|
||||
per_url_delay = 0.3
|
||||
num_urls = 5
|
||||
urls = [f"http://example.com/page{i}" for i in range(num_urls)]
|
||||
|
||||
mock_get.return_value = _make_mock_response(delay=per_url_delay)
|
||||
|
||||
crawler = OnyxWebCrawler()
|
||||
start = time.monotonic()
|
||||
results = crawler.contents(urls)
|
||||
elapsed = time.monotonic() - start
|
||||
|
||||
# Sequential would take ~1.5s; parallel should be well under that
|
||||
assert elapsed < per_url_delay * num_urls * 0.7
|
||||
assert len(results) == num_urls
|
||||
assert all(r.scrape_successful for r in results)
|
||||
|
||||
@patch("onyx.tools.tool_implementations.open_url.onyx_web_crawler.ssrf_safe_get")
|
||||
def test_empty_urls_returns_empty(self, mock_get: MagicMock) -> None:
|
||||
crawler = OnyxWebCrawler()
|
||||
results = crawler.contents([])
|
||||
assert results == []
|
||||
mock_get.assert_not_called()
|
||||
|
||||
@patch("onyx.tools.tool_implementations.open_url.onyx_web_crawler.ssrf_safe_get")
|
||||
def test_single_url(self, mock_get: MagicMock) -> None:
|
||||
mock_get.return_value = _make_mock_response()
|
||||
crawler = OnyxWebCrawler()
|
||||
results = crawler.contents(["http://example.com"])
|
||||
assert len(results) == 1
|
||||
assert results[0].scrape_successful
|
||||
|
||||
|
||||
class TestFailureIsolation:
|
||||
"""Verify that one URL failure doesn't affect others in the batch."""
|
||||
|
||||
@patch("onyx.tools.tool_implementations.open_url.onyx_web_crawler.ssrf_safe_get")
|
||||
def test_one_failure_doesnt_kill_batch(self, mock_get: MagicMock) -> None:
|
||||
good_resp = _make_mock_response()
|
||||
bad_resp = _make_mock_response(status_code=500)
|
||||
|
||||
# First and third URLs succeed, second fails
|
||||
mock_get.side_effect = [good_resp, bad_resp, good_resp]
|
||||
|
||||
crawler = OnyxWebCrawler()
|
||||
results = crawler.contents(["http://a.com", "http://b.com", "http://c.com"])
|
||||
|
||||
assert len(results) == 3
|
||||
assert results[0].scrape_successful
|
||||
assert not results[1].scrape_successful
|
||||
assert results[2].scrape_successful
|
||||
|
||||
@patch("onyx.tools.tool_implementations.open_url.onyx_web_crawler.ssrf_safe_get")
|
||||
def test_exception_doesnt_kill_batch(self, mock_get: MagicMock) -> None:
|
||||
good_resp = _make_mock_response()
|
||||
|
||||
# Second URL raises an exception
|
||||
mock_get.side_effect = [
|
||||
good_resp,
|
||||
RuntimeError("connection reset"),
|
||||
_make_mock_response(),
|
||||
]
|
||||
|
||||
crawler = OnyxWebCrawler()
|
||||
results = crawler.contents(["http://a.com", "http://b.com", "http://c.com"])
|
||||
|
||||
assert len(results) == 3
|
||||
assert results[0].scrape_successful
|
||||
assert not results[1].scrape_successful
|
||||
assert results[2].scrape_successful
|
||||
|
||||
@patch("onyx.tools.tool_implementations.open_url.onyx_web_crawler.ssrf_safe_get")
|
||||
def test_ssrf_exception_doesnt_kill_batch(self, mock_get: MagicMock) -> None:
|
||||
from onyx.utils.url import SSRFException
|
||||
|
||||
good_resp = _make_mock_response()
|
||||
mock_get.side_effect = [
|
||||
good_resp,
|
||||
SSRFException("blocked"),
|
||||
_make_mock_response(),
|
||||
]
|
||||
|
||||
crawler = OnyxWebCrawler()
|
||||
results = crawler.contents(
|
||||
["http://a.com", "http://internal.local", "http://c.com"]
|
||||
)
|
||||
|
||||
assert len(results) == 3
|
||||
assert results[0].scrape_successful
|
||||
assert not results[1].scrape_successful
|
||||
assert results[2].scrape_successful
|
||||
|
||||
|
||||
class TestTupleTimeout:
|
||||
"""Verify that separate connect and read timeouts are passed correctly."""
|
||||
|
||||
@patch("onyx.tools.tool_implementations.open_url.onyx_web_crawler.ssrf_safe_get")
|
||||
def test_default_tuple_timeout(self, mock_get: MagicMock) -> None:
|
||||
mock_get.return_value = _make_mock_response()
|
||||
|
||||
crawler = OnyxWebCrawler()
|
||||
crawler.contents(["http://example.com"])
|
||||
|
||||
call_kwargs = mock_get.call_args
|
||||
assert call_kwargs.kwargs["timeout"] == (
|
||||
DEFAULT_CONNECT_TIMEOUT_SECONDS,
|
||||
DEFAULT_READ_TIMEOUT_SECONDS,
|
||||
)
|
||||
|
||||
@patch("onyx.tools.tool_implementations.open_url.onyx_web_crawler.ssrf_safe_get")
|
||||
def test_custom_tuple_timeout(self, mock_get: MagicMock) -> None:
|
||||
mock_get.return_value = _make_mock_response()
|
||||
|
||||
crawler = OnyxWebCrawler(timeout_seconds=30, connect_timeout_seconds=3)
|
||||
crawler.contents(["http://example.com"])
|
||||
|
||||
call_kwargs = mock_get.call_args
|
||||
assert call_kwargs.kwargs["timeout"] == (3, 30)
|
||||
|
||||
@@ -291,7 +291,7 @@ class TestSsrfSafeGet:
|
||||
assert call_args[1]["headers"]["User-Agent"] == "TestBot/1.0"
|
||||
|
||||
def test_passes_timeout(self) -> None:
|
||||
"""Test that timeout is passed through."""
|
||||
"""Test that timeout is passed through, including tuple form."""
|
||||
mock_response = MagicMock()
|
||||
mock_response.is_redirect = False
|
||||
|
||||
@@ -301,7 +301,7 @@ class TestSsrfSafeGet:
|
||||
with patch("onyx.utils.url.requests.get") as mock_get:
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
ssrf_safe_get("http://example.com/", timeout=30)
|
||||
ssrf_safe_get("http://example.com/", timeout=(5, 15))
|
||||
|
||||
call_args = mock_get.call_args
|
||||
assert call_args[1]["timeout"] == 30
|
||||
assert call_args[1]["timeout"] == (5, 15)
|
||||
|
||||
@@ -147,7 +147,7 @@ Add clear comments:
|
||||
|
||||
## Trunk-based development and feature flags
|
||||
|
||||
- **PRs should contain no more than 500 lines of real change**
|
||||
- **PRs should contain no more than 500 lines of real change.**
|
||||
- **Merge to main frequently.** Avoid long-lived feature branches—they create merge conflicts and integration pain.
|
||||
- **Use feature flags for incremental rollout.**
|
||||
- Large features should be merged in small, shippable increments behind a flag.
|
||||
@@ -155,3 +155,11 @@ Add clear comments:
|
||||
- **Keep flags short-lived.** Once a feature is fully rolled out, remove the flag and dead code paths promptly.
|
||||
- **Flag at the right level.** Prefer flagging at API/UI entry points rather than deep in business logic.
|
||||
- **Test both flag states.** Ensure the codebase works correctly with the flag on and off.
|
||||
|
||||
---
|
||||
|
||||
## Misc
|
||||
|
||||
- Any TODOs you add in the code must be accompanied by either the name/username
|
||||
of the owner of that TODO, or an issue number for an issue referencing that
|
||||
piece of work.
|
||||
|
||||
@@ -490,7 +490,6 @@ func createCherryPickPR(headBranch, baseBranch, title string, commitSHAs, commit
|
||||
|
||||
// Add standard checklist
|
||||
body += "\n\n"
|
||||
body += "- [x] [Required] I have considered whether this PR needs to be cherry-picked to the latest beta branch.\n"
|
||||
body += "- [x] [Optional] Override Linear Check\n"
|
||||
|
||||
cmd := exec.Command("gh", "pr", "create",
|
||||
|
||||
@@ -118,7 +118,7 @@ func runCI(cmd *cobra.Command, args []string, opts *RunCIOptions) {
|
||||
// Create the CI branch
|
||||
ciBranch := fmt.Sprintf("run-ci/%s", prNumber)
|
||||
prTitle := fmt.Sprintf("chore: [Running GitHub actions for #%s]", prNumber)
|
||||
prBody := fmt.Sprintf("This PR runs GitHub Actions CI for #%s.\n\n- [x] I have considered whether this PR needs to be cherry-picked to the latest beta branch.\n- [x] Override Linear Check\n\n**This PR should be closed (not merged) after CI completes.**", prNumber)
|
||||
prBody := fmt.Sprintf("This PR runs GitHub Actions CI for #%s.\n\n- [x] Override Linear Check\n\n**This PR should be closed (not merged) after CI completes.**", prNumber)
|
||||
|
||||
// Fetch the fork's branch
|
||||
if forkRepo == "" {
|
||||
|
||||
@@ -105,6 +105,18 @@ const nextConfig = {
|
||||
destination: "/app",
|
||||
permanent: true,
|
||||
},
|
||||
// NRF routes: Redirect to /nrf which doesn't require auth
|
||||
// (NRFPage handles unauthenticated users gracefully with a login modal)
|
||||
{
|
||||
source: "/app/nrf/side-panel",
|
||||
destination: "/nrf/side-panel",
|
||||
permanent: true,
|
||||
},
|
||||
{
|
||||
source: "/app/nrf",
|
||||
destination: "/nrf",
|
||||
permanent: true,
|
||||
},
|
||||
{
|
||||
source: "/chat/:path*",
|
||||
destination: "/app/:path*",
|
||||
|
||||
@@ -31,6 +31,7 @@ import Button from "@/refresh-components/buttons/Button";
|
||||
import CopyIconButton from "@/refresh-components/buttons/CopyIconButton";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { SvgEdit, SvgKey, SvgRefreshCw } from "@opal/icons";
|
||||
import { useCloudSubscription } from "@/hooks/useCloudSubscription";
|
||||
|
||||
function Main() {
|
||||
const {
|
||||
@@ -39,6 +40,8 @@ function Main() {
|
||||
error,
|
||||
} = useSWR<APIKey[]>("/api/admin/api-key", errorHandlingFetcher);
|
||||
|
||||
const canCreateKeys = useCloudSubscription();
|
||||
|
||||
const [fullApiKey, setFullApiKey] = useState<string | null>(null);
|
||||
const [keyIsGenerating, setKeyIsGenerating] = useState(false);
|
||||
const [showCreateUpdateForm, setShowCreateUpdateForm] = useState(false);
|
||||
@@ -70,12 +73,23 @@ function Main() {
|
||||
const introSection = (
|
||||
<div className="flex flex-col items-start gap-4">
|
||||
<Text as="p">
|
||||
API Keys allow you to access Onyx APIs programmatically. Click the
|
||||
button below to generate a new API Key.
|
||||
API Keys allow you to access Onyx APIs programmatically.
|
||||
{canCreateKeys
|
||||
? " Click the button below to generate a new API Key."
|
||||
: ""}
|
||||
</Text>
|
||||
<CreateButton onClick={() => setShowCreateUpdateForm(true)}>
|
||||
Create API Key
|
||||
</CreateButton>
|
||||
{canCreateKeys ? (
|
||||
<CreateButton onClick={() => setShowCreateUpdateForm(true)}>
|
||||
Create API Key
|
||||
</CreateButton>
|
||||
) : (
|
||||
<div className="flex flex-col gap-2 rounded-lg bg-background-tint-02 p-4">
|
||||
<Text as="p" text04>
|
||||
This feature requires an active paid subscription.
|
||||
</Text>
|
||||
<Button href="/admin/billing">Upgrade Plan</Button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -109,7 +123,7 @@ function Main() {
|
||||
title="New API Key"
|
||||
icon={SvgKey}
|
||||
onClose={() => setFullApiKey(null)}
|
||||
description="Make sure you copy your new API key. You won’t be able to see this key again."
|
||||
description="Make sure you copy your new API key. You won't be able to see this key again."
|
||||
/>
|
||||
<Modal.Body>
|
||||
<Text as="p" className="break-all flex-1">
|
||||
@@ -124,88 +138,94 @@ function Main() {
|
||||
|
||||
{introSection}
|
||||
|
||||
<Separator />
|
||||
{canCreateKeys && (
|
||||
<>
|
||||
<Separator />
|
||||
|
||||
<Title className="mt-6">Existing API Keys</Title>
|
||||
<Table className="overflow-visible">
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>Name</TableHead>
|
||||
<TableHead>API Key</TableHead>
|
||||
<TableHead>Role</TableHead>
|
||||
<TableHead>Regenerate</TableHead>
|
||||
<TableHead>Delete</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{filteredApiKeys.map((apiKey) => (
|
||||
<TableRow key={apiKey.api_key_id}>
|
||||
<TableCell>
|
||||
<Button
|
||||
internal
|
||||
onClick={() => handleEdit(apiKey)}
|
||||
leftIcon={SvgEdit}
|
||||
>
|
||||
{apiKey.api_key_name || <i>null</i>}
|
||||
</Button>
|
||||
</TableCell>
|
||||
<TableCell className="max-w-64">
|
||||
{apiKey.api_key_display}
|
||||
</TableCell>
|
||||
<TableCell className="max-w-64">
|
||||
{apiKey.api_key_role.toUpperCase()}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<Button
|
||||
internal
|
||||
leftIcon={SvgRefreshCw}
|
||||
onClick={async () => {
|
||||
setKeyIsGenerating(true);
|
||||
const response = await regenerateApiKey(apiKey);
|
||||
setKeyIsGenerating(false);
|
||||
if (!response.ok) {
|
||||
const errorMsg = await response.text();
|
||||
toast.error(`Failed to regenerate API Key: ${errorMsg}`);
|
||||
return;
|
||||
}
|
||||
const newKey = (await response.json()) as APIKey;
|
||||
setFullApiKey(newKey.api_key);
|
||||
mutate("/api/admin/api-key");
|
||||
}}
|
||||
>
|
||||
Refresh
|
||||
</Button>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<DeleteButton
|
||||
onClick={async () => {
|
||||
const response = await deleteApiKey(apiKey.api_key_id);
|
||||
if (!response.ok) {
|
||||
const errorMsg = await response.text();
|
||||
toast.error(`Failed to delete API Key: ${errorMsg}`);
|
||||
return;
|
||||
}
|
||||
mutate("/api/admin/api-key");
|
||||
}}
|
||||
/>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
<Title className="mt-6">Existing API Keys</Title>
|
||||
<Table className="overflow-visible">
|
||||
<TableHeader>
|
||||
<TableRow>
|
||||
<TableHead>Name</TableHead>
|
||||
<TableHead>API Key</TableHead>
|
||||
<TableHead>Role</TableHead>
|
||||
<TableHead>Regenerate</TableHead>
|
||||
<TableHead>Delete</TableHead>
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{filteredApiKeys.map((apiKey) => (
|
||||
<TableRow key={apiKey.api_key_id}>
|
||||
<TableCell>
|
||||
<Button
|
||||
internal
|
||||
onClick={() => handleEdit(apiKey)}
|
||||
leftIcon={SvgEdit}
|
||||
>
|
||||
{apiKey.api_key_name || <i>null</i>}
|
||||
</Button>
|
||||
</TableCell>
|
||||
<TableCell className="max-w-64">
|
||||
{apiKey.api_key_display}
|
||||
</TableCell>
|
||||
<TableCell className="max-w-64">
|
||||
{apiKey.api_key_role.toUpperCase()}
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<Button
|
||||
internal
|
||||
leftIcon={SvgRefreshCw}
|
||||
onClick={async () => {
|
||||
setKeyIsGenerating(true);
|
||||
const response = await regenerateApiKey(apiKey);
|
||||
setKeyIsGenerating(false);
|
||||
if (!response.ok) {
|
||||
const errorMsg = await response.text();
|
||||
toast.error(
|
||||
`Failed to regenerate API Key: ${errorMsg}`
|
||||
);
|
||||
return;
|
||||
}
|
||||
const newKey = (await response.json()) as APIKey;
|
||||
setFullApiKey(newKey.api_key);
|
||||
mutate("/api/admin/api-key");
|
||||
}}
|
||||
>
|
||||
Refresh
|
||||
</Button>
|
||||
</TableCell>
|
||||
<TableCell>
|
||||
<DeleteButton
|
||||
onClick={async () => {
|
||||
const response = await deleteApiKey(apiKey.api_key_id);
|
||||
if (!response.ok) {
|
||||
const errorMsg = await response.text();
|
||||
toast.error(`Failed to delete API Key: ${errorMsg}`);
|
||||
return;
|
||||
}
|
||||
mutate("/api/admin/api-key");
|
||||
}}
|
||||
/>
|
||||
</TableCell>
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
|
||||
{showCreateUpdateForm && (
|
||||
<OnyxApiKeyForm
|
||||
onCreateApiKey={(apiKey) => {
|
||||
setFullApiKey(apiKey.api_key);
|
||||
}}
|
||||
onClose={() => {
|
||||
setShowCreateUpdateForm(false);
|
||||
setSelectedApiKey(undefined);
|
||||
mutate("/api/admin/api-key");
|
||||
}}
|
||||
apiKey={selectedApiKey}
|
||||
/>
|
||||
{showCreateUpdateForm && (
|
||||
<OnyxApiKeyForm
|
||||
onCreateApiKey={(apiKey) => {
|
||||
setFullApiKey(apiKey.api_key);
|
||||
}}
|
||||
onClose={() => {
|
||||
setShowCreateUpdateForm(false);
|
||||
setSelectedApiKey(undefined);
|
||||
mutate("/api/admin/api-key");
|
||||
}}
|
||||
apiKey={selectedApiKey}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
)}
|
||||
</>
|
||||
);
|
||||
|
||||
@@ -112,18 +112,6 @@ function MainContent({
|
||||
<CreateButton href="/app/agents/create?admin=true">
|
||||
Create Your First Assistant
|
||||
</CreateButton>
|
||||
<div className="mt-6 pt-6 border-t border-border">
|
||||
<Text className="text-subtle text-sm">
|
||||
OR go{" "}
|
||||
<a
|
||||
href="/admin/configuration/default-assistant"
|
||||
className="text-link underline"
|
||||
>
|
||||
here
|
||||
</a>{" "}
|
||||
to adjust the Default Assistant
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@@ -5,7 +5,6 @@ import { SlackBot, ValidSources } from "@/lib/types";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { useState, useEffect, useRef } from "react";
|
||||
import { updateSlackBotField } from "@/lib/updateSlackBotField";
|
||||
import { Checkbox } from "@/app/admin/settings/SettingsForm";
|
||||
import { SlackTokensForm } from "./SlackTokensForm";
|
||||
import { SourceIcon } from "@/components/SourceIcon";
|
||||
import { EditableStringFieldDisplay } from "@/components/EditableStringFieldDisplay";
|
||||
@@ -15,6 +14,28 @@ import Button from "@/refresh-components/buttons/Button";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { SvgChevronDownSmall, SvgTrash } from "@opal/icons";
|
||||
|
||||
function Checkbox({
|
||||
label,
|
||||
checked,
|
||||
onChange,
|
||||
}: {
|
||||
label: string;
|
||||
checked: boolean;
|
||||
onChange: (e: React.ChangeEvent<HTMLInputElement>) => void;
|
||||
}) {
|
||||
return (
|
||||
<label className="flex text-xs cursor-pointer">
|
||||
<input
|
||||
checked={checked}
|
||||
onChange={onChange}
|
||||
type="checkbox"
|
||||
className="mr-2 w-3.5 h-3.5 my-auto"
|
||||
/>
|
||||
<span className="block font-medium text-text-700 text-sm">{label}</span>
|
||||
</label>
|
||||
);
|
||||
}
|
||||
|
||||
export const ExistingSlackBotForm = ({
|
||||
existingSlackBot,
|
||||
refreshSlackBot,
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
"use client";
|
||||
|
||||
import ChatPreferencesPage from "@/refresh-pages/admin/ChatPreferencesPage";
|
||||
|
||||
export default function Page() {
|
||||
return <ChatPreferencesPage />;
|
||||
}
|
||||
@@ -1,307 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { Formik, Form } from "formik";
|
||||
import { ThreeDotsLoader } from "@/components/Loading";
|
||||
import { useRouter } from "next/navigation";
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import useSWR, { mutate } from "swr";
|
||||
import { ErrorCallout } from "@/components/ErrorCallout";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { useAgents } from "@/hooks/useAgents";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import { SubLabel } from "@/components/Field";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import { useSettingsContext } from "@/providers/SettingsProvider";
|
||||
import Link from "next/link";
|
||||
import { Callout } from "@/components/ui/callout";
|
||||
import { ToolSnapshot, MCPServersResponse } from "@/lib/tools/interfaces";
|
||||
import { ToolSelector } from "@/components/admin/assistants/ToolSelector";
|
||||
import InputTextArea from "@/refresh-components/inputs/InputTextArea";
|
||||
import { HoverPopup } from "@/components/HoverPopup";
|
||||
import { Info } from "lucide-react";
|
||||
import { SvgOnyxLogo } from "@opal/icons";
|
||||
|
||||
interface DefaultAssistantConfiguration {
|
||||
tool_ids: number[];
|
||||
system_prompt: string | null;
|
||||
default_system_prompt: string;
|
||||
}
|
||||
|
||||
interface DefaultAssistantUpdateRequest {
|
||||
tool_ids?: number[];
|
||||
system_prompt?: string | null;
|
||||
}
|
||||
|
||||
function DefaultAssistantConfig() {
|
||||
const router = useRouter();
|
||||
const { refresh: refreshAgents } = useAgents();
|
||||
const combinedSettings = useSettingsContext();
|
||||
|
||||
const {
|
||||
data: config,
|
||||
isLoading,
|
||||
error,
|
||||
} = useSWR<DefaultAssistantConfiguration>(
|
||||
"/api/admin/default-assistant/configuration",
|
||||
errorHandlingFetcher
|
||||
);
|
||||
|
||||
// Use the same endpoint as regular assistant editor
|
||||
const { data: tools } = useSWR<ToolSnapshot[]>(
|
||||
"/api/tool",
|
||||
errorHandlingFetcher
|
||||
);
|
||||
|
||||
const { data: mcpServersResponse } = useSWR<MCPServersResponse>(
|
||||
"/api/admin/mcp/servers",
|
||||
errorHandlingFetcher
|
||||
);
|
||||
|
||||
const [isSubmitting, setIsSubmitting] = useState(false);
|
||||
|
||||
const persistConfiguration = async (
|
||||
updates: DefaultAssistantUpdateRequest
|
||||
) => {
|
||||
const response = await fetch("/api/admin/default-assistant", {
|
||||
method: "PATCH",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify(updates),
|
||||
});
|
||||
if (!response.ok) {
|
||||
const errorText = await response.text();
|
||||
throw new Error(errorText || "Failed to update assistant");
|
||||
}
|
||||
};
|
||||
|
||||
if (isLoading) {
|
||||
return <ThreeDotsLoader />;
|
||||
}
|
||||
|
||||
if (error) {
|
||||
return (
|
||||
<ErrorCallout
|
||||
errorTitle="Failed to load configuration"
|
||||
errorMsg="Unable to fetch the default assistant configuration."
|
||||
/>
|
||||
);
|
||||
}
|
||||
|
||||
if (combinedSettings?.settings?.disable_default_assistant) {
|
||||
return (
|
||||
<div>
|
||||
<Callout type="notice">
|
||||
<p className="mb-3">
|
||||
The default assistant is currently disabled in your workspace
|
||||
settings.
|
||||
</p>
|
||||
<p>
|
||||
To configure the default assistant, you must first enable it in{" "}
|
||||
<Link href="/admin/settings" className="text-link font-medium">
|
||||
Workspace Settings
|
||||
</Link>
|
||||
.
|
||||
</p>
|
||||
</Callout>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
if (!config || !tools) {
|
||||
return <ThreeDotsLoader />;
|
||||
}
|
||||
|
||||
const enabledToolsMap: { [key: number]: boolean } = {};
|
||||
tools.forEach((tool) => {
|
||||
enabledToolsMap[tool.id] = config.tool_ids.includes(tool.id);
|
||||
});
|
||||
|
||||
return (
|
||||
<div>
|
||||
<Formik
|
||||
enableReinitialize
|
||||
initialValues={{
|
||||
enabled_tools_map: enabledToolsMap,
|
||||
// Display the default prompt when system_prompt is null
|
||||
system_prompt: config.system_prompt ?? config.default_system_prompt,
|
||||
// Track if we're using the default (null in DB)
|
||||
isUsingDefault: config.system_prompt === null,
|
||||
}}
|
||||
onSubmit={async (values) => {
|
||||
setIsSubmitting(true);
|
||||
try {
|
||||
const enabledToolIds = Object.keys(values.enabled_tools_map)
|
||||
.map((id) => Number(id))
|
||||
.filter((id) => values.enabled_tools_map[id]);
|
||||
|
||||
const updates: DefaultAssistantUpdateRequest = {
|
||||
tool_ids: enabledToolIds,
|
||||
};
|
||||
|
||||
// Determine if we need to send system_prompt
|
||||
// Use config directly since it reflects the original DB state
|
||||
const wasUsingDefault = config.system_prompt === null;
|
||||
const initialPrompt =
|
||||
config.system_prompt ?? config.default_system_prompt;
|
||||
const isNowUsingDefault = values.isUsingDefault;
|
||||
const promptChanged = values.system_prompt !== initialPrompt;
|
||||
|
||||
if (wasUsingDefault && isNowUsingDefault && !promptChanged) {
|
||||
// Was default, still default, no changes - don't send
|
||||
} else if (isNowUsingDefault) {
|
||||
// User clicked reset - send null to set DB to null (use default)
|
||||
updates.system_prompt = null;
|
||||
} else if (promptChanged || wasUsingDefault !== isNowUsingDefault) {
|
||||
// Prompt changed or switched from default to custom
|
||||
updates.system_prompt = values.system_prompt;
|
||||
}
|
||||
|
||||
await persistConfiguration(updates);
|
||||
|
||||
await mutate("/api/admin/default-assistant/configuration");
|
||||
router.refresh();
|
||||
await refreshAgents();
|
||||
|
||||
toast.success("Default assistant updated successfully!");
|
||||
} catch (error: any) {
|
||||
toast.error(error.message || "Failed to update assistant");
|
||||
} finally {
|
||||
setIsSubmitting(false);
|
||||
}
|
||||
}}
|
||||
>
|
||||
{({ values, setFieldValue }) => (
|
||||
<Form>
|
||||
<div className="space-y-6">
|
||||
<div className="mt-4">
|
||||
<Text as="p" className="text-text-dark">
|
||||
Configure which capabilities are enabled for the default
|
||||
assistant in chat. These settings apply to all users who
|
||||
haven't customized their assistant preferences.
|
||||
</Text>
|
||||
</div>
|
||||
|
||||
<Separator />
|
||||
|
||||
<div className="max-w-4xl">
|
||||
<div className="flex gap-x-2 items-center">
|
||||
<Text
|
||||
as="p"
|
||||
mainUiBody
|
||||
text04
|
||||
className="font-medium text-sm"
|
||||
>
|
||||
Instructions
|
||||
</Text>
|
||||
</div>
|
||||
<div className="flex items-start gap-1.5 mb-1">
|
||||
<SubLabel>
|
||||
Add instructions to tailor the behavior of the assistant.
|
||||
</SubLabel>
|
||||
<HoverPopup
|
||||
mainContent={
|
||||
<Info className="h-3.5 w-3.5 text-text-400 cursor-help" />
|
||||
}
|
||||
popupContent={
|
||||
<div className="text-xs space-y-1.5 max-w-xs bg-background-neutral-dark-03 text-text-light-05">
|
||||
<div>You can use placeholders in your prompt:</div>
|
||||
<div>
|
||||
<span className="font-mono font-semibold">
|
||||
{"{{CURRENT_DATETIME}}"}
|
||||
</span>{" "}
|
||||
- Injects the current date and day of the week in a
|
||||
human/LLM readable format.
|
||||
</div>
|
||||
<div>
|
||||
<span className="font-mono font-semibold">
|
||||
{"{{CITATION_GUIDANCE}}"}
|
||||
</span>{" "}
|
||||
- Injects instructions to provide citations for facts
|
||||
found from search tools. This is not included if no
|
||||
search tools are called.
|
||||
</div>
|
||||
<div>
|
||||
<span className="font-mono font-semibold">
|
||||
{"{{REMINDER_TAG_DESCRIPTION}}"}
|
||||
</span>{" "}
|
||||
- Injects instructions for how the Agent should handle
|
||||
system reminder tags.
|
||||
</div>
|
||||
</div>
|
||||
}
|
||||
direction="bottom"
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<InputTextArea
|
||||
rows={8}
|
||||
value={values.system_prompt}
|
||||
onChange={(event) => {
|
||||
setFieldValue("system_prompt", event.target.value);
|
||||
// Mark as no longer using default when user edits
|
||||
if (values.isUsingDefault) {
|
||||
setFieldValue("isUsingDefault", false);
|
||||
}
|
||||
}}
|
||||
placeholder="You are a professional email writing assistant that always uses a polite enthusiastic tone, emphasizes action items, and leaves blanks for the human to fill in when you have unknowns"
|
||||
/>
|
||||
<div className="flex justify-between items-center mt-2">
|
||||
<button
|
||||
type="button"
|
||||
className="text-sm text-link hover:underline disabled:opacity-50 disabled:cursor-not-allowed"
|
||||
disabled={values.isUsingDefault}
|
||||
onClick={() => {
|
||||
setFieldValue(
|
||||
"system_prompt",
|
||||
config.default_system_prompt
|
||||
);
|
||||
setFieldValue("isUsingDefault", true);
|
||||
}}
|
||||
>
|
||||
Reset to Default
|
||||
</button>
|
||||
<Text as="p" mainUiMuted text03 className="text-sm">
|
||||
{values.system_prompt.length} characters
|
||||
</Text>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<Separator />
|
||||
|
||||
<ToolSelector
|
||||
tools={tools}
|
||||
mcpServers={mcpServersResponse?.mcp_servers}
|
||||
enabledToolsMap={values.enabled_tools_map}
|
||||
setFieldValue={setFieldValue}
|
||||
hideSearchTool={
|
||||
combinedSettings?.settings.vector_db_enabled === false
|
||||
}
|
||||
/>
|
||||
|
||||
<div className="flex justify-end pt-4">
|
||||
<Button type="submit" disabled={isSubmitting}>
|
||||
{isSubmitting ? "Saving..." : "Save Changes"}
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</Form>
|
||||
)}
|
||||
</Formik>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
export default function Page() {
|
||||
return (
|
||||
<>
|
||||
<AdminPageTitle
|
||||
title="Default Assistant"
|
||||
icon={<SvgOnyxLogo size={32} className="my-auto stroke-text-04" />}
|
||||
/>
|
||||
<DefaultAssistantConfig />
|
||||
</>
|
||||
);
|
||||
}
|
||||
@@ -2,10 +2,10 @@
|
||||
|
||||
import Image from "next/image";
|
||||
import { useMemo, useState, useReducer } from "react";
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import { InfoIcon } from "@/components/icons/icons";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import * as SettingsLayouts from "@/layouts/settings-layouts";
|
||||
import { Content } from "@opal/layouts";
|
||||
import useSWR from "swr";
|
||||
import { errorHandlingFetcher, FetchError } from "@/lib/fetcher";
|
||||
import { ThreeDotsLoader } from "@/components/Loading";
|
||||
@@ -22,7 +22,6 @@ import {
|
||||
SvgOnyxLogo,
|
||||
SvgX,
|
||||
} from "@opal/icons";
|
||||
|
||||
import { WebProviderSetupModal } from "@/app/admin/configuration/web-search/WebProviderSetupModal";
|
||||
import {
|
||||
SEARCH_PROVIDERS_URL,
|
||||
@@ -402,36 +401,40 @@ export default function Page() {
|
||||
: undefined);
|
||||
|
||||
return (
|
||||
<>
|
||||
<AdminPageTitle
|
||||
title="Web Search"
|
||||
<SettingsLayouts.Root>
|
||||
<SettingsLayouts.Header
|
||||
icon={SvgGlobe}
|
||||
includeDivider={false}
|
||||
title="Web Search"
|
||||
description="Search settings for external search across the internet."
|
||||
separator
|
||||
/>
|
||||
<Callout type="danger" title="Failed to load web search settings">
|
||||
{message}
|
||||
{detail && (
|
||||
<Text as="p" className="mt-2 text-text-03" mainContentBody text03>
|
||||
{detail}
|
||||
</Text>
|
||||
)}
|
||||
</Callout>
|
||||
</>
|
||||
<SettingsLayouts.Body>
|
||||
<Callout type="danger" title="Failed to load web search settings">
|
||||
{message}
|
||||
{detail && (
|
||||
<Text as="p" className="mt-2 text-text-03" mainContentBody text03>
|
||||
{detail}
|
||||
</Text>
|
||||
)}
|
||||
</Callout>
|
||||
</SettingsLayouts.Body>
|
||||
</SettingsLayouts.Root>
|
||||
);
|
||||
}
|
||||
|
||||
if (isLoading) {
|
||||
return (
|
||||
<>
|
||||
<AdminPageTitle
|
||||
title="Web Search"
|
||||
<SettingsLayouts.Root>
|
||||
<SettingsLayouts.Header
|
||||
icon={SvgGlobe}
|
||||
includeDivider={false}
|
||||
title="Web Search"
|
||||
description="Search settings for external search across the internet."
|
||||
separator
|
||||
/>
|
||||
<div className="mt-8">
|
||||
<SettingsLayouts.Body>
|
||||
<ThreeDotsLoader />
|
||||
</div>
|
||||
</>
|
||||
</SettingsLayouts.Body>
|
||||
</SettingsLayouts.Root>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -827,32 +830,22 @@ export default function Page() {
|
||||
|
||||
return (
|
||||
<>
|
||||
<>
|
||||
<AdminPageTitle icon={SvgGlobe} title="Web Search" />
|
||||
<div className="pt-4 pb-4">
|
||||
<Text as="p" className="text-text-dark">
|
||||
Search settings for external search across the internet.
|
||||
</Text>
|
||||
</div>
|
||||
<SettingsLayouts.Root>
|
||||
<SettingsLayouts.Header
|
||||
icon={SvgGlobe}
|
||||
title="Web Search"
|
||||
description="Search settings for external search across the internet."
|
||||
separator
|
||||
/>
|
||||
|
||||
<Separator />
|
||||
|
||||
<div className="flex w-full flex-col gap-8 pb-6">
|
||||
<div className="flex w-full max-w-[960px] flex-col gap-3">
|
||||
<div className="flex flex-col gap-0.5">
|
||||
<Text as="p" mainContentEmphasis text05>
|
||||
Search Engine
|
||||
</Text>
|
||||
<Text
|
||||
as="p"
|
||||
className="flex items-start gap-[2px] self-stretch text-text-03"
|
||||
secondaryBody
|
||||
text03
|
||||
>
|
||||
External search engine API used for web search result URLs,
|
||||
snippets, and metadata.
|
||||
</Text>
|
||||
</div>
|
||||
<SettingsLayouts.Body>
|
||||
<div className="flex w-full flex-col gap-3">
|
||||
<Content
|
||||
title="Search Engine"
|
||||
description="External search engine API used for web search result URLs, snippets, and metadata."
|
||||
sizePreset="main-content"
|
||||
variant="section"
|
||||
/>
|
||||
|
||||
{activationError && (
|
||||
<Callout type="danger" title="Unable to update default provider">
|
||||
@@ -974,14 +967,12 @@ export default function Page() {
|
||||
size: 16,
|
||||
isHighlighted,
|
||||
})}
|
||||
<div className="flex flex-col gap-0.5">
|
||||
<Text as="p" mainUiAction text05>
|
||||
{label}
|
||||
</Text>
|
||||
<Text as="p" secondaryBody text03>
|
||||
{subtitle}
|
||||
</Text>
|
||||
</div>
|
||||
<Content
|
||||
title={label}
|
||||
description={subtitle}
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
/>
|
||||
</div>
|
||||
<div className="flex items-center justify-end gap-2">
|
||||
{isConfigured && (
|
||||
@@ -1045,20 +1036,13 @@ export default function Page() {
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="flex w-full max-w-[960px] flex-col gap-3">
|
||||
<div className="flex flex-col gap-0.5">
|
||||
<Text as="p" mainContentEmphasis text05>
|
||||
Web Crawler
|
||||
</Text>
|
||||
<Text
|
||||
as="p"
|
||||
className="flex items-start gap-[2px] self-stretch text-text-03"
|
||||
secondaryBody
|
||||
text03
|
||||
>
|
||||
Used to read the full contents of search result pages.
|
||||
</Text>
|
||||
</div>
|
||||
<div className="flex w-full flex-col gap-3">
|
||||
<Content
|
||||
title="Web Crawler"
|
||||
description="Used to read the full contents of search result pages."
|
||||
sizePreset="main-content"
|
||||
variant="section"
|
||||
/>
|
||||
|
||||
{contentActivationError && (
|
||||
<Callout type="danger" title="Unable to update crawler">
|
||||
@@ -1173,14 +1157,12 @@ export default function Page() {
|
||||
size: 16,
|
||||
isHighlighted: isCurrentCrawler,
|
||||
})}
|
||||
<div className="flex flex-col gap-0.5">
|
||||
<Text as="p" mainUiAction text05>
|
||||
{label}
|
||||
</Text>
|
||||
<Text as="p" secondaryBody text03>
|
||||
{subtitle}
|
||||
</Text>
|
||||
</div>
|
||||
<Content
|
||||
title={label}
|
||||
description={subtitle}
|
||||
sizePreset="main-ui"
|
||||
variant="section"
|
||||
/>
|
||||
</div>
|
||||
<div className="flex items-center justify-end gap-2">
|
||||
{provider.provider_type !== "onyx_web_crawler" &&
|
||||
@@ -1244,8 +1226,8 @@ export default function Page() {
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
</SettingsLayouts.Body>
|
||||
</SettingsLayouts.Root>
|
||||
|
||||
<WebProviderSetupModal
|
||||
isOpen={selectedProviderType !== null}
|
||||
|
||||
@@ -1,124 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import useSWR from "swr";
|
||||
import { useContext, useState } from "react";
|
||||
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import { SettingsContext } from "@/providers/SettingsProvider";
|
||||
import { Card } from "@/refresh-components/cards";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
|
||||
import CopyIconButton from "@/refresh-components/buttons/CopyIconButton";
|
||||
import * as GeneralLayouts from "@/layouts/general-layouts";
|
||||
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
|
||||
|
||||
export function AnonymousUserPath() {
|
||||
const settings = useContext(SettingsContext);
|
||||
const [customPath, setCustomPath] = useState<string | null>(null);
|
||||
|
||||
const {
|
||||
data: anonymousUserPath,
|
||||
error,
|
||||
mutate,
|
||||
isLoading,
|
||||
} = useSWR("/api/tenants/anonymous-user-path", (url) =>
|
||||
fetch(url)
|
||||
.then((res) => {
|
||||
return res.json();
|
||||
})
|
||||
.then((data) => {
|
||||
return data.anonymous_user_path;
|
||||
})
|
||||
);
|
||||
|
||||
if (error) {
|
||||
console.error("Failed to fetch anonymous user path:", error);
|
||||
}
|
||||
|
||||
async function handleCustomPathUpdate() {
|
||||
try {
|
||||
// Validate custom path
|
||||
if (!customPath || !customPath.trim()) {
|
||||
toast.error("Custom path cannot be empty");
|
||||
return;
|
||||
}
|
||||
|
||||
if (!/^[a-zA-Z0-9-]+$/.test(customPath)) {
|
||||
toast.error(
|
||||
"Custom path can only contain letters, numbers, and hyphens"
|
||||
);
|
||||
return;
|
||||
}
|
||||
const response = await fetch(
|
||||
`/api/tenants/anonymous-user-path?anonymous_user_path=${encodeURIComponent(
|
||||
customPath
|
||||
)}`,
|
||||
{
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
}
|
||||
);
|
||||
if (!response.ok) {
|
||||
const detail = await response.json();
|
||||
toast.error(detail.detail || "Failed to update anonymous user path");
|
||||
return;
|
||||
}
|
||||
mutate(); // Revalidate the SWR cache
|
||||
toast.success("Anonymous user path updated successfully!");
|
||||
} catch (error) {
|
||||
toast.error(`Failed to update anonymous user path: ${error}`);
|
||||
console.error("Error updating anonymous user path:", error);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="max-w-xl">
|
||||
<Card gap={0}>
|
||||
<GeneralLayouts.Section alignItems="start" gap={0.5}>
|
||||
<Text headingH3>Anonymous User Access</Text>
|
||||
<Text secondaryBody text03>
|
||||
Enable this to allow anonymous users to access all public connectors
|
||||
in your workspace. Anonymous users will not be able to access
|
||||
private or restricted content.
|
||||
</Text>
|
||||
</GeneralLayouts.Section>
|
||||
|
||||
{isLoading ? (
|
||||
<SimpleLoader className="self-center animate-spin mt-4" />
|
||||
) : (
|
||||
<>
|
||||
<GeneralLayouts.Section flexDirection="row" gap={0.5}>
|
||||
<Text mainContentBody text03>
|
||||
{settings?.webDomain}/anonymous/
|
||||
</Text>
|
||||
<InputTypeIn
|
||||
placeholder="your-custom-path"
|
||||
value={customPath ?? anonymousUserPath ?? ""}
|
||||
onChange={(e) => setCustomPath(e.target.value)}
|
||||
showClearButton={false}
|
||||
/>
|
||||
</GeneralLayouts.Section>
|
||||
|
||||
<GeneralLayouts.Section
|
||||
flexDirection="row"
|
||||
gap={0.5}
|
||||
justifyContent="start"
|
||||
>
|
||||
<Button onClick={handleCustomPathUpdate}>Update Path</Button>
|
||||
<CopyIconButton
|
||||
getCopyText={() =>
|
||||
`${settings?.webDomain}/anonymous/${anonymousUserPath ?? ""}`
|
||||
}
|
||||
tooltip="Copy invite link"
|
||||
prominence="secondary"
|
||||
/>
|
||||
</GeneralLayouts.Section>
|
||||
</>
|
||||
)}
|
||||
</Card>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,453 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { Label, SubLabel } from "@/components/Field";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import Title from "@/components/ui/title";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import { Settings } from "./interfaces";
|
||||
import { useRouter } from "next/navigation";
|
||||
import React, { useContext, useState, useEffect } from "react";
|
||||
import { SettingsContext } from "@/providers/SettingsProvider";
|
||||
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
|
||||
import Modal from "@/refresh-components/Modal";
|
||||
import { AuthType, NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
|
||||
import { AnonymousUserPath } from "./AnonymousUserPath";
|
||||
import LLMSelector from "@/components/llm/LLMSelector";
|
||||
import { useVisionProviders } from "./hooks/useVisionProviders";
|
||||
import InputTextArea from "@/refresh-components/inputs/InputTextArea";
|
||||
import { SvgAlertTriangle } from "@opal/icons";
|
||||
import { useUser } from "@/providers/UserProvider";
|
||||
|
||||
export function Checkbox({
|
||||
label,
|
||||
sublabel,
|
||||
checked,
|
||||
onChange,
|
||||
}: {
|
||||
label: string;
|
||||
sublabel?: string;
|
||||
checked: boolean;
|
||||
onChange: (e: React.ChangeEvent<HTMLInputElement>) => void;
|
||||
}) {
|
||||
return (
|
||||
<label className="flex text-xs cursor-pointer">
|
||||
<input
|
||||
checked={checked}
|
||||
onChange={onChange}
|
||||
type="checkbox"
|
||||
className="mr-2 w-3.5 h-3.5 my-auto"
|
||||
/>
|
||||
<div>
|
||||
<span className="block font-medium text-text-700 dark:text-neutral-100 text-sm">
|
||||
{label}
|
||||
</span>
|
||||
{sublabel && <SubLabel>{sublabel}</SubLabel>}
|
||||
</div>
|
||||
</label>
|
||||
);
|
||||
}
|
||||
|
||||
function IntegerInput({
|
||||
label,
|
||||
sublabel,
|
||||
value,
|
||||
onChange,
|
||||
id,
|
||||
placeholder = "Enter a number", // Default placeholder if none is provided
|
||||
}: {
|
||||
label: string;
|
||||
sublabel: string;
|
||||
value: number | null;
|
||||
onChange: (e: React.ChangeEvent<HTMLInputElement>) => void;
|
||||
id?: string;
|
||||
placeholder?: string;
|
||||
}) {
|
||||
return (
|
||||
<label className="flex flex-col text-sm mb-4">
|
||||
<Label>{label}</Label>
|
||||
<SubLabel>{sublabel}</SubLabel>
|
||||
<input
|
||||
type="number"
|
||||
className="mt-1 p-2 border rounded w-full max-w-xs"
|
||||
value={value ?? ""}
|
||||
onChange={onChange}
|
||||
min="1"
|
||||
step="1"
|
||||
id={id}
|
||||
placeholder={placeholder}
|
||||
/>
|
||||
</label>
|
||||
);
|
||||
}
|
||||
|
||||
export function SettingsForm() {
|
||||
const router = useRouter();
|
||||
const { authTypeMetadata } = useUser();
|
||||
const [showConfirmModal, setShowConfirmModal] = useState(false);
|
||||
const [settings, setSettings] = useState<Settings | null>(null);
|
||||
const [chatRetention, setChatRetention] = useState("");
|
||||
const [companyName, setCompanyName] = useState("");
|
||||
const [companyDescription, setCompanyDescription] = useState("");
|
||||
const isEnterpriseEnabled = usePaidEnterpriseFeaturesEnabled();
|
||||
|
||||
const {
|
||||
visionProviders,
|
||||
visionLLM,
|
||||
setVisionLLM,
|
||||
updateDefaultVisionProvider,
|
||||
} = useVisionProviders();
|
||||
|
||||
const combinedSettings = useContext(SettingsContext);
|
||||
|
||||
useEffect(() => {
|
||||
if (combinedSettings) {
|
||||
setSettings(combinedSettings.settings);
|
||||
setChatRetention(
|
||||
combinedSettings.settings.maximum_chat_retention_days?.toString() || ""
|
||||
);
|
||||
setCompanyName(combinedSettings.settings.company_name || "");
|
||||
setCompanyDescription(
|
||||
combinedSettings.settings.company_description || ""
|
||||
);
|
||||
}
|
||||
// We don't need to fetch vision providers here anymore as the hook handles it
|
||||
}, []);
|
||||
|
||||
if (!settings) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const showInviteOnlyModeToggle =
|
||||
authTypeMetadata.authType === AuthType.BASIC ||
|
||||
authTypeMetadata.authType === AuthType.GOOGLE_OAUTH;
|
||||
|
||||
async function updateSettingField(
|
||||
updateRequests: { fieldName: keyof Settings; newValue: any }[]
|
||||
) {
|
||||
// Optimistically update the local state
|
||||
const newSettings: Settings | null = settings
|
||||
? {
|
||||
...settings,
|
||||
...updateRequests.reduce((acc, { fieldName, newValue }) => {
|
||||
acc[fieldName] = newValue ?? settings[fieldName];
|
||||
return acc;
|
||||
}, {} as Partial<Settings>),
|
||||
}
|
||||
: null;
|
||||
setSettings(newSettings);
|
||||
|
||||
try {
|
||||
const response = await fetch("/api/admin/settings", {
|
||||
method: "PUT",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify(newSettings),
|
||||
});
|
||||
|
||||
if (!response.ok) {
|
||||
const errorMsg = (await response.json()).detail;
|
||||
throw new Error(errorMsg);
|
||||
}
|
||||
|
||||
router.refresh();
|
||||
toast.success("Settings updated successfully!");
|
||||
} catch (error) {
|
||||
// Revert the optimistic update
|
||||
setSettings(settings);
|
||||
console.error("Error updating settings:", error);
|
||||
toast.error("Failed to update settings");
|
||||
}
|
||||
}
|
||||
|
||||
function handleToggleSettingsField(
|
||||
fieldName: keyof Settings,
|
||||
checked: boolean
|
||||
) {
|
||||
if (fieldName === "anonymous_user_enabled" && checked) {
|
||||
setShowConfirmModal(true);
|
||||
} else {
|
||||
const updates: { fieldName: keyof Settings; newValue: any }[] = [
|
||||
{ fieldName, newValue: checked },
|
||||
];
|
||||
updateSettingField(updates);
|
||||
}
|
||||
}
|
||||
|
||||
function handleConfirmAnonymousUsers() {
|
||||
const updates: { fieldName: keyof Settings; newValue: any }[] = [
|
||||
{ fieldName: "anonymous_user_enabled", newValue: true },
|
||||
];
|
||||
updateSettingField(updates);
|
||||
setShowConfirmModal(false);
|
||||
}
|
||||
|
||||
function handleSetChatRetention() {
|
||||
const newValue = chatRetention === "" ? null : parseInt(chatRetention, 10);
|
||||
updateSettingField([
|
||||
{ fieldName: "maximum_chat_retention_days", newValue },
|
||||
]);
|
||||
}
|
||||
|
||||
function handleClearChatRetention() {
|
||||
setChatRetention("");
|
||||
updateSettingField([
|
||||
{ fieldName: "maximum_chat_retention_days", newValue: null },
|
||||
]);
|
||||
}
|
||||
|
||||
function handleCompanyNameBlur() {
|
||||
const originalValue = settings?.company_name || "";
|
||||
if (companyName !== originalValue) {
|
||||
updateSettingField([
|
||||
{ fieldName: "company_name", newValue: companyName || null },
|
||||
]);
|
||||
}
|
||||
}
|
||||
|
||||
function handleCompanyDescriptionBlur() {
|
||||
const originalValue = settings?.company_description || "";
|
||||
if (companyDescription !== originalValue) {
|
||||
updateSettingField([
|
||||
{
|
||||
fieldName: "company_description",
|
||||
newValue: companyDescription || null,
|
||||
},
|
||||
]);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<>
|
||||
<Title className="mb-4">Workspace Settings</Title>
|
||||
<label className="flex flex-col text-sm mb-4">
|
||||
<Label>Company Name</Label>
|
||||
<SubLabel>
|
||||
Set the company name used for search and chat context.
|
||||
</SubLabel>
|
||||
<input
|
||||
type="text"
|
||||
className="mt-1 p-2 border rounded w-full max-w-xl"
|
||||
value={companyName}
|
||||
onChange={(e) => setCompanyName(e.target.value)}
|
||||
onBlur={handleCompanyNameBlur}
|
||||
placeholder="Enter company name"
|
||||
/>
|
||||
</label>
|
||||
|
||||
<label className="flex flex-col text-sm mb-4">
|
||||
<Label>Company Description</Label>
|
||||
<SubLabel>
|
||||
Provide a short description of the company for search and chat
|
||||
context.
|
||||
</SubLabel>
|
||||
<InputTextArea
|
||||
className="mt-1 w-full max-w-xl"
|
||||
value={companyDescription}
|
||||
onChange={(event) => setCompanyDescription(event.target.value)}
|
||||
onBlur={handleCompanyDescriptionBlur}
|
||||
placeholder="Enter company description"
|
||||
rows={4}
|
||||
/>
|
||||
</label>
|
||||
|
||||
<Checkbox
|
||||
label="Auto-scroll"
|
||||
sublabel="If set, the chat window will automatically scroll to the bottom as new lines of text are generated by the AI model. This can be overridden by individual user settings."
|
||||
checked={settings.auto_scroll}
|
||||
onChange={(e) =>
|
||||
handleToggleSettingsField("auto_scroll", e.target.checked)
|
||||
}
|
||||
/>
|
||||
<Checkbox
|
||||
label="Override default temperature"
|
||||
sublabel="If set, users will be able to override the default temperature for each assistant."
|
||||
checked={settings.temperature_override_enabled}
|
||||
onChange={(e) =>
|
||||
handleToggleSettingsField(
|
||||
"temperature_override_enabled",
|
||||
e.target.checked
|
||||
)
|
||||
}
|
||||
/>
|
||||
<Checkbox
|
||||
label="Anonymous Users"
|
||||
sublabel="If set, users will not be required to sign in to use Onyx."
|
||||
checked={settings.anonymous_user_enabled}
|
||||
onChange={(e) =>
|
||||
handleToggleSettingsField("anonymous_user_enabled", e.target.checked)
|
||||
}
|
||||
/>
|
||||
{showInviteOnlyModeToggle && (
|
||||
<Checkbox
|
||||
label="Whitelist / Invite-only"
|
||||
sublabel="If set, only users on the invite list can join this workspace. If unset, users from your normal sign-up domain flow can still join even if invites exist."
|
||||
checked={settings.invite_only_enabled}
|
||||
onChange={(e) =>
|
||||
handleToggleSettingsField("invite_only_enabled", e.target.checked)
|
||||
}
|
||||
/>
|
||||
)}
|
||||
|
||||
<Checkbox
|
||||
label="Deep Research"
|
||||
sublabel="Enables a button to run deep research - a more complex and time intensive flow. Note: this costs >10x more in tokens to normal questions."
|
||||
checked={settings.deep_research_enabled ?? true}
|
||||
onChange={(e) =>
|
||||
handleToggleSettingsField("deep_research_enabled", e.target.checked)
|
||||
}
|
||||
/>
|
||||
|
||||
<Checkbox
|
||||
label="Disable Default Assistant"
|
||||
sublabel="When enabled, the 'New Session' button will start a new chat with the current agent instead of the default assistant. The default assistant will be hidden from all users."
|
||||
checked={settings.disable_default_assistant ?? false}
|
||||
onChange={(e) =>
|
||||
handleToggleSettingsField(
|
||||
"disable_default_assistant",
|
||||
e.target.checked
|
||||
)
|
||||
}
|
||||
/>
|
||||
|
||||
{NEXT_PUBLIC_CLOUD_ENABLED && settings.anonymous_user_enabled && (
|
||||
<AnonymousUserPath />
|
||||
)}
|
||||
{showConfirmModal && (
|
||||
<Modal open onOpenChange={() => setShowConfirmModal(false)}>
|
||||
<Modal.Content>
|
||||
<Modal.Header
|
||||
icon={SvgAlertTriangle}
|
||||
title="Enable Anonymous Users"
|
||||
onClose={() => setShowConfirmModal(false)}
|
||||
/>
|
||||
<Modal.Body>
|
||||
<p>
|
||||
Are you sure you want to enable anonymous users? This will allow
|
||||
anyone to use Onyx without signing in.
|
||||
</p>
|
||||
</Modal.Body>
|
||||
<Modal.Footer>
|
||||
<Button secondary onClick={() => setShowConfirmModal(false)}>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button onClick={handleConfirmAnonymousUsers}>Confirm</Button>
|
||||
</Modal.Footer>
|
||||
</Modal.Content>
|
||||
</Modal>
|
||||
)}
|
||||
{isEnterpriseEnabled && (
|
||||
<>
|
||||
<Title className="mt-8 mb-4">Chat Settings</Title>
|
||||
<IntegerInput
|
||||
label="Chat Retention"
|
||||
sublabel="Enter the maximum number of days you would like Onyx to retain chat messages. Leaving this field empty will cause Onyx to never delete chat messages."
|
||||
value={chatRetention === "" ? null : Number(chatRetention)}
|
||||
onChange={(e) => {
|
||||
const numValue = parseInt(e.target.value, 10);
|
||||
if (numValue >= 1 || e.target.value === "") {
|
||||
setChatRetention(e.target.value);
|
||||
}
|
||||
}}
|
||||
id="chatRetentionInput"
|
||||
placeholder="Infinite Retention"
|
||||
/>
|
||||
<div className="mr-auto flex gap-2">
|
||||
<Button onClick={handleSetChatRetention} className="mr-auto">
|
||||
Set Retention Limit
|
||||
</Button>
|
||||
<Button onClick={handleClearChatRetention} className="mr-auto">
|
||||
Retain All
|
||||
</Button>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
|
||||
{/* Image Processing Settings */}
|
||||
<Title className="mt-8 mb-4">Image Processing</Title>
|
||||
|
||||
<div className="flex flex-col gap-2">
|
||||
<Checkbox
|
||||
label="Enable Image Extraction and Analysis"
|
||||
sublabel="Extract and analyze images from documents during indexing. This allows the system to process images and create searchable descriptions of them."
|
||||
checked={settings.image_extraction_and_analysis_enabled ?? false}
|
||||
onChange={(e) =>
|
||||
handleToggleSettingsField(
|
||||
"image_extraction_and_analysis_enabled",
|
||||
e.target.checked
|
||||
)
|
||||
}
|
||||
/>
|
||||
|
||||
<Checkbox
|
||||
label="Enable Search-time Image Analysis"
|
||||
sublabel="Analyze images at search time when a user asks about images. This provides more detailed and query-specific image analysis but may increase search-time latency."
|
||||
checked={settings.search_time_image_analysis_enabled ?? false}
|
||||
onChange={(e) =>
|
||||
handleToggleSettingsField(
|
||||
"search_time_image_analysis_enabled",
|
||||
e.target.checked
|
||||
)
|
||||
}
|
||||
/>
|
||||
|
||||
<IntegerInput
|
||||
label="Maximum Image Size for Analysis (MB)"
|
||||
sublabel="Images larger than this size will not be analyzed to prevent excessive resource usage."
|
||||
value={settings.image_analysis_max_size_mb ?? null}
|
||||
onChange={(e) => {
|
||||
const value = e.target.value ? parseInt(e.target.value) : null;
|
||||
if (value !== null && !isNaN(value) && value > 0) {
|
||||
updateSettingField([
|
||||
{ fieldName: "image_analysis_max_size_mb", newValue: value },
|
||||
]);
|
||||
}
|
||||
}}
|
||||
id="image-analysis-max-size"
|
||||
placeholder="Enter maximum size in MB"
|
||||
/>
|
||||
{/* Default Vision LLM Section */}
|
||||
<div className="mt-4">
|
||||
<Label>Default Vision LLM</Label>
|
||||
<SubLabel>
|
||||
Select the default LLM to use for image analysis. This model will be
|
||||
utilized during image indexing and at query time for search results,
|
||||
if the above settings are enabled.
|
||||
</SubLabel>
|
||||
|
||||
<div className="mt-2 max-w-xs">
|
||||
{!visionProviders || visionProviders.length === 0 ? (
|
||||
<div className="text-sm text-gray-500">
|
||||
No vision providers found. Please add a vision provider.
|
||||
</div>
|
||||
) : visionProviders.length > 0 ? (
|
||||
<>
|
||||
<LLMSelector
|
||||
userSettings={false}
|
||||
llmProviders={visionProviders.map((provider) => ({
|
||||
...provider,
|
||||
model_names: provider.vision_models,
|
||||
display_model_names: provider.vision_models,
|
||||
}))}
|
||||
currentLlm={visionLLM}
|
||||
onSelect={(value) => setVisionLLM(value)}
|
||||
/>
|
||||
<Button
|
||||
onClick={() => updateDefaultVisionProvider(visionLLM)}
|
||||
className="mt-2"
|
||||
>
|
||||
Set Default Vision LLM
|
||||
</Button>
|
||||
</>
|
||||
) : (
|
||||
<div className="text-sm text-gray-500">
|
||||
No vision-capable LLMs found. Please add an LLM provider that
|
||||
supports image input.
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
}
|
||||
@@ -1,109 +0,0 @@
|
||||
import { useState, useEffect, useCallback } from "react";
|
||||
import { VisionProvider } from "@/app/admin/configuration/llm/interfaces";
|
||||
import {
|
||||
fetchVisionProviders,
|
||||
setDefaultVisionProvider,
|
||||
} from "@/lib/llm/visionLLM";
|
||||
import { parseLlmDescriptor, structureValue } from "@/lib/llm/utils";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
|
||||
export function useVisionProviders() {
|
||||
const [visionProviders, setVisionProviders] = useState<VisionProvider[]>([]);
|
||||
const [visionLLM, setVisionLLM] = useState<string | null>(null);
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
const loadVisionProviders = useCallback(async () => {
|
||||
setIsLoading(true);
|
||||
setError(null);
|
||||
try {
|
||||
const data = await fetchVisionProviders();
|
||||
setVisionProviders(data);
|
||||
|
||||
// Find the default vision provider and set it
|
||||
const defaultProvider = data.find(
|
||||
(provider) => provider.is_default_vision_provider
|
||||
);
|
||||
|
||||
if (defaultProvider) {
|
||||
const modelToUse =
|
||||
defaultProvider.default_vision_model ||
|
||||
defaultProvider.default_model_name;
|
||||
|
||||
if (modelToUse && defaultProvider.vision_models.includes(modelToUse)) {
|
||||
setVisionLLM(
|
||||
structureValue(
|
||||
defaultProvider.name,
|
||||
defaultProvider.provider,
|
||||
modelToUse
|
||||
)
|
||||
);
|
||||
}
|
||||
}
|
||||
} catch (error) {
|
||||
console.error("Error fetching vision providers:", error);
|
||||
setError(
|
||||
error instanceof Error ? error.message : "Unknown error occurred"
|
||||
);
|
||||
toast.error(
|
||||
`Failed to load vision providers: ${
|
||||
error instanceof Error ? error.message : "Unknown error"
|
||||
}`
|
||||
);
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
}, []);
|
||||
|
||||
const updateDefaultVisionProvider = useCallback(
|
||||
async (llmValue: string | null) => {
|
||||
if (!llmValue) {
|
||||
toast.error("Please select a valid vision model");
|
||||
return false;
|
||||
}
|
||||
|
||||
try {
|
||||
const { name, modelName } = parseLlmDescriptor(llmValue);
|
||||
|
||||
// Find the provider ID
|
||||
const providerObj = visionProviders.find((p) => p.name === name);
|
||||
if (!providerObj) {
|
||||
throw new Error("Provider not found");
|
||||
}
|
||||
|
||||
await setDefaultVisionProvider(providerObj.id, modelName);
|
||||
|
||||
toast.success("Default vision provider updated successfully!");
|
||||
setVisionLLM(llmValue);
|
||||
|
||||
// Refresh the list to reflect the change
|
||||
await loadVisionProviders();
|
||||
return true;
|
||||
} catch (error: unknown) {
|
||||
console.error("Error setting default vision provider:", error);
|
||||
const errorMessage =
|
||||
error instanceof Error ? error.message : "Unknown error occurred";
|
||||
toast.error(
|
||||
`Failed to update default vision provider: ${errorMessage}`
|
||||
);
|
||||
return false;
|
||||
}
|
||||
},
|
||||
[visionProviders, loadVisionProviders]
|
||||
);
|
||||
|
||||
// Load providers on mount
|
||||
useEffect(() => {
|
||||
loadVisionProviders();
|
||||
}, [loadVisionProviders]);
|
||||
|
||||
return {
|
||||
visionProviders,
|
||||
visionLLM,
|
||||
isLoading,
|
||||
error,
|
||||
setVisionLLM,
|
||||
updateDefaultVisionProvider,
|
||||
refreshVisionProviders: loadVisionProviders,
|
||||
};
|
||||
}
|
||||
@@ -1,20 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { AdminPageTitle } from "@/components/admin/Title";
|
||||
import { SettingsForm } from "@/app/admin/settings/SettingsForm";
|
||||
import Text from "@/components/ui/text";
|
||||
import { SvgSettings } from "@opal/icons";
|
||||
|
||||
export default function Page() {
|
||||
return (
|
||||
<>
|
||||
<AdminPageTitle title="Workspace Settings" icon={SvgSettings} />
|
||||
|
||||
<Text className="mb-8">
|
||||
Manage general Onyx settings applicable to all users in the workspace.
|
||||
</Text>
|
||||
|
||||
<SettingsForm />
|
||||
</>
|
||||
);
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
import React, { JSX, memo } from "react";
|
||||
import {
|
||||
ChatPacket,
|
||||
ImageGenerationToolPacket,
|
||||
Packet,
|
||||
PacketType,
|
||||
ReasoningPacket,
|
||||
@@ -28,7 +29,7 @@ import { InternalSearchToolRenderer } from "./timeline/renderers/search/Internal
|
||||
import { SearchToolStart } from "../../services/streamingModels";
|
||||
|
||||
// Different types of chat packets using discriminated unions
|
||||
export interface GroupedPackets {
|
||||
interface GroupedPackets {
|
||||
packets: Packet[];
|
||||
}
|
||||
|
||||
@@ -153,6 +154,53 @@ export function findRenderer(
|
||||
return null;
|
||||
}
|
||||
|
||||
// Handles display groups containing both chat text and image generation packets
|
||||
function MixedContentHandler({
|
||||
chatPackets,
|
||||
imagePackets,
|
||||
chatState,
|
||||
onComplete,
|
||||
animate,
|
||||
stopPacketSeen,
|
||||
stopReason,
|
||||
children,
|
||||
}: {
|
||||
chatPackets: Packet[];
|
||||
imagePackets: Packet[];
|
||||
chatState: FullChatState;
|
||||
onComplete: () => void;
|
||||
animate: boolean;
|
||||
stopPacketSeen: boolean;
|
||||
stopReason?: StopReason;
|
||||
children: (result: RendererOutput) => JSX.Element;
|
||||
}) {
|
||||
return (
|
||||
<MessageTextRenderer
|
||||
packets={chatPackets as ChatPacket[]}
|
||||
state={chatState}
|
||||
onComplete={() => {}}
|
||||
animate={animate}
|
||||
renderType={RenderType.FULL}
|
||||
stopPacketSeen={stopPacketSeen}
|
||||
stopReason={stopReason}
|
||||
>
|
||||
{(textResults) => (
|
||||
<ImageToolRenderer
|
||||
packets={imagePackets as ImageGenerationToolPacket[]}
|
||||
state={chatState}
|
||||
onComplete={onComplete}
|
||||
animate={animate}
|
||||
renderType={RenderType.FULL}
|
||||
stopPacketSeen={stopPacketSeen}
|
||||
stopReason={stopReason}
|
||||
>
|
||||
{(imageResults) => children([...textResults, ...imageResults])}
|
||||
</ImageToolRenderer>
|
||||
)}
|
||||
</MessageTextRenderer>
|
||||
);
|
||||
}
|
||||
|
||||
// Props interface for RendererComponent
|
||||
interface RendererComponentProps {
|
||||
packets: Packet[];
|
||||
@@ -161,7 +209,6 @@ interface RendererComponentProps {
|
||||
animate: boolean;
|
||||
stopPacketSeen: boolean;
|
||||
stopReason?: StopReason;
|
||||
useShortRenderer?: boolean;
|
||||
children: (result: RendererOutput) => JSX.Element;
|
||||
}
|
||||
|
||||
@@ -175,7 +222,6 @@ function areRendererPropsEqual(
|
||||
prev.stopPacketSeen === next.stopPacketSeen &&
|
||||
prev.stopReason === next.stopReason &&
|
||||
prev.animate === next.animate &&
|
||||
prev.useShortRenderer === next.useShortRenderer &&
|
||||
prev.chatState.assistant?.id === next.chatState.assistant?.id
|
||||
// Skip: onComplete, children (function refs), chatState (memoized upstream)
|
||||
);
|
||||
@@ -189,11 +235,47 @@ export const RendererComponent = memo(function RendererComponent({
|
||||
animate,
|
||||
stopPacketSeen,
|
||||
stopReason,
|
||||
useShortRenderer = false,
|
||||
children,
|
||||
}: RendererComponentProps) {
|
||||
// Detect mixed display groups (both chat text and image generation)
|
||||
const hasChatPackets = packets.some((p) => isChatPacket(p));
|
||||
const hasImagePackets = packets.some((p) => isImageToolPacket(p));
|
||||
|
||||
if (hasChatPackets && hasImagePackets) {
|
||||
const sharedTypes = new Set<string>([
|
||||
PacketType.SECTION_END,
|
||||
PacketType.ERROR,
|
||||
]);
|
||||
|
||||
const chatPackets = packets.filter(
|
||||
(p) =>
|
||||
isChatPacket(p) ||
|
||||
p.obj.type === PacketType.CITATION_INFO ||
|
||||
sharedTypes.has(p.obj.type as string)
|
||||
);
|
||||
const imagePackets = packets.filter(
|
||||
(p) =>
|
||||
isImageToolPacket(p) ||
|
||||
p.obj.type === PacketType.IMAGE_GENERATION_TOOL_DELTA ||
|
||||
sharedTypes.has(p.obj.type as string)
|
||||
);
|
||||
|
||||
return (
|
||||
<MixedContentHandler
|
||||
chatPackets={chatPackets}
|
||||
imagePackets={imagePackets}
|
||||
chatState={chatState}
|
||||
onComplete={onComplete}
|
||||
animate={animate}
|
||||
stopPacketSeen={stopPacketSeen}
|
||||
stopReason={stopReason}
|
||||
>
|
||||
{children}
|
||||
</MixedContentHandler>
|
||||
);
|
||||
}
|
||||
|
||||
const RendererFn = findRenderer({ packets });
|
||||
const renderType = useShortRenderer ? RenderType.HIGHLIGHT : RenderType.FULL;
|
||||
|
||||
if (!RendererFn) {
|
||||
return children([{ icon: null, status: null, content: <></> }]);
|
||||
@@ -205,7 +287,7 @@ export const RendererComponent = memo(function RendererComponent({
|
||||
state={chatState}
|
||||
onComplete={onComplete}
|
||||
animate={animate}
|
||||
renderType={renderType}
|
||||
renderType={RenderType.FULL}
|
||||
stopPacketSeen={stopPacketSeen}
|
||||
stopReason={stopReason}
|
||||
>
|
||||
|
||||
@@ -25,7 +25,6 @@ export const CollapsedStreamingContent = React.memo(
|
||||
stopReason,
|
||||
renderTypeOverride,
|
||||
}: CollapsedStreamingContentProps) {
|
||||
const noopComplete = useCallback(() => {}, []);
|
||||
const renderContentOnly = useCallback(
|
||||
(results: TimelineRendererOutput) => (
|
||||
<>
|
||||
@@ -44,7 +43,6 @@ export const CollapsedStreamingContent = React.memo(
|
||||
key={`${step.key}-compact`}
|
||||
packets={step.packets}
|
||||
chatState={chatState}
|
||||
onComplete={noopComplete}
|
||||
animate={true}
|
||||
stopPacketSeen={false}
|
||||
stopReason={stopReason}
|
||||
|
||||
@@ -37,8 +37,6 @@ interface TimelineStepProps {
|
||||
isStreaming?: boolean;
|
||||
}
|
||||
|
||||
const noopCallback = () => {};
|
||||
|
||||
const TimelineStep = React.memo(function TimelineStep({
|
||||
step,
|
||||
chatState,
|
||||
@@ -104,7 +102,6 @@ const TimelineStep = React.memo(function TimelineStep({
|
||||
<TimelineRendererComponent
|
||||
packets={step.packets}
|
||||
chatState={chatState}
|
||||
onComplete={noopCallback}
|
||||
animate={!stopPacketSeen}
|
||||
stopPacketSeen={stopPacketSeen}
|
||||
stopReason={stopReason}
|
||||
|
||||
@@ -51,8 +51,6 @@ export function ParallelTimelineTabs({
|
||||
const handleToggle = useCallback(() => setIsExpanded((prev) => !prev), []);
|
||||
const handleHeaderEnter = useCallback(() => setIsHover(true), []);
|
||||
const handleHeaderLeave = useCallback(() => setIsHover(false), []);
|
||||
const noopComplete = useCallback(() => {}, []);
|
||||
|
||||
const topSpacerVariant = isFirstTurnGroup ? "first" : "none";
|
||||
const shouldShowResults = !(!isExpanded && stopPacketSeen);
|
||||
|
||||
@@ -165,7 +163,6 @@ export function ParallelTimelineTabs({
|
||||
key={`${activeTab}-${isExpanded}`}
|
||||
packets={activeStep.packets}
|
||||
chatState={chatState}
|
||||
onComplete={noopComplete}
|
||||
animate={!stopPacketSeen}
|
||||
stopPacketSeen={stopPacketSeen}
|
||||
stopReason={stopReason}
|
||||
|
||||
@@ -34,8 +34,6 @@ export interface TimelineRendererComponentProps {
|
||||
packets: Packet[];
|
||||
/** Chat state for rendering */
|
||||
chatState: FullChatState;
|
||||
/** Completion callback */
|
||||
onComplete: () => void;
|
||||
/** Whether to animate streaming */
|
||||
animate: boolean;
|
||||
/** Whether stop packet has been seen */
|
||||
@@ -77,7 +75,6 @@ export const TimelineRendererComponent = React.memo(
|
||||
function TimelineRendererComponent({
|
||||
packets,
|
||||
chatState,
|
||||
onComplete,
|
||||
animate,
|
||||
stopPacketSeen,
|
||||
stopReason,
|
||||
@@ -125,7 +122,7 @@ export const TimelineRendererComponent = React.memo(
|
||||
<RendererFn
|
||||
packets={packets as any}
|
||||
state={chatState}
|
||||
onComplete={onComplete}
|
||||
onComplete={() => {}}
|
||||
animate={animate}
|
||||
renderType={renderType}
|
||||
stopPacketSeen={stopPacketSeen}
|
||||
|
||||
@@ -349,11 +349,11 @@ function processPacket(state: ProcessorState, packet: Packet): void {
|
||||
if (isDisplayPacket(packet)) {
|
||||
state.displayGroupKeys.add(groupKey);
|
||||
}
|
||||
}
|
||||
|
||||
// Track image generation for header display
|
||||
if (packet.obj.type === PacketType.IMAGE_GENERATION_TOOL_START) {
|
||||
state.isGeneratingImage = true;
|
||||
}
|
||||
// Track image generation for header display (regardless of group position)
|
||||
if (packet.obj.type === PacketType.IMAGE_GENERATION_TOOL_START) {
|
||||
state.isGeneratingImage = true;
|
||||
}
|
||||
|
||||
// Count generated images from DELTA packets
|
||||
|
||||
@@ -169,8 +169,6 @@ export const ResearchAgentRenderer: MessageRenderer<
|
||||
);
|
||||
|
||||
// Stable callbacks to avoid creating new functions on every render
|
||||
const noopComplete = useCallback(() => {}, []);
|
||||
|
||||
// renderReport renders the processed content
|
||||
// Uses pre-computed processedReportContent since ExpandableTextDisplay
|
||||
// passes the same fullReportContent that we processed above
|
||||
@@ -221,7 +219,6 @@ export const ResearchAgentRenderer: MessageRenderer<
|
||||
key={latestGroup.sub_turn_index}
|
||||
packets={latestGroup.packets}
|
||||
chatState={state}
|
||||
onComplete={noopComplete}
|
||||
animate={!stopPacketSeen && !latestGroup.isComplete}
|
||||
stopPacketSeen={stopPacketSeen}
|
||||
defaultExpanded={false}
|
||||
@@ -327,7 +324,6 @@ export const ResearchAgentRenderer: MessageRenderer<
|
||||
key={group.sub_turn_index}
|
||||
packets={group.packets}
|
||||
chatState={state}
|
||||
onComplete={noopComplete}
|
||||
animate={!stopPacketSeen && !group.isComplete}
|
||||
stopPacketSeen={stopPacketSeen}
|
||||
defaultExpanded={true}
|
||||
|
||||
@@ -382,9 +382,8 @@ export default function NRFPage({ isSidePanel = false }: NRFPageProps) {
|
||||
<IconButton
|
||||
icon={SvgMenu}
|
||||
onClick={toggleSettings}
|
||||
tertiary
|
||||
secondary
|
||||
tooltip="Open settings"
|
||||
className="bg-mask-02 backdrop-blur-[12px] rounded-full shadow-01 hover:bg-mask-03"
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
|
||||
@@ -12,7 +12,7 @@ import { SettingsContext } from "@/providers/SettingsProvider";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { Formik, Form } from "formik";
|
||||
import * as Yup from "yup";
|
||||
import { EnterpriseSettings } from "@/app/admin/settings/interfaces";
|
||||
import { EnterpriseSettings } from "@/interfaces/settings";
|
||||
import { useRouter } from "next/navigation";
|
||||
|
||||
const CHAR_LIMITS = {
|
||||
|
||||
@@ -14,10 +14,7 @@ import {
|
||||
import { Metadata } from "next";
|
||||
import { buildClientUrl } from "@/lib/utilsSS";
|
||||
import { Inter } from "next/font/google";
|
||||
import {
|
||||
EnterpriseSettings,
|
||||
ApplicationStatus,
|
||||
} from "./admin/settings/interfaces";
|
||||
import { EnterpriseSettings, ApplicationStatus } from "@/interfaces/settings";
|
||||
import AppProvider from "@/providers/AppProvider";
|
||||
import { PHProvider } from "./providers";
|
||||
import { getAuthTypeMetadataSS, getCurrentUserSS } from "@/lib/userSS";
|
||||
|
||||
26
web/src/app/nrf/(main)/layout.tsx
Normal file
26
web/src/app/nrf/(main)/layout.tsx
Normal file
@@ -0,0 +1,26 @@
|
||||
import { unstable_noStore as noStore } from "next/cache";
|
||||
import AppSidebar from "@/sections/sidebar/AppSidebar";
|
||||
import { getCurrentUserSS } from "@/lib/userSS";
|
||||
|
||||
export interface LayoutProps {
|
||||
children: React.ReactNode;
|
||||
}
|
||||
|
||||
/**
|
||||
* NRF Main (New Tab) Layout
|
||||
*
|
||||
* Shows the app sidebar when the user is authenticated.
|
||||
* This layout is NOT used by the side-panel route.
|
||||
*/
|
||||
export default async function Layout({ children }: LayoutProps) {
|
||||
noStore();
|
||||
|
||||
const user = await getCurrentUserSS();
|
||||
|
||||
return (
|
||||
<div className="flex flex-row w-full h-full">
|
||||
{user && <AppSidebar />}
|
||||
{children}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
31
web/src/app/nrf/(main)/page.tsx
Normal file
31
web/src/app/nrf/(main)/page.tsx
Normal file
@@ -0,0 +1,31 @@
|
||||
import { unstable_noStore as noStore } from "next/cache";
|
||||
import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh";
|
||||
import NRFPage from "@/app/app/nrf/NRFPage";
|
||||
import { NRFPreferencesProvider } from "@/components/context/NRFPreferencesContext";
|
||||
import NRFChrome from "../NRFChrome";
|
||||
|
||||
/**
|
||||
* NRF (New Tab Page) Route - No Auth Required
|
||||
*
|
||||
* This route is placed outside /app/app/ to bypass the authentication
|
||||
* requirement in /app/app/layout.tsx. The NRFPage component handles
|
||||
* unauthenticated users gracefully by showing a login modal instead of
|
||||
* redirecting, which is better UX for the Chrome extension.
|
||||
*
|
||||
* Instead of AppLayouts.Root (which pulls in heavy Header state management),
|
||||
* we use NRFChrome — a lightweight overlay that renders only the search/chat
|
||||
* mode toggle and footer, floating transparently over NRFPage's background.
|
||||
*/
|
||||
export default async function Page() {
|
||||
noStore();
|
||||
|
||||
return (
|
||||
<div className="relative w-full h-full">
|
||||
<InstantSSRAutoRefresh />
|
||||
<NRFPreferencesProvider>
|
||||
<NRFPage />
|
||||
</NRFPreferencesProvider>
|
||||
<NRFChrome />
|
||||
</div>
|
||||
);
|
||||
}
|
||||
148
web/src/app/nrf/NRFChrome.tsx
Normal file
148
web/src/app/nrf/NRFChrome.tsx
Normal file
@@ -0,0 +1,148 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { cn, ensureHrefProtocol, noProp } from "@/lib/utils";
|
||||
import type { Components } from "react-markdown";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import Popover from "@/refresh-components/Popover";
|
||||
import { OpenButton } from "@opal/components";
|
||||
import LineItem from "@/refresh-components/buttons/LineItem";
|
||||
import IconButton from "@/refresh-components/buttons/IconButton";
|
||||
import { SvgBubbleText, SvgSearchMenu, SvgSidebar } from "@opal/icons";
|
||||
import MinimalMarkdown from "@/components/chat/MinimalMarkdown";
|
||||
import { useSettingsContext } from "@/providers/SettingsProvider";
|
||||
import { AppMode, useAppMode } from "@/providers/AppModeProvider";
|
||||
import useAppFocus from "@/hooks/useAppFocus";
|
||||
import { useQueryController } from "@/providers/QueryControllerProvider";
|
||||
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
|
||||
import { useAppSidebarContext } from "@/providers/AppSidebarProvider";
|
||||
import useScreenSize from "@/hooks/useScreenSize";
|
||||
|
||||
const footerMarkdownComponents = {
|
||||
p: ({ children }: { children?: React.ReactNode }) => (
|
||||
<Text as="p" text03 secondaryAction className="!my-0 text-center">
|
||||
{children}
|
||||
</Text>
|
||||
),
|
||||
a: ({
|
||||
href,
|
||||
className,
|
||||
children,
|
||||
...rest
|
||||
}: React.AnchorHTMLAttributes<HTMLAnchorElement>) => {
|
||||
const fullHref = ensureHrefProtocol(href);
|
||||
return (
|
||||
<a
|
||||
href={fullHref}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
{...rest}
|
||||
className={cn(className, "underline underline-offset-2")}
|
||||
>
|
||||
<Text text03 secondaryAction>
|
||||
{children}
|
||||
</Text>
|
||||
</a>
|
||||
);
|
||||
},
|
||||
} satisfies Partial<Components>;
|
||||
|
||||
/**
|
||||
* Lightweight chrome overlay for the NRF page.
|
||||
*
|
||||
* Renders only the search/chat mode toggle (top-left) and footer (bottom),
|
||||
* absolutely positioned so they float transparently over NRFPage's own
|
||||
* background. This avoids pulling in the full AppLayouts.Root Header which
|
||||
* carries heavy state management (share/delete/move modals) that the
|
||||
* extension doesn't need.
|
||||
*/
|
||||
export default function NRFChrome() {
|
||||
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
|
||||
const { appMode, setAppMode } = useAppMode();
|
||||
const settings = useSettingsContext();
|
||||
const { isMobile } = useScreenSize();
|
||||
const { setFolded } = useAppSidebarContext();
|
||||
const appFocus = useAppFocus();
|
||||
const { classification } = useQueryController();
|
||||
const [modePopoverOpen, setModePopoverOpen] = useState(false);
|
||||
|
||||
const effectiveMode: AppMode = appFocus.isNewSession() ? appMode : "chat";
|
||||
|
||||
const customFooterContent =
|
||||
settings?.enterpriseSettings?.custom_lower_disclaimer_content ||
|
||||
`[Onyx ${
|
||||
settings?.webVersion || "dev"
|
||||
}](https://www.onyx.app/) - Open Source AI Platform`;
|
||||
|
||||
const showModeToggle =
|
||||
isPaidEnterpriseFeaturesEnabled &&
|
||||
appFocus.isNewSession() &&
|
||||
!classification;
|
||||
|
||||
const showHeader = isMobile || showModeToggle;
|
||||
|
||||
return (
|
||||
<>
|
||||
{/* Header chrome — top-left, mirrors position of settings button at top-right */}
|
||||
{showHeader && (
|
||||
<div className="absolute top-0 left-0 p-4 z-10 flex flex-row items-center gap-2">
|
||||
{isMobile && (
|
||||
<IconButton
|
||||
icon={SvgSidebar}
|
||||
onClick={() => setFolded(false)}
|
||||
internal
|
||||
/>
|
||||
)}
|
||||
{showModeToggle && (
|
||||
<Popover open={modePopoverOpen} onOpenChange={setModePopoverOpen}>
|
||||
<Popover.Trigger asChild>
|
||||
<OpenButton
|
||||
icon={
|
||||
effectiveMode === "search" ? SvgSearchMenu : SvgBubbleText
|
||||
}
|
||||
>
|
||||
{effectiveMode === "search" ? "Search" : "Chat"}
|
||||
</OpenButton>
|
||||
</Popover.Trigger>
|
||||
<Popover.Content align="start" width="lg">
|
||||
<Popover.Menu>
|
||||
<LineItem
|
||||
icon={SvgSearchMenu}
|
||||
selected={effectiveMode === "search"}
|
||||
description="Quick search for documents"
|
||||
onClick={noProp(() => {
|
||||
setAppMode("search");
|
||||
setModePopoverOpen(false);
|
||||
})}
|
||||
>
|
||||
Search
|
||||
</LineItem>
|
||||
<LineItem
|
||||
icon={SvgBubbleText}
|
||||
selected={effectiveMode === "chat"}
|
||||
description="Conversation and research"
|
||||
onClick={noProp(() => {
|
||||
setAppMode("chat");
|
||||
setModePopoverOpen(false);
|
||||
})}
|
||||
>
|
||||
Chat
|
||||
</LineItem>
|
||||
</Popover.Menu>
|
||||
</Popover.Content>
|
||||
</Popover>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{/* Footer — bottom-center, transparent background */}
|
||||
<footer className="absolute bottom-0 left-0 w-full z-10 flex flex-row justify-center items-center gap-2 px-2 pb-2 pointer-events-auto">
|
||||
<MinimalMarkdown
|
||||
content={customFooterContent}
|
||||
className="max-w-full text-center"
|
||||
components={footerMarkdownComponents}
|
||||
/>
|
||||
</footer>
|
||||
</>
|
||||
);
|
||||
}
|
||||
15
web/src/app/nrf/layout.tsx
Normal file
15
web/src/app/nrf/layout.tsx
Normal file
@@ -0,0 +1,15 @@
|
||||
import { ProjectsProvider } from "@/providers/ProjectsContext";
|
||||
|
||||
export interface LayoutProps {
|
||||
children: React.ReactNode;
|
||||
}
|
||||
|
||||
/**
|
||||
* NRF Root Layout - Shared by all NRF routes
|
||||
*
|
||||
* Provides ProjectsProvider (needed by NRFPage) without auth redirect.
|
||||
* Sidebar and chrome are handled by sub-layouts / individual pages.
|
||||
*/
|
||||
export default function Layout({ children }: LayoutProps) {
|
||||
return <ProjectsProvider>{children}</ProjectsProvider>;
|
||||
}
|
||||
24
web/src/app/nrf/side-panel/page.tsx
Normal file
24
web/src/app/nrf/side-panel/page.tsx
Normal file
@@ -0,0 +1,24 @@
|
||||
import { unstable_noStore as noStore } from "next/cache";
|
||||
import { InstantSSRAutoRefresh } from "@/components/SSRAutoRefresh";
|
||||
import NRFPage from "@/app/app/nrf/NRFPage";
|
||||
import { NRFPreferencesProvider } from "@/components/context/NRFPreferencesContext";
|
||||
|
||||
/**
|
||||
* NRF Side Panel Route - No Auth Required
|
||||
*
|
||||
* Side panel variant — no NRFChrome overlay needed since the side panel
|
||||
* has its own header (logo + "Open in Onyx" button) and doesn't show
|
||||
* the mode toggle or footer.
|
||||
*/
|
||||
export default async function Page() {
|
||||
noStore();
|
||||
|
||||
return (
|
||||
<>
|
||||
<InstantSSRAutoRefresh />
|
||||
<NRFPreferencesProvider>
|
||||
<NRFPage isSidePanel />
|
||||
</NRFPreferencesProvider>
|
||||
</>
|
||||
);
|
||||
}
|
||||
@@ -3,8 +3,9 @@
|
||||
import AdminSidebar from "@/sections/sidebar/AdminSidebar";
|
||||
import { usePathname } from "next/navigation";
|
||||
import { useSettingsContext } from "@/providers/SettingsProvider";
|
||||
import { ApplicationStatus } from "@/app/admin/settings/interfaces";
|
||||
import { ApplicationStatus } from "@/interfaces/settings";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
export interface ClientLayoutProps {
|
||||
children: React.ReactNode;
|
||||
@@ -12,6 +13,22 @@ export interface ClientLayoutProps {
|
||||
enableCloud: boolean;
|
||||
}
|
||||
|
||||
// TODO (@raunakab): Migrate ALL admin pages to use SettingsLayouts from
|
||||
// `@/layouts/settings-layouts`. Once every page manages its own layout,
|
||||
// the `py-10 px-4 md:px-12` padding below can be removed entirely and
|
||||
// this prefix list can be deleted.
|
||||
const SETTINGS_LAYOUT_PREFIXES = [
|
||||
"/admin/configuration/chat-preferences",
|
||||
"/admin/configuration/image-generation",
|
||||
"/admin/configuration/web-search",
|
||||
"/admin/actions/mcp",
|
||||
"/admin/actions/open-api",
|
||||
"/admin/billing",
|
||||
"/admin/document-index-migration",
|
||||
"/admin/discord-bot",
|
||||
"/admin/theme",
|
||||
];
|
||||
|
||||
export function ClientLayout({
|
||||
children,
|
||||
enableEnterprise,
|
||||
@@ -26,6 +43,11 @@ export function ClientLayout({
|
||||
pathname.startsWith("/admin/connectors") ||
|
||||
pathname.startsWith("/admin/embeddings");
|
||||
|
||||
// Pages using SettingsLayouts handle their own padding/centering.
|
||||
const hasOwnLayout = SETTINGS_LAYOUT_PREFIXES.some((prefix) =>
|
||||
pathname.startsWith(prefix)
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="h-screen w-screen flex overflow-hidden">
|
||||
{settings.settings.application_status ===
|
||||
@@ -49,7 +71,12 @@ export function ClientLayout({
|
||||
enableCloudSS={enableCloud}
|
||||
enableEnterpriseSS={enableEnterprise}
|
||||
/>
|
||||
<div className="flex flex-1 flex-col min-w-0 min-h-0 overflow-y-auto py-10 px-4 md:px-12">
|
||||
<div
|
||||
className={cn(
|
||||
"flex flex-1 flex-col min-w-0 min-h-0 overflow-y-auto",
|
||||
!hasOwnLayout && "py-10 px-4 md:px-12"
|
||||
)}
|
||||
>
|
||||
{children}
|
||||
</div>
|
||||
</>
|
||||
|
||||
@@ -1,228 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import React, { memo } from "react";
|
||||
import { FastField, useFormikContext } from "formik";
|
||||
import { TextFormField } from "@/components/Field";
|
||||
import { FiChevronRight, FiChevronDown } from "react-icons/fi";
|
||||
import Checkbox from "@/refresh-components/inputs/Checkbox";
|
||||
import { DOCS_ADMINS_PATH } from "@/lib/constants";
|
||||
|
||||
const MAX_DESCRIPTION_LENGTH = 600;
|
||||
import { useState, useEffect } from "react";
|
||||
|
||||
// Isolated Name Field that only re-renders when its value changes
|
||||
export const NameField = memo(function NameField() {
|
||||
return (
|
||||
<FastField name="name">
|
||||
{({ field }: any) => (
|
||||
<TextFormField
|
||||
{...field}
|
||||
maxWidth="max-w-lg"
|
||||
name="name"
|
||||
label="Name"
|
||||
placeholder="Email Assistant"
|
||||
aria-label="assistant-name-input"
|
||||
className="[&_input]:placeholder:text-text-muted/50"
|
||||
/>
|
||||
)}
|
||||
</FastField>
|
||||
);
|
||||
});
|
||||
|
||||
// Isolated Description Field
|
||||
export const DescriptionField = memo(function DescriptionField() {
|
||||
return (
|
||||
<FastField name="description">
|
||||
{({ field }: any) => (
|
||||
<TextFormField
|
||||
{...field}
|
||||
maxWidth="max-w-lg"
|
||||
name="description"
|
||||
label="Description"
|
||||
placeholder="Use this Assistant to help draft professional emails"
|
||||
className="[&_input]:placeholder:text-text-muted/50"
|
||||
/>
|
||||
)}
|
||||
</FastField>
|
||||
);
|
||||
});
|
||||
|
||||
// Isolated System Prompt Field
|
||||
export const SystemPromptField = memo(function SystemPromptField() {
|
||||
return (
|
||||
<FastField name="system_prompt">
|
||||
{({ field }: any) => (
|
||||
<TextFormField
|
||||
{...field}
|
||||
maxWidth="max-w-4xl"
|
||||
name="system_prompt"
|
||||
label="Instructions"
|
||||
isTextArea={true}
|
||||
placeholder="You are a professional email writing assistant that always uses a polite enthusiastic tone, emphasizes action items, and leaves blanks for the human to fill in when you have unknowns"
|
||||
data-testid="assistant-instructions-input"
|
||||
className="[&_textarea]:placeholder:text-text-muted/50"
|
||||
/>
|
||||
)}
|
||||
</FastField>
|
||||
);
|
||||
});
|
||||
|
||||
// Isolated Task Prompt Field
|
||||
export const TaskPromptField = memo(function TaskPromptField() {
|
||||
return (
|
||||
<FastField name="task_prompt">
|
||||
{({ field, form }: any) => (
|
||||
<TextFormField
|
||||
{...field}
|
||||
maxWidth="max-w-4xl"
|
||||
name="task_prompt"
|
||||
label="[Optional] Reminders"
|
||||
isTextArea={true}
|
||||
placeholder="Remember to reference all of the points mentioned in my message to you and focus on identifying action items that can move things forward"
|
||||
onChange={(e: React.ChangeEvent<HTMLTextAreaElement>) => {
|
||||
form.setFieldValue("task_prompt", e.target.value);
|
||||
}}
|
||||
explanationText="Learn about prompting in our docs!"
|
||||
explanationLink={`${DOCS_ADMINS_PATH}/agents/overview`}
|
||||
className="[&_textarea]:placeholder:text-text-muted/50"
|
||||
/>
|
||||
)}
|
||||
</FastField>
|
||||
);
|
||||
});
|
||||
|
||||
// Memoized MCP Server Section that only re-renders when its specific data changes
|
||||
export const MCPServerSection = memo(function MCPServerSection({
|
||||
serverId,
|
||||
serverTools,
|
||||
serverName,
|
||||
serverUrl,
|
||||
isCollapsed,
|
||||
onToggleCollapse,
|
||||
onToggleServerTools,
|
||||
}: {
|
||||
serverId: number;
|
||||
serverTools: any[];
|
||||
serverName: string;
|
||||
serverUrl: string;
|
||||
isCollapsed: boolean;
|
||||
onToggleCollapse: (serverId: number) => void;
|
||||
onToggleServerTools: () => void;
|
||||
}) {
|
||||
const { values } = useFormikContext<any>();
|
||||
const [expandedToolDescriptions, setExpandedToolDescriptions] = useState<
|
||||
Record<number, boolean>
|
||||
>({});
|
||||
|
||||
// Calculate checkbox state locally
|
||||
const enabledCount = serverTools.filter(
|
||||
(tool) => values.enabled_tools_map[tool.id]
|
||||
).length;
|
||||
|
||||
const checkboxState =
|
||||
enabledCount === 0
|
||||
? false
|
||||
: enabledCount === serverTools.length
|
||||
? true
|
||||
: "indeterminate";
|
||||
|
||||
return (
|
||||
<div
|
||||
className="border rounded-lg p-4 space-y-3 dark:border-gray-700"
|
||||
data-testid={`mcp-server-section-${serverId}`}
|
||||
>
|
||||
<div className="flex items-center space-x-3">
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => onToggleCollapse(serverId)}
|
||||
className="flex-shrink-0 p-1 hover:bg-gray-100 dark:hover:bg-gray-700 rounded transition-colors"
|
||||
data-testid={`mcp-server-toggle-${serverId}`}
|
||||
aria-expanded={!isCollapsed}
|
||||
>
|
||||
{isCollapsed ? (
|
||||
<FiChevronRight className="w-4 h-4 text-gray-600 dark:text-gray-400" />
|
||||
) : (
|
||||
<FiChevronDown className="w-4 h-4 text-gray-600 dark:text-gray-400" />
|
||||
)}
|
||||
</button>
|
||||
<Checkbox
|
||||
checked={checkboxState === true}
|
||||
indeterminate={checkboxState === "indeterminate"}
|
||||
onCheckedChange={onToggleServerTools}
|
||||
aria-label="mcp-server-select-all-tools-checkbox"
|
||||
/>
|
||||
<div className="flex-grow">
|
||||
<div className="font-medium text-sm text-gray-900 dark:text-gray-100">
|
||||
{serverName}
|
||||
</div>
|
||||
<div className="text-xs text-gray-600 dark:text-gray-400">
|
||||
{serverUrl} ({serverTools.length} tools)
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
{!isCollapsed && (
|
||||
<div className="ml-7 space-y-2">
|
||||
{serverTools.map((tool) => (
|
||||
<FastField
|
||||
key={`${tool.id}-${
|
||||
expandedToolDescriptions[tool.id] ? "expanded" : "collapsed"
|
||||
}`}
|
||||
name={`enabled_tools_map.${tool.id}`}
|
||||
>
|
||||
{({ field, form }: any) => (
|
||||
<label className="flex items-center space-x-2">
|
||||
<div className="pt-0.5">
|
||||
<Checkbox
|
||||
checked={field.value || false}
|
||||
onCheckedChange={(checked) => {
|
||||
form.setFieldValue(field.name, checked);
|
||||
}}
|
||||
aria-label={`mcp-server-tool-checkbox-${tool.display_name}`}
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<div className="text-sm font-medium">
|
||||
{tool.display_name}
|
||||
</div>
|
||||
<div className="text-xs text-gray-600">
|
||||
{tool.description &&
|
||||
tool.description.length > MAX_DESCRIPTION_LENGTH ? (
|
||||
<>
|
||||
{expandedToolDescriptions[tool.id]
|
||||
? tool.description
|
||||
: `${tool.description.slice(
|
||||
0,
|
||||
MAX_DESCRIPTION_LENGTH
|
||||
)}... `}
|
||||
<button
|
||||
type="button"
|
||||
className="ml-1 text-blue-500 underline text-xs focus:outline-none"
|
||||
onClick={() =>
|
||||
setExpandedToolDescriptions(
|
||||
(prev: Record<number, boolean>) => ({
|
||||
...prev,
|
||||
[tool.id]: !prev[tool.id],
|
||||
})
|
||||
)
|
||||
}
|
||||
tabIndex={0}
|
||||
>
|
||||
{expandedToolDescriptions[tool.id]
|
||||
? "Show less"
|
||||
: "Expand"}
|
||||
</button>
|
||||
</>
|
||||
) : (
|
||||
tool.description
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</label>
|
||||
)}
|
||||
</FastField>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
});
|
||||
@@ -1,76 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import React, { memo } from "react";
|
||||
import { BooleanFormField } from "@/components/Field";
|
||||
import { ToolSnapshot } from "@/lib/tools/interfaces";
|
||||
import { FastField } from "formik";
|
||||
const MAX_DESCRIPTION_LENGTH = 300;
|
||||
|
||||
// Memoized individual tool checkbox - only re-renders when its specific props change
|
||||
const MemoizedToolCheckbox = memo(function MemoizedToolCheckbox({
|
||||
toolId,
|
||||
displayName,
|
||||
description,
|
||||
}: {
|
||||
toolId: number;
|
||||
displayName: string;
|
||||
description: string;
|
||||
}) {
|
||||
return (
|
||||
<FastField name={`enabled_tools_map.${toolId}`}>
|
||||
{() => (
|
||||
<BooleanFormField
|
||||
name={`enabled_tools_map.${toolId}`}
|
||||
label={displayName}
|
||||
subtext={description}
|
||||
/>
|
||||
)}
|
||||
</FastField>
|
||||
);
|
||||
});
|
||||
|
||||
// Memoized tool list component
|
||||
export const MemoizedToolList = memo(function MemoizedToolList({
|
||||
tools,
|
||||
}: {
|
||||
tools: ToolSnapshot[];
|
||||
}) {
|
||||
return (
|
||||
<>
|
||||
{tools.map((tool) => (
|
||||
<MemoizedToolCheckbox
|
||||
key={tool.id}
|
||||
toolId={tool.id}
|
||||
displayName={tool.display_name}
|
||||
description={
|
||||
tool.description && tool.description.length > MAX_DESCRIPTION_LENGTH
|
||||
? tool.description.slice(0, MAX_DESCRIPTION_LENGTH) + "…"
|
||||
: tool.description
|
||||
}
|
||||
/>
|
||||
))}
|
||||
</>
|
||||
);
|
||||
});
|
||||
|
||||
// Memoized MCP server tools section
|
||||
export const MemoizedMCPServerTools = memo(function MemoizedMCPServerTools({
|
||||
serverId,
|
||||
serverTools,
|
||||
}: {
|
||||
serverId: number;
|
||||
serverTools: ToolSnapshot[];
|
||||
}) {
|
||||
return (
|
||||
<div className="ml-7 space-y-2">
|
||||
{serverTools.map((tool) => (
|
||||
<MemoizedToolCheckbox
|
||||
key={tool.id}
|
||||
toolId={tool.id}
|
||||
displayName={tool.display_name}
|
||||
description={tool.description}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
});
|
||||
@@ -1,304 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import React, {
|
||||
useMemo,
|
||||
useCallback,
|
||||
useState,
|
||||
useRef,
|
||||
useEffect,
|
||||
} from "react";
|
||||
import { BooleanFormField } from "@/components/Field";
|
||||
import { ToolSnapshot, MCPServer } from "@/lib/tools/interfaces";
|
||||
import { MCPServerSection } from "./FormSections";
|
||||
import { MemoizedToolList } from "./MemoizedToolCheckboxes";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import {
|
||||
SEARCH_TOOL_ID,
|
||||
WEB_SEARCH_TOOL_ID,
|
||||
IMAGE_GENERATION_TOOL_ID,
|
||||
PYTHON_TOOL_ID,
|
||||
OPEN_URL_TOOL_ID,
|
||||
FILE_READER_TOOL_ID,
|
||||
} from "@/app/app/components/tools/constants";
|
||||
import { HoverPopup } from "@/components/HoverPopup";
|
||||
import { Info } from "lucide-react";
|
||||
|
||||
interface ToolSelectorProps {
|
||||
tools: ToolSnapshot[];
|
||||
mcpServers?: MCPServer[];
|
||||
enabledToolsMap: { [key: number]: boolean };
|
||||
setFieldValue?: (field: string, value: any) => void;
|
||||
imageGenerationDisabled?: boolean;
|
||||
imageGenerationDisabledTooltip?: string;
|
||||
searchToolDisabled?: boolean;
|
||||
searchToolDisabledTooltip?: string;
|
||||
hideSearchTool?: boolean;
|
||||
}
|
||||
|
||||
export function ToolSelector({
|
||||
tools,
|
||||
mcpServers = [],
|
||||
enabledToolsMap,
|
||||
setFieldValue,
|
||||
imageGenerationDisabled = false,
|
||||
imageGenerationDisabledTooltip,
|
||||
searchToolDisabled = false,
|
||||
searchToolDisabledTooltip,
|
||||
hideSearchTool = false,
|
||||
}: ToolSelectorProps) {
|
||||
const searchTool = tools.find((t) => t.in_code_tool_id === SEARCH_TOOL_ID);
|
||||
const webSearchTool = tools.find(
|
||||
(t) => t.in_code_tool_id === WEB_SEARCH_TOOL_ID
|
||||
);
|
||||
const imageGenerationTool = tools.find(
|
||||
(t) => t.in_code_tool_id === IMAGE_GENERATION_TOOL_ID
|
||||
);
|
||||
const pythonTool = tools.find((t) => t.in_code_tool_id === PYTHON_TOOL_ID);
|
||||
const openUrlTool = tools.find((t) => t.in_code_tool_id === OPEN_URL_TOOL_ID);
|
||||
const fileReaderTool = tools.find(
|
||||
(t) => t.in_code_tool_id === FILE_READER_TOOL_ID
|
||||
);
|
||||
|
||||
// Check if Web Search is enabled - if so, OpenURL must be enabled
|
||||
const isWebSearchEnabled = webSearchTool && enabledToolsMap[webSearchTool.id];
|
||||
const isOpenUrlForced = isWebSearchEnabled;
|
||||
|
||||
const { mcpTools, customTools, mcpToolsByServer } = useMemo(() => {
|
||||
const allCustom = tools.filter(
|
||||
(tool) =>
|
||||
tool.in_code_tool_id !== SEARCH_TOOL_ID &&
|
||||
tool.in_code_tool_id !== IMAGE_GENERATION_TOOL_ID &&
|
||||
tool.in_code_tool_id !== WEB_SEARCH_TOOL_ID &&
|
||||
tool.in_code_tool_id !== PYTHON_TOOL_ID &&
|
||||
tool.in_code_tool_id !== OPEN_URL_TOOL_ID &&
|
||||
tool.in_code_tool_id !== FILE_READER_TOOL_ID
|
||||
);
|
||||
|
||||
const mcp = allCustom.filter((tool) => tool.mcp_server_id);
|
||||
const custom = allCustom.filter((tool) => !tool.mcp_server_id);
|
||||
|
||||
const groups: { [serverId: number]: ToolSnapshot[] } = {};
|
||||
mcp.forEach((tool) => {
|
||||
if (tool.mcp_server_id) {
|
||||
if (!groups[tool.mcp_server_id]) {
|
||||
groups[tool.mcp_server_id] = [];
|
||||
}
|
||||
groups[tool.mcp_server_id]!.push(tool);
|
||||
}
|
||||
});
|
||||
|
||||
return { mcpTools: mcp, customTools: custom, mcpToolsByServer: groups };
|
||||
}, [tools]);
|
||||
|
||||
const [collapsedServers, setCollapsedServers] = useState<Set<number>>(
|
||||
() => new Set(Object.keys(mcpToolsByServer).map((id) => parseInt(id, 10)))
|
||||
);
|
||||
|
||||
const seenServerIdsRef = useRef<Set<number>>(
|
||||
new Set(Object.keys(mcpToolsByServer).map((id) => parseInt(id, 10)))
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
const serverIds = Object.keys(mcpToolsByServer).map((id) =>
|
||||
parseInt(id, 10)
|
||||
);
|
||||
const unseenIds = serverIds.filter(
|
||||
(id) => !seenServerIdsRef.current.has(id)
|
||||
);
|
||||
|
||||
if (unseenIds.length === 0) return;
|
||||
|
||||
const updatedSeen = new Set(seenServerIdsRef.current);
|
||||
unseenIds.forEach((id) => updatedSeen.add(id));
|
||||
seenServerIdsRef.current = updatedSeen;
|
||||
|
||||
setCollapsedServers((prev) => {
|
||||
const next = new Set(prev);
|
||||
unseenIds.forEach((id) => next.add(id));
|
||||
return next;
|
||||
});
|
||||
}, [mcpToolsByServer]);
|
||||
|
||||
const toggleServerCollapse = useCallback((serverId: number) => {
|
||||
setCollapsedServers((prev) => {
|
||||
const next = new Set(prev);
|
||||
if (next.has(serverId)) {
|
||||
next.delete(serverId);
|
||||
} else {
|
||||
next.add(serverId);
|
||||
}
|
||||
return next;
|
||||
});
|
||||
}, []);
|
||||
|
||||
const toggleMCPServerTools = useCallback(
|
||||
(serverId: number) => {
|
||||
if (!setFieldValue) return;
|
||||
|
||||
const serverTools = mcpToolsByServer[serverId] || [];
|
||||
const enabledCount = serverTools.filter(
|
||||
(tool) => enabledToolsMap[tool.id]
|
||||
).length;
|
||||
const shouldEnable = enabledCount !== serverTools.length;
|
||||
|
||||
const updatedMap = { ...enabledToolsMap };
|
||||
serverTools.forEach((tool) => {
|
||||
updatedMap[tool.id] = shouldEnable;
|
||||
});
|
||||
|
||||
setFieldValue("enabled_tools_map", updatedMap);
|
||||
},
|
||||
[mcpToolsByServer, enabledToolsMap, setFieldValue]
|
||||
);
|
||||
|
||||
return (
|
||||
<div className="space-y-2">
|
||||
<div className="flex items-center gap-1.5 mb-2">
|
||||
<Text as="p" mainUiBody text04>
|
||||
Built-in Actions
|
||||
</Text>
|
||||
<HoverPopup
|
||||
mainContent={
|
||||
<Info className="h-3.5 w-3.5 text-text-400 cursor-help" />
|
||||
}
|
||||
popupContent={
|
||||
<div className="text-xs space-y-2 max-w-xs bg-background-neutral-dark-03 text-text-light-05">
|
||||
<div>
|
||||
<span className="font-semibold">Internal Search:</span> Requires
|
||||
at least one connector to be configured to search your
|
||||
organization's knowledge base.
|
||||
</div>
|
||||
<div>
|
||||
<span className="font-semibold">Web Search:</span> Configure a
|
||||
provider on the Web Search admin page to enable this tool.
|
||||
</div>
|
||||
<div>
|
||||
<span className="font-semibold">Image Generation:</span> Add an
|
||||
OpenAI LLM provider with an API key under Admin → Configuration
|
||||
→ LLM.
|
||||
</div>
|
||||
<div>
|
||||
<span className="font-semibold">Code Interpreter:</span>{" "}
|
||||
Requires the Code Interpreter service to be configured with a
|
||||
valid base URL.
|
||||
</div>
|
||||
<div>
|
||||
<span className="font-semibold">Open URL:</span> Open and read
|
||||
the content of URLs provided in the conversation.
|
||||
</div>
|
||||
</div>
|
||||
}
|
||||
direction="bottom"
|
||||
/>
|
||||
</div>
|
||||
{!hideSearchTool && searchTool && (
|
||||
<BooleanFormField
|
||||
name={`enabled_tools_map.${searchTool.id}`}
|
||||
label={searchTool.display_name}
|
||||
subtext="Search through your organization's knowledge base and documents"
|
||||
disabled={searchToolDisabled}
|
||||
disabledTooltip={searchToolDisabledTooltip}
|
||||
disabledTooltipSide="bottom"
|
||||
/>
|
||||
)}
|
||||
|
||||
{webSearchTool && (
|
||||
<BooleanFormField
|
||||
name={`enabled_tools_map.${webSearchTool.id}`}
|
||||
label={webSearchTool.display_name}
|
||||
subtext="Access real-time information and search the web for up-to-date results"
|
||||
onChange={(checked) => {
|
||||
// When enabling Web Search, also enable OpenURL
|
||||
if (checked && openUrlTool && setFieldValue) {
|
||||
setFieldValue(`enabled_tools_map.${openUrlTool.id}`, true);
|
||||
}
|
||||
}}
|
||||
/>
|
||||
)}
|
||||
|
||||
{openUrlTool && setFieldValue && (
|
||||
<BooleanFormField
|
||||
name={`enabled_tools_map.${openUrlTool.id}`}
|
||||
label="Open URL"
|
||||
subtext="Open and read the content of URLs provided in the conversation"
|
||||
disabled={isOpenUrlForced}
|
||||
disabledTooltip="Required for Web Search"
|
||||
disabledTooltipSide="bottom"
|
||||
/>
|
||||
)}
|
||||
|
||||
{imageGenerationTool && (
|
||||
<BooleanFormField
|
||||
name={`enabled_tools_map.${imageGenerationTool.id}`}
|
||||
label={imageGenerationTool.display_name}
|
||||
subtext="Generate and manipulate images using AI-powered tools."
|
||||
disabled={imageGenerationDisabled}
|
||||
disabledTooltip={imageGenerationDisabledTooltip}
|
||||
disabledTooltipSide="bottom"
|
||||
/>
|
||||
)}
|
||||
|
||||
{pythonTool && (
|
||||
<BooleanFormField
|
||||
name={`enabled_tools_map.${pythonTool.id}`}
|
||||
label={pythonTool.display_name}
|
||||
subtext={
|
||||
"Execute Python code in a secure, isolated environment to " +
|
||||
"analyze data, create visualizations, and perform computations"
|
||||
}
|
||||
/>
|
||||
)}
|
||||
|
||||
{fileReaderTool && (
|
||||
<BooleanFormField
|
||||
name={`enabled_tools_map.${fileReaderTool.id}`}
|
||||
label={fileReaderTool.display_name}
|
||||
subtext="Read sections of uploaded files. Required for files that exceed the context window."
|
||||
/>
|
||||
)}
|
||||
|
||||
{customTools.length > 0 && (
|
||||
<>
|
||||
<Text as="p" mainUiBody text04 className="mb-2">
|
||||
OpenAPI Actions
|
||||
</Text>
|
||||
<MemoizedToolList tools={customTools} />
|
||||
</>
|
||||
)}
|
||||
|
||||
{Object.keys(mcpToolsByServer).length > 0 && (
|
||||
<>
|
||||
<Text as="p" mainUiBody text04 className="mb-2">
|
||||
MCP Actions
|
||||
</Text>
|
||||
{Object.entries(mcpToolsByServer).map(([serverId, serverTools]) => {
|
||||
const serverIdNum = parseInt(serverId);
|
||||
const serverInfo =
|
||||
mcpServers.find((server) => server.id === serverIdNum) || null;
|
||||
const isCollapsed = collapsedServers.has(serverIdNum);
|
||||
|
||||
const firstTool = serverTools[0];
|
||||
const serverName =
|
||||
serverInfo?.name ||
|
||||
firstTool?.name?.split("_").slice(0, -1).join("_") ||
|
||||
`MCP Server ${serverId}`;
|
||||
const serverUrl = serverInfo?.server_url || "Unknown URL";
|
||||
|
||||
return (
|
||||
<MCPServerSection
|
||||
key={`mcp-server-${serverId}`}
|
||||
serverId={serverIdNum}
|
||||
serverTools={serverTools}
|
||||
serverName={serverName}
|
||||
serverUrl={serverUrl}
|
||||
isCollapsed={isCollapsed}
|
||||
onToggleCollapse={toggleServerCollapse}
|
||||
onToggleServerTools={() => toggleMCPServerTools(serverIdNum)}
|
||||
/>
|
||||
);
|
||||
})}
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -4,7 +4,7 @@ import {
|
||||
ApplicationStatus,
|
||||
Settings,
|
||||
QueryHistoryType,
|
||||
} from "@/app/admin/settings/interfaces";
|
||||
} from "@/interfaces/settings";
|
||||
import {
|
||||
CUSTOM_ANALYTICS_ENABLED,
|
||||
HOST_URL,
|
||||
@@ -126,6 +126,9 @@ export async function fetchSettingsSS(): Promise<CombinedSettings | null> {
|
||||
customAnalyticsScript,
|
||||
webVersion,
|
||||
webDomain: HOST_URL,
|
||||
// Server-side default; the real value is computed client-side in
|
||||
// SettingsProvider where connector data is available via useCCPairs.
|
||||
isSearchModeAvailable: settings.search_ui_enabled !== false,
|
||||
};
|
||||
|
||||
return combinedSettings;
|
||||
|
||||
@@ -49,7 +49,7 @@ export async function searchDocuments(
|
||||
const request: SendSearchQueryRequest = {
|
||||
search_query: query,
|
||||
filters: options?.filters,
|
||||
num_hits: options?.numHits ?? 50,
|
||||
num_hits: options?.numHits ?? 30,
|
||||
include_content: options?.includeContent ?? false,
|
||||
stream: false,
|
||||
};
|
||||
|
||||
@@ -4,6 +4,7 @@ import React, { useState, useCallback } from "react";
|
||||
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
|
||||
import { AppModeContext, AppMode } from "@/providers/AppModeProvider";
|
||||
import { useUser } from "@/providers/UserProvider";
|
||||
import { useSettingsContext } from "@/providers/SettingsProvider";
|
||||
|
||||
export interface AppModeProviderProps {
|
||||
children: React.ReactNode;
|
||||
@@ -17,14 +18,17 @@ export interface AppModeProviderProps {
|
||||
* - **chat**: Forces chat mode - conversation with follow-up questions
|
||||
*
|
||||
* The initial mode is read from the user's persisted `default_app_mode` preference.
|
||||
* When search mode is unavailable (admin setting or no connectors), the mode is locked to "chat".
|
||||
*/
|
||||
export function AppModeProvider({ children }: AppModeProviderProps) {
|
||||
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
|
||||
const { user } = useUser();
|
||||
const settings = useSettingsContext();
|
||||
const { isSearchModeAvailable } = settings;
|
||||
|
||||
const persistedMode = user?.preferences?.default_app_mode;
|
||||
const initialMode: AppMode =
|
||||
isPaidEnterpriseFeaturesEnabled && persistedMode
|
||||
isPaidEnterpriseFeaturesEnabled && isSearchModeAvailable && persistedMode
|
||||
? (persistedMode.toLowerCase() as AppMode)
|
||||
: "chat";
|
||||
|
||||
@@ -32,10 +36,10 @@ export function AppModeProvider({ children }: AppModeProviderProps) {
|
||||
|
||||
const setAppMode = useCallback(
|
||||
(mode: AppMode) => {
|
||||
if (!isPaidEnterpriseFeaturesEnabled) return;
|
||||
if (!isPaidEnterpriseFeaturesEnabled || !isSearchModeAvailable) return;
|
||||
setAppModeState(mode);
|
||||
},
|
||||
[isPaidEnterpriseFeaturesEnabled]
|
||||
[isPaidEnterpriseFeaturesEnabled, isSearchModeAvailable]
|
||||
);
|
||||
|
||||
return (
|
||||
|
||||
@@ -11,6 +11,7 @@ import { classifyQuery, searchDocuments } from "@/ee/lib/search/svc";
|
||||
import { useAppMode } from "@/providers/AppModeProvider";
|
||||
import useAppFocus from "@/hooks/useAppFocus";
|
||||
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
|
||||
import { useSettingsContext } from "@/providers/SettingsProvider";
|
||||
import {
|
||||
QueryControllerContext,
|
||||
QueryClassification,
|
||||
@@ -27,6 +28,8 @@ export function QueryControllerProvider({
|
||||
const { appMode, setAppMode } = useAppMode();
|
||||
const appFocus = useAppFocus();
|
||||
const isPaidEnterpriseFeaturesEnabled = usePaidEnterpriseFeaturesEnabled();
|
||||
const settings = useSettingsContext();
|
||||
const { isSearchModeAvailable: searchUiEnabled } = settings;
|
||||
|
||||
// Query state
|
||||
const [query, setQuery] = useState<string | null>(null);
|
||||
@@ -64,7 +67,7 @@ export function QueryControllerProvider({
|
||||
searchQuery,
|
||||
{
|
||||
filters,
|
||||
numHits: 50,
|
||||
numHits: 30,
|
||||
includeContent: false,
|
||||
signal: controller.signal,
|
||||
}
|
||||
@@ -149,13 +152,17 @@ export function QueryControllerProvider({
|
||||
// We always route through chat if we're not Enterprise Enabled.
|
||||
//
|
||||
// 2.
|
||||
// We always route through chat if the admin has disabled the Search UI.
|
||||
//
|
||||
// 3.
|
||||
// We only go down the classification route if we're in the "New Session" tab.
|
||||
// Everywhere else, we always use the chat-flow.
|
||||
//
|
||||
// 3.
|
||||
// 4.
|
||||
// If we're in the "New Session" tab and the app-mode is "Chat", we continue with the chat-flow anyways.
|
||||
if (
|
||||
!isPaidEnterpriseFeaturesEnabled ||
|
||||
!searchUiEnabled ||
|
||||
!appFocus.isNewSession() ||
|
||||
appMode === "chat"
|
||||
) {
|
||||
@@ -218,6 +225,7 @@ export function QueryControllerProvider({
|
||||
performClassification,
|
||||
performSearch,
|
||||
isPaidEnterpriseFeaturesEnabled,
|
||||
searchUiEnabled,
|
||||
]
|
||||
);
|
||||
|
||||
|
||||
25
web/src/hooks/useCloudSubscription.ts
Normal file
25
web/src/hooks/useCloudSubscription.ts
Normal file
@@ -0,0 +1,25 @@
|
||||
import { NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
|
||||
import { hasPaidSubscription } from "@/lib/billing/interfaces";
|
||||
import { useBillingInformation } from "@/hooks/useBillingInformation";
|
||||
|
||||
/**
|
||||
* Returns whether the current tenant has an active paid subscription on cloud.
|
||||
*
|
||||
* Self-hosted deployments always return true (no billing gate).
|
||||
* Cloud deployments check billing status via the billing API.
|
||||
* Returns true while loading to avoid flashing the upgrade prompt.
|
||||
*/
|
||||
export function useCloudSubscription(): boolean {
|
||||
const { data: billingData, isLoading } = useBillingInformation();
|
||||
|
||||
if (!NEXT_PUBLIC_CLOUD_ENABLED) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Treat loading as subscribed to avoid UI flash
|
||||
if (isLoading || billingData == null) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return hasPaidSubscription(billingData);
|
||||
}
|
||||
@@ -3,7 +3,7 @@
|
||||
import { useMemo } from "react";
|
||||
import { useSearchParams } from "next/navigation";
|
||||
import { SEARCH_PARAM_NAMES } from "@/app/app/services/searchParams";
|
||||
import { CombinedSettings } from "@/app/admin/settings/interfaces";
|
||||
import { CombinedSettings } from "@/interfaces/settings";
|
||||
import { ChatSession } from "@/app/app/interfaces";
|
||||
import { MinimalPersonaSnapshot } from "@/app/admin/assistants/interfaces";
|
||||
import { DEFAULT_ASSISTANT_ID } from "@/lib/constants";
|
||||
|
||||
@@ -26,6 +26,7 @@ export interface Settings {
|
||||
query_history_type: QueryHistoryType;
|
||||
|
||||
deep_research_enabled?: boolean;
|
||||
search_ui_enabled?: boolean;
|
||||
|
||||
// Image processing settings
|
||||
image_extraction_and_analysis_enabled?: boolean;
|
||||
@@ -117,4 +118,16 @@ export interface CombinedSettings {
|
||||
isMobile?: boolean;
|
||||
webVersion: string | null;
|
||||
webDomain: string | null;
|
||||
|
||||
/**
|
||||
* NOTE (@raunakab):
|
||||
* Whether search mode is actually available to users.
|
||||
*
|
||||
* Prefer this over reading `settings.search_ui_enabled` directly.
|
||||
* `search_ui_enabled` only reflects the admin's *preference* — it does not
|
||||
* account for prerequisites like connectors being configured. This derived
|
||||
* flag combines the admin setting with runtime checks (e.g. connectors
|
||||
* exist) so consumers get a single, accurate boolean.
|
||||
*/
|
||||
isSearchModeAvailable: boolean;
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user