Compare commits

...

11 Commits

131 changed files with 3170 additions and 737 deletions

View File

@@ -40,6 +40,9 @@ jobs:
- name: Generate OpenAPI schema and Python client
shell: bash
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
env:
LICENSE_ENFORCEMENT_ENABLED: "false"
run: |
ods openapi all

View File

@@ -302,6 +302,8 @@ jobs:
cat <<EOF > deployment/docker_compose/.env
COMPOSE_PROFILES=s3-filestore
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
LICENSE_ENFORCEMENT_ENABLED=false
AUTH_TYPE=basic
POSTGRES_POOL_PRE_PING=true
POSTGRES_USE_NULL_POOL=true
@@ -478,6 +480,7 @@ jobs:
run: |
cd deployment/docker_compose
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
LICENSE_ENFORCEMENT_ENABLED=false \
MULTI_TENANT=true \
AUTH_TYPE=cloud \
REQUIRE_EMAIL_VERIFICATION=false \

View File

@@ -291,6 +291,8 @@ jobs:
cat <<EOF > deployment/docker_compose/.env
COMPOSE_PROFILES=s3-filestore
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
LICENSE_ENFORCEMENT_ENABLED=false
AUTH_TYPE=basic
GEN_AI_API_KEY=${OPENAI_API_KEY_VALUE}
EXA_API_KEY=${EXA_API_KEY_VALUE}

View File

@@ -42,6 +42,9 @@ jobs:
- name: Generate OpenAPI schema and Python client
shell: bash
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
env:
LICENSE_ENFORCEMENT_ENABLED: "false"
run: |
ods openapi all

View File

@@ -27,6 +27,8 @@ jobs:
PYTHONPATH: ./backend
REDIS_CLOUD_PYTEST_PASSWORD: ${{ secrets.REDIS_CLOUD_PYTEST_PASSWORD }}
DISABLE_TELEMETRY: "true"
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
LICENSE_ENFORCEMENT_ENABLED: "false"
steps:
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2

View File

@@ -0,0 +1,27 @@
"""add_user_preferences
Revision ID: 175ea04c7087
Revises: d56ffa94ca32
Create Date: 2026-02-04 18:16:24.830873
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "175ea04c7087"
down_revision = "d56ffa94ca32"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"user",
sa.Column("user_preferences", sa.Text(), nullable=True),
)
def downgrade() -> None:
op.drop_column("user", "user_preferences")

View File

@@ -134,7 +134,7 @@ GATED_TENANTS_KEY = "gated_tenants"
# License enforcement - when True, blocks API access for gated/expired licenses
LICENSE_ENFORCEMENT_ENABLED = (
os.environ.get("LICENSE_ENFORCEMENT_ENABLED", "").lower() == "true"
os.environ.get("LICENSE_ENFORCEMENT_ENABLED", "true").lower() == "true"
)
# Cloud data plane URL - self-hosted instances call this to reach cloud proxy endpoints

View File

@@ -50,7 +50,12 @@ def github_doc_sync(
**cc_pair.connector.connector_specific_config
)
github_connector.load_credentials(cc_pair.credential.credential_json)
credential_json = (
cc_pair.credential.credential_json.get_value(apply_mask=False)
if cc_pair.credential.credential_json
else {}
)
github_connector.load_credentials(credential_json)
logger.info("GitHub connector credentials loaded successfully")
if not github_connector.github_client:

View File

@@ -18,7 +18,12 @@ def github_group_sync(
github_connector: GithubConnector = GithubConnector(
**cc_pair.connector.connector_specific_config
)
github_connector.load_credentials(cc_pair.credential.credential_json)
credential_json = (
cc_pair.credential.credential_json.get_value(apply_mask=False)
if cc_pair.credential.credential_json
else {}
)
github_connector.load_credentials(credential_json)
if not github_connector.github_client:
raise ValueError("github_client is required")

View File

@@ -50,7 +50,12 @@ def gmail_doc_sync(
already populated.
"""
gmail_connector = GmailConnector(**cc_pair.connector.connector_specific_config)
gmail_connector.load_credentials(cc_pair.credential.credential_json)
credential_json = (
cc_pair.credential.credential_json.get_value(apply_mask=False)
if cc_pair.credential.credential_json
else {}
)
gmail_connector.load_credentials(credential_json)
slim_doc_generator = _get_slim_doc_generator(
cc_pair, gmail_connector, callback=callback

View File

@@ -295,7 +295,12 @@ def gdrive_doc_sync(
google_drive_connector = GoogleDriveConnector(
**cc_pair.connector.connector_specific_config
)
google_drive_connector.load_credentials(cc_pair.credential.credential_json)
credential_json = (
cc_pair.credential.credential_json.get_value(apply_mask=False)
if cc_pair.credential.credential_json
else {}
)
google_drive_connector.load_credentials(credential_json)
slim_doc_generator = _get_slim_doc_generator(cc_pair, google_drive_connector)

View File

@@ -391,7 +391,12 @@ def gdrive_group_sync(
google_drive_connector = GoogleDriveConnector(
**cc_pair.connector.connector_specific_config
)
google_drive_connector.load_credentials(cc_pair.credential.credential_json)
credential_json = (
cc_pair.credential.credential_json.get_value(apply_mask=False)
if cc_pair.credential.credential_json
else {}
)
google_drive_connector.load_credentials(credential_json)
admin_service = get_admin_service(
google_drive_connector.creds, google_drive_connector.primary_admin_email
)

View File

@@ -24,7 +24,12 @@ def jira_doc_sync(
jira_connector = JiraConnector(
**cc_pair.connector.connector_specific_config,
)
jira_connector.load_credentials(cc_pair.credential.credential_json)
credential_json = (
cc_pair.credential.credential_json.get_value(apply_mask=False)
if cc_pair.credential.credential_json
else {}
)
jira_connector.load_credentials(credential_json)
yield from generic_doc_sync(
cc_pair=cc_pair,

View File

@@ -119,8 +119,13 @@ def jira_group_sync(
if not jira_base_url:
raise ValueError("No jira_base_url found in connector config")
credential_json = (
cc_pair.credential.credential_json.get_value(apply_mask=False)
if cc_pair.credential.credential_json
else {}
)
jira_client = build_jira_client(
credentials=cc_pair.credential.credential_json,
credentials=credential_json,
jira_base=jira_base_url,
scoped_token=scoped_token,
)

View File

@@ -30,7 +30,11 @@ def get_any_salesforce_client_for_doc_id(
if _ANY_SALESFORCE_CLIENT is None:
cc_pairs = get_cc_pairs_for_document(db_session, doc_id)
first_cc_pair = cc_pairs[0]
credential_json = first_cc_pair.credential.credential_json
credential_json = (
first_cc_pair.credential.credential_json.get_value(apply_mask=False)
if first_cc_pair.credential.credential_json
else {}
)
_ANY_SALESFORCE_CLIENT = Salesforce(
username=credential_json["sf_username"],
password=credential_json["sf_password"],
@@ -158,7 +162,11 @@ def _get_salesforce_client_for_doc_id(db_session: Session, doc_id: str) -> Sales
)
if cc_pair is None:
raise ValueError(f"CC pair {cc_pair_id} not found")
credential_json = cc_pair.credential.credential_json
credential_json = (
cc_pair.credential.credential_json.get_value(apply_mask=False)
if cc_pair.credential.credential_json
else {}
)
_CC_PAIR_ID_SALESFORCE_CLIENT_MAP[cc_pair_id] = Salesforce(
username=credential_json["sf_username"],
password=credential_json["sf_password"],

View File

@@ -24,7 +24,12 @@ def sharepoint_doc_sync(
sharepoint_connector = SharepointConnector(
**cc_pair.connector.connector_specific_config,
)
sharepoint_connector.load_credentials(cc_pair.credential.credential_json)
credential_json = (
cc_pair.credential.credential_json.get_value(apply_mask=False)
if cc_pair.credential.credential_json
else {}
)
sharepoint_connector.load_credentials(credential_json)
yield from generic_doc_sync(
cc_pair=cc_pair,

View File

@@ -25,7 +25,12 @@ def sharepoint_group_sync(
# Create SharePoint connector instance and load credentials
connector = SharepointConnector(**connector_config)
connector.load_credentials(cc_pair.credential.credential_json)
credential_json = (
cc_pair.credential.credential_json.get_value(apply_mask=False)
if cc_pair.credential.credential_json
else {}
)
connector.load_credentials(credential_json)
if not connector.msal_app:
raise RuntimeError("MSAL app not initialized in connector")

View File

@@ -151,9 +151,14 @@ def slack_doc_sync(
tenant_id = get_current_tenant_id()
provider = OnyxDBCredentialsProvider(tenant_id, "slack", cc_pair.credential.id)
r = get_redis_client(tenant_id=tenant_id)
credential_json = (
cc_pair.credential.credential_json.get_value(apply_mask=False)
if cc_pair.credential.credential_json
else {}
)
slack_client = SlackConnector.make_slack_web_client(
provider.get_provider_key(),
cc_pair.credential.credential_json["slack_bot_token"],
credential_json["slack_bot_token"],
SlackConnector.MAX_RETRIES,
r,
)

View File

@@ -63,9 +63,14 @@ def slack_group_sync(
provider = OnyxDBCredentialsProvider(tenant_id, "slack", cc_pair.credential.id)
r = get_redis_client(tenant_id=tenant_id)
credential_json = (
cc_pair.credential.credential_json.get_value(apply_mask=False)
if cc_pair.credential.credential_json
else {}
)
slack_client = SlackConnector.make_slack_web_client(
provider.get_provider_key(),
cc_pair.credential.credential_json["slack_bot_token"],
credential_json["slack_bot_token"],
SlackConnector.MAX_RETRIES,
r,
)

View File

@@ -25,7 +25,12 @@ def teams_doc_sync(
teams_connector = TeamsConnector(
**cc_pair.connector.connector_specific_config,
)
teams_connector.load_credentials(cc_pair.credential.credential_json)
credential_json = (
cc_pair.credential.credential_json.get_value(apply_mask=False)
if cc_pair.credential.credential_json
else {}
)
teams_connector.load_credentials(credential_json)
yield from generic_doc_sync(
cc_pair=cc_pair,

View File

@@ -270,7 +270,11 @@ def confluence_oauth_accessible_resources(
if not credential:
raise HTTPException(400, f"Credential {credential_id} not found.")
credential_dict = credential.credential_json
credential_dict = (
credential.credential_json.get_value(apply_mask=False)
if credential.credential_json
else {}
)
access_token = credential_dict["confluence_access_token"]
try:
@@ -337,7 +341,12 @@ def confluence_oauth_finalize(
detail=f"Confluence Cloud OAuth failed - credential {credential_id} not found.",
)
new_credential_json: dict[str, Any] = dict(credential.credential_json)
existing_credential_json = (
credential.credential_json.get_value(apply_mask=False)
if credential.credential_json
else {}
)
new_credential_json: dict[str, Any] = dict(existing_credential_json)
new_credential_json["cloud_id"] = cloud_id
new_credential_json["cloud_name"] = cloud_name
new_credential_json["wiki_base"] = cloud_url

View File

@@ -11,6 +11,7 @@ from onyx.db.models import OAuthUserToken
from onyx.db.oauth_config import get_user_oauth_token
from onyx.db.oauth_config import upsert_user_oauth_token
from onyx.utils.logger import setup_logger
from onyx.utils.sensitive import SensitiveValue
logger = setup_logger()
@@ -33,7 +34,10 @@ class OAuthTokenManager:
if not user_token:
return None
token_data = user_token.token_data
if not user_token.token_data:
return None
token_data = self._unwrap_token_data(user_token.token_data)
# Check if token is expired
if OAuthTokenManager.is_token_expired(token_data):
@@ -51,7 +55,10 @@ class OAuthTokenManager:
def refresh_token(self, user_token: OAuthUserToken) -> str:
"""Refresh access token using refresh token"""
token_data = user_token.token_data
if not user_token.token_data:
raise ValueError("No token data available for refresh")
token_data = self._unwrap_token_data(user_token.token_data)
response = requests.post(
self.oauth_config.token_url,
@@ -153,3 +160,11 @@ class OAuthTokenManager:
separator = "&" if "?" in oauth_config.authorization_url else "?"
return f"{oauth_config.authorization_url}{separator}{urlencode(params)}"
@staticmethod
def _unwrap_token_data(
token_data: SensitiveValue[dict[str, Any]] | dict[str, Any],
) -> dict[str, Any]:
if isinstance(token_data, SensitiveValue):
return token_data.get_value(apply_mask=False)
return token_data

View File

@@ -27,6 +27,7 @@ from onyx.configs.constants import DocumentSource
from onyx.configs.constants import MessageType
from onyx.context.search.models import SearchDoc
from onyx.context.search.models import SearchDocsResponse
from onyx.db.memory import UserMemoryContext
from onyx.db.models import Persona
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import LLMUserIdentity
@@ -370,7 +371,7 @@ def run_llm_loop(
custom_agent_prompt: str | None,
project_files: ExtractedProjectFiles,
persona: Persona | None,
memories: list[str] | None,
user_memory_context: UserMemoryContext | None,
llm: LLM,
token_counter: Callable[[str], int],
db_session: Session,
@@ -483,7 +484,7 @@ def run_llm_loop(
system_prompt_str = build_system_prompt(
base_system_prompt=default_base_system_prompt,
datetime_aware=persona.datetime_aware if persona else True,
memories=memories,
user_memory_context=user_memory_context,
tools=tools,
should_cite_documents=should_cite_documents
or always_cite_documents,
@@ -637,7 +638,7 @@ def run_llm_loop(
tool_calls=tool_calls,
tools=final_tools,
message_history=truncated_message_history,
memories=memories,
user_memory_context=user_memory_context,
user_info=None, # TODO, this is part of memories right now, might want to separate it out
citation_mapping=citation_mapping,
next_citation_num=citation_processor.get_next_citation_number(),

View File

@@ -471,7 +471,7 @@ def handle_stream_message_objects(
# Filter chat_history to only messages after the cutoff
chat_history = [m for m in chat_history if m.id > cutoff_id]
memories = get_memories(user, db_session)
user_memory_context = get_memories(user, db_session)
custom_agent_prompt = get_custom_agent_prompt(persona, chat_session)
@@ -480,7 +480,7 @@ def handle_stream_message_objects(
persona_system_prompt=custom_agent_prompt or "",
token_counter=token_counter,
files=new_msg_req.file_descriptors,
memories=memories,
user_memory_context=user_memory_context,
)
# Process projects, if all of the files fit in the context, it doesn't need to use RAG
@@ -667,7 +667,7 @@ def handle_stream_message_objects(
custom_agent_prompt=custom_agent_prompt,
project_files=extracted_project_files,
persona=persona,
memories=memories,
user_memory_context=user_memory_context,
llm=llm,
token_counter=token_counter,
db_session=db_session,

View File

@@ -4,6 +4,7 @@ from uuid import UUID
from sqlalchemy.orm import Session
from onyx.db.memory import UserMemoryContext
from onyx.db.persona import get_default_behavior_persona
from onyx.db.user_file import calculate_user_files_token_count
from onyx.file_store.models import FileDescriptor
@@ -12,7 +13,6 @@ from onyx.prompts.chat_prompts import CODE_BLOCK_MARKDOWN
from onyx.prompts.chat_prompts import DEFAULT_SYSTEM_PROMPT
from onyx.prompts.chat_prompts import LAST_CYCLE_CITATION_REMINDER
from onyx.prompts.chat_prompts import REQUIRE_CITATION_GUIDANCE
from onyx.prompts.chat_prompts import USER_INFO_HEADER
from onyx.prompts.prompt_utils import get_company_context
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
from onyx.prompts.prompt_utils import replace_citation_guidance_tag
@@ -25,6 +25,7 @@ from onyx.prompts.tool_prompts import TOOL_DESCRIPTION_SEARCH_GUIDANCE
from onyx.prompts.tool_prompts import TOOL_SECTION_HEADER
from onyx.prompts.tool_prompts import WEB_SEARCH_GUIDANCE
from onyx.prompts.tool_prompts import WEB_SEARCH_SITE_DISABLED_GUIDANCE
from onyx.prompts.user_info import USER_INFORMATION_HEADER
from onyx.tools.interface import Tool
from onyx.tools.tool_implementations.images.image_generation_tool import (
ImageGenerationTool,
@@ -52,7 +53,7 @@ def calculate_reserved_tokens(
persona_system_prompt: str,
token_counter: Callable[[str], int],
files: list[FileDescriptor] | None = None,
memories: list[str] | None = None,
user_memory_context: UserMemoryContext | None = None,
) -> int:
"""
Calculate reserved token count for system prompt and user files.
@@ -66,7 +67,7 @@ def calculate_reserved_tokens(
persona_system_prompt: Custom agent system prompt (can be empty string)
token_counter: Function that counts tokens in text
files: List of file descriptors from the chat message (optional)
memories: List of memory strings (optional)
user_memory_context: User memory context (optional)
Returns:
Total reserved token count
@@ -77,7 +78,7 @@ def calculate_reserved_tokens(
fake_system_prompt = build_system_prompt(
base_system_prompt=base_system_prompt,
datetime_aware=True,
memories=memories,
user_memory_context=user_memory_context,
tools=None,
should_cite_documents=True,
include_all_guidance=True,
@@ -133,7 +134,7 @@ def build_reminder_message(
def build_system_prompt(
base_system_prompt: str,
datetime_aware: bool = False,
memories: list[str] | None = None,
user_memory_context: UserMemoryContext | None = None,
tools: Sequence[Tool] | None = None,
should_cite_documents: bool = False,
include_all_guidance: bool = False,
@@ -157,14 +158,15 @@ def build_system_prompt(
)
company_context = get_company_context()
if company_context or memories:
system_prompt += USER_INFO_HEADER
formatted_user_context = (
user_memory_context.as_formatted_prompt() if user_memory_context else ""
)
if company_context or formatted_user_context:
system_prompt += USER_INFORMATION_HEADER
if company_context:
system_prompt += company_context
if memories:
system_prompt += "\n".join(
"- " + memory.strip() for memory in memories if memory.strip()
)
if formatted_user_context:
system_prompt += formatted_user_context
# Append citation guidance after company context if placeholder was not present
# This maintains backward compatibility and ensures citations are always enforced when needed

View File

@@ -65,7 +65,9 @@ class OnyxDBCredentialsProvider(
f"No credential found: credential={self._credential_id}"
)
return credential.credential_json
if credential.credential_json is None:
return {}
return credential.credential_json.get_value(apply_mask=False)
def set_credentials(self, credential_json: dict[str, Any]) -> None:
with get_session_with_tenant(tenant_id=self._tenant_id) as db_session:
@@ -81,7 +83,7 @@ class OnyxDBCredentialsProvider(
f"No credential found: credential={self._credential_id}"
)
credential.credential_json = credential_json
credential.credential_json = credential_json # type: ignore[assignment]
db_session.commit()
except Exception:
db_session.rollback()

View File

@@ -118,7 +118,12 @@ def instantiate_connector(
)
connector.set_credentials_provider(provider)
else:
new_credentials = connector.load_credentials(credential.credential_json)
credential_json = (
credential.credential_json.get_value(apply_mask=False)
if credential.credential_json
else {}
)
new_credentials = connector.load_credentials(credential_json)
if new_credentials is not None:
backend_update_credential_json(credential, new_credentials, db_session)

View File

@@ -270,6 +270,8 @@ def create_credential(
)
db_session.commit()
# Expire to ensure credential_json is reloaded as SensitiveValue from DB
db_session.expire(credential)
return credential
@@ -297,14 +299,21 @@ def alter_credential(
credential.name = name
# Assign a new dictionary to credential.credential_json
credential.credential_json = {
**credential.credential_json,
# Get existing credential_json and merge with new values
existing_json = (
credential.credential_json.get_value(apply_mask=False)
if credential.credential_json
else {}
)
credential.credential_json = { # type: ignore[assignment]
**existing_json,
**credential_json,
}
credential.user_id = user.id
db_session.commit()
# Expire to ensure credential_json is reloaded as SensitiveValue from DB
db_session.expire(credential)
return credential
@@ -318,10 +327,12 @@ def update_credential(
if credential is None:
return None
credential.credential_json = credential_data.credential_json
credential.user_id = user.id
credential.credential_json = credential_data.credential_json # type: ignore[assignment]
credential.user_id = user.id if user is not None else None
db_session.commit()
# Expire to ensure credential_json is reloaded as SensitiveValue from DB
db_session.expire(credential)
return credential
@@ -335,8 +346,10 @@ def update_credential_json(
if credential is None:
return None
credential.credential_json = credential_json
credential.credential_json = credential_json # type: ignore[assignment]
db_session.commit()
# Expire to ensure credential_json is reloaded as SensitiveValue from DB
db_session.expire(credential)
return credential
@@ -346,7 +359,7 @@ def backend_update_credential_json(
db_session: Session,
) -> None:
"""This should not be used in any flows involving the frontend or users"""
credential.credential_json = credential_json
credential.credential_json = credential_json # type: ignore[assignment]
db_session.commit()
@@ -441,7 +454,12 @@ def create_initial_public_credential(db_session: Session) -> None:
)
if first_credential is not None:
if first_credential.credential_json != {} or first_credential.user is not None:
credential_json_value = (
first_credential.credential_json.get_value(apply_mask=False)
if first_credential.credential_json
else {}
)
if credential_json_value != {} or first_credential.user is not None:
raise ValueError(error_msg)
return
@@ -477,8 +495,13 @@ def delete_service_account_credentials(
) -> None:
credentials = fetch_credentials_for_user(db_session=db_session, user=user)
for credential in credentials:
credential_json = (
credential.credential_json.get_value(apply_mask=False)
if credential.credential_json
else {}
)
if (
credential.credential_json.get(DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY)
credential_json.get(DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY)
and credential.source == source
):
db_session.delete(credential)

View File

@@ -111,7 +111,7 @@ def update_federated_connector_oauth_token(
if existing_token:
# Update existing token
existing_token.token = token
existing_token.token = token # type: ignore[assignment]
existing_token.expires_at = expires_at
db_session.commit()
return existing_token
@@ -267,7 +267,13 @@ def update_federated_connector(
# Use provided credentials if updating them, otherwise use existing credentials
# This is needed to instantiate the connector for config validation when only config is being updated
creds_to_use = (
credentials if credentials is not None else federated_connector.credentials
credentials
if credentials is not None
else (
federated_connector.credentials.get_value(apply_mask=False)
if federated_connector.credentials
else {}
)
)
if credentials is not None:
@@ -278,7 +284,7 @@ def update_federated_connector(
raise ValueError(
f"Invalid credentials for federated connector source: {federated_connector.source}"
)
federated_connector.credentials = credentials
federated_connector.credentials = credentials # type: ignore[assignment]
if config is not None:
# Validate config using connector-specific validation

View File

@@ -232,7 +232,8 @@ def upsert_llm_provider(
custom_config = custom_config or None
existing_llm_provider.provider = llm_provider_upsert_request.provider
existing_llm_provider.api_key = llm_provider_upsert_request.api_key
# EncryptedString accepts str for writes, returns SensitiveValue for reads
existing_llm_provider.api_key = llm_provider_upsert_request.api_key # type: ignore[assignment]
existing_llm_provider.api_base = llm_provider_upsert_request.api_base
existing_llm_provider.api_version = llm_provider_upsert_request.api_version
existing_llm_provider.custom_config = custom_config

View File

@@ -19,6 +19,7 @@ from onyx.db.models import Tool
from onyx.db.models import User
from onyx.server.features.mcp.models import MCPConnectionData
from onyx.utils.logger import setup_logger
from onyx.utils.sensitive import SensitiveValue
logger = setup_logger()
@@ -204,6 +205,21 @@ def remove_user_from_mcp_server(
# MCPConnectionConfig operations
def extract_connection_data(
config: MCPConnectionConfig | None, apply_mask: bool = False
) -> MCPConnectionData:
"""Extract MCPConnectionData from a connection config, with proper typing.
This helper encapsulates the cast from the JSON column's dict[str, Any]
to the typed MCPConnectionData structure.
"""
if config is None or config.config is None:
return MCPConnectionData(headers={})
if isinstance(config.config, SensitiveValue):
return cast(MCPConnectionData, config.config.get_value(apply_mask=apply_mask))
return cast(MCPConnectionData, config.config)
def get_connection_config_by_id(
config_id: int, db_session: Session
) -> MCPConnectionConfig:
@@ -269,7 +285,7 @@ def update_connection_config(
config = get_connection_config_by_id(config_id, db_session)
if config_data is not None:
config.config = config_data
config.config = config_data # type: ignore[assignment]
# Force SQLAlchemy to detect the change by marking the field as modified
flag_modified(config, "config")
@@ -287,7 +303,7 @@ def upsert_user_connection_config(
existing_config = get_user_connection_config(server_id, user_email, db_session)
if existing_config:
existing_config.config = config_data
existing_config.config = config_data # type: ignore[assignment]
db_session.flush() # Don't commit yet, let caller decide when to commit
return existing_config
else:

View File

@@ -1,22 +1,111 @@
from pydantic import BaseModel
from pydantic import ConfigDict
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.db.models import Memory
from onyx.db.models import User
from onyx.prompts.user_info import BASIC_INFORMATION_PROMPT
from onyx.prompts.user_info import USER_MEMORIES_PROMPT
from onyx.prompts.user_info import USER_PREFERENCES_PROMPT
from onyx.prompts.user_info import USER_ROLE_PROMPT
def get_memories(user: User, db_session: Session) -> list[str]:
class UserInfo(BaseModel):
name: str | None = None
role: str | None = None
email: str | None = None
def to_dict(self) -> dict:
return {
"name": self.name,
"role": self.role,
"email": self.email,
}
class UserMemoryContext(BaseModel):
model_config = ConfigDict(frozen=True)
user_info: UserInfo
user_preferences: str | None = None
memories: tuple[str, ...] = ()
def as_formatted_list(self) -> list[str]:
"""Returns combined list of user info, preferences, and memories."""
result = []
if self.user_info.name:
result.append(f"User's name: {self.user_info.name}")
if self.user_info.role:
result.append(f"User's role: {self.user_info.role}")
if self.user_info.email:
result.append(f"User's email: {self.user_info.email}")
if self.user_preferences:
result.append(f"User preferences: {self.user_preferences}")
result.extend(self.memories)
return result
def as_formatted_prompt(self) -> str:
"""Returns structured prompt sections for the system prompt."""
has_basic_info = (
self.user_info.name or self.user_info.email or self.user_info.role
)
if not has_basic_info and not self.user_preferences and not self.memories:
return ""
sections: list[str] = []
if has_basic_info:
role_line = (
USER_ROLE_PROMPT.format(user_role=self.user_info.role).strip()
if self.user_info.role
else ""
)
if role_line:
role_line = "\n" + role_line
sections.append(
BASIC_INFORMATION_PROMPT.format(
user_name=self.user_info.name or "",
user_email=self.user_info.email or "",
user_role=role_line,
)
)
if self.user_preferences:
sections.append(
USER_PREFERENCES_PROMPT.format(user_preferences=self.user_preferences)
)
if self.memories:
formatted_memories = "\n".join(f"- {memory}" for memory in self.memories)
sections.append(
USER_MEMORIES_PROMPT.format(user_memories=formatted_memories)
)
return "".join(sections)
def get_memories(user: User, db_session: Session) -> UserMemoryContext:
if not user.use_memories:
return []
return UserMemoryContext(user_info=UserInfo())
user_info = [
f"User's name: {user.personal_name}" if user.personal_name else "",
f"User's role: {user.personal_role}" if user.personal_role else "",
f"User's email: {user.email}" if user.email else "",
]
user_info = UserInfo(
name=user.personal_name,
role=user.personal_role,
email=user.email,
)
user_preferences = None
if user.user_preferences:
user_preferences = user.user_preferences
memory_rows = db_session.scalars(
select(Memory).where(Memory.user_id == user.id)
select(Memory).where(Memory.user_id == user.id).order_by(Memory.id.asc())
).all()
memories = [memory.memory_text for memory in memory_rows if memory.memory_text]
return user_info + memories
memories = tuple(memory.memory_text for memory in memory_rows if memory.memory_text)
return UserMemoryContext(
user_info=user_info,
user_preferences=user_preferences,
memories=memories,
)

View File

@@ -95,10 +95,10 @@ from onyx.file_store.models import FileDescriptor
from onyx.llm.override_models import LLMOverride
from onyx.llm.override_models import PromptOverride
from onyx.kg.models import KGStage
from onyx.server.features.mcp.models import MCPConnectionData
from onyx.tools.tool_implementations.web_search.models import WebContentProviderConfig
from onyx.utils.encryption import decrypt_bytes_to_string
from onyx.utils.encryption import encrypt_string_to_bytes
from onyx.utils.sensitive import SensitiveValue
from onyx.utils.headers import HeaderItemDict
from shared_configs.enums import EmbeddingProvider
from onyx.context.search.enums import RecencyBiasSetting
@@ -122,18 +122,35 @@ class EncryptedString(TypeDecorator):
cache_ok = True
def process_bind_param(
self, value: str | None, dialect: Dialect # noqa: ARG002
self, value: str | SensitiveValue[str] | None, dialect: Dialect # noqa: ARG002
) -> bytes | None:
if value is not None:
# Handle both raw strings and SensitiveValue wrappers
if isinstance(value, SensitiveValue):
# Get raw value for storage
value = value.get_value(apply_mask=False)
return encrypt_string_to_bytes(value)
return value
def process_result_value(
self, value: bytes | None, dialect: Dialect # noqa: ARG002
) -> str | None:
) -> SensitiveValue[str] | None:
if value is not None:
return decrypt_bytes_to_string(value)
return value
return SensitiveValue(
encrypted_bytes=value,
decrypt_fn=decrypt_bytes_to_string,
is_json=False,
)
return None
def compare_values(self, x: Any, y: Any) -> bool:
if x is None or y is None:
return x == y
if isinstance(x, SensitiveValue):
x = x.get_value(apply_mask=False)
if isinstance(y, SensitiveValue):
y = y.get_value(apply_mask=False)
return x == y
class EncryptedJson(TypeDecorator):
@@ -142,20 +159,38 @@ class EncryptedJson(TypeDecorator):
cache_ok = True
def process_bind_param(
self, value: dict | None, dialect: Dialect # noqa: ARG002
self,
value: dict[str, Any] | SensitiveValue[dict[str, Any]] | None,
dialect: Dialect, # noqa: ARG002
) -> bytes | None:
if value is not None:
# Handle both raw dicts and SensitiveValue wrappers
if isinstance(value, SensitiveValue):
# Get raw value for storage
value = value.get_value(apply_mask=False)
json_str = json.dumps(value)
return encrypt_string_to_bytes(json_str)
return value
def process_result_value(
self, value: bytes | None, dialect: Dialect # noqa: ARG002
) -> dict | None:
) -> SensitiveValue[dict[str, Any]] | None:
if value is not None:
json_str = decrypt_bytes_to_string(value)
return json.loads(json_str)
return value
return SensitiveValue(
encrypted_bytes=value,
decrypt_fn=decrypt_bytes_to_string,
is_json=True,
)
return None
def compare_values(self, x: Any, y: Any) -> bool:
if x is None or y is None:
return x == y
if isinstance(x, SensitiveValue):
x = x.get_value(apply_mask=False)
if isinstance(y, SensitiveValue):
y = y.get_value(apply_mask=False)
return x == y
class NullFilteredString(TypeDecorator):
@@ -216,6 +251,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
personal_name: Mapped[str | None] = mapped_column(String, nullable=True)
personal_role: Mapped[str | None] = mapped_column(String, nullable=True)
use_memories: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
user_preferences: Mapped[str | None] = mapped_column(Text, nullable=True)
chosen_assistants: Mapped[list[int] | None] = mapped_column(
postgresql.JSONB(), nullable=True, default=None
@@ -1755,7 +1791,9 @@ class Credential(Base):
)
id: Mapped[int] = mapped_column(primary_key=True)
credential_json: Mapped[dict[str, Any]] = mapped_column(EncryptedJson())
credential_json: Mapped[SensitiveValue[dict[str, Any]] | None] = mapped_column(
EncryptedJson()
)
user_id: Mapped[UUID | None] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
)
@@ -1793,7 +1831,9 @@ class FederatedConnector(Base):
source: Mapped[FederatedConnectorSource] = mapped_column(
Enum(FederatedConnectorSource, native_enum=False)
)
credentials: Mapped[dict[str, str]] = mapped_column(EncryptedJson(), nullable=False)
credentials: Mapped[SensitiveValue[dict[str, Any]] | None] = mapped_column(
EncryptedJson(), nullable=False
)
config: Mapped[dict[str, Any]] = mapped_column(
postgresql.JSONB(), default=dict, nullable=False, server_default="{}"
)
@@ -1820,7 +1860,9 @@ class FederatedConnectorOAuthToken(Base):
user_id: Mapped[UUID] = mapped_column(
ForeignKey("user.id", ondelete="CASCADE"), nullable=False
)
token: Mapped[str] = mapped_column(EncryptedString(), nullable=False)
token: Mapped[SensitiveValue[str] | None] = mapped_column(
EncryptedString(), nullable=False
)
expires_at: Mapped[datetime.datetime | None] = mapped_column(
DateTime, nullable=True
)
@@ -1964,7 +2006,9 @@ class SearchSettings(Base):
@property
def api_key(self) -> str | None:
return self.cloud_provider.api_key if self.cloud_provider is not None else None
if self.cloud_provider is None or self.cloud_provider.api_key is None:
return None
return self.cloud_provider.api_key.get_value(apply_mask=False)
@property
def large_chunks_enabled(self) -> bool:
@@ -2726,7 +2770,9 @@ class LLMProvider(Base):
id: Mapped[int] = mapped_column(Integer, primary_key=True)
name: Mapped[str] = mapped_column(String, unique=True)
provider: Mapped[str] = mapped_column(String)
api_key: Mapped[str | None] = mapped_column(EncryptedString(), nullable=True)
api_key: Mapped[SensitiveValue[str] | None] = mapped_column(
EncryptedString(), nullable=True
)
api_base: Mapped[str | None] = mapped_column(String, nullable=True)
api_version: Mapped[str | None] = mapped_column(String, nullable=True)
# custom configs that should be passed to the LLM provider at inference time
@@ -2879,7 +2925,7 @@ class CloudEmbeddingProvider(Base):
Enum(EmbeddingProvider), primary_key=True
)
api_url: Mapped[str | None] = mapped_column(String, nullable=True)
api_key: Mapped[str | None] = mapped_column(EncryptedString())
api_key: Mapped[SensitiveValue[str] | None] = mapped_column(EncryptedString())
api_version: Mapped[str | None] = mapped_column(String, nullable=True)
deployment_name: Mapped[str | None] = mapped_column(String, nullable=True)
@@ -2898,7 +2944,9 @@ class InternetSearchProvider(Base):
id: Mapped[int] = mapped_column(Integer, primary_key=True)
name: Mapped[str] = mapped_column(String, unique=True, nullable=False)
provider_type: Mapped[str] = mapped_column(String, nullable=False)
api_key: Mapped[str | None] = mapped_column(EncryptedString(), nullable=True)
api_key: Mapped[SensitiveValue[str] | None] = mapped_column(
EncryptedString(), nullable=True
)
config: Mapped[dict[str, str] | None] = mapped_column(
postgresql.JSONB(), nullable=True
)
@@ -2920,7 +2968,9 @@ class InternetContentProvider(Base):
id: Mapped[int] = mapped_column(Integer, primary_key=True)
name: Mapped[str] = mapped_column(String, unique=True, nullable=False)
provider_type: Mapped[str] = mapped_column(String, nullable=False)
api_key: Mapped[str | None] = mapped_column(EncryptedString(), nullable=True)
api_key: Mapped[SensitiveValue[str] | None] = mapped_column(
EncryptedString(), nullable=True
)
config: Mapped[WebContentProviderConfig | None] = mapped_column(
PydanticType(WebContentProviderConfig), nullable=True
)
@@ -3064,8 +3114,12 @@ class OAuthConfig(Base):
token_url: Mapped[str] = mapped_column(Text, nullable=False)
# Client credentials (encrypted)
client_id: Mapped[str] = mapped_column(EncryptedString(), nullable=False)
client_secret: Mapped[str] = mapped_column(EncryptedString(), nullable=False)
client_id: Mapped[SensitiveValue[str] | None] = mapped_column(
EncryptedString(), nullable=False
)
client_secret: Mapped[SensitiveValue[str] | None] = mapped_column(
EncryptedString(), nullable=False
)
# Optional configurations
scopes: Mapped[list[str] | None] = mapped_column(postgresql.JSONB(), nullable=True)
@@ -3112,7 +3166,9 @@ class OAuthUserToken(Base):
# "expires_at": 1234567890, # Unix timestamp, optional
# "scope": "repo user" # Optional
# }
token_data: Mapped[dict[str, Any]] = mapped_column(EncryptedJson(), nullable=False)
token_data: Mapped[SensitiveValue[dict[str, Any]] | None] = mapped_column(
EncryptedJson(), nullable=False
)
# Metadata
created_at: Mapped[datetime.datetime] = mapped_column(
@@ -3445,9 +3501,15 @@ class SlackBot(Base):
name: Mapped[str] = mapped_column(String)
enabled: Mapped[bool] = mapped_column(Boolean, default=True)
bot_token: Mapped[str] = mapped_column(EncryptedString(), unique=True)
app_token: Mapped[str] = mapped_column(EncryptedString(), unique=True)
user_token: Mapped[str | None] = mapped_column(EncryptedString(), nullable=True)
bot_token: Mapped[SensitiveValue[str] | None] = mapped_column(
EncryptedString(), unique=True
)
app_token: Mapped[SensitiveValue[str] | None] = mapped_column(
EncryptedString(), unique=True
)
user_token: Mapped[SensitiveValue[str] | None] = mapped_column(
EncryptedString(), nullable=True
)
slack_channel_configs: Mapped[list[SlackChannelConfig]] = relationship(
"SlackChannelConfig",
@@ -3468,7 +3530,9 @@ class DiscordBotConfig(Base):
id: Mapped[str] = mapped_column(
String, primary_key=True, server_default=text("'SINGLETON'")
)
bot_token: Mapped[str] = mapped_column(EncryptedString(), nullable=False)
bot_token: Mapped[SensitiveValue[str] | None] = mapped_column(
EncryptedString(), nullable=False
)
created_at: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True), server_default=func.now(), nullable=False
)
@@ -3624,7 +3688,9 @@ class KVStore(Base):
key: Mapped[str] = mapped_column(String, primary_key=True)
value: Mapped[JSON_ro] = mapped_column(postgresql.JSONB(), nullable=True)
encrypted_value: Mapped[JSON_ro] = mapped_column(EncryptedJson(), nullable=True)
encrypted_value: Mapped[SensitiveValue[dict[str, Any]] | None] = mapped_column(
EncryptedJson(), nullable=True
)
class FileRecord(Base):
@@ -4344,7 +4410,7 @@ class MCPConnectionConfig(Base):
# "registration_access_token": "<token>", # For managing registration
# "registration_client_uri": "<uri>", # For managing registration
# }
config: Mapped[MCPConnectionData] = mapped_column(
config: Mapped[SensitiveValue[dict[str, Any]] | None] = mapped_column(
EncryptedJson(), nullable=False, default=dict
)

View File

@@ -87,13 +87,13 @@ def update_oauth_config(
if token_url is not None:
oauth_config.token_url = token_url
if clear_client_id:
oauth_config.client_id = ""
oauth_config.client_id = "" # type: ignore[assignment]
elif client_id is not None:
oauth_config.client_id = client_id
oauth_config.client_id = client_id # type: ignore[assignment]
if clear_client_secret:
oauth_config.client_secret = ""
oauth_config.client_secret = "" # type: ignore[assignment]
elif client_secret is not None:
oauth_config.client_secret = client_secret
oauth_config.client_secret = client_secret # type: ignore[assignment]
if scopes is not None:
oauth_config.scopes = scopes
if additional_params is not None:
@@ -154,7 +154,7 @@ def upsert_user_oauth_token(
if existing_token:
# Update existing token
existing_token.token_data = token_data
existing_token.token_data = token_data # type: ignore[assignment]
db_session.commit()
return existing_token
else:

View File

@@ -43,9 +43,9 @@ def update_slack_bot(
# update the app
slack_bot.name = name
slack_bot.enabled = enabled
slack_bot.bot_token = bot_token
slack_bot.app_token = app_token
slack_bot.user_token = user_token
slack_bot.bot_token = bot_token # type: ignore[assignment]
slack_bot.app_token = app_token # type: ignore[assignment]
slack_bot.user_token = user_token # type: ignore[assignment]
db_session.commit()

View File

@@ -160,6 +160,7 @@ def update_user_personalization(
personal_role: str | None,
use_memories: bool,
memories: list[str],
user_preferences: str | None,
db_session: Session,
) -> None:
db_session.execute(
@@ -169,6 +170,7 @@ def update_user_personalization(
personal_name=personal_name,
personal_role=personal_role,
use_memories=use_memories,
user_preferences=user_preferences,
)
)

View File

@@ -73,7 +73,8 @@ def _apply_search_provider_updates(
provider.provider_type = provider_type.value
provider.config = config
if api_key_changed or provider.api_key is None:
provider.api_key = api_key
# EncryptedString accepts str for writes, returns SensitiveValue for reads
provider.api_key = api_key # type: ignore[assignment]
def upsert_web_search_provider(
@@ -228,7 +229,8 @@ def _apply_content_provider_updates(
provider.provider_type = provider_type.value
provider.config = config
if api_key_changed or provider.api_key is None:
provider.api_key = api_key
# EncryptedString accepts str for writes, returns SensitiveValue for reads
provider.api_key = api_key # type: ignore[assignment]
def upsert_web_content_provider(

View File

@@ -119,7 +119,16 @@ def get_federated_retrieval_functions(
federated_retrieval_infos_slack = []
# Use user_token if available, otherwise fall back to bot_token
access_token = tenant_slack_bot.user_token or tenant_slack_bot.bot_token
# Unwrap SensitiveValue for backend API calls
access_token = (
tenant_slack_bot.user_token.get_value(apply_mask=False)
if tenant_slack_bot.user_token
else (
tenant_slack_bot.bot_token.get_value(apply_mask=False)
if tenant_slack_bot.bot_token
else ""
)
)
if not tenant_slack_bot.user_token:
logger.warning(
f"Using bot_token for Slack search (limited functionality): {tenant_slack_bot.name}"
@@ -138,7 +147,12 @@ def get_federated_retrieval_functions(
)
# Capture variables by value to avoid lambda closure issues
bot_token = tenant_slack_bot.bot_token
# Unwrap SensitiveValue for backend API calls
bot_token = (
tenant_slack_bot.bot_token.get_value(apply_mask=False)
if tenant_slack_bot.bot_token
else ""
)
# Use connector config for channel filtering (guaranteed to exist at this point)
connector_entities = slack_federated_connector_config
@@ -252,11 +266,11 @@ def get_federated_retrieval_functions(
connector = get_federated_connector(
oauth_token.federated_connector.source,
oauth_token.federated_connector.credentials,
oauth_token.federated_connector.credentials.get_value(apply_mask=False),
)
# Capture variables by value to avoid lambda closure issues
access_token = oauth_token.token
access_token = oauth_token.token.get_value(apply_mask=False)
def create_retrieval_function(
conn: FederatedConnector,

View File

@@ -43,7 +43,7 @@ class PgRedisKVStore(KeyValueStore):
obj = db_session.query(KVStore).filter_by(key=key).first()
if obj:
obj.value = plain_val
obj.encrypted_value = encrypted_val
obj.encrypted_value = encrypted_val # type: ignore[assignment]
else:
obj = KVStore(key=key, value=plain_val, encrypted_value=encrypted_val)
db_session.query(KVStore).filter_by(key=key).delete() # just in case
@@ -73,7 +73,8 @@ class PgRedisKVStore(KeyValueStore):
if obj.value is not None:
value = obj.value
elif obj.encrypted_value is not None:
value = obj.encrypted_value
# Unwrap SensitiveValue - this is internal backend use
value = obj.encrypted_value.get_value(apply_mask=False)
else:
value = None

View File

@@ -1,4 +1,8 @@
import os
import threading
from collections.abc import Iterator
from contextlib import contextmanager
from contextlib import nullcontext
from typing import Any
from typing import cast
from typing import TYPE_CHECKING
@@ -44,11 +48,13 @@ from onyx.llm.well_known_providers.constants import (
VERTEX_CREDENTIALS_FILE_KWARG_ENV_VAR_FORMAT,
)
from onyx.llm.well_known_providers.constants import VERTEX_LOCATION_KWARG
from onyx.server.utils import mask_string
from onyx.utils.encryption import mask_string
from onyx.utils.logger import setup_logger
logger = setup_logger()
_env_lock = threading.Lock()
if TYPE_CHECKING:
from litellm import CustomStreamWrapper
from litellm import HTTPHandler
@@ -378,23 +384,29 @@ class LitellmLLM(LLM):
if "api_key" not in passthrough_kwargs:
passthrough_kwargs["api_key"] = self._api_key or None
response = litellm.completion(
mock_response=get_llm_mock_response() or MOCK_LLM_RESPONSE,
model=model,
base_url=self._api_base or None,
api_version=self._api_version or None,
custom_llm_provider=self._custom_llm_provider or None,
messages=_prompt_to_dicts(prompt),
tools=tools,
tool_choice=tool_choice,
stream=stream,
temperature=temperature,
timeout=timeout_override or self._timeout,
max_tokens=max_tokens,
client=client,
**optional_kwargs,
**passthrough_kwargs,
env_ctx = (
temporary_env_and_lock(self._custom_config)
if self._custom_config
else nullcontext()
)
with env_ctx:
response = litellm.completion(
mock_response=get_llm_mock_response() or MOCK_LLM_RESPONSE,
model=model,
base_url=self._api_base or None,
api_version=self._api_version or None,
custom_llm_provider=self._custom_llm_provider or None,
messages=_prompt_to_dicts(prompt),
tools=tools,
tool_choice=tool_choice,
stream=stream,
temperature=temperature,
timeout=timeout_override or self._timeout,
max_tokens=max_tokens,
client=client,
**optional_kwargs,
**passthrough_kwargs,
)
return response
except Exception as e:
# for break pointing
@@ -475,22 +487,53 @@ class LitellmLLM(LLM):
client = HTTPHandler(timeout=timeout_override or self._timeout)
try:
response = cast(
LiteLLMModelResponse,
self._completion(
prompt=prompt,
tools=tools,
tool_choice=tool_choice,
stream=False,
structured_response_format=structured_response_format,
timeout_override=timeout_override,
max_tokens=max_tokens,
parallel_tool_calls=True,
reasoning_effort=reasoning_effort,
user_identity=user_identity,
client=client,
),
)
if self._custom_config:
# When custom_config is set, env vars are temporarily injected
# under a global lock. Using stream=True here means the lock is
# only held during connection setup (not the full inference).
# The chunks are then collected outside the lock and reassembled
# into a single ModelResponse via stream_chunk_builder.
from litellm import stream_chunk_builder
from litellm import CustomStreamWrapper as LiteLLMCustomStreamWrapper
stream_response = cast(
LiteLLMCustomStreamWrapper,
self._completion(
prompt=prompt,
tools=tools,
tool_choice=tool_choice,
stream=True,
structured_response_format=structured_response_format,
timeout_override=timeout_override,
max_tokens=max_tokens,
parallel_tool_calls=True,
reasoning_effort=reasoning_effort,
user_identity=user_identity,
client=client,
),
)
chunks = list(stream_response)
response = cast(
LiteLLMModelResponse,
stream_chunk_builder(chunks),
)
else:
response = cast(
LiteLLMModelResponse,
self._completion(
prompt=prompt,
tools=tools,
tool_choice=tool_choice,
stream=False,
structured_response_format=structured_response_format,
timeout_override=timeout_override,
max_tokens=max_tokens,
parallel_tool_calls=True,
reasoning_effort=reasoning_effort,
user_identity=user_identity,
client=client,
),
)
model_response = from_litellm_model_response(response)
@@ -581,3 +624,29 @@ class LitellmLLM(LLM):
finally:
if client is not None:
client.close()
@contextmanager
def temporary_env_and_lock(env_variables: dict[str, str]) -> Iterator[None]:
"""
Temporarily sets the environment variables to the given values.
Code path is locked while the environment variables are set.
Then cleans up the environment and frees the lock.
"""
with _env_lock:
logger.debug("Acquired lock in temporary_env_and_lock")
# Store original values (None if key didn't exist)
original_values: dict[str, str | None] = {
key: os.environ.get(key) for key in env_variables
}
try:
os.environ.update(env_variables)
yield
finally:
for key, original_value in original_values.items():
if original_value is None:
os.environ.pop(key, None) # Remove if it didn't exist before
else:
os.environ[key] = original_value # Restore original value
logger.debug("Released lock in temporary_env_and_lock")

View File

@@ -4,6 +4,7 @@ from onyx.configs.constants import AuthType
from onyx.db.discord_bot import get_discord_bot_config
from onyx.db.engine.sql_engine import get_session_with_tenant
from onyx.utils.logger import setup_logger
from onyx.utils.sensitive import SensitiveValue
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
logger = setup_logger()
@@ -36,4 +37,8 @@ def get_bot_token() -> str | None:
except Exception as e:
logger.error(f"Failed to get bot token from database: {e}")
return None
return config.bot_token if config else None
if config and config.bot_token:
if isinstance(config.bot_token, SensitiveValue):
return config.bot_token.get_value(apply_mask=False)
return config.bot_token
return None

View File

@@ -216,14 +216,10 @@ class SlackbotHandler:
- If the tokens have changed, close the existing socket client and reconnect.
- If the tokens are new, warm up the model and start a new socket client.
"""
slack_bot_tokens = SlackBotTokens(
bot_token=bot.bot_token,
app_token=bot.app_token,
)
tenant_bot_pair = (tenant_id, bot.id)
# If the tokens are missing or empty, close the socket client and remove them.
if not slack_bot_tokens:
if not bot.bot_token or not bot.app_token:
logger.debug(
f"No Slack bot tokens found for tenant={tenant_id}, bot {bot.id}"
)
@@ -233,6 +229,11 @@ class SlackbotHandler:
del self.slack_bot_tokens[tenant_bot_pair]
return
slack_bot_tokens = SlackBotTokens(
bot_token=bot.bot_token.get_value(apply_mask=False),
app_token=bot.app_token.get_value(apply_mask=False),
)
tokens_exist = tenant_bot_pair in self.slack_bot_tokens
tokens_changed = (
tokens_exist and slack_bot_tokens != self.slack_bot_tokens[tenant_bot_pair]

View File

@@ -25,9 +25,6 @@ You can use Markdown tables to format your responses for data, lists, and other
""".lstrip()
# Section for information about the user if provided such as their name, role, memories, etc.
USER_INFO_HEADER = "\n\n# User Information\n"
COMPANY_NAME_BLOCK = """
The user is at an organization called `{company_name}`.
"""

View File

@@ -403,12 +403,13 @@ def check_drive_tokens(
db_session: Session = Depends(get_session),
) -> AuthStatus:
db_credentials = fetch_credential_by_id_for_user(credential_id, user, db_session)
if (
not db_credentials
or DB_CREDENTIALS_DICT_TOKEN_KEY not in db_credentials.credential_json
):
if not db_credentials or not db_credentials.credential_json:
return AuthStatus(authenticated=False)
token_json_str = str(db_credentials.credential_json[DB_CREDENTIALS_DICT_TOKEN_KEY])
credential_json = db_credentials.credential_json.get_value(apply_mask=False)
if DB_CREDENTIALS_DICT_TOKEN_KEY not in credential_json:
return AuthStatus(authenticated=False)
token_json_str = str(credential_json[DB_CREDENTIALS_DICT_TOKEN_KEY])
google_drive_creds = get_google_oauth_creds(
token_json_str=token_json_str,
source=DocumentSource.GOOGLE_DRIVE,

View File

@@ -346,10 +346,17 @@ def update_credential_from_model(
detail=f"Credential {credential_id} does not exist or does not belong to user",
)
# Get credential_json value - use masking for API responses
credential_json_value = (
updated_credential.credential_json.get_value(apply_mask=True)
if updated_credential.credential_json
else {}
)
return CredentialSnapshot(
source=updated_credential.source,
id=updated_credential.id,
credential_json=updated_credential.credential_json,
credential_json=credential_json_value,
user_id=updated_credential.user_id,
name=updated_credential.name,
admin_public=updated_credential.admin_public,

View File

@@ -28,7 +28,6 @@ from onyx.db.models import IndexAttempt
from onyx.db.models import IndexingStatus
from onyx.db.models import TaskStatus
from onyx.server.federated.models import FederatedConnectorStatus
from onyx.server.utils import mask_credential_dict
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
@@ -145,13 +144,21 @@ class CredentialSnapshot(CredentialBase):
@classmethod
def from_credential_db_model(cls, credential: Credential) -> "CredentialSnapshot":
# Get the credential_json value with appropriate masking
if credential.credential_json is None:
credential_json_value: dict[str, Any] = {}
elif MASK_CREDENTIAL_PREFIX:
credential_json_value = credential.credential_json.get_value(
apply_mask=True
)
else:
credential_json_value = credential.credential_json.get_value(
apply_mask=False
)
return CredentialSnapshot(
id=credential.id,
credential_json=(
mask_credential_dict(credential.credential_json)
if MASK_CREDENTIAL_PREFIX and credential.credential_json
else credential.credential_json
),
credential_json=credential_json_value,
user_id=credential.user_id,
user_email=credential.user.email if credential.user else None,
admin_public=credential.admin_public,

View File

@@ -88,7 +88,7 @@ SANDBOX_NAMESPACE = os.environ.get("SANDBOX_NAMESPACE", "onyx-sandboxes")
# Container image for sandbox pods
# Should include Next.js template, opencode CLI, and demo_data zip
SANDBOX_CONTAINER_IMAGE = os.environ.get(
"SANDBOX_CONTAINER_IMAGE", "onyxdotapp/sandbox:v0.1.2"
"SANDBOX_CONTAINER_IMAGE", "onyxdotapp/sandbox:v0.1.3"
)
# S3 bucket for sandbox file storage (snapshots, knowledge files, uploads)

View File

@@ -0,0 +1,119 @@
# Sandbox Container Image
This directory contains the Dockerfile and resources for building the Onyx Craft sandbox container image.
## Directory Structure
```
docker/
├── Dockerfile # Main container image definition
├── demo_data.zip # Demo data (extracted to /workspace/demo_data)
├── templates/
│ └── outputs/ # Web app scaffold template (Next.js)
├── initial-requirements.txt # Python packages pre-installed in sandbox
├── generate_agents_md.py # Script to generate AGENTS.md for sessions
└── README.md # This file
```
## Building the Image
The sandbox image must be built for **amd64** architecture since our Kubernetes cluster runs on x86_64 nodes.
### Build for amd64 only (fastest)
```bash
cd backend/onyx/server/features/build/sandbox/kubernetes/docker
docker build --platform linux/amd64 -t onyxdotapp/sandbox:v0.1.x .
docker push onyxdotapp/sandbox:v0.1.x
```
### Build multi-arch (recommended for flexibility)
```bash
docker buildx build --platform linux/amd64,linux/arm64 \
-t onyxdotapp/sandbox:v0.1.x \
--push .
```
### Update the `latest` tag
After pushing a versioned tag, update `latest`:
```bash
docker tag onyxdotapp/sandbox:v0.1.x onyxdotapp/sandbox:latest
docker push onyxdotapp/sandbox:latest
```
Or with buildx:
```bash
docker buildx build --platform linux/amd64,linux/arm64 \
-t onyxdotapp/sandbox:v0.1.x \
-t onyxdotapp/sandbox:latest \
--push .
```
## Deploying a New Version
1. **Build and push** the new image (see above)
2. **Update the ConfigMap** in `cloud-deployment-yamls/danswer/configmap/env-configmap.yaml`:
```yaml
SANDBOX_CONTAINER_IMAGE: "onyxdotapp/sandbox:v0.1.x"
```
3. **Apply the ConfigMap**:
```bash
kubectl apply -f configmap/env-configmap.yaml
```
4. **Restart the API server** to pick up the new config:
```bash
kubectl rollout restart deployment/api-server -n danswer
```
5. **Delete existing sandbox pods** (they will be recreated with the new image):
```bash
kubectl delete pods -n onyx-sandboxes -l app.kubernetes.io/component=sandbox
```
## What's Baked Into the Image
- **Base**: `node:20-slim` (Debian-based)
- **Demo data**: `/workspace/demo_data/` - sample files for demo sessions
- **Templates**: `/workspace/templates/outputs/` - Next.js web app scaffold
- **Python venv**: `/workspace/.venv/` with packages from `initial-requirements.txt`
- **OpenCode CLI**: Installed in `/home/sandbox/.opencode/bin/`
## Runtime Directory Structure
When a session is created, the following structure is set up in the pod:
```
/workspace/
├── demo_data/ # Baked into image
├── files/ # Mounted volume, synced from S3
├── templates/ # Baked into image
└── sessions/
└── $session_id/
├── files/ # Symlink to /workspace/demo_data or /workspace/files
├── outputs/ # Copied from templates, contains web app
├── attachments/ # User-uploaded files
├── org_info/ # Demo persona info (if demo mode)
├── AGENTS.md # Instructions for the AI agent
└── opencode.json # OpenCode configuration
```
## Troubleshooting
### Verify image exists on Docker Hub
```bash
curl -s "https://hub.docker.com/v2/repositories/onyxdotapp/sandbox/tags" | jq '.results[].name'
```
### Check what image a pod is using
```bash
kubectl get pod <pod-name> -n onyx-sandboxes -o jsonpath='{.spec.containers[?(@.name=="sandbox")].image}'
```

View File

@@ -349,7 +349,11 @@ class SessionManager:
return LLMProviderConfig(
provider=default_model.llm_provider.provider,
model_name=default_model.name,
api_key=default_model.llm_provider.api_key,
api_key=(
default_model.llm_provider.api_key.get_value(apply_mask=False)
if default_model.llm_provider.api_key
else None
),
api_base=default_model.llm_provider.api_base,
)

View File

@@ -41,6 +41,7 @@ from onyx.db.mcp import delete_all_user_connection_configs_for_server_no_commit
from onyx.db.mcp import delete_connection_config
from onyx.db.mcp import delete_mcp_server
from onyx.db.mcp import delete_user_connection_configs_for_server
from onyx.db.mcp import extract_connection_data
from onyx.db.mcp import get_all_mcp_servers
from onyx.db.mcp import get_connection_config_by_id
from onyx.db.mcp import get_mcp_server_by_id
@@ -79,6 +80,7 @@ from onyx.server.features.tool.models import ToolSnapshot
from onyx.tools.tool_implementations.mcp.mcp_client import discover_mcp_tools
from onyx.tools.tool_implementations.mcp.mcp_client import initialize_mcp_client
from onyx.tools.tool_implementations.mcp.mcp_client import log_exception_group
from onyx.utils.encryption import mask_string
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -143,7 +145,8 @@ class OnyxTokenStorage(TokenStorage):
async def get_tokens(self) -> OAuthToken | None:
with get_session_with_current_tenant() as db_session:
config = self._ensure_connection_config(db_session)
tokens_raw = config.config.get(MCPOAuthKeys.TOKENS.value)
config_data = extract_connection_data(config)
tokens_raw = config_data.get(MCPOAuthKeys.TOKENS.value)
if tokens_raw:
return OAuthToken.model_validate(tokens_raw)
return None
@@ -151,14 +154,14 @@ class OnyxTokenStorage(TokenStorage):
async def set_tokens(self, tokens: OAuthToken) -> None:
with get_session_with_current_tenant() as db_session:
config = self._ensure_connection_config(db_session)
config.config[MCPOAuthKeys.TOKENS.value] = tokens.model_dump(mode="json")
cfg_headers = {
config_data = extract_connection_data(config)
config_data[MCPOAuthKeys.TOKENS.value] = tokens.model_dump(mode="json")
config_data["headers"] = {
"Authorization": f"{tokens.token_type} {tokens.access_token}"
}
config.config["headers"] = cfg_headers
update_connection_config(config.id, db_session, config.config)
update_connection_config(config.id, db_session, config_data)
if self.alt_config_id:
update_connection_config(self.alt_config_id, db_session, config.config)
update_connection_config(self.alt_config_id, db_session, config_data)
# signal the oauth callback that token exchange is complete
r = get_redis_client()
@@ -168,19 +171,21 @@ class OnyxTokenStorage(TokenStorage):
async def get_client_info(self) -> OAuthClientInformationFull | None:
with get_session_with_current_tenant() as db_session:
config = self._ensure_connection_config(db_session)
client_info_raw = config.config.get(MCPOAuthKeys.CLIENT_INFO.value)
config_data = extract_connection_data(config)
client_info_raw = config_data.get(MCPOAuthKeys.CLIENT_INFO.value)
if client_info_raw:
return OAuthClientInformationFull.model_validate(client_info_raw)
if self.alt_config_id:
alt_config = get_connection_config_by_id(self.alt_config_id, db_session)
if alt_config:
alt_client_info = alt_config.config.get(
alt_config_data = extract_connection_data(alt_config)
alt_client_info = alt_config_data.get(
MCPOAuthKeys.CLIENT_INFO.value
)
if alt_client_info:
# Cache the admin client info on the user config for future calls
config.config[MCPOAuthKeys.CLIENT_INFO.value] = alt_client_info
update_connection_config(config.id, db_session, config.config)
config_data[MCPOAuthKeys.CLIENT_INFO.value] = alt_client_info
update_connection_config(config.id, db_session, config_data)
return OAuthClientInformationFull.model_validate(
alt_client_info
)
@@ -189,10 +194,11 @@ class OnyxTokenStorage(TokenStorage):
async def set_client_info(self, info: OAuthClientInformationFull) -> None:
with get_session_with_current_tenant() as db_session:
config = self._ensure_connection_config(db_session)
config.config[MCPOAuthKeys.CLIENT_INFO.value] = info.model_dump(mode="json")
update_connection_config(config.id, db_session, config.config)
config_data = extract_connection_data(config)
config_data[MCPOAuthKeys.CLIENT_INFO.value] = info.model_dump(mode="json")
update_connection_config(config.id, db_session, config_data)
if self.alt_config_id:
update_connection_config(self.alt_config_id, db_session, config.config)
update_connection_config(self.alt_config_id, db_session, config_data)
def make_oauth_provider(
@@ -436,9 +442,12 @@ async def _connect_oauth(
db.commit()
connection_config_dict = extract_connection_data(
connection_config, apply_mask=False
)
is_connected = (
MCPOAuthKeys.CLIENT_INFO.value in connection_config.config
and connection_config.config.get("headers")
MCPOAuthKeys.CLIENT_INFO.value in connection_config_dict
and connection_config_dict.get("headers")
)
# Step 1: make unauthenticated request and parse returned www authenticate header
# Ensure we have a trailing slash for the MCP endpoint
@@ -471,7 +480,7 @@ async def _connect_oauth(
try:
x = await initialize_mcp_client(
probe_url,
connection_headers=connection_config.config.get("headers", {}),
connection_headers=connection_config_dict.get("headers", {}),
transport=transport,
auth=oauth_auth,
)
@@ -684,15 +693,18 @@ def save_user_credentials(
# Use template to create the full connection config
try:
# TODO: fix and/or type correctly w/base model
auth_template_dict = extract_connection_data(
auth_template, apply_mask=False
)
config_data = MCPConnectionData(
headers=auth_template.config.get("headers", {}),
headers=auth_template_dict.get("headers", {}),
header_substitutions=request.credentials,
)
for oauth_field_key in MCPOAuthKeys:
field_key: Literal["client_info", "tokens", "metadata"] = (
oauth_field_key.value
)
if field_val := auth_template.config.get(field_key):
if field_val := auth_template_dict.get(field_key):
config_data[field_key] = field_val
except Exception as e:
@@ -839,18 +851,20 @@ def _db_mcp_server_to_api_mcp_server(
and db_server.admin_connection_config is not None
and include_auth_config
):
admin_config_dict = extract_connection_data(
db_server.admin_connection_config, apply_mask=False
)
if db_server.auth_type == MCPAuthenticationType.API_TOKEN:
raw_api_key = admin_config_dict["headers"]["Authorization"].split(" ")[
-1
]
admin_credentials = {
"api_key": db_server.admin_connection_config.config["headers"][
"Authorization"
].split(" ")[-1]
"api_key": mask_string(raw_api_key),
}
elif db_server.auth_type == MCPAuthenticationType.OAUTH:
user_authenticated = False
client_info = None
client_info_raw = db_server.admin_connection_config.config.get(
MCPOAuthKeys.CLIENT_INFO.value
)
client_info_raw = admin_config_dict.get(MCPOAuthKeys.CLIENT_INFO.value)
if client_info_raw:
client_info = OAuthClientInformationFull.model_validate(
client_info_raw
@@ -861,8 +875,8 @@ def _db_mcp_server_to_api_mcp_server(
"Stored client info had empty client ID or secret"
)
admin_credentials = {
"client_id": client_info.client_id,
"client_secret": client_info.client_secret,
"client_id": mask_string(client_info.client_id),
"client_secret": mask_string(client_info.client_secret),
}
else:
admin_credentials = {}
@@ -879,14 +893,18 @@ def _db_mcp_server_to_api_mcp_server(
include_auth_config
and db_server.auth_type != MCPAuthenticationType.OAUTH
):
user_credentials = user_config.config.get(HEADER_SUBSTITUTIONS, {})
user_config_dict = extract_connection_data(user_config, apply_mask=True)
user_credentials = user_config_dict.get(HEADER_SUBSTITUTIONS, {})
if (
db_server.auth_type == MCPAuthenticationType.OAUTH
and db_server.admin_connection_config
):
client_info = None
client_info_raw = db_server.admin_connection_config.config.get(
oauth_admin_config_dict = extract_connection_data(
db_server.admin_connection_config, apply_mask=False
)
client_info_raw = oauth_admin_config_dict.get(
MCPOAuthKeys.CLIENT_INFO.value
)
if client_info_raw:
@@ -896,8 +914,8 @@ def _db_mcp_server_to_api_mcp_server(
raise ValueError("Stored client info had empty client ID or secret")
if can_view_admin_credentials:
admin_credentials = {
"client_id": client_info.client_id,
"client_secret": client_info.client_secret,
"client_id": mask_string(client_info.client_id),
"client_secret": mask_string(client_info.client_secret),
}
elif can_view_admin_credentials:
admin_credentials = {}
@@ -909,7 +927,10 @@ def _db_mcp_server_to_api_mcp_server(
try:
template_config = db_server.admin_connection_config
if template_config:
headers = template_config.config.get("headers", {})
template_config_dict = extract_connection_data(
template_config, apply_mask=False
)
headers = template_config_dict.get("headers", {})
auth_template = MCPAuthTemplate(
headers=headers,
required_fields=[], # would need to regex, not worth it
@@ -1232,7 +1253,10 @@ def _list_mcp_tools_by_id(
)
if connection_config:
headers.update(connection_config.config.get("headers", {}))
connection_config_dict = extract_connection_data(
connection_config, apply_mask=False
)
headers.update(connection_config_dict.get("headers", {}))
import time
@@ -1320,7 +1344,10 @@ def _upsert_mcp_server(
_ensure_mcp_server_owner_or_admin(mcp_server, user)
client_info = None
if mcp_server.admin_connection_config:
client_info_raw = mcp_server.admin_connection_config.config.get(
existing_admin_config_dict = extract_connection_data(
mcp_server.admin_connection_config, apply_mask=False
)
client_info_raw = existing_admin_config_dict.get(
MCPOAuthKeys.CLIENT_INFO.value
)
if client_info_raw:

View File

@@ -32,11 +32,16 @@ def get_user_oauth_token_status(
and whether their tokens are expired.
"""
user_tokens = get_all_user_oauth_tokens(user.id, db_session)
return [
OAuthTokenStatus(
oauth_config_id=token.oauth_config_id,
expires_at=OAuthTokenManager.token_expiration_time(token.token_data),
is_expired=OAuthTokenManager.is_token_expired(token.token_data),
result = []
for token in user_tokens:
token_data = (
token.token_data.get_value(apply_mask=False) if token.token_data else {}
)
for token in user_tokens
]
result.append(
OAuthTokenStatus(
oauth_config_id=token.oauth_config_id,
expires_at=OAuthTokenManager.token_expiration_time(token_data),
is_expired=OAuthTokenManager.is_token_expired(token_data),
)
)
return result

View File

@@ -75,7 +75,7 @@ def _get_active_search_provider(
has_api_key=bool(provider_model.api_key),
)
if not provider_model.api_key:
if provider_model.api_key is None:
raise HTTPException(
status_code=400,
detail="Web search provider requires an API key.",
@@ -84,7 +84,7 @@ def _get_active_search_provider(
try:
provider: WebSearchProvider = build_search_provider_from_config(
provider_type=provider_view.provider_type,
api_key=provider_model.api_key,
api_key=provider_model.api_key.get_value(apply_mask=False),
config=provider_model.config or {},
)
except ValueError as exc:
@@ -121,7 +121,7 @@ def _get_active_content_provider(
provider: WebContentProvider | None = build_content_provider_from_config(
provider_type=provider_type,
api_key=provider_model.api_key,
api_key=provider_model.api_key.get_value(apply_mask=False),
config=config,
)
except ValueError as exc:

View File

@@ -114,9 +114,14 @@ def get_entities(
federated_connector = fetch_federated_connector_by_id(id, db_session)
if not federated_connector:
raise HTTPException(status_code=404, detail="Federated connector not found")
if federated_connector.credentials is None:
raise HTTPException(
status_code=400, detail="Federated connector has no credentials"
)
connector_instance = _get_federated_connector_instance(
federated_connector.source, federated_connector.credentials
federated_connector.source,
federated_connector.credentials.get_value(apply_mask=False),
)
entities_spec = connector_instance.configuration_schema()
@@ -151,9 +156,14 @@ def get_credentials_schema(
federated_connector = fetch_federated_connector_by_id(id, db_session)
if not federated_connector:
raise HTTPException(status_code=404, detail="Federated connector not found")
if federated_connector.credentials is None:
raise HTTPException(
status_code=400, detail="Federated connector has no credentials"
)
connector_instance = _get_federated_connector_instance(
federated_connector.source, federated_connector.credentials
federated_connector.source,
federated_connector.credentials.get_value(apply_mask=False),
)
credentials_spec = connector_instance.credentials_schema()
@@ -275,6 +285,8 @@ def validate_entities(
federated_connector = fetch_federated_connector_by_id(id, db_session)
if not federated_connector:
raise HTTPException(status_code=404, detail="Federated connector not found")
if federated_connector.credentials is None:
return Response(status_code=400)
# For HEAD requests, we'll expect entities as query parameters
# since HEAD requests shouldn't have request bodies
@@ -288,7 +300,8 @@ def validate_entities(
return Response(status_code=400)
connector_instance = _get_federated_connector_instance(
federated_connector.source, federated_connector.credentials
federated_connector.source,
federated_connector.credentials.get_value(apply_mask=False),
)
is_valid = connector_instance.validate_entities(entities_dict)
@@ -318,9 +331,15 @@ def get_authorize_url(
federated_connector = fetch_federated_connector_by_id(id, db_session)
if not federated_connector:
raise HTTPException(status_code=404, detail="Federated connector not found")
if federated_connector.credentials is None:
raise HTTPException(
status_code=400, detail="Federated connector has no credentials"
)
# Update credentials to include the correct redirect URI with the connector ID
updated_credentials = federated_connector.credentials.copy()
updated_credentials = federated_connector.credentials.get_value(
apply_mask=False
).copy()
if "redirect_uri" in updated_credentials and updated_credentials["redirect_uri"]:
# Replace the {id} placeholder with the actual federated connector ID
updated_credentials["redirect_uri"] = updated_credentials[
@@ -391,9 +410,14 @@ def handle_oauth_callback_generic(
)
if not federated_connector:
raise HTTPException(status_code=404, detail="Federated connector not found")
if federated_connector.credentials is None:
raise HTTPException(
status_code=400, detail="Federated connector has no credentials"
)
connector_instance = _get_federated_connector_instance(
federated_connector.source, federated_connector.credentials
federated_connector.source,
federated_connector.credentials.get_value(apply_mask=False),
)
oauth_result = connector_instance.callback(callback_data, get_oauth_callback_uri())
@@ -460,9 +484,9 @@ def get_user_oauth_status(
# Generate authorize URL if needed
authorize_url = None
if not oauth_token:
if not oauth_token and fc.credentials is not None:
connector_instance = _get_federated_connector_instance(
fc.source, fc.credentials
fc.source, fc.credentials.get_value(apply_mask=False)
)
base_authorize_url = connector_instance.authorize(get_oauth_callback_uri())
@@ -496,6 +520,10 @@ def get_federated_connector_detail(
federated_connector = fetch_federated_connector_by_id(id, db_session)
if not federated_connector:
raise HTTPException(status_code=404, detail="Federated connector not found")
if federated_connector.credentials is None:
raise HTTPException(
status_code=400, detail="Federated connector has no credentials"
)
# Get OAuth token information for the current user
oauth_token = None
@@ -521,7 +549,9 @@ def get_federated_connector_detail(
id=federated_connector.id,
source=federated_connector.source,
name=f"{federated_connector.source.replace('_', ' ').title()}",
credentials=FederatedConnectorCredentials(**federated_connector.credentials),
credentials=FederatedConnectorCredentials(
**federated_connector.credentials.get_value(apply_mask=True)
),
config=federated_connector.config,
oauth_token_exists=oauth_token is not None,
oauth_token_expires_at=oauth_token.expires_at if oauth_token else None,

View File

@@ -16,7 +16,7 @@ from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
from onyx.server.manage.embedding.models import CloudEmbeddingProvider
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
from onyx.server.manage.embedding.models import TestEmbeddingRequest
from onyx.server.utils import mask_string
from onyx.utils.encryption import mask_string
from onyx.utils.logger import setup_logger
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT

View File

@@ -37,7 +37,11 @@ class CloudEmbeddingProvider(BaseModel):
) -> "CloudEmbeddingProvider":
return cls(
provider_type=cloud_provider_model.provider_type,
api_key=cloud_provider_model.api_key,
api_key=(
cloud_provider_model.api_key.get_value(apply_mask=True)
if cloud_provider_model.api_key
else None
),
api_url=cloud_provider_model.api_url,
api_version=cloud_provider_model.api_version,
deployment_name=cloud_provider_model.deployment_name,

View File

@@ -90,7 +90,11 @@ def _build_llm_provider_request(
return LLMProviderUpsertRequest(
name=f"Image Gen - {image_provider_id}",
provider=source_provider.provider,
api_key=source_provider.api_key, # Only this from source
api_key=(
source_provider.api_key.get_value(apply_mask=False)
if source_provider.api_key
else None
), # Only this from source
api_base=api_base, # From request
api_version=api_version, # From request
default_model_name=model_name,
@@ -227,7 +231,11 @@ def test_image_generation(
api_key_changed=False, # Using stored key from source provider
)
api_key = source_provider.api_key
api_key = (
source_provider.api_key.get_value(apply_mask=False)
if source_provider.api_key
else None
)
provider = source_provider.provider
if provider is None:
@@ -431,7 +439,11 @@ def update_config(
api_key_changed=False,
)
# Preserve existing API key when user didn't change it
actual_api_key = old_provider.api_key
actual_api_key = (
old_provider.api_key.get_value(apply_mask=False)
if old_provider.api_key
else None
)
# 3. Build and create new LLM provider
provider_request = _build_llm_provider_request(

View File

@@ -140,7 +140,11 @@ class ImageGenerationCredentials(BaseModel):
"""
llm_provider = config.model_configuration.llm_provider
return cls(
api_key=_mask_api_key(llm_provider.api_key),
api_key=_mask_api_key(
llm_provider.api_key.get_value(apply_mask=False)
if llm_provider.api_key
else None
),
api_base=llm_provider.api_base,
api_version=llm_provider.api_version,
deployment_name=llm_provider.deployment_name,
@@ -168,7 +172,11 @@ class DefaultImageGenerationConfig(BaseModel):
model_configuration_id=config.model_configuration_id,
model_name=config.model_configuration.name,
provider=llm_provider.provider,
api_key=llm_provider.api_key,
api_key=(
llm_provider.api_key.get_value(apply_mask=False)
if llm_provider.api_key
else None
),
api_base=llm_provider.api_base,
api_version=llm_provider.api_version,
deployment_name=llm_provider.deployment_name,

View File

@@ -203,7 +203,11 @@ def test_llm_configuration(
new_custom_config=test_llm_request.custom_config,
api_key_changed=False,
)
test_api_key = existing_provider.api_key
test_api_key = (
existing_provider.api_key.get_value(apply_mask=False)
if existing_provider.api_key
else None
)
if existing_provider and not test_llm_request.custom_config_changed:
test_custom_config = existing_provider.custom_config
@@ -351,7 +355,11 @@ def put_llm_provider(
# the llm api key is sanitized when returned to clients, so the only time we
# should get a real key is when it is explicitly changed
if existing_provider and not llm_provider_upsert_request.api_key_changed:
llm_provider_upsert_request.api_key = existing_provider.api_key
llm_provider_upsert_request.api_key = (
existing_provider.api_key.get_value(apply_mask=False)
if existing_provider.api_key
else None
)
if existing_provider and not llm_provider_upsert_request.custom_config_changed:
llm_provider_upsert_request.custom_config = existing_provider.custom_config
@@ -646,7 +654,11 @@ def get_provider_contextual_cost(
provider=provider.provider,
model=model_configuration.name,
deployment_name=provider.deployment_name,
api_key=provider.api_key,
api_key=(
provider.api_key.get_value(apply_mask=False)
if provider.api_key
else None
),
api_base=provider.api_base,
api_version=provider.api_version,
custom_config=provider.custom_config,
@@ -926,6 +938,11 @@ def get_ollama_available_models(
)
)
sorted_results = sorted(
all_models_with_context_size_and_vision,
key=lambda m: m.name.lower(),
)
# Sync new models to DB if provider_name is specified
if request.provider_name:
try:
@@ -936,7 +953,7 @@ def get_ollama_available_models(
"max_input_tokens": r.max_input_tokens,
"supports_image_input": r.supports_image_input,
}
for r in all_models_with_context_size_and_vision
for r in sorted_results
]
new_count = sync_model_configurations(
db_session=db_session,
@@ -950,7 +967,7 @@ def get_ollama_available_models(
except ValueError as e:
logger.warning(f"Failed to sync Ollama models to DB: {e}")
return all_models_with_context_size_and_vision
return sorted_results
def _get_openrouter_models_response(api_base: str, api_key: str) -> dict:

View File

@@ -190,7 +190,11 @@ class LLMProviderView(LLMProvider):
id=llm_provider_model.id,
name=llm_provider_model.name,
provider=provider,
api_key=llm_provider_model.api_key,
api_key=(
llm_provider_model.api_key.get_value(apply_mask=False)
if llm_provider_model.api_key
else None
),
api_base=llm_provider_model.api_base,
api_version=llm_provider_model.api_version,
custom_config=llm_provider_model.custom_config,

View File

@@ -79,6 +79,7 @@ class UserPersonalization(BaseModel):
role: str = ""
use_memories: bool = True
memories: list[str] = Field(default_factory=list)
user_preferences: str = ""
class TenantSnapshot(BaseModel):
@@ -160,6 +161,7 @@ class UserInfo(BaseModel):
role=user.personal_role or "",
use_memories=user.use_memories,
memories=[memory.memory_text for memory in (user.memories or [])],
user_preferences=user.user_preferences or "",
),
)
@@ -213,6 +215,7 @@ class PersonalizationUpdateRequest(BaseModel):
role: str | None = None
use_memories: bool | None = None
memories: list[str] | None = None
user_preferences: str | None = Field(default=None, max_length=500)
class SlackBotCreationRequest(BaseModel):
@@ -341,9 +344,21 @@ class SlackBot(BaseModel):
name=slack_bot_model.name,
enabled=slack_bot_model.enabled,
configs_count=len(slack_bot_model.slack_channel_configs),
bot_token=slack_bot_model.bot_token,
app_token=slack_bot_model.app_token,
user_token=slack_bot_model.user_token,
bot_token=(
slack_bot_model.bot_token.get_value(apply_mask=True)
if slack_bot_model.bot_token
else ""
),
app_token=(
slack_bot_model.app_token.get_value(apply_mask=True)
if slack_bot_model.app_token
else ""
),
user_token=(
slack_bot_model.user_token.get_value(apply_mask=True)
if slack_bot_model.user_token
else None
),
)

View File

@@ -844,6 +844,11 @@ def update_user_personalization_api(
new_memories = (
request.memories if request.memories is not None else existing_memories
)
new_user_preferences = (
request.user_preferences
if request.user_preferences is not None
else user.user_preferences
)
update_user_personalization(
user.id,
@@ -851,6 +856,7 @@ def update_user_personalization_api(
personal_role=new_role,
use_memories=new_use_memories,
memories=new_memories,
user_preferences=new_user_preferences,
db_session=db_session,
)

View File

@@ -194,7 +194,7 @@ def test_search_provider(
status_code=400,
detail="No stored API key found for this provider type.",
)
api_key = existing_provider.api_key
api_key = existing_provider.api_key.get_value(apply_mask=False)
if provider_requires_api_key and not api_key:
raise HTTPException(
@@ -391,7 +391,7 @@ def test_content_provider(
detail="Base URL cannot differ from stored provider when using stored API key",
)
api_key = existing_provider.api_key
api_key = existing_provider.api_key.get_value(apply_mask=False)
if not api_key:
raise HTTPException(

View File

@@ -8,10 +8,6 @@ from uuid import UUID
from fastapi import HTTPException
from fastapi import status
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_AUTHENTICATION_METHOD,
)
class BasicAuthenticationError(HTTPException):
def __init__(self, detail: str):
@@ -45,42 +41,6 @@ def get_json_line(
return json.dumps(json_dict, cls=encoder) + "\n"
def mask_string(sensitive_str: str) -> str:
return "****...**" + sensitive_str[-4:]
MASK_CREDENTIALS_WHITELIST = {
DB_CREDENTIALS_AUTHENTICATION_METHOD,
"wiki_base",
"cloud_name",
"cloud_id",
}
def mask_credential_dict(credential_dict: dict[str, Any]) -> dict[str, str]:
masked_creds = {}
for key, val in credential_dict.items():
if isinstance(val, str):
# we want to pass the authentication_method field through so the frontend
# can disambiguate credentials created by different methods
if key in MASK_CREDENTIALS_WHITELIST:
masked_creds[key] = val
else:
masked_creds[key] = mask_string(val)
continue
if isinstance(val, int):
masked_creds[key] = "*****"
continue
raise ValueError(
f"Unable to mask credentials of type other than string or int, cannot process request."
f"Received type: {type(val)}"
)
return masked_creds
def make_short_id() -> str:
"""Fast way to generate a random 8 character id ... useful for tagging data
to trace it through a flow. This is definitely not guaranteed to be unique and is

View File

@@ -446,7 +446,7 @@ def run_research_agent_call(
tool_calls=tool_calls,
tools=current_tools,
message_history=msg_history,
memories=None,
user_memory_context=None,
user_info=None,
citation_mapping=citation_mapping,
next_citation_num=citation_processor.get_next_citation_number(),

View File

@@ -16,6 +16,7 @@ from onyx.configs.chat_configs import NUM_RETURNED_HITS
from onyx.configs.constants import MessageType
from onyx.context.search.models import SearchDoc
from onyx.context.search.models import SearchDocsResponse
from onyx.db.memory import UserMemoryContext
from onyx.server.query_and_chat.placement import Placement
from onyx.server.query_and_chat.streaming_models import GeneratedImage
from onyx.tools.tool_implementations.images.models import FinalImageGenerationResponse
@@ -165,7 +166,7 @@ class SearchToolOverrideKwargs(BaseModel):
# without help and a specific custom prompt for this
original_query: str | None = None
message_history: list[ChatMinimalTextMessage] | None = None
memories: list[str] | None = None
user_memory_context: UserMemoryContext | None = None
user_info: str | None = None
# Used for tool calls after the first one but in the same chat turn. The reason for this is that if the initial pass through

View File

@@ -82,7 +82,11 @@ def _get_image_generation_config(llm: LLM, db_session: Session) -> LLMConfig:
model_provider=llm_provider.provider,
model_name=default_config.model_configuration.name,
temperature=GEN_AI_TEMPERATURE,
api_key=llm_provider.api_key,
api_key=(
llm_provider.api_key.get_value(apply_mask=False)
if llm_provider.api_key
else None
),
api_base=llm_provider.api_base,
api_version=llm_provider.api_version,
deployment_name=llm_provider.deployment_name,

View File

@@ -94,7 +94,11 @@ class ImageGenerationTool(Tool[None]):
llm_provider = config.model_configuration.llm_provider
credentials = ImageGenerationProviderCredentials(
api_key=llm_provider.api_key,
api_key=(
llm_provider.api_key.get_value(apply_mask=False)
if llm_provider.api_key
else None
),
api_base=llm_provider.api_base,
api_version=llm_provider.api_version,
deployment_name=llm_provider.deployment_name,

View File

@@ -142,8 +142,9 @@ class MCPTool(Tool[None]):
)
# Priority 2: Base headers from connection config (DB) - overrides request
if self.connection_config:
headers.update(self.connection_config.config.get("headers", {}))
if self.connection_config and self.connection_config.config:
config_dict = self.connection_config.config.get_value(apply_mask=False)
headers.update(config_dict.get("headers", {}))
# Priority 3: For pass-through OAuth, use the user's login OAuth token
if self._user_oauth_token:

View File

@@ -352,10 +352,17 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
)
if tenant_slack_bot:
bot_token = tenant_slack_bot.bot_token
access_token = (
tenant_slack_bot.user_token or tenant_slack_bot.bot_token
bot_token = (
tenant_slack_bot.bot_token.get_value(apply_mask=False)
if tenant_slack_bot.bot_token
else None
)
user_token = (
tenant_slack_bot.user_token.get_value(apply_mask=False)
if tenant_slack_bot.user_token
else None
)
access_token = user_token or bot_token
except Exception as e:
logger.warning(f"Could not fetch Slack bot tokens: {e}")
@@ -375,8 +382,10 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
None,
)
if slack_oauth_token:
access_token = slack_oauth_token.token
if slack_oauth_token and slack_oauth_token.token:
access_token = slack_oauth_token.token.get_value(
apply_mask=False
)
entities = slack_oauth_token.federated_connector.config or {}
except Exception as e:
logger.warning(f"Could not fetch Slack OAuth token: {e}")
@@ -550,7 +559,11 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
if override_kwargs.message_history
else []
)
memories = override_kwargs.memories
memories = (
override_kwargs.user_memory_context.as_formatted_list()
if override_kwargs.user_memory_context
else []
)
user_info = override_kwargs.user_info
# Skip query expansion if this is a repeat search call

View File

@@ -77,7 +77,11 @@ def build_search_provider_from_config(
def _build_search_provider(provider_model: InternetSearchProvider) -> WebSearchProvider:
return build_search_provider_from_config(
provider_type=WebSearchProviderType(provider_model.provider_type),
api_key=provider_model.api_key or "",
api_key=(
provider_model.api_key.get_value(apply_mask=False)
if provider_model.api_key
else ""
),
config=provider_model.config or {},
)
@@ -129,7 +133,11 @@ def get_default_content_provider() -> WebContentProvider:
if provider_model:
provider = build_content_provider_from_config(
provider_type=WebContentProviderType(provider_model.provider_type),
api_key=provider_model.api_key or "",
api_key=(
provider_model.api_key.get_value(apply_mask=False)
if provider_model.api_key
else ""
),
config=provider_model.config or WebContentProviderConfig(),
)
if provider:

View File

@@ -69,7 +69,11 @@ class WebSearchTool(Tool[WebSearchToolOverrideKwargs]):
if provider_model is None:
raise RuntimeError("No web search provider configured.")
provider_type = WebSearchProviderType(provider_model.provider_type)
api_key = provider_model.api_key
api_key = (
provider_model.api_key.get_value(apply_mask=False)
if provider_model.api_key
else None
)
config = provider_model.config
# TODO - This should just be enforced at the DB level

View File

@@ -6,6 +6,7 @@ import onyx.tracing.framework._error_tracing as _error_tracing
from onyx.chat.models import ChatMessageSimple
from onyx.configs.constants import MessageType
from onyx.context.search.models import SearchDocsResponse
from onyx.db.memory import UserMemoryContext
from onyx.server.query_and_chat.streaming_models import Packet
from onyx.server.query_and_chat.streaming_models import PacketException
from onyx.server.query_and_chat.streaming_models import SectionEnd
@@ -220,7 +221,7 @@ def run_tool_calls(
tools: list[Tool],
# The stuff below is needed for the different individual built-in tools
message_history: list[ChatMessageSimple],
memories: list[str] | None,
user_memory_context: UserMemoryContext | None,
user_info: str | None,
citation_mapping: dict[int, str],
next_citation_num: int,
@@ -252,7 +253,7 @@ def run_tool_calls(
tools: List of available tool instances.
message_history: Chat message history (used to find the most recent user query
for `SearchTool` override kwargs).
memories: User memories, if available (passed through to `SearchTool`).
user_memory_context: User memory context, if available (passed through to `SearchTool`).
user_info: User information string, if available (passed through to `SearchTool`).
citation_mapping: Current citation number to URL mapping. May be updated with
new citations produced by search tools.
@@ -342,7 +343,7 @@ def run_tool_calls(
starting_citation_num=starting_citation_num,
original_query=last_user_message,
message_history=minimal_history,
memories=memories,
user_memory_context=user_memory_context,
user_info=user_info,
skip_query_expansion=skip_search_query_expansion,
)

View File

@@ -1,22 +1,91 @@
from typing import Any
from onyx.configs.app_configs import ENCRYPTION_KEY_SECRET
from onyx.connectors.google_utils.shared_constants import (
DB_CREDENTIALS_AUTHENTICATION_METHOD,
)
from onyx.utils.logger import setup_logger
from onyx.utils.variable_functionality import fetch_versioned_implementation
logger = setup_logger()
# IMPORTANT DO NOT DELETE, THIS IS USED BY fetch_versioned_implementation
def _encrypt_string(input_str: str) -> bytes:
if ENCRYPTION_KEY_SECRET:
logger.warning("MIT version of Onyx does not support encryption of secrets.")
return input_str.encode()
# IMPORTANT DO NOT DELETE, THIS IS USED BY fetch_versioned_implementation
def _decrypt_bytes(input_bytes: bytes) -> str:
# No need to double warn. If you wish to learn more about encryption features
# refer to the Onyx EE code
return input_bytes.decode()
def mask_string(sensitive_str: str) -> str:
"""Masks a sensitive string, showing first and last few characters.
If the string is too short to safely mask, returns a fully masked placeholder.
"""
visible_start = 4
visible_end = 4
min_masked_chars = 6
if len(sensitive_str) < visible_start + visible_end + min_masked_chars:
return "••••••••••••"
return f"{sensitive_str[:visible_start]}...{sensitive_str[-visible_end:]}"
MASK_CREDENTIALS_WHITELIST = {
DB_CREDENTIALS_AUTHENTICATION_METHOD,
"wiki_base",
"cloud_name",
"cloud_id",
}
def mask_credential_dict(credential_dict: dict[str, Any]) -> dict[str, Any]:
masked_creds: dict[str, Any] = {}
for key, val in credential_dict.items():
if isinstance(val, str):
# we want to pass the authentication_method field through so the frontend
# can disambiguate credentials created by different methods
if key in MASK_CREDENTIALS_WHITELIST:
masked_creds[key] = val
else:
masked_creds[key] = mask_string(val)
elif isinstance(val, dict):
masked_creds[key] = mask_credential_dict(val)
elif isinstance(val, list):
masked_creds[key] = _mask_list(val)
elif isinstance(val, (bool, type(None))):
masked_creds[key] = val
elif isinstance(val, (int, float)):
masked_creds[key] = "*****"
else:
masked_creds[key] = "*****"
return masked_creds
def _mask_list(items: list[Any]) -> list[Any]:
masked: list[Any] = []
for item in items:
if isinstance(item, dict):
masked.append(mask_credential_dict(item))
elif isinstance(item, str):
masked.append(mask_string(item))
elif isinstance(item, list):
masked.append(_mask_list(item))
elif isinstance(item, (bool, type(None))):
masked.append(item)
else:
masked.append("*****")
return masked
def encrypt_string_to_bytes(intput_str: str) -> bytes:
versioned_encryption_fn = fetch_versioned_implementation(
"onyx.utils.encryption", "_encrypt_string"

View File

@@ -0,0 +1,205 @@
"""
Wrapper class for sensitive values that require explicit masking decisions.
This module provides a wrapper for encrypted values that forces developers to
make an explicit decision about whether to mask the value when accessing it.
This prevents accidental exposure of sensitive data in API responses.
"""
from __future__ import annotations
import json
from collections.abc import Callable
from typing import Any
from typing import Generic
from typing import NoReturn
from typing import TypeVar
from unittest.mock import MagicMock
from onyx.utils.encryption import mask_credential_dict
from onyx.utils.encryption import mask_string
T = TypeVar("T", str, dict[str, Any])
def make_mock_sensitive_value(value: dict[str, Any] | str | None) -> MagicMock:
"""
Create a mock SensitiveValue for use in tests.
This helper makes it easy to create mock objects that behave like
SensitiveValue for testing code that uses credentials.
Args:
value: The value to return from get_value(). Can be a dict, string, or None.
Returns:
A MagicMock configured to behave like a SensitiveValue.
Example:
>>> mock_credential = MagicMock()
>>> mock_credential.credential_json = make_mock_sensitive_value({"api_key": "secret"})
>>> # Now mock_credential.credential_json.get_value(apply_mask=False) returns {"api_key": "secret"}
"""
if value is None:
return None # type: ignore[return-value]
mock = MagicMock(spec=SensitiveValue)
mock.get_value.return_value = value
mock.__bool__ = lambda self: True # noqa: ARG005
return mock
class SensitiveAccessError(Exception):
"""Raised when attempting to access a SensitiveValue without explicit masking decision."""
class SensitiveValue(Generic[T]):
"""
Wrapper requiring explicit masking decisions for sensitive data.
This class wraps encrypted data and forces callers to make an explicit
decision about whether to mask the value when accessing it. This prevents
accidental exposure of sensitive data.
Usage:
# Get raw value (for internal use like connectors)
raw_value = sensitive.get_value(apply_mask=False)
# Get masked value (for API responses)
masked_value = sensitive.get_value(apply_mask=True)
Raises SensitiveAccessError when:
- Attempting to convert to string via str() or repr()
- Attempting to iterate over the value
- Attempting to subscript the value (e.g., value["key"])
- Attempting to serialize to JSON without explicit get_value()
"""
def __init__(
self,
*,
encrypted_bytes: bytes,
decrypt_fn: Callable[[bytes], str],
is_json: bool = False,
) -> None:
"""
Initialize a SensitiveValue wrapper.
Args:
encrypted_bytes: The encrypted bytes to wrap
decrypt_fn: Function to decrypt bytes to string
is_json: If True, the decrypted value is JSON and will be parsed to dict
"""
self._encrypted_bytes = encrypted_bytes
self._decrypt_fn = decrypt_fn
self._is_json = is_json
# Cache for decrypted value to avoid repeated decryption
self._decrypted_value: T | None = None
def _decrypt(self) -> T:
"""Lazily decrypt and cache the value."""
if self._decrypted_value is None:
decrypted_str = self._decrypt_fn(self._encrypted_bytes)
if self._is_json:
self._decrypted_value = json.loads(decrypted_str)
else:
self._decrypted_value = decrypted_str # type: ignore[assignment]
# The return type should always match T based on is_json flag
return self._decrypted_value # type: ignore[return-value]
def get_value(
self,
*,
apply_mask: bool,
mask_fn: Callable[[T], T] | None = None,
) -> T:
"""
Get the value with explicit masking decision.
Args:
apply_mask: Required. True = return masked value, False = return raw value
mask_fn: Optional custom masking function. Defaults to mask_string for
strings and mask_credential_dict for dicts.
Returns:
The value, either masked or raw depending on apply_mask.
"""
value = self._decrypt()
if not apply_mask:
return value
# Apply masking
if mask_fn is not None:
return mask_fn(value)
# Use default masking based on type
# Type narrowing doesn't work well here due to the generic T,
# but at runtime the types will match
if isinstance(value, dict):
return mask_credential_dict(value)
elif isinstance(value, str):
return mask_string(value)
else:
raise ValueError(f"Cannot mask value of type {type(value)}")
def __bool__(self) -> bool:
"""Allow truthiness checks without exposing the value."""
return True
def __str__(self) -> NoReturn:
"""Prevent accidental string conversion."""
raise SensitiveAccessError(
"Cannot convert SensitiveValue to string. "
"Use .get_value(apply_mask=True/False) to access the value."
)
def __repr__(self) -> str:
"""Prevent accidental repr exposure."""
return "<SensitiveValue: use .get_value(apply_mask=True/False) to access>"
def __iter__(self) -> NoReturn:
"""Prevent iteration over the value."""
raise SensitiveAccessError(
"Cannot iterate over SensitiveValue. "
"Use .get_value(apply_mask=True/False) to access the value."
)
def __getitem__(self, key: Any) -> NoReturn:
"""Prevent subscript access."""
raise SensitiveAccessError(
"Cannot subscript SensitiveValue. "
"Use .get_value(apply_mask=True/False) to access the value."
)
def __eq__(self, other: Any) -> bool:
"""Prevent direct comparison which might expose value."""
if isinstance(other, SensitiveValue):
# Compare encrypted bytes for equality check
return self._encrypted_bytes == other._encrypted_bytes
raise SensitiveAccessError(
"Cannot compare SensitiveValue with non-SensitiveValue. "
"Use .get_value(apply_mask=True/False) to access the value for comparison."
)
def __hash__(self) -> int:
"""Allow hashing based on encrypted bytes."""
return hash(self._encrypted_bytes)
# Prevent JSON serialization
def __json__(self) -> Any:
"""Prevent JSON serialization."""
raise SensitiveAccessError(
"Cannot serialize SensitiveValue to JSON. "
"Use .get_value(apply_mask=True/False) to access the value."
)
# For Pydantic compatibility
@classmethod
def __get_pydantic_core_schema__(cls, source_type: Any, handler: Any) -> Any:
"""Prevent Pydantic from serializing without explicit get_value()."""
raise SensitiveAccessError(
"Cannot serialize SensitiveValue in Pydantic model. "
"Use .get_value(apply_mask=True/False) to access the value before serialization."
)

View File

@@ -36,7 +36,7 @@ global_version = OnyxVersion()
# Eventually, ENABLE_PAID_ENTERPRISE_EDITION_FEATURES will be removed
# and license enforcement will be the only mechanism for EE features.
_LICENSE_ENFORCEMENT_ENABLED = (
os.environ.get("LICENSE_ENFORCEMENT_ENABLED", "").lower() == "true"
os.environ.get("LICENSE_ENFORCEMENT_ENABLED", "true").lower() == "true"
)

View File

@@ -112,7 +112,10 @@ def test_gdrive_perm_sync_with_real_data(
mock_cc_pair.connector = MagicMock()
mock_cc_pair.connector.connector_specific_config = {}
mock_cc_pair.credential_id = 1
mock_cc_pair.credential.credential_json = {}
# Import and use the mock helper
from onyx.utils.sensitive import make_mock_sensitive_value
mock_cc_pair.credential.credential_json = make_mock_sensitive_value({})
mock_cc_pair.last_time_perm_sync = None
mock_cc_pair.last_time_external_group_sync = None

View File

@@ -43,6 +43,8 @@ def ensure_default_llm_provider(db_session: Session) -> None:
)
update_default_provider(provider.id, db_session)
except Exception as exc: # pragma: no cover - only hits on duplicate setup issues
# Rollback to clear the pending transaction state
db_session.rollback()
print(f"Note: Could not create LLM provider: {exc}")

View File

@@ -129,6 +129,8 @@ def test_confluence_group_sync(
)
db_session.add(credential)
db_session.flush()
# Expire the credential so it reloads from DB with SensitiveValue wrapper
db_session.expire(credential)
cc_pair = ConnectorCredentialPair(
connector_id=connector.id,

View File

@@ -53,6 +53,8 @@ def _create_test_connector_credential_pair(
)
db_session.add(credential)
db_session.flush() # To get the credential ID
# Expire the credential so it reloads from DB with SensitiveValue wrapper
db_session.expire(credential)
cc_pair = ConnectorCredentialPair(
connector_id=connector.id,

View File

@@ -80,6 +80,8 @@ def test_jira_doc_sync(
)
db_session.add(credential)
db_session.flush()
# Expire the credential so it reloads from DB with SensitiveValue wrapper
db_session.expire(credential)
cc_pair = ConnectorCredentialPair(
connector_id=connector.id,
@@ -176,6 +178,8 @@ def test_jira_doc_sync_with_specific_permissions(
)
db_session.add(credential)
db_session.flush()
# Expire the credential so it reloads from DB with SensitiveValue wrapper
db_session.expire(credential)
cc_pair = ConnectorCredentialPair(
connector_id=connector.id,

View File

@@ -114,6 +114,8 @@ def test_jira_group_sync(
)
db_session.add(credential)
db_session.flush()
# Expire the credential so it reloads from DB with SensitiveValue wrapper
db_session.expire(credential)
cc_pair = ConnectorCredentialPair(
connector_id=connector.id,

View File

@@ -57,6 +57,8 @@ def _create_test_connector_credential_pair(
)
db_session.add(credential)
db_session.flush()
# Expire the credential so it reloads from DB with SensitiveValue wrapper
db_session.expire(credential)
cc_pair = ConnectorCredentialPair(
connector_id=connector.id,

View File

@@ -67,6 +67,8 @@ def _create_test_connector_credential_pair(
)
db_session.add(credential)
db_session.flush()
# Expire the credential so it reloads from DB with SensitiveValue wrapper
db_session.expire(credential)
cc_pair = ConnectorCredentialPair(
connector_id=connector.id,

View File

@@ -254,6 +254,8 @@ class TestSlackBotFederatedSearch:
)
db_session.add(federated_connector)
db_session.flush()
# Expire to ensure credentials is reloaded as SensitiveValue from DB
db_session.expire(federated_connector)
# Associate the federated connector with the persona's document sets
# This is required for Slack federated search to be enabled
@@ -276,6 +278,8 @@ class TestSlackBotFederatedSearch:
)
db_session.add(slack_bot)
db_session.flush()
# Expire to ensure tokens are reloaded as SensitiveValue from DB
db_session.expire(slack_bot)
slack_channel_config = SlackChannelConfig(
slack_bot_id=slack_bot.id,

View File

@@ -211,9 +211,11 @@ class TestOAuthTokenManagerRefresh:
# Verify token was updated in DB
db_session.refresh(user_token)
assert user_token.token_data["access_token"] == "new_token"
assert user_token.token_data["refresh_token"] == "new_refresh"
assert "expires_at" in user_token.token_data
assert user_token.token_data is not None
token_data = user_token.token_data.get_value(apply_mask=False)
assert token_data["access_token"] == "new_token"
assert token_data["refresh_token"] == "new_refresh"
assert "expires_at" in token_data
@patch("onyx.auth.oauth_token_manager.requests.post")
def test_refresh_token_preserves_refresh_token(
@@ -249,7 +251,9 @@ class TestOAuthTokenManagerRefresh:
# Verify old refresh_token was preserved
db_session.refresh(user_token)
assert user_token.token_data["refresh_token"] == "old_refresh"
assert user_token.token_data is not None
token_data = user_token.token_data.get_value(apply_mask=False)
assert token_data["refresh_token"] == "old_refresh"
@patch("onyx.auth.oauth_token_manager.requests.post")
def test_refresh_token_http_error(

View File

@@ -17,7 +17,8 @@ COPY ./tests/* /app/tests/
FROM base AS openapi-schema
COPY ./scripts/onyx_openapi_schema.py /app/scripts/onyx_openapi_schema.py
RUN python scripts/onyx_openapi_schema.py --filename openapi.json
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
RUN LICENSE_ENFORCEMENT_ENABLED=false python scripts/onyx_openapi_schema.py --filename openapi.json
FROM openapitools/openapi-generator-cli:latest AS openapi-client
WORKDIR /local

View File

@@ -40,6 +40,8 @@ def test_github_private_repo_permission_sync(
) = github_test_env_setup
# Create GitHub client from credential
# Note: github_credential is a DATestCredential (Pydantic model), not a SQLAlchemy model
# so credential_json is already a plain dict
github_access_token = github_credential.credential_json["github_access_token"]
github_client = Github(github_access_token)
github_manager = GitHubManager(github_client)
@@ -158,6 +160,8 @@ def test_github_public_repo_permission_sync(
) = github_test_env_setup
# Create GitHub client from credential
# Note: github_credential is a DATestCredential (Pydantic model), not a SQLAlchemy model
# so credential_json is already a plain dict
github_access_token = github_credential.credential_json["github_access_token"]
github_client = Github(github_access_token)
github_manager = GitHubManager(github_client)
@@ -262,6 +266,8 @@ def test_github_internal_repo_permission_sync(
) = github_test_env_setup
# Create GitHub client from credential
# Note: github_credential is a DATestCredential (Pydantic model), not a SQLAlchemy model
# so credential_json is already a plain dict
github_access_token = github_credential.credential_json["github_access_token"]
github_client = Github(github_access_token)
github_manager = GitHubManager(github_client)

View File

@@ -1,21 +0,0 @@
import os
import pytest
from tests.integration.common_utils.managers.user import UserManager
from tests.integration.common_utils.reset import reset_all
from tests.integration.common_utils.test_models import DATestUser
@pytest.mark.skipif(
os.getenv("PYTEST_IGNORE_SKIP") is None,
reason="Skipped by default unless env var exists",
)
def test_playwright_setup() -> None:
"""Not really a test, just using this to automate setup for playwright tests."""
if not os.getenv("PYTEST_PLAYWRIGHT_SKIP_INITIAL_RESET", "").lower() == "true":
reset_all()
# Creating an admin user (first user created is automatically an admin)
admin_user: DATestUser = UserManager.create(name="admin_user")
assert admin_user

View File

@@ -7,6 +7,7 @@ from ee.onyx.external_permissions.jira.doc_sync import jira_doc_sync
from onyx.connectors.jira.connector import JiraConnector
from onyx.connectors.jira.utils import JIRA_SERVER_API_VERSION
from onyx.db.models import ConnectorCredentialPair
from onyx.utils.sensitive import make_mock_sensitive_value
@pytest.fixture
@@ -18,10 +19,12 @@ def mock_jira_cc_pair(
) -> MagicMock:
mock_cc_pair = MagicMock(spec=ConnectorCredentialPair)
mock_cc_pair.connector = MagicMock()
mock_cc_pair.credential.credential_json = {
"jira_user_email": user_email,
"jira_api_token": mock_jira_api_token,
}
mock_cc_pair.credential.credential_json = make_mock_sensitive_value(
{
"jira_user_email": user_email,
"jira_api_token": mock_jira_api_token,
}
)
mock_cc_pair.connector.connector_specific_config = {
"jira_base_url": jira_base_url,
"project_key": project_key,

View File

@@ -247,11 +247,13 @@ class TestInstantiateConnectorIntegration:
def test_instantiate_connector_loads_class_lazily(self) -> None:
"""Test that instantiate_connector triggers lazy loading."""
from onyx.utils.sensitive import make_mock_sensitive_value
# Mock the database session and credential
mock_session = MagicMock()
mock_credential = MagicMock()
mock_credential.id = 123
mock_credential.credential_json = {"test": "data"}
mock_credential.credential_json = make_mock_sensitive_value({"test": "data"})
# This should trigger lazy loading but will fail on actual instantiation
# due to missing real configuration - that's expected

View File

@@ -1,3 +1,7 @@
import os
import threading
import time
from typing import Any
from unittest.mock import ANY
from unittest.mock import patch
@@ -828,3 +832,274 @@ def test_azure_openai_model_uses_httphandler_client() -> None:
mock_completion.assert_called_once()
kwargs = mock_completion.call_args.kwargs
assert isinstance(kwargs["client"], HTTPHandler)
def test_temporary_env_cleanup(monkeypatch: pytest.MonkeyPatch) -> None:
# Assign some environment variables
EXPECTED_ENV_VARS = {
"TEST_ENV_VAR": "test_value",
"ANOTHER_ONE": "1",
"THIRD_ONE": "2",
}
CUSTOM_CONFIG = {
"TEST_ENV_VAR": "fdsfsdf",
"ANOTHER_ONE": "3",
"THIS_IS_RANDOM": "123213",
}
for env_var, value in EXPECTED_ENV_VARS.items():
monkeypatch.setenv(env_var, value)
model_provider = LlmProviderNames.OPENAI
model_name = "gpt-3.5-turbo"
llm = LitellmLLM(
api_key="test_key",
timeout=30,
model_provider=model_provider,
model_name=model_name,
max_input_tokens=get_max_input_tokens(
model_provider=model_provider,
model_name=model_name,
),
model_kwargs={"metadata": {"foo": "bar"}},
custom_config=CUSTOM_CONFIG,
)
# When custom_config is set, invoke() internally uses stream=True and
# reassembles via stream_chunk_builder, so the mock must return stream chunks.
mock_stream_chunks = [
litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
delta=_create_delta(role="assistant", content="Hello"),
finish_reason="stop",
index=0,
)
],
model="gpt-3.5-turbo",
),
]
def on_litellm_completion(
**kwargs: dict[str, Any], # noqa: ARG001
) -> list[litellm.ModelResponse]:
# Validate that the environment variables are those in custom config
for env_var, value in CUSTOM_CONFIG.items():
assert env_var in os.environ
assert os.environ[env_var] == value
return mock_stream_chunks
with (
patch("litellm.completion") as mock_completion,
patch("onyx.llm.utils.SEND_USER_METADATA_TO_LLM_PROVIDER", False),
):
mock_completion.side_effect = on_litellm_completion
messages: LanguageModelInput = [UserMessage(content="Hi")]
identity = LLMUserIdentity(user_id="user_123", session_id="session_abc")
llm.invoke(messages, user_identity=identity)
mock_completion.assert_called_once()
kwargs = mock_completion.call_args.kwargs
assert kwargs["stream"] is True
assert "user" not in kwargs
assert kwargs["metadata"]["foo"] == "bar"
# Check that the environment variables are back to the original values
for env_var, value in EXPECTED_ENV_VARS.items():
assert env_var in os.environ
assert os.environ[env_var] == value
# Check that temporary env var from CUSTOM_CONFIG is no longer set
assert "THIS_IS_RANDOM" not in os.environ
def test_temporary_env_cleanup_on_exception(monkeypatch: pytest.MonkeyPatch) -> None:
"""Verify env vars are restored even when an exception occurs during LLM invocation."""
# Assign some environment variables
EXPECTED_ENV_VARS = {
"TEST_ENV_VAR": "test_value",
"ANOTHER_ONE": "1",
"THIRD_ONE": "2",
}
CUSTOM_CONFIG = {
"TEST_ENV_VAR": "fdsfsdf",
"ANOTHER_ONE": "3",
"THIS_IS_RANDOM": "123213",
}
for env_var, value in EXPECTED_ENV_VARS.items():
monkeypatch.setenv(env_var, value)
model_provider = LlmProviderNames.OPENAI
model_name = "gpt-3.5-turbo"
llm = LitellmLLM(
api_key="test_key",
timeout=30,
model_provider=model_provider,
model_name=model_name,
max_input_tokens=get_max_input_tokens(
model_provider=model_provider,
model_name=model_name,
),
model_kwargs={"metadata": {"foo": "bar"}},
custom_config=CUSTOM_CONFIG,
)
def on_litellm_completion_raises(**kwargs: dict[str, Any]) -> None: # noqa: ARG001
# Validate that the environment variables are those in custom config
for env_var, value in CUSTOM_CONFIG.items():
assert env_var in os.environ
assert os.environ[env_var] == value
# Simulate an error during LLM call
raise RuntimeError("Simulated LLM API failure")
with (
patch("litellm.completion") as mock_completion,
patch("onyx.llm.utils.SEND_USER_METADATA_TO_LLM_PROVIDER", False),
):
mock_completion.side_effect = on_litellm_completion_raises
messages: LanguageModelInput = [UserMessage(content="Hi")]
identity = LLMUserIdentity(user_id="user_123", session_id="session_abc")
with pytest.raises(RuntimeError, match="Simulated LLM API failure"):
llm.invoke(messages, user_identity=identity)
mock_completion.assert_called_once()
# Check that the environment variables are back to the original values
for env_var, value in EXPECTED_ENV_VARS.items():
assert env_var in os.environ
assert os.environ[env_var] == value
# Check that temporary env var from CUSTOM_CONFIG is no longer set
assert "THIS_IS_RANDOM" not in os.environ
def test_multithreaded_custom_config_isolation(monkeypatch: pytest.MonkeyPatch) -> None:
"""Verify the env lock prevents concurrent LLM calls from seeing each other's custom_config.
Two LitellmLLM instances with different custom_config dicts invoke concurrently.
The _env_lock in temporary_env_and_lock serializes their access so each call only
ever sees its own env vars—never the other's.
"""
# Ensure these keys start unset
monkeypatch.delenv("SHARED_KEY", raising=False)
monkeypatch.delenv("LLM_A_ONLY", raising=False)
monkeypatch.delenv("LLM_B_ONLY", raising=False)
CONFIG_A = {
"SHARED_KEY": "value_from_A",
"LLM_A_ONLY": "a_secret",
}
CONFIG_B = {
"SHARED_KEY": "value_from_B",
"LLM_B_ONLY": "b_secret",
}
all_env_keys = list(set(list(CONFIG_A.keys()) + list(CONFIG_B.keys())))
model_provider = LlmProviderNames.OPENAI
model_name = "gpt-3.5-turbo"
llm_a = LitellmLLM(
api_key="key_a",
timeout=30,
model_provider=model_provider,
model_name=model_name,
max_input_tokens=get_max_input_tokens(
model_provider=model_provider,
model_name=model_name,
),
custom_config=CONFIG_A,
)
llm_b = LitellmLLM(
api_key="key_b",
timeout=30,
model_provider=model_provider,
model_name=model_name,
max_input_tokens=get_max_input_tokens(
model_provider=model_provider,
model_name=model_name,
),
custom_config=CONFIG_B,
)
# invoke() uses stream=True internally when custom_config is set,
# so the mock must return stream chunks.
mock_stream_chunks = [
litellm.ModelResponse(
id="chatcmpl-123",
choices=[
litellm.Choices(
delta=_create_delta(role="assistant", content="Hi"),
finish_reason="stop",
index=0,
)
],
model=model_name,
),
]
# Track what each call observed inside litellm.completion.
# Keyed by api_key so we can identify which LLM instance made the call.
observed_envs: dict[str, dict[str, str | None]] = {}
def fake_completion(**kwargs: Any) -> list[litellm.ModelResponse]:
time.sleep(0.1) # We expect someone to get caught on the lock
api_key = kwargs.get("api_key", "")
label = "A" if api_key == "key_a" else "B"
snapshot: dict[str, str | None] = {}
for key in all_env_keys:
snapshot[key] = os.environ.get(key)
observed_envs[label] = snapshot
return mock_stream_chunks
errors: list[Exception] = []
def run_llm(llm: LitellmLLM) -> None:
try:
messages: LanguageModelInput = [UserMessage(content="Hi")]
llm.invoke(messages)
except Exception as e:
errors.append(e)
with patch("litellm.completion", side_effect=fake_completion):
t_a = threading.Thread(target=run_llm, args=(llm_a,))
t_b = threading.Thread(target=run_llm, args=(llm_b,))
t_a.start()
t_b.start()
t_a.join(timeout=10)
t_b.join(timeout=10)
assert not errors, f"Thread errors: {errors}"
assert "A" in observed_envs and "B" in observed_envs
# Thread A must have seen its own config for SHARED_KEY, not B's
assert observed_envs["A"]["SHARED_KEY"] == "value_from_A"
assert observed_envs["A"]["LLM_A_ONLY"] == "a_secret"
# A must NOT see B's exclusive key
assert observed_envs["A"]["LLM_B_ONLY"] is None
# Thread B must have seen its own config for SHARED_KEY, not A's
assert observed_envs["B"]["SHARED_KEY"] == "value_from_B"
assert observed_envs["B"]["LLM_B_ONLY"] == "b_secret"
# B must NOT see A's exclusive key
assert observed_envs["B"]["LLM_A_ONLY"] is None
# After both calls, env should be clean
assert os.environ.get("SHARED_KEY") is None
assert os.environ.get("LLM_A_ONLY") is None
assert os.environ.get("LLM_B_ONLY") is None

View File

@@ -73,6 +73,10 @@ class TestGetOllamaAvailableModels:
# Check display names are generated
assert any("Llama" in r.display_name for r in results)
assert any("Mistral" in r.display_name for r in results)
# Results should be alphabetically sorted by model name
assert [r.name for r in results] == sorted(
[r.name for r in results], key=str.lower
)
def test_syncs_to_db_when_provider_name_specified(
self, mock_ollama_tags_response: dict, mock_ollama_show_response: dict

View File

@@ -0,0 +1,239 @@
"""Tests for SensitiveValue wrapper class."""
import json
from typing import Any
import pytest
from onyx.utils.sensitive import SensitiveAccessError
from onyx.utils.sensitive import SensitiveValue
def _encrypt_string(value: str) -> bytes:
"""Simple mock encryption (just encoding for tests)."""
return value.encode("utf-8")
def _decrypt_string(value: bytes) -> str:
"""Simple mock decryption (just decoding for tests)."""
return value.decode("utf-8")
class TestSensitiveValueString:
"""Tests for SensitiveValue with string values."""
def test_get_value_raw(self) -> None:
"""Test getting raw unmasked value."""
sensitive = SensitiveValue(
encrypted_bytes=_encrypt_string("my-secret-token"),
decrypt_fn=_decrypt_string,
is_json=False,
)
assert sensitive.get_value(apply_mask=False) == "my-secret-token"
def test_get_value_masked(self) -> None:
"""Test getting masked value with default masking."""
sensitive = SensitiveValue(
encrypted_bytes=_encrypt_string("my-very-long-secret-token-here"),
decrypt_fn=_decrypt_string,
is_json=False,
)
result = sensitive.get_value(apply_mask=True)
# Default mask_string shows first 4 and last 4 chars
assert result == "my-v...here"
def test_get_value_masked_short_string(self) -> None:
"""Test that short strings are fully masked."""
sensitive = SensitiveValue(
encrypted_bytes=_encrypt_string("short"),
decrypt_fn=_decrypt_string,
is_json=False,
)
result = sensitive.get_value(apply_mask=True)
# Short strings get fully masked
assert result == "••••••••••••"
def test_get_value_custom_mask_fn(self) -> None:
"""Test using a custom masking function."""
sensitive = SensitiveValue(
encrypted_bytes=_encrypt_string("secret"),
decrypt_fn=_decrypt_string,
is_json=False,
)
result = sensitive.get_value(
apply_mask=True,
mask_fn=lambda x: "REDACTED", # noqa: ARG005
)
assert result == "REDACTED"
def test_str_raises_error(self) -> None:
"""Test that str() raises SensitiveAccessError."""
sensitive = SensitiveValue(
encrypted_bytes=_encrypt_string("secret"),
decrypt_fn=_decrypt_string,
is_json=False,
)
with pytest.raises(SensitiveAccessError):
str(sensitive)
def test_repr_is_safe(self) -> None:
"""Test that repr() doesn't expose the value."""
sensitive = SensitiveValue(
encrypted_bytes=_encrypt_string("secret"),
decrypt_fn=_decrypt_string,
is_json=False,
)
result = repr(sensitive)
assert "secret" not in result
assert "SensitiveValue" in result
assert "get_value" in result
def test_iter_raises_error(self) -> None:
"""Test that iteration raises SensitiveAccessError."""
sensitive = SensitiveValue(
encrypted_bytes=_encrypt_string("secret"),
decrypt_fn=_decrypt_string,
is_json=False,
)
with pytest.raises(SensitiveAccessError):
for _ in sensitive: # type: ignore[attr-defined]
pass
def test_getitem_raises_error(self) -> None:
"""Test that subscript access raises SensitiveAccessError."""
sensitive = SensitiveValue(
encrypted_bytes=_encrypt_string("secret"),
decrypt_fn=_decrypt_string,
is_json=False,
)
with pytest.raises(SensitiveAccessError):
_ = sensitive[0]
def test_bool_returns_true(self) -> None:
"""Test that bool() works for truthiness checks."""
sensitive = SensitiveValue(
encrypted_bytes=_encrypt_string("secret"),
decrypt_fn=_decrypt_string,
is_json=False,
)
assert bool(sensitive) is True
def test_equality_with_same_value(self) -> None:
"""Test equality comparison between SensitiveValues with same encrypted bytes."""
encrypted = _encrypt_string("secret")
sensitive1 = SensitiveValue(
encrypted_bytes=encrypted,
decrypt_fn=_decrypt_string,
is_json=False,
)
sensitive2 = SensitiveValue(
encrypted_bytes=encrypted,
decrypt_fn=_decrypt_string,
is_json=False,
)
assert sensitive1 == sensitive2
def test_equality_with_different_value(self) -> None:
"""Test equality comparison between SensitiveValues with different encrypted bytes."""
sensitive1 = SensitiveValue(
encrypted_bytes=_encrypt_string("secret1"),
decrypt_fn=_decrypt_string,
is_json=False,
)
sensitive2 = SensitiveValue(
encrypted_bytes=_encrypt_string("secret2"),
decrypt_fn=_decrypt_string,
is_json=False,
)
assert sensitive1 != sensitive2
def test_equality_with_non_sensitive_raises(self) -> None:
"""Test that comparing with non-SensitiveValue raises error."""
sensitive = SensitiveValue(
encrypted_bytes=_encrypt_string("secret"),
decrypt_fn=_decrypt_string,
is_json=False,
)
with pytest.raises(SensitiveAccessError):
_ = sensitive == "secret"
class TestSensitiveValueJson:
"""Tests for SensitiveValue with JSON/dict values."""
def test_get_value_raw_dict(self) -> None:
"""Test getting raw unmasked dict value."""
data: dict[str, Any] = {"api_key": "secret-key", "username": "user123"}
sensitive: SensitiveValue[dict[str, Any]] = SensitiveValue(
encrypted_bytes=_encrypt_string(json.dumps(data)),
decrypt_fn=_decrypt_string,
is_json=True,
)
result = sensitive.get_value(apply_mask=False)
assert result == data
def test_get_value_masked_dict(self) -> None:
"""Test getting masked dict value with default masking."""
data = {"api_key": "my-very-long-api-key-value", "username": "user123456789"}
sensitive = SensitiveValue(
encrypted_bytes=_encrypt_string(json.dumps(data)),
decrypt_fn=_decrypt_string,
is_json=True,
)
result = sensitive.get_value(apply_mask=True)
# Values should be masked
assert "my-very-long-api-key-value" not in str(result)
assert "user123456789" not in str(result)
def test_getitem_raises_error_for_dict(self) -> None:
"""Test that subscript access raises SensitiveAccessError for dict."""
data = {"api_key": "secret"}
sensitive = SensitiveValue(
encrypted_bytes=_encrypt_string(json.dumps(data)),
decrypt_fn=_decrypt_string,
is_json=True,
)
with pytest.raises(SensitiveAccessError):
_ = sensitive["api_key"]
def test_iter_raises_error_for_dict(self) -> None:
"""Test that iteration raises SensitiveAccessError for dict."""
data = {"api_key": "secret"}
sensitive = SensitiveValue(
encrypted_bytes=_encrypt_string(json.dumps(data)),
decrypt_fn=_decrypt_string,
is_json=True,
)
with pytest.raises(SensitiveAccessError):
for _ in sensitive: # type: ignore[attr-defined]
pass
class TestSensitiveValueCaching:
"""Tests for lazy decryption caching."""
def test_decryption_is_cached(self) -> None:
"""Test that decryption result is cached."""
decrypt_count = [0]
def counting_decrypt(value: bytes) -> str:
decrypt_count[0] += 1
return value.decode("utf-8")
sensitive = SensitiveValue(
encrypted_bytes=_encrypt_string("secret"),
decrypt_fn=counting_decrypt,
is_json=False,
)
# First access
sensitive.get_value(apply_mask=False)
assert decrypt_count[0] == 1
# Second access should use cached value
sensitive.get_value(apply_mask=False)
assert decrypt_count[0] == 1
# Masked access should also use cached value
sensitive.get_value(apply_mask=True)
assert decrypt_count[0] == 1

View File

@@ -0,0 +1,78 @@
"""
Tests demonstrating static type checking for SensitiveValue.
Run with: mypy tests/unit/onyx/utils/test_sensitive_typing.py --ignore-missing-imports
These tests show what mypy will catch when SensitiveValue is misused.
"""
from typing import Any
# This file demonstrates what mypy will catch.
# The commented-out code below would produce type errors.
def demonstrate_correct_usage() -> None:
"""Shows correct patterns that pass type checking."""
from onyx.utils.sensitive import SensitiveValue
from onyx.utils.encryption import encrypt_string_to_bytes, decrypt_bytes_to_string
# Create a SensitiveValue
encrypted = encrypt_string_to_bytes('{"api_key": "secret"}')
sensitive: SensitiveValue[dict[str, Any]] = SensitiveValue(
encrypted_bytes=encrypted,
decrypt_fn=decrypt_bytes_to_string,
is_json=True,
)
# CORRECT: Using get_value() to access the value
raw_dict: dict[str, Any] = sensitive.get_value(apply_mask=False)
assert raw_dict["api_key"] == "secret"
masked_dict: dict[str, Any] = sensitive.get_value(apply_mask=True)
assert "secret" not in str(masked_dict)
# CORRECT: Using bool for truthiness
if sensitive:
print("Value exists")
# The code below demonstrates what mypy would catch.
# Uncomment to see the type errors.
"""
def demonstrate_incorrect_usage() -> None:
'''Shows patterns that mypy will flag as errors.'''
from onyx.utils.sensitive import SensitiveValue
from onyx.utils.encryption import encrypt_string_to_bytes, decrypt_bytes_to_string
encrypted = encrypt_string_to_bytes('{"api_key": "secret"}')
sensitive: SensitiveValue[dict[str, Any]] = SensitiveValue(
encrypted_bytes=encrypted,
decrypt_fn=decrypt_bytes_to_string,
is_json=True,
)
# ERROR: SensitiveValue doesn't support subscript access
# mypy error: Value of type "SensitiveValue[dict[str, Any]]" is not indexable
api_key = sensitive["api_key"]
# ERROR: SensitiveValue doesn't support iteration
# mypy error: "SensitiveValue[dict[str, Any]]" has no attribute "__iter__"
for key in sensitive:
print(key)
# ERROR: Can't pass SensitiveValue where dict is expected
# mypy error: Argument 1 has incompatible type "SensitiveValue[dict[str, Any]]"; expected "dict[str, Any]"
def process_dict(d: dict[str, Any]) -> None:
pass
process_dict(sensitive)
# ERROR: Can't use .get() on SensitiveValue
# mypy error: "SensitiveValue[dict[str, Any]]" has no attribute "get"
value = sensitive.get("api_key")
"""
def test_correct_usage_passes() -> None:
"""This test runs the correct usage demonstration."""
demonstrate_correct_usage()

View File

@@ -29,6 +29,8 @@ services:
- MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-}
- ENV_SEED_CONFIGURATION=${ENV_SEED_CONFIGURATION:-}
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=True
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
- LICENSE_ENFORCEMENT_ENABLED=false
# MinIO configuration
- S3_ENDPOINT_URL=${S3_ENDPOINT_URL:-http://minio:9000}
- S3_AWS_ACCESS_KEY_ID=${S3_AWS_ACCESS_KEY_ID:-minioadmin}
@@ -66,6 +68,8 @@ services:
- INDEXING_MODEL_SERVER_HOST=${INDEXING_MODEL_SERVER_HOST:-indexing_model_server}
- ENV_SEED_CONFIGURATION=${ENV_SEED_CONFIGURATION:-}
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=True
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
- LICENSE_ENFORCEMENT_ENABLED=false
# MinIO configuration
- S3_ENDPOINT_URL=${S3_ENDPOINT_URL:-http://minio:9000}
- S3_AWS_ACCESS_KEY_ID=${S3_AWS_ACCESS_KEY_ID:-minioadmin}

View File

@@ -65,22 +65,7 @@ Bring up the entire application.
npx playwright install
```
1. Reset the instance
```cd backend
export PYTEST_IGNORE_SKIP=true
pytest -s tests/integration/tests/playwright/test_playwright.py
```
If you don't want to reset your local instance, you can still run playwright tests
with SKIP_AUTH=true. This is convenient but slightly different from what happens
in CI so tests might pass locally and fail in CI.
```cd web
SKIP_AUTH=true npx playwright test create_and_edit_assistant.spec.ts --project=admin
```
2. Run playwright
1. Run playwright
```
cd web
@@ -101,7 +86,7 @@ npx playwright test --ui
npx playwright test --headed
```
3. Inspect results
2. Inspect results
By default, playwright.config.ts is configured to output the results to:
@@ -109,7 +94,7 @@ By default, playwright.config.ts is configured to output the results to:
web/test-results
```
4. Upload results to Chromatic (Optional)
3. Upload results to Chromatic (Optional)
This step would normally not be run by third party developers, but first party devs
may use this for local troubleshooting and testing.

View File

@@ -1,6 +1,6 @@
# Opal Components
High-level UI components built on the [`@opal/core`](../core/) primitives. Every component in this directory delegates state styling (hover, active, disabled, selected) to `Interactive.Base` via CSS data-attributes and the `--interactive-foreground` custom property — no duplicated Tailwind class maps.
High-level UI components built on the [`@opal/core`](../core/) primitives. Every component in this directory delegates state styling (hover, active, disabled, transient) to `Interactive.Base` via CSS data-attributes and the `--interactive-foreground` custom property — no duplicated Tailwind class maps.
## Package export

View File

@@ -7,34 +7,36 @@ A single component that handles both labeled buttons and icon-only buttons. It r
## Architecture
```
Interactive.Base <- variant/subvariant, selected, disabled, href, onClick
└─ Interactive.Container <- height, rounding, padding (derived from `size`)
Interactive.Base <- variant/subvariant, transient, disabled, href, onClick
└─ Interactive.Container <- height, rounding, padding (derived from `size`), border (auto for secondary)
└─ div.opal-button.interactive-foreground <- flexbox row layout
├─ Icon? .opal-button-icon (1rem x 1rem, shrink-0)
├─ <span>? .opal-button-label (whitespace-nowrap, font)
└─ RightIcon? .opal-button-icon
├─ div.p-0.5 > Icon? (compact: 12px, default: 16px, shrink-0)
├─ <span>? .opal-button-label (whitespace-nowrap, font)
└─ div.p-0.5 > RightIcon? (compact: 12px, default: 16px, shrink-0)
```
- **Colors are not in the Button.** `Interactive.Base` sets `background-color` and `--interactive-foreground` per variant/subvariant/state. The `.interactive-foreground` utility class on the content div sets `color: var(--interactive-foreground)`, which both the `<span>` text and `stroke="currentColor"` SVG icons inherit automatically.
- **Layout is in `styles.css`.** The CSS classes (`.opal-button`, `.opal-button-icon`, `.opal-button-label`) handle flexbox alignment, gap, icon sizing, and text styling. A `[data-size="compact"]` selector tightens the gap and reduces font size.
- **Layout is in `styles.css`.** The CSS classes (`.opal-button`, `.opal-button-label`) handle flexbox alignment, gap, and text styling. Default labels use `font-main-ui-action` (14px/600); compact labels use `font-secondary-action` (12px/600) via a `[data-size="compact"]` selector.
- **Sizing is delegated to `Interactive.Container` presets.** The `size` prop maps to Container height/rounding/padding presets:
- `"default"` -> height 2.25rem, rounding 12px, padding 8px
- `"compact"` -> height 1.75rem, rounding 8px, padding 4px
- **Icon-only buttons render as squares** because `Interactive.Container` enforces `min-width >= height` for every height preset.
- **Border is automatic for `subvariant="secondary"`.** The Container receives `border={subvariant === "secondary"}` internally — there is no external `border` prop.
## Props
| Prop | Type | Default | Description |
|------|------|---------|-------------|
| `variant` | `"default" \| "action" \| "danger" \| "none" \| "select"` | `"default"` | Top-level color variant (maps to `Interactive.Base`) |
| `subvariant` | Depends on `variant` | `"primary"` | Color subvariant -- e.g. `"primary"`, `"secondary"`, `"ghost"` for default/action/danger |
| `subvariant` | Depends on `variant` | `"primary"` | Color subvariant -- e.g. `"primary"`, `"secondary"`, `"ghost"` for default/action/danger. `"secondary"` automatically renders a border. |
| `icon` | `IconFunctionComponent` | -- | Left icon component |
| `children` | `string` | -- | Button label text. Omit for icon-only buttons |
| `rightIcon` | `IconFunctionComponent` | -- | Right icon component |
| `size` | `SizeVariant` | `"default"` | Size preset controlling height, rounding, padding, gap, and font size |
| `size` | `SizeVariant` | `"default"` | Size preset controlling height, rounding, padding, icon size, and font style |
| `tooltip` | `string` | -- | Tooltip text shown on hover |
| `tooltipSide` | `TooltipSide` | `"top"` | Which side the tooltip appears on |
| `selected` | `boolean` | `false` | Forces the selected visual state (data-selected) |
| `selected` | `boolean` | `false` | Switches foreground to action-link colours (only available with `variant="select"`) |
| `transient` | `boolean` | `false` | Forces the transient (hover) visual state (data-transient) |
| `disabled` | `boolean` | `false` | Disables the button (data-disabled, aria-disabled) |
| `href` | `string` | -- | URL; renders an `<a>` wrapper instead of Radix Slot |
| `onClick` | `MouseEventHandler<HTMLElement>` | -- | Click handler |
@@ -59,7 +61,7 @@ import { SvgPlus, SvgArrowRight } from "@opal/icons";
Add item
</Button>
// Labeled button with right icon
// Secondary button (automatically renders a border)
<Button rightIcon={SvgArrowRight} variant="default" subvariant="secondary">
Continue
</Button>
@@ -74,8 +76,8 @@ import { SvgPlus, SvgArrowRight } from "@opal/icons";
Settings
</Button>
// Selected state (e.g. inside a popover trigger)
<Button icon={SvgFilter} subvariant="ghost" selected={isOpen} />
// Transient state (e.g. inside a popover trigger)
<Button icon={SvgFilter} subvariant="ghost" transient={isOpen} />
// With tooltip
<Button icon={SvgPlus} subvariant="ghost" tooltip="Add item" />
@@ -91,7 +93,7 @@ import { SvgPlus, SvgArrowRight } from "@opal/icons";
| `primary` | `subvariant="primary"` (default, can be omitted) |
| `secondary` | `subvariant="secondary"` |
| `tertiary` | `subvariant="ghost"` |
| `transient={x}` | `selected={x}` |
| `transient={x}` | `transient={x}` |
| `size="md"` | `size="compact"` |
| `size="lg"` | `size="default"` (default, can be omitted) |
| `leftIcon={X}` | `icon={X}` |

View File

@@ -4,6 +4,7 @@ import { Interactive, type InteractiveBaseProps } from "@opal/core";
import type { SizeVariant, TooltipSide } from "@opal/components";
import type { IconFunctionComponent } from "@opal/types";
import * as TooltipPrimitive from "@radix-ui/react-tooltip";
import { cn } from "@opal/utils";
// ---------------------------------------------------------------------------
// Types
@@ -29,6 +30,25 @@ type ButtonProps = InteractiveBaseProps & {
tooltipSide?: TooltipSide;
};
function iconWrapper(
Icon: IconFunctionComponent | undefined,
isCompact: boolean
) {
return Icon ? (
<div className="p-0.5">
<Icon
className={cn(
"shrink-0",
isCompact ? "h-[0.75rem] w-[0.75rem]" : "h-[1rem] w-[1rem]"
)}
size={isCompact ? 12 : 16}
/>
</div>
) : (
<div className="w-[0.125rem]" />
);
}
// ---------------------------------------------------------------------------
// Button
// ---------------------------------------------------------------------------
@@ -40,29 +60,31 @@ function Button({
size = "default",
tooltip,
tooltipSide = "top",
variant,
subvariant,
...baseProps
...interactiveBaseProps
}: ButtonProps) {
const isCompact = size === "compact";
const button = (
<Interactive.Base
{...({ variant, subvariant } as InteractiveBaseProps)}
{...baseProps}
>
<Interactive.Base {...interactiveBaseProps}>
<Interactive.Container
border={interactiveBaseProps.subvariant === "secondary"}
heightVariant={isCompact ? "compact" : "default"}
roundingVariant={isCompact ? "compact" : "default"}
paddingVariant={isCompact ? "thin" : "default"}
>
<div
className="opal-button interactive-foreground"
data-size={isCompact ? "compact" : undefined}
>
{Icon && <Icon className="opal-button-icon" />}
{children && <span className="opal-button-label">{children}</span>}
{RightIcon && <RightIcon className="opal-button-icon" />}
<div className="opal-button interactive-foreground">
{iconWrapper(Icon, isCompact)}
{children && (
<span
className={cn(
"opal-button-label",
isCompact ? "font-secondary-action" : "font-main-ui-action"
)}
>
{children}
</span>
)}
{iconWrapper(RightIcon, isCompact)}
</div>
</Interactive.Container>
</Interactive.Base>

Some files were not shown because too many files have changed in this diff Show More