mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-18 16:25:45 +00:00
Compare commits
11 Commits
ci_artifac
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5150ffc3e0 | ||
|
|
858c1dbe4a | ||
|
|
a8e7353227 | ||
|
|
343cda35cb | ||
|
|
1cbe47d85e | ||
|
|
221658132a | ||
|
|
fe8fb9eb75 | ||
|
|
f7925584b8 | ||
|
|
00b0e15ed7 | ||
|
|
c2968e3bfe | ||
|
|
978f0a9d35 |
@@ -84,18 +84,20 @@ test.describe("Feature Name", () => {
|
||||
});
|
||||
```
|
||||
|
||||
**User isolation** — tests that modify visible app state (creating assistants, sending chat messages, pinning items) should use `loginAsRandomUser` to get a fresh user per test. This prevents side effects from leaking into other parallel tests' screenshots and assertions:
|
||||
**User isolation** — tests that modify visible app state (creating assistants, sending chat messages, pinning items) should run as the pre-provisioned `"user"` account and clean up resources in `afterAll`. This keeps usernames deterministic for screenshots and avoids cross-contamination with admin data from other parallel tests:
|
||||
|
||||
```typescript
|
||||
import { loginAsRandomUser } from "@tests/e2e/utils/auth";
|
||||
import { loginAs } from "@tests/e2e/utils/auth";
|
||||
|
||||
test.beforeEach(async ({ page }) => {
|
||||
await page.context().clearCookies();
|
||||
await loginAsRandomUser(page);
|
||||
await loginAs(page, "user");
|
||||
});
|
||||
```
|
||||
|
||||
Switch to admin only when privileged setup is needed (creating providers, configuring tools), then back to the isolated user for the actual test. See `chat/default_assistant.spec.ts` for a full example.
|
||||
If the test requires admin privileges *and* modifies visible state, use `"admin2"` instead — it's a pre-provisioned admin account that keeps the primary `"admin"` clean for other parallel tests. Switch to `"admin"` only for privileged setup (creating providers, configuring tools), then back to `"user"` or `"admin2"` for the actual test. See `chat/default_assistant.spec.ts` for a full example.
|
||||
|
||||
`loginAsRandomUser` exists for the rare case where the test requires a brand-new user (e.g. onboarding flows). Avoid it elsewhere — it produces non-deterministic usernames that complicate screenshots.
|
||||
|
||||
**API resource setup** — only when tests need to create backend resources (image gen configs, web search providers, MCP servers). Use `beforeAll`/`afterAll` with `OnyxApiClient` to create and clean up. See `chat/default_assistant.spec.ts` or `mcp/mcp_oauth_flow.spec.ts` for examples. This is uncommon (~4 of 37 test files).
|
||||
|
||||
@@ -126,6 +128,30 @@ Backend API client for test setup/teardown. Key methods:
|
||||
- `expectElementScreenshot(locator, { name, mask?, hide? })`
|
||||
- Controlled by `VISUAL_REGRESSION=true` env var
|
||||
|
||||
### `theme` (`@tests/e2e/utils/theme`)
|
||||
|
||||
- `THEMES` — `["light", "dark"] as const` array for iterating over both themes
|
||||
- `setThemeBeforeNavigation(page, theme)` — sets `next-themes` theme via `localStorage` before navigation
|
||||
|
||||
When tests need light/dark screenshots, loop over `THEMES` at the `test.describe` level and call `setThemeBeforeNavigation` in `beforeEach` **before** any `page.goto()`. Include the theme in screenshot names. See `admin/admin_pages.spec.ts` or `chat/chat_message_rendering.spec.ts` for examples:
|
||||
|
||||
```typescript
|
||||
import { THEMES, setThemeBeforeNavigation } from "@tests/e2e/utils/theme";
|
||||
|
||||
for (const theme of THEMES) {
|
||||
test.describe(`Feature (${theme} mode)`, () => {
|
||||
test.beforeEach(async ({ page }) => {
|
||||
await setThemeBeforeNavigation(page, theme);
|
||||
});
|
||||
|
||||
test("renders correctly", async ({ page }) => {
|
||||
await page.goto("/app");
|
||||
await expectScreenshot(page, { name: `feature-${theme}` });
|
||||
});
|
||||
});
|
||||
}
|
||||
```
|
||||
|
||||
### `tools` (`@tests/e2e/utils/tools`)
|
||||
|
||||
- `TOOL_IDS` — centralized `data-testid` selectors for tool options
|
||||
@@ -206,10 +232,10 @@ await page.waitForResponse(resp => resp.url().includes("/api/chat") && resp.stat
|
||||
|
||||
1. **Descriptive test names** — clearly state expected behavior: `"should display greeting message when opening new chat"`
|
||||
2. **API-first setup** — use `OnyxApiClient` for backend state; reserve UI interactions for the behavior under test
|
||||
3. **User isolation** — tests that modify visible app state (sidebar, chat history) should use `loginAsRandomUser` for a fresh user per test, avoiding cross-test contamination. Always cleanup API-created resources in `afterAll`
|
||||
3. **User isolation** — tests that modify visible app state (sidebar, chat history) should run as `"user"` (not admin) and clean up resources in `afterAll`. This keeps usernames deterministic for screenshots. Reserve `loginAsRandomUser` for flows that require a brand-new user (e.g. onboarding)
|
||||
4. **DRY helpers** — extract reusable logic into `utils/` with JSDoc comments
|
||||
5. **No hardcoded waits** — use `waitFor`, `waitForLoadState`, or web-first assertions
|
||||
6. **Parallel-safe** — no shared mutable state between tests; use unique names with timestamps (`\`test-${Date.now()}\``)
|
||||
6. **Parallel-safe** — no shared mutable state between tests. Prefer static, human-readable names (e.g. `"E2E-CMD Chat 1"`) and clean up resources by ID in `afterAll`. This keeps screenshots deterministic and avoids needing to mask/hide dynamic text. Only fall back to timestamps (`\`test-${Date.now()}\``) when resources cannot be reliably cleaned up or when name collisions across parallel workers would cause functional failures
|
||||
7. **Error context** — catch and re-throw with useful debug info (page text, URL, etc.)
|
||||
8. **Tag slow tests** — mark serial/slow tests with `@exclusive` in the test title
|
||||
9. **Visual regression** — use `expectScreenshot()` for UI consistency checks
|
||||
|
||||
5
.github/workflows/pr-playwright-tests.yml
vendored
5
.github/workflows/pr-playwright-tests.yml
vendored
@@ -593,7 +593,10 @@ jobs:
|
||||
# Post a single combined visual regression comment after all matrix jobs finish
|
||||
visual-regression-comment:
|
||||
needs: [playwright-tests]
|
||||
if: always() && github.event_name == 'pull_request'
|
||||
if: >-
|
||||
always() &&
|
||||
github.event_name == 'pull_request' &&
|
||||
needs.playwright-tests.result != 'cancelled'
|
||||
runs-on: ubuntu-slim
|
||||
timeout-minutes: 5
|
||||
permissions:
|
||||
|
||||
@@ -31,6 +31,7 @@ from ee.onyx.server.query_and_chat.query_backend import (
|
||||
from ee.onyx.server.query_and_chat.search_backend import router as search_router
|
||||
from ee.onyx.server.query_history.api import router as query_history_router
|
||||
from ee.onyx.server.reporting.usage_export_api import router as usage_export_router
|
||||
from ee.onyx.server.scim.api import scim_router
|
||||
from ee.onyx.server.seeding import seed_db
|
||||
from ee.onyx.server.tenants.api import router as tenants_router
|
||||
from ee.onyx.server.token_rate_limits.api import (
|
||||
@@ -162,6 +163,10 @@ def get_application() -> FastAPI:
|
||||
# Tenant management
|
||||
include_router_with_global_prefix_prepended(application, tenants_router)
|
||||
|
||||
# SCIM 2.0 — service discovery endpoints (unauthenticated).
|
||||
# Not behind APP_API_PREFIX because IdPs expect /scim/v2/... directly.
|
||||
application.include_router(scim_router)
|
||||
|
||||
# Ensure all routes have auth enabled or are explicitly marked as public
|
||||
check_ee_router_auth(application)
|
||||
|
||||
|
||||
@@ -5,6 +5,10 @@ from onyx.server.auth_check import PUBLIC_ENDPOINT_SPECS
|
||||
|
||||
|
||||
EE_PUBLIC_ENDPOINT_SPECS = PUBLIC_ENDPOINT_SPECS + [
|
||||
# SCIM 2.0 service discovery — IdPs probe these before auth setup
|
||||
("/scim/v2/ServiceProviderConfig", {"GET"}),
|
||||
("/scim/v2/ResourceTypes", {"GET"}),
|
||||
("/scim/v2/Schemas", {"GET"}),
|
||||
# needs to be accessible prior to user login
|
||||
("/enterprise-settings", {"GET"}),
|
||||
("/enterprise-settings/logo", {"GET"}),
|
||||
|
||||
53
backend/ee/onyx/server/scim/api.py
Normal file
53
backend/ee/onyx/server/scim/api.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""SCIM 2.0 API endpoints (RFC 7644).
|
||||
|
||||
This module provides the FastAPI router for SCIM service discovery,
|
||||
User CRUD, and Group CRUD. Identity providers (Okta, Azure AD) call
|
||||
these endpoints to provision and manage users and groups.
|
||||
|
||||
Service discovery endpoints are unauthenticated — IdPs may probe them
|
||||
before bearer token configuration is complete. All other endpoints
|
||||
require a valid SCIM bearer token.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from fastapi import APIRouter
|
||||
|
||||
from ee.onyx.server.scim.models import ScimResourceType
|
||||
from ee.onyx.server.scim.models import ScimSchemaDefinition
|
||||
from ee.onyx.server.scim.models import ScimServiceProviderConfig
|
||||
from ee.onyx.server.scim.schema_definitions import GROUP_RESOURCE_TYPE
|
||||
from ee.onyx.server.scim.schema_definitions import GROUP_SCHEMA_DEF
|
||||
from ee.onyx.server.scim.schema_definitions import SERVICE_PROVIDER_CONFIG
|
||||
from ee.onyx.server.scim.schema_definitions import USER_RESOURCE_TYPE
|
||||
from ee.onyx.server.scim.schema_definitions import USER_SCHEMA_DEF
|
||||
|
||||
|
||||
# NOTE: All URL paths in this router (/ServiceProviderConfig, /ResourceTypes,
|
||||
# /Schemas, /Users, /Groups) are mandated by the SCIM spec (RFC 7643/7644).
|
||||
# IdPs like Okta and Azure AD hardcode these exact paths, so they cannot be
|
||||
# changed to kebab-case.
|
||||
scim_router = APIRouter(prefix="/scim/v2", tags=["SCIM"])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Service Discovery Endpoints (unauthenticated)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@scim_router.get("/ServiceProviderConfig")
|
||||
def get_service_provider_config() -> ScimServiceProviderConfig:
|
||||
"""Advertise supported SCIM features (RFC 7643 §5)."""
|
||||
return SERVICE_PROVIDER_CONFIG
|
||||
|
||||
|
||||
@scim_router.get("/ResourceTypes")
|
||||
def get_resource_types() -> list[ScimResourceType]:
|
||||
"""List available SCIM resource types (RFC 7643 §6)."""
|
||||
return [USER_RESOURCE_TYPE, GROUP_RESOURCE_TYPE]
|
||||
|
||||
|
||||
@scim_router.get("/Schemas")
|
||||
def get_schemas() -> list[ScimSchemaDefinition]:
|
||||
"""Return SCIM schema definitions (RFC 7643 §7)."""
|
||||
return [USER_SCHEMA_DEF, GROUP_SCHEMA_DEF]
|
||||
@@ -30,6 +30,7 @@ SCIM_SERVICE_PROVIDER_CONFIG_SCHEMA = (
|
||||
"urn:ietf:params:scim:schemas:core:2.0:ServiceProviderConfig"
|
||||
)
|
||||
SCIM_RESOURCE_TYPE_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:ResourceType"
|
||||
SCIM_SCHEMA_SCHEMA = "urn:ietf:params:scim:schemas:core:2.0:Schema"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -195,10 +196,39 @@ class ScimServiceProviderConfig(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class ScimSchemaAttribute(BaseModel):
|
||||
"""Attribute definition within a SCIM Schema (RFC 7643 §7)."""
|
||||
|
||||
name: str
|
||||
type: str
|
||||
multiValued: bool = False
|
||||
required: bool = False
|
||||
description: str = ""
|
||||
caseExact: bool = False
|
||||
mutability: str = "readWrite"
|
||||
returned: str = "default"
|
||||
uniqueness: str = "none"
|
||||
subAttributes: list["ScimSchemaAttribute"] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ScimSchemaDefinition(BaseModel):
|
||||
"""SCIM Schema definition (RFC 7643 §7).
|
||||
|
||||
Served at GET /scim/v2/Schemas. Describes the attributes available
|
||||
on each resource type so IdPs know which fields they can provision.
|
||||
"""
|
||||
|
||||
schemas: list[str] = Field(default_factory=lambda: [SCIM_SCHEMA_SCHEMA])
|
||||
id: str
|
||||
name: str
|
||||
description: str
|
||||
attributes: list[ScimSchemaAttribute] = Field(default_factory=list)
|
||||
|
||||
|
||||
class ScimSchemaExtension(BaseModel):
|
||||
"""Schema extension reference within ResourceType (RFC 7643 §6)."""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
|
||||
|
||||
schema_: str = Field(alias="schema")
|
||||
required: bool
|
||||
@@ -211,7 +241,7 @@ class ScimResourceType(BaseModel):
|
||||
types are available (Users, Groups) and their respective endpoints.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
model_config = ConfigDict(populate_by_name=True, serialize_by_alias=True)
|
||||
|
||||
schemas: list[str] = Field(default_factory=lambda: [SCIM_RESOURCE_TYPE_SCHEMA])
|
||||
id: str
|
||||
|
||||
144
backend/ee/onyx/server/scim/schema_definitions.py
Normal file
144
backend/ee/onyx/server/scim/schema_definitions.py
Normal file
@@ -0,0 +1,144 @@
|
||||
"""Static SCIM service discovery responses (RFC 7643 §5, §6, §7).
|
||||
|
||||
Pre-built at import time — these never change at runtime. Separated from
|
||||
api.py to keep the endpoint module focused on request handling.
|
||||
"""
|
||||
|
||||
from ee.onyx.server.scim.models import SCIM_GROUP_SCHEMA
|
||||
from ee.onyx.server.scim.models import SCIM_USER_SCHEMA
|
||||
from ee.onyx.server.scim.models import ScimResourceType
|
||||
from ee.onyx.server.scim.models import ScimSchemaAttribute
|
||||
from ee.onyx.server.scim.models import ScimSchemaDefinition
|
||||
from ee.onyx.server.scim.models import ScimServiceProviderConfig
|
||||
|
||||
SERVICE_PROVIDER_CONFIG = ScimServiceProviderConfig()
|
||||
|
||||
USER_RESOURCE_TYPE = ScimResourceType.model_validate(
|
||||
{
|
||||
"id": "User",
|
||||
"name": "User",
|
||||
"endpoint": "/scim/v2/Users",
|
||||
"description": "SCIM User resource",
|
||||
"schema": SCIM_USER_SCHEMA,
|
||||
}
|
||||
)
|
||||
|
||||
GROUP_RESOURCE_TYPE = ScimResourceType.model_validate(
|
||||
{
|
||||
"id": "Group",
|
||||
"name": "Group",
|
||||
"endpoint": "/scim/v2/Groups",
|
||||
"description": "SCIM Group resource",
|
||||
"schema": SCIM_GROUP_SCHEMA,
|
||||
}
|
||||
)
|
||||
|
||||
USER_SCHEMA_DEF = ScimSchemaDefinition(
|
||||
id=SCIM_USER_SCHEMA,
|
||||
name="User",
|
||||
description="SCIM core User schema",
|
||||
attributes=[
|
||||
ScimSchemaAttribute(
|
||||
name="userName",
|
||||
type="string",
|
||||
required=True,
|
||||
uniqueness="server",
|
||||
description="Unique identifier for the user, typically an email address.",
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="name",
|
||||
type="complex",
|
||||
description="The components of the user's name.",
|
||||
subAttributes=[
|
||||
ScimSchemaAttribute(
|
||||
name="givenName",
|
||||
type="string",
|
||||
description="The user's first name.",
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="familyName",
|
||||
type="string",
|
||||
description="The user's last name.",
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="formatted",
|
||||
type="string",
|
||||
description="The full name, including all middle names and titles.",
|
||||
),
|
||||
],
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="emails",
|
||||
type="complex",
|
||||
multiValued=True,
|
||||
description="Email addresses for the user.",
|
||||
subAttributes=[
|
||||
ScimSchemaAttribute(
|
||||
name="value",
|
||||
type="string",
|
||||
description="Email address value.",
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="type",
|
||||
type="string",
|
||||
description="Label for this email (e.g. 'work').",
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="primary",
|
||||
type="boolean",
|
||||
description="Whether this is the primary email.",
|
||||
),
|
||||
],
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="active",
|
||||
type="boolean",
|
||||
description="Whether the user account is active.",
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="externalId",
|
||||
type="string",
|
||||
description="Identifier from the provisioning client (IdP).",
|
||||
caseExact=True,
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
GROUP_SCHEMA_DEF = ScimSchemaDefinition(
|
||||
id=SCIM_GROUP_SCHEMA,
|
||||
name="Group",
|
||||
description="SCIM core Group schema",
|
||||
attributes=[
|
||||
ScimSchemaAttribute(
|
||||
name="displayName",
|
||||
type="string",
|
||||
required=True,
|
||||
description="Human-readable name for the group.",
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="members",
|
||||
type="complex",
|
||||
multiValued=True,
|
||||
description="Members of the group.",
|
||||
subAttributes=[
|
||||
ScimSchemaAttribute(
|
||||
name="value",
|
||||
type="string",
|
||||
description="User ID of the group member.",
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="display",
|
||||
type="string",
|
||||
mutability="readOnly",
|
||||
description="Display name of the group member.",
|
||||
),
|
||||
],
|
||||
),
|
||||
ScimSchemaAttribute(
|
||||
name="externalId",
|
||||
type="string",
|
||||
description="Identifier from the provisioning client (IdP).",
|
||||
caseExact=True,
|
||||
),
|
||||
],
|
||||
)
|
||||
@@ -68,6 +68,18 @@ from shared_configs.contextvars import get_current_tenant_id
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _looks_like_xml_tool_call_payload(text: str | None) -> bool:
|
||||
"""Detect XML-style marshaled tool calls emitted as plain text."""
|
||||
if not text:
|
||||
return False
|
||||
lowered = text.lower()
|
||||
return (
|
||||
"<function_calls" in lowered
|
||||
and "<invoke" in lowered
|
||||
and "<parameter" in lowered
|
||||
)
|
||||
|
||||
|
||||
def _should_keep_bedrock_tool_definitions(
|
||||
llm: object, simple_chat_history: list[ChatMessageSimple]
|
||||
) -> bool:
|
||||
@@ -122,38 +134,56 @@ def _try_fallback_tool_extraction(
|
||||
reasoning_but_no_answer_or_tools = (
|
||||
llm_step_result.reasoning and not llm_step_result.answer and no_tool_calls
|
||||
)
|
||||
xml_tool_call_text_detected = no_tool_calls and (
|
||||
_looks_like_xml_tool_call_payload(llm_step_result.answer)
|
||||
or _looks_like_xml_tool_call_payload(llm_step_result.raw_answer)
|
||||
or _looks_like_xml_tool_call_payload(llm_step_result.reasoning)
|
||||
)
|
||||
should_try_fallback = (
|
||||
tool_choice == ToolChoiceOptions.REQUIRED and no_tool_calls
|
||||
) or reasoning_but_no_answer_or_tools
|
||||
(tool_choice == ToolChoiceOptions.REQUIRED and no_tool_calls)
|
||||
or reasoning_but_no_answer_or_tools
|
||||
or xml_tool_call_text_detected
|
||||
)
|
||||
|
||||
if not should_try_fallback:
|
||||
return llm_step_result, False
|
||||
|
||||
# Try to extract from answer first, then fall back to reasoning
|
||||
extracted_tool_calls: list[ToolCallKickoff] = []
|
||||
|
||||
if llm_step_result.answer:
|
||||
extracted_tool_calls = extract_tool_calls_from_response_text(
|
||||
response_text=llm_step_result.answer,
|
||||
tool_definitions=tool_defs,
|
||||
placement=Placement(turn_index=turn_index),
|
||||
)
|
||||
if (
|
||||
not extracted_tool_calls
|
||||
and llm_step_result.raw_answer
|
||||
and llm_step_result.raw_answer != llm_step_result.answer
|
||||
):
|
||||
extracted_tool_calls = extract_tool_calls_from_response_text(
|
||||
response_text=llm_step_result.raw_answer,
|
||||
tool_definitions=tool_defs,
|
||||
placement=Placement(turn_index=turn_index),
|
||||
)
|
||||
if not extracted_tool_calls and llm_step_result.reasoning:
|
||||
extracted_tool_calls = extract_tool_calls_from_response_text(
|
||||
response_text=llm_step_result.reasoning,
|
||||
tool_definitions=tool_defs,
|
||||
placement=Placement(turn_index=turn_index),
|
||||
)
|
||||
|
||||
if extracted_tool_calls:
|
||||
logger.info(
|
||||
f"Extracted {len(extracted_tool_calls)} tool call(s) from response text "
|
||||
f"as fallback (tool_choice was REQUIRED but no tool calls returned)"
|
||||
"as fallback"
|
||||
)
|
||||
return (
|
||||
LlmStepResult(
|
||||
reasoning=llm_step_result.reasoning,
|
||||
answer=llm_step_result.answer,
|
||||
tool_calls=extracted_tool_calls,
|
||||
raw_answer=llm_step_result.raw_answer,
|
||||
),
|
||||
True,
|
||||
)
|
||||
|
||||
@@ -1,10 +1,12 @@
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Mapping
|
||||
from collections.abc import Sequence
|
||||
from html import unescape
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
@@ -56,6 +58,112 @@ from onyx.utils.text_processing import find_all_json_objects
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_XML_INVOKE_BLOCK_RE = re.compile(
|
||||
r"<invoke\b(?P<attrs>[^>]*)>(?P<body>.*?)</invoke>",
|
||||
re.IGNORECASE | re.DOTALL,
|
||||
)
|
||||
_XML_PARAMETER_RE = re.compile(
|
||||
r"<parameter\b(?P<attrs>[^>]*)>(?P<value>.*?)</parameter>",
|
||||
re.IGNORECASE | re.DOTALL,
|
||||
)
|
||||
_FUNCTION_CALLS_OPEN_MARKER = "<function_calls"
|
||||
_FUNCTION_CALLS_CLOSE_MARKER = "</function_calls>"
|
||||
|
||||
|
||||
class _XmlToolCallContentFilter:
|
||||
"""Streaming filter that strips XML-style tool call payload blocks from text."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._pending = ""
|
||||
self._inside_function_calls_block = False
|
||||
|
||||
def process(self, content: str) -> str:
|
||||
if not content:
|
||||
return ""
|
||||
|
||||
self._pending += content
|
||||
output_parts: list[str] = []
|
||||
|
||||
while self._pending:
|
||||
pending_lower = self._pending.lower()
|
||||
|
||||
if self._inside_function_calls_block:
|
||||
end_idx = pending_lower.find(_FUNCTION_CALLS_CLOSE_MARKER)
|
||||
if end_idx == -1:
|
||||
# Keep buffering until we see the close marker.
|
||||
return "".join(output_parts)
|
||||
|
||||
# Drop the whole function_calls block.
|
||||
self._pending = self._pending[
|
||||
end_idx + len(_FUNCTION_CALLS_CLOSE_MARKER) :
|
||||
]
|
||||
self._inside_function_calls_block = False
|
||||
continue
|
||||
|
||||
start_idx = _find_function_calls_open_marker(pending_lower)
|
||||
if start_idx == -1:
|
||||
# Keep only a possible prefix of "<function_calls" in the buffer so
|
||||
# marker splits across chunks are handled correctly.
|
||||
tail_len = _matching_open_marker_prefix_len(self._pending)
|
||||
emit_upto = len(self._pending) - tail_len
|
||||
if emit_upto > 0:
|
||||
output_parts.append(self._pending[:emit_upto])
|
||||
self._pending = self._pending[emit_upto:]
|
||||
return "".join(output_parts)
|
||||
|
||||
if start_idx > 0:
|
||||
output_parts.append(self._pending[:start_idx])
|
||||
|
||||
# Enter block-stripping mode and keep scanning for close marker.
|
||||
self._pending = self._pending[start_idx:]
|
||||
self._inside_function_calls_block = True
|
||||
|
||||
return "".join(output_parts)
|
||||
|
||||
def flush(self) -> str:
|
||||
if self._inside_function_calls_block:
|
||||
# Drop any incomplete block at stream end.
|
||||
self._pending = ""
|
||||
self._inside_function_calls_block = False
|
||||
return ""
|
||||
|
||||
remaining = self._pending
|
||||
self._pending = ""
|
||||
return remaining
|
||||
|
||||
|
||||
def _matching_open_marker_prefix_len(text: str) -> int:
|
||||
"""Return longest suffix of text that matches prefix of "<function_calls"."""
|
||||
max_len = min(len(text), len(_FUNCTION_CALLS_OPEN_MARKER) - 1)
|
||||
text_lower = text.lower()
|
||||
marker_lower = _FUNCTION_CALLS_OPEN_MARKER
|
||||
|
||||
for candidate_len in range(max_len, 0, -1):
|
||||
if text_lower.endswith(marker_lower[:candidate_len]):
|
||||
return candidate_len
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def _is_valid_function_calls_open_follower(char: str | None) -> bool:
|
||||
return char is None or char in {">", " ", "\t", "\n", "\r"}
|
||||
|
||||
|
||||
def _find_function_calls_open_marker(text_lower: str) -> int:
|
||||
"""Find '<function_calls' with a valid tag boundary follower."""
|
||||
search_from = 0
|
||||
while True:
|
||||
idx = text_lower.find(_FUNCTION_CALLS_OPEN_MARKER, search_from)
|
||||
if idx == -1:
|
||||
return -1
|
||||
|
||||
follower_pos = idx + len(_FUNCTION_CALLS_OPEN_MARKER)
|
||||
follower = text_lower[follower_pos] if follower_pos < len(text_lower) else None
|
||||
if _is_valid_function_calls_open_follower(follower):
|
||||
return idx
|
||||
|
||||
search_from = idx + 1
|
||||
|
||||
|
||||
def _sanitize_llm_output(value: str) -> str:
|
||||
"""Remove characters that PostgreSQL's text/JSONB types cannot store.
|
||||
@@ -307,8 +415,9 @@ def extract_tool_calls_from_response_text(
|
||||
"""Extract tool calls from LLM response text by matching JSON against tool definitions.
|
||||
|
||||
This is a fallback mechanism for when the LLM was expected to return tool calls
|
||||
but didn't use the proper tool call format. It searches for JSON objects in the
|
||||
response text that match the structure of available tools.
|
||||
but didn't use the proper tool call format. It searches for tool calls embedded
|
||||
in response text (JSON first, then XML-like invoke blocks) that match available
|
||||
tool definitions.
|
||||
|
||||
Args:
|
||||
response_text: The LLM's text response to search for tool calls
|
||||
@@ -333,10 +442,9 @@ def extract_tool_calls_from_response_text(
|
||||
if not tool_name_to_def:
|
||||
return []
|
||||
|
||||
matched_tool_calls: list[tuple[str, dict[str, Any]]] = []
|
||||
# Find all JSON objects in the response text
|
||||
json_objects = find_all_json_objects(response_text)
|
||||
|
||||
matched_tool_calls: list[tuple[str, dict[str, Any]]] = []
|
||||
prev_json_obj: dict[str, Any] | None = None
|
||||
prev_tool_call: tuple[str, dict[str, Any]] | None = None
|
||||
|
||||
@@ -364,6 +472,14 @@ def extract_tool_calls_from_response_text(
|
||||
prev_json_obj = json_obj
|
||||
prev_tool_call = matched_tool_call
|
||||
|
||||
# Some providers/models emit XML-style function calls instead of JSON objects.
|
||||
# Keep this as a fallback behind JSON extraction to preserve current behavior.
|
||||
if not matched_tool_calls:
|
||||
matched_tool_calls = _extract_xml_tool_calls_from_response_text(
|
||||
response_text=response_text,
|
||||
tool_name_to_def=tool_name_to_def,
|
||||
)
|
||||
|
||||
tool_calls: list[ToolCallKickoff] = []
|
||||
for tab_index, (tool_name, tool_args) in enumerate(matched_tool_calls):
|
||||
tool_calls.append(
|
||||
@@ -386,6 +502,71 @@ def extract_tool_calls_from_response_text(
|
||||
return tool_calls
|
||||
|
||||
|
||||
def _extract_xml_tool_calls_from_response_text(
|
||||
response_text: str,
|
||||
tool_name_to_def: dict[str, dict],
|
||||
) -> list[tuple[str, dict[str, Any]]]:
|
||||
"""Extract XML-style tool calls from response text.
|
||||
|
||||
Supports formats such as:
|
||||
<function_calls>
|
||||
<invoke name="internal_search">
|
||||
<parameter name="queries" string="false">["foo"]</parameter>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
"""
|
||||
matched_tool_calls: list[tuple[str, dict[str, Any]]] = []
|
||||
|
||||
for invoke_match in _XML_INVOKE_BLOCK_RE.finditer(response_text):
|
||||
invoke_attrs = invoke_match.group("attrs")
|
||||
tool_name = _extract_xml_attribute(invoke_attrs, "name")
|
||||
if not tool_name or tool_name not in tool_name_to_def:
|
||||
continue
|
||||
|
||||
tool_args: dict[str, Any] = {}
|
||||
invoke_body = invoke_match.group("body")
|
||||
for parameter_match in _XML_PARAMETER_RE.finditer(invoke_body):
|
||||
parameter_attrs = parameter_match.group("attrs")
|
||||
parameter_name = _extract_xml_attribute(parameter_attrs, "name")
|
||||
if not parameter_name:
|
||||
continue
|
||||
|
||||
string_attr = _extract_xml_attribute(parameter_attrs, "string")
|
||||
tool_args[parameter_name] = _parse_xml_parameter_value(
|
||||
raw_value=parameter_match.group("value"),
|
||||
string_attr=string_attr,
|
||||
)
|
||||
|
||||
matched_tool_calls.append((tool_name, tool_args))
|
||||
|
||||
return matched_tool_calls
|
||||
|
||||
|
||||
def _extract_xml_attribute(attrs: str, attr_name: str) -> str | None:
|
||||
"""Extract a single XML-style attribute value from a tag attribute string."""
|
||||
attr_match = re.search(
|
||||
rf"""\b{re.escape(attr_name)}\s*=\s*(['"])(.*?)\1""",
|
||||
attrs,
|
||||
flags=re.IGNORECASE | re.DOTALL,
|
||||
)
|
||||
if not attr_match:
|
||||
return None
|
||||
return _sanitize_llm_output(unescape(attr_match.group(2).strip()))
|
||||
|
||||
|
||||
def _parse_xml_parameter_value(raw_value: str, string_attr: str | None) -> Any:
|
||||
"""Parse a parameter value from XML-style tool call payloads."""
|
||||
value = _sanitize_llm_output(unescape(raw_value).strip())
|
||||
|
||||
if string_attr and string_attr.lower() == "true":
|
||||
return value
|
||||
|
||||
try:
|
||||
return json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
return value
|
||||
|
||||
|
||||
def _try_match_json_to_tool(
|
||||
json_obj: dict[str, Any],
|
||||
tool_name_to_def: dict[str, dict],
|
||||
@@ -749,6 +930,8 @@ def run_llm_step_pkt_generator(
|
||||
answer_start = False
|
||||
accumulated_reasoning = ""
|
||||
accumulated_answer = ""
|
||||
accumulated_raw_answer = ""
|
||||
xml_tool_call_content_filter = _XmlToolCallContentFilter()
|
||||
|
||||
processor_state: Any = None
|
||||
|
||||
@@ -764,6 +947,120 @@ def run_llm_step_pkt_generator(
|
||||
)
|
||||
stream_start_time = time.monotonic()
|
||||
first_action_recorded = False
|
||||
|
||||
def _emit_content_chunk(content_chunk: str) -> Generator[Packet, None, None]:
|
||||
nonlocal accumulated_answer
|
||||
nonlocal accumulated_reasoning
|
||||
nonlocal answer_start
|
||||
nonlocal reasoning_start
|
||||
nonlocal has_reasoned
|
||||
nonlocal turn_index
|
||||
nonlocal sub_turn_index
|
||||
|
||||
# When tool_choice is REQUIRED, content before tool calls is reasoning/thinking
|
||||
# about which tool to call, not an actual answer to the user.
|
||||
# Treat this content as reasoning instead of answer.
|
||||
if is_deep_research and tool_choice == ToolChoiceOptions.REQUIRED:
|
||||
accumulated_reasoning += content_chunk
|
||||
if state_container:
|
||||
state_container.set_reasoning_tokens(accumulated_reasoning)
|
||||
if not reasoning_start:
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=ReasoningStart(),
|
||||
)
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=ReasoningDelta(reasoning=content_chunk),
|
||||
)
|
||||
reasoning_start = True
|
||||
return
|
||||
|
||||
# Normal flow for AUTO or NONE tool choice
|
||||
if reasoning_start:
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=ReasoningDone(),
|
||||
)
|
||||
has_reasoned = 1
|
||||
turn_index, sub_turn_index = _increment_turns(
|
||||
turn_index, sub_turn_index
|
||||
)
|
||||
reasoning_start = False
|
||||
|
||||
if not answer_start:
|
||||
# Store pre-answer processing time in state container for save_chat
|
||||
if state_container and pre_answer_processing_time is not None:
|
||||
state_container.set_pre_answer_processing_time(
|
||||
pre_answer_processing_time
|
||||
)
|
||||
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=AgentResponseStart(
|
||||
final_documents=final_documents,
|
||||
pre_answer_processing_seconds=pre_answer_processing_time,
|
||||
),
|
||||
)
|
||||
answer_start = True
|
||||
|
||||
if citation_processor:
|
||||
for result in citation_processor.process_token(content_chunk):
|
||||
if isinstance(result, str):
|
||||
accumulated_answer += result
|
||||
# Save answer incrementally to state container
|
||||
if state_container:
|
||||
state_container.set_answer_tokens(accumulated_answer)
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=AgentResponseDelta(content=result),
|
||||
)
|
||||
elif isinstance(result, CitationInfo):
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=result,
|
||||
)
|
||||
# Track emitted citation for saving
|
||||
if state_container:
|
||||
state_container.add_emitted_citation(result.citation_number)
|
||||
else:
|
||||
accumulated_answer += content_chunk
|
||||
# Save answer incrementally to state container
|
||||
if state_container:
|
||||
state_container.set_answer_tokens(accumulated_answer)
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=AgentResponseDelta(content=content_chunk),
|
||||
)
|
||||
|
||||
for packet in llm.stream(
|
||||
prompt=llm_msg_history,
|
||||
tools=tool_definitions,
|
||||
@@ -840,114 +1137,12 @@ def run_llm_step_pkt_generator(
|
||||
reasoning_start = True
|
||||
|
||||
if delta.content:
|
||||
# When tool_choice is REQUIRED, content before tool calls is reasoning/thinking
|
||||
# about which tool to call, not an actual answer to the user.
|
||||
# Treat this content as reasoning instead of answer.
|
||||
if is_deep_research and tool_choice == ToolChoiceOptions.REQUIRED:
|
||||
# Treat content as reasoning when we know tool calls are coming
|
||||
accumulated_reasoning += delta.content
|
||||
if state_container:
|
||||
state_container.set_reasoning_tokens(accumulated_reasoning)
|
||||
if not reasoning_start:
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=ReasoningStart(),
|
||||
)
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=ReasoningDelta(reasoning=delta.content),
|
||||
)
|
||||
reasoning_start = True
|
||||
else:
|
||||
# Normal flow for AUTO or NONE tool choice
|
||||
if reasoning_start:
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=ReasoningDone(),
|
||||
)
|
||||
has_reasoned = 1
|
||||
turn_index, sub_turn_index = _increment_turns(
|
||||
turn_index, sub_turn_index
|
||||
)
|
||||
reasoning_start = False
|
||||
|
||||
if not answer_start:
|
||||
# Store pre-answer processing time in state container for save_chat
|
||||
if state_container and pre_answer_processing_time is not None:
|
||||
state_container.set_pre_answer_processing_time(
|
||||
pre_answer_processing_time
|
||||
)
|
||||
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=AgentResponseStart(
|
||||
final_documents=final_documents,
|
||||
pre_answer_processing_seconds=pre_answer_processing_time,
|
||||
),
|
||||
)
|
||||
answer_start = True
|
||||
|
||||
if citation_processor:
|
||||
for result in citation_processor.process_token(delta.content):
|
||||
if isinstance(result, str):
|
||||
accumulated_answer += result
|
||||
# Save answer incrementally to state container
|
||||
if state_container:
|
||||
state_container.set_answer_tokens(
|
||||
accumulated_answer
|
||||
)
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=AgentResponseDelta(content=result),
|
||||
)
|
||||
elif isinstance(result, CitationInfo):
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=result,
|
||||
)
|
||||
# Track emitted citation for saving
|
||||
if state_container:
|
||||
state_container.add_emitted_citation(
|
||||
result.citation_number
|
||||
)
|
||||
else:
|
||||
# When citation_processor is None, use delta.content directly without modification
|
||||
accumulated_answer += delta.content
|
||||
# Save answer incrementally to state container
|
||||
if state_container:
|
||||
state_container.set_answer_tokens(accumulated_answer)
|
||||
yield Packet(
|
||||
placement=Placement(
|
||||
turn_index=turn_index,
|
||||
tab_index=tab_index,
|
||||
sub_turn_index=sub_turn_index,
|
||||
),
|
||||
obj=AgentResponseDelta(content=delta.content),
|
||||
)
|
||||
# Keep raw content for fallback extraction. Display content can be
|
||||
# filtered and, in deep-research REQUIRED mode, routed as reasoning.
|
||||
accumulated_raw_answer += delta.content
|
||||
filtered_content = xml_tool_call_content_filter.process(delta.content)
|
||||
if filtered_content:
|
||||
yield from _emit_content_chunk(filtered_content)
|
||||
|
||||
if delta.tool_calls:
|
||||
if reasoning_start:
|
||||
@@ -968,6 +1163,11 @@ def run_llm_step_pkt_generator(
|
||||
for tool_call_delta in delta.tool_calls:
|
||||
_update_tool_call_with_delta(id_to_tool_call_map, tool_call_delta)
|
||||
|
||||
# Flush any tail text buffered while checking for split "<function_calls" markers.
|
||||
filtered_content_tail = xml_tool_call_content_filter.flush()
|
||||
if filtered_content_tail:
|
||||
yield from _emit_content_chunk(filtered_content_tail)
|
||||
|
||||
# Flush custom token processor to get any final tool calls
|
||||
if custom_token_processor:
|
||||
flush_delta, processor_state = custom_token_processor(None, processor_state)
|
||||
@@ -1088,6 +1288,7 @@ def run_llm_step_pkt_generator(
|
||||
reasoning=accumulated_reasoning if accumulated_reasoning else None,
|
||||
answer=accumulated_answer if accumulated_answer else None,
|
||||
tool_calls=tool_calls if tool_calls else None,
|
||||
raw_answer=accumulated_raw_answer if accumulated_raw_answer else None,
|
||||
),
|
||||
bool(has_reasoned),
|
||||
)
|
||||
|
||||
@@ -185,3 +185,6 @@ class LlmStepResult(BaseModel):
|
||||
reasoning: str | None
|
||||
answer: str | None
|
||||
tool_calls: list[ToolCallKickoff] | None
|
||||
# Raw LLM text before any display-oriented filtering/sanitization.
|
||||
# Used for fallback tool-call extraction when providers emit calls as text.
|
||||
raw_answer: str | None = None
|
||||
|
||||
@@ -46,6 +46,7 @@ from onyx.connectors.google_drive.file_retrieval import get_external_access_for_
|
||||
from onyx.connectors.google_drive.file_retrieval import get_files_in_shared_drive
|
||||
from onyx.connectors.google_drive.file_retrieval import get_folder_metadata
|
||||
from onyx.connectors.google_drive.file_retrieval import get_root_folder_id
|
||||
from onyx.connectors.google_drive.file_retrieval import get_shared_drive_name
|
||||
from onyx.connectors.google_drive.file_retrieval import has_link_only_permission
|
||||
from onyx.connectors.google_drive.models import DriveRetrievalStage
|
||||
from onyx.connectors.google_drive.models import GoogleDriveCheckpoint
|
||||
@@ -156,10 +157,7 @@ def _is_shared_drive_root(folder: GoogleDriveFileType) -> bool:
|
||||
return False
|
||||
|
||||
# For shared drive content, the root has id == driveId
|
||||
if drive_id and folder_id == drive_id:
|
||||
return True
|
||||
|
||||
return False
|
||||
return bool(drive_id and folder_id == drive_id)
|
||||
|
||||
|
||||
def _public_access() -> ExternalAccess:
|
||||
@@ -616,6 +614,16 @@ class GoogleDriveConnector(
|
||||
# empty parents due to permission limitations)
|
||||
# Check shared drive root first (simple ID comparison)
|
||||
if _is_shared_drive_root(folder):
|
||||
# files().get() returns 'Drive' for shared drive roots;
|
||||
# fetch the real name via drives().get().
|
||||
# Try both the retriever and admin since the admin may
|
||||
# not have access to private shared drives.
|
||||
drive_name = self._get_shared_drive_name(
|
||||
current_id, file.user_email
|
||||
)
|
||||
if drive_name:
|
||||
node.display_name = drive_name
|
||||
node.node_type = HierarchyNodeType.SHARED_DRIVE
|
||||
reached_terminal = True
|
||||
break
|
||||
|
||||
@@ -691,6 +699,15 @@ class GoogleDriveConnector(
|
||||
)
|
||||
return None
|
||||
|
||||
def _get_shared_drive_name(self, drive_id: str, retriever_email: str) -> str | None:
|
||||
"""Fetch the name of a shared drive, trying both the retriever and admin."""
|
||||
for email in {retriever_email, self.primary_admin_email}:
|
||||
svc = get_drive_service(self.creds, email)
|
||||
name = get_shared_drive_name(svc, drive_id)
|
||||
if name:
|
||||
return name
|
||||
return None
|
||||
|
||||
def get_all_drive_ids(self) -> set[str]:
|
||||
return self._get_all_drives_for_user(self.primary_admin_email)
|
||||
|
||||
|
||||
@@ -154,6 +154,26 @@ def _get_hierarchy_fields_for_file_type(field_type: DriveFileFieldType) -> str:
|
||||
return HIERARCHY_FIELDS
|
||||
|
||||
|
||||
def get_shared_drive_name(
|
||||
service: Resource,
|
||||
drive_id: str,
|
||||
) -> str | None:
|
||||
"""Fetch the actual name of a shared drive via the drives().get() API.
|
||||
|
||||
The files().get() API returns 'Drive' as the name for shared drive root
|
||||
folders. Only drives().get() returns the real user-assigned name.
|
||||
"""
|
||||
try:
|
||||
drive = service.drives().get(driveId=drive_id, fields="name").execute()
|
||||
return drive.get("name")
|
||||
except HttpError as e:
|
||||
if e.resp.status in (403, 404):
|
||||
logger.debug(f"Cannot access drive {drive_id}: {e}")
|
||||
else:
|
||||
raise
|
||||
return None
|
||||
|
||||
|
||||
def get_external_access_for_folder(
|
||||
folder: GoogleDriveFileType,
|
||||
google_domain: str,
|
||||
|
||||
@@ -21,7 +21,6 @@ from fastapi.routing import APIRoute
|
||||
from httpx_oauth.clients.google import GoogleOAuth2
|
||||
from httpx_oauth.clients.openid import BASE_SCOPES
|
||||
from httpx_oauth.clients.openid import OpenID
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
from sentry_sdk.integrations.fastapi import FastApiIntegration
|
||||
from sentry_sdk.integrations.starlette import StarletteIntegration
|
||||
from starlette.types import Lifespan
|
||||
@@ -121,6 +120,7 @@ from onyx.server.middleware.rate_limiting import get_auth_rate_limiters
|
||||
from onyx.server.middleware.rate_limiting import setup_auth_limiter
|
||||
from onyx.server.onyx_api.ingestion import router as onyx_api_router
|
||||
from onyx.server.pat.api import router as pat_router
|
||||
from onyx.server.prometheus_instrumentation import setup_prometheus_metrics
|
||||
from onyx.server.query_and_chat.chat_backend import router as chat_router
|
||||
from onyx.server.query_and_chat.query_backend import (
|
||||
admin_router as admin_query_router,
|
||||
@@ -563,8 +563,8 @@ def get_application(lifespan_override: Lifespan | None = None) -> FastAPI:
|
||||
# Ensure all routes have auth enabled or are explicitly marked as public
|
||||
check_router_auth(application)
|
||||
|
||||
# Initialize and instrument the app
|
||||
Instrumentator().instrument(application).expose(application)
|
||||
# Initialize and instrument the app with production Prometheus config
|
||||
setup_prometheus_metrics(application)
|
||||
|
||||
use_route_function_names_as_operation_ids(application)
|
||||
|
||||
|
||||
@@ -8,6 +8,7 @@ import httpx
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx import __version__
|
||||
from onyx.configs.app_configs import INSTANCE_TYPE
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.release_notes import create_release_notifications_for_versions
|
||||
from onyx.redis.redis_pool import get_shared_redis_client
|
||||
@@ -56,7 +57,7 @@ def is_version_gte(v1: str, v2: str) -> bool:
|
||||
|
||||
|
||||
def parse_mdx_to_release_note_entries(mdx_content: str) -> list[ReleaseNoteEntry]:
|
||||
"""Parse MDX content into ReleaseNoteEntry objects for versions >= __version__."""
|
||||
"""Parse MDX content into ReleaseNoteEntry objects."""
|
||||
all_entries = []
|
||||
|
||||
update_pattern = (
|
||||
@@ -82,6 +83,12 @@ def parse_mdx_to_release_note_entries(mdx_content: str) -> list[ReleaseNoteEntry
|
||||
if not all_entries:
|
||||
raise ValueError("Could not parse any release note entries from MDX.")
|
||||
|
||||
if INSTANCE_TYPE == "cloud":
|
||||
# Cloud often runs ahead of docs release tags; always notify on latest release.
|
||||
return sorted(
|
||||
all_entries, key=lambda x: parse_version_tuple(x.version), reverse=True
|
||||
)[:1]
|
||||
|
||||
# Filter to valid versions >= __version__
|
||||
if __version__ and is_valid_version(__version__):
|
||||
entries = [
|
||||
|
||||
63
backend/onyx/server/prometheus_instrumentation.py
Normal file
63
backend/onyx/server/prometheus_instrumentation.py
Normal file
@@ -0,0 +1,63 @@
|
||||
"""Prometheus instrumentation for the Onyx API server.
|
||||
|
||||
Provides a production-grade metrics configuration with:
|
||||
|
||||
- Exact HTTP status codes (no grouping into 2xx/3xx)
|
||||
- In-progress request gauge broken down by handler and method
|
||||
- Custom latency histogram buckets tuned for API workloads
|
||||
- Request/response size tracking
|
||||
- Slow request counter with configurable threshold
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
from prometheus_client import Counter
|
||||
from prometheus_fastapi_instrumentator import Instrumentator
|
||||
from prometheus_fastapi_instrumentator.metrics import Info
|
||||
from starlette.applications import Starlette
|
||||
|
||||
SLOW_REQUEST_THRESHOLD_SECONDS: float = float(
|
||||
os.environ.get("SLOW_REQUEST_THRESHOLD_SECONDS", "1.0")
|
||||
)
|
||||
|
||||
_EXCLUDED_HANDLERS = [
|
||||
"/health",
|
||||
"/metrics",
|
||||
"/openapi.json",
|
||||
]
|
||||
|
||||
_slow_requests = Counter(
|
||||
"onyx_api_slow_requests_total",
|
||||
"Total requests exceeding the slow request threshold",
|
||||
["method", "handler", "status"],
|
||||
)
|
||||
|
||||
|
||||
def _slow_request_callback(info: Info) -> None:
|
||||
"""Increment slow request counter when duration exceeds threshold."""
|
||||
if info.modified_duration > SLOW_REQUEST_THRESHOLD_SECONDS:
|
||||
_slow_requests.labels(
|
||||
method=info.method,
|
||||
handler=info.modified_handler,
|
||||
status=info.modified_status,
|
||||
).inc()
|
||||
|
||||
|
||||
def setup_prometheus_metrics(app: Starlette) -> None:
|
||||
"""Configure and attach Prometheus instrumentation to the FastAPI app.
|
||||
|
||||
Records exact status codes, tracks in-progress requests per handler,
|
||||
and counts slow requests exceeding a configurable threshold.
|
||||
"""
|
||||
instrumentator = Instrumentator(
|
||||
should_group_status_codes=False,
|
||||
should_ignore_untemplated=False,
|
||||
should_group_untemplated=True,
|
||||
should_instrument_requests_inprogress=True,
|
||||
inprogress_labels=True,
|
||||
excluded_handlers=_EXCLUDED_HANDLERS,
|
||||
)
|
||||
|
||||
instrumentator.add(_slow_request_callback)
|
||||
|
||||
instrumentator.instrument(app).expose(app)
|
||||
@@ -349,6 +349,7 @@ def get_chat_session(
|
||||
shared_status=chat_session.shared_status,
|
||||
current_temperature_override=chat_session.temperature_override,
|
||||
deleted=chat_session.deleted,
|
||||
owner_name=chat_session.user.personal_name if chat_session.user else None,
|
||||
# Packets are now directly serialized as Packet Pydantic models
|
||||
packets=replay_packet_lists,
|
||||
)
|
||||
|
||||
@@ -224,6 +224,7 @@ class ChatSessionDetailResponse(BaseModel):
|
||||
current_alternate_model: str | None
|
||||
current_temperature_override: float | None
|
||||
deleted: bool = False
|
||||
owner_name: str | None = None
|
||||
packets: list[list[Packet]]
|
||||
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@ import time
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import field
|
||||
from dataclasses import replace
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
@@ -134,25 +135,25 @@ EXPECTED_SHARED_DRIVE_1_HIERARCHY = ExpectedHierarchyNode(
|
||||
children=[
|
||||
ExpectedHierarchyNode(
|
||||
raw_node_id=RESTRICTED_ACCESS_FOLDER_ID,
|
||||
display_name="restricted_access_folder",
|
||||
display_name="restricted_access",
|
||||
node_type=HierarchyNodeType.FOLDER,
|
||||
raw_parent_id=SHARED_DRIVE_1_ID,
|
||||
),
|
||||
ExpectedHierarchyNode(
|
||||
raw_node_id=FOLDER_1_ID,
|
||||
display_name="folder_1",
|
||||
display_name="folder 1",
|
||||
node_type=HierarchyNodeType.FOLDER,
|
||||
raw_parent_id=SHARED_DRIVE_1_ID,
|
||||
children=[
|
||||
ExpectedHierarchyNode(
|
||||
raw_node_id=FOLDER_1_1_ID,
|
||||
display_name="folder_1_1",
|
||||
display_name="folder 1-1",
|
||||
node_type=HierarchyNodeType.FOLDER,
|
||||
raw_parent_id=FOLDER_1_ID,
|
||||
),
|
||||
ExpectedHierarchyNode(
|
||||
raw_node_id=FOLDER_1_2_ID,
|
||||
display_name="folder_1_2",
|
||||
display_name="folder 1-2",
|
||||
node_type=HierarchyNodeType.FOLDER,
|
||||
raw_parent_id=FOLDER_1_ID,
|
||||
),
|
||||
@@ -170,25 +171,25 @@ EXPECTED_SHARED_DRIVE_2_HIERARCHY = ExpectedHierarchyNode(
|
||||
children=[
|
||||
ExpectedHierarchyNode(
|
||||
raw_node_id=SECTIONS_FOLDER_ID,
|
||||
display_name="sections_folder",
|
||||
display_name="sections",
|
||||
node_type=HierarchyNodeType.FOLDER,
|
||||
raw_parent_id=SHARED_DRIVE_2_ID,
|
||||
),
|
||||
ExpectedHierarchyNode(
|
||||
raw_node_id=FOLDER_2_ID,
|
||||
display_name="folder_2",
|
||||
display_name="folder 2",
|
||||
node_type=HierarchyNodeType.FOLDER,
|
||||
raw_parent_id=SHARED_DRIVE_2_ID,
|
||||
children=[
|
||||
ExpectedHierarchyNode(
|
||||
raw_node_id=FOLDER_2_1_ID,
|
||||
display_name="folder_2_1",
|
||||
display_name="folder 2-1",
|
||||
node_type=HierarchyNodeType.FOLDER,
|
||||
raw_parent_id=FOLDER_2_ID,
|
||||
),
|
||||
ExpectedHierarchyNode(
|
||||
raw_node_id=FOLDER_2_2_ID,
|
||||
display_name="folder_2_2",
|
||||
display_name="folder 2-2",
|
||||
node_type=HierarchyNodeType.FOLDER,
|
||||
raw_parent_id=FOLDER_2_ID,
|
||||
),
|
||||
@@ -208,27 +209,23 @@ def flatten_hierarchy(
|
||||
return result
|
||||
|
||||
|
||||
def _node(
|
||||
raw_node_id: str,
|
||||
display_name: str,
|
||||
node_type: HierarchyNodeType,
|
||||
raw_parent_id: str | None = None,
|
||||
) -> ExpectedHierarchyNode:
|
||||
return ExpectedHierarchyNode(
|
||||
raw_node_id=raw_node_id,
|
||||
display_name=display_name,
|
||||
node_type=node_type,
|
||||
raw_parent_id=raw_parent_id,
|
||||
)
|
||||
|
||||
|
||||
# Flattened maps for easy lookup
|
||||
EXPECTED_SHARED_DRIVE_1_NODES = flatten_hierarchy(EXPECTED_SHARED_DRIVE_1_HIERARCHY)
|
||||
EXPECTED_SHARED_DRIVE_2_NODES = flatten_hierarchy(EXPECTED_SHARED_DRIVE_2_HIERARCHY)
|
||||
ALL_EXPECTED_SHARED_DRIVE_NODES = {
|
||||
**EXPECTED_SHARED_DRIVE_1_NODES,
|
||||
**EXPECTED_SHARED_DRIVE_2_NODES,
|
||||
}
|
||||
|
||||
# Map of folder ID to its expected parent ID
|
||||
EXPECTED_PARENT_MAPPING: dict[str, str | None] = {
|
||||
SHARED_DRIVE_1_ID: None,
|
||||
RESTRICTED_ACCESS_FOLDER_ID: SHARED_DRIVE_1_ID,
|
||||
FOLDER_1_ID: SHARED_DRIVE_1_ID,
|
||||
FOLDER_1_1_ID: FOLDER_1_ID,
|
||||
FOLDER_1_2_ID: FOLDER_1_ID,
|
||||
SHARED_DRIVE_2_ID: None,
|
||||
SECTIONS_FOLDER_ID: SHARED_DRIVE_2_ID,
|
||||
FOLDER_2_ID: SHARED_DRIVE_2_ID,
|
||||
FOLDER_2_1_ID: FOLDER_2_ID,
|
||||
FOLDER_2_2_ID: FOLDER_2_ID,
|
||||
}
|
||||
|
||||
EXTERNAL_SHARED_FOLDER_URL = (
|
||||
"https://drive.google.com/drive/folders/1sWC7Oi0aQGgifLiMnhTjvkhRWVeDa-XS"
|
||||
@@ -286,7 +283,7 @@ TEST_USER_1_MY_DRIVE_FOLDER_ID = (
|
||||
)
|
||||
|
||||
TEST_USER_1_DRIVE_B_ID = (
|
||||
"0AFskk4zfZm86Uk9PVA" # My_super_special_shared_drive_suuuuuuper_private
|
||||
"0AFskk4zfZm86Uk9PVA" # My_super_special_shared_drive_suuuper_private
|
||||
)
|
||||
TEST_USER_1_DRIVE_B_FOLDER_ID = (
|
||||
"1oIj7nigzvP5xI2F8BmibUA8R_J3AbBA-" # Child folder (silliness)
|
||||
@@ -325,6 +322,106 @@ PERM_SYNC_DRIVE_ACCESS_MAPPING: dict[str, set[str]] = {
|
||||
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID: {ADMIN_EMAIL, TEST_USER_1_EMAIL},
|
||||
}
|
||||
|
||||
# ============================================================================
|
||||
# NON-SHARED-DRIVE HIERARCHY NODES
|
||||
# ============================================================================
|
||||
# These cover My Drive roots, perm sync drives, extra shared drives,
|
||||
# and standalone folders that appear in various tests.
|
||||
# Display names must match what the Google Drive API actually returns.
|
||||
# ============================================================================
|
||||
|
||||
EXPECTED_FOLDER_3 = _node(
|
||||
FOLDER_3_ID, "Folder 3", HierarchyNodeType.FOLDER, ADMIN_MY_DRIVE_ID
|
||||
)
|
||||
|
||||
EXPECTED_ADMIN_MY_DRIVE = _node(ADMIN_MY_DRIVE_ID, "My Drive", HierarchyNodeType.FOLDER)
|
||||
EXPECTED_TEST_USER_1_MY_DRIVE = _node(
|
||||
TEST_USER_1_MY_DRIVE_ID, "My Drive", HierarchyNodeType.FOLDER
|
||||
)
|
||||
EXPECTED_TEST_USER_1_MY_DRIVE_FOLDER = _node(
|
||||
TEST_USER_1_MY_DRIVE_FOLDER_ID,
|
||||
"partial_sharing",
|
||||
HierarchyNodeType.FOLDER,
|
||||
TEST_USER_1_MY_DRIVE_ID,
|
||||
)
|
||||
EXPECTED_TEST_USER_2_MY_DRIVE = _node(
|
||||
TEST_USER_2_MY_DRIVE, "My Drive", HierarchyNodeType.FOLDER
|
||||
)
|
||||
EXPECTED_TEST_USER_3_MY_DRIVE = _node(
|
||||
TEST_USER_3_MY_DRIVE_ID, "My Drive", HierarchyNodeType.FOLDER
|
||||
)
|
||||
|
||||
EXPECTED_PERM_SYNC_DRIVE_ADMIN_ONLY = _node(
|
||||
PERM_SYNC_DRIVE_ADMIN_ONLY_ID,
|
||||
"perm_sync_drive_0dc9d8b5-e243-4c2f-8678-2235958f7d7c",
|
||||
HierarchyNodeType.SHARED_DRIVE,
|
||||
)
|
||||
EXPECTED_PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A = _node(
|
||||
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID,
|
||||
"perm_sync_drive_785db121-0823-4ebe-8689-ad7f52405e32",
|
||||
HierarchyNodeType.SHARED_DRIVE,
|
||||
)
|
||||
EXPECTED_PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B = _node(
|
||||
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID,
|
||||
"perm_sync_drive_d8dc3649-3f65-4392-b87f-4b20e0389673",
|
||||
HierarchyNodeType.SHARED_DRIVE,
|
||||
)
|
||||
|
||||
EXPECTED_TEST_USER_1_DRIVE_B = _node(
|
||||
TEST_USER_1_DRIVE_B_ID,
|
||||
"My_super_special_shared_drive_suuuper_private",
|
||||
HierarchyNodeType.SHARED_DRIVE,
|
||||
)
|
||||
EXPECTED_TEST_USER_1_DRIVE_B_FOLDER = _node(
|
||||
TEST_USER_1_DRIVE_B_FOLDER_ID,
|
||||
"silliness",
|
||||
HierarchyNodeType.FOLDER,
|
||||
TEST_USER_1_DRIVE_B_ID,
|
||||
)
|
||||
EXPECTED_TEST_USER_1_EXTRA_DRIVE_1 = _node(
|
||||
TEST_USER_1_EXTRA_DRIVE_1_ID,
|
||||
"Okay_Admin_fine_I_will_share",
|
||||
HierarchyNodeType.SHARED_DRIVE,
|
||||
)
|
||||
EXPECTED_TEST_USER_1_EXTRA_DRIVE_2 = _node(
|
||||
TEST_USER_1_EXTRA_DRIVE_2_ID, "reee test", HierarchyNodeType.SHARED_DRIVE
|
||||
)
|
||||
EXPECTED_TEST_USER_1_EXTRA_FOLDER = _node(
|
||||
TEST_USER_1_EXTRA_FOLDER_ID,
|
||||
"read only no download test",
|
||||
HierarchyNodeType.FOLDER,
|
||||
)
|
||||
|
||||
EXPECTED_PILL_FOLDER = _node(
|
||||
PILL_FOLDER_ID, "pill_folder", HierarchyNodeType.FOLDER, ADMIN_MY_DRIVE_ID
|
||||
)
|
||||
EXPECTED_EXTERNAL_SHARED_FOLDER = _node(
|
||||
EXTERNAL_SHARED_FOLDER_ID, "Onyx-test", HierarchyNodeType.FOLDER
|
||||
)
|
||||
|
||||
# Comprehensive mapping of ALL known hierarchy nodes.
|
||||
# Every retrieved node is checked against this for display_name and node_type.
|
||||
ALL_EXPECTED_HIERARCHY_NODES: dict[str, ExpectedHierarchyNode] = {
|
||||
**EXPECTED_SHARED_DRIVE_1_NODES,
|
||||
**EXPECTED_SHARED_DRIVE_2_NODES,
|
||||
FOLDER_3_ID: EXPECTED_FOLDER_3,
|
||||
ADMIN_MY_DRIVE_ID: EXPECTED_ADMIN_MY_DRIVE,
|
||||
TEST_USER_1_MY_DRIVE_ID: EXPECTED_TEST_USER_1_MY_DRIVE,
|
||||
TEST_USER_1_MY_DRIVE_FOLDER_ID: EXPECTED_TEST_USER_1_MY_DRIVE_FOLDER,
|
||||
TEST_USER_2_MY_DRIVE: EXPECTED_TEST_USER_2_MY_DRIVE,
|
||||
TEST_USER_3_MY_DRIVE_ID: EXPECTED_TEST_USER_3_MY_DRIVE,
|
||||
PERM_SYNC_DRIVE_ADMIN_ONLY_ID: EXPECTED_PERM_SYNC_DRIVE_ADMIN_ONLY,
|
||||
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID: EXPECTED_PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A,
|
||||
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID: EXPECTED_PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B,
|
||||
TEST_USER_1_DRIVE_B_ID: EXPECTED_TEST_USER_1_DRIVE_B,
|
||||
TEST_USER_1_DRIVE_B_FOLDER_ID: EXPECTED_TEST_USER_1_DRIVE_B_FOLDER,
|
||||
TEST_USER_1_EXTRA_DRIVE_1_ID: EXPECTED_TEST_USER_1_EXTRA_DRIVE_1,
|
||||
TEST_USER_1_EXTRA_DRIVE_2_ID: EXPECTED_TEST_USER_1_EXTRA_DRIVE_2,
|
||||
TEST_USER_1_EXTRA_FOLDER_ID: EXPECTED_TEST_USER_1_EXTRA_FOLDER,
|
||||
PILL_FOLDER_ID: EXPECTED_PILL_FOLDER,
|
||||
EXTERNAL_SHARED_FOLDER_ID: EXPECTED_EXTERNAL_SHARED_FOLDER,
|
||||
}
|
||||
|
||||
# Dictionary for access permissions
|
||||
# All users have access to their own My Drive as well as public files
|
||||
ACCESS_MAPPING: dict[str, list[int]] = {
|
||||
@@ -508,28 +605,29 @@ def load_connector_outputs(
|
||||
|
||||
def assert_hierarchy_nodes_match_expected(
|
||||
retrieved_nodes: list[HierarchyNode],
|
||||
expected_node_ids: set[str],
|
||||
expected_parent_mapping: dict[str, str | None] | None = None,
|
||||
expected_nodes: dict[str, ExpectedHierarchyNode],
|
||||
ignorable_node_ids: set[str] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Assert that retrieved hierarchy nodes match expected structure.
|
||||
|
||||
Checks node IDs, display names, node types, and parent relationships
|
||||
for EVERY retrieved node (global checks).
|
||||
|
||||
Args:
|
||||
retrieved_nodes: List of HierarchyNode objects from the connector
|
||||
expected_node_ids: Set of expected raw_node_ids
|
||||
expected_parent_mapping: Optional dict mapping node_id -> parent_id for parent verification
|
||||
ignorable_node_ids: Optional set of node IDs that can be missing or extra without failing.
|
||||
Useful for nodes that are non-deterministically returned by the connector.
|
||||
expected_nodes: Dict mapping raw_node_id -> ExpectedHierarchyNode with
|
||||
expected display_name, node_type, and raw_parent_id
|
||||
ignorable_node_ids: Optional set of node IDs that can be missing or extra
|
||||
without failing. Useful for non-deterministically returned nodes.
|
||||
"""
|
||||
expected_node_ids = set(expected_nodes.keys())
|
||||
retrieved_node_ids = {node.raw_node_id for node in retrieved_nodes}
|
||||
ignorable = ignorable_node_ids or set()
|
||||
|
||||
# Calculate differences, excluding ignorable nodes
|
||||
missing = expected_node_ids - retrieved_node_ids - ignorable
|
||||
extra = retrieved_node_ids - expected_node_ids - ignorable
|
||||
|
||||
# Print discrepancies for debugging
|
||||
if missing or extra:
|
||||
print("Expected hierarchy node IDs:")
|
||||
print(sorted(expected_node_ids))
|
||||
@@ -543,181 +641,146 @@ def assert_hierarchy_nodes_match_expected(
|
||||
print("Ignorable node IDs:")
|
||||
print(sorted(ignorable))
|
||||
|
||||
assert not missing and not extra, (
|
||||
f"Hierarchy node mismatch. " f"Missing: {missing}, " f"Extra: {extra}"
|
||||
)
|
||||
assert (
|
||||
not missing and not extra
|
||||
), f"Hierarchy node mismatch. Missing: {missing}, Extra: {extra}"
|
||||
|
||||
# Verify parent relationships if provided
|
||||
if expected_parent_mapping is not None:
|
||||
for node in retrieved_nodes:
|
||||
if node.raw_node_id not in expected_parent_mapping:
|
||||
continue
|
||||
expected_parent = expected_parent_mapping[node.raw_node_id]
|
||||
assert node.raw_parent_id == expected_parent, (
|
||||
for node in retrieved_nodes:
|
||||
if node.raw_node_id in ignorable and node.raw_node_id not in expected_nodes:
|
||||
continue
|
||||
|
||||
assert (
|
||||
node.raw_node_id in expected_nodes
|
||||
), f"Node {node.raw_node_id} ({node.display_name}) not found in expected_nodes"
|
||||
expected = expected_nodes[node.raw_node_id]
|
||||
|
||||
assert node.display_name == expected.display_name, (
|
||||
f"Display name mismatch for node {node.raw_node_id}: "
|
||||
f"expected '{expected.display_name}', got '{node.display_name}'"
|
||||
)
|
||||
assert node.node_type == expected.node_type, (
|
||||
f"Node type mismatch for node {node.raw_node_id}: "
|
||||
f"expected '{expected.node_type}', got '{node.node_type}'"
|
||||
)
|
||||
if expected.raw_parent_id is not None:
|
||||
assert node.raw_parent_id == expected.raw_parent_id, (
|
||||
f"Parent mismatch for node {node.raw_node_id} ({node.display_name}): "
|
||||
f"expected parent={expected_parent}, got parent={node.raw_parent_id}"
|
||||
f"expected parent={expected.raw_parent_id}, got parent={node.raw_parent_id}"
|
||||
)
|
||||
|
||||
|
||||
def _pick(
|
||||
*node_ids: str,
|
||||
) -> dict[str, ExpectedHierarchyNode]:
|
||||
"""Pick nodes from ALL_EXPECTED_HIERARCHY_NODES by their IDs."""
|
||||
return {nid: ALL_EXPECTED_HIERARCHY_NODES[nid] for nid in node_ids}
|
||||
|
||||
|
||||
def _clear_parents(
|
||||
nodes: dict[str, ExpectedHierarchyNode],
|
||||
*node_ids: str,
|
||||
) -> dict[str, ExpectedHierarchyNode]:
|
||||
"""Return a shallow copy of nodes with the specified nodes' parents set to None.
|
||||
Useful for OAuth tests where the user can't resolve certain parents
|
||||
(e.g. a folder in another user's My Drive)."""
|
||||
result = dict(nodes)
|
||||
for nid in node_ids:
|
||||
result[nid] = replace(result[nid], raw_parent_id=None)
|
||||
return result
|
||||
|
||||
|
||||
def get_expected_hierarchy_for_shared_drives(
|
||||
include_drive_1: bool = True,
|
||||
include_drive_2: bool = True,
|
||||
include_restricted_folder: bool = True,
|
||||
) -> tuple[set[str], dict[str, str | None]]:
|
||||
"""
|
||||
Get expected hierarchy node IDs and parent mapping for shared drives.
|
||||
|
||||
Returns:
|
||||
Tuple of (expected_node_ids, expected_parent_mapping)
|
||||
"""
|
||||
expected_ids: set[str] = set()
|
||||
expected_parents: dict[str, str | None] = {}
|
||||
) -> dict[str, ExpectedHierarchyNode]:
|
||||
"""Get expected hierarchy nodes for shared drives."""
|
||||
result: dict[str, ExpectedHierarchyNode] = {}
|
||||
|
||||
if include_drive_1:
|
||||
expected_ids.add(SHARED_DRIVE_1_ID)
|
||||
expected_parents[SHARED_DRIVE_1_ID] = None
|
||||
|
||||
if include_restricted_folder:
|
||||
expected_ids.add(RESTRICTED_ACCESS_FOLDER_ID)
|
||||
expected_parents[RESTRICTED_ACCESS_FOLDER_ID] = SHARED_DRIVE_1_ID
|
||||
|
||||
expected_ids.add(FOLDER_1_ID)
|
||||
expected_parents[FOLDER_1_ID] = SHARED_DRIVE_1_ID
|
||||
|
||||
expected_ids.add(FOLDER_1_1_ID)
|
||||
expected_parents[FOLDER_1_1_ID] = FOLDER_1_ID
|
||||
|
||||
expected_ids.add(FOLDER_1_2_ID)
|
||||
expected_parents[FOLDER_1_2_ID] = FOLDER_1_ID
|
||||
result.update(EXPECTED_SHARED_DRIVE_1_NODES)
|
||||
if not include_restricted_folder:
|
||||
result.pop(RESTRICTED_ACCESS_FOLDER_ID, None)
|
||||
|
||||
if include_drive_2:
|
||||
expected_ids.add(SHARED_DRIVE_2_ID)
|
||||
expected_parents[SHARED_DRIVE_2_ID] = None
|
||||
result.update(EXPECTED_SHARED_DRIVE_2_NODES)
|
||||
|
||||
expected_ids.add(SECTIONS_FOLDER_ID)
|
||||
expected_parents[SECTIONS_FOLDER_ID] = SHARED_DRIVE_2_ID
|
||||
|
||||
expected_ids.add(FOLDER_2_ID)
|
||||
expected_parents[FOLDER_2_ID] = SHARED_DRIVE_2_ID
|
||||
|
||||
expected_ids.add(FOLDER_2_1_ID)
|
||||
expected_parents[FOLDER_2_1_ID] = FOLDER_2_ID
|
||||
|
||||
expected_ids.add(FOLDER_2_2_ID)
|
||||
expected_parents[FOLDER_2_2_ID] = FOLDER_2_ID
|
||||
|
||||
return expected_ids, expected_parents
|
||||
return result
|
||||
|
||||
|
||||
def get_expected_hierarchy_for_folder_1() -> tuple[set[str], dict[str, str | None]]:
|
||||
def get_expected_hierarchy_for_folder_1() -> dict[str, ExpectedHierarchyNode]:
|
||||
"""Get expected hierarchy for folder_1 and its children only."""
|
||||
return (
|
||||
{FOLDER_1_ID, FOLDER_1_1_ID, FOLDER_1_2_ID},
|
||||
{
|
||||
FOLDER_1_ID: SHARED_DRIVE_1_ID,
|
||||
FOLDER_1_1_ID: FOLDER_1_ID,
|
||||
FOLDER_1_2_ID: FOLDER_1_ID,
|
||||
},
|
||||
)
|
||||
return _pick(FOLDER_1_ID, FOLDER_1_1_ID, FOLDER_1_2_ID)
|
||||
|
||||
|
||||
def get_expected_hierarchy_for_folder_2() -> tuple[set[str], dict[str, str | None]]:
|
||||
def get_expected_hierarchy_for_folder_2() -> dict[str, ExpectedHierarchyNode]:
|
||||
"""Get expected hierarchy for folder_2 and its children only."""
|
||||
return (
|
||||
{FOLDER_2_ID, FOLDER_2_1_ID, FOLDER_2_2_ID},
|
||||
{
|
||||
FOLDER_2_ID: SHARED_DRIVE_2_ID,
|
||||
FOLDER_2_1_ID: FOLDER_2_ID,
|
||||
FOLDER_2_2_ID: FOLDER_2_ID,
|
||||
},
|
||||
)
|
||||
return _pick(FOLDER_2_ID, FOLDER_2_1_ID, FOLDER_2_2_ID)
|
||||
|
||||
|
||||
def get_expected_hierarchy_for_test_user_1() -> tuple[set[str], dict[str, str | None]]:
|
||||
def get_expected_hierarchy_for_test_user_1() -> dict[str, ExpectedHierarchyNode]:
|
||||
"""
|
||||
Get expected hierarchy for test_user_1's full access.
|
||||
Get expected hierarchy for test_user_1's full access (OAuth).
|
||||
|
||||
test_user_1 has access to:
|
||||
- shared_drive_1 and its contents (folder_1, folder_1_1, folder_1_2)
|
||||
- folder_3 (shared from admin's My Drive)
|
||||
- PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A and PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B
|
||||
- Additional drives/folders the user has access to
|
||||
|
||||
NOTE: Folder 3 lives in the admin's My Drive. When running as an OAuth
|
||||
connector for test_user_1, the Google Drive API won't return the parent
|
||||
for Folder 3 because the user can't access the admin's My Drive root.
|
||||
"""
|
||||
# Start with shared_drive_1 hierarchy
|
||||
expected_ids, expected_parents = get_expected_hierarchy_for_shared_drives(
|
||||
result = get_expected_hierarchy_for_shared_drives(
|
||||
include_drive_1=True,
|
||||
include_drive_2=False,
|
||||
include_restricted_folder=False,
|
||||
)
|
||||
|
||||
# folder_3 is shared from admin's My Drive
|
||||
expected_ids.add(FOLDER_3_ID)
|
||||
|
||||
# Perm sync drives that test_user_1 has access to
|
||||
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID)
|
||||
expected_parents[PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID] = None
|
||||
|
||||
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID)
|
||||
expected_parents[PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID] = None
|
||||
|
||||
# Additional drives/folders test_user_1 has access to
|
||||
expected_ids.add(TEST_USER_1_MY_DRIVE_ID)
|
||||
expected_parents[TEST_USER_1_MY_DRIVE_ID] = None
|
||||
|
||||
expected_ids.add(TEST_USER_1_MY_DRIVE_FOLDER_ID)
|
||||
expected_parents[TEST_USER_1_MY_DRIVE_FOLDER_ID] = TEST_USER_1_MY_DRIVE_ID
|
||||
|
||||
expected_ids.add(TEST_USER_1_DRIVE_B_ID)
|
||||
expected_parents[TEST_USER_1_DRIVE_B_ID] = None
|
||||
|
||||
expected_ids.add(TEST_USER_1_DRIVE_B_FOLDER_ID)
|
||||
expected_parents[TEST_USER_1_DRIVE_B_FOLDER_ID] = TEST_USER_1_DRIVE_B_ID
|
||||
|
||||
expected_ids.add(TEST_USER_1_EXTRA_DRIVE_1_ID)
|
||||
expected_parents[TEST_USER_1_EXTRA_DRIVE_1_ID] = None
|
||||
|
||||
expected_ids.add(TEST_USER_1_EXTRA_DRIVE_2_ID)
|
||||
expected_parents[TEST_USER_1_EXTRA_DRIVE_2_ID] = None
|
||||
|
||||
expected_ids.add(TEST_USER_1_EXTRA_FOLDER_ID)
|
||||
# Parent unknown, skip adding to expected_parents
|
||||
|
||||
return expected_ids, expected_parents
|
||||
result.update(
|
||||
_pick(
|
||||
FOLDER_3_ID,
|
||||
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID,
|
||||
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID,
|
||||
TEST_USER_1_MY_DRIVE_ID,
|
||||
TEST_USER_1_MY_DRIVE_FOLDER_ID,
|
||||
TEST_USER_1_DRIVE_B_ID,
|
||||
TEST_USER_1_DRIVE_B_FOLDER_ID,
|
||||
TEST_USER_1_EXTRA_DRIVE_1_ID,
|
||||
TEST_USER_1_EXTRA_DRIVE_2_ID,
|
||||
TEST_USER_1_EXTRA_FOLDER_ID,
|
||||
)
|
||||
)
|
||||
return _clear_parents(result, FOLDER_3_ID)
|
||||
|
||||
|
||||
def get_expected_hierarchy_for_test_user_1_shared_drives_only() -> (
|
||||
tuple[set[str], dict[str, str | None]]
|
||||
dict[str, ExpectedHierarchyNode]
|
||||
):
|
||||
"""Expected hierarchy nodes when test_user_1 runs with include_shared_drives=True only."""
|
||||
expected_ids, expected_parents = get_expected_hierarchy_for_test_user_1()
|
||||
|
||||
# This mode should not include My Drive roots/folders.
|
||||
expected_ids.discard(TEST_USER_1_MY_DRIVE_ID)
|
||||
expected_ids.discard(TEST_USER_1_MY_DRIVE_FOLDER_ID)
|
||||
|
||||
# don't include shared with me
|
||||
expected_ids.discard(FOLDER_3_ID)
|
||||
expected_ids.discard(TEST_USER_1_EXTRA_FOLDER_ID)
|
||||
|
||||
return expected_ids, expected_parents
|
||||
result = get_expected_hierarchy_for_test_user_1()
|
||||
for nid in (
|
||||
TEST_USER_1_MY_DRIVE_ID,
|
||||
TEST_USER_1_MY_DRIVE_FOLDER_ID,
|
||||
FOLDER_3_ID,
|
||||
TEST_USER_1_EXTRA_FOLDER_ID,
|
||||
):
|
||||
result.pop(nid, None)
|
||||
return result
|
||||
|
||||
|
||||
def get_expected_hierarchy_for_test_user_1_shared_with_me_only() -> (
|
||||
tuple[set[str], dict[str, str | None]]
|
||||
dict[str, ExpectedHierarchyNode]
|
||||
):
|
||||
"""Expected hierarchy nodes when test_user_1 runs with include_files_shared_with_me=True only."""
|
||||
expected_ids: set[str] = {FOLDER_3_ID, TEST_USER_1_EXTRA_FOLDER_ID}
|
||||
expected_parents: dict[str, str | None] = {}
|
||||
return expected_ids, expected_parents
|
||||
return _clear_parents(
|
||||
_pick(FOLDER_3_ID, TEST_USER_1_EXTRA_FOLDER_ID),
|
||||
FOLDER_3_ID,
|
||||
)
|
||||
|
||||
|
||||
def get_expected_hierarchy_for_test_user_1_my_drive_only() -> (
|
||||
tuple[set[str], dict[str, str | None]]
|
||||
dict[str, ExpectedHierarchyNode]
|
||||
):
|
||||
"""Expected hierarchy nodes when test_user_1 runs with include_my_drives=True only."""
|
||||
expected_ids: set[str] = {TEST_USER_1_MY_DRIVE_ID, TEST_USER_1_MY_DRIVE_FOLDER_ID}
|
||||
expected_parents: dict[str, str | None] = {
|
||||
TEST_USER_1_MY_DRIVE_ID: None,
|
||||
TEST_USER_1_MY_DRIVE_FOLDER_ID: TEST_USER_1_MY_DRIVE_ID,
|
||||
}
|
||||
return expected_ids, expected_parents
|
||||
return _pick(TEST_USER_1_MY_DRIVE_ID, TEST_USER_1_MY_DRIVE_FOLDER_ID)
|
||||
|
||||
@@ -3,12 +3,11 @@ from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import _pick
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_EMAIL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FOLDER_3_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
ADMIN_MY_DRIVE_ID,
|
||||
)
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_MY_DRIVE_ID
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
assert_expected_docs_in_retrieved_docs,
|
||||
)
|
||||
@@ -16,21 +15,15 @@ from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
assert_hierarchy_nodes_match_expected,
|
||||
)
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_ID
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_URL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_2_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_2_ID
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_2_URL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_ID
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_1_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_1_ID
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_1_URL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_2_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_2_ID
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_2_URL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_ID
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_URL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_3_ID
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_3_URL
|
||||
@@ -47,18 +40,15 @@ from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
PERM_SYNC_DRIVE_ADMIN_ONLY_ID,
|
||||
)
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
PILL_FOLDER_ID,
|
||||
)
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import PILL_FOLDER_ID
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
RESTRICTED_ACCESS_FOLDER_ID,
|
||||
)
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import SECTIONS_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import SECTIONS_FOLDER_ID
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_ID
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_URL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_2_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_2_ID
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
TEST_USER_1_EXTRA_DRIVE_1_ID,
|
||||
)
|
||||
@@ -90,7 +80,6 @@ def test_include_all(
|
||||
)
|
||||
output = load_connector_outputs(connector)
|
||||
|
||||
# Should get everything in shared and admin's My Drive with oauth
|
||||
expected_file_ids = (
|
||||
ADMIN_FILE_IDS
|
||||
+ ADMIN_FOLDER_3_FILE_IDS
|
||||
@@ -109,33 +98,28 @@ def test_include_all(
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
|
||||
# Verify hierarchy nodes for shared drives
|
||||
# When include_shared_drives=True, we get ALL shared drives the admin has access to
|
||||
expected_ids, expected_parents = get_expected_hierarchy_for_shared_drives(
|
||||
expected_nodes = get_expected_hierarchy_for_shared_drives(
|
||||
include_drive_1=True,
|
||||
include_drive_2=True,
|
||||
# Restricted folder may not always be retrieved due to access limitations
|
||||
include_restricted_folder=False,
|
||||
)
|
||||
|
||||
# Add additional shared drives that admin has access to
|
||||
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_ONLY_ID)
|
||||
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID)
|
||||
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID)
|
||||
expected_ids.add(TEST_USER_1_EXTRA_DRIVE_1_ID)
|
||||
expected_ids.add(TEST_USER_1_EXTRA_DRIVE_2_ID)
|
||||
expected_ids.add(ADMIN_MY_DRIVE_ID)
|
||||
expected_ids.add(PILL_FOLDER_ID)
|
||||
expected_ids.add(RESTRICTED_ACCESS_FOLDER_ID)
|
||||
expected_ids.add(TEST_USER_1_EXTRA_FOLDER_ID)
|
||||
|
||||
# My Drive folders
|
||||
expected_ids.add(FOLDER_3_ID)
|
||||
|
||||
expected_nodes.update(
|
||||
_pick(
|
||||
PERM_SYNC_DRIVE_ADMIN_ONLY_ID,
|
||||
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID,
|
||||
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID,
|
||||
TEST_USER_1_EXTRA_DRIVE_1_ID,
|
||||
TEST_USER_1_EXTRA_DRIVE_2_ID,
|
||||
ADMIN_MY_DRIVE_ID,
|
||||
PILL_FOLDER_ID,
|
||||
RESTRICTED_ACCESS_FOLDER_ID,
|
||||
TEST_USER_1_EXTRA_FOLDER_ID,
|
||||
FOLDER_3_ID,
|
||||
)
|
||||
)
|
||||
assert_hierarchy_nodes_match_expected(
|
||||
retrieved_nodes=output.hierarchy_nodes,
|
||||
expected_node_ids=expected_ids,
|
||||
expected_parent_mapping=expected_parents,
|
||||
expected_nodes=expected_nodes,
|
||||
ignorable_node_ids={RESTRICTED_ACCESS_FOLDER_ID},
|
||||
)
|
||||
|
||||
@@ -160,7 +144,6 @@ def test_include_shared_drives_only(
|
||||
)
|
||||
output = load_connector_outputs(connector)
|
||||
|
||||
# Should only get shared drives
|
||||
expected_file_ids = (
|
||||
SHARED_DRIVE_1_FILE_IDS
|
||||
+ FOLDER_1_FILE_IDS
|
||||
@@ -177,26 +160,24 @@ def test_include_shared_drives_only(
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
|
||||
# Verify hierarchy nodes - should include both shared drives and their folders
|
||||
# When include_shared_drives=True, we get ALL shared drives admin has access to
|
||||
expected_ids, expected_parents = get_expected_hierarchy_for_shared_drives(
|
||||
expected_nodes = get_expected_hierarchy_for_shared_drives(
|
||||
include_drive_1=True,
|
||||
include_drive_2=True,
|
||||
include_restricted_folder=False,
|
||||
)
|
||||
|
||||
# Add additional shared drives that admin has access to
|
||||
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_ONLY_ID)
|
||||
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID)
|
||||
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID)
|
||||
expected_ids.add(TEST_USER_1_EXTRA_DRIVE_1_ID)
|
||||
expected_ids.add(TEST_USER_1_EXTRA_DRIVE_2_ID)
|
||||
expected_ids.add(RESTRICTED_ACCESS_FOLDER_ID)
|
||||
|
||||
expected_nodes.update(
|
||||
_pick(
|
||||
PERM_SYNC_DRIVE_ADMIN_ONLY_ID,
|
||||
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID,
|
||||
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID,
|
||||
TEST_USER_1_EXTRA_DRIVE_1_ID,
|
||||
TEST_USER_1_EXTRA_DRIVE_2_ID,
|
||||
RESTRICTED_ACCESS_FOLDER_ID,
|
||||
)
|
||||
)
|
||||
assert_hierarchy_nodes_match_expected(
|
||||
retrieved_nodes=output.hierarchy_nodes,
|
||||
expected_node_ids=expected_ids,
|
||||
expected_parent_mapping=expected_parents,
|
||||
expected_nodes=expected_nodes,
|
||||
)
|
||||
|
||||
|
||||
@@ -220,24 +201,21 @@ def test_include_my_drives_only(
|
||||
)
|
||||
output = load_connector_outputs(connector)
|
||||
|
||||
# Should only get primary_admins My Drive because we are impersonating them
|
||||
expected_file_ids = ADMIN_FILE_IDS + ADMIN_FOLDER_3_FILE_IDS
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=output.documents,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
|
||||
# Verify hierarchy nodes - My Drive should yield folder_3 as a hierarchy node
|
||||
# Also includes admin's My Drive root and folders shared with admin
|
||||
expected_ids = {
|
||||
expected_nodes = _pick(
|
||||
FOLDER_3_ID,
|
||||
ADMIN_MY_DRIVE_ID,
|
||||
PILL_FOLDER_ID,
|
||||
TEST_USER_1_EXTRA_FOLDER_ID,
|
||||
}
|
||||
)
|
||||
assert_hierarchy_nodes_match_expected(
|
||||
retrieved_nodes=output.hierarchy_nodes,
|
||||
expected_node_ids=expected_ids,
|
||||
expected_nodes=expected_nodes,
|
||||
)
|
||||
|
||||
|
||||
@@ -273,17 +251,14 @@ def test_drive_one_only(
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
|
||||
# Verify hierarchy nodes - should only include shared_drive_1 and its folders
|
||||
expected_ids, expected_parents = get_expected_hierarchy_for_shared_drives(
|
||||
expected_nodes = get_expected_hierarchy_for_shared_drives(
|
||||
include_drive_1=True,
|
||||
include_drive_2=False,
|
||||
include_restricted_folder=False,
|
||||
)
|
||||
# Restricted folder is non-deterministically returned by the connector
|
||||
assert_hierarchy_nodes_match_expected(
|
||||
retrieved_nodes=output.hierarchy_nodes,
|
||||
expected_node_ids=expected_ids,
|
||||
expected_parent_mapping=expected_parents,
|
||||
expected_nodes=expected_nodes,
|
||||
ignorable_node_ids={RESTRICTED_ACCESS_FOLDER_ID},
|
||||
)
|
||||
|
||||
@@ -324,33 +299,15 @@ def test_folder_and_shared_drive(
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
|
||||
# Verify hierarchy nodes - shared_drive_1 and folder_2 with children
|
||||
# SHARED_DRIVE_2_ID is included because folder_2's parent is shared_drive_2
|
||||
expected_ids = {
|
||||
SHARED_DRIVE_1_ID,
|
||||
FOLDER_1_ID,
|
||||
FOLDER_1_1_ID,
|
||||
FOLDER_1_2_ID,
|
||||
SHARED_DRIVE_2_ID,
|
||||
FOLDER_2_ID,
|
||||
FOLDER_2_1_ID,
|
||||
FOLDER_2_2_ID,
|
||||
}
|
||||
expected_parents = {
|
||||
SHARED_DRIVE_1_ID: None,
|
||||
FOLDER_1_ID: SHARED_DRIVE_1_ID,
|
||||
FOLDER_1_1_ID: FOLDER_1_ID,
|
||||
FOLDER_1_2_ID: FOLDER_1_ID,
|
||||
SHARED_DRIVE_2_ID: None,
|
||||
FOLDER_2_ID: SHARED_DRIVE_2_ID,
|
||||
FOLDER_2_1_ID: FOLDER_2_ID,
|
||||
FOLDER_2_2_ID: FOLDER_2_ID,
|
||||
}
|
||||
# Restricted folder is non-deterministically returned
|
||||
expected_nodes = get_expected_hierarchy_for_shared_drives(
|
||||
include_drive_1=True,
|
||||
include_drive_2=True,
|
||||
include_restricted_folder=False,
|
||||
)
|
||||
expected_nodes.pop(SECTIONS_FOLDER_ID, None)
|
||||
assert_hierarchy_nodes_match_expected(
|
||||
retrieved_nodes=output.hierarchy_nodes,
|
||||
expected_node_ids=expected_ids,
|
||||
expected_parent_mapping=expected_parents,
|
||||
expected_nodes=expected_nodes,
|
||||
ignorable_node_ids={RESTRICTED_ACCESS_FOLDER_ID},
|
||||
)
|
||||
|
||||
@@ -370,7 +327,6 @@ def test_folders_only(
|
||||
FOLDER_2_2_URL,
|
||||
FOLDER_3_URL,
|
||||
]
|
||||
# This should get converted to a drive request and spit out a warning in the logs
|
||||
shared_drive_urls = [
|
||||
FOLDER_1_1_URL,
|
||||
]
|
||||
@@ -397,23 +353,16 @@ def test_folders_only(
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
|
||||
# Verify hierarchy nodes - specific folders requested plus their parent nodes
|
||||
# The connector walks up the hierarchy to include parent drives/folders
|
||||
expected_ids = {
|
||||
SHARED_DRIVE_1_ID,
|
||||
FOLDER_1_ID,
|
||||
FOLDER_1_1_ID,
|
||||
FOLDER_1_2_ID,
|
||||
SHARED_DRIVE_2_ID,
|
||||
FOLDER_2_ID,
|
||||
FOLDER_2_1_ID,
|
||||
FOLDER_2_2_ID,
|
||||
ADMIN_MY_DRIVE_ID,
|
||||
FOLDER_3_ID,
|
||||
}
|
||||
expected_nodes = get_expected_hierarchy_for_shared_drives(
|
||||
include_drive_1=True,
|
||||
include_drive_2=True,
|
||||
include_restricted_folder=False,
|
||||
)
|
||||
expected_nodes.pop(SECTIONS_FOLDER_ID, None)
|
||||
expected_nodes.update(_pick(ADMIN_MY_DRIVE_ID, FOLDER_3_ID))
|
||||
assert_hierarchy_nodes_match_expected(
|
||||
retrieved_nodes=output.hierarchy_nodes,
|
||||
expected_node_ids=expected_ids,
|
||||
expected_nodes=expected_nodes,
|
||||
)
|
||||
|
||||
|
||||
@@ -446,9 +395,8 @@ def test_personal_folders_only(
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
|
||||
# Verify hierarchy nodes - folder_3 and its parent (admin's My Drive root)
|
||||
expected_ids = {FOLDER_3_ID, ADMIN_MY_DRIVE_ID}
|
||||
expected_nodes = _pick(FOLDER_3_ID, ADMIN_MY_DRIVE_ID)
|
||||
assert_hierarchy_nodes_match_expected(
|
||||
retrieved_nodes=output.hierarchy_nodes,
|
||||
expected_node_ids=expected_ids,
|
||||
expected_nodes=expected_nodes,
|
||||
)
|
||||
|
||||
@@ -14,11 +14,10 @@ from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.db.utils import DocumentRow
|
||||
from onyx.db.utils import SortOrder
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import _pick
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ACCESS_MAPPING
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_EMAIL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
ADMIN_MY_DRIVE_ID,
|
||||
)
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_MY_DRIVE_ID
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
assert_hierarchy_nodes_match_expected,
|
||||
)
|
||||
@@ -262,37 +261,35 @@ def test_gdrive_perm_sync_with_real_data(
|
||||
hierarchy_connector = _build_connector(google_drive_service_acct_connector_factory)
|
||||
output = load_connector_outputs(hierarchy_connector, include_permissions=True)
|
||||
|
||||
# Verify the expected shared drives hierarchy
|
||||
# When include_shared_drives=True and include_my_drives=True, we get ALL drives
|
||||
expected_ids, expected_parents = get_expected_hierarchy_for_shared_drives(
|
||||
expected_nodes = get_expected_hierarchy_for_shared_drives(
|
||||
include_drive_1=True,
|
||||
include_drive_2=True,
|
||||
include_restricted_folder=False,
|
||||
)
|
||||
|
||||
# Add additional shared drives in the organization
|
||||
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_ONLY_ID)
|
||||
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID)
|
||||
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID)
|
||||
expected_ids.add(TEST_USER_1_MY_DRIVE_ID)
|
||||
expected_ids.add(TEST_USER_1_MY_DRIVE_FOLDER_ID)
|
||||
expected_ids.add(TEST_USER_1_DRIVE_B_ID)
|
||||
expected_ids.add(TEST_USER_1_DRIVE_B_FOLDER_ID)
|
||||
expected_ids.add(TEST_USER_1_EXTRA_DRIVE_1_ID)
|
||||
expected_ids.add(TEST_USER_1_EXTRA_DRIVE_2_ID)
|
||||
expected_ids.add(ADMIN_MY_DRIVE_ID)
|
||||
expected_ids.add(TEST_USER_2_MY_DRIVE)
|
||||
expected_ids.add(TEST_USER_3_MY_DRIVE_ID)
|
||||
expected_ids.add(PILL_FOLDER_ID)
|
||||
expected_ids.add(RESTRICTED_ACCESS_FOLDER_ID)
|
||||
expected_ids.add(TEST_USER_1_EXTRA_FOLDER_ID)
|
||||
expected_ids.add(EXTERNAL_SHARED_FOLDER_ID)
|
||||
expected_ids.add(FOLDER_3_ID)
|
||||
|
||||
expected_nodes.update(
|
||||
_pick(
|
||||
PERM_SYNC_DRIVE_ADMIN_ONLY_ID,
|
||||
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID,
|
||||
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID,
|
||||
TEST_USER_1_MY_DRIVE_ID,
|
||||
TEST_USER_1_MY_DRIVE_FOLDER_ID,
|
||||
TEST_USER_1_DRIVE_B_ID,
|
||||
TEST_USER_1_DRIVE_B_FOLDER_ID,
|
||||
TEST_USER_1_EXTRA_DRIVE_1_ID,
|
||||
TEST_USER_1_EXTRA_DRIVE_2_ID,
|
||||
ADMIN_MY_DRIVE_ID,
|
||||
TEST_USER_2_MY_DRIVE,
|
||||
TEST_USER_3_MY_DRIVE_ID,
|
||||
PILL_FOLDER_ID,
|
||||
RESTRICTED_ACCESS_FOLDER_ID,
|
||||
TEST_USER_1_EXTRA_FOLDER_ID,
|
||||
EXTERNAL_SHARED_FOLDER_ID,
|
||||
FOLDER_3_ID,
|
||||
)
|
||||
)
|
||||
assert_hierarchy_nodes_match_expected(
|
||||
retrieved_nodes=output.hierarchy_nodes,
|
||||
expected_node_ids=expected_ids,
|
||||
expected_parent_mapping=expected_parents,
|
||||
expected_nodes=expected_nodes,
|
||||
ignorable_node_ids={RESTRICTED_ACCESS_FOLDER_ID},
|
||||
)
|
||||
|
||||
|
||||
@@ -4,12 +4,11 @@ from unittest.mock import patch
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import _pick
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_EMAIL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FOLDER_3_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
ADMIN_MY_DRIVE_ID,
|
||||
)
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_MY_DRIVE_ID
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
assert_expected_docs_in_retrieved_docs,
|
||||
)
|
||||
@@ -29,21 +28,15 @@ from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
EXTERNAL_SHARED_FOLDER_URL,
|
||||
)
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_ID
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_1_URL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_2_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_2_ID
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_2_URL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_1_ID
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_1_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_1_ID
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_1_URL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_2_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_2_ID
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_2_URL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_ID
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_2_URL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_3_ID
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import FOLDER_3_URL
|
||||
@@ -74,11 +67,10 @@ from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
RESTRICTED_ACCESS_FOLDER_URL,
|
||||
)
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import SECTIONS_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import SECTIONS_FOLDER_ID
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_ID
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_1_URL
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_2_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import SHARED_DRIVE_2_ID
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
TEST_USER_1_DRIVE_B_FOLDER_ID,
|
||||
)
|
||||
@@ -156,39 +148,35 @@ def test_include_all(
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
|
||||
# Verify hierarchy nodes for shared drives
|
||||
# When include_shared_drives=True, we get ALL shared drives in the organization
|
||||
expected_ids, expected_parents = get_expected_hierarchy_for_shared_drives(
|
||||
expected_nodes = get_expected_hierarchy_for_shared_drives(
|
||||
include_drive_1=True,
|
||||
include_drive_2=True,
|
||||
include_restricted_folder=False,
|
||||
)
|
||||
|
||||
# Add additional shared drives in the organization
|
||||
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_ONLY_ID)
|
||||
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID)
|
||||
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID)
|
||||
expected_ids.add(TEST_USER_1_MY_DRIVE_ID)
|
||||
expected_ids.add(TEST_USER_1_MY_DRIVE_FOLDER_ID)
|
||||
expected_ids.add(TEST_USER_1_DRIVE_B_ID)
|
||||
expected_ids.add(TEST_USER_1_DRIVE_B_FOLDER_ID)
|
||||
expected_ids.add(TEST_USER_1_EXTRA_DRIVE_1_ID)
|
||||
expected_ids.add(TEST_USER_1_EXTRA_DRIVE_2_ID)
|
||||
expected_ids.add(ADMIN_MY_DRIVE_ID)
|
||||
expected_ids.add(TEST_USER_2_MY_DRIVE)
|
||||
expected_ids.add(TEST_USER_3_MY_DRIVE_ID)
|
||||
expected_ids.add(PILL_FOLDER_ID)
|
||||
expected_ids.add(RESTRICTED_ACCESS_FOLDER_ID)
|
||||
expected_ids.add(TEST_USER_1_EXTRA_FOLDER_ID)
|
||||
expected_ids.add(EXTERNAL_SHARED_FOLDER_ID)
|
||||
|
||||
# My Drive folders
|
||||
expected_ids.add(FOLDER_3_ID)
|
||||
|
||||
expected_nodes.update(
|
||||
_pick(
|
||||
PERM_SYNC_DRIVE_ADMIN_ONLY_ID,
|
||||
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID,
|
||||
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID,
|
||||
TEST_USER_1_MY_DRIVE_ID,
|
||||
TEST_USER_1_MY_DRIVE_FOLDER_ID,
|
||||
TEST_USER_1_DRIVE_B_ID,
|
||||
TEST_USER_1_DRIVE_B_FOLDER_ID,
|
||||
TEST_USER_1_EXTRA_DRIVE_1_ID,
|
||||
TEST_USER_1_EXTRA_DRIVE_2_ID,
|
||||
ADMIN_MY_DRIVE_ID,
|
||||
TEST_USER_2_MY_DRIVE,
|
||||
TEST_USER_3_MY_DRIVE_ID,
|
||||
PILL_FOLDER_ID,
|
||||
RESTRICTED_ACCESS_FOLDER_ID,
|
||||
TEST_USER_1_EXTRA_FOLDER_ID,
|
||||
EXTERNAL_SHARED_FOLDER_ID,
|
||||
FOLDER_3_ID,
|
||||
)
|
||||
)
|
||||
assert_hierarchy_nodes_match_expected(
|
||||
retrieved_nodes=output.hierarchy_nodes,
|
||||
expected_node_ids=expected_ids,
|
||||
expected_parent_mapping=expected_parents,
|
||||
expected_nodes=expected_nodes,
|
||||
ignorable_node_ids={RESTRICTED_ACCESS_FOLDER_ID},
|
||||
)
|
||||
|
||||
@@ -294,28 +282,26 @@ def test_include_shared_drives_only(
|
||||
# TODO: switch to 54 when restricted access issue is resolved
|
||||
assert len(output.documents) == 51 or len(output.documents) == 52
|
||||
|
||||
# Verify hierarchy nodes - should include both shared drives and their folders
|
||||
# When include_shared_drives=True, we get ALL shared drives in the organization
|
||||
expected_ids, expected_parents = get_expected_hierarchy_for_shared_drives(
|
||||
expected_nodes = get_expected_hierarchy_for_shared_drives(
|
||||
include_drive_1=True,
|
||||
include_drive_2=True,
|
||||
include_restricted_folder=False,
|
||||
)
|
||||
|
||||
# Add additional shared drives in the organization
|
||||
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_ONLY_ID)
|
||||
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID)
|
||||
expected_ids.add(PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID)
|
||||
expected_ids.add(TEST_USER_1_DRIVE_B_ID)
|
||||
expected_ids.add(TEST_USER_1_DRIVE_B_FOLDER_ID)
|
||||
expected_ids.add(TEST_USER_1_EXTRA_DRIVE_1_ID)
|
||||
expected_ids.add(TEST_USER_1_EXTRA_DRIVE_2_ID)
|
||||
expected_ids.add(RESTRICTED_ACCESS_FOLDER_ID)
|
||||
|
||||
expected_nodes.update(
|
||||
_pick(
|
||||
PERM_SYNC_DRIVE_ADMIN_ONLY_ID,
|
||||
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_A_ID,
|
||||
PERM_SYNC_DRIVE_ADMIN_AND_USER_1_B_ID,
|
||||
TEST_USER_1_DRIVE_B_ID,
|
||||
TEST_USER_1_DRIVE_B_FOLDER_ID,
|
||||
TEST_USER_1_EXTRA_DRIVE_1_ID,
|
||||
TEST_USER_1_EXTRA_DRIVE_2_ID,
|
||||
RESTRICTED_ACCESS_FOLDER_ID,
|
||||
)
|
||||
)
|
||||
assert_hierarchy_nodes_match_expected(
|
||||
retrieved_nodes=output.hierarchy_nodes,
|
||||
expected_node_ids=expected_ids,
|
||||
expected_parent_mapping=expected_parents,
|
||||
expected_nodes=expected_nodes,
|
||||
ignorable_node_ids={RESTRICTED_ACCESS_FOLDER_ID},
|
||||
)
|
||||
|
||||
@@ -353,9 +339,7 @@ def test_include_my_drives_only(
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
|
||||
# Verify hierarchy nodes - My Drive roots and folders for all users
|
||||
# Service account impersonates all users, so it sees all My Drives
|
||||
expected_ids = {
|
||||
expected_nodes = _pick(
|
||||
FOLDER_3_ID,
|
||||
ADMIN_MY_DRIVE_ID,
|
||||
TEST_USER_1_MY_DRIVE_ID,
|
||||
@@ -365,10 +349,10 @@ def test_include_my_drives_only(
|
||||
PILL_FOLDER_ID,
|
||||
TEST_USER_1_EXTRA_FOLDER_ID,
|
||||
EXTERNAL_SHARED_FOLDER_ID,
|
||||
}
|
||||
)
|
||||
assert_hierarchy_nodes_match_expected(
|
||||
retrieved_nodes=output.hierarchy_nodes,
|
||||
expected_node_ids=expected_ids,
|
||||
expected_nodes=expected_nodes,
|
||||
)
|
||||
|
||||
|
||||
@@ -405,17 +389,14 @@ def test_drive_one_only(
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
|
||||
# Verify hierarchy nodes - should only include shared_drive_1 and its folders
|
||||
expected_ids, expected_parents = get_expected_hierarchy_for_shared_drives(
|
||||
expected_nodes = get_expected_hierarchy_for_shared_drives(
|
||||
include_drive_1=True,
|
||||
include_drive_2=False,
|
||||
include_restricted_folder=False,
|
||||
)
|
||||
# Restricted folder is non-deterministically returned
|
||||
assert_hierarchy_nodes_match_expected(
|
||||
retrieved_nodes=output.hierarchy_nodes,
|
||||
expected_node_ids=expected_ids,
|
||||
expected_parent_mapping=expected_parents,
|
||||
expected_nodes=expected_nodes,
|
||||
ignorable_node_ids={RESTRICTED_ACCESS_FOLDER_ID},
|
||||
)
|
||||
|
||||
@@ -457,33 +438,15 @@ def test_folder_and_shared_drive(
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
|
||||
# Verify hierarchy nodes - shared_drive_1 and folder_2 with children
|
||||
# SHARED_DRIVE_2_ID is included because folder_2's parent is shared_drive_2
|
||||
expected_ids = {
|
||||
SHARED_DRIVE_1_ID,
|
||||
FOLDER_1_ID,
|
||||
FOLDER_1_1_ID,
|
||||
FOLDER_1_2_ID,
|
||||
SHARED_DRIVE_2_ID,
|
||||
FOLDER_2_ID,
|
||||
FOLDER_2_1_ID,
|
||||
FOLDER_2_2_ID,
|
||||
}
|
||||
expected_parents = {
|
||||
SHARED_DRIVE_1_ID: None,
|
||||
FOLDER_1_ID: SHARED_DRIVE_1_ID,
|
||||
FOLDER_1_1_ID: FOLDER_1_ID,
|
||||
FOLDER_1_2_ID: FOLDER_1_ID,
|
||||
SHARED_DRIVE_2_ID: None,
|
||||
FOLDER_2_ID: SHARED_DRIVE_2_ID,
|
||||
FOLDER_2_1_ID: FOLDER_2_ID,
|
||||
FOLDER_2_2_ID: FOLDER_2_ID,
|
||||
}
|
||||
# Restricted folder is non-deterministically returned
|
||||
expected_nodes = get_expected_hierarchy_for_shared_drives(
|
||||
include_drive_1=True,
|
||||
include_drive_2=True,
|
||||
include_restricted_folder=False,
|
||||
)
|
||||
expected_nodes.pop(SECTIONS_FOLDER_ID, None)
|
||||
assert_hierarchy_nodes_match_expected(
|
||||
retrieved_nodes=output.hierarchy_nodes,
|
||||
expected_node_ids=expected_ids,
|
||||
expected_parent_mapping=expected_parents,
|
||||
expected_nodes=expected_nodes,
|
||||
ignorable_node_ids={RESTRICTED_ACCESS_FOLDER_ID},
|
||||
)
|
||||
|
||||
@@ -530,23 +493,16 @@ def test_folders_only(
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
|
||||
# Verify hierarchy nodes - specific folders requested plus their parent nodes
|
||||
# The connector walks up the hierarchy to include parent drives/folders
|
||||
expected_ids = {
|
||||
SHARED_DRIVE_1_ID,
|
||||
FOLDER_1_ID,
|
||||
FOLDER_1_1_ID,
|
||||
FOLDER_1_2_ID,
|
||||
SHARED_DRIVE_2_ID,
|
||||
FOLDER_2_ID,
|
||||
FOLDER_2_1_ID,
|
||||
FOLDER_2_2_ID,
|
||||
ADMIN_MY_DRIVE_ID,
|
||||
FOLDER_3_ID,
|
||||
}
|
||||
expected_nodes = get_expected_hierarchy_for_shared_drives(
|
||||
include_drive_1=True,
|
||||
include_drive_2=True,
|
||||
include_restricted_folder=False,
|
||||
)
|
||||
expected_nodes.pop(SECTIONS_FOLDER_ID, None)
|
||||
expected_nodes.update(_pick(ADMIN_MY_DRIVE_ID, FOLDER_3_ID))
|
||||
assert_hierarchy_nodes_match_expected(
|
||||
retrieved_nodes=output.hierarchy_nodes,
|
||||
expected_node_ids=expected_ids,
|
||||
expected_nodes=expected_nodes,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -4,6 +4,8 @@ from unittest.mock import patch
|
||||
|
||||
from onyx.connectors.google_drive.connector import GoogleDriveConnector
|
||||
from onyx.connectors.models import Document
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import _clear_parents
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import _pick
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import ADMIN_FOLDER_3_FILE_IDS
|
||||
from tests.daily.connectors.google_drive.consts_and_utils import (
|
||||
assert_expected_docs_in_retrieved_docs,
|
||||
@@ -51,8 +53,6 @@ def _check_for_error(
|
||||
retrieved_failures = output.failures
|
||||
assert len(retrieved_failures) <= 1
|
||||
|
||||
# current behavior is to fail silently for 403s; leaving this here for when we revert
|
||||
# if all 403s get fixed
|
||||
if len(retrieved_failures) == 1:
|
||||
fail_msg = retrieved_failures[0].failure_message
|
||||
assert "HttpError 403" in fail_msg
|
||||
@@ -83,14 +83,11 @@ def test_all(
|
||||
output = load_connector_outputs(connector)
|
||||
|
||||
expected_file_ids = (
|
||||
# These are the files from my drive
|
||||
TEST_USER_1_FILE_IDS
|
||||
# These are the files from shared drives
|
||||
+ SHARED_DRIVE_1_FILE_IDS
|
||||
+ FOLDER_1_FILE_IDS
|
||||
+ FOLDER_1_1_FILE_IDS
|
||||
+ FOLDER_1_2_FILE_IDS
|
||||
# These are the files shared with me from admin
|
||||
+ ADMIN_FOLDER_3_FILE_IDS
|
||||
+ list(range(0, 2))
|
||||
)
|
||||
@@ -102,13 +99,9 @@ def test_all(
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
|
||||
# Verify hierarchy nodes - test_user_1 has access to shared_drive_1, folder_3,
|
||||
# perm sync drives, and additional drives/folders
|
||||
expected_ids, expected_parents = get_expected_hierarchy_for_test_user_1()
|
||||
assert_hierarchy_nodes_match_expected(
|
||||
retrieved_nodes=output.hierarchy_nodes,
|
||||
expected_node_ids=expected_ids,
|
||||
expected_parent_mapping=expected_parents,
|
||||
expected_nodes=get_expected_hierarchy_for_test_user_1(),
|
||||
)
|
||||
|
||||
|
||||
@@ -133,7 +126,6 @@ def test_shared_drives_only(
|
||||
output = load_connector_outputs(connector)
|
||||
|
||||
expected_file_ids = (
|
||||
# These are the files from shared drives
|
||||
SHARED_DRIVE_1_FILE_IDS
|
||||
+ FOLDER_1_FILE_IDS
|
||||
+ FOLDER_1_1_FILE_IDS
|
||||
@@ -146,14 +138,9 @@ def test_shared_drives_only(
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
|
||||
# Verify hierarchy nodes - test_user_1 sees multiple shared drives/folders
|
||||
expected_ids, expected_parents = (
|
||||
get_expected_hierarchy_for_test_user_1_shared_drives_only()
|
||||
)
|
||||
assert_hierarchy_nodes_match_expected(
|
||||
retrieved_nodes=output.hierarchy_nodes,
|
||||
expected_node_ids=expected_ids,
|
||||
expected_parent_mapping=expected_parents,
|
||||
expected_nodes=get_expected_hierarchy_for_test_user_1_shared_drives_only(),
|
||||
)
|
||||
|
||||
|
||||
@@ -177,24 +164,15 @@ def test_shared_with_me_only(
|
||||
)
|
||||
output = load_connector_outputs(connector)
|
||||
|
||||
expected_file_ids = (
|
||||
# These are the files shared with me from admin
|
||||
ADMIN_FOLDER_3_FILE_IDS
|
||||
+ list(range(0, 2))
|
||||
)
|
||||
expected_file_ids = ADMIN_FOLDER_3_FILE_IDS + list(range(0, 2))
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=output.documents,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
|
||||
# Verify hierarchy nodes - shared-with-me folders
|
||||
expected_ids, expected_parents = (
|
||||
get_expected_hierarchy_for_test_user_1_shared_with_me_only()
|
||||
)
|
||||
assert_hierarchy_nodes_match_expected(
|
||||
retrieved_nodes=output.hierarchy_nodes,
|
||||
expected_node_ids=expected_ids,
|
||||
expected_parent_mapping=expected_parents,
|
||||
expected_nodes=get_expected_hierarchy_for_test_user_1_shared_with_me_only(),
|
||||
)
|
||||
|
||||
|
||||
@@ -218,21 +196,15 @@ def test_my_drive_only(
|
||||
)
|
||||
output = load_connector_outputs(connector)
|
||||
|
||||
# These are the files from my drive
|
||||
expected_file_ids = TEST_USER_1_FILE_IDS
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=output.documents,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
|
||||
# Verify hierarchy nodes - My Drive root + its folder(s)
|
||||
expected_ids, expected_parents = (
|
||||
get_expected_hierarchy_for_test_user_1_my_drive_only()
|
||||
)
|
||||
assert_hierarchy_nodes_match_expected(
|
||||
retrieved_nodes=output.hierarchy_nodes,
|
||||
expected_node_ids=expected_ids,
|
||||
expected_parent_mapping=expected_parents,
|
||||
expected_nodes=get_expected_hierarchy_for_test_user_1_my_drive_only(),
|
||||
)
|
||||
|
||||
|
||||
@@ -256,20 +228,15 @@ def test_shared_my_drive_folder(
|
||||
)
|
||||
output = load_connector_outputs(connector)
|
||||
|
||||
expected_file_ids = (
|
||||
# this is a folder from admin's drive that is shared with me
|
||||
ADMIN_FOLDER_3_FILE_IDS
|
||||
)
|
||||
expected_file_ids = ADMIN_FOLDER_3_FILE_IDS
|
||||
assert_expected_docs_in_retrieved_docs(
|
||||
retrieved_docs=output.documents,
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
|
||||
# Verify hierarchy nodes - only folder_3
|
||||
expected_ids = {FOLDER_3_ID}
|
||||
assert_hierarchy_nodes_match_expected(
|
||||
retrieved_nodes=output.hierarchy_nodes,
|
||||
expected_node_ids=expected_ids,
|
||||
expected_nodes=_clear_parents(_pick(FOLDER_3_ID), FOLDER_3_ID),
|
||||
)
|
||||
|
||||
|
||||
@@ -299,16 +266,9 @@ def test_shared_drive_folder(
|
||||
expected_file_ids=expected_file_ids,
|
||||
)
|
||||
|
||||
# Verify hierarchy nodes - includes shared drive root + folder_1 subtree
|
||||
expected_ids = {SHARED_DRIVE_1_ID, FOLDER_1_ID, FOLDER_1_1_ID, FOLDER_1_2_ID}
|
||||
expected_parents: dict[str, str | None] = {
|
||||
SHARED_DRIVE_1_ID: None,
|
||||
FOLDER_1_ID: SHARED_DRIVE_1_ID,
|
||||
FOLDER_1_1_ID: FOLDER_1_ID,
|
||||
FOLDER_1_2_ID: FOLDER_1_ID,
|
||||
}
|
||||
assert_hierarchy_nodes_match_expected(
|
||||
retrieved_nodes=output.hierarchy_nodes,
|
||||
expected_node_ids=expected_ids,
|
||||
expected_parent_mapping=expected_parents,
|
||||
expected_nodes=_pick(
|
||||
SHARED_DRIVE_1_ID, FOLDER_1_ID, FOLDER_1_1_ID, FOLDER_1_2_ID
|
||||
),
|
||||
)
|
||||
|
||||
@@ -996,6 +996,114 @@ class TestFallbackToolExtraction:
|
||||
assert result.tool_calls[0].tool_args == {"queries": ["beta"]}
|
||||
assert result.tool_calls[0].placement == Placement(turn_index=5)
|
||||
|
||||
def test_extracts_xml_style_invoke_from_answer_when_required(self) -> None:
|
||||
llm_step_result = LlmStepResult(
|
||||
reasoning=None,
|
||||
answer=(
|
||||
'<function_calls><invoke name="internal_search">'
|
||||
'<parameter name="queries" string="false">'
|
||||
'["Onyx documentation", "Onyx docs", "Onyx platform"]'
|
||||
"</parameter></invoke></function_calls>"
|
||||
),
|
||||
tool_calls=None,
|
||||
)
|
||||
|
||||
result, attempted = _try_fallback_tool_extraction(
|
||||
llm_step_result=llm_step_result,
|
||||
tool_choice=ToolChoiceOptions.REQUIRED,
|
||||
fallback_extraction_attempted=False,
|
||||
tool_defs=self._tool_defs(),
|
||||
turn_index=7,
|
||||
)
|
||||
|
||||
assert attempted is True
|
||||
assert result.tool_calls is not None
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].tool_name == "internal_search"
|
||||
assert result.tool_calls[0].tool_args == {
|
||||
"queries": ["Onyx documentation", "Onyx docs", "Onyx platform"]
|
||||
}
|
||||
assert result.tool_calls[0].placement == Placement(turn_index=7)
|
||||
|
||||
def test_extracts_xml_style_invoke_from_answer_when_auto(self) -> None:
|
||||
llm_step_result = LlmStepResult(
|
||||
reasoning=None,
|
||||
# Runtime-faithful shape: filtered answer is empty, raw answer has XML payload.
|
||||
answer=None,
|
||||
raw_answer=(
|
||||
'<function_calls><invoke name="internal_search">'
|
||||
'<parameter name="queries" string="false">'
|
||||
'["Onyx documentation", "Onyx docs", "Onyx internal docs"]'
|
||||
"</parameter></invoke></function_calls>"
|
||||
),
|
||||
tool_calls=None,
|
||||
)
|
||||
|
||||
result, attempted = _try_fallback_tool_extraction(
|
||||
llm_step_result=llm_step_result,
|
||||
tool_choice=ToolChoiceOptions.AUTO,
|
||||
fallback_extraction_attempted=False,
|
||||
tool_defs=self._tool_defs(),
|
||||
turn_index=9,
|
||||
)
|
||||
|
||||
assert attempted is True
|
||||
assert result.tool_calls is not None
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].tool_name == "internal_search"
|
||||
assert result.tool_calls[0].tool_args == {
|
||||
"queries": ["Onyx documentation", "Onyx docs", "Onyx internal docs"]
|
||||
}
|
||||
assert result.tool_calls[0].placement == Placement(turn_index=9)
|
||||
|
||||
def test_extracts_from_raw_answer_when_filtered_answer_has_no_xml(self) -> None:
|
||||
llm_step_result = LlmStepResult(
|
||||
reasoning=None,
|
||||
answer="",
|
||||
raw_answer=(
|
||||
'<function_calls><invoke name="internal_search">'
|
||||
'<parameter name="queries" string="false">'
|
||||
'["Onyx documentation", "Onyx docs"]'
|
||||
"</parameter></invoke></function_calls>"
|
||||
),
|
||||
tool_calls=None,
|
||||
)
|
||||
|
||||
result, attempted = _try_fallback_tool_extraction(
|
||||
llm_step_result=llm_step_result,
|
||||
tool_choice=ToolChoiceOptions.AUTO,
|
||||
fallback_extraction_attempted=False,
|
||||
tool_defs=self._tool_defs(),
|
||||
turn_index=10,
|
||||
)
|
||||
|
||||
assert attempted is True
|
||||
assert result.tool_calls is not None
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].tool_name == "internal_search"
|
||||
assert result.tool_calls[0].tool_args == {
|
||||
"queries": ["Onyx documentation", "Onyx docs"]
|
||||
}
|
||||
assert result.tool_calls[0].placement == Placement(turn_index=10)
|
||||
|
||||
def test_does_not_attempt_fallback_for_auto_without_tool_call_hints(self) -> None:
|
||||
llm_step_result = LlmStepResult(
|
||||
reasoning=None,
|
||||
answer="Here is a normal answer with no tool call payload.",
|
||||
tool_calls=None,
|
||||
)
|
||||
|
||||
result, attempted = _try_fallback_tool_extraction(
|
||||
llm_step_result=llm_step_result,
|
||||
tool_choice=ToolChoiceOptions.AUTO,
|
||||
fallback_extraction_attempted=False,
|
||||
tool_defs=self._tool_defs(),
|
||||
turn_index=2,
|
||||
)
|
||||
|
||||
assert result is llm_step_result
|
||||
assert attempted is False
|
||||
|
||||
def test_returns_unchanged_when_required_but_nothing_extractable(self) -> None:
|
||||
llm_step_result = LlmStepResult(
|
||||
reasoning="Need more info.",
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from onyx.chat.llm_step import _parse_tool_args_to_dict
|
||||
from onyx.chat.llm_step import _sanitize_llm_output
|
||||
from onyx.chat.llm_step import _XmlToolCallContentFilter
|
||||
from onyx.chat.llm_step import extract_tool_calls_from_response_text
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
|
||||
@@ -211,3 +212,79 @@ class TestExtractToolCallsFromResponseText:
|
||||
{"queries": ["alpha"]},
|
||||
{"queries": ["alpha"]},
|
||||
]
|
||||
|
||||
def test_extracts_xml_style_invoke_tool_call(self) -> None:
|
||||
response_text = """
|
||||
<function_calls>
|
||||
<invoke name="internal_search">
|
||||
<parameter name="queries" string="false">["Onyx documentation", "Onyx docs", "Onyx platform"]</parameter>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
"""
|
||||
tool_calls = extract_tool_calls_from_response_text(
|
||||
response_text=response_text,
|
||||
tool_definitions=self._tool_defs(),
|
||||
placement=self._placement(),
|
||||
)
|
||||
assert len(tool_calls) == 1
|
||||
assert tool_calls[0].tool_name == "internal_search"
|
||||
assert tool_calls[0].tool_args == {
|
||||
"queries": ["Onyx documentation", "Onyx docs", "Onyx platform"]
|
||||
}
|
||||
|
||||
def test_ignores_unknown_tool_in_xml_style_invoke(self) -> None:
|
||||
response_text = """
|
||||
<function_calls>
|
||||
<invoke name="unknown_tool">
|
||||
<parameter name="queries" string="false">["Onyx docs"]</parameter>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
"""
|
||||
tool_calls = extract_tool_calls_from_response_text(
|
||||
response_text=response_text,
|
||||
tool_definitions=self._tool_defs(),
|
||||
placement=self._placement(),
|
||||
)
|
||||
assert len(tool_calls) == 0
|
||||
|
||||
|
||||
class TestXmlToolCallContentFilter:
|
||||
def test_strips_function_calls_block_single_chunk(self) -> None:
|
||||
f = _XmlToolCallContentFilter()
|
||||
output = f.process(
|
||||
"prefix "
|
||||
'<function_calls><invoke name="internal_search">'
|
||||
'<parameter name="queries" string="false">["Onyx docs"]</parameter>'
|
||||
"</invoke></function_calls> suffix"
|
||||
)
|
||||
output += f.flush()
|
||||
assert output == "prefix suffix"
|
||||
|
||||
def test_strips_function_calls_block_split_across_chunks(self) -> None:
|
||||
f = _XmlToolCallContentFilter()
|
||||
chunks = [
|
||||
"Start ",
|
||||
"<function_",
|
||||
'calls><invoke name="internal_search">',
|
||||
'<parameter name="queries" string="false">["Onyx docs"]',
|
||||
"</parameter></invoke></function_calls>",
|
||||
" End",
|
||||
]
|
||||
output = "".join(f.process(chunk) for chunk in chunks) + f.flush()
|
||||
assert output == "Start End"
|
||||
|
||||
def test_preserves_non_tool_call_xml(self) -> None:
|
||||
f = _XmlToolCallContentFilter()
|
||||
output = f.process("A <tag>value</tag> B")
|
||||
output += f.flush()
|
||||
assert output == "A <tag>value</tag> B"
|
||||
|
||||
def test_does_not_strip_similar_tag_names(self) -> None:
|
||||
f = _XmlToolCallContentFilter()
|
||||
output = f.process(
|
||||
"A <function_calls_v2><invoke>noop</invoke></function_calls_v2> B"
|
||||
)
|
||||
output += f.flush()
|
||||
assert (
|
||||
output == "A <function_calls_v2><invoke>noop</invoke></function_calls_v2> B"
|
||||
)
|
||||
|
||||
@@ -0,0 +1,171 @@
|
||||
"""Unit tests for Prometheus instrumentation module."""
|
||||
|
||||
import threading
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
from fastapi import FastAPI
|
||||
from fastapi.testclient import TestClient
|
||||
from prometheus_client import CollectorRegistry
|
||||
from prometheus_client import Gauge
|
||||
|
||||
from onyx.server.prometheus_instrumentation import _slow_request_callback
|
||||
from onyx.server.prometheus_instrumentation import setup_prometheus_metrics
|
||||
|
||||
|
||||
def _make_info(
|
||||
duration: float,
|
||||
method: str = "GET",
|
||||
handler: str = "/api/test",
|
||||
status: str = "200",
|
||||
) -> Any:
|
||||
"""Build a fake metrics Info object matching the instrumentator's Info shape."""
|
||||
return MagicMock(
|
||||
modified_duration=duration,
|
||||
method=method,
|
||||
modified_handler=handler,
|
||||
modified_status=status,
|
||||
)
|
||||
|
||||
|
||||
def test_slow_request_callback_increments_above_threshold() -> None:
|
||||
with patch("onyx.server.prometheus_instrumentation._slow_requests") as mock_counter:
|
||||
mock_labels = MagicMock()
|
||||
mock_counter.labels.return_value = mock_labels
|
||||
|
||||
info = _make_info(
|
||||
duration=2.0, method="POST", handler="/api/chat", status="200"
|
||||
)
|
||||
_slow_request_callback(info)
|
||||
|
||||
mock_counter.labels.assert_called_once_with(
|
||||
method="POST", handler="/api/chat", status="200"
|
||||
)
|
||||
mock_labels.inc.assert_called_once()
|
||||
|
||||
|
||||
def test_slow_request_callback_skips_below_threshold() -> None:
|
||||
with patch("onyx.server.prometheus_instrumentation._slow_requests") as mock_counter:
|
||||
info = _make_info(duration=0.5)
|
||||
_slow_request_callback(info)
|
||||
|
||||
mock_counter.labels.assert_not_called()
|
||||
|
||||
|
||||
def test_slow_request_callback_skips_at_exact_threshold() -> None:
|
||||
with (
|
||||
patch(
|
||||
"onyx.server.prometheus_instrumentation.SLOW_REQUEST_THRESHOLD_SECONDS", 1.0
|
||||
),
|
||||
patch("onyx.server.prometheus_instrumentation._slow_requests") as mock_counter,
|
||||
):
|
||||
info = _make_info(duration=1.0)
|
||||
_slow_request_callback(info)
|
||||
|
||||
mock_counter.labels.assert_not_called()
|
||||
|
||||
|
||||
def test_setup_attaches_instrumentator_to_app() -> None:
|
||||
with patch("onyx.server.prometheus_instrumentation.Instrumentator") as mock_cls:
|
||||
mock_instance = MagicMock()
|
||||
mock_instance.instrument.return_value = mock_instance
|
||||
mock_cls.return_value = mock_instance
|
||||
|
||||
app = FastAPI()
|
||||
setup_prometheus_metrics(app)
|
||||
|
||||
mock_cls.assert_called_once_with(
|
||||
should_group_status_codes=False,
|
||||
should_ignore_untemplated=False,
|
||||
should_group_untemplated=True,
|
||||
should_instrument_requests_inprogress=True,
|
||||
inprogress_labels=True,
|
||||
excluded_handlers=["/health", "/metrics", "/openapi.json"],
|
||||
)
|
||||
mock_instance.add.assert_called_once()
|
||||
mock_instance.instrument.assert_called_once_with(app)
|
||||
mock_instance.expose.assert_called_once_with(app)
|
||||
|
||||
|
||||
def test_inprogress_gauge_increments_during_request() -> None:
|
||||
"""Verify the in-progress gauge goes up while a request is in flight."""
|
||||
registry = CollectorRegistry()
|
||||
gauge = Gauge(
|
||||
"http_requests_inprogress_test",
|
||||
"In-progress requests",
|
||||
["method", "handler"],
|
||||
registry=registry,
|
||||
)
|
||||
|
||||
request_started = threading.Event()
|
||||
request_release = threading.Event()
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/slow")
|
||||
def slow_endpoint() -> dict:
|
||||
gauge.labels(method="GET", handler="/slow").inc()
|
||||
request_started.set()
|
||||
request_release.wait(timeout=5)
|
||||
gauge.labels(method="GET", handler="/slow").dec()
|
||||
return {"status": "done"}
|
||||
|
||||
client = TestClient(app, raise_server_exceptions=False)
|
||||
|
||||
def make_request() -> None:
|
||||
client.get("/slow")
|
||||
|
||||
thread = threading.Thread(target=make_request)
|
||||
thread.start()
|
||||
|
||||
request_started.wait(timeout=5)
|
||||
assert gauge.labels(method="GET", handler="/slow")._value.get() == 1.0
|
||||
|
||||
request_release.set()
|
||||
thread.join(timeout=5)
|
||||
assert gauge.labels(method="GET", handler="/slow")._value.get() == 0.0
|
||||
|
||||
|
||||
def test_inprogress_gauge_tracks_concurrent_requests() -> None:
|
||||
"""Verify the gauge correctly counts multiple concurrent in-flight requests."""
|
||||
registry = CollectorRegistry()
|
||||
gauge = Gauge(
|
||||
"http_requests_inprogress_concurrent_test",
|
||||
"In-progress requests",
|
||||
["method", "handler"],
|
||||
registry=registry,
|
||||
)
|
||||
|
||||
# 3 parties: 2 request threads + main thread
|
||||
barrier = threading.Barrier(3)
|
||||
release = threading.Event()
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
@app.get("/concurrent")
|
||||
def concurrent_endpoint() -> dict:
|
||||
gauge.labels(method="GET", handler="/concurrent").inc()
|
||||
barrier.wait(timeout=5)
|
||||
release.wait(timeout=5)
|
||||
gauge.labels(method="GET", handler="/concurrent").dec()
|
||||
return {"status": "done"}
|
||||
|
||||
client = TestClient(app, raise_server_exceptions=False)
|
||||
|
||||
def make_request() -> None:
|
||||
client.get("/concurrent")
|
||||
|
||||
t1 = threading.Thread(target=make_request)
|
||||
t2 = threading.Thread(target=make_request)
|
||||
t1.start()
|
||||
t2.start()
|
||||
|
||||
# All 3 threads meet here — both requests are in-flight
|
||||
barrier.wait(timeout=5)
|
||||
assert gauge.labels(method="GET", handler="/concurrent")._value.get() == 2.0
|
||||
|
||||
release.set()
|
||||
t1.join(timeout=5)
|
||||
t2.join(timeout=5)
|
||||
assert gauge.labels(method="GET", handler="/concurrent")._value.get() == 0.0
|
||||
88
docs/METRICS.md
Normal file
88
docs/METRICS.md
Normal file
@@ -0,0 +1,88 @@
|
||||
# Onyx Prometheus Metrics Reference
|
||||
|
||||
## API Server Metrics
|
||||
|
||||
These metrics are exposed at `GET /metrics` on the API server.
|
||||
|
||||
### Built-in (via `prometheus-fastapi-instrumentator`)
|
||||
|
||||
| Metric | Type | Labels | Description |
|
||||
|--------|------|--------|-------------|
|
||||
| `http_requests_total` | Counter | `method`, `status`, `handler` | Total request count |
|
||||
| `http_request_duration_highr_seconds` | Histogram | _(none)_ | High-resolution latency (many buckets, no labels) |
|
||||
| `http_request_duration_seconds` | Histogram | `method`, `handler` | Latency by handler (few buckets for aggregation) |
|
||||
| `http_request_size_bytes` | Summary | `handler` | Incoming request content length |
|
||||
| `http_response_size_bytes` | Summary | `handler` | Outgoing response content length |
|
||||
| `http_requests_inprogress` | Gauge | `method`, `handler` | Currently in-flight requests |
|
||||
|
||||
### Custom (via `onyx.server.prometheus_instrumentation`)
|
||||
|
||||
| Metric | Type | Labels | Description |
|
||||
|--------|------|--------|-------------|
|
||||
| `onyx_api_slow_requests_total` | Counter | `method`, `handler`, `status` | Requests exceeding `SLOW_REQUEST_THRESHOLD_SECONDS` (default 1s) |
|
||||
|
||||
### Configuration
|
||||
|
||||
| Env Var | Default | Description |
|
||||
|---------|---------|-------------|
|
||||
| `SLOW_REQUEST_THRESHOLD_SECONDS` | `1.0` | Duration threshold for slow request counting |
|
||||
|
||||
### Instrumentator Settings
|
||||
|
||||
- `should_group_status_codes=False` — Reports exact HTTP status codes (e.g. 401, 403, 500)
|
||||
- `should_instrument_requests_inprogress=True` — Enables the in-progress request gauge
|
||||
- `inprogress_labels=True` — Breaks down in-progress gauge by `method` and `handler`
|
||||
- `excluded_handlers=["/health", "/metrics", "/openapi.json"]` — Excludes noisy endpoints from metrics
|
||||
|
||||
## Example PromQL Queries
|
||||
|
||||
### Which endpoints are saturated right now?
|
||||
|
||||
```promql
|
||||
# Top 10 endpoints by in-progress requests
|
||||
topk(10, http_requests_inprogress)
|
||||
```
|
||||
|
||||
### What's the P99 latency per endpoint?
|
||||
|
||||
```promql
|
||||
# P99 latency by handler over the last 5 minutes
|
||||
histogram_quantile(0.99, sum by (handler, le) (rate(http_request_duration_seconds_bucket[5m])))
|
||||
```
|
||||
|
||||
### Which endpoints have the highest request rate?
|
||||
|
||||
```promql
|
||||
# Requests per second by handler, top 10
|
||||
topk(10, sum by (handler) (rate(http_requests_total[5m])))
|
||||
```
|
||||
|
||||
### Which endpoints are returning errors?
|
||||
|
||||
```promql
|
||||
# 5xx error rate by handler
|
||||
sum by (handler) (rate(http_requests_total{status=~"5.."}[5m]))
|
||||
```
|
||||
|
||||
### Slow request hotspots
|
||||
|
||||
```promql
|
||||
# Slow requests per minute by handler
|
||||
sum by (handler) (rate(onyx_api_slow_requests_total[5m])) * 60
|
||||
```
|
||||
|
||||
### Latency trending up?
|
||||
|
||||
```promql
|
||||
# Compare P50 latency now vs 1 hour ago
|
||||
histogram_quantile(0.5, sum by (le) (rate(http_request_duration_highr_seconds_bucket[5m])))
|
||||
-
|
||||
histogram_quantile(0.5, sum by (le) (rate(http_request_duration_highr_seconds_bucket[5m] offset 1h)))
|
||||
```
|
||||
|
||||
### Overall request throughput
|
||||
|
||||
```promql
|
||||
# Total requests per second across all endpoints
|
||||
sum(rate(http_requests_total[5m]))
|
||||
```
|
||||
@@ -1,242 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import { Callout } from "@/components/ui/callout";
|
||||
import Text from "@/components/ui/text";
|
||||
import { ChatSession, ChatSessionSharedStatus } from "@/app/app/interfaces";
|
||||
import { SEARCH_PARAM_NAMES } from "@/app/app/services/searchParams";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { structureValue } from "@/lib/llm/utils";
|
||||
import { LlmDescriptor, useLlmManager } from "@/lib/hooks";
|
||||
import Separator from "@/refresh-components/Separator";
|
||||
import { AdvancedOptionsToggle } from "@/components/AdvancedOptionsToggle";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { useCurrentAgent } from "@/hooks/useAgents";
|
||||
import { useSearchParams } from "next/navigation";
|
||||
import { useChatSessionStore } from "@/app/app/stores/useChatSessionStore";
|
||||
import ConfirmationModalLayout from "@/refresh-components/layouts/ConfirmationModalLayout";
|
||||
import CopyIconButton from "@/refresh-components/buttons/CopyIconButton";
|
||||
import { copyAll } from "@/app/app/message/copyingUtils";
|
||||
import { SvgCopy, SvgShare } from "@opal/icons";
|
||||
|
||||
function buildShareLink(chatSessionId: string) {
|
||||
const baseUrl = `${window.location.protocol}//${window.location.host}`;
|
||||
return `${baseUrl}/app/shared/${chatSessionId}`;
|
||||
}
|
||||
|
||||
async function generateShareLink(chatSessionId: string) {
|
||||
const response = await fetch(`/api/chat/chat-session/${chatSessionId}`, {
|
||||
method: "PATCH",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({ sharing_status: "public" }),
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
return buildShareLink(chatSessionId);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
async function generateSeedLink(
|
||||
message?: string,
|
||||
assistantId?: number,
|
||||
modelOverride?: LlmDescriptor
|
||||
) {
|
||||
const baseUrl = `${window.location.protocol}//${window.location.host}`;
|
||||
const model = modelOverride
|
||||
? structureValue(
|
||||
modelOverride.name,
|
||||
modelOverride.provider,
|
||||
modelOverride.modelName
|
||||
)
|
||||
: null;
|
||||
return `${baseUrl}/app${
|
||||
message
|
||||
? `?${SEARCH_PARAM_NAMES.USER_PROMPT}=${encodeURIComponent(message)}`
|
||||
: ""
|
||||
}${
|
||||
assistantId
|
||||
? `${message ? "&" : "?"}${SEARCH_PARAM_NAMES.PERSONA_ID}=${assistantId}`
|
||||
: ""
|
||||
}${
|
||||
model
|
||||
? `${message || assistantId ? "&" : "?"}${
|
||||
SEARCH_PARAM_NAMES.STRUCTURED_MODEL
|
||||
}=${encodeURIComponent(model)}`
|
||||
: ""
|
||||
}${message ? `&${SEARCH_PARAM_NAMES.SEND_ON_LOAD}=true` : ""}`;
|
||||
}
|
||||
|
||||
async function deleteShareLink(chatSessionId: string) {
|
||||
const response = await fetch(`/api/chat/chat-session/${chatSessionId}`, {
|
||||
method: "PATCH",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
body: JSON.stringify({ sharing_status: "private" }),
|
||||
});
|
||||
|
||||
return response.ok;
|
||||
}
|
||||
|
||||
interface ShareChatSessionModalProps {
|
||||
chatSession: ChatSession;
|
||||
onClose: () => void;
|
||||
}
|
||||
|
||||
export default function ShareChatSessionModal({
|
||||
chatSession,
|
||||
onClose,
|
||||
}: ShareChatSessionModalProps) {
|
||||
const [shareLink, setShareLink] = useState<string>(
|
||||
chatSession.shared_status === ChatSessionSharedStatus.Public
|
||||
? buildShareLink(chatSession.id)
|
||||
: ""
|
||||
);
|
||||
const [showAdvancedOptions, setShowAdvancedOptions] = useState(false);
|
||||
const currentAgent = useCurrentAgent();
|
||||
const searchParams = useSearchParams();
|
||||
const message = searchParams?.get(SEARCH_PARAM_NAMES.USER_PROMPT) || "";
|
||||
const llmManager = useLlmManager(chatSession, currentAgent || undefined);
|
||||
const updateCurrentChatSessionSharedStatus = useChatSessionStore(
|
||||
(state) => state.updateCurrentChatSessionSharedStatus
|
||||
);
|
||||
|
||||
return (
|
||||
<>
|
||||
<ConfirmationModalLayout
|
||||
icon={SvgShare}
|
||||
title="Share Chat"
|
||||
onClose={onClose}
|
||||
submit={<Button onClick={onClose}>Share</Button>}
|
||||
>
|
||||
{shareLink ? (
|
||||
<div>
|
||||
<Text>
|
||||
This chat session is currently shared. Anyone in your team can
|
||||
view the message history using the following link:
|
||||
</Text>
|
||||
|
||||
<div className="flex items-center mt-2">
|
||||
{/* <CopyButton content={shareLink} /> */}
|
||||
<CopyIconButton
|
||||
getCopyText={() => shareLink}
|
||||
prominence="secondary"
|
||||
/>
|
||||
<a
|
||||
href={shareLink}
|
||||
target="_blank"
|
||||
className={cn(
|
||||
"underline mt-1 ml-1 text-sm my-auto",
|
||||
"text-action-link-05"
|
||||
)}
|
||||
rel="noreferrer"
|
||||
>
|
||||
{shareLink}
|
||||
</a>
|
||||
</div>
|
||||
|
||||
<Separator />
|
||||
|
||||
<Text className={cn("mb-4")}>
|
||||
Click the button below to make the chat private again.
|
||||
</Text>
|
||||
|
||||
<Button
|
||||
onClick={async () => {
|
||||
const success = await deleteShareLink(chatSession.id);
|
||||
if (success) {
|
||||
setShareLink("");
|
||||
updateCurrentChatSessionSharedStatus(
|
||||
ChatSessionSharedStatus.Private
|
||||
);
|
||||
} else {
|
||||
alert("Failed to delete share link");
|
||||
}
|
||||
}}
|
||||
danger
|
||||
>
|
||||
Delete Share Link
|
||||
</Button>
|
||||
</div>
|
||||
) : (
|
||||
<div className="flex flex-col gap-2">
|
||||
<Callout type="warning" title="Warning">
|
||||
Please make sure that all content in this chat is safe to share
|
||||
with the whole team.
|
||||
</Callout>
|
||||
<Button
|
||||
leftIcon={SvgCopy}
|
||||
onClick={async () => {
|
||||
// NOTE: for "insecure" non-https setup, the `navigator.clipboard.writeText` may fail
|
||||
// as the browser may not allow the clipboard to be accessed.
|
||||
try {
|
||||
const shareLink = await generateShareLink(chatSession.id);
|
||||
if (!shareLink) {
|
||||
alert("Failed to generate share link");
|
||||
} else {
|
||||
setShareLink(shareLink);
|
||||
updateCurrentChatSessionSharedStatus(
|
||||
ChatSessionSharedStatus.Public
|
||||
);
|
||||
copyAll(shareLink);
|
||||
}
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
}
|
||||
}}
|
||||
secondary
|
||||
>
|
||||
Generate and Copy Share Link
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
|
||||
<Separator className={cn("my-4")} />
|
||||
|
||||
<AdvancedOptionsToggle
|
||||
showAdvancedOptions={showAdvancedOptions}
|
||||
setShowAdvancedOptions={setShowAdvancedOptions}
|
||||
title="Advanced Options"
|
||||
/>
|
||||
|
||||
{showAdvancedOptions && (
|
||||
<div className="flex flex-col gap-2">
|
||||
<Callout type="notice" title="Seed New Chat">
|
||||
Generate a link to a new chat session with the same settings as
|
||||
this chat (including the assistant and model).
|
||||
</Callout>
|
||||
<Button
|
||||
leftIcon={SvgCopy}
|
||||
onClick={async () => {
|
||||
try {
|
||||
const seedLink = await generateSeedLink(
|
||||
message,
|
||||
currentAgent?.id,
|
||||
llmManager.currentLlm
|
||||
);
|
||||
if (!seedLink) {
|
||||
toast.error("Failed to generate seed link");
|
||||
} else {
|
||||
navigator.clipboard.writeText(seedLink);
|
||||
copyAll(seedLink);
|
||||
toast.success("Link copied to clipboard!");
|
||||
}
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
alert("Failed to generate or copy link.");
|
||||
}
|
||||
}}
|
||||
secondary
|
||||
>
|
||||
Generate and Copy Seed Link
|
||||
</Button>
|
||||
</div>
|
||||
)}
|
||||
</ConfirmationModalLayout>
|
||||
</>
|
||||
);
|
||||
}
|
||||
@@ -186,6 +186,7 @@ export interface BackendChatSession {
|
||||
current_temperature_override: number | null;
|
||||
current_alternate_model?: string;
|
||||
|
||||
owner_name: string | null;
|
||||
packets: Packet[][];
|
||||
}
|
||||
|
||||
|
||||
@@ -20,16 +20,6 @@ import {
|
||||
import { openDocument } from "@/lib/search/utils";
|
||||
import { ensureHrefProtocol } from "@/lib/utils";
|
||||
|
||||
function isSameOriginUrl(url: string): boolean {
|
||||
if (!url.startsWith("http")) return true;
|
||||
try {
|
||||
if (typeof window === "undefined") return false;
|
||||
return new URL(url).origin === window.location.origin;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
export const MemoizedAnchor = memo(
|
||||
({
|
||||
docs,
|
||||
@@ -188,25 +178,6 @@ export const MemoizedLink = memo(
|
||||
|
||||
const url = ensureHrefProtocol(href);
|
||||
|
||||
const isChatFile = url?.includes("/api/chat/file/") && isSameOriginUrl(url);
|
||||
if (isChatFile && updatePresentingDocument) {
|
||||
const fileId = url!.split("/api/chat/file/")[1]?.split(/[?#]/)[0] || "";
|
||||
const filename = value?.toString() || "download";
|
||||
return (
|
||||
<a
|
||||
onClick={() =>
|
||||
updatePresentingDocument({
|
||||
document_id: fileId,
|
||||
semantic_identifier: filename,
|
||||
} as OnyxDocument)
|
||||
}
|
||||
className="cursor-pointer text-link hover:text-link-hover"
|
||||
>
|
||||
{rest.children}
|
||||
</a>
|
||||
);
|
||||
}
|
||||
|
||||
return (
|
||||
<a
|
||||
href={url}
|
||||
|
||||
@@ -15,6 +15,7 @@ import TextViewModal from "@/sections/modals/TextViewModal";
|
||||
import { UNNAMED_CHAT } from "@/lib/constants";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import useOnMount from "@/hooks/useOnMount";
|
||||
import SharedAppInputBar from "@/sections/input/SharedAppInputBar";
|
||||
|
||||
export interface SharedChatDisplayProps {
|
||||
chatSession: BackendChatSession | null;
|
||||
@@ -69,65 +70,78 @@ export default function SharedChatDisplay({
|
||||
/>
|
||||
)}
|
||||
|
||||
<div className="flex flex-col items-center h-full w-full overflow-hidden overflow-y-scroll">
|
||||
<div className="sticky top-0 z-10 flex flex-col w-full bg-background-tint-01 px-8 py-4">
|
||||
<Text as="p" headingH2>
|
||||
{chatSession.description || UNNAMED_CHAT}
|
||||
</Text>
|
||||
<Text as="p" text03>
|
||||
{humanReadableFormat(chatSession.time_created)}
|
||||
</Text>
|
||||
<div className="flex flex-col h-full w-full overflow-hidden">
|
||||
<div className="flex-1 flex flex-col items-center overflow-y-auto">
|
||||
<div className="sticky top-0 z-10 flex items-center justify-between w-full bg-background-tint-01 px-8 py-4">
|
||||
<Text as="p" text04 headingH2>
|
||||
{chatSession.description || UNNAMED_CHAT}
|
||||
</Text>
|
||||
<div className="flex flex-col items-end">
|
||||
<Text as="p" text03 secondaryBody>
|
||||
Shared on {humanReadableFormat(chatSession.time_created)}
|
||||
</Text>
|
||||
{chatSession.owner_name && (
|
||||
<Text as="p" text03 secondaryBody>
|
||||
by {chatSession.owner_name}
|
||||
</Text>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{isMounted ? (
|
||||
<div className="w-[min(50rem,100%)]">
|
||||
{messages.map((message, i) => {
|
||||
if (message.type === "user") {
|
||||
return (
|
||||
<HumanMessage
|
||||
key={message.messageId}
|
||||
content={message.message}
|
||||
files={message.files}
|
||||
nodeId={message.nodeId}
|
||||
/>
|
||||
);
|
||||
} else if (message.type === "assistant") {
|
||||
return (
|
||||
<AgentMessage
|
||||
key={message.messageId}
|
||||
rawPackets={message.packets}
|
||||
chatState={{
|
||||
assistant: persona,
|
||||
docs: message.documents,
|
||||
citations: message.citations,
|
||||
setPresentingDocument: setPresentingDocument,
|
||||
overriddenModel: message.overridden_model,
|
||||
}}
|
||||
nodeId={message.nodeId}
|
||||
llmManager={null}
|
||||
otherMessagesCanSwitchTo={undefined}
|
||||
onMessageSelection={undefined}
|
||||
/>
|
||||
);
|
||||
} else {
|
||||
// Error message case
|
||||
return (
|
||||
<div key={message.messageId} className="py-5 ml-4 lg:px-5">
|
||||
<div className="mx-auto w-[90%] max-w-message-max">
|
||||
<p className="text-status-text-error-05 text-sm my-auto">
|
||||
{message.message}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
})}
|
||||
</div>
|
||||
) : (
|
||||
<div className="h-full w-full flex items-center justify-center">
|
||||
<OnyxInitializingLoader />
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
{isMounted ? (
|
||||
<div className="w-[min(50rem,100%)]">
|
||||
{messages.map((message, i) => {
|
||||
if (message.type === "user") {
|
||||
return (
|
||||
<HumanMessage
|
||||
key={message.messageId}
|
||||
content={message.message}
|
||||
files={message.files}
|
||||
nodeId={message.nodeId}
|
||||
/>
|
||||
);
|
||||
} else if (message.type === "assistant") {
|
||||
return (
|
||||
<AgentMessage
|
||||
key={message.messageId}
|
||||
rawPackets={message.packets}
|
||||
chatState={{
|
||||
assistant: persona,
|
||||
docs: message.documents,
|
||||
citations: message.citations,
|
||||
setPresentingDocument: setPresentingDocument,
|
||||
overriddenModel: message.overridden_model,
|
||||
}}
|
||||
nodeId={message.nodeId}
|
||||
llmManager={null}
|
||||
otherMessagesCanSwitchTo={undefined}
|
||||
onMessageSelection={undefined}
|
||||
/>
|
||||
);
|
||||
} else {
|
||||
// Error message case
|
||||
return (
|
||||
<div key={message.messageId} className="py-5 ml-4 lg:px-5">
|
||||
<div className="mx-auto w-[90%] max-w-message-max">
|
||||
<p className="text-status-text-error-05 text-sm my-auto">
|
||||
{message.message}
|
||||
</p>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
})}
|
||||
</div>
|
||||
) : (
|
||||
<div className="h-full w-full flex items-center justify-center">
|
||||
<OnyxInitializingLoader />
|
||||
</div>
|
||||
)}
|
||||
<div className="w-full max-w-[50rem] mx-auto px-4 pb-4">
|
||||
<SharedAppInputBar />
|
||||
</div>
|
||||
</div>
|
||||
</>
|
||||
);
|
||||
|
||||
@@ -12,7 +12,8 @@ export type AppFocusType =
|
||||
| { type: "agent" | "project" | "chat"; id: string }
|
||||
| "new-session"
|
||||
| "more-agents"
|
||||
| "user-settings";
|
||||
| "user-settings"
|
||||
| "shared-chat";
|
||||
|
||||
export class AppFocus {
|
||||
constructor(public value: AppFocusType) {}
|
||||
@@ -29,6 +30,10 @@ export class AppFocus {
|
||||
return typeof this.value === "object" && this.value.type === "chat";
|
||||
}
|
||||
|
||||
isSharedChat(): boolean {
|
||||
return this.value === "shared-chat";
|
||||
}
|
||||
|
||||
isNewSession(): boolean {
|
||||
return this.value === "new-session";
|
||||
}
|
||||
@@ -49,6 +54,7 @@ export class AppFocus {
|
||||
| "agent"
|
||||
| "project"
|
||||
| "chat"
|
||||
| "shared-chat"
|
||||
| "new-session"
|
||||
| "more-agents"
|
||||
| "user-settings" {
|
||||
@@ -60,6 +66,11 @@ export default function useAppFocus(): AppFocus {
|
||||
const pathname = usePathname();
|
||||
const searchParams = useSearchParams();
|
||||
|
||||
// Check if we're viewing a shared chat
|
||||
if (pathname.startsWith("/app/shared/")) {
|
||||
return new AppFocus("shared-chat");
|
||||
}
|
||||
|
||||
// Check if we're on the user settings page
|
||||
if (pathname.startsWith("/app/settings")) {
|
||||
return new AppFocus("user-settings");
|
||||
|
||||
63
web/src/hooks/useContainerCenter.ts
Normal file
63
web/src/hooks/useContainerCenter.ts
Normal file
@@ -0,0 +1,63 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useEffect } from "react";
|
||||
|
||||
const SELECTOR = "[data-main-container]";
|
||||
|
||||
interface ContainerCenter {
|
||||
centerX: number | null;
|
||||
centerY: number | null;
|
||||
hasContainerCenter: boolean;
|
||||
}
|
||||
|
||||
function measure(el: HTMLElement): { x: number; y: number } {
|
||||
const rect = el.getBoundingClientRect();
|
||||
return { x: rect.left + rect.width / 2, y: rect.top + rect.height / 2 };
|
||||
}
|
||||
|
||||
/**
|
||||
* Tracks the center point of the `[data-main-container]` element so that
|
||||
* portaled overlays (modals, command menus) can center relative to the main
|
||||
* content area rather than the full viewport.
|
||||
*
|
||||
* Returns `{ centerX, centerY, hasContainerCenter }`.
|
||||
* When the container is not present (e.g. pages without `AppLayouts.Root`),
|
||||
* both center values are `null` and `hasContainerCenter` is `false`, allowing
|
||||
* callers to fall back to standard viewport centering.
|
||||
*
|
||||
* Uses a lazy `useState` initializer so the first render already has the
|
||||
* correct values (no flash), and a `ResizeObserver` to stay reactive when
|
||||
* the sidebar folds/unfolds.
|
||||
*/
|
||||
export default function useContainerCenter(): ContainerCenter {
|
||||
const [center, setCenter] = useState<{ x: number | null; y: number | null }>(
|
||||
() => {
|
||||
if (typeof document === "undefined") return { x: null, y: null };
|
||||
const el = document.querySelector<HTMLElement>(SELECTOR);
|
||||
if (!el) return { x: null, y: null };
|
||||
const m = measure(el);
|
||||
return { x: m.x, y: m.y };
|
||||
}
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
const container = document.querySelector<HTMLElement>(SELECTOR);
|
||||
if (!container) return;
|
||||
|
||||
const update = () => {
|
||||
const m = measure(container);
|
||||
setCenter({ x: m.x, y: m.y });
|
||||
};
|
||||
|
||||
update();
|
||||
const observer = new ResizeObserver(update);
|
||||
observer.observe(container);
|
||||
return () => observer.disconnect();
|
||||
}, []);
|
||||
|
||||
return {
|
||||
centerX: center.x,
|
||||
centerY: center.y,
|
||||
hasContainerCenter: center.x !== null && center.y !== null,
|
||||
};
|
||||
}
|
||||
@@ -27,7 +27,7 @@ import Button from "@/refresh-components/buttons/Button";
|
||||
import { useCallback, useMemo, useState, useEffect } from "react";
|
||||
import { useAppBackground } from "@/providers/AppBackgroundProvider";
|
||||
import { useTheme } from "next-themes";
|
||||
import ShareChatSessionModal from "@/app/app/components/modal/ShareChatSessionModal";
|
||||
import ShareChatSessionModal from "@/sections/modals/ShareChatSessionModal";
|
||||
import IconButton from "@/refresh-components/buttons/IconButton";
|
||||
import LineItem from "@/refresh-components/buttons/LineItem";
|
||||
import { useProjectsContext } from "@/providers/ProjectsContext";
|
||||
@@ -375,6 +375,7 @@ function Header() {
|
||||
transient={showShareModal}
|
||||
tertiary
|
||||
onClick={() => setShowShareModal(true)}
|
||||
aria-label="share-chat-button"
|
||||
>
|
||||
Share Chat
|
||||
</Button>
|
||||
@@ -510,8 +511,12 @@ function Root({ children, enableBackground }: AppRootProps) {
|
||||
return (
|
||||
/* NOTE: Some elements, markdown tables in particular, refer to this `@container` in order to
|
||||
breakout of their immediate containers using cqw units.
|
||||
The `data-main-container` attribute is used by portaled elements (e.g. CommandMenu) to
|
||||
render inside this container so they can be centered relative to the main content area
|
||||
rather than the full viewport (which would include the sidebar).
|
||||
*/
|
||||
<div
|
||||
data-main-container
|
||||
className={cn(
|
||||
"@container flex flex-col h-full w-full relative overflow-hidden",
|
||||
showBackground && "bg-cover bg-center bg-fixed"
|
||||
@@ -564,7 +569,7 @@ function Root({ children, enableBackground }: AppRootProps) {
|
||||
)}
|
||||
|
||||
<div className="z-app-layout">
|
||||
<Header />
|
||||
{!appFocus.isSharedChat() && <Header />}
|
||||
</div>
|
||||
<div className="z-app-layout flex-1 overflow-auto h-full w-full">
|
||||
{children}
|
||||
|
||||
@@ -1,40 +0,0 @@
|
||||
import * as languages from "linguist-languages";
|
||||
|
||||
interface LinguistLanguage {
|
||||
name: string;
|
||||
type: string;
|
||||
extensions?: string[];
|
||||
filenames?: string[];
|
||||
}
|
||||
|
||||
// Build extension → language name and filename → language name maps at module load
|
||||
const extensionMap = new Map<string, string>();
|
||||
const filenameMap = new Map<string, string>();
|
||||
|
||||
for (const lang of Object.values(languages) as LinguistLanguage[]) {
|
||||
if (lang.type !== "programming") continue;
|
||||
|
||||
const name = lang.name.toLowerCase();
|
||||
for (const ext of lang.extensions ?? []) {
|
||||
// First language to claim an extension wins
|
||||
if (!extensionMap.has(ext)) {
|
||||
extensionMap.set(ext, name);
|
||||
}
|
||||
}
|
||||
for (const filename of lang.filenames ?? []) {
|
||||
if (!filenameMap.has(filename.toLowerCase())) {
|
||||
filenameMap.set(filename.toLowerCase(), name);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns the language name for a given file name, or null if it's not a
|
||||
* recognised code file. Looks up by extension first, then by exact filename
|
||||
* (e.g. "Dockerfile", "Makefile"). Runs in O(1).
|
||||
*/
|
||||
export function getCodeLanguage(name: string): string | null {
|
||||
const lower = name.toLowerCase();
|
||||
const ext = lower.match(/\.[^.]+$/)?.[0];
|
||||
return (ext && extensionMap.get(ext)) ?? filenameMap.get(lower) ?? null;
|
||||
}
|
||||
@@ -9,6 +9,7 @@ import { Button } from "@opal/components";
|
||||
import { SvgX } from "@opal/icons";
|
||||
import { WithoutStyles } from "@/types";
|
||||
import { Section, SectionProps } from "@/layouts/general-layouts";
|
||||
import useContainerCenter from "@/hooks/useContainerCenter";
|
||||
|
||||
/**
|
||||
* Modal Root Component
|
||||
@@ -264,6 +265,8 @@ const ModalContent = React.forwardRef<
|
||||
contentRef(node);
|
||||
};
|
||||
|
||||
const { centerX, centerY, hasContainerCenter } = useContainerCenter();
|
||||
|
||||
const animationClasses = cn(
|
||||
"data-[state=open]:fade-in-0 data-[state=closed]:fade-out-0",
|
||||
"data-[state=open]:zoom-in-95 data-[state=closed]:zoom-out-95",
|
||||
@@ -271,6 +274,22 @@ const ModalContent = React.forwardRef<
|
||||
"duration-200"
|
||||
);
|
||||
|
||||
const containerStyle: React.CSSProperties | undefined = hasContainerCenter
|
||||
? ({
|
||||
left: centerX,
|
||||
top: centerY,
|
||||
"--tw-enter-translate-x": "-50%",
|
||||
"--tw-exit-translate-x": "-50%",
|
||||
"--tw-enter-translate-y": "-50%",
|
||||
"--tw-exit-translate-y": "-50%",
|
||||
} as React.CSSProperties)
|
||||
: undefined;
|
||||
|
||||
const positionClasses = cn(
|
||||
"fixed -translate-x-1/2 -translate-y-1/2",
|
||||
!hasContainerCenter && "left-1/2 top-1/2"
|
||||
);
|
||||
|
||||
const dialogEventHandlers = {
|
||||
onOpenAutoFocus: (e: Event) => {
|
||||
resetState();
|
||||
@@ -315,8 +334,9 @@ const ModalContent = React.forwardRef<
|
||||
{...dialogEventHandlers}
|
||||
>
|
||||
<div
|
||||
style={containerStyle}
|
||||
className={cn(
|
||||
"fixed left-1/2 top-1/2 -translate-x-1/2 -translate-y-1/2",
|
||||
positionClasses,
|
||||
"z-modal",
|
||||
"flex flex-col gap-4 items-center",
|
||||
"max-w-[calc(100dvw-2rem)] max-h-[calc(100dvh-2rem)]",
|
||||
@@ -334,8 +354,10 @@ const ModalContent = React.forwardRef<
|
||||
// Without bottomSlot: original single-element rendering
|
||||
<DialogPrimitive.Content
|
||||
ref={handleRef}
|
||||
style={containerStyle}
|
||||
className={cn(
|
||||
"fixed left-1/2 top-1/2 -translate-x-1/2 -translate-y-1/2 overflow-hidden",
|
||||
positionClasses,
|
||||
"overflow-hidden",
|
||||
"z-modal",
|
||||
background === "gray"
|
||||
? "bg-background-tint-01"
|
||||
|
||||
@@ -10,6 +10,7 @@ import React, {
|
||||
} from "react";
|
||||
import * as DialogPrimitive from "@radix-ui/react-dialog";
|
||||
import * as VisuallyHidden from "@radix-ui/react-visually-hidden";
|
||||
import useContainerCenter from "@/hooks/useContainerCenter";
|
||||
import { cn } from "@/lib/utils";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
|
||||
@@ -366,10 +367,11 @@ const CommandMenuContent = React.forwardRef<
|
||||
CommandMenuContentProps
|
||||
>(({ children }, ref) => {
|
||||
const { handleKeyDown } = useCommandMenuContext();
|
||||
const { centerX, hasContainerCenter } = useContainerCenter();
|
||||
|
||||
return (
|
||||
<DialogPrimitive.Portal>
|
||||
{/* Overlay - hidden from assistive technology */}
|
||||
{/* Overlay - fixed to full viewport, hidden from assistive technology */}
|
||||
<DialogPrimitive.Overlay
|
||||
aria-hidden="true"
|
||||
className={cn(
|
||||
@@ -378,12 +380,23 @@ const CommandMenuContent = React.forwardRef<
|
||||
"data-[state=open]:fade-in-0 data-[state=closed]:fade-out-0"
|
||||
)}
|
||||
/>
|
||||
{/* Content */}
|
||||
{/* Content - centered within the main container when available,
|
||||
otherwise falls back to viewport centering */}
|
||||
<DialogPrimitive.Content
|
||||
ref={ref}
|
||||
onKeyDown={handleKeyDown}
|
||||
style={
|
||||
hasContainerCenter
|
||||
? ({
|
||||
left: centerX,
|
||||
"--tw-enter-translate-x": "-50%",
|
||||
"--tw-exit-translate-x": "-50%",
|
||||
} as React.CSSProperties)
|
||||
: undefined
|
||||
}
|
||||
className={cn(
|
||||
"fixed inset-x-0 top-[72px] mx-auto",
|
||||
"fixed top-[72px]",
|
||||
hasContainerCenter ? "-translate-x-1/2" : "inset-x-0 mx-auto",
|
||||
"z-modal",
|
||||
"bg-background-tint-00 border rounded-16 shadow-2xl outline-none",
|
||||
"flex flex-col overflow-hidden",
|
||||
|
||||
@@ -26,8 +26,6 @@ import ExceptionTraceModal from "@/components/modals/ExceptionTraceModal";
|
||||
import { useUser } from "@/providers/UserProvider";
|
||||
import NoAssistantModal from "@/components/modals/NoAssistantModal";
|
||||
import TextViewModal from "@/sections/modals/TextViewModal";
|
||||
import CodeViewModal from "@/sections/modals/CodeViewModal";
|
||||
import { getCodeLanguage } from "@/lib/languages";
|
||||
import Modal from "@/refresh-components/Modal";
|
||||
import { useSendMessageToParent } from "@/lib/extension/utils";
|
||||
import { SUBMIT_MESSAGE_TYPES } from "@/lib/extension/constants";
|
||||
@@ -684,18 +682,12 @@ export default function AppPage({ firstMessage }: ChatPageProps) {
|
||||
</div>
|
||||
)}
|
||||
|
||||
{presentingDocument &&
|
||||
(getCodeLanguage(presentingDocument.semantic_identifier || "") ? (
|
||||
<CodeViewModal
|
||||
presentingDocument={presentingDocument}
|
||||
onClose={() => setPresentingDocument(null)}
|
||||
/>
|
||||
) : (
|
||||
<TextViewModal
|
||||
presentingDocument={presentingDocument}
|
||||
onClose={() => setPresentingDocument(null)}
|
||||
/>
|
||||
))}
|
||||
{presentingDocument && (
|
||||
<TextViewModal
|
||||
presentingDocument={presentingDocument}
|
||||
onClose={() => setPresentingDocument(null)}
|
||||
/>
|
||||
)}
|
||||
|
||||
{stackTraceModalContent && (
|
||||
<ExceptionTraceModal
|
||||
|
||||
@@ -346,6 +346,7 @@ const ChatScrollContainer = React.memo(
|
||||
<div
|
||||
key={sessionId}
|
||||
ref={scrollContainerRef}
|
||||
data-testid="chat-scroll-container"
|
||||
className="flex flex-col flex-1 min-h-0 overflow-y-auto overflow-x-hidden default-scrollbar"
|
||||
onScroll={handleScroll}
|
||||
style={{
|
||||
|
||||
55
web/src/sections/input/SharedAppInputBar.tsx
Normal file
55
web/src/sections/input/SharedAppInputBar.tsx
Normal file
@@ -0,0 +1,55 @@
|
||||
"use client";
|
||||
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { Button, OpenButton } from "@opal/components";
|
||||
import { OpenAISVG } from "@/components/icons/icons";
|
||||
import {
|
||||
SvgPlusCircle,
|
||||
SvgArrowUp,
|
||||
SvgSliders,
|
||||
SvgHourglass,
|
||||
SvgEditBig,
|
||||
} from "@opal/icons";
|
||||
|
||||
export default function SharedAppInputBar() {
|
||||
return (
|
||||
<div className="relative w-full">
|
||||
<div className="w-full flex flex-col shadow-01 bg-background-neutral-00 rounded-16">
|
||||
{/* Textarea area */}
|
||||
<div className="flex flex-row items-center w-full">
|
||||
<Text text03 className="w-full px-3 pt-3 pb-2 select-none">
|
||||
How can Onyx help you today
|
||||
</Text>
|
||||
</div>
|
||||
|
||||
{/* Bottom toolbar */}
|
||||
<div className="flex justify-between items-center w-full p-1 min-h-[40px]">
|
||||
{/* Left side controls */}
|
||||
<div className="flex flex-row items-center">
|
||||
<Button icon={SvgPlusCircle} prominence="tertiary" disabled />
|
||||
<Button icon={SvgSliders} prominence="tertiary" disabled />
|
||||
<Button icon={SvgHourglass} variant="select" disabled />
|
||||
</div>
|
||||
|
||||
{/* Right side controls */}
|
||||
<div className="flex flex-row items-center gap-1">
|
||||
<OpenButton icon={OpenAISVG} foldable disabled>
|
||||
GPT-4o
|
||||
</OpenButton>
|
||||
<Button icon={SvgArrowUp} disabled />
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
{/* Fade overlay */}
|
||||
<div className="absolute inset-0 rounded-16 backdrop-blur-sm bg-background-neutral-00/50" />
|
||||
|
||||
{/* CTA button */}
|
||||
<div className="absolute inset-0 flex items-center justify-center">
|
||||
<Button prominence="secondary" icon={SvgEditBig} href="/app">
|
||||
Start New Session
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -1,42 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import { MinimalOnyxDocument } from "@/lib/search/interfaces";
|
||||
import DocumentViewModal, {
|
||||
DocumentData,
|
||||
} from "@/sections/modals/DocumentViewModal";
|
||||
import { getCodeLanguage } from "@/lib/languages";
|
||||
import MinimalMarkdown from "@/components/chat/MinimalMarkdown";
|
||||
import "@/app/app/message/custom-code-styles.css";
|
||||
import ScrollIndicatorDiv from "@/refresh-components/ScrollIndicatorDiv";
|
||||
|
||||
export interface CodeViewProps {
|
||||
presentingDocument: MinimalOnyxDocument;
|
||||
onClose: () => void;
|
||||
}
|
||||
|
||||
export default function CodeViewModal({
|
||||
presentingDocument,
|
||||
onClose,
|
||||
}: CodeViewProps) {
|
||||
const language =
|
||||
getCodeLanguage(presentingDocument.semantic_identifier || "") ||
|
||||
"plaintext";
|
||||
|
||||
const renderContent = (data: DocumentData) => (
|
||||
<ScrollIndicatorDiv className="flex-1 min-h-0 p-4" variant="shadow">
|
||||
<MinimalMarkdown
|
||||
content={`\`\`\`${language}\n${data.fileContent}\n\`\`\``}
|
||||
className="w-full pb-4 h-full break-words"
|
||||
/>
|
||||
</ScrollIndicatorDiv>
|
||||
);
|
||||
|
||||
return (
|
||||
<DocumentViewModal
|
||||
presentingDocument={presentingDocument}
|
||||
onClose={onClose}
|
||||
renderContent={renderContent}
|
||||
width="md"
|
||||
/>
|
||||
);
|
||||
}
|
||||
@@ -1,186 +0,0 @@
|
||||
"use client";
|
||||
|
||||
import React, { useState, useEffect, useCallback } from "react";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import { MinimalOnyxDocument } from "@/lib/search/interfaces";
|
||||
import { Button as OpalButton } from "@opal/components";
|
||||
import Modal, { BasicModalFooter } from "@/refresh-components/Modal";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { SvgDownloadCloud, SvgFileText } from "@opal/icons";
|
||||
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
|
||||
export interface DocumentData {
|
||||
fileContent: string;
|
||||
fileUrl: string;
|
||||
fileName: string;
|
||||
fileType: string;
|
||||
handleDownload: () => void;
|
||||
}
|
||||
|
||||
export interface DocumentViewModalProps {
|
||||
presentingDocument: MinimalOnyxDocument;
|
||||
onClose: () => void;
|
||||
headerExtras?: (data: DocumentData) => React.ReactNode;
|
||||
renderContent: (data: DocumentData) => React.ReactNode;
|
||||
width?: "sm" | "md" | "lg";
|
||||
}
|
||||
|
||||
export default function DocumentViewModal({
|
||||
presentingDocument,
|
||||
onClose,
|
||||
headerExtras,
|
||||
renderContent,
|
||||
width = "lg",
|
||||
}: DocumentViewModalProps) {
|
||||
const [fileContent, setFileContent] = useState("");
|
||||
const [fileUrl, setFileUrl] = useState("");
|
||||
const [fileName, setFileName] = useState("");
|
||||
const [isLoading, setIsLoading] = useState(true);
|
||||
const [loadError, setLoadError] = useState<string | null>(null);
|
||||
const [fileType, setFileType] = useState("application/octet-stream");
|
||||
|
||||
const fetchFile = useCallback(
|
||||
async (signal?: AbortSignal) => {
|
||||
setIsLoading(true);
|
||||
setLoadError(null);
|
||||
setFileContent("");
|
||||
const fileIdLocal =
|
||||
presentingDocument.document_id.split("__")[1] ||
|
||||
presentingDocument.document_id;
|
||||
|
||||
try {
|
||||
const response = await fetch(
|
||||
`/api/chat/file/${encodeURIComponent(fileIdLocal)}`,
|
||||
{
|
||||
method: "GET",
|
||||
signal,
|
||||
cache: "force-cache",
|
||||
}
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
setLoadError("Failed to load document.");
|
||||
return;
|
||||
}
|
||||
|
||||
const blob = await response.blob();
|
||||
const url = window.URL.createObjectURL(blob);
|
||||
setFileUrl((prev) => {
|
||||
if (prev) {
|
||||
window.URL.revokeObjectURL(prev);
|
||||
}
|
||||
return url;
|
||||
});
|
||||
|
||||
const originalFileName =
|
||||
presentingDocument.semantic_identifier || "document";
|
||||
setFileName(originalFileName);
|
||||
|
||||
const contentType =
|
||||
response.headers.get("Content-Type") || "application/octet-stream";
|
||||
setFileType(contentType);
|
||||
|
||||
const text = await blob.text();
|
||||
setFileContent(text);
|
||||
} catch (error) {
|
||||
if (signal?.aborted) {
|
||||
return;
|
||||
}
|
||||
setLoadError("Failed to load document.");
|
||||
} finally {
|
||||
if (!signal?.aborted) {
|
||||
setIsLoading(false);
|
||||
}
|
||||
}
|
||||
},
|
||||
[presentingDocument]
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
const controller = new AbortController();
|
||||
fetchFile(controller.signal);
|
||||
return () => {
|
||||
controller.abort();
|
||||
};
|
||||
}, [fetchFile]);
|
||||
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
if (fileUrl) {
|
||||
window.URL.revokeObjectURL(fileUrl);
|
||||
}
|
||||
};
|
||||
}, [fileUrl]);
|
||||
|
||||
const handleDownload = () => {
|
||||
const link = document.createElement("a");
|
||||
link.href = fileUrl;
|
||||
link.download = fileName || presentingDocument.document_id;
|
||||
document.body.appendChild(link);
|
||||
link.click();
|
||||
document.body.removeChild(link);
|
||||
};
|
||||
|
||||
const data: DocumentData = {
|
||||
fileContent,
|
||||
fileUrl,
|
||||
fileName,
|
||||
fileType,
|
||||
handleDownload,
|
||||
};
|
||||
|
||||
return (
|
||||
<Modal
|
||||
open
|
||||
onOpenChange={(open) => {
|
||||
if (!open) {
|
||||
onClose();
|
||||
}
|
||||
}}
|
||||
>
|
||||
<Modal.Content
|
||||
width={width}
|
||||
height="full"
|
||||
preventAccidentalClose={false}
|
||||
onOpenAutoFocus={(e) => e.preventDefault()}
|
||||
>
|
||||
<Modal.Header
|
||||
icon={SvgFileText}
|
||||
title={fileName || "Document"}
|
||||
onClose={onClose}
|
||||
>
|
||||
<Section flexDirection="row" justifyContent="start" gap={0.25}>
|
||||
{headerExtras?.(data)}
|
||||
<OpalButton
|
||||
prominence="tertiary"
|
||||
onClick={handleDownload}
|
||||
icon={SvgDownloadCloud}
|
||||
tooltip="Download"
|
||||
/>
|
||||
</Section>
|
||||
</Modal.Header>
|
||||
|
||||
<Modal.Body>
|
||||
<Section>
|
||||
{isLoading ? (
|
||||
<SimpleLoader className="h-8 w-8" />
|
||||
) : loadError ? (
|
||||
<Text text03 mainUiBody>
|
||||
{loadError}
|
||||
</Text>
|
||||
) : (
|
||||
renderContent(data)
|
||||
)}
|
||||
</Section>
|
||||
</Modal.Body>
|
||||
|
||||
<Modal.Footer>
|
||||
<BasicModalFooter
|
||||
submit={<Button onClick={handleDownload}>Download File</Button>}
|
||||
/>
|
||||
</Modal.Footer>
|
||||
</Modal.Content>
|
||||
</Modal>
|
||||
);
|
||||
}
|
||||
255
web/src/sections/modals/ShareChatSessionModal.tsx
Normal file
255
web/src/sections/modals/ShareChatSessionModal.tsx
Normal file
@@ -0,0 +1,255 @@
|
||||
"use client";
|
||||
|
||||
import { useState } from "react";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { ChatSession, ChatSessionSharedStatus } from "@/app/app/interfaces";
|
||||
import { toast } from "@/hooks/useToast";
|
||||
import { useChatSessionStore } from "@/app/app/stores/useChatSessionStore";
|
||||
import { copyAll } from "@/app/app/message/copyingUtils";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
import Modal from "@/refresh-components/Modal";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import CopyIconButton from "@/refresh-components/buttons/CopyIconButton";
|
||||
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { SvgLink, SvgShare, SvgUsers } from "@opal/icons";
|
||||
import SvgCheck from "@opal/icons/check";
|
||||
import SvgLock from "@opal/icons/lock";
|
||||
|
||||
import type { IconProps } from "@opal/types";
|
||||
import useChatSessions from "@/hooks/useChatSessions";
|
||||
|
||||
function buildShareLink(chatSessionId: string) {
|
||||
const baseUrl = `${window.location.protocol}//${window.location.host}`;
|
||||
return `${baseUrl}/app/shared/${chatSessionId}`;
|
||||
}
|
||||
|
||||
async function generateShareLink(chatSessionId: string) {
|
||||
const response = await fetch(`/api/chat/chat-session/${chatSessionId}`, {
|
||||
method: "PATCH",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({ sharing_status: "public" }),
|
||||
});
|
||||
|
||||
if (response.ok) {
|
||||
return buildShareLink(chatSessionId);
|
||||
}
|
||||
return null;
|
||||
}
|
||||
|
||||
async function deleteShareLink(chatSessionId: string) {
|
||||
const response = await fetch(`/api/chat/chat-session/${chatSessionId}`, {
|
||||
method: "PATCH",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
body: JSON.stringify({ sharing_status: "private" }),
|
||||
});
|
||||
|
||||
return response.ok;
|
||||
}
|
||||
|
||||
interface PrivacyOptionProps {
|
||||
icon: React.FunctionComponent<IconProps>;
|
||||
title: string;
|
||||
description: string;
|
||||
selected: boolean;
|
||||
onClick: () => void;
|
||||
ariaLabel?: string;
|
||||
}
|
||||
|
||||
function PrivacyOption({
|
||||
icon: Icon,
|
||||
title,
|
||||
description,
|
||||
selected,
|
||||
onClick,
|
||||
ariaLabel,
|
||||
}: PrivacyOptionProps) {
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
"p-1.5 rounded-08 cursor-pointer ",
|
||||
selected ? "bg-background-tint-00" : "bg-transparent",
|
||||
"hover:bg-background-tint-02"
|
||||
)}
|
||||
onClick={onClick}
|
||||
aria-label={ariaLabel}
|
||||
>
|
||||
<div className="flex flex-row gap-1 items-center">
|
||||
<div className="flex w-5 p-[2px] self-stretch justify-center">
|
||||
<Icon
|
||||
size={16}
|
||||
className={cn(selected ? "stroke-text-05" : "stroke-text-03")}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex flex-col flex-1 px-0.5">
|
||||
<Text mainUiBody text05={selected} text03={!selected}>
|
||||
{title}
|
||||
</Text>
|
||||
<Text secondaryBody text03>
|
||||
{description}
|
||||
</Text>
|
||||
</div>
|
||||
{selected && (
|
||||
<div className="flex w-5 self-stretch justify-center">
|
||||
<SvgCheck size={16} className="stroke-action-link-05" />
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
interface ShareChatSessionModalProps {
|
||||
chatSession: ChatSession;
|
||||
onClose: () => void;
|
||||
}
|
||||
|
||||
export default function ShareChatSessionModal({
|
||||
chatSession,
|
||||
onClose,
|
||||
}: ShareChatSessionModalProps) {
|
||||
const isCurrentlyPublic =
|
||||
chatSession.shared_status === ChatSessionSharedStatus.Public;
|
||||
|
||||
const [selectedPrivacy, setSelectedPrivacy] = useState<"private" | "public">(
|
||||
isCurrentlyPublic ? "public" : "private"
|
||||
);
|
||||
const [shareLink, setShareLink] = useState<string>(
|
||||
isCurrentlyPublic ? buildShareLink(chatSession.id) : ""
|
||||
);
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
const updateCurrentChatSessionSharedStatus = useChatSessionStore(
|
||||
(state) => state.updateCurrentChatSessionSharedStatus
|
||||
);
|
||||
const { refreshChatSessions } = useChatSessions();
|
||||
|
||||
const wantsPublic = selectedPrivacy === "public";
|
||||
|
||||
const isShared = shareLink && selectedPrivacy === "public";
|
||||
|
||||
let submitButtonText = "Done";
|
||||
if (wantsPublic && !isCurrentlyPublic && !shareLink) {
|
||||
submitButtonText = "Create Share Link";
|
||||
} else if (!wantsPublic && isCurrentlyPublic) {
|
||||
submitButtonText = "Make Private";
|
||||
} else if (isShared) {
|
||||
submitButtonText = "Copy Link";
|
||||
}
|
||||
|
||||
async function handleSubmit() {
|
||||
setIsLoading(true);
|
||||
try {
|
||||
if (wantsPublic && !isCurrentlyPublic && !shareLink) {
|
||||
const link = await generateShareLink(chatSession.id);
|
||||
if (link) {
|
||||
setShareLink(link);
|
||||
updateCurrentChatSessionSharedStatus(ChatSessionSharedStatus.Public);
|
||||
await refreshChatSessions();
|
||||
copyAll(link);
|
||||
toast.success("Share link copied to clipboard!");
|
||||
} else {
|
||||
toast.error("Failed to generate share link");
|
||||
}
|
||||
} else if (!wantsPublic && isCurrentlyPublic) {
|
||||
const success = await deleteShareLink(chatSession.id);
|
||||
if (success) {
|
||||
setShareLink("");
|
||||
updateCurrentChatSessionSharedStatus(ChatSessionSharedStatus.Private);
|
||||
await refreshChatSessions();
|
||||
toast.success("Chat is now private");
|
||||
onClose();
|
||||
} else {
|
||||
toast.error("Failed to make chat private");
|
||||
}
|
||||
} else if (wantsPublic && shareLink) {
|
||||
copyAll(shareLink);
|
||||
toast.success("Share link copied to clipboard!");
|
||||
} else {
|
||||
onClose();
|
||||
}
|
||||
} catch (e) {
|
||||
console.error(e);
|
||||
toast.error("An error occurred");
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
<Modal open onOpenChange={(isOpen) => !isOpen && onClose()}>
|
||||
<Modal.Content width="sm">
|
||||
<Modal.Header
|
||||
icon={SvgShare}
|
||||
title={isShared ? "Chat shared" : "Share this chat"}
|
||||
description="All existing and future messages in this chat will be shared."
|
||||
onClose={onClose}
|
||||
/>
|
||||
<Modal.Body twoTone>
|
||||
<Section
|
||||
justifyContent="start"
|
||||
alignItems="stretch"
|
||||
gap={1}
|
||||
height="auto"
|
||||
>
|
||||
<Section
|
||||
justifyContent="start"
|
||||
alignItems="stretch"
|
||||
height="auto"
|
||||
gap={0.12}
|
||||
>
|
||||
<PrivacyOption
|
||||
icon={SvgLock}
|
||||
title="Private"
|
||||
description="Only you have access to this chat."
|
||||
selected={selectedPrivacy === "private"}
|
||||
onClick={() => setSelectedPrivacy("private")}
|
||||
ariaLabel="share-modal-option-private"
|
||||
/>
|
||||
<PrivacyOption
|
||||
icon={SvgUsers}
|
||||
title="Your Organization"
|
||||
description="Anyone in your organization can view this chat."
|
||||
selected={selectedPrivacy === "public"}
|
||||
onClick={() => setSelectedPrivacy("public")}
|
||||
ariaLabel="share-modal-option-public"
|
||||
/>
|
||||
</Section>
|
||||
|
||||
{isShared && (
|
||||
<div aria-label="share-modal-link-input">
|
||||
<InputTypeIn
|
||||
readOnly
|
||||
value={shareLink}
|
||||
rightSection={
|
||||
<CopyIconButton
|
||||
getCopyText={() => shareLink}
|
||||
tooltip="Copy link"
|
||||
size="sm"
|
||||
aria-label="share-modal-copy-link"
|
||||
/>
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</Section>
|
||||
</Modal.Body>
|
||||
<Modal.Footer>
|
||||
{!isShared && (
|
||||
<Button secondary onClick={onClose} aria-label="share-modal-cancel">
|
||||
Cancel
|
||||
</Button>
|
||||
)}
|
||||
<Button
|
||||
onClick={handleSubmit}
|
||||
disabled={isLoading}
|
||||
leftIcon={isShared ? SvgLink : undefined}
|
||||
className={isShared ? "w-full" : undefined}
|
||||
aria-label="share-modal-submit"
|
||||
>
|
||||
{submitButtonText}
|
||||
</Button>
|
||||
</Modal.Footer>
|
||||
</Modal.Content>
|
||||
</Modal>
|
||||
);
|
||||
}
|
||||
@@ -1,14 +1,7 @@
|
||||
"use client";
|
||||
|
||||
import { useState, useMemo } from "react";
|
||||
import { useState, useEffect, useCallback, useMemo } from "react";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import { MinimalOnyxDocument } from "@/lib/search/interfaces";
|
||||
import { Button as OpalButton } from "@opal/components";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { SvgZoomIn, SvgZoomOut } from "@opal/icons";
|
||||
import DocumentViewModal, {
|
||||
DocumentData,
|
||||
} from "@/sections/modals/DocumentViewModal";
|
||||
import {
|
||||
Table,
|
||||
TableBody,
|
||||
@@ -17,73 +10,38 @@ import {
|
||||
TableHeader,
|
||||
TableRow,
|
||||
} from "@/components/ui/table";
|
||||
import { MinimalOnyxDocument } from "@/lib/search/interfaces";
|
||||
import MinimalMarkdown from "@/components/chat/MinimalMarkdown";
|
||||
import { Button as OpalButton } from "@opal/components";
|
||||
import Modal, { BasicModalFooter } from "@/refresh-components/Modal";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import {
|
||||
SvgDownloadCloud,
|
||||
SvgFileText,
|
||||
SvgZoomIn,
|
||||
SvgZoomOut,
|
||||
} from "@opal/icons";
|
||||
import SimpleLoader from "@/refresh-components/loaders/SimpleLoader";
|
||||
import ScrollIndicatorDiv from "@/refresh-components/ScrollIndicatorDiv";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { Section } from "@/layouts/general-layouts";
|
||||
|
||||
export interface TextViewProps {
|
||||
presentingDocument: MinimalOnyxDocument;
|
||||
onClose: () => void;
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolves an effective content type by overriding generic octet-stream
|
||||
* when the file name suggests a known text-based extension.
|
||||
*/
|
||||
function resolveTextContentType(
|
||||
rawContentType: string,
|
||||
fileName: string
|
||||
): string {
|
||||
if (rawContentType !== "application/octet-stream") return rawContentType;
|
||||
const lowerName = fileName.toLowerCase();
|
||||
if (lowerName.endsWith(".md") || lowerName.endsWith(".markdown"))
|
||||
return "text/markdown";
|
||||
if (lowerName.endsWith(".txt")) return "text/plain";
|
||||
if (lowerName.endsWith(".csv")) return "text/csv";
|
||||
return rawContentType;
|
||||
}
|
||||
|
||||
function isMarkdownFormat(mimeType: string): boolean {
|
||||
const markdownFormats = [
|
||||
"text/markdown",
|
||||
"text/x-markdown",
|
||||
"text/plain",
|
||||
"text/csv",
|
||||
"text/x-rst",
|
||||
"text/x-org",
|
||||
"txt",
|
||||
];
|
||||
return markdownFormats.some((format) => mimeType.startsWith(format));
|
||||
}
|
||||
|
||||
function isImageFormat(mimeType: string): boolean {
|
||||
const imageFormats = [
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"image/gif",
|
||||
"image/svg+xml",
|
||||
];
|
||||
return imageFormats.some((format) => mimeType.startsWith(format));
|
||||
}
|
||||
|
||||
function isSupportedIframeFormat(mimeType: string): boolean {
|
||||
const supportedFormats = [
|
||||
"application/pdf",
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"image/gif",
|
||||
"image/svg+xml",
|
||||
];
|
||||
return supportedFormats.some((format) => mimeType.startsWith(format));
|
||||
}
|
||||
|
||||
function TextViewContent({
|
||||
fileContent,
|
||||
fileType,
|
||||
}: {
|
||||
fileContent: string;
|
||||
fileType: string;
|
||||
}) {
|
||||
export default function TextViewModal({
|
||||
presentingDocument,
|
||||
onClose,
|
||||
}: TextViewProps) {
|
||||
const [zoom, setZoom] = useState(100);
|
||||
const [fileContent, setFileContent] = useState("");
|
||||
const [fileUrl, setFileUrl] = useState("");
|
||||
const [fileName, setFileName] = useState("");
|
||||
const [isLoading, setIsLoading] = useState(true);
|
||||
const [loadError, setLoadError] = useState<string | null>(null);
|
||||
const [fileType, setFileType] = useState("application/octet-stream");
|
||||
const csvData = useMemo(() => {
|
||||
if (!fileType.startsWith("text/csv")) {
|
||||
return null;
|
||||
@@ -96,122 +54,280 @@ function TextViewContent({
|
||||
return { headers, rows } as { headers: string[]; rows: string[][] };
|
||||
}, [fileContent, fileType]);
|
||||
|
||||
return (
|
||||
<ScrollIndicatorDiv className="flex-1 min-h-0 p-4" variant="shadow">
|
||||
{csvData ? (
|
||||
<Table>
|
||||
<TableHeader className="sticky top-0 z-sticky">
|
||||
<TableRow className="bg-background-tint-02">
|
||||
{csvData.headers.map((h, i) => (
|
||||
<TableHead key={i}>
|
||||
<Text
|
||||
as="p"
|
||||
className="line-clamp-2 font-medium"
|
||||
text03
|
||||
mainUiBody
|
||||
>
|
||||
{h}
|
||||
</Text>
|
||||
</TableHead>
|
||||
))}
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{csvData.rows.map((row, rIdx) => (
|
||||
<TableRow key={rIdx}>
|
||||
{csvData.headers.map((_, cIdx) => (
|
||||
<TableCell
|
||||
key={cIdx}
|
||||
className={cn(
|
||||
cIdx === 0 && "sticky left-0 bg-background-tint-01",
|
||||
"py-0 px-4 whitespace-normal break-words"
|
||||
)}
|
||||
>
|
||||
{row?.[cIdx] ?? ""}
|
||||
</TableCell>
|
||||
))}
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
) : (
|
||||
<MinimalMarkdown
|
||||
content={fileContent}
|
||||
className="w-full pb-4 h-full text-lg break-words"
|
||||
/>
|
||||
)}
|
||||
</ScrollIndicatorDiv>
|
||||
);
|
||||
}
|
||||
// Detect if a given MIME type is one of the recognized markdown formats
|
||||
const isMarkdownFormat = (mimeType: string): boolean => {
|
||||
const markdownFormats = [
|
||||
"text/markdown",
|
||||
"text/x-markdown",
|
||||
"text/plain",
|
||||
"text/csv",
|
||||
"text/x-rst",
|
||||
"text/x-org",
|
||||
"txt",
|
||||
];
|
||||
return markdownFormats.some((format) => mimeType.startsWith(format));
|
||||
};
|
||||
|
||||
export default function TextViewModal({
|
||||
presentingDocument,
|
||||
onClose,
|
||||
}: TextViewProps) {
|
||||
const [zoom, setZoom] = useState(100);
|
||||
const isImageFormat = (mimeType: string) => {
|
||||
const imageFormats = [
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"image/gif",
|
||||
"image/svg+xml",
|
||||
];
|
||||
return imageFormats.some((format) => mimeType.startsWith(format));
|
||||
};
|
||||
// Detect if a given MIME type can be rendered in an <iframe>
|
||||
const isSupportedIframeFormat = (mimeType: string): boolean => {
|
||||
const supportedFormats = [
|
||||
"application/pdf",
|
||||
"image/png",
|
||||
"image/jpeg",
|
||||
"image/gif",
|
||||
"image/svg+xml",
|
||||
];
|
||||
return supportedFormats.some((format) => mimeType.startsWith(format));
|
||||
};
|
||||
|
||||
const fetchFile = useCallback(
|
||||
async (signal?: AbortSignal) => {
|
||||
setIsLoading(true);
|
||||
setLoadError(null);
|
||||
setFileContent("");
|
||||
const fileIdLocal =
|
||||
presentingDocument.document_id.split("__")[1] ||
|
||||
presentingDocument.document_id;
|
||||
|
||||
try {
|
||||
const response = await fetch(
|
||||
`/api/chat/file/${encodeURIComponent(fileIdLocal)}`,
|
||||
{
|
||||
method: "GET",
|
||||
signal,
|
||||
cache: "force-cache",
|
||||
}
|
||||
);
|
||||
|
||||
if (!response.ok) {
|
||||
setLoadError("Failed to load document.");
|
||||
return;
|
||||
}
|
||||
|
||||
const blob = await response.blob();
|
||||
const url = window.URL.createObjectURL(blob);
|
||||
setFileUrl((prev) => {
|
||||
if (prev) {
|
||||
window.URL.revokeObjectURL(prev);
|
||||
}
|
||||
return url;
|
||||
});
|
||||
|
||||
const originalFileName =
|
||||
presentingDocument.semantic_identifier || "document";
|
||||
setFileName(originalFileName);
|
||||
|
||||
let contentType =
|
||||
response.headers.get("Content-Type") || "application/octet-stream";
|
||||
|
||||
// If it's octet-stream but file name suggests a text-based extension, override accordingly
|
||||
if (contentType === "application/octet-stream") {
|
||||
const lowerName = originalFileName.toLowerCase();
|
||||
if (lowerName.endsWith(".md") || lowerName.endsWith(".markdown")) {
|
||||
contentType = "text/markdown";
|
||||
} else if (lowerName.endsWith(".txt")) {
|
||||
contentType = "text/plain";
|
||||
} else if (lowerName.endsWith(".csv")) {
|
||||
contentType = "text/csv";
|
||||
}
|
||||
}
|
||||
setFileType(contentType);
|
||||
|
||||
// If the final content type looks like markdown, read its text
|
||||
if (isMarkdownFormat(contentType)) {
|
||||
const text = await blob.text();
|
||||
setFileContent(text);
|
||||
}
|
||||
} catch (error) {
|
||||
// Abort is expected on unmount / doc change
|
||||
if (signal?.aborted) {
|
||||
return;
|
||||
}
|
||||
setLoadError("Failed to load document.");
|
||||
} finally {
|
||||
// Prevent stale/aborted requests from clobbering the loading state.
|
||||
// This is especially important in React StrictMode where effects can run twice.
|
||||
if (!signal?.aborted) {
|
||||
setIsLoading(false);
|
||||
}
|
||||
}
|
||||
},
|
||||
[presentingDocument]
|
||||
);
|
||||
|
||||
useEffect(() => {
|
||||
const controller = new AbortController();
|
||||
fetchFile(controller.signal);
|
||||
return () => {
|
||||
controller.abort();
|
||||
};
|
||||
}, [fetchFile]);
|
||||
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
if (fileUrl) {
|
||||
window.URL.revokeObjectURL(fileUrl);
|
||||
}
|
||||
};
|
||||
}, [fileUrl]);
|
||||
|
||||
const handleDownload = () => {
|
||||
const link = document.createElement("a");
|
||||
link.href = fileUrl;
|
||||
link.download = fileName || presentingDocument.document_id;
|
||||
document.body.appendChild(link);
|
||||
link.click();
|
||||
document.body.removeChild(link);
|
||||
};
|
||||
|
||||
const handleZoomIn = () => setZoom((prev) => Math.min(prev + 25, 200));
|
||||
const handleZoomOut = () => setZoom((prev) => Math.max(prev - 25, 100));
|
||||
|
||||
const renderHeaderExtras = () => (
|
||||
<>
|
||||
<OpalButton
|
||||
prominence="tertiary"
|
||||
onClick={handleZoomOut}
|
||||
icon={SvgZoomOut}
|
||||
tooltip="Zoom Out"
|
||||
/>
|
||||
<Text mainUiBody>{zoom}%</Text>
|
||||
<OpalButton
|
||||
prominence="tertiary"
|
||||
onClick={handleZoomIn}
|
||||
icon={SvgZoomIn}
|
||||
tooltip="Zoom In"
|
||||
/>
|
||||
</>
|
||||
);
|
||||
|
||||
const renderContent = (data: DocumentData) => {
|
||||
const { fileUrl, fileName, fileContent, handleDownload } = data;
|
||||
const fileType = resolveTextContentType(data.fileType, fileName);
|
||||
|
||||
return (
|
||||
<div
|
||||
className="flex flex-col flex-1 min-h-0 min-w-0 w-full transform origin-center transition-transform duration-300 ease-in-out"
|
||||
style={{ transform: `scale(${zoom / 100})` }}
|
||||
>
|
||||
{isImageFormat(fileType) ? (
|
||||
<img
|
||||
src={fileUrl}
|
||||
alt={fileName}
|
||||
className="w-full flex-1 min-h-0 object-contain object-center"
|
||||
/>
|
||||
) : isSupportedIframeFormat(fileType) ? (
|
||||
<iframe
|
||||
src={`${fileUrl}#toolbar=0`}
|
||||
className="w-full h-full flex-1 min-h-0 border-none"
|
||||
title="File Viewer"
|
||||
/>
|
||||
) : isMarkdownFormat(fileType) ? (
|
||||
<TextViewContent fileContent={fileContent} fileType={fileType} />
|
||||
) : (
|
||||
<div className="flex flex-col items-center justify-center flex-1 min-h-0 p-6 gap-4">
|
||||
<Text as="p" text03 mainUiBody>
|
||||
This file format is not supported for preview.
|
||||
</Text>
|
||||
<Button onClick={handleDownload}>Download File</Button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
return (
|
||||
<DocumentViewModal
|
||||
presentingDocument={presentingDocument}
|
||||
onClose={onClose}
|
||||
headerExtras={renderHeaderExtras}
|
||||
renderContent={renderContent}
|
||||
/>
|
||||
<Modal
|
||||
open
|
||||
onOpenChange={(open) => {
|
||||
if (!open) {
|
||||
onClose();
|
||||
}
|
||||
}}
|
||||
>
|
||||
<Modal.Content
|
||||
width="lg"
|
||||
height="full"
|
||||
preventAccidentalClose={false}
|
||||
onOpenAutoFocus={(e) => e.preventDefault()}
|
||||
>
|
||||
<Modal.Header
|
||||
icon={SvgFileText}
|
||||
title={fileName || "Document"}
|
||||
onClose={onClose}
|
||||
>
|
||||
<Section flexDirection="row" justifyContent="start" gap={0.25}>
|
||||
<OpalButton
|
||||
prominence="tertiary"
|
||||
onClick={handleZoomOut}
|
||||
icon={SvgZoomOut}
|
||||
tooltip="Zoom Out"
|
||||
/>
|
||||
<Text mainUiBody>{zoom}%</Text>
|
||||
<OpalButton
|
||||
prominence="tertiary"
|
||||
onClick={handleZoomIn}
|
||||
icon={SvgZoomIn}
|
||||
tooltip="Zoom In"
|
||||
/>
|
||||
<OpalButton
|
||||
prominence="tertiary"
|
||||
onClick={handleDownload}
|
||||
icon={SvgDownloadCloud}
|
||||
tooltip="Download"
|
||||
/>
|
||||
</Section>
|
||||
</Modal.Header>
|
||||
|
||||
<Modal.Body>
|
||||
<Section>
|
||||
{isLoading ? (
|
||||
<SimpleLoader className="h-8 w-8" />
|
||||
) : loadError ? (
|
||||
<Text text03 mainUiBody>
|
||||
{loadError}
|
||||
</Text>
|
||||
) : (
|
||||
<div
|
||||
className="flex flex-col flex-1 min-h-0 min-w-0 w-full transform origin-center transition-transform duration-300 ease-in-out"
|
||||
style={{ transform: `scale(${zoom / 100})` }}
|
||||
>
|
||||
{isImageFormat(fileType) ? (
|
||||
<img
|
||||
src={fileUrl}
|
||||
alt={fileName}
|
||||
className="w-full flex-1 min-h-0 object-contain object-center"
|
||||
/>
|
||||
) : isSupportedIframeFormat(fileType) ? (
|
||||
<iframe
|
||||
src={`${fileUrl}#toolbar=0`}
|
||||
className="w-full h-full flex-1 min-h-0 border-none"
|
||||
title="File Viewer"
|
||||
/>
|
||||
) : isMarkdownFormat(fileType) ? (
|
||||
<ScrollIndicatorDiv
|
||||
className="flex-1 min-h-0 p-4"
|
||||
variant="shadow"
|
||||
>
|
||||
{csvData ? (
|
||||
<Table>
|
||||
<TableHeader className="sticky top-0 z-sticky">
|
||||
<TableRow className="bg-background-tint-02">
|
||||
{csvData.headers.map((h, i) => (
|
||||
<TableHead key={i}>
|
||||
<Text
|
||||
as="p"
|
||||
className="line-clamp-2 font-medium"
|
||||
text03
|
||||
mainUiBody
|
||||
>
|
||||
{h}
|
||||
</Text>
|
||||
</TableHead>
|
||||
))}
|
||||
</TableRow>
|
||||
</TableHeader>
|
||||
<TableBody>
|
||||
{csvData.rows.map((row, rIdx) => (
|
||||
<TableRow key={rIdx}>
|
||||
{csvData.headers.map((_, cIdx) => (
|
||||
<TableCell
|
||||
key={cIdx}
|
||||
className={cn(
|
||||
cIdx === 0 &&
|
||||
"sticky left-0 bg-background-tint-01",
|
||||
"py-0 px-4 whitespace-normal break-words"
|
||||
)}
|
||||
>
|
||||
{row?.[cIdx] ?? ""}
|
||||
</TableCell>
|
||||
))}
|
||||
</TableRow>
|
||||
))}
|
||||
</TableBody>
|
||||
</Table>
|
||||
) : (
|
||||
<MinimalMarkdown
|
||||
content={fileContent}
|
||||
className="w-full pb-4 h-full text-lg break-words"
|
||||
/>
|
||||
)}
|
||||
</ScrollIndicatorDiv>
|
||||
) : (
|
||||
<div className="flex flex-col items-center justify-center flex-1 min-h-0 p-6 gap-4">
|
||||
<Text as="p" text03 mainUiBody>
|
||||
This file format is not supported for preview.
|
||||
</Text>
|
||||
<Button onClick={handleDownload}>Download File</Button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</Section>
|
||||
</Modal.Body>
|
||||
|
||||
<Modal.Footer>
|
||||
<BasicModalFooter
|
||||
submit={<Button onClick={handleDownload}>Download File</Button>}
|
||||
/>
|
||||
</Modal.Footer>
|
||||
</Modal.Content>
|
||||
</Modal>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -18,7 +18,7 @@ import {
|
||||
import { useProjectsContext } from "@/providers/ProjectsContext";
|
||||
import MoveCustomAgentChatModal from "@/components/modals/MoveCustomAgentChatModal";
|
||||
import { UNNAMED_CHAT } from "@/lib/constants";
|
||||
import ShareChatSessionModal from "@/app/app/components/modal/ShareChatSessionModal";
|
||||
import ShareChatSessionModal from "@/sections/modals/ShareChatSessionModal";
|
||||
import SidebarTab from "@/refresh-components/buttons/SidebarTab";
|
||||
import IconButton from "@/refresh-components/buttons/IconButton";
|
||||
import { Button as OpalButton } from "@opal/components";
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import { test, expect } from "@playwright/test";
|
||||
import type { Page } from "@playwright/test";
|
||||
import { THEMES, setThemeBeforeNavigation } from "@tests/e2e/utils/theme";
|
||||
import { expectScreenshot } from "@tests/e2e/utils/visualRegression";
|
||||
|
||||
test.use({ storageState: "admin_auth.json" });
|
||||
@@ -162,16 +163,10 @@ async function verifyAdminPageNavigation(
|
||||
}
|
||||
}
|
||||
|
||||
const THEMES = ["light", "dark"] as const;
|
||||
|
||||
for (const theme of THEMES) {
|
||||
test.describe(`Admin pages (${theme} mode)`, () => {
|
||||
// Inject the theme into localStorage before every navigation so
|
||||
// next-themes picks it up on first render.
|
||||
test.beforeEach(async ({ page }) => {
|
||||
await page.addInitScript((t: string) => {
|
||||
localStorage.setItem("theme", t);
|
||||
}, theme);
|
||||
await setThemeBeforeNavigation(page, theme);
|
||||
});
|
||||
|
||||
for (const snapshot of ADMIN_PAGES) {
|
||||
|
||||
@@ -141,7 +141,7 @@ test.describe("Web Content Provider Configuration", () => {
|
||||
await modalDialog
|
||||
.getByRole("button", { name: "Connect", exact: true })
|
||||
.click();
|
||||
await expect(modalDialog).not.toBeVisible({ timeout: 30000 });
|
||||
await expect(modalDialog).not.toBeVisible({ timeout: 60000 });
|
||||
await page.waitForLoadState("networkidle");
|
||||
} else if (await setDefaultButton.isVisible()) {
|
||||
// If already configured but not active, set as default
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import { test, expect, Page, Locator } from "@playwright/test";
|
||||
import { OnyxApiClient } from "@tests/e2e/utils/onyxApiClient";
|
||||
|
||||
test.use({ storageState: "admin_auth.json" });
|
||||
import { loginAs } from "@tests/e2e/utils/auth";
|
||||
import { expectScreenshot } from "@tests/e2e/utils/visualRegression";
|
||||
|
||||
// Test data storage
|
||||
const TEST_PREFIX = `E2E-CMD-${Date.now()}`;
|
||||
const TEST_PREFIX = "E2E-CMD";
|
||||
let chatSessionIds: string[] = [];
|
||||
let projectIds: number[] = [];
|
||||
|
||||
@@ -12,17 +12,9 @@ let projectIds: number[] = [];
|
||||
* Helper to get the command menu dialog locator (using the content wrapper)
|
||||
*/
|
||||
function getCommandMenuContent(page: Page): Locator {
|
||||
// Use DialogPrimitive.Content which has role="dialog" and contains the visually-hidden title
|
||||
return page.locator('[role="dialog"]:has([data-command-menu-list])');
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper to get the command menu list locator
|
||||
*/
|
||||
function getCommandMenuList(page: Page): Locator {
|
||||
return page.locator("[data-command-menu-list]");
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper to open the command menu and return a scoped locator
|
||||
*/
|
||||
@@ -36,25 +28,20 @@ async function openCommandMenu(page: Page): Promise<Locator> {
|
||||
}
|
||||
|
||||
test.describe("Chat Search Command Menu", () => {
|
||||
// Create all test data ONCE before all tests
|
||||
test.beforeAll(async ({ browser }) => {
|
||||
const context = await browser.newContext({
|
||||
storageState: "admin_auth.json",
|
||||
});
|
||||
const context = await browser.newContext();
|
||||
const page = await context.newPage();
|
||||
await loginAs(page, "user");
|
||||
const client = new OnyxApiClient(page.request);
|
||||
|
||||
// Navigate to app to establish session
|
||||
await page.goto("/app");
|
||||
await page.waitForLoadState("networkidle");
|
||||
|
||||
// Create 5 chat sessions
|
||||
for (let i = 1; i <= 5; i++) {
|
||||
const id = await client.createChatSession(`${TEST_PREFIX} Chat ${i}`);
|
||||
chatSessionIds.push(id);
|
||||
}
|
||||
|
||||
// Create 4 projects
|
||||
for (let i = 1; i <= 4; i++) {
|
||||
const id = await client.createProject(`${TEST_PREFIX} Project ${i}`);
|
||||
projectIds.push(id);
|
||||
@@ -63,24 +50,18 @@ test.describe("Chat Search Command Menu", () => {
|
||||
await context.close();
|
||||
});
|
||||
|
||||
// Cleanup all test data ONCE after all tests
|
||||
test.afterAll(async ({ browser }) => {
|
||||
const context = await browser.newContext({
|
||||
storageState: "admin_auth.json",
|
||||
});
|
||||
const context = await browser.newContext();
|
||||
const page = await context.newPage();
|
||||
await loginAs(page, "user");
|
||||
const client = new OnyxApiClient(page.request);
|
||||
|
||||
// Navigate to app to establish session
|
||||
await page.goto("/app");
|
||||
await page.waitForLoadState("networkidle");
|
||||
|
||||
// Delete chat sessions
|
||||
for (const id of chatSessionIds) {
|
||||
await client.deleteChatSession(id);
|
||||
}
|
||||
|
||||
// Delete projects
|
||||
for (const id of projectIds) {
|
||||
await client.deleteProject(id);
|
||||
}
|
||||
@@ -88,472 +69,269 @@ test.describe("Chat Search Command Menu", () => {
|
||||
await context.close();
|
||||
});
|
||||
|
||||
test.describe("Menu Opening", () => {
|
||||
test("Opens when clicking sidebar search trigger", async ({ page }) => {
|
||||
await page.goto("/app");
|
||||
await page.waitForLoadState("networkidle");
|
||||
|
||||
const dialog = await openCommandMenu(page);
|
||||
|
||||
await expect(
|
||||
dialog.getByPlaceholder("Search chat sessions, projects...")
|
||||
).toBeVisible();
|
||||
// "New Session" action should be visible within the command menu
|
||||
await expect(
|
||||
dialog.locator('[data-command-item="new-session"]')
|
||||
).toBeVisible();
|
||||
});
|
||||
|
||||
test("Shows search input with placeholder and focus", async ({ page }) => {
|
||||
await page.goto("/app");
|
||||
await page.waitForLoadState("networkidle");
|
||||
|
||||
const dialog = await openCommandMenu(page);
|
||||
|
||||
const input = dialog.getByPlaceholder(
|
||||
"Search chat sessions, projects..."
|
||||
);
|
||||
await expect(input).toBeVisible();
|
||||
await expect(input).toBeFocused();
|
||||
});
|
||||
|
||||
test('Shows "New Session" action when no search term', async ({ page }) => {
|
||||
await page.goto("/app");
|
||||
await page.waitForLoadState("networkidle");
|
||||
|
||||
const dialog = await openCommandMenu(page);
|
||||
|
||||
// Use data-command-item attribute to target the action
|
||||
await expect(
|
||||
dialog.locator('[data-command-item="new-session"]')
|
||||
).toBeVisible();
|
||||
});
|
||||
test.beforeEach(async ({ page }) => {
|
||||
await page.context().clearCookies();
|
||||
await loginAs(page, "user");
|
||||
await page.goto("/app");
|
||||
await page.waitForLoadState("networkidle");
|
||||
});
|
||||
|
||||
test.describe("Preview Display", () => {
|
||||
test("Shows at most 4 chat sessions (PREVIEW_CHATS_LIMIT)", async ({
|
||||
page,
|
||||
}) => {
|
||||
await page.goto("/app");
|
||||
await page.waitForLoadState("networkidle");
|
||||
// -- Opening --
|
||||
|
||||
const dialog = await openCommandMenu(page);
|
||||
test("Opens with search input, New Session action, and correct positioning", async ({
|
||||
page,
|
||||
}) => {
|
||||
const dialog = await openCommandMenu(page);
|
||||
|
||||
// Should show at most 4 chat sessions in preview mode
|
||||
const chatItems = dialog.locator('[data-command-item^="chat-"]');
|
||||
const chatCount = await chatItems.count();
|
||||
// In "all" filter with no search, should show max 4 chats
|
||||
expect(chatCount).toBeLessThanOrEqual(4);
|
||||
});
|
||||
await expect(
|
||||
dialog.getByPlaceholder("Search chat sessions, projects...")
|
||||
).toBeFocused();
|
||||
await expect(
|
||||
dialog.locator('[data-command-item="new-session"]')
|
||||
).toBeVisible();
|
||||
|
||||
test("Shows at most 3 projects (PREVIEW_PROJECTS_LIMIT)", async ({
|
||||
page,
|
||||
}) => {
|
||||
await page.goto("/app");
|
||||
await page.waitForLoadState("networkidle");
|
||||
|
||||
const dialog = await openCommandMenu(page);
|
||||
|
||||
// Should show at most 3 projects in preview mode
|
||||
const projectItems = dialog.locator('[data-command-item^="project-"]');
|
||||
const projectCount = await projectItems.count();
|
||||
// In "all" filter with no search, should show max 3 projects
|
||||
expect(projectCount).toBeLessThanOrEqual(3);
|
||||
});
|
||||
|
||||
test('Shows "Recent Sessions" filter', async ({ page }) => {
|
||||
await page.goto("/app");
|
||||
await page.waitForLoadState("networkidle");
|
||||
|
||||
const dialog = await openCommandMenu(page);
|
||||
|
||||
// The "Recent Sessions" filter should be visible
|
||||
await expect(
|
||||
dialog.locator('[data-command-item="recent-sessions"]')
|
||||
).toBeVisible();
|
||||
});
|
||||
|
||||
test('Shows "Projects" filter', async ({ page }) => {
|
||||
await page.goto("/app");
|
||||
await page.waitForLoadState("networkidle");
|
||||
|
||||
const dialog = await openCommandMenu(page);
|
||||
|
||||
// The "Projects" filter should be visible
|
||||
await expect(
|
||||
dialog.locator('[data-command-item="projects"]')
|
||||
).toBeVisible();
|
||||
});
|
||||
|
||||
test('Shows "New Project" action', async ({ page }) => {
|
||||
await page.goto("/app");
|
||||
await page.waitForLoadState("networkidle");
|
||||
|
||||
const dialog = await openCommandMenu(page);
|
||||
|
||||
await expect(
|
||||
dialog.locator('[data-command-item="new-project"]')
|
||||
).toBeVisible();
|
||||
});
|
||||
await expectScreenshot(page, { name: "command-menu-default-open" });
|
||||
});
|
||||
|
||||
test.describe("Filter Expansion", () => {
|
||||
test('Click "Recent Sessions" filter shows all 5 chats', async ({
|
||||
page,
|
||||
}) => {
|
||||
await page.goto("/app");
|
||||
await page.waitForLoadState("networkidle");
|
||||
// -- Preview limits --
|
||||
|
||||
const dialog = await openCommandMenu(page);
|
||||
test("Shows at most 4 chats and 3 projects in preview", async ({ page }) => {
|
||||
const dialog = await openCommandMenu(page);
|
||||
|
||||
// Click on Recent Sessions filter to expand
|
||||
await dialog.locator('[data-command-item="recent-sessions"]').click();
|
||||
const chatCount = await dialog
|
||||
.locator('[data-command-item^="chat-"]')
|
||||
.count();
|
||||
expect(chatCount).toBeLessThanOrEqual(4);
|
||||
|
||||
// Wait for the filter to be applied and all chats to load
|
||||
await page.waitForTimeout(500);
|
||||
|
||||
// Should now show all 5 test chats - use data-command-item to find them
|
||||
for (let i = 1; i <= 5; i++) {
|
||||
await expect(
|
||||
dialog.locator(`[data-command-item="chat-${chatSessionIds[i - 1]}"]`)
|
||||
).toBeVisible();
|
||||
}
|
||||
});
|
||||
|
||||
test('Filter chip "Sessions" appears in header when chats filter is active', async ({
|
||||
page,
|
||||
}) => {
|
||||
await page.goto("/app");
|
||||
await page.waitForLoadState("networkidle");
|
||||
|
||||
const dialog = await openCommandMenu(page);
|
||||
|
||||
// Click on Recent Sessions filter
|
||||
await dialog.locator('[data-command-item="recent-sessions"]').click();
|
||||
|
||||
// The filter chip should appear (look for the editable tag with "Sessions")
|
||||
await expect(dialog.getByText("Sessions")).toBeVisible();
|
||||
});
|
||||
|
||||
test('Click "Projects" filter shows all 4 projects', async ({ page }) => {
|
||||
await page.goto("/app");
|
||||
await page.waitForLoadState("networkidle");
|
||||
|
||||
const dialog = await openCommandMenu(page);
|
||||
|
||||
// Click on Projects filter to expand
|
||||
await dialog.locator('[data-command-item="projects"]').click();
|
||||
|
||||
// Wait for the filter to be applied
|
||||
await page.waitForTimeout(500);
|
||||
|
||||
// Should now show all 4 test projects - use data-command-item to find them
|
||||
for (let i = 1; i <= 4; i++) {
|
||||
await expect(
|
||||
dialog.locator(`[data-command-item="project-${projectIds[i - 1]}"]`)
|
||||
).toBeVisible();
|
||||
}
|
||||
});
|
||||
|
||||
test("Clicking filter chip X removes filter and returns to 'all'", async ({
|
||||
page,
|
||||
}) => {
|
||||
await page.goto("/app");
|
||||
await page.waitForLoadState("networkidle");
|
||||
|
||||
const dialog = await openCommandMenu(page);
|
||||
|
||||
// Click on Recent Sessions filter
|
||||
await dialog.locator('[data-command-item="recent-sessions"]').click();
|
||||
|
||||
// Wait for the filter to be applied
|
||||
await expect(dialog.getByText("Sessions")).toBeVisible();
|
||||
|
||||
// Click the X on the filter tag to remove it (aria-label is "Remove Sessions filter")
|
||||
await dialog
|
||||
.locator('button[aria-label="Remove Sessions filter"]')
|
||||
.click();
|
||||
|
||||
// Should be back to "all" view - "New Session" action should be visible again
|
||||
await expect(
|
||||
dialog.locator('[data-command-item="new-session"]')
|
||||
).toBeVisible();
|
||||
});
|
||||
|
||||
test("Backspace on empty input removes active filter", async ({ page }) => {
|
||||
await page.goto("/app");
|
||||
await page.waitForLoadState("networkidle");
|
||||
|
||||
const dialog = await openCommandMenu(page);
|
||||
|
||||
// Click on Recent Sessions filter
|
||||
await dialog.locator('[data-command-item="recent-sessions"]').click();
|
||||
|
||||
// Wait for the filter to be applied
|
||||
await expect(dialog.getByText("Sessions")).toBeVisible();
|
||||
|
||||
// Ensure focus is on the input field
|
||||
const input = dialog.getByPlaceholder(
|
||||
"Search chat sessions, projects..."
|
||||
);
|
||||
await input.focus();
|
||||
|
||||
// Press backspace on empty input to remove filter
|
||||
await page.keyboard.press("Backspace");
|
||||
|
||||
// Should be back to "all" view
|
||||
await expect(
|
||||
dialog.locator('[data-command-item="new-session"]')
|
||||
).toBeVisible();
|
||||
});
|
||||
|
||||
test("Backspace on empty input with no filter closes menu", async ({
|
||||
page,
|
||||
}) => {
|
||||
await page.goto("/app");
|
||||
await page.waitForLoadState("networkidle");
|
||||
|
||||
await openCommandMenu(page);
|
||||
|
||||
// Press backspace on empty input (no filter active)
|
||||
await page.keyboard.press("Backspace");
|
||||
|
||||
// Menu should close
|
||||
await expect(getCommandMenuContent(page)).not.toBeVisible();
|
||||
});
|
||||
const projectCount = await dialog
|
||||
.locator('[data-command-item^="project-"]')
|
||||
.count();
|
||||
expect(projectCount).toBeLessThanOrEqual(3);
|
||||
});
|
||||
|
||||
test.describe("Search Filtering", () => {
|
||||
test("Search finds matching chat session", async ({ page }) => {
|
||||
await page.goto("/app");
|
||||
await page.waitForLoadState("networkidle");
|
||||
test('Shows "Recent Sessions", "Projects" filters and "New Project" action', async ({
|
||||
page,
|
||||
}) => {
|
||||
const dialog = await openCommandMenu(page);
|
||||
|
||||
const dialog = await openCommandMenu(page);
|
||||
|
||||
const input = dialog.getByPlaceholder(
|
||||
"Search chat sessions, projects..."
|
||||
);
|
||||
await input.fill(`${TEST_PREFIX} Chat 3`);
|
||||
|
||||
// Wait for search results
|
||||
await page.waitForTimeout(500);
|
||||
|
||||
// Should show the matching chat - use specific data-command-item
|
||||
await expect(
|
||||
dialog.locator(`[data-command-item="chat-${chatSessionIds[2]}"]`)
|
||||
).toBeVisible();
|
||||
});
|
||||
|
||||
test("Search finds matching project", async ({ page }) => {
|
||||
await page.goto("/app");
|
||||
await page.waitForLoadState("networkidle");
|
||||
|
||||
const dialog = await openCommandMenu(page);
|
||||
|
||||
const input = dialog.getByPlaceholder(
|
||||
"Search chat sessions, projects..."
|
||||
);
|
||||
await input.fill(`${TEST_PREFIX} Project 2`);
|
||||
|
||||
// Wait for search results
|
||||
await page.waitForTimeout(500);
|
||||
|
||||
// Should show the matching project - use specific data-command-item
|
||||
await expect(
|
||||
dialog.locator(`[data-command-item="project-${projectIds[1]}"]`)
|
||||
).toBeVisible();
|
||||
});
|
||||
|
||||
test('Search shows "Create New Project" action with typed name', async ({
|
||||
page,
|
||||
}) => {
|
||||
await page.goto("/app");
|
||||
await page.waitForLoadState("networkidle");
|
||||
|
||||
const dialog = await openCommandMenu(page);
|
||||
|
||||
const input = dialog.getByPlaceholder(
|
||||
"Search chat sessions, projects..."
|
||||
);
|
||||
await input.fill("my custom project name");
|
||||
|
||||
// Should show create project action with the search term
|
||||
await expect(
|
||||
dialog.locator('[data-command-item="create-project-with-name"]')
|
||||
).toBeVisible();
|
||||
});
|
||||
|
||||
test('Search with no results shows "No results found"', async ({
|
||||
page,
|
||||
}) => {
|
||||
await page.goto("/app");
|
||||
await page.waitForLoadState("networkidle");
|
||||
|
||||
const dialog = await openCommandMenu(page);
|
||||
|
||||
const input = dialog.getByPlaceholder(
|
||||
"Search chat sessions, projects..."
|
||||
);
|
||||
await input.fill("xyz123nonexistent9999");
|
||||
|
||||
// Wait for search to complete
|
||||
await page.waitForTimeout(500);
|
||||
|
||||
// Should show no results message or the "No more results" separator
|
||||
// The component shows "No results found" when there are no matches
|
||||
const noResults = dialog.getByText("No results found");
|
||||
const noMore = dialog.getByText("No more results");
|
||||
await expect(noResults.or(noMore)).toBeVisible();
|
||||
});
|
||||
await expect(
|
||||
dialog.locator('[data-command-item="recent-sessions"]')
|
||||
).toBeVisible();
|
||||
await expect(
|
||||
dialog.locator('[data-command-item="projects"]')
|
||||
).toBeVisible();
|
||||
await expect(
|
||||
dialog.locator('[data-command-item="new-project"]')
|
||||
).toBeVisible();
|
||||
});
|
||||
|
||||
test.describe("Navigation Actions", () => {
|
||||
test('"New Session" click navigates to /app', async ({ page }) => {
|
||||
await page.goto("/chat");
|
||||
await page.waitForLoadState("networkidle");
|
||||
// -- Filter expansion --
|
||||
|
||||
const dialog = await openCommandMenu(page);
|
||||
test('"Recent Sessions" filter expands to show all 5 chats', async ({
|
||||
page,
|
||||
}) => {
|
||||
const dialog = await openCommandMenu(page);
|
||||
await dialog.locator('[data-command-item="recent-sessions"]').click();
|
||||
|
||||
// Click New Session action
|
||||
await dialog.locator('[data-command-item="new-session"]').click();
|
||||
await page.waitForTimeout(500);
|
||||
|
||||
// Should navigate to /app
|
||||
await page.waitForURL(/\/app/);
|
||||
expect(page.url()).toContain("/app");
|
||||
});
|
||||
for (let i = 1; i <= 5; i++) {
|
||||
await expect(
|
||||
dialog.locator(`[data-command-item="chat-${chatSessionIds[i - 1]}"]`)
|
||||
).toBeVisible();
|
||||
}
|
||||
|
||||
test("Click chat session navigates to /chat?chatId={id}", async ({
|
||||
page,
|
||||
}) => {
|
||||
await page.goto("/app");
|
||||
await page.waitForLoadState("networkidle");
|
||||
await expect(dialog.getByText("Sessions")).toBeVisible();
|
||||
|
||||
const dialog = await openCommandMenu(page);
|
||||
|
||||
// Search for a specific chat
|
||||
const input = dialog.getByPlaceholder(
|
||||
"Search chat sessions, projects..."
|
||||
);
|
||||
await input.fill(`${TEST_PREFIX} Chat 1`);
|
||||
|
||||
// Wait for search results
|
||||
await page.waitForTimeout(500);
|
||||
|
||||
// Click on the chat using data-command-item
|
||||
await dialog
|
||||
.locator(`[data-command-item="chat-${chatSessionIds[0]}"]`)
|
||||
.click();
|
||||
|
||||
// Should navigate to the chat URL
|
||||
await page.waitForURL(/chatId=/);
|
||||
expect(page.url()).toContain(`chatId=${chatSessionIds[0]}`);
|
||||
});
|
||||
|
||||
test("Click project navigates to /chat?projectId={id}", async ({
|
||||
page,
|
||||
}) => {
|
||||
await page.goto("/app");
|
||||
await page.waitForLoadState("networkidle");
|
||||
|
||||
const dialog = await openCommandMenu(page);
|
||||
|
||||
// Search for a specific project
|
||||
const input = dialog.getByPlaceholder(
|
||||
"Search chat sessions, projects..."
|
||||
);
|
||||
await input.fill(`${TEST_PREFIX} Project 1`);
|
||||
|
||||
// Wait for search results
|
||||
await page.waitForTimeout(500);
|
||||
|
||||
// Click on the project using data-command-item
|
||||
await dialog
|
||||
.locator(`[data-command-item="project-${projectIds[0]}"]`)
|
||||
.click();
|
||||
|
||||
// Should navigate to the project URL
|
||||
await page.waitForURL(/projectId=/);
|
||||
expect(page.url()).toContain(`projectId=${projectIds[0]}`);
|
||||
});
|
||||
|
||||
test('"New Project" click opens create project modal', async ({ page }) => {
|
||||
await page.goto("/app");
|
||||
await page.waitForLoadState("networkidle");
|
||||
|
||||
const dialog = await openCommandMenu(page);
|
||||
|
||||
// Click New Project action
|
||||
await dialog.locator('[data-command-item="new-project"]').click();
|
||||
|
||||
// Should open the create project modal
|
||||
await expect(page.getByText("Create New Project")).toBeVisible();
|
||||
});
|
||||
await expectScreenshot(page, { name: "command-menu-sessions-filter" });
|
||||
});
|
||||
|
||||
test.describe("Menu State", () => {
|
||||
test("Menu closes after selecting an action/item", async ({ page }) => {
|
||||
await page.goto("/app");
|
||||
await page.waitForLoadState("networkidle");
|
||||
test('"Projects" filter expands to show all 4 projects', async ({ page }) => {
|
||||
const dialog = await openCommandMenu(page);
|
||||
await dialog.locator('[data-command-item="projects"]').click();
|
||||
|
||||
const dialog = await openCommandMenu(page);
|
||||
await page.waitForTimeout(500);
|
||||
|
||||
// Select New Session
|
||||
await dialog.locator('[data-command-item="new-session"]').click();
|
||||
|
||||
// Menu should close
|
||||
await expect(getCommandMenuContent(page)).not.toBeVisible();
|
||||
});
|
||||
|
||||
test("Menu state resets when reopened (search cleared, filter reset)", async ({
|
||||
page,
|
||||
}) => {
|
||||
await page.goto("/app");
|
||||
await page.waitForLoadState("networkidle");
|
||||
|
||||
// Open menu and apply a filter first (filter is only visible when search is empty)
|
||||
let dialog = await openCommandMenu(page);
|
||||
await dialog.locator('[data-command-item="recent-sessions"]').click();
|
||||
|
||||
// Wait for the filter to be applied
|
||||
await expect(dialog.getByText("Sessions")).toBeVisible();
|
||||
|
||||
// Now type something in the search
|
||||
const input = dialog.getByPlaceholder(
|
||||
"Search chat sessions, projects..."
|
||||
);
|
||||
await input.fill("test query");
|
||||
|
||||
// Close with Escape
|
||||
await page.keyboard.press("Escape");
|
||||
|
||||
// Wait for menu to close
|
||||
await expect(getCommandMenuContent(page)).not.toBeVisible();
|
||||
|
||||
// Reopen
|
||||
dialog = await openCommandMenu(page);
|
||||
|
||||
// Search input should be empty
|
||||
for (let i = 1; i <= 4; i++) {
|
||||
await expect(
|
||||
dialog.getByPlaceholder("Search chat sessions, projects...")
|
||||
).toHaveValue("");
|
||||
|
||||
// Should be back to "all" view with "New Session" action visible
|
||||
await expect(
|
||||
dialog.locator('[data-command-item="new-session"]')
|
||||
dialog.locator(`[data-command-item="project-${projectIds[i - 1]}"]`)
|
||||
).toBeVisible();
|
||||
});
|
||||
}
|
||||
|
||||
test("Escape closes menu", async ({ page }) => {
|
||||
await page.goto("/app");
|
||||
await page.waitForLoadState("networkidle");
|
||||
await expectScreenshot(page, { name: "command-menu-projects-filter" });
|
||||
});
|
||||
|
||||
await openCommandMenu(page);
|
||||
test("Filter chip X removes filter and returns to all", async ({ page }) => {
|
||||
const dialog = await openCommandMenu(page);
|
||||
await dialog.locator('[data-command-item="recent-sessions"]').click();
|
||||
await expect(dialog.getByText("Sessions")).toBeVisible();
|
||||
|
||||
// Press Escape
|
||||
await page.keyboard.press("Escape");
|
||||
await dialog.locator('button[aria-label="Remove Sessions filter"]').click();
|
||||
|
||||
// Menu should close
|
||||
await expect(getCommandMenuContent(page)).not.toBeVisible();
|
||||
});
|
||||
await expect(
|
||||
dialog.locator('[data-command-item="new-session"]')
|
||||
).toBeVisible();
|
||||
});
|
||||
|
||||
test("Backspace on empty input removes active filter", async ({ page }) => {
|
||||
const dialog = await openCommandMenu(page);
|
||||
await dialog.locator('[data-command-item="recent-sessions"]').click();
|
||||
await expect(dialog.getByText("Sessions")).toBeVisible();
|
||||
|
||||
const input = dialog.getByPlaceholder("Search chat sessions, projects...");
|
||||
await input.focus();
|
||||
await page.keyboard.press("Backspace");
|
||||
|
||||
await expect(
|
||||
dialog.locator('[data-command-item="new-session"]')
|
||||
).toBeVisible();
|
||||
});
|
||||
|
||||
test("Backspace on empty input with no filter closes menu", async ({
|
||||
page,
|
||||
}) => {
|
||||
await openCommandMenu(page);
|
||||
await page.keyboard.press("Backspace");
|
||||
await expect(getCommandMenuContent(page)).not.toBeVisible();
|
||||
});
|
||||
|
||||
// -- Search --
|
||||
|
||||
test("Search finds matching chat session", async ({ page }) => {
|
||||
const dialog = await openCommandMenu(page);
|
||||
|
||||
const input = dialog.getByPlaceholder("Search chat sessions, projects...");
|
||||
await input.fill(`${TEST_PREFIX} Chat 3`);
|
||||
await page.waitForTimeout(500);
|
||||
|
||||
await expect(
|
||||
dialog.locator(`[data-command-item="chat-${chatSessionIds[2]}"]`)
|
||||
).toBeVisible();
|
||||
|
||||
await expectScreenshot(page, { name: "command-menu-search-results" });
|
||||
});
|
||||
|
||||
test("Search finds matching project", async ({ page }) => {
|
||||
const dialog = await openCommandMenu(page);
|
||||
|
||||
const input = dialog.getByPlaceholder("Search chat sessions, projects...");
|
||||
await input.fill(`${TEST_PREFIX} Project 2`);
|
||||
await page.waitForTimeout(500);
|
||||
|
||||
await expect(
|
||||
dialog.locator(`[data-command-item="project-${projectIds[1]}"]`)
|
||||
).toBeVisible();
|
||||
});
|
||||
|
||||
test('Search shows "Create New Project" action with typed name', async ({
|
||||
page,
|
||||
}) => {
|
||||
const dialog = await openCommandMenu(page);
|
||||
|
||||
const input = dialog.getByPlaceholder("Search chat sessions, projects...");
|
||||
await input.fill("my custom project name");
|
||||
|
||||
await expect(
|
||||
dialog.locator('[data-command-item="create-project-with-name"]')
|
||||
).toBeVisible();
|
||||
});
|
||||
|
||||
test("Search with no results shows empty state", async ({ page }) => {
|
||||
const dialog = await openCommandMenu(page);
|
||||
|
||||
const input = dialog.getByPlaceholder("Search chat sessions, projects...");
|
||||
await input.fill("xyz123nonexistent9999");
|
||||
await page.waitForTimeout(500);
|
||||
|
||||
const noResults = dialog.getByText("No results found");
|
||||
const noMore = dialog.getByText("No more results");
|
||||
await expect(noResults.or(noMore)).toBeVisible();
|
||||
|
||||
await expectScreenshot(page, { name: "command-menu-no-results" });
|
||||
});
|
||||
|
||||
// -- Navigation --
|
||||
|
||||
test('"New Session" navigates to /app', async ({ page }) => {
|
||||
// Start from /chat so navigation is observable
|
||||
await page.goto("/chat");
|
||||
await page.waitForLoadState("networkidle");
|
||||
|
||||
const dialog = await openCommandMenu(page);
|
||||
await dialog.locator('[data-command-item="new-session"]').click();
|
||||
|
||||
await page.waitForURL(/\/app/);
|
||||
expect(page.url()).toContain("/app");
|
||||
});
|
||||
|
||||
test("Clicking a chat session navigates to its URL", async ({ page }) => {
|
||||
const dialog = await openCommandMenu(page);
|
||||
|
||||
const input = dialog.getByPlaceholder("Search chat sessions, projects...");
|
||||
await input.fill(`${TEST_PREFIX} Chat 1`);
|
||||
await page.waitForTimeout(500);
|
||||
|
||||
await dialog
|
||||
.locator(`[data-command-item="chat-${chatSessionIds[0]}"]`)
|
||||
.click();
|
||||
|
||||
await page.waitForURL(/chatId=/);
|
||||
expect(page.url()).toContain(`chatId=${chatSessionIds[0]}`);
|
||||
});
|
||||
|
||||
test("Clicking a project navigates to its URL", async ({ page }) => {
|
||||
const dialog = await openCommandMenu(page);
|
||||
|
||||
const input = dialog.getByPlaceholder("Search chat sessions, projects...");
|
||||
await input.fill(`${TEST_PREFIX} Project 1`);
|
||||
await page.waitForTimeout(500);
|
||||
|
||||
await dialog
|
||||
.locator(`[data-command-item="project-${projectIds[0]}"]`)
|
||||
.click();
|
||||
|
||||
await page.waitForURL(/projectId=/);
|
||||
expect(page.url()).toContain(`projectId=${projectIds[0]}`);
|
||||
});
|
||||
|
||||
test('"New Project" opens create project modal', async ({ page }) => {
|
||||
const dialog = await openCommandMenu(page);
|
||||
await dialog.locator('[data-command-item="new-project"]').click();
|
||||
await expect(page.getByText("Create New Project")).toBeVisible();
|
||||
});
|
||||
|
||||
// -- Menu state --
|
||||
|
||||
test("Menu closes after selecting an item", async ({ page }) => {
|
||||
const dialog = await openCommandMenu(page);
|
||||
await dialog.locator('[data-command-item="new-session"]').click();
|
||||
await expect(getCommandMenuContent(page)).not.toBeVisible();
|
||||
});
|
||||
|
||||
test("Escape closes menu", async ({ page }) => {
|
||||
await openCommandMenu(page);
|
||||
await page.keyboard.press("Escape");
|
||||
await expect(getCommandMenuContent(page)).not.toBeVisible();
|
||||
});
|
||||
|
||||
test("Menu state resets when reopened", async ({ page }) => {
|
||||
let dialog = await openCommandMenu(page);
|
||||
await dialog.locator('[data-command-item="recent-sessions"]').click();
|
||||
await expect(dialog.getByText("Sessions")).toBeVisible();
|
||||
|
||||
const input = dialog.getByPlaceholder("Search chat sessions, projects...");
|
||||
await input.fill("test query");
|
||||
|
||||
await page.keyboard.press("Escape");
|
||||
await expect(getCommandMenuContent(page)).not.toBeVisible();
|
||||
|
||||
dialog = await openCommandMenu(page);
|
||||
|
||||
await expect(
|
||||
dialog.getByPlaceholder("Search chat sessions, projects...")
|
||||
).toHaveValue("");
|
||||
await expect(
|
||||
dialog.locator('[data-command-item="new-session"]')
|
||||
).toBeVisible();
|
||||
});
|
||||
});
|
||||
|
||||
662
web/tests/e2e/chat/chat_message_rendering.spec.ts
Normal file
662
web/tests/e2e/chat/chat_message_rendering.spec.ts
Normal file
@@ -0,0 +1,662 @@
|
||||
import { expect, Page, test } from "@playwright/test";
|
||||
import { loginAs } from "@tests/e2e/utils/auth";
|
||||
import { sendMessage } from "@tests/e2e/utils/chatActions";
|
||||
import { THEMES, setThemeBeforeNavigation } from "@tests/e2e/utils/theme";
|
||||
import { expectElementScreenshot } from "@tests/e2e/utils/visualRegression";
|
||||
|
||||
const SHORT_USER_MESSAGE = "What is Onyx?";
|
||||
|
||||
const LONG_USER_MESSAGE = `I've been evaluating several enterprise search and AI platforms for our organization, and I have a number of detailed questions about Onyx that I'd like to understand before we make a decision.
|
||||
|
||||
First, can you explain how Onyx handles document indexing across multiple data sources? We currently use Confluence, Google Drive, Slack, and GitHub, and we need to ensure that all of these can be indexed simultaneously without performance degradation.
|
||||
|
||||
Second, I'm interested in understanding the security model. Specifically, how does Onyx handle document-level permissions when syncing from sources that have their own ACL systems? Does it respect the original source permissions, or does it create its own permission layer?
|
||||
|
||||
Third, we have a requirement for real-time or near-real-time indexing. What is the typical latency between a document being updated in a source system and it becoming searchable in Onyx?
|
||||
|
||||
Finally, could you walk me through the architecture of the AI chat system? How does it decide which documents to reference when answering a question, and how does it handle cases where the retrieved documents might contain conflicting information?`;
|
||||
|
||||
const SHORT_AI_RESPONSE =
|
||||
"Onyx is an open-source AI-powered enterprise search platform that connects to your company's documents, apps, and people.";
|
||||
|
||||
const LONG_AI_RESPONSE = `Onyx is an open-source Gen-AI and Enterprise Search platform designed to connect to your company's documents, applications, and people. Let me address each of your questions in detail.
|
||||
|
||||
## Document Indexing
|
||||
|
||||
Onyx uses a **connector-based architecture** where each data source has a dedicated connector. These connectors run as background workers and can index simultaneously without interfering with each other. The supported connectors include:
|
||||
|
||||
- **Confluence** — Full page and space indexing with attachment support
|
||||
- **Google Drive** — File and folder indexing with shared drive support
|
||||
- **Slack** — Channel message indexing with thread support
|
||||
- **GitHub** — Repository, issue, and pull request indexing
|
||||
|
||||
Each connector runs on its own schedule and can be configured independently for polling frequency.
|
||||
|
||||
## Security Model
|
||||
|
||||
Onyx implements a **document-level permission system** that syncs with source ACLs. When documents are indexed, their permissions are preserved:
|
||||
|
||||
\`\`\`
|
||||
Source Permission → Onyx ACL Sync → Query-time Filtering
|
||||
\`\`\`
|
||||
|
||||
This means that when a user searches, they only see documents they have access to in the original source system. The permission sync runs periodically to stay up to date.
|
||||
|
||||
## Indexing Latency
|
||||
|
||||
The typical indexing latency depends on your configuration:
|
||||
|
||||
1. **Polling mode**: Documents are picked up on the next polling cycle (configurable, default 10 minutes)
|
||||
2. **Webhook mode**: Near real-time, typically under 30 seconds
|
||||
3. **Manual trigger**: Immediate indexing on demand
|
||||
|
||||
## AI Chat Architecture
|
||||
|
||||
The chat system uses a **Retrieval-Augmented Generation (RAG)** pipeline:
|
||||
|
||||
1. User query is analyzed and expanded
|
||||
2. Relevant documents are retrieved from the vector database (Vespa)
|
||||
3. Documents are ranked and filtered by relevance and permissions
|
||||
4. The LLM generates a response grounded in the retrieved documents
|
||||
5. Citations are attached to specific claims in the response
|
||||
|
||||
When documents contain conflicting information, the system presents the most relevant and recent information first, and includes citations so users can verify the source material themselves.`;
|
||||
|
||||
const MARKDOWN_AI_RESPONSE = `Here's a quick overview with various formatting:
|
||||
|
||||
### Key Features
|
||||
|
||||
| Feature | Status | Notes |
|
||||
|---------|--------|-------|
|
||||
| Enterprise Search | ✅ Available | Full-text and semantic |
|
||||
| AI Chat | ✅ Available | Multi-model support |
|
||||
| Connectors | ✅ Available | 30+ integrations |
|
||||
| Permissions | ✅ Available | Source ACL sync |
|
||||
|
||||
### Code Example
|
||||
|
||||
\`\`\`python
|
||||
from onyx import OnyxClient
|
||||
|
||||
client = OnyxClient(api_key="your-key")
|
||||
results = client.search("quarterly revenue report")
|
||||
|
||||
for doc in results:
|
||||
print(f"{doc.title}: {doc.score:.2f}")
|
||||
\`\`\`
|
||||
|
||||
> **Note**: Onyx supports both cloud and self-hosted deployments. The self-hosted option gives you full control over your data.
|
||||
|
||||
Key benefits include:
|
||||
|
||||
- **Privacy**: Your data stays within your infrastructure
|
||||
- **Flexibility**: Connect any data source via custom connectors
|
||||
- **Extensibility**: Open-source codebase with active community`;
|
||||
|
||||
interface MockDocument {
|
||||
document_id: string;
|
||||
semantic_identifier: string;
|
||||
link: string;
|
||||
source_type: string;
|
||||
blurb: string;
|
||||
is_internet: boolean;
|
||||
}
|
||||
|
||||
interface SearchMockOptions {
|
||||
content: string;
|
||||
queries: string[];
|
||||
documents: MockDocument[];
|
||||
/** Maps citation number -> document_id */
|
||||
citations: Record<number, string>;
|
||||
isInternetSearch?: boolean;
|
||||
}
|
||||
|
||||
let turnCounter = 0;
|
||||
|
||||
function buildMockStream(content: string): string {
|
||||
turnCounter += 1;
|
||||
const userMessageId = turnCounter * 100 + 1;
|
||||
const assistantMessageId = turnCounter * 100 + 2;
|
||||
|
||||
const packets = [
|
||||
{
|
||||
user_message_id: userMessageId,
|
||||
reserved_assistant_message_id: assistantMessageId,
|
||||
},
|
||||
{
|
||||
placement: { turn_index: 0, tab_index: 0 },
|
||||
obj: {
|
||||
type: "message_start",
|
||||
id: `mock-${assistantMessageId}`,
|
||||
content,
|
||||
final_documents: null,
|
||||
},
|
||||
},
|
||||
{
|
||||
placement: { turn_index: 0, tab_index: 0 },
|
||||
obj: { type: "stop", stop_reason: "finished" },
|
||||
},
|
||||
{
|
||||
message_id: assistantMessageId,
|
||||
citations: {},
|
||||
files: [],
|
||||
},
|
||||
];
|
||||
|
||||
return `${packets.map((p) => JSON.stringify(p)).join("\n")}\n`;
|
||||
}
|
||||
|
||||
function buildMockSearchStream(options: SearchMockOptions): string {
|
||||
turnCounter += 1;
|
||||
const userMessageId = turnCounter * 100 + 1;
|
||||
const assistantMessageId = turnCounter * 100 + 2;
|
||||
|
||||
const fullDocs = options.documents.map((doc) => ({
|
||||
...doc,
|
||||
boost: 0,
|
||||
hidden: false,
|
||||
score: 0.95,
|
||||
chunk_ind: 0,
|
||||
match_highlights: [],
|
||||
metadata: {},
|
||||
updated_at: null,
|
||||
}));
|
||||
|
||||
// Turn 0: search tool
|
||||
// Turn 1: answer + citations
|
||||
const packets: Record<string, unknown>[] = [
|
||||
{
|
||||
user_message_id: userMessageId,
|
||||
reserved_assistant_message_id: assistantMessageId,
|
||||
},
|
||||
{
|
||||
placement: { turn_index: 0, tab_index: 0 },
|
||||
obj: {
|
||||
type: "search_tool_start",
|
||||
...(options.isInternetSearch !== undefined && {
|
||||
is_internet_search: options.isInternetSearch,
|
||||
}),
|
||||
},
|
||||
},
|
||||
{
|
||||
placement: { turn_index: 0, tab_index: 0 },
|
||||
obj: { type: "search_tool_queries_delta", queries: options.queries },
|
||||
},
|
||||
{
|
||||
placement: { turn_index: 0, tab_index: 0 },
|
||||
obj: { type: "search_tool_documents_delta", documents: fullDocs },
|
||||
},
|
||||
{
|
||||
placement: { turn_index: 0, tab_index: 0 },
|
||||
obj: { type: "section_end" },
|
||||
},
|
||||
{
|
||||
placement: { turn_index: 1, tab_index: 0 },
|
||||
obj: {
|
||||
type: "message_start",
|
||||
id: `mock-${assistantMessageId}`,
|
||||
content: options.content,
|
||||
final_documents: fullDocs,
|
||||
},
|
||||
},
|
||||
...Object.entries(options.citations).map(([num, docId]) => ({
|
||||
placement: { turn_index: 1, tab_index: 0 },
|
||||
obj: {
|
||||
type: "citation_info",
|
||||
citation_number: Number(num),
|
||||
document_id: docId,
|
||||
},
|
||||
})),
|
||||
{
|
||||
placement: { turn_index: 1, tab_index: 0 },
|
||||
obj: { type: "stop", stop_reason: "finished" },
|
||||
},
|
||||
{
|
||||
message_id: assistantMessageId,
|
||||
citations: options.citations,
|
||||
files: [],
|
||||
},
|
||||
];
|
||||
|
||||
return `${packets.map((p) => JSON.stringify(p)).join("\n")}\n`;
|
||||
}
|
||||
|
||||
async function openChat(page: Page): Promise<void> {
|
||||
await page.goto("/app");
|
||||
await page.waitForLoadState("networkidle");
|
||||
await page.waitForSelector("#onyx-chat-input-textarea", { timeout: 15000 });
|
||||
}
|
||||
|
||||
async function mockChatEndpoint(
|
||||
page: Page,
|
||||
responseContent: string
|
||||
): Promise<void> {
|
||||
await page.route("**/api/chat/send-chat-message", async (route) => {
|
||||
await route.fulfill({
|
||||
status: 200,
|
||||
contentType: "text/plain",
|
||||
body: buildMockStream(responseContent),
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
async function mockChatEndpointSequence(
|
||||
page: Page,
|
||||
responses: string[]
|
||||
): Promise<void> {
|
||||
let callIndex = 0;
|
||||
await page.route("**/api/chat/send-chat-message", async (route) => {
|
||||
const content =
|
||||
responses[Math.min(callIndex, responses.length - 1)] ??
|
||||
responses[responses.length - 1]!;
|
||||
callIndex += 1;
|
||||
await route.fulfill({
|
||||
status: 200,
|
||||
contentType: "text/plain",
|
||||
body: buildMockStream(content),
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
async function screenshotChatContainer(
|
||||
page: Page,
|
||||
name: string
|
||||
): Promise<void> {
|
||||
const container = page.locator("[data-main-container]");
|
||||
await expect(container).toBeVisible();
|
||||
await expectElementScreenshot(container, { name });
|
||||
}
|
||||
|
||||
/**
|
||||
* Captures two screenshots of the chat container for long-content tests:
|
||||
* one scrolled to the top and one scrolled to the bottom. Both are captured
|
||||
* for the current theme, ensuring consistent scroll positions regardless of
|
||||
* whether the page was just navigated to (top) or just finished streaming (bottom).
|
||||
*/
|
||||
async function screenshotChatContainerTopAndBottom(
|
||||
page: Page,
|
||||
name: string
|
||||
): Promise<void> {
|
||||
const container = page.locator("[data-main-container]");
|
||||
await expect(container).toBeVisible();
|
||||
const scrollContainer = page.getByTestId("chat-scroll-container");
|
||||
|
||||
await scrollContainer.evaluate((el) => el.scrollTo({ top: 0 }));
|
||||
await expectElementScreenshot(container, { name: `${name}-top` });
|
||||
|
||||
await scrollContainer.evaluate((el) => el.scrollTo({ top: el.scrollHeight }));
|
||||
await expectElementScreenshot(container, { name: `${name}-bottom` });
|
||||
}
|
||||
|
||||
for (const theme of THEMES) {
|
||||
test.describe(`Chat Message Rendering (${theme} mode)`, () => {
|
||||
test.beforeEach(async ({ page }) => {
|
||||
turnCounter = 0;
|
||||
await page.context().clearCookies();
|
||||
await setThemeBeforeNavigation(page, theme);
|
||||
await loginAs(page, "user");
|
||||
});
|
||||
|
||||
test.describe("Short Messages", () => {
|
||||
test("short user message with short AI response renders correctly", async ({
|
||||
page,
|
||||
}) => {
|
||||
await openChat(page);
|
||||
await mockChatEndpoint(page, SHORT_AI_RESPONSE);
|
||||
|
||||
await sendMessage(page, SHORT_USER_MESSAGE);
|
||||
|
||||
const userMessage = page.locator("#onyx-human-message").first();
|
||||
await expect(userMessage).toContainText(SHORT_USER_MESSAGE);
|
||||
|
||||
const aiMessage = page.getByTestId("onyx-ai-message").first();
|
||||
await expect(aiMessage).toContainText("open-source AI-powered");
|
||||
|
||||
await screenshotChatContainer(
|
||||
page,
|
||||
`chat-short-message-short-response-${theme}`
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
test.describe("Long Messages", () => {
|
||||
test("long user message renders without truncation", async ({ page }) => {
|
||||
await openChat(page);
|
||||
await mockChatEndpoint(page, SHORT_AI_RESPONSE);
|
||||
|
||||
await sendMessage(page, LONG_USER_MESSAGE);
|
||||
|
||||
const userMessage = page.locator("#onyx-human-message").first();
|
||||
await expect(userMessage).toContainText("document indexing");
|
||||
await expect(userMessage).toContainText("security model");
|
||||
await expect(userMessage).toContainText("real-time or near-real-time");
|
||||
await expect(userMessage).toContainText("architecture of the AI chat");
|
||||
|
||||
await screenshotChatContainer(
|
||||
page,
|
||||
`chat-long-user-message-short-response-${theme}`
|
||||
);
|
||||
});
|
||||
|
||||
test("long AI response with markdown renders correctly", async ({
|
||||
page,
|
||||
}) => {
|
||||
await openChat(page);
|
||||
await mockChatEndpoint(page, LONG_AI_RESPONSE);
|
||||
|
||||
await sendMessage(page, SHORT_USER_MESSAGE);
|
||||
|
||||
const aiMessage = page.getByTestId("onyx-ai-message").first();
|
||||
await expect(aiMessage).toContainText("Document Indexing");
|
||||
await expect(aiMessage).toContainText("Security Model");
|
||||
await expect(aiMessage).toContainText("Indexing Latency");
|
||||
await expect(aiMessage).toContainText("AI Chat Architecture");
|
||||
|
||||
await screenshotChatContainerTopAndBottom(
|
||||
page,
|
||||
`chat-short-message-long-response-${theme}`
|
||||
);
|
||||
});
|
||||
|
||||
test("long user message with long AI response renders correctly", async ({
|
||||
page,
|
||||
}) => {
|
||||
await openChat(page);
|
||||
await mockChatEndpoint(page, LONG_AI_RESPONSE);
|
||||
|
||||
await sendMessage(page, LONG_USER_MESSAGE);
|
||||
|
||||
const userMessage = page.locator("#onyx-human-message").first();
|
||||
await expect(userMessage).toContainText("document indexing");
|
||||
|
||||
const aiMessage = page.getByTestId("onyx-ai-message").first();
|
||||
await expect(aiMessage).toContainText("Retrieval-Augmented Generation");
|
||||
|
||||
await screenshotChatContainerTopAndBottom(
|
||||
page,
|
||||
`chat-long-message-long-response-${theme}`
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
test.describe("Markdown and Code Rendering", () => {
|
||||
test("AI response with tables and code blocks renders correctly", async ({
|
||||
page,
|
||||
}) => {
|
||||
await openChat(page);
|
||||
await mockChatEndpoint(page, MARKDOWN_AI_RESPONSE);
|
||||
|
||||
await sendMessage(page, "Give me an overview of Onyx features");
|
||||
|
||||
const aiMessage = page.getByTestId("onyx-ai-message").first();
|
||||
await expect(aiMessage).toContainText("Key Features");
|
||||
await expect(aiMessage).toContainText("OnyxClient");
|
||||
await expect(aiMessage).toContainText("Privacy");
|
||||
|
||||
await screenshotChatContainer(
|
||||
page,
|
||||
`chat-markdown-code-response-${theme}`
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
test.describe("Multi-Turn Conversation", () => {
|
||||
test("multi-turn conversation renders all messages correctly", async ({
|
||||
page,
|
||||
}) => {
|
||||
await openChat(page);
|
||||
|
||||
const responses = [
|
||||
SHORT_AI_RESPONSE,
|
||||
"Yes, Onyx supports over 30 data source connectors including Confluence, Google Drive, Slack, GitHub, Jira, Notion, and many more.",
|
||||
"To get started, you can deploy Onyx using Docker Compose with a single command. The setup takes about 5 minutes.",
|
||||
];
|
||||
|
||||
await mockChatEndpointSequence(page, responses);
|
||||
|
||||
await sendMessage(page, SHORT_USER_MESSAGE);
|
||||
await expect(page.getByTestId("onyx-ai-message").first()).toContainText(
|
||||
"open-source AI-powered"
|
||||
);
|
||||
|
||||
await sendMessage(page, "What connectors does it support?");
|
||||
await expect(page.getByTestId("onyx-ai-message")).toHaveCount(2, {
|
||||
timeout: 30000,
|
||||
});
|
||||
|
||||
await sendMessage(page, "How do I get started?");
|
||||
await expect(page.getByTestId("onyx-ai-message")).toHaveCount(3, {
|
||||
timeout: 30000,
|
||||
});
|
||||
|
||||
const userMessages = page.locator("#onyx-human-message");
|
||||
await expect(userMessages).toHaveCount(3);
|
||||
|
||||
await screenshotChatContainerTopAndBottom(
|
||||
page,
|
||||
`chat-multi-turn-conversation-${theme}`
|
||||
);
|
||||
});
|
||||
|
||||
test("multi-turn with mixed message lengths renders correctly", async ({
|
||||
page,
|
||||
}) => {
|
||||
await openChat(page);
|
||||
|
||||
const responses = [LONG_AI_RESPONSE, SHORT_AI_RESPONSE];
|
||||
|
||||
await mockChatEndpointSequence(page, responses);
|
||||
|
||||
await sendMessage(page, LONG_USER_MESSAGE);
|
||||
await expect(page.getByTestId("onyx-ai-message").first()).toContainText(
|
||||
"Document Indexing"
|
||||
);
|
||||
|
||||
await sendMessage(page, SHORT_USER_MESSAGE);
|
||||
await expect(page.getByTestId("onyx-ai-message")).toHaveCount(2, {
|
||||
timeout: 30000,
|
||||
});
|
||||
|
||||
await screenshotChatContainerTopAndBottom(
|
||||
page,
|
||||
`chat-multi-turn-mixed-lengths-${theme}`
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
test.describe("Web Search with Citations", () => {
|
||||
const WEB_SEARCH_DOCUMENTS: MockDocument[] = [
|
||||
{
|
||||
document_id: "web-doc-1",
|
||||
semantic_identifier: "Onyx Documentation - Getting Started",
|
||||
link: "https://docs.onyx.app/getting-started",
|
||||
source_type: "web",
|
||||
blurb:
|
||||
"Onyx is an open-source enterprise search and AI platform. Deploy in minutes with Docker Compose.",
|
||||
is_internet: true,
|
||||
},
|
||||
{
|
||||
document_id: "web-doc-2",
|
||||
semantic_identifier: "Onyx GitHub Repository",
|
||||
link: "https://github.com/onyx-dot-app/onyx",
|
||||
source_type: "web",
|
||||
blurb:
|
||||
"Open-source Gen-AI platform with 30+ connectors. MIT licensed community edition.",
|
||||
is_internet: true,
|
||||
},
|
||||
{
|
||||
document_id: "web-doc-3",
|
||||
semantic_identifier: "Enterprise Search Comparison 2025",
|
||||
link: "https://example.com/enterprise-search-comparison",
|
||||
source_type: "web",
|
||||
blurb:
|
||||
"Comparing top enterprise search platforms including Onyx, Glean, and Coveo.",
|
||||
is_internet: true,
|
||||
},
|
||||
];
|
||||
|
||||
const WEB_SEARCH_RESPONSE = `Based on my web search, here's what I found about Onyx:
|
||||
|
||||
Onyx is an open-source enterprise search and AI platform that can be deployed in minutes using Docker Compose [[D1]](https://docs.onyx.app/getting-started). The project is hosted on GitHub and is MIT licensed for the community edition, with over 30 connectors available [[D2]](https://github.com/onyx-dot-app/onyx).
|
||||
|
||||
In comparisons with other enterprise search platforms, Onyx stands out for its open-source nature and self-hosted deployment option [[D3]](https://example.com/enterprise-search-comparison). Unlike proprietary alternatives, you maintain full control over your data and infrastructure.
|
||||
|
||||
Key advantages include:
|
||||
|
||||
- **Self-hosted**: Deploy on your own infrastructure
|
||||
- **Open source**: Full visibility into the codebase [[D2]](https://github.com/onyx-dot-app/onyx)
|
||||
- **Quick setup**: Get running in under 5 minutes [[D1]](https://docs.onyx.app/getting-started)
|
||||
- **Extensible**: 30+ pre-built connectors with custom connector support`;
|
||||
|
||||
test("web search response with citations renders correctly", async ({
|
||||
page,
|
||||
}) => {
|
||||
await openChat(page);
|
||||
|
||||
await page.route("**/api/chat/send-chat-message", async (route) => {
|
||||
await route.fulfill({
|
||||
status: 200,
|
||||
contentType: "text/plain",
|
||||
body: buildMockSearchStream({
|
||||
content: WEB_SEARCH_RESPONSE,
|
||||
queries: ["Onyx enterprise search platform overview"],
|
||||
documents: WEB_SEARCH_DOCUMENTS,
|
||||
citations: {
|
||||
1: "web-doc-1",
|
||||
2: "web-doc-2",
|
||||
3: "web-doc-3",
|
||||
},
|
||||
isInternetSearch: true,
|
||||
}),
|
||||
});
|
||||
});
|
||||
|
||||
await sendMessage(page, "Search the web for information about Onyx");
|
||||
|
||||
const aiMessage = page.getByTestId("onyx-ai-message").first();
|
||||
await expect(aiMessage).toContainText("open-source enterprise search");
|
||||
await expect(aiMessage).toContainText("Docker Compose");
|
||||
await expect(aiMessage).toContainText("MIT licensed");
|
||||
|
||||
await screenshotChatContainer(
|
||||
page,
|
||||
`chat-web-search-with-citations-${theme}`
|
||||
);
|
||||
});
|
||||
|
||||
test("internal document search response renders correctly", async ({
|
||||
page,
|
||||
}) => {
|
||||
const internalDocs: MockDocument[] = [
|
||||
{
|
||||
document_id: "confluence-doc-1",
|
||||
semantic_identifier: "Q3 2025 Engineering Roadmap",
|
||||
link: "https://company.atlassian.net/wiki/spaces/ENG/pages/123",
|
||||
source_type: "confluence",
|
||||
blurb:
|
||||
"Engineering priorities for Q3 include platform stability, new connector integrations, and performance improvements.",
|
||||
is_internet: false,
|
||||
},
|
||||
{
|
||||
document_id: "gdrive-doc-1",
|
||||
semantic_identifier: "Platform Architecture Overview.pdf",
|
||||
link: "https://drive.google.com/file/d/abc123",
|
||||
source_type: "google_drive",
|
||||
blurb:
|
||||
"Onyx platform architecture document covering microservices, data flow, and deployment topology.",
|
||||
is_internet: false,
|
||||
},
|
||||
];
|
||||
|
||||
const internalResponse = `Based on your company's internal documents, here is the engineering roadmap:
|
||||
|
||||
The Q3 2025 priorities focus on three main areas [[D1]](https://company.atlassian.net/wiki/spaces/ENG/pages/123):
|
||||
|
||||
1. **Platform stability** — Improving error handling and retry mechanisms across all connectors
|
||||
2. **New integrations** — Adding support for ServiceNow and Zendesk connectors
|
||||
3. **Performance** — Optimizing vector search latency and reducing indexing time
|
||||
|
||||
The platform architecture document provides additional context on how these improvements fit into the overall system design [[D2]](https://drive.google.com/file/d/abc123). The microservices architecture allows each component to be scaled independently.`;
|
||||
|
||||
await openChat(page);
|
||||
|
||||
await page.route("**/api/chat/send-chat-message", async (route) => {
|
||||
await route.fulfill({
|
||||
status: 200,
|
||||
contentType: "text/plain",
|
||||
body: buildMockSearchStream({
|
||||
content: internalResponse,
|
||||
queries: ["Q3 engineering roadmap priorities"],
|
||||
documents: internalDocs,
|
||||
citations: {
|
||||
1: "confluence-doc-1",
|
||||
2: "gdrive-doc-1",
|
||||
},
|
||||
isInternetSearch: false,
|
||||
}),
|
||||
});
|
||||
});
|
||||
|
||||
await sendMessage(page, "What are our engineering priorities for Q3?");
|
||||
|
||||
const aiMessage = page.getByTestId("onyx-ai-message").first();
|
||||
await expect(aiMessage).toContainText("Platform stability");
|
||||
await expect(aiMessage).toContainText("New integrations");
|
||||
await expect(aiMessage).toContainText("Performance");
|
||||
|
||||
await screenshotChatContainer(
|
||||
page,
|
||||
`chat-internal-search-with-citations-${theme}`
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
test.describe("Message Interaction States", () => {
|
||||
test("hovering over user message shows action buttons", async ({
|
||||
page,
|
||||
}) => {
|
||||
await openChat(page);
|
||||
await mockChatEndpoint(page, SHORT_AI_RESPONSE);
|
||||
|
||||
await sendMessage(page, SHORT_USER_MESSAGE);
|
||||
|
||||
const userMessage = page.locator("#onyx-human-message").first();
|
||||
await userMessage.hover();
|
||||
|
||||
const editButton = userMessage.getByTestId("HumanMessage/edit-button");
|
||||
await expect(editButton).toBeVisible({ timeout: 5000 });
|
||||
|
||||
await screenshotChatContainer(
|
||||
page,
|
||||
`chat-user-message-hover-state-${theme}`
|
||||
);
|
||||
});
|
||||
|
||||
test("AI message toolbar is visible after response completes", async ({
|
||||
page,
|
||||
}) => {
|
||||
await openChat(page);
|
||||
await mockChatEndpoint(page, SHORT_AI_RESPONSE);
|
||||
|
||||
await sendMessage(page, SHORT_USER_MESSAGE);
|
||||
|
||||
const aiMessage = page.getByTestId("onyx-ai-message").first();
|
||||
|
||||
const copyButton = aiMessage.getByTestId("AgentMessage/copy-button");
|
||||
const likeButton = aiMessage.getByTestId("AgentMessage/like-button");
|
||||
const dislikeButton = aiMessage.getByTestId(
|
||||
"AgentMessage/dislike-button"
|
||||
);
|
||||
|
||||
await expect(copyButton).toBeVisible({ timeout: 10000 });
|
||||
await expect(likeButton).toBeVisible();
|
||||
await expect(dislikeButton).toBeVisible();
|
||||
|
||||
await screenshotChatContainer(
|
||||
page,
|
||||
`chat-ai-message-with-toolbar-${theme}`
|
||||
);
|
||||
});
|
||||
});
|
||||
});
|
||||
}
|
||||
@@ -1,218 +0,0 @@
|
||||
import { test, expect, Page } from "@playwright/test";
|
||||
import { loginAsRandomUser } from "../utils/auth";
|
||||
|
||||
/**
|
||||
* Builds a newline-delimited JSON stream body matching the packet
|
||||
* format that useChatController expects:
|
||||
*
|
||||
* 1. MessageResponseIDInfo — identifies the user/assistant messages
|
||||
* 2. Packet-wrapped streaming objects ({placement, obj}) — the actual content
|
||||
* 3. BackendMessage — the final completed message
|
||||
*
|
||||
* Each line is a raw JSON object parsed by handleSSEStream.
|
||||
*/
|
||||
function buildMockStream(messageContent: string): string {
|
||||
const packets = [
|
||||
// 1. Message ID info — tells the frontend the message IDs
|
||||
JSON.stringify({
|
||||
user_message_id: 1,
|
||||
reserved_assistant_message_id: 2,
|
||||
}),
|
||||
// 2. Streaming content packets wrapped in {placement, obj}
|
||||
JSON.stringify({
|
||||
placement: { turn_index: 0 },
|
||||
obj: {
|
||||
type: "message_start",
|
||||
id: "mock-message-id",
|
||||
content: "",
|
||||
final_documents: null,
|
||||
},
|
||||
}),
|
||||
JSON.stringify({
|
||||
placement: { turn_index: 0 },
|
||||
obj: {
|
||||
type: "message_delta",
|
||||
content: messageContent,
|
||||
},
|
||||
}),
|
||||
JSON.stringify({
|
||||
placement: { turn_index: 0 },
|
||||
obj: {
|
||||
type: "message_end",
|
||||
},
|
||||
}),
|
||||
JSON.stringify({
|
||||
placement: { turn_index: 0 },
|
||||
obj: {
|
||||
type: "stop",
|
||||
stop_reason: "finished",
|
||||
},
|
||||
}),
|
||||
// 3. Final BackendMessage — the completed message record
|
||||
JSON.stringify({
|
||||
message_id: 2,
|
||||
message_type: "assistant",
|
||||
research_type: null,
|
||||
parent_message: 1,
|
||||
latest_child_message: null,
|
||||
message: messageContent,
|
||||
rephrased_query: null,
|
||||
context_docs: null,
|
||||
time_sent: new Date().toISOString(),
|
||||
citations: {},
|
||||
files: [],
|
||||
tool_call: null,
|
||||
overridden_model: null,
|
||||
}),
|
||||
];
|
||||
return packets.join("\n") + "\n";
|
||||
}
|
||||
|
||||
/**
|
||||
* Sends a message while intercepting the backend response with
|
||||
* a controlled mock stream. Returns once the AI message renders.
|
||||
*/
|
||||
async function sendMessageWithMockResponse(
|
||||
page: Page,
|
||||
userMessage: string,
|
||||
mockResponseContent: string
|
||||
) {
|
||||
const existingMessageCount = await page
|
||||
.locator('[data-testid="onyx-ai-message"]')
|
||||
.count();
|
||||
|
||||
// Intercept the send-chat-message endpoint and return our mock stream
|
||||
await page.route("**/api/chat/send-chat-message", async (route) => {
|
||||
await route.fulfill({
|
||||
status: 200,
|
||||
contentType: "application/json",
|
||||
body: buildMockStream(mockResponseContent),
|
||||
});
|
||||
});
|
||||
|
||||
await page.locator("#onyx-chat-input-textarea").click();
|
||||
await page.locator("#onyx-chat-input-textarea").fill(userMessage);
|
||||
await page.locator("#onyx-chat-input-send-button").click();
|
||||
|
||||
// Wait for the AI message to appear
|
||||
await expect(page.locator('[data-testid="onyx-ai-message"]')).toHaveCount(
|
||||
existingMessageCount + 1,
|
||||
{ timeout: 30000 }
|
||||
);
|
||||
|
||||
// Unroute so future requests go through normally
|
||||
await page.unroute("**/api/chat/send-chat-message");
|
||||
}
|
||||
|
||||
const MOCK_FILE_ID = "00000000-0000-0000-0000-000000000001";
|
||||
|
||||
test.describe("File preview modal from chat file links", () => {
|
||||
test.beforeEach(async ({ page }) => {
|
||||
await page.context().clearCookies();
|
||||
await loginAsRandomUser(page);
|
||||
await page.goto("/app");
|
||||
await page.waitForLoadState("networkidle");
|
||||
});
|
||||
|
||||
test("clicking a text file link opens the TextViewModal", async ({
|
||||
page,
|
||||
}) => {
|
||||
const mockContent = `Here is your file: [notes.txt](/api/chat/file/${MOCK_FILE_ID})`;
|
||||
|
||||
// Mock the file endpoint to return text content
|
||||
await page.route(`**/api/chat/file/${MOCK_FILE_ID}`, async (route) => {
|
||||
await route.fulfill({
|
||||
status: 200,
|
||||
contentType: "text/plain",
|
||||
body: "Hello from the mock file!",
|
||||
});
|
||||
});
|
||||
|
||||
await sendMessageWithMockResponse(page, "Give me the file", mockContent);
|
||||
|
||||
// Find the link in the AI message and click it
|
||||
const aiMessage = page.getByTestId("onyx-ai-message").last();
|
||||
const fileLink = aiMessage.locator("a").filter({ hasText: "notes.txt" });
|
||||
await expect(fileLink).toBeVisible({ timeout: 5000 });
|
||||
await fileLink.click();
|
||||
|
||||
// Verify the modal opens
|
||||
const modal = page.getByRole("dialog");
|
||||
await expect(modal).toBeVisible({ timeout: 5000 });
|
||||
|
||||
// Verify the file name is shown in the header
|
||||
await expect(modal.getByText("notes.txt")).toBeVisible();
|
||||
|
||||
// Verify the download button exists
|
||||
await expect(modal.getByText("Download File")).toBeVisible();
|
||||
|
||||
// Verify the file content is rendered
|
||||
await expect(modal.getByText("Hello from the mock file!")).toBeVisible();
|
||||
});
|
||||
|
||||
test("clicking a code file link opens the CodeViewModal with syntax highlighting", async ({
|
||||
page,
|
||||
}) => {
|
||||
const mockContent = `Here is your script: [app.py](/api/chat/file/${MOCK_FILE_ID})`;
|
||||
const pythonCode = 'def hello():\n print("Hello, world!")';
|
||||
|
||||
// Mock the file endpoint to return Python code
|
||||
await page.route(`**/api/chat/file/${MOCK_FILE_ID}`, async (route) => {
|
||||
await route.fulfill({
|
||||
status: 200,
|
||||
contentType: "application/octet-stream",
|
||||
body: pythonCode,
|
||||
});
|
||||
});
|
||||
|
||||
await sendMessageWithMockResponse(page, "Give me the script", mockContent);
|
||||
|
||||
// Find the link in the AI message and click it
|
||||
const aiMessage = page.getByTestId("onyx-ai-message").last();
|
||||
const fileLink = aiMessage.locator("a").filter({ hasText: "app.py" });
|
||||
await expect(fileLink).toBeVisible({ timeout: 5000 });
|
||||
await fileLink.click();
|
||||
|
||||
// Verify the modal opens
|
||||
const modal = page.getByRole("dialog");
|
||||
await expect(modal).toBeVisible({ timeout: 5000 });
|
||||
|
||||
// Verify the file name is shown
|
||||
await expect(modal.getByText("app.py")).toBeVisible();
|
||||
|
||||
// Verify the code content is rendered
|
||||
await expect(modal.getByText("Hello, world!")).toBeVisible();
|
||||
|
||||
// Verify the download button exists
|
||||
await expect(modal.getByText("Download File")).toBeVisible();
|
||||
});
|
||||
|
||||
test("download button triggers file download", async ({ page }) => {
|
||||
const mockContent = `Here: [data.csv](/api/chat/file/${MOCK_FILE_ID})`;
|
||||
|
||||
await page.route(`**/api/chat/file/${MOCK_FILE_ID}`, async (route) => {
|
||||
await route.fulfill({
|
||||
status: 200,
|
||||
contentType: "text/csv",
|
||||
body: "name,age\nAlice,30\nBob,25",
|
||||
});
|
||||
});
|
||||
|
||||
await sendMessageWithMockResponse(page, "Give me the csv", mockContent);
|
||||
|
||||
const aiMessage = page.getByTestId("onyx-ai-message").last();
|
||||
const fileLink = aiMessage.locator("a").filter({ hasText: "data.csv" });
|
||||
await expect(fileLink).toBeVisible({ timeout: 5000 });
|
||||
await fileLink.click();
|
||||
|
||||
const modal = page.getByRole("dialog");
|
||||
await expect(modal).toBeVisible({ timeout: 5000 });
|
||||
|
||||
// Click the download button and verify a download starts
|
||||
const downloadPromise = page.waitForEvent("download");
|
||||
await modal.getByText("Download File").last().click();
|
||||
const download = await downloadPromise;
|
||||
|
||||
expect(download.suggestedFilename()).toContain("data.csv");
|
||||
});
|
||||
});
|
||||
248
web/tests/e2e/chat/share_chat.spec.ts
Normal file
248
web/tests/e2e/chat/share_chat.spec.ts
Normal file
@@ -0,0 +1,248 @@
|
||||
import { test, expect } from "@playwright/test";
|
||||
import type { Page } from "@playwright/test";
|
||||
import { loginAsRandomUser } from "../utils/auth";
|
||||
import { expectElementScreenshot } from "../utils/visualRegression";
|
||||
|
||||
async function sendMessageAndWaitForChat(page: Page, message: string) {
|
||||
await page.locator("#onyx-chat-input-textarea").click();
|
||||
await page.locator("#onyx-chat-input-textarea").fill(message);
|
||||
await page.locator("#onyx-chat-input-send-button").click();
|
||||
|
||||
await page.waitForFunction(
|
||||
() => window.location.href.includes("chatId="),
|
||||
null,
|
||||
{ timeout: 15000 }
|
||||
);
|
||||
|
||||
await expect(page.locator('[aria-label="share-chat-button"]')).toBeVisible({
|
||||
timeout: 10000,
|
||||
});
|
||||
}
|
||||
|
||||
async function openShareModal(page: Page) {
|
||||
await page.locator('[aria-label="share-chat-button"]').click();
|
||||
await expect(page.getByRole("dialog")).toBeVisible({ timeout: 5000 });
|
||||
}
|
||||
|
||||
test.describe("Share Chat Session Modal", () => {
|
||||
test.describe.configure({ mode: "serial" });
|
||||
|
||||
let page: Page;
|
||||
|
||||
test.beforeAll(async ({ browser }) => {
|
||||
page = await browser.newPage();
|
||||
await loginAsRandomUser(page);
|
||||
await sendMessageAndWaitForChat(page, "Hello for share test");
|
||||
});
|
||||
|
||||
test.afterAll(async () => {
|
||||
await page.close();
|
||||
});
|
||||
|
||||
test("shows Private selected by default", async () => {
|
||||
await openShareModal(page);
|
||||
|
||||
const dialog = page.getByRole("dialog");
|
||||
await expect(dialog).toBeVisible();
|
||||
|
||||
const privateOption = dialog.locator(
|
||||
'[aria-label="share-modal-option-private"]'
|
||||
);
|
||||
await expect(privateOption.locator("svg").last()).toBeVisible();
|
||||
|
||||
const submitButton = dialog.locator('[aria-label="share-modal-submit"]');
|
||||
await expect(submitButton).toHaveText("Done");
|
||||
|
||||
const cancelButton = dialog.locator('[aria-label="share-modal-cancel"]');
|
||||
await expect(cancelButton).toBeVisible();
|
||||
|
||||
await expectElementScreenshot(dialog, {
|
||||
name: "share-modal-default-private",
|
||||
});
|
||||
|
||||
await page.keyboard.press("Escape");
|
||||
await expect(dialog).toBeHidden({ timeout: 5000 });
|
||||
});
|
||||
|
||||
test("selecting Your Organization changes submit text", async () => {
|
||||
await openShareModal(page);
|
||||
|
||||
const dialog = page.getByRole("dialog");
|
||||
|
||||
await dialog.locator('[aria-label="share-modal-option-public"]').click();
|
||||
|
||||
const submitButton = dialog.locator('[aria-label="share-modal-submit"]');
|
||||
await expect(submitButton).toHaveText("Create Share Link");
|
||||
|
||||
const cancelButton = dialog.locator('[aria-label="share-modal-cancel"]');
|
||||
await expect(cancelButton).toBeVisible();
|
||||
|
||||
await expectElementScreenshot(dialog, {
|
||||
name: "share-modal-public-selected",
|
||||
});
|
||||
|
||||
await page.keyboard.press("Escape");
|
||||
await expect(dialog).toBeHidden({ timeout: 5000 });
|
||||
});
|
||||
|
||||
test("Cancel closes modal without API calls", async () => {
|
||||
let patchCallCount = 0;
|
||||
await page.route("**/api/chat/chat-session/*", async (route) => {
|
||||
if (route.request().method() === "PATCH") {
|
||||
patchCallCount++;
|
||||
}
|
||||
await route.continue();
|
||||
});
|
||||
|
||||
await openShareModal(page);
|
||||
|
||||
const dialog = page.getByRole("dialog");
|
||||
const cancelButton = dialog.locator('[aria-label="share-modal-cancel"]');
|
||||
await cancelButton.click();
|
||||
|
||||
await expect(dialog).toBeHidden({ timeout: 5000 });
|
||||
expect(patchCallCount).toBe(0);
|
||||
|
||||
await page.unrouteAll({ behavior: "ignoreErrors" });
|
||||
});
|
||||
|
||||
test("X button closes modal without API calls", async () => {
|
||||
let patchCallCount = 0;
|
||||
await page.route("**/api/chat/chat-session/*", async (route) => {
|
||||
if (route.request().method() === "PATCH") {
|
||||
patchCallCount++;
|
||||
}
|
||||
await route.continue();
|
||||
});
|
||||
|
||||
await openShareModal(page);
|
||||
|
||||
const dialog = page.getByRole("dialog");
|
||||
const closeButton = dialog.locator('div[tabindex="-1"] button');
|
||||
await closeButton.click();
|
||||
|
||||
await expect(dialog).toBeHidden({ timeout: 5000 });
|
||||
expect(patchCallCount).toBe(0);
|
||||
|
||||
await page.unrouteAll({ behavior: "ignoreErrors" });
|
||||
});
|
||||
|
||||
test("creating a share link calls API and shows link", async () => {
|
||||
await openShareModal(page);
|
||||
|
||||
const dialog = page.getByRole("dialog");
|
||||
|
||||
let patchBody: Record<string, unknown> | null = null;
|
||||
await page.route("**/api/chat/chat-session/*", async (route) => {
|
||||
if (route.request().method() === "PATCH") {
|
||||
patchBody = JSON.parse(route.request().postData() ?? "{}");
|
||||
await route.continue();
|
||||
} else {
|
||||
await route.continue();
|
||||
}
|
||||
});
|
||||
|
||||
await dialog.locator('[aria-label="share-modal-option-public"]').click();
|
||||
const submitButton = dialog.locator('[aria-label="share-modal-submit"]');
|
||||
await submitButton.click();
|
||||
|
||||
await page.waitForResponse(
|
||||
(r) =>
|
||||
r.url().includes("/api/chat/chat-session/") &&
|
||||
r.request().method() === "PATCH",
|
||||
{ timeout: 10000 }
|
||||
);
|
||||
|
||||
expect(patchBody).toEqual({ sharing_status: "public" });
|
||||
|
||||
const linkInput = dialog.locator('[aria-label="share-modal-link-input"]');
|
||||
await expect(linkInput).toBeVisible({ timeout: 5000 });
|
||||
|
||||
const inputValue = await linkInput.locator("input").inputValue();
|
||||
expect(inputValue).toContain("/app/shared/");
|
||||
|
||||
await expect(submitButton).toHaveText("Copy Link");
|
||||
await expect(dialog.getByText("Chat shared")).toBeVisible();
|
||||
await expect(
|
||||
dialog.locator('[aria-label="share-modal-cancel"]')
|
||||
).toBeHidden();
|
||||
|
||||
await expectElementScreenshot(dialog, {
|
||||
name: "share-modal-link-created",
|
||||
mask: ['[aria-label="share-modal-link-input"]'],
|
||||
});
|
||||
|
||||
await page.unrouteAll({ behavior: "ignoreErrors" });
|
||||
|
||||
// Wait for the toast to confirm SWR data has been refreshed
|
||||
// before closing, so the next test sees up-to-date shared_status
|
||||
await expect(
|
||||
page.getByText("Share link copied to clipboard!").first()
|
||||
).toBeVisible({ timeout: 5000 });
|
||||
|
||||
await page.keyboard.press("Escape");
|
||||
await expect(dialog).toBeHidden({ timeout: 5000 });
|
||||
});
|
||||
|
||||
test("Copy Link triggers clipboard copy", async () => {
|
||||
await openShareModal(page);
|
||||
|
||||
const dialog = page.getByRole("dialog");
|
||||
|
||||
await expect(
|
||||
dialog.locator('[aria-label="share-modal-link-input"]')
|
||||
).toBeVisible({ timeout: 5000 });
|
||||
|
||||
const submitButton = dialog.locator('[aria-label="share-modal-submit"]');
|
||||
await expect(submitButton).toHaveText("Copy Link");
|
||||
|
||||
await submitButton.click();
|
||||
|
||||
await expect(
|
||||
page.getByText("Share link copied to clipboard!").first()
|
||||
).toBeVisible({ timeout: 5000 });
|
||||
|
||||
await page.keyboard.press("Escape");
|
||||
await expect(dialog).toBeHidden({ timeout: 5000 });
|
||||
});
|
||||
|
||||
test("making chat private again calls API and closes modal", async () => {
|
||||
let patchBody: Record<string, unknown> | null = null;
|
||||
await page.route("**/api/chat/chat-session/*", async (route) => {
|
||||
if (route.request().method() === "PATCH") {
|
||||
patchBody = JSON.parse(route.request().postData() ?? "{}");
|
||||
await route.continue();
|
||||
} else {
|
||||
await route.continue();
|
||||
}
|
||||
});
|
||||
|
||||
await openShareModal(page);
|
||||
|
||||
const dialog = page.getByRole("dialog");
|
||||
const submitButton = dialog.locator('[aria-label="share-modal-submit"]');
|
||||
|
||||
await dialog.locator('[aria-label="share-modal-option-private"]').click();
|
||||
|
||||
await expect(submitButton).toHaveText("Make Private");
|
||||
|
||||
await submitButton.click();
|
||||
|
||||
await page.waitForResponse(
|
||||
(r) =>
|
||||
r.url().includes("/api/chat/chat-session/") &&
|
||||
r.request().method() === "PATCH",
|
||||
{ timeout: 10000 }
|
||||
);
|
||||
|
||||
expect(patchBody).toEqual({ sharing_status: "private" });
|
||||
|
||||
await expect(dialog).toBeHidden({ timeout: 5000 });
|
||||
|
||||
await expect(page.getByText("Chat is now private")).toBeVisible({
|
||||
timeout: 5000,
|
||||
});
|
||||
|
||||
await page.unrouteAll({ behavior: "ignoreErrors" });
|
||||
});
|
||||
});
|
||||
18
web/tests/e2e/utils/theme.ts
Normal file
18
web/tests/e2e/utils/theme.ts
Normal file
@@ -0,0 +1,18 @@
|
||||
import type { Page } from "@playwright/test";
|
||||
|
||||
export const THEMES = ["light", "dark"] as const;
|
||||
export type Theme = (typeof THEMES)[number];
|
||||
|
||||
/**
|
||||
* Injects the given theme into localStorage via `addInitScript` so that
|
||||
* `next-themes` applies it on first render. Call this in `beforeEach`
|
||||
* **before** any `page.goto()`.
|
||||
*/
|
||||
export async function setThemeBeforeNavigation(
|
||||
page: Page,
|
||||
theme: Theme
|
||||
): Promise<void> {
|
||||
await page.addInitScript((t: string) => {
|
||||
localStorage.setItem("theme", t);
|
||||
}, theme);
|
||||
}
|
||||
Reference in New Issue
Block a user