mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-17 07:45:47 +00:00
Compare commits
5 Commits
dump-scrip
...
v2.5.5
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a1df56df13 | ||
|
|
90c206d9e1 | ||
|
|
5e1c89d673 | ||
|
|
2239a58b1d | ||
|
|
825edba531 |
@@ -2,6 +2,7 @@ from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import joinedload
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
@@ -269,7 +270,9 @@ def fetch_slack_channel_config_for_channel_or_default(
|
||||
# attempt to find channel-specific config first
|
||||
if channel_name is not None:
|
||||
sc_config = db_session.scalar(
|
||||
select(SlackChannelConfig).where(
|
||||
select(SlackChannelConfig)
|
||||
.options(joinedload(SlackChannelConfig.persona))
|
||||
.where(
|
||||
SlackChannelConfig.slack_bot_id == slack_bot_id,
|
||||
SlackChannelConfig.channel_config["channel_name"].astext
|
||||
== channel_name,
|
||||
@@ -283,7 +286,9 @@ def fetch_slack_channel_config_for_channel_or_default(
|
||||
|
||||
# if none found, see if there is a default
|
||||
default_sc = db_session.scalar(
|
||||
select(SlackChannelConfig).where(
|
||||
select(SlackChannelConfig)
|
||||
.options(joinedload(SlackChannelConfig.persona))
|
||||
.where(
|
||||
SlackChannelConfig.slack_bot_id == slack_bot_id,
|
||||
SlackChannelConfig.is_default == True, # noqa: E712
|
||||
)
|
||||
|
||||
@@ -1066,6 +1066,17 @@ class InformationContentClassificationModel:
|
||||
self,
|
||||
queries: list[str],
|
||||
) -> list[ContentClassificationPrediction]:
|
||||
if os.environ.get("DISABLE_MODEL_SERVER", "").lower() == "true":
|
||||
logger.info(
|
||||
"DISABLE_MODEL_SERVER is set, returning default classifications"
|
||||
)
|
||||
return [
|
||||
ContentClassificationPrediction(
|
||||
predicted_label=1, content_boost_factor=1.0
|
||||
)
|
||||
for _ in queries
|
||||
]
|
||||
|
||||
response = requests.post(self.content_server_endpoint, json=queries)
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -1092,6 +1103,14 @@ class ConnectorClassificationModel:
|
||||
query: str,
|
||||
available_connectors: list[str],
|
||||
) -> list[str]:
|
||||
# Check if model server is disabled
|
||||
if os.environ.get("DISABLE_MODEL_SERVER", "").lower() == "true":
|
||||
logger.info(
|
||||
"DISABLE_MODEL_SERVER is set, returning all available connectors"
|
||||
)
|
||||
# Return all available connectors when model server is disabled
|
||||
return available_connectors
|
||||
|
||||
connector_classification_request = ConnectorClassificationRequest(
|
||||
available_connectors=available_connectors,
|
||||
query=query,
|
||||
|
||||
@@ -117,12 +117,14 @@ def handle_regular_answer(
|
||||
# This way slack flow always has a persona
|
||||
persona = slack_channel_config.persona
|
||||
if not persona:
|
||||
logger.warning("No persona found for channel config, using default persona")
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
persona = get_persona_by_id(DEFAULT_PERSONA_ID, user, db_session)
|
||||
document_set_names = [
|
||||
document_set.name for document_set in persona.document_sets
|
||||
]
|
||||
else:
|
||||
logger.info(f"Using persona {persona.name} for channel config")
|
||||
document_set_names = [
|
||||
document_set.name for document_set in persona.document_sets
|
||||
]
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
from functools import lru_cache
|
||||
|
||||
import requests
|
||||
@@ -13,16 +14,14 @@ logger = setup_logger()
|
||||
|
||||
|
||||
def _get_gpu_status_from_model_server(indexing: bool) -> bool:
|
||||
if os.environ.get("DISABLE_MODEL_SERVER", "").lower() == "true":
|
||||
logger.info("DISABLE_MODEL_SERVER is set, assuming no GPU available")
|
||||
return False
|
||||
if indexing:
|
||||
model_server_url = f"{INDEXING_MODEL_SERVER_HOST}:{INDEXING_MODEL_SERVER_PORT}"
|
||||
else:
|
||||
model_server_url = f"{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}"
|
||||
|
||||
# If model server is disabled, return False (no GPU available)
|
||||
if model_server_url in ["disabled", "disabled:9000"]:
|
||||
logger.info("Model server is disabled, assuming no GPU available")
|
||||
return False
|
||||
|
||||
if "http" not in model_server_url:
|
||||
model_server_url = f"http://{model_server_url}"
|
||||
|
||||
|
||||
@@ -10,14 +10,24 @@ SLACK_CHANNEL_ID = "channel_id"
|
||||
# Default to True (skip warmup) if not set, otherwise respect the value
|
||||
SKIP_WARM_UP = os.environ.get("SKIP_WARM_UP", "true").lower() == "true"
|
||||
|
||||
MODEL_SERVER_HOST = os.environ.get("MODEL_SERVER_HOST") or "localhost"
|
||||
MODEL_SERVER_ALLOWED_HOST = os.environ.get("MODEL_SERVER_HOST") or "0.0.0.0"
|
||||
# Check if model server is disabled
|
||||
DISABLE_MODEL_SERVER = os.environ.get("DISABLE_MODEL_SERVER", "").lower() == "true"
|
||||
|
||||
# If model server is disabled, use "disabled" as host to trigger proper handling
|
||||
if DISABLE_MODEL_SERVER:
|
||||
MODEL_SERVER_HOST = "disabled"
|
||||
MODEL_SERVER_ALLOWED_HOST = "disabled"
|
||||
INDEXING_MODEL_SERVER_HOST = "disabled"
|
||||
else:
|
||||
MODEL_SERVER_HOST = os.environ.get("MODEL_SERVER_HOST") or "localhost"
|
||||
MODEL_SERVER_ALLOWED_HOST = os.environ.get("MODEL_SERVER_HOST") or "0.0.0.0"
|
||||
INDEXING_MODEL_SERVER_HOST = (
|
||||
os.environ.get("INDEXING_MODEL_SERVER_HOST") or MODEL_SERVER_HOST
|
||||
)
|
||||
|
||||
MODEL_SERVER_PORT = int(os.environ.get("MODEL_SERVER_PORT") or "9000")
|
||||
# Model server for indexing should use a separate one to not allow indexing to introduce delay
|
||||
# for inference
|
||||
INDEXING_MODEL_SERVER_HOST = (
|
||||
os.environ.get("INDEXING_MODEL_SERVER_HOST") or MODEL_SERVER_HOST
|
||||
)
|
||||
INDEXING_MODEL_SERVER_PORT = int(
|
||||
os.environ.get("INDEXING_MODEL_SERVER_PORT") or MODEL_SERVER_PORT
|
||||
)
|
||||
|
||||
@@ -6,9 +6,11 @@ from unittest.mock import patch
|
||||
from uuid import uuid4
|
||||
|
||||
# Set environment variables to disable model server for testing
|
||||
os.environ["DISABLE_MODEL_SERVER"] = "true"
|
||||
os.environ["MODEL_SERVER_HOST"] = "disabled"
|
||||
os.environ["MODEL_SERVER_PORT"] = "9000"
|
||||
|
||||
from sqlalchemy import inspect
|
||||
from sqlalchemy.orm import Session
|
||||
from slack_sdk.errors import SlackApiError
|
||||
|
||||
@@ -760,3 +762,76 @@ def test_multiple_missing_scopes_resilience(
|
||||
# Should still return available channels
|
||||
assert len(result) == 1, f"Expected 1 channel, got {len(result)}"
|
||||
assert result["C1234567890"]["name"] == "general"
|
||||
|
||||
|
||||
def test_slack_channel_config_eager_loads_persona(db_session: Session) -> None:
|
||||
"""Test that fetch_slack_channel_config_for_channel_or_default eagerly loads persona.
|
||||
|
||||
This prevents lazy loading failures when the session context changes later
|
||||
in the request handling flow (e.g., in handle_regular_answer).
|
||||
"""
|
||||
from onyx.db.slack_channel_config import (
|
||||
fetch_slack_channel_config_for_channel_or_default,
|
||||
)
|
||||
|
||||
unique_id = str(uuid4())[:8]
|
||||
|
||||
# Create a persona (using same fields as _create_test_persona_with_slack_config)
|
||||
persona = Persona(
|
||||
name=f"test_eager_load_persona_{unique_id}",
|
||||
description="Test persona for eager loading test",
|
||||
chunks_above=0,
|
||||
chunks_below=0,
|
||||
llm_relevance_filter=True,
|
||||
llm_filter_extraction=True,
|
||||
recency_bias=RecencyBiasSetting.AUTO,
|
||||
system_prompt="You are a helpful assistant.",
|
||||
task_prompt="Answer the user's question.",
|
||||
)
|
||||
db_session.add(persona)
|
||||
db_session.flush()
|
||||
|
||||
# Create a slack bot
|
||||
slack_bot = SlackBot(
|
||||
name=f"Test Bot {unique_id}",
|
||||
bot_token=f"xoxb-test-{unique_id}",
|
||||
app_token=f"xapp-test-{unique_id}",
|
||||
enabled=True,
|
||||
)
|
||||
db_session.add(slack_bot)
|
||||
db_session.flush()
|
||||
|
||||
# Create slack channel config with persona
|
||||
channel_name = f"test-channel-{unique_id}"
|
||||
slack_channel_config = SlackChannelConfig(
|
||||
slack_bot_id=slack_bot.id,
|
||||
persona_id=persona.id,
|
||||
channel_config={"channel_name": channel_name, "disabled": False},
|
||||
enable_auto_filters=False,
|
||||
is_default=False,
|
||||
)
|
||||
db_session.add(slack_channel_config)
|
||||
db_session.commit()
|
||||
|
||||
# Fetch the config using the function under test
|
||||
fetched_config = fetch_slack_channel_config_for_channel_or_default(
|
||||
db_session=db_session,
|
||||
slack_bot_id=slack_bot.id,
|
||||
channel_name=channel_name,
|
||||
)
|
||||
|
||||
assert fetched_config is not None, "Should find the channel config"
|
||||
|
||||
# Check that persona relationship is already loaded (not pending lazy load)
|
||||
insp = inspect(fetched_config)
|
||||
assert insp is not None, "Should be able to inspect the config"
|
||||
assert "persona" not in insp.unloaded, (
|
||||
"Persona should be eagerly loaded, not pending lazy load. "
|
||||
"This is required to prevent fallback to default persona when "
|
||||
"session context changes in handle_regular_answer."
|
||||
)
|
||||
|
||||
# Verify the persona is correct
|
||||
assert fetched_config.persona is not None, "Persona should not be None"
|
||||
assert fetched_config.persona.id == persona.id, "Should load the correct persona"
|
||||
assert fetched_config.persona.name == persona.name
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import os
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import List
|
||||
from unittest.mock import AsyncMock
|
||||
@@ -9,6 +10,12 @@ from httpx import AsyncClient
|
||||
from litellm.exceptions import RateLimitError
|
||||
|
||||
from onyx.natural_language_processing.search_nlp_models import CloudEmbedding
|
||||
from onyx.natural_language_processing.search_nlp_models import (
|
||||
ConnectorClassificationModel,
|
||||
)
|
||||
from onyx.natural_language_processing.search_nlp_models import (
|
||||
InformationContentClassificationModel,
|
||||
)
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
from shared_configs.enums import EmbedTextType
|
||||
|
||||
@@ -81,3 +88,95 @@ async def test_rate_limit_handling() -> None:
|
||||
model_name="fake-model",
|
||||
text_type=EmbedTextType.QUERY,
|
||||
)
|
||||
|
||||
|
||||
class TestInformationContentClassificationModel:
|
||||
"""Test cases for InformationContentClassificationModel with DISABLE_MODEL_SERVER"""
|
||||
|
||||
@patch.dict(os.environ, {"DISABLE_MODEL_SERVER": "true"})
|
||||
def test_predict_with_disable_model_server(self) -> None:
|
||||
"""Test that predict returns default classifications when DISABLE_MODEL_SERVER is true"""
|
||||
model = InformationContentClassificationModel()
|
||||
queries = ["What is AI?", "How does Python work?"]
|
||||
|
||||
results = model.predict(queries)
|
||||
|
||||
assert len(results) == 2
|
||||
for result in results:
|
||||
assert result.predicted_label == 1 # 1 indicates informational content
|
||||
assert result.content_boost_factor == 1.0 # Default boost factor
|
||||
|
||||
@patch.dict(os.environ, {"DISABLE_MODEL_SERVER": "false"})
|
||||
@patch("requests.post")
|
||||
def test_predict_with_model_server_enabled(self, mock_post: MagicMock) -> None:
|
||||
"""Test that predict makes request when DISABLE_MODEL_SERVER is false"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = [
|
||||
{"predicted_label": 1, "content_boost_factor": 1.0},
|
||||
{"predicted_label": 0, "content_boost_factor": 0.8},
|
||||
]
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
model = InformationContentClassificationModel()
|
||||
queries = ["test1", "test2"]
|
||||
|
||||
results = model.predict(queries)
|
||||
|
||||
assert len(results) == 2
|
||||
assert results[0].predicted_label == 1
|
||||
assert results[0].content_boost_factor == 1.0
|
||||
assert results[1].predicted_label == 0
|
||||
assert results[1].content_boost_factor == 0.8
|
||||
mock_post.assert_called_once()
|
||||
|
||||
|
||||
class TestConnectorClassificationModel:
|
||||
"""Test cases for ConnectorClassificationModel with DISABLE_MODEL_SERVER"""
|
||||
|
||||
@patch.dict(os.environ, {"DISABLE_MODEL_SERVER": "true"})
|
||||
def test_predict_with_disable_model_server(self) -> None:
|
||||
"""Test that predict returns all connectors when DISABLE_MODEL_SERVER is true"""
|
||||
model = ConnectorClassificationModel()
|
||||
query = "Search for documentation"
|
||||
available_connectors = ["confluence", "slack", "github"]
|
||||
|
||||
results = model.predict(query, available_connectors)
|
||||
|
||||
assert results == available_connectors
|
||||
|
||||
@patch.dict(os.environ, {"DISABLE_MODEL_SERVER": "false"})
|
||||
@patch("requests.post")
|
||||
def test_predict_with_model_server_enabled(self, mock_post: MagicMock) -> None:
|
||||
"""Test that predict makes request when DISABLE_MODEL_SERVER is false"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"connectors": ["confluence", "github"]}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
model = ConnectorClassificationModel()
|
||||
query = "Search for documentation"
|
||||
available_connectors = ["confluence", "slack", "github"]
|
||||
|
||||
results = model.predict(query, available_connectors)
|
||||
|
||||
assert results == ["confluence", "github"]
|
||||
mock_post.assert_called_once()
|
||||
|
||||
@patch.dict(os.environ, {"DISABLE_MODEL_SERVER": "1"})
|
||||
@patch("requests.post")
|
||||
def test_predict_with_disable_model_server_numeric(
|
||||
self, mock_post: MagicMock
|
||||
) -> None:
|
||||
"""Test that predict makes request when DISABLE_MODEL_SERVER is 1 (not 'true')"""
|
||||
# "1" should NOT trigger disable (only "true" should)
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"connectors": ["github"]}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
model = ConnectorClassificationModel()
|
||||
query = "Find issues"
|
||||
available_connectors = ["jira", "github"]
|
||||
|
||||
results = model.predict(query, available_connectors)
|
||||
|
||||
assert results == ["github"]
|
||||
mock_post.assert_called_once()
|
||||
|
||||
103
backend/tests/unit/onyx/utils/test_gpu_utils.py
Normal file
103
backend/tests/unit/onyx/utils/test_gpu_utils.py
Normal file
@@ -0,0 +1,103 @@
|
||||
"""
|
||||
Test cases for onyx/utils/gpu_utils.py with DISABLE_MODEL_SERVER environment variable
|
||||
"""
|
||||
|
||||
import os
|
||||
from unittest import TestCase
|
||||
from unittest.mock import MagicMock
|
||||
from unittest.mock import patch
|
||||
|
||||
import requests
|
||||
|
||||
from onyx.utils.gpu_utils import _get_gpu_status_from_model_server
|
||||
|
||||
|
||||
class TestGPUUtils(TestCase):
|
||||
"""Test cases for GPU utilities with DISABLE_MODEL_SERVER support"""
|
||||
|
||||
@patch.dict(os.environ, {"DISABLE_MODEL_SERVER": "true"})
|
||||
def test_disable_model_server_true(self) -> None:
|
||||
"""Test that GPU status returns False when DISABLE_MODEL_SERVER is true"""
|
||||
result = _get_gpu_status_from_model_server(indexing=False)
|
||||
assert result is False
|
||||
|
||||
@patch.dict(os.environ, {"DISABLE_MODEL_SERVER": "True"})
|
||||
def test_disable_model_server_capital_true(self) -> None:
|
||||
"""Test that GPU status returns False when DISABLE_MODEL_SERVER is True (capital)"""
|
||||
# "True" WILL trigger disable because .lower() is called
|
||||
result = _get_gpu_status_from_model_server(indexing=False)
|
||||
assert result is False
|
||||
|
||||
@patch.dict(os.environ, {"DISABLE_MODEL_SERVER": "1"})
|
||||
@patch("requests.get")
|
||||
def test_disable_model_server_one(self, mock_get: MagicMock) -> None:
|
||||
"""Test that GPU status makes request when DISABLE_MODEL_SERVER is 1"""
|
||||
# "1" should NOT trigger disable (only "true" should)
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"gpu_available": True}
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = _get_gpu_status_from_model_server(indexing=False)
|
||||
assert result is True
|
||||
mock_get.assert_called_once()
|
||||
|
||||
@patch.dict(os.environ, {"DISABLE_MODEL_SERVER": "yes"})
|
||||
@patch("requests.get")
|
||||
def test_disable_model_server_yes(self, mock_get: MagicMock) -> None:
|
||||
"""Test that GPU status makes request when DISABLE_MODEL_SERVER is yes"""
|
||||
# "yes" should NOT trigger disable (only "true" should)
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"gpu_available": False}
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = _get_gpu_status_from_model_server(indexing=True)
|
||||
assert result is False
|
||||
mock_get.assert_called_once()
|
||||
|
||||
@patch.dict(os.environ, {"DISABLE_MODEL_SERVER": "false"})
|
||||
@patch("requests.get")
|
||||
def test_disable_model_server_false(self, mock_get: MagicMock) -> None:
|
||||
"""Test that GPU status makes request when DISABLE_MODEL_SERVER is false"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"gpu_available": True}
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = _get_gpu_status_from_model_server(indexing=True)
|
||||
assert result is True
|
||||
mock_get.assert_called_once()
|
||||
|
||||
@patch.dict(os.environ, {}, clear=True)
|
||||
@patch("requests.get")
|
||||
def test_disable_model_server_not_set(self, mock_get: MagicMock) -> None:
|
||||
"""Test that GPU status makes request when DISABLE_MODEL_SERVER is not set"""
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = {"gpu_available": False}
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
result = _get_gpu_status_from_model_server(indexing=False)
|
||||
assert result is False
|
||||
mock_get.assert_called_once()
|
||||
|
||||
@patch.dict(os.environ, {"DISABLE_MODEL_SERVER": "true"})
|
||||
def test_disabled_host_fallback(self) -> None:
|
||||
"""Test that disabled host is handled correctly via environment variable"""
|
||||
result = _get_gpu_status_from_model_server(indexing=True)
|
||||
assert result is False
|
||||
|
||||
@patch.dict(os.environ, {"DISABLE_MODEL_SERVER": "false"})
|
||||
@patch("requests.get")
|
||||
def test_request_exception_handling(self, mock_get: MagicMock) -> None:
|
||||
"""Test that exceptions are properly raised when GPU status request fails"""
|
||||
mock_get.side_effect = requests.RequestException("Connection error")
|
||||
|
||||
with self.assertRaises(requests.RequestException):
|
||||
_get_gpu_status_from_model_server(indexing=False)
|
||||
|
||||
@patch.dict(os.environ, {"DISABLE_MODEL_SERVER": "true"})
|
||||
@patch("requests.get")
|
||||
def test_gpu_status_request_with_disable(self, mock_get: MagicMock) -> None:
|
||||
"""Test that no request is made when DISABLE_MODEL_SERVER is true"""
|
||||
result = _get_gpu_status_from_model_server(indexing=True)
|
||||
assert result is False
|
||||
# Verify that no HTTP request was made
|
||||
mock_get.assert_not_called()
|
||||
@@ -101,7 +101,7 @@ export default function AIMessage({
|
||||
// Toggle logic
|
||||
if (currentFeedback === clickedFeedback) {
|
||||
// Clicking same button - remove feedback
|
||||
await handleFeedbackChange(nodeId, null);
|
||||
await handleFeedbackChange(messageId, null);
|
||||
}
|
||||
|
||||
// Clicking like (will automatically clear dislike if it was active).
|
||||
@@ -113,12 +113,12 @@ export default function AIMessage({
|
||||
// Open modal for positive feedback
|
||||
setFeedbackModalProps({
|
||||
feedbackType: "like",
|
||||
messageId: nodeId,
|
||||
messageId,
|
||||
});
|
||||
modal.toggle(true);
|
||||
} else {
|
||||
// No modal needed - just submit like (this replaces any existing feedback)
|
||||
await handleFeedbackChange(nodeId, "like");
|
||||
await handleFeedbackChange(messageId, "like");
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -3,24 +3,14 @@ import {
|
||||
FederatedConnectorDetail,
|
||||
FederatedConnectorConfig,
|
||||
federatedSourceToRegularSource,
|
||||
ValidSources,
|
||||
} from "@/lib/types";
|
||||
import { SourceIcon } from "@/components/SourceIcon";
|
||||
import SvgX from "@/icons/x";
|
||||
import SvgSettings from "@/icons/settings";
|
||||
import { Label } from "@/components/ui/label";
|
||||
import { ErrorMessage } from "formik";
|
||||
import Text from "@/refresh-components/texts/Text";
|
||||
import IconButton from "@/refresh-components/buttons/IconButton";
|
||||
import { Input } from "@/components/ui/input";
|
||||
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
|
||||
import {
|
||||
Dialog,
|
||||
DialogContent,
|
||||
DialogHeader,
|
||||
DialogTitle,
|
||||
} from "@/components/ui/dialog";
|
||||
import Button from "@/refresh-components/buttons/Button";
|
||||
|
||||
interface FederatedConnectorSelectorProps {
|
||||
name: string;
|
||||
@@ -33,194 +23,6 @@ interface FederatedConnectorSelectorProps {
|
||||
showError?: boolean;
|
||||
}
|
||||
|
||||
interface EntityConfigDialogProps {
|
||||
connectorId: number;
|
||||
connectorName: string;
|
||||
connectorSource: ValidSources | null;
|
||||
currentEntities: Record<string, any>;
|
||||
onSave: (entities: Record<string, any>) => void;
|
||||
onClose: () => void;
|
||||
isOpen: boolean;
|
||||
}
|
||||
|
||||
const EntityConfigDialog = ({
|
||||
connectorId,
|
||||
connectorName,
|
||||
connectorSource,
|
||||
currentEntities,
|
||||
onSave,
|
||||
onClose,
|
||||
isOpen,
|
||||
}: EntityConfigDialogProps) => {
|
||||
const [entities, setEntities] =
|
||||
useState<Record<string, any>>(currentEntities);
|
||||
const [entitySchema, setEntitySchema] = useState<Record<string, any> | null>(
|
||||
null
|
||||
);
|
||||
const [isLoading, setIsLoading] = useState(false);
|
||||
const [error, setError] = useState<string | null>(null);
|
||||
|
||||
useEffect(() => {
|
||||
if (isOpen) {
|
||||
setEntities(currentEntities || {});
|
||||
}
|
||||
}, [currentEntities, isOpen]);
|
||||
|
||||
useEffect(() => {
|
||||
if (isOpen && connectorId) {
|
||||
const fetchEntitySchema = async () => {
|
||||
setIsLoading(true);
|
||||
setError(null);
|
||||
try {
|
||||
const response = await fetch(
|
||||
`/api/federated/${connectorId}/entities`
|
||||
);
|
||||
if (!response.ok) {
|
||||
throw new Error(
|
||||
`Failed to fetch entity schema: ${response.statusText}`
|
||||
);
|
||||
}
|
||||
const data = await response.json();
|
||||
setEntitySchema(data.entities);
|
||||
} catch (err) {
|
||||
setError(
|
||||
err instanceof Error ? err.message : "Failed to load entity schema"
|
||||
);
|
||||
} finally {
|
||||
setIsLoading(false);
|
||||
}
|
||||
};
|
||||
fetchEntitySchema();
|
||||
}
|
||||
}, [isOpen, connectorId]);
|
||||
|
||||
const handleSave = () => {
|
||||
onSave(entities);
|
||||
onClose();
|
||||
};
|
||||
|
||||
const handleEntityChange = (key: string, value: any) => {
|
||||
setEntities((prev) => ({
|
||||
...prev,
|
||||
[key]: value,
|
||||
}));
|
||||
};
|
||||
|
||||
if (!connectorSource) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return (
|
||||
<Dialog open={isOpen} onOpenChange={onClose}>
|
||||
<DialogContent className="max-w-md">
|
||||
<DialogHeader>
|
||||
<DialogTitle className="flex items-center gap-2">
|
||||
<SourceIcon
|
||||
sourceType={federatedSourceToRegularSource(connectorSource)}
|
||||
iconSize={20}
|
||||
/>
|
||||
Configure {connectorName}
|
||||
</DialogTitle>
|
||||
</DialogHeader>
|
||||
|
||||
<div className="space-y-4">
|
||||
{isLoading && (
|
||||
<div className="text-center py-4">
|
||||
<div className="animate-spin h-6 w-6 border-2 border-blue-500 border-t-transparent rounded-full mx-auto mb-2"></div>
|
||||
<p className="text-sm text-muted-foreground">
|
||||
Loading configuration...
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
|
||||
{error && (
|
||||
<div className="text-red-500 text-sm p-3 bg-red-50 rounded-md">
|
||||
{error}
|
||||
</div>
|
||||
)}
|
||||
|
||||
{entitySchema && !isLoading && (
|
||||
<div className="space-y-3">
|
||||
<p className="text-sm text-muted-foreground">
|
||||
Configure which entities to include from this connector:
|
||||
</p>
|
||||
|
||||
{Object.entries(entitySchema).map(
|
||||
([key, field]: [string, any]) => (
|
||||
<div key={key} className="space-y-2">
|
||||
<Label className="text-sm font-medium">
|
||||
{field.description || key}
|
||||
{field.required && (
|
||||
<span className="text-red-500 ml-1">*</span>
|
||||
)}
|
||||
</Label>
|
||||
|
||||
{field.type === "list" ? (
|
||||
<div className="space-y-2">
|
||||
<Input
|
||||
type="text"
|
||||
placeholder={
|
||||
field.example || `Enter ${key} (comma-separated)`
|
||||
}
|
||||
value={
|
||||
Array.isArray(entities[key])
|
||||
? entities[key].join(", ")
|
||||
: ""
|
||||
}
|
||||
onChange={(e) => {
|
||||
const value = e.target.value;
|
||||
const list = value
|
||||
? value
|
||||
.split(",")
|
||||
.map((item) => item.trim())
|
||||
.filter(Boolean)
|
||||
: [];
|
||||
handleEntityChange(key, list);
|
||||
}}
|
||||
/>
|
||||
<p className="text-xs text-muted-foreground">
|
||||
{field.description && field.description !== key
|
||||
? field.description
|
||||
: `Enter ${key} separated by commas`}
|
||||
</p>
|
||||
</div>
|
||||
) : (
|
||||
<div className="space-y-2">
|
||||
<Input
|
||||
type="text"
|
||||
placeholder={field.example || `Enter ${key}`}
|
||||
value={entities[key] || ""}
|
||||
onChange={(e) =>
|
||||
handleEntityChange(key, e.target.value)
|
||||
}
|
||||
/>
|
||||
{field.description && field.description !== key && (
|
||||
<p className="text-xs text-muted-foreground">
|
||||
{field.description}
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
)}
|
||||
</div>
|
||||
)}
|
||||
|
||||
<div className="flex justify-end gap-2 pt-4">
|
||||
<Button secondary onClick={onClose}>
|
||||
Cancel
|
||||
</Button>
|
||||
<Button onClick={handleSave} disabled={isLoading}>
|
||||
Save Configuration
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
</DialogContent>
|
||||
</Dialog>
|
||||
);
|
||||
};
|
||||
|
||||
export const FederatedConnectorSelector = ({
|
||||
name,
|
||||
label,
|
||||
@@ -233,19 +35,6 @@ export const FederatedConnectorSelector = ({
|
||||
}: FederatedConnectorSelectorProps) => {
|
||||
const [open, setOpen] = useState(false);
|
||||
const [searchQuery, setSearchQuery] = useState("");
|
||||
const [configDialogState, setConfigDialogState] = useState<{
|
||||
isOpen: boolean;
|
||||
connectorId: number | null;
|
||||
connectorName: string;
|
||||
connectorSource: ValidSources | null;
|
||||
currentEntities: Record<string, any>;
|
||||
}>({
|
||||
isOpen: false,
|
||||
connectorId: null,
|
||||
connectorName: "",
|
||||
connectorSource: null,
|
||||
currentEntities: {},
|
||||
});
|
||||
const dropdownRef = useRef<HTMLDivElement>(null);
|
||||
const inputRef = useRef<HTMLInputElement>(null);
|
||||
|
||||
@@ -307,36 +96,6 @@ export const FederatedConnectorSelector = ({
|
||||
);
|
||||
};
|
||||
|
||||
const openConfigDialog = (connectorId: number) => {
|
||||
const connector = federatedConnectors.find((c) => c.id === connectorId);
|
||||
const config = selectedConfigs.find(
|
||||
(c) => c.federated_connector_id === connectorId
|
||||
);
|
||||
|
||||
if (connector) {
|
||||
setConfigDialogState({
|
||||
isOpen: true,
|
||||
connectorId,
|
||||
connectorName: connector.name,
|
||||
connectorSource: connector.source,
|
||||
currentEntities: config?.entities || {},
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
const saveEntityConfig = (entities: Record<string, any>) => {
|
||||
const updatedConfigs = selectedConfigs.map((config) => {
|
||||
if (config.federated_connector_id === configDialogState.connectorId) {
|
||||
return {
|
||||
...config,
|
||||
entities,
|
||||
};
|
||||
}
|
||||
return config;
|
||||
});
|
||||
onChange(updatedConfigs);
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
const handleClickOutside = (event: MouseEvent) => {
|
||||
if (
|
||||
@@ -472,14 +231,6 @@ export const FederatedConnectorSelector = ({
|
||||
)}
|
||||
</div>
|
||||
<div className="flex items-center ml-2 gap-1">
|
||||
<IconButton
|
||||
internal
|
||||
type="button"
|
||||
tooltip="Configure entities"
|
||||
aria-label="Configure entities"
|
||||
onClick={() => openConfigDialog(connector.id)}
|
||||
icon={SvgSettings}
|
||||
/>
|
||||
<IconButton
|
||||
internal
|
||||
type="button"
|
||||
@@ -500,18 +251,6 @@ export const FederatedConnectorSelector = ({
|
||||
</div>
|
||||
)}
|
||||
|
||||
<EntityConfigDialog
|
||||
connectorId={configDialogState.connectorId!}
|
||||
connectorName={configDialogState.connectorName}
|
||||
connectorSource={configDialogState.connectorSource}
|
||||
currentEntities={configDialogState.currentEntities}
|
||||
onSave={saveEntityConfig}
|
||||
onClose={() =>
|
||||
setConfigDialogState((prev) => ({ ...prev, isOpen: false }))
|
||||
}
|
||||
isOpen={configDialogState.isOpen}
|
||||
/>
|
||||
|
||||
{showError && (
|
||||
<ErrorMessage
|
||||
name={name}
|
||||
|
||||
@@ -93,19 +93,8 @@ export const NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK =
|
||||
export const NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY =
|
||||
process.env.NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY;
|
||||
|
||||
// Add support for custom URL protocols in markdown links
|
||||
export const ALLOWED_URL_PROTOCOLS = [
|
||||
"http:",
|
||||
"https:",
|
||||
"mailto:",
|
||||
"tel:",
|
||||
"slack:",
|
||||
"vscode:",
|
||||
"file:",
|
||||
"sms:",
|
||||
"spotify:",
|
||||
"zoommtg:",
|
||||
];
|
||||
// Restrict markdown links to safe protocols
|
||||
export const ALLOWED_URL_PROTOCOLS = ["http:", "https:", "mailto:"] as const;
|
||||
|
||||
export const MAX_CHARACTERS_PERSONA_DESCRIPTION = 5000000;
|
||||
export const MAX_STARTER_MESSAGES = 4;
|
||||
|
||||
@@ -11,28 +11,33 @@ export const truncateString = (str: string, maxLength: number) => {
|
||||
};
|
||||
|
||||
/**
|
||||
* Custom URL transformer function for ReactMarkdown
|
||||
* Allows specific protocols to be used in markdown links
|
||||
* We use this with the urlTransform prop in ReactMarkdown
|
||||
* Custom URL transformer function for ReactMarkdown.
|
||||
* Only allows a small, safe set of protocols and strips everything else.
|
||||
* Returning null removes the href attribute entirely.
|
||||
*/
|
||||
export function transformLinkUri(href: string) {
|
||||
if (!href) return href;
|
||||
export function transformLinkUri(href: string): string | null {
|
||||
if (!href) return null;
|
||||
|
||||
const trimmedHref = href.trim();
|
||||
if (!trimmedHref) return null;
|
||||
|
||||
const url = href.trim();
|
||||
try {
|
||||
const parsedUrl = new URL(url);
|
||||
if (
|
||||
ALLOWED_URL_PROTOCOLS.some((protocol) =>
|
||||
parsedUrl.protocol.startsWith(protocol)
|
||||
)
|
||||
) {
|
||||
return url;
|
||||
const parsedUrl = new URL(trimmedHref);
|
||||
const protocol = parsedUrl.protocol.toLowerCase();
|
||||
|
||||
if (ALLOWED_URL_PROTOCOLS.some((allowed) => allowed === protocol)) {
|
||||
return trimmedHref;
|
||||
}
|
||||
|
||||
return null;
|
||||
} catch {
|
||||
// If it's not a valid URL with protocol, return the original href
|
||||
return href;
|
||||
// Allow relative URLs, but drop anything that looks like a protocol-prefixed link
|
||||
if (/^[a-zA-Z][a-zA-Z\d+.-]*:\S*/.test(trimmedHref)) {
|
||||
return null;
|
||||
}
|
||||
|
||||
return trimmedHref;
|
||||
}
|
||||
return href;
|
||||
}
|
||||
|
||||
export function isSubset(parent: string[], child: string[]): boolean {
|
||||
|
||||
Reference in New Issue
Block a user