mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-21 23:52:43 +00:00
Compare commits
1 Commits
bo/query_p
...
v2.1.0-bet
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
53494324aa |
@@ -109,6 +109,15 @@ jobs:
|
||||
# Needed for trivyignore
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Check if stable release version
|
||||
id: check_version
|
||||
run: |
|
||||
if [[ "${{ github.ref_name }}" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]] && [[ "${{ github.ref_name }}" != *"cloud"* ]]; then
|
||||
echo "is_stable=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "is_stable=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Download digests
|
||||
uses: actions/download-artifact@v4
|
||||
|
||||
@@ -120,6 +120,15 @@ jobs:
|
||||
if: needs.precheck.outputs.should-run == 'true'
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Check if stable release version
|
||||
id: check_version
|
||||
run: |
|
||||
if [[ "${{ github.ref_name }}" =~ ^v[0-9]+\.[0-9]+\.[0-9]+$ ]] && [[ "${{ github.ref_name }}" != *"cloud"* ]]; then
|
||||
echo "is_stable=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "is_stable=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Download digests
|
||||
uses: actions/download-artifact@v4
|
||||
with:
|
||||
|
||||
4
.github/workflows/docker-tag-latest.yml
vendored
4
.github/workflows/docker-tag-latest.yml
vendored
@@ -35,3 +35,7 @@ jobs:
|
||||
- name: Pull, Tag and Push API Server Image
|
||||
run: |
|
||||
docker buildx imagetools create -t onyxdotapp/onyx-backend:latest onyxdotapp/onyx-backend:${{ github.event.inputs.version }}
|
||||
|
||||
- name: Pull, Tag and Push Model Server Image
|
||||
run: |
|
||||
docker buildx imagetools create -t onyxdotapp/onyx-model-server:latest onyxdotapp/onyx-model-server:${{ github.event.inputs.version }}
|
||||
|
||||
35
.github/workflows/pr-jest-tests.yml
vendored
Normal file
35
.github/workflows/pr-jest-tests.yml
vendored
Normal file
@@ -0,0 +1,35 @@
|
||||
name: Run Jest Tests
|
||||
concurrency:
|
||||
group: Run-Jest-Tests-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on: push
|
||||
|
||||
jobs:
|
||||
jest-tests:
|
||||
name: Jest Tests
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Setup node
|
||||
uses: actions/setup-node@v4
|
||||
with:
|
||||
node-version: 22
|
||||
|
||||
- name: Install node dependencies
|
||||
working-directory: ./web
|
||||
run: npm ci
|
||||
|
||||
- name: Run Jest tests
|
||||
working-directory: ./web
|
||||
run: npm test -- --ci --coverage --maxWorkers=50%
|
||||
|
||||
- name: Upload coverage reports
|
||||
if: always()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: jest-coverage-${{ github.run_id }}
|
||||
path: ./web/coverage
|
||||
retention-days: 7
|
||||
@@ -50,6 +50,25 @@ def get_empty_chat_messages_entries__paginated(
|
||||
if message.message_type != MessageType.USER:
|
||||
continue
|
||||
|
||||
# Get user email
|
||||
user_email = chat_session.user.email if chat_session.user else None
|
||||
|
||||
# Get assistant name (from session persona, or alternate if specified)
|
||||
assistant_name = None
|
||||
if message.alternate_assistant_id:
|
||||
# If there's an alternate assistant, we need to fetch it
|
||||
from onyx.db.models import Persona
|
||||
|
||||
alternate_persona = (
|
||||
db_session.query(Persona)
|
||||
.filter(Persona.id == message.alternate_assistant_id)
|
||||
.first()
|
||||
)
|
||||
if alternate_persona:
|
||||
assistant_name = alternate_persona.name
|
||||
elif chat_session.persona:
|
||||
assistant_name = chat_session.persona.name
|
||||
|
||||
message_skeletons.append(
|
||||
ChatMessageSkeleton(
|
||||
message_id=message.id,
|
||||
@@ -57,6 +76,9 @@ def get_empty_chat_messages_entries__paginated(
|
||||
user_id=str(chat_session.user_id) if chat_session.user_id else None,
|
||||
flow_type=flow_type,
|
||||
time_sent=message.time_sent,
|
||||
assistant_name=assistant_name,
|
||||
user_email=user_email,
|
||||
number_of_tokens=message.token_count,
|
||||
)
|
||||
)
|
||||
if len(chat_sessions) == 0:
|
||||
|
||||
@@ -48,7 +48,17 @@ def generate_chat_messages_report(
|
||||
max_size=MAX_IN_MEMORY_SIZE, mode="w+"
|
||||
) as temp_file:
|
||||
csvwriter = csv.writer(temp_file, delimiter=",")
|
||||
csvwriter.writerow(["session_id", "user_id", "flow_type", "time_sent"])
|
||||
csvwriter.writerow(
|
||||
[
|
||||
"session_id",
|
||||
"user_id",
|
||||
"flow_type",
|
||||
"time_sent",
|
||||
"assistant_name",
|
||||
"user_email",
|
||||
"number_of_tokens",
|
||||
]
|
||||
)
|
||||
for chat_message_skeleton_batch in get_all_empty_chat_message_entries(
|
||||
db_session, period
|
||||
):
|
||||
@@ -59,6 +69,9 @@ def generate_chat_messages_report(
|
||||
chat_message_skeleton.user_id,
|
||||
chat_message_skeleton.flow_type,
|
||||
chat_message_skeleton.time_sent.isoformat(),
|
||||
chat_message_skeleton.assistant_name,
|
||||
chat_message_skeleton.user_email,
|
||||
chat_message_skeleton.number_of_tokens,
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@@ -16,6 +16,9 @@ class ChatMessageSkeleton(BaseModel):
|
||||
user_id: str | None
|
||||
flow_type: FlowType
|
||||
time_sent: datetime
|
||||
assistant_name: str | None
|
||||
user_email: str | None
|
||||
number_of_tokens: int
|
||||
|
||||
|
||||
class UserSkeleton(BaseModel):
|
||||
|
||||
@@ -37,6 +37,8 @@ from onyx.configs.model_configs import LITELLM_EXTRA_BODY
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMConfig
|
||||
from onyx.llm.interfaces import ToolChoiceOptions
|
||||
from onyx.llm.llm_provider_options import VERTEX_CREDENTIALS_FILE_KWARG
|
||||
from onyx.llm.llm_provider_options import VERTEX_LOCATION_KWARG
|
||||
from onyx.llm.utils import model_is_reasoning_model
|
||||
from onyx.server.utils import mask_string
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -49,8 +51,6 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
_LLM_PROMPT_LONG_TERM_LOG_CATEGORY = "llm_prompt"
|
||||
VERTEX_CREDENTIALS_FILE_KWARG = "vertex_credentials"
|
||||
VERTEX_LOCATION_KWARG = "vertex_location"
|
||||
LEGACY_MAX_TOKENS_KWARG = "max_tokens"
|
||||
STANDARD_MAX_TOKENS_KWARG = "max_completion_tokens"
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ from onyx.llm.exceptions import GenAIDisabledException
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.llm_provider_options import OLLAMA_API_KEY_CONFIG_KEY
|
||||
from onyx.llm.llm_provider_options import OLLAMA_PROVIDER_NAME
|
||||
from onyx.llm.llm_provider_options import OPENROUTER_PROVIDER_NAME
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.llm.utils import get_max_input_tokens_from_llm_provider
|
||||
from onyx.llm.utils import model_supports_image_input
|
||||
@@ -26,19 +27,24 @@ logger = setup_logger()
|
||||
def _build_provider_extra_headers(
|
||||
provider: str, custom_config: dict[str, str] | None
|
||||
) -> dict[str, str]:
|
||||
if provider != OLLAMA_PROVIDER_NAME or not custom_config:
|
||||
return {}
|
||||
# Ollama Cloud: allow passing Bearer token via custom config for cloud instances
|
||||
if provider == OLLAMA_PROVIDER_NAME and custom_config:
|
||||
raw_api_key = custom_config.get(OLLAMA_API_KEY_CONFIG_KEY)
|
||||
api_key = raw_api_key.strip() if raw_api_key else None
|
||||
if not api_key:
|
||||
return {}
|
||||
if not api_key.lower().startswith("bearer "):
|
||||
api_key = f"Bearer {api_key}"
|
||||
return {"Authorization": api_key}
|
||||
|
||||
raw_api_key = custom_config.get(OLLAMA_API_KEY_CONFIG_KEY)
|
||||
# Passing these will put Onyx on the OpenRouter leaderboard
|
||||
elif provider == OPENROUTER_PROVIDER_NAME:
|
||||
return {
|
||||
"HTTP-Referer": "https://onyx.app",
|
||||
"X-Title": "Onyx",
|
||||
}
|
||||
|
||||
api_key = raw_api_key.strip() if raw_api_key else None
|
||||
if not api_key:
|
||||
return {}
|
||||
|
||||
if not api_key.lower().startswith("bearer "):
|
||||
api_key = f"Bearer {api_key}"
|
||||
|
||||
return {"Authorization": api_key}
|
||||
return {}
|
||||
|
||||
|
||||
def get_main_llm_from_tuple(
|
||||
|
||||
@@ -2,8 +2,6 @@ from enum import Enum
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.llm.chat_llm import VERTEX_CREDENTIALS_FILE_KWARG
|
||||
from onyx.llm.chat_llm import VERTEX_LOCATION_KWARG
|
||||
from onyx.llm.utils import model_supports_image_input
|
||||
from onyx.server.manage.llm.models import ModelConfigurationView
|
||||
|
||||
@@ -137,18 +135,8 @@ BEDROCK_REGION_OPTIONS = _build_bedrock_region_options()
|
||||
OLLAMA_PROVIDER_NAME = "ollama"
|
||||
OLLAMA_API_KEY_CONFIG_KEY = "OLLAMA_API_KEY"
|
||||
|
||||
|
||||
def get_bedrock_model_names() -> list[str]:
|
||||
import litellm
|
||||
|
||||
# bedrock_converse_models are just extensions of the bedrock_models, not sure why
|
||||
# litellm has split them into two lists :(
|
||||
return [
|
||||
model
|
||||
for model in list(litellm.bedrock_models.union(litellm.bedrock_converse_models))
|
||||
if "/" not in model and "embed" not in model
|
||||
][::-1]
|
||||
|
||||
# OpenRouter
|
||||
OPENROUTER_PROVIDER_NAME = "openrouter"
|
||||
|
||||
IGNORABLE_ANTHROPIC_MODELS = [
|
||||
"claude-2",
|
||||
@@ -157,17 +145,6 @@ IGNORABLE_ANTHROPIC_MODELS = [
|
||||
]
|
||||
ANTHROPIC_PROVIDER_NAME = "anthropic"
|
||||
|
||||
|
||||
def get_anthropic_model_names() -> list[str]:
|
||||
import litellm
|
||||
|
||||
return [
|
||||
model
|
||||
for model in litellm.anthropic_models
|
||||
if model not in IGNORABLE_ANTHROPIC_MODELS
|
||||
][::-1]
|
||||
|
||||
|
||||
ANTHROPIC_VISIBLE_MODEL_NAMES = [
|
||||
"claude-sonnet-4-5-20250929",
|
||||
"claude-sonnet-4-20250514",
|
||||
@@ -177,6 +154,8 @@ AZURE_PROVIDER_NAME = "azure"
|
||||
|
||||
|
||||
VERTEXAI_PROVIDER_NAME = "vertex_ai"
|
||||
VERTEX_CREDENTIALS_FILE_KWARG = "vertex_credentials"
|
||||
VERTEX_LOCATION_KWARG = "vertex_location"
|
||||
VERTEXAI_DEFAULT_MODEL = "gemini-2.0-flash"
|
||||
VERTEXAI_DEFAULT_FAST_MODEL = "gemini-2.0-flash-lite"
|
||||
VERTEXAI_MODEL_NAMES = [
|
||||
@@ -223,15 +202,39 @@ def _get_provider_to_models_map() -> dict[str, list[str]]:
|
||||
ANTHROPIC_PROVIDER_NAME: get_anthropic_model_names(),
|
||||
VERTEXAI_PROVIDER_NAME: VERTEXAI_MODEL_NAMES,
|
||||
OLLAMA_PROVIDER_NAME: [],
|
||||
OPENROUTER_PROVIDER_NAME: [],
|
||||
}
|
||||
|
||||
|
||||
def get_bedrock_model_names() -> list[str]:
|
||||
import litellm
|
||||
|
||||
# bedrock_converse_models are just extensions of the bedrock_models, not sure why
|
||||
# litellm has split them into two lists :(
|
||||
return [
|
||||
model
|
||||
for model in list(litellm.bedrock_models.union(litellm.bedrock_converse_models))
|
||||
if "/" not in model and "embed" not in model
|
||||
][::-1]
|
||||
|
||||
|
||||
def get_anthropic_model_names() -> list[str]:
|
||||
import litellm
|
||||
|
||||
return [
|
||||
model
|
||||
for model in litellm.anthropic_models
|
||||
if model not in IGNORABLE_ANTHROPIC_MODELS
|
||||
][::-1]
|
||||
|
||||
|
||||
_PROVIDER_TO_VISIBLE_MODELS_MAP = {
|
||||
OPENAI_PROVIDER_NAME: OPEN_AI_VISIBLE_MODEL_NAMES,
|
||||
BEDROCK_PROVIDER_NAME: [],
|
||||
ANTHROPIC_PROVIDER_NAME: ANTHROPIC_VISIBLE_MODEL_NAMES,
|
||||
VERTEXAI_PROVIDER_NAME: VERTEXAI_VISIBLE_MODEL_NAMES,
|
||||
OLLAMA_PROVIDER_NAME: [],
|
||||
OPENROUTER_PROVIDER_NAME: [],
|
||||
}
|
||||
|
||||
|
||||
@@ -372,6 +375,20 @@ def fetch_available_well_known_llms() -> list[WellKnownLLMProviderDescriptor]:
|
||||
default_model=VERTEXAI_DEFAULT_MODEL,
|
||||
default_fast_model=VERTEXAI_DEFAULT_MODEL,
|
||||
),
|
||||
WellKnownLLMProviderDescriptor(
|
||||
name=OPENROUTER_PROVIDER_NAME,
|
||||
display_name="OpenRouter",
|
||||
api_key_required=True,
|
||||
api_base_required=True,
|
||||
api_version_required=False,
|
||||
custom_config_keys=[],
|
||||
model_configurations=fetch_model_configurations_for_provider(
|
||||
OPENROUTER_PROVIDER_NAME
|
||||
),
|
||||
default_model=None,
|
||||
default_fast_model=None,
|
||||
default_api_base="https://openrouter.ai/api/v1",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -644,41 +644,37 @@ def get_max_input_tokens_from_llm_provider(
|
||||
|
||||
|
||||
def model_supports_image_input(model_name: str, model_provider: str) -> bool:
|
||||
# TODO: Add support to check model config for any provider
|
||||
# TODO: Circular import means OLLAMA_PROVIDER_NAME is not available here
|
||||
|
||||
if model_provider == "ollama":
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
model_config = db_session.scalar(
|
||||
select(ModelConfiguration)
|
||||
.join(
|
||||
LLMProvider,
|
||||
ModelConfiguration.llm_provider_id == LLMProvider.id,
|
||||
)
|
||||
.where(
|
||||
ModelConfiguration.name == model_name,
|
||||
LLMProvider.provider == model_provider,
|
||||
)
|
||||
)
|
||||
if model_config and model_config.supports_image_input is not None:
|
||||
return model_config.supports_image_input
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to query database for {model_provider} model {model_name} image support: {e}"
|
||||
)
|
||||
|
||||
model_map = get_model_map()
|
||||
# First, try to read an explicit configuration from the model_configuration table
|
||||
try:
|
||||
model_obj = find_model_obj(
|
||||
model_map,
|
||||
model_provider,
|
||||
model_name,
|
||||
)
|
||||
if not model_obj:
|
||||
raise RuntimeError(
|
||||
f"No litellm entry found for {model_provider}/{model_name}"
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
model_config = db_session.scalar(
|
||||
select(ModelConfiguration)
|
||||
.join(
|
||||
LLMProvider,
|
||||
ModelConfiguration.llm_provider_id == LLMProvider.id,
|
||||
)
|
||||
.where(
|
||||
ModelConfiguration.name == model_name,
|
||||
LLMProvider.provider == model_provider,
|
||||
)
|
||||
)
|
||||
if model_config and model_config.supports_image_input is not None:
|
||||
return model_config.supports_image_input
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to query database for {model_provider} model {model_name} image support: {e}"
|
||||
)
|
||||
|
||||
# Fallback to looking up the model in the litellm model_cost dict
|
||||
try:
|
||||
model_obj = find_model_obj(get_model_map(), model_provider, model_name)
|
||||
if not model_obj:
|
||||
logger.warning(
|
||||
f"No litellm entry found for {model_provider}/{model_name}, "
|
||||
"this model may or may not support image input."
|
||||
)
|
||||
return False
|
||||
return model_obj.get("supports_vision", False)
|
||||
except Exception:
|
||||
logger.exception(
|
||||
|
||||
@@ -46,6 +46,9 @@ from onyx.server.manage.llm.models import ModelConfigurationUpsertRequest
|
||||
from onyx.server.manage.llm.models import OllamaFinalModelResponse
|
||||
from onyx.server.manage.llm.models import OllamaModelDetails
|
||||
from onyx.server.manage.llm.models import OllamaModelsRequest
|
||||
from onyx.server.manage.llm.models import OpenRouterFinalModelResponse
|
||||
from onyx.server.manage.llm.models import OpenRouterModelDetails
|
||||
from onyx.server.manage.llm.models import OpenRouterModelsRequest
|
||||
from onyx.server.manage.llm.models import TestLLMRequest
|
||||
from onyx.server.manage.llm.models import VisionProviderResponse
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -577,3 +580,75 @@ def get_ollama_available_models(
|
||||
)
|
||||
|
||||
return all_models_with_context_size_and_vision
|
||||
|
||||
|
||||
def _get_openrouter_models_response(api_base: str, api_key: str) -> dict:
|
||||
"""Perform GET to OpenRouter /models and return parsed JSON."""
|
||||
cleaned_api_base = api_base.strip().rstrip("/")
|
||||
url = f"{cleaned_api_base}/models"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {api_key}",
|
||||
# Optional headers recommended by OpenRouter for attribution
|
||||
"HTTP-Referer": "https://onyx.app",
|
||||
"X-Title": "Onyx",
|
||||
}
|
||||
try:
|
||||
response = httpx.get(url, headers=headers, timeout=10.0)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Failed to fetch OpenRouter models: {e}",
|
||||
)
|
||||
|
||||
|
||||
@admin_router.post("/openrouter/available-models")
|
||||
def get_openrouter_available_models(
|
||||
request: OpenRouterModelsRequest,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> list[OpenRouterFinalModelResponse]:
|
||||
"""Fetch available models from OpenRouter `/models` endpoint.
|
||||
|
||||
Parses id, context_length, and architecture.input_modalities to infer vision support.
|
||||
"""
|
||||
|
||||
response_json = _get_openrouter_models_response(
|
||||
api_base=request.api_base, api_key=request.api_key
|
||||
)
|
||||
|
||||
data = response_json.get("data", [])
|
||||
if not isinstance(data, list) or len(data) == 0:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="No models found from your OpenRouter endpoint",
|
||||
)
|
||||
|
||||
results: list[OpenRouterFinalModelResponse] = []
|
||||
for item in data:
|
||||
try:
|
||||
model_details = OpenRouterModelDetails.model_validate(item)
|
||||
|
||||
# NOTE: This should be removed if we ever support dynamically fetching embedding models.
|
||||
if model_details.is_embedding_model:
|
||||
continue
|
||||
|
||||
results.append(
|
||||
OpenRouterFinalModelResponse(
|
||||
name=model_details.id,
|
||||
max_input_tokens=model_details.context_length,
|
||||
supports_image_input=model_details.supports_image_input,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
"Failed to parse OpenRouter model entry",
|
||||
extra={"error": str(e), "item": str(item)[:1000]},
|
||||
)
|
||||
|
||||
if not results:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="No compatible models found from OpenRouter"
|
||||
)
|
||||
|
||||
return sorted(results, key=lambda m: m.name.lower())
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import Field
|
||||
from pydantic import field_validator
|
||||
|
||||
from onyx.llm.utils import get_max_input_tokens
|
||||
from onyx.llm.utils import model_supports_image_input
|
||||
@@ -34,6 +35,12 @@ class TestLLMRequest(BaseModel):
|
||||
# if try and use the existing API key
|
||||
api_key_changed: bool
|
||||
|
||||
@field_validator("provider", mode="before")
|
||||
@classmethod
|
||||
def normalize_provider(cls, value: str) -> str:
|
||||
"""Normalize provider name by stripping whitespace and lowercasing."""
|
||||
return value.strip().lower()
|
||||
|
||||
|
||||
class LLMProviderDescriptor(BaseModel):
|
||||
"""A descriptor for an LLM provider that can be safely viewed by
|
||||
@@ -91,6 +98,12 @@ class LLMProviderUpsertRequest(LLMProvider):
|
||||
api_key_changed: bool = False
|
||||
model_configurations: list["ModelConfigurationUpsertRequest"] = []
|
||||
|
||||
@field_validator("provider", mode="before")
|
||||
@classmethod
|
||||
def normalize_provider(cls, value: str) -> str:
|
||||
"""Normalize provider name by stripping whitespace and lowercasing."""
|
||||
return value.strip().lower()
|
||||
|
||||
|
||||
class LLMProviderView(LLMProvider):
|
||||
"""Stripped down representation of LLMProvider for display / limited access info only"""
|
||||
@@ -224,3 +237,36 @@ class OllamaModelDetails(BaseModel):
|
||||
def supports_image_input(self) -> bool:
|
||||
"""Check if this model supports image input"""
|
||||
return "vision" in self.capabilities
|
||||
|
||||
|
||||
# OpenRouter dynamic models fetch
|
||||
class OpenRouterModelsRequest(BaseModel):
|
||||
api_base: str
|
||||
api_key: str
|
||||
|
||||
|
||||
class OpenRouterModelDetails(BaseModel):
|
||||
"""Response model for OpenRouter /api/v1/models endpoint"""
|
||||
|
||||
# This is used to ignore any extra fields that are returned from the API
|
||||
model_config = {"extra": "ignore"}
|
||||
|
||||
id: str
|
||||
context_length: int
|
||||
architecture: dict[str, Any] # Contains 'input_modalities' key
|
||||
|
||||
@property
|
||||
def supports_image_input(self) -> bool:
|
||||
input_modalities = self.architecture.get("input_modalities", [])
|
||||
return isinstance(input_modalities, list) and "image" in input_modalities
|
||||
|
||||
@property
|
||||
def is_embedding_model(self) -> bool:
|
||||
output_modalities = self.architecture.get("output_modalities", [])
|
||||
return isinstance(output_modalities, list) and "embeddings" in output_modalities
|
||||
|
||||
|
||||
class OpenRouterFinalModelResponse(BaseModel):
|
||||
name: str
|
||||
max_input_tokens: int
|
||||
supports_image_input: bool
|
||||
|
||||
@@ -33,6 +33,7 @@ from onyx.server.manage.models import FullModelVersionResponse
|
||||
from onyx.server.models import IdReturn
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import ALT_INDEX_SUFFIX
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
|
||||
router = APIRouter(prefix="/search-settings")
|
||||
logger = setup_logger()
|
||||
@@ -50,6 +51,13 @@ def set_new_search_settings(
|
||||
if search_settings_new.index_name:
|
||||
logger.warning("Index name was specified by request, this is not suggested")
|
||||
|
||||
# Disallow contextual RAG for cloud deployments
|
||||
if MULTI_TENANT and search_settings_new.enable_contextual_rag:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Contextual RAG disabled in Onyx Cloud",
|
||||
)
|
||||
|
||||
# Validate cloud provider exists or create new LiteLLM provider
|
||||
if search_settings_new.provider_type is not None:
|
||||
cloud_provider = get_embedding_provider_from_provider_type(
|
||||
@@ -217,6 +225,13 @@ def update_saved_search_settings(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
# Disallow contextual RAG for cloud deployments
|
||||
if MULTI_TENANT and search_settings.enable_contextual_rag:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Contextual RAG disabled in Onyx Cloud",
|
||||
)
|
||||
|
||||
update_current_search_settings(
|
||||
search_settings=search_settings, db_session=db_session
|
||||
)
|
||||
|
||||
@@ -6,6 +6,7 @@ Full tenant analysis script that:
|
||||
3. Analyzes the collected data
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import csv
|
||||
import json
|
||||
import os
|
||||
@@ -361,17 +362,35 @@ def find_recent_tenant_data() -> tuple[list[dict[str, Any]] | None, str | None]:
|
||||
|
||||
|
||||
def main() -> None:
|
||||
# Parse command-line arguments
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Analyze tenant data and identify gated tenants with no recent queries"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--skip-cache",
|
||||
action="store_true",
|
||||
help="Skip cached tenant data and collect fresh data from pod",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
# Step 0: Collect control plane data
|
||||
control_plane_data = collect_control_plane_data()
|
||||
|
||||
# Step 1: Check for recent tenant data (< 7 days old)
|
||||
tenant_data, cached_file = find_recent_tenant_data()
|
||||
# Step 1: Check for recent tenant data (< 7 days old) unless --skip-cache is set
|
||||
tenant_data = None
|
||||
cached_file = None
|
||||
|
||||
if not args.skip_cache:
|
||||
tenant_data, cached_file = find_recent_tenant_data()
|
||||
|
||||
if tenant_data:
|
||||
print(f"Using cached tenant data from: {cached_file}")
|
||||
print(f"Total tenants in cache: {len(tenant_data)}")
|
||||
else:
|
||||
if args.skip_cache:
|
||||
print("\n⚠ Skipping cache (--skip-cache flag set)")
|
||||
|
||||
# Step 2a: Find the heavy worker pod
|
||||
pod_name = find_worker_pod()
|
||||
|
||||
|
||||
@@ -21,10 +21,12 @@ Examples:
|
||||
python backend/scripts/cleanup_tenant.py --csv gated_tenants_no_query_3mo.csv --force
|
||||
"""
|
||||
|
||||
import csv
|
||||
import json
|
||||
import signal
|
||||
import subprocess
|
||||
import sys
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
from scripts.tenant_cleanup.cleanup_utils import confirm_step
|
||||
@@ -420,13 +422,17 @@ def cleanup_control_plane(tenant_id: str, force: bool = False) -> None:
|
||||
raise
|
||||
|
||||
|
||||
def cleanup_tenant(tenant_id: str, force: bool = False) -> None:
|
||||
def cleanup_tenant(tenant_id: str, pod_name: str, force: bool = False) -> bool:
|
||||
"""
|
||||
Main cleanup function that orchestrates all cleanup steps.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant ID to clean up
|
||||
pod_name: The Kubernetes pod name to execute operations on
|
||||
force: If True, skip all confirmation prompts
|
||||
|
||||
Returns:
|
||||
True if cleanup was performed, False if skipped
|
||||
"""
|
||||
print(f"Starting cleanup for tenant: {tenant_id}")
|
||||
|
||||
@@ -445,44 +451,52 @@ def cleanup_tenant(tenant_id: str, force: bool = False) -> None:
|
||||
)
|
||||
print(f"{'=' * 80}\n")
|
||||
|
||||
if force:
|
||||
print(f"Skipping cleanup for tenant {tenant_id} in force mode")
|
||||
return False
|
||||
|
||||
# Always ask for confirmation if not gated, even in force mode
|
||||
response = input(
|
||||
"Are you ABSOLUTELY SURE you want to proceed? Type 'yes' to confirm: "
|
||||
)
|
||||
if response.lower() != "yes":
|
||||
print("Cleanup aborted - tenant is not GATED_ACCESS")
|
||||
return
|
||||
return False
|
||||
elif tenant_status == "GATED_ACCESS":
|
||||
print("✓ Tenant status is GATED_ACCESS - safe to proceed with cleanup")
|
||||
elif tenant_status is None:
|
||||
print("⚠️ WARNING: Could not determine tenant status!")
|
||||
|
||||
if force:
|
||||
print(f"Skipping cleanup for tenant {tenant_id} in force mode")
|
||||
return False
|
||||
|
||||
response = input("Continue anyway? Type 'yes' to confirm: ")
|
||||
if response.lower() != "yes":
|
||||
print("Cleanup aborted - could not verify tenant status")
|
||||
return
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"⚠️ WARNING: Failed to check tenant status: {e}")
|
||||
|
||||
if force:
|
||||
print(f"Skipping cleanup for tenant {tenant_id} in force mode")
|
||||
return False
|
||||
|
||||
response = input("Continue anyway? Type 'yes' to confirm: ")
|
||||
if response.lower() != "yes":
|
||||
print("Cleanup aborted - could not verify tenant status")
|
||||
return
|
||||
return False
|
||||
print(f"{'=' * 80}\n")
|
||||
|
||||
# Find heavy worker pod for Vespa and schema operations
|
||||
try:
|
||||
pod_name = find_worker_pod()
|
||||
except Exception as e:
|
||||
print(f"✗ Failed to find heavy worker pod: {e}", file=sys.stderr)
|
||||
print("Cannot proceed with Vespa and schema cleanup")
|
||||
return
|
||||
|
||||
# Fetch tenant users for informational purposes (non-blocking)
|
||||
print(f"\n{'=' * 80}")
|
||||
try:
|
||||
get_tenant_users(pod_name, tenant_id)
|
||||
except Exception as e:
|
||||
print(f"⚠ Could not fetch tenant users: {e}")
|
||||
print(f"{'=' * 80}\n")
|
||||
# Skip in force mode as it's only informational
|
||||
if not force:
|
||||
print(f"\n{'=' * 80}")
|
||||
try:
|
||||
get_tenant_users(pod_name, tenant_id)
|
||||
except Exception as e:
|
||||
print(f"⚠ Could not fetch tenant users: {e}")
|
||||
print(f"{'=' * 80}\n")
|
||||
|
||||
# Step 1: Make sure all documents are deleted
|
||||
print(f"\n{'=' * 80}")
|
||||
@@ -498,7 +512,7 @@ def cleanup_tenant(tenant_id: str, force: bool = False) -> None:
|
||||
print(
|
||||
"You may need to mark connectors for deletion and wait for cleanup to complete."
|
||||
)
|
||||
return
|
||||
return False
|
||||
print(f"{'=' * 80}\n")
|
||||
|
||||
# Step 2: Drop data plane schema
|
||||
@@ -514,7 +528,7 @@ def cleanup_tenant(tenant_id: str, force: bool = False) -> None:
|
||||
response = input("Continue with control plane cleanup? (y/n): ")
|
||||
if response.lower() != "y":
|
||||
print("Cleanup aborted by user")
|
||||
return
|
||||
return False
|
||||
else:
|
||||
print("[FORCE MODE] Continuing despite schema cleanup failure")
|
||||
else:
|
||||
@@ -535,11 +549,12 @@ def cleanup_tenant(tenant_id: str, force: bool = False) -> None:
|
||||
print("[FORCE MODE] Control plane cleanup failed but continuing")
|
||||
else:
|
||||
print("Step 3 skipped by user")
|
||||
return
|
||||
return False
|
||||
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"✓ Cleanup completed for tenant: {tenant_id}")
|
||||
print(f"{'=' * 80}")
|
||||
return True
|
||||
|
||||
|
||||
def main() -> None:
|
||||
@@ -643,43 +658,87 @@ def main() -> None:
|
||||
f"⚠ FORCE MODE: Running cleanup for {len(tenant_ids)} tenants without confirmations"
|
||||
)
|
||||
|
||||
# Find heavy worker pod once for all tenants
|
||||
try:
|
||||
pod_name = find_worker_pod()
|
||||
print(f"✓ Found worker pod: {pod_name}\n")
|
||||
except Exception as e:
|
||||
print(f"✗ Failed to find heavy worker pod: {e}", file=sys.stderr)
|
||||
print("Cannot proceed with cleanup")
|
||||
sys.exit(1)
|
||||
|
||||
# Run cleanup for each tenant
|
||||
failed_tenants = []
|
||||
successful_tenants = []
|
||||
skipped_tenants = []
|
||||
|
||||
for idx, tenant_id in enumerate(tenant_ids, 1):
|
||||
if len(tenant_ids) > 1:
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"Processing tenant {idx}/{len(tenant_ids)}: {tenant_id}")
|
||||
print(f"{'=' * 80}")
|
||||
# Open CSV file for writing successful cleanups in real-time
|
||||
csv_output_path = "cleaned_tenants.csv"
|
||||
with open(csv_output_path, "w", newline="") as csv_file:
|
||||
csv_writer = csv.writer(csv_file)
|
||||
csv_writer.writerow(["tenant_id", "cleaned_at"])
|
||||
csv_file.flush() # Ensure header is written immediately
|
||||
|
||||
try:
|
||||
cleanup_tenant(tenant_id, force)
|
||||
successful_tenants.append(tenant_id)
|
||||
except Exception as e:
|
||||
print(f"✗ Cleanup failed for tenant {tenant_id}: {e}", file=sys.stderr)
|
||||
failed_tenants.append((tenant_id, str(e)))
|
||||
print(f"Writing successful cleanups to: {csv_output_path}\n")
|
||||
|
||||
# If not in force mode and there are more tenants, ask if we should continue
|
||||
if not force and idx < len(tenant_ids):
|
||||
response = input(
|
||||
f"\nContinue with remaining {len(tenant_ids) - idx} tenant(s)? (y/n): "
|
||||
)
|
||||
if response.lower() != "y":
|
||||
print("Cleanup aborted by user")
|
||||
break
|
||||
for idx, tenant_id in enumerate(tenant_ids, 1):
|
||||
if len(tenant_ids) > 1:
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"Processing tenant {idx}/{len(tenant_ids)}: {tenant_id}")
|
||||
print(f"{'=' * 80}")
|
||||
|
||||
# Print summary if multiple tenants
|
||||
if len(tenant_ids) > 1:
|
||||
try:
|
||||
was_cleaned = cleanup_tenant(tenant_id, pod_name, force)
|
||||
|
||||
if was_cleaned:
|
||||
# Only record if actually cleaned up (not skipped)
|
||||
successful_tenants.append(tenant_id)
|
||||
|
||||
# Write to CSV immediately after successful cleanup
|
||||
timestamp = datetime.utcnow().isoformat()
|
||||
csv_writer.writerow([tenant_id, timestamp])
|
||||
csv_file.flush() # Ensure real-time write
|
||||
print(f"✓ Recorded cleanup in {csv_output_path}")
|
||||
else:
|
||||
skipped_tenants.append(tenant_id)
|
||||
print(f"⚠ Tenant {tenant_id} was skipped (not recorded in CSV)")
|
||||
|
||||
except Exception as e:
|
||||
print(f"✗ Cleanup failed for tenant {tenant_id}: {e}", file=sys.stderr)
|
||||
failed_tenants.append((tenant_id, str(e)))
|
||||
|
||||
# If not in force mode and there are more tenants, ask if we should continue
|
||||
if not force and idx < len(tenant_ids):
|
||||
response = input(
|
||||
f"\nContinue with remaining {len(tenant_ids) - idx} tenant(s)? (y/n): "
|
||||
)
|
||||
if response.lower() != "y":
|
||||
print("Cleanup aborted by user")
|
||||
break
|
||||
|
||||
# Print summary
|
||||
if len(tenant_ids) == 1:
|
||||
if successful_tenants:
|
||||
print(f"\n✓ Successfully cleaned tenant written to: {csv_output_path}")
|
||||
elif skipped_tenants:
|
||||
print("\n⚠ Tenant was skipped")
|
||||
elif len(tenant_ids) > 1:
|
||||
print(f"\n{'=' * 80}")
|
||||
print("CLEANUP SUMMARY")
|
||||
print(f"{'=' * 80}")
|
||||
print(f"Total tenants: {len(tenant_ids)}")
|
||||
print(f"Successful: {len(successful_tenants)}")
|
||||
print(f"Skipped: {len(skipped_tenants)}")
|
||||
print(f"Failed: {len(failed_tenants)}")
|
||||
print(f"\nSuccessfully cleaned tenants written to: {csv_output_path}")
|
||||
|
||||
if skipped_tenants:
|
||||
print(f"\nSkipped tenants ({len(skipped_tenants)}):")
|
||||
for tenant_id in skipped_tenants:
|
||||
print(f" - {tenant_id}")
|
||||
|
||||
if failed_tenants:
|
||||
print("\nFailed tenants:")
|
||||
print(f"\nFailed tenants ({len(failed_tenants)}):")
|
||||
for tenant_id, error in failed_tenants:
|
||||
print(f" - {tenant_id}: {error}")
|
||||
|
||||
|
||||
@@ -7,137 +7,65 @@ Mark connectors for deletion script that:
|
||||
4. Triggers the cleanup task
|
||||
|
||||
Usage:
|
||||
python backend/scripts/tenant_cleanup/mark_connectors_for_deletion.py <tenant_id> [--force]
|
||||
python backend/scripts/tenant_cleanup/mark_connectors_for_deletion.py --csv <csv_file_path> [--force]
|
||||
python backend/scripts/tenant_cleanup/mark_connectors_for_deletion.py <tenant_id> [--force] [--concurrency N]
|
||||
python backend/scripts/tenant_cleanup/mark_connectors_for_deletion.py --csv <csv_file_path> [--force] [--concurrency N]
|
||||
|
||||
Arguments:
|
||||
tenant_id The tenant ID to process (required if not using --csv)
|
||||
--csv PATH Path to CSV file containing tenant IDs to process
|
||||
--force Skip all confirmation prompts (optional)
|
||||
--concurrency N Process N tenants concurrently (default: 1)
|
||||
|
||||
Examples:
|
||||
python backend/scripts/tenant_cleanup/mark_connectors_for_deletion.py tenant_abc123-def456-789
|
||||
python backend/scripts/tenant_cleanup/mark_connectors_for_deletion.py tenant_abc123-def456-789 --force
|
||||
python backend/scripts/tenant_cleanup/mark_connectors_for_deletion.py --csv gated_tenants_no_query_3mo.csv
|
||||
python backend/scripts/tenant_cleanup/mark_connectors_for_deletion.py --csv gated_tenants_no_query_3mo.csv --force
|
||||
python backend/scripts/tenant_cleanup/mark_connectors_for_deletion.py \
|
||||
--csv gated_tenants_no_query_3mo.csv --force --concurrency 16
|
||||
"""
|
||||
|
||||
import json
|
||||
import subprocess
|
||||
import sys
|
||||
from concurrent.futures import as_completed
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
from typing import Any
|
||||
|
||||
from scripts.tenant_cleanup.cleanup_utils import confirm_step
|
||||
from scripts.tenant_cleanup.cleanup_utils import find_worker_pod
|
||||
from scripts.tenant_cleanup.cleanup_utils import get_tenant_status
|
||||
from scripts.tenant_cleanup.cleanup_utils import read_tenant_ids_from_csv
|
||||
|
||||
|
||||
def get_tenant_connectors(pod_name: str, tenant_id: str) -> list[dict]:
|
||||
"""Get list of connector credential pairs for the tenant.
|
||||
|
||||
Args:
|
||||
pod_name: The Kubernetes pod name to execute on
|
||||
tenant_id: The tenant ID to query
|
||||
|
||||
Returns:
|
||||
List of connector credential pair dicts with id, connector_id, credential_id, name, status
|
||||
"""
|
||||
print(f"Fetching connector credential pairs for tenant: {tenant_id}")
|
||||
|
||||
# Get the path to the script
|
||||
script_dir = Path(__file__).parent
|
||||
get_connectors_script = script_dir / "on_pod_scripts" / "get_tenant_connectors.py"
|
||||
|
||||
if not get_connectors_script.exists():
|
||||
raise FileNotFoundError(
|
||||
f"get_tenant_connectors.py not found at {get_connectors_script}"
|
||||
)
|
||||
|
||||
try:
|
||||
# Copy script to pod
|
||||
print(" Copying script to pod...")
|
||||
subprocess.run(
|
||||
[
|
||||
"kubectl",
|
||||
"cp",
|
||||
str(get_connectors_script),
|
||||
f"{pod_name}:/tmp/get_tenant_connectors.py",
|
||||
],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
)
|
||||
|
||||
# Execute script on pod
|
||||
print(" Executing script on pod...")
|
||||
result = subprocess.run(
|
||||
[
|
||||
"kubectl",
|
||||
"exec",
|
||||
pod_name,
|
||||
"--",
|
||||
"python",
|
||||
"/tmp/get_tenant_connectors.py",
|
||||
tenant_id,
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
|
||||
# Show progress messages from stderr
|
||||
if result.stderr:
|
||||
print(f" {result.stderr}", end="")
|
||||
|
||||
# Parse JSON result from stdout
|
||||
result_data = json.loads(result.stdout)
|
||||
status = result_data.get("status")
|
||||
|
||||
if status == "success":
|
||||
connectors = result_data.get("connectors", [])
|
||||
if connectors:
|
||||
print(f"✓ Found {len(connectors)} connector credential pair(s):")
|
||||
for cc in connectors:
|
||||
print(
|
||||
f" - CC Pair ID: {cc['id']}, Name: {cc['name']}, Status: {cc['status']}"
|
||||
)
|
||||
else:
|
||||
print(" No connector credential pairs found for tenant")
|
||||
return connectors
|
||||
else:
|
||||
message = result_data.get("message", "Unknown error")
|
||||
print(f"⚠ Could not fetch connectors: {message}")
|
||||
return []
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"⚠ Failed to get connectors for tenant {tenant_id}: {e}")
|
||||
if e.stderr:
|
||||
print(f" Error details: {e.stderr}")
|
||||
return []
|
||||
except Exception as e:
|
||||
print(f"⚠ Failed to get connectors for tenant {tenant_id}: {e}")
|
||||
return []
|
||||
# Global lock for thread-safe printing
|
||||
_print_lock: Lock = Lock()
|
||||
|
||||
|
||||
def mark_connector_for_deletion(pod_name: str, tenant_id: str, cc_pair_id: int) -> None:
|
||||
"""Mark a connector credential pair for deletion.
|
||||
def safe_print(*args: Any, **kwargs: Any) -> None:
|
||||
"""Thread-safe print function."""
|
||||
with _print_lock:
|
||||
print(*args, **kwargs)
|
||||
|
||||
|
||||
def run_connector_deletion(pod_name: str, tenant_id: str) -> None:
|
||||
"""Mark all connector credential pairs for deletion.
|
||||
|
||||
Args:
|
||||
pod_name: The Kubernetes pod name to execute on
|
||||
tenant_id: The tenant ID
|
||||
cc_pair_id: The connector credential pair ID to mark for deletion
|
||||
"""
|
||||
print(f" Marking CC pair {cc_pair_id} for deletion...")
|
||||
safe_print(" Marking all connector credential pairs for deletion...")
|
||||
|
||||
# Get the path to the script
|
||||
script_dir = Path(__file__).parent
|
||||
mark_deletion_script = (
|
||||
script_dir / "on_pod_scripts" / "mark_connector_for_deletion.py"
|
||||
script_dir / "on_pod_scripts" / "execute_connector_deletion.py"
|
||||
)
|
||||
|
||||
if not mark_deletion_script.exists():
|
||||
raise FileNotFoundError(
|
||||
f"mark_connector_for_deletion.py not found at {mark_deletion_script}"
|
||||
f"execute_connector_deletion.py not found at {mark_deletion_script}"
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -147,7 +75,7 @@ def mark_connector_for_deletion(pod_name: str, tenant_id: str, cc_pair_id: int)
|
||||
"kubectl",
|
||||
"cp",
|
||||
str(mark_deletion_script),
|
||||
f"{pod_name}:/tmp/mark_connector_for_deletion.py",
|
||||
f"{pod_name}:/tmp/execute_connector_deletion.py",
|
||||
],
|
||||
check=True,
|
||||
capture_output=True,
|
||||
@@ -161,166 +89,119 @@ def mark_connector_for_deletion(pod_name: str, tenant_id: str, cc_pair_id: int)
|
||||
pod_name,
|
||||
"--",
|
||||
"python",
|
||||
"/tmp/mark_connector_for_deletion.py",
|
||||
"/tmp/execute_connector_deletion.py",
|
||||
tenant_id,
|
||||
str(cc_pair_id),
|
||||
"--all",
|
||||
],
|
||||
capture_output=True,
|
||||
text=True,
|
||||
check=True,
|
||||
)
|
||||
|
||||
# Show progress messages from stderr
|
||||
if result.stderr:
|
||||
print(f" {result.stderr}", end="")
|
||||
|
||||
# Parse JSON result from stdout
|
||||
result_data = json.loads(result.stdout)
|
||||
status = result_data.get("status")
|
||||
message = result_data.get("message")
|
||||
|
||||
if status == "success":
|
||||
print(f" ✓ {message}")
|
||||
else:
|
||||
print(f" ✗ {message}", file=sys.stderr)
|
||||
raise RuntimeError(message)
|
||||
if result.returncode != 0:
|
||||
raise RuntimeError(result.stderr)
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(
|
||||
f" ✗ Failed to mark CC pair {cc_pair_id} for deletion: {e}",
|
||||
safe_print(
|
||||
f" ✗ Failed to mark all connector credential pairs for deletion: {e}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
if e.stderr:
|
||||
print(f" Error details: {e.stderr}", file=sys.stderr)
|
||||
safe_print(f" Error details: {e.stderr}", file=sys.stderr)
|
||||
raise
|
||||
except Exception as e:
|
||||
print(
|
||||
f" ✗ Failed to mark CC pair {cc_pair_id} for deletion: {e}",
|
||||
safe_print(
|
||||
f" ✗ Failed to mark all connector credential pairs for deletion: {e}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def mark_tenant_connectors_for_deletion(tenant_id: str, force: bool = False) -> None:
|
||||
def mark_tenant_connectors_for_deletion(
|
||||
tenant_id: str, pod_name: str, force: bool = False
|
||||
) -> None:
|
||||
"""
|
||||
Main function to mark all connectors for a tenant for deletion.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant ID to process
|
||||
pod_name: The Kubernetes pod name to execute on
|
||||
force: If True, skip all confirmation prompts
|
||||
"""
|
||||
print(f"Processing connectors for tenant: {tenant_id}")
|
||||
safe_print(f"Processing connectors for tenant: {tenant_id}")
|
||||
|
||||
# Check tenant status first
|
||||
print(f"\n{'=' * 80}")
|
||||
safe_print(f"\n{'=' * 80}")
|
||||
try:
|
||||
tenant_status = get_tenant_status(tenant_id)
|
||||
|
||||
# If tenant is not GATED_ACCESS, require explicit confirmation even in force mode
|
||||
if tenant_status and tenant_status != "GATED_ACCESS":
|
||||
print(
|
||||
safe_print(
|
||||
f"\n⚠️ WARNING: Tenant status is '{tenant_status}', not 'GATED_ACCESS'!"
|
||||
)
|
||||
print(
|
||||
safe_print(
|
||||
"This tenant may be active and should not have connectors deleted without careful review."
|
||||
)
|
||||
print(f"{'=' * 80}\n")
|
||||
safe_print(f"{'=' * 80}\n")
|
||||
|
||||
# Always ask for confirmation if not gated, even in force mode
|
||||
response = input(
|
||||
"Are you ABSOLUTELY SURE you want to proceed? Type 'yes' to confirm: "
|
||||
)
|
||||
if response.lower() != "yes":
|
||||
print("Operation aborted - tenant is not GATED_ACCESS")
|
||||
return
|
||||
# Note: In parallel mode with force, this will still block
|
||||
if not force:
|
||||
response = input(
|
||||
"Are you ABSOLUTELY SURE you want to proceed? Type 'yes' to confirm: "
|
||||
)
|
||||
if response.lower() != "yes":
|
||||
safe_print("Operation aborted - tenant is not GATED_ACCESS")
|
||||
raise RuntimeError(f"Tenant {tenant_id} is not GATED_ACCESS")
|
||||
else:
|
||||
raise RuntimeError(f"Tenant {tenant_id} is not GATED_ACCESS")
|
||||
elif tenant_status == "GATED_ACCESS":
|
||||
print("✓ Tenant status is GATED_ACCESS - safe to proceed")
|
||||
safe_print("✓ Tenant status is GATED_ACCESS - safe to proceed")
|
||||
elif tenant_status is None:
|
||||
print("⚠️ WARNING: Could not determine tenant status!")
|
||||
safe_print("⚠️ WARNING: Could not determine tenant status!")
|
||||
if not force:
|
||||
response = input("Continue anyway? Type 'yes' to confirm: ")
|
||||
if response.lower() != "yes":
|
||||
safe_print("Operation aborted - could not verify tenant status")
|
||||
raise RuntimeError(
|
||||
f"Could not verify tenant status for {tenant_id}"
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(f"Could not verify tenant status for {tenant_id}")
|
||||
except Exception as e:
|
||||
safe_print(f"⚠️ WARNING: Failed to check tenant status: {e}")
|
||||
if not force:
|
||||
response = input("Continue anyway? Type 'yes' to confirm: ")
|
||||
if response.lower() != "yes":
|
||||
print("Operation aborted - could not verify tenant status")
|
||||
return
|
||||
except Exception as e:
|
||||
print(f"⚠️ WARNING: Failed to check tenant status: {e}")
|
||||
response = input("Continue anyway? Type 'yes' to confirm: ")
|
||||
if response.lower() != "yes":
|
||||
print("Operation aborted - could not verify tenant status")
|
||||
return
|
||||
print(f"{'=' * 80}\n")
|
||||
safe_print("Operation aborted - could not verify tenant status")
|
||||
raise
|
||||
else:
|
||||
raise RuntimeError(f"Failed to check tenant status for {tenant_id}")
|
||||
safe_print(f"{'=' * 80}\n")
|
||||
|
||||
# Find heavy worker pod for operations
|
||||
try:
|
||||
pod_name = find_worker_pod()
|
||||
except Exception as e:
|
||||
print(f"✗ Failed to find heavy worker pod: {e}", file=sys.stderr)
|
||||
print("Cannot proceed with marking connectors for deletion")
|
||||
return
|
||||
|
||||
# Fetch connectors
|
||||
print(f"\n{'=' * 80}")
|
||||
try:
|
||||
connectors = get_tenant_connectors(pod_name, tenant_id)
|
||||
except Exception as e:
|
||||
print(f"✗ Failed to fetch connectors: {e}", file=sys.stderr)
|
||||
return
|
||||
print(f"{'=' * 80}\n")
|
||||
|
||||
if not connectors:
|
||||
print(f"No connectors found for tenant {tenant_id}, nothing to do.")
|
||||
return
|
||||
|
||||
# Confirm before proceeding
|
||||
# Confirm before proceeding (only in non-force mode)
|
||||
if not confirm_step(
|
||||
f"Mark {len(connectors)} connector credential pair(s) for deletion?",
|
||||
f"Mark all connector credential pairs for deletion for tenant {tenant_id}?",
|
||||
force,
|
||||
):
|
||||
print("Operation cancelled by user")
|
||||
return
|
||||
safe_print("Operation cancelled by user")
|
||||
raise ValueError("Operation cancelled by user")
|
||||
|
||||
# Mark each connector for deletion
|
||||
failed_connectors = []
|
||||
successful_connectors = []
|
||||
|
||||
for cc in connectors:
|
||||
cc_pair_id = cc["id"]
|
||||
cc_name = cc["name"]
|
||||
cc_status = cc["status"]
|
||||
|
||||
# Skip if already marked for deletion
|
||||
if cc_status == "DELETING":
|
||||
print(
|
||||
f" Skipping CC pair {cc_pair_id} ({cc_name}) - already marked for deletion"
|
||||
)
|
||||
continue
|
||||
|
||||
try:
|
||||
mark_connector_for_deletion(pod_name, tenant_id, cc_pair_id)
|
||||
successful_connectors.append(cc_pair_id)
|
||||
except Exception as e:
|
||||
print(
|
||||
f" ✗ Failed to mark CC pair {cc_pair_id} ({cc_name}) for deletion: {e}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
failed_connectors.append((cc_pair_id, cc_name, str(e)))
|
||||
run_connector_deletion(pod_name, tenant_id)
|
||||
|
||||
# Print summary
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"✓ Marked {len(successful_connectors)} connector(s) for deletion")
|
||||
if failed_connectors:
|
||||
print(f"✗ Failed to mark {len(failed_connectors)} connector(s):")
|
||||
for cc_id, cc_name, error in failed_connectors:
|
||||
print(f" - CC Pair {cc_id} ({cc_name}): {error}")
|
||||
print(f"{'=' * 80}")
|
||||
safe_print(
|
||||
f"✓ Marked all connector credential pairs for deletion for tenant {tenant_id}"
|
||||
)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
if len(sys.argv) < 2:
|
||||
print(
|
||||
"Usage: python backend/scripts/tenant_cleanup/mark_connectors_for_deletion.py <tenant_id> [--force]"
|
||||
"Usage: python backend/scripts/tenant_cleanup/mark_connectors_for_deletion.py <tenant_id> [--force] "
|
||||
"[--concurrency N]"
|
||||
)
|
||||
print(
|
||||
" python backend/scripts/tenant_cleanup/mark_connectors_for_deletion.py --csv <csv_file_path> [--force]"
|
||||
" [--concurrency N]"
|
||||
)
|
||||
print("\nArguments:")
|
||||
print(
|
||||
@@ -328,6 +209,7 @@ def main() -> None:
|
||||
)
|
||||
print(" --csv PATH Path to CSV file containing tenant IDs to process")
|
||||
print(" --force Skip all confirmation prompts (optional)")
|
||||
print(" --concurrency N Process N tenants concurrently (default: 1)")
|
||||
print("\nExamples:")
|
||||
print(
|
||||
" python backend/scripts/tenant_cleanup/mark_connectors_for_deletion.py tenant_abc123-def456-789"
|
||||
@@ -339,23 +221,48 @@ def main() -> None:
|
||||
" python backend/scripts/tenant_cleanup/mark_connectors_for_deletion.py --csv gated_tenants_no_query_3mo.csv"
|
||||
)
|
||||
print(
|
||||
" python backend/scripts/tenant_cleanup/mark_connectors_for_deletion.py --csv gated_tenants_no_query_3mo.csv --force"
|
||||
" python backend/scripts/tenant_cleanup/mark_connectors_for_deletion.py --csv gated_tenants_no_query_3mo.csv "
|
||||
"--force --concurrency 16"
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# Parse arguments
|
||||
force = "--force" in sys.argv
|
||||
tenant_ids = []
|
||||
tenant_ids: list[str] = []
|
||||
|
||||
# Parse concurrency
|
||||
concurrency: int = 1
|
||||
if "--concurrency" in sys.argv:
|
||||
try:
|
||||
concurrency_index = sys.argv.index("--concurrency")
|
||||
if concurrency_index + 1 >= len(sys.argv):
|
||||
print("Error: --concurrency flag requires a number", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
concurrency = int(sys.argv[concurrency_index + 1])
|
||||
if concurrency < 1:
|
||||
print("Error: concurrency must be at least 1", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
except ValueError:
|
||||
print("Error: --concurrency value must be an integer", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Validate: concurrency > 1 requires --force
|
||||
if concurrency > 1 and not force:
|
||||
print(
|
||||
"Error: --concurrency > 1 requires --force flag (interactive mode not supported with parallel processing)",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
# Check for CSV mode
|
||||
if "--csv" in sys.argv:
|
||||
try:
|
||||
csv_index = sys.argv.index("--csv")
|
||||
csv_index: int = sys.argv.index("--csv")
|
||||
if csv_index + 1 >= len(sys.argv):
|
||||
print("Error: --csv flag requires a file path", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
csv_path = sys.argv[csv_index + 1]
|
||||
csv_path: str = sys.argv[csv_index + 1]
|
||||
tenant_ids = read_tenant_ids_from_csv(csv_path)
|
||||
|
||||
if not tenant_ids:
|
||||
@@ -371,6 +278,16 @@ def main() -> None:
|
||||
# Single tenant mode
|
||||
tenant_ids = [sys.argv[1]]
|
||||
|
||||
# Find heavy worker pod once before processing
|
||||
try:
|
||||
print("Finding worker pod...")
|
||||
pod_name: str = find_worker_pod()
|
||||
print(f"✓ Using worker pod: {pod_name}")
|
||||
except Exception as e:
|
||||
print(f"✗ Failed to find heavy worker pod: {e}", file=sys.stderr)
|
||||
print("Cannot proceed with marking connectors for deletion")
|
||||
sys.exit(1)
|
||||
|
||||
# Initial confirmation (unless --force is used)
|
||||
if not force:
|
||||
print(f"\n{'=' * 80}")
|
||||
@@ -387,6 +304,7 @@ def main() -> None:
|
||||
print(
|
||||
f"Mode: {'FORCE (no confirmations)' if force else 'Interactive (will ask for confirmation at each step)'}"
|
||||
)
|
||||
print(f"Concurrency: {concurrency} tenant(s) at a time")
|
||||
print("\nThis will:")
|
||||
print(" 1. Fetch all connector credential pairs for each tenant")
|
||||
print(" 2. Cancel any scheduled indexing attempts for each connector")
|
||||
@@ -409,37 +327,78 @@ def main() -> None:
|
||||
)
|
||||
else:
|
||||
print(
|
||||
f"⚠ FORCE MODE: Marking connectors for deletion for {len(tenant_ids)} tenants without confirmations"
|
||||
f"⚠ FORCE MODE: Marking connectors for deletion for {len(tenant_ids)} tenants "
|
||||
f"(concurrency: {concurrency}) without confirmations"
|
||||
)
|
||||
|
||||
# Process each tenant
|
||||
failed_tenants = []
|
||||
successful_tenants = []
|
||||
# Process tenants (in parallel if concurrency > 1)
|
||||
failed_tenants: list[tuple[str, str]] = []
|
||||
successful_tenants: list[str] = []
|
||||
|
||||
for idx, tenant_id in enumerate(tenant_ids, 1):
|
||||
if len(tenant_ids) > 1:
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"Processing tenant {idx}/{len(tenant_ids)}: {tenant_id}")
|
||||
print(f"{'=' * 80}")
|
||||
if concurrency == 1:
|
||||
# Sequential processing
|
||||
for idx, tenant_id in enumerate(tenant_ids, 1):
|
||||
if len(tenant_ids) > 1:
|
||||
print(f"\n{'=' * 80}")
|
||||
print(f"Processing tenant {idx}/{len(tenant_ids)}: {tenant_id}")
|
||||
print(f"{'=' * 80}")
|
||||
|
||||
try:
|
||||
mark_tenant_connectors_for_deletion(tenant_id, force)
|
||||
successful_tenants.append(tenant_id)
|
||||
except Exception as e:
|
||||
print(
|
||||
f"✗ Failed to process tenant {tenant_id}: {e}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
failed_tenants.append((tenant_id, str(e)))
|
||||
|
||||
# If not in force mode and there are more tenants, ask if we should continue
|
||||
if not force and idx < len(tenant_ids):
|
||||
response = input(
|
||||
f"\nContinue with remaining {len(tenant_ids) - idx} tenant(s)? (y/n): "
|
||||
try:
|
||||
mark_tenant_connectors_for_deletion(tenant_id, pod_name, force)
|
||||
successful_tenants.append(tenant_id)
|
||||
except Exception as e:
|
||||
print(
|
||||
f"✗ Failed to process tenant {tenant_id}: {e}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
if response.lower() != "y":
|
||||
print("Operation aborted by user")
|
||||
break
|
||||
failed_tenants.append((tenant_id, str(e)))
|
||||
|
||||
# If not in force mode and there are more tenants, ask if we should continue
|
||||
if not force and idx < len(tenant_ids):
|
||||
response = input(
|
||||
f"\nContinue with remaining {len(tenant_ids) - idx} tenant(s)? (y/n): "
|
||||
)
|
||||
if response.lower() != "y":
|
||||
print("Operation aborted by user")
|
||||
break
|
||||
else:
|
||||
# Parallel processing
|
||||
print(
|
||||
f"\nProcessing {len(tenant_ids)} tenant(s) with concurrency={concurrency}"
|
||||
)
|
||||
|
||||
def process_tenant(tenant_id: str) -> tuple[str, bool, str | None]:
|
||||
"""Process a single tenant. Returns (tenant_id, success, error_message)."""
|
||||
try:
|
||||
mark_tenant_connectors_for_deletion(tenant_id, pod_name, force)
|
||||
return (tenant_id, True, None)
|
||||
except Exception as e:
|
||||
return (tenant_id, False, str(e))
|
||||
|
||||
with ThreadPoolExecutor(max_workers=concurrency) as executor:
|
||||
# Submit all tasks
|
||||
future_to_tenant = {
|
||||
executor.submit(process_tenant, tenant_id): tenant_id
|
||||
for tenant_id in tenant_ids
|
||||
}
|
||||
|
||||
# Process results as they complete
|
||||
completed: int = 0
|
||||
for future in as_completed(future_to_tenant):
|
||||
completed += 1
|
||||
tenant_id, success, error = future.result()
|
||||
|
||||
if success:
|
||||
successful_tenants.append(tenant_id)
|
||||
safe_print(
|
||||
f"[{completed}/{len(tenant_ids)}] ✓ Successfully processed {tenant_id}"
|
||||
)
|
||||
else:
|
||||
failed_tenants.append((tenant_id, error or "Unknown error"))
|
||||
safe_print(
|
||||
f"[{completed}/{len(tenant_ids)}] ✗ Failed to process {tenant_id}: {error}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
# Print summary if multiple tenants
|
||||
if len(tenant_ids) > 1:
|
||||
|
||||
@@ -51,13 +51,14 @@ def check_documents_deleted(tenant_id: str) -> dict:
|
||||
cc_count = cc_count or 0
|
||||
doc_count = doc_count or 0
|
||||
|
||||
# If any records remain, return error status
|
||||
if cc_count > 0 or doc_count > 0:
|
||||
# If any records remain beyond acceptable thresholds, return error status
|
||||
if cc_count > 0 or doc_count > 5:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": (
|
||||
f"Found {cc_count} ConnectorCredentialPair(s) and {doc_count} Document(s) "
|
||||
"still remaining. All documents must be deleted before cleanup."
|
||||
"still remaining. Must have 0 ConnectorCredentialPairs and no more than "
|
||||
"5 Documents before cleanup."
|
||||
),
|
||||
"connector_credential_pair_count": cc_count,
|
||||
"document_count": doc_count,
|
||||
|
||||
@@ -0,0 +1,348 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Script to mark connector credential pairs for deletion.
|
||||
Runs on a Kubernetes pod with access to the data plane database.
|
||||
|
||||
Usage:
|
||||
# Mark a specific connector for deletion
|
||||
python mark_connector_for_deletion.py <tenant_id> <cc_pair_id>
|
||||
|
||||
# Mark all connectors for deletion
|
||||
python mark_connector_for_deletion.py <tenant_id> --all
|
||||
|
||||
Output:
|
||||
JSON to stdout with structure:
|
||||
{
|
||||
"status": "success" | "error",
|
||||
"message": str,
|
||||
"deleted_count": int (when using --all),
|
||||
"timing": {
|
||||
"total_seconds": float,
|
||||
"per_connector": [...]
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pairs
|
||||
from onyx.db.connector_credential_pair import update_connector_credential_pair_from_id
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.index_attempt import cancel_indexing_attempts_for_ccpair
|
||||
|
||||
|
||||
def mark_connector_for_deletion(
|
||||
tenant_id: str, cc_pair_id: int, db_session: Session | None = None
|
||||
) -> dict[str, Any]:
|
||||
"""Mark a connector credential pair for deletion.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant ID
|
||||
cc_pair_id: The connector credential pair ID
|
||||
db_session: Optional database session (if None, creates a new one)
|
||||
|
||||
Returns:
|
||||
Dict with status, message, and timing
|
||||
"""
|
||||
timing: dict[str, float] = {}
|
||||
start_time: float = time.time()
|
||||
|
||||
try:
|
||||
print(
|
||||
f"Marking connector credential pair {cc_pair_id} for deletion",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
def _mark_deletion(db_sess: Session) -> dict[str, Any]:
|
||||
# Get the connector credential pair
|
||||
fetch_start: float = time.time()
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_sess,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
timing["fetch_cc_pair_seconds"] = time.time() - fetch_start
|
||||
|
||||
if not cc_pair:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Connector credential pair {cc_pair_id} not found",
|
||||
"timing": timing,
|
||||
}
|
||||
|
||||
# Cancel any scheduled indexing attempts
|
||||
print(
|
||||
f"Canceling indexing attempts for CC pair {cc_pair_id}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
cancel_start: float = time.time()
|
||||
cancel_indexing_attempts_for_ccpair(
|
||||
cc_pair_id=cc_pair.id,
|
||||
db_session=db_sess,
|
||||
include_secondary_index=True,
|
||||
)
|
||||
timing["cancel_indexing_seconds"] = time.time() - cancel_start
|
||||
|
||||
# Mark as deleting
|
||||
print(
|
||||
f"Updating CC pair {cc_pair_id} status to DELETING",
|
||||
file=sys.stderr,
|
||||
)
|
||||
update_start: float = time.time()
|
||||
update_connector_credential_pair_from_id(
|
||||
db_session=db_sess,
|
||||
cc_pair_id=cc_pair.id,
|
||||
status=ConnectorCredentialPairStatus.DELETING,
|
||||
)
|
||||
timing["update_status_seconds"] = time.time() - update_start
|
||||
|
||||
commit_start: float = time.time()
|
||||
db_sess.commit()
|
||||
timing["commit_seconds"] = time.time() - commit_start
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Marked connector credential pair {cc_pair_id} for deletion",
|
||||
"timing": timing,
|
||||
}
|
||||
|
||||
result: dict[str, Any]
|
||||
if db_session:
|
||||
result = _mark_deletion(db_session)
|
||||
else:
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_sess:
|
||||
result = _mark_deletion(db_sess)
|
||||
|
||||
# Trigger the deletion check task
|
||||
print(
|
||||
"Triggering connector deletion check task",
|
||||
file=sys.stderr,
|
||||
)
|
||||
task_start: float = time.time()
|
||||
client_app.send_task(
|
||||
OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
kwargs={"tenant_id": tenant_id},
|
||||
)
|
||||
timing["send_task_seconds"] = time.time() - task_start
|
||||
timing["total_seconds"] = time.time() - start_time
|
||||
|
||||
result["timing"] = timing
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
print(
|
||||
f"Error marking connector for deletion: {e}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
timing["total_seconds"] = time.time() - start_time
|
||||
return {
|
||||
"status": "error",
|
||||
"message": str(e),
|
||||
"timing": timing,
|
||||
}
|
||||
|
||||
|
||||
def mark_all_connectors_for_deletion(tenant_id: str) -> dict[str, Any]:
|
||||
"""Mark all connector credential pairs for a tenant for deletion.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant ID
|
||||
|
||||
Returns:
|
||||
Dict with status, message, deleted_count, and timing
|
||||
"""
|
||||
overall_start: float = time.time()
|
||||
per_connector_timing: list[dict[str, Any]] = []
|
||||
|
||||
try:
|
||||
print(
|
||||
f"Marking all connector credential pairs for tenant {tenant_id} for deletion",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
# Get all connector credential pairs
|
||||
fetch_all_start: float = time.time()
|
||||
cc_pairs = get_connector_credential_pairs(db_session=db_session)
|
||||
fetch_all_time: float = time.time() - fetch_all_start
|
||||
|
||||
print(
|
||||
f"Found {len(cc_pairs)} connector credential pairs to delete",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
if not cc_pairs:
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "No connector credential pairs found for tenant",
|
||||
"deleted_count": 0,
|
||||
"timing": {
|
||||
"fetch_all_seconds": fetch_all_time,
|
||||
"total_seconds": time.time() - overall_start,
|
||||
},
|
||||
}
|
||||
|
||||
deleted_count: int = 0
|
||||
errors: list[str] = []
|
||||
|
||||
for cc_pair in cc_pairs:
|
||||
connector_start: float = time.time()
|
||||
print(
|
||||
f"Processing CC pair {cc_pair.id} ({deleted_count + 1}/{len(cc_pairs)})",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
# Cancel any scheduled indexing attempts
|
||||
cancel_start: float = time.time()
|
||||
cancel_indexing_attempts_for_ccpair(
|
||||
cc_pair_id=cc_pair.id,
|
||||
db_session=db_session,
|
||||
include_secondary_index=True,
|
||||
)
|
||||
cancel_time: float = time.time() - cancel_start
|
||||
|
||||
# Mark as deleting
|
||||
update_start: float = time.time()
|
||||
try:
|
||||
update_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair.id,
|
||||
status=ConnectorCredentialPairStatus.DELETING,
|
||||
)
|
||||
deleted_count += 1
|
||||
except Exception as e:
|
||||
errors.append(f"CC pair {cc_pair.id}: {str(e)}")
|
||||
print(
|
||||
f"Error updating CC pair {cc_pair.id}: {e}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
update_time: float = time.time() - update_start
|
||||
connector_total_time: float = time.time() - connector_start
|
||||
|
||||
per_connector_timing.append(
|
||||
{
|
||||
"cc_pair_id": cc_pair.id,
|
||||
"cancel_indexing_seconds": cancel_time,
|
||||
"update_status_seconds": update_time,
|
||||
"total_seconds": connector_total_time,
|
||||
}
|
||||
)
|
||||
|
||||
# Commit all changes
|
||||
commit_start: float = time.time()
|
||||
db_session.commit()
|
||||
commit_time: float = time.time() - commit_start
|
||||
|
||||
# Trigger the deletion check task
|
||||
print(
|
||||
"Triggering connector deletion check task",
|
||||
file=sys.stderr,
|
||||
)
|
||||
task_start: float = time.time()
|
||||
client_app.send_task(
|
||||
OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
kwargs={"tenant_id": tenant_id},
|
||||
)
|
||||
task_time: float = time.time() - task_start
|
||||
|
||||
total_time: float = time.time() - overall_start
|
||||
|
||||
result: dict[str, Any] = {
|
||||
"status": "success",
|
||||
"message": f"Marked {deleted_count} connector credential pairs for deletion",
|
||||
"deleted_count": deleted_count,
|
||||
"timing": {
|
||||
"fetch_all_seconds": fetch_all_time,
|
||||
"commit_seconds": commit_time,
|
||||
"send_task_seconds": task_time,
|
||||
"total_seconds": total_time,
|
||||
"per_connector": per_connector_timing,
|
||||
},
|
||||
}
|
||||
|
||||
if errors:
|
||||
result["errors"] = errors
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
print(
|
||||
f"Error marking all connectors for deletion: {e}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": str(e),
|
||||
"timing": {
|
||||
"total_seconds": time.time() - overall_start,
|
||||
"per_connector": per_connector_timing,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def main() -> None:
|
||||
if len(sys.argv) < 2 or len(sys.argv) > 3:
|
||||
print(
|
||||
json.dumps(
|
||||
{
|
||||
"status": "error",
|
||||
"message": "Usage: python mark_connector_for_deletion.py <tenant_id> [<cc_pair_id>|--all]",
|
||||
}
|
||||
)
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
tenant_id: str = sys.argv[1]
|
||||
|
||||
SqlEngine.init_engine(pool_size=5, max_overflow=2)
|
||||
|
||||
result: dict[str, Any]
|
||||
# Check if we should mark all connectors or just one
|
||||
if len(sys.argv) == 3:
|
||||
second_arg: str = sys.argv[2]
|
||||
if second_arg == "--all":
|
||||
result = mark_all_connectors_for_deletion(tenant_id)
|
||||
else:
|
||||
try:
|
||||
cc_pair_id: int = int(second_arg)
|
||||
result = mark_connector_for_deletion(tenant_id, cc_pair_id)
|
||||
except ValueError:
|
||||
print(
|
||||
json.dumps(
|
||||
{
|
||||
"status": "error",
|
||||
"message": "cc_pair_id must be an integer or use --all",
|
||||
}
|
||||
)
|
||||
)
|
||||
sys.exit(1)
|
||||
else:
|
||||
# If only tenant_id is provided, show error
|
||||
print(
|
||||
json.dumps(
|
||||
{
|
||||
"status": "error",
|
||||
"message": "Usage: python mark_connector_for_deletion.py <tenant_id> [<cc_pair_id>|--all]",
|
||||
}
|
||||
)
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
print(json.dumps(result, indent=2))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,144 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Script to mark a connector credential pair for deletion.
|
||||
Runs on a Kubernetes pod with access to the data plane database.
|
||||
|
||||
Usage:
|
||||
python mark_connector_for_deletion.py <tenant_id> <cc_pair_id>
|
||||
|
||||
Output:
|
||||
JSON to stdout with structure:
|
||||
{
|
||||
"status": "success" | "error",
|
||||
"message": str
|
||||
}
|
||||
"""
|
||||
|
||||
import json
|
||||
import sys
|
||||
|
||||
from onyx.background.celery.versioned_apps.client import app as client_app
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import update_connector_credential_pair_from_id
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.db.engine.sql_engine import SqlEngine
|
||||
from onyx.db.enums import ConnectorCredentialPairStatus
|
||||
from onyx.db.index_attempt import cancel_indexing_attempts_for_ccpair
|
||||
|
||||
|
||||
def mark_connector_for_deletion(tenant_id: str, cc_pair_id: int) -> dict:
|
||||
"""Mark a connector credential pair for deletion.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant ID
|
||||
cc_pair_id: The connector credential pair ID
|
||||
|
||||
Returns:
|
||||
Dict with status and message
|
||||
"""
|
||||
try:
|
||||
print(
|
||||
f"Marking connector credential pair {cc_pair_id} for deletion",
|
||||
file=sys.stderr,
|
||||
)
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
# Get the connector credential pair
|
||||
cc_pair = get_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair_id,
|
||||
)
|
||||
|
||||
if not cc_pair:
|
||||
return {
|
||||
"status": "error",
|
||||
"message": f"Connector credential pair {cc_pair_id} not found",
|
||||
}
|
||||
|
||||
# Cancel any scheduled indexing attempts
|
||||
print(
|
||||
f"Canceling indexing attempts for CC pair {cc_pair_id}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
cancel_indexing_attempts_for_ccpair(
|
||||
cc_pair_id=cc_pair.id,
|
||||
db_session=db_session,
|
||||
include_secondary_index=True,
|
||||
)
|
||||
|
||||
# Mark as deleting
|
||||
print(
|
||||
f"Updating CC pair {cc_pair_id} status to DELETING",
|
||||
file=sys.stderr,
|
||||
)
|
||||
update_connector_credential_pair_from_id(
|
||||
db_session=db_session,
|
||||
cc_pair_id=cc_pair.id,
|
||||
status=ConnectorCredentialPairStatus.DELETING,
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
# Trigger the deletion check task
|
||||
print(
|
||||
"Triggering connector deletion check task",
|
||||
file=sys.stderr,
|
||||
)
|
||||
client_app.send_task(
|
||||
OnyxCeleryTask.CHECK_FOR_CONNECTOR_DELETION,
|
||||
priority=OnyxCeleryPriority.HIGH,
|
||||
kwargs={"tenant_id": tenant_id},
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": f"Marked connector credential pair {cc_pair_id} for deletion",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
print(
|
||||
f"Error marking connector for deletion: {e}",
|
||||
file=sys.stderr,
|
||||
)
|
||||
return {
|
||||
"status": "error",
|
||||
"message": str(e),
|
||||
}
|
||||
|
||||
|
||||
def main() -> None:
|
||||
if len(sys.argv) != 3:
|
||||
print(
|
||||
json.dumps(
|
||||
{
|
||||
"status": "error",
|
||||
"message": "Usage: python mark_connector_for_deletion.py <tenant_id> <cc_pair_id>",
|
||||
}
|
||||
)
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
tenant_id = sys.argv[1]
|
||||
try:
|
||||
cc_pair_id = int(sys.argv[2])
|
||||
except ValueError:
|
||||
print(
|
||||
json.dumps(
|
||||
{
|
||||
"status": "error",
|
||||
"message": "cc_pair_id must be an integer",
|
||||
}
|
||||
)
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
SqlEngine.init_engine(pool_size=5, max_overflow=2)
|
||||
|
||||
result = mark_connector_for_deletion(tenant_id, cc_pair_id)
|
||||
print(json.dumps(result))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,9 +1,11 @@
|
||||
import csv
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from io import BytesIO
|
||||
from io import StringIO
|
||||
from zipfile import ZipFile
|
||||
|
||||
import pytest
|
||||
@@ -257,6 +259,48 @@ class TestUsageExportAPI:
|
||||
assert "chat_messages.csv" in file_names
|
||||
assert "users.csv" in file_names
|
||||
|
||||
# Verify chat_messages.csv has the expected columns
|
||||
with zip_file.open("chat_messages.csv") as csv_file:
|
||||
csv_content = csv_file.read().decode("utf-8")
|
||||
csv_reader = csv.DictReader(StringIO(csv_content))
|
||||
|
||||
# Check that all expected columns are present
|
||||
expected_columns = {
|
||||
"session_id",
|
||||
"user_id",
|
||||
"flow_type",
|
||||
"time_sent",
|
||||
"assistant_name",
|
||||
"user_email",
|
||||
"number_of_tokens",
|
||||
}
|
||||
actual_columns = set(csv_reader.fieldnames or [])
|
||||
assert (
|
||||
expected_columns == actual_columns
|
||||
), f"Expected columns {expected_columns}, but got {actual_columns}"
|
||||
|
||||
# Verify there's at least one row of data
|
||||
rows = list(csv_reader)
|
||||
assert len(rows) > 0, "Expected at least one message in the report"
|
||||
|
||||
# Verify the first row has non-empty values for all columns
|
||||
first_row = rows[0]
|
||||
for column in expected_columns:
|
||||
assert column in first_row, f"Column {column} not found in row"
|
||||
assert first_row[
|
||||
column
|
||||
], f"Column {column} has empty value in first row"
|
||||
|
||||
# Verify specific new fields have appropriate values
|
||||
assert first_row["assistant_name"], "assistant_name should not be empty"
|
||||
assert first_row["user_email"], "user_email should not be empty"
|
||||
assert first_row[
|
||||
"number_of_tokens"
|
||||
].isdigit(), "number_of_tokens should be a numeric value"
|
||||
assert (
|
||||
int(first_row["number_of_tokens"]) >= 0
|
||||
), "number_of_tokens should be non-negative"
|
||||
|
||||
def test_read_nonexistent_report(self, reset: None, admin_user: DATestUser) -> None:
|
||||
# Try to download a report that doesn't exist
|
||||
response = requests.get(
|
||||
|
||||
1
web/.gitignore
vendored
1
web/.gitignore
vendored
@@ -43,3 +43,4 @@ next-env.d.ts
|
||||
|
||||
# generated clients ... in particular, the API to the Onyx backend itself!
|
||||
/src/lib/generated
|
||||
.jest-cache
|
||||
|
||||
@@ -1,11 +1,100 @@
|
||||
module.exports = {
|
||||
/**
|
||||
* Jest configuration with separate projects for different test environments.
|
||||
*
|
||||
* We use two separate projects:
|
||||
* 1. "unit" - Node environment for pure unit tests (no DOM needed)
|
||||
* 2. "integration" - jsdom environment for React integration tests
|
||||
*
|
||||
* This allows us to run tests with the correct environment automatically
|
||||
* without needing @jest-environment comments in every test file.
|
||||
*/
|
||||
|
||||
// Shared configuration
|
||||
const sharedConfig = {
|
||||
preset: "ts-jest",
|
||||
testEnvironment: "node",
|
||||
setupFilesAfterEnv: ["<rootDir>/tests/setup/jest.setup.ts"],
|
||||
|
||||
// Performance: Use 50% of CPU cores for parallel execution
|
||||
maxWorkers: "50%",
|
||||
|
||||
moduleNameMapper: {
|
||||
// Mock react-markdown and related packages
|
||||
"^react-markdown$": "<rootDir>/tests/setup/__mocks__/react-markdown.tsx",
|
||||
"^remark-gfm$": "<rootDir>/tests/setup/__mocks__/remark-gfm.ts",
|
||||
// Mock UserProvider
|
||||
"^@/components/user/UserProvider$":
|
||||
"<rootDir>/tests/setup/__mocks__/@/components/user/UserProvider.tsx",
|
||||
// Path aliases (must come after specific mocks)
|
||||
"^@/(.*)$": "<rootDir>/src/$1",
|
||||
"^@tests/(.*)$": "<rootDir>/tests/$1",
|
||||
// Mock CSS imports
|
||||
"\\.(css|less|scss|sass)$": "identity-obj-proxy",
|
||||
// Mock static file imports
|
||||
"\\.(jpg|jpeg|png|gif|svg|woff|woff2|ttf|eot)$":
|
||||
"<rootDir>/tests/setup/fileMock.js",
|
||||
},
|
||||
testPathIgnorePatterns: ["/node_modules/", "/tests/e2e/"],
|
||||
|
||||
testPathIgnorePatterns: ["/node_modules/", "/tests/e2e/", "/.next/"],
|
||||
|
||||
transformIgnorePatterns: [
|
||||
"/node_modules/(?!(jose|@radix-ui|@headlessui|@phosphor-icons|msw|until-async|react-markdown|remark-gfm|remark-parse|unified|bail|is-plain-obj|trough|vfile|unist-.*|mdast-.*|micromark.*|decode-named-character-reference|character-entities)/)",
|
||||
],
|
||||
|
||||
transform: {
|
||||
"^.+\\.tsx?$": "ts-jest",
|
||||
"^.+\\.tsx?$": [
|
||||
"ts-jest",
|
||||
{
|
||||
// Performance: Disable type-checking in tests (types are checked by tsc)
|
||||
isolatedModules: true,
|
||||
tsconfig: {
|
||||
jsx: "react-jsx",
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
|
||||
// Performance: Cache results between runs
|
||||
cache: true,
|
||||
cacheDirectory: "<rootDir>/.jest-cache",
|
||||
|
||||
collectCoverageFrom: [
|
||||
"src/**/*.{ts,tsx}",
|
||||
"!src/**/*.d.ts",
|
||||
"!src/**/*.stories.tsx",
|
||||
],
|
||||
|
||||
coveragePathIgnorePatterns: ["/node_modules/", "/tests/", "/.next/"],
|
||||
|
||||
// Performance: Clear mocks automatically between tests
|
||||
clearMocks: true,
|
||||
resetMocks: false,
|
||||
restoreMocks: false,
|
||||
};
|
||||
|
||||
module.exports = {
|
||||
projects: [
|
||||
{
|
||||
displayName: "unit",
|
||||
...sharedConfig,
|
||||
testEnvironment: "node",
|
||||
testMatch: [
|
||||
// Pure unit tests that don't need DOM
|
||||
"**/src/**/codeUtils.test.ts",
|
||||
"**/src/lib/**/*.test.ts",
|
||||
// Add more patterns here as you add more unit tests
|
||||
],
|
||||
},
|
||||
{
|
||||
displayName: "integration",
|
||||
...sharedConfig,
|
||||
testEnvironment: "jsdom",
|
||||
testMatch: [
|
||||
// React component integration tests
|
||||
"**/src/app/**/*.test.tsx",
|
||||
"**/src/components/**/*.test.tsx",
|
||||
"**/src/lib/**/*.test.tsx",
|
||||
// Add more patterns here as you add more integration tests
|
||||
],
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
938
web/package-lock.json
generated
938
web/package-lock.json
generated
File diff suppressed because it is too large
Load Diff
@@ -11,7 +11,13 @@
|
||||
"lint:unused": "eslint --ext .js,.jsx,.ts,.tsx --rule 'unused-imports/no-unused-imports: error' --quiet --fix=false src/",
|
||||
"lint:fix-unused": "eslint --ext .js,.jsx,.ts,.tsx --rule 'unused-imports/no-unused-imports: error' --quiet --fix src/",
|
||||
"lint:fix-unused-vars": "eslint --ext .js,.jsx,.ts,.tsx --fix --quiet src/",
|
||||
"test": "jest"
|
||||
"test": "jest",
|
||||
"test:watch": "jest --watch",
|
||||
"test:coverage": "jest --coverage",
|
||||
"test:verbose": "jest --verbose",
|
||||
"test:ci": "jest --ci --maxWorkers=2 --silent --bail",
|
||||
"test:changed": "jest --onlyChanged",
|
||||
"test:debug": "node --inspect-brk node_modules/.bin/jest --runInBand"
|
||||
},
|
||||
"dependencies": {
|
||||
"@dnd-kit/core": "^6.1.0",
|
||||
@@ -96,6 +102,9 @@
|
||||
"@chromatic-com/playwright": "^0.10.2",
|
||||
"@playwright/test": "^1.39.0",
|
||||
"@tailwindcss/typography": "^0.5.10",
|
||||
"@testing-library/jest-dom": "^6.9.1",
|
||||
"@testing-library/react": "^14.3.1",
|
||||
"@testing-library/user-event": "^14.6.1",
|
||||
"@types/chrome": "^0.0.287",
|
||||
"@types/jest": "^29.5.14",
|
||||
"@types/js-cookie": "^3.0.6",
|
||||
@@ -109,10 +118,13 @@
|
||||
"eslint": "^8.57.1",
|
||||
"eslint-config-next": "^14.1.0",
|
||||
"eslint-plugin-unused-imports": "^4.1.4",
|
||||
"identity-obj-proxy": "^3.0.0",
|
||||
"jest": "^29.7.0",
|
||||
"jest-environment-jsdom": "^29.7.0",
|
||||
"prettier": "3.1.0",
|
||||
"ts-jest": "^29.2.5",
|
||||
"ts-unused-exports": "^11.0.1"
|
||||
"ts-unused-exports": "^11.0.1",
|
||||
"whatwg-fetch": "^3.6.20"
|
||||
},
|
||||
"overrides": {
|
||||
"react-is": "^19.0.0-rc-69d4b800-20241021"
|
||||
|
||||
@@ -23,6 +23,8 @@ export default defineConfig({
|
||||
// },
|
||||
// ],
|
||||
],
|
||||
// Only run Playwright tests from tests/e2e directory (ignore Jest tests in src/)
|
||||
testMatch: /.*\/tests\/e2e\/.*\.spec\.ts/,
|
||||
projects: [
|
||||
{
|
||||
name: "admin",
|
||||
@@ -31,7 +33,6 @@ export default defineConfig({
|
||||
viewport: { width: 1280, height: 720 },
|
||||
storageState: "admin_auth.json",
|
||||
},
|
||||
testIgnore: ["**/codeUtils.test.ts"],
|
||||
},
|
||||
{
|
||||
name: "no-auth",
|
||||
@@ -39,7 +40,6 @@ export default defineConfig({
|
||||
...devices["Desktop Chrome"],
|
||||
viewport: { width: 1280, height: 720 },
|
||||
},
|
||||
testIgnore: ["**/codeUtils.test.ts"],
|
||||
},
|
||||
],
|
||||
});
|
||||
|
||||
@@ -8,7 +8,7 @@ import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import useSWR, { mutate } from "swr";
|
||||
import { ErrorCallout } from "@/components/ErrorCallout";
|
||||
import { OnyxSparkleIcon } from "@/components/icons/icons";
|
||||
import OnyxLogo from "@/icons/onyx-logo";
|
||||
import { usePopup } from "@/components/admin/connectors/Popup";
|
||||
import { useAgentsContext } from "@/refresh-components/contexts/AgentsContext";
|
||||
import { Switch } from "@/components/ui/switch";
|
||||
@@ -312,7 +312,9 @@ export default function Page() {
|
||||
<div className="mx-auto max-w-4xl w-full">
|
||||
<AdminPageTitle
|
||||
title="Default Assistant"
|
||||
icon={<OnyxSparkleIcon size={32} className="my-auto" />}
|
||||
icon={
|
||||
<OnyxLogo className="my-auto w-[1.5rem] h-[1.5rem] stroke-text-04" />
|
||||
}
|
||||
/>
|
||||
<DefaultAssistantConfig />
|
||||
</div>
|
||||
|
||||
@@ -9,7 +9,7 @@ import { mutate } from "swr";
|
||||
import { Badge } from "@/components/ui/badge";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { isSubset } from "@/lib/utils";
|
||||
import { cn, isSubset } from "@/lib/utils";
|
||||
|
||||
function LLMProviderUpdateModal({
|
||||
llmProviderDescriptor,
|
||||
@@ -70,6 +70,29 @@ function LLMProviderDisplay({
|
||||
const [formIsVisible, setFormIsVisible] = useState(false);
|
||||
const { popup, setPopup } = usePopup();
|
||||
|
||||
async function handleSetAsDefault(): Promise<void> {
|
||||
const response = await fetch(
|
||||
`${LLM_PROVIDERS_ADMIN_URL}/${existingLlmProvider.id}/default`,
|
||||
{
|
||||
method: "POST",
|
||||
}
|
||||
);
|
||||
if (!response.ok) {
|
||||
const errorMsg = (await response.json()).detail;
|
||||
setPopup({
|
||||
type: "error",
|
||||
message: `Failed to set provider as default: ${errorMsg}`,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
await mutate(LLM_PROVIDERS_ADMIN_URL);
|
||||
setPopup({
|
||||
type: "success",
|
||||
message: "Provider set as default successfully!",
|
||||
});
|
||||
}
|
||||
|
||||
const providerName =
|
||||
existingLlmProvider?.name ||
|
||||
llmProviderDescriptor?.display_name ||
|
||||
@@ -87,29 +110,8 @@ function LLMProviderDisplay({
|
||||
</Text>
|
||||
{!existingLlmProvider.is_default_provider && (
|
||||
<Text
|
||||
className="text-action-link-05"
|
||||
onClick={async () => {
|
||||
const response = await fetch(
|
||||
`${LLM_PROVIDERS_ADMIN_URL}/${existingLlmProvider.id}/default`,
|
||||
{
|
||||
method: "POST",
|
||||
}
|
||||
);
|
||||
if (!response.ok) {
|
||||
const errorMsg = (await response.json()).detail;
|
||||
setPopup({
|
||||
type: "error",
|
||||
message: `Failed to set provider as default: ${errorMsg}`,
|
||||
});
|
||||
return;
|
||||
}
|
||||
|
||||
mutate(LLM_PROVIDERS_ADMIN_URL);
|
||||
setPopup({
|
||||
type: "success",
|
||||
message: "Provider set as default successfully!",
|
||||
});
|
||||
}}
|
||||
className={cn("text-action-link-05", "cursor-pointer")}
|
||||
onClick={handleSetAsDefault}
|
||||
>
|
||||
Set as default
|
||||
</Text>
|
||||
|
||||
@@ -0,0 +1,464 @@
|
||||
/**
|
||||
* Integration Test: Custom LLM Provider Configuration Workflow
|
||||
*
|
||||
* Tests the complete user journey for configuring a custom LLM provider.
|
||||
* This tests the full workflow: form fill → test config → save → set as default
|
||||
*/
|
||||
import React from "react";
|
||||
import { render, screen, setupUser, waitFor } from "@tests/setup/test-utils";
|
||||
import { CustomLLMProviderUpdateForm } from "./CustomLLMProviderUpdateForm";
|
||||
|
||||
// Mock SWR's mutate function
|
||||
const mockMutate = jest.fn();
|
||||
jest.mock("swr", () => ({
|
||||
...jest.requireActual("swr"),
|
||||
useSWRConfig: () => ({ mutate: mockMutate }),
|
||||
}));
|
||||
|
||||
// Mock usePaidEnterpriseFeaturesEnabled
|
||||
jest.mock("@/components/settings/usePaidEnterpriseFeaturesEnabled", () => ({
|
||||
usePaidEnterpriseFeaturesEnabled: () => false,
|
||||
}));
|
||||
|
||||
describe("Custom LLM Provider Configuration Workflow", () => {
|
||||
let fetchSpy: jest.SpyInstance;
|
||||
const mockOnClose = jest.fn();
|
||||
const mockSetPopup = jest.fn();
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
fetchSpy = jest.spyOn(global, "fetch");
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
fetchSpy.mockRestore();
|
||||
});
|
||||
|
||||
test("creates a new custom LLM provider successfully", async () => {
|
||||
const user = setupUser();
|
||||
|
||||
// Mock POST /api/admin/llm/test
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({}),
|
||||
} as Response);
|
||||
|
||||
// Mock PUT /api/admin/llm/provider?is_creation=true
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
id: 1,
|
||||
name: "My Custom Provider",
|
||||
provider: "openai",
|
||||
api_key: "test-key",
|
||||
default_model_name: "gpt-4",
|
||||
}),
|
||||
} as Response);
|
||||
|
||||
render(
|
||||
<CustomLLMProviderUpdateForm
|
||||
onClose={mockOnClose}
|
||||
setPopup={mockSetPopup}
|
||||
/>
|
||||
);
|
||||
|
||||
// Fill in the form
|
||||
const nameInput = screen.getByPlaceholderText(/display name/i);
|
||||
const providerInput = screen.getByPlaceholderText(
|
||||
/name of the custom provider/i
|
||||
);
|
||||
const apiKeyInput = screen.getByPlaceholderText(/api key/i);
|
||||
|
||||
await user.type(nameInput, "My Custom Provider");
|
||||
await user.type(providerInput, "openai");
|
||||
await user.type(apiKeyInput, "test-key-123");
|
||||
|
||||
// Fill in model configuration (use placeholder to find input)
|
||||
const modelNameInput = screen.getByPlaceholderText(/model-name-1/i);
|
||||
await user.type(modelNameInput, "gpt-4");
|
||||
|
||||
// Set default model (there are 2 inputs with this placeholder - default and fast)
|
||||
// We want the first one (Default Model)
|
||||
const defaultModelInputs = screen.getAllByPlaceholderText(/e\.g\. gpt-4/i);
|
||||
await user.type(defaultModelInputs[0], "gpt-4");
|
||||
|
||||
// Submit the form
|
||||
const submitButton = screen.getByRole("button", { name: /enable/i });
|
||||
await user.click(submitButton);
|
||||
|
||||
// Verify test API was called first
|
||||
await waitFor(() => {
|
||||
expect(fetchSpy).toHaveBeenCalledWith(
|
||||
"/api/admin/llm/test",
|
||||
expect.objectContaining({
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
// Verify create API was called
|
||||
await waitFor(() => {
|
||||
expect(fetchSpy).toHaveBeenCalledWith(
|
||||
"/api/admin/llm/provider?is_creation=true",
|
||||
expect.objectContaining({
|
||||
method: "PUT",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
// Verify success popup
|
||||
await waitFor(() => {
|
||||
expect(mockSetPopup).toHaveBeenCalledWith({
|
||||
type: "success",
|
||||
message: "Provider enabled successfully!",
|
||||
});
|
||||
});
|
||||
|
||||
// Verify onClose was called
|
||||
expect(mockOnClose).toHaveBeenCalled();
|
||||
|
||||
// Verify SWR cache was invalidated
|
||||
expect(mockMutate).toHaveBeenCalledWith("/api/admin/llm/provider");
|
||||
});
|
||||
|
||||
test("shows error when test configuration fails", async () => {
|
||||
const user = setupUser();
|
||||
|
||||
// Mock POST /api/admin/llm/test (failure)
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: false,
|
||||
status: 400,
|
||||
json: async () => ({ detail: "Invalid API key" }),
|
||||
} as Response);
|
||||
|
||||
render(
|
||||
<CustomLLMProviderUpdateForm
|
||||
onClose={mockOnClose}
|
||||
setPopup={mockSetPopup}
|
||||
/>
|
||||
);
|
||||
|
||||
// Fill in the form with invalid credentials
|
||||
const nameInput = screen.getByPlaceholderText(/display name/i);
|
||||
const providerInput = screen.getByPlaceholderText(
|
||||
/name of the custom provider/i
|
||||
);
|
||||
const apiKeyInput = screen.getByPlaceholderText(/api key/i);
|
||||
|
||||
await user.type(nameInput, "Bad Provider");
|
||||
await user.type(providerInput, "openai");
|
||||
await user.type(apiKeyInput, "invalid-key");
|
||||
|
||||
// Fill in model configuration
|
||||
const modelNameInput = screen.getByPlaceholderText(/model-name-1/i);
|
||||
await user.type(modelNameInput, "gpt-4");
|
||||
|
||||
// Set default model (there are 2 inputs with this placeholder - default and fast)
|
||||
const defaultModelInputs = screen.getAllByPlaceholderText(/e\.g\. gpt-4/i);
|
||||
await user.type(defaultModelInputs[0], "gpt-4");
|
||||
|
||||
// Submit the form
|
||||
const submitButton = screen.getByRole("button", { name: /enable/i });
|
||||
await user.click(submitButton);
|
||||
|
||||
// Verify test API was called
|
||||
await waitFor(() => {
|
||||
expect(fetchSpy).toHaveBeenCalledWith(
|
||||
"/api/admin/llm/test",
|
||||
expect.objectContaining({
|
||||
method: "POST",
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
// Verify error is displayed (form should NOT proceed to create)
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText(/invalid api key/i)).toBeInTheDocument();
|
||||
});
|
||||
|
||||
// Verify create API was NOT called
|
||||
expect(
|
||||
fetchSpy.mock.calls.find((call) =>
|
||||
call[0].includes("/api/admin/llm/provider")
|
||||
)
|
||||
).toBeUndefined();
|
||||
});
|
||||
|
||||
test("updates an existing LLM provider", async () => {
|
||||
const user = setupUser();
|
||||
|
||||
const existingProvider = {
|
||||
id: 1,
|
||||
name: "Existing Provider",
|
||||
provider: "anthropic",
|
||||
api_key: "old-key",
|
||||
api_base: "",
|
||||
api_version: "",
|
||||
default_model_name: "claude-3-opus",
|
||||
fast_default_model_name: null,
|
||||
model_configurations: [
|
||||
{ name: "claude-3-opus", is_visible: true, max_input_tokens: null },
|
||||
],
|
||||
custom_config: {},
|
||||
is_public: true,
|
||||
groups: [],
|
||||
deployment_name: null,
|
||||
};
|
||||
|
||||
// Mock POST /api/admin/llm/test
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({}),
|
||||
} as Response);
|
||||
|
||||
// Mock PUT /api/admin/llm/provider (update, no is_creation param)
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({ ...existingProvider, api_key: "new-key" }),
|
||||
} as Response);
|
||||
|
||||
render(
|
||||
<CustomLLMProviderUpdateForm
|
||||
onClose={mockOnClose}
|
||||
existingLlmProvider={existingProvider}
|
||||
setPopup={mockSetPopup}
|
||||
/>
|
||||
);
|
||||
|
||||
// Update the API key
|
||||
const apiKeyInput = screen.getByPlaceholderText(/api key/i);
|
||||
await user.clear(apiKeyInput);
|
||||
await user.type(apiKeyInput, "new-key-456");
|
||||
|
||||
// Submit
|
||||
const submitButton = screen.getByRole("button", { name: /update/i });
|
||||
await user.click(submitButton);
|
||||
|
||||
// Verify test was called
|
||||
await waitFor(() => {
|
||||
expect(fetchSpy).toHaveBeenCalledWith(
|
||||
"/api/admin/llm/test",
|
||||
expect.any(Object)
|
||||
);
|
||||
});
|
||||
|
||||
// Verify update API was called (without is_creation param)
|
||||
await waitFor(() => {
|
||||
expect(fetchSpy).toHaveBeenCalledWith(
|
||||
"/api/admin/llm/provider",
|
||||
expect.objectContaining({
|
||||
method: "PUT",
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
// Verify success message says "updated"
|
||||
await waitFor(() => {
|
||||
expect(mockSetPopup).toHaveBeenCalledWith({
|
||||
type: "success",
|
||||
message: "Provider updated successfully!",
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
test("sets provider as default when shouldMarkAsDefault is true", async () => {
|
||||
const user = setupUser();
|
||||
|
||||
// Mock POST /api/admin/llm/test
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({}),
|
||||
} as Response);
|
||||
|
||||
// Mock PUT /api/admin/llm/provider?is_creation=true
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
id: 5,
|
||||
name: "New Default Provider",
|
||||
provider: "openai",
|
||||
}),
|
||||
} as Response);
|
||||
|
||||
// Mock POST /api/admin/llm/provider/5/default
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({}),
|
||||
} as Response);
|
||||
|
||||
render(
|
||||
<CustomLLMProviderUpdateForm
|
||||
onClose={mockOnClose}
|
||||
setPopup={mockSetPopup}
|
||||
shouldMarkAsDefault={true}
|
||||
/>
|
||||
);
|
||||
|
||||
// Fill form
|
||||
const nameInput = screen.getByPlaceholderText(/display name/i);
|
||||
await user.type(nameInput, "New Default Provider");
|
||||
|
||||
const providerInput = screen.getByPlaceholderText(
|
||||
/name of the custom provider/i
|
||||
);
|
||||
await user.type(providerInput, "openai");
|
||||
|
||||
// Fill in model configuration
|
||||
const modelNameInput = screen.getByPlaceholderText(/model-name-1/i);
|
||||
await user.type(modelNameInput, "gpt-4");
|
||||
|
||||
// Set default model (there are 2 inputs with this placeholder - default and fast)
|
||||
const defaultModelInputs = screen.getAllByPlaceholderText(/e\.g\. gpt-4/i);
|
||||
await user.type(defaultModelInputs[0], "gpt-4");
|
||||
|
||||
// Submit
|
||||
const submitButton = screen.getByRole("button", { name: /enable/i });
|
||||
await user.click(submitButton);
|
||||
|
||||
// Verify set as default API was called
|
||||
await waitFor(() => {
|
||||
expect(fetchSpy).toHaveBeenCalledWith(
|
||||
"/api/admin/llm/provider/5/default",
|
||||
expect.objectContaining({
|
||||
method: "POST",
|
||||
})
|
||||
);
|
||||
});
|
||||
});
|
||||
|
||||
test("shows error when provider creation fails", async () => {
|
||||
const user = setupUser();
|
||||
|
||||
// Mock POST /api/admin/llm/test
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({}),
|
||||
} as Response);
|
||||
|
||||
// Mock PUT /api/admin/llm/provider?is_creation=true (failure)
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: false,
|
||||
status: 500,
|
||||
json: async () => ({ detail: "Database error" }),
|
||||
} as Response);
|
||||
|
||||
render(
|
||||
<CustomLLMProviderUpdateForm
|
||||
onClose={mockOnClose}
|
||||
setPopup={mockSetPopup}
|
||||
/>
|
||||
);
|
||||
|
||||
// Fill form
|
||||
const nameInput = screen.getByPlaceholderText(/display name/i);
|
||||
await user.type(nameInput, "Test Provider");
|
||||
|
||||
const providerInput = screen.getByPlaceholderText(
|
||||
/name of the custom provider/i
|
||||
);
|
||||
await user.type(providerInput, "openai");
|
||||
|
||||
// Fill in model configuration
|
||||
const modelNameInput = screen.getByPlaceholderText(/model-name-1/i);
|
||||
await user.type(modelNameInput, "gpt-4");
|
||||
|
||||
// Set default model (there are 2 inputs with this placeholder - default and fast)
|
||||
const defaultModelInputs = screen.getAllByPlaceholderText(/e\.g\. gpt-4/i);
|
||||
await user.type(defaultModelInputs[0], "gpt-4");
|
||||
|
||||
// Submit
|
||||
const submitButton = screen.getByRole("button", { name: /enable/i });
|
||||
await user.click(submitButton);
|
||||
|
||||
// Verify error popup
|
||||
await waitFor(() => {
|
||||
expect(mockSetPopup).toHaveBeenCalledWith({
|
||||
type: "error",
|
||||
message: "Failed to enable provider: Database error",
|
||||
});
|
||||
});
|
||||
|
||||
// Verify onClose was NOT called
|
||||
expect(mockOnClose).not.toHaveBeenCalled();
|
||||
});
|
||||
|
||||
test("adds custom configuration key-value pairs", async () => {
|
||||
const user = setupUser();
|
||||
|
||||
// Mock POST /api/admin/llm/test
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({}),
|
||||
} as Response);
|
||||
|
||||
// Mock PUT /api/admin/llm/provider?is_creation=true
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({ id: 1, name: "Provider with Custom Config" }),
|
||||
} as Response);
|
||||
|
||||
render(
|
||||
<CustomLLMProviderUpdateForm
|
||||
onClose={mockOnClose}
|
||||
setPopup={mockSetPopup}
|
||||
/>
|
||||
);
|
||||
|
||||
// Fill basic fields
|
||||
const nameInput = screen.getByPlaceholderText(/display name/i);
|
||||
await user.type(nameInput, "Cloudflare Provider");
|
||||
|
||||
const providerInput = screen.getByPlaceholderText(
|
||||
/name of the custom provider/i
|
||||
);
|
||||
await user.type(providerInput, "cloudflare");
|
||||
|
||||
// Click "Add New" button for custom config (there are 2 "Add New" buttons - one for custom config, one for models)
|
||||
// The custom config "Add New" appears first
|
||||
const addNewButtons = screen.getAllByRole("button", { name: /add new/i });
|
||||
const customConfigAddButton = addNewButtons[0]; // First "Add New" is for custom config
|
||||
await user.click(customConfigAddButton);
|
||||
|
||||
// Fill in custom config key-value pair
|
||||
const customConfigInputs = screen.getAllByRole("textbox");
|
||||
const keyInput = customConfigInputs.find(
|
||||
(input) => input.getAttribute("name") === "custom_config_list[0][0]"
|
||||
);
|
||||
const valueInput = customConfigInputs.find(
|
||||
(input) => input.getAttribute("name") === "custom_config_list[0][1]"
|
||||
);
|
||||
|
||||
expect(keyInput).toBeDefined();
|
||||
expect(valueInput).toBeDefined();
|
||||
|
||||
await user.type(keyInput!, "CLOUDFLARE_ACCOUNT_ID");
|
||||
await user.type(valueInput!, "my-account-id-123");
|
||||
|
||||
// Fill in model configuration
|
||||
const modelNameInput = screen.getByPlaceholderText(/model-name-1/i);
|
||||
await user.type(modelNameInput, "@cf/meta/llama-2-7b-chat-int8");
|
||||
|
||||
// Set default model (there are 2 inputs with this placeholder - default and fast)
|
||||
const defaultModelInputs = screen.getAllByPlaceholderText(/e\.g\. gpt-4/i);
|
||||
await user.type(defaultModelInputs[0], "@cf/meta/llama-2-7b-chat-int8");
|
||||
|
||||
// Submit
|
||||
const submitButton = screen.getByRole("button", { name: /enable/i });
|
||||
await user.click(submitButton);
|
||||
|
||||
// Verify the custom config was included in the request
|
||||
await waitFor(() => {
|
||||
const createCall = fetchSpy.mock.calls.find((call) =>
|
||||
call[0].includes("/api/admin/llm/provider")
|
||||
);
|
||||
expect(createCall).toBeDefined();
|
||||
|
||||
const requestBody = JSON.parse(createCall![1].body);
|
||||
expect(requestBody.custom_config).toEqual({
|
||||
CLOUDFLARE_ACCOUNT_ID: "my-account-id-123",
|
||||
});
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -88,7 +88,13 @@ export function CustomLLMProviderUpdateForm({
|
||||
Yup.object({
|
||||
name: Yup.string().required("Model name is required"),
|
||||
is_visible: Yup.boolean().required("Visibility is required"),
|
||||
max_input_tokens: Yup.number().nullable().optional(),
|
||||
// Coerce empty string from input field into null so it's optional
|
||||
max_input_tokens: Yup.number()
|
||||
.transform((value, originalValue) =>
|
||||
originalValue === "" || originalValue === undefined ? null : value
|
||||
)
|
||||
.nullable()
|
||||
.optional(),
|
||||
})
|
||||
),
|
||||
default_model_name: Yup.string().required("Model name is required"),
|
||||
|
||||
@@ -106,20 +106,25 @@ function AddCustomLLMProvider({
|
||||
existingLlmProviders: LLMProviderView[];
|
||||
}) {
|
||||
const [formIsVisible, setFormIsVisible] = useState(false);
|
||||
const { popup, setPopup } = usePopup();
|
||||
|
||||
if (formIsVisible) {
|
||||
return (
|
||||
<Modal
|
||||
title={`Setup Custom LLM Provider`}
|
||||
onOutsideClick={() => setFormIsVisible(false)}
|
||||
>
|
||||
<div className="max-h-[70vh] overflow-y-auto px-4">
|
||||
<CustomLLMProviderUpdateForm
|
||||
onClose={() => setFormIsVisible(false)}
|
||||
shouldMarkAsDefault={existingLlmProviders.length === 0}
|
||||
/>
|
||||
</div>
|
||||
</Modal>
|
||||
<>
|
||||
{popup}
|
||||
<Modal
|
||||
title={`Setup Custom LLM Provider`}
|
||||
onOutsideClick={() => setFormIsVisible(false)}
|
||||
>
|
||||
<div className="max-h-[70vh] overflow-y-auto px-4">
|
||||
<CustomLLMProviderUpdateForm
|
||||
onClose={() => setFormIsVisible(false)}
|
||||
shouldMarkAsDefault={existingLlmProviders.length === 0}
|
||||
setPopup={setPopup}
|
||||
/>
|
||||
</div>
|
||||
</Modal>
|
||||
</>
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -132,12 +132,14 @@ export function LLMProviderUpdateForm({
|
||||
llmProviderDescriptor.default_api_base ??
|
||||
"",
|
||||
api_version: existingLlmProvider?.api_version ?? "",
|
||||
// For Azure OpenAI, combine api_base and api_version into target_uri
|
||||
// For Azure OpenAI, combine api_base, deployment_name, and api_version into target_uri
|
||||
target_uri:
|
||||
llmProviderDescriptor.name === "azure" &&
|
||||
existingLlmProvider?.api_base &&
|
||||
existingLlmProvider?.api_version
|
||||
? `${existingLlmProvider.api_base}/openai/deployments/your-deployment?api-version=${existingLlmProvider.api_version}`
|
||||
? `${existingLlmProvider.api_base}/openai/deployments/${
|
||||
existingLlmProvider.deployment_name || "your-deployment"
|
||||
}/chat/completions?api-version=${existingLlmProvider.api_version}`
|
||||
: "",
|
||||
default_model_name:
|
||||
existingLlmProvider?.default_model_name ??
|
||||
@@ -201,20 +203,22 @@ export function LLMProviderUpdateForm({
|
||||
.required("Target URI is required")
|
||||
.test(
|
||||
"valid-target-uri",
|
||||
"Target URI must be a valid URL with exactly one query parameter (api-version)",
|
||||
"Target URI must be a valid URL with api-version query parameter and the deployment name in the path",
|
||||
(value) => {
|
||||
if (!value) return false;
|
||||
try {
|
||||
const url = new URL(value);
|
||||
const params = new URLSearchParams(url.search);
|
||||
const paramKeys = Array.from(params.keys());
|
||||
const hasApiVersion = !!url.searchParams
|
||||
.get("api-version")
|
||||
?.trim();
|
||||
|
||||
// Check if there's exactly one parameter and it's api-version
|
||||
return (
|
||||
paramKeys.length === 1 &&
|
||||
paramKeys[0] === "api-version" &&
|
||||
!!params.get("api-version")
|
||||
// Check if the path contains a deployment name
|
||||
const pathMatch = url.pathname.match(
|
||||
/\/openai\/deployments\/([^\/]+)/
|
||||
);
|
||||
const hasDeploymentName = pathMatch && pathMatch[1];
|
||||
|
||||
return hasApiVersion && !!hasDeploymentName;
|
||||
} catch {
|
||||
return false;
|
||||
}
|
||||
@@ -240,9 +244,11 @@ export function LLMProviderUpdateForm({
|
||||
),
|
||||
}
|
||||
: {}),
|
||||
deployment_name: llmProviderDescriptor.deployment_name_required
|
||||
? Yup.string().required("Deployment Name is required")
|
||||
: Yup.string().nullable(),
|
||||
deployment_name:
|
||||
llmProviderDescriptor.deployment_name_required &&
|
||||
llmProviderDescriptor.name !== "azure"
|
||||
? Yup.string().required("Deployment Name is required")
|
||||
: Yup.string().nullable(),
|
||||
default_model_name: Yup.string().required("Model name is required"),
|
||||
fast_default_model_name: Yup.string().nullable(),
|
||||
// EE Only
|
||||
@@ -276,15 +282,24 @@ export function LLMProviderUpdateForm({
|
||||
...rest
|
||||
} = values;
|
||||
|
||||
// For Azure OpenAI, parse target_uri to extract api_base and api_version
|
||||
// For Azure OpenAI, parse target_uri to extract api_base, api_version, and deployment_name
|
||||
let finalApiBase = rest.api_base;
|
||||
let finalApiVersion = rest.api_version;
|
||||
let finalDeploymentName = rest.deployment_name;
|
||||
|
||||
if (llmProviderDescriptor.name === "azure" && target_uri) {
|
||||
try {
|
||||
const url = new URL(target_uri);
|
||||
finalApiBase = url.origin; // Only use origin (protocol + hostname + port)
|
||||
finalApiVersion = url.searchParams.get("api-version") || "";
|
||||
|
||||
// Extract deployment name from path: /openai/deployments/{deployment-name}/...
|
||||
const pathMatch = url.pathname.match(
|
||||
/\/openai\/deployments\/([^\/]+)/
|
||||
);
|
||||
if (pathMatch && pathMatch[1]) {
|
||||
finalDeploymentName = pathMatch[1];
|
||||
}
|
||||
} catch (error) {
|
||||
// This should not happen due to validation, but handle gracefully
|
||||
console.error("Failed to parse target_uri:", error);
|
||||
@@ -296,6 +311,7 @@ export function LLMProviderUpdateForm({
|
||||
...rest,
|
||||
api_base: finalApiBase,
|
||||
api_version: finalApiVersion,
|
||||
deployment_name: finalDeploymentName,
|
||||
api_key_changed: values.api_key !== initialValues.api_key,
|
||||
model_configurations: getCurrentModelConfigurations(values).map(
|
||||
(modelConfiguration): ModelConfiguration => ({
|
||||
@@ -451,7 +467,7 @@ export function LLMProviderUpdateForm({
|
||||
name="target_uri"
|
||||
label="Target URI"
|
||||
placeholder="https://your-resource.cognitiveservices.azure.com/openai/deployments/deployment-name/chat/completions?api-version=2025-01-01-preview"
|
||||
subtext="The complete Azure OpenAI endpoint URL including the API version as a query parameter"
|
||||
subtext="The complete target URI for your deployment from the Azure AI portal."
|
||||
/>
|
||||
) : (
|
||||
<>
|
||||
@@ -612,13 +628,14 @@ export function LLMProviderUpdateForm({
|
||||
/>
|
||||
)}
|
||||
|
||||
{llmProviderDescriptor.deployment_name_required && (
|
||||
<TextFormField
|
||||
name="deployment_name"
|
||||
label="Deployment Name"
|
||||
placeholder="Deployment Name"
|
||||
/>
|
||||
)}
|
||||
{llmProviderDescriptor.deployment_name_required &&
|
||||
llmProviderDescriptor.name !== "azure" && (
|
||||
<TextFormField
|
||||
name="deployment_name"
|
||||
label="Deployment Name"
|
||||
placeholder="Deployment Name"
|
||||
/>
|
||||
)}
|
||||
|
||||
{!llmProviderDescriptor.single_model_supported &&
|
||||
(currentModelConfigurations.length > 0 ? (
|
||||
@@ -726,22 +743,24 @@ export function LLMProviderUpdateForm({
|
||||
}
|
||||
|
||||
// If the deleted provider was the default, set the first remaining provider as default
|
||||
const remainingProvidersResponse = await fetch(
|
||||
LLM_PROVIDERS_ADMIN_URL
|
||||
);
|
||||
if (remainingProvidersResponse.ok) {
|
||||
const remainingProviders =
|
||||
await remainingProvidersResponse.json();
|
||||
if (existingLlmProvider.is_default_provider) {
|
||||
const remainingProvidersResponse = await fetch(
|
||||
LLM_PROVIDERS_ADMIN_URL
|
||||
);
|
||||
if (remainingProvidersResponse.ok) {
|
||||
const remainingProviders =
|
||||
await remainingProvidersResponse.json();
|
||||
|
||||
if (remainingProviders.length > 0) {
|
||||
const setDefaultResponse = await fetch(
|
||||
`${LLM_PROVIDERS_ADMIN_URL}/${remainingProviders[0].id}/default`,
|
||||
{
|
||||
method: "POST",
|
||||
if (remainingProviders.length > 0) {
|
||||
const setDefaultResponse = await fetch(
|
||||
`${LLM_PROVIDERS_ADMIN_URL}/${remainingProviders[0].id}/default`,
|
||||
{
|
||||
method: "POST",
|
||||
}
|
||||
);
|
||||
if (!setDefaultResponse.ok) {
|
||||
console.error("Failed to set new default provider");
|
||||
}
|
||||
);
|
||||
if (!setDefaultResponse.ok) {
|
||||
console.error("Failed to set new default provider");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -150,7 +150,8 @@ export function ModelConfigurationField({
|
||||
arrayHelpers.push({
|
||||
name: "",
|
||||
is_visible: true,
|
||||
max_input_tokens: "",
|
||||
// Use null so Yup.number().nullable() accepts empty inputs
|
||||
max_input_tokens: null,
|
||||
});
|
||||
}}
|
||||
className="mt-3"
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import {
|
||||
AnthropicIcon,
|
||||
AmazonIcon,
|
||||
AzureIcon,
|
||||
CPUIcon,
|
||||
MicrosoftIconSVG,
|
||||
MistralIcon,
|
||||
@@ -38,6 +39,8 @@ export const getProviderIcon = (
|
||||
claude: AnthropicIcon,
|
||||
anthropic: AnthropicIcon,
|
||||
openai: OpenAISVG,
|
||||
// Azure OpenAI should display the Azure logo
|
||||
azure: AzureIcon,
|
||||
microsoft: MicrosoftIconSVG,
|
||||
meta: MetaIcon,
|
||||
google: GeminiIcon,
|
||||
@@ -126,6 +129,31 @@ export const dynamicProviderConfigs: Record<
|
||||
successMessage: (count: number) =>
|
||||
`Successfully fetched ${count} models from Ollama.`,
|
||||
},
|
||||
openrouter: {
|
||||
endpoint: "/api/admin/llm/openrouter/available-models",
|
||||
isDisabled: (values) => !values.api_base || !values.api_key,
|
||||
disabledReason:
|
||||
"API Base and API Key are required to fetch OpenRouter models",
|
||||
buildRequestBody: ({ values }) => ({
|
||||
api_base: values.api_base,
|
||||
api_key: values.api_key,
|
||||
}),
|
||||
processResponse: (data: OllamaModelResponse[], llmProviderDescriptor) =>
|
||||
data.map((modelData) => {
|
||||
const existingConfig = llmProviderDescriptor.model_configurations.find(
|
||||
(config) => config.name === modelData.name
|
||||
);
|
||||
return {
|
||||
name: modelData.name,
|
||||
is_visible: existingConfig?.is_visible ?? true,
|
||||
max_input_tokens: modelData.max_input_tokens,
|
||||
supports_image_input: modelData.supports_image_input,
|
||||
};
|
||||
}),
|
||||
getModelNames: (data: OllamaModelResponse[]) => data.map((m) => m.name),
|
||||
successMessage: (count: number) =>
|
||||
`Successfully fetched ${count} models from OpenRouter.`,
|
||||
},
|
||||
};
|
||||
|
||||
export const fetchModels = async (
|
||||
|
||||
@@ -19,14 +19,8 @@ import { localizeAndPrettify } from "@/lib/time";
|
||||
import { getDocsProcessedPerMinute } from "@/lib/indexAttempt";
|
||||
import { InfoIcon } from "@/components/icons/icons";
|
||||
import ExceptionTraceModal from "@/components/modals/ExceptionTraceModal";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipProvider,
|
||||
TooltipTrigger,
|
||||
} from "@/components/ui/tooltip";
|
||||
import { FaBarsProgress } from "react-icons/fa6";
|
||||
import SimpleTooltip from "@/refresh-components/SimpleTooltip";
|
||||
import SvgClock from "@/icons/clock";
|
||||
|
||||
export interface IndexingAttemptsTableProps {
|
||||
ccPair: CCPairFullInfo;
|
||||
@@ -96,6 +90,14 @@ export function IndexAttemptsTable({
|
||||
{indexAttempts.map((indexAttempt) => {
|
||||
const docsPerMinute =
|
||||
getDocsProcessedPerMinute(indexAttempt)?.toFixed(2);
|
||||
const isReindexInProgress =
|
||||
indexAttempt.status === "in_progress" ||
|
||||
indexAttempt.status === "not_started";
|
||||
const reindexTooltip = `This index attempt ${
|
||||
isReindexInProgress ? "is" : "was"
|
||||
} a full re-index. All documents from the source ${
|
||||
isReindexInProgress ? "are being" : "were"
|
||||
} synced into the system.`;
|
||||
return (
|
||||
<TableRow key={indexAttempt.id}>
|
||||
<TableCell>
|
||||
@@ -136,28 +138,11 @@ export function IndexAttemptsTable({
|
||||
<div className="flex items-center">
|
||||
{indexAttempt.total_docs_indexed}
|
||||
{indexAttempt.from_beginning && (
|
||||
<TooltipProvider>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<span className="cursor-help flex items-center">
|
||||
<FaBarsProgress className="ml-2 h-3.5 w-3.5" />
|
||||
</span>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>
|
||||
This index attempt{" "}
|
||||
{indexAttempt.status === "in_progress" ||
|
||||
indexAttempt.status === "not_started"
|
||||
? "is"
|
||||
: "was"}{" "}
|
||||
a full re-index. All documents from the source{" "}
|
||||
{indexAttempt.status === "in_progress" ||
|
||||
indexAttempt.status === "not_started"
|
||||
? "are being "
|
||||
: "were "}
|
||||
synced into the system.
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
<SimpleTooltip side="top" tooltip={reindexTooltip}>
|
||||
<span className="cursor-help flex items-center">
|
||||
<SvgClock className="ml-2 h-3.5 w-3.5 stroke-current" />
|
||||
</span>
|
||||
</SimpleTooltip>
|
||||
)}
|
||||
</div>
|
||||
</TableCell>
|
||||
|
||||
@@ -178,7 +178,6 @@ export const DocumentSetCreationForm = ({
|
||||
name="name"
|
||||
label="Name:"
|
||||
placeholder="A name for the document set"
|
||||
disabled={isUpdate}
|
||||
autoCompleteDisabled={true}
|
||||
/>
|
||||
<TextFormField
|
||||
|
||||
@@ -112,7 +112,7 @@ const EditRow = ({
|
||||
|
||||
if (!isEditable) {
|
||||
return (
|
||||
<div className="text-text-darkerfont-medium my-auto p-1">
|
||||
<div className="text-text-darker font-medium my-auto p-1">
|
||||
{documentSet.name}
|
||||
</div>
|
||||
);
|
||||
@@ -125,7 +125,7 @@ const EditRow = ({
|
||||
<TooltipTrigger asChild>
|
||||
<div
|
||||
className={`
|
||||
text-text-darkerfont-medium my-auto p-1 hover:bg-accent-background flex items-center select-none
|
||||
text-text-darker font-medium my-auto p-1 hover:bg-accent-background flex items-center select-none
|
||||
${documentSet.is_up_to_date ? "cursor-pointer" : "cursor-default"}
|
||||
`}
|
||||
style={{ wordBreak: "normal", overflowWrap: "break-word" }}
|
||||
|
||||
@@ -21,6 +21,7 @@ import { getDisplayNameForModel } from "@/lib/hooks";
|
||||
import { errorHandlingFetcher } from "@/lib/fetcher";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import SvgPlusCircle from "@/icons/plus-circle";
|
||||
import { NEXT_PUBLIC_CLOUD_ENABLED } from "@/lib/constants";
|
||||
|
||||
// Number of tokens to show cost calculation for
|
||||
const COST_CALCULATION_TOKENS = 1_000_000;
|
||||
@@ -281,10 +282,15 @@ const AdvancedEmbeddingFormPage = forwardRef<
|
||||
name="disable_rerank_for_streaming"
|
||||
/>
|
||||
<BooleanFormField
|
||||
subtext="Enable contextual RAG for all chunk sizes."
|
||||
subtext={
|
||||
NEXT_PUBLIC_CLOUD_ENABLED
|
||||
? "Contextual RAG disabled in Onyx Cloud"
|
||||
: "Enable contextual RAG for all chunk sizes."
|
||||
}
|
||||
optional
|
||||
label="Contextual RAG"
|
||||
name="enable_contextual_rag"
|
||||
disabled={NEXT_PUBLIC_CLOUD_ENABLED}
|
||||
/>
|
||||
<div>
|
||||
<SelectorFormField
|
||||
|
||||
@@ -197,20 +197,24 @@ export function SettingsForm() {
|
||||
}
|
||||
|
||||
function handleCompanyNameBlur() {
|
||||
updateSettingField([
|
||||
{ fieldName: "company_name", newValue: companyName || null },
|
||||
]);
|
||||
const originalValue = settings?.company_name || "";
|
||||
if (companyName !== originalValue) {
|
||||
updateSettingField([
|
||||
{ fieldName: "company_name", newValue: companyName || null },
|
||||
]);
|
||||
}
|
||||
}
|
||||
|
||||
// NOTE: at the moment there's a small bug where if you click another admin panel page after typing
|
||||
// the field doesn't update correctly
|
||||
function handleCompanyDescriptionBlur() {
|
||||
updateSettingField([
|
||||
{
|
||||
fieldName: "company_description",
|
||||
newValue: companyDescription || null,
|
||||
},
|
||||
]);
|
||||
const originalValue = settings?.company_description || "";
|
||||
if (companyDescription !== originalValue) {
|
||||
updateSettingField([
|
||||
{
|
||||
fieldName: "company_description",
|
||||
newValue: companyDescription || null,
|
||||
},
|
||||
]);
|
||||
}
|
||||
}
|
||||
|
||||
return (
|
||||
|
||||
256
web/src/app/auth/login/EmailPasswordForm.test.tsx
Normal file
256
web/src/app/auth/login/EmailPasswordForm.test.tsx
Normal file
@@ -0,0 +1,256 @@
|
||||
/**
|
||||
* Integration Test: Email/Password Authentication Workflow
|
||||
*
|
||||
* Tests the complete user journey for logging in.
|
||||
* This tests the full workflow: form → validation → API call → redirect
|
||||
*/
|
||||
import React from "react";
|
||||
import { render, screen, waitFor, setupUser } from "@tests/setup/test-utils";
|
||||
import EmailPasswordForm from "./EmailPasswordForm";
|
||||
|
||||
// Mock next/navigation (not used by this component, but required by dependencies)
|
||||
jest.mock("next/navigation", () => ({
|
||||
useRouter: () => ({
|
||||
push: jest.fn(),
|
||||
refresh: jest.fn(),
|
||||
}),
|
||||
}));
|
||||
|
||||
describe("Email/Password Login Workflow", () => {
|
||||
let fetchSpy: jest.SpyInstance;
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
fetchSpy = jest.spyOn(global, "fetch");
|
||||
// Mock window.location.href for redirect testing
|
||||
delete (window as any).location;
|
||||
window.location = { href: "" } as any;
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
fetchSpy.mockRestore();
|
||||
});
|
||||
|
||||
test("allows user to login with valid credentials", async () => {
|
||||
const user = setupUser();
|
||||
|
||||
// Mock POST /api/auth/login
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({}),
|
||||
} as Response);
|
||||
|
||||
render(<EmailPasswordForm isSignup={false} />);
|
||||
|
||||
// User fills out the form using placeholder text
|
||||
const emailInput = screen.getByPlaceholderText(/email@yourcompany.com/i);
|
||||
const passwordInput = screen.getByPlaceholderText(/\*/);
|
||||
|
||||
await user.type(emailInput, "test@example.com");
|
||||
await user.type(passwordInput, "password123");
|
||||
|
||||
// User submits the form
|
||||
const loginButton = screen.getByRole("button", { name: /log in/i });
|
||||
await user.click(loginButton);
|
||||
|
||||
// After successful login, user should be redirected to /chat
|
||||
await waitFor(() => {
|
||||
expect(window.location.href).toBe("/chat");
|
||||
});
|
||||
|
||||
// Verify API was called with correct credentials
|
||||
expect(fetchSpy).toHaveBeenCalledWith(
|
||||
"/api/auth/login",
|
||||
expect.objectContaining({
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/x-www-form-urlencoded",
|
||||
},
|
||||
})
|
||||
);
|
||||
|
||||
// Verify the request body contains email and password
|
||||
const callArgs = fetchSpy.mock.calls[0];
|
||||
const body = callArgs[1].body;
|
||||
expect(body.toString()).toContain("username=test%40example.com");
|
||||
expect(body.toString()).toContain("password=password123");
|
||||
});
|
||||
|
||||
test("shows error message when login fails", async () => {
|
||||
const user = setupUser();
|
||||
|
||||
// Mock POST /api/auth/login (failure)
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: false,
|
||||
status: 401,
|
||||
json: async () => ({ detail: "LOGIN_BAD_CREDENTIALS" }),
|
||||
} as Response);
|
||||
|
||||
render(<EmailPasswordForm isSignup={false} />);
|
||||
|
||||
// User fills out form with invalid credentials
|
||||
const emailInput = screen.getByPlaceholderText(/email@yourcompany.com/i);
|
||||
const passwordInput = screen.getByPlaceholderText(/\*/);
|
||||
|
||||
await user.type(emailInput, "wrong@example.com");
|
||||
await user.type(passwordInput, "wrongpassword");
|
||||
|
||||
// User submits
|
||||
const loginButton = screen.getByRole("button", { name: /log in/i });
|
||||
await user.click(loginButton);
|
||||
|
||||
// Verify error message is displayed
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText(/invalid email or password/i)
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
|
||||
describe("Email/Password Signup Workflow", () => {
|
||||
let fetchSpy: jest.SpyInstance;
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
fetchSpy = jest.spyOn(global, "fetch");
|
||||
// Mock window.location.href
|
||||
delete (window as any).location;
|
||||
window.location = { href: "" } as any;
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
fetchSpy.mockRestore();
|
||||
});
|
||||
|
||||
test("allows user to sign up and login with valid credentials", async () => {
|
||||
const user = setupUser();
|
||||
|
||||
// Mock POST /api/auth/register
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({}),
|
||||
} as Response);
|
||||
|
||||
// Mock POST /api/auth/login (after successful signup)
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({}),
|
||||
} as Response);
|
||||
|
||||
render(<EmailPasswordForm isSignup={true} />);
|
||||
|
||||
// User fills out the signup form
|
||||
const emailInput = screen.getByPlaceholderText(/email@yourcompany.com/i);
|
||||
const passwordInput = screen.getByPlaceholderText(/\*/);
|
||||
|
||||
await user.type(emailInput, "newuser@example.com");
|
||||
await user.type(passwordInput, "securepassword123");
|
||||
|
||||
// User submits the signup form
|
||||
const signupButton = screen.getByRole("button", { name: /sign up/i });
|
||||
await user.click(signupButton);
|
||||
|
||||
// Verify signup API was called
|
||||
await waitFor(() => {
|
||||
expect(fetchSpy).toHaveBeenCalledWith(
|
||||
"/api/auth/register",
|
||||
expect.objectContaining({
|
||||
method: "POST",
|
||||
headers: {
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
// Verify signup request body
|
||||
const signupCallArgs = fetchSpy.mock.calls[0];
|
||||
const signupBody = JSON.parse(signupCallArgs[1].body);
|
||||
expect(signupBody).toEqual({
|
||||
email: "newuser@example.com",
|
||||
username: "newuser@example.com",
|
||||
password: "securepassword123",
|
||||
referral_source: undefined,
|
||||
});
|
||||
|
||||
// Verify login API was called after successful signup
|
||||
await waitFor(() => {
|
||||
expect(fetchSpy).toHaveBeenCalledWith(
|
||||
"/api/auth/login",
|
||||
expect.objectContaining({
|
||||
method: "POST",
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
// Verify success message is shown
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText(/account created successfully/i)
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("shows error when email already exists", async () => {
|
||||
const user = setupUser();
|
||||
|
||||
// Mock POST /api/auth/register (failure - user exists)
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: false,
|
||||
status: 400,
|
||||
json: async () => ({ detail: "REGISTER_USER_ALREADY_EXISTS" }),
|
||||
} as Response);
|
||||
|
||||
render(<EmailPasswordForm isSignup={true} />);
|
||||
|
||||
// User fills out form with existing email
|
||||
const emailInput = screen.getByPlaceholderText(/email@yourcompany.com/i);
|
||||
const passwordInput = screen.getByPlaceholderText(/\*/);
|
||||
|
||||
await user.type(emailInput, "existing@example.com");
|
||||
await user.type(passwordInput, "password123");
|
||||
|
||||
// User submits
|
||||
const signupButton = screen.getByRole("button", { name: /sign up/i });
|
||||
await user.click(signupButton);
|
||||
|
||||
// Verify error message is displayed
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText(/an account already exists with the specified email/i)
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("shows rate limit error when too many requests", async () => {
|
||||
const user = setupUser();
|
||||
|
||||
// Mock POST /api/auth/register (failure - rate limit)
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: false,
|
||||
status: 429,
|
||||
json: async () => ({ detail: "Too many requests" }),
|
||||
} as Response);
|
||||
|
||||
render(<EmailPasswordForm isSignup={true} />);
|
||||
|
||||
// User fills out form
|
||||
const emailInput = screen.getByPlaceholderText(/email@yourcompany.com/i);
|
||||
const passwordInput = screen.getByPlaceholderText(/\*/);
|
||||
|
||||
await user.type(emailInput, "user@example.com");
|
||||
await user.type(passwordInput, "password123");
|
||||
|
||||
// User submits
|
||||
const signupButton = screen.getByRole("button", { name: /sign up/i });
|
||||
await user.click(signupButton);
|
||||
|
||||
// Verify rate limit message is displayed
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText(/too many requests\. please try again later/i)
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -43,6 +43,8 @@ export default function EmailPasswordForm({
|
||||
email: defaultEmail ? defaultEmail.toLowerCase() : "",
|
||||
password: "",
|
||||
}}
|
||||
validateOnChange={false}
|
||||
validateOnBlur={true}
|
||||
validationSchema={Yup.object().shape({
|
||||
email: Yup.string()
|
||||
.email()
|
||||
|
||||
@@ -397,18 +397,16 @@ function ChatInputBarInner({
|
||||
|
||||
<div className="w-full h-full flex flex-col shadow-01 bg-background-neutral-00 rounded-16">
|
||||
{currentMessageFiles.length > 0 && (
|
||||
<div className="p-spacing-inline bg-background-neutral-01 rounded-t-16">
|
||||
<div className="flex flex-wrap gap-2">
|
||||
{currentMessageFiles.map((file) => (
|
||||
<FileCard
|
||||
key={file.id}
|
||||
file={file}
|
||||
removeFile={handleRemoveMessageFile}
|
||||
hideProcessingState={hideProcessingState}
|
||||
onFileClick={handleFileClick}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
<div className="p-spacing-inline rounded-t-16 flex flex-wrap gap-spacing-interline">
|
||||
{currentMessageFiles.map((file) => (
|
||||
<FileCard
|
||||
key={file.id}
|
||||
file={file}
|
||||
removeFile={handleRemoveMessageFile}
|
||||
hideProcessingState={hideProcessingState}
|
||||
onFileClick={handleFileClick}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
)}
|
||||
|
||||
|
||||
393
web/src/app/chat/input-prompts/InputPrompts.test.tsx
Normal file
393
web/src/app/chat/input-prompts/InputPrompts.test.tsx
Normal file
@@ -0,0 +1,393 @@
|
||||
/**
|
||||
* Integration Test: Input Prompts CRUD Workflow
|
||||
*
|
||||
* Tests the complete user journey for managing prompt shortcuts.
|
||||
* This tests the full workflow: fetch → create → edit → delete
|
||||
*/
|
||||
import React from "react";
|
||||
import { render, screen, setupUser, waitFor } from "@tests/setup/test-utils";
|
||||
import InputPrompts from "./InputPrompts";
|
||||
|
||||
// Mock next/navigation for BackButton
|
||||
jest.mock("next/navigation", () => ({
|
||||
useRouter: () => ({
|
||||
push: jest.fn(),
|
||||
back: jest.fn(),
|
||||
refresh: jest.fn(),
|
||||
}),
|
||||
}));
|
||||
|
||||
describe("Input Prompts CRUD Workflow", () => {
|
||||
let fetchSpy: jest.SpyInstance;
|
||||
|
||||
beforeEach(() => {
|
||||
jest.clearAllMocks();
|
||||
fetchSpy = jest.spyOn(global, "fetch");
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
fetchSpy.mockRestore();
|
||||
});
|
||||
|
||||
test("fetches and displays existing prompts on load", async () => {
|
||||
// Mock GET /api/input_prompt
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => [
|
||||
{
|
||||
id: 1,
|
||||
prompt: "Summarize",
|
||||
content: "Summarize the uploaded document and highlight key points.",
|
||||
is_public: false,
|
||||
},
|
||||
{
|
||||
id: 2,
|
||||
prompt: "Explain",
|
||||
content: "Explain this concept in simple terms.",
|
||||
is_public: true,
|
||||
},
|
||||
],
|
||||
} as Response);
|
||||
|
||||
render(<InputPrompts />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(fetchSpy).toHaveBeenCalledWith("/api/input_prompt");
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("Summarize")).toBeInTheDocument();
|
||||
expect(screen.getByText("Explain")).toBeInTheDocument();
|
||||
expect(
|
||||
screen.getByText(
|
||||
/Summarize the uploaded document and highlight key points/i
|
||||
)
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("creates a new prompt successfully", async () => {
|
||||
const user = setupUser();
|
||||
|
||||
// Mock GET /api/input_prompt
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => [],
|
||||
} as Response);
|
||||
|
||||
// Mock POST /api/input_prompt
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
id: 3,
|
||||
prompt: "Review",
|
||||
content: "Review this code for potential improvements.",
|
||||
is_public: false,
|
||||
}),
|
||||
} as Response);
|
||||
|
||||
render(<InputPrompts />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(fetchSpy).toHaveBeenCalledWith("/api/input_prompt");
|
||||
});
|
||||
|
||||
const createButton = screen.getByRole("button", {
|
||||
name: /create new prompt/i,
|
||||
});
|
||||
await user.click(createButton);
|
||||
|
||||
expect(
|
||||
await screen.findByPlaceholderText(/prompt shortcut/i)
|
||||
).toBeInTheDocument();
|
||||
expect(screen.getByPlaceholderText(/actual prompt/i)).toBeInTheDocument();
|
||||
|
||||
const shortcutInput = screen.getByPlaceholderText(/prompt shortcut/i);
|
||||
const promptInput = screen.getByPlaceholderText(/actual prompt/i);
|
||||
|
||||
await user.type(shortcutInput, "Review");
|
||||
await user.type(
|
||||
promptInput,
|
||||
"Review this code for potential improvements."
|
||||
);
|
||||
|
||||
const submitButton = screen.getByRole("button", { name: /^create$/i });
|
||||
await user.click(submitButton);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(fetchSpy).toHaveBeenCalledWith(
|
||||
"/api/input_prompt",
|
||||
expect.objectContaining({
|
||||
method: "POST",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
const createCallArgs = fetchSpy.mock.calls[1]; // Second call (first was GET)
|
||||
const createBody = JSON.parse(createCallArgs[1].body);
|
||||
expect(createBody).toEqual({
|
||||
prompt: "Review",
|
||||
content: "Review this code for potential improvements.",
|
||||
is_public: false,
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText(/prompt created successfully/i)
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
expect(await screen.findByText("Review")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("edits an existing user-created prompt", async () => {
|
||||
const user = setupUser();
|
||||
|
||||
// Mock GET /api/input_prompt
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => [
|
||||
{
|
||||
id: 1,
|
||||
prompt: "Summarize",
|
||||
content: "Summarize the document.",
|
||||
is_public: false,
|
||||
},
|
||||
],
|
||||
} as Response);
|
||||
|
||||
// Mock PATCH /api/input_prompt/1
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({}),
|
||||
} as Response);
|
||||
|
||||
render(<InputPrompts />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("Summarize")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
const dropdownButtons = screen.getAllByRole("button");
|
||||
const moreButton = dropdownButtons.find(
|
||||
(btn) => btn.textContent === "" && btn.querySelector("svg")
|
||||
);
|
||||
expect(moreButton).toBeDefined();
|
||||
await user.click(moreButton!);
|
||||
|
||||
const editOption = await screen.findByRole("menuitem", { name: /edit/i });
|
||||
await user.click(editOption);
|
||||
|
||||
let textareas: HTMLElement[];
|
||||
await waitFor(() => {
|
||||
textareas = screen.getAllByRole("textbox");
|
||||
expect(textareas[0]).toHaveValue("Summarize");
|
||||
expect(textareas[1]).toHaveValue("Summarize the document.");
|
||||
});
|
||||
|
||||
await user.clear(textareas![1]);
|
||||
await user.type(
|
||||
textareas![1],
|
||||
"Summarize the document and provide key insights."
|
||||
);
|
||||
|
||||
const saveButton = screen.getByRole("button", { name: /save/i });
|
||||
await user.click(saveButton);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(fetchSpy).toHaveBeenCalledWith(
|
||||
"/api/input_prompt/1",
|
||||
expect.objectContaining({
|
||||
method: "PATCH",
|
||||
headers: { "Content-Type": "application/json" },
|
||||
})
|
||||
);
|
||||
});
|
||||
|
||||
const patchCallArgs = fetchSpy.mock.calls[1];
|
||||
const patchBody = JSON.parse(patchCallArgs[1].body);
|
||||
expect(patchBody).toEqual({
|
||||
prompt: "Summarize",
|
||||
content: "Summarize the document and provide key insights.",
|
||||
active: true,
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText(/prompt updated successfully/i)
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("deletes a user-created prompt", async () => {
|
||||
const user = setupUser();
|
||||
|
||||
// Mock GET /api/input_prompt
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => [
|
||||
{
|
||||
id: 1,
|
||||
prompt: "Summarize",
|
||||
content: "Summarize the document.",
|
||||
is_public: false,
|
||||
},
|
||||
],
|
||||
} as Response);
|
||||
|
||||
// Mock DELETE /api/input_prompt/1
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({}),
|
||||
} as Response);
|
||||
|
||||
render(<InputPrompts />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("Summarize")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
const dropdownButtons = screen.getAllByRole("button");
|
||||
const moreButton = dropdownButtons.find(
|
||||
(btn) => btn.textContent === "" && btn.querySelector("svg")
|
||||
);
|
||||
await user.click(moreButton!);
|
||||
|
||||
const deleteOption = await screen.findByRole("menuitem", {
|
||||
name: /delete/i,
|
||||
});
|
||||
await user.click(deleteOption);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(fetchSpy).toHaveBeenCalledWith("/api/input_prompt/1", {
|
||||
method: "DELETE",
|
||||
});
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText(/prompt deleted successfully/i)
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByText("Summarize")).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("hides a public prompt instead of deleting it", async () => {
|
||||
const user = setupUser();
|
||||
|
||||
// Mock GET /api/input_prompt
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => [
|
||||
{
|
||||
id: 2,
|
||||
prompt: "Explain",
|
||||
content: "Explain this concept.",
|
||||
is_public: true,
|
||||
},
|
||||
],
|
||||
} as Response);
|
||||
|
||||
// Mock POST /api/input_prompt/2/hide
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({}),
|
||||
} as Response);
|
||||
|
||||
render(<InputPrompts />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("Explain")).toBeInTheDocument();
|
||||
expect(screen.getByText("Built-in")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
const dropdownButtons = screen.getAllByRole("button");
|
||||
const moreButton = dropdownButtons.find(
|
||||
(btn) => btn.textContent === "" && btn.querySelector("svg")
|
||||
);
|
||||
await user.click(moreButton!);
|
||||
|
||||
// Edit option should NOT be shown for public prompts
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByRole("menuitem", { name: /delete/i })
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
expect(
|
||||
screen.queryByRole("menuitem", { name: /edit/i })
|
||||
).not.toBeInTheDocument();
|
||||
|
||||
// Public prompts use the hide endpoint instead of delete
|
||||
const deleteOption = screen.getByRole("menuitem", { name: /delete/i });
|
||||
await user.click(deleteOption);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(fetchSpy).toHaveBeenCalledWith("/api/input_prompt/2/hide", {
|
||||
method: "POST",
|
||||
});
|
||||
});
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText(/prompt hidden successfully/i)
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("shows error when fetch fails", async () => {
|
||||
// Mock GET /api/input_prompt (failure)
|
||||
fetchSpy.mockRejectedValueOnce(new Error("Network error"));
|
||||
|
||||
render(<InputPrompts />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(
|
||||
screen.getByText(/failed to fetch prompt shortcuts/i)
|
||||
).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("shows error when create fails", async () => {
|
||||
const user = setupUser();
|
||||
|
||||
// Mock GET /api/input_prompt
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => [],
|
||||
} as Response);
|
||||
|
||||
// Mock POST /api/input_prompt (failure)
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: false,
|
||||
status: 500,
|
||||
} as Response);
|
||||
|
||||
render(<InputPrompts />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(fetchSpy).toHaveBeenCalledWith("/api/input_prompt");
|
||||
});
|
||||
|
||||
const createButton = screen.getByRole("button", {
|
||||
name: /create new prompt/i,
|
||||
});
|
||||
await user.click(createButton);
|
||||
|
||||
const shortcutInput =
|
||||
await screen.findByPlaceholderText(/prompt shortcut/i);
|
||||
const promptInput = screen.getByPlaceholderText(/actual prompt/i);
|
||||
await user.type(shortcutInput, "Test");
|
||||
await user.type(promptInput, "Test content");
|
||||
|
||||
const submitButton = screen.getByRole("button", { name: /^create$/i });
|
||||
await user.click(submitButton);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText(/failed to create prompt/i)).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
@@ -121,10 +121,10 @@ export const preprocessLaTeX = (content: string) => {
|
||||
(_, equation) => `$$${equation}$$`
|
||||
);
|
||||
|
||||
// Replace inline LaTeX delimiters \( \) with $$ $$
|
||||
// Replace inline LaTeX delimiters \( \) with $ $
|
||||
const inlineProcessed = blockProcessed.replace(
|
||||
/\\\(([\s\S]*?)\\\)/g,
|
||||
(_, equation) => `$$${equation}$$`
|
||||
(_, equation) => `$${equation}$`
|
||||
);
|
||||
|
||||
// Restore original dollar signs in code contexts
|
||||
|
||||
@@ -223,14 +223,31 @@ export const ProjectsProvider: React.FC<ProjectsProviderProps> = ({
|
||||
|
||||
const renameProject = useCallback(
|
||||
async (projectId: number, name: string): Promise<Project> => {
|
||||
// Optimistically update the UI immediately
|
||||
setProjects((prev) =>
|
||||
prev.map((p) => (p.id === projectId ? { ...p, name } : p))
|
||||
);
|
||||
|
||||
if (currentProjectId === projectId) {
|
||||
setCurrentProjectDetails((prev) =>
|
||||
prev ? { ...prev, project: { ...prev.project, name } } : prev
|
||||
);
|
||||
}
|
||||
|
||||
try {
|
||||
const updated = await svcRenameProject(projectId, name);
|
||||
// Refresh to get canonical state from server
|
||||
await fetchProjects();
|
||||
if (currentProjectId === projectId) {
|
||||
await refreshCurrentProjectDetails();
|
||||
}
|
||||
return updated;
|
||||
} catch (err) {
|
||||
// Rollback optimistic update on failure
|
||||
await fetchProjects();
|
||||
if (currentProjectId === projectId) {
|
||||
await refreshCurrentProjectDetails();
|
||||
}
|
||||
const message =
|
||||
err instanceof Error ? err.message : "Failed to rename project";
|
||||
throw err;
|
||||
|
||||
@@ -12,7 +12,7 @@ import {
|
||||
SubLabel,
|
||||
TextFormField,
|
||||
} from "@/components/Field";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import Text from "@/components/ui/text";
|
||||
import { ImageUpload } from "./ImageUpload";
|
||||
import { AdvancedOptionsToggle } from "@/components/AdvancedOptionsToggle";
|
||||
@@ -158,8 +158,7 @@ export function WhitelabelingForm() {
|
||||
/>
|
||||
|
||||
<Button
|
||||
variant="destructive"
|
||||
size="sm"
|
||||
danger
|
||||
type="button"
|
||||
className="mb-8"
|
||||
onClick={async () => {
|
||||
@@ -304,8 +303,7 @@ export function WhitelabelingForm() {
|
||||
/>
|
||||
|
||||
<Button
|
||||
variant="destructive"
|
||||
size="sm"
|
||||
danger
|
||||
type="button"
|
||||
className="mb-8"
|
||||
onClick={async () => {
|
||||
|
||||
@@ -42,8 +42,8 @@
|
||||
--alpha-grey-100-20: #00000033;
|
||||
--alpha-grey-100-15: #00000026;
|
||||
--alpha-grey-100-10: #0000001a;
|
||||
--alpha-grey-100-5: #0000000d;
|
||||
--alpha-grey-100-0: #00000000;
|
||||
--alpha-grey-100-05: #0000000d;
|
||||
--alpha-grey-100-00: #00000000;
|
||||
|
||||
/* Alpha Grey 00 (White with opacity) */
|
||||
--alpha-grey-00-95: #fffffff2;
|
||||
@@ -64,8 +64,8 @@
|
||||
--alpha-grey-00-20: #ffffff33;
|
||||
--alpha-grey-00-15: #ffffff26;
|
||||
--alpha-grey-00-10: #ffffff1a;
|
||||
--alpha-grey-00-5: #ffffff0d;
|
||||
--alpha-grey-00-0: #ffffff00;
|
||||
--alpha-grey-00-05: #ffffff0d;
|
||||
--alpha-grey-00-00: #ffffff00;
|
||||
|
||||
/* Blue Scale */
|
||||
--blue-95: #040e25;
|
||||
|
||||
@@ -580,6 +580,28 @@ code[class*="language-"] {
|
||||
margin-bottom: 1rem;
|
||||
}
|
||||
|
||||
/* MASKING */
|
||||
|
||||
/*
|
||||
TODO: add correct class
|
||||
.mask-01 {
|
||||
background: linear-gradient(to top, var(--mask-01), transparent);
|
||||
-webkit-mask-image: -webkit-gradient(
|
||||
linear,
|
||||
left top,
|
||||
left bottom,
|
||||
from(var(--mask-01)),
|
||||
to(transparent)
|
||||
);
|
||||
mask-image: -webkit-gradient(
|
||||
linear,
|
||||
left top,
|
||||
left bottom,
|
||||
from(var(--mask-01)),
|
||||
to(transparent)
|
||||
);
|
||||
} */
|
||||
|
||||
/* DEBUGGING UTILITIES
|
||||
|
||||
If you ever want to highlight a component for debugging purposes, just type in `className="dbg-red ..."`, and a red box should appear around it.
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
import { usePaidEnterpriseFeaturesEnabled } from "@/components/settings/usePaidEnterpriseFeaturesEnabled";
|
||||
import React, { useState, useEffect } from "react";
|
||||
import { FormikProps, FieldArray, ArrayHelpers, ErrorMessage } from "formik";
|
||||
import Text from "@/components/ui/text";
|
||||
import { FiUsers } from "react-icons/fi";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import { UserGroup, UserRole } from "@/lib/types";
|
||||
import { useUserGroups } from "@/lib/hooks";
|
||||
import { BooleanFormField } from "@/components/Field";
|
||||
import { useUser } from "./user/UserProvider";
|
||||
import SvgUsers from "@/icons/users";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
export type IsPublicGroupSelectorFormType = {
|
||||
is_public: boolean;
|
||||
@@ -112,59 +114,48 @@ export const IsPublicGroupSelector = <T extends IsPublicGroupSelectorFormType>({
|
||||
userGroups &&
|
||||
userGroups?.length > 0 && (
|
||||
<>
|
||||
<div className="flex mt-4 gap-x-2 items-center">
|
||||
<div className="block font-medium text-base">
|
||||
<div className="flex flex-col gap-3 pt-4">
|
||||
<Text mainUiAction text05>
|
||||
Assign group access for this {objectName}
|
||||
</div>
|
||||
</div>
|
||||
{userGroupsIsLoading ? (
|
||||
<div className="animate-pulse bg-background-200 h-8 w-32 rounded"></div>
|
||||
) : (
|
||||
<Text className="mb-3">
|
||||
{isAdmin || !enforceGroupSelection ? (
|
||||
<>
|
||||
This {objectName} will be visible/accessible by the groups
|
||||
selected below
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
Curators must select one or more groups to give access to
|
||||
this {objectName}
|
||||
</>
|
||||
)}
|
||||
</Text>
|
||||
)}
|
||||
{userGroupsIsLoading ? (
|
||||
<div className="animate-pulse bg-background-200 h-8 w-32 rounded" />
|
||||
) : (
|
||||
<Text mainUiMuted text03>
|
||||
{isAdmin || !enforceGroupSelection ? (
|
||||
<>
|
||||
This {objectName} will be visible/accessible by the groups
|
||||
selected below
|
||||
</>
|
||||
) : (
|
||||
<>
|
||||
Curators must select one or more groups to give access to
|
||||
this {objectName}
|
||||
</>
|
||||
)}
|
||||
</Text>
|
||||
)}
|
||||
</div>
|
||||
<FieldArray
|
||||
name="groups"
|
||||
render={(arrayHelpers: ArrayHelpers) => (
|
||||
<div className="flex gap-2 flex-wrap mb-4">
|
||||
<div className="flex flex-wrap gap-2 py-4">
|
||||
{userGroupsIsLoading ? (
|
||||
<div className="animate-pulse bg-background-200 h-8 w-32 rounded"></div>
|
||||
<div className="animate-pulse bg-background-200 h-8 w-32 rounded" />
|
||||
) : (
|
||||
userGroups &&
|
||||
userGroups.map((userGroup: UserGroup) => {
|
||||
const ind = formikProps.values.groups.indexOf(
|
||||
userGroup.id
|
||||
);
|
||||
let isSelected = ind !== -1;
|
||||
const isSelected = ind !== -1;
|
||||
return (
|
||||
<div
|
||||
<Button
|
||||
key={userGroup.id}
|
||||
className={`
|
||||
px-3
|
||||
py-1
|
||||
rounded-lg
|
||||
border
|
||||
border-border
|
||||
w-fit
|
||||
flex
|
||||
cursor-pointer
|
||||
${
|
||||
isSelected
|
||||
? "bg-background-200"
|
||||
: "hover:bg-accent-background-hovered"
|
||||
}
|
||||
`}
|
||||
primary
|
||||
action={isSelected}
|
||||
type="button"
|
||||
leftIcon={SvgUsers}
|
||||
onClick={() => {
|
||||
if (isSelected) {
|
||||
arrayHelpers.remove(ind);
|
||||
@@ -173,11 +164,8 @@ export const IsPublicGroupSelector = <T extends IsPublicGroupSelectorFormType>({
|
||||
}
|
||||
}}
|
||||
>
|
||||
<div className="my-auto flex">
|
||||
<FiUsers className="my-auto mr-2" />{" "}
|
||||
{userGroup.name}
|
||||
</div>
|
||||
</div>
|
||||
{userGroup.name}
|
||||
</Button>
|
||||
);
|
||||
})
|
||||
)}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
"use client";
|
||||
|
||||
import ErrorPageLayout from "./ErrorPageLayout";
|
||||
import { fetchCustomerPortal } from "@/lib/billing/utils";
|
||||
import { useState } from "react";
|
||||
import ErrorPageLayout from "@/components/errorPages/ErrorPageLayout";
|
||||
import { fetchCustomerPortal } from "@/lib/billing/utils";
|
||||
import { useRouter } from "next/navigation";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
import { logout } from "@/lib/user";
|
||||
@@ -86,59 +86,63 @@ export default function AccessRestricted() {
|
||||
|
||||
return (
|
||||
<ErrorPageLayout>
|
||||
<div className="flex items-center gap-2 mb-4">
|
||||
<div className="flex items-center gap-spacing-interline">
|
||||
<Text headingH2>Access Restricted</Text>
|
||||
<SvgLock className="stroke-status-error-05 w-[1.5rem] h-[1.5rem]" />
|
||||
</div>
|
||||
<div className="space-y-4">
|
||||
<Text text03>
|
||||
We regret to inform you that your access to Onyx has been temporarily
|
||||
suspended due to a lapse in your subscription.
|
||||
</Text>
|
||||
<Text text03>
|
||||
To reinstate your access and continue benefiting from Onyx's
|
||||
powerful features, please update your payment information.
|
||||
</Text>
|
||||
<Text text03>
|
||||
If you're an admin, you can manage your subscription by clicking
|
||||
the button below. For other users, please reach out to your
|
||||
administrator to address this matter.
|
||||
</Text>
|
||||
<div className="flex flex-row gap-spacing-interline">
|
||||
<Button onClick={handleResubscribe} disabled={isLoading}>
|
||||
{isLoading ? "Loading..." : "Resubscribe"}
|
||||
</Button>
|
||||
<Button
|
||||
secondary
|
||||
onClick={handleManageSubscription}
|
||||
disabled={isLoading}
|
||||
>
|
||||
Manage Existing Subscription
|
||||
</Button>
|
||||
<Button
|
||||
secondary
|
||||
onClick={async () => {
|
||||
await logout();
|
||||
window.location.reload();
|
||||
}}
|
||||
>
|
||||
Log out
|
||||
</Button>
|
||||
</div>
|
||||
{error && <Text className="text-status-error-05">{error}</Text>}
|
||||
<Text text03>
|
||||
Need help? Join our{" "}
|
||||
<a
|
||||
className="text-action-link-05 hover:text-action-link-06"
|
||||
href="https://discord.gg/4NA5SbzrWb"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
>
|
||||
Discord community
|
||||
</a>{" "}
|
||||
for support.
|
||||
</Text>
|
||||
|
||||
<Text text03>
|
||||
We regret to inform you that your access to Onyx has been temporarily
|
||||
suspended due to a lapse in your subscription.
|
||||
</Text>
|
||||
|
||||
<Text text03>
|
||||
To reinstate your access and continue benefiting from Onyx's
|
||||
powerful features, please update your payment information.
|
||||
</Text>
|
||||
|
||||
<Text text03>
|
||||
If you're an admin, you can manage your subscription by clicking
|
||||
the button below. For other users, please reach out to your
|
||||
administrator to address this matter.
|
||||
</Text>
|
||||
|
||||
<div className="flex flex-row gap-spacing-interline">
|
||||
<Button onClick={handleResubscribe} disabled={isLoading}>
|
||||
{isLoading ? "Loading..." : "Resubscribe"}
|
||||
</Button>
|
||||
<Button
|
||||
secondary
|
||||
onClick={handleManageSubscription}
|
||||
disabled={isLoading}
|
||||
>
|
||||
Manage Existing Subscription
|
||||
</Button>
|
||||
<Button
|
||||
secondary
|
||||
onClick={async () => {
|
||||
await logout();
|
||||
window.location.reload();
|
||||
}}
|
||||
>
|
||||
Log out
|
||||
</Button>
|
||||
</div>
|
||||
|
||||
{error && <Text className="text-status-error-05">{error}</Text>}
|
||||
|
||||
<Text text03>
|
||||
Need help? Join our{" "}
|
||||
<a
|
||||
className="text-action-link-05 hover:text-action-link-06"
|
||||
href="https://discord.gg/4NA5SbzrWb"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
>
|
||||
Discord community
|
||||
</a>{" "}
|
||||
for support.
|
||||
</Text>
|
||||
</ErrorPageLayout>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,22 +1,20 @@
|
||||
"use client";
|
||||
import ErrorPageLayout from "./ErrorPageLayout";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import ErrorPageLayout from "@/components/errorPages/ErrorPageLayout";
|
||||
|
||||
export default function CloudError() {
|
||||
return (
|
||||
<ErrorPageLayout>
|
||||
<h1 className="text-2xl font-semibold mb-4 text-gray-800 dark:text-gray-200">
|
||||
Maintenance in Progress
|
||||
</h1>
|
||||
<div className="space-y-4 text-gray-600 dark:text-gray-300">
|
||||
<p>
|
||||
Onyx is currently in a maintenance window. Please check back in a
|
||||
couple of minutes.
|
||||
</p>
|
||||
<p>
|
||||
We apologize for any inconvenience this may cause and appreciate your
|
||||
patience.
|
||||
</p>
|
||||
</div>
|
||||
<Text headingH2>Maintenance in Progress</Text>
|
||||
|
||||
<Text text03>
|
||||
Onyx is currently in a maintenance window. Please check back in a couple
|
||||
of minutes.
|
||||
</Text>
|
||||
|
||||
<Text text03>
|
||||
We apologize for any inconvenience this may cause and appreciate your
|
||||
patience.
|
||||
</Text>
|
||||
</ErrorPageLayout>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,47 +1,46 @@
|
||||
import ErrorPageLayout from "./ErrorPageLayout";
|
||||
import ErrorPageLayout from "@/components/errorPages/ErrorPageLayout";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import SvgAlertCircle from "@/icons/alert-circle";
|
||||
|
||||
export default function Error() {
|
||||
return (
|
||||
<ErrorPageLayout>
|
||||
<div className="flex items-center gap-2 mb-4 ">
|
||||
<Text headingH2 inverted>
|
||||
We encountered an issue
|
||||
</Text>
|
||||
<SvgAlertCircle className="w-[1.5rem] h-[1.5rem] stroke-text-inverted-04" />
|
||||
</div>
|
||||
<div className="space-y-4 text-gray-600 dark:text-gray-300">
|
||||
<Text inverted>
|
||||
It seems there was a problem loading your Onyx settings. This could be
|
||||
due to a configuration issue or incomplete setup.
|
||||
</Text>
|
||||
<Text inverted>
|
||||
If you're an admin, please review our{" "}
|
||||
<a
|
||||
className="text-blue-500 hover:text-blue-700 dark:text-blue-400 dark:hover:text-blue-300"
|
||||
href="https://docs.onyx.app/?utm_source=app&utm_medium=error_page&utm_campaign=config_error"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
>
|
||||
documentation
|
||||
</a>{" "}
|
||||
for proper configuration steps. If you're a user, please contact
|
||||
your admin for assistance.
|
||||
</Text>
|
||||
<Text inverted>
|
||||
Need help? Join our{" "}
|
||||
<a
|
||||
className="text-blue-500 hover:text-blue-700 dark:text-blue-400 dark:hover:text-blue-300"
|
||||
href="https://discord.gg/4NA5SbzrWb"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
>
|
||||
Discord community
|
||||
</a>{" "}
|
||||
for support.
|
||||
</Text>
|
||||
<div className="flex flex-row items-center gap-spacing-interline">
|
||||
<Text headingH2>We encountered an issue</Text>
|
||||
<SvgAlertCircle className="w-[1.5rem] h-[1.5rem] stroke-text-04" />
|
||||
</div>
|
||||
|
||||
<Text text03>
|
||||
It seems there was a problem loading your Onyx settings. This could be
|
||||
due to a configuration issue or incomplete setup.
|
||||
</Text>
|
||||
|
||||
<Text text03>
|
||||
If you're an admin, please review our{" "}
|
||||
<a
|
||||
className="text-action-link-05"
|
||||
href="https://docs.onyx.app/?utm_source=app&utm_medium=error_page&utm_campaign=config_error"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
>
|
||||
documentation
|
||||
</a>{" "}
|
||||
for proper configuration steps. If you're a user, please contact
|
||||
your admin for assistance.
|
||||
</Text>
|
||||
|
||||
<Text text03>
|
||||
Need help? Join our{" "}
|
||||
<a
|
||||
className="text-action-link-05"
|
||||
href="https://discord.gg/4NA5SbzrWb"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
>
|
||||
Discord community
|
||||
</a>{" "}
|
||||
for support.
|
||||
</Text>
|
||||
</ErrorPageLayout>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import React from "react";
|
||||
import { LogoType } from "../logo/Logo";
|
||||
import { OnyxLogoTypeIcon } from "@/components/icons/icons";
|
||||
|
||||
interface ErrorPageLayoutProps {
|
||||
children: React.ReactNode;
|
||||
@@ -7,12 +7,10 @@ interface ErrorPageLayoutProps {
|
||||
|
||||
export default function ErrorPageLayout({ children }: ErrorPageLayoutProps) {
|
||||
return (
|
||||
<div className="flex flex-col items-center justify-center min-h-screen">
|
||||
<div className="mb-4 flex items-center max-w-[220px]">
|
||||
<LogoType size="large" />
|
||||
</div>
|
||||
<div className="max-w-xl border border-border w-full bg-white shadow-md rounded-lg overflow-hidden">
|
||||
<div className="p-6 sm:p-8">{children}</div>
|
||||
<div className="flex flex-col items-center justify-center w-full h-screen gap-spacing-paragraph">
|
||||
<OnyxLogoTypeIcon size={120} className="" />
|
||||
<div className="max-w-[40rem] w-full border bg-background-neutral-00 shadow-02 rounded-16 p-padding-content flex flex-col gap-spacing-paragraph">
|
||||
{children}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
|
||||
@@ -54,12 +54,9 @@ const TooltipContent = React.forwardRef<
|
||||
sideOffset={sideOffset}
|
||||
side={side}
|
||||
className={cn(
|
||||
`z-[100] rounded-08 ${
|
||||
backgroundColor || "bg-background-neutral-inverted-03"
|
||||
}
|
||||
${width}
|
||||
text-wrap
|
||||
p-padding-button shadow-lg animate-in fade-in-0 zoom-in-95 data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=closed]:zoom-out-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2`,
|
||||
"z-[100] rounded-08 text-text-inverted-05 text-wrap p-padding-button shadow-lg animate-in fade-in-0 zoom-in-95 data-[state=closed]:animate-out data-[state=closed]:fade-out-0 data-[state=closed]:zoom-out-95 data-[side=bottom]:slide-in-from-top-2 data-[side=left]:slide-in-from-right-2 data-[side=right]:slide-in-from-left-2 data-[side=top]:slide-in-from-bottom-2",
|
||||
backgroundColor || "bg-background-neutral-inverted-03",
|
||||
width,
|
||||
className
|
||||
)}
|
||||
{...props}
|
||||
|
||||
45
web/src/icons/onyx-logo.tsx
Normal file
45
web/src/icons/onyx-logo.tsx
Normal file
@@ -0,0 +1,45 @@
|
||||
import * as React from "react";
|
||||
import type { SVGProps } from "react";
|
||||
const OnyxLogo = (props: SVGProps<SVGSVGElement>) => (
|
||||
<svg
|
||||
width="16"
|
||||
height="16"
|
||||
viewBox="0 0 16 16"
|
||||
fill="none"
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
{...props}
|
||||
>
|
||||
<g clipPath="url(#clip0_586_577)">
|
||||
<path
|
||||
d="M8 4.00001L4.5 2.50002L8 1.00002L11.5 2.50002L8 4.00001Z"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
<path
|
||||
d="M8 12L11.5 13.5L8 15L4.5 13.5L8 12Z"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
<path
|
||||
d="M4 8L2.5 11.5L1 8L2.5 4.50002L4 8Z"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
<path
|
||||
d="M12 8.00002L13.5 4.50002L15 8.00001L13.5 11.5L12 8.00002Z"
|
||||
strokeWidth={1.5}
|
||||
strokeLinecap="round"
|
||||
strokeLinejoin="round"
|
||||
/>
|
||||
</g>
|
||||
<defs>
|
||||
<clipPath id="clip0_586_577">
|
||||
<rect width={16} height={16} fill="white" />
|
||||
</clipPath>
|
||||
</defs>
|
||||
</svg>
|
||||
);
|
||||
export default OnyxLogo;
|
||||
@@ -2,11 +2,7 @@ import React from "react";
|
||||
import crypto from "crypto";
|
||||
import { MinimalPersonaSnapshot } from "@/app/admin/assistants/interfaces";
|
||||
import { buildImgUrl } from "@/app/chat/components/files/images/utils";
|
||||
import {
|
||||
ArtAsistantIcon,
|
||||
GeneralAssistantIcon,
|
||||
OnyxIcon,
|
||||
} from "@/components/icons/icons";
|
||||
import { ArtAsistantIcon, OnyxIcon } from "@/components/icons/icons";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
@@ -15,6 +11,7 @@ import {
|
||||
} from "@/components/ui/tooltip";
|
||||
import { cn } from "@/lib/utils";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import { useSettingsContext } from "@/components/settings/SettingsProvider";
|
||||
|
||||
function md5ToBits(str: string): number[] {
|
||||
const md5hex = crypto.createHash("md5").update(str).digest("hex");
|
||||
@@ -94,17 +91,35 @@ export interface AgentIconProps {
|
||||
}
|
||||
|
||||
export function AgentIcon({ agent, size = 24 }: AgentIconProps) {
|
||||
const settings = useSettingsContext();
|
||||
|
||||
// Check if whitelabeling is enabled for the default assistant
|
||||
const shouldUseWhitelabelLogo =
|
||||
agent.id === 0 && settings?.enterpriseSettings?.use_custom_logo === true;
|
||||
|
||||
return (
|
||||
<TooltipProvider>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<div className="text-text-04">
|
||||
{agent.id == -3 ? (
|
||||
{agent.id === -3 ? (
|
||||
<ArtAsistantIcon size={size} />
|
||||
) : agent.id == 0 ? (
|
||||
<OnyxIcon size={size} />
|
||||
) : agent.id == -1 ? (
|
||||
<GeneralAssistantIcon size={size} />
|
||||
) : agent.id === 0 ? (
|
||||
shouldUseWhitelabelLogo ? (
|
||||
<img
|
||||
alt="Logo"
|
||||
src="/api/enterprise-settings/logo"
|
||||
loading="lazy"
|
||||
className={cn(
|
||||
"rounded-full object-cover object-center transition-opacity duration-300"
|
||||
)}
|
||||
width={size}
|
||||
height={size}
|
||||
style={{ objectFit: "contain" }}
|
||||
/>
|
||||
) : (
|
||||
<OnyxIcon size={size} />
|
||||
)
|
||||
) : agent.uploaded_image_id ? (
|
||||
<img
|
||||
alt={agent.name}
|
||||
|
||||
46
web/src/refresh-components/Logo.tsx
Normal file
46
web/src/refresh-components/Logo.tsx
Normal file
@@ -0,0 +1,46 @@
|
||||
import { OnyxIcon, OnyxLogoTypeIcon } from "@/components/icons/icons";
|
||||
import { useSettingsContext } from "@/components/settings/SettingsProvider";
|
||||
import { NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED } from "@/lib/constants";
|
||||
import { cn } from "@/lib/utils";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
|
||||
const FOLDED_SIZE = 24;
|
||||
|
||||
export interface LogoProps {
|
||||
folded?: boolean;
|
||||
className?: string;
|
||||
}
|
||||
|
||||
export default function Logo({ folded, className }: LogoProps) {
|
||||
const settings = useSettingsContext();
|
||||
|
||||
const logo = settings.enterpriseSettings?.use_custom_logo ? (
|
||||
<img
|
||||
src="/api/enterprise-settings/logo"
|
||||
alt="Logo"
|
||||
style={{ objectFit: "contain", height: FOLDED_SIZE, width: FOLDED_SIZE }}
|
||||
/>
|
||||
) : (
|
||||
<OnyxIcon size={FOLDED_SIZE} className={cn("flex-shrink-0", className)} />
|
||||
);
|
||||
|
||||
if (folded) return logo;
|
||||
|
||||
return settings.enterpriseSettings?.application_name ? (
|
||||
<div className="flex flex-col">
|
||||
<div className="flex flex-row items-center gap-spacing-interline">
|
||||
{logo}
|
||||
<Text headingH3 className="break-all line-clamp-2">
|
||||
{settings.enterpriseSettings?.application_name}
|
||||
</Text>
|
||||
</div>
|
||||
{!NEXT_PUBLIC_DO_NOT_USE_TOGGLE_OFF_DANSWER_POWERED && (
|
||||
<Text secondaryBody text03 className="ml-[33px]">
|
||||
Powered by Onyx
|
||||
</Text>
|
||||
)}
|
||||
</div>
|
||||
) : (
|
||||
<OnyxLogoTypeIcon size={88} className={className} />
|
||||
);
|
||||
}
|
||||
@@ -58,10 +58,17 @@ export default function VerticalShadowScroller({
|
||||
|
||||
{showBottomShadow && (
|
||||
<div
|
||||
className="absolute bottom-0 left-0 right-0 h-[2rem] pointer-events-none z-[100] dbg-rd"
|
||||
style={{
|
||||
background: "linear-gradient(to top, var(--mask-01), transparent)",
|
||||
}}
|
||||
className={cn(
|
||||
"absolute",
|
||||
"bottom-0",
|
||||
"left-0",
|
||||
"right-0",
|
||||
"h-[2rem]",
|
||||
"pointer-events-none",
|
||||
"z-[100]"
|
||||
// TODO: add masking to match mocks
|
||||
// "mask-01"
|
||||
)}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@@ -7,6 +7,7 @@ export default function CreateButton({
|
||||
href,
|
||||
onClick,
|
||||
children,
|
||||
type = "button",
|
||||
...props
|
||||
}: ButtonProps) {
|
||||
return (
|
||||
@@ -15,6 +16,7 @@ export default function CreateButton({
|
||||
onClick={onClick}
|
||||
leftIcon={SvgPlusCircle}
|
||||
href={href}
|
||||
type={type}
|
||||
{...props}
|
||||
>
|
||||
{children || "Create"}
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import type { HTMLAttributes } from "react";
|
||||
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
const fonts = {
|
||||
@@ -46,7 +48,7 @@ const colors = {
|
||||
},
|
||||
};
|
||||
|
||||
export interface TextProps extends React.HTMLAttributes<HTMLElement> {
|
||||
export interface TextProps extends HTMLAttributes<HTMLParagraphElement> {
|
||||
nowrap?: boolean;
|
||||
|
||||
// Fonts
|
||||
@@ -106,6 +108,7 @@ export default function Text({
|
||||
inverted,
|
||||
children,
|
||||
className,
|
||||
...rest
|
||||
}: TextProps) {
|
||||
const font = headingH1
|
||||
? "headingH1"
|
||||
@@ -159,6 +162,7 @@ export default function Text({
|
||||
|
||||
return (
|
||||
<p
|
||||
{...rest}
|
||||
className={cn(
|
||||
fonts[font],
|
||||
inverted ? colors.inverted[color] : colors[color],
|
||||
|
||||
@@ -37,12 +37,13 @@ import {
|
||||
SearchIcon,
|
||||
DocumentIcon2,
|
||||
BrainIcon,
|
||||
OnyxSparkleIcon,
|
||||
} from "@/components/icons/icons";
|
||||
import OnyxLogo from "@/icons/onyx-logo";
|
||||
import { CombinedSettings } from "@/app/admin/settings/interfaces";
|
||||
import { FiActivity, FiBarChart2 } from "react-icons/fi";
|
||||
import SidebarTab from "@/refresh-components/buttons/SidebarTab";
|
||||
import VerticalShadowScroller from "@/refresh-components/VerticalShadowScroller";
|
||||
import { cn } from "@/lib/utils";
|
||||
|
||||
const connectors_items = () => [
|
||||
{
|
||||
@@ -160,7 +161,7 @@ const collections = (
|
||||
items: [
|
||||
{
|
||||
name: "Default Assistant",
|
||||
icon: OnyxSparkleIcon,
|
||||
icon: OnyxLogo,
|
||||
link: "/admin/configuration/default-assistant",
|
||||
},
|
||||
{
|
||||
@@ -168,12 +169,16 @@ const collections = (
|
||||
icon: CpuIconSkeleton,
|
||||
link: "/admin/configuration/llm",
|
||||
},
|
||||
{
|
||||
error: settings?.settings.needs_reindexing,
|
||||
name: "Search Settings",
|
||||
icon: SearchIcon,
|
||||
link: "/admin/configuration/search",
|
||||
},
|
||||
...(!enableCloud
|
||||
? [
|
||||
{
|
||||
error: settings?.settings.needs_reindexing,
|
||||
name: "Search Settings",
|
||||
icon: SearchIcon,
|
||||
link: "/admin/configuration/search",
|
||||
},
|
||||
]
|
||||
: []),
|
||||
{
|
||||
name: "Document Processing",
|
||||
icon: DocumentIcon2,
|
||||
@@ -357,7 +362,14 @@ export default function AdminSidebar({
|
||||
))}
|
||||
</VerticalShadowScroller>
|
||||
|
||||
<div className="flex flex-col px-spacing-interline gap-spacing-interline">
|
||||
<div
|
||||
className={cn(
|
||||
"flex flex-col",
|
||||
"px-spacing-interline",
|
||||
"pt-spacing-interline",
|
||||
"gap-spacing-interline"
|
||||
)}
|
||||
>
|
||||
{combinedSettings.webVersion && (
|
||||
<Text text02 secondaryBody className="px-spacing-interline">
|
||||
{`Onyx version: ${combinedSettings.webVersion}`}
|
||||
|
||||
@@ -349,7 +349,15 @@ function AppSidebarInner() {
|
||||
)}
|
||||
|
||||
<SidebarWrapper folded={folded} setFolded={setFolded}>
|
||||
<div className="flex flex-col px-spacing-interline gap-spacing-interline">
|
||||
<div
|
||||
className={cn(
|
||||
"flex flex-col",
|
||||
"px-spacing-interline",
|
||||
"gap-spacing-interline",
|
||||
"pt-spacing-paragraph",
|
||||
"pb-spacing-paragraph"
|
||||
)}
|
||||
>
|
||||
<div data-testid="AppSidebar/new-session">
|
||||
<SidebarTab
|
||||
leftIcon={SvgEditBig}
|
||||
@@ -462,7 +470,7 @@ function AppSidebarInner() {
|
||||
)}
|
||||
</VerticalShadowScroller>
|
||||
|
||||
<div className="px-spacing-interline">
|
||||
<div className="px-spacing-interline pt-spacing-interline-mini">
|
||||
<Settings folded={folded} />
|
||||
</div>
|
||||
</SidebarWrapper>
|
||||
|
||||
@@ -28,13 +28,15 @@ export default function ButtonRenaming({
|
||||
return;
|
||||
}
|
||||
|
||||
// Close immediately for instant feedback
|
||||
onClose();
|
||||
|
||||
// Proceed with the rename operation after closing
|
||||
try {
|
||||
await onRename(newName);
|
||||
} catch (error) {
|
||||
console.error("Failed to rename:", error);
|
||||
}
|
||||
|
||||
onClose();
|
||||
}
|
||||
|
||||
return (
|
||||
|
||||
@@ -42,7 +42,6 @@ function ProjectFolderButtonInner({ project }: ProjectFolderProps) {
|
||||
useState(false);
|
||||
const { renameProject, deleteProject } = useProjectsContext();
|
||||
const [isEditing, setIsEditing] = useState(false);
|
||||
const [name, setName] = useState(project.name);
|
||||
const [popoverOpen, setPopoverOpen] = useState(false);
|
||||
const [isHoveringIcon, setIsHoveringIcon] = useState(false);
|
||||
|
||||
@@ -79,7 +78,6 @@ function ProjectFolderButtonInner({ project }: ProjectFolderProps) {
|
||||
|
||||
async function handleRename(newName: string) {
|
||||
await renameProject(project.id, newName);
|
||||
setName(newName);
|
||||
}
|
||||
|
||||
const popoverItems = [
|
||||
@@ -180,7 +178,7 @@ function ProjectFolderButtonInner({ project }: ProjectFolderProps) {
|
||||
onClose={() => setIsEditing(false)}
|
||||
/>
|
||||
) : (
|
||||
name
|
||||
project.name
|
||||
)}
|
||||
</SidebarTab>
|
||||
</PopoverAnchor>
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
import React, { Dispatch, SetStateAction } from "react";
|
||||
import React, { Dispatch, SetStateAction, useMemo } from "react";
|
||||
import { cn } from "@/lib/utils";
|
||||
import { OnyxIcon, OnyxLogoTypeIcon } from "@/components/icons/icons";
|
||||
import IconButton from "@/refresh-components/buttons/IconButton";
|
||||
import SvgSidebar from "@/icons/sidebar";
|
||||
import Logo from "@/refresh-components/Logo";
|
||||
|
||||
interface SidebarWrapperProps {
|
||||
export interface SidebarWrapperProps {
|
||||
folded?: boolean;
|
||||
setFolded?: Dispatch<SetStateAction<boolean>>;
|
||||
children: React.ReactNode;
|
||||
children?: React.ReactNode;
|
||||
}
|
||||
|
||||
export default function SidebarWrapper({
|
||||
@@ -15,25 +15,47 @@ export default function SidebarWrapper({
|
||||
setFolded,
|
||||
children,
|
||||
}: SidebarWrapperProps) {
|
||||
const logo = useMemo(
|
||||
() => (
|
||||
<Logo
|
||||
folded={folded}
|
||||
className={cn(folded && "visible group-hover/SidebarWrapper:hidden")}
|
||||
/>
|
||||
),
|
||||
[folded]
|
||||
);
|
||||
|
||||
return (
|
||||
// This extra `div` wrapping needs to be present (for some reason).
|
||||
// Without, the widths of the sidebars don't properly get set to the explicitly declared widths (i.e., `4rem` folded and `15rem` unfolded).
|
||||
<div>
|
||||
<div
|
||||
className={cn(
|
||||
"h-screen flex flex-col bg-background-tint-02 py-spacing-interline justify-between gap-padding-content group/SidebarWrapper",
|
||||
"h-screen",
|
||||
"flex flex-col",
|
||||
"py-spacing-interline",
|
||||
"bg-background-tint-02",
|
||||
"justify-between",
|
||||
"group/SidebarWrapper",
|
||||
folded ? "w-[4rem]" : "w-[15rem]"
|
||||
)}
|
||||
>
|
||||
<div
|
||||
className={cn(
|
||||
"flex flex-row items-center px-spacing-paragraph py-spacing-inline flex-shrink-0",
|
||||
"flex",
|
||||
"flex-row",
|
||||
"items-center",
|
||||
"px-spacing-paragraph",
|
||||
"pt-spacing-interline-mini",
|
||||
"pb-spacing-paragraph",
|
||||
"flex-shrink-0",
|
||||
folded ? "justify-center" : "justify-between"
|
||||
)}
|
||||
>
|
||||
{folded ? (
|
||||
<div className="h-[2rem] flex flex-col justify-center items-center">
|
||||
<>
|
||||
{logo}
|
||||
<IconButton
|
||||
icon={SvgSidebar}
|
||||
tertiary
|
||||
@@ -41,15 +63,11 @@ export default function SidebarWrapper({
|
||||
className="hidden group-hover/SidebarWrapper:flex"
|
||||
tooltip="Close Sidebar"
|
||||
/>
|
||||
<OnyxIcon
|
||||
size={24}
|
||||
className="visible group-hover/SidebarWrapper:hidden"
|
||||
/>
|
||||
</>
|
||||
</div>
|
||||
) : (
|
||||
<>
|
||||
<OnyxLogoTypeIcon size={88} />
|
||||
{logo}
|
||||
<IconButton
|
||||
icon={SvgSidebar}
|
||||
tertiary
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
||||
import { useContext, ReactNode } from "react";
|
||||
import { LogoComponent } from "@/components/logo/FixedLogo";
|
||||
import { ReactNode } from "react";
|
||||
import { SvgProps } from "@/icons";
|
||||
import SidebarTab from "@/refresh-components/buttons/SidebarTab";
|
||||
import SidebarWrapper from "@/sections/sidebar/SidebarWrapper";
|
||||
|
||||
interface StepSidebarProps {
|
||||
export interface StepSidebarProps {
|
||||
children: ReactNode;
|
||||
buttonName: string;
|
||||
buttonIcon: React.FunctionComponent<SvgProps>;
|
||||
@@ -17,30 +16,19 @@ export default function StepSidebar({
|
||||
buttonIcon,
|
||||
buttonHref,
|
||||
}: StepSidebarProps) {
|
||||
const combinedSettings = useContext(SettingsContext);
|
||||
if (!combinedSettings) {
|
||||
return null;
|
||||
}
|
||||
|
||||
const enterpriseSettings = combinedSettings.enterpriseSettings;
|
||||
|
||||
return (
|
||||
<div className="fixed left-0 top-0 flex flex-col h-screen w-[15rem] bg-background-tint-02 py-padding-content px-padding-button gap-padding-content z-10">
|
||||
<div className="flex flex-col items-start justify-center">
|
||||
<LogoComponent enterpriseSettings={enterpriseSettings} />
|
||||
<SidebarWrapper>
|
||||
<div className="px-spacing-interline">
|
||||
<SidebarTab
|
||||
leftIcon={buttonIcon}
|
||||
className="bg-background-tint-00"
|
||||
href={buttonHref}
|
||||
>
|
||||
{buttonName}
|
||||
</SidebarTab>
|
||||
</div>
|
||||
|
||||
<SidebarTab
|
||||
leftIcon={buttonIcon}
|
||||
className="bg-background-tint-00"
|
||||
href={buttonHref}
|
||||
>
|
||||
{buttonName}
|
||||
</SidebarTab>
|
||||
|
||||
<div className="h-full flex">
|
||||
<div className="w-full px-2">{children}</div>
|
||||
</div>
|
||||
</div>
|
||||
<div className="h-full w-full px-spacing-paragraph">{children}</div>
|
||||
</SidebarWrapper>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -219,10 +219,11 @@ module.exports = {
|
||||
"token-function": "var(--token-function)",
|
||||
"token-regex": "var(--token-regex)",
|
||||
"token-attr-name": "var(--token-attr-name)",
|
||||
// "non-selectable": "var(--non-selectable)",
|
||||
},
|
||||
boxShadow: {
|
||||
"01": "0px 2px 8px 0px var(--shadow-02)",
|
||||
"01": "0px 2px 8px 0px var(--shadow-01)",
|
||||
"02": "0px 2px 8px 0px var(--shadow-02)",
|
||||
"03": "0px 2px 8px 0px var(--shadow-03)",
|
||||
|
||||
// light
|
||||
"tremor-input": "0 1px 2px 0 rgb(0 0 0 / 0.05)",
|
||||
|
||||
788
web/tests/README.md
Normal file
788
web/tests/README.md
Normal file
@@ -0,0 +1,788 @@
|
||||
# React Integration Testing Guide
|
||||
|
||||
Comprehensive guide for writing integration tests in the Onyx web application using Jest and React Testing Library.
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Running Tests](#running-tests)
|
||||
- [Core Concepts](#core-concepts)
|
||||
- [Writing Tests](#writing-tests)
|
||||
- [Query Selectors](#query-selectors)
|
||||
- [User Interactions](#user-interactions)
|
||||
- [Async Operations](#async-operations)
|
||||
- [Mocking](#mocking)
|
||||
- [Common Patterns](#common-patterns)
|
||||
- [Testing Philosophy](#testing-philosophy)
|
||||
- [Troubleshooting](#troubleshooting)
|
||||
|
||||
## Running Tests
|
||||
|
||||
```bash
|
||||
# Run all tests
|
||||
npm test
|
||||
|
||||
# Run specific test file
|
||||
npm test -- EmailPasswordForm.test
|
||||
|
||||
# Run tests matching pattern
|
||||
npm test -- --testPathPattern="auth"
|
||||
|
||||
# Run without coverage
|
||||
npm test -- --no-coverage
|
||||
|
||||
# Run in watch mode
|
||||
npm test -- --watch
|
||||
|
||||
# Run with verbose output
|
||||
npm test -- --verbose
|
||||
```
|
||||
|
||||
## Core Concepts
|
||||
|
||||
### Test Structure
|
||||
|
||||
Tests are **co-located** with source files for easy discovery and maintenance:
|
||||
|
||||
```
|
||||
src/app/auth/login/
|
||||
├── EmailPasswordForm.tsx
|
||||
└── EmailPasswordForm.test.tsx
|
||||
```
|
||||
|
||||
### Test Anatomy
|
||||
|
||||
Every test follows this structure:
|
||||
|
||||
```typescript
|
||||
import { render, screen, setupUser, waitFor } from "@tests/setup/test-utils";
|
||||
import MyComponent from "./MyComponent";
|
||||
|
||||
test("descriptive test name explaining user behavior", async () => {
|
||||
// 1. Setup - Create user, mock APIs
|
||||
const user = setupUser();
|
||||
const fetchSpy = jest.spyOn(global, "fetch");
|
||||
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({ data: "value" }),
|
||||
} as Response);
|
||||
|
||||
// 2. Render - Display the component
|
||||
render(<MyComponent />);
|
||||
|
||||
// 3. Act - Simulate user interactions
|
||||
await user.type(screen.getByRole("textbox"), "test input");
|
||||
await user.click(screen.getByRole("button", { name: /submit/i }));
|
||||
|
||||
// 4. Assert - Verify expected outcomes
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText(/success/i)).toBeInTheDocument();
|
||||
});
|
||||
|
||||
// 5. Cleanup - Restore mocks
|
||||
fetchSpy.mockRestore();
|
||||
});
|
||||
```
|
||||
|
||||
### setupUser() - Automatic act() Wrapping
|
||||
|
||||
**ALWAYS use `setupUser()` instead of `userEvent.setup()`**
|
||||
|
||||
```typescript
|
||||
// ✅ Correct - Automatic act() wrapping
|
||||
const user = setupUser();
|
||||
await user.click(button);
|
||||
await user.type(input, "text");
|
||||
|
||||
// ❌ Wrong - Manual act() required, verbose
|
||||
const user = userEvent.setup();
|
||||
await act(async () => {
|
||||
await user.click(button);
|
||||
});
|
||||
```
|
||||
|
||||
The `setupUser()` helper automatically wraps all user interactions in React's `act()` to prevent warnings and ensure proper state updates.
|
||||
|
||||
## Writing Tests
|
||||
|
||||
### Query Selectors
|
||||
|
||||
Use queries in this priority order (most accessible first):
|
||||
|
||||
#### 1. Role Queries (Preferred)
|
||||
|
||||
```typescript
|
||||
// Buttons
|
||||
screen.getByRole("button", { name: /submit/i })
|
||||
screen.getByRole("button", { name: /cancel/i })
|
||||
|
||||
// Text inputs
|
||||
screen.getByRole("textbox", { name: /email/i })
|
||||
|
||||
// Checkboxes
|
||||
screen.getByRole("checkbox", { name: /remember me/i })
|
||||
|
||||
// Links
|
||||
screen.getByRole("link", { name: /learn more/i })
|
||||
|
||||
// Headings
|
||||
screen.getByRole("heading", { name: /welcome/i })
|
||||
```
|
||||
|
||||
#### 2. Label Queries
|
||||
|
||||
```typescript
|
||||
// For form inputs with labels
|
||||
screen.getByLabelText(/password/i)
|
||||
screen.getByLabelText(/email address/i)
|
||||
```
|
||||
|
||||
#### 3. Placeholder Queries
|
||||
|
||||
```typescript
|
||||
// When no label exists
|
||||
screen.getByPlaceholderText(/enter email/i)
|
||||
```
|
||||
|
||||
#### 4. Text Queries
|
||||
|
||||
```typescript
|
||||
// For non-interactive text
|
||||
screen.getByText(/welcome back/i)
|
||||
screen.getByText(/error occurred/i)
|
||||
```
|
||||
|
||||
#### Query Variants
|
||||
|
||||
```typescript
|
||||
// getBy - Throws error if not found (immediate)
|
||||
screen.getByRole("button")
|
||||
|
||||
// queryBy - Returns null if not found (checking absence)
|
||||
expect(screen.queryByText(/error/i)).not.toBeInTheDocument()
|
||||
|
||||
// findBy - Returns promise, waits for element (async)
|
||||
expect(await screen.findByText(/success/i)).toBeInTheDocument()
|
||||
|
||||
// getAllBy - Returns array of all matches
|
||||
const inputs = screen.getAllByRole("textbox")
|
||||
```
|
||||
|
||||
### Query Selectors: The Wrong Way
|
||||
|
||||
**❌ Avoid these anti-patterns:**
|
||||
|
||||
```typescript
|
||||
// DON'T query by test IDs
|
||||
screen.getByTestId("submit-button")
|
||||
|
||||
// DON'T query by class names
|
||||
container.querySelector(".submit-btn")
|
||||
|
||||
// DON'T query by element types
|
||||
container.querySelector("button")
|
||||
```
|
||||
|
||||
## User Interactions
|
||||
|
||||
### Basic Interactions
|
||||
|
||||
```typescript
|
||||
const user = setupUser();
|
||||
|
||||
// Click
|
||||
await user.click(screen.getByRole("button", { name: /submit/i }));
|
||||
|
||||
// Type text
|
||||
await user.type(screen.getByRole("textbox"), "test input");
|
||||
|
||||
// Clear and type
|
||||
await user.clear(input);
|
||||
await user.type(input, "new value");
|
||||
|
||||
// Check/uncheck checkbox
|
||||
await user.click(screen.getByRole("checkbox"));
|
||||
|
||||
// Select from dropdown
|
||||
await user.selectOptions(
|
||||
screen.getByRole("combobox"),
|
||||
"option-value"
|
||||
);
|
||||
|
||||
// Upload file
|
||||
const file = new File(["content"], "test.txt", { type: "text/plain" });
|
||||
const input = screen.getByLabelText(/upload/i);
|
||||
await user.upload(input, file);
|
||||
```
|
||||
|
||||
### Form Interactions
|
||||
|
||||
```typescript
|
||||
test("user can fill and submit form", async () => {
|
||||
const user = setupUser();
|
||||
|
||||
render(<ContactForm />);
|
||||
|
||||
await user.type(screen.getByLabelText(/name/i), "John Doe");
|
||||
await user.type(screen.getByLabelText(/email/i), "john@example.com");
|
||||
await user.type(screen.getByLabelText(/message/i), "Hello!");
|
||||
await user.click(screen.getByRole("button", { name: /send/i }));
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText(/message sent/i)).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
```
|
||||
|
||||
## Async Operations
|
||||
|
||||
### Handling Async State Updates
|
||||
|
||||
**Rule**: After triggering state changes, always wait for UI updates before asserting.
|
||||
|
||||
#### Pattern 1: findBy Queries (Simplest)
|
||||
|
||||
```typescript
|
||||
// Element appears after async operation
|
||||
await user.click(createButton);
|
||||
expect(await screen.findByRole("textbox")).toBeInTheDocument();
|
||||
```
|
||||
|
||||
#### Pattern 2: waitFor (Complex Assertions)
|
||||
|
||||
```typescript
|
||||
await user.click(submitButton);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("Success")).toBeInTheDocument();
|
||||
expect(screen.getByText("Count: 5")).toBeInTheDocument();
|
||||
});
|
||||
```
|
||||
|
||||
#### Pattern 3: waitForElementToBeRemoved
|
||||
|
||||
```typescript
|
||||
await user.click(deleteButton);
|
||||
|
||||
await waitForElementToBeRemoved(() => screen.queryByText(/item name/i));
|
||||
```
|
||||
|
||||
### Common Async Mistakes
|
||||
|
||||
```typescript
|
||||
// ❌ Wrong - getBy immediately after state change
|
||||
await user.click(button);
|
||||
expect(screen.getByText("Updated")).toBeInTheDocument(); // May fail!
|
||||
|
||||
// ✅ Correct - Wait for state update
|
||||
await user.click(button);
|
||||
expect(await screen.findByText("Updated")).toBeInTheDocument();
|
||||
|
||||
// ❌ Wrong - Multiple getBy calls without waiting
|
||||
await user.click(button);
|
||||
expect(screen.getByText("Success")).toBeInTheDocument();
|
||||
expect(screen.getByText("Data loaded")).toBeInTheDocument();
|
||||
|
||||
// ✅ Correct - Single waitFor with multiple assertions
|
||||
await user.click(button);
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("Success")).toBeInTheDocument();
|
||||
expect(screen.getByText("Data loaded")).toBeInTheDocument();
|
||||
});
|
||||
```
|
||||
|
||||
## Mocking
|
||||
|
||||
### Mocking fetch API
|
||||
|
||||
**IMPORTANT**: Always document which endpoint each mock corresponds to using comments.
|
||||
|
||||
```typescript
|
||||
let fetchSpy: jest.SpyInstance;
|
||||
|
||||
beforeEach(() => {
|
||||
fetchSpy = jest.spyOn(global, "fetch");
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
fetchSpy.mockRestore();
|
||||
});
|
||||
|
||||
test("fetches data successfully", async () => {
|
||||
// Mock GET /api/data
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({ data: [1, 2, 3] }),
|
||||
} as Response);
|
||||
|
||||
render(<MyComponent />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(fetchSpy).toHaveBeenCalledWith("/api/data");
|
||||
});
|
||||
});
|
||||
```
|
||||
|
||||
**Why comment the endpoint?** Sequential mocks can be confusing. Comments make it clear which API call each mock corresponds to, making tests easier to understand and maintain.
|
||||
|
||||
### Multiple API Calls
|
||||
|
||||
**Pattern**: Document each endpoint with a comment, then verify it was called correctly.
|
||||
|
||||
```typescript
|
||||
test("handles multiple API calls", async () => {
|
||||
const user = setupUser();
|
||||
|
||||
// Mock GET /api/items
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({ items: [] }),
|
||||
} as Response);
|
||||
|
||||
// Mock POST /api/items
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({ id: 1, name: "New Item" }),
|
||||
} as Response);
|
||||
|
||||
render(<MyComponent />);
|
||||
|
||||
// Verify GET was called
|
||||
await waitFor(() => {
|
||||
expect(fetchSpy).toHaveBeenCalledWith("/api/items");
|
||||
});
|
||||
|
||||
await user.click(screen.getByRole("button", { name: /create/i }));
|
||||
|
||||
// Verify POST was called
|
||||
await waitFor(() => {
|
||||
expect(fetchSpy).toHaveBeenCalledWith(
|
||||
"/api/items",
|
||||
expect.objectContaining({ method: "POST" })
|
||||
);
|
||||
});
|
||||
});
|
||||
```
|
||||
|
||||
**Three API calls example:**
|
||||
|
||||
```typescript
|
||||
test("test, create, and set as default", async () => {
|
||||
const user = setupUser();
|
||||
|
||||
// Mock POST /api/llm/test
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({}),
|
||||
} as Response);
|
||||
|
||||
// Mock PUT /api/llm/provider?is_creation=true
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({ id: 5, name: "New Provider" }),
|
||||
} as Response);
|
||||
|
||||
// Mock POST /api/llm/provider/5/default
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({}),
|
||||
} as Response);
|
||||
|
||||
render(<MyForm />);
|
||||
|
||||
await user.type(screen.getByLabelText(/name/i), "New Provider");
|
||||
await user.click(screen.getByRole("button", { name: /create/i }));
|
||||
|
||||
// Verify all three endpoints were called
|
||||
await waitFor(() => {
|
||||
expect(fetchSpy).toHaveBeenCalledWith(
|
||||
"/api/llm/test",
|
||||
expect.objectContaining({ method: "POST" })
|
||||
);
|
||||
expect(fetchSpy).toHaveBeenCalledWith(
|
||||
"/api/llm/provider",
|
||||
expect.objectContaining({ method: "PUT" })
|
||||
);
|
||||
expect(fetchSpy).toHaveBeenCalledWith(
|
||||
"/api/llm/provider/5/default",
|
||||
expect.objectContaining({ method: "POST" })
|
||||
);
|
||||
});
|
||||
});
|
||||
```
|
||||
|
||||
### Verifying Request Body
|
||||
|
||||
```typescript
|
||||
test("sends correct data", async () => {
|
||||
const user = setupUser();
|
||||
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({}),
|
||||
} as Response);
|
||||
|
||||
render(<MyForm />);
|
||||
|
||||
await user.type(screen.getByLabelText(/name/i), "Test");
|
||||
await user.click(screen.getByRole("button", { name: /submit/i }));
|
||||
|
||||
await waitFor(() => {
|
||||
expect(fetchSpy).toHaveBeenCalled();
|
||||
});
|
||||
|
||||
const callArgs = fetchSpy.mock.calls[0];
|
||||
const requestBody = JSON.parse(callArgs[1].body);
|
||||
|
||||
expect(requestBody).toEqual({
|
||||
name: "Test",
|
||||
active: true,
|
||||
});
|
||||
});
|
||||
```
|
||||
|
||||
### Mocking Errors
|
||||
|
||||
```typescript
|
||||
test("displays error message on failure", async () => {
|
||||
// Mock GET /api/data (network error)
|
||||
fetchSpy.mockRejectedValueOnce(new Error("Network error"));
|
||||
|
||||
render(<MyComponent />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText(/failed to load/i)).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("handles API error response", async () => {
|
||||
// Mock POST /api/items (server error)
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: false,
|
||||
status: 500,
|
||||
} as Response);
|
||||
|
||||
render(<MyComponent />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText(/something went wrong/i)).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
```
|
||||
|
||||
### Mocking Next.js Router
|
||||
|
||||
```typescript
|
||||
// At top of test file
|
||||
jest.mock("next/navigation", () => ({
|
||||
useRouter: () => ({
|
||||
push: jest.fn(),
|
||||
back: jest.fn(),
|
||||
refresh: jest.fn(),
|
||||
}),
|
||||
usePathname: () => "/current-path",
|
||||
}));
|
||||
```
|
||||
|
||||
## Common Patterns
|
||||
|
||||
### Testing CRUD Operations
|
||||
|
||||
```typescript
|
||||
describe("User Management", () => {
|
||||
test("creates new user", async () => {
|
||||
const user = setupUser();
|
||||
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({ id: 1, name: "New User" }),
|
||||
} as Response);
|
||||
|
||||
render(<UserForm />);
|
||||
|
||||
await user.type(screen.getByLabelText(/name/i), "New User");
|
||||
await user.click(screen.getByRole("button", { name: /create/i }));
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText(/user created/i)).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("edits existing user", async () => {
|
||||
const user = setupUser();
|
||||
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({ id: 1, name: "Updated User" }),
|
||||
} as Response);
|
||||
|
||||
render(<UserForm initialData={{ id: 1, name: "Old Name" }} />);
|
||||
|
||||
await user.clear(screen.getByLabelText(/name/i));
|
||||
await user.type(screen.getByLabelText(/name/i), "Updated User");
|
||||
await user.click(screen.getByRole("button", { name: /save/i }));
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText(/user updated/i)).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("deletes user", async () => {
|
||||
const user = setupUser();
|
||||
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({}),
|
||||
} as Response);
|
||||
|
||||
render(<UserList />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("John Doe")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
await user.click(screen.getByRole("button", { name: /delete/i }));
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByText("John Doe")).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
});
|
||||
```
|
||||
|
||||
### Testing Conditional Rendering
|
||||
|
||||
```typescript
|
||||
test("shows edit form when edit button clicked", async () => {
|
||||
const user = setupUser();
|
||||
|
||||
render(<MyComponent />);
|
||||
|
||||
expect(screen.queryByRole("textbox")).not.toBeInTheDocument();
|
||||
|
||||
await user.click(screen.getByRole("button", { name: /edit/i }));
|
||||
|
||||
expect(await screen.findByRole("textbox")).toBeInTheDocument();
|
||||
});
|
||||
|
||||
test("toggles between states", async () => {
|
||||
const user = setupUser();
|
||||
|
||||
render(<Toggle />);
|
||||
|
||||
const button = screen.getByRole("button", { name: /show details/i });
|
||||
|
||||
await user.click(button);
|
||||
expect(await screen.findByText(/details content/i)).toBeInTheDocument();
|
||||
|
||||
await user.click(button);
|
||||
expect(screen.queryByText(/details content/i)).not.toBeInTheDocument();
|
||||
});
|
||||
```
|
||||
|
||||
### Testing Lists and Tables
|
||||
|
||||
```typescript
|
||||
test("displays list of items", async () => {
|
||||
fetchSpy.mockResolvedValueOnce({
|
||||
ok: true,
|
||||
json: async () => ({
|
||||
items: [
|
||||
{ id: 1, name: "Item 1" },
|
||||
{ id: 2, name: "Item 2" },
|
||||
{ id: 3, name: "Item 3" },
|
||||
],
|
||||
}),
|
||||
} as Response);
|
||||
|
||||
render(<ItemList />);
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("Item 1")).toBeInTheDocument();
|
||||
expect(screen.getByText("Item 2")).toBeInTheDocument();
|
||||
expect(screen.getByText("Item 3")).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("filters items", async () => {
|
||||
const user = setupUser();
|
||||
|
||||
render(<FilterableList items={mockItems} />);
|
||||
|
||||
await user.type(screen.getByRole("searchbox"), "specific");
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText("Specific Item")).toBeInTheDocument();
|
||||
expect(screen.queryByText("Other Item")).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
```
|
||||
|
||||
### Testing Validation
|
||||
|
||||
```typescript
|
||||
test("shows validation errors", async () => {
|
||||
const user = setupUser();
|
||||
|
||||
render(<LoginForm />);
|
||||
|
||||
await user.click(screen.getByRole("button", { name: /submit/i }));
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText(/email is required/i)).toBeInTheDocument();
|
||||
expect(screen.getByText(/password is required/i)).toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
|
||||
test("clears validation on valid input", async () => {
|
||||
const user = setupUser();
|
||||
|
||||
render(<LoginForm />);
|
||||
|
||||
await user.click(screen.getByRole("button", { name: /submit/i }));
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText(/email is required/i)).toBeInTheDocument();
|
||||
});
|
||||
|
||||
await user.type(screen.getByLabelText(/email/i), "valid@email.com");
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.queryByText(/email is required/i)).not.toBeInTheDocument();
|
||||
});
|
||||
});
|
||||
```
|
||||
|
||||
## Testing Philosophy
|
||||
|
||||
### What to Test
|
||||
|
||||
**✅ Test user-visible behavior:**
|
||||
- Forms can be filled and submitted
|
||||
- Buttons trigger expected actions
|
||||
- Success/error messages appear
|
||||
- Navigation works correctly
|
||||
- Data is displayed after loading
|
||||
- Validation errors show and clear appropriately
|
||||
|
||||
**✅ Test integration points:**
|
||||
- API calls are made with correct parameters
|
||||
- Responses are handled properly
|
||||
- Error states are handled
|
||||
- Loading states appear
|
||||
|
||||
**❌ Don't test implementation details:**
|
||||
- Internal state values
|
||||
- Component lifecycle methods
|
||||
- CSS class names
|
||||
- Specific React hooks being used
|
||||
|
||||
### Test Naming
|
||||
|
||||
Write test names that describe user behavior:
|
||||
|
||||
```typescript
|
||||
// ✅ Good - Describes what user can do
|
||||
test("user can create new prompt", async () => {})
|
||||
test("shows error when API call fails", async () => {})
|
||||
test("filters items by search term", async () => {})
|
||||
|
||||
// ❌ Bad - Implementation-focused
|
||||
test("handleSubmit is called", async () => {})
|
||||
test("state updates correctly", async () => {})
|
||||
test("renders without crashing", async () => {})
|
||||
```
|
||||
|
||||
### Minimal Mocking
|
||||
|
||||
Only mock external dependencies:
|
||||
|
||||
```typescript
|
||||
// ✅ Mock external APIs
|
||||
jest.spyOn(global, "fetch")
|
||||
|
||||
// ✅ Mock Next.js router
|
||||
jest.mock("next/navigation")
|
||||
|
||||
// ✅ Mock problematic packages
|
||||
// (configured in tests/setup/__mocks__)
|
||||
|
||||
// ❌ Don't mock application code
|
||||
// ❌ Don't mock component internals
|
||||
// ❌ Don't mock utility functions
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### "Not wrapped in act()" Warning
|
||||
|
||||
**Solution**: Always use `setupUser()` instead of `userEvent.setup()`
|
||||
|
||||
```typescript
|
||||
// ✅ Correct
|
||||
const user = setupUser();
|
||||
|
||||
// ❌ Wrong
|
||||
const user = userEvent.setup();
|
||||
```
|
||||
|
||||
### "Unable to find element" Error
|
||||
|
||||
**Solution**: Element hasn't appeared yet, use `findBy` or `waitFor`
|
||||
|
||||
```typescript
|
||||
// ❌ Wrong - getBy doesn't wait
|
||||
await user.click(button);
|
||||
expect(screen.getByText("Success")).toBeInTheDocument();
|
||||
|
||||
// ✅ Correct - findBy waits
|
||||
await user.click(button);
|
||||
expect(await screen.findByText("Success")).toBeInTheDocument();
|
||||
```
|
||||
|
||||
### "Multiple elements found" Error
|
||||
|
||||
**Solution**: Be more specific with your query
|
||||
|
||||
```typescript
|
||||
// ❌ Too broad
|
||||
screen.getByRole("button")
|
||||
|
||||
// ✅ Specific
|
||||
screen.getByRole("button", { name: /submit/i })
|
||||
```
|
||||
|
||||
### Test Times Out
|
||||
|
||||
**Causes**:
|
||||
1. Async operation never completes
|
||||
2. Waiting for element that never appears
|
||||
3. Missing mock for API call
|
||||
|
||||
**Solutions**:
|
||||
```typescript
|
||||
// Check fetch is mocked
|
||||
expect(fetchSpy).toHaveBeenCalled()
|
||||
|
||||
// Use queryBy to check if element exists
|
||||
expect(screen.queryByText("Text")).toBeInTheDocument()
|
||||
|
||||
// Verify mock is set up before render
|
||||
fetchSpy.mockResolvedValueOnce(...)
|
||||
render(<Component />)
|
||||
```
|
||||
|
||||
## Examples
|
||||
|
||||
See comprehensive test examples:
|
||||
- `src/app/auth/login/EmailPasswordForm.test.tsx` - Login/signup workflows, validation
|
||||
- `src/app/chat/input-prompts/InputPrompts.test.tsx` - CRUD operations, conditional rendering
|
||||
- `src/app/admin/configuration/llm/CustomLLMProviderUpdateForm.test.tsx` - Complex forms, multi-step workflows
|
||||
|
||||
## Built-in Mocks
|
||||
|
||||
Only essential mocks in `tests/setup/__mocks__/`:
|
||||
- `UserProvider` - Removes auth requirement for tests
|
||||
- `react-markdown` / `remark-gfm` - ESM compatibility
|
||||
|
||||
See `tests/setup/__mocks__/README.md` for details.
|
||||
66
web/tests/setup/__mocks__/@/components/user/UserProvider.tsx
Normal file
66
web/tests/setup/__mocks__/@/components/user/UserProvider.tsx
Normal file
@@ -0,0 +1,66 @@
|
||||
/**
|
||||
* Mock for @/components/user/UserProvider
|
||||
*
|
||||
* Why this mock exists:
|
||||
* The real UserProvider requires complex props (authTypeMetadata, settings, user)
|
||||
* that are not relevant for most component integration tests. This mock provides
|
||||
* a simple useUser() hook with safe default values.
|
||||
*
|
||||
* Usage:
|
||||
* Automatically applied via jest.config.js moduleNameMapper.
|
||||
* Any component that imports from "@/components/user/UserProvider" will get this mock.
|
||||
*
|
||||
* To customize user values in a specific test:
|
||||
* You would need to either:
|
||||
* 1. Pass props to the real UserProvider (requires disabling this mock for that test)
|
||||
* 2. Extend this mock to accept custom values via a setup function
|
||||
*/
|
||||
import React, { createContext, useContext } from "react";
|
||||
|
||||
interface UserContextType {
|
||||
user: any;
|
||||
isAdmin: boolean;
|
||||
isCurator: boolean;
|
||||
refreshUser: () => Promise<void>;
|
||||
isCloudSuperuser: boolean;
|
||||
updateUserAutoScroll: (autoScroll: boolean) => Promise<void>;
|
||||
updateUserShortcuts: (enabled: boolean) => Promise<void>;
|
||||
toggleAssistantPinnedStatus: (
|
||||
currentPinnedAssistantIDs: number[],
|
||||
assistantId: number,
|
||||
isPinned: boolean
|
||||
) => Promise<boolean>;
|
||||
updateUserTemperatureOverrideEnabled: (enabled: boolean) => Promise<void>;
|
||||
updateUserPersonalization: (personalization: any) => Promise<void>;
|
||||
}
|
||||
|
||||
const mockUserContext: UserContextType = {
|
||||
user: null,
|
||||
isAdmin: false,
|
||||
isCurator: false,
|
||||
refreshUser: async () => {},
|
||||
isCloudSuperuser: false,
|
||||
updateUserAutoScroll: async () => {},
|
||||
updateUserShortcuts: async () => {},
|
||||
toggleAssistantPinnedStatus: async () => true,
|
||||
updateUserTemperatureOverrideEnabled: async () => {},
|
||||
updateUserPersonalization: async () => {},
|
||||
};
|
||||
|
||||
const UserContext = createContext<UserContextType | undefined>(mockUserContext);
|
||||
|
||||
export function useUser() {
|
||||
const context = useContext(UserContext);
|
||||
if (context === undefined) {
|
||||
throw new Error("useUser must be used within a UserProvider");
|
||||
}
|
||||
return context;
|
||||
}
|
||||
|
||||
export function UserProvider({ children }: { children: React.ReactNode }) {
|
||||
return (
|
||||
<UserContext.Provider value={mockUserContext}>
|
||||
{children}
|
||||
</UserContext.Provider>
|
||||
);
|
||||
}
|
||||
92
web/tests/setup/__mocks__/README.md
Normal file
92
web/tests/setup/__mocks__/README.md
Normal file
@@ -0,0 +1,92 @@
|
||||
# Test Mocks
|
||||
|
||||
This directory contains Jest mocks for dependencies that are difficult to test with in their original form.
|
||||
|
||||
## Why These Mocks Exist
|
||||
|
||||
### `@/components/user/UserProvider.tsx`
|
||||
|
||||
**Problem:** The real `UserProvider` requires complex setup with props like:
|
||||
- `authTypeMetadata` (auth configuration)
|
||||
- `settings` (CombinedSettings object)
|
||||
- `user` (User object)
|
||||
|
||||
**Solution:** This mock provides a simple `useUser()` hook that returns safe default values, allowing components that depend on user context to render in tests without extensive setup.
|
||||
|
||||
**Usage:** Automatically applied via `jest.config.js` moduleNameMapper. Components using `useUser()` will get the mock instead of the real provider.
|
||||
|
||||
**Mock values:**
|
||||
```typescript
|
||||
{
|
||||
user: null,
|
||||
isAdmin: false,
|
||||
isCurator: false,
|
||||
isCloudSuperuser: false,
|
||||
refreshUser: async () => {},
|
||||
updateUserAutoScroll: async () => {},
|
||||
updateUserShortcuts: async () => {},
|
||||
toggleAssistantPinnedStatus: async () => true,
|
||||
updateUserTemperatureOverrideEnabled: async () => {},
|
||||
updateUserPersonalization: async () => {},
|
||||
}
|
||||
```
|
||||
|
||||
### `react-markdown.tsx`
|
||||
|
||||
**Problem:** The `react-markdown` library uses ESM (ECMAScript Modules) which Jest cannot parse by default without extensive configuration. Many components (like `Field.tsx`) import this library.
|
||||
|
||||
**Solution:** This mock provides a simple component that renders markdown content as plain text, avoiding ESM parsing issues.
|
||||
|
||||
**Usage:** Automatically applied via `jest.config.js` moduleNameMapper.
|
||||
|
||||
**Limitation:** Markdown is not actually rendered/parsed in tests - content is displayed as-is. If you need to test markdown rendering, you'll need to configure Jest to handle ESM properly.
|
||||
|
||||
### `remark-gfm.ts`
|
||||
|
||||
**Problem:** The `remark-gfm` library (GitHub Flavored Markdown plugin for react-markdown) also uses ESM.
|
||||
|
||||
**Solution:** This mock provides a no-op plugin function that does nothing but allows components to import it without errors.
|
||||
|
||||
**Usage:** Automatically applied via `jest.config.js` moduleNameMapper.
|
||||
|
||||
## When to Add New Mocks
|
||||
|
||||
Add mocks to this directory when:
|
||||
|
||||
1. **ESM compatibility issues** - Library uses `export/import` syntax that Jest can't parse
|
||||
2. **Complex setup requirements** - Component/provider requires extensive configuration that's not relevant to your test
|
||||
3. **External dependencies** - Library makes network calls, accesses browser APIs not available in tests, etc.
|
||||
|
||||
## When NOT to Mock
|
||||
|
||||
Avoid mocking when:
|
||||
|
||||
1. **Testing the actual behavior** - If you're testing how markdown renders, don't mock react-markdown
|
||||
2. **Simple to configure** - If it's easy to provide real props, use the real component
|
||||
3. **Core business logic** - Don't mock your own application logic
|
||||
|
||||
## Configuration
|
||||
|
||||
These mocks are configured in `jest.config.js`:
|
||||
|
||||
```javascript
|
||||
moduleNameMapper: {
|
||||
// Mock react-markdown and related packages
|
||||
"^react-markdown$": "<rootDir>/tests/setup/__mocks__/react-markdown.tsx",
|
||||
"^remark-gfm$": "<rootDir>/tests/setup/__mocks__/remark-gfm.ts",
|
||||
// Mock UserProvider
|
||||
"^@/components/user/UserProvider$": "<rootDir>/tests/setup/__mocks__/@/components/user/UserProvider.tsx",
|
||||
// ... other mappings
|
||||
}
|
||||
```
|
||||
|
||||
**Important:** Specific mocks must come BEFORE the generic `@/` path alias, otherwise the path alias will match first and the mock won't be applied.
|
||||
|
||||
## Debugging Mock Issues
|
||||
|
||||
If a component isn't getting the mock:
|
||||
|
||||
1. Check that the import path exactly matches the moduleNameMapper pattern
|
||||
2. Clear Jest cache: `npx jest --clearCache`
|
||||
3. Check that the mock file path is correct relative to `<rootDir>`
|
||||
4. Verify the mock comes BEFORE generic path aliases in moduleNameMapper
|
||||
22
web/tests/setup/__mocks__/react-markdown.tsx
Normal file
22
web/tests/setup/__mocks__/react-markdown.tsx
Normal file
@@ -0,0 +1,22 @@
|
||||
/**
|
||||
* Mock for react-markdown
|
||||
*
|
||||
* Why this mock exists:
|
||||
* react-markdown uses ESM (ECMAScript Modules) which Jest cannot parse by default.
|
||||
* Components like Field.tsx import react-markdown, which would cause test failures.
|
||||
*
|
||||
* Limitation:
|
||||
* Markdown is NOT actually rendered/parsed in tests - content is displayed as plain text.
|
||||
* If you need to test actual markdown rendering, you'll need to configure Jest for ESM.
|
||||
*
|
||||
* Usage:
|
||||
* Automatically applied via jest.config.js moduleNameMapper.
|
||||
*/
|
||||
import React from "react";
|
||||
|
||||
// Simple mock that renders markdown content as plain text
|
||||
const ReactMarkdown = ({ children }: { children: string }) => {
|
||||
return <div>{children}</div>;
|
||||
};
|
||||
|
||||
export default ReactMarkdown;
|
||||
18
web/tests/setup/__mocks__/remark-gfm.ts
Normal file
18
web/tests/setup/__mocks__/remark-gfm.ts
Normal file
@@ -0,0 +1,18 @@
|
||||
/**
|
||||
* Mock for remark-gfm
|
||||
*
|
||||
* Why this mock exists:
|
||||
* remark-gfm (GitHub Flavored Markdown plugin) uses ESM which Jest cannot parse.
|
||||
* It's a dependency of react-markdown that components import.
|
||||
*
|
||||
* Limitation:
|
||||
* GFM features (tables, strikethrough, etc.) are not processed in tests.
|
||||
*
|
||||
* Usage:
|
||||
* Automatically applied via jest.config.js moduleNameMapper.
|
||||
*/
|
||||
|
||||
// No-op plugin that does nothing but allows imports to succeed
|
||||
export default function remarkGfm() {
|
||||
return function () {};
|
||||
}
|
||||
1
web/tests/setup/fileMock.js
Normal file
1
web/tests/setup/fileMock.js
Normal file
@@ -0,0 +1 @@
|
||||
module.exports = "test-file-stub";
|
||||
82
web/tests/setup/jest.setup.ts
Normal file
82
web/tests/setup/jest.setup.ts
Normal file
@@ -0,0 +1,82 @@
|
||||
import "@testing-library/jest-dom";
|
||||
import { TextEncoder, TextDecoder } from "util";
|
||||
|
||||
// Polyfill TextEncoder/TextDecoder (required for some libraries)
|
||||
global.TextEncoder = TextEncoder;
|
||||
global.TextDecoder = TextDecoder as any;
|
||||
|
||||
// Only set up browser-specific mocks if we're in a jsdom environment
|
||||
if (typeof window !== "undefined") {
|
||||
// Polyfill fetch for jsdom
|
||||
// @ts-ignore
|
||||
import("whatwg-fetch");
|
||||
|
||||
// Mock BroadcastChannel for JSDOM
|
||||
global.BroadcastChannel = class BroadcastChannel {
|
||||
constructor(public name: string) {}
|
||||
postMessage() {}
|
||||
close() {}
|
||||
addEventListener() {}
|
||||
removeEventListener() {}
|
||||
dispatchEvent() {
|
||||
return true;
|
||||
}
|
||||
} as any;
|
||||
|
||||
// Mock window.matchMedia for responsive components
|
||||
Object.defineProperty(window, "matchMedia", {
|
||||
writable: true,
|
||||
value: jest.fn().mockImplementation((query) => ({
|
||||
matches: false,
|
||||
media: query,
|
||||
onchange: null,
|
||||
addListener: jest.fn(), // deprecated
|
||||
removeListener: jest.fn(), // deprecated
|
||||
addEventListener: jest.fn(),
|
||||
removeEventListener: jest.fn(),
|
||||
dispatchEvent: jest.fn(),
|
||||
})),
|
||||
});
|
||||
|
||||
// Mock IntersectionObserver
|
||||
global.IntersectionObserver = class IntersectionObserver {
|
||||
constructor() {}
|
||||
disconnect() {}
|
||||
observe() {}
|
||||
takeRecords() {
|
||||
return [];
|
||||
}
|
||||
unobserve() {}
|
||||
} as any;
|
||||
|
||||
// Mock ResizeObserver
|
||||
global.ResizeObserver = class ResizeObserver {
|
||||
constructor() {}
|
||||
disconnect() {}
|
||||
observe() {}
|
||||
unobserve() {}
|
||||
} as any;
|
||||
|
||||
// Mock window.scrollTo
|
||||
global.scrollTo = jest.fn();
|
||||
}
|
||||
|
||||
// Suppress console errors in tests (optional - comment out if you want to see them)
|
||||
// const originalError = console.error;
|
||||
// beforeAll(() => {
|
||||
// console.error = (...args: any[]) => {
|
||||
// // Filter out known React warnings that are not actionable in tests
|
||||
// if (
|
||||
// typeof args[0] === "string" &&
|
||||
// (args[0].includes("Warning: ReactDOM.render") ||
|
||||
// args[0].includes("Not implemented: HTMLFormElement.prototype.submit"))
|
||||
// ) {
|
||||
// return;
|
||||
// }
|
||||
// originalError.call(console, ...args);
|
||||
// };
|
||||
// });
|
||||
|
||||
// afterAll(() => {
|
||||
// console.error = originalError;
|
||||
// });
|
||||
119
web/tests/setup/test-utils.tsx
Normal file
119
web/tests/setup/test-utils.tsx
Normal file
@@ -0,0 +1,119 @@
|
||||
import React, { ReactElement } from "react";
|
||||
import { render, RenderOptions } from "@testing-library/react";
|
||||
import userEvent from "@testing-library/user-event";
|
||||
import { SWRConfig } from "swr";
|
||||
|
||||
/**
|
||||
* Custom render function that wraps components with common providers
|
||||
* used throughout the Onyx application.
|
||||
*/
|
||||
|
||||
interface AllProvidersProps {
|
||||
children: React.ReactNode;
|
||||
swrConfig?: Record<string, any>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Wrapper component that provides all necessary context providers for tests.
|
||||
* Customize this as needed when you discover more global providers in the app.
|
||||
*/
|
||||
function AllTheProviders({ children, swrConfig = {} }: AllProvidersProps) {
|
||||
return (
|
||||
<SWRConfig
|
||||
value={{
|
||||
// Disable deduping in tests to ensure each test gets fresh data
|
||||
dedupingInterval: 0,
|
||||
// Use a Map instead of cache to avoid state leaking between tests
|
||||
provider: () => new Map(),
|
||||
// Disable error retries in tests for faster failures
|
||||
shouldRetryOnError: false,
|
||||
// Merge any custom SWR config passed from tests
|
||||
...swrConfig,
|
||||
}}
|
||||
>
|
||||
{children}
|
||||
</SWRConfig>
|
||||
);
|
||||
}
|
||||
|
||||
interface CustomRenderOptions extends Omit<RenderOptions, "wrapper"> {
|
||||
swrConfig?: Record<string, any>;
|
||||
}
|
||||
|
||||
/**
|
||||
* Custom render function that wraps the component with all providers.
|
||||
* Use this instead of @testing-library/react's render in your tests.
|
||||
*
|
||||
* @example
|
||||
* import { render, screen } from '@tests/setup/test-utils';
|
||||
*
|
||||
* test('renders component', () => {
|
||||
* render(<MyComponent />);
|
||||
* expect(screen.getByText('Hello')).toBeInTheDocument();
|
||||
* });
|
||||
*
|
||||
* @example
|
||||
* // With custom SWR config to mock API responses
|
||||
* render(<MyComponent />, {
|
||||
* swrConfig: {
|
||||
* fallback: {
|
||||
* '/api/credentials': mockCredentials,
|
||||
* },
|
||||
* },
|
||||
* });
|
||||
*/
|
||||
const customRender = (
|
||||
ui: ReactElement,
|
||||
{ swrConfig, ...options }: CustomRenderOptions = {}
|
||||
) => {
|
||||
const Wrapper = ({ children }: { children: React.ReactNode }) => (
|
||||
<AllTheProviders swrConfig={swrConfig}>{children}</AllTheProviders>
|
||||
);
|
||||
|
||||
return render(ui, { wrapper: Wrapper, ...options });
|
||||
};
|
||||
|
||||
// Re-export everything from @testing-library/react
|
||||
export * from "@testing-library/react";
|
||||
export { userEvent };
|
||||
|
||||
// Override render with our custom render
|
||||
export { customRender as render };
|
||||
|
||||
/**
|
||||
* Setup userEvent with optimized configuration for testing.
|
||||
* All user interactions are automatically wrapped in act() to prevent warnings.
|
||||
* Use this helper instead of userEvent.setup() directly.
|
||||
*
|
||||
* @example
|
||||
* const user = setupUser();
|
||||
* await user.click(button);
|
||||
* await user.type(input, "text");
|
||||
*/
|
||||
export function setupUser(options = {}) {
|
||||
const baseUser = userEvent.setup({
|
||||
// Configure for React 18 to reduce act warnings
|
||||
delay: null, // Instant typing - batches state updates better
|
||||
...options,
|
||||
});
|
||||
|
||||
// Wrap all user-event methods in act() to prevent act warnings. We add this here
|
||||
// to prevent all callsites from needing to import and wrap user events in act()
|
||||
return new Proxy(baseUser, {
|
||||
get(target, prop) {
|
||||
const value = target[prop as keyof typeof target];
|
||||
|
||||
// Only wrap methods (functions), not properties
|
||||
if (typeof value === "function") {
|
||||
return async (...args: any[]) => {
|
||||
const { act } = await import("@testing-library/react");
|
||||
return act(async () => {
|
||||
return (value as Function).apply(target, args);
|
||||
});
|
||||
};
|
||||
}
|
||||
|
||||
return value;
|
||||
},
|
||||
});
|
||||
}
|
||||
Reference in New Issue
Block a user