Compare commits

...

5 Commits

Author SHA1 Message Date
Nikolas Garza
a1df56df13 chore: remove fed slack entities button on doc set edit page (#6385) 2025-12-02 16:50:12 -08:00
Nikolas Garza
90c206d9e1 fix: eager load persona in slack channel config (#6535) 2025-12-02 16:50:12 -08:00
きわみざむらい
5e1c89d673 fix: Add proper DISABLE_MODEL_SERVER environment variable support (#6468)
Co-authored-by: Jamison Lahman <jamison@lahman.dev>
2025-12-02 16:50:12 -08:00
Emerson Gomes
2239a58b1d Harden markdown link protocol handling (#6517) 2025-12-02 16:50:12 -08:00
Justin Tahara
825edba531 fix(feedback): API Endpoint fix (#6500) 2025-12-02 16:50:12 -08:00
12 changed files with 350 additions and 305 deletions

View File

@@ -2,6 +2,7 @@ from collections.abc import Sequence
from typing import Any
from sqlalchemy import select
from sqlalchemy.orm import joinedload
from sqlalchemy.orm import Session
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
@@ -269,7 +270,9 @@ def fetch_slack_channel_config_for_channel_or_default(
# attempt to find channel-specific config first
if channel_name is not None:
sc_config = db_session.scalar(
select(SlackChannelConfig).where(
select(SlackChannelConfig)
.options(joinedload(SlackChannelConfig.persona))
.where(
SlackChannelConfig.slack_bot_id == slack_bot_id,
SlackChannelConfig.channel_config["channel_name"].astext
== channel_name,
@@ -283,7 +286,9 @@ def fetch_slack_channel_config_for_channel_or_default(
# if none found, see if there is a default
default_sc = db_session.scalar(
select(SlackChannelConfig).where(
select(SlackChannelConfig)
.options(joinedload(SlackChannelConfig.persona))
.where(
SlackChannelConfig.slack_bot_id == slack_bot_id,
SlackChannelConfig.is_default == True, # noqa: E712
)

View File

@@ -1066,6 +1066,17 @@ class InformationContentClassificationModel:
self,
queries: list[str],
) -> list[ContentClassificationPrediction]:
if os.environ.get("DISABLE_MODEL_SERVER", "").lower() == "true":
logger.info(
"DISABLE_MODEL_SERVER is set, returning default classifications"
)
return [
ContentClassificationPrediction(
predicted_label=1, content_boost_factor=1.0
)
for _ in queries
]
response = requests.post(self.content_server_endpoint, json=queries)
response.raise_for_status()
@@ -1092,6 +1103,14 @@ class ConnectorClassificationModel:
query: str,
available_connectors: list[str],
) -> list[str]:
# Check if model server is disabled
if os.environ.get("DISABLE_MODEL_SERVER", "").lower() == "true":
logger.info(
"DISABLE_MODEL_SERVER is set, returning all available connectors"
)
# Return all available connectors when model server is disabled
return available_connectors
connector_classification_request = ConnectorClassificationRequest(
available_connectors=available_connectors,
query=query,

View File

@@ -117,12 +117,14 @@ def handle_regular_answer(
# This way slack flow always has a persona
persona = slack_channel_config.persona
if not persona:
logger.warning("No persona found for channel config, using default persona")
with get_session_with_current_tenant() as db_session:
persona = get_persona_by_id(DEFAULT_PERSONA_ID, user, db_session)
document_set_names = [
document_set.name for document_set in persona.document_sets
]
else:
logger.info(f"Using persona {persona.name} for channel config")
document_set_names = [
document_set.name for document_set in persona.document_sets
]

View File

@@ -1,3 +1,4 @@
import os
from functools import lru_cache
import requests
@@ -13,16 +14,14 @@ logger = setup_logger()
def _get_gpu_status_from_model_server(indexing: bool) -> bool:
if os.environ.get("DISABLE_MODEL_SERVER", "").lower() == "true":
logger.info("DISABLE_MODEL_SERVER is set, assuming no GPU available")
return False
if indexing:
model_server_url = f"{INDEXING_MODEL_SERVER_HOST}:{INDEXING_MODEL_SERVER_PORT}"
else:
model_server_url = f"{MODEL_SERVER_HOST}:{MODEL_SERVER_PORT}"
# If model server is disabled, return False (no GPU available)
if model_server_url in ["disabled", "disabled:9000"]:
logger.info("Model server is disabled, assuming no GPU available")
return False
if "http" not in model_server_url:
model_server_url = f"http://{model_server_url}"

View File

@@ -10,14 +10,24 @@ SLACK_CHANNEL_ID = "channel_id"
# Default to True (skip warmup) if not set, otherwise respect the value
SKIP_WARM_UP = os.environ.get("SKIP_WARM_UP", "true").lower() == "true"
MODEL_SERVER_HOST = os.environ.get("MODEL_SERVER_HOST") or "localhost"
MODEL_SERVER_ALLOWED_HOST = os.environ.get("MODEL_SERVER_HOST") or "0.0.0.0"
# Check if model server is disabled
DISABLE_MODEL_SERVER = os.environ.get("DISABLE_MODEL_SERVER", "").lower() == "true"
# If model server is disabled, use "disabled" as host to trigger proper handling
if DISABLE_MODEL_SERVER:
MODEL_SERVER_HOST = "disabled"
MODEL_SERVER_ALLOWED_HOST = "disabled"
INDEXING_MODEL_SERVER_HOST = "disabled"
else:
MODEL_SERVER_HOST = os.environ.get("MODEL_SERVER_HOST") or "localhost"
MODEL_SERVER_ALLOWED_HOST = os.environ.get("MODEL_SERVER_HOST") or "0.0.0.0"
INDEXING_MODEL_SERVER_HOST = (
os.environ.get("INDEXING_MODEL_SERVER_HOST") or MODEL_SERVER_HOST
)
MODEL_SERVER_PORT = int(os.environ.get("MODEL_SERVER_PORT") or "9000")
# Model server for indexing should use a separate one to not allow indexing to introduce delay
# for inference
INDEXING_MODEL_SERVER_HOST = (
os.environ.get("INDEXING_MODEL_SERVER_HOST") or MODEL_SERVER_HOST
)
INDEXING_MODEL_SERVER_PORT = int(
os.environ.get("INDEXING_MODEL_SERVER_PORT") or MODEL_SERVER_PORT
)

View File

@@ -6,9 +6,11 @@ from unittest.mock import patch
from uuid import uuid4
# Set environment variables to disable model server for testing
os.environ["DISABLE_MODEL_SERVER"] = "true"
os.environ["MODEL_SERVER_HOST"] = "disabled"
os.environ["MODEL_SERVER_PORT"] = "9000"
from sqlalchemy import inspect
from sqlalchemy.orm import Session
from slack_sdk.errors import SlackApiError
@@ -760,3 +762,76 @@ def test_multiple_missing_scopes_resilience(
# Should still return available channels
assert len(result) == 1, f"Expected 1 channel, got {len(result)}"
assert result["C1234567890"]["name"] == "general"
def test_slack_channel_config_eager_loads_persona(db_session: Session) -> None:
"""Test that fetch_slack_channel_config_for_channel_or_default eagerly loads persona.
This prevents lazy loading failures when the session context changes later
in the request handling flow (e.g., in handle_regular_answer).
"""
from onyx.db.slack_channel_config import (
fetch_slack_channel_config_for_channel_or_default,
)
unique_id = str(uuid4())[:8]
# Create a persona (using same fields as _create_test_persona_with_slack_config)
persona = Persona(
name=f"test_eager_load_persona_{unique_id}",
description="Test persona for eager loading test",
chunks_above=0,
chunks_below=0,
llm_relevance_filter=True,
llm_filter_extraction=True,
recency_bias=RecencyBiasSetting.AUTO,
system_prompt="You are a helpful assistant.",
task_prompt="Answer the user's question.",
)
db_session.add(persona)
db_session.flush()
# Create a slack bot
slack_bot = SlackBot(
name=f"Test Bot {unique_id}",
bot_token=f"xoxb-test-{unique_id}",
app_token=f"xapp-test-{unique_id}",
enabled=True,
)
db_session.add(slack_bot)
db_session.flush()
# Create slack channel config with persona
channel_name = f"test-channel-{unique_id}"
slack_channel_config = SlackChannelConfig(
slack_bot_id=slack_bot.id,
persona_id=persona.id,
channel_config={"channel_name": channel_name, "disabled": False},
enable_auto_filters=False,
is_default=False,
)
db_session.add(slack_channel_config)
db_session.commit()
# Fetch the config using the function under test
fetched_config = fetch_slack_channel_config_for_channel_or_default(
db_session=db_session,
slack_bot_id=slack_bot.id,
channel_name=channel_name,
)
assert fetched_config is not None, "Should find the channel config"
# Check that persona relationship is already loaded (not pending lazy load)
insp = inspect(fetched_config)
assert insp is not None, "Should be able to inspect the config"
assert "persona" not in insp.unloaded, (
"Persona should be eagerly loaded, not pending lazy load. "
"This is required to prevent fallback to default persona when "
"session context changes in handle_regular_answer."
)
# Verify the persona is correct
assert fetched_config.persona is not None, "Persona should not be None"
assert fetched_config.persona.id == persona.id, "Should load the correct persona"
assert fetched_config.persona.name == persona.name

View File

@@ -1,3 +1,4 @@
import os
from collections.abc import AsyncGenerator
from typing import List
from unittest.mock import AsyncMock
@@ -9,6 +10,12 @@ from httpx import AsyncClient
from litellm.exceptions import RateLimitError
from onyx.natural_language_processing.search_nlp_models import CloudEmbedding
from onyx.natural_language_processing.search_nlp_models import (
ConnectorClassificationModel,
)
from onyx.natural_language_processing.search_nlp_models import (
InformationContentClassificationModel,
)
from shared_configs.enums import EmbeddingProvider
from shared_configs.enums import EmbedTextType
@@ -81,3 +88,95 @@ async def test_rate_limit_handling() -> None:
model_name="fake-model",
text_type=EmbedTextType.QUERY,
)
class TestInformationContentClassificationModel:
"""Test cases for InformationContentClassificationModel with DISABLE_MODEL_SERVER"""
@patch.dict(os.environ, {"DISABLE_MODEL_SERVER": "true"})
def test_predict_with_disable_model_server(self) -> None:
"""Test that predict returns default classifications when DISABLE_MODEL_SERVER is true"""
model = InformationContentClassificationModel()
queries = ["What is AI?", "How does Python work?"]
results = model.predict(queries)
assert len(results) == 2
for result in results:
assert result.predicted_label == 1 # 1 indicates informational content
assert result.content_boost_factor == 1.0 # Default boost factor
@patch.dict(os.environ, {"DISABLE_MODEL_SERVER": "false"})
@patch("requests.post")
def test_predict_with_model_server_enabled(self, mock_post: MagicMock) -> None:
"""Test that predict makes request when DISABLE_MODEL_SERVER is false"""
mock_response = MagicMock()
mock_response.json.return_value = [
{"predicted_label": 1, "content_boost_factor": 1.0},
{"predicted_label": 0, "content_boost_factor": 0.8},
]
mock_post.return_value = mock_response
model = InformationContentClassificationModel()
queries = ["test1", "test2"]
results = model.predict(queries)
assert len(results) == 2
assert results[0].predicted_label == 1
assert results[0].content_boost_factor == 1.0
assert results[1].predicted_label == 0
assert results[1].content_boost_factor == 0.8
mock_post.assert_called_once()
class TestConnectorClassificationModel:
"""Test cases for ConnectorClassificationModel with DISABLE_MODEL_SERVER"""
@patch.dict(os.environ, {"DISABLE_MODEL_SERVER": "true"})
def test_predict_with_disable_model_server(self) -> None:
"""Test that predict returns all connectors when DISABLE_MODEL_SERVER is true"""
model = ConnectorClassificationModel()
query = "Search for documentation"
available_connectors = ["confluence", "slack", "github"]
results = model.predict(query, available_connectors)
assert results == available_connectors
@patch.dict(os.environ, {"DISABLE_MODEL_SERVER": "false"})
@patch("requests.post")
def test_predict_with_model_server_enabled(self, mock_post: MagicMock) -> None:
"""Test that predict makes request when DISABLE_MODEL_SERVER is false"""
mock_response = MagicMock()
mock_response.json.return_value = {"connectors": ["confluence", "github"]}
mock_post.return_value = mock_response
model = ConnectorClassificationModel()
query = "Search for documentation"
available_connectors = ["confluence", "slack", "github"]
results = model.predict(query, available_connectors)
assert results == ["confluence", "github"]
mock_post.assert_called_once()
@patch.dict(os.environ, {"DISABLE_MODEL_SERVER": "1"})
@patch("requests.post")
def test_predict_with_disable_model_server_numeric(
self, mock_post: MagicMock
) -> None:
"""Test that predict makes request when DISABLE_MODEL_SERVER is 1 (not 'true')"""
# "1" should NOT trigger disable (only "true" should)
mock_response = MagicMock()
mock_response.json.return_value = {"connectors": ["github"]}
mock_post.return_value = mock_response
model = ConnectorClassificationModel()
query = "Find issues"
available_connectors = ["jira", "github"]
results = model.predict(query, available_connectors)
assert results == ["github"]
mock_post.assert_called_once()

View File

@@ -0,0 +1,103 @@
"""
Test cases for onyx/utils/gpu_utils.py with DISABLE_MODEL_SERVER environment variable
"""
import os
from unittest import TestCase
from unittest.mock import MagicMock
from unittest.mock import patch
import requests
from onyx.utils.gpu_utils import _get_gpu_status_from_model_server
class TestGPUUtils(TestCase):
"""Test cases for GPU utilities with DISABLE_MODEL_SERVER support"""
@patch.dict(os.environ, {"DISABLE_MODEL_SERVER": "true"})
def test_disable_model_server_true(self) -> None:
"""Test that GPU status returns False when DISABLE_MODEL_SERVER is true"""
result = _get_gpu_status_from_model_server(indexing=False)
assert result is False
@patch.dict(os.environ, {"DISABLE_MODEL_SERVER": "True"})
def test_disable_model_server_capital_true(self) -> None:
"""Test that GPU status returns False when DISABLE_MODEL_SERVER is True (capital)"""
# "True" WILL trigger disable because .lower() is called
result = _get_gpu_status_from_model_server(indexing=False)
assert result is False
@patch.dict(os.environ, {"DISABLE_MODEL_SERVER": "1"})
@patch("requests.get")
def test_disable_model_server_one(self, mock_get: MagicMock) -> None:
"""Test that GPU status makes request when DISABLE_MODEL_SERVER is 1"""
# "1" should NOT trigger disable (only "true" should)
mock_response = MagicMock()
mock_response.json.return_value = {"gpu_available": True}
mock_get.return_value = mock_response
result = _get_gpu_status_from_model_server(indexing=False)
assert result is True
mock_get.assert_called_once()
@patch.dict(os.environ, {"DISABLE_MODEL_SERVER": "yes"})
@patch("requests.get")
def test_disable_model_server_yes(self, mock_get: MagicMock) -> None:
"""Test that GPU status makes request when DISABLE_MODEL_SERVER is yes"""
# "yes" should NOT trigger disable (only "true" should)
mock_response = MagicMock()
mock_response.json.return_value = {"gpu_available": False}
mock_get.return_value = mock_response
result = _get_gpu_status_from_model_server(indexing=True)
assert result is False
mock_get.assert_called_once()
@patch.dict(os.environ, {"DISABLE_MODEL_SERVER": "false"})
@patch("requests.get")
def test_disable_model_server_false(self, mock_get: MagicMock) -> None:
"""Test that GPU status makes request when DISABLE_MODEL_SERVER is false"""
mock_response = MagicMock()
mock_response.json.return_value = {"gpu_available": True}
mock_get.return_value = mock_response
result = _get_gpu_status_from_model_server(indexing=True)
assert result is True
mock_get.assert_called_once()
@patch.dict(os.environ, {}, clear=True)
@patch("requests.get")
def test_disable_model_server_not_set(self, mock_get: MagicMock) -> None:
"""Test that GPU status makes request when DISABLE_MODEL_SERVER is not set"""
mock_response = MagicMock()
mock_response.json.return_value = {"gpu_available": False}
mock_get.return_value = mock_response
result = _get_gpu_status_from_model_server(indexing=False)
assert result is False
mock_get.assert_called_once()
@patch.dict(os.environ, {"DISABLE_MODEL_SERVER": "true"})
def test_disabled_host_fallback(self) -> None:
"""Test that disabled host is handled correctly via environment variable"""
result = _get_gpu_status_from_model_server(indexing=True)
assert result is False
@patch.dict(os.environ, {"DISABLE_MODEL_SERVER": "false"})
@patch("requests.get")
def test_request_exception_handling(self, mock_get: MagicMock) -> None:
"""Test that exceptions are properly raised when GPU status request fails"""
mock_get.side_effect = requests.RequestException("Connection error")
with self.assertRaises(requests.RequestException):
_get_gpu_status_from_model_server(indexing=False)
@patch.dict(os.environ, {"DISABLE_MODEL_SERVER": "true"})
@patch("requests.get")
def test_gpu_status_request_with_disable(self, mock_get: MagicMock) -> None:
"""Test that no request is made when DISABLE_MODEL_SERVER is true"""
result = _get_gpu_status_from_model_server(indexing=True)
assert result is False
# Verify that no HTTP request was made
mock_get.assert_not_called()

View File

@@ -101,7 +101,7 @@ export default function AIMessage({
// Toggle logic
if (currentFeedback === clickedFeedback) {
// Clicking same button - remove feedback
await handleFeedbackChange(nodeId, null);
await handleFeedbackChange(messageId, null);
}
// Clicking like (will automatically clear dislike if it was active).
@@ -113,12 +113,12 @@ export default function AIMessage({
// Open modal for positive feedback
setFeedbackModalProps({
feedbackType: "like",
messageId: nodeId,
messageId,
});
modal.toggle(true);
} else {
// No modal needed - just submit like (this replaces any existing feedback)
await handleFeedbackChange(nodeId, "like");
await handleFeedbackChange(messageId, "like");
}
}

View File

@@ -3,24 +3,14 @@ import {
FederatedConnectorDetail,
FederatedConnectorConfig,
federatedSourceToRegularSource,
ValidSources,
} from "@/lib/types";
import { SourceIcon } from "@/components/SourceIcon";
import SvgX from "@/icons/x";
import SvgSettings from "@/icons/settings";
import { Label } from "@/components/ui/label";
import { ErrorMessage } from "formik";
import Text from "@/refresh-components/texts/Text";
import IconButton from "@/refresh-components/buttons/IconButton";
import { Input } from "@/components/ui/input";
import InputTypeIn from "@/refresh-components/inputs/InputTypeIn";
import {
Dialog,
DialogContent,
DialogHeader,
DialogTitle,
} from "@/components/ui/dialog";
import Button from "@/refresh-components/buttons/Button";
interface FederatedConnectorSelectorProps {
name: string;
@@ -33,194 +23,6 @@ interface FederatedConnectorSelectorProps {
showError?: boolean;
}
interface EntityConfigDialogProps {
connectorId: number;
connectorName: string;
connectorSource: ValidSources | null;
currentEntities: Record<string, any>;
onSave: (entities: Record<string, any>) => void;
onClose: () => void;
isOpen: boolean;
}
const EntityConfigDialog = ({
connectorId,
connectorName,
connectorSource,
currentEntities,
onSave,
onClose,
isOpen,
}: EntityConfigDialogProps) => {
const [entities, setEntities] =
useState<Record<string, any>>(currentEntities);
const [entitySchema, setEntitySchema] = useState<Record<string, any> | null>(
null
);
const [isLoading, setIsLoading] = useState(false);
const [error, setError] = useState<string | null>(null);
useEffect(() => {
if (isOpen) {
setEntities(currentEntities || {});
}
}, [currentEntities, isOpen]);
useEffect(() => {
if (isOpen && connectorId) {
const fetchEntitySchema = async () => {
setIsLoading(true);
setError(null);
try {
const response = await fetch(
`/api/federated/${connectorId}/entities`
);
if (!response.ok) {
throw new Error(
`Failed to fetch entity schema: ${response.statusText}`
);
}
const data = await response.json();
setEntitySchema(data.entities);
} catch (err) {
setError(
err instanceof Error ? err.message : "Failed to load entity schema"
);
} finally {
setIsLoading(false);
}
};
fetchEntitySchema();
}
}, [isOpen, connectorId]);
const handleSave = () => {
onSave(entities);
onClose();
};
const handleEntityChange = (key: string, value: any) => {
setEntities((prev) => ({
...prev,
[key]: value,
}));
};
if (!connectorSource) {
return null;
}
return (
<Dialog open={isOpen} onOpenChange={onClose}>
<DialogContent className="max-w-md">
<DialogHeader>
<DialogTitle className="flex items-center gap-2">
<SourceIcon
sourceType={federatedSourceToRegularSource(connectorSource)}
iconSize={20}
/>
Configure {connectorName}
</DialogTitle>
</DialogHeader>
<div className="space-y-4">
{isLoading && (
<div className="text-center py-4">
<div className="animate-spin h-6 w-6 border-2 border-blue-500 border-t-transparent rounded-full mx-auto mb-2"></div>
<p className="text-sm text-muted-foreground">
Loading configuration...
</p>
</div>
)}
{error && (
<div className="text-red-500 text-sm p-3 bg-red-50 rounded-md">
{error}
</div>
)}
{entitySchema && !isLoading && (
<div className="space-y-3">
<p className="text-sm text-muted-foreground">
Configure which entities to include from this connector:
</p>
{Object.entries(entitySchema).map(
([key, field]: [string, any]) => (
<div key={key} className="space-y-2">
<Label className="text-sm font-medium">
{field.description || key}
{field.required && (
<span className="text-red-500 ml-1">*</span>
)}
</Label>
{field.type === "list" ? (
<div className="space-y-2">
<Input
type="text"
placeholder={
field.example || `Enter ${key} (comma-separated)`
}
value={
Array.isArray(entities[key])
? entities[key].join(", ")
: ""
}
onChange={(e) => {
const value = e.target.value;
const list = value
? value
.split(",")
.map((item) => item.trim())
.filter(Boolean)
: [];
handleEntityChange(key, list);
}}
/>
<p className="text-xs text-muted-foreground">
{field.description && field.description !== key
? field.description
: `Enter ${key} separated by commas`}
</p>
</div>
) : (
<div className="space-y-2">
<Input
type="text"
placeholder={field.example || `Enter ${key}`}
value={entities[key] || ""}
onChange={(e) =>
handleEntityChange(key, e.target.value)
}
/>
{field.description && field.description !== key && (
<p className="text-xs text-muted-foreground">
{field.description}
</p>
)}
</div>
)}
</div>
)
)}
</div>
)}
<div className="flex justify-end gap-2 pt-4">
<Button secondary onClick={onClose}>
Cancel
</Button>
<Button onClick={handleSave} disabled={isLoading}>
Save Configuration
</Button>
</div>
</div>
</DialogContent>
</Dialog>
);
};
export const FederatedConnectorSelector = ({
name,
label,
@@ -233,19 +35,6 @@ export const FederatedConnectorSelector = ({
}: FederatedConnectorSelectorProps) => {
const [open, setOpen] = useState(false);
const [searchQuery, setSearchQuery] = useState("");
const [configDialogState, setConfigDialogState] = useState<{
isOpen: boolean;
connectorId: number | null;
connectorName: string;
connectorSource: ValidSources | null;
currentEntities: Record<string, any>;
}>({
isOpen: false,
connectorId: null,
connectorName: "",
connectorSource: null,
currentEntities: {},
});
const dropdownRef = useRef<HTMLDivElement>(null);
const inputRef = useRef<HTMLInputElement>(null);
@@ -307,36 +96,6 @@ export const FederatedConnectorSelector = ({
);
};
const openConfigDialog = (connectorId: number) => {
const connector = federatedConnectors.find((c) => c.id === connectorId);
const config = selectedConfigs.find(
(c) => c.federated_connector_id === connectorId
);
if (connector) {
setConfigDialogState({
isOpen: true,
connectorId,
connectorName: connector.name,
connectorSource: connector.source,
currentEntities: config?.entities || {},
});
}
};
const saveEntityConfig = (entities: Record<string, any>) => {
const updatedConfigs = selectedConfigs.map((config) => {
if (config.federated_connector_id === configDialogState.connectorId) {
return {
...config,
entities,
};
}
return config;
});
onChange(updatedConfigs);
};
useEffect(() => {
const handleClickOutside = (event: MouseEvent) => {
if (
@@ -472,14 +231,6 @@ export const FederatedConnectorSelector = ({
)}
</div>
<div className="flex items-center ml-2 gap-1">
<IconButton
internal
type="button"
tooltip="Configure entities"
aria-label="Configure entities"
onClick={() => openConfigDialog(connector.id)}
icon={SvgSettings}
/>
<IconButton
internal
type="button"
@@ -500,18 +251,6 @@ export const FederatedConnectorSelector = ({
</div>
)}
<EntityConfigDialog
connectorId={configDialogState.connectorId!}
connectorName={configDialogState.connectorName}
connectorSource={configDialogState.connectorSource}
currentEntities={configDialogState.currentEntities}
onSave={saveEntityConfig}
onClose={() =>
setConfigDialogState((prev) => ({ ...prev, isOpen: false }))
}
isOpen={configDialogState.isOpen}
/>
{showError && (
<ErrorMessage
name={name}

View File

@@ -93,19 +93,8 @@ export const NEXT_PUBLIC_INCLUDE_ERROR_POPUP_SUPPORT_LINK =
export const NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY =
process.env.NEXT_PUBLIC_STRIPE_PUBLISHABLE_KEY;
// Add support for custom URL protocols in markdown links
export const ALLOWED_URL_PROTOCOLS = [
"http:",
"https:",
"mailto:",
"tel:",
"slack:",
"vscode:",
"file:",
"sms:",
"spotify:",
"zoommtg:",
];
// Restrict markdown links to safe protocols
export const ALLOWED_URL_PROTOCOLS = ["http:", "https:", "mailto:"] as const;
export const MAX_CHARACTERS_PERSONA_DESCRIPTION = 5000000;
export const MAX_STARTER_MESSAGES = 4;

View File

@@ -11,28 +11,33 @@ export const truncateString = (str: string, maxLength: number) => {
};
/**
* Custom URL transformer function for ReactMarkdown
* Allows specific protocols to be used in markdown links
* We use this with the urlTransform prop in ReactMarkdown
* Custom URL transformer function for ReactMarkdown.
* Only allows a small, safe set of protocols and strips everything else.
* Returning null removes the href attribute entirely.
*/
export function transformLinkUri(href: string) {
if (!href) return href;
export function transformLinkUri(href: string): string | null {
if (!href) return null;
const trimmedHref = href.trim();
if (!trimmedHref) return null;
const url = href.trim();
try {
const parsedUrl = new URL(url);
if (
ALLOWED_URL_PROTOCOLS.some((protocol) =>
parsedUrl.protocol.startsWith(protocol)
)
) {
return url;
const parsedUrl = new URL(trimmedHref);
const protocol = parsedUrl.protocol.toLowerCase();
if (ALLOWED_URL_PROTOCOLS.some((allowed) => allowed === protocol)) {
return trimmedHref;
}
return null;
} catch {
// If it's not a valid URL with protocol, return the original href
return href;
// Allow relative URLs, but drop anything that looks like a protocol-prefixed link
if (/^[a-zA-Z][a-zA-Z\d+.-]*:\S*/.test(trimmedHref)) {
return null;
}
return trimmedHref;
}
return href;
}
export function isSubset(parent: string[], child: string[]): boolean {