Compare commits

...

1 Commits

Author SHA1 Message Date
Wenxi
53494324aa chore: hotfix v2.1.0-beta.1 (#5814)
Co-authored-by: Chris Weaver <chris@onyx.app>
Co-authored-by: Raunak Bhagat <r@rabh.io>
Co-authored-by: Justin Tahara <105671973+justin-tahara@users.noreply.github.com>
Co-authored-by: Evan Lohn <evan@danswer.ai>
Co-authored-by: Nikolas Garza <90273783+nmgarza5@users.noreply.github.com>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Co-authored-by: cubic-dev-ai[bot] <191113872+cubic-dev-ai[bot]@users.noreply.github.com>
Co-authored-by: Nikolas Garza <nikolas@Nikolass-MacBook-Pro.attlocal.net>
2025-10-20 18:27:50 -07:00
75 changed files with 4975 additions and 876 deletions

View File

@@ -109,6 +109,15 @@ jobs:
# Needed for trivyignore
- name: Checkout
uses: actions/checkout@v4
- name: Check if stable release version
id: check_version
run: |
if [[ "${{ github.ref_name }}" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]] && [[ "${{ github.ref_name }}" != *"cloud"* ]]; then
echo "is_stable=true" >> $GITHUB_OUTPUT
else
echo "is_stable=false" >> $GITHUB_OUTPUT
fi
- name: Download digests
uses: actions/download-artifact@v4

View File

@@ -120,6 +120,15 @@ jobs:
if: needs.precheck.outputs.should-run == 'true'
runs-on: ubuntu-latest
steps:
- name: Check if stable release version
id: check_version
run: |
if [[ "${{ github.ref_name }}" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]] && [[ "${{ github.ref_name }}" != *"cloud"* ]]; then
echo "is_stable=true" >> $GITHUB_OUTPUT
else
echo "is_stable=false" >> $GITHUB_OUTPUT
fi
- name: Download digests
uses: actions/download-artifact@v4
with:

View File

@@ -35,3 +35,7 @@ jobs:
- name: Pull, Tag and Push API Server Image
run: |
docker buildx imagetools create -t onyxdotapp/onyx-backend:latest onyxdotapp/onyx-backend:${{ github.event.inputs.version }}
- name: Pull, Tag and Push Model Server Image
run: |
docker buildx imagetools create -t onyxdotapp/onyx-model-server:latest onyxdotapp/onyx-model-server:${{ github.event.inputs.version }}

35
.github/workflows/pr-jest-tests.yml vendored Normal file
View File

@@ -0,0 +1,35 @@
name: Run Jest Tests
concurrency:
group: Run-Jest-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
cancel-in-progress: true
on: push
jobs:
jest-tests:
name: Jest Tests
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Setup node
uses: actions/setup-node@v4
with:
node-version: 22
- name: Install node dependencies
working-directory: ./web
run: npm ci
- name: Run Jest tests
working-directory: ./web
run: npm test -- --ci --coverage --maxWorkers=50%
- name: Upload coverage reports
if: always()
uses: actions/upload-artifact@v4
with:
name: jest-coverage-${{ github.run_id }}
path: ./web/coverage
retention-days: 7

View File

@@ -50,6 +50,25 @@ def get_empty_chat_messages_entries__paginated(
if message.message_type != MessageType.USER:
continue
# Get user email
user_email = chat_session.user.email if chat_session.user else None
# Get assistant name (from session persona, or alternate if specified)
assistant_name = None
if message.alternate_assistant_id:
# If there's an alternate assistant, we need to fetch it
from onyx.db.models import Persona
alternate_persona = (
db_session.query(Persona)
.filter(Persona.id == message.alternate_assistant_id)
.first()
)
if alternate_persona:
assistant_name = alternate_persona.name
elif chat_session.persona:
assistant_name = chat_session.persona.name
message_skeletons.append(
ChatMessageSkeleton(
message_id=message.id,
@@ -57,6 +76,9 @@ def get_empty_chat_messages_entries__paginated(
user_id=str(chat_session.user_id) if chat_session.user_id else None,
flow_type=flow_type,
time_sent=message.time_sent,
assistant_name=assistant_name,
user_email=user_email,
number_of_tokens=message.token_count,
)
)
if len(chat_sessions) == 0:

View File

@@ -48,7 +48,17 @@ def generate_chat_messages_report(
max_size=MAX_IN_MEMORY_SIZE, mode="w+"
) as temp_file:
csvwriter = csv.writer(temp_file, delimiter=",")
csvwriter.writerow(["session_id", "user_id", "flow_type", "time_sent"])
csvwriter.writerow(
[
"session_id",
"user_id",
"flow_type",
"time_sent",
"assistant_name",
"user_email",
"number_of_tokens",
]
)
for chat_message_skeleton_batch in get_all_empty_chat_message_entries(
db_session, period
):
@@ -59,6 +69,9 @@ def generate_chat_messages_report(
chat_message_skeleton.user_id,
chat_message_skeleton.flow_type,
chat_message_skeleton.time_sent.isoformat(),
chat_message_skeleton.assistant_name,
chat_message_skeleton.user_email,
chat_message_skeleton.number_of_tokens,
]
)

View File

@@ -16,6 +16,9 @@ class ChatMessageSkeleton(BaseModel):
user_id: str | None
flow_type: FlowType
time_sent: datetime
assistant_name: str | None
user_email: str | None
number_of_tokens: int
class UserSkeleton(BaseModel):

View File

@@ -37,6 +37,8 @@ from onyx.configs.model_configs import LITELLM_EXTRA_BODY
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import LLMConfig
from onyx.llm.interfaces import ToolChoiceOptions
from onyx.llm.llm_provider_options import VERTEX_CREDENTIALS_FILE_KWARG
from onyx.llm.llm_provider_options import VERTEX_LOCATION_KWARG
from onyx.llm.utils import model_is_reasoning_model
from onyx.server.utils import mask_string
from onyx.utils.logger import setup_logger
@@ -49,8 +51,6 @@ if TYPE_CHECKING:
_LLM_PROMPT_LONG_TERM_LOG_CATEGORY = "llm_prompt"
VERTEX_CREDENTIALS_FILE_KWARG = "vertex_credentials"
VERTEX_LOCATION_KWARG = "vertex_location"
LEGACY_MAX_TOKENS_KWARG = "max_tokens"
STANDARD_MAX_TOKENS_KWARG = "max_completion_tokens"

View File

@@ -12,6 +12,7 @@ from onyx.llm.exceptions import GenAIDisabledException
from onyx.llm.interfaces import LLM
from onyx.llm.llm_provider_options import OLLAMA_API_KEY_CONFIG_KEY
from onyx.llm.llm_provider_options import OLLAMA_PROVIDER_NAME
from onyx.llm.llm_provider_options import OPENROUTER_PROVIDER_NAME
from onyx.llm.override_models import LLMOverride
from onyx.llm.utils import get_max_input_tokens_from_llm_provider
from onyx.llm.utils import model_supports_image_input
@@ -26,19 +27,24 @@ logger = setup_logger()
def _build_provider_extra_headers(
provider: str, custom_config: dict[str, str] | None
) -> dict[str, str]:
if provider != OLLAMA_PROVIDER_NAME or not custom_config:
return {}
# Ollama Cloud: allow passing Bearer token via custom config for cloud instances
if provider == OLLAMA_PROVIDER_NAME and custom_config:
raw_api_key = custom_config.get(OLLAMA_API_KEY_CONFIG_KEY)
api_key = raw_api_key.strip() if raw_api_key else None
if not api_key:
return {}
if not api_key.lower().startswith("bearer "):
api_key = f"Bearer {api_key}"
return {"Authorization": api_key}
raw_api_key = custom_config.get(OLLAMA_API_KEY_CONFIG_KEY)
# Passing these will put Onyx on the OpenRouter leaderboard
elif provider == OPENROUTER_PROVIDER_NAME:
return {
"HTTP-Referer": "https://onyx.app",
"X-Title": "Onyx",
}
api_key = raw_api_key.strip() if raw_api_key else None
if not api_key:
return {}
if not api_key.lower().startswith("bearer "):
api_key = f"Bearer {api_key}"
return {"Authorization": api_key}
return {}
def get_main_llm_from_tuple(

View File

@@ -2,8 +2,6 @@ from enum import Enum
from pydantic import BaseModel
from onyx.llm.chat_llm import VERTEX_CREDENTIALS_FILE_KWARG
from onyx.llm.chat_llm import VERTEX_LOCATION_KWARG
from onyx.llm.utils import model_supports_image_input
from onyx.server.manage.llm.models import ModelConfigurationView
@@ -137,18 +135,8 @@ BEDROCK_REGION_OPTIONS = _build_bedrock_region_options()
OLLAMA_PROVIDER_NAME = "ollama"
OLLAMA_API_KEY_CONFIG_KEY = "OLLAMA_API_KEY"
def get_bedrock_model_names() -> list[str]:
import litellm
# bedrock_converse_models are just extensions of the bedrock_models, not sure why
# litellm has split them into two lists :(
return [
model
for model in list(litellm.bedrock_models.union(litellm.bedrock_converse_models))
if "/" not in model and "embed" not in model
][::-1]
# OpenRouter
OPENROUTER_PROVIDER_NAME = "openrouter"
IGNORABLE_ANTHROPIC_MODELS = [
"claude-2",
@@ -157,17 +145,6 @@ IGNORABLE_ANTHROPIC_MODELS = [
]
ANTHROPIC_PROVIDER_NAME = "anthropic"
def get_anthropic_model_names() -> list[str]:
import litellm
return [
model
for model in litellm.anthropic_models
if model not in IGNORABLE_ANTHROPIC_MODELS
][::-1]
ANTHROPIC_VISIBLE_MODEL_NAMES = [
"claude-sonnet-4-5-20250929",
"claude-sonnet-4-20250514",
@@ -177,6 +154,8 @@ AZURE_PROVIDER_NAME = "azure"
VERTEXAI_PROVIDER_NAME = "vertex_ai"
VERTEX_CREDENTIALS_FILE_KWARG = "vertex_credentials"
VERTEX_LOCATION_KWARG = "vertex_location"
VERTEXAI_DEFAULT_MODEL = "gemini-2.0-flash"
VERTEXAI_DEFAULT_FAST_MODEL = "gemini-2.0-flash-lite"
VERTEXAI_MODEL_NAMES = [
@@ -223,15 +202,39 @@ def _get_provider_to_models_map() -> dict[str, list[str]]:
ANTHROPIC_PROVIDER_NAME: get_anthropic_model_names(),
VERTEXAI_PROVIDER_NAME: VERTEXAI_MODEL_NAMES,
OLLAMA_PROVIDER_NAME: [],
OPENROUTER_PROVIDER_NAME: [],
}
def get_bedrock_model_names() -> list[str]:
import litellm
# bedrock_converse_models are just extensions of the bedrock_models, not sure why
# litellm has split them into two lists :(
return [
model
for model in list(litellm.bedrock_models.union(litellm.bedrock_converse_models))
if "/" not in model and "embed" not in model
][::-1]
def get_anthropic_model_names() -> list[str]:
import litellm
return [
model
for model in litellm.anthropic_models
if model not in IGNORABLE_ANTHROPIC_MODELS
][::-1]
_PROVIDER_TO_VISIBLE_MODELS_MAP = {
OPENAI_PROVIDER_NAME: OPEN_AI_VISIBLE_MODEL_NAMES,
BEDROCK_PROVIDER_NAME: [],
ANTHROPIC_PROVIDER_NAME: ANTHROPIC_VISIBLE_MODEL_NAMES,
VERTEXAI_PROVIDER_NAME: VERTEXAI_VISIBLE_MODEL_NAMES,
OLLAMA_PROVIDER_NAME: [],
OPENROUTER_PROVIDER_NAME: [],
}
@@ -372,6 +375,20 @@ def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]:
default_model=VERTEXAI_DEFAULT_MODEL,
default_fast_model=VERTEXAI_DEFAULT_MODEL,
),
WellKnownLLMProviderDescriptor(
name=OPENROUTER_PROVIDER_NAME,
display_name="OpenRouter",
api_key_required=True,
api_base_required=True,
api_version_required=False,
custom_config_keys=[],
model_configurations=fetch_model_configurations_for_provider(
OPENROUTER_PROVIDER_NAME
),
default_model=None,
default_fast_model=None,
default_api_base="https://openrouter.ai/api/v1",
),
]

View File

@@ -644,41 +644,37 @@ def get_max_input_tokens_from_llm_provider(
def model_supports_image_input(model_name: str, model_provider: str) -> bool:
# TODO: Add support to check model config for any provider
# TODO: Circular import means OLLAMA_PROVIDER_NAME is not available here
if model_provider == "ollama":
try:
with get_session_with_current_tenant() as db_session:
model_config = db_session.scalar(
select(ModelConfiguration)
.join(
LLMProvider,
ModelConfiguration.llm_provider_id == LLMProvider.id,
)
.where(
ModelConfiguration.name == model_name,
LLMProvider.provider == model_provider,
)
)
if model_config and model_config.supports_image_input is not None:
return model_config.supports_image_input
except Exception as e:
logger.warning(
f"Failed to query database for {model_provider} model {model_name} image support: {e}"
)
model_map = get_model_map()
# First, try to read an explicit configuration from the model_configuration table
try:
model_obj = find_model_obj(
model_map,
model_provider,
model_name,
)
if not model_obj:
raise RuntimeError(
f"No litellm entry found for {model_provider}/{model_name}"
with get_session_with_current_tenant() as db_session:
model_config = db_session.scalar(
select(ModelConfiguration)
.join(
LLMProvider,
ModelConfiguration.llm_provider_id == LLMProvider.id,
)
.where(
ModelConfiguration.name == model_name,
LLMProvider.provider == model_provider,
)
)
if model_config and model_config.supports_image_input is not None:
return model_config.supports_image_input
except Exception as e:
logger.warning(
f"Failed to query database for {model_provider} model {model_name} image support: {e}"
)
# Fallback to looking up the model in the litellm model_cost dict
try:
model_obj = find_model_obj(get_model_map(), model_provider, model_name)
if not model_obj:
logger.warning(
f"No litellm entry found for {model_provider}/{model_name}, "
"this model may or may not support image input."
)
return False
return model_obj.get("supports_vision", False)
except Exception:
logger.exception(

View File

@@ -46,6 +46,9 @@ from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
from onyx.server.manage.llm.models import OllamaFinalModelResponse
from onyx.server.manage.llm.models import OllamaModelDetails
from onyx.server.manage.llm.models import OllamaModelsRequest
from onyx.server.manage.llm.models import OpenRouterFinalModelResponse
from onyx.server.manage.llm.models import OpenRouterModelDetails
from onyx.server.manage.llm.models import OpenRouterModelsRequest
from onyx.server.manage.llm.models import TestLLMRequest
from onyx.server.manage.llm.models import VisionProviderResponse
from onyx.utils.logger import setup_logger
@@ -577,3 +580,75 @@ def get_ollama_available_models(
)
return all_models_with_context_size_and_vision
def _get_openrouter_models_response(api_base: str, api_key: str) -> dict:
"""Perform GET to OpenRouter /models and return parsed JSON."""
cleaned_api_base = api_base.strip().rstrip("/")
url = f"{cleaned_api_base}/models"
headers = {
"Authorization": f"Bearer {api_key}",
# Optional headers recommended by OpenRouter for attribution
"HTTP-Referer": "https://onyx.app",
"X-Title": "Onyx",
}
try:
response = httpx.get(url, headers=headers, timeout=10.0)
response.raise_for_status()
return response.json()
except Exception as e:
raise HTTPException(
status_code=400,
detail=f"Failed to fetch OpenRouter models: {e}",
)
@admin_router.post("/openrouter/available-models")
def get_openrouter_available_models(
request: OpenRouterModelsRequest,
_: User | None = Depends(current_admin_user),
) -> list[OpenRouterFinalModelResponse]:
"""Fetch available models from OpenRouter `/models` endpoint.
Parses id, context_length, and architecture.input_modalities to infer vision support.
"""
response_json = _get_openrouter_models_response(
api_base=request.api_base, api_key=request.api_key
)
data = response_json.get("data", [])
if not isinstance(data, list) or len(data) == 0:
raise HTTPException(
status_code=400,
detail="No models found from your OpenRouter endpoint",
)
results: list[OpenRouterFinalModelResponse] = []
for item in data:
try:
model_details = OpenRouterModelDetails.model_validate(item)
# NOTE: This should be removed if we ever support dynamically fetching embedding models.
if model_details.is_embedding_model:
continue
results.append(
OpenRouterFinalModelResponse(
name=model_details.id,
max_input_tokens=model_details.context_length,
supports_image_input=model_details.supports_image_input,
)
)
except Exception as e:
logger.warning(
"Failed to parse OpenRouter model entry",
extra={"error": str(e), "item": str(item)[:1000]},
)
if not results:
raise HTTPException(
status_code=400, detail="No compatible models found from OpenRouter"
)
return sorted(results, key=lambda m: m.name.lower())

View File

@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING
from pydantic import BaseModel
from pydantic import Field
from pydantic import field_validator
from onyx.llm.utils import get_max_input_tokens
from onyx.llm.utils import model_supports_image_input
@@ -34,6 +35,12 @@ class TestLLMRequest(BaseModel):
# if try and use the existing API key
api_key_changed: bool
@field_validator("provider", mode="before")
@classmethod
def normalize_provider(cls, value: str) -> str:
"""Normalize provider name by stripping whitespace and lowercasing."""
return value.strip().lower()
class LLMProviderDescriptor(BaseModel):
"""A descriptor for an LLM provider that can be safely viewed by
@@ -91,6 +98,12 @@ class LLMProviderUpsertRequest(LLMProvider):
api_key_changed: bool = False
model_configurations: list["ModelConfigurationUpsertRequest"] = []
@field_validator("provider", mode="before")
@classmethod
def normalize_provider(cls, value: str) -> str:
"""Normalize provider name by stripping whitespace and lowercasing."""
return value.strip().lower()
class LLMProviderView(LLMProvider):
"""Stripped down representation of LLMProvider for display / limited access info only"""
@@ -224,3 +237,36 @@ class OllamaModelDetails(BaseModel):
def supports_image_input(self) -> bool:
"""Check if this model supports image input"""
return "vision" in self.capabilities
# OpenRouter dynamic models fetch
class OpenRouterModelsRequest(BaseModel):
api_base: str
api_key: str
class OpenRouterModelDetails(BaseModel):
"""Response model for OpenRouter /api/v1/models endpoint"""
# This is used to ignore any extra fields that are returned from the API
model_config = {"extra": "ignore"}
id: str
context_length: int
architecture: dict[str, Any] # Contains 'input_modalities' key
@property
def supports_image_input(self) -> bool:
input_modalities = self.architecture.get("input_modalities", [])
return isinstance(input_modalities, list) and "image" in input_modalities
@property
def is_embedding_model(self) -> bool:
output_modalities = self.architecture.get("output_modalities", [])
return isinstance(output_modalities, list) and "embeddings" in output_modalities
class OpenRouterFinalModelResponse(BaseModel):
name: str
max_input_tokens: int
supports_image_input: bool

View File

@@ -33,6 +33,7 @@ from onyx.server.manage.models import FullModelVersionResponse
from onyx.server.models import IdReturn
from onyx.utils.logger import setup_logger
from shared_configs.configs import ALT_INDEX_SUFFIX
from shared_configs.configs import MULTI_TENANT
router = APIRouter(prefix="/search-settings")
logger = setup_logger()
@@ -50,6 +51,13 @@ def set_new_search_settings(
if search_settings_new.index_name:
logger.warning("Index name was specified by request, this is not suggested")
# Disallow contextual RAG for cloud deployments
if MULTI_TENANT and search_settings_new.enable_contextual_rag:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Contextual RAG disabled in Onyx Cloud",
)
# Validate cloud provider exists or create new LiteLLM provider
if search_settings_new.provider_type is not None:
cloud_provider = get_embedding_provider_from_provider_type(
@@ -217,6 +225,13 @@ def update_saved_search_settings(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
# Disallow contextual RAG for cloud deployments
if MULTI_TENANT and search_settings.enable_contextual_rag:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Contextual RAG disabled in Onyx Cloud",
)
update_current_search_settings(
search_settings=search_settings, db_session=db_session
)

View File

@@ -6,6 +6,7 @@ Full tenant analysis script that:
3. Analyzes the collected data
"""
import argparse
import csv
import json
import os
@@ -361,17 +362,35 @@ def find_recent_tenant_data() -> tuple[list[dict[str, Any]] | None, str | None]:
def main() -> None:
# Parse command-line arguments
parser = argparse.ArgumentParser(
description="Analyze tenant data and identify gated tenants with no recent queries"
)
parser.add_argument(
"--skip-cache",
action="store_true",
help="Skip cached tenant data and collect fresh data from pod",
)
args = parser.parse_args()
try:
# Step 0: Collect control plane data
control_plane_data = collect_control_plane_data()
# Step 1: Check for recent tenant data (< 7 days old)
tenant_data, cached_file = find_recent_tenant_data()
# Step 1: Check for recent tenant data (< 7 days old) unless --skip-cache is set
tenant_data = None
cached_file = None
if not args.skip_cache:
tenant_data, cached_file = find_recent_tenant_data()
if tenant_data:
print(f"Using cached tenant data from: {cached_file}")
print(f"Total tenants in cache: {len(tenant_data)}")
else:
if args.skip_cache:
print("\n⚠ Skipping cache (--skip-cache flag set)")
# Step 2a: Find the heavy worker pod
pod_name = find_worker_pod()

View File

@@ -21,10 +21,12 @@ Examples:
python backend/scripts/cleanup_tenant.py --csv gated_tenants_no_query_3mo.csv --force
"""
import csv
import json
import signal
import subprocess
import sys
from datetime import datetime
from pathlib import Path
from scripts.tenant_cleanup.cleanup_utils import confirm_step
@@ -420,13 +422,17 @@ def cleanup_control_plane(tenant_id: str, force: bool = False) -> None:
raise
def cleanup_tenant(tenant_id: str, force: bool = False) -> None:
def cleanup_tenant(tenant_id: str, pod_name: str, force: bool = False) -> bool:
"""
Main cleanup function that orchestrates all cleanup steps.
Args:
tenant_id: The tenant ID to clean up
pod_name: The Kubernetes pod name to execute operations on
force: If True, skip all confirmation prompts
Returns:
True if cleanup was performed, False if skipped
"""
print(f"Starting cleanup for tenant: {tenant_id}")
@@ -445,44 +451,52 @@ def cleanup_tenant(tenant_id: str, force: bool = False) -> None:
)
print(f"{'=' * 80}\n")
if force:
print(f"Skipping cleanup for tenant {tenant_id} in force mode")
return False
# Always ask for confirmation if not gated, even in force mode
response = input(
"Are you ABSOLUTELY SURE you want to proceed? Type 'yes' to confirm: "
)
if response.lower() != "yes":
print("Cleanup aborted - tenant is not GATED_ACCESS")
return
return False
elif tenant_status == "GATED_ACCESS":
print("✓ Tenant status is GATED_ACCESS - safe to proceed with cleanup")
elif tenant_status is None:
print("⚠️ WARNING: Could not determine tenant status!")
if force:
print(f"Skipping cleanup for tenant {tenant_id} in force mode")
return False
response = input("Continue anyway? Type 'yes' to confirm: ")
if response.lower() != "yes":
print("Cleanup aborted - could not verify tenant status")
return
return False
except Exception as e:
print(f"⚠️ WARNING: Failed to check tenant status: {e}")
if force:
print(f"Skipping cleanup for tenant {tenant_id} in force mode")
return False
response = input("Continue anyway? Type 'yes' to confirm: ")
if response.lower() != "yes":
print("Cleanup aborted - could not verify tenant status")
return
return False
print(f"{'=' * 80}\n")
# Find heavy worker pod for Vespa and schema operations
try:
pod_name = find_worker_pod()
except Exception as e:
print(f"✗ Failed to find heavy worker pod: {e}", file=sys.stderr)
print("Cannot proceed with Vespa and schema cleanup")
return
# Fetch tenant users for informational purposes (non-blocking)
print(f"\n{'=' * 80}")
try:
get_tenant_users(pod_name, tenant_id)
except Exception as e:
print(f"⚠ Could not fetch tenant users: {e}")
print(f"{'=' * 80}\n")
# Skip in force mode as it's only informational
if not force:
print(f"\n{'=' * 80}")
try:
get_tenant_users(pod_name, tenant_id)
except Exception as e:
print(f"⚠ Could not fetch tenant users: {e}")
print(f"{'=' * 80}\n")
# Step 1: Make sure all documents are deleted
print(f"\n{'=' * 80}")
@@ -498,7 +512,7 @@ def cleanup_tenant(tenant_id: str, force: bool = False) -> None:
print(
"You may need to mark connectors for deletion and wait for cleanup to complete."
)
return
return False
print(f"{'=' * 80}\n")
# Step 2: Drop data plane schema
@@ -514,7 +528,7 @@ def cleanup_tenant(tenant_id: str, force: bool = False) -> None:
response = input("Continue with control plane cleanup? (y/n): ")
if response.lower() != "y":
print("Cleanup aborted by user")
return
return False
else:
print("[FORCE MODE] Continuing despite schema cleanup failure")
else:
@@ -535,11 +549,12 @@ def cleanup_tenant(tenant_id: str, force: bool = False) -> None:
print("[FORCE MODE] Control plane cleanup failed but continuing")
else:
print("Step 3 skipped by user")
return
return False
print(f"\n{'=' * 80}")
print(f"✓ Cleanup completed for tenant: {tenant_id}")
print(f"{'=' * 80}")
return True
def main() -> None:
@@ -643,43 +658,87 @@ def main() -> None:
f"⚠ FORCE MODE: Running cleanup for {len(tenant_ids)} tenants without confirmations"
)
# Find heavy worker pod once for all tenants
try:
pod_name = find_worker_pod()
print(f"✓ Found worker pod: {pod_name}\n")
except Exception as e:
print(f"✗ Failed to find heavy worker pod: {e}", file=sys.stderr)
print("Cannot proceed with cleanup")
sys.exit(1)
# Run cleanup for each tenant
failed_tenants = []
successful_tenants = []
skipped_tenants = []
for idx, tenant_id in enumerate(tenant_ids, 1):
if len(tenant_ids) > 1:
print(f"\n{'=' * 80}")
print(f"Processing tenant {idx}/{len(tenant_ids)}: {tenant_id}")
print(f"{'=' * 80}")
# Open CSV file for writing successful cleanups in real-time
csv_output_path = "cleaned_tenants.csv"
with open(csv_output_path, "w", newline="") as csv_file:
csv_writer = csv.writer(csv_file)
csv_writer.writerow(["tenant_id", "cleaned_at"])
csv_file.flush() # Ensure header is written immediately
try:
cleanup_tenant(tenant_id, force)
successful_tenants.append(tenant_id)
except Exception as e:
print(f"✗ Cleanup failed for tenant {tenant_id}: {e}", file=sys.stderr)
failed_tenants.append((tenant_id, str(e)))
print(f"Writing successful cleanups to: {csv_output_path}\n")
# If not in force mode and there are more tenants, ask if we should continue
if not force and idx < len(tenant_ids):
response = input(
f"\nContinue with remaining {len(tenant_ids) - idx} tenant(s)? (y/n): "
)
if response.lower() != "y":
print("Cleanup aborted by user")
break
for idx, tenant_id in enumerate(tenant_ids, 1):
if len(tenant_ids) > 1:
print(f"\n{'=' * 80}")
print(f"Processing tenant {idx}/{len(tenant_ids)}: {tenant_id}")
print(f"{'=' * 80}")
# Print summary if multiple tenants
if len(tenant_ids) > 1:
try:
was_cleaned = cleanup_tenant(tenant_id, pod_name, force)
if was_cleaned:
# Only record if actually cleaned up (not skipped)
successful_tenants.append(tenant_id)
# Write to CSV immediately after successful cleanup
timestamp = datetime.utcnow().isoformat()
csv_writer.writerow([tenant_id, timestamp])
csv_file.flush() # Ensure real-time write
print(f"✓ Recorded cleanup in {csv_output_path}")
else:
skipped_tenants.append(tenant_id)
print(f"⚠ Tenant {tenant_id} was skipped (not recorded in CSV)")
except Exception as e:
print(f"✗ Cleanup failed for tenant {tenant_id}: {e}", file=sys.stderr)
failed_tenants.append((tenant_id, str(e)))
# If not in force mode and there are more tenants, ask if we should continue
if not force and idx < len(tenant_ids):
response = input(
f"\nContinue with remaining {len(tenant_ids) - idx} tenant(s)? (y/n): "
)
if response.lower() != "y":
print("Cleanup aborted by user")
break
# Print summary
if len(tenant_ids) == 1:
if successful_tenants:
print(f"\n✓ Successfully cleaned tenant written to: {csv_output_path}")
elif skipped_tenants:
print("\n⚠ Tenant was skipped")
elif len(tenant_ids) > 1:
print(f"\n{'=' * 80}")
print("CLEANUP SUMMARY")
print(f"{'=' * 80}")
print(f"Total tenants: {len(tenant_ids)}")
print(f"Successful: {len(successful_tenants)}")
print(f"Skipped: {len(skipped_tenants)}")
print(f"Failed: {len(failed_tenants)}")
print(f"\nSuccessfully cleaned tenants written to: {csv_output_path}")
if skipped_tenants:
print(f"\nSkipped tenants ({len(skipped_tenants)}):")
for tenant_id in skipped_tenants:
print(f" - {tenant_id}")
if failed_tenants:
print("\nFailed tenants:")
print(f"\nFailed tenants ({len(failed_tenants)}):")
for tenant_id, error in failed_tenants:
print(f" - {tenant_id}: {error}")

View File

@@ -7,137 +7,65 @@ Mark connectors for deletion script that:
4. Triggers the cleanup task
Usage:
python backend/scripts/tenant_cleanup/mark_connectors_for_deletion.py <tenant_id> [--force]
python backend/scripts/tenant_cleanup/mark_connectors_for_deletion.py --csv <csv_file_path> [--force]
python backend/scripts/tenant_cleanup/mark_connectors_for_deletion.py <tenant_id> [--force] [--concurrency N]
python backend/scripts/tenant_cleanup/mark_connectors_for_deletion.py --csv <csv_file_path> [--force] [--concurrency N]
Arguments:
tenant_id The tenant ID to process (required if not using --csv)
--csv PATH Path to CSV file containing tenant IDs to process
--force Skip all confirmation prompts (optional)
--concurrency N Process N tenants concurrently (default: 1)
Examples:
python backend/scripts/tenant_cleanup/mark_connectors_for_deletion.py tenant_abc123-def456-789
python backend/scripts/tenant_cleanup/mark_connectors_for_deletion.py tenant_abc123-def456-789 --force
python backend/scripts/tenant_cleanup/mark_connectors_for_deletion.py --csv gated_tenants_no_query_3mo.csv
python backend/scripts/tenant_cleanup/mark_connectors_for_deletion.py --csv gated_tenants_no_query_3mo.csv --force
python backend/scripts/tenant_cleanup/mark_connectors_for_deletion.py \
--csv gated_tenants_no_query_3mo.csv --force --concurrency 16
"""
import json
import subprocess
import sys
from concurrent.futures import as_completed
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
from threading import Lock
from typing import Any
from scripts.tenant_cleanup.cleanup_utils import confirm_step
from scripts.tenant_cleanup.cleanup_utils import find_worker_pod
from scripts.tenant_cleanup.cleanup_utils import get_tenant_status
from scripts.tenant_cleanup.cleanup_utils import read_tenant_ids_from_csv
def get_tenant_connectors(pod_name: str, tenant_id: str) -> list[dict]:
"""Get list of connector credential pairs for the tenant.
Args:
pod_name: The Kubernetes pod name to execute on
tenant_id: The tenant ID to query
Returns:
List of connector credential pair dicts with id, connector_id, credential_id, name, status
"""
print(f"Fetching connector credential pairs for tenant: {tenant_id}")
# Get the path to the script
script_dir = Path(__file__).parent
get_connectors_script = script_dir / "on_pod_scripts" / "get_tenant_connectors.py"
if not get_connectors_script.exists():
raise FileNotFoundError(
f"get_tenant_connectors.py not found at {get_connectors_script}"
)
try:
# Copy script to pod
print(" Copying script to pod...")
subprocess.run(
[
"kubectl",
"cp",
str(get_connectors_script),
f"{pod_name}:/tmp/get_tenant_connectors.py",
],
check=True,
capture_output=True,
)
# Execute script on pod
print(" Executing script on pod...")
result = subprocess.run(
[
"kubectl",
"exec",
pod_name,
"--",
"python",
"/tmp/get_tenant_connectors.py",
tenant_id,
],
capture_output=True,
text=True,
check=True,
)
# Show progress messages from stderr
if result.stderr:
print(f" {result.stderr}", end="")
# Parse JSON result from stdout
result_data = json.loads(result.stdout)
status = result_data.get("status")
if status == "success":
connectors = result_data.get("connectors", [])
if connectors:
print(f"✓ Found {len(connectors)} connector credential pair(s):")
for cc in connectors:
print(
f" - CC Pair ID: {cc['id']}, Name: {cc['name']}, Status: {cc['status']}"
)
else:
print(" No connector credential pairs found for tenant")
return connectors
else:
message = result_data.get("message", "Unknown error")
print(f"⚠ Could not fetch connectors: {message}")
return []
except subprocess.CalledProcessError as e:
print(f"⚠ Failed to get connectors for tenant {tenant_id}: {e}")
if e.stderr:
print(f" Error details: {e.stderr}")
return []
except Exception as e:
print(f"⚠ Failed to get connectors for tenant {tenant_id}: {e}")
return []
# Global lock for thread-safe printing
_print_lock: Lock = Lock()
def mark_connector_for_deletion(pod_name: str, tenant_id: str, cc_pair_id: int) -> None:
"""Mark a connector credential pair for deletion.
def safe_print(*args: Any, **kwargs: Any) -> None:
"""Thread-safe print function."""
with _print_lock:
print(*args, **kwargs)
def run_connector_deletion(pod_name: str, tenant_id: str) -> None:
"""Mark all connector credential pairs for deletion.
Args:
pod_name: The Kubernetes pod name to execute on
tenant_id: The tenant ID
cc_pair_id: The connector credential pair ID to mark for deletion
"""
print(f" Marking CC pair {cc_pair_id} for deletion...")
safe_print(" Marking all connector credential pairs for deletion...")
# Get the path to the script
script_dir = Path(__file__).parent
mark_deletion_script = (
script_dir / "on_pod_scripts" / "mark_connector_for_deletion.py"
script_dir / "on_pod_scripts" / "execute_connector_deletion.py"
)
if not mark_deletion_script.exists():
raise FileNotFoundError(
f"mark_connector_for_deletion.py not found at {mark_deletion_script}"
f"execute_connector_deletion.py not found at {mark_deletion_script}"
)
try:
@@ -147,7 +75,7 @@ def mark_connector_for_deletion(pod_name: str, tenant_id: str, cc_pair_id: int)
"kubectl",
"cp",
str(mark_deletion_script),
f"{pod_name}:/tmp/mark_connector_for_deletion.py",
f"{pod_name}:/tmp/execute_connector_deletion.py",
],
check=True,
capture_output=True,
@@ -161,166 +89,119 @@ def mark_connector_for_deletion(pod_name: str, tenant_id: str, cc_pair_id: int)
pod_name,
"--",
"python",
"/tmp/mark_connector_for_deletion.py",
"/tmp/execute_connector_deletion.py",
tenant_id,
str(cc_pair_id),
"--all",
],
capture_output=True,
text=True,
check=True,
)
# Show progress messages from stderr
if result.stderr:
print(f" {result.stderr}", end="")
# Parse JSON result from stdout
result_data = json.loads(result.stdout)
status = result_data.get("status")
message = result_data.get("message")
if status == "success":
print(f"{message}")
else:
print(f"{message}", file=sys.stderr)
raise RuntimeError(message)
if result.returncode != 0:
raise RuntimeError(result.stderr)
except subprocess.CalledProcessError as e:
print(
f" ✗ Failed to mark CC pair {cc_pair_id} for deletion: {e}",
safe_print(
f" ✗ Failed to mark all connector credential pairs for deletion: {e}",
file=sys.stderr,
)
if e.stderr:
print(f" Error details: {e.stderr}", file=sys.stderr)
safe_print(f" Error details: {e.stderr}", file=sys.stderr)
raise
except Exception as e:
print(
f" ✗ Failed to mark CC pair {cc_pair_id} for deletion: {e}",
safe_print(
f" ✗ Failed to mark all connector credential pairs for deletion: {e}",
file=sys.stderr,
)
raise
def mark_tenant_connectors_for_deletion(tenant_id: str, force: bool = False) -> None:
def mark_tenant_connectors_for_deletion(
tenant_id: str, pod_name: str, force: bool = False
) -> None:
"""
Main function to mark all connectors for a tenant for deletion.
Args:
tenant_id: The tenant ID to process
pod_name: The Kubernetes pod name to execute on
force: If True, skip all confirmation prompts
"""
print(f"Processing connectors for tenant: {tenant_id}")
safe_print(f"Processing connectors for tenant: {tenant_id}")
# Check tenant status first
print(f"\n{'=' * 80}")
safe_print(f"\n{'=' * 80}")
try:
tenant_status = get_tenant_status(tenant_id)
# If tenant is not GATED_ACCESS, require explicit confirmation even in force mode
if tenant_status and tenant_status != "GATED_ACCESS":
print(
safe_print(
f"\n⚠️ WARNING: Tenant status is '{tenant_status}', not 'GATED_ACCESS'!"
)
print(
safe_print(
"This tenant may be active and should not have connectors deleted without careful review."
)
print(f"{'=' * 80}\n")
safe_print(f"{'=' * 80}\n")
# Always ask for confirmation if not gated, even in force mode
response = input(
"Are you ABSOLUTELY SURE you want to proceed? Type 'yes' to confirm: "
)
if response.lower() != "yes":
print("Operation aborted - tenant is not GATED_ACCESS")
return
# Note: In parallel mode with force, this will still block
if not force:
response = input(
"Are you ABSOLUTELY SURE you want to proceed? Type 'yes' to confirm: "
)
if response.lower() != "yes":
safe_print("Operation aborted - tenant is not GATED_ACCESS")
raise RuntimeError(f"Tenant {tenant_id} is not GATED_ACCESS")
else:
raise RuntimeError(f"Tenant {tenant_id} is not GATED_ACCESS")
elif tenant_status == "GATED_ACCESS":
print("✓ Tenant status is GATED_ACCESS - safe to proceed")
safe_print("✓ Tenant status is GATED_ACCESS - safe to proceed")
elif tenant_status is None:
print("⚠️ WARNING: Could not determine tenant status!")
safe_print("⚠️ WARNING: Could not determine tenant status!")
if not force:
response = input("Continue anyway? Type 'yes' to confirm: ")
if response.lower() != "yes":
safe_print("Operation aborted - could not verify tenant status")
raise RuntimeError(
f"Could not verify tenant status for {tenant_id}"
)
else:
raise RuntimeError(f"Could not verify tenant status for {tenant_id}")
except Exception as e:
safe_print(f"⚠️ WARNING: Failed to check tenant status: {e}")
if not force:
response = input("Continue anyway? Type 'yes' to confirm: ")
if response.lower() != "yes":
print("Operation aborted - could not verify tenant status")
return
except Exception as e:
print(f"⚠️ WARNING: Failed to check tenant status: {e}")
response = input("Continue anyway? Type 'yes' to confirm: ")
if response.lower() != "yes":
print("Operation aborted - could not verify tenant status")
return
print(f"{'=' * 80}\n")
safe_print("Operation aborted - could not verify tenant status")
raise
else:
raise RuntimeError(f"Failed to check tenant status for {tenant_id}")
safe_print(f"{'=' * 80}\n")
# Find heavy worker pod for operations
try:
pod_name = find_worker_pod()
except Exception as e:
print(f"✗ Failed to find heavy worker pod: {e}", file=sys.stderr)
print("Cannot proceed with marking connectors for deletion")
return
# Fetch connectors
print(f"\n{'=' * 80}")
try:
connectors = get_tenant_connectors(pod_name, tenant_id)
except Exception as e:
print(f"✗ Failed to fetch connectors: {e}", file=sys.stderr)
return
print(f"{'=' * 80}\n")
if not connectors:
print(f"No connectors found for tenant {tenant_id}, nothing to do.")
return
# Confirm before proceeding
# Confirm before proceeding (only in non-force mode)
if not confirm_step(
f"Mark {len(connectors)} connector credential pair(s) for deletion?",
f"Mark all connector credential pairs for deletion for tenant {tenant_id}?",
force,
):
print("Operation cancelled by user")
return
safe_print("Operation cancelled by user")
raise ValueError("Operation cancelled by user")
# Mark each connector for deletion
failed_connectors = []
successful_connectors = []
for cc in connectors:
cc_pair_id = cc["id"]
cc_name = cc["name"]
cc_status = cc["status"]
# Skip if already marked for deletion
if cc_status == "DELETING":
print(
f" Skipping CC pair {cc_pair_id} ({cc_name}) - already marked for deletion"
)
continue
try:
mark_connector_for_deletion(pod_name, tenant_id, cc_pair_id)
successful_connectors.append(cc_pair_id)
except Exception as e:
print(
f" ✗ Failed to mark CC pair {cc_pair_id} ({cc_name}) for deletion: {e}",
file=sys.stderr,
)
failed_connectors.append((cc_pair_id, cc_name, str(e)))
run_connector_deletion(pod_name, tenant_id)
# Print summary
print(f"\n{'=' * 80}")
print(f"✓ Marked {len(successful_connectors)} connector(s) for deletion")
if failed_connectors:
print(f"✗ Failed to mark {len(failed_connectors)} connector(s):")
for cc_id, cc_name, error in failed_connectors:
print(f" - CC Pair {cc_id} ({cc_name}): {error}")
print(f"{'=' * 80}")
safe_print(
f"✓ Marked all connector credential pairs for deletion for tenant {tenant_id}"
)
def main() -> None:
if len(sys.argv) < 2:
print(
"Usage: python backend/scripts/tenant_cleanup/mark_connectors_for_deletion.py <tenant_id> [--force]"
"Usage: python backend/scripts/tenant_cleanup/mark_connectors_for_deletion.py <tenant_id> [--force] "
"[--concurrency N]"
)
print(
" python backend/scripts/tenant_cleanup/mark_connectors_for_deletion.py --csv <csv_file_path> [--force]"
" [--concurrency N]"
)
print("\nArguments:")
print(
@@ -328,6 +209,7 @@ def main() -> None:
)
print(" --csv PATH Path to CSV file containing tenant IDs to process")
print(" --force Skip all confirmation prompts (optional)")
print(" --concurrency N Process N tenants concurrently (default: 1)")
print("\nExamples:")
print(
" python backend/scripts/tenant_cleanup/mark_connectors_for_deletion.py tenant_abc123-def456-789"
@@ -339,23 +221,48 @@ def main() -> None:
" python backend/scripts/tenant_cleanup/mark_connectors_for_deletion.py --csv gated_tenants_no_query_3mo.csv"
)
print(
" python backend/scripts/tenant_cleanup/mark_connectors_for_deletion.py --csv gated_tenants_no_query_3mo.csv --force"
" python backend/scripts/tenant_cleanup/mark_connectors_for_deletion.py --csv gated_tenants_no_query_3mo.csv "
"--force --concurrency 16"
)
sys.exit(1)
# Parse arguments
force = "--force" in sys.argv
tenant_ids = []
tenant_ids: list[str] = []
# Parse concurrency
concurrency: int = 1
if "--concurrency" in sys.argv:
try:
concurrency_index = sys.argv.index("--concurrency")
if concurrency_index + 1 >= len(sys.argv):
print("Error: --concurrency flag requires a number", file=sys.stderr)
sys.exit(1)
concurrency = int(sys.argv[concurrency_index + 1])
if concurrency < 1:
print("Error: concurrency must be at least 1", file=sys.stderr)
sys.exit(1)
except ValueError:
print("Error: --concurrency value must be an integer", file=sys.stderr)
sys.exit(1)
# Validate: concurrency > 1 requires --force
if concurrency > 1 and not force:
print(
"Error: --concurrency > 1 requires --force flag (interactive mode not supported with parallel processing)",
file=sys.stderr,
)
sys.exit(1)
# Check for CSV mode
if "--csv" in sys.argv:
try:
csv_index = sys.argv.index("--csv")
csv_index: int = sys.argv.index("--csv")
if csv_index + 1 >= len(sys.argv):
print("Error: --csv flag requires a file path", file=sys.stderr)
sys.exit(1)
csv_path = sys.argv[csv_index + 1]
csv_path: str = sys.argv[csv_index + 1]
tenant_ids = read_tenant_ids_from_csv(csv_path)
if not tenant_ids:
@@ -371,6 +278,16 @@ def main() -> None:
# Single tenant mode
tenant_ids = [sys.argv[1]]
# Find heavy worker pod once before processing
try:
print("Finding worker pod...")
pod_name: str = find_worker_pod()
print(f"✓ Using worker pod: {pod_name}")
except Exception as e:
print(f"✗ Failed to find heavy worker pod: {e}", file=sys.stderr)
print("Cannot proceed with marking connectors for deletion")
sys.exit(1)
# Initial confirmation (unless --force is used)
if not force:
print(f"\n{'=' * 80}")
@@ -387,6 +304,7 @@ def main() -> None:
print(
f"Mode: {'FORCE (no confirmations)' if force else 'Interactive (will ask for confirmation at each step)'}"
)
print(f"Concurrency: {concurrency} tenant(s) at a time")
print("\nThis will:")
print(" 1. Fetch all connector credential pairs for each tenant")
print(" 2. Cancel any scheduled indexing attempts for each connector")
@@ -409,37 +327,78 @@ def main() -> None:
)
else:
print(
f"⚠ FORCE MODE: Marking connectors for deletion for {len(tenant_ids)} tenants without confirmations"
f"⚠ FORCE MODE: Marking connectors for deletion for {len(tenant_ids)} tenants "
f"(concurrency: {concurrency}) without confirmations"
)
# Process each tenant
failed_tenants = []
successful_tenants = []
# Process tenants (in parallel if concurrency > 1)
failed_tenants: list[tuple[str, str]] = []
successful_tenants: list[str] = []
for idx, tenant_id in enumerate(tenant_ids, 1):
if len(tenant_ids) > 1:
print(f"\n{'=' * 80}")
print(f"Processing tenant {idx}/{len(tenant_ids)}: {tenant_id}")
print(f"{'=' * 80}")
if concurrency == 1:
# Sequential processing
for idx, tenant_id in enumerate(tenant_ids, 1):
if len(tenant_ids) > 1:
print(f"\n{'=' * 80}")
print(f"Processing tenant {idx}/{len(tenant_ids)}: {tenant_id}")
print(f"{'=' * 80}")
try:
mark_tenant_connectors_for_deletion(tenant_id, force)
successful_tenants.append(tenant_id)
except Exception as e:
print(
f"✗ Failed to process tenant {tenant_id}: {e}",
file=sys.stderr,
)
failed_tenants.append((tenant_id, str(e)))
# If not in force mode and there are more tenants, ask if we should continue
if not force and idx < len(tenant_ids):
response = input(
f"\nContinue with remaining {len(tenant_ids) - idx} tenant(s)? (y/n): "
try:
mark_tenant_connectors_for_deletion(tenant_id, pod_name, force)
successful_tenants.append(tenant_id)
except Exception as e:
print(
f"✗ Failed to process tenant {tenant_id}: {e}",
file=sys.stderr,
)
if response.lower() != "y":
print("Operation aborted by user")
break
failed_tenants.append((tenant_id, str(e)))
# If not in force mode and there are more tenants, ask if we should continue
if not force and idx < len(tenant_ids):
response = input(
f"\nContinue with remaining {len(tenant_ids) - idx} tenant(s)? (y/n): "
)
if response.lower() != "y":
print("Operation aborted by user")
break
else:
# Parallel processing
print(
f"\nProcessing {len(tenant_ids)} tenant(s) with concurrency={concurrency}"
)
def process_tenant(tenant_id: str) -> tuple[str, bool, str | None]:
"""Process a single tenant. Returns (tenant_id, success, error_message)."""
try:
mark_tenant_connectors_for_deletion(tenant_id, pod_name, force)
return (tenant_id, True, None)
except Exception as e:
return (tenant_id, False, str(e))
with ThreadPoolExecutor(max_workers=concurrency) as executor:
# Submit all tasks
future_to_tenant = {
executor.submit(process_tenant, tenant_id): tenant_id
for tenant_id in tenant_ids
}
# Process results as they complete
completed: int = 0
for future in as_completed(future_to_tenant):
completed += 1
tenant_id, success, error = future.result()
if success:
successful_tenants.append(tenant_id)
safe_print(
f"[{completed}/{len(tenant_ids)}] ✓ Successfully processed {tenant_id}"
)
else:
failed_tenants.append((tenant_id, error or "Unknown error"))
safe_print(
f"[{completed}/{len(tenant_ids)}] ✗ Failed to process {tenant_id}: {error}",
file=sys.stderr,
)
# Print summary if multiple tenants
if len(tenant_ids) > 1:

View File

@@ -51,13 +51,14 @@ def check_documents_deleted(tenant_id: str) -> dict:
cc_count = cc_count or 0
doc_count = doc_count or 0
# If any records remain, return error status
if cc_count > 0 or doc_count > 0:
# If any records remain beyond acceptable thresholds, return error status
if cc_count > 0 or doc_count > 5:
return {
"status": "error",
"message": (
f"Found {cc_count} ConnectorCredentialPair(s) and {doc_count} Document(s) "
"still remaining. All documents must be deleted before cleanup."
"still remaining. Must have 0 ConnectorCredentialPairs and no more than "
"5 Documents before cleanup."
),
"connector_credential_pair_count": cc_count,
"document_count": doc_count,

View File

@@ -0,0 +1,348 @@
#!/usr/bin/env python3
"""
Script to mark connector credential pairs for deletion.
Runs on a Kubernetes pod with access to the data plane database.
Usage:
# Mark a specific connector for deletion
python mark_connector_for_deletion.py <tenant_id> <cc_pair_id>
# Mark all connectors for deletion
python mark_connector_for_deletion.py <tenant_id> --all
Output:
JSON to stdout with structure:
{
"status": "success" | "error",
"message": str,
"deleted_count": int (when using --all),
"timing": {
"total_seconds": float,
"per_connector": [...]
}
}
"""
import json
import sys
import time
from typing import Any
from sqlalchemy.orm import Session
from onyx.background.celery.versioned_apps.client import app as client_app
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import get_connector_credential_pairs
from onyx.db.connector_credential_pair import update_connector_credential_pair_from_id
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.engine.sql_engine import SqlEngine
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.index_attempt import cancel_indexing_attempts_for_ccpair
def mark_connector_for_deletion(
tenant_id: str, cc_pair_id: int, db_session: Session | None = None
) -> dict[str, Any]:
"""Mark a connector credential pair for deletion.
Args:
tenant_id: The tenant ID
cc_pair_id: The connector credential pair ID
db_session: Optional database session (if None, creates a new one)
Returns:
Dict with status, message, and timing
"""
timing: dict[str, float] = {}
start_time: float = time.time()
try:
print(
f"Marking connector credential pair {cc_pair_id} for deletion",
file=sys.stderr,
)
def _mark_deletion(db_sess: Session) -> dict[str, Any]:
# Get the connector credential pair
fetch_start: float = time.time()
cc_pair = get_connector_credential_pair_from_id(
db_session=db_sess,
cc_pair_id=cc_pair_id,
)
timing["fetch_cc_pair_seconds"] = time.time() - fetch_start
if not cc_pair:
return {
"status": "error",
"message": f"Connector credential pair {cc_pair_id} not found",
"timing": timing,
}
# Cancel any scheduled indexing attempts
print(
f"Canceling indexing attempts for CC pair {cc_pair_id}",
file=sys.stderr,
)
cancel_start: float = time.time()
cancel_indexing_attempts_for_ccpair(
cc_pair_id=cc_pair.id,
db_session=db_sess,
include_secondary_index=True,
)
timing["cancel_indexing_seconds"] = time.time() - cancel_start
# Mark as deleting
print(
f"Updating CC pair {cc_pair_id} status to DELETING",
file=sys.stderr,
)
update_start: float = time.time()
update_connector_credential_pair_from_id(
db_session=db_sess,
cc_pair_id=cc_pair.id,
status=ConnectorCredentialPairStatus.DELETING,
)
timing["update_status_seconds"] = time.time() - update_start
commit_start: float = time.time()
db_sess.commit()
timing["commit_seconds"] = time.time() - commit_start
return {
"status": "success",
"message": f"Marked connector credential pair {cc_pair_id} for deletion",
"timing": timing,
}
result: dict[str, Any]
if db_session:
result = _mark_deletion(db_session)
else:
with get_session_with_tenant(tenant_id=tenant_id) as db_sess:
result = _mark_deletion(db_sess)
# Trigger the deletion check task
print(
"Triggering connector deletion check task",
file=sys.stderr,
)
task_start: float = time.time()
client_app.send_task(
OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
priority=OnyxCeleryPriority.HIGH,
kwargs={"tenant_id": tenant_id},
)
timing["send_task_seconds"] = time.time() - task_start
timing["total_seconds"] = time.time() - start_time
result["timing"] = timing
return result
except Exception as e:
print(
f"Error marking connector for deletion: {e}",
file=sys.stderr,
)
timing["total_seconds"] = time.time() - start_time
return {
"status": "error",
"message": str(e),
"timing": timing,
}
def mark_all_connectors_for_deletion(tenant_id: str) -> dict[str, Any]:
"""Mark all connector credential pairs for a tenant for deletion.
Args:
tenant_id: The tenant ID
Returns:
Dict with status, message, deleted_count, and timing
"""
overall_start: float = time.time()
per_connector_timing: list[dict[str, Any]] = []
try:
print(
f"Marking all connector credential pairs for tenant {tenant_id} for deletion",
file=sys.stderr,
)
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
# Get all connector credential pairs
fetch_all_start: float = time.time()
cc_pairs = get_connector_credential_pairs(db_session=db_session)
fetch_all_time: float = time.time() - fetch_all_start
print(
f"Found {len(cc_pairs)} connector credential pairs to delete",
file=sys.stderr,
)
if not cc_pairs:
return {
"status": "success",
"message": "No connector credential pairs found for tenant",
"deleted_count": 0,
"timing": {
"fetch_all_seconds": fetch_all_time,
"total_seconds": time.time() - overall_start,
},
}
deleted_count: int = 0
errors: list[str] = []
for cc_pair in cc_pairs:
connector_start: float = time.time()
print(
f"Processing CC pair {cc_pair.id} ({deleted_count + 1}/{len(cc_pairs)})",
file=sys.stderr,
)
# Cancel any scheduled indexing attempts
cancel_start: float = time.time()
cancel_indexing_attempts_for_ccpair(
cc_pair_id=cc_pair.id,
db_session=db_session,
include_secondary_index=True,
)
cancel_time: float = time.time() - cancel_start
# Mark as deleting
update_start: float = time.time()
try:
update_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair.id,
status=ConnectorCredentialPairStatus.DELETING,
)
deleted_count += 1
except Exception as e:
errors.append(f"CC pair {cc_pair.id}: {str(e)}")
print(
f"Error updating CC pair {cc_pair.id}: {e}",
file=sys.stderr,
)
update_time: float = time.time() - update_start
connector_total_time: float = time.time() - connector_start
per_connector_timing.append(
{
"cc_pair_id": cc_pair.id,
"cancel_indexing_seconds": cancel_time,
"update_status_seconds": update_time,
"total_seconds": connector_total_time,
}
)
# Commit all changes
commit_start: float = time.time()
db_session.commit()
commit_time: float = time.time() - commit_start
# Trigger the deletion check task
print(
"Triggering connector deletion check task",
file=sys.stderr,
)
task_start: float = time.time()
client_app.send_task(
OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
priority=OnyxCeleryPriority.HIGH,
kwargs={"tenant_id": tenant_id},
)
task_time: float = time.time() - task_start
total_time: float = time.time() - overall_start
result: dict[str, Any] = {
"status": "success",
"message": f"Marked {deleted_count} connector credential pairs for deletion",
"deleted_count": deleted_count,
"timing": {
"fetch_all_seconds": fetch_all_time,
"commit_seconds": commit_time,
"send_task_seconds": task_time,
"total_seconds": total_time,
"per_connector": per_connector_timing,
},
}
if errors:
result["errors"] = errors
return result
except Exception as e:
print(
f"Error marking all connectors for deletion: {e}",
file=sys.stderr,
)
return {
"status": "error",
"message": str(e),
"timing": {
"total_seconds": time.time() - overall_start,
"per_connector": per_connector_timing,
},
}
def main() -> None:
if len(sys.argv) < 2 or len(sys.argv) > 3:
print(
json.dumps(
{
"status": "error",
"message": "Usage: python mark_connector_for_deletion.py <tenant_id> [<cc_pair_id>|--all]",
}
)
)
sys.exit(1)
tenant_id: str = sys.argv[1]
SqlEngine.init_engine(pool_size=5, max_overflow=2)
result: dict[str, Any]
# Check if we should mark all connectors or just one
if len(sys.argv) == 3:
second_arg: str = sys.argv[2]
if second_arg == "--all":
result = mark_all_connectors_for_deletion(tenant_id)
else:
try:
cc_pair_id: int = int(second_arg)
result = mark_connector_for_deletion(tenant_id, cc_pair_id)
except ValueError:
print(
json.dumps(
{
"status": "error",
"message": "cc_pair_id must be an integer or use --all",
}
)
)
sys.exit(1)
else:
# If only tenant_id is provided, show error
print(
json.dumps(
{
"status": "error",
"message": "Usage: python mark_connector_for_deletion.py <tenant_id> [<cc_pair_id>|--all]",
}
)
)
sys.exit(1)
print(json.dumps(result, indent=2))
if __name__ == "__main__":
main()

View File

@@ -1,144 +0,0 @@
#!/usr/bin/env python3
"""
Script to mark a connector credential pair for deletion.
Runs on a Kubernetes pod with access to the data plane database.
Usage:
python mark_connector_for_deletion.py <tenant_id> <cc_pair_id>
Output:
JSON to stdout with structure:
{
"status": "success" | "error",
"message": str
}
"""
import json
import sys
from onyx.background.celery.versioned_apps.client import app as client_app
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryTask
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import update_connector_credential_pair_from_id
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.db.engine.sql_engine import SqlEngine
from onyx.db.enums import ConnectorCredentialPairStatus
from onyx.db.index_attempt import cancel_indexing_attempts_for_ccpair
def mark_connector_for_deletion(tenant_id: str, cc_pair_id: int) -> dict:
"""Mark a connector credential pair for deletion.
Args:
tenant_id: The tenant ID
cc_pair_id: The connector credential pair ID
Returns:
Dict with status and message
"""
try:
print(
f"Marking connector credential pair {cc_pair_id} for deletion",
file=sys.stderr,
)
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
# Get the connector credential pair
cc_pair = get_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair_id,
)
if not cc_pair:
return {
"status": "error",
"message": f"Connector credential pair {cc_pair_id} not found",
}
# Cancel any scheduled indexing attempts
print(
f"Canceling indexing attempts for CC pair {cc_pair_id}",
file=sys.stderr,
)
cancel_indexing_attempts_for_ccpair(
cc_pair_id=cc_pair.id,
db_session=db_session,
include_secondary_index=True,
)
# Mark as deleting
print(
f"Updating CC pair {cc_pair_id} status to DELETING",
file=sys.stderr,
)
update_connector_credential_pair_from_id(
db_session=db_session,
cc_pair_id=cc_pair.id,
status=ConnectorCredentialPairStatus.DELETING,
)
db_session.commit()
# Trigger the deletion check task
print(
"Triggering connector deletion check task",
file=sys.stderr,
)
client_app.send_task(
OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
priority=OnyxCeleryPriority.HIGH,
kwargs={"tenant_id": tenant_id},
)
return {
"status": "success",
"message": f"Marked connector credential pair {cc_pair_id} for deletion",
}
except Exception as e:
print(
f"Error marking connector for deletion: {e}",
file=sys.stderr,
)
return {
"status": "error",
"message": str(e),
}
def main() -> None:
if len(sys.argv) != 3:
print(
json.dumps(
{
"status": "error",
"message": "Usage: python mark_connector_for_deletion.py <tenant_id> <cc_pair_id>",
}
)
)
sys.exit(1)
tenant_id = sys.argv[1]
try:
cc_pair_id = int(sys.argv[2])
except ValueError:
print(
json.dumps(
{
"status": "error",
"message": "cc_pair_id must be an integer",
}
)
)
sys.exit(1)
SqlEngine.init_engine(pool_size=5, max_overflow=2)
result = mark_connector_for_deletion(tenant_id, cc_pair_id)
print(json.dumps(result))
if __name__ == "__main__":
main()

View File

@@ -1,9 +1,11 @@
import csv
import os
import time
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from io import BytesIO
from io import StringIO
from zipfile import ZipFile
import pytest
@@ -257,6 +259,48 @@ class TestUsageExportAPI:
assert "chat_messages.csv" in file_names
assert "users.csv" in file_names
# Verify chat_messages.csv has the expected columns
with zip_file.open("chat_messages.csv") as csv_file:
csv_content = csv_file.read().decode("utf-8")
csv_reader = csv.DictReader(StringIO(csv_content))
# Check that all expected columns are present
expected_columns = {
"session_id",
"user_id",
"flow_type",
"time_sent",
"assistant_name",
"user_email",
"number_of_tokens",
}
actual_columns = set(csv_reader.fieldnames or [])
assert (
expected_columns == actual_columns
), f"Expected columns {expected_columns}, but got {actual_columns}"
# Verify there's at least one row of data
rows = list(csv_reader)
assert len(rows) > 0, "Expected at least one message in the report"
# Verify the first row has non-empty values for all columns
first_row = rows[0]
for column in expected_columns:
assert column in first_row, f"Column {column} not found in row"
assert first_row[
column
], f"Column {column} has empty value in first row"
# Verify specific new fields have appropriate values
assert first_row["assistant_name"], "assistant_name should not be empty"
assert first_row["user_email"], "user_email should not be empty"
assert first_row[
"number_of_tokens"
].isdigit(), "number_of_tokens should be a numeric value"
assert (
int(first_row["number_of_tokens"]) >= 0
), "number_of_tokens should be non-negative"
def test_read_nonexistent_report(self, reset: None, admin_user: DATestUser) -> None:
# Try to download a report that doesn't exist
response = requests.get(

1
web/.gitignore vendored
View File

@@ -43,3 +43,4 @@ next-env.d.ts
# generated clients ... in particular, the API to the Onyx backend itself!
/src/lib/generated
.jest-cache

View File

@@ -1,11 +1,100 @@
module.exports = {
/**
* Jest configuration with separate projects for different test environments.
*
* We use two separate projects:
* 1. "unit" - Node environment for pure unit tests (no DOM needed)
* 2. "integration" - jsdom environment for React integration tests
*
* This allows us to run tests with the correct environment automatically
* without needing @jest-environment comments in every test file.
*/
// Shared configuration
const sharedConfig = {
preset: "ts-jest",
testEnvironment: "node",
setupFilesAfterEnv: ["<rootDir>/tests/setup/jest.setup.ts"],
// Performance: Use 50% of CPU cores for parallel execution
maxWorkers: "50%",
moduleNameMapper: {
// Mock react-markdown and related packages
"^react-markdown$": "<rootDir>/tests/setup/__mocks__/react-markdown.tsx",
"^remark-gfm$": "<rootDir>/tests/setup/__mocks__/remark-gfm.ts",
// Mock UserProvider
"^@/components/user/UserProvider$":
"<rootDir>/tests/setup/__mocks__/@/components/user/UserProvider.tsx",
// Path aliases (must come after specific mocks)
"^@/(.*)$": "<rootDir>/src/$1",
"^@tests/(.*)$": "<rootDir>/tests/$1",
// Mock CSS imports
"\\.(css|less|scss|sass)$": "identity-obj-proxy",
// Mock static file imports
"\\.(jpg|jpeg|png|gif|svg|woff|woff2|ttf|eot)$":
"<rootDir>/tests/setup/fileMock.js",
},
testPathIgnorePatterns: ["/node_modules/", "/tests/e2e/"],
testPathIgnorePatterns: ["/node_modules/", "/tests/e2e/", "/.next/"],
transformIgnorePatterns: [
"/node_modules/(?!(jose|@radix-ui|@headlessui|@phosphor-icons|msw|until-async|react-markdown|remark-gfm|remark-parse|unified|bail|is-plain-obj|trough|vfile|unist-.*|mdast-.*|micromark.*|decode-named-character-reference|character-entities)/)",
],
transform: {
"^.+\\.tsx?$": "ts-jest",
"^.+\\.tsx?$": [
"ts-jest",
{
// Performance: Disable type-checking in tests (types are checked by tsc)
isolatedModules: true,
tsconfig: {
jsx: "react-jsx",
},
},
],
},
// Performance: Cache results between runs
cache: true,
cacheDirectory: "<rootDir>/.jest-cache",
collectCoverageFrom: [
"src/**/*.{ts,tsx}",
"!src/**/*.d.ts",
"!src/**/*.stories.tsx",
],
coveragePathIgnorePatterns: ["/node_modules/", "/tests/", "/.next/"],
// Performance: Clear mocks automatically between tests
clearMocks: true,
resetMocks: false,
restoreMocks: false,
};
module.exports = {
projects: [
{
displayName: "unit",
...sharedConfig,
testEnvironment: "node",
testMatch: [
// Pure unit tests that don't need DOM
"**/src/**/codeUtils.test.ts",
"**/src/lib/**/*.test.ts",
// Add more patterns here as you add more unit tests
],
},
{
displayName: "integration",
...sharedConfig,
testEnvironment: "jsdom",
testMatch: [
// React component integration tests
"**/src/app/**/*.test.tsx",
"**/src/components/**/*.test.tsx",
"**/src/lib/**/*.test.tsx",
// Add more patterns here as you add more integration tests
],
},
],
};

938
web/package-lock.json generated

File diff suppressed because it is too large Load Diff

View File

@@ -11,7 +11,13 @@
"lint:unused": "eslint --ext .js,.jsx,.ts,.tsx --rule 'unused-imports/no-unused-imports: error' --quiet --fix=false src/",
"lint:fix-unused": "eslint --ext .js,.jsx,.ts,.tsx --rule 'unused-imports/no-unused-imports: error' --quiet --fix src/",
"lint:fix-unused-vars": "eslint --ext .js,.jsx,.ts,.tsx --fix --quiet src/",
"test": "jest"
"test": "jest",
"test:watch": "jest --watch",
"test:coverage": "jest --coverage",
"test:verbose": "jest --verbose",
"test:ci": "jest --ci --maxWorkers=2 --silent --bail",
"test:changed": "jest --onlyChanged",
"test:debug": "node --inspect-brk node_modules/.bin/jest --runInBand"
},
"dependencies": {
"@dnd-kit/core": "^6.1.0",
@@ -96,6 +102,9 @@
"@chromatic-com/playwright": "^0.10.2",
"@playwright/test": "^1.39.0",
"@tailwindcss/typography": "^0.5.10",
"@testing-library/jest-dom": "^6.9.1",
"@testing-library/react": "^14.3.1",
"@testing-library/user-event": "^14.6.1",
"@types/chrome": "^0.0.287",
"@types/jest": "^29.5.14",
"@types/js-cookie": "^3.0.6",
@@ -109,10 +118,13 @@
"eslint": "^8.57.1",
"eslint-config-next": "^14.1.0",
"eslint-plugin-unused-imports": "^4.1.4",
"identity-obj-proxy": "^3.0.0",
"jest": "^29.7.0",
"jest-environment-jsdom": "^29.7.0",
"prettier": "3.1.0",
"ts-jest": "^29.2.5",
"ts-unused-exports": "^11.0.1"
"ts-unused-exports": "^11.0.1",
"whatwg-fetch": "^3.6.20"
},
"overrides": {
"react-is": "^19.0.0-rc-69d4b800-20241021"

View File

@@ -23,6 +23,8 @@ export default defineConfig({
// },
// ],
],
// Only run Playwright tests from tests/e2e directory (ignore Jest tests in src/)
testMatch: /.*\/tests\/e2e\/.*\.spec\.ts/,
projects: [
{
name: "admin",
@@ -31,7 +33,6 @@ export default defineConfig({
viewport: { width: 1280, height: 720 },
storageState: "admin_auth.json",
},
testIgnore: ["**/codeUtils.test.ts"],
},
{
name: "no-auth",
@@ -39,7 +40,6 @@ export default defineConfig({
...devices["Desktop Chrome"],
viewport: { width: 1280, height: 720 },
},
testIgnore: ["**/codeUtils.test.ts"],
},
],
});

View File

@@ -8,7 +8,7 @@ import { errorHandlingFetcher } from "@/lib/fetcher";
import Text from "@/refresh-components/texts/Text";
import useSWR, { mutate } from "swr";
import { ErrorCallout } from "@/components/ErrorCallout";
import { OnyxSparkleIcon } from "@/components/icons/icons";
import OnyxLogo from "@/icons/onyx-logo";
import { usePopup } from "@/components/admin/connectors/Popup";
import { useAgentsContext } from "@/refresh-components/contexts/AgentsContext";
import { Switch } from "@/components/ui/switch";
@@ -312,7 +312,9 @@ export default function Page() {
<div className="mx-auto max-w-4xl w-full">
<AdminPageTitle
title="Default Assistant"
icon={<OnyxSparkleIcon size={32} className="my-auto" />}
icon={
<OnyxLogo className="my-auto w-[1.5rem] h-[1.5rem] stroke-text-04" />
}
/>
<DefaultAssistantConfig />
</div>

View File

@@ -9,7 +9,7 @@ import { mutate } from "swr";
import { Badge } from "@/components/ui/badge";
import Button from "@/refresh-components/buttons/Button";
import Text from "@/refresh-components/texts/Text";
import { isSubset } from "@/lib/utils";
import { cn, isSubset } from "@/lib/utils";
function LLMProviderUpdateModal({
llmProviderDescriptor,
@@ -70,6 +70,29 @@ function LLMProviderDisplay({
const [formIsVisible, setFormIsVisible] = useState(false);
const { popup, setPopup } = usePopup();
async function handleSetAsDefault(): Promise<void> {
const response = await fetch(
`${LLM_PROVIDERS_ADMIN_URL}/${existingLlmProvider.id}/default`,
{
method: "POST",
}
);
if (!response.ok) {
const errorMsg = (await response.json()).detail;
setPopup({
type: "error",
message: `Failed to set provider as default: ${errorMsg}`,
});
return;
}
await mutate(LLM_PROVIDERS_ADMIN_URL);
setPopup({
type: "success",
message: "Provider set as default successfully!",
});
}
const providerName =
existingLlmProvider?.name ||
llmProviderDescriptor?.display_name ||
@@ -87,29 +110,8 @@ function LLMProviderDisplay({
</Text>
{!existingLlmProvider.is_default_provider && (
<Text
className="text-action-link-05"
onClick={async () => {
const response = await fetch(
`${LLM_PROVIDERS_ADMIN_URL}/${existingLlmProvider.id}/default`,
{
method: "POST",
}
);
if (!response.ok) {
const errorMsg = (await response.json()).detail;
setPopup({
type: "error",
message: `Failed to set provider as default: ${errorMsg}`,
});
return;
}
mutate(LLM_PROVIDERS_ADMIN_URL);
setPopup({
type: "success",
message: "Provider set as default successfully!",
});
}}
className={cn("text-action-link-05", "cursor-pointer")}
onClick={handleSetAsDefault}
>
Set as default
</Text>

View File

@@ -0,0 +1,464 @@
/**
* Integration Test: Custom LLM Provider Configuration Workflow
*
* Tests the complete user journey for configuring a custom LLM provider.
* This tests the full workflow: form fill → test config → save → set as default
*/
import React from "react";
import { render, screen, setupUser, waitFor } from "@tests/setup/test-utils";
import { CustomLLMProviderUpdateForm } from "./CustomLLMProviderUpdateForm";
// Mock SWR's mutate function
const mockMutate = jest.fn();
jest.mock("swr", () => ({
...jest.requireActual("swr"),
useSWRConfig: () => ({ mutate: mockMutate }),
}));
// Mock usePaidEnterpriseFeaturesEnabled
jest.mock("@/components/settings/usePaidEnterpriseFeaturesEnabled", () => ({
usePaidEnterpriseFeaturesEnabled: () => false,
}));
describe("Custom LLM Provider Configuration Workflow", () => {
let fetchSpy: jest.SpyInstance;
const mockOnClose = jest.fn();
const mockSetPopup = jest.fn();
beforeEach(() => {
jest.clearAllMocks();
fetchSpy = jest.spyOn(global, "fetch");
});
afterEach(() => {
fetchSpy.mockRestore();
});
test("creates a new custom LLM provider successfully", async () => {
const user = setupUser();
// Mock POST /api/admin/llm/test
fetchSpy.mockResolvedValueOnce({
ok: true,
json: async () => ({}),
} as Response);
// Mock PUT /api/admin/llm/provider?is_creation=true
fetchSpy.mockResolvedValueOnce({
ok: true,
json: async () => ({
id: 1,
name: "My Custom Provider",
provider: "openai",
api_key: "test-key",
default_model_name: "gpt-4",
}),
} as Response);
render(
<CustomLLMProviderUpdateForm
onClose={mockOnClose}
setPopup={mockSetPopup}
/>
);
// Fill in the form
const nameInput = screen.getByPlaceholderText(/display name/i);
const providerInput = screen.getByPlaceholderText(
/name of the custom provider/i
);
const apiKeyInput = screen.getByPlaceholderText(/api key/i);
await user.type(nameInput, "My Custom Provider");
await user.type(providerInput, "openai");
await user.type(apiKeyInput, "test-key-123");
// Fill in model configuration (use placeholder to find input)
const modelNameInput = screen.getByPlaceholderText(/model-name-1/i);
await user.type(modelNameInput, "gpt-4");
// Set default model (there are 2 inputs with this placeholder - default and fast)
// We want the first one (Default Model)
const defaultModelInputs = screen.getAllByPlaceholderText(/e\.g\. gpt-4/i);
await user.type(defaultModelInputs[0], "gpt-4");
// Submit the form
const submitButton = screen.getByRole("button", { name: /enable/i });
await user.click(submitButton);
// Verify test API was called first
await waitFor(() => {
expect(fetchSpy).toHaveBeenCalledWith(
"/api/admin/llm/test",
expect.objectContaining({
method: "POST",
headers: { "Content-Type": "application/json" },
})
);
});
// Verify create API was called
await waitFor(() => {
expect(fetchSpy).toHaveBeenCalledWith(
"/api/admin/llm/provider?is_creation=true",
expect.objectContaining({
method: "PUT",
headers: { "Content-Type": "application/json" },
})
);
});
// Verify success popup
await waitFor(() => {
expect(mockSetPopup).toHaveBeenCalledWith({
type: "success",
message: "Provider enabled successfully!",
});
});
// Verify onClose was called
expect(mockOnClose).toHaveBeenCalled();
// Verify SWR cache was invalidated
expect(mockMutate).toHaveBeenCalledWith("/api/admin/llm/provider");
});
test("shows error when test configuration fails", async () => {
const user = setupUser();
// Mock POST /api/admin/llm/test (failure)
fetchSpy.mockResolvedValueOnce({
ok: false,
status: 400,
json: async () => ({ detail: "Invalid API key" }),
} as Response);
render(
<CustomLLMProviderUpdateForm
onClose={mockOnClose}
setPopup={mockSetPopup}
/>
);
// Fill in the form with invalid credentials
const nameInput = screen.getByPlaceholderText(/display name/i);
const providerInput = screen.getByPlaceholderText(
/name of the custom provider/i
);
const apiKeyInput = screen.getByPlaceholderText(/api key/i);
await user.type(nameInput, "Bad Provider");
await user.type(providerInput, "openai");
await user.type(apiKeyInput, "invalid-key");
// Fill in model configuration
const modelNameInput = screen.getByPlaceholderText(/model-name-1/i);
await user.type(modelNameInput, "gpt-4");
// Set default model (there are 2 inputs with this placeholder - default and fast)
const defaultModelInputs = screen.getAllByPlaceholderText(/e\.g\. gpt-4/i);
await user.type(defaultModelInputs[0], "gpt-4");
// Submit the form
const submitButton = screen.getByRole("button", { name: /enable/i });
await user.click(submitButton);
// Verify test API was called
await waitFor(() => {
expect(fetchSpy).toHaveBeenCalledWith(
"/api/admin/llm/test",
expect.objectContaining({
method: "POST",
})
);
});
// Verify error is displayed (form should NOT proceed to create)
await waitFor(() => {
expect(screen.getByText(/invalid api key/i)).toBeInTheDocument();
});
// Verify create API was NOT called
expect(
fetchSpy.mock.calls.find((call) =>
call[0].includes("/api/admin/llm/provider")
)
).toBeUndefined();
});
test("updates an existing LLM provider", async () => {
const user = setupUser();
const existingProvider = {
id: 1,
name: "Existing Provider",
provider: "anthropic",
api_key: "old-key",
api_base: "",
api_version: "",
default_model_name: "claude-3-opus",
fast_default_model_name: null,
model_configurations: [
{ name: "claude-3-opus", is_visible: true, max_input_tokens: null },
],
custom_config: {},
is_public: true,
groups: [],
deployment_name: null,
};
// Mock POST /api/admin/llm/test
fetchSpy.mockResolvedValueOnce({
ok: true,
json: async () => ({}),
} as Response);
// Mock PUT /api/admin/llm/provider (update, no is_creation param)
fetchSpy.mockResolvedValueOnce({
ok: true,
json: async () => ({ ...existingProvider, api_key: "new-key" }),
} as Response);
render(
<CustomLLMProviderUpdateForm
onClose={mockOnClose}
existingLlmProvider={existingProvider}
setPopup={mockSetPopup}
/>
);
// Update the API key
const apiKeyInput = screen.getByPlaceholderText(/api key/i);
await user.clear(apiKeyInput);
await user.type(apiKeyInput, "new-key-456");
// Submit
const submitButton = screen.getByRole("button", { name: /update/i });
await user.click(submitButton);
// Verify test was called
await waitFor(() => {
expect(fetchSpy).toHaveBeenCalledWith(
"/api/admin/llm/test",
expect.any(Object)
);
});
// Verify update API was called (without is_creation param)
await waitFor(() => {
expect(fetchSpy).toHaveBeenCalledWith(
"/api/admin/llm/provider",
expect.objectContaining({
method: "PUT",
})
);
});
// Verify success message says "updated"
await waitFor(() => {
expect(mockSetPopup).toHaveBeenCalledWith({
type: "success",
message: "Provider updated successfully!",
});
});
});
test("sets provider as default when shouldMarkAsDefault is true", async () => {
const user = setupUser();
// Mock POST /api/admin/llm/test
fetchSpy.mockResolvedValueOnce({
ok: true,
json: async () => ({}),
} as Response);
// Mock PUT /api/admin/llm/provider?is_creation=true
fetchSpy.mockResolvedValueOnce({
ok: true,
json: async () => ({
id: 5,
name: "New Default Provider",
provider: "openai",
}),
} as Response);
// Mock POST /api/admin/llm/provider/5/default
fetchSpy.mockResolvedValueOnce({
ok: true,
json: async () => ({}),
} as Response);
render(
<CustomLLMProviderUpdateForm
onClose={mockOnClose}
setPopup={mockSetPopup}
shouldMarkAsDefault={true}
/>
);
// Fill form
const nameInput = screen.getByPlaceholderText(/display name/i);
await user.type(nameInput, "New Default Provider");
const providerInput = screen.getByPlaceholderText(
/name of the custom provider/i
);
await user.type(providerInput, "openai");
// Fill in model configuration
const modelNameInput = screen.getByPlaceholderText(/model-name-1/i);
await user.type(modelNameInput, "gpt-4");
// Set default model (there are 2 inputs with this placeholder - default and fast)
const defaultModelInputs = screen.getAllByPlaceholderText(/e\.g\. gpt-4/i);
await user.type(defaultModelInputs[0], "gpt-4");
// Submit
const submitButton = screen.getByRole("button", { name: /enable/i });
await user.click(submitButton);
// Verify set as default API was called
await waitFor(() => {
expect(fetchSpy).toHaveBeenCalledWith(
"/api/admin/llm/provider/5/default",
expect.objectContaining({
method: "POST",
})
);
});
});
test("shows error when provider creation fails", async () => {
const user = setupUser();
// Mock POST /api/admin/llm/test
fetchSpy.mockResolvedValueOnce({
ok: true,
json: async () => ({}),
} as Response);
// Mock PUT /api/admin/llm/provider?is_creation=true (failure)
fetchSpy.mockResolvedValueOnce({
ok: false,
status: 500,
json: async () => ({ detail: "Database error" }),
} as Response);
render(
<CustomLLMProviderUpdateForm
onClose={mockOnClose}
setPopup={mockSetPopup}
/>
);
// Fill form
const nameInput = screen.getByPlaceholderText(/display name/i);
await user.type(nameInput, "Test Provider");
const providerInput = screen.getByPlaceholderText(
/name of the custom provider/i
);
await user.type(providerInput, "openai");
// Fill in model configuration
const modelNameInput = screen.getByPlaceholderText(/model-name-1/i);
await user.type(modelNameInput, "gpt-4");
// Set default model (there are 2 inputs with this placeholder - default and fast)
const defaultModelInputs = screen.getAllByPlaceholderText(/e\.g\. gpt-4/i);
await user.type(defaultModelInputs[0], "gpt-4");
// Submit
const submitButton = screen.getByRole("button", { name: /enable/i });
await user.click(submitButton);
// Verify error popup
await waitFor(() => {
expect(mockSetPopup).toHaveBeenCalledWith({
type: "error",
message: "Failed to enable provider: Database error",
});
});
// Verify onClose was NOT called
expect(mockOnClose).not.toHaveBeenCalled();
});
test("adds custom configuration key-value pairs", async () => {
const user = setupUser();
// Mock POST /api/admin/llm/test
fetchSpy.mockResolvedValueOnce({
ok: true,
json: async () => ({}),
} as Response);
// Mock PUT /api/admin/llm/provider?is_creation=true
fetchSpy.mockResolvedValueOnce({
ok: true,
json: async () => ({ id: 1, name: "Provider with Custom Config" }),
} as Response);
render(
<CustomLLMProviderUpdateForm
onClose={mockOnClose}
setPopup={mockSetPopup}
/>
);
// Fill basic fields
const nameInput = screen.getByPlaceholderText(/display name/i);
await user.type(nameInput, "Cloudflare Provider");
const providerInput = screen.getByPlaceholderText(
/name of the custom provider/i
);
await user.type(providerInput, "cloudflare");
// Click "Add New" button for custom config (there are 2 "Add New" buttons - one for custom config, one for models)
// The custom config "Add New" appears first
const addNewButtons = screen.getAllByRole("button", { name: /add new/i });
const customConfigAddButton = addNewButtons[0]; // First "Add New" is for custom config
await user.click(customConfigAddButton);
// Fill in custom config key-value pair
const customConfigInputs = screen.getAllByRole("textbox");
const keyInput = customConfigInputs.find(
(input) => input.getAttribute("name") === "custom_config_list[0][0]"
);
const valueInput = customConfigInputs.find(
(input) => input.getAttribute("name") === "custom_config_list[0][1]"
);
expect(keyInput).toBeDefined();
expect(valueInput).toBeDefined();
await user.type(keyInput!, "CLOUDFLARE_ACCOUNT_ID");
await user.type(valueInput!, "my-account-id-123");
// Fill in model configuration
const modelNameInput = screen.getByPlaceholderText(/model-name-1/i);
await user.type(modelNameInput, "@cf/meta/llama-2-7b-chat-int8");
// Set default model (there are 2 inputs with this placeholder - default and fast)
const defaultModelInputs = screen.getAllByPlaceholderText(/e\.g\. gpt-4/i);
await user.type(defaultModelInputs[0], "@cf/meta/llama-2-7b-chat-int8");
// Submit
const submitButton = screen.getByRole("button", { name: /enable/i });
await user.click(submitButton);
// Verify the custom config was included in the request
await waitFor(() => {
const createCall = fetchSpy.mock.calls.find((call) =>
call[0].includes("/api/admin/llm/provider")
);
expect(createCall).toBeDefined();
const requestBody = JSON.parse(createCall![1].body);
expect(requestBody.custom_config).toEqual({
CLOUDFLARE_ACCOUNT_ID: "my-account-id-123",
});
});
});
});

View File

@@ -88,7 +88,13 @@ export function CustomLLMProviderUpdateForm({
Yup.object({
name: Yup.string().required("Model name is required"),
is_visible: Yup.boolean().required("Visibility is required"),
max_input_tokens: Yup.number().nullable().optional(),
// Coerce empty string from input field into null so it's optional
max_input_tokens: Yup.number()
.transform((value, originalValue) =>
originalValue === "" || originalValue === undefined ? null : value
)
.nullable()
.optional(),
})
),
default_model_name: Yup.string().required("Model name is required"),

View File

@@ -106,20 +106,25 @@ function AddCustomLLMProvider({
existingLlmProviders: LLMProviderView[];
}) {
const [formIsVisible, setFormIsVisible] = useState(false);
const { popup, setPopup } = usePopup();
if (formIsVisible) {
return (
<Modal
title={`Setup Custom LLM Provider`}
onOutsideClick={() => setFormIsVisible(false)}
>
<div className="max-h-[70vh] overflow-y-auto px-4">
<CustomLLMProviderUpdateForm
onClose={() => setFormIsVisible(false)}
shouldMarkAsDefault={existingLlmProviders.length === 0}
/>
</div>
</Modal>
<>
{popup}
<Modal
title={`Setup Custom LLM Provider`}
onOutsideClick={() => setFormIsVisible(false)}
>
<div className="max-h-[70vh] overflow-y-auto px-4">
<CustomLLMProviderUpdateForm
onClose={() => setFormIsVisible(false)}
shouldMarkAsDefault={existingLlmProviders.length === 0}
setPopup={setPopup}
/>
</div>
</Modal>
</>
);
}

View File

@@ -132,12 +132,14 @@ export function LLMProviderUpdateForm({
llmProviderDescriptor.default_api_base ??
"",
api_version: existingLlmProvider?.api_version ?? "",
// For Azure OpenAI, combine api_base and api_version into target_uri
// For Azure OpenAI, combine api_base, deployment_name, and api_version into target_uri
target_uri:
llmProviderDescriptor.name === "azure" &&
existingLlmProvider?.api_base &&
existingLlmProvider?.api_version
? `${existingLlmProvider.api_base}/openai/deployments/your-deployment?api-version=${existingLlmProvider.api_version}`
? `${existingLlmProvider.api_base}/openai/deployments/${
existingLlmProvider.deployment_name || "your-deployment"
}/chat/completions?api-version=${existingLlmProvider.api_version}`
: "",
default_model_name:
existingLlmProvider?.default_model_name ??
@@ -201,20 +203,22 @@ export function LLMProviderUpdateForm({
.required("Target URI is required")
.test(
"valid-target-uri",
"Target URI must be a valid URL with exactly one query parameter (api-version)",
"Target URI must be a valid URL with api-version query parameter and the deployment name in the path",
(value) => {
if (!value) return false;
try {
const url = new URL(value);
const params = new URLSearchParams(url.search);
const paramKeys = Array.from(params.keys());
const hasApiVersion = !!url.searchParams
.get("api-version")
?.trim();
// Check if there's exactly one parameter and it's api-version
return (
paramKeys.length === 1 &&
paramKeys[0] === "api-version" &&
!!params.get("api-version")
// Check if the path contains a deployment name
const pathMatch = url.pathname.match(
/\/openai\/deployments\/([^\/]+)/
);
const hasDeploymentName = pathMatch && pathMatch[1];
return hasApiVersion && !!hasDeploymentName;
} catch {
return false;
}
@@ -240,9 +244,11 @@ export function LLMProviderUpdateForm({
),
}
: {}),
deployment_name: llmProviderDescriptor.deployment_name_required
? Yup.string().required("Deployment Name is required")
: Yup.string().nullable(),
deployment_name:
llmProviderDescriptor.deployment_name_required &&
llmProviderDescriptor.name !== "azure"
? Yup.string().required("Deployment Name is required")
: Yup.string().nullable(),
default_model_name: Yup.string().required("Model name is required"),
fast_default_model_name: Yup.string().nullable(),
// EE Only
@@ -276,15 +282,24 @@ export function LLMProviderUpdateForm({
...rest
} = values;
// For Azure OpenAI, parse target_uri to extract api_base and api_version
// For Azure OpenAI, parse target_uri to extract api_base, api_version, and deployment_name
let finalApiBase = rest.api_base;
let finalApiVersion = rest.api_version;
let finalDeploymentName = rest.deployment_name;
if (llmProviderDescriptor.name === "azure" && target_uri) {
try {
const url = new URL(target_uri);
finalApiBase = url.origin; // Only use origin (protocol + hostname + port)
finalApiVersion = url.searchParams.get("api-version") || "";
// Extract deployment name from path: /openai/deployments/{deployment-name}/...
const pathMatch = url.pathname.match(
/\/openai\/deployments\/([^\/]+)/
);
if (pathMatch && pathMatch[1]) {
finalDeploymentName = pathMatch[1];
}
} catch (error) {
// This should not happen due to validation, but handle gracefully
console.error("Failed to parse target_uri:", error);
@@ -296,6 +311,7 @@ export function LLMProviderUpdateForm({
...rest,
api_base: finalApiBase,
api_version: finalApiVersion,
deployment_name: finalDeploymentName,
api_key_changed: values.api_key !== initialValues.api_key,
model_configurations: getCurrentModelConfigurations(values).map(
(modelConfiguration): ModelConfiguration => ({
@@ -451,7 +467,7 @@ export function LLMProviderUpdateForm({
name="target_uri"
label="Target URI"
placeholder="https://your-resource.cognitiveservices.azure.com/openai/deployments/deployment-name/chat/completions?api-version=2025-01-01-preview"
subtext="The complete Azure OpenAI endpoint URL including the API version as a query parameter"
subtext="The complete target URI for your deployment from the Azure AI portal."
/>
) : (
<>
@@ -612,13 +628,14 @@ export function LLMProviderUpdateForm({
/>
)}
{llmProviderDescriptor.deployment_name_required && (
<TextFormField
name="deployment_name"
label="Deployment Name"
placeholder="Deployment Name"
/>
)}
{llmProviderDescriptor.deployment_name_required &&
llmProviderDescriptor.name !== "azure" && (
<TextFormField
name="deployment_name"
label="Deployment Name"
placeholder="Deployment Name"
/>
)}
{!llmProviderDescriptor.single_model_supported &&
(currentModelConfigurations.length > 0 ? (
@@ -726,22 +743,24 @@ export function LLMProviderUpdateForm({
}
// If the deleted provider was the default, set the first remaining provider as default
const remainingProvidersResponse = await fetch(
LLM_PROVIDERS_ADMIN_URL
);
if (remainingProvidersResponse.ok) {
const remainingProviders =
await remainingProvidersResponse.json();
if (existingLlmProvider.is_default_provider) {
const remainingProvidersResponse = await fetch(
LLM_PROVIDERS_ADMIN_URL
);
if (remainingProvidersResponse.ok) {
const remainingProviders =
await remainingProvidersResponse.json();
if (remainingProviders.length > 0) {
const setDefaultResponse = await fetch(
`${LLM_PROVIDERS_ADMIN_URL}/${remainingProviders[0].id}/default`,
{
method: "POST",
if (remainingProviders.length > 0) {
const setDefaultResponse = await fetch(
`${LLM_PROVIDERS_ADMIN_URL}/${remainingProviders[0].id}/default`,
{
method: "POST",
}
);
if (!setDefaultResponse.ok) {
console.error("Failed to set new default provider");
}
);
if (!setDefaultResponse.ok) {
console.error("Failed to set new default provider");
}
}
}

View File

@@ -150,7 +150,8 @@ export function ModelConfigurationField({
arrayHelpers.push({
name: "",
is_visible: true,
max_input_tokens: "",
// Use null so Yup.number().nullable() accepts empty inputs
max_input_tokens: null,
});
}}
className="mt-3"

View File

@@ -1,6 +1,7 @@
import {
AnthropicIcon,
AmazonIcon,
AzureIcon,
CPUIcon,
MicrosoftIconSVG,
MistralIcon,
@@ -38,6 +39,8 @@ export const getProviderIcon = (
claude: AnthropicIcon,
anthropic: AnthropicIcon,
openai: OpenAISVG,
// Azure OpenAI should display the Azure logo
azure: AzureIcon,
microsoft: MicrosoftIconSVG,
meta: MetaIcon,
google: GeminiIcon,
@@ -126,6 +129,31 @@ export const dynamicProviderConfigs: Record<
successMessage: (count: number) =>
`Successfully fetched ${count} models from Ollama.`,
},
openrouter: {
endpoint: "/api/admin/llm/openrouter/available-models",
isDisabled: (values) => !values.api_base || !values.api_key,
disabledReason:
"API Base and API Key are required to fetch OpenRouter models",
buildRequestBody: ({ values }) => ({
api_base: values.api_base,
api_key: values.api_key,
}),
processResponse: (data: OllamaModelResponse[], llmProviderDescriptor) =>
data.map((modelData) => {
const existingConfig = llmProviderDescriptor.model_configurations.find(
(config) => config.name === modelData.name
);
return {
name: modelData.name,
is_visible: existingConfig?.is_visible ?? true,
max_input_tokens: modelData.max_input_tokens,
supports_image_input: modelData.supports_image_input,
};
}),
getModelNames: (data: OllamaModelResponse[]) => data.map((m) => m.name),
successMessage: (count: number) =>
`Successfully fetched ${count} models from OpenRouter.`,
},
};
export const fetchModels = async (

View File

@@ -19,14 +19,8 @@ import { localizeAndPrettify } from "@/lib/time";
import { getDocsProcessedPerMinute } from "@/lib/indexAttempt";
import { InfoIcon } from "@/components/icons/icons";
import ExceptionTraceModal from "@/components/modals/ExceptionTraceModal";
import {
Tooltip,
TooltipContent,
TooltipProvider,
TooltipTrigger,
} from "@/components/ui/tooltip";
import { FaBarsProgress } from "react-icons/fa6";
import SimpleTooltip from "@/refresh-components/SimpleTooltip";
import SvgClock from "@/icons/clock";
export interface IndexingAttemptsTableProps {
ccPair: CCPairFullInfo;
@@ -96,6 +90,14 @@ export function IndexAttemptsTable({
{indexAttempts.map((indexAttempt) => {
const docsPerMinute =
getDocsProcessedPerMinute(indexAttempt)?.toFixed(2);
const isReindexInProgress =
indexAttempt.status === "in_progress" ||
indexAttempt.status === "not_started";
const reindexTooltip = `This index attempt ${
isReindexInProgress ? "is" : "was"
} a full re-index. All documents from the source ${
isReindexInProgress ? "are being" : "were"
} synced into the system.`;
return (
<TableRow key={indexAttempt.id}>
<TableCell>
@@ -136,28 +138,11 @@ export function IndexAttemptsTable({
<div className="flex items-center">
{indexAttempt.total_docs_indexed}
{indexAttempt.from_beginning && (
<TooltipProvider>
<Tooltip>
<TooltipTrigger asChild>
<span className="cursor-help flex items-center">
<FaBarsProgress className="ml-2 h-3.5 w-3.5" />
</span>
</TooltipTrigger>
<TooltipContent>
This index attempt{" "}
{indexAttempt.status === "in_progress" ||
indexAttempt.status === "not_started"
? "is"
: "was"}{" "}
a full re-index. All documents from the source{" "}
{indexAttempt.status === "in_progress" ||
indexAttempt.status === "not_started"
? "are being "
: "were "}
synced into the system.
</TooltipContent>
</Tooltip>
</TooltipProvider>
<SimpleTooltip side="top" tooltip={reindexTooltip}>
<span className="cursor-help flex items-center">
<SvgClock className="ml-2 h-3.5 w-3.5 stroke-current" />
</span>
</SimpleTooltip>
)}
</div>
</TableCell>

View File

@@ -178,7 +178,6 @@ export const DocumentSetCreationForm = ({
name="name"
label="Name:"
placeholder="A name for the document set"
disabled={isUpdate}
autoCompleteDisabled={true}
/>
<TextFormField

View File

@@ -112,7 +112,7 @@ const EditRow = ({
if (!isEditable) {
return (
<div className="text-text-darkerfont-medium my-auto p-1">
<div className="text-text-darker font-medium my-auto p-1">
{documentSet.name}
</div>
);
@@ -125,7 +125,7 @@ const EditRow = ({
<TooltipTrigger asChild>
<div
className={`
text-text-darkerfont-medium my-auto p-1 hover:bg-accent-background flex items-center select-none
text-text-darker font-medium my-auto p-1 hover:bg-accent-background flex items-center select-none
${documentSet.is_up_to_date ? "cursor-pointer" : "cursor-default"}
`}
style={{ wordBreak: "normal", overflowWrap: "break-word" }}

View File

@@ -21,6 +21,7 @@ import { getDisplayNameForModel } from "@/lib/hooks";
import { errorHandlingFetcher } from "@/lib/fetcher";
import Button from "@/refresh-components/buttons/Button";
import SvgPlusCircle from "@/icons/plus-circle";
import { NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
// Number of tokens to show cost calculation for
const COST_CALCULATION_TOKENS = 1_000_000;
@@ -281,10 +282,15 @@ const AdvancedEmbeddingFormPage = forwardRef<
name="disable_rerank_for_streaming"
/>
<BooleanFormField
subtext="Enable contextual RAG for all chunk sizes."
subtext={
NEXT_PUBLIC_CLOUD_ENABLED
? "Contextual RAG disabled in Onyx Cloud"
: "Enable contextual RAG for all chunk sizes."
}
optional
label="Contextual RAG"
name="enable_contextual_rag"
disabled={NEXT_PUBLIC_CLOUD_ENABLED}
/>
<div>
<SelectorFormField

View File

@@ -197,20 +197,24 @@ export function SettingsForm() {
}
function handleCompanyNameBlur() {
updateSettingField([
{ fieldName: "company_name", newValue: companyName || null },
]);
const originalValue = settings?.company_name || "";
if (companyName !== originalValue) {
updateSettingField([
{ fieldName: "company_name", newValue: companyName || null },
]);
}
}
// NOTE: at the moment there's a small bug where if you click another admin panel page after typing
// the field doesn't update correctly
function handleCompanyDescriptionBlur() {
updateSettingField([
{
fieldName: "company_description",
newValue: companyDescription || null,
},
]);
const originalValue = settings?.company_description || "";
if (companyDescription !== originalValue) {
updateSettingField([
{
fieldName: "company_description",
newValue: companyDescription || null,
},
]);
}
}
return (

View File

@@ -0,0 +1,256 @@
/**
* Integration Test: Email/Password Authentication Workflow
*
* Tests the complete user journey for logging in.
* This tests the full workflow: form → validation → API call → redirect
*/
import React from "react";
import { render, screen, waitFor, setupUser } from "@tests/setup/test-utils";
import EmailPasswordForm from "./EmailPasswordForm";
// Mock next/navigation (not used by this component, but required by dependencies)
jest.mock("next/navigation", () => ({
useRouter: () => ({
push: jest.fn(),
refresh: jest.fn(),
}),
}));
describe("Email/Password Login Workflow", () => {
let fetchSpy: jest.SpyInstance;
beforeEach(() => {
jest.clearAllMocks();
fetchSpy = jest.spyOn(global, "fetch");
// Mock window.location.href for redirect testing
delete (window as any).location;
window.location = { href: "" } as any;
});
afterEach(() => {
fetchSpy.mockRestore();
});
test("allows user to login with valid credentials", async () => {
const user = setupUser();
// Mock POST /api/auth/login
fetchSpy.mockResolvedValueOnce({
ok: true,
json: async () => ({}),
} as Response);
render(<EmailPasswordForm isSignup={false} />);
// User fills out the form using placeholder text
const emailInput = screen.getByPlaceholderText(/email@yourcompany.com/i);
const passwordInput = screen.getByPlaceholderText(/\*/);
await user.type(emailInput, "test@example.com");
await user.type(passwordInput, "password123");
// User submits the form
const loginButton = screen.getByRole("button", { name: /log in/i });
await user.click(loginButton);
// After successful login, user should be redirected to /chat
await waitFor(() => {
expect(window.location.href).toBe("/chat");
});
// Verify API was called with correct credentials
expect(fetchSpy).toHaveBeenCalledWith(
"/api/auth/login",
expect.objectContaining({
method: "POST",
headers: {
"Content-Type": "application/x-www-form-urlencoded",
},
})
);
// Verify the request body contains email and password
const callArgs = fetchSpy.mock.calls[0];
const body = callArgs[1].body;
expect(body.toString()).toContain("username=test%40example.com");
expect(body.toString()).toContain("password=password123");
});
test("shows error message when login fails", async () => {
const user = setupUser();
// Mock POST /api/auth/login (failure)
fetchSpy.mockResolvedValueOnce({
ok: false,
status: 401,
json: async () => ({ detail: "LOGIN_BAD_CREDENTIALS" }),
} as Response);
render(<EmailPasswordForm isSignup={false} />);
// User fills out form with invalid credentials
const emailInput = screen.getByPlaceholderText(/email@yourcompany.com/i);
const passwordInput = screen.getByPlaceholderText(/\*/);
await user.type(emailInput, "wrong@example.com");
await user.type(passwordInput, "wrongpassword");
// User submits
const loginButton = screen.getByRole("button", { name: /log in/i });
await user.click(loginButton);
// Verify error message is displayed
await waitFor(() => {
expect(
screen.getByText(/invalid email or password/i)
).toBeInTheDocument();
});
});
});
describe("Email/Password Signup Workflow", () => {
let fetchSpy: jest.SpyInstance;
beforeEach(() => {
jest.clearAllMocks();
fetchSpy = jest.spyOn(global, "fetch");
// Mock window.location.href
delete (window as any).location;
window.location = { href: "" } as any;
});
afterEach(() => {
fetchSpy.mockRestore();
});
test("allows user to sign up and login with valid credentials", async () => {
const user = setupUser();
// Mock POST /api/auth/register
fetchSpy.mockResolvedValueOnce({
ok: true,
json: async () => ({}),
} as Response);
// Mock POST /api/auth/login (after successful signup)
fetchSpy.mockResolvedValueOnce({
ok: true,
json: async () => ({}),
} as Response);
render(<EmailPasswordForm isSignup={true} />);
// User fills out the signup form
const emailInput = screen.getByPlaceholderText(/email@yourcompany.com/i);
const passwordInput = screen.getByPlaceholderText(/\*/);
await user.type(emailInput, "newuser@example.com");
await user.type(passwordInput, "securepassword123");
// User submits the signup form
const signupButton = screen.getByRole("button", { name: /sign up/i });
await user.click(signupButton);
// Verify signup API was called
await waitFor(() => {
expect(fetchSpy).toHaveBeenCalledWith(
"/api/auth/register",
expect.objectContaining({
method: "POST",
headers: {
"Content-Type": "application/json",
},
})
);
});
// Verify signup request body
const signupCallArgs = fetchSpy.mock.calls[0];
const signupBody = JSON.parse(signupCallArgs[1].body);
expect(signupBody).toEqual({
email: "newuser@example.com",
username: "newuser@example.com",
password: "securepassword123",
referral_source: undefined,
});
// Verify login API was called after successful signup
await waitFor(() => {
expect(fetchSpy).toHaveBeenCalledWith(
"/api/auth/login",
expect.objectContaining({
method: "POST",
})
);
});
// Verify success message is shown
await waitFor(() => {
expect(
screen.getByText(/account created successfully/i)
).toBeInTheDocument();
});
});
test("shows error when email already exists", async () => {
const user = setupUser();
// Mock POST /api/auth/register (failure - user exists)
fetchSpy.mockResolvedValueOnce({
ok: false,
status: 400,
json: async () => ({ detail: "REGISTER_USER_ALREADY_EXISTS" }),
} as Response);
render(<EmailPasswordForm isSignup={true} />);
// User fills out form with existing email
const emailInput = screen.getByPlaceholderText(/email@yourcompany.com/i);
const passwordInput = screen.getByPlaceholderText(/\*/);
await user.type(emailInput, "existing@example.com");
await user.type(passwordInput, "password123");
// User submits
const signupButton = screen.getByRole("button", { name: /sign up/i });
await user.click(signupButton);
// Verify error message is displayed
await waitFor(() => {
expect(
screen.getByText(/an account already exists with the specified email/i)
).toBeInTheDocument();
});
});
test("shows rate limit error when too many requests", async () => {
const user = setupUser();
// Mock POST /api/auth/register (failure - rate limit)
fetchSpy.mockResolvedValueOnce({
ok: false,
status: 429,
json: async () => ({ detail: "Too many requests" }),
} as Response);
render(<EmailPasswordForm isSignup={true} />);
// User fills out form
const emailInput = screen.getByPlaceholderText(/email@yourcompany.com/i);
const passwordInput = screen.getByPlaceholderText(/\*/);
await user.type(emailInput, "user@example.com");
await user.type(passwordInput, "password123");
// User submits
const signupButton = screen.getByRole("button", { name: /sign up/i });
await user.click(signupButton);
// Verify rate limit message is displayed
await waitFor(() => {
expect(
screen.getByText(/too many requests\. please try again later/i)
).toBeInTheDocument();
});
});
});

View File

@@ -43,6 +43,8 @@ export default function EmailPasswordForm({
email: defaultEmail ? defaultEmail.toLowerCase() : "",
password: "",
}}
validateOnChange={false}
validateOnBlur={true}
validationSchema={Yup.object().shape({
email: Yup.string()
.email()

View File

@@ -397,18 +397,16 @@ function ChatInputBarInner({
<div className="w-full h-full flex flex-col shadow-01 bg-background-neutral-00 rounded-16">
{currentMessageFiles.length > 0 && (
<div className="p-spacing-inline bg-background-neutral-01 rounded-t-16">
<div className="flex flex-wrap gap-2">
{currentMessageFiles.map((file) => (
<FileCard
key={file.id}
file={file}
removeFile={handleRemoveMessageFile}
hideProcessingState={hideProcessingState}
onFileClick={handleFileClick}
/>
))}
</div>
<div className="p-spacing-inline rounded-t-16 flex flex-wrap gap-spacing-interline">
{currentMessageFiles.map((file) => (
<FileCard
key={file.id}
file={file}
removeFile={handleRemoveMessageFile}
hideProcessingState={hideProcessingState}
onFileClick={handleFileClick}
/>
))}
</div>
)}

View File

@@ -0,0 +1,393 @@
/**
* Integration Test: Input Prompts CRUD Workflow
*
* Tests the complete user journey for managing prompt shortcuts.
* This tests the full workflow: fetch → create → edit → delete
*/
import React from "react";
import { render, screen, setupUser, waitFor } from "@tests/setup/test-utils";
import InputPrompts from "./InputPrompts";
// Mock next/navigation for BackButton
jest.mock("next/navigation", () => ({
useRouter: () => ({
push: jest.fn(),
back: jest.fn(),
refresh: jest.fn(),
}),
}));
describe("Input Prompts CRUD Workflow", () => {
let fetchSpy: jest.SpyInstance;
beforeEach(() => {
jest.clearAllMocks();
fetchSpy = jest.spyOn(global, "fetch");
});
afterEach(() => {
fetchSpy.mockRestore();
});
test("fetches and displays existing prompts on load", async () => {
// Mock GET /api/input_prompt
fetchSpy.mockResolvedValueOnce({
ok: true,
json: async () => [
{
id: 1,
prompt: "Summarize",
content: "Summarize the uploaded document and highlight key points.",
is_public: false,
},
{
id: 2,
prompt: "Explain",
content: "Explain this concept in simple terms.",
is_public: true,
},
],
} as Response);
render(<InputPrompts />);
await waitFor(() => {
expect(fetchSpy).toHaveBeenCalledWith("/api/input_prompt");
});
await waitFor(() => {
expect(screen.getByText("Summarize")).toBeInTheDocument();
expect(screen.getByText("Explain")).toBeInTheDocument();
expect(
screen.getByText(
/Summarize the uploaded document and highlight key points/i
)
).toBeInTheDocument();
});
});
test("creates a new prompt successfully", async () => {
const user = setupUser();
// Mock GET /api/input_prompt
fetchSpy.mockResolvedValueOnce({
ok: true,
json: async () => [],
} as Response);
// Mock POST /api/input_prompt
fetchSpy.mockResolvedValueOnce({
ok: true,
json: async () => ({
id: 3,
prompt: "Review",
content: "Review this code for potential improvements.",
is_public: false,
}),
} as Response);
render(<InputPrompts />);
await waitFor(() => {
expect(fetchSpy).toHaveBeenCalledWith("/api/input_prompt");
});
const createButton = screen.getByRole("button", {
name: /create new prompt/i,
});
await user.click(createButton);
expect(
await screen.findByPlaceholderText(/prompt shortcut/i)
).toBeInTheDocument();
expect(screen.getByPlaceholderText(/actual prompt/i)).toBeInTheDocument();
const shortcutInput = screen.getByPlaceholderText(/prompt shortcut/i);
const promptInput = screen.getByPlaceholderText(/actual prompt/i);
await user.type(shortcutInput, "Review");
await user.type(
promptInput,
"Review this code for potential improvements."
);
const submitButton = screen.getByRole("button", { name: /^create$/i });
await user.click(submitButton);
await waitFor(() => {
expect(fetchSpy).toHaveBeenCalledWith(
"/api/input_prompt",
expect.objectContaining({
method: "POST",
headers: { "Content-Type": "application/json" },
})
);
});
const createCallArgs = fetchSpy.mock.calls[1]; // Second call (first was GET)
const createBody = JSON.parse(createCallArgs[1].body);
expect(createBody).toEqual({
prompt: "Review",
content: "Review this code for potential improvements.",
is_public: false,
});
await waitFor(() => {
expect(
screen.getByText(/prompt created successfully/i)
).toBeInTheDocument();
});
expect(await screen.findByText("Review")).toBeInTheDocument();
});
test("edits an existing user-created prompt", async () => {
const user = setupUser();
// Mock GET /api/input_prompt
fetchSpy.mockResolvedValueOnce({
ok: true,
json: async () => [
{
id: 1,
prompt: "Summarize",
content: "Summarize the document.",
is_public: false,
},
],
} as Response);
// Mock PATCH /api/input_prompt/1
fetchSpy.mockResolvedValueOnce({
ok: true,
json: async () => ({}),
} as Response);
render(<InputPrompts />);
await waitFor(() => {
expect(screen.getByText("Summarize")).toBeInTheDocument();
});
const dropdownButtons = screen.getAllByRole("button");
const moreButton = dropdownButtons.find(
(btn) => btn.textContent === "" && btn.querySelector("svg")
);
expect(moreButton).toBeDefined();
await user.click(moreButton!);
const editOption = await screen.findByRole("menuitem", { name: /edit/i });
await user.click(editOption);
let textareas: HTMLElement[];
await waitFor(() => {
textareas = screen.getAllByRole("textbox");
expect(textareas[0]).toHaveValue("Summarize");
expect(textareas[1]).toHaveValue("Summarize the document.");
});
await user.clear(textareas![1]);
await user.type(
textareas![1],
"Summarize the document and provide key insights."
);
const saveButton = screen.getByRole("button", { name: /save/i });
await user.click(saveButton);
await waitFor(() => {
expect(fetchSpy).toHaveBeenCalledWith(
"/api/input_prompt/1",
expect.objectContaining({
method: "PATCH",
headers: { "Content-Type": "application/json" },
})
);
});
const patchCallArgs = fetchSpy.mock.calls[1];
const patchBody = JSON.parse(patchCallArgs[1].body);
expect(patchBody).toEqual({
prompt: "Summarize",
content: "Summarize the document and provide key insights.",
active: true,
});
await waitFor(() => {
expect(
screen.getByText(/prompt updated successfully/i)
).toBeInTheDocument();
});
});
test("deletes a user-created prompt", async () => {
const user = setupUser();
// Mock GET /api/input_prompt
fetchSpy.mockResolvedValueOnce({
ok: true,
json: async () => [
{
id: 1,
prompt: "Summarize",
content: "Summarize the document.",
is_public: false,
},
],
} as Response);
// Mock DELETE /api/input_prompt/1
fetchSpy.mockResolvedValueOnce({
ok: true,
json: async () => ({}),
} as Response);
render(<InputPrompts />);
await waitFor(() => {
expect(screen.getByText("Summarize")).toBeInTheDocument();
});
const dropdownButtons = screen.getAllByRole("button");
const moreButton = dropdownButtons.find(
(btn) => btn.textContent === "" && btn.querySelector("svg")
);
await user.click(moreButton!);
const deleteOption = await screen.findByRole("menuitem", {
name: /delete/i,
});
await user.click(deleteOption);
await waitFor(() => {
expect(fetchSpy).toHaveBeenCalledWith("/api/input_prompt/1", {
method: "DELETE",
});
});
await waitFor(() => {
expect(
screen.getByText(/prompt deleted successfully/i)
).toBeInTheDocument();
});
await waitFor(() => {
expect(screen.queryByText("Summarize")).not.toBeInTheDocument();
});
});
test("hides a public prompt instead of deleting it", async () => {
const user = setupUser();
// Mock GET /api/input_prompt
fetchSpy.mockResolvedValueOnce({
ok: true,
json: async () => [
{
id: 2,
prompt: "Explain",
content: "Explain this concept.",
is_public: true,
},
],
} as Response);
// Mock POST /api/input_prompt/2/hide
fetchSpy.mockResolvedValueOnce({
ok: true,
json: async () => ({}),
} as Response);
render(<InputPrompts />);
await waitFor(() => {
expect(screen.getByText("Explain")).toBeInTheDocument();
expect(screen.getByText("Built-in")).toBeInTheDocument();
});
const dropdownButtons = screen.getAllByRole("button");
const moreButton = dropdownButtons.find(
(btn) => btn.textContent === "" && btn.querySelector("svg")
);
await user.click(moreButton!);
// Edit option should NOT be shown for public prompts
await waitFor(() => {
expect(
screen.getByRole("menuitem", { name: /delete/i })
).toBeInTheDocument();
});
expect(
screen.queryByRole("menuitem", { name: /edit/i })
).not.toBeInTheDocument();
// Public prompts use the hide endpoint instead of delete
const deleteOption = screen.getByRole("menuitem", { name: /delete/i });
await user.click(deleteOption);
await waitFor(() => {
expect(fetchSpy).toHaveBeenCalledWith("/api/input_prompt/2/hide", {
method: "POST",
});
});
await waitFor(() => {
expect(
screen.getByText(/prompt hidden successfully/i)
).toBeInTheDocument();
});
});
test("shows error when fetch fails", async () => {
// Mock GET /api/input_prompt (failure)
fetchSpy.mockRejectedValueOnce(new Error("Network error"));
render(<InputPrompts />);
await waitFor(() => {
expect(
screen.getByText(/failed to fetch prompt shortcuts/i)
).toBeInTheDocument();
});
});
test("shows error when create fails", async () => {
const user = setupUser();
// Mock GET /api/input_prompt
fetchSpy.mockResolvedValueOnce({
ok: true,
json: async () => [],
} as Response);
// Mock POST /api/input_prompt (failure)
fetchSpy.mockResolvedValueOnce({
ok: false,
status: 500,
} as Response);
render(<InputPrompts />);
await waitFor(() => {
expect(fetchSpy).toHaveBeenCalledWith("/api/input_prompt");
});
const createButton = screen.getByRole("button", {
name: /create new prompt/i,
});
await user.click(createButton);
const shortcutInput =
await screen.findByPlaceholderText(/prompt shortcut/i);
const promptInput = screen.getByPlaceholderText(/actual prompt/i);
await user.type(shortcutInput, "Test");
await user.type(promptInput, "Test content");
const submitButton = screen.getByRole("button", { name: /^create$/i });
await user.click(submitButton);
await waitFor(() => {
expect(screen.getByText(/failed to create prompt/i)).toBeInTheDocument();
});
});
});

View File

@@ -121,10 +121,10 @@ export const preprocessLaTeX = (content: string) => {
(_, equation) => `$$${equation}$$`
);
// Replace inline LaTeX delimiters \( \) with $$ $$
// Replace inline LaTeX delimiters \( \) with $ $
const inlineProcessed = blockProcessed.replace(
/\\\(([\s\S]*?)\\\)/g,
(_, equation) => `$$${equation}$$`
(_, equation) => `$${equation}$`
);
// Restore original dollar signs in code contexts

View File

@@ -223,14 +223,31 @@ export const ProjectsProvider: React.FC<ProjectsProviderProps> = ({
const renameProject = useCallback(
async (projectId: number, name: string): Promise<Project> => {
// Optimistically update the UI immediately
setProjects((prev) =>
prev.map((p) => (p.id === projectId ? { ...p, name } : p))
);
if (currentProjectId === projectId) {
setCurrentProjectDetails((prev) =>
prev ? { ...prev, project: { ...prev.project, name } } : prev
);
}
try {
const updated = await svcRenameProject(projectId, name);
// Refresh to get canonical state from server
await fetchProjects();
if (currentProjectId === projectId) {
await refreshCurrentProjectDetails();
}
return updated;
} catch (err) {
// Rollback optimistic update on failure
await fetchProjects();
if (currentProjectId === projectId) {
await refreshCurrentProjectDetails();
}
const message =
err instanceof Error ? err.message : "Failed to rename project";
throw err;

View File

@@ -12,7 +12,7 @@ import {
SubLabel,
TextFormField,
} from "@/components/Field";
import { Button } from "@/components/ui/button";
import Button from "@/refresh-components/buttons/Button";
import Text from "@/components/ui/text";
import { ImageUpload } from "./ImageUpload";
import { AdvancedOptionsToggle } from "@/components/AdvancedOptionsToggle";
@@ -158,8 +158,7 @@ export function WhitelabelingForm() {
/>
<Button
variant="destructive"
size="sm"
danger
type="button"
className="mb-8"
onClick={async () => {
@@ -304,8 +303,7 @@ export function WhitelabelingForm() {
/>
<Button
variant="destructive"
size="sm"
danger
type="button"
className="mb-8"
onClick={async () => {

View File

@@ -42,8 +42,8 @@
--alpha-grey-100-20: #00000033;
--alpha-grey-100-15: #00000026;
--alpha-grey-100-10: #0000001a;
--alpha-grey-100-5: #0000000d;
--alpha-grey-100-0: #00000000;
--alpha-grey-100-05: #0000000d;
--alpha-grey-100-00: #00000000;
/* Alpha Grey 00 (White with opacity) */
--alpha-grey-00-95: #fffffff2;
@@ -64,8 +64,8 @@
--alpha-grey-00-20: #ffffff33;
--alpha-grey-00-15: #ffffff26;
--alpha-grey-00-10: #ffffff1a;
--alpha-grey-00-5: #ffffff0d;
--alpha-grey-00-0: #ffffff00;
--alpha-grey-00-05: #ffffff0d;
--alpha-grey-00-00: #ffffff00;
/* Blue Scale */
--blue-95: #040e25;

View File

@@ -580,6 +580,28 @@ code[class*="language-"] {
margin-bottom: 1rem;
}
/* MASKING */
/*
TODO: add correct class
.mask-01 {
background: linear-gradient(to top, var(--mask-01), transparent);
-webkit-mask-image: -webkit-gradient(
linear,
left top,
left bottom,
from(var(--mask-01)),
to(transparent)
);
mask-image: -webkit-gradient(
linear,
left top,
left bottom,
from(var(--mask-01)),
to(transparent)
);
} */
/* DEBUGGING UTILITIES
If you ever want to highlight a component for debugging purposes, just type in `className="dbg-red ..."`, and a red box should appear around it.

View File

@@ -1,12 +1,14 @@
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
import React, { useState, useEffect } from "react";
import { FormikProps, FieldArray, ArrayHelpers, ErrorMessage } from "formik";
import Text from "@/components/ui/text";
import { FiUsers } from "react-icons/fi";
import Text from "@/refresh-components/texts/Text";
import Button from "@/refresh-components/buttons/Button";
import { UserGroup, UserRole } from "@/lib/types";
import { useUserGroups } from "@/lib/hooks";
import { BooleanFormField } from "@/components/Field";
import { useUser } from "./user/UserProvider";
import SvgUsers from "@/icons/users";
import { cn } from "@/lib/utils";
export type IsPublicGroupSelectorFormType = {
is_public: boolean;
@@ -112,59 +114,48 @@ export const IsPublicGroupSelector = <T extends IsPublicGroupSelectorFormType>({
userGroups &&
userGroups?.length > 0 && (
<>
<div className="flex mt-4 gap-x-2 items-center">
<div className="block font-medium text-base">
<div className="flex flex-col gap-3 pt-4">
<Text mainUiAction text05>
Assign group access for this {objectName}
</div>
</div>
{userGroupsIsLoading ? (
<div className="animate-pulse bg-background-200 h-8 w-32 rounded"></div>
) : (
<Text className="mb-3">
{isAdmin || !enforceGroupSelection ? (
<>
This {objectName} will be visible/accessible by the groups
selected below
</>
) : (
<>
Curators must select one or more groups to give access to
this {objectName}
</>
)}
</Text>
)}
{userGroupsIsLoading ? (
<div className="animate-pulse bg-background-200 h-8 w-32 rounded" />
) : (
<Text mainUiMuted text03>
{isAdmin || !enforceGroupSelection ? (
<>
This {objectName} will be visible/accessible by the groups
selected below
</>
) : (
<>
Curators must select one or more groups to give access to
this {objectName}
</>
)}
</Text>
)}
</div>
<FieldArray
name="groups"
render={(arrayHelpers: ArrayHelpers) => (
<div className="flex gap-2 flex-wrap mb-4">
<div className="flex flex-wrap gap-2 py-4">
{userGroupsIsLoading ? (
<div className="animate-pulse bg-background-200 h-8 w-32 rounded"></div>
<div className="animate-pulse bg-background-200 h-8 w-32 rounded" />
) : (
userGroups &&
userGroups.map((userGroup: UserGroup) => {
const ind = formikProps.values.groups.indexOf(
userGroup.id
);
let isSelected = ind !== -1;
const isSelected = ind !== -1;
return (
<div
<Button
key={userGroup.id}
className={`
px-3
py-1
rounded-lg
border
border-border
w-fit
flex
cursor-pointer
${
isSelected
? "bg-background-200"
: "hover:bg-accent-background-hovered"
}
`}
primary
action={isSelected}
type="button"
leftIcon={SvgUsers}
onClick={() => {
if (isSelected) {
arrayHelpers.remove(ind);
@@ -173,11 +164,8 @@ export const IsPublicGroupSelector = <T extends IsPublicGroupSelectorFormType>({
}
}}
>
<div className="my-auto flex">
<FiUsers className="my-auto mr-2" />{" "}
{userGroup.name}
</div>
</div>
{userGroup.name}
</Button>
);
})
)}

View File

@@ -1,8 +1,8 @@
"use client";
import ErrorPageLayout from "./ErrorPageLayout";
import { fetchCustomerPortal } from "@/lib/billing/utils";
import { useState } from "react";
import ErrorPageLayout from "@/components/errorPages/ErrorPageLayout";
import { fetchCustomerPortal } from "@/lib/billing/utils";
import { useRouter } from "next/navigation";
import Button from "@/refresh-components/buttons/Button";
import { logout } from "@/lib/user";
@@ -86,59 +86,63 @@ export default function AccessRestricted() {
return (
<ErrorPageLayout>
<div className="flex items-center gap-2 mb-4">
<div className="flex items-center gap-spacing-interline">
<Text headingH2>Access Restricted</Text>
<SvgLock className="stroke-status-error-05 w-[1.5rem] h-[1.5rem]" />
</div>
<div className="space-y-4">
<Text text03>
We regret to inform you that your access to Onyx has been temporarily
suspended due to a lapse in your subscription.
</Text>
<Text text03>
To reinstate your access and continue benefiting from Onyx&apos;s
powerful features, please update your payment information.
</Text>
<Text text03>
If you&apos;re an admin, you can manage your subscription by clicking
the button below. For other users, please reach out to your
administrator to address this matter.
</Text>
<div className="flex flex-row gap-spacing-interline">
<Button onClick={handleResubscribe} disabled={isLoading}>
{isLoading ? "Loading..." : "Resubscribe"}
</Button>
<Button
secondary
onClick={handleManageSubscription}
disabled={isLoading}
>
Manage Existing Subscription
</Button>
<Button
secondary
onClick={async () => {
await logout();
window.location.reload();
}}
>
Log out
</Button>
</div>
{error && <Text className="text-status-error-05">{error}</Text>}
<Text text03>
Need help? Join our{" "}
<a
className="text-action-link-05 hover:text-action-link-06"
href="https://discord.gg/4NA5SbzrWb"
target="_blank"
rel="noopener noreferrer"
>
Discord community
</a>{" "}
for support.
</Text>
<Text text03>
We regret to inform you that your access to Onyx has been temporarily
suspended due to a lapse in your subscription.
</Text>
<Text text03>
To reinstate your access and continue benefiting from Onyx&apos;s
powerful features, please update your payment information.
</Text>
<Text text03>
If you&apos;re an admin, you can manage your subscription by clicking
the button below. For other users, please reach out to your
administrator to address this matter.
</Text>
<div className="flex flex-row gap-spacing-interline">
<Button onClick={handleResubscribe} disabled={isLoading}>
{isLoading ? "Loading..." : "Resubscribe"}
</Button>
<Button
secondary
onClick={handleManageSubscription}
disabled={isLoading}
>
Manage Existing Subscription
</Button>
<Button
secondary
onClick={async () => {
await logout();
window.location.reload();
}}
>
Log out
</Button>
</div>
{error && <Text className="text-status-error-05">{error}</Text>}
<Text text03>
Need help? Join our{" "}
<a
className="text-action-link-05 hover:text-action-link-06"
href="https://discord.gg/4NA5SbzrWb"
target="_blank"
rel="noopener noreferrer"
>
Discord community
</a>{" "}
for support.
</Text>
</ErrorPageLayout>
);
}

View File

@@ -1,22 +1,20 @@
"use client";
import ErrorPageLayout from "./ErrorPageLayout";
import Text from "@/refresh-components/texts/Text";
import ErrorPageLayout from "@/components/errorPages/ErrorPageLayout";
export default function CloudError() {
return (
<ErrorPageLayout>
<h1 className="text-2xl font-semibold mb-4 text-gray-800 dark:text-gray-200">
Maintenance in Progress
</h1>
<div className="space-y-4 text-gray-600 dark:text-gray-300">
<p>
Onyx is currently in a maintenance window. Please check back in a
couple of minutes.
</p>
<p>
We apologize for any inconvenience this may cause and appreciate your
patience.
</p>
</div>
<Text headingH2>Maintenance in Progress</Text>
<Text text03>
Onyx is currently in a maintenance window. Please check back in a couple
of minutes.
</Text>
<Text text03>
We apologize for any inconvenience this may cause and appreciate your
patience.
</Text>
</ErrorPageLayout>
);
}

View File

@@ -1,47 +1,46 @@
import ErrorPageLayout from "./ErrorPageLayout";
import ErrorPageLayout from "@/components/errorPages/ErrorPageLayout";
import Text from "@/refresh-components/texts/Text";
import SvgAlertCircle from "@/icons/alert-circle";
export default function Error() {
return (
<ErrorPageLayout>
<div className="flex items-center gap-2 mb-4 ">
<Text headingH2 inverted>
We encountered an issue
</Text>
<SvgAlertCircle className="w-[1.5rem] h-[1.5rem] stroke-text-inverted-04" />
</div>
<div className="space-y-4 text-gray-600 dark:text-gray-300">
<Text inverted>
It seems there was a problem loading your Onyx settings. This could be
due to a configuration issue or incomplete setup.
</Text>
<Text inverted>
If you&apos;re an admin, please review our{" "}
<a
className="text-blue-500 hover:text-blue-700 dark:text-blue-400 dark:hover:text-blue-300"
href="https://docs.onyx.app/?utm_source=app&utm_medium=error_page&utm_campaign=config_error"
target="_blank"
rel="noopener noreferrer"
>
documentation
</a>{" "}
for proper configuration steps. If you&apos;re a user, please contact
your admin for assistance.
</Text>
<Text inverted>
Need help? Join our{" "}
<a
className="text-blue-500 hover:text-blue-700 dark:text-blue-400 dark:hover:text-blue-300"
href="https://discord.gg/4NA5SbzrWb"
target="_blank"
rel="noopener noreferrer"
>
Discord community
</a>{" "}
for support.
</Text>
<div className="flex flex-row items-center gap-spacing-interline">
<Text headingH2>We encountered an issue</Text>
<SvgAlertCircle className="w-[1.5rem] h-[1.5rem] stroke-text-04" />
</div>
<Text text03>
It seems there was a problem loading your Onyx settings. This could be
due to a configuration issue or incomplete setup.
</Text>
<Text text03>
If you&apos;re an admin, please review our{" "}
<a
className="text-action-link-05"
href="https://docs.onyx.app/?utm_source=app&utm_medium=error_page&utm_campaign=config_error"
target="_blank"
rel="noopener noreferrer"
>
documentation
</a>{" "}
for proper configuration steps. If you&apos;re a user, please contact
your admin for assistance.
</Text>
<Text text03>
Need help? Join our{" "}
<a
className="text-action-link-05"
href="https://discord.gg/4NA5SbzrWb"
target="_blank"
rel="noopener noreferrer"
>
Discord community
</a>{" "}
for support.
</Text>
</ErrorPageLayout>
);
}

View File

@@ -1,5 +1,5 @@
import React from "react";
import { LogoType } from "../logo/Logo";
import { OnyxLogoTypeIcon } from "@/components/icons/icons";
interface ErrorPageLayoutProps {
children: React.ReactNode;
@@ -7,12 +7,10 @@ interface ErrorPageLayoutProps {
export default function ErrorPageLayout({ children }: ErrorPageLayoutProps) {
return (
<div className="flex flex-col items-center justify-center min-h-screen">
<div className="mb-4 flex items-center max-w-[220px]">
<LogoType size="large" />
</div>
<div className="max-w-xl border border-border w-full bg-white shadow-md rounded-lg overflow-hidden">
<div className="p-6 sm:p-8">{children}</div>
<div className="flex flex-col items-center justify-center w-full h-screen gap-spacing-paragraph">
<OnyxLogoTypeIcon size={120} className="" />
<div className="max-w-[40rem] w-full border bg-background-neutral-00 shadow-02 rounded-16 p-padding-content flex flex-col gap-spacing-paragraph">
{children}
</div>
</div>
);

View File

@@ -54,12 +54,9 @@ const TooltipContent = React.forwardRef<
sideOffset={sideOffset}
side={side}
className={cn(
`z-[100] rounded-08 ${
backgroundColor || "bg-background-neutral-inverted-03"
}
${width}
text-wrap
p-padding-button shadow-lg animate-in fade-in-0 zoom-in-95 data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=closed]:zoom-out-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2`,
"z-[100] rounded-08 text-text-inverted-05 text-wrap p-padding-button shadow-lg animate-in fade-in-0 zoom-in-95 data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=closed]:zoom-out-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2",
backgroundColor || "bg-background-neutral-inverted-03",
width,
className
)}
{...props}

View File

@@ -0,0 +1,45 @@
import * as React from "react";
import type { SVGProps } from "react";
const OnyxLogo = (props: SVGProps<SVGSVGElement>) => (
<svg
width="16"
height="16"
viewBox="0 0 16 16"
fill="none"
xmlns="http://www.w3.org/2000/svg"
{...props}
>
<g clipPath="url(#clip0_586_577)">
<path
d="M8 4.00001L4.5 2.50002L8 1.00002L11.5 2.50002L8 4.00001Z"
strokeWidth={1.5}
strokeLinecap="round"
strokeLinejoin="round"
/>
<path
d="M8 12L11.5 13.5L8 15L4.5 13.5L8 12Z"
strokeWidth={1.5}
strokeLinecap="round"
strokeLinejoin="round"
/>
<path
d="M4 8L2.5 11.5L1 8L2.5 4.50002L4 8Z"
strokeWidth={1.5}
strokeLinecap="round"
strokeLinejoin="round"
/>
<path
d="M12 8.00002L13.5 4.50002L15 8.00001L13.5 11.5L12 8.00002Z"
strokeWidth={1.5}
strokeLinecap="round"
strokeLinejoin="round"
/>
</g>
<defs>
<clipPath id="clip0_586_577">
<rect width={16} height={16} fill="white" />
</clipPath>
</defs>
</svg>
);
export default OnyxLogo;

View File

@@ -2,11 +2,7 @@ import React from "react";
import crypto from "crypto";
import { MinimalPersonaSnapshot } from "@/app/admin/assistants/interfaces";
import { buildImgUrl } from "@/app/chat/components/files/images/utils";
import {
ArtAsistantIcon,
GeneralAssistantIcon,
OnyxIcon,
} from "@/components/icons/icons";
import { ArtAsistantIcon, OnyxIcon } from "@/components/icons/icons";
import {
Tooltip,
TooltipContent,
@@ -15,6 +11,7 @@ import {
} from "@/components/ui/tooltip";
import { cn } from "@/lib/utils";
import Text from "@/refresh-components/texts/Text";
import { useSettingsContext } from "@/components/settings/SettingsProvider";
function md5ToBits(str: string): number[] {
const md5hex = crypto.createHash("md5").update(str).digest("hex");
@@ -94,17 +91,35 @@ export interface AgentIconProps {
}
export function AgentIcon({ agent, size = 24 }: AgentIconProps) {
const settings = useSettingsContext();
// Check if whitelabeling is enabled for the default assistant
const shouldUseWhitelabelLogo =
agent.id === 0 && settings?.enterpriseSettings?.use_custom_logo === true;
return (
<TooltipProvider>
<Tooltip>
<TooltipTrigger asChild>
<div className="text-text-04">
{agent.id == -3 ? (
{agent.id === -3 ? (
<ArtAsistantIcon size={size} />
) : agent.id == 0 ? (
<OnyxIcon size={size} />
) : agent.id == -1 ? (
<GeneralAssistantIcon size={size} />
) : agent.id === 0 ? (
shouldUseWhitelabelLogo ? (
<img
alt="Logo"
src="/api/enterprise-settings/logo"
loading="lazy"
className={cn(
"rounded-full object-cover object-center transition-opacity duration-300"
)}
width={size}
height={size}
style={{ objectFit: "contain" }}
/>
) : (
<OnyxIcon size={size} />
)
) : agent.uploaded_image_id ? (
<img
alt={agent.name}

View File

@@ -0,0 +1,46 @@
import { OnyxIcon, OnyxLogoTypeIcon } from "@/components/icons/icons";
import { useSettingsContext } from "@/components/settings/SettingsProvider";
import { NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED } from "@/lib/constants";
import { cn } from "@/lib/utils";
import Text from "@/refresh-components/texts/Text";
const FOLDED_SIZE = 24;
export interface LogoProps {
folded?: boolean;
className?: string;
}
export default function Logo({ folded, className }: LogoProps) {
const settings = useSettingsContext();
const logo = settings.enterpriseSettings?.use_custom_logo ? (
<img
src="/api/enterprise-settings/logo"
alt="Logo"
style={{ objectFit: "contain", height: FOLDED_SIZE, width: FOLDED_SIZE }}
/>
) : (
<OnyxIcon size={FOLDED_SIZE} className={cn("flex-shrink-0", className)} />
);
if (folded) return logo;
return settings.enterpriseSettings?.application_name ? (
<div className="flex flex-col">
<div className="flex flex-row items-center gap-spacing-interline">
{logo}
<Text headingH3 className="break-all line-clamp-2">
{settings.enterpriseSettings?.application_name}
</Text>
</div>
{!NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED && (
<Text secondaryBody text03 className="ml-[33px]">
Powered by Onyx
</Text>
)}
</div>
) : (
<OnyxLogoTypeIcon size={88} className={className} />
);
}

View File

@@ -58,10 +58,17 @@ export default function VerticalShadowScroller({
{showBottomShadow && (
<div
className="absolute bottom-0 left-0 right-0 h-[2rem] pointer-events-none z-[100] dbg-rd"
style={{
background: "linear-gradient(to top, var(--mask-01), transparent)",
}}
className={cn(
"absolute",
"bottom-0",
"left-0",
"right-0",
"h-[2rem]",
"pointer-events-none",
"z-[100]"
// TODO: add masking to match mocks
// "mask-01"
)}
/>
)}
</div>

View File

@@ -7,6 +7,7 @@ export default function CreateButton({
href,
onClick,
children,
type = "button",
...props
}: ButtonProps) {
return (
@@ -15,6 +16,7 @@ export default function CreateButton({
onClick={onClick}
leftIcon={SvgPlusCircle}
href={href}
type={type}
{...props}
>
{children || "Create"}

View File

@@ -1,3 +1,5 @@
import type { HTMLAttributes } from "react";
import { cn } from "@/lib/utils";
const fonts = {
@@ -46,7 +48,7 @@ const colors = {
},
};
export interface TextProps extends React.HTMLAttributes<HTMLElement> {
export interface TextProps extends HTMLAttributes<HTMLParagraphElement> {
nowrap?: boolean;
// Fonts
@@ -106,6 +108,7 @@ export default function Text({
inverted,
children,
className,
...rest
}: TextProps) {
const font = headingH1
? "headingH1"
@@ -159,6 +162,7 @@ export default function Text({
return (
<p
{...rest}
className={cn(
fonts[font],
inverted ? colors.inverted[color] : colors[color],

View File

@@ -37,12 +37,13 @@ import {
SearchIcon,
DocumentIcon2,
BrainIcon,
OnyxSparkleIcon,
} from "@/components/icons/icons";
import OnyxLogo from "@/icons/onyx-logo";
import { CombinedSettings } from "@/app/admin/settings/interfaces";
import { FiActivity, FiBarChart2 } from "react-icons/fi";
import SidebarTab from "@/refresh-components/buttons/SidebarTab";
import VerticalShadowScroller from "@/refresh-components/VerticalShadowScroller";
import { cn } from "@/lib/utils";
const connectors_items = () => [
{
@@ -160,7 +161,7 @@ const collections = (
items: [
{
name: "Default Assistant",
icon: OnyxSparkleIcon,
icon: OnyxLogo,
link: "/admin/configuration/default-assistant",
},
{
@@ -168,12 +169,16 @@ const collections = (
icon: CpuIconSkeleton,
link: "/admin/configuration/llm",
},
{
error: settings?.settings.needs_reindexing,
name: "Search Settings",
icon: SearchIcon,
link: "/admin/configuration/search",
},
...(!enableCloud
? [
{
error: settings?.settings.needs_reindexing,
name: "Search Settings",
icon: SearchIcon,
link: "/admin/configuration/search",
},
]
: []),
{
name: "Document Processing",
icon: DocumentIcon2,
@@ -357,7 +362,14 @@ export default function AdminSidebar({
))}
</VerticalShadowScroller>
<div className="flex flex-col px-spacing-interline gap-spacing-interline">
<div
className={cn(
"flex flex-col",
"px-spacing-interline",
"pt-spacing-interline",
"gap-spacing-interline"
)}
>
{combinedSettings.webVersion && (
<Text text02 secondaryBody className="px-spacing-interline">
{`Onyx version: ${combinedSettings.webVersion}`}

View File

@@ -349,7 +349,15 @@ function AppSidebarInner() {
)}
<SidebarWrapper folded={folded} setFolded={setFolded}>
<div className="flex flex-col px-spacing-interline gap-spacing-interline">
<div
className={cn(
"flex flex-col",
"px-spacing-interline",
"gap-spacing-interline",
"pt-spacing-paragraph",
"pb-spacing-paragraph"
)}
>
<div data-testid="AppSidebar/new-session">
<SidebarTab
leftIcon={SvgEditBig}
@@ -462,7 +470,7 @@ function AppSidebarInner() {
)}
</VerticalShadowScroller>
<div className="px-spacing-interline">
<div className="px-spacing-interline pt-spacing-interline-mini">
<Settings folded={folded} />
</div>
</SidebarWrapper>

View File

@@ -28,13 +28,15 @@ export default function ButtonRenaming({
return;
}
// Close immediately for instant feedback
onClose();
// Proceed with the rename operation after closing
try {
await onRename(newName);
} catch (error) {
console.error("Failed to rename:", error);
}
onClose();
}
return (

View File

@@ -42,7 +42,6 @@ function ProjectFolderButtonInner({ project }: ProjectFolderProps) {
useState(false);
const { renameProject, deleteProject } = useProjectsContext();
const [isEditing, setIsEditing] = useState(false);
const [name, setName] = useState(project.name);
const [popoverOpen, setPopoverOpen] = useState(false);
const [isHoveringIcon, setIsHoveringIcon] = useState(false);
@@ -79,7 +78,6 @@ function ProjectFolderButtonInner({ project }: ProjectFolderProps) {
async function handleRename(newName: string) {
await renameProject(project.id, newName);
setName(newName);
}
const popoverItems = [
@@ -180,7 +178,7 @@ function ProjectFolderButtonInner({ project }: ProjectFolderProps) {
onClose={() => setIsEditing(false)}
/>
) : (
name
project.name
)}
</SidebarTab>
</PopoverAnchor>

View File

@@ -1,13 +1,13 @@
import React, { Dispatch, SetStateAction } from "react";
import React, { Dispatch, SetStateAction, useMemo } from "react";
import { cn } from "@/lib/utils";
import { OnyxIcon, OnyxLogoTypeIcon } from "@/components/icons/icons";
import IconButton from "@/refresh-components/buttons/IconButton";
import SvgSidebar from "@/icons/sidebar";
import Logo from "@/refresh-components/Logo";
interface SidebarWrapperProps {
export interface SidebarWrapperProps {
folded?: boolean;
setFolded?: Dispatch<SetStateAction<boolean>>;
children: React.ReactNode;
children?: React.ReactNode;
}
export default function SidebarWrapper({
@@ -15,25 +15,47 @@ export default function SidebarWrapper({
setFolded,
children,
}: SidebarWrapperProps) {
const logo = useMemo(
() => (
<Logo
folded={folded}
className={cn(folded && "visible group-hover/SidebarWrapper:hidden")}
/>
),
[folded]
);
return (
// This extra `div` wrapping needs to be present (for some reason).
// Without, the widths of the sidebars don't properly get set to the explicitly declared widths (i.e., `4rem` folded and `15rem` unfolded).
<div>
<div
className={cn(
"h-screen flex flex-col bg-background-tint-02 py-spacing-interline justify-between gap-padding-content group/SidebarWrapper",
"h-screen",
"flex flex-col",
"py-spacing-interline",
"bg-background-tint-02",
"justify-between",
"group/SidebarWrapper",
folded ? "w-[4rem]" : "w-[15rem]"
)}
>
<div
className={cn(
"flex flex-row items-center px-spacing-paragraph py-spacing-inline flex-shrink-0",
"flex",
"flex-row",
"items-center",
"px-spacing-paragraph",
"pt-spacing-interline-mini",
"pb-spacing-paragraph",
"flex-shrink-0",
folded ? "justify-center" : "justify-between"
)}
>
{folded ? (
<div className="h-[2rem] flex flex-col justify-center items-center">
<>
{logo}
<IconButton
icon={SvgSidebar}
tertiary
@@ -41,15 +63,11 @@ export default function SidebarWrapper({
className="hidden group-hover/SidebarWrapper:flex"
tooltip="Close Sidebar"
/>
<OnyxIcon
size={24}
className="visible group-hover/SidebarWrapper:hidden"
/>
</>
</div>
) : (
<>
<OnyxLogoTypeIcon size={88} />
{logo}
<IconButton
icon={SvgSidebar}
tertiary

View File

@@ -1,10 +1,9 @@
import { SettingsContext } from "@/components/settings/SettingsProvider";
import { useContext, ReactNode } from "react";
import { LogoComponent } from "@/components/logo/FixedLogo";
import { ReactNode } from "react";
import { SvgProps } from "@/icons";
import SidebarTab from "@/refresh-components/buttons/SidebarTab";
import SidebarWrapper from "@/sections/sidebar/SidebarWrapper";
interface StepSidebarProps {
export interface StepSidebarProps {
children: ReactNode;
buttonName: string;
buttonIcon: React.FunctionComponent<SvgProps>;
@@ -17,30 +16,19 @@ export default function StepSidebar({
buttonIcon,
buttonHref,
}: StepSidebarProps) {
const combinedSettings = useContext(SettingsContext);
if (!combinedSettings) {
return null;
}
const enterpriseSettings = combinedSettings.enterpriseSettings;
return (
<div className="fixed left-0 top-0 flex flex-col h-screen w-[15rem] bg-background-tint-02 py-padding-content px-padding-button gap-padding-content z-10">
<div className="flex flex-col items-start justify-center">
<LogoComponent enterpriseSettings={enterpriseSettings} />
<SidebarWrapper>
<div className="px-spacing-interline">
<SidebarTab
leftIcon={buttonIcon}
className="bg-background-tint-00"
href={buttonHref}
>
{buttonName}
</SidebarTab>
</div>
<SidebarTab
leftIcon={buttonIcon}
className="bg-background-tint-00"
href={buttonHref}
>
{buttonName}
</SidebarTab>
<div className="h-full flex">
<div className="w-full px-2">{children}</div>
</div>
</div>
<div className="h-full w-full px-spacing-paragraph">{children}</div>
</SidebarWrapper>
);
}

View File

@@ -219,10 +219,11 @@ module.exports = {
"token-function": "var(--token-function)",
"token-regex": "var(--token-regex)",
"token-attr-name": "var(--token-attr-name)",
// "non-selectable": "var(--non-selectable)",
},
boxShadow: {
"01": "0px 2px 8px 0px var(--shadow-02)",
"01": "0px 2px 8px 0px var(--shadow-01)",
"02": "0px 2px 8px 0px var(--shadow-02)",
"03": "0px 2px 8px 0px var(--shadow-03)",
// light
"tremor-input": "0 1px 2px 0 rgb(0 0 0 / 0.05)",

788
web/tests/README.md Normal file
View File

@@ -0,0 +1,788 @@
# React Integration Testing Guide
Comprehensive guide for writing integration tests in the Onyx web application using Jest and React Testing Library.
## Table of Contents
- [Running Tests](#running-tests)
- [Core Concepts](#core-concepts)
- [Writing Tests](#writing-tests)
- [Query Selectors](#query-selectors)
- [User Interactions](#user-interactions)
- [Async Operations](#async-operations)
- [Mocking](#mocking)
- [Common Patterns](#common-patterns)
- [Testing Philosophy](#testing-philosophy)
- [Troubleshooting](#troubleshooting)
## Running Tests
```bash
# Run all tests
npm test
# Run specific test file
npm test -- EmailPasswordForm.test
# Run tests matching pattern
npm test -- --testPathPattern="auth"
# Run without coverage
npm test -- --no-coverage
# Run in watch mode
npm test -- --watch
# Run with verbose output
npm test -- --verbose
```
## Core Concepts
### Test Structure
Tests are **co-located** with source files for easy discovery and maintenance:
```
src/app/auth/login/
├── EmailPasswordForm.tsx
└── EmailPasswordForm.test.tsx
```
### Test Anatomy
Every test follows this structure:
```typescript
import { render, screen, setupUser, waitFor } from "@tests/setup/test-utils";
import MyComponent from "./MyComponent";
test("descriptive test name explaining user behavior", async () => {
// 1. Setup - Create user, mock APIs
const user = setupUser();
const fetchSpy = jest.spyOn(global, "fetch");
fetchSpy.mockResolvedValueOnce({
ok: true,
json: async () => ({ data: "value" }),
} as Response);
// 2. Render - Display the component
render(<MyComponent />);
// 3. Act - Simulate user interactions
await user.type(screen.getByRole("textbox"), "test input");
await user.click(screen.getByRole("button", { name: /submit/i }));
// 4. Assert - Verify expected outcomes
await waitFor(() => {
expect(screen.getByText(/success/i)).toBeInTheDocument();
});
// 5. Cleanup - Restore mocks
fetchSpy.mockRestore();
});
```
### setupUser() - Automatic act() Wrapping
**ALWAYS use `setupUser()` instead of `userEvent.setup()`**
```typescript
// ✅ Correct - Automatic act() wrapping
const user = setupUser();
await user.click(button);
await user.type(input, "text");
// ❌ Wrong - Manual act() required, verbose
const user = userEvent.setup();
await act(async () => {
await user.click(button);
});
```
The `setupUser()` helper automatically wraps all user interactions in React's `act()` to prevent warnings and ensure proper state updates.
## Writing Tests
### Query Selectors
Use queries in this priority order (most accessible first):
#### 1. Role Queries (Preferred)
```typescript
// Buttons
screen.getByRole("button", { name: /submit/i })
screen.getByRole("button", { name: /cancel/i })
// Text inputs
screen.getByRole("textbox", { name: /email/i })
// Checkboxes
screen.getByRole("checkbox", { name: /remember me/i })
// Links
screen.getByRole("link", { name: /learn more/i })
// Headings
screen.getByRole("heading", { name: /welcome/i })
```
#### 2. Label Queries
```typescript
// For form inputs with labels
screen.getByLabelText(/password/i)
screen.getByLabelText(/email address/i)
```
#### 3. Placeholder Queries
```typescript
// When no label exists
screen.getByPlaceholderText(/enter email/i)
```
#### 4. Text Queries
```typescript
// For non-interactive text
screen.getByText(/welcome back/i)
screen.getByText(/error occurred/i)
```
#### Query Variants
```typescript
// getBy - Throws error if not found (immediate)
screen.getByRole("button")
// queryBy - Returns null if not found (checking absence)
expect(screen.queryByText(/error/i)).not.toBeInTheDocument()
// findBy - Returns promise, waits for element (async)
expect(await screen.findByText(/success/i)).toBeInTheDocument()
// getAllBy - Returns array of all matches
const inputs = screen.getAllByRole("textbox")
```
### Query Selectors: The Wrong Way
**❌ Avoid these anti-patterns:**
```typescript
// DON'T query by test IDs
screen.getByTestId("submit-button")
// DON'T query by class names
container.querySelector(".submit-btn")
// DON'T query by element types
container.querySelector("button")
```
## User Interactions
### Basic Interactions
```typescript
const user = setupUser();
// Click
await user.click(screen.getByRole("button", { name: /submit/i }));
// Type text
await user.type(screen.getByRole("textbox"), "test input");
// Clear and type
await user.clear(input);
await user.type(input, "new value");
// Check/uncheck checkbox
await user.click(screen.getByRole("checkbox"));
// Select from dropdown
await user.selectOptions(
screen.getByRole("combobox"),
"option-value"
);
// Upload file
const file = new File(["content"], "test.txt", { type: "text/plain" });
const input = screen.getByLabelText(/upload/i);
await user.upload(input, file);
```
### Form Interactions
```typescript
test("user can fill and submit form", async () => {
const user = setupUser();
render(<ContactForm />);
await user.type(screen.getByLabelText(/name/i), "John Doe");
await user.type(screen.getByLabelText(/email/i), "john@example.com");
await user.type(screen.getByLabelText(/message/i), "Hello!");
await user.click(screen.getByRole("button", { name: /send/i }));
await waitFor(() => {
expect(screen.getByText(/message sent/i)).toBeInTheDocument();
});
});
```
## Async Operations
### Handling Async State Updates
**Rule**: After triggering state changes, always wait for UI updates before asserting.
#### Pattern 1: findBy Queries (Simplest)
```typescript
// Element appears after async operation
await user.click(createButton);
expect(await screen.findByRole("textbox")).toBeInTheDocument();
```
#### Pattern 2: waitFor (Complex Assertions)
```typescript
await user.click(submitButton);
await waitFor(() => {
expect(screen.getByText("Success")).toBeInTheDocument();
expect(screen.getByText("Count: 5")).toBeInTheDocument();
});
```
#### Pattern 3: waitForElementToBeRemoved
```typescript
await user.click(deleteButton);
await waitForElementToBeRemoved(() => screen.queryByText(/item name/i));
```
### Common Async Mistakes
```typescript
// ❌ Wrong - getBy immediately after state change
await user.click(button);
expect(screen.getByText("Updated")).toBeInTheDocument(); // May fail!
// ✅ Correct - Wait for state update
await user.click(button);
expect(await screen.findByText("Updated")).toBeInTheDocument();
// ❌ Wrong - Multiple getBy calls without waiting
await user.click(button);
expect(screen.getByText("Success")).toBeInTheDocument();
expect(screen.getByText("Data loaded")).toBeInTheDocument();
// ✅ Correct - Single waitFor with multiple assertions
await user.click(button);
await waitFor(() => {
expect(screen.getByText("Success")).toBeInTheDocument();
expect(screen.getByText("Data loaded")).toBeInTheDocument();
});
```
## Mocking
### Mocking fetch API
**IMPORTANT**: Always document which endpoint each mock corresponds to using comments.
```typescript
let fetchSpy: jest.SpyInstance;
beforeEach(() => {
fetchSpy = jest.spyOn(global, "fetch");
});
afterEach(() => {
fetchSpy.mockRestore();
});
test("fetches data successfully", async () => {
// Mock GET /api/data
fetchSpy.mockResolvedValueOnce({
ok: true,
json: async () => ({ data: [1, 2, 3] }),
} as Response);
render(<MyComponent />);
await waitFor(() => {
expect(fetchSpy).toHaveBeenCalledWith("/api/data");
});
});
```
**Why comment the endpoint?** Sequential mocks can be confusing. Comments make it clear which API call each mock corresponds to, making tests easier to understand and maintain.
### Multiple API Calls
**Pattern**: Document each endpoint with a comment, then verify it was called correctly.
```typescript
test("handles multiple API calls", async () => {
const user = setupUser();
// Mock GET /api/items
fetchSpy.mockResolvedValueOnce({
ok: true,
json: async () => ({ items: [] }),
} as Response);
// Mock POST /api/items
fetchSpy.mockResolvedValueOnce({
ok: true,
json: async () => ({ id: 1, name: "New Item" }),
} as Response);
render(<MyComponent />);
// Verify GET was called
await waitFor(() => {
expect(fetchSpy).toHaveBeenCalledWith("/api/items");
});
await user.click(screen.getByRole("button", { name: /create/i }));
// Verify POST was called
await waitFor(() => {
expect(fetchSpy).toHaveBeenCalledWith(
"/api/items",
expect.objectContaining({ method: "POST" })
);
});
});
```
**Three API calls example:**
```typescript
test("test, create, and set as default", async () => {
const user = setupUser();
// Mock POST /api/llm/test
fetchSpy.mockResolvedValueOnce({
ok: true,
json: async () => ({}),
} as Response);
// Mock PUT /api/llm/provider?is_creation=true
fetchSpy.mockResolvedValueOnce({
ok: true,
json: async () => ({ id: 5, name: "New Provider" }),
} as Response);
// Mock POST /api/llm/provider/5/default
fetchSpy.mockResolvedValueOnce({
ok: true,
json: async () => ({}),
} as Response);
render(<MyForm />);
await user.type(screen.getByLabelText(/name/i), "New Provider");
await user.click(screen.getByRole("button", { name: /create/i }));
// Verify all three endpoints were called
await waitFor(() => {
expect(fetchSpy).toHaveBeenCalledWith(
"/api/llm/test",
expect.objectContaining({ method: "POST" })
);
expect(fetchSpy).toHaveBeenCalledWith(
"/api/llm/provider",
expect.objectContaining({ method: "PUT" })
);
expect(fetchSpy).toHaveBeenCalledWith(
"/api/llm/provider/5/default",
expect.objectContaining({ method: "POST" })
);
});
});
```
### Verifying Request Body
```typescript
test("sends correct data", async () => {
const user = setupUser();
fetchSpy.mockResolvedValueOnce({
ok: true,
json: async () => ({}),
} as Response);
render(<MyForm />);
await user.type(screen.getByLabelText(/name/i), "Test");
await user.click(screen.getByRole("button", { name: /submit/i }));
await waitFor(() => {
expect(fetchSpy).toHaveBeenCalled();
});
const callArgs = fetchSpy.mock.calls[0];
const requestBody = JSON.parse(callArgs[1].body);
expect(requestBody).toEqual({
name: "Test",
active: true,
});
});
```
### Mocking Errors
```typescript
test("displays error message on failure", async () => {
// Mock GET /api/data (network error)
fetchSpy.mockRejectedValueOnce(new Error("Network error"));
render(<MyComponent />);
await waitFor(() => {
expect(screen.getByText(/failed to load/i)).toBeInTheDocument();
});
});
test("handles API error response", async () => {
// Mock POST /api/items (server error)
fetchSpy.mockResolvedValueOnce({
ok: false,
status: 500,
} as Response);
render(<MyComponent />);
await waitFor(() => {
expect(screen.getByText(/something went wrong/i)).toBeInTheDocument();
});
});
```
### Mocking Next.js Router
```typescript
// At top of test file
jest.mock("next/navigation", () => ({
useRouter: () => ({
push: jest.fn(),
back: jest.fn(),
refresh: jest.fn(),
}),
usePathname: () => "/current-path",
}));
```
## Common Patterns
### Testing CRUD Operations
```typescript
describe("User Management", () => {
test("creates new user", async () => {
const user = setupUser();
fetchSpy.mockResolvedValueOnce({
ok: true,
json: async () => ({ id: 1, name: "New User" }),
} as Response);
render(<UserForm />);
await user.type(screen.getByLabelText(/name/i), "New User");
await user.click(screen.getByRole("button", { name: /create/i }));
await waitFor(() => {
expect(screen.getByText(/user created/i)).toBeInTheDocument();
});
});
test("edits existing user", async () => {
const user = setupUser();
fetchSpy.mockResolvedValueOnce({
ok: true,
json: async () => ({ id: 1, name: "Updated User" }),
} as Response);
render(<UserForm initialData={{ id: 1, name: "Old Name" }} />);
await user.clear(screen.getByLabelText(/name/i));
await user.type(screen.getByLabelText(/name/i), "Updated User");
await user.click(screen.getByRole("button", { name: /save/i }));
await waitFor(() => {
expect(screen.getByText(/user updated/i)).toBeInTheDocument();
});
});
test("deletes user", async () => {
const user = setupUser();
fetchSpy.mockResolvedValueOnce({
ok: true,
json: async () => ({}),
} as Response);
render(<UserList />);
await waitFor(() => {
expect(screen.getByText("John Doe")).toBeInTheDocument();
});
await user.click(screen.getByRole("button", { name: /delete/i }));
await waitFor(() => {
expect(screen.queryByText("John Doe")).not.toBeInTheDocument();
});
});
});
```
### Testing Conditional Rendering
```typescript
test("shows edit form when edit button clicked", async () => {
const user = setupUser();
render(<MyComponent />);
expect(screen.queryByRole("textbox")).not.toBeInTheDocument();
await user.click(screen.getByRole("button", { name: /edit/i }));
expect(await screen.findByRole("textbox")).toBeInTheDocument();
});
test("toggles between states", async () => {
const user = setupUser();
render(<Toggle />);
const button = screen.getByRole("button", { name: /show details/i });
await user.click(button);
expect(await screen.findByText(/details content/i)).toBeInTheDocument();
await user.click(button);
expect(screen.queryByText(/details content/i)).not.toBeInTheDocument();
});
```
### Testing Lists and Tables
```typescript
test("displays list of items", async () => {
fetchSpy.mockResolvedValueOnce({
ok: true,
json: async () => ({
items: [
{ id: 1, name: "Item 1" },
{ id: 2, name: "Item 2" },
{ id: 3, name: "Item 3" },
],
}),
} as Response);
render(<ItemList />);
await waitFor(() => {
expect(screen.getByText("Item 1")).toBeInTheDocument();
expect(screen.getByText("Item 2")).toBeInTheDocument();
expect(screen.getByText("Item 3")).toBeInTheDocument();
});
});
test("filters items", async () => {
const user = setupUser();
render(<FilterableList items={mockItems} />);
await user.type(screen.getByRole("searchbox"), "specific");
await waitFor(() => {
expect(screen.getByText("Specific Item")).toBeInTheDocument();
expect(screen.queryByText("Other Item")).not.toBeInTheDocument();
});
});
```
### Testing Validation
```typescript
test("shows validation errors", async () => {
const user = setupUser();
render(<LoginForm />);
await user.click(screen.getByRole("button", { name: /submit/i }));
await waitFor(() => {
expect(screen.getByText(/email is required/i)).toBeInTheDocument();
expect(screen.getByText(/password is required/i)).toBeInTheDocument();
});
});
test("clears validation on valid input", async () => {
const user = setupUser();
render(<LoginForm />);
await user.click(screen.getByRole("button", { name: /submit/i }));
await waitFor(() => {
expect(screen.getByText(/email is required/i)).toBeInTheDocument();
});
await user.type(screen.getByLabelText(/email/i), "valid@email.com");
await waitFor(() => {
expect(screen.queryByText(/email is required/i)).not.toBeInTheDocument();
});
});
```
## Testing Philosophy
### What to Test
**✅ Test user-visible behavior:**
- Forms can be filled and submitted
- Buttons trigger expected actions
- Success/error messages appear
- Navigation works correctly
- Data is displayed after loading
- Validation errors show and clear appropriately
**✅ Test integration points:**
- API calls are made with correct parameters
- Responses are handled properly
- Error states are handled
- Loading states appear
**❌ Don't test implementation details:**
- Internal state values
- Component lifecycle methods
- CSS class names
- Specific React hooks being used
### Test Naming
Write test names that describe user behavior:
```typescript
// ✅ Good - Describes what user can do
test("user can create new prompt", async () => {})
test("shows error when API call fails", async () => {})
test("filters items by search term", async () => {})
// ❌ Bad - Implementation-focused
test("handleSubmit is called", async () => {})
test("state updates correctly", async () => {})
test("renders without crashing", async () => {})
```
### Minimal Mocking
Only mock external dependencies:
```typescript
// ✅ Mock external APIs
jest.spyOn(global, "fetch")
// ✅ Mock Next.js router
jest.mock("next/navigation")
// ✅ Mock problematic packages
// (configured in tests/setup/__mocks__)
// ❌ Don't mock application code
// ❌ Don't mock component internals
// ❌ Don't mock utility functions
```
## Troubleshooting
### "Not wrapped in act()" Warning
**Solution**: Always use `setupUser()` instead of `userEvent.setup()`
```typescript
// ✅ Correct
const user = setupUser();
// ❌ Wrong
const user = userEvent.setup();
```
### "Unable to find element" Error
**Solution**: Element hasn't appeared yet, use `findBy` or `waitFor`
```typescript
// ❌ Wrong - getBy doesn't wait
await user.click(button);
expect(screen.getByText("Success")).toBeInTheDocument();
// ✅ Correct - findBy waits
await user.click(button);
expect(await screen.findByText("Success")).toBeInTheDocument();
```
### "Multiple elements found" Error
**Solution**: Be more specific with your query
```typescript
// ❌ Too broad
screen.getByRole("button")
// ✅ Specific
screen.getByRole("button", { name: /submit/i })
```
### Test Times Out
**Causes**:
1. Async operation never completes
2. Waiting for element that never appears
3. Missing mock for API call
**Solutions**:
```typescript
// Check fetch is mocked
expect(fetchSpy).toHaveBeenCalled()
// Use queryBy to check if element exists
expect(screen.queryByText("Text")).toBeInTheDocument()
// Verify mock is set up before render
fetchSpy.mockResolvedValueOnce(...)
render(<Component />)
```
## Examples
See comprehensive test examples:
- `src/app/auth/login/EmailPasswordForm.test.tsx` - Login/signup workflows, validation
- `src/app/chat/input-prompts/InputPrompts.test.tsx` - CRUD operations, conditional rendering
- `src/app/admin/configuration/llm/CustomLLMProviderUpdateForm.test.tsx` - Complex forms, multi-step workflows
## Built-in Mocks
Only essential mocks in `tests/setup/__mocks__/`:
- `UserProvider` - Removes auth requirement for tests
- `react-markdown` / `remark-gfm` - ESM compatibility
See `tests/setup/__mocks__/README.md` for details.

View File

@@ -0,0 +1,66 @@
/**
* Mock for @/components/user/UserProvider
*
* Why this mock exists:
* The real UserProvider requires complex props (authTypeMetadata, settings, user)
* that are not relevant for most component integration tests. This mock provides
* a simple useUser() hook with safe default values.
*
* Usage:
* Automatically applied via jest.config.js moduleNameMapper.
* Any component that imports from "@/components/user/UserProvider" will get this mock.
*
* To customize user values in a specific test:
* You would need to either:
* 1. Pass props to the real UserProvider (requires disabling this mock for that test)
* 2. Extend this mock to accept custom values via a setup function
*/
import React, { createContext, useContext } from "react";
interface UserContextType {
user: any;
isAdmin: boolean;
isCurator: boolean;
refreshUser: () => Promise<void>;
isCloudSuperuser: boolean;
updateUserAutoScroll: (autoScroll: boolean) => Promise<void>;
updateUserShortcuts: (enabled: boolean) => Promise<void>;
toggleAssistantPinnedStatus: (
currentPinnedAssistantIDs: number[],
assistantId: number,
isPinned: boolean
) => Promise<boolean>;
updateUserTemperatureOverrideEnabled: (enabled: boolean) => Promise<void>;
updateUserPersonalization: (personalization: any) => Promise<void>;
}
const mockUserContext: UserContextType = {
user: null,
isAdmin: false,
isCurator: false,
refreshUser: async () => {},
isCloudSuperuser: false,
updateUserAutoScroll: async () => {},
updateUserShortcuts: async () => {},
toggleAssistantPinnedStatus: async () => true,
updateUserTemperatureOverrideEnabled: async () => {},
updateUserPersonalization: async () => {},
};
const UserContext = createContext<UserContextType | undefined>(mockUserContext);
export function useUser() {
const context = useContext(UserContext);
if (context === undefined) {
throw new Error("useUser must be used within a UserProvider");
}
return context;
}
export function UserProvider({ children }: { children: React.ReactNode }) {
return (
<UserContext.Provider value={mockUserContext}>
{children}
</UserContext.Provider>
);
}

View File

@@ -0,0 +1,92 @@
# Test Mocks
This directory contains Jest mocks for dependencies that are difficult to test with in their original form.
## Why These Mocks Exist
### `@/components/user/UserProvider.tsx`
**Problem:** The real `UserProvider` requires complex setup with props like:
- `authTypeMetadata` (auth configuration)
- `settings` (CombinedSettings object)
- `user` (User object)
**Solution:** This mock provides a simple `useUser()` hook that returns safe default values, allowing components that depend on user context to render in tests without extensive setup.
**Usage:** Automatically applied via `jest.config.js` moduleNameMapper. Components using `useUser()` will get the mock instead of the real provider.
**Mock values:**
```typescript
{
user: null,
isAdmin: false,
isCurator: false,
isCloudSuperuser: false,
refreshUser: async () => {},
updateUserAutoScroll: async () => {},
updateUserShortcuts: async () => {},
toggleAssistantPinnedStatus: async () => true,
updateUserTemperatureOverrideEnabled: async () => {},
updateUserPersonalization: async () => {},
}
```
### `react-markdown.tsx`
**Problem:** The `react-markdown` library uses ESM (ECMAScript Modules) which Jest cannot parse by default without extensive configuration. Many components (like `Field.tsx`) import this library.
**Solution:** This mock provides a simple component that renders markdown content as plain text, avoiding ESM parsing issues.
**Usage:** Automatically applied via `jest.config.js` moduleNameMapper.
**Limitation:** Markdown is not actually rendered/parsed in tests - content is displayed as-is. If you need to test markdown rendering, you'll need to configure Jest to handle ESM properly.
### `remark-gfm.ts`
**Problem:** The `remark-gfm` library (GitHub Flavored Markdown plugin for react-markdown) also uses ESM.
**Solution:** This mock provides a no-op plugin function that does nothing but allows components to import it without errors.
**Usage:** Automatically applied via `jest.config.js` moduleNameMapper.
## When to Add New Mocks
Add mocks to this directory when:
1. **ESM compatibility issues** - Library uses `export/import` syntax that Jest can't parse
2. **Complex setup requirements** - Component/provider requires extensive configuration that's not relevant to your test
3. **External dependencies** - Library makes network calls, accesses browser APIs not available in tests, etc.
## When NOT to Mock
Avoid mocking when:
1. **Testing the actual behavior** - If you're testing how markdown renders, don't mock react-markdown
2. **Simple to configure** - If it's easy to provide real props, use the real component
3. **Core business logic** - Don't mock your own application logic
## Configuration
These mocks are configured in `jest.config.js`:
```javascript
moduleNameMapper: {
// Mock react-markdown and related packages
"^react-markdown$": "<rootDir>/tests/setup/__mocks__/react-markdown.tsx",
"^remark-gfm$": "<rootDir>/tests/setup/__mocks__/remark-gfm.ts",
// Mock UserProvider
"^@/components/user/UserProvider$": "<rootDir>/tests/setup/__mocks__/@/components/user/UserProvider.tsx",
// ... other mappings
}
```
**Important:** Specific mocks must come BEFORE the generic `@/` path alias, otherwise the path alias will match first and the mock won't be applied.
## Debugging Mock Issues
If a component isn't getting the mock:
1. Check that the import path exactly matches the moduleNameMapper pattern
2. Clear Jest cache: `npx jest --clearCache`
3. Check that the mock file path is correct relative to `<rootDir>`
4. Verify the mock comes BEFORE generic path aliases in moduleNameMapper

View File

@@ -0,0 +1,22 @@
/**
* Mock for react-markdown
*
* Why this mock exists:
* react-markdown uses ESM (ECMAScript Modules) which Jest cannot parse by default.
* Components like Field.tsx import react-markdown, which would cause test failures.
*
* Limitation:
* Markdown is NOT actually rendered/parsed in tests - content is displayed as plain text.
* If you need to test actual markdown rendering, you'll need to configure Jest for ESM.
*
* Usage:
* Automatically applied via jest.config.js moduleNameMapper.
*/
import React from "react";
// Simple mock that renders markdown content as plain text
const ReactMarkdown = ({ children }: { children: string }) => {
return <div>{children}</div>;
};
export default ReactMarkdown;

View File

@@ -0,0 +1,18 @@
/**
* Mock for remark-gfm
*
* Why this mock exists:
* remark-gfm (GitHub Flavored Markdown plugin) uses ESM which Jest cannot parse.
* It's a dependency of react-markdown that components import.
*
* Limitation:
* GFM features (tables, strikethrough, etc.) are not processed in tests.
*
* Usage:
* Automatically applied via jest.config.js moduleNameMapper.
*/
// No-op plugin that does nothing but allows imports to succeed
export default function remarkGfm() {
return function () {};
}

View File

@@ -0,0 +1 @@
module.exports = "test-file-stub";

View File

@@ -0,0 +1,82 @@
import "@testing-library/jest-dom";
import { TextEncoder, TextDecoder } from "util";
// Polyfill TextEncoder/TextDecoder (required for some libraries)
global.TextEncoder = TextEncoder;
global.TextDecoder = TextDecoder as any;
// Only set up browser-specific mocks if we're in a jsdom environment
if (typeof window !== "undefined") {
// Polyfill fetch for jsdom
// @ts-ignore
import("whatwg-fetch");
// Mock BroadcastChannel for JSDOM
global.BroadcastChannel = class BroadcastChannel {
constructor(public name: string) {}
postMessage() {}
close() {}
addEventListener() {}
removeEventListener() {}
dispatchEvent() {
return true;
}
} as any;
// Mock window.matchMedia for responsive components
Object.defineProperty(window, "matchMedia", {
writable: true,
value: jest.fn().mockImplementation((query) => ({
matches: false,
media: query,
onchange: null,
addListener: jest.fn(), // deprecated
removeListener: jest.fn(), // deprecated
addEventListener: jest.fn(),
removeEventListener: jest.fn(),
dispatchEvent: jest.fn(),
})),
});
// Mock IntersectionObserver
global.IntersectionObserver = class IntersectionObserver {
constructor() {}
disconnect() {}
observe() {}
takeRecords() {
return [];
}
unobserve() {}
} as any;
// Mock ResizeObserver
global.ResizeObserver = class ResizeObserver {
constructor() {}
disconnect() {}
observe() {}
unobserve() {}
} as any;
// Mock window.scrollTo
global.scrollTo = jest.fn();
}
// Suppress console errors in tests (optional - comment out if you want to see them)
// const originalError = console.error;
// beforeAll(() => {
// console.error = (...args: any[]) => {
// // Filter out known React warnings that are not actionable in tests
// if (
// typeof args[0] === "string" &&
// (args[0].includes("Warning: ReactDOM.render") ||
// args[0].includes("Not implemented: HTMLFormElement.prototype.submit"))
// ) {
// return;
// }
// originalError.call(console, ...args);
// };
// });
// afterAll(() => {
// console.error = originalError;
// });

View File

@@ -0,0 +1,119 @@
import React, { ReactElement } from "react";
import { render, RenderOptions } from "@testing-library/react";
import userEvent from "@testing-library/user-event";
import { SWRConfig } from "swr";
/**
* Custom render function that wraps components with common providers
* used throughout the Onyx application.
*/
interface AllProvidersProps {
children: React.ReactNode;
swrConfig?: Record<string, any>;
}
/**
* Wrapper component that provides all necessary context providers for tests.
* Customize this as needed when you discover more global providers in the app.
*/
function AllTheProviders({ children, swrConfig = {} }: AllProvidersProps) {
return (
<SWRConfig
value={{
// Disable deduping in tests to ensure each test gets fresh data
dedupingInterval: 0,
// Use a Map instead of cache to avoid state leaking between tests
provider: () => new Map(),
// Disable error retries in tests for faster failures
shouldRetryOnError: false,
// Merge any custom SWR config passed from tests
...swrConfig,
}}
>
{children}
</SWRConfig>
);
}
interface CustomRenderOptions extends Omit<RenderOptions, "wrapper"> {
swrConfig?: Record<string, any>;
}
/**
* Custom render function that wraps the component with all providers.
* Use this instead of @testing-library/react's render in your tests.
*
* @example
* import { render, screen } from '@tests/setup/test-utils';
*
* test('renders component', () => {
* render(<MyComponent />);
* expect(screen.getByText('Hello')).toBeInTheDocument();
* });
*
* @example
* // With custom SWR config to mock API responses
* render(<MyComponent />, {
* swrConfig: {
* fallback: {
* '/api/credentials': mockCredentials,
* },
* },
* });
*/
const customRender = (
ui: ReactElement,
{ swrConfig, ...options }: CustomRenderOptions = {}
) => {
const Wrapper = ({ children }: { children: React.ReactNode }) => (
<AllTheProviders swrConfig={swrConfig}>{children}</AllTheProviders>
);
return render(ui, { wrapper: Wrapper, ...options });
};
// Re-export everything from @testing-library/react
export * from "@testing-library/react";
export { userEvent };
// Override render with our custom render
export { customRender as render };
/**
* Setup userEvent with optimized configuration for testing.
* All user interactions are automatically wrapped in act() to prevent warnings.
* Use this helper instead of userEvent.setup() directly.
*
* @example
* const user = setupUser();
* await user.click(button);
* await user.type(input, "text");
*/
export function setupUser(options = {}) {
const baseUser = userEvent.setup({
// Configure for React 18 to reduce act warnings
delay: null, // Instant typing - batches state updates better
...options,
});
// Wrap all user-event methods in act() to prevent act warnings. We add this here
// to prevent all callsites from needing to import and wrap user events in act()
return new Proxy(baseUser, {
get(target, prop) {
const value = target[prop as keyof typeof target];
// Only wrap methods (functions), not properties
if (typeof value === "function") {
return async (...args: any[]) => {
const { act } = await import("@testing-library/react");
return act(async () => {
return (value as Function).apply(target, args);
});
};
}
return value;
},
});
}