mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-03-03 14:45:46 +00:00
Compare commits
11 Commits
v2.12.2
...
litellm_co
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5115b621c8 | ||
|
|
a924b49405 | ||
|
|
2d2d998811 | ||
|
|
0925b5fbd4 | ||
|
|
a02d8414ee | ||
|
|
c8abc4a115 | ||
|
|
cec37bff6a | ||
|
|
06d5d3971b | ||
|
|
ed287a2fc0 | ||
|
|
60857d1e73 | ||
|
|
bb5c22104e |
3
.github/workflows/pr-database-tests.yml
vendored
3
.github/workflows/pr-database-tests.yml
vendored
@@ -40,6 +40,9 @@ jobs:
|
||||
|
||||
- name: Generate OpenAPI schema and Python client
|
||||
shell: bash
|
||||
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
|
||||
env:
|
||||
LICENSE_ENFORCEMENT_ENABLED: "false"
|
||||
run: |
|
||||
ods openapi all
|
||||
|
||||
|
||||
3
.github/workflows/pr-integration-tests.yml
vendored
3
.github/workflows/pr-integration-tests.yml
vendored
@@ -302,6 +302,8 @@ jobs:
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
COMPOSE_PROFILES=s3-filestore
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
|
||||
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
|
||||
LICENSE_ENFORCEMENT_ENABLED=false
|
||||
AUTH_TYPE=basic
|
||||
POSTGRES_POOL_PRE_PING=true
|
||||
POSTGRES_USE_NULL_POOL=true
|
||||
@@ -478,6 +480,7 @@ jobs:
|
||||
run: |
|
||||
cd deployment/docker_compose
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true \
|
||||
LICENSE_ENFORCEMENT_ENABLED=false \
|
||||
MULTI_TENANT=true \
|
||||
AUTH_TYPE=cloud \
|
||||
REQUIRE_EMAIL_VERIFICATION=false \
|
||||
|
||||
2
.github/workflows/pr-playwright-tests.yml
vendored
2
.github/workflows/pr-playwright-tests.yml
vendored
@@ -291,6 +291,8 @@ jobs:
|
||||
cat <<EOF > deployment/docker_compose/.env
|
||||
COMPOSE_PROFILES=s3-filestore
|
||||
ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=true
|
||||
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
|
||||
LICENSE_ENFORCEMENT_ENABLED=false
|
||||
AUTH_TYPE=basic
|
||||
GEN_AI_API_KEY=${OPENAI_API_KEY_VALUE}
|
||||
EXA_API_KEY=${EXA_API_KEY_VALUE}
|
||||
|
||||
3
.github/workflows/pr-python-checks.yml
vendored
3
.github/workflows/pr-python-checks.yml
vendored
@@ -42,6 +42,9 @@ jobs:
|
||||
|
||||
- name: Generate OpenAPI schema and Python client
|
||||
shell: bash
|
||||
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
|
||||
env:
|
||||
LICENSE_ENFORCEMENT_ENABLED: "false"
|
||||
run: |
|
||||
ods openapi all
|
||||
|
||||
|
||||
2
.github/workflows/pr-python-tests.yml
vendored
2
.github/workflows/pr-python-tests.yml
vendored
@@ -27,6 +27,8 @@ jobs:
|
||||
PYTHONPATH: ./backend
|
||||
REDIS_CLOUD_PYTEST_PASSWORD: ${{ secrets.REDIS_CLOUD_PYTEST_PASSWORD }}
|
||||
DISABLE_TELEMETRY: "true"
|
||||
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
|
||||
LICENSE_ENFORCEMENT_ENABLED: "false"
|
||||
|
||||
steps:
|
||||
- uses: runs-on/action@cd2b598b0515d39d78c38a02d529db87d2196d1e # ratchet:runs-on/action@v2
|
||||
|
||||
@@ -0,0 +1,27 @@
|
||||
"""add_user_preferences
|
||||
|
||||
Revision ID: 175ea04c7087
|
||||
Revises: d56ffa94ca32
|
||||
Create Date: 2026-02-04 18:16:24.830873
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "175ea04c7087"
|
||||
down_revision = "d56ffa94ca32"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"user",
|
||||
sa.Column("user_preferences", sa.Text(), nullable=True),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("user", "user_preferences")
|
||||
@@ -134,7 +134,7 @@ GATED_TENANTS_KEY = "gated_tenants"
|
||||
|
||||
# License enforcement - when True, blocks API access for gated/expired licenses
|
||||
LICENSE_ENFORCEMENT_ENABLED = (
|
||||
os.environ.get("LICENSE_ENFORCEMENT_ENABLED", "").lower() == "true"
|
||||
os.environ.get("LICENSE_ENFORCEMENT_ENABLED", "true").lower() == "true"
|
||||
)
|
||||
|
||||
# Cloud data plane URL - self-hosted instances call this to reach cloud proxy endpoints
|
||||
|
||||
@@ -50,7 +50,12 @@ def github_doc_sync(
|
||||
**cc_pair.connector.connector_specific_config
|
||||
)
|
||||
|
||||
github_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
credential_json = (
|
||||
cc_pair.credential.credential_json.get_value(apply_mask=False)
|
||||
if cc_pair.credential.credential_json
|
||||
else {}
|
||||
)
|
||||
github_connector.load_credentials(credential_json)
|
||||
logger.info("GitHub connector credentials loaded successfully")
|
||||
|
||||
if not github_connector.github_client:
|
||||
|
||||
@@ -18,7 +18,12 @@ def github_group_sync(
|
||||
github_connector: GithubConnector = GithubConnector(
|
||||
**cc_pair.connector.connector_specific_config
|
||||
)
|
||||
github_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
credential_json = (
|
||||
cc_pair.credential.credential_json.get_value(apply_mask=False)
|
||||
if cc_pair.credential.credential_json
|
||||
else {}
|
||||
)
|
||||
github_connector.load_credentials(credential_json)
|
||||
if not github_connector.github_client:
|
||||
raise ValueError("github_client is required")
|
||||
|
||||
|
||||
@@ -50,7 +50,12 @@ def gmail_doc_sync(
|
||||
already populated.
|
||||
"""
|
||||
gmail_connector = GmailConnector(**cc_pair.connector.connector_specific_config)
|
||||
gmail_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
credential_json = (
|
||||
cc_pair.credential.credential_json.get_value(apply_mask=False)
|
||||
if cc_pair.credential.credential_json
|
||||
else {}
|
||||
)
|
||||
gmail_connector.load_credentials(credential_json)
|
||||
|
||||
slim_doc_generator = _get_slim_doc_generator(
|
||||
cc_pair, gmail_connector, callback=callback
|
||||
|
||||
@@ -295,7 +295,12 @@ def gdrive_doc_sync(
|
||||
google_drive_connector = GoogleDriveConnector(
|
||||
**cc_pair.connector.connector_specific_config
|
||||
)
|
||||
google_drive_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
credential_json = (
|
||||
cc_pair.credential.credential_json.get_value(apply_mask=False)
|
||||
if cc_pair.credential.credential_json
|
||||
else {}
|
||||
)
|
||||
google_drive_connector.load_credentials(credential_json)
|
||||
|
||||
slim_doc_generator = _get_slim_doc_generator(cc_pair, google_drive_connector)
|
||||
|
||||
|
||||
@@ -391,7 +391,12 @@ def gdrive_group_sync(
|
||||
google_drive_connector = GoogleDriveConnector(
|
||||
**cc_pair.connector.connector_specific_config
|
||||
)
|
||||
google_drive_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
credential_json = (
|
||||
cc_pair.credential.credential_json.get_value(apply_mask=False)
|
||||
if cc_pair.credential.credential_json
|
||||
else {}
|
||||
)
|
||||
google_drive_connector.load_credentials(credential_json)
|
||||
admin_service = get_admin_service(
|
||||
google_drive_connector.creds, google_drive_connector.primary_admin_email
|
||||
)
|
||||
|
||||
@@ -24,7 +24,12 @@ def jira_doc_sync(
|
||||
jira_connector = JiraConnector(
|
||||
**cc_pair.connector.connector_specific_config,
|
||||
)
|
||||
jira_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
credential_json = (
|
||||
cc_pair.credential.credential_json.get_value(apply_mask=False)
|
||||
if cc_pair.credential.credential_json
|
||||
else {}
|
||||
)
|
||||
jira_connector.load_credentials(credential_json)
|
||||
|
||||
yield from generic_doc_sync(
|
||||
cc_pair=cc_pair,
|
||||
|
||||
@@ -119,8 +119,13 @@ def jira_group_sync(
|
||||
if not jira_base_url:
|
||||
raise ValueError("No jira_base_url found in connector config")
|
||||
|
||||
credential_json = (
|
||||
cc_pair.credential.credential_json.get_value(apply_mask=False)
|
||||
if cc_pair.credential.credential_json
|
||||
else {}
|
||||
)
|
||||
jira_client = build_jira_client(
|
||||
credentials=cc_pair.credential.credential_json,
|
||||
credentials=credential_json,
|
||||
jira_base=jira_base_url,
|
||||
scoped_token=scoped_token,
|
||||
)
|
||||
|
||||
@@ -30,7 +30,11 @@ def get_any_salesforce_client_for_doc_id(
|
||||
if _ANY_SALESFORCE_CLIENT is None:
|
||||
cc_pairs = get_cc_pairs_for_document(db_session, doc_id)
|
||||
first_cc_pair = cc_pairs[0]
|
||||
credential_json = first_cc_pair.credential.credential_json
|
||||
credential_json = (
|
||||
first_cc_pair.credential.credential_json.get_value(apply_mask=False)
|
||||
if first_cc_pair.credential.credential_json
|
||||
else {}
|
||||
)
|
||||
_ANY_SALESFORCE_CLIENT = Salesforce(
|
||||
username=credential_json["sf_username"],
|
||||
password=credential_json["sf_password"],
|
||||
@@ -158,7 +162,11 @@ def _get_salesforce_client_for_doc_id(db_session: Session, doc_id: str) -> Sales
|
||||
)
|
||||
if cc_pair is None:
|
||||
raise ValueError(f"CC pair {cc_pair_id} not found")
|
||||
credential_json = cc_pair.credential.credential_json
|
||||
credential_json = (
|
||||
cc_pair.credential.credential_json.get_value(apply_mask=False)
|
||||
if cc_pair.credential.credential_json
|
||||
else {}
|
||||
)
|
||||
_CC_PAIR_ID_SALESFORCE_CLIENT_MAP[cc_pair_id] = Salesforce(
|
||||
username=credential_json["sf_username"],
|
||||
password=credential_json["sf_password"],
|
||||
|
||||
@@ -24,7 +24,12 @@ def sharepoint_doc_sync(
|
||||
sharepoint_connector = SharepointConnector(
|
||||
**cc_pair.connector.connector_specific_config,
|
||||
)
|
||||
sharepoint_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
credential_json = (
|
||||
cc_pair.credential.credential_json.get_value(apply_mask=False)
|
||||
if cc_pair.credential.credential_json
|
||||
else {}
|
||||
)
|
||||
sharepoint_connector.load_credentials(credential_json)
|
||||
|
||||
yield from generic_doc_sync(
|
||||
cc_pair=cc_pair,
|
||||
|
||||
@@ -25,7 +25,12 @@ def sharepoint_group_sync(
|
||||
|
||||
# Create SharePoint connector instance and load credentials
|
||||
connector = SharepointConnector(**connector_config)
|
||||
connector.load_credentials(cc_pair.credential.credential_json)
|
||||
credential_json = (
|
||||
cc_pair.credential.credential_json.get_value(apply_mask=False)
|
||||
if cc_pair.credential.credential_json
|
||||
else {}
|
||||
)
|
||||
connector.load_credentials(credential_json)
|
||||
|
||||
if not connector.msal_app:
|
||||
raise RuntimeError("MSAL app not initialized in connector")
|
||||
|
||||
@@ -151,9 +151,14 @@ def slack_doc_sync(
|
||||
tenant_id = get_current_tenant_id()
|
||||
provider = OnyxDBCredentialsProvider(tenant_id, "slack", cc_pair.credential.id)
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
credential_json = (
|
||||
cc_pair.credential.credential_json.get_value(apply_mask=False)
|
||||
if cc_pair.credential.credential_json
|
||||
else {}
|
||||
)
|
||||
slack_client = SlackConnector.make_slack_web_client(
|
||||
provider.get_provider_key(),
|
||||
cc_pair.credential.credential_json["slack_bot_token"],
|
||||
credential_json["slack_bot_token"],
|
||||
SlackConnector.MAX_RETRIES,
|
||||
r,
|
||||
)
|
||||
|
||||
@@ -63,9 +63,14 @@ def slack_group_sync(
|
||||
|
||||
provider = OnyxDBCredentialsProvider(tenant_id, "slack", cc_pair.credential.id)
|
||||
r = get_redis_client(tenant_id=tenant_id)
|
||||
credential_json = (
|
||||
cc_pair.credential.credential_json.get_value(apply_mask=False)
|
||||
if cc_pair.credential.credential_json
|
||||
else {}
|
||||
)
|
||||
slack_client = SlackConnector.make_slack_web_client(
|
||||
provider.get_provider_key(),
|
||||
cc_pair.credential.credential_json["slack_bot_token"],
|
||||
credential_json["slack_bot_token"],
|
||||
SlackConnector.MAX_RETRIES,
|
||||
r,
|
||||
)
|
||||
|
||||
@@ -25,7 +25,12 @@ def teams_doc_sync(
|
||||
teams_connector = TeamsConnector(
|
||||
**cc_pair.connector.connector_specific_config,
|
||||
)
|
||||
teams_connector.load_credentials(cc_pair.credential.credential_json)
|
||||
credential_json = (
|
||||
cc_pair.credential.credential_json.get_value(apply_mask=False)
|
||||
if cc_pair.credential.credential_json
|
||||
else {}
|
||||
)
|
||||
teams_connector.load_credentials(credential_json)
|
||||
|
||||
yield from generic_doc_sync(
|
||||
cc_pair=cc_pair,
|
||||
|
||||
@@ -270,7 +270,11 @@ def confluence_oauth_accessible_resources(
|
||||
if not credential:
|
||||
raise HTTPException(400, f"Credential {credential_id} not found.")
|
||||
|
||||
credential_dict = credential.credential_json
|
||||
credential_dict = (
|
||||
credential.credential_json.get_value(apply_mask=False)
|
||||
if credential.credential_json
|
||||
else {}
|
||||
)
|
||||
access_token = credential_dict["confluence_access_token"]
|
||||
|
||||
try:
|
||||
@@ -337,7 +341,12 @@ def confluence_oauth_finalize(
|
||||
detail=f"Confluence Cloud OAuth failed - credential {credential_id} not found.",
|
||||
)
|
||||
|
||||
new_credential_json: dict[str, Any] = dict(credential.credential_json)
|
||||
existing_credential_json = (
|
||||
credential.credential_json.get_value(apply_mask=False)
|
||||
if credential.credential_json
|
||||
else {}
|
||||
)
|
||||
new_credential_json: dict[str, Any] = dict(existing_credential_json)
|
||||
new_credential_json["cloud_id"] = cloud_id
|
||||
new_credential_json["cloud_name"] = cloud_name
|
||||
new_credential_json["wiki_base"] = cloud_url
|
||||
|
||||
@@ -11,6 +11,7 @@ from onyx.db.models import OAuthUserToken
|
||||
from onyx.db.oauth_config import get_user_oauth_token
|
||||
from onyx.db.oauth_config import upsert_user_oauth_token
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.sensitive import SensitiveValue
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -33,7 +34,10 @@ class OAuthTokenManager:
|
||||
if not user_token:
|
||||
return None
|
||||
|
||||
token_data = user_token.token_data
|
||||
if not user_token.token_data:
|
||||
return None
|
||||
|
||||
token_data = self._unwrap_token_data(user_token.token_data)
|
||||
|
||||
# Check if token is expired
|
||||
if OAuthTokenManager.is_token_expired(token_data):
|
||||
@@ -51,7 +55,10 @@ class OAuthTokenManager:
|
||||
|
||||
def refresh_token(self, user_token: OAuthUserToken) -> str:
|
||||
"""Refresh access token using refresh token"""
|
||||
token_data = user_token.token_data
|
||||
if not user_token.token_data:
|
||||
raise ValueError("No token data available for refresh")
|
||||
|
||||
token_data = self._unwrap_token_data(user_token.token_data)
|
||||
|
||||
response = requests.post(
|
||||
self.oauth_config.token_url,
|
||||
@@ -153,3 +160,11 @@ class OAuthTokenManager:
|
||||
separator = "&" if "?" in oauth_config.authorization_url else "?"
|
||||
|
||||
return f"{oauth_config.authorization_url}{separator}{urlencode(params)}"
|
||||
|
||||
@staticmethod
|
||||
def _unwrap_token_data(
|
||||
token_data: SensitiveValue[dict[str, Any]] | dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
if isinstance(token_data, SensitiveValue):
|
||||
return token_data.get_value(apply_mask=False)
|
||||
return token_data
|
||||
|
||||
@@ -27,6 +27,7 @@ from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.context.search.models import SearchDocsResponse
|
||||
from onyx.db.memory import UserMemoryContext
|
||||
from onyx.db.models import Persona
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMUserIdentity
|
||||
@@ -370,7 +371,7 @@ def run_llm_loop(
|
||||
custom_agent_prompt: str | None,
|
||||
project_files: ExtractedProjectFiles,
|
||||
persona: Persona | None,
|
||||
memories: list[str] | None,
|
||||
user_memory_context: UserMemoryContext | None,
|
||||
llm: LLM,
|
||||
token_counter: Callable[[str], int],
|
||||
db_session: Session,
|
||||
@@ -483,7 +484,7 @@ def run_llm_loop(
|
||||
system_prompt_str = build_system_prompt(
|
||||
base_system_prompt=default_base_system_prompt,
|
||||
datetime_aware=persona.datetime_aware if persona else True,
|
||||
memories=memories,
|
||||
user_memory_context=user_memory_context,
|
||||
tools=tools,
|
||||
should_cite_documents=should_cite_documents
|
||||
or always_cite_documents,
|
||||
@@ -637,7 +638,7 @@ def run_llm_loop(
|
||||
tool_calls=tool_calls,
|
||||
tools=final_tools,
|
||||
message_history=truncated_message_history,
|
||||
memories=memories,
|
||||
user_memory_context=user_memory_context,
|
||||
user_info=None, # TODO, this is part of memories right now, might want to separate it out
|
||||
citation_mapping=citation_mapping,
|
||||
next_citation_num=citation_processor.get_next_citation_number(),
|
||||
|
||||
@@ -471,7 +471,7 @@ def handle_stream_message_objects(
|
||||
# Filter chat_history to only messages after the cutoff
|
||||
chat_history = [m for m in chat_history if m.id > cutoff_id]
|
||||
|
||||
memories = get_memories(user, db_session)
|
||||
user_memory_context = get_memories(user, db_session)
|
||||
|
||||
custom_agent_prompt = get_custom_agent_prompt(persona, chat_session)
|
||||
|
||||
@@ -480,7 +480,7 @@ def handle_stream_message_objects(
|
||||
persona_system_prompt=custom_agent_prompt or "",
|
||||
token_counter=token_counter,
|
||||
files=new_msg_req.file_descriptors,
|
||||
memories=memories,
|
||||
user_memory_context=user_memory_context,
|
||||
)
|
||||
|
||||
# Process projects, if all of the files fit in the context, it doesn't need to use RAG
|
||||
@@ -667,7 +667,7 @@ def handle_stream_message_objects(
|
||||
custom_agent_prompt=custom_agent_prompt,
|
||||
project_files=extracted_project_files,
|
||||
persona=persona,
|
||||
memories=memories,
|
||||
user_memory_context=user_memory_context,
|
||||
llm=llm,
|
||||
token_counter=token_counter,
|
||||
db_session=db_session,
|
||||
|
||||
@@ -4,6 +4,7 @@ from uuid import UUID
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.memory import UserMemoryContext
|
||||
from onyx.db.persona import get_default_behavior_persona
|
||||
from onyx.db.user_file import calculate_user_files_token_count
|
||||
from onyx.file_store.models import FileDescriptor
|
||||
@@ -12,7 +13,6 @@ from onyx.prompts.chat_prompts import CODE_BLOCK_MARKDOWN
|
||||
from onyx.prompts.chat_prompts import DEFAULT_SYSTEM_PROMPT
|
||||
from onyx.prompts.chat_prompts import LAST_CYCLE_CITATION_REMINDER
|
||||
from onyx.prompts.chat_prompts import REQUIRE_CITATION_GUIDANCE
|
||||
from onyx.prompts.chat_prompts import USER_INFO_HEADER
|
||||
from onyx.prompts.prompt_utils import get_company_context
|
||||
from onyx.prompts.prompt_utils import handle_onyx_date_awareness
|
||||
from onyx.prompts.prompt_utils import replace_citation_guidance_tag
|
||||
@@ -25,6 +25,7 @@ from onyx.prompts.tool_prompts import TOOL_DESCRIPTION_SEARCH_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import TOOL_SECTION_HEADER
|
||||
from onyx.prompts.tool_prompts import WEB_SEARCH_GUIDANCE
|
||||
from onyx.prompts.tool_prompts import WEB_SEARCH_SITE_DISABLED_GUIDANCE
|
||||
from onyx.prompts.user_info import USER_INFORMATION_HEADER
|
||||
from onyx.tools.interface import Tool
|
||||
from onyx.tools.tool_implementations.images.image_generation_tool import (
|
||||
ImageGenerationTool,
|
||||
@@ -52,7 +53,7 @@ def calculate_reserved_tokens(
|
||||
persona_system_prompt: str,
|
||||
token_counter: Callable[[str], int],
|
||||
files: list[FileDescriptor] | None = None,
|
||||
memories: list[str] | None = None,
|
||||
user_memory_context: UserMemoryContext | None = None,
|
||||
) -> int:
|
||||
"""
|
||||
Calculate reserved token count for system prompt and user files.
|
||||
@@ -66,7 +67,7 @@ def calculate_reserved_tokens(
|
||||
persona_system_prompt: Custom agent system prompt (can be empty string)
|
||||
token_counter: Function that counts tokens in text
|
||||
files: List of file descriptors from the chat message (optional)
|
||||
memories: List of memory strings (optional)
|
||||
user_memory_context: User memory context (optional)
|
||||
|
||||
Returns:
|
||||
Total reserved token count
|
||||
@@ -77,7 +78,7 @@ def calculate_reserved_tokens(
|
||||
fake_system_prompt = build_system_prompt(
|
||||
base_system_prompt=base_system_prompt,
|
||||
datetime_aware=True,
|
||||
memories=memories,
|
||||
user_memory_context=user_memory_context,
|
||||
tools=None,
|
||||
should_cite_documents=True,
|
||||
include_all_guidance=True,
|
||||
@@ -133,7 +134,7 @@ def build_reminder_message(
|
||||
def build_system_prompt(
|
||||
base_system_prompt: str,
|
||||
datetime_aware: bool = False,
|
||||
memories: list[str] | None = None,
|
||||
user_memory_context: UserMemoryContext | None = None,
|
||||
tools: Sequence[Tool] | None = None,
|
||||
should_cite_documents: bool = False,
|
||||
include_all_guidance: bool = False,
|
||||
@@ -157,14 +158,15 @@ def build_system_prompt(
|
||||
)
|
||||
|
||||
company_context = get_company_context()
|
||||
if company_context or memories:
|
||||
system_prompt += USER_INFO_HEADER
|
||||
formatted_user_context = (
|
||||
user_memory_context.as_formatted_prompt() if user_memory_context else ""
|
||||
)
|
||||
if company_context or formatted_user_context:
|
||||
system_prompt += USER_INFORMATION_HEADER
|
||||
if company_context:
|
||||
system_prompt += company_context
|
||||
if memories:
|
||||
system_prompt += "\n".join(
|
||||
"- " + memory.strip() for memory in memories if memory.strip()
|
||||
)
|
||||
if formatted_user_context:
|
||||
system_prompt += formatted_user_context
|
||||
|
||||
# Append citation guidance after company context if placeholder was not present
|
||||
# This maintains backward compatibility and ensures citations are always enforced when needed
|
||||
|
||||
@@ -65,7 +65,9 @@ class OnyxDBCredentialsProvider(
|
||||
f"No credential found: credential={self._credential_id}"
|
||||
)
|
||||
|
||||
return credential.credential_json
|
||||
if credential.credential_json is None:
|
||||
return {}
|
||||
return credential.credential_json.get_value(apply_mask=False)
|
||||
|
||||
def set_credentials(self, credential_json: dict[str, Any]) -> None:
|
||||
with get_session_with_tenant(tenant_id=self._tenant_id) as db_session:
|
||||
@@ -81,7 +83,7 @@ class OnyxDBCredentialsProvider(
|
||||
f"No credential found: credential={self._credential_id}"
|
||||
)
|
||||
|
||||
credential.credential_json = credential_json
|
||||
credential.credential_json = credential_json # type: ignore[assignment]
|
||||
db_session.commit()
|
||||
except Exception:
|
||||
db_session.rollback()
|
||||
|
||||
@@ -118,7 +118,12 @@ def instantiate_connector(
|
||||
)
|
||||
connector.set_credentials_provider(provider)
|
||||
else:
|
||||
new_credentials = connector.load_credentials(credential.credential_json)
|
||||
credential_json = (
|
||||
credential.credential_json.get_value(apply_mask=False)
|
||||
if credential.credential_json
|
||||
else {}
|
||||
)
|
||||
new_credentials = connector.load_credentials(credential_json)
|
||||
|
||||
if new_credentials is not None:
|
||||
backend_update_credential_json(credential, new_credentials, db_session)
|
||||
|
||||
@@ -270,6 +270,8 @@ def create_credential(
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
# Expire to ensure credential_json is reloaded as SensitiveValue from DB
|
||||
db_session.expire(credential)
|
||||
return credential
|
||||
|
||||
|
||||
@@ -297,14 +299,21 @@ def alter_credential(
|
||||
|
||||
credential.name = name
|
||||
|
||||
# Assign a new dictionary to credential.credential_json
|
||||
credential.credential_json = {
|
||||
**credential.credential_json,
|
||||
# Get existing credential_json and merge with new values
|
||||
existing_json = (
|
||||
credential.credential_json.get_value(apply_mask=False)
|
||||
if credential.credential_json
|
||||
else {}
|
||||
)
|
||||
credential.credential_json = { # type: ignore[assignment]
|
||||
**existing_json,
|
||||
**credential_json,
|
||||
}
|
||||
|
||||
credential.user_id = user.id
|
||||
db_session.commit()
|
||||
# Expire to ensure credential_json is reloaded as SensitiveValue from DB
|
||||
db_session.expire(credential)
|
||||
return credential
|
||||
|
||||
|
||||
@@ -318,10 +327,12 @@ def update_credential(
|
||||
if credential is None:
|
||||
return None
|
||||
|
||||
credential.credential_json = credential_data.credential_json
|
||||
credential.user_id = user.id
|
||||
credential.credential_json = credential_data.credential_json # type: ignore[assignment]
|
||||
credential.user_id = user.id if user is not None else None
|
||||
|
||||
db_session.commit()
|
||||
# Expire to ensure credential_json is reloaded as SensitiveValue from DB
|
||||
db_session.expire(credential)
|
||||
return credential
|
||||
|
||||
|
||||
@@ -335,8 +346,10 @@ def update_credential_json(
|
||||
if credential is None:
|
||||
return None
|
||||
|
||||
credential.credential_json = credential_json
|
||||
credential.credential_json = credential_json # type: ignore[assignment]
|
||||
db_session.commit()
|
||||
# Expire to ensure credential_json is reloaded as SensitiveValue from DB
|
||||
db_session.expire(credential)
|
||||
return credential
|
||||
|
||||
|
||||
@@ -346,7 +359,7 @@ def backend_update_credential_json(
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""This should not be used in any flows involving the frontend or users"""
|
||||
credential.credential_json = credential_json
|
||||
credential.credential_json = credential_json # type: ignore[assignment]
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -441,7 +454,12 @@ def create_initial_public_credential(db_session: Session) -> None:
|
||||
)
|
||||
|
||||
if first_credential is not None:
|
||||
if first_credential.credential_json != {} or first_credential.user is not None:
|
||||
credential_json_value = (
|
||||
first_credential.credential_json.get_value(apply_mask=False)
|
||||
if first_credential.credential_json
|
||||
else {}
|
||||
)
|
||||
if credential_json_value != {} or first_credential.user is not None:
|
||||
raise ValueError(error_msg)
|
||||
return
|
||||
|
||||
@@ -477,8 +495,13 @@ def delete_service_account_credentials(
|
||||
) -> None:
|
||||
credentials = fetch_credentials_for_user(db_session=db_session, user=user)
|
||||
for credential in credentials:
|
||||
credential_json = (
|
||||
credential.credential_json.get_value(apply_mask=False)
|
||||
if credential.credential_json
|
||||
else {}
|
||||
)
|
||||
if (
|
||||
credential.credential_json.get(DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY)
|
||||
credential_json.get(DB_CREDENTIALS_DICT_SERVICE_ACCOUNT_KEY)
|
||||
and credential.source == source
|
||||
):
|
||||
db_session.delete(credential)
|
||||
|
||||
@@ -111,7 +111,7 @@ def update_federated_connector_oauth_token(
|
||||
|
||||
if existing_token:
|
||||
# Update existing token
|
||||
existing_token.token = token
|
||||
existing_token.token = token # type: ignore[assignment]
|
||||
existing_token.expires_at = expires_at
|
||||
db_session.commit()
|
||||
return existing_token
|
||||
@@ -267,7 +267,13 @@ def update_federated_connector(
|
||||
# Use provided credentials if updating them, otherwise use existing credentials
|
||||
# This is needed to instantiate the connector for config validation when only config is being updated
|
||||
creds_to_use = (
|
||||
credentials if credentials is not None else federated_connector.credentials
|
||||
credentials
|
||||
if credentials is not None
|
||||
else (
|
||||
federated_connector.credentials.get_value(apply_mask=False)
|
||||
if federated_connector.credentials
|
||||
else {}
|
||||
)
|
||||
)
|
||||
|
||||
if credentials is not None:
|
||||
@@ -278,7 +284,7 @@ def update_federated_connector(
|
||||
raise ValueError(
|
||||
f"Invalid credentials for federated connector source: {federated_connector.source}"
|
||||
)
|
||||
federated_connector.credentials = credentials
|
||||
federated_connector.credentials = credentials # type: ignore[assignment]
|
||||
|
||||
if config is not None:
|
||||
# Validate config using connector-specific validation
|
||||
|
||||
@@ -232,7 +232,8 @@ def upsert_llm_provider(
|
||||
custom_config = custom_config or None
|
||||
|
||||
existing_llm_provider.provider = llm_provider_upsert_request.provider
|
||||
existing_llm_provider.api_key = llm_provider_upsert_request.api_key
|
||||
# EncryptedString accepts str for writes, returns SensitiveValue for reads
|
||||
existing_llm_provider.api_key = llm_provider_upsert_request.api_key # type: ignore[assignment]
|
||||
existing_llm_provider.api_base = llm_provider_upsert_request.api_base
|
||||
existing_llm_provider.api_version = llm_provider_upsert_request.api_version
|
||||
existing_llm_provider.custom_config = custom_config
|
||||
|
||||
@@ -19,6 +19,7 @@ from onyx.db.models import Tool
|
||||
from onyx.db.models import User
|
||||
from onyx.server.features.mcp.models import MCPConnectionData
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.sensitive import SensitiveValue
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -204,6 +205,21 @@ def remove_user_from_mcp_server(
|
||||
|
||||
|
||||
# MCPConnectionConfig operations
|
||||
def extract_connection_data(
|
||||
config: MCPConnectionConfig | None, apply_mask: bool = False
|
||||
) -> MCPConnectionData:
|
||||
"""Extract MCPConnectionData from a connection config, with proper typing.
|
||||
|
||||
This helper encapsulates the cast from the JSON column's dict[str, Any]
|
||||
to the typed MCPConnectionData structure.
|
||||
"""
|
||||
if config is None or config.config is None:
|
||||
return MCPConnectionData(headers={})
|
||||
if isinstance(config.config, SensitiveValue):
|
||||
return cast(MCPConnectionData, config.config.get_value(apply_mask=apply_mask))
|
||||
return cast(MCPConnectionData, config.config)
|
||||
|
||||
|
||||
def get_connection_config_by_id(
|
||||
config_id: int, db_session: Session
|
||||
) -> MCPConnectionConfig:
|
||||
@@ -269,7 +285,7 @@ def update_connection_config(
|
||||
config = get_connection_config_by_id(config_id, db_session)
|
||||
|
||||
if config_data is not None:
|
||||
config.config = config_data
|
||||
config.config = config_data # type: ignore[assignment]
|
||||
# Force SQLAlchemy to detect the change by marking the field as modified
|
||||
flag_modified(config, "config")
|
||||
|
||||
@@ -287,7 +303,7 @@ def upsert_user_connection_config(
|
||||
existing_config = get_user_connection_config(server_id, user_email, db_session)
|
||||
|
||||
if existing_config:
|
||||
existing_config.config = config_data
|
||||
existing_config.config = config_data # type: ignore[assignment]
|
||||
db_session.flush() # Don't commit yet, let caller decide when to commit
|
||||
return existing_config
|
||||
else:
|
||||
|
||||
@@ -1,22 +1,111 @@
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import Memory
|
||||
from onyx.db.models import User
|
||||
from onyx.prompts.user_info import BASIC_INFORMATION_PROMPT
|
||||
from onyx.prompts.user_info import USER_MEMORIES_PROMPT
|
||||
from onyx.prompts.user_info import USER_PREFERENCES_PROMPT
|
||||
from onyx.prompts.user_info import USER_ROLE_PROMPT
|
||||
|
||||
|
||||
def get_memories(user: User, db_session: Session) -> list[str]:
|
||||
class UserInfo(BaseModel):
|
||||
name: str | None = None
|
||||
role: str | None = None
|
||||
email: str | None = None
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"name": self.name,
|
||||
"role": self.role,
|
||||
"email": self.email,
|
||||
}
|
||||
|
||||
|
||||
class UserMemoryContext(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
user_info: UserInfo
|
||||
user_preferences: str | None = None
|
||||
memories: tuple[str, ...] = ()
|
||||
|
||||
def as_formatted_list(self) -> list[str]:
|
||||
"""Returns combined list of user info, preferences, and memories."""
|
||||
result = []
|
||||
if self.user_info.name:
|
||||
result.append(f"User's name: {self.user_info.name}")
|
||||
if self.user_info.role:
|
||||
result.append(f"User's role: {self.user_info.role}")
|
||||
if self.user_info.email:
|
||||
result.append(f"User's email: {self.user_info.email}")
|
||||
if self.user_preferences:
|
||||
result.append(f"User preferences: {self.user_preferences}")
|
||||
result.extend(self.memories)
|
||||
return result
|
||||
|
||||
def as_formatted_prompt(self) -> str:
|
||||
"""Returns structured prompt sections for the system prompt."""
|
||||
has_basic_info = (
|
||||
self.user_info.name or self.user_info.email or self.user_info.role
|
||||
)
|
||||
if not has_basic_info and not self.user_preferences and not self.memories:
|
||||
return ""
|
||||
|
||||
sections: list[str] = []
|
||||
|
||||
if has_basic_info:
|
||||
role_line = (
|
||||
USER_ROLE_PROMPT.format(user_role=self.user_info.role).strip()
|
||||
if self.user_info.role
|
||||
else ""
|
||||
)
|
||||
if role_line:
|
||||
role_line = "\n" + role_line
|
||||
sections.append(
|
||||
BASIC_INFORMATION_PROMPT.format(
|
||||
user_name=self.user_info.name or "",
|
||||
user_email=self.user_info.email or "",
|
||||
user_role=role_line,
|
||||
)
|
||||
)
|
||||
|
||||
if self.user_preferences:
|
||||
sections.append(
|
||||
USER_PREFERENCES_PROMPT.format(user_preferences=self.user_preferences)
|
||||
)
|
||||
|
||||
if self.memories:
|
||||
formatted_memories = "\n".join(f"- {memory}" for memory in self.memories)
|
||||
sections.append(
|
||||
USER_MEMORIES_PROMPT.format(user_memories=formatted_memories)
|
||||
)
|
||||
|
||||
return "".join(sections)
|
||||
|
||||
|
||||
def get_memories(user: User, db_session: Session) -> UserMemoryContext:
|
||||
if not user.use_memories:
|
||||
return []
|
||||
return UserMemoryContext(user_info=UserInfo())
|
||||
|
||||
user_info = [
|
||||
f"User's name: {user.personal_name}" if user.personal_name else "",
|
||||
f"User's role: {user.personal_role}" if user.personal_role else "",
|
||||
f"User's email: {user.email}" if user.email else "",
|
||||
]
|
||||
user_info = UserInfo(
|
||||
name=user.personal_name,
|
||||
role=user.personal_role,
|
||||
email=user.email,
|
||||
)
|
||||
|
||||
user_preferences = None
|
||||
if user.user_preferences:
|
||||
user_preferences = user.user_preferences
|
||||
|
||||
memory_rows = db_session.scalars(
|
||||
select(Memory).where(Memory.user_id == user.id)
|
||||
select(Memory).where(Memory.user_id == user.id).order_by(Memory.id.asc())
|
||||
).all()
|
||||
memories = [memory.memory_text for memory in memory_rows if memory.memory_text]
|
||||
return user_info + memories
|
||||
memories = tuple(memory.memory_text for memory in memory_rows if memory.memory_text)
|
||||
|
||||
return UserMemoryContext(
|
||||
user_info=user_info,
|
||||
user_preferences=user_preferences,
|
||||
memories=memories,
|
||||
)
|
||||
|
||||
@@ -95,10 +95,10 @@ from onyx.file_store.models import FileDescriptor
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.llm.override_models import PromptOverride
|
||||
from onyx.kg.models import KGStage
|
||||
from onyx.server.features.mcp.models import MCPConnectionData
|
||||
from onyx.tools.tool_implementations.web_search.models import WebContentProviderConfig
|
||||
from onyx.utils.encryption import decrypt_bytes_to_string
|
||||
from onyx.utils.encryption import encrypt_string_to_bytes
|
||||
from onyx.utils.sensitive import SensitiveValue
|
||||
from onyx.utils.headers import HeaderItemDict
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
from onyx.context.search.enums import RecencyBiasSetting
|
||||
@@ -122,18 +122,35 @@ class EncryptedString(TypeDecorator):
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(
|
||||
self, value: str | None, dialect: Dialect # noqa: ARG002
|
||||
self, value: str | SensitiveValue[str] | None, dialect: Dialect # noqa: ARG002
|
||||
) -> bytes | None:
|
||||
if value is not None:
|
||||
# Handle both raw strings and SensitiveValue wrappers
|
||||
if isinstance(value, SensitiveValue):
|
||||
# Get raw value for storage
|
||||
value = value.get_value(apply_mask=False)
|
||||
return encrypt_string_to_bytes(value)
|
||||
return value
|
||||
|
||||
def process_result_value(
|
||||
self, value: bytes | None, dialect: Dialect # noqa: ARG002
|
||||
) -> str | None:
|
||||
) -> SensitiveValue[str] | None:
|
||||
if value is not None:
|
||||
return decrypt_bytes_to_string(value)
|
||||
return value
|
||||
return SensitiveValue(
|
||||
encrypted_bytes=value,
|
||||
decrypt_fn=decrypt_bytes_to_string,
|
||||
is_json=False,
|
||||
)
|
||||
return None
|
||||
|
||||
def compare_values(self, x: Any, y: Any) -> bool:
|
||||
if x is None or y is None:
|
||||
return x == y
|
||||
if isinstance(x, SensitiveValue):
|
||||
x = x.get_value(apply_mask=False)
|
||||
if isinstance(y, SensitiveValue):
|
||||
y = y.get_value(apply_mask=False)
|
||||
return x == y
|
||||
|
||||
|
||||
class EncryptedJson(TypeDecorator):
|
||||
@@ -142,20 +159,38 @@ class EncryptedJson(TypeDecorator):
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(
|
||||
self, value: dict | None, dialect: Dialect # noqa: ARG002
|
||||
self,
|
||||
value: dict[str, Any] | SensitiveValue[dict[str, Any]] | None,
|
||||
dialect: Dialect, # noqa: ARG002
|
||||
) -> bytes | None:
|
||||
if value is not None:
|
||||
# Handle both raw dicts and SensitiveValue wrappers
|
||||
if isinstance(value, SensitiveValue):
|
||||
# Get raw value for storage
|
||||
value = value.get_value(apply_mask=False)
|
||||
json_str = json.dumps(value)
|
||||
return encrypt_string_to_bytes(json_str)
|
||||
return value
|
||||
|
||||
def process_result_value(
|
||||
self, value: bytes | None, dialect: Dialect # noqa: ARG002
|
||||
) -> dict | None:
|
||||
) -> SensitiveValue[dict[str, Any]] | None:
|
||||
if value is not None:
|
||||
json_str = decrypt_bytes_to_string(value)
|
||||
return json.loads(json_str)
|
||||
return value
|
||||
return SensitiveValue(
|
||||
encrypted_bytes=value,
|
||||
decrypt_fn=decrypt_bytes_to_string,
|
||||
is_json=True,
|
||||
)
|
||||
return None
|
||||
|
||||
def compare_values(self, x: Any, y: Any) -> bool:
|
||||
if x is None or y is None:
|
||||
return x == y
|
||||
if isinstance(x, SensitiveValue):
|
||||
x = x.get_value(apply_mask=False)
|
||||
if isinstance(y, SensitiveValue):
|
||||
y = y.get_value(apply_mask=False)
|
||||
return x == y
|
||||
|
||||
|
||||
class NullFilteredString(TypeDecorator):
|
||||
@@ -216,6 +251,7 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
personal_name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
personal_role: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
use_memories: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
user_preferences: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
|
||||
chosen_assistants: Mapped[list[int] | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True, default=None
|
||||
@@ -1755,7 +1791,9 @@ class Credential(Base):
|
||||
)
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
credential_json: Mapped[dict[str, Any]] = mapped_column(EncryptedJson())
|
||||
credential_json: Mapped[SensitiveValue[dict[str, Any]] | None] = mapped_column(
|
||||
EncryptedJson()
|
||||
)
|
||||
user_id: Mapped[UUID | None] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
@@ -1793,7 +1831,9 @@ class FederatedConnector(Base):
|
||||
source: Mapped[FederatedConnectorSource] = mapped_column(
|
||||
Enum(FederatedConnectorSource, native_enum=False)
|
||||
)
|
||||
credentials: Mapped[dict[str, str]] = mapped_column(EncryptedJson(), nullable=False)
|
||||
credentials: Mapped[SensitiveValue[dict[str, Any]] | None] = mapped_column(
|
||||
EncryptedJson(), nullable=False
|
||||
)
|
||||
config: Mapped[dict[str, Any]] = mapped_column(
|
||||
postgresql.JSONB(), default=dict, nullable=False, server_default="{}"
|
||||
)
|
||||
@@ -1820,7 +1860,9 @@ class FederatedConnectorOAuthToken(Base):
|
||||
user_id: Mapped[UUID] = mapped_column(
|
||||
ForeignKey("user.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
token: Mapped[str] = mapped_column(EncryptedString(), nullable=False)
|
||||
token: Mapped[SensitiveValue[str] | None] = mapped_column(
|
||||
EncryptedString(), nullable=False
|
||||
)
|
||||
expires_at: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime, nullable=True
|
||||
)
|
||||
@@ -1964,7 +2006,9 @@ class SearchSettings(Base):
|
||||
|
||||
@property
|
||||
def api_key(self) -> str | None:
|
||||
return self.cloud_provider.api_key if self.cloud_provider is not None else None
|
||||
if self.cloud_provider is None or self.cloud_provider.api_key is None:
|
||||
return None
|
||||
return self.cloud_provider.api_key.get_value(apply_mask=False)
|
||||
|
||||
@property
|
||||
def large_chunks_enabled(self) -> bool:
|
||||
@@ -2726,7 +2770,9 @@ class LLMProvider(Base):
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String, unique=True)
|
||||
provider: Mapped[str] = mapped_column(String)
|
||||
api_key: Mapped[str | None] = mapped_column(EncryptedString(), nullable=True)
|
||||
api_key: Mapped[SensitiveValue[str] | None] = mapped_column(
|
||||
EncryptedString(), nullable=True
|
||||
)
|
||||
api_base: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
api_version: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
# custom configs that should be passed to the LLM provider at inference time
|
||||
@@ -2879,7 +2925,7 @@ class CloudEmbeddingProvider(Base):
|
||||
Enum(EmbeddingProvider), primary_key=True
|
||||
)
|
||||
api_url: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
api_key: Mapped[str | None] = mapped_column(EncryptedString())
|
||||
api_key: Mapped[SensitiveValue[str] | None] = mapped_column(EncryptedString())
|
||||
api_version: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
deployment_name: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
|
||||
@@ -2898,7 +2944,9 @@ class InternetSearchProvider(Base):
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String, unique=True, nullable=False)
|
||||
provider_type: Mapped[str] = mapped_column(String, nullable=False)
|
||||
api_key: Mapped[str | None] = mapped_column(EncryptedString(), nullable=True)
|
||||
api_key: Mapped[SensitiveValue[str] | None] = mapped_column(
|
||||
EncryptedString(), nullable=True
|
||||
)
|
||||
config: Mapped[dict[str, str] | None] = mapped_column(
|
||||
postgresql.JSONB(), nullable=True
|
||||
)
|
||||
@@ -2920,7 +2968,9 @@ class InternetContentProvider(Base):
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True)
|
||||
name: Mapped[str] = mapped_column(String, unique=True, nullable=False)
|
||||
provider_type: Mapped[str] = mapped_column(String, nullable=False)
|
||||
api_key: Mapped[str | None] = mapped_column(EncryptedString(), nullable=True)
|
||||
api_key: Mapped[SensitiveValue[str] | None] = mapped_column(
|
||||
EncryptedString(), nullable=True
|
||||
)
|
||||
config: Mapped[WebContentProviderConfig | None] = mapped_column(
|
||||
PydanticType(WebContentProviderConfig), nullable=True
|
||||
)
|
||||
@@ -3064,8 +3114,12 @@ class OAuthConfig(Base):
|
||||
token_url: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
|
||||
# Client credentials (encrypted)
|
||||
client_id: Mapped[str] = mapped_column(EncryptedString(), nullable=False)
|
||||
client_secret: Mapped[str] = mapped_column(EncryptedString(), nullable=False)
|
||||
client_id: Mapped[SensitiveValue[str] | None] = mapped_column(
|
||||
EncryptedString(), nullable=False
|
||||
)
|
||||
client_secret: Mapped[SensitiveValue[str] | None] = mapped_column(
|
||||
EncryptedString(), nullable=False
|
||||
)
|
||||
|
||||
# Optional configurations
|
||||
scopes: Mapped[list[str] | None] = mapped_column(postgresql.JSONB(), nullable=True)
|
||||
@@ -3112,7 +3166,9 @@ class OAuthUserToken(Base):
|
||||
# "expires_at": 1234567890, # Unix timestamp, optional
|
||||
# "scope": "repo user" # Optional
|
||||
# }
|
||||
token_data: Mapped[dict[str, Any]] = mapped_column(EncryptedJson(), nullable=False)
|
||||
token_data: Mapped[SensitiveValue[dict[str, Any]] | None] = mapped_column(
|
||||
EncryptedJson(), nullable=False
|
||||
)
|
||||
|
||||
# Metadata
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
@@ -3445,9 +3501,15 @@ class SlackBot(Base):
|
||||
name: Mapped[str] = mapped_column(String)
|
||||
enabled: Mapped[bool] = mapped_column(Boolean, default=True)
|
||||
|
||||
bot_token: Mapped[str] = mapped_column(EncryptedString(), unique=True)
|
||||
app_token: Mapped[str] = mapped_column(EncryptedString(), unique=True)
|
||||
user_token: Mapped[str | None] = mapped_column(EncryptedString(), nullable=True)
|
||||
bot_token: Mapped[SensitiveValue[str] | None] = mapped_column(
|
||||
EncryptedString(), unique=True
|
||||
)
|
||||
app_token: Mapped[SensitiveValue[str] | None] = mapped_column(
|
||||
EncryptedString(), unique=True
|
||||
)
|
||||
user_token: Mapped[SensitiveValue[str] | None] = mapped_column(
|
||||
EncryptedString(), nullable=True
|
||||
)
|
||||
|
||||
slack_channel_configs: Mapped[list[SlackChannelConfig]] = relationship(
|
||||
"SlackChannelConfig",
|
||||
@@ -3468,7 +3530,9 @@ class DiscordBotConfig(Base):
|
||||
id: Mapped[str] = mapped_column(
|
||||
String, primary_key=True, server_default=text("'SINGLETON'")
|
||||
)
|
||||
bot_token: Mapped[str] = mapped_column(EncryptedString(), nullable=False)
|
||||
bot_token: Mapped[SensitiveValue[str] | None] = mapped_column(
|
||||
EncryptedString(), nullable=False
|
||||
)
|
||||
created_at: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True), server_default=func.now(), nullable=False
|
||||
)
|
||||
@@ -3624,7 +3688,9 @@ class KVStore(Base):
|
||||
|
||||
key: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
value: Mapped[JSON_ro] = mapped_column(postgresql.JSONB(), nullable=True)
|
||||
encrypted_value: Mapped[JSON_ro] = mapped_column(EncryptedJson(), nullable=True)
|
||||
encrypted_value: Mapped[SensitiveValue[dict[str, Any]] | None] = mapped_column(
|
||||
EncryptedJson(), nullable=True
|
||||
)
|
||||
|
||||
|
||||
class FileRecord(Base):
|
||||
@@ -4344,7 +4410,7 @@ class MCPConnectionConfig(Base):
|
||||
# "registration_access_token": "<token>", # For managing registration
|
||||
# "registration_client_uri": "<uri>", # For managing registration
|
||||
# }
|
||||
config: Mapped[MCPConnectionData] = mapped_column(
|
||||
config: Mapped[SensitiveValue[dict[str, Any]] | None] = mapped_column(
|
||||
EncryptedJson(), nullable=False, default=dict
|
||||
)
|
||||
|
||||
|
||||
@@ -87,13 +87,13 @@ def update_oauth_config(
|
||||
if token_url is not None:
|
||||
oauth_config.token_url = token_url
|
||||
if clear_client_id:
|
||||
oauth_config.client_id = ""
|
||||
oauth_config.client_id = "" # type: ignore[assignment]
|
||||
elif client_id is not None:
|
||||
oauth_config.client_id = client_id
|
||||
oauth_config.client_id = client_id # type: ignore[assignment]
|
||||
if clear_client_secret:
|
||||
oauth_config.client_secret = ""
|
||||
oauth_config.client_secret = "" # type: ignore[assignment]
|
||||
elif client_secret is not None:
|
||||
oauth_config.client_secret = client_secret
|
||||
oauth_config.client_secret = client_secret # type: ignore[assignment]
|
||||
if scopes is not None:
|
||||
oauth_config.scopes = scopes
|
||||
if additional_params is not None:
|
||||
@@ -154,7 +154,7 @@ def upsert_user_oauth_token(
|
||||
|
||||
if existing_token:
|
||||
# Update existing token
|
||||
existing_token.token_data = token_data
|
||||
existing_token.token_data = token_data # type: ignore[assignment]
|
||||
db_session.commit()
|
||||
return existing_token
|
||||
else:
|
||||
|
||||
@@ -43,9 +43,9 @@ def update_slack_bot(
|
||||
# update the app
|
||||
slack_bot.name = name
|
||||
slack_bot.enabled = enabled
|
||||
slack_bot.bot_token = bot_token
|
||||
slack_bot.app_token = app_token
|
||||
slack_bot.user_token = user_token
|
||||
slack_bot.bot_token = bot_token # type: ignore[assignment]
|
||||
slack_bot.app_token = app_token # type: ignore[assignment]
|
||||
slack_bot.user_token = user_token # type: ignore[assignment]
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
@@ -160,6 +160,7 @@ def update_user_personalization(
|
||||
personal_role: str | None,
|
||||
use_memories: bool,
|
||||
memories: list[str],
|
||||
user_preferences: str | None,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
db_session.execute(
|
||||
@@ -169,6 +170,7 @@ def update_user_personalization(
|
||||
personal_name=personal_name,
|
||||
personal_role=personal_role,
|
||||
use_memories=use_memories,
|
||||
user_preferences=user_preferences,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -73,7 +73,8 @@ def _apply_search_provider_updates(
|
||||
provider.provider_type = provider_type.value
|
||||
provider.config = config
|
||||
if api_key_changed or provider.api_key is None:
|
||||
provider.api_key = api_key
|
||||
# EncryptedString accepts str for writes, returns SensitiveValue for reads
|
||||
provider.api_key = api_key # type: ignore[assignment]
|
||||
|
||||
|
||||
def upsert_web_search_provider(
|
||||
@@ -228,7 +229,8 @@ def _apply_content_provider_updates(
|
||||
provider.provider_type = provider_type.value
|
||||
provider.config = config
|
||||
if api_key_changed or provider.api_key is None:
|
||||
provider.api_key = api_key
|
||||
# EncryptedString accepts str for writes, returns SensitiveValue for reads
|
||||
provider.api_key = api_key # type: ignore[assignment]
|
||||
|
||||
|
||||
def upsert_web_content_provider(
|
||||
|
||||
@@ -119,7 +119,16 @@ def get_federated_retrieval_functions(
|
||||
federated_retrieval_infos_slack = []
|
||||
|
||||
# Use user_token if available, otherwise fall back to bot_token
|
||||
access_token = tenant_slack_bot.user_token or tenant_slack_bot.bot_token
|
||||
# Unwrap SensitiveValue for backend API calls
|
||||
access_token = (
|
||||
tenant_slack_bot.user_token.get_value(apply_mask=False)
|
||||
if tenant_slack_bot.user_token
|
||||
else (
|
||||
tenant_slack_bot.bot_token.get_value(apply_mask=False)
|
||||
if tenant_slack_bot.bot_token
|
||||
else ""
|
||||
)
|
||||
)
|
||||
if not tenant_slack_bot.user_token:
|
||||
logger.warning(
|
||||
f"Using bot_token for Slack search (limited functionality): {tenant_slack_bot.name}"
|
||||
@@ -138,7 +147,12 @@ def get_federated_retrieval_functions(
|
||||
)
|
||||
|
||||
# Capture variables by value to avoid lambda closure issues
|
||||
bot_token = tenant_slack_bot.bot_token
|
||||
# Unwrap SensitiveValue for backend API calls
|
||||
bot_token = (
|
||||
tenant_slack_bot.bot_token.get_value(apply_mask=False)
|
||||
if tenant_slack_bot.bot_token
|
||||
else ""
|
||||
)
|
||||
|
||||
# Use connector config for channel filtering (guaranteed to exist at this point)
|
||||
connector_entities = slack_federated_connector_config
|
||||
@@ -252,11 +266,11 @@ def get_federated_retrieval_functions(
|
||||
|
||||
connector = get_federated_connector(
|
||||
oauth_token.federated_connector.source,
|
||||
oauth_token.federated_connector.credentials,
|
||||
oauth_token.federated_connector.credentials.get_value(apply_mask=False),
|
||||
)
|
||||
|
||||
# Capture variables by value to avoid lambda closure issues
|
||||
access_token = oauth_token.token
|
||||
access_token = oauth_token.token.get_value(apply_mask=False)
|
||||
|
||||
def create_retrieval_function(
|
||||
conn: FederatedConnector,
|
||||
|
||||
@@ -43,7 +43,7 @@ class PgRedisKVStore(KeyValueStore):
|
||||
obj = db_session.query(KVStore).filter_by(key=key).first()
|
||||
if obj:
|
||||
obj.value = plain_val
|
||||
obj.encrypted_value = encrypted_val
|
||||
obj.encrypted_value = encrypted_val # type: ignore[assignment]
|
||||
else:
|
||||
obj = KVStore(key=key, value=plain_val, encrypted_value=encrypted_val)
|
||||
db_session.query(KVStore).filter_by(key=key).delete() # just in case
|
||||
@@ -73,7 +73,8 @@ class PgRedisKVStore(KeyValueStore):
|
||||
if obj.value is not None:
|
||||
value = obj.value
|
||||
elif obj.encrypted_value is not None:
|
||||
value = obj.encrypted_value
|
||||
# Unwrap SensitiveValue - this is internal backend use
|
||||
value = obj.encrypted_value.get_value(apply_mask=False)
|
||||
else:
|
||||
value = None
|
||||
|
||||
|
||||
@@ -1,4 +1,8 @@
|
||||
import os
|
||||
import threading
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from contextlib import nullcontext
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TYPE_CHECKING
|
||||
@@ -44,11 +48,13 @@ from onyx.llm.well_known_providers.constants import (
|
||||
VERTEX_CREDENTIALS_FILE_KWARG_ENV_VAR_FORMAT,
|
||||
)
|
||||
from onyx.llm.well_known_providers.constants import VERTEX_LOCATION_KWARG
|
||||
from onyx.server.utils import mask_string
|
||||
from onyx.utils.encryption import mask_string
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_env_lock = threading.Lock()
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from litellm import CustomStreamWrapper
|
||||
from litellm import HTTPHandler
|
||||
@@ -378,23 +384,29 @@ class LitellmLLM(LLM):
|
||||
if "api_key" not in passthrough_kwargs:
|
||||
passthrough_kwargs["api_key"] = self._api_key or None
|
||||
|
||||
response = litellm.completion(
|
||||
mock_response=get_llm_mock_response() or MOCK_LLM_RESPONSE,
|
||||
model=model,
|
||||
base_url=self._api_base or None,
|
||||
api_version=self._api_version or None,
|
||||
custom_llm_provider=self._custom_llm_provider or None,
|
||||
messages=_prompt_to_dicts(prompt),
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
stream=stream,
|
||||
temperature=temperature,
|
||||
timeout=timeout_override or self._timeout,
|
||||
max_tokens=max_tokens,
|
||||
client=client,
|
||||
**optional_kwargs,
|
||||
**passthrough_kwargs,
|
||||
env_ctx = (
|
||||
temporary_env_and_lock(self._custom_config)
|
||||
if self._custom_config
|
||||
else nullcontext()
|
||||
)
|
||||
with env_ctx:
|
||||
response = litellm.completion(
|
||||
mock_response=get_llm_mock_response() or MOCK_LLM_RESPONSE,
|
||||
model=model,
|
||||
base_url=self._api_base or None,
|
||||
api_version=self._api_version or None,
|
||||
custom_llm_provider=self._custom_llm_provider or None,
|
||||
messages=_prompt_to_dicts(prompt),
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
stream=stream,
|
||||
temperature=temperature,
|
||||
timeout=timeout_override or self._timeout,
|
||||
max_tokens=max_tokens,
|
||||
client=client,
|
||||
**optional_kwargs,
|
||||
**passthrough_kwargs,
|
||||
)
|
||||
return response
|
||||
except Exception as e:
|
||||
# for break pointing
|
||||
@@ -475,22 +487,53 @@ class LitellmLLM(LLM):
|
||||
client = HTTPHandler(timeout=timeout_override or self._timeout)
|
||||
|
||||
try:
|
||||
response = cast(
|
||||
LiteLLMModelResponse,
|
||||
self._completion(
|
||||
prompt=prompt,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
stream=False,
|
||||
structured_response_format=structured_response_format,
|
||||
timeout_override=timeout_override,
|
||||
max_tokens=max_tokens,
|
||||
parallel_tool_calls=True,
|
||||
reasoning_effort=reasoning_effort,
|
||||
user_identity=user_identity,
|
||||
client=client,
|
||||
),
|
||||
)
|
||||
if self._custom_config:
|
||||
# When custom_config is set, env vars are temporarily injected
|
||||
# under a global lock. Using stream=True here means the lock is
|
||||
# only held during connection setup (not the full inference).
|
||||
# The chunks are then collected outside the lock and reassembled
|
||||
# into a single ModelResponse via stream_chunk_builder.
|
||||
from litellm import stream_chunk_builder
|
||||
from litellm import CustomStreamWrapper as LiteLLMCustomStreamWrapper
|
||||
|
||||
stream_response = cast(
|
||||
LiteLLMCustomStreamWrapper,
|
||||
self._completion(
|
||||
prompt=prompt,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
stream=True,
|
||||
structured_response_format=structured_response_format,
|
||||
timeout_override=timeout_override,
|
||||
max_tokens=max_tokens,
|
||||
parallel_tool_calls=True,
|
||||
reasoning_effort=reasoning_effort,
|
||||
user_identity=user_identity,
|
||||
client=client,
|
||||
),
|
||||
)
|
||||
chunks = list(stream_response)
|
||||
response = cast(
|
||||
LiteLLMModelResponse,
|
||||
stream_chunk_builder(chunks),
|
||||
)
|
||||
else:
|
||||
response = cast(
|
||||
LiteLLMModelResponse,
|
||||
self._completion(
|
||||
prompt=prompt,
|
||||
tools=tools,
|
||||
tool_choice=tool_choice,
|
||||
stream=False,
|
||||
structured_response_format=structured_response_format,
|
||||
timeout_override=timeout_override,
|
||||
max_tokens=max_tokens,
|
||||
parallel_tool_calls=True,
|
||||
reasoning_effort=reasoning_effort,
|
||||
user_identity=user_identity,
|
||||
client=client,
|
||||
),
|
||||
)
|
||||
|
||||
model_response = from_litellm_model_response(response)
|
||||
|
||||
@@ -581,3 +624,29 @@ class LitellmLLM(LLM):
|
||||
finally:
|
||||
if client is not None:
|
||||
client.close()
|
||||
|
||||
|
||||
@contextmanager
|
||||
def temporary_env_and_lock(env_variables: dict[str, str]) -> Iterator[None]:
|
||||
"""
|
||||
Temporarily sets the environment variables to the given values.
|
||||
Code path is locked while the environment variables are set.
|
||||
Then cleans up the environment and frees the lock.
|
||||
"""
|
||||
with _env_lock:
|
||||
logger.debug("Acquired lock in temporary_env_and_lock")
|
||||
# Store original values (None if key didn't exist)
|
||||
original_values: dict[str, str | None] = {
|
||||
key: os.environ.get(key) for key in env_variables
|
||||
}
|
||||
try:
|
||||
os.environ.update(env_variables)
|
||||
yield
|
||||
finally:
|
||||
for key, original_value in original_values.items():
|
||||
if original_value is None:
|
||||
os.environ.pop(key, None) # Remove if it didn't exist before
|
||||
else:
|
||||
os.environ[key] = original_value # Restore original value
|
||||
|
||||
logger.debug("Released lock in temporary_env_and_lock")
|
||||
|
||||
@@ -4,6 +4,7 @@ from onyx.configs.constants import AuthType
|
||||
from onyx.db.discord_bot import get_discord_bot_config
|
||||
from onyx.db.engine.sql_engine import get_session_with_tenant
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.sensitive import SensitiveValue
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -36,4 +37,8 @@ def get_bot_token() -> str | None:
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to get bot token from database: {e}")
|
||||
return None
|
||||
return config.bot_token if config else None
|
||||
if config and config.bot_token:
|
||||
if isinstance(config.bot_token, SensitiveValue):
|
||||
return config.bot_token.get_value(apply_mask=False)
|
||||
return config.bot_token
|
||||
return None
|
||||
|
||||
@@ -216,14 +216,10 @@ class SlackbotHandler:
|
||||
- If the tokens have changed, close the existing socket client and reconnect.
|
||||
- If the tokens are new, warm up the model and start a new socket client.
|
||||
"""
|
||||
slack_bot_tokens = SlackBotTokens(
|
||||
bot_token=bot.bot_token,
|
||||
app_token=bot.app_token,
|
||||
)
|
||||
tenant_bot_pair = (tenant_id, bot.id)
|
||||
|
||||
# If the tokens are missing or empty, close the socket client and remove them.
|
||||
if not slack_bot_tokens:
|
||||
if not bot.bot_token or not bot.app_token:
|
||||
logger.debug(
|
||||
f"No Slack bot tokens found for tenant={tenant_id}, bot {bot.id}"
|
||||
)
|
||||
@@ -233,6 +229,11 @@ class SlackbotHandler:
|
||||
del self.slack_bot_tokens[tenant_bot_pair]
|
||||
return
|
||||
|
||||
slack_bot_tokens = SlackBotTokens(
|
||||
bot_token=bot.bot_token.get_value(apply_mask=False),
|
||||
app_token=bot.app_token.get_value(apply_mask=False),
|
||||
)
|
||||
|
||||
tokens_exist = tenant_bot_pair in self.slack_bot_tokens
|
||||
tokens_changed = (
|
||||
tokens_exist and slack_bot_tokens != self.slack_bot_tokens[tenant_bot_pair]
|
||||
|
||||
@@ -25,9 +25,6 @@ You can use Markdown tables to format your responses for data, lists, and other
|
||||
""".lstrip()
|
||||
|
||||
|
||||
# Section for information about the user if provided such as their name, role, memories, etc.
|
||||
USER_INFO_HEADER = "\n\n# User Information\n"
|
||||
|
||||
COMPANY_NAME_BLOCK = """
|
||||
The user is at an organization called `{company_name}`.
|
||||
"""
|
||||
|
||||
@@ -403,12 +403,13 @@ def check_drive_tokens(
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> AuthStatus:
|
||||
db_credentials = fetch_credential_by_id_for_user(credential_id, user, db_session)
|
||||
if (
|
||||
not db_credentials
|
||||
or DB_CREDENTIALS_DICT_TOKEN_KEY not in db_credentials.credential_json
|
||||
):
|
||||
if not db_credentials or not db_credentials.credential_json:
|
||||
return AuthStatus(authenticated=False)
|
||||
token_json_str = str(db_credentials.credential_json[DB_CREDENTIALS_DICT_TOKEN_KEY])
|
||||
|
||||
credential_json = db_credentials.credential_json.get_value(apply_mask=False)
|
||||
if DB_CREDENTIALS_DICT_TOKEN_KEY not in credential_json:
|
||||
return AuthStatus(authenticated=False)
|
||||
token_json_str = str(credential_json[DB_CREDENTIALS_DICT_TOKEN_KEY])
|
||||
google_drive_creds = get_google_oauth_creds(
|
||||
token_json_str=token_json_str,
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
|
||||
@@ -346,10 +346,17 @@ def update_credential_from_model(
|
||||
detail=f"Credential {credential_id} does not exist or does not belong to user",
|
||||
)
|
||||
|
||||
# Get credential_json value - use masking for API responses
|
||||
credential_json_value = (
|
||||
updated_credential.credential_json.get_value(apply_mask=True)
|
||||
if updated_credential.credential_json
|
||||
else {}
|
||||
)
|
||||
|
||||
return CredentialSnapshot(
|
||||
source=updated_credential.source,
|
||||
id=updated_credential.id,
|
||||
credential_json=updated_credential.credential_json,
|
||||
credential_json=credential_json_value,
|
||||
user_id=updated_credential.user_id,
|
||||
name=updated_credential.name,
|
||||
admin_public=updated_credential.admin_public,
|
||||
|
||||
@@ -28,7 +28,6 @@ from onyx.db.models import IndexAttempt
|
||||
from onyx.db.models import IndexingStatus
|
||||
from onyx.db.models import TaskStatus
|
||||
from onyx.server.federated.models import FederatedConnectorStatus
|
||||
from onyx.server.utils import mask_credential_dict
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
|
||||
|
||||
@@ -145,13 +144,21 @@ class CredentialSnapshot(CredentialBase):
|
||||
|
||||
@classmethod
|
||||
def from_credential_db_model(cls, credential: Credential) -> "CredentialSnapshot":
|
||||
# Get the credential_json value with appropriate masking
|
||||
if credential.credential_json is None:
|
||||
credential_json_value: dict[str, Any] = {}
|
||||
elif MASK_CREDENTIAL_PREFIX:
|
||||
credential_json_value = credential.credential_json.get_value(
|
||||
apply_mask=True
|
||||
)
|
||||
else:
|
||||
credential_json_value = credential.credential_json.get_value(
|
||||
apply_mask=False
|
||||
)
|
||||
|
||||
return CredentialSnapshot(
|
||||
id=credential.id,
|
||||
credential_json=(
|
||||
mask_credential_dict(credential.credential_json)
|
||||
if MASK_CREDENTIAL_PREFIX and credential.credential_json
|
||||
else credential.credential_json
|
||||
),
|
||||
credential_json=credential_json_value,
|
||||
user_id=credential.user_id,
|
||||
user_email=credential.user.email if credential.user else None,
|
||||
admin_public=credential.admin_public,
|
||||
|
||||
@@ -88,7 +88,7 @@ SANDBOX_NAMESPACE = os.environ.get("SANDBOX_NAMESPACE", "onyx-sandboxes")
|
||||
# Container image for sandbox pods
|
||||
# Should include Next.js template, opencode CLI, and demo_data zip
|
||||
SANDBOX_CONTAINER_IMAGE = os.environ.get(
|
||||
"SANDBOX_CONTAINER_IMAGE", "onyxdotapp/sandbox:v0.1.2"
|
||||
"SANDBOX_CONTAINER_IMAGE", "onyxdotapp/sandbox:v0.1.3"
|
||||
)
|
||||
|
||||
# S3 bucket for sandbox file storage (snapshots, knowledge files, uploads)
|
||||
|
||||
@@ -0,0 +1,119 @@
|
||||
# Sandbox Container Image
|
||||
|
||||
This directory contains the Dockerfile and resources for building the Onyx Craft sandbox container image.
|
||||
|
||||
## Directory Structure
|
||||
|
||||
```
|
||||
docker/
|
||||
├── Dockerfile # Main container image definition
|
||||
├── demo_data.zip # Demo data (extracted to /workspace/demo_data)
|
||||
├── templates/
|
||||
│ └── outputs/ # Web app scaffold template (Next.js)
|
||||
├── initial-requirements.txt # Python packages pre-installed in sandbox
|
||||
├── generate_agents_md.py # Script to generate AGENTS.md for sessions
|
||||
└── README.md # This file
|
||||
```
|
||||
|
||||
## Building the Image
|
||||
|
||||
The sandbox image must be built for **amd64** architecture since our Kubernetes cluster runs on x86_64 nodes.
|
||||
|
||||
### Build for amd64 only (fastest)
|
||||
|
||||
```bash
|
||||
cd backend/onyx/server/features/build/sandbox/kubernetes/docker
|
||||
docker build --platform linux/amd64 -t onyxdotapp/sandbox:v0.1.x .
|
||||
docker push onyxdotapp/sandbox:v0.1.x
|
||||
```
|
||||
|
||||
### Build multi-arch (recommended for flexibility)
|
||||
|
||||
```bash
|
||||
docker buildx build --platform linux/amd64,linux/arm64 \
|
||||
-t onyxdotapp/sandbox:v0.1.x \
|
||||
--push .
|
||||
```
|
||||
|
||||
### Update the `latest` tag
|
||||
|
||||
After pushing a versioned tag, update `latest`:
|
||||
|
||||
```bash
|
||||
docker tag onyxdotapp/sandbox:v0.1.x onyxdotapp/sandbox:latest
|
||||
docker push onyxdotapp/sandbox:latest
|
||||
```
|
||||
|
||||
Or with buildx:
|
||||
|
||||
```bash
|
||||
docker buildx build --platform linux/amd64,linux/arm64 \
|
||||
-t onyxdotapp/sandbox:v0.1.x \
|
||||
-t onyxdotapp/sandbox:latest \
|
||||
--push .
|
||||
```
|
||||
|
||||
## Deploying a New Version
|
||||
|
||||
1. **Build and push** the new image (see above)
|
||||
|
||||
2. **Update the ConfigMap** in `cloud-deployment-yamls/danswer/configmap/env-configmap.yaml`:
|
||||
```yaml
|
||||
SANDBOX_CONTAINER_IMAGE: "onyxdotapp/sandbox:v0.1.x"
|
||||
```
|
||||
|
||||
3. **Apply the ConfigMap**:
|
||||
```bash
|
||||
kubectl apply -f configmap/env-configmap.yaml
|
||||
```
|
||||
|
||||
4. **Restart the API server** to pick up the new config:
|
||||
```bash
|
||||
kubectl rollout restart deployment/api-server -n danswer
|
||||
```
|
||||
|
||||
5. **Delete existing sandbox pods** (they will be recreated with the new image):
|
||||
```bash
|
||||
kubectl delete pods -n onyx-sandboxes -l app.kubernetes.io/component=sandbox
|
||||
```
|
||||
|
||||
## What's Baked Into the Image
|
||||
|
||||
- **Base**: `node:20-slim` (Debian-based)
|
||||
- **Demo data**: `/workspace/demo_data/` - sample files for demo sessions
|
||||
- **Templates**: `/workspace/templates/outputs/` - Next.js web app scaffold
|
||||
- **Python venv**: `/workspace/.venv/` with packages from `initial-requirements.txt`
|
||||
- **OpenCode CLI**: Installed in `/home/sandbox/.opencode/bin/`
|
||||
|
||||
## Runtime Directory Structure
|
||||
|
||||
When a session is created, the following structure is set up in the pod:
|
||||
|
||||
```
|
||||
/workspace/
|
||||
├── demo_data/ # Baked into image
|
||||
├── files/ # Mounted volume, synced from S3
|
||||
├── templates/ # Baked into image
|
||||
└── sessions/
|
||||
└── $session_id/
|
||||
├── files/ # Symlink to /workspace/demo_data or /workspace/files
|
||||
├── outputs/ # Copied from templates, contains web app
|
||||
├── attachments/ # User-uploaded files
|
||||
├── org_info/ # Demo persona info (if demo mode)
|
||||
├── AGENTS.md # Instructions for the AI agent
|
||||
└── opencode.json # OpenCode configuration
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Verify image exists on Docker Hub
|
||||
|
||||
```bash
|
||||
curl -s "https://hub.docker.com/v2/repositories/onyxdotapp/sandbox/tags" | jq '.results[].name'
|
||||
```
|
||||
|
||||
### Check what image a pod is using
|
||||
|
||||
```bash
|
||||
kubectl get pod <pod-name> -n onyx-sandboxes -o jsonpath='{.spec.containers[?(@.name=="sandbox")].image}'
|
||||
```
|
||||
Binary file not shown.
@@ -349,7 +349,11 @@ class SessionManager:
|
||||
return LLMProviderConfig(
|
||||
provider=default_model.llm_provider.provider,
|
||||
model_name=default_model.name,
|
||||
api_key=default_model.llm_provider.api_key,
|
||||
api_key=(
|
||||
default_model.llm_provider.api_key.get_value(apply_mask=False)
|
||||
if default_model.llm_provider.api_key
|
||||
else None
|
||||
),
|
||||
api_base=default_model.llm_provider.api_base,
|
||||
)
|
||||
|
||||
|
||||
@@ -41,6 +41,7 @@ from onyx.db.mcp import delete_all_user_connection_configs_for_server_no_commit
|
||||
from onyx.db.mcp import delete_connection_config
|
||||
from onyx.db.mcp import delete_mcp_server
|
||||
from onyx.db.mcp import delete_user_connection_configs_for_server
|
||||
from onyx.db.mcp import extract_connection_data
|
||||
from onyx.db.mcp import get_all_mcp_servers
|
||||
from onyx.db.mcp import get_connection_config_by_id
|
||||
from onyx.db.mcp import get_mcp_server_by_id
|
||||
@@ -79,6 +80,7 @@ from onyx.server.features.tool.models import ToolSnapshot
|
||||
from onyx.tools.tool_implementations.mcp.mcp_client import discover_mcp_tools
|
||||
from onyx.tools.tool_implementations.mcp.mcp_client import initialize_mcp_client
|
||||
from onyx.tools.tool_implementations.mcp.mcp_client import log_exception_group
|
||||
from onyx.utils.encryption import mask_string
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -143,7 +145,8 @@ class OnyxTokenStorage(TokenStorage):
|
||||
async def get_tokens(self) -> OAuthToken | None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
config = self._ensure_connection_config(db_session)
|
||||
tokens_raw = config.config.get(MCPOAuthKeys.TOKENS.value)
|
||||
config_data = extract_connection_data(config)
|
||||
tokens_raw = config_data.get(MCPOAuthKeys.TOKENS.value)
|
||||
if tokens_raw:
|
||||
return OAuthToken.model_validate(tokens_raw)
|
||||
return None
|
||||
@@ -151,14 +154,14 @@ class OnyxTokenStorage(TokenStorage):
|
||||
async def set_tokens(self, tokens: OAuthToken) -> None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
config = self._ensure_connection_config(db_session)
|
||||
config.config[MCPOAuthKeys.TOKENS.value] = tokens.model_dump(mode="json")
|
||||
cfg_headers = {
|
||||
config_data = extract_connection_data(config)
|
||||
config_data[MCPOAuthKeys.TOKENS.value] = tokens.model_dump(mode="json")
|
||||
config_data["headers"] = {
|
||||
"Authorization": f"{tokens.token_type} {tokens.access_token}"
|
||||
}
|
||||
config.config["headers"] = cfg_headers
|
||||
update_connection_config(config.id, db_session, config.config)
|
||||
update_connection_config(config.id, db_session, config_data)
|
||||
if self.alt_config_id:
|
||||
update_connection_config(self.alt_config_id, db_session, config.config)
|
||||
update_connection_config(self.alt_config_id, db_session, config_data)
|
||||
|
||||
# signal the oauth callback that token exchange is complete
|
||||
r = get_redis_client()
|
||||
@@ -168,19 +171,21 @@ class OnyxTokenStorage(TokenStorage):
|
||||
async def get_client_info(self) -> OAuthClientInformationFull | None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
config = self._ensure_connection_config(db_session)
|
||||
client_info_raw = config.config.get(MCPOAuthKeys.CLIENT_INFO.value)
|
||||
config_data = extract_connection_data(config)
|
||||
client_info_raw = config_data.get(MCPOAuthKeys.CLIENT_INFO.value)
|
||||
if client_info_raw:
|
||||
return OAuthClientInformationFull.model_validate(client_info_raw)
|
||||
if self.alt_config_id:
|
||||
alt_config = get_connection_config_by_id(self.alt_config_id, db_session)
|
||||
if alt_config:
|
||||
alt_client_info = alt_config.config.get(
|
||||
alt_config_data = extract_connection_data(alt_config)
|
||||
alt_client_info = alt_config_data.get(
|
||||
MCPOAuthKeys.CLIENT_INFO.value
|
||||
)
|
||||
if alt_client_info:
|
||||
# Cache the admin client info on the user config for future calls
|
||||
config.config[MCPOAuthKeys.CLIENT_INFO.value] = alt_client_info
|
||||
update_connection_config(config.id, db_session, config.config)
|
||||
config_data[MCPOAuthKeys.CLIENT_INFO.value] = alt_client_info
|
||||
update_connection_config(config.id, db_session, config_data)
|
||||
return OAuthClientInformationFull.model_validate(
|
||||
alt_client_info
|
||||
)
|
||||
@@ -189,10 +194,11 @@ class OnyxTokenStorage(TokenStorage):
|
||||
async def set_client_info(self, info: OAuthClientInformationFull) -> None:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
config = self._ensure_connection_config(db_session)
|
||||
config.config[MCPOAuthKeys.CLIENT_INFO.value] = info.model_dump(mode="json")
|
||||
update_connection_config(config.id, db_session, config.config)
|
||||
config_data = extract_connection_data(config)
|
||||
config_data[MCPOAuthKeys.CLIENT_INFO.value] = info.model_dump(mode="json")
|
||||
update_connection_config(config.id, db_session, config_data)
|
||||
if self.alt_config_id:
|
||||
update_connection_config(self.alt_config_id, db_session, config.config)
|
||||
update_connection_config(self.alt_config_id, db_session, config_data)
|
||||
|
||||
|
||||
def make_oauth_provider(
|
||||
@@ -436,9 +442,12 @@ async def _connect_oauth(
|
||||
|
||||
db.commit()
|
||||
|
||||
connection_config_dict = extract_connection_data(
|
||||
connection_config, apply_mask=False
|
||||
)
|
||||
is_connected = (
|
||||
MCPOAuthKeys.CLIENT_INFO.value in connection_config.config
|
||||
and connection_config.config.get("headers")
|
||||
MCPOAuthKeys.CLIENT_INFO.value in connection_config_dict
|
||||
and connection_config_dict.get("headers")
|
||||
)
|
||||
# Step 1: make unauthenticated request and parse returned www authenticate header
|
||||
# Ensure we have a trailing slash for the MCP endpoint
|
||||
@@ -471,7 +480,7 @@ async def _connect_oauth(
|
||||
try:
|
||||
x = await initialize_mcp_client(
|
||||
probe_url,
|
||||
connection_headers=connection_config.config.get("headers", {}),
|
||||
connection_headers=connection_config_dict.get("headers", {}),
|
||||
transport=transport,
|
||||
auth=oauth_auth,
|
||||
)
|
||||
@@ -684,15 +693,18 @@ def save_user_credentials(
|
||||
# Use template to create the full connection config
|
||||
try:
|
||||
# TODO: fix and/or type correctly w/base model
|
||||
auth_template_dict = extract_connection_data(
|
||||
auth_template, apply_mask=False
|
||||
)
|
||||
config_data = MCPConnectionData(
|
||||
headers=auth_template.config.get("headers", {}),
|
||||
headers=auth_template_dict.get("headers", {}),
|
||||
header_substitutions=request.credentials,
|
||||
)
|
||||
for oauth_field_key in MCPOAuthKeys:
|
||||
field_key: Literal["client_info", "tokens", "metadata"] = (
|
||||
oauth_field_key.value
|
||||
)
|
||||
if field_val := auth_template.config.get(field_key):
|
||||
if field_val := auth_template_dict.get(field_key):
|
||||
config_data[field_key] = field_val
|
||||
|
||||
except Exception as e:
|
||||
@@ -839,18 +851,20 @@ def _db_mcp_server_to_api_mcp_server(
|
||||
and db_server.admin_connection_config is not None
|
||||
and include_auth_config
|
||||
):
|
||||
admin_config_dict = extract_connection_data(
|
||||
db_server.admin_connection_config, apply_mask=False
|
||||
)
|
||||
if db_server.auth_type == MCPAuthenticationType.API_TOKEN:
|
||||
raw_api_key = admin_config_dict["headers"]["Authorization"].split(" ")[
|
||||
-1
|
||||
]
|
||||
admin_credentials = {
|
||||
"api_key": db_server.admin_connection_config.config["headers"][
|
||||
"Authorization"
|
||||
].split(" ")[-1]
|
||||
"api_key": mask_string(raw_api_key),
|
||||
}
|
||||
elif db_server.auth_type == MCPAuthenticationType.OAUTH:
|
||||
user_authenticated = False
|
||||
client_info = None
|
||||
client_info_raw = db_server.admin_connection_config.config.get(
|
||||
MCPOAuthKeys.CLIENT_INFO.value
|
||||
)
|
||||
client_info_raw = admin_config_dict.get(MCPOAuthKeys.CLIENT_INFO.value)
|
||||
if client_info_raw:
|
||||
client_info = OAuthClientInformationFull.model_validate(
|
||||
client_info_raw
|
||||
@@ -861,8 +875,8 @@ def _db_mcp_server_to_api_mcp_server(
|
||||
"Stored client info had empty client ID or secret"
|
||||
)
|
||||
admin_credentials = {
|
||||
"client_id": client_info.client_id,
|
||||
"client_secret": client_info.client_secret,
|
||||
"client_id": mask_string(client_info.client_id),
|
||||
"client_secret": mask_string(client_info.client_secret),
|
||||
}
|
||||
else:
|
||||
admin_credentials = {}
|
||||
@@ -879,14 +893,18 @@ def _db_mcp_server_to_api_mcp_server(
|
||||
include_auth_config
|
||||
and db_server.auth_type != MCPAuthenticationType.OAUTH
|
||||
):
|
||||
user_credentials = user_config.config.get(HEADER_SUBSTITUTIONS, {})
|
||||
user_config_dict = extract_connection_data(user_config, apply_mask=True)
|
||||
user_credentials = user_config_dict.get(HEADER_SUBSTITUTIONS, {})
|
||||
|
||||
if (
|
||||
db_server.auth_type == MCPAuthenticationType.OAUTH
|
||||
and db_server.admin_connection_config
|
||||
):
|
||||
client_info = None
|
||||
client_info_raw = db_server.admin_connection_config.config.get(
|
||||
oauth_admin_config_dict = extract_connection_data(
|
||||
db_server.admin_connection_config, apply_mask=False
|
||||
)
|
||||
client_info_raw = oauth_admin_config_dict.get(
|
||||
MCPOAuthKeys.CLIENT_INFO.value
|
||||
)
|
||||
if client_info_raw:
|
||||
@@ -896,8 +914,8 @@ def _db_mcp_server_to_api_mcp_server(
|
||||
raise ValueError("Stored client info had empty client ID or secret")
|
||||
if can_view_admin_credentials:
|
||||
admin_credentials = {
|
||||
"client_id": client_info.client_id,
|
||||
"client_secret": client_info.client_secret,
|
||||
"client_id": mask_string(client_info.client_id),
|
||||
"client_secret": mask_string(client_info.client_secret),
|
||||
}
|
||||
elif can_view_admin_credentials:
|
||||
admin_credentials = {}
|
||||
@@ -909,7 +927,10 @@ def _db_mcp_server_to_api_mcp_server(
|
||||
try:
|
||||
template_config = db_server.admin_connection_config
|
||||
if template_config:
|
||||
headers = template_config.config.get("headers", {})
|
||||
template_config_dict = extract_connection_data(
|
||||
template_config, apply_mask=False
|
||||
)
|
||||
headers = template_config_dict.get("headers", {})
|
||||
auth_template = MCPAuthTemplate(
|
||||
headers=headers,
|
||||
required_fields=[], # would need to regex, not worth it
|
||||
@@ -1232,7 +1253,10 @@ def _list_mcp_tools_by_id(
|
||||
)
|
||||
|
||||
if connection_config:
|
||||
headers.update(connection_config.config.get("headers", {}))
|
||||
connection_config_dict = extract_connection_data(
|
||||
connection_config, apply_mask=False
|
||||
)
|
||||
headers.update(connection_config_dict.get("headers", {}))
|
||||
|
||||
import time
|
||||
|
||||
@@ -1320,7 +1344,10 @@ def _upsert_mcp_server(
|
||||
_ensure_mcp_server_owner_or_admin(mcp_server, user)
|
||||
client_info = None
|
||||
if mcp_server.admin_connection_config:
|
||||
client_info_raw = mcp_server.admin_connection_config.config.get(
|
||||
existing_admin_config_dict = extract_connection_data(
|
||||
mcp_server.admin_connection_config, apply_mask=False
|
||||
)
|
||||
client_info_raw = existing_admin_config_dict.get(
|
||||
MCPOAuthKeys.CLIENT_INFO.value
|
||||
)
|
||||
if client_info_raw:
|
||||
|
||||
@@ -32,11 +32,16 @@ def get_user_oauth_token_status(
|
||||
and whether their tokens are expired.
|
||||
"""
|
||||
user_tokens = get_all_user_oauth_tokens(user.id, db_session)
|
||||
return [
|
||||
OAuthTokenStatus(
|
||||
oauth_config_id=token.oauth_config_id,
|
||||
expires_at=OAuthTokenManager.token_expiration_time(token.token_data),
|
||||
is_expired=OAuthTokenManager.is_token_expired(token.token_data),
|
||||
result = []
|
||||
for token in user_tokens:
|
||||
token_data = (
|
||||
token.token_data.get_value(apply_mask=False) if token.token_data else {}
|
||||
)
|
||||
for token in user_tokens
|
||||
]
|
||||
result.append(
|
||||
OAuthTokenStatus(
|
||||
oauth_config_id=token.oauth_config_id,
|
||||
expires_at=OAuthTokenManager.token_expiration_time(token_data),
|
||||
is_expired=OAuthTokenManager.is_token_expired(token_data),
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
@@ -75,7 +75,7 @@ def _get_active_search_provider(
|
||||
has_api_key=bool(provider_model.api_key),
|
||||
)
|
||||
|
||||
if not provider_model.api_key:
|
||||
if provider_model.api_key is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="Web search provider requires an API key.",
|
||||
@@ -84,7 +84,7 @@ def _get_active_search_provider(
|
||||
try:
|
||||
provider: WebSearchProvider = build_search_provider_from_config(
|
||||
provider_type=provider_view.provider_type,
|
||||
api_key=provider_model.api_key,
|
||||
api_key=provider_model.api_key.get_value(apply_mask=False),
|
||||
config=provider_model.config or {},
|
||||
)
|
||||
except ValueError as exc:
|
||||
@@ -121,7 +121,7 @@ def _get_active_content_provider(
|
||||
|
||||
provider: WebContentProvider | None = build_content_provider_from_config(
|
||||
provider_type=provider_type,
|
||||
api_key=provider_model.api_key,
|
||||
api_key=provider_model.api_key.get_value(apply_mask=False),
|
||||
config=config,
|
||||
)
|
||||
except ValueError as exc:
|
||||
|
||||
@@ -114,9 +114,14 @@ def get_entities(
|
||||
federated_connector = fetch_federated_connector_by_id(id, db_session)
|
||||
if not federated_connector:
|
||||
raise HTTPException(status_code=404, detail="Federated connector not found")
|
||||
if federated_connector.credentials is None:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Federated connector has no credentials"
|
||||
)
|
||||
|
||||
connector_instance = _get_federated_connector_instance(
|
||||
federated_connector.source, federated_connector.credentials
|
||||
federated_connector.source,
|
||||
federated_connector.credentials.get_value(apply_mask=False),
|
||||
)
|
||||
entities_spec = connector_instance.configuration_schema()
|
||||
|
||||
@@ -151,9 +156,14 @@ def get_credentials_schema(
|
||||
federated_connector = fetch_federated_connector_by_id(id, db_session)
|
||||
if not federated_connector:
|
||||
raise HTTPException(status_code=404, detail="Federated connector not found")
|
||||
if federated_connector.credentials is None:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Federated connector has no credentials"
|
||||
)
|
||||
|
||||
connector_instance = _get_federated_connector_instance(
|
||||
federated_connector.source, federated_connector.credentials
|
||||
federated_connector.source,
|
||||
federated_connector.credentials.get_value(apply_mask=False),
|
||||
)
|
||||
credentials_spec = connector_instance.credentials_schema()
|
||||
|
||||
@@ -275,6 +285,8 @@ def validate_entities(
|
||||
federated_connector = fetch_federated_connector_by_id(id, db_session)
|
||||
if not federated_connector:
|
||||
raise HTTPException(status_code=404, detail="Federated connector not found")
|
||||
if federated_connector.credentials is None:
|
||||
return Response(status_code=400)
|
||||
|
||||
# For HEAD requests, we'll expect entities as query parameters
|
||||
# since HEAD requests shouldn't have request bodies
|
||||
@@ -288,7 +300,8 @@ def validate_entities(
|
||||
return Response(status_code=400)
|
||||
|
||||
connector_instance = _get_federated_connector_instance(
|
||||
federated_connector.source, federated_connector.credentials
|
||||
federated_connector.source,
|
||||
federated_connector.credentials.get_value(apply_mask=False),
|
||||
)
|
||||
is_valid = connector_instance.validate_entities(entities_dict)
|
||||
|
||||
@@ -318,9 +331,15 @@ def get_authorize_url(
|
||||
federated_connector = fetch_federated_connector_by_id(id, db_session)
|
||||
if not federated_connector:
|
||||
raise HTTPException(status_code=404, detail="Federated connector not found")
|
||||
if federated_connector.credentials is None:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Federated connector has no credentials"
|
||||
)
|
||||
|
||||
# Update credentials to include the correct redirect URI with the connector ID
|
||||
updated_credentials = federated_connector.credentials.copy()
|
||||
updated_credentials = federated_connector.credentials.get_value(
|
||||
apply_mask=False
|
||||
).copy()
|
||||
if "redirect_uri" in updated_credentials and updated_credentials["redirect_uri"]:
|
||||
# Replace the {id} placeholder with the actual federated connector ID
|
||||
updated_credentials["redirect_uri"] = updated_credentials[
|
||||
@@ -391,9 +410,14 @@ def handle_oauth_callback_generic(
|
||||
)
|
||||
if not federated_connector:
|
||||
raise HTTPException(status_code=404, detail="Federated connector not found")
|
||||
if federated_connector.credentials is None:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Federated connector has no credentials"
|
||||
)
|
||||
|
||||
connector_instance = _get_federated_connector_instance(
|
||||
federated_connector.source, federated_connector.credentials
|
||||
federated_connector.source,
|
||||
federated_connector.credentials.get_value(apply_mask=False),
|
||||
)
|
||||
oauth_result = connector_instance.callback(callback_data, get_oauth_callback_uri())
|
||||
|
||||
@@ -460,9 +484,9 @@ def get_user_oauth_status(
|
||||
|
||||
# Generate authorize URL if needed
|
||||
authorize_url = None
|
||||
if not oauth_token:
|
||||
if not oauth_token and fc.credentials is not None:
|
||||
connector_instance = _get_federated_connector_instance(
|
||||
fc.source, fc.credentials
|
||||
fc.source, fc.credentials.get_value(apply_mask=False)
|
||||
)
|
||||
base_authorize_url = connector_instance.authorize(get_oauth_callback_uri())
|
||||
|
||||
@@ -496,6 +520,10 @@ def get_federated_connector_detail(
|
||||
federated_connector = fetch_federated_connector_by_id(id, db_session)
|
||||
if not federated_connector:
|
||||
raise HTTPException(status_code=404, detail="Federated connector not found")
|
||||
if federated_connector.credentials is None:
|
||||
raise HTTPException(
|
||||
status_code=400, detail="Federated connector has no credentials"
|
||||
)
|
||||
|
||||
# Get OAuth token information for the current user
|
||||
oauth_token = None
|
||||
@@ -521,7 +549,9 @@ def get_federated_connector_detail(
|
||||
id=federated_connector.id,
|
||||
source=federated_connector.source,
|
||||
name=f"{federated_connector.source.replace('_', ' ').title()}",
|
||||
credentials=FederatedConnectorCredentials(**federated_connector.credentials),
|
||||
credentials=FederatedConnectorCredentials(
|
||||
**federated_connector.credentials.get_value(apply_mask=True)
|
||||
),
|
||||
config=federated_connector.config,
|
||||
oauth_token_exists=oauth_token is not None,
|
||||
oauth_token_expires_at=oauth_token.expires_at if oauth_token else None,
|
||||
|
||||
@@ -16,7 +16,7 @@ from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
|
||||
from onyx.server.manage.embedding.models import CloudEmbeddingProvider
|
||||
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||
from onyx.server.manage.embedding.models import TestEmbeddingRequest
|
||||
from onyx.server.utils import mask_string
|
||||
from onyx.utils.encryption import mask_string
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
|
||||
@@ -37,7 +37,11 @@ class CloudEmbeddingProvider(BaseModel):
|
||||
) -> "CloudEmbeddingProvider":
|
||||
return cls(
|
||||
provider_type=cloud_provider_model.provider_type,
|
||||
api_key=cloud_provider_model.api_key,
|
||||
api_key=(
|
||||
cloud_provider_model.api_key.get_value(apply_mask=True)
|
||||
if cloud_provider_model.api_key
|
||||
else None
|
||||
),
|
||||
api_url=cloud_provider_model.api_url,
|
||||
api_version=cloud_provider_model.api_version,
|
||||
deployment_name=cloud_provider_model.deployment_name,
|
||||
|
||||
@@ -90,7 +90,11 @@ def _build_llm_provider_request(
|
||||
return LLMProviderUpsertRequest(
|
||||
name=f"Image Gen - {image_provider_id}",
|
||||
provider=source_provider.provider,
|
||||
api_key=source_provider.api_key, # Only this from source
|
||||
api_key=(
|
||||
source_provider.api_key.get_value(apply_mask=False)
|
||||
if source_provider.api_key
|
||||
else None
|
||||
), # Only this from source
|
||||
api_base=api_base, # From request
|
||||
api_version=api_version, # From request
|
||||
default_model_name=model_name,
|
||||
@@ -227,7 +231,11 @@ def test_image_generation(
|
||||
api_key_changed=False, # Using stored key from source provider
|
||||
)
|
||||
|
||||
api_key = source_provider.api_key
|
||||
api_key = (
|
||||
source_provider.api_key.get_value(apply_mask=False)
|
||||
if source_provider.api_key
|
||||
else None
|
||||
)
|
||||
provider = source_provider.provider
|
||||
|
||||
if provider is None:
|
||||
@@ -431,7 +439,11 @@ def update_config(
|
||||
api_key_changed=False,
|
||||
)
|
||||
# Preserve existing API key when user didn't change it
|
||||
actual_api_key = old_provider.api_key
|
||||
actual_api_key = (
|
||||
old_provider.api_key.get_value(apply_mask=False)
|
||||
if old_provider.api_key
|
||||
else None
|
||||
)
|
||||
|
||||
# 3. Build and create new LLM provider
|
||||
provider_request = _build_llm_provider_request(
|
||||
|
||||
@@ -140,7 +140,11 @@ class ImageGenerationCredentials(BaseModel):
|
||||
"""
|
||||
llm_provider = config.model_configuration.llm_provider
|
||||
return cls(
|
||||
api_key=_mask_api_key(llm_provider.api_key),
|
||||
api_key=_mask_api_key(
|
||||
llm_provider.api_key.get_value(apply_mask=False)
|
||||
if llm_provider.api_key
|
||||
else None
|
||||
),
|
||||
api_base=llm_provider.api_base,
|
||||
api_version=llm_provider.api_version,
|
||||
deployment_name=llm_provider.deployment_name,
|
||||
@@ -168,7 +172,11 @@ class DefaultImageGenerationConfig(BaseModel):
|
||||
model_configuration_id=config.model_configuration_id,
|
||||
model_name=config.model_configuration.name,
|
||||
provider=llm_provider.provider,
|
||||
api_key=llm_provider.api_key,
|
||||
api_key=(
|
||||
llm_provider.api_key.get_value(apply_mask=False)
|
||||
if llm_provider.api_key
|
||||
else None
|
||||
),
|
||||
api_base=llm_provider.api_base,
|
||||
api_version=llm_provider.api_version,
|
||||
deployment_name=llm_provider.deployment_name,
|
||||
|
||||
@@ -203,7 +203,11 @@ def test_llm_configuration(
|
||||
new_custom_config=test_llm_request.custom_config,
|
||||
api_key_changed=False,
|
||||
)
|
||||
test_api_key = existing_provider.api_key
|
||||
test_api_key = (
|
||||
existing_provider.api_key.get_value(apply_mask=False)
|
||||
if existing_provider.api_key
|
||||
else None
|
||||
)
|
||||
if existing_provider and not test_llm_request.custom_config_changed:
|
||||
test_custom_config = existing_provider.custom_config
|
||||
|
||||
@@ -351,7 +355,11 @@ def put_llm_provider(
|
||||
# the llm api key is sanitized when returned to clients, so the only time we
|
||||
# should get a real key is when it is explicitly changed
|
||||
if existing_provider and not llm_provider_upsert_request.api_key_changed:
|
||||
llm_provider_upsert_request.api_key = existing_provider.api_key
|
||||
llm_provider_upsert_request.api_key = (
|
||||
existing_provider.api_key.get_value(apply_mask=False)
|
||||
if existing_provider.api_key
|
||||
else None
|
||||
)
|
||||
if existing_provider and not llm_provider_upsert_request.custom_config_changed:
|
||||
llm_provider_upsert_request.custom_config = existing_provider.custom_config
|
||||
|
||||
@@ -646,7 +654,11 @@ def get_provider_contextual_cost(
|
||||
provider=provider.provider,
|
||||
model=model_configuration.name,
|
||||
deployment_name=provider.deployment_name,
|
||||
api_key=provider.api_key,
|
||||
api_key=(
|
||||
provider.api_key.get_value(apply_mask=False)
|
||||
if provider.api_key
|
||||
else None
|
||||
),
|
||||
api_base=provider.api_base,
|
||||
api_version=provider.api_version,
|
||||
custom_config=provider.custom_config,
|
||||
@@ -926,6 +938,11 @@ def get_ollama_available_models(
|
||||
)
|
||||
)
|
||||
|
||||
sorted_results = sorted(
|
||||
all_models_with_context_size_and_vision,
|
||||
key=lambda m: m.name.lower(),
|
||||
)
|
||||
|
||||
# Sync new models to DB if provider_name is specified
|
||||
if request.provider_name:
|
||||
try:
|
||||
@@ -936,7 +953,7 @@ def get_ollama_available_models(
|
||||
"max_input_tokens": r.max_input_tokens,
|
||||
"supports_image_input": r.supports_image_input,
|
||||
}
|
||||
for r in all_models_with_context_size_and_vision
|
||||
for r in sorted_results
|
||||
]
|
||||
new_count = sync_model_configurations(
|
||||
db_session=db_session,
|
||||
@@ -950,7 +967,7 @@ def get_ollama_available_models(
|
||||
except ValueError as e:
|
||||
logger.warning(f"Failed to sync Ollama models to DB: {e}")
|
||||
|
||||
return all_models_with_context_size_and_vision
|
||||
return sorted_results
|
||||
|
||||
|
||||
def _get_openrouter_models_response(api_base: str, api_key: str) -> dict:
|
||||
|
||||
@@ -190,7 +190,11 @@ class LLMProviderView(LLMProvider):
|
||||
id=llm_provider_model.id,
|
||||
name=llm_provider_model.name,
|
||||
provider=provider,
|
||||
api_key=llm_provider_model.api_key,
|
||||
api_key=(
|
||||
llm_provider_model.api_key.get_value(apply_mask=False)
|
||||
if llm_provider_model.api_key
|
||||
else None
|
||||
),
|
||||
api_base=llm_provider_model.api_base,
|
||||
api_version=llm_provider_model.api_version,
|
||||
custom_config=llm_provider_model.custom_config,
|
||||
|
||||
@@ -79,6 +79,7 @@ class UserPersonalization(BaseModel):
|
||||
role: str = ""
|
||||
use_memories: bool = True
|
||||
memories: list[str] = Field(default_factory=list)
|
||||
user_preferences: str = ""
|
||||
|
||||
|
||||
class TenantSnapshot(BaseModel):
|
||||
@@ -160,6 +161,7 @@ class UserInfo(BaseModel):
|
||||
role=user.personal_role or "",
|
||||
use_memories=user.use_memories,
|
||||
memories=[memory.memory_text for memory in (user.memories or [])],
|
||||
user_preferences=user.user_preferences or "",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -213,6 +215,7 @@ class PersonalizationUpdateRequest(BaseModel):
|
||||
role: str | None = None
|
||||
use_memories: bool | None = None
|
||||
memories: list[str] | None = None
|
||||
user_preferences: str | None = Field(default=None, max_length=500)
|
||||
|
||||
|
||||
class SlackBotCreationRequest(BaseModel):
|
||||
@@ -341,9 +344,21 @@ class SlackBot(BaseModel):
|
||||
name=slack_bot_model.name,
|
||||
enabled=slack_bot_model.enabled,
|
||||
configs_count=len(slack_bot_model.slack_channel_configs),
|
||||
bot_token=slack_bot_model.bot_token,
|
||||
app_token=slack_bot_model.app_token,
|
||||
user_token=slack_bot_model.user_token,
|
||||
bot_token=(
|
||||
slack_bot_model.bot_token.get_value(apply_mask=True)
|
||||
if slack_bot_model.bot_token
|
||||
else ""
|
||||
),
|
||||
app_token=(
|
||||
slack_bot_model.app_token.get_value(apply_mask=True)
|
||||
if slack_bot_model.app_token
|
||||
else ""
|
||||
),
|
||||
user_token=(
|
||||
slack_bot_model.user_token.get_value(apply_mask=True)
|
||||
if slack_bot_model.user_token
|
||||
else None
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -844,6 +844,11 @@ def update_user_personalization_api(
|
||||
new_memories = (
|
||||
request.memories if request.memories is not None else existing_memories
|
||||
)
|
||||
new_user_preferences = (
|
||||
request.user_preferences
|
||||
if request.user_preferences is not None
|
||||
else user.user_preferences
|
||||
)
|
||||
|
||||
update_user_personalization(
|
||||
user.id,
|
||||
@@ -851,6 +856,7 @@ def update_user_personalization_api(
|
||||
personal_role=new_role,
|
||||
use_memories=new_use_memories,
|
||||
memories=new_memories,
|
||||
user_preferences=new_user_preferences,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
@@ -194,7 +194,7 @@ def test_search_provider(
|
||||
status_code=400,
|
||||
detail="No stored API key found for this provider type.",
|
||||
)
|
||||
api_key = existing_provider.api_key
|
||||
api_key = existing_provider.api_key.get_value(apply_mask=False)
|
||||
|
||||
if provider_requires_api_key and not api_key:
|
||||
raise HTTPException(
|
||||
@@ -391,7 +391,7 @@ def test_content_provider(
|
||||
detail="Base URL cannot differ from stored provider when using stored API key",
|
||||
)
|
||||
|
||||
api_key = existing_provider.api_key
|
||||
api_key = existing_provider.api_key.get_value(apply_mask=False)
|
||||
|
||||
if not api_key:
|
||||
raise HTTPException(
|
||||
|
||||
@@ -8,10 +8,6 @@ from uuid import UUID
|
||||
from fastapi import HTTPException
|
||||
from fastapi import status
|
||||
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD,
|
||||
)
|
||||
|
||||
|
||||
class BasicAuthenticationError(HTTPException):
|
||||
def __init__(self, detail: str):
|
||||
@@ -45,42 +41,6 @@ def get_json_line(
|
||||
return json.dumps(json_dict, cls=encoder) + "\n"
|
||||
|
||||
|
||||
def mask_string(sensitive_str: str) -> str:
|
||||
return "****...**" + sensitive_str[-4:]
|
||||
|
||||
|
||||
MASK_CREDENTIALS_WHITELIST = {
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD,
|
||||
"wiki_base",
|
||||
"cloud_name",
|
||||
"cloud_id",
|
||||
}
|
||||
|
||||
|
||||
def mask_credential_dict(credential_dict: dict[str, Any]) -> dict[str, str]:
|
||||
masked_creds = {}
|
||||
for key, val in credential_dict.items():
|
||||
if isinstance(val, str):
|
||||
# we want to pass the authentication_method field through so the frontend
|
||||
# can disambiguate credentials created by different methods
|
||||
if key in MASK_CREDENTIALS_WHITELIST:
|
||||
masked_creds[key] = val
|
||||
else:
|
||||
masked_creds[key] = mask_string(val)
|
||||
continue
|
||||
|
||||
if isinstance(val, int):
|
||||
masked_creds[key] = "*****"
|
||||
continue
|
||||
|
||||
raise ValueError(
|
||||
f"Unable to mask credentials of type other than string or int, cannot process request."
|
||||
f"Received type: {type(val)}"
|
||||
)
|
||||
|
||||
return masked_creds
|
||||
|
||||
|
||||
def make_short_id() -> str:
|
||||
"""Fast way to generate a random 8 character id ... useful for tagging data
|
||||
to trace it through a flow. This is definitely not guaranteed to be unique and is
|
||||
|
||||
@@ -446,7 +446,7 @@ def run_research_agent_call(
|
||||
tool_calls=tool_calls,
|
||||
tools=current_tools,
|
||||
message_history=msg_history,
|
||||
memories=None,
|
||||
user_memory_context=None,
|
||||
user_info=None,
|
||||
citation_mapping=citation_mapping,
|
||||
next_citation_num=citation_processor.get_next_citation_number(),
|
||||
|
||||
@@ -16,6 +16,7 @@ from onyx.configs.chat_configs import NUM_RETURNED_HITS
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.models import SearchDoc
|
||||
from onyx.context.search.models import SearchDocsResponse
|
||||
from onyx.db.memory import UserMemoryContext
|
||||
from onyx.server.query_and_chat.placement import Placement
|
||||
from onyx.server.query_and_chat.streaming_models import GeneratedImage
|
||||
from onyx.tools.tool_implementations.images.models import FinalImageGenerationResponse
|
||||
@@ -165,7 +166,7 @@ class SearchToolOverrideKwargs(BaseModel):
|
||||
# without help and a specific custom prompt for this
|
||||
original_query: str | None = None
|
||||
message_history: list[ChatMinimalTextMessage] | None = None
|
||||
memories: list[str] | None = None
|
||||
user_memory_context: UserMemoryContext | None = None
|
||||
user_info: str | None = None
|
||||
|
||||
# Used for tool calls after the first one but in the same chat turn. The reason for this is that if the initial pass through
|
||||
|
||||
@@ -82,7 +82,11 @@ def _get_image_generation_config(llm: LLM, db_session: Session) -> LLMConfig:
|
||||
model_provider=llm_provider.provider,
|
||||
model_name=default_config.model_configuration.name,
|
||||
temperature=GEN_AI_TEMPERATURE,
|
||||
api_key=llm_provider.api_key,
|
||||
api_key=(
|
||||
llm_provider.api_key.get_value(apply_mask=False)
|
||||
if llm_provider.api_key
|
||||
else None
|
||||
),
|
||||
api_base=llm_provider.api_base,
|
||||
api_version=llm_provider.api_version,
|
||||
deployment_name=llm_provider.deployment_name,
|
||||
|
||||
@@ -94,7 +94,11 @@ class ImageGenerationTool(Tool[None]):
|
||||
|
||||
llm_provider = config.model_configuration.llm_provider
|
||||
credentials = ImageGenerationProviderCredentials(
|
||||
api_key=llm_provider.api_key,
|
||||
api_key=(
|
||||
llm_provider.api_key.get_value(apply_mask=False)
|
||||
if llm_provider.api_key
|
||||
else None
|
||||
),
|
||||
api_base=llm_provider.api_base,
|
||||
api_version=llm_provider.api_version,
|
||||
deployment_name=llm_provider.deployment_name,
|
||||
|
||||
@@ -142,8 +142,9 @@ class MCPTool(Tool[None]):
|
||||
)
|
||||
|
||||
# Priority 2: Base headers from connection config (DB) - overrides request
|
||||
if self.connection_config:
|
||||
headers.update(self.connection_config.config.get("headers", {}))
|
||||
if self.connection_config and self.connection_config.config:
|
||||
config_dict = self.connection_config.config.get_value(apply_mask=False)
|
||||
headers.update(config_dict.get("headers", {}))
|
||||
|
||||
# Priority 3: For pass-through OAuth, use the user's login OAuth token
|
||||
if self._user_oauth_token:
|
||||
|
||||
@@ -352,10 +352,17 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
)
|
||||
|
||||
if tenant_slack_bot:
|
||||
bot_token = tenant_slack_bot.bot_token
|
||||
access_token = (
|
||||
tenant_slack_bot.user_token or tenant_slack_bot.bot_token
|
||||
bot_token = (
|
||||
tenant_slack_bot.bot_token.get_value(apply_mask=False)
|
||||
if tenant_slack_bot.bot_token
|
||||
else None
|
||||
)
|
||||
user_token = (
|
||||
tenant_slack_bot.user_token.get_value(apply_mask=False)
|
||||
if tenant_slack_bot.user_token
|
||||
else None
|
||||
)
|
||||
access_token = user_token or bot_token
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not fetch Slack bot tokens: {e}")
|
||||
|
||||
@@ -375,8 +382,10 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
None,
|
||||
)
|
||||
|
||||
if slack_oauth_token:
|
||||
access_token = slack_oauth_token.token
|
||||
if slack_oauth_token and slack_oauth_token.token:
|
||||
access_token = slack_oauth_token.token.get_value(
|
||||
apply_mask=False
|
||||
)
|
||||
entities = slack_oauth_token.federated_connector.config or {}
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not fetch Slack OAuth token: {e}")
|
||||
@@ -550,7 +559,11 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
if override_kwargs.message_history
|
||||
else []
|
||||
)
|
||||
memories = override_kwargs.memories
|
||||
memories = (
|
||||
override_kwargs.user_memory_context.as_formatted_list()
|
||||
if override_kwargs.user_memory_context
|
||||
else []
|
||||
)
|
||||
user_info = override_kwargs.user_info
|
||||
|
||||
# Skip query expansion if this is a repeat search call
|
||||
|
||||
@@ -77,7 +77,11 @@ def build_search_provider_from_config(
|
||||
def _build_search_provider(provider_model: InternetSearchProvider) -> WebSearchProvider:
|
||||
return build_search_provider_from_config(
|
||||
provider_type=WebSearchProviderType(provider_model.provider_type),
|
||||
api_key=provider_model.api_key or "",
|
||||
api_key=(
|
||||
provider_model.api_key.get_value(apply_mask=False)
|
||||
if provider_model.api_key
|
||||
else ""
|
||||
),
|
||||
config=provider_model.config or {},
|
||||
)
|
||||
|
||||
@@ -129,7 +133,11 @@ def get_default_content_provider() -> WebContentProvider:
|
||||
if provider_model:
|
||||
provider = build_content_provider_from_config(
|
||||
provider_type=WebContentProviderType(provider_model.provider_type),
|
||||
api_key=provider_model.api_key or "",
|
||||
api_key=(
|
||||
provider_model.api_key.get_value(apply_mask=False)
|
||||
if provider_model.api_key
|
||||
else ""
|
||||
),
|
||||
config=provider_model.config or WebContentProviderConfig(),
|
||||
)
|
||||
if provider:
|
||||
|
||||
@@ -69,7 +69,11 @@ class WebSearchTool(Tool[WebSearchToolOverrideKwargs]):
|
||||
if provider_model is None:
|
||||
raise RuntimeError("No web search provider configured.")
|
||||
provider_type = WebSearchProviderType(provider_model.provider_type)
|
||||
api_key = provider_model.api_key
|
||||
api_key = (
|
||||
provider_model.api_key.get_value(apply_mask=False)
|
||||
if provider_model.api_key
|
||||
else None
|
||||
)
|
||||
config = provider_model.config
|
||||
|
||||
# TODO - This should just be enforced at the DB level
|
||||
|
||||
@@ -6,6 +6,7 @@ import onyx.tracing.framework._error_tracing as _error_tracing
|
||||
from onyx.chat.models import ChatMessageSimple
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.context.search.models import SearchDocsResponse
|
||||
from onyx.db.memory import UserMemoryContext
|
||||
from onyx.server.query_and_chat.streaming_models import Packet
|
||||
from onyx.server.query_and_chat.streaming_models import PacketException
|
||||
from onyx.server.query_and_chat.streaming_models import SectionEnd
|
||||
@@ -220,7 +221,7 @@ def run_tool_calls(
|
||||
tools: list[Tool],
|
||||
# The stuff below is needed for the different individual built-in tools
|
||||
message_history: list[ChatMessageSimple],
|
||||
memories: list[str] | None,
|
||||
user_memory_context: UserMemoryContext | None,
|
||||
user_info: str | None,
|
||||
citation_mapping: dict[int, str],
|
||||
next_citation_num: int,
|
||||
@@ -252,7 +253,7 @@ def run_tool_calls(
|
||||
tools: List of available tool instances.
|
||||
message_history: Chat message history (used to find the most recent user query
|
||||
for `SearchTool` override kwargs).
|
||||
memories: User memories, if available (passed through to `SearchTool`).
|
||||
user_memory_context: User memory context, if available (passed through to `SearchTool`).
|
||||
user_info: User information string, if available (passed through to `SearchTool`).
|
||||
citation_mapping: Current citation number to URL mapping. May be updated with
|
||||
new citations produced by search tools.
|
||||
@@ -342,7 +343,7 @@ def run_tool_calls(
|
||||
starting_citation_num=starting_citation_num,
|
||||
original_query=last_user_message,
|
||||
message_history=minimal_history,
|
||||
memories=memories,
|
||||
user_memory_context=user_memory_context,
|
||||
user_info=user_info,
|
||||
skip_query_expansion=skip_search_query_expansion,
|
||||
)
|
||||
|
||||
@@ -1,22 +1,91 @@
|
||||
from typing import Any
|
||||
|
||||
from onyx.configs.app_configs import ENCRYPTION_KEY_SECRET
|
||||
from onyx.connectors.google_utils.shared_constants import (
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
# IMPORTANT DO NOT DELETE, THIS IS USED BY fetch_versioned_implementation
|
||||
def _encrypt_string(input_str: str) -> bytes:
|
||||
if ENCRYPTION_KEY_SECRET:
|
||||
logger.warning("MIT version of Onyx does not support encryption of secrets.")
|
||||
return input_str.encode()
|
||||
|
||||
|
||||
# IMPORTANT DO NOT DELETE, THIS IS USED BY fetch_versioned_implementation
|
||||
def _decrypt_bytes(input_bytes: bytes) -> str:
|
||||
# No need to double warn. If you wish to learn more about encryption features
|
||||
# refer to the Onyx EE code
|
||||
return input_bytes.decode()
|
||||
|
||||
|
||||
def mask_string(sensitive_str: str) -> str:
|
||||
"""Masks a sensitive string, showing first and last few characters.
|
||||
If the string is too short to safely mask, returns a fully masked placeholder.
|
||||
"""
|
||||
visible_start = 4
|
||||
visible_end = 4
|
||||
min_masked_chars = 6
|
||||
|
||||
if len(sensitive_str) < visible_start + visible_end + min_masked_chars:
|
||||
return "••••••••••••"
|
||||
|
||||
return f"{sensitive_str[:visible_start]}...{sensitive_str[-visible_end:]}"
|
||||
|
||||
|
||||
MASK_CREDENTIALS_WHITELIST = {
|
||||
DB_CREDENTIALS_AUTHENTICATION_METHOD,
|
||||
"wiki_base",
|
||||
"cloud_name",
|
||||
"cloud_id",
|
||||
}
|
||||
|
||||
|
||||
def mask_credential_dict(credential_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
masked_creds: dict[str, Any] = {}
|
||||
for key, val in credential_dict.items():
|
||||
if isinstance(val, str):
|
||||
# we want to pass the authentication_method field through so the frontend
|
||||
# can disambiguate credentials created by different methods
|
||||
if key in MASK_CREDENTIALS_WHITELIST:
|
||||
masked_creds[key] = val
|
||||
else:
|
||||
masked_creds[key] = mask_string(val)
|
||||
elif isinstance(val, dict):
|
||||
masked_creds[key] = mask_credential_dict(val)
|
||||
elif isinstance(val, list):
|
||||
masked_creds[key] = _mask_list(val)
|
||||
elif isinstance(val, (bool, type(None))):
|
||||
masked_creds[key] = val
|
||||
elif isinstance(val, (int, float)):
|
||||
masked_creds[key] = "*****"
|
||||
else:
|
||||
masked_creds[key] = "*****"
|
||||
|
||||
return masked_creds
|
||||
|
||||
|
||||
def _mask_list(items: list[Any]) -> list[Any]:
|
||||
masked: list[Any] = []
|
||||
for item in items:
|
||||
if isinstance(item, dict):
|
||||
masked.append(mask_credential_dict(item))
|
||||
elif isinstance(item, str):
|
||||
masked.append(mask_string(item))
|
||||
elif isinstance(item, list):
|
||||
masked.append(_mask_list(item))
|
||||
elif isinstance(item, (bool, type(None))):
|
||||
masked.append(item)
|
||||
else:
|
||||
masked.append("*****")
|
||||
return masked
|
||||
|
||||
|
||||
def encrypt_string_to_bytes(intput_str: str) -> bytes:
|
||||
versioned_encryption_fn = fetch_versioned_implementation(
|
||||
"onyx.utils.encryption", "_encrypt_string"
|
||||
|
||||
205
backend/onyx/utils/sensitive.py
Normal file
205
backend/onyx/utils/sensitive.py
Normal file
@@ -0,0 +1,205 @@
|
||||
"""
|
||||
Wrapper class for sensitive values that require explicit masking decisions.
|
||||
|
||||
This module provides a wrapper for encrypted values that forces developers to
|
||||
make an explicit decision about whether to mask the value when accessing it.
|
||||
This prevents accidental exposure of sensitive data in API responses.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
from typing import Generic
|
||||
from typing import NoReturn
|
||||
from typing import TypeVar
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from onyx.utils.encryption import mask_credential_dict
|
||||
from onyx.utils.encryption import mask_string
|
||||
|
||||
|
||||
T = TypeVar("T", str, dict[str, Any])
|
||||
|
||||
|
||||
def make_mock_sensitive_value(value: dict[str, Any] | str | None) -> MagicMock:
|
||||
"""
|
||||
Create a mock SensitiveValue for use in tests.
|
||||
|
||||
This helper makes it easy to create mock objects that behave like
|
||||
SensitiveValue for testing code that uses credentials.
|
||||
|
||||
Args:
|
||||
value: The value to return from get_value(). Can be a dict, string, or None.
|
||||
|
||||
Returns:
|
||||
A MagicMock configured to behave like a SensitiveValue.
|
||||
|
||||
Example:
|
||||
>>> mock_credential = MagicMock()
|
||||
>>> mock_credential.credential_json = make_mock_sensitive_value({"api_key": "secret"})
|
||||
>>> # Now mock_credential.credential_json.get_value(apply_mask=False) returns {"api_key": "secret"}
|
||||
"""
|
||||
if value is None:
|
||||
return None # type: ignore[return-value]
|
||||
|
||||
mock = MagicMock(spec=SensitiveValue)
|
||||
mock.get_value.return_value = value
|
||||
mock.__bool__ = lambda self: True # noqa: ARG005
|
||||
return mock
|
||||
|
||||
|
||||
class SensitiveAccessError(Exception):
|
||||
"""Raised when attempting to access a SensitiveValue without explicit masking decision."""
|
||||
|
||||
|
||||
class SensitiveValue(Generic[T]):
|
||||
"""
|
||||
Wrapper requiring explicit masking decisions for sensitive data.
|
||||
|
||||
This class wraps encrypted data and forces callers to make an explicit
|
||||
decision about whether to mask the value when accessing it. This prevents
|
||||
accidental exposure of sensitive data.
|
||||
|
||||
Usage:
|
||||
# Get raw value (for internal use like connectors)
|
||||
raw_value = sensitive.get_value(apply_mask=False)
|
||||
|
||||
# Get masked value (for API responses)
|
||||
masked_value = sensitive.get_value(apply_mask=True)
|
||||
|
||||
Raises SensitiveAccessError when:
|
||||
- Attempting to convert to string via str() or repr()
|
||||
- Attempting to iterate over the value
|
||||
- Attempting to subscript the value (e.g., value["key"])
|
||||
- Attempting to serialize to JSON without explicit get_value()
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
encrypted_bytes: bytes,
|
||||
decrypt_fn: Callable[[bytes], str],
|
||||
is_json: bool = False,
|
||||
) -> None:
|
||||
"""
|
||||
Initialize a SensitiveValue wrapper.
|
||||
|
||||
Args:
|
||||
encrypted_bytes: The encrypted bytes to wrap
|
||||
decrypt_fn: Function to decrypt bytes to string
|
||||
is_json: If True, the decrypted value is JSON and will be parsed to dict
|
||||
"""
|
||||
self._encrypted_bytes = encrypted_bytes
|
||||
self._decrypt_fn = decrypt_fn
|
||||
self._is_json = is_json
|
||||
# Cache for decrypted value to avoid repeated decryption
|
||||
self._decrypted_value: T | None = None
|
||||
|
||||
def _decrypt(self) -> T:
|
||||
"""Lazily decrypt and cache the value."""
|
||||
if self._decrypted_value is None:
|
||||
decrypted_str = self._decrypt_fn(self._encrypted_bytes)
|
||||
if self._is_json:
|
||||
self._decrypted_value = json.loads(decrypted_str)
|
||||
else:
|
||||
self._decrypted_value = decrypted_str # type: ignore[assignment]
|
||||
# The return type should always match T based on is_json flag
|
||||
return self._decrypted_value # type: ignore[return-value]
|
||||
|
||||
def get_value(
|
||||
self,
|
||||
*,
|
||||
apply_mask: bool,
|
||||
mask_fn: Callable[[T], T] | None = None,
|
||||
) -> T:
|
||||
"""
|
||||
Get the value with explicit masking decision.
|
||||
|
||||
Args:
|
||||
apply_mask: Required. True = return masked value, False = return raw value
|
||||
mask_fn: Optional custom masking function. Defaults to mask_string for
|
||||
strings and mask_credential_dict for dicts.
|
||||
|
||||
Returns:
|
||||
The value, either masked or raw depending on apply_mask.
|
||||
"""
|
||||
value = self._decrypt()
|
||||
|
||||
if not apply_mask:
|
||||
return value
|
||||
|
||||
# Apply masking
|
||||
if mask_fn is not None:
|
||||
return mask_fn(value)
|
||||
|
||||
# Use default masking based on type
|
||||
# Type narrowing doesn't work well here due to the generic T,
|
||||
# but at runtime the types will match
|
||||
if isinstance(value, dict):
|
||||
return mask_credential_dict(value)
|
||||
elif isinstance(value, str):
|
||||
return mask_string(value)
|
||||
else:
|
||||
raise ValueError(f"Cannot mask value of type {type(value)}")
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
"""Allow truthiness checks without exposing the value."""
|
||||
return True
|
||||
|
||||
def __str__(self) -> NoReturn:
|
||||
"""Prevent accidental string conversion."""
|
||||
raise SensitiveAccessError(
|
||||
"Cannot convert SensitiveValue to string. "
|
||||
"Use .get_value(apply_mask=True/False) to access the value."
|
||||
)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
"""Prevent accidental repr exposure."""
|
||||
return "<SensitiveValue: use .get_value(apply_mask=True/False) to access>"
|
||||
|
||||
def __iter__(self) -> NoReturn:
|
||||
"""Prevent iteration over the value."""
|
||||
raise SensitiveAccessError(
|
||||
"Cannot iterate over SensitiveValue. "
|
||||
"Use .get_value(apply_mask=True/False) to access the value."
|
||||
)
|
||||
|
||||
def __getitem__(self, key: Any) -> NoReturn:
|
||||
"""Prevent subscript access."""
|
||||
raise SensitiveAccessError(
|
||||
"Cannot subscript SensitiveValue. "
|
||||
"Use .get_value(apply_mask=True/False) to access the value."
|
||||
)
|
||||
|
||||
def __eq__(self, other: Any) -> bool:
|
||||
"""Prevent direct comparison which might expose value."""
|
||||
if isinstance(other, SensitiveValue):
|
||||
# Compare encrypted bytes for equality check
|
||||
return self._encrypted_bytes == other._encrypted_bytes
|
||||
raise SensitiveAccessError(
|
||||
"Cannot compare SensitiveValue with non-SensitiveValue. "
|
||||
"Use .get_value(apply_mask=True/False) to access the value for comparison."
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
"""Allow hashing based on encrypted bytes."""
|
||||
return hash(self._encrypted_bytes)
|
||||
|
||||
# Prevent JSON serialization
|
||||
def __json__(self) -> Any:
|
||||
"""Prevent JSON serialization."""
|
||||
raise SensitiveAccessError(
|
||||
"Cannot serialize SensitiveValue to JSON. "
|
||||
"Use .get_value(apply_mask=True/False) to access the value."
|
||||
)
|
||||
|
||||
# For Pydantic compatibility
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(cls, source_type: Any, handler: Any) -> Any:
|
||||
"""Prevent Pydantic from serializing without explicit get_value()."""
|
||||
raise SensitiveAccessError(
|
||||
"Cannot serialize SensitiveValue in Pydantic model. "
|
||||
"Use .get_value(apply_mask=True/False) to access the value before serialization."
|
||||
)
|
||||
@@ -36,7 +36,7 @@ global_version = OnyxVersion()
|
||||
# Eventually, ENABLE_PAID_ENTERPRISE_EDITION_FEATURES will be removed
|
||||
# and license enforcement will be the only mechanism for EE features.
|
||||
_LICENSE_ENFORCEMENT_ENABLED = (
|
||||
os.environ.get("LICENSE_ENFORCEMENT_ENABLED", "").lower() == "true"
|
||||
os.environ.get("LICENSE_ENFORCEMENT_ENABLED", "true").lower() == "true"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -112,7 +112,10 @@ def test_gdrive_perm_sync_with_real_data(
|
||||
mock_cc_pair.connector = MagicMock()
|
||||
mock_cc_pair.connector.connector_specific_config = {}
|
||||
mock_cc_pair.credential_id = 1
|
||||
mock_cc_pair.credential.credential_json = {}
|
||||
# Import and use the mock helper
|
||||
from onyx.utils.sensitive import make_mock_sensitive_value
|
||||
|
||||
mock_cc_pair.credential.credential_json = make_mock_sensitive_value({})
|
||||
mock_cc_pair.last_time_perm_sync = None
|
||||
mock_cc_pair.last_time_external_group_sync = None
|
||||
|
||||
|
||||
@@ -43,6 +43,8 @@ def ensure_default_llm_provider(db_session: Session) -> None:
|
||||
)
|
||||
update_default_provider(provider.id, db_session)
|
||||
except Exception as exc: # pragma: no cover - only hits on duplicate setup issues
|
||||
# Rollback to clear the pending transaction state
|
||||
db_session.rollback()
|
||||
print(f"Note: Could not create LLM provider: {exc}")
|
||||
|
||||
|
||||
|
||||
@@ -129,6 +129,8 @@ def test_confluence_group_sync(
|
||||
)
|
||||
db_session.add(credential)
|
||||
db_session.flush()
|
||||
# Expire the credential so it reloads from DB with SensitiveValue wrapper
|
||||
db_session.expire(credential)
|
||||
|
||||
cc_pair = ConnectorCredentialPair(
|
||||
connector_id=connector.id,
|
||||
|
||||
@@ -53,6 +53,8 @@ def _create_test_connector_credential_pair(
|
||||
)
|
||||
db_session.add(credential)
|
||||
db_session.flush() # To get the credential ID
|
||||
# Expire the credential so it reloads from DB with SensitiveValue wrapper
|
||||
db_session.expire(credential)
|
||||
|
||||
cc_pair = ConnectorCredentialPair(
|
||||
connector_id=connector.id,
|
||||
|
||||
@@ -80,6 +80,8 @@ def test_jira_doc_sync(
|
||||
)
|
||||
db_session.add(credential)
|
||||
db_session.flush()
|
||||
# Expire the credential so it reloads from DB with SensitiveValue wrapper
|
||||
db_session.expire(credential)
|
||||
|
||||
cc_pair = ConnectorCredentialPair(
|
||||
connector_id=connector.id,
|
||||
@@ -176,6 +178,8 @@ def test_jira_doc_sync_with_specific_permissions(
|
||||
)
|
||||
db_session.add(credential)
|
||||
db_session.flush()
|
||||
# Expire the credential so it reloads from DB with SensitiveValue wrapper
|
||||
db_session.expire(credential)
|
||||
|
||||
cc_pair = ConnectorCredentialPair(
|
||||
connector_id=connector.id,
|
||||
|
||||
@@ -114,6 +114,8 @@ def test_jira_group_sync(
|
||||
)
|
||||
db_session.add(credential)
|
||||
db_session.flush()
|
||||
# Expire the credential so it reloads from DB with SensitiveValue wrapper
|
||||
db_session.expire(credential)
|
||||
|
||||
cc_pair = ConnectorCredentialPair(
|
||||
connector_id=connector.id,
|
||||
|
||||
@@ -57,6 +57,8 @@ def _create_test_connector_credential_pair(
|
||||
)
|
||||
db_session.add(credential)
|
||||
db_session.flush()
|
||||
# Expire the credential so it reloads from DB with SensitiveValue wrapper
|
||||
db_session.expire(credential)
|
||||
|
||||
cc_pair = ConnectorCredentialPair(
|
||||
connector_id=connector.id,
|
||||
|
||||
@@ -67,6 +67,8 @@ def _create_test_connector_credential_pair(
|
||||
)
|
||||
db_session.add(credential)
|
||||
db_session.flush()
|
||||
# Expire the credential so it reloads from DB with SensitiveValue wrapper
|
||||
db_session.expire(credential)
|
||||
|
||||
cc_pair = ConnectorCredentialPair(
|
||||
connector_id=connector.id,
|
||||
|
||||
@@ -254,6 +254,8 @@ class TestSlackBotFederatedSearch:
|
||||
)
|
||||
db_session.add(federated_connector)
|
||||
db_session.flush()
|
||||
# Expire to ensure credentials is reloaded as SensitiveValue from DB
|
||||
db_session.expire(federated_connector)
|
||||
|
||||
# Associate the federated connector with the persona's document sets
|
||||
# This is required for Slack federated search to be enabled
|
||||
@@ -276,6 +278,8 @@ class TestSlackBotFederatedSearch:
|
||||
)
|
||||
db_session.add(slack_bot)
|
||||
db_session.flush()
|
||||
# Expire to ensure tokens are reloaded as SensitiveValue from DB
|
||||
db_session.expire(slack_bot)
|
||||
|
||||
slack_channel_config = SlackChannelConfig(
|
||||
slack_bot_id=slack_bot.id,
|
||||
|
||||
@@ -211,9 +211,11 @@ class TestOAuthTokenManagerRefresh:
|
||||
|
||||
# Verify token was updated in DB
|
||||
db_session.refresh(user_token)
|
||||
assert user_token.token_data["access_token"] == "new_token"
|
||||
assert user_token.token_data["refresh_token"] == "new_refresh"
|
||||
assert "expires_at" in user_token.token_data
|
||||
assert user_token.token_data is not None
|
||||
token_data = user_token.token_data.get_value(apply_mask=False)
|
||||
assert token_data["access_token"] == "new_token"
|
||||
assert token_data["refresh_token"] == "new_refresh"
|
||||
assert "expires_at" in token_data
|
||||
|
||||
@patch("onyx.auth.oauth_token_manager.requests.post")
|
||||
def test_refresh_token_preserves_refresh_token(
|
||||
@@ -249,7 +251,9 @@ class TestOAuthTokenManagerRefresh:
|
||||
|
||||
# Verify old refresh_token was preserved
|
||||
db_session.refresh(user_token)
|
||||
assert user_token.token_data["refresh_token"] == "old_refresh"
|
||||
assert user_token.token_data is not None
|
||||
token_data = user_token.token_data.get_value(apply_mask=False)
|
||||
assert token_data["refresh_token"] == "old_refresh"
|
||||
|
||||
@patch("onyx.auth.oauth_token_manager.requests.post")
|
||||
def test_refresh_token_http_error(
|
||||
|
||||
@@ -17,7 +17,8 @@ COPY ./tests/* /app/tests/
|
||||
|
||||
FROM base AS openapi-schema
|
||||
COPY ./scripts/onyx_openapi_schema.py /app/scripts/onyx_openapi_schema.py
|
||||
RUN python scripts/onyx_openapi_schema.py --filename openapi.json
|
||||
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
|
||||
RUN LICENSE_ENFORCEMENT_ENABLED=false python scripts/onyx_openapi_schema.py --filename openapi.json
|
||||
|
||||
FROM openapitools/openapi-generator-cli:latest AS openapi-client
|
||||
WORKDIR /local
|
||||
|
||||
@@ -40,6 +40,8 @@ def test_github_private_repo_permission_sync(
|
||||
) = github_test_env_setup
|
||||
|
||||
# Create GitHub client from credential
|
||||
# Note: github_credential is a DATestCredential (Pydantic model), not a SQLAlchemy model
|
||||
# so credential_json is already a plain dict
|
||||
github_access_token = github_credential.credential_json["github_access_token"]
|
||||
github_client = Github(github_access_token)
|
||||
github_manager = GitHubManager(github_client)
|
||||
@@ -158,6 +160,8 @@ def test_github_public_repo_permission_sync(
|
||||
) = github_test_env_setup
|
||||
|
||||
# Create GitHub client from credential
|
||||
# Note: github_credential is a DATestCredential (Pydantic model), not a SQLAlchemy model
|
||||
# so credential_json is already a plain dict
|
||||
github_access_token = github_credential.credential_json["github_access_token"]
|
||||
github_client = Github(github_access_token)
|
||||
github_manager = GitHubManager(github_client)
|
||||
@@ -262,6 +266,8 @@ def test_github_internal_repo_permission_sync(
|
||||
) = github_test_env_setup
|
||||
|
||||
# Create GitHub client from credential
|
||||
# Note: github_credential is a DATestCredential (Pydantic model), not a SQLAlchemy model
|
||||
# so credential_json is already a plain dict
|
||||
github_access_token = github_credential.credential_json["github_access_token"]
|
||||
github_client = Github(github_access_token)
|
||||
github_manager = GitHubManager(github_client)
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
from tests.integration.common_utils.reset import reset_all
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
|
||||
|
||||
@pytest.mark.skipif(
|
||||
os.getenv("PYTEST_IGNORE_SKIP") is None,
|
||||
reason="Skipped by default unless env var exists",
|
||||
)
|
||||
def test_playwright_setup() -> None:
|
||||
"""Not really a test, just using this to automate setup for playwright tests."""
|
||||
if not os.getenv("PYTEST_PLAYWRIGHT_SKIP_INITIAL_RESET", "").lower() == "true":
|
||||
reset_all()
|
||||
|
||||
# Creating an admin user (first user created is automatically an admin)
|
||||
admin_user: DATestUser = UserManager.create(name="admin_user")
|
||||
assert admin_user
|
||||
@@ -7,6 +7,7 @@ from ee.onyx.external_permissions.jira.doc_sync import jira_doc_sync
|
||||
from onyx.connectors.jira.connector import JiraConnector
|
||||
from onyx.connectors.jira.utils import JIRA_SERVER_API_VERSION
|
||||
from onyx.db.models import ConnectorCredentialPair
|
||||
from onyx.utils.sensitive import make_mock_sensitive_value
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -18,10 +19,12 @@ def mock_jira_cc_pair(
|
||||
) -> MagicMock:
|
||||
mock_cc_pair = MagicMock(spec=ConnectorCredentialPair)
|
||||
mock_cc_pair.connector = MagicMock()
|
||||
mock_cc_pair.credential.credential_json = {
|
||||
"jira_user_email": user_email,
|
||||
"jira_api_token": mock_jira_api_token,
|
||||
}
|
||||
mock_cc_pair.credential.credential_json = make_mock_sensitive_value(
|
||||
{
|
||||
"jira_user_email": user_email,
|
||||
"jira_api_token": mock_jira_api_token,
|
||||
}
|
||||
)
|
||||
mock_cc_pair.connector.connector_specific_config = {
|
||||
"jira_base_url": jira_base_url,
|
||||
"project_key": project_key,
|
||||
|
||||
@@ -247,11 +247,13 @@ class TestInstantiateConnectorIntegration:
|
||||
|
||||
def test_instantiate_connector_loads_class_lazily(self) -> None:
|
||||
"""Test that instantiate_connector triggers lazy loading."""
|
||||
from onyx.utils.sensitive import make_mock_sensitive_value
|
||||
|
||||
# Mock the database session and credential
|
||||
mock_session = MagicMock()
|
||||
mock_credential = MagicMock()
|
||||
mock_credential.id = 123
|
||||
mock_credential.credential_json = {"test": "data"}
|
||||
mock_credential.credential_json = make_mock_sensitive_value({"test": "data"})
|
||||
|
||||
# This should trigger lazy loading but will fail on actual instantiation
|
||||
# due to missing real configuration - that's expected
|
||||
|
||||
@@ -1,3 +1,7 @@
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from typing import Any
|
||||
from unittest.mock import ANY
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -828,3 +832,274 @@ def test_azure_openai_model_uses_httphandler_client() -> None:
|
||||
mock_completion.assert_called_once()
|
||||
kwargs = mock_completion.call_args.kwargs
|
||||
assert isinstance(kwargs["client"], HTTPHandler)
|
||||
|
||||
|
||||
def test_temporary_env_cleanup(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
# Assign some environment variables
|
||||
EXPECTED_ENV_VARS = {
|
||||
"TEST_ENV_VAR": "test_value",
|
||||
"ANOTHER_ONE": "1",
|
||||
"THIRD_ONE": "2",
|
||||
}
|
||||
|
||||
CUSTOM_CONFIG = {
|
||||
"TEST_ENV_VAR": "fdsfsdf",
|
||||
"ANOTHER_ONE": "3",
|
||||
"THIS_IS_RANDOM": "123213",
|
||||
}
|
||||
|
||||
for env_var, value in EXPECTED_ENV_VARS.items():
|
||||
monkeypatch.setenv(env_var, value)
|
||||
|
||||
model_provider = LlmProviderNames.OPENAI
|
||||
model_name = "gpt-3.5-turbo"
|
||||
|
||||
llm = LitellmLLM(
|
||||
api_key="test_key",
|
||||
timeout=30,
|
||||
model_provider=model_provider,
|
||||
model_name=model_name,
|
||||
max_input_tokens=get_max_input_tokens(
|
||||
model_provider=model_provider,
|
||||
model_name=model_name,
|
||||
),
|
||||
model_kwargs={"metadata": {"foo": "bar"}},
|
||||
custom_config=CUSTOM_CONFIG,
|
||||
)
|
||||
|
||||
# When custom_config is set, invoke() internally uses stream=True and
|
||||
# reassembles via stream_chunk_builder, so the mock must return stream chunks.
|
||||
mock_stream_chunks = [
|
||||
litellm.ModelResponse(
|
||||
id="chatcmpl-123",
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
delta=_create_delta(role="assistant", content="Hello"),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
model="gpt-3.5-turbo",
|
||||
),
|
||||
]
|
||||
|
||||
def on_litellm_completion(
|
||||
**kwargs: dict[str, Any], # noqa: ARG001
|
||||
) -> list[litellm.ModelResponse]:
|
||||
# Validate that the environment variables are those in custom config
|
||||
for env_var, value in CUSTOM_CONFIG.items():
|
||||
assert env_var in os.environ
|
||||
assert os.environ[env_var] == value
|
||||
|
||||
return mock_stream_chunks
|
||||
|
||||
with (
|
||||
patch("litellm.completion") as mock_completion,
|
||||
patch("onyx.llm.utils.SEND_USER_METADATA_TO_LLM_PROVIDER", False),
|
||||
):
|
||||
mock_completion.side_effect = on_litellm_completion
|
||||
|
||||
messages: LanguageModelInput = [UserMessage(content="Hi")]
|
||||
identity = LLMUserIdentity(user_id="user_123", session_id="session_abc")
|
||||
|
||||
llm.invoke(messages, user_identity=identity)
|
||||
|
||||
mock_completion.assert_called_once()
|
||||
kwargs = mock_completion.call_args.kwargs
|
||||
assert kwargs["stream"] is True
|
||||
assert "user" not in kwargs
|
||||
assert kwargs["metadata"]["foo"] == "bar"
|
||||
|
||||
# Check that the environment variables are back to the original values
|
||||
for env_var, value in EXPECTED_ENV_VARS.items():
|
||||
assert env_var in os.environ
|
||||
assert os.environ[env_var] == value
|
||||
|
||||
# Check that temporary env var from CUSTOM_CONFIG is no longer set
|
||||
assert "THIS_IS_RANDOM" not in os.environ
|
||||
|
||||
|
||||
def test_temporary_env_cleanup_on_exception(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Verify env vars are restored even when an exception occurs during LLM invocation."""
|
||||
# Assign some environment variables
|
||||
EXPECTED_ENV_VARS = {
|
||||
"TEST_ENV_VAR": "test_value",
|
||||
"ANOTHER_ONE": "1",
|
||||
"THIRD_ONE": "2",
|
||||
}
|
||||
|
||||
CUSTOM_CONFIG = {
|
||||
"TEST_ENV_VAR": "fdsfsdf",
|
||||
"ANOTHER_ONE": "3",
|
||||
"THIS_IS_RANDOM": "123213",
|
||||
}
|
||||
|
||||
for env_var, value in EXPECTED_ENV_VARS.items():
|
||||
monkeypatch.setenv(env_var, value)
|
||||
|
||||
model_provider = LlmProviderNames.OPENAI
|
||||
model_name = "gpt-3.5-turbo"
|
||||
|
||||
llm = LitellmLLM(
|
||||
api_key="test_key",
|
||||
timeout=30,
|
||||
model_provider=model_provider,
|
||||
model_name=model_name,
|
||||
max_input_tokens=get_max_input_tokens(
|
||||
model_provider=model_provider,
|
||||
model_name=model_name,
|
||||
),
|
||||
model_kwargs={"metadata": {"foo": "bar"}},
|
||||
custom_config=CUSTOM_CONFIG,
|
||||
)
|
||||
|
||||
def on_litellm_completion_raises(**kwargs: dict[str, Any]) -> None: # noqa: ARG001
|
||||
# Validate that the environment variables are those in custom config
|
||||
for env_var, value in CUSTOM_CONFIG.items():
|
||||
assert env_var in os.environ
|
||||
assert os.environ[env_var] == value
|
||||
|
||||
# Simulate an error during LLM call
|
||||
raise RuntimeError("Simulated LLM API failure")
|
||||
|
||||
with (
|
||||
patch("litellm.completion") as mock_completion,
|
||||
patch("onyx.llm.utils.SEND_USER_METADATA_TO_LLM_PROVIDER", False),
|
||||
):
|
||||
mock_completion.side_effect = on_litellm_completion_raises
|
||||
|
||||
messages: LanguageModelInput = [UserMessage(content="Hi")]
|
||||
identity = LLMUserIdentity(user_id="user_123", session_id="session_abc")
|
||||
|
||||
with pytest.raises(RuntimeError, match="Simulated LLM API failure"):
|
||||
llm.invoke(messages, user_identity=identity)
|
||||
|
||||
mock_completion.assert_called_once()
|
||||
|
||||
# Check that the environment variables are back to the original values
|
||||
for env_var, value in EXPECTED_ENV_VARS.items():
|
||||
assert env_var in os.environ
|
||||
assert os.environ[env_var] == value
|
||||
|
||||
# Check that temporary env var from CUSTOM_CONFIG is no longer set
|
||||
assert "THIS_IS_RANDOM" not in os.environ
|
||||
|
||||
|
||||
def test_multithreaded_custom_config_isolation(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Verify the env lock prevents concurrent LLM calls from seeing each other's custom_config.
|
||||
|
||||
Two LitellmLLM instances with different custom_config dicts invoke concurrently.
|
||||
The _env_lock in temporary_env_and_lock serializes their access so each call only
|
||||
ever sees its own env vars—never the other's.
|
||||
"""
|
||||
# Ensure these keys start unset
|
||||
monkeypatch.delenv("SHARED_KEY", raising=False)
|
||||
monkeypatch.delenv("LLM_A_ONLY", raising=False)
|
||||
monkeypatch.delenv("LLM_B_ONLY", raising=False)
|
||||
|
||||
CONFIG_A = {
|
||||
"SHARED_KEY": "value_from_A",
|
||||
"LLM_A_ONLY": "a_secret",
|
||||
}
|
||||
CONFIG_B = {
|
||||
"SHARED_KEY": "value_from_B",
|
||||
"LLM_B_ONLY": "b_secret",
|
||||
}
|
||||
|
||||
all_env_keys = list(set(list(CONFIG_A.keys()) + list(CONFIG_B.keys())))
|
||||
|
||||
model_provider = LlmProviderNames.OPENAI
|
||||
model_name = "gpt-3.5-turbo"
|
||||
|
||||
llm_a = LitellmLLM(
|
||||
api_key="key_a",
|
||||
timeout=30,
|
||||
model_provider=model_provider,
|
||||
model_name=model_name,
|
||||
max_input_tokens=get_max_input_tokens(
|
||||
model_provider=model_provider,
|
||||
model_name=model_name,
|
||||
),
|
||||
custom_config=CONFIG_A,
|
||||
)
|
||||
llm_b = LitellmLLM(
|
||||
api_key="key_b",
|
||||
timeout=30,
|
||||
model_provider=model_provider,
|
||||
model_name=model_name,
|
||||
max_input_tokens=get_max_input_tokens(
|
||||
model_provider=model_provider,
|
||||
model_name=model_name,
|
||||
),
|
||||
custom_config=CONFIG_B,
|
||||
)
|
||||
|
||||
# invoke() uses stream=True internally when custom_config is set,
|
||||
# so the mock must return stream chunks.
|
||||
mock_stream_chunks = [
|
||||
litellm.ModelResponse(
|
||||
id="chatcmpl-123",
|
||||
choices=[
|
||||
litellm.Choices(
|
||||
delta=_create_delta(role="assistant", content="Hi"),
|
||||
finish_reason="stop",
|
||||
index=0,
|
||||
)
|
||||
],
|
||||
model=model_name,
|
||||
),
|
||||
]
|
||||
|
||||
# Track what each call observed inside litellm.completion.
|
||||
# Keyed by api_key so we can identify which LLM instance made the call.
|
||||
observed_envs: dict[str, dict[str, str | None]] = {}
|
||||
|
||||
def fake_completion(**kwargs: Any) -> list[litellm.ModelResponse]:
|
||||
time.sleep(0.1) # We expect someone to get caught on the lock
|
||||
api_key = kwargs.get("api_key", "")
|
||||
label = "A" if api_key == "key_a" else "B"
|
||||
|
||||
snapshot: dict[str, str | None] = {}
|
||||
for key in all_env_keys:
|
||||
snapshot[key] = os.environ.get(key)
|
||||
observed_envs[label] = snapshot
|
||||
|
||||
return mock_stream_chunks
|
||||
|
||||
errors: list[Exception] = []
|
||||
|
||||
def run_llm(llm: LitellmLLM) -> None:
|
||||
try:
|
||||
messages: LanguageModelInput = [UserMessage(content="Hi")]
|
||||
llm.invoke(messages)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
with patch("litellm.completion", side_effect=fake_completion):
|
||||
t_a = threading.Thread(target=run_llm, args=(llm_a,))
|
||||
t_b = threading.Thread(target=run_llm, args=(llm_b,))
|
||||
|
||||
t_a.start()
|
||||
t_b.start()
|
||||
t_a.join(timeout=10)
|
||||
t_b.join(timeout=10)
|
||||
|
||||
assert not errors, f"Thread errors: {errors}"
|
||||
assert "A" in observed_envs and "B" in observed_envs
|
||||
|
||||
# Thread A must have seen its own config for SHARED_KEY, not B's
|
||||
assert observed_envs["A"]["SHARED_KEY"] == "value_from_A"
|
||||
assert observed_envs["A"]["LLM_A_ONLY"] == "a_secret"
|
||||
# A must NOT see B's exclusive key
|
||||
assert observed_envs["A"]["LLM_B_ONLY"] is None
|
||||
|
||||
# Thread B must have seen its own config for SHARED_KEY, not A's
|
||||
assert observed_envs["B"]["SHARED_KEY"] == "value_from_B"
|
||||
assert observed_envs["B"]["LLM_B_ONLY"] == "b_secret"
|
||||
# B must NOT see A's exclusive key
|
||||
assert observed_envs["B"]["LLM_A_ONLY"] is None
|
||||
|
||||
# After both calls, env should be clean
|
||||
assert os.environ.get("SHARED_KEY") is None
|
||||
assert os.environ.get("LLM_A_ONLY") is None
|
||||
assert os.environ.get("LLM_B_ONLY") is None
|
||||
|
||||
@@ -73,6 +73,10 @@ class TestGetOllamaAvailableModels:
|
||||
# Check display names are generated
|
||||
assert any("Llama" in r.display_name for r in results)
|
||||
assert any("Mistral" in r.display_name for r in results)
|
||||
# Results should be alphabetically sorted by model name
|
||||
assert [r.name for r in results] == sorted(
|
||||
[r.name for r in results], key=str.lower
|
||||
)
|
||||
|
||||
def test_syncs_to_db_when_provider_name_specified(
|
||||
self, mock_ollama_tags_response: dict, mock_ollama_show_response: dict
|
||||
|
||||
239
backend/tests/unit/onyx/utils/test_sensitive.py
Normal file
239
backend/tests/unit/onyx/utils/test_sensitive.py
Normal file
@@ -0,0 +1,239 @@
|
||||
"""Tests for SensitiveValue wrapper class."""
|
||||
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.utils.sensitive import SensitiveAccessError
|
||||
from onyx.utils.sensitive import SensitiveValue
|
||||
|
||||
|
||||
def _encrypt_string(value: str) -> bytes:
|
||||
"""Simple mock encryption (just encoding for tests)."""
|
||||
return value.encode("utf-8")
|
||||
|
||||
|
||||
def _decrypt_string(value: bytes) -> str:
|
||||
"""Simple mock decryption (just decoding for tests)."""
|
||||
return value.decode("utf-8")
|
||||
|
||||
|
||||
class TestSensitiveValueString:
|
||||
"""Tests for SensitiveValue with string values."""
|
||||
|
||||
def test_get_value_raw(self) -> None:
|
||||
"""Test getting raw unmasked value."""
|
||||
sensitive = SensitiveValue(
|
||||
encrypted_bytes=_encrypt_string("my-secret-token"),
|
||||
decrypt_fn=_decrypt_string,
|
||||
is_json=False,
|
||||
)
|
||||
assert sensitive.get_value(apply_mask=False) == "my-secret-token"
|
||||
|
||||
def test_get_value_masked(self) -> None:
|
||||
"""Test getting masked value with default masking."""
|
||||
sensitive = SensitiveValue(
|
||||
encrypted_bytes=_encrypt_string("my-very-long-secret-token-here"),
|
||||
decrypt_fn=_decrypt_string,
|
||||
is_json=False,
|
||||
)
|
||||
result = sensitive.get_value(apply_mask=True)
|
||||
# Default mask_string shows first 4 and last 4 chars
|
||||
assert result == "my-v...here"
|
||||
|
||||
def test_get_value_masked_short_string(self) -> None:
|
||||
"""Test that short strings are fully masked."""
|
||||
sensitive = SensitiveValue(
|
||||
encrypted_bytes=_encrypt_string("short"),
|
||||
decrypt_fn=_decrypt_string,
|
||||
is_json=False,
|
||||
)
|
||||
result = sensitive.get_value(apply_mask=True)
|
||||
# Short strings get fully masked
|
||||
assert result == "••••••••••••"
|
||||
|
||||
def test_get_value_custom_mask_fn(self) -> None:
|
||||
"""Test using a custom masking function."""
|
||||
sensitive = SensitiveValue(
|
||||
encrypted_bytes=_encrypt_string("secret"),
|
||||
decrypt_fn=_decrypt_string,
|
||||
is_json=False,
|
||||
)
|
||||
result = sensitive.get_value(
|
||||
apply_mask=True,
|
||||
mask_fn=lambda x: "REDACTED", # noqa: ARG005
|
||||
)
|
||||
assert result == "REDACTED"
|
||||
|
||||
def test_str_raises_error(self) -> None:
|
||||
"""Test that str() raises SensitiveAccessError."""
|
||||
sensitive = SensitiveValue(
|
||||
encrypted_bytes=_encrypt_string("secret"),
|
||||
decrypt_fn=_decrypt_string,
|
||||
is_json=False,
|
||||
)
|
||||
with pytest.raises(SensitiveAccessError):
|
||||
str(sensitive)
|
||||
|
||||
def test_repr_is_safe(self) -> None:
|
||||
"""Test that repr() doesn't expose the value."""
|
||||
sensitive = SensitiveValue(
|
||||
encrypted_bytes=_encrypt_string("secret"),
|
||||
decrypt_fn=_decrypt_string,
|
||||
is_json=False,
|
||||
)
|
||||
result = repr(sensitive)
|
||||
assert "secret" not in result
|
||||
assert "SensitiveValue" in result
|
||||
assert "get_value" in result
|
||||
|
||||
def test_iter_raises_error(self) -> None:
|
||||
"""Test that iteration raises SensitiveAccessError."""
|
||||
sensitive = SensitiveValue(
|
||||
encrypted_bytes=_encrypt_string("secret"),
|
||||
decrypt_fn=_decrypt_string,
|
||||
is_json=False,
|
||||
)
|
||||
with pytest.raises(SensitiveAccessError):
|
||||
for _ in sensitive: # type: ignore[attr-defined]
|
||||
pass
|
||||
|
||||
def test_getitem_raises_error(self) -> None:
|
||||
"""Test that subscript access raises SensitiveAccessError."""
|
||||
sensitive = SensitiveValue(
|
||||
encrypted_bytes=_encrypt_string("secret"),
|
||||
decrypt_fn=_decrypt_string,
|
||||
is_json=False,
|
||||
)
|
||||
with pytest.raises(SensitiveAccessError):
|
||||
_ = sensitive[0]
|
||||
|
||||
def test_bool_returns_true(self) -> None:
|
||||
"""Test that bool() works for truthiness checks."""
|
||||
sensitive = SensitiveValue(
|
||||
encrypted_bytes=_encrypt_string("secret"),
|
||||
decrypt_fn=_decrypt_string,
|
||||
is_json=False,
|
||||
)
|
||||
assert bool(sensitive) is True
|
||||
|
||||
def test_equality_with_same_value(self) -> None:
|
||||
"""Test equality comparison between SensitiveValues with same encrypted bytes."""
|
||||
encrypted = _encrypt_string("secret")
|
||||
sensitive1 = SensitiveValue(
|
||||
encrypted_bytes=encrypted,
|
||||
decrypt_fn=_decrypt_string,
|
||||
is_json=False,
|
||||
)
|
||||
sensitive2 = SensitiveValue(
|
||||
encrypted_bytes=encrypted,
|
||||
decrypt_fn=_decrypt_string,
|
||||
is_json=False,
|
||||
)
|
||||
assert sensitive1 == sensitive2
|
||||
|
||||
def test_equality_with_different_value(self) -> None:
|
||||
"""Test equality comparison between SensitiveValues with different encrypted bytes."""
|
||||
sensitive1 = SensitiveValue(
|
||||
encrypted_bytes=_encrypt_string("secret1"),
|
||||
decrypt_fn=_decrypt_string,
|
||||
is_json=False,
|
||||
)
|
||||
sensitive2 = SensitiveValue(
|
||||
encrypted_bytes=_encrypt_string("secret2"),
|
||||
decrypt_fn=_decrypt_string,
|
||||
is_json=False,
|
||||
)
|
||||
assert sensitive1 != sensitive2
|
||||
|
||||
def test_equality_with_non_sensitive_raises(self) -> None:
|
||||
"""Test that comparing with non-SensitiveValue raises error."""
|
||||
sensitive = SensitiveValue(
|
||||
encrypted_bytes=_encrypt_string("secret"),
|
||||
decrypt_fn=_decrypt_string,
|
||||
is_json=False,
|
||||
)
|
||||
with pytest.raises(SensitiveAccessError):
|
||||
_ = sensitive == "secret"
|
||||
|
||||
|
||||
class TestSensitiveValueJson:
|
||||
"""Tests for SensitiveValue with JSON/dict values."""
|
||||
|
||||
def test_get_value_raw_dict(self) -> None:
|
||||
"""Test getting raw unmasked dict value."""
|
||||
data: dict[str, Any] = {"api_key": "secret-key", "username": "user123"}
|
||||
sensitive: SensitiveValue[dict[str, Any]] = SensitiveValue(
|
||||
encrypted_bytes=_encrypt_string(json.dumps(data)),
|
||||
decrypt_fn=_decrypt_string,
|
||||
is_json=True,
|
||||
)
|
||||
result = sensitive.get_value(apply_mask=False)
|
||||
assert result == data
|
||||
|
||||
def test_get_value_masked_dict(self) -> None:
|
||||
"""Test getting masked dict value with default masking."""
|
||||
data = {"api_key": "my-very-long-api-key-value", "username": "user123456789"}
|
||||
sensitive = SensitiveValue(
|
||||
encrypted_bytes=_encrypt_string(json.dumps(data)),
|
||||
decrypt_fn=_decrypt_string,
|
||||
is_json=True,
|
||||
)
|
||||
result = sensitive.get_value(apply_mask=True)
|
||||
# Values should be masked
|
||||
assert "my-very-long-api-key-value" not in str(result)
|
||||
assert "user123456789" not in str(result)
|
||||
|
||||
def test_getitem_raises_error_for_dict(self) -> None:
|
||||
"""Test that subscript access raises SensitiveAccessError for dict."""
|
||||
data = {"api_key": "secret"}
|
||||
sensitive = SensitiveValue(
|
||||
encrypted_bytes=_encrypt_string(json.dumps(data)),
|
||||
decrypt_fn=_decrypt_string,
|
||||
is_json=True,
|
||||
)
|
||||
with pytest.raises(SensitiveAccessError):
|
||||
_ = sensitive["api_key"]
|
||||
|
||||
def test_iter_raises_error_for_dict(self) -> None:
|
||||
"""Test that iteration raises SensitiveAccessError for dict."""
|
||||
data = {"api_key": "secret"}
|
||||
sensitive = SensitiveValue(
|
||||
encrypted_bytes=_encrypt_string(json.dumps(data)),
|
||||
decrypt_fn=_decrypt_string,
|
||||
is_json=True,
|
||||
)
|
||||
with pytest.raises(SensitiveAccessError):
|
||||
for _ in sensitive: # type: ignore[attr-defined]
|
||||
pass
|
||||
|
||||
|
||||
class TestSensitiveValueCaching:
|
||||
"""Tests for lazy decryption caching."""
|
||||
|
||||
def test_decryption_is_cached(self) -> None:
|
||||
"""Test that decryption result is cached."""
|
||||
decrypt_count = [0]
|
||||
|
||||
def counting_decrypt(value: bytes) -> str:
|
||||
decrypt_count[0] += 1
|
||||
return value.decode("utf-8")
|
||||
|
||||
sensitive = SensitiveValue(
|
||||
encrypted_bytes=_encrypt_string("secret"),
|
||||
decrypt_fn=counting_decrypt,
|
||||
is_json=False,
|
||||
)
|
||||
|
||||
# First access
|
||||
sensitive.get_value(apply_mask=False)
|
||||
assert decrypt_count[0] == 1
|
||||
|
||||
# Second access should use cached value
|
||||
sensitive.get_value(apply_mask=False)
|
||||
assert decrypt_count[0] == 1
|
||||
|
||||
# Masked access should also use cached value
|
||||
sensitive.get_value(apply_mask=True)
|
||||
assert decrypt_count[0] == 1
|
||||
78
backend/tests/unit/onyx/utils/test_sensitive_typing.py
Normal file
78
backend/tests/unit/onyx/utils/test_sensitive_typing.py
Normal file
@@ -0,0 +1,78 @@
|
||||
"""
|
||||
Tests demonstrating static type checking for SensitiveValue.
|
||||
|
||||
Run with: mypy tests/unit/onyx/utils/test_sensitive_typing.py --ignore-missing-imports
|
||||
|
||||
These tests show what mypy will catch when SensitiveValue is misused.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
# This file demonstrates what mypy will catch.
|
||||
# The commented-out code below would produce type errors.
|
||||
|
||||
|
||||
def demonstrate_correct_usage() -> None:
|
||||
"""Shows correct patterns that pass type checking."""
|
||||
from onyx.utils.sensitive import SensitiveValue
|
||||
from onyx.utils.encryption import encrypt_string_to_bytes, decrypt_bytes_to_string
|
||||
|
||||
# Create a SensitiveValue
|
||||
encrypted = encrypt_string_to_bytes('{"api_key": "secret"}')
|
||||
sensitive: SensitiveValue[dict[str, Any]] = SensitiveValue(
|
||||
encrypted_bytes=encrypted,
|
||||
decrypt_fn=decrypt_bytes_to_string,
|
||||
is_json=True,
|
||||
)
|
||||
|
||||
# CORRECT: Using get_value() to access the value
|
||||
raw_dict: dict[str, Any] = sensitive.get_value(apply_mask=False)
|
||||
assert raw_dict["api_key"] == "secret"
|
||||
|
||||
masked_dict: dict[str, Any] = sensitive.get_value(apply_mask=True)
|
||||
assert "secret" not in str(masked_dict)
|
||||
|
||||
# CORRECT: Using bool for truthiness
|
||||
if sensitive:
|
||||
print("Value exists")
|
||||
|
||||
|
||||
# The code below demonstrates what mypy would catch.
|
||||
# Uncomment to see the type errors.
|
||||
"""
|
||||
def demonstrate_incorrect_usage() -> None:
|
||||
'''Shows patterns that mypy will flag as errors.'''
|
||||
from onyx.utils.sensitive import SensitiveValue
|
||||
from onyx.utils.encryption import encrypt_string_to_bytes, decrypt_bytes_to_string
|
||||
|
||||
encrypted = encrypt_string_to_bytes('{"api_key": "secret"}')
|
||||
sensitive: SensitiveValue[dict[str, Any]] = SensitiveValue(
|
||||
encrypted_bytes=encrypted,
|
||||
decrypt_fn=decrypt_bytes_to_string,
|
||||
is_json=True,
|
||||
)
|
||||
|
||||
# ERROR: SensitiveValue doesn't support subscript access
|
||||
# mypy error: Value of type "SensitiveValue[dict[str, Any]]" is not indexable
|
||||
api_key = sensitive["api_key"]
|
||||
|
||||
# ERROR: SensitiveValue doesn't support iteration
|
||||
# mypy error: "SensitiveValue[dict[str, Any]]" has no attribute "__iter__"
|
||||
for key in sensitive:
|
||||
print(key)
|
||||
|
||||
# ERROR: Can't pass SensitiveValue where dict is expected
|
||||
# mypy error: Argument 1 has incompatible type "SensitiveValue[dict[str, Any]]"; expected "dict[str, Any]"
|
||||
def process_dict(d: dict[str, Any]) -> None:
|
||||
pass
|
||||
process_dict(sensitive)
|
||||
|
||||
# ERROR: Can't use .get() on SensitiveValue
|
||||
# mypy error: "SensitiveValue[dict[str, Any]]" has no attribute "get"
|
||||
value = sensitive.get("api_key")
|
||||
"""
|
||||
|
||||
|
||||
def test_correct_usage_passes() -> None:
|
||||
"""This test runs the correct usage demonstration."""
|
||||
demonstrate_correct_usage()
|
||||
@@ -29,6 +29,8 @@ services:
|
||||
- MODEL_SERVER_PORT=${MODEL_SERVER_PORT:-}
|
||||
- ENV_SEED_CONFIGURATION=${ENV_SEED_CONFIGURATION:-}
|
||||
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=True
|
||||
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
|
||||
- LICENSE_ENFORCEMENT_ENABLED=false
|
||||
# MinIO configuration
|
||||
- S3_ENDPOINT_URL=${S3_ENDPOINT_URL:-http://minio:9000}
|
||||
- S3_AWS_ACCESS_KEY_ID=${S3_AWS_ACCESS_KEY_ID:-minioadmin}
|
||||
@@ -66,6 +68,8 @@ services:
|
||||
- INDEXING_MODEL_SERVER_HOST=${INDEXING_MODEL_SERVER_HOST:-indexing_model_server}
|
||||
- ENV_SEED_CONFIGURATION=${ENV_SEED_CONFIGURATION:-}
|
||||
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=True
|
||||
# TODO(Nik): https://linear.app/onyx-app/issue/ENG-1/update-test-infra-to-use-test-license
|
||||
- LICENSE_ENFORCEMENT_ENABLED=false
|
||||
# MinIO configuration
|
||||
- S3_ENDPOINT_URL=${S3_ENDPOINT_URL:-http://minio:9000}
|
||||
- S3_AWS_ACCESS_KEY_ID=${S3_AWS_ACCESS_KEY_ID:-minioadmin}
|
||||
|
||||
@@ -65,22 +65,7 @@ Bring up the entire application.
|
||||
npx playwright install
|
||||
```
|
||||
|
||||
1. Reset the instance
|
||||
|
||||
```cd backend
|
||||
export PYTEST_IGNORE_SKIP=true
|
||||
pytest -s tests/integration/tests/playwright/test_playwright.py
|
||||
```
|
||||
|
||||
If you don't want to reset your local instance, you can still run playwright tests
|
||||
with SKIP_AUTH=true. This is convenient but slightly different from what happens
|
||||
in CI so tests might pass locally and fail in CI.
|
||||
|
||||
```cd web
|
||||
SKIP_AUTH=true npx playwright test create_and_edit_assistant.spec.ts --project=admin
|
||||
```
|
||||
|
||||
2. Run playwright
|
||||
1. Run playwright
|
||||
|
||||
```
|
||||
cd web
|
||||
@@ -101,7 +86,7 @@ npx playwright test --ui
|
||||
npx playwright test --headed
|
||||
```
|
||||
|
||||
3. Inspect results
|
||||
2. Inspect results
|
||||
|
||||
By default, playwright.config.ts is configured to output the results to:
|
||||
|
||||
@@ -109,7 +94,7 @@ By default, playwright.config.ts is configured to output the results to:
|
||||
web/test-results
|
||||
```
|
||||
|
||||
4. Upload results to Chromatic (Optional)
|
||||
3. Upload results to Chromatic (Optional)
|
||||
|
||||
This step would normally not be run by third party developers, but first party devs
|
||||
may use this for local troubleshooting and testing.
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# Opal Components
|
||||
|
||||
High-level UI components built on the [`@opal/core`](../core/) primitives. Every component in this directory delegates state styling (hover, active, disabled, selected) to `Interactive.Base` via CSS data-attributes and the `--interactive-foreground` custom property — no duplicated Tailwind class maps.
|
||||
High-level UI components built on the [`@opal/core`](../core/) primitives. Every component in this directory delegates state styling (hover, active, disabled, transient) to `Interactive.Base` via CSS data-attributes and the `--interactive-foreground` custom property — no duplicated Tailwind class maps.
|
||||
|
||||
## Package export
|
||||
|
||||
|
||||
@@ -7,34 +7,36 @@ A single component that handles both labeled buttons and icon-only buttons. It r
|
||||
## Architecture
|
||||
|
||||
```
|
||||
Interactive.Base <- variant/subvariant, selected, disabled, href, onClick
|
||||
└─ Interactive.Container <- height, rounding, padding (derived from `size`)
|
||||
Interactive.Base <- variant/subvariant, transient, disabled, href, onClick
|
||||
└─ Interactive.Container <- height, rounding, padding (derived from `size`), border (auto for secondary)
|
||||
└─ div.opal-button.interactive-foreground <- flexbox row layout
|
||||
├─ Icon? .opal-button-icon (1rem x 1rem, shrink-0)
|
||||
├─ <span>? .opal-button-label (whitespace-nowrap, font)
|
||||
└─ RightIcon? .opal-button-icon
|
||||
├─ div.p-0.5 > Icon? (compact: 12px, default: 16px, shrink-0)
|
||||
├─ <span>? .opal-button-label (whitespace-nowrap, font)
|
||||
└─ div.p-0.5 > RightIcon? (compact: 12px, default: 16px, shrink-0)
|
||||
```
|
||||
|
||||
- **Colors are not in the Button.** `Interactive.Base` sets `background-color` and `--interactive-foreground` per variant/subvariant/state. The `.interactive-foreground` utility class on the content div sets `color: var(--interactive-foreground)`, which both the `<span>` text and `stroke="currentColor"` SVG icons inherit automatically.
|
||||
- **Layout is in `styles.css`.** The CSS classes (`.opal-button`, `.opal-button-icon`, `.opal-button-label`) handle flexbox alignment, gap, icon sizing, and text styling. A `[data-size="compact"]` selector tightens the gap and reduces font size.
|
||||
- **Layout is in `styles.css`.** The CSS classes (`.opal-button`, `.opal-button-label`) handle flexbox alignment, gap, and text styling. Default labels use `font-main-ui-action` (14px/600); compact labels use `font-secondary-action` (12px/600) via a `[data-size="compact"]` selector.
|
||||
- **Sizing is delegated to `Interactive.Container` presets.** The `size` prop maps to Container height/rounding/padding presets:
|
||||
- `"default"` -> height 2.25rem, rounding 12px, padding 8px
|
||||
- `"compact"` -> height 1.75rem, rounding 8px, padding 4px
|
||||
- **Icon-only buttons render as squares** because `Interactive.Container` enforces `min-width >= height` for every height preset.
|
||||
- **Border is automatic for `subvariant="secondary"`.** The Container receives `border={subvariant === "secondary"}` internally — there is no external `border` prop.
|
||||
|
||||
## Props
|
||||
|
||||
| Prop | Type | Default | Description |
|
||||
|------|------|---------|-------------|
|
||||
| `variant` | `"default" \| "action" \| "danger" \| "none" \| "select"` | `"default"` | Top-level color variant (maps to `Interactive.Base`) |
|
||||
| `subvariant` | Depends on `variant` | `"primary"` | Color subvariant -- e.g. `"primary"`, `"secondary"`, `"ghost"` for default/action/danger |
|
||||
| `subvariant` | Depends on `variant` | `"primary"` | Color subvariant -- e.g. `"primary"`, `"secondary"`, `"ghost"` for default/action/danger. `"secondary"` automatically renders a border. |
|
||||
| `icon` | `IconFunctionComponent` | -- | Left icon component |
|
||||
| `children` | `string` | -- | Button label text. Omit for icon-only buttons |
|
||||
| `rightIcon` | `IconFunctionComponent` | -- | Right icon component |
|
||||
| `size` | `SizeVariant` | `"default"` | Size preset controlling height, rounding, padding, gap, and font size |
|
||||
| `size` | `SizeVariant` | `"default"` | Size preset controlling height, rounding, padding, icon size, and font style |
|
||||
| `tooltip` | `string` | -- | Tooltip text shown on hover |
|
||||
| `tooltipSide` | `TooltipSide` | `"top"` | Which side the tooltip appears on |
|
||||
| `selected` | `boolean` | `false` | Forces the selected visual state (data-selected) |
|
||||
| `selected` | `boolean` | `false` | Switches foreground to action-link colours (only available with `variant="select"`) |
|
||||
| `transient` | `boolean` | `false` | Forces the transient (hover) visual state (data-transient) |
|
||||
| `disabled` | `boolean` | `false` | Disables the button (data-disabled, aria-disabled) |
|
||||
| `href` | `string` | -- | URL; renders an `<a>` wrapper instead of Radix Slot |
|
||||
| `onClick` | `MouseEventHandler<HTMLElement>` | -- | Click handler |
|
||||
@@ -59,7 +61,7 @@ import { SvgPlus, SvgArrowRight } from "@opal/icons";
|
||||
Add item
|
||||
</Button>
|
||||
|
||||
// Labeled button with right icon
|
||||
// Secondary button (automatically renders a border)
|
||||
<Button rightIcon={SvgArrowRight} variant="default" subvariant="secondary">
|
||||
Continue
|
||||
</Button>
|
||||
@@ -74,8 +76,8 @@ import { SvgPlus, SvgArrowRight } from "@opal/icons";
|
||||
Settings
|
||||
</Button>
|
||||
|
||||
// Selected state (e.g. inside a popover trigger)
|
||||
<Button icon={SvgFilter} subvariant="ghost" selected={isOpen} />
|
||||
// Transient state (e.g. inside a popover trigger)
|
||||
<Button icon={SvgFilter} subvariant="ghost" transient={isOpen} />
|
||||
|
||||
// With tooltip
|
||||
<Button icon={SvgPlus} subvariant="ghost" tooltip="Add item" />
|
||||
@@ -91,7 +93,7 @@ import { SvgPlus, SvgArrowRight } from "@opal/icons";
|
||||
| `primary` | `subvariant="primary"` (default, can be omitted) |
|
||||
| `secondary` | `subvariant="secondary"` |
|
||||
| `tertiary` | `subvariant="ghost"` |
|
||||
| `transient={x}` | `selected={x}` |
|
||||
| `transient={x}` | `transient={x}` |
|
||||
| `size="md"` | `size="compact"` |
|
||||
| `size="lg"` | `size="default"` (default, can be omitted) |
|
||||
| `leftIcon={X}` | `icon={X}` |
|
||||
|
||||
@@ -4,6 +4,7 @@ import { Interactive, type InteractiveBaseProps } from "@opal/core";
|
||||
import type { SizeVariant, TooltipSide } from "@opal/components";
|
||||
import type { IconFunctionComponent } from "@opal/types";
|
||||
import * as TooltipPrimitive from "@radix-ui/react-tooltip";
|
||||
import { cn } from "@opal/utils";
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Types
|
||||
@@ -29,6 +30,25 @@ type ButtonProps = InteractiveBaseProps & {
|
||||
tooltipSide?: TooltipSide;
|
||||
};
|
||||
|
||||
function iconWrapper(
|
||||
Icon: IconFunctionComponent | undefined,
|
||||
isCompact: boolean
|
||||
) {
|
||||
return Icon ? (
|
||||
<div className="p-0.5">
|
||||
<Icon
|
||||
className={cn(
|
||||
"shrink-0",
|
||||
isCompact ? "h-[0.75rem] w-[0.75rem]" : "h-[1rem] w-[1rem]"
|
||||
)}
|
||||
size={isCompact ? 12 : 16}
|
||||
/>
|
||||
</div>
|
||||
) : (
|
||||
<div className="w-[0.125rem]" />
|
||||
);
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Button
|
||||
// ---------------------------------------------------------------------------
|
||||
@@ -40,29 +60,31 @@ function Button({
|
||||
size = "default",
|
||||
tooltip,
|
||||
tooltipSide = "top",
|
||||
variant,
|
||||
subvariant,
|
||||
...baseProps
|
||||
...interactiveBaseProps
|
||||
}: ButtonProps) {
|
||||
const isCompact = size === "compact";
|
||||
|
||||
const button = (
|
||||
<Interactive.Base
|
||||
{...({ variant, subvariant } as InteractiveBaseProps)}
|
||||
{...baseProps}
|
||||
>
|
||||
<Interactive.Base {...interactiveBaseProps}>
|
||||
<Interactive.Container
|
||||
border={interactiveBaseProps.subvariant === "secondary"}
|
||||
heightVariant={isCompact ? "compact" : "default"}
|
||||
roundingVariant={isCompact ? "compact" : "default"}
|
||||
paddingVariant={isCompact ? "thin" : "default"}
|
||||
>
|
||||
<div
|
||||
className="opal-button interactive-foreground"
|
||||
data-size={isCompact ? "compact" : undefined}
|
||||
>
|
||||
{Icon && <Icon className="opal-button-icon" />}
|
||||
{children && <span className="opal-button-label">{children}</span>}
|
||||
{RightIcon && <RightIcon className="opal-button-icon" />}
|
||||
<div className="opal-button interactive-foreground">
|
||||
{iconWrapper(Icon, isCompact)}
|
||||
{children && (
|
||||
<span
|
||||
className={cn(
|
||||
"opal-button-label",
|
||||
isCompact ? "font-secondary-action" : "font-main-ui-action"
|
||||
)}
|
||||
>
|
||||
{children}
|
||||
</span>
|
||||
)}
|
||||
{iconWrapper(RightIcon, isCompact)}
|
||||
</div>
|
||||
</Interactive.Container>
|
||||
</Interactive.Base>
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user