Compare commits

..

2 Commits

Author SHA1 Message Date
pablonyx
a821b833ac k 2025-03-05 15:36:29 -08:00
pablonyx
751af6824a update 2025-03-05 12:28:23 -08:00
109 changed files with 1213 additions and 1642 deletions

View File

@@ -12,40 +12,29 @@ env:
BUILDKIT_PROGRESS: plain
jobs:
# Bypassing this for now as the idea of not building is glitching
# releases and builds that depends on everything being tagged in docker
# 1) Preliminary job to check if the changed files are relevant
# check_model_server_changes:
# runs-on: ubuntu-latest
# outputs:
# changed: ${{ steps.check.outputs.changed }}
# steps:
# - name: Checkout code
# uses: actions/checkout@v4
#
# - name: Check if relevant files changed
# id: check
# run: |
# # Default to "false"
# echo "changed=false" >> $GITHUB_OUTPUT
#
# # Compare the previous commit (github.event.before) to the current one (github.sha)
# # If any file in backend/model_server/** or backend/Dockerfile.model_server is changed,
# # set changed=true
# if git diff --name-only ${{ github.event.before }} ${{ github.sha }} \
# | grep -E '^backend/model_server/|^backend/Dockerfile.model_server'; then
# echo "changed=true" >> $GITHUB_OUTPUT
# fi
# 1) Preliminary job to check if the changed files are relevant
check_model_server_changes:
runs-on: ubuntu-latest
outputs:
changed: "true"
changed: ${{ steps.check.outputs.changed }}
steps:
- name: Bypass check and set output
run: echo "changed=true" >> $GITHUB_OUTPUT
- name: Checkout code
uses: actions/checkout@v4
- name: Check if relevant files changed
id: check
run: |
# Default to "false"
echo "changed=false" >> $GITHUB_OUTPUT
# Compare the previous commit (github.event.before) to the current one (github.sha)
# If any file in backend/model_server/** or backend/Dockerfile.model_server is changed,
# set changed=true
if git diff --name-only ${{ github.event.before }} ${{ github.sha }} \
| grep -E '^backend/model_server/|^backend/Dockerfile.model_server'; then
echo "changed=true" >> $GITHUB_OUTPUT
fi
build-amd64:
needs: [check_model_server_changes]
if: needs.check_model_server_changes.outputs.changed == 'true'

View File

@@ -1,7 +1,6 @@
name: Connector Tests
on:
merge_group:
pull_request:
branches: [main]
schedule:
@@ -52,7 +51,7 @@ env:
jobs:
connectors-check:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}"]
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
env:
PYTHONPATH: ./backend
@@ -77,7 +76,7 @@ jobs:
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
playwright install chromium
playwright install-deps chromium
- name: Run Tests
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
run: py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/connectors

View File

@@ -1,125 +0,0 @@
"""Update GitHub connector repo_name to repositories
Revision ID: 3934b1bc7b62
Revises: b7c2b63c4a03
Create Date: 2025-03-05 10:50:30.516962
"""
from alembic import op
import sqlalchemy as sa
import json
import logging
# revision identifiers, used by Alembic.
revision = "3934b1bc7b62"
down_revision = "b7c2b63c4a03"
branch_labels = None
depends_on = None
logger = logging.getLogger("alembic.runtime.migration")
def upgrade() -> None:
# Get all GitHub connectors
conn = op.get_bind()
# First get all GitHub connectors
github_connectors = conn.execute(
sa.text(
"""
SELECT id, connector_specific_config
FROM connector
WHERE source = 'GITHUB'
"""
)
).fetchall()
# Update each connector's config
updated_count = 0
for connector_id, config in github_connectors:
try:
if not config:
logger.warning(f"Connector {connector_id} has no config, skipping")
continue
# Parse the config if it's a string
if isinstance(config, str):
config = json.loads(config)
if "repo_name" not in config:
continue
# Create new config with repositories instead of repo_name
new_config = dict(config)
repo_name_value = new_config.pop("repo_name")
new_config["repositories"] = repo_name_value
# Update the connector with the new config
conn.execute(
sa.text(
"""
UPDATE connector
SET connector_specific_config = :new_config
WHERE id = :connector_id
"""
),
{"connector_id": connector_id, "new_config": json.dumps(new_config)},
)
updated_count += 1
except Exception as e:
logger.error(f"Error updating connector {connector_id}: {str(e)}")
def downgrade() -> None:
# Get all GitHub connectors
conn = op.get_bind()
logger.debug(
"Starting rollback of GitHub connectors from repositories to repo_name"
)
github_connectors = conn.execute(
sa.text(
"""
SELECT id, connector_specific_config
FROM connector
WHERE source = 'GITHUB'
"""
)
).fetchall()
logger.debug(f"Found {len(github_connectors)} GitHub connectors to rollback")
# Revert each GitHub connector to use repo_name instead of repositories
reverted_count = 0
for connector_id, config in github_connectors:
try:
if not config:
continue
# Parse the config if it's a string
if isinstance(config, str):
config = json.loads(config)
if "repositories" not in config:
continue
# Create new config with repo_name instead of repositories
new_config = dict(config)
repositories_value = new_config.pop("repositories")
new_config["repo_name"] = repositories_value
# Update the connector with the new config
conn.execute(
sa.text(
"""
UPDATE connector
SET connector_specific_config = :new_config
WHERE id = :connector_id
"""
),
{"new_config": json.dumps(new_config), "connector_id": connector_id},
)
reverted_count += 1
except Exception as e:
logger.error(f"Error reverting connector {connector_id}: {str(e)}")

View File

@@ -134,9 +134,7 @@ def fetch_chat_sessions_eagerly_by_time(
limit: int | None = 500,
initial_time: datetime | None = None,
) -> list[ChatSession]:
"""Sorted by oldest to newest, then by message id"""
asc_time_order: UnaryExpression = asc(ChatSession.time_created)
time_order: UnaryExpression = desc(ChatSession.time_created)
message_order: UnaryExpression = asc(ChatMessage.id)
filters: list[ColumnElement | BinaryExpression] = [
@@ -149,7 +147,8 @@ def fetch_chat_sessions_eagerly_by_time(
subquery = (
db_session.query(ChatSession.id, ChatSession.time_created)
.filter(*filters)
.order_by(asc_time_order)
.order_by(ChatSession.id, time_order)
.distinct(ChatSession.id)
.limit(limit)
.subquery()
)
@@ -165,7 +164,7 @@ def fetch_chat_sessions_eagerly_by_time(
ChatMessage.chat_message_feedbacks
),
)
.order_by(asc_time_order, message_order)
.order_by(time_order, message_order)
)
chat_sessions = query.all()

View File

@@ -16,20 +16,13 @@ from onyx.db.models import UsageReport
from onyx.file_store.file_store import get_default_file_store
# Gets skeletons of all messages in the given range
# Gets skeletons of all message
def get_empty_chat_messages_entries__paginated(
db_session: Session,
period: tuple[datetime, datetime],
limit: int | None = 500,
initial_time: datetime | None = None,
) -> tuple[Optional[datetime], list[ChatMessageSkeleton]]:
"""Returns a tuple where:
first element is the most recent timestamp out of the sessions iterated
- this timestamp can be used to paginate forward in time
second element is a list of messages belonging to all the sessions iterated
Only messages of type USER are returned
"""
chat_sessions = fetch_chat_sessions_eagerly_by_time(
start=period[0],
end=period[1],
@@ -59,17 +52,18 @@ def get_empty_chat_messages_entries__paginated(
if len(chat_sessions) == 0:
return None, []
return chat_sessions[-1].time_created, message_skeletons
return chat_sessions[0].time_created, message_skeletons
def get_all_empty_chat_message_entries(
db_session: Session,
period: tuple[datetime, datetime],
) -> Generator[list[ChatMessageSkeleton], None, None]:
"""period is the range of time over which to fetch messages."""
initial_time: Optional[datetime] = period[0]
ind = 0
while True:
# iterate from oldest to newest
ind += 1
time_created, message_skeletons = get_empty_chat_messages_entries__paginated(
db_session,
period,

View File

@@ -15,7 +15,7 @@ from ee.onyx.server.enterprise_settings.api import (
)
from ee.onyx.server.manage.standard_answer import router as standard_answer_router
from ee.onyx.server.middleware.tenant_tracking import add_tenant_id_middleware
from ee.onyx.server.oauth.api import router as ee_oauth_router
from ee.onyx.server.oauth.api import router as oauth_router
from ee.onyx.server.query_and_chat.chat_backend import (
router as chat_router,
)
@@ -26,7 +26,7 @@ from ee.onyx.server.query_history.api import router as query_history_router
from ee.onyx.server.reporting.usage_export_api import router as usage_export_router
from ee.onyx.server.saml import router as saml_router
from ee.onyx.server.seeding import seed_db
from ee.onyx.server.tenants.api import router as tenants_router
from ee.onyx.server.tenants.router import router as tenants_router
from ee.onyx.server.token_rate_limits.api import (
router as token_rate_limit_settings_router,
)
@@ -128,7 +128,7 @@ def get_application() -> FastAPI:
include_router_with_global_prefix_prepended(application, query_router)
include_router_with_global_prefix_prepended(application, chat_router)
include_router_with_global_prefix_prepended(application, standard_answer_router)
include_router_with_global_prefix_prepended(application, ee_oauth_router)
include_router_with_global_prefix_prepended(application, oauth_router)
# Enterprise-only global settings
include_router_with_global_prefix_prepended(

View File

@@ -80,7 +80,6 @@ class ConfluenceCloudOAuth:
"search:confluence%20"
# granular scope
"read:attachment:confluence%20" # possibly unneeded unless calling v2 attachments api
"read:content-details:confluence%20" # for permission sync
"offline_access"
)

View File

@@ -48,15 +48,10 @@ def fetch_and_process_chat_session_history(
feedback_type: QAFeedbackType | None,
limit: int | None = 500,
) -> list[ChatSessionSnapshot]:
# observed to be slow a scale of 8192 sessions and 4 messages per session
# this is a little slow (5 seconds)
chat_sessions = fetch_chat_sessions_eagerly_by_time(
start=start, end=end, db_session=db_session, limit=limit
)
# this is VERY slow (80 seconds) due to create_chat_chain being called
# for each session. Needs optimizing.
chat_session_snapshots = [
snapshot_from_chat_session(chat_session=chat_session, db_session=db_session)
for chat_session in chat_sessions
@@ -251,8 +246,6 @@ def get_query_history_as_csv(
detail="Query history has been disabled by the administrator.",
)
# this call is very expensive and is timing out via endpoint
# TODO: optimize call and/or generate via background task
complete_chat_session_history = fetch_and_process_chat_session_history(
db_session=db_session,
start=start or datetime.fromtimestamp(0, tz=timezone.utc),

View File

@@ -0,0 +1,45 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Response
from ee.onyx.auth.users import current_cloud_superuser
from ee.onyx.server.tenants.models import ImpersonateRequest
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
from onyx.auth.users import auth_backend
from onyx.auth.users import get_redis_strategy
from onyx.auth.users import User
from onyx.db.engine import get_session_with_tenant
from onyx.db.users import get_user_by_email
from onyx.utils.logger import setup_logger
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.post("/impersonate")
async def impersonate_user(
impersonate_request: ImpersonateRequest,
_: User = Depends(current_cloud_superuser),
) -> Response:
"""Allows a cloud superuser to impersonate another user by generating an impersonation JWT token"""
tenant_id = get_tenant_id_for_email(impersonate_request.email)
with get_session_with_tenant(tenant_id=tenant_id) as tenant_session:
user_to_impersonate = get_user_by_email(
impersonate_request.email, tenant_session
)
if user_to_impersonate is None:
raise HTTPException(status_code=404, detail="User not found")
token = await get_redis_strategy().write_token(user_to_impersonate)
response = await auth_backend.transport.get_login_response(token)
response.set_cookie(
key="fastapiusersauth",
value=token,
httponly=True,
secure=True,
samesite="lax",
)
return response

View File

@@ -0,0 +1,98 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Response
from sqlalchemy.exc import IntegrityError
from ee.onyx.auth.users import generate_anonymous_user_jwt_token
from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME
from ee.onyx.server.tenants.anonymous_user_path import get_anonymous_user_path
from ee.onyx.server.tenants.anonymous_user_path import (
get_tenant_id_for_anonymous_user_path,
)
from ee.onyx.server.tenants.anonymous_user_path import modify_anonymous_user_path
from ee.onyx.server.tenants.anonymous_user_path import validate_anonymous_user_path
from ee.onyx.server.tenants.models import AnonymousUserPath
from onyx.auth.users import anonymous_user_enabled
from onyx.auth.users import current_admin_user
from onyx.auth.users import optional_user
from onyx.auth.users import User
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
from onyx.db.engine import get_session_with_shared_schema
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.get("/anonymous-user-path")
async def get_anonymous_user_path_api(
_: User | None = Depends(current_admin_user),
) -> AnonymousUserPath:
tenant_id = get_current_tenant_id()
if tenant_id is None:
raise HTTPException(status_code=404, detail="Tenant not found")
with get_session_with_shared_schema() as db_session:
current_path = get_anonymous_user_path(tenant_id, db_session)
return AnonymousUserPath(anonymous_user_path=current_path)
@router.post("/anonymous-user-path")
async def set_anonymous_user_path_api(
anonymous_user_path: str,
_: User | None = Depends(current_admin_user),
) -> None:
tenant_id = get_current_tenant_id()
try:
validate_anonymous_user_path(anonymous_user_path)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
with get_session_with_shared_schema() as db_session:
try:
modify_anonymous_user_path(tenant_id, anonymous_user_path, db_session)
except IntegrityError:
raise HTTPException(
status_code=409,
detail="The anonymous user path is already in use. Please choose a different path.",
)
except Exception as e:
logger.exception(f"Failed to modify anonymous user path: {str(e)}")
raise HTTPException(
status_code=500,
detail="An unexpected error occurred while modifying the anonymous user path",
)
@router.post("/anonymous-user")
async def login_as_anonymous_user(
anonymous_user_path: str,
_: User | None = Depends(optional_user),
) -> Response:
with get_session_with_shared_schema() as db_session:
tenant_id = get_tenant_id_for_anonymous_user_path(
anonymous_user_path, db_session
)
if not tenant_id:
raise HTTPException(status_code=404, detail="Tenant not found")
if not anonymous_user_enabled(tenant_id=tenant_id):
raise HTTPException(status_code=403, detail="Anonymous user is not enabled")
token = generate_anonymous_user_jwt_token(tenant_id)
response = Response()
response.delete_cookie(FASTAPI_USERS_AUTH_COOKIE_NAME)
response.set_cookie(
key=ANONYMOUS_USER_COOKIE_NAME,
value=token,
httponly=True,
secure=True,
samesite="strict",
)
return response

View File

@@ -0,0 +1,143 @@
import asyncio
import logging
from sqlalchemy.orm import Session
from ee.onyx.configs.app_configs import ANTHROPIC_DEFAULT_API_KEY
from ee.onyx.configs.app_configs import COHERE_DEFAULT_API_KEY
from ee.onyx.configs.app_configs import OPENAI_DEFAULT_API_KEY
from ee.onyx.server.tenants.schema_management import run_alembic_migrations
from onyx.configs.constants import MilestoneRecordType
from onyx.db.engine import get_session_with_tenant
from onyx.db.llm import update_default_provider
from onyx.db.llm import upsert_cloud_embedding_provider
from onyx.db.llm import upsert_llm_provider
from onyx.db.models import IndexModelStatus
from onyx.db.models import SearchSettings
from onyx.llm.llm_provider_options import ANTHROPIC_MODEL_NAMES
from onyx.llm.llm_provider_options import ANTHROPIC_PROVIDER_NAME
from onyx.llm.llm_provider_options import OPEN_AI_MODEL_NAMES
from onyx.llm.llm_provider_options import OPENAI_PROVIDER_NAME
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.setup import setup_onyx
from onyx.utils.telemetry import create_milestone_and_report
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.enums import EmbeddingProvider
logger = logging.getLogger(__name__)
async def complete_tenant_setup(tenant_id: str, email: str) -> None:
"""
Complete the tenant setup process asynchronously after the essential migrations
have been applied. This includes:
1. Running the remaining Alembic migrations
2. Setting up Onyx
3. Creating milestone records
"""
logger.info(f"Starting asynchronous tenant setup for tenant {tenant_id}")
token = None
try:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
# Run the remaining Alembic migrations
await asyncio.to_thread(run_alembic_migrations, tenant_id)
# Configure default API keys
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
configure_default_api_keys(db_session)
# Setup Onyx
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
current_search_settings = (
db_session.query(SearchSettings)
.filter_by(status=IndexModelStatus.FUTURE)
.first()
)
cohere_enabled = (
current_search_settings is not None
and current_search_settings.provider_type == EmbeddingProvider.COHERE
)
setup_onyx(db_session, tenant_id, cohere_enabled=cohere_enabled)
# Create milestone record
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
create_milestone_and_report(
user=None,
distinct_id=tenant_id,
event_type=MilestoneRecordType.TENANT_CREATED,
properties={
"email": email,
},
db_session=db_session,
)
logger.info(f"Asynchronous tenant setup completed for tenant {tenant_id}")
except Exception as e:
logger.exception(
f"Failed to complete asynchronous tenant setup for tenant {tenant_id}: {e}"
)
finally:
if token is not None:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def configure_default_api_keys(db_session: Session) -> None:
if ANTHROPIC_DEFAULT_API_KEY:
anthropic_provider = LLMProviderUpsertRequest(
name="Anthropic",
provider=ANTHROPIC_PROVIDER_NAME,
api_key=ANTHROPIC_DEFAULT_API_KEY,
default_model_name="claude-3-7-sonnet-20250219",
fast_default_model_name="claude-3-5-sonnet-20241022",
model_names=ANTHROPIC_MODEL_NAMES,
display_model_names=["claude-3-5-sonnet-20241022"],
)
try:
full_provider = upsert_llm_provider(anthropic_provider, db_session)
update_default_provider(full_provider.id, db_session)
except Exception as e:
logger.error(f"Failed to configure Anthropic provider: {e}")
else:
logger.error(
"ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration"
)
if OPENAI_DEFAULT_API_KEY:
open_provider = LLMProviderUpsertRequest(
name="OpenAI",
provider=OPENAI_PROVIDER_NAME,
api_key=OPENAI_DEFAULT_API_KEY,
default_model_name="gpt-4o",
fast_default_model_name="gpt-4o-mini",
model_names=OPEN_AI_MODEL_NAMES,
display_model_names=["o1", "o3-mini", "gpt-4o", "gpt-4o-mini"],
)
try:
full_provider = upsert_llm_provider(open_provider, db_session)
update_default_provider(full_provider.id, db_session)
except Exception as e:
logger.error(f"Failed to configure OpenAI provider: {e}")
else:
logger.error(
"OPENAI_DEFAULT_API_KEY not set, skipping OpenAI provider configuration"
)
if COHERE_DEFAULT_API_KEY:
cloud_embedding_provider = CloudEmbeddingProviderCreationRequest(
provider_type=EmbeddingProvider.COHERE,
api_key=COHERE_DEFAULT_API_KEY,
)
try:
logger.info("Attempting to upsert Cohere cloud embedding provider")
upsert_cloud_embedding_provider(cloud_embedding_provider, db_session)
except Exception as e:
logger.error(f"Failed to configure Cohere provider: {e}")
else:
logger.error(
"COHERE_DEFAULT_API_KEY not set, skipping Cohere provider configuration"
)

View File

@@ -0,0 +1,96 @@
import stripe
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from ee.onyx.auth.users import current_admin_user
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
from ee.onyx.server.tenants.access import control_plane_dep
from ee.onyx.server.tenants.billing import fetch_billing_information
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
from ee.onyx.server.tenants.models import BillingInformation
from ee.onyx.server.tenants.models import ProductGatingRequest
from ee.onyx.server.tenants.models import ProductGatingResponse
from ee.onyx.server.tenants.models import SubscriptionSessionResponse
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
from ee.onyx.server.tenants.product_gating import store_product_gating
from onyx.auth.users import User
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
stripe.api_key = STRIPE_SECRET_KEY
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.post("/product-gating")
def gate_product(
product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep)
) -> ProductGatingResponse:
"""
Gating the product means that the product is not available to the tenant.
They will be directed to the billing page.
We gate the product when their subscription has ended.
"""
try:
store_product_gating(
product_gating_request.tenant_id, product_gating_request.application_status
)
return ProductGatingResponse(updated=True, error=None)
except Exception as e:
logger.exception("Failed to gate product")
return ProductGatingResponse(updated=False, error=str(e))
@router.get("/billing-information")
async def billing_information(
_: User = Depends(current_admin_user),
) -> BillingInformation | SubscriptionStatusResponse:
logger.info("Fetching billing information")
tenant_id = get_current_tenant_id()
return fetch_billing_information(tenant_id)
@router.post("/create-customer-portal-session")
async def create_customer_portal_session(
_: User = Depends(current_admin_user),
) -> dict:
tenant_id = get_current_tenant_id()
try:
stripe_info = fetch_tenant_stripe_information(tenant_id)
stripe_customer_id = stripe_info.get("stripe_customer_id")
if not stripe_customer_id:
raise HTTPException(status_code=400, detail="Stripe customer ID not found")
logger.info(stripe_customer_id)
portal_session = stripe.billing_portal.Session.create(
customer=stripe_customer_id,
return_url=f"{WEB_DOMAIN}/admin/billing",
)
logger.info(portal_session)
return {"url": portal_session.url}
except Exception as e:
logger.exception("Failed to create customer portal session")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/create-subscription-session")
async def create_subscription_session(
_: User = Depends(current_admin_user),
) -> SubscriptionSessionResponse:
try:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if not tenant_id:
raise HTTPException(status_code=400, detail="Tenant ID not found")
session_id = fetch_stripe_checkout_session(tenant_id)
return SubscriptionSessionResponse(sessionId=session_id)
except Exception as e:
logger.exception("Failed to create resubscription session")
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -67,3 +67,19 @@ class ProductGatingResponse(BaseModel):
class SubscriptionSessionResponse(BaseModel):
sessionId: str
class TenantByDomainResponse(BaseModel):
tenant_id: str
status: str
is_complete: bool
class ApproveUserRequest(BaseModel):
email: str
tenant_id: str
class RequestInviteRequest(BaseModel):
email: str
tenant_id: str

View File

@@ -48,5 +48,4 @@ def store_product_gating(tenant_id: str, application_status: ApplicationStatus)
def get_gated_tenants() -> set[str]:
redis_client = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
gated_tenants_bytes = cast(set[bytes], redis_client.smembers(GATED_TENANTS_KEY))
return {tenant_id.decode("utf-8") for tenant_id in gated_tenants_bytes}
return cast(set[str], redis_client.smembers(GATED_TENANTS_KEY))

View File

@@ -6,47 +6,28 @@ import aiohttp # Async HTTP client
import httpx
from fastapi import HTTPException
from fastapi import Request
from sqlalchemy import select
from sqlalchemy.orm import Session
from ee.onyx.configs.app_configs import ANTHROPIC_DEFAULT_API_KEY
from ee.onyx.configs.app_configs import COHERE_DEFAULT_API_KEY
from ee.onyx.configs.app_configs import HUBSPOT_TRACKING_URL
from ee.onyx.configs.app_configs import OPENAI_DEFAULT_API_KEY
from ee.onyx.server.tenants.access import generate_data_plane_token
from ee.onyx.server.tenants.async_setup import complete_tenant_setup
from ee.onyx.server.tenants.models import TenantCreationPayload
from ee.onyx.server.tenants.models import TenantDeletionPayload
from ee.onyx.server.tenants.schema_management import create_schema_if_not_exists
from ee.onyx.server.tenants.schema_management import drop_schema
from ee.onyx.server.tenants.schema_management import run_alembic_migrations
from ee.onyx.server.tenants.schema_management import run_essential_alembic_migrations
from ee.onyx.server.tenants.user_mapping import add_users_to_tenant
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
from ee.onyx.server.tenants.user_mapping import user_owns_a_tenant
from onyx.auth.users import exceptions
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.constants import MilestoneRecordType
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_sqlalchemy_engine
from onyx.db.llm import update_default_provider
from onyx.db.llm import upsert_cloud_embedding_provider
from onyx.db.llm import upsert_llm_provider
from onyx.db.models import IndexModelStatus
from onyx.db.models import SearchSettings
from onyx.db.models import UserTenantMapping
from onyx.llm.llm_provider_options import ANTHROPIC_MODEL_NAMES
from onyx.llm.llm_provider_options import ANTHROPIC_PROVIDER_NAME
from onyx.llm.llm_provider_options import OPEN_AI_MODEL_NAMES
from onyx.llm.llm_provider_options import OPENAI_PROVIDER_NAME
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.setup import setup_onyx
from onyx.utils.telemetry import create_milestone_and_report
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.configs import TENANT_ID_PREFIX
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.enums import EmbeddingProvider
logger = logging.getLogger(__name__)
@@ -55,11 +36,7 @@ logger = logging.getLogger(__name__)
async def get_or_provision_tenant(
email: str, referral_source: str | None = None, request: Request | None = None
) -> str:
"""
Get existing tenant ID for an email or create a new tenant if none exists.
This function should only be called after we have verified we want this user's tenant to exist.
It returns the tenant ID associated with the email, creating a new tenant if necessary.
"""
"""Get existing tenant ID for an email or create a new tenant if none exists."""
if not MULTI_TENANT:
return POSTGRES_DEFAULT_SCHEMA
@@ -119,35 +96,19 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
# Await the Alembic migrations
await asyncio.to_thread(run_alembic_migrations, tenant_id)
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
configure_default_api_keys(db_session)
current_search_settings = (
db_session.query(SearchSettings)
.filter_by(status=IndexModelStatus.FUTURE)
.first()
)
cohere_enabled = (
current_search_settings is not None
and current_search_settings.provider_type == EmbeddingProvider.COHERE
)
setup_onyx(db_session, tenant_id, cohere_enabled=cohere_enabled)
# Run only the essential Alembic migrations needed for auth
await asyncio.to_thread(run_essential_alembic_migrations, tenant_id)
# Add user to tenant immediately so they can log in
add_users_to_tenant([email], tenant_id)
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
create_milestone_and_report(
user=None,
distinct_id=tenant_id,
event_type=MilestoneRecordType.TENANT_CREATED,
properties={
"email": email,
},
db_session=db_session,
)
# Start the rest of the setup process asynchronously
asyncio.create_task(complete_tenant_setup(tenant_id, email))
logger.info(f"Essential tenant provisioning completed for tenant {tenant_id}")
logger.info(
f"Remaining setup will continue asynchronously for tenant {tenant_id}"
)
except Exception as e:
logger.exception(f"Failed to create tenant {tenant_id}")
@@ -203,136 +164,43 @@ async def rollback_tenant_provisioning(tenant_id: str) -> None:
logger.error(f"Failed to rollback tenant provisioning: {e}")
def configure_default_api_keys(db_session: Session) -> None:
if ANTHROPIC_DEFAULT_API_KEY:
anthropic_provider = LLMProviderUpsertRequest(
name="Anthropic",
provider=ANTHROPIC_PROVIDER_NAME,
api_key=ANTHROPIC_DEFAULT_API_KEY,
default_model_name="claude-3-7-sonnet-20250219",
fast_default_model_name="claude-3-5-sonnet-20241022",
model_names=ANTHROPIC_MODEL_NAMES,
display_model_names=["claude-3-5-sonnet-20241022"],
)
try:
full_provider = upsert_llm_provider(anthropic_provider, db_session)
update_default_provider(full_provider.id, db_session)
except Exception as e:
logger.error(f"Failed to configure Anthropic provider: {e}")
else:
logger.error(
"ANTHROPIC_DEFAULT_API_KEY not set, skipping Anthropic provider configuration"
)
if OPENAI_DEFAULT_API_KEY:
open_provider = LLMProviderUpsertRequest(
name="OpenAI",
provider=OPENAI_PROVIDER_NAME,
api_key=OPENAI_DEFAULT_API_KEY,
default_model_name="gpt-4o",
fast_default_model_name="gpt-4o-mini",
model_names=OPEN_AI_MODEL_NAMES,
display_model_names=["o1", "o3-mini", "gpt-4o", "gpt-4o-mini"],
)
try:
full_provider = upsert_llm_provider(open_provider, db_session)
update_default_provider(full_provider.id, db_session)
except Exception as e:
logger.error(f"Failed to configure OpenAI provider: {e}")
else:
logger.error(
"OPENAI_DEFAULT_API_KEY not set, skipping OpenAI provider configuration"
)
if COHERE_DEFAULT_API_KEY:
cloud_embedding_provider = CloudEmbeddingProviderCreationRequest(
provider_type=EmbeddingProvider.COHERE,
api_key=COHERE_DEFAULT_API_KEY,
)
try:
logger.info("Attempting to upsert Cohere cloud embedding provider")
upsert_cloud_embedding_provider(db_session, cloud_embedding_provider)
logger.info("Successfully upserted Cohere cloud embedding provider")
logger.info("Updating search settings with Cohere embedding model details")
query = (
select(SearchSettings)
.where(SearchSettings.status == IndexModelStatus.FUTURE)
.order_by(SearchSettings.id.desc())
)
result = db_session.execute(query)
current_search_settings = result.scalars().first()
if current_search_settings:
current_search_settings.model_name = (
"embed-english-v3.0" # Cohere's latest model as of now
)
current_search_settings.model_dim = (
1024 # Cohere's embed-english-v3.0 dimension
)
current_search_settings.provider_type = EmbeddingProvider.COHERE
current_search_settings.index_name = (
"danswer_chunk_cohere_embed_english_v3_0"
)
current_search_settings.query_prefix = ""
current_search_settings.passage_prefix = ""
db_session.commit()
else:
raise RuntimeError(
"No search settings specified, DB is not in a valid state"
)
logger.info("Fetching updated search settings to verify changes")
updated_query = (
select(SearchSettings)
.where(SearchSettings.status == IndexModelStatus.PRESENT)
.order_by(SearchSettings.id.desc())
)
updated_result = db_session.execute(updated_query)
updated_result.scalars().first()
except Exception:
logger.exception("Failed to configure Cohere embedding provider")
else:
logger.info(
"COHERE_DEFAULT_API_KEY not set, skipping Cohere embedding provider configuration"
)
async def submit_to_hubspot(
email: str, referral_source: str | None, request: Request
) -> None:
if not HUBSPOT_TRACKING_URL:
logger.info("HUBSPOT_TRACKING_URL not set, skipping HubSpot submission")
return
# HubSpot tracking cookie
hubspot_cookie = request.cookies.get("hubspotutk")
try:
user_agent = request.headers.get("user-agent", "")
referer = request.headers.get("referer", "")
ip_address = request.client.host if request.client else ""
# IP address
ip_address = request.client.host if request.client else None
payload = {
"email": email,
"referral_source": referral_source or "",
"user_agent": user_agent,
"referer": referer,
"ip_address": ip_address,
}
data = {
"fields": [
{"name": "email", "value": email},
{"name": "referral_source", "value": referral_source or ""},
],
"context": {
"hutk": hubspot_cookie,
"ipAddress": ip_address,
"pageUri": str(request.url),
"pageName": "User Registration",
},
}
async with httpx.AsyncClient() as client:
response = await client.post(HUBSPOT_TRACKING_URL, json=data)
if response.status_code != 200:
logger.error(f"Failed to submit to HubSpot: {response.text}")
async with httpx.AsyncClient() as client:
response = await client.post(
HUBSPOT_TRACKING_URL,
json=payload,
timeout=5.0,
)
if response.status_code != 200:
logger.error(
f"Failed to submit to HubSpot: {response.status_code} {response.text}"
)
except Exception as e:
logger.error(f"Error submitting to HubSpot: {e}")
async def delete_user_from_control_plane(tenant_id: str, email: str) -> None:
if DEV_MODE:
return
token = generate_data_plane_token()
headers = {
"Authorization": f"Bearer {token}",
@@ -341,15 +209,14 @@ async def delete_user_from_control_plane(tenant_id: str, email: str) -> None:
payload = TenantDeletionPayload(tenant_id=tenant_id, email=email)
async with aiohttp.ClientSession() as session:
async with session.delete(
async with session.post(
f"{CONTROL_PLANE_API_BASE_URL}/tenants/delete",
headers=headers,
json=payload.model_dump(),
) as response:
print(response)
if response.status != 200:
error_text = await response.text()
logger.error(f"Control plane tenant creation failed: {error_text}")
logger.error(f"Control plane tenant deletion failed: {error_text}")
raise Exception(
f"Failed to delete tenant on control plane: {error_text}"
)

View File

@@ -0,0 +1,62 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from pydantic import BaseModel
from ee.onyx.server.tenants.admin_api import router as admin_router
from ee.onyx.server.tenants.anonymous_users_api import router as anonymous_users_router
from ee.onyx.server.tenants.billing_api import router as billing_router
from ee.onyx.server.tenants.team_membership_api import router as team_membership_router
from ee.onyx.server.tenants.tenant_management_api import (
router as tenant_management_router,
)
from ee.onyx.server.tenants.user_invitations_api import (
router as user_invitations_router,
)
from onyx.auth.users import current_user
from onyx.auth.users import User
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
# from ee.onyx.server.tenants.provisioning import get_tenant_setup_status
logger = setup_logger()
# Create a main router to include all sub-routers
router = APIRouter()
# Include all the sub-routers
router.include_router(admin_router)
router.include_router(anonymous_users_router)
router.include_router(billing_router)
router.include_router(team_membership_router)
router.include_router(tenant_management_router)
router.include_router(user_invitations_router)
class TenantSetupStatusResponse(BaseModel):
"""Response model for tenant setup status."""
tenant_id: str
status: str
is_complete: bool
# Add the setup status endpoint directly to the main router
@router.get("/tenants/setup-status", response_model=TenantSetupStatusResponse)
async def get_setup_status(
current_user: User = Depends(current_user),
) -> TenantSetupStatusResponse:
"""
Get the current setup status for the tenant.
This is used by the frontend to determine if the tenant setup is complete.
"""
tenant_id = get_current_tenant_id()
if not tenant_id:
raise HTTPException(status_code=404, detail="Tenant not found")
# status = get_tenant_setup_status(tenant_id)
return TenantSetupStatusResponse(
tenant_id=tenant_id, status="completed", is_complete=True
)

View File

@@ -49,6 +49,47 @@ def run_alembic_migrations(schema_name: str) -> None:
raise
def run_essential_alembic_migrations(schema_name: str) -> None:
"""
Run only the essential Alembic migrations up to the 465f78d9b7f9 revision.
This is used for the auth flow to complete quickly, with the rest of the migrations
and setup being deferred to run asynchronously.
"""
logger.info(f"Starting essential Alembic migrations for schema: {schema_name}")
try:
current_dir = os.path.dirname(os.path.abspath(__file__))
root_dir = os.path.abspath(os.path.join(current_dir, "..", "..", "..", ".."))
alembic_ini_path = os.path.join(root_dir, "alembic.ini")
# Configure Alembic
alembic_cfg = Config(alembic_ini_path)
alembic_cfg.set_main_option("sqlalchemy.url", build_connection_string())
alembic_cfg.set_main_option(
"script_location", os.path.join(root_dir, "alembic")
)
# Ensure that logging isn't broken
alembic_cfg.attributes["configure_logger"] = False
# Mimic command-line options by adding 'cmd_opts' to the config
alembic_cfg.cmd_opts = SimpleNamespace() # type: ignore
alembic_cfg.cmd_opts.x = [f"schema={schema_name}"] # type: ignore
# Run migrations programmatically up to the specified revision
command.upgrade(alembic_cfg, "465f78d9b7f9")
logger.info(
f"Essential Alembic migrations completed successfully for schema: {schema_name}"
)
except Exception as e:
logger.exception(
f"Essential Alembic migration failed for schema {schema_name}: {str(e)}"
)
raise
def create_schema_if_not_exists(tenant_id: str) -> bool:
with Session(get_sqlalchemy_engine()) as db_session:
with db_session.begin():

View File

@@ -0,0 +1,67 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane
from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant
from onyx.auth.users import current_admin_user
from onyx.auth.users import User
from onyx.db.auth import get_user_count
from onyx.db.engine import get_session
from onyx.db.users import delete_user_from_db
from onyx.db.users import get_user_by_email
from onyx.server.manage.models import UserByEmail
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.post("/leave-team")
async def leave_organization(
user_email: UserByEmail,
current_user: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
tenant_id = get_current_tenant_id()
if current_user is None or current_user.email != user_email.user_email:
raise HTTPException(
status_code=403, detail="You can only leave the organization as yourself"
)
user_to_delete = get_user_by_email(user_email.user_email, db_session)
if user_to_delete is None:
raise HTTPException(status_code=404, detail="User not found")
num_admin_users = await get_user_count(only_admin_users=True)
should_delete_tenant = num_admin_users == 1
if should_delete_tenant:
logger.info(
"Last admin user is leaving the organization. Deleting tenant from control plane."
)
try:
await delete_user_from_control_plane(tenant_id, user_to_delete.email)
logger.debug("User deleted from control plane")
except Exception as e:
logger.exception(
f"Failed to delete user from control plane for tenant {tenant_id}: {e}"
)
raise HTTPException(
status_code=500,
detail=f"Failed to remove user from control plane: {str(e)}",
)
db_session.expunge(user_to_delete)
delete_user_from_db(user_to_delete, db_session)
if should_delete_tenant:
remove_all_users_from_tenant(tenant_id)
else:
remove_users_from_tenant([user_to_delete.email], tenant_id)

View File

@@ -0,0 +1,62 @@
from fastapi import APIRouter
from fastapi import Depends
from ee.onyx.server.tenants.models import TenantByDomainResponse
from onyx.auth.users import current_admin_user
from onyx.auth.users import User
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
# from ee.onyx.server.tenants.provisioning import get_tenant_by_domain_from_control_plane
logger = setup_logger()
router = APIRouter(prefix="/tenants")
FORBIDDEN_COMMON_EMAIL_DOMAINS = [
"gmail.com",
"yahoo.com",
"hotmail.com",
"outlook.com",
"icloud.com",
"msn.com",
"live.com",
"msn.com",
"hotmail.com",
"hotmail.co.uk",
"hotmail.fr",
"hotmail.de",
"hotmail.it",
"hotmail.es",
"hotmail.nl",
"hotmail.pl",
"hotmail.pt",
"hotmail.ro",
"hotmail.ru",
"hotmail.sa",
"hotmail.se",
"hotmail.tr",
"hotmail.tw",
"hotmail.ua",
"hotmail.us",
"hotmail.vn",
"hotmail.za",
"hotmail.zw",
]
@router.get("/existing-team-by-domain")
def get_existing_tenant_by_domain(
user: User | None = Depends(current_admin_user),
) -> TenantByDomainResponse | None:
if not user:
return None
domain = user.email.split("@")[1]
if domain in FORBIDDEN_COMMON_EMAIL_DOMAINS:
return None
tenant_id = get_current_tenant_id()
return TenantByDomainResponse(
tenant_id=tenant_id, status="completed", is_complete=True
)
# return get_tenant_by_domain_from_control_plane(domain, tenant_id)

View File

@@ -0,0 +1,91 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from ee.onyx.server.tenants.models import ApproveUserRequest
from ee.onyx.server.tenants.models import PendingUserSnapshot
from ee.onyx.server.tenants.models import RequestInviteRequest
from ee.onyx.server.tenants.user_mapping import accept_user_invite
from ee.onyx.server.tenants.user_mapping import approve_user_invite
from ee.onyx.server.tenants.user_mapping import deny_user_invite
from ee.onyx.server.tenants.user_mapping import invite_self_to_tenant
from onyx.auth.invited_users import get_pending_users
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_user
from onyx.auth.users import User
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.post("/request-invite")
async def request_invite(
invite_request: RequestInviteRequest,
user: User | None = Depends(current_admin_user),
) -> None:
if user is None:
raise HTTPException(status_code=401, detail="User not authenticated")
try:
invite_self_to_tenant(user.email, invite_request.tenant_id)
except Exception as e:
logger.exception(
f"Failed to invite self to tenant {invite_request.tenant_id}: {e}"
)
raise HTTPException(status_code=500, detail=str(e))
@router.get("/users/pending")
def list_pending_users(
_: User | None = Depends(current_admin_user),
) -> list[PendingUserSnapshot]:
pending_emails = get_pending_users()
return [PendingUserSnapshot(email=email) for email in pending_emails]
@router.post("/users/approve-invite")
async def approve_user(
approve_user_request: ApproveUserRequest,
_: User | None = Depends(current_admin_user),
) -> None:
tenant_id = get_current_tenant_id()
approve_user_invite(approve_user_request.email, tenant_id)
@router.post("/users/accept-invite")
async def accept_invite(
invite_request: RequestInviteRequest,
user: User | None = Depends(current_user),
) -> None:
"""
Accept an invitation to join a tenant.
"""
if not user:
raise HTTPException(status_code=401, detail="Not authenticated")
try:
accept_user_invite(user.email, invite_request.tenant_id)
except Exception as e:
logger.exception(f"Failed to accept invite: {str(e)}")
raise HTTPException(status_code=500, detail="Failed to accept invitation")
@router.post("/users/deny-invite")
async def deny_invite(
invite_request: RequestInviteRequest,
user: User | None = Depends(current_user),
) -> None:
"""
Deny an invitation to join a tenant.
"""
if not user:
raise HTTPException(status_code=401, detail="Not authenticated")
try:
deny_user_invite(user.email, invite_request.tenant_id)
except Exception as e:
logger.exception(f"Failed to deny invite: {str(e)}")
raise HTTPException(status_code=500, detail="Failed to deny invitation")

View File

@@ -31,7 +31,6 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_VALIDATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_SUBANSWER_CHECK
from onyx.llm.chat_llm import LLMRateLimitError
@@ -93,7 +92,6 @@ def check_sub_answer(
fast_llm.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK,
max_tokens=AGENT_MAX_TOKENS_VALIDATION,
)
quality_str: str = cast(str, response.content)

View File

@@ -46,7 +46,6 @@ from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_SUBANSWER_GENERATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION
from onyx.llm.chat_llm import LLMRateLimitError
@@ -120,7 +119,6 @@ def generate_sub_answer(
for message in fast_llm.stream(
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION,
max_tokens=AGENT_MAX_TOKENS_SUBANSWER_GENERATION,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content

View File

@@ -43,7 +43,6 @@ from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrin
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_section_list,
)
from onyx.agents.agent_search.shared_graph_utils.utils import _should_restrict_tokens
from onyx.agents.agent_search.shared_graph_utils.utils import (
dispatch_main_answer_stop_info,
)
@@ -63,7 +62,6 @@ from onyx.chat.models import StreamingError
from onyx.configs.agent_configs import AGENT_ANSWER_GENERATION_BY_FAST_LLM
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_MAX_STREAMED_DOCS_FOR_INITIAL_ANSWER
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_ANSWER_GENERATION
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
@@ -155,9 +153,8 @@ def generate_initial_answer(
)
for tool_response in yield_search_responses(
query=question,
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
get_reranked_sections=lambda: answer_generation_documents.streaming_documents,
get_final_context_sections=lambda: answer_generation_documents.context_documents,
reranked_sections=answer_generation_documents.streaming_documents,
final_context_sections=answer_generation_documents.context_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,
search_tool=graph_config.tooling.search_tool,
@@ -281,9 +278,6 @@ def generate_initial_answer(
for message in model.stream(
msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
max_tokens=AGENT_MAX_TOKENS_ANSWER_GENERATION
if _should_restrict_tokens(model.config)
else None,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content

View File

@@ -34,7 +34,6 @@ from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.chat.models import SubQuestionPiece
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_SUBQUESTION_GENERATION
from onyx.configs.agent_configs import AGENT_NUM_DOCS_FOR_DECOMPOSITION
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION,
@@ -142,7 +141,6 @@ def decompose_orig_question(
model.stream(
msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION,
max_tokens=AGENT_MAX_TOKENS_SUBQUESTION_GENERATION,
),
dispatch_subquestion(0, writer),
sep_callback=dispatch_subquestion_sep(0, writer),

View File

@@ -33,7 +33,6 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import RefinedAnswerImprovement
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_VALIDATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_COMPARE_ANSWERS
from onyx.llm.chat_llm import LLMRateLimitError
@@ -113,7 +112,6 @@ def compare_answers(
model.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS,
max_tokens=AGENT_MAX_TOKENS_VALIDATION,
)
except (LLMTimeoutError, TimeoutError):

View File

@@ -43,7 +43,6 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import StreamingError
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_SUBQUESTION_GENERATION
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION,
)
@@ -145,7 +144,6 @@ def create_refined_sub_questions(
model.stream(
msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION,
max_tokens=AGENT_MAX_TOKENS_SUBQUESTION_GENERATION,
),
dispatch_subquestion(1, writer),
sep_callback=dispatch_subquestion_sep(1, writer),

View File

@@ -50,7 +50,13 @@ def decide_refinement_need(
)
]
return RequireRefinemenEvalUpdate(
require_refined_answer_eval=graph_config.behavior.allow_refinement and decision,
log_messages=log_messages,
)
if graph_config.behavior.allow_refinement:
return RequireRefinemenEvalUpdate(
require_refined_answer_eval=decision,
log_messages=log_messages,
)
else:
return RequireRefinemenEvalUpdate(
require_refined_answer_eval=False,
log_messages=log_messages,
)

View File

@@ -21,7 +21,6 @@ from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_ENTITY_TERM_EXTRACTION
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION,
)
@@ -97,7 +96,6 @@ def extract_entities_terms(
fast_llm.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION,
max_tokens=AGENT_MAX_TOKENS_ENTITY_TERM_EXTRACTION,
)
cleaned_response = (

View File

@@ -46,7 +46,6 @@ from onyx.agents.agent_search.shared_graph_utils.models import RefinedAgentStats
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_section_list,
)
from onyx.agents.agent_search.shared_graph_utils.utils import _should_restrict_tokens
from onyx.agents.agent_search.shared_graph_utils.utils import (
dispatch_main_answer_stop_info,
)
@@ -69,8 +68,6 @@ from onyx.chat.models import StreamingError
from onyx.configs.agent_configs import AGENT_ANSWER_GENERATION_BY_FAST_LLM
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_MAX_STREAMED_DOCS_FOR_REFINED_ANSWER
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_ANSWER_GENERATION
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_VALIDATION
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION,
@@ -182,9 +179,8 @@ def generate_validate_refined_answer(
)
for tool_response in yield_search_responses(
query=question,
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
get_reranked_sections=lambda: answer_generation_documents.streaming_documents,
get_final_context_sections=lambda: answer_generation_documents.context_documents,
reranked_sections=answer_generation_documents.streaming_documents,
final_context_sections=answer_generation_documents.context_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,
search_tool=graph_config.tooling.search_tool,
@@ -306,11 +302,7 @@ def generate_validate_refined_answer(
def stream_refined_answer() -> list[str]:
for message in model.stream(
msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION,
max_tokens=AGENT_MAX_TOKENS_ANSWER_GENERATION
if _should_restrict_tokens(model.config)
else None,
msg, timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
@@ -417,7 +409,6 @@ def generate_validate_refined_answer(
validation_model.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION,
max_tokens=AGENT_MAX_TOKENS_VALIDATION,
)
refined_answer_quality = binary_string_test_after_answer_separator(
text=cast(str, validation_response.content),

View File

@@ -13,6 +13,7 @@ from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.chat.models import SubQuestionPiece
from onyx.context.search.models import IndexFilters
from onyx.tools.models import SearchQueryInfo
from onyx.utils.logger import setup_logger
@@ -143,6 +144,8 @@ def get_query_info(results: list[QueryRetrievalResult]) -> SearchQueryInfo:
if result.query_info is not None:
query_info = result.query_info
break
assert query_info is not None, "must have query info"
return query_info
return query_info or SearchQueryInfo(
predicted_search=None,
final_filters=IndexFilters(access_control_list=None),
recency_bias_multiplier=1.0,
)

View File

@@ -33,7 +33,6 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_SUBQUERY_GENERATION
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION,
)
@@ -97,7 +96,6 @@ def expand_queries(
model.stream(
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION,
max_tokens=AGENT_MAX_TOKENS_SUBQUERY_GENERATION,
),
dispatch_subquery(level, question_num, writer),
)

View File

@@ -56,9 +56,8 @@ def format_results(
relevance_list = relevance_from_docs(reranked_documents)
for tool_response in yield_search_responses(
query=state.question,
get_retrieved_sections=lambda: reranked_documents,
get_reranked_sections=lambda: state.retrieved_documents,
get_final_context_sections=lambda: reranked_documents,
reranked_sections=state.retrieved_documents,
final_context_sections=reranked_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,
search_tool=graph_config.tooling.search_tool,

View File

@@ -91,7 +91,7 @@ def retrieve_documents(
retrieved_docs = retrieved_docs[:AGENT_MAX_QUERY_RETRIEVAL_RESULTS]
if AGENT_RETRIEVAL_STATS:
pre_rerank_docs = callback_container[0] if callback_container else []
pre_rerank_docs = callback_container[0]
fit_scores = get_fit_scores(
pre_rerank_docs,
retrieved_docs,

View File

@@ -25,7 +25,6 @@ from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrin
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_VALIDATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION
from onyx.llm.chat_llm import LLMRateLimitError
@@ -94,7 +93,6 @@ def verify_documents(
fast_llm.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION,
max_tokens=AGENT_MAX_TOKENS_VALIDATION,
)
assert isinstance(response.content, str)

View File

@@ -44,9 +44,7 @@ def call_tool(
tool = tool_choice.tool
tool_args = tool_choice.tool_args
tool_id = tool_choice.id
tool_runner = ToolRunner(
tool, tool_args, override_kwargs=tool_choice.search_tool_override_kwargs
)
tool_runner = ToolRunner(tool, tool_args)
tool_kickoff = tool_runner.kickoff()
emit_packet(tool_kickoff, writer)

View File

@@ -15,17 +15,8 @@ from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
from onyx.chat.tool_handling.tool_response_handler import (
get_tool_call_for_non_tool_calling_llm_impl,
)
from onyx.context.search.preprocessing.preprocessing import query_analysis
from onyx.context.search.retrieval.search_runner import get_query_embedding
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_in_background
from onyx.utils.threadpool_concurrency import TimeoutThread
from onyx.utils.threadpool_concurrency import wait_on_background
from onyx.utils.timing import log_function_time
from shared_configs.model_server_models import Embedding
logger = setup_logger()
@@ -34,7 +25,6 @@ logger = setup_logger()
# and a function that handles extracting the necessary fields
# from the state and config
# TODO: fan-out to multiple tool call nodes? Make this configurable?
@log_function_time(print_only=True)
def choose_tool(
state: ToolChoiceState,
config: RunnableConfig,
@@ -47,31 +37,6 @@ def choose_tool(
should_stream_answer = state.should_stream_answer
agent_config = cast(GraphConfig, config["metadata"]["config"])
force_use_tool = agent_config.tooling.force_use_tool
embedding_thread: TimeoutThread[Embedding] | None = None
keyword_thread: TimeoutThread[tuple[bool, list[str]]] | None = None
override_kwargs: SearchToolOverrideKwargs | None = None
if (
not agent_config.behavior.use_agentic_search
and agent_config.tooling.search_tool is not None
and (
not force_use_tool.force_use or force_use_tool.tool_name == SearchTool.name
)
):
override_kwargs = SearchToolOverrideKwargs()
# Run in a background thread to avoid blocking the main thread
embedding_thread = run_in_background(
get_query_embedding,
agent_config.inputs.search_request.query,
agent_config.persistence.db_session,
)
keyword_thread = run_in_background(
query_analysis,
agent_config.inputs.search_request.query,
)
using_tool_calling_llm = agent_config.tooling.using_tool_calling_llm
prompt_builder = state.prompt_snapshot or agent_config.inputs.prompt_builder
@@ -82,6 +47,7 @@ def choose_tool(
tools = [
tool for tool in (agent_config.tooling.tools or []) if tool.name in state.tools
]
force_use_tool = agent_config.tooling.force_use_tool
tool, tool_args = None, None
if force_use_tool.force_use and force_use_tool.args is not None:
@@ -105,22 +71,11 @@ def choose_tool(
# If we have a tool and tool args, we are ready to request a tool call.
# This only happens if the tool call was forced or we are using a non-tool calling LLM.
if tool and tool_args:
if embedding_thread and tool.name == SearchTool._NAME:
# Wait for the embedding thread to finish
embedding = wait_on_background(embedding_thread)
assert override_kwargs is not None, "must have override kwargs"
override_kwargs.precomputed_query_embedding = embedding
if keyword_thread and tool.name == SearchTool._NAME:
is_keyword, keywords = wait_on_background(keyword_thread)
assert override_kwargs is not None, "must have override kwargs"
override_kwargs.precomputed_is_keyword = is_keyword
override_kwargs.precomputed_keywords = keywords
return ToolChoiceUpdate(
tool_choice=ToolChoice(
tool=tool,
tool_args=tool_args,
id=str(uuid4()),
search_tool_override_kwargs=override_kwargs,
),
)
@@ -198,22 +153,10 @@ def choose_tool(
logger.debug(f"Selected tool: {selected_tool.name}")
logger.debug(f"Selected tool call request: {selected_tool_call_request}")
if embedding_thread and selected_tool.name == SearchTool._NAME:
# Wait for the embedding thread to finish
embedding = wait_on_background(embedding_thread)
assert override_kwargs is not None, "must have override kwargs"
override_kwargs.precomputed_query_embedding = embedding
if keyword_thread and selected_tool.name == SearchTool._NAME:
is_keyword, keywords = wait_on_background(keyword_thread)
assert override_kwargs is not None, "must have override kwargs"
override_kwargs.precomputed_is_keyword = is_keyword
override_kwargs.precomputed_keywords = keywords
return ToolChoiceUpdate(
tool_choice=ToolChoice(
tool=selected_tool,
tool_args=selected_tool_call_request["args"],
id=selected_tool_call_request["id"],
search_tool_override_kwargs=override_kwargs,
),
)

View File

@@ -9,23 +9,18 @@ from onyx.agents.agent_search.basic.states import BasicState
from onyx.agents.agent_search.basic.utils import process_llm_stream
from onyx.agents.agent_search.models import GraphConfig
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxContexts
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
from onyx.tools.tool_implementations.search.search_utils import (
context_from_inference_section,
SEARCH_DOC_CONTENT_ID,
)
from onyx.tools.tool_implementations.search_like_tool_utils import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
from onyx.utils.logger import setup_logger
from onyx.utils.timing import log_function_time
logger = setup_logger()
@log_function_time(print_only=True)
def basic_use_tool_response(
state: BasicState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> BasicOutput:
@@ -55,13 +50,11 @@ def basic_use_tool_response(
for yield_item in tool_call_responses:
if yield_item.id == FINAL_CONTEXT_DOCUMENTS_ID:
final_search_results = cast(list[LlmDoc], yield_item.response)
elif yield_item.id == SEARCH_RESPONSE_SUMMARY_ID:
search_response_summary = cast(SearchResponseSummary, yield_item.response)
for section in search_response_summary.top_sections:
if section.center_chunk.document_id not in initial_search_results:
initial_search_results.append(
context_from_inference_section(section)
)
elif yield_item.id == SEARCH_DOC_CONTENT_ID:
search_contexts = cast(OnyxContexts, yield_item.response).contexts
for doc in search_contexts:
if doc.document_id not in initial_search_results:
initial_search_results.append(doc)
new_tool_call_chunk = AIMessageChunk(content="")
if not agent_config.behavior.skip_gen_ai_answer_generation:

View File

@@ -2,7 +2,6 @@ from pydantic import BaseModel
from onyx.chat.prompt_builder.answer_prompt_builder import PromptSnapshot
from onyx.tools.message import ToolCallSummary
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.models import ToolCallFinalResult
from onyx.tools.models import ToolCallKickoff
from onyx.tools.models import ToolResponse
@@ -36,7 +35,6 @@ class ToolChoice(BaseModel):
tool: Tool
tool_args: dict
id: str | None
search_tool_override_kwargs: SearchToolOverrideKwargs | None = None
class Config:
arbitrary_types_allowed = True

View File

@@ -13,11 +13,6 @@ AGENT_NEGATIVE_VALUE_STR = "no"
AGENT_ANSWER_SEPARATOR = "Answer:"
EMBEDDING_KEY = "embedding"
IS_KEYWORD_KEY = "is_keyword"
KEYWORDS_KEY = "keywords"
class AgentLLMErrorType(str, Enum):
TIMEOUT = "timeout"
RATE_LIMIT = "rate_limit"

View File

@@ -42,7 +42,6 @@ from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_HISTORY_SUMMARY
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION,
)
@@ -62,7 +61,6 @@ from onyx.db.persona import Persona
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import LLMConfig
from onyx.prompts.agent_search import (
ASSISTANT_SYSTEM_PROMPT_DEFAULT,
)
@@ -404,7 +402,6 @@ def summarize_history(
llm.invoke,
history_context_prompt,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION,
max_tokens=AGENT_MAX_TOKENS_HISTORY_SUMMARY,
)
except (LLMTimeoutError, TimeoutError):
logger.error("LLM Timeout Error - summarize history")
@@ -508,9 +505,3 @@ def get_deduplicated_structured_subquestion_documents(
cited_documents=dedup_inference_section_list(cited_docs),
context_documents=dedup_inference_section_list(context_docs),
)
def _should_restrict_tokens(llm_config: LLMConfig) -> bool:
return not (
llm_config.model_provider == "openai" and llm_config.model_name.startswith("o")
)

View File

@@ -0,0 +1,52 @@
from typing import Optional
from fastapi import Depends
from fastapi import Request
from fastapi_users import BaseUserManager
from fastapi_users import UUIDIDMixin
from fastapi_users.db import SQLAlchemyUserDatabase
from onyx.auth.essential_user import EssentialUser
from onyx.auth.essential_user import get_essential_user_db
from onyx.configs.app_configs import USER_MANAGER_SECRET
class EssentialUserManager(UUIDIDMixin, BaseUserManager[EssentialUser, str]):
"""
A simplified user manager that only handles essential authentication operations.
This is used during the initial tenant setup phase to avoid errors with missing columns.
"""
reset_password_token_secret = USER_MANAGER_SECRET
verification_token_secret = USER_MANAGER_SECRET
async def on_after_register(
self, user: EssentialUser, request: Optional[Request] = None
) -> None:
"""
Simplified post-registration hook.
"""
async def on_after_forgot_password(
self, user: EssentialUser, token: str, request: Optional[Request] = None
) -> None:
"""
Simplified post-forgot-password hook.
"""
async def on_after_request_verify(
self, user: EssentialUser, token: str, request: Optional[Request] = None
) -> None:
"""
Simplified post-verification-request hook.
"""
async def get_essential_user_manager(
user_db: SQLAlchemyUserDatabase = Depends(get_essential_user_db),
) -> EssentialUserManager:
"""
Get a user manager that uses the essential user model.
This avoids errors with missing columns during the initial tenant setup.
"""
yield EssentialUserManager(user_db)

View File

@@ -0,0 +1,47 @@
from collections.abc import AsyncGenerator
from typing import Optional
from fastapi import Depends
from fastapi_users.db import SQLAlchemyBaseUserTableUUID
from fastapi_users.db import SQLAlchemyUserDatabase
from sqlalchemy import Boolean
from sqlalchemy import Column
from sqlalchemy import String
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.ext.declarative import DeclarativeMeta
from sqlalchemy.orm import relationship
from onyx.db.engine import get_async_session
Base: DeclarativeMeta = declarative_base()
class EssentialUser(SQLAlchemyBaseUserTableUUID, Base):
"""
A simplified user model that only includes essential columns needed for authentication.
This is used during the initial tenant setup phase to avoid errors with missing columns
that would be added in later migrations.
"""
__tablename__ = "user"
email: str = Column(String(length=320), unique=True, index=True, nullable=False)
hashed_password: Optional[str] = Column(String(length=1024), nullable=True)
is_active: bool = Column(Boolean, default=True, nullable=False)
is_superuser: bool = Column(Boolean, default=False, nullable=False)
is_verified: bool = Column(Boolean, default=False, nullable=False)
# Relationships are defined but not used in the essential auth flow
oauth_accounts = relationship("OAuthAccount", lazy="joined")
credentials = relationship("Credential", lazy="joined")
async def get_essential_user_db(
session: AsyncSession = Depends(get_async_session),
) -> AsyncGenerator[SQLAlchemyUserDatabase, None]:
"""
Get a user database that uses the essential user model.
This avoids errors with missing columns during the initial tenant setup.
"""
yield SQLAlchemyUserDatabase(session, EssentialUser)

View File

@@ -587,20 +587,14 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
) -> Optional[User]:
email = credentials.username
tenant_id: str | None = None
try:
tenant_id = fetch_ee_implementation_or_noop(
"onyx.server.tenants.provisioning",
"get_tenant_id_for_email",
None,
)(
email=email,
)
except Exception as e:
logger.warning(
f"User attempted to login with invalid credentials: {str(e)}"
)
# Get tenant_id from mapping table
tenant_id = await fetch_ee_implementation_or_noop(
"onyx.server.tenants.provisioning",
"get_or_provision_tenant",
async_return_default_schema,
)(
email=email,
)
if not tenant_id:
# User not found in mapping
self.password_helper.hash(credentials.password)

View File

@@ -111,6 +111,5 @@ celery_app.autodiscover_tasks(
"onyx.background.celery.tasks.vespa",
"onyx.background.celery.tasks.connector_deletion",
"onyx.background.celery.tasks.doc_permission_syncing",
"onyx.background.celery.tasks.indexing",
]
)

View File

@@ -1,60 +0,0 @@
# backend/onyx/background/celery/memory_monitoring.py
import logging
import os
from logging.handlers import RotatingFileHandler
import psutil
from onyx.utils.logger import setup_logger
# Regular application logger
logger = setup_logger()
# Set up a dedicated memory monitoring logger
MEMORY_LOG_DIR = "/var/log/persisted-logs/memory"
MEMORY_LOG_FILE = os.path.join(MEMORY_LOG_DIR, "memory_usage.log")
MEMORY_LOG_MAX_BYTES = 10 * 1024 * 1024 # 10MB
MEMORY_LOG_BACKUP_COUNT = 5 # Keep 5 backup files
# Ensure log directory exists
os.makedirs(MEMORY_LOG_DIR, exist_ok=True)
# Create a dedicated logger for memory monitoring
memory_logger = logging.getLogger("memory_monitoring")
memory_logger.setLevel(logging.INFO)
# Create a rotating file handler
memory_handler = RotatingFileHandler(
MEMORY_LOG_FILE, maxBytes=MEMORY_LOG_MAX_BYTES, backupCount=MEMORY_LOG_BACKUP_COUNT
)
# Create a formatter that includes all relevant information
memory_formatter = logging.Formatter(
"%(asctime)s [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
)
memory_handler.setFormatter(memory_formatter)
memory_logger.addHandler(memory_handler)
def emit_process_memory(
pid: int, process_name: str, additional_metadata: dict[str, str | int]
) -> None:
try:
process = psutil.Process(pid)
memory_info = process.memory_info()
cpu_percent = process.cpu_percent(interval=0.1)
# Build metadata string from additional_metadata dictionary
metadata_str = " ".join(
[f"{key}={value}" for key, value in additional_metadata.items()]
)
metadata_str = f" {metadata_str}" if metadata_str else ""
memory_logger.info(
f"PROCESS_MEMORY process_name={process_name} pid={pid} "
f"rss_mb={memory_info.rss / (1024 * 1024):.2f} "
f"vms_mb={memory_info.vms / (1024 * 1024):.2f} "
f"cpu={cpu_percent:.2f}{metadata_str}"
)
except Exception:
logger.exception("Error monitoring process memory.")

View File

@@ -23,7 +23,6 @@ from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.background.celery.memory_monitoring import emit_process_memory
from onyx.background.celery.tasks.indexing.utils import get_unfenced_index_attempt_ids
from onyx.background.celery.tasks.indexing.utils import IndexingCallback
from onyx.background.celery.tasks.indexing.utils import should_index
@@ -985,9 +984,6 @@ def connector_indexing_proxy_task(
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
# Track the last time memory info was emitted
last_memory_emit_time = 0.0
try:
with get_session_with_current_tenant() as db_session:
index_attempt = get_index_attempt(
@@ -1028,23 +1024,6 @@ def connector_indexing_proxy_task(
job.release()
break
# log the memory usage for tracking down memory leaks / connector-specific memory issues
pid = job.process.pid
if pid is not None:
# Only emit memory info once per minute (60 seconds)
current_time = time.monotonic()
if current_time - last_memory_emit_time >= 60.0:
emit_process_memory(
pid,
"indexing_worker",
{
"cc_pair_id": cc_pair_id,
"search_settings_id": search_settings_id,
"index_attempt_id": index_attempt_id,
},
)
last_memory_emit_time = current_time
# if a termination signal is detected, break (exit point will clean up)
if self.request.id and redis_connector_index.terminating(self.request.id):
task_logger.warning(
@@ -1191,7 +1170,6 @@ def connector_indexing_proxy_task(
return
# primary
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_CHECKPOINT_CLEANUP,
soft_time_limit=300,
@@ -1239,7 +1217,6 @@ def check_for_checkpoint_cleanup(*, tenant_id: str) -> None:
)
# light worker
@shared_task(
name=OnyxCeleryTask.CLEANUP_CHECKPOINT,
bind=True,

View File

@@ -15,8 +15,6 @@ from onyx.chat.stream_processing.answer_response_handler import (
from onyx.chat.tool_handling.tool_response_handler import ToolResponseHandler
# This is Legacy code that is not used anymore.
# It is kept here for reference.
class LLMResponseHandlerManager:
"""
This class is responsible for postprocessing the LLM response stream.

View File

@@ -90,97 +90,97 @@ class CitationProcessor:
next(group for group in citation.groups() if group is not None)
)
if not (1 <= numerical_value <= self.max_citation_num):
continue
context_llm_doc = self.context_docs[numerical_value - 1]
final_citation_num = self.final_order_mapping[
context_llm_doc.document_id
]
if final_citation_num not in self.citation_order:
self.citation_order.append(final_citation_num)
citation_order_idx = self.citation_order.index(final_citation_num) + 1
# get the value that was displayed to user, should always
# be in the display_doc_order_dict. But check anyways
if context_llm_doc.document_id in self.display_order_mapping:
displayed_citation_num = self.display_order_mapping[
if 1 <= numerical_value <= self.max_citation_num:
context_llm_doc = self.context_docs[numerical_value - 1]
final_citation_num = self.final_order_mapping[
context_llm_doc.document_id
]
else:
displayed_citation_num = final_citation_num
logger.warning(
f"Doc {context_llm_doc.document_id} not in display_doc_order_dict. Used LLM citation number instead."
if final_citation_num not in self.citation_order:
self.citation_order.append(final_citation_num)
citation_order_idx = (
self.citation_order.index(final_citation_num) + 1
)
# Skip consecutive citations of the same work
if final_citation_num in self.current_citations:
start, end = citation.span()
real_start = length_to_add + start
diff = end - start
self.curr_segment = (
self.curr_segment[: length_to_add + start]
+ self.curr_segment[real_start + diff :]
)
length_to_add -= diff
continue
# Handle edge case where LLM outputs citation itself
if self.curr_segment.startswith("[["):
match = re.match(r"\[\[(\d+)\]\]", self.curr_segment)
if match:
try:
doc_id = int(match.group(1))
context_llm_doc = self.context_docs[doc_id - 1]
yield CitationInfo(
# citation_num is now the number post initial ranking, i.e. as displayed to user
citation_num=displayed_citation_num,
document_id=context_llm_doc.document_id,
)
except Exception as e:
logger.warning(
f"Manual LLM citation didn't properly cite documents {e}"
)
# get the value that was displayed to user, should always
# be in the display_doc_order_dict. But check anyways
if context_llm_doc.document_id in self.display_order_mapping:
displayed_citation_num = self.display_order_mapping[
context_llm_doc.document_id
]
else:
displayed_citation_num = final_citation_num
logger.warning(
"Manual LLM citation wasn't able to close brackets"
f"Doc {context_llm_doc.document_id} not in display_doc_order_dict. Used LLM citation number instead."
)
continue
link = context_llm_doc.link
# Skip consecutive citations of the same work
if final_citation_num in self.current_citations:
start, end = citation.span()
real_start = length_to_add + start
diff = end - start
self.curr_segment = (
self.curr_segment[: length_to_add + start]
+ self.curr_segment[real_start + diff :]
)
length_to_add -= diff
continue
self.past_cite_count = len(self.llm_out)
self.current_citations.append(final_citation_num)
# Handle edge case where LLM outputs citation itself
if self.curr_segment.startswith("[["):
match = re.match(r"\[\[(\d+)\]\]", self.curr_segment)
if match:
try:
doc_id = int(match.group(1))
context_llm_doc = self.context_docs[doc_id - 1]
yield CitationInfo(
# citation_num is now the number post initial ranking, i.e. as displayed to user
citation_num=displayed_citation_num,
document_id=context_llm_doc.document_id,
)
except Exception as e:
logger.warning(
f"Manual LLM citation didn't properly cite documents {e}"
)
else:
logger.warning(
"Manual LLM citation wasn't able to close brackets"
)
continue
if citation_order_idx not in self.cited_inds:
self.cited_inds.add(citation_order_idx)
yield CitationInfo(
# citation number is now the one that was displayed to user
citation_num=displayed_citation_num,
document_id=context_llm_doc.document_id,
)
link = context_llm_doc.link
start, end = citation.span()
if link:
prev_length = len(self.curr_segment)
self.curr_segment = (
self.curr_segment[: start + length_to_add]
+ f"[[{displayed_citation_num}]]({link})" # use the value that was displayed to user
+ self.curr_segment[end + length_to_add :]
)
length_to_add += len(self.curr_segment) - prev_length
else:
prev_length = len(self.curr_segment)
self.curr_segment = (
self.curr_segment[: start + length_to_add]
+ f"[[{displayed_citation_num}]]()" # use the value that was displayed to user
+ self.curr_segment[end + length_to_add :]
)
length_to_add += len(self.curr_segment) - prev_length
self.past_cite_count = len(self.llm_out)
self.current_citations.append(final_citation_num)
last_citation_end = end + length_to_add
if citation_order_idx not in self.cited_inds:
self.cited_inds.add(citation_order_idx)
yield CitationInfo(
# citation number is now the one that was displayed to user
citation_num=displayed_citation_num,
document_id=context_llm_doc.document_id,
)
start, end = citation.span()
if link:
prev_length = len(self.curr_segment)
self.curr_segment = (
self.curr_segment[: start + length_to_add]
+ f"[[{displayed_citation_num}]]({link})" # use the value that was displayed to user
+ self.curr_segment[end + length_to_add :]
)
length_to_add += len(self.curr_segment) - prev_length
else:
prev_length = len(self.curr_segment)
self.curr_segment = (
self.curr_segment[: start + length_to_add]
+ f"[[{displayed_citation_num}]]()" # use the value that was displayed to user
+ self.curr_segment[end + length_to_add :]
)
length_to_add += len(self.curr_segment) - prev_length
last_citation_end = end + length_to_add
if last_citation_end > 0:
result += self.curr_segment[:last_citation_end]

View File

@@ -217,20 +217,20 @@ AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION = int(
)
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION = 6 # in seconds
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION = 4 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION
)
AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION = 40 # in seconds
AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION = 30 # in seconds
AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION
)
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION = 10 # in seconds
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION = 5 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION
@@ -243,13 +243,13 @@ AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION = int(
)
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION = 15 # in seconds
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION = 5 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION
)
AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION = 45 # in seconds
AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION = 30 # in seconds
AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION
@@ -333,45 +333,4 @@ AGENT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION = int(
or AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION
)
AGENT_DEFAULT_MAX_TOKENS_VALIDATION = 4
AGENT_MAX_TOKENS_VALIDATION = int(
os.environ.get("AGENT_MAX_TOKENS_VALIDATION") or AGENT_DEFAULT_MAX_TOKENS_VALIDATION
)
AGENT_DEFAULT_MAX_TOKENS_SUBANSWER_GENERATION = 256
AGENT_MAX_TOKENS_SUBANSWER_GENERATION = int(
os.environ.get("AGENT_MAX_TOKENS_SUBANSWER_GENERATION")
or AGENT_DEFAULT_MAX_TOKENS_SUBANSWER_GENERATION
)
AGENT_DEFAULT_MAX_TOKENS_ANSWER_GENERATION = 1024
AGENT_MAX_TOKENS_ANSWER_GENERATION = int(
os.environ.get("AGENT_MAX_TOKENS_ANSWER_GENERATION")
or AGENT_DEFAULT_MAX_TOKENS_ANSWER_GENERATION
)
AGENT_DEFAULT_MAX_TOKENS_SUBQUESTION_GENERATION = 256
AGENT_MAX_TOKENS_SUBQUESTION_GENERATION = int(
os.environ.get("AGENT_MAX_TOKENS_SUBQUESTION_GENERATION")
or AGENT_DEFAULT_MAX_TOKENS_SUBQUESTION_GENERATION
)
AGENT_DEFAULT_MAX_TOKENS_ENTITY_TERM_EXTRACTION = 1024
AGENT_MAX_TOKENS_ENTITY_TERM_EXTRACTION = int(
os.environ.get("AGENT_MAX_TOKENS_ENTITY_TERM_EXTRACTION")
or AGENT_DEFAULT_MAX_TOKENS_ENTITY_TERM_EXTRACTION
)
AGENT_DEFAULT_MAX_TOKENS_SUBQUERY_GENERATION = 64
AGENT_MAX_TOKENS_SUBQUERY_GENERATION = int(
os.environ.get("AGENT_MAX_TOKENS_SUBQUERY_GENERATION")
or AGENT_DEFAULT_MAX_TOKENS_SUBQUERY_GENERATION
)
AGENT_DEFAULT_MAX_TOKENS_HISTORY_SUMMARY = 128
AGENT_MAX_TOKENS_HISTORY_SUMMARY = int(
os.environ.get("AGENT_MAX_TOKENS_HISTORY_SUMMARY")
or AGENT_DEFAULT_MAX_TOKENS_HISTORY_SUMMARY
)
GRAPH_VERSION_NAME: str = "a"

View File

@@ -642,4 +642,14 @@ MOCK_LLM_RESPONSE = (
)
DEFAULT_IMAGE_ANALYSIS_MAX_SIZE_MB = 20
# Image processing configurations
ENABLE_IMAGE_EXTRACTION = (
os.environ.get("ENABLE_IMAGE_EXTRACTION", "true").lower() == "true"
)
ENABLE_INDEXING_TIME_IMAGE_ANALYSIS = not (
os.environ.get("DISABLE_INDEXING_TIME_IMAGE_ANALYSIS", "false").lower() == "true"
)
ENABLE_SEARCH_TIME_IMAGE_ANALYSIS = not (
os.environ.get("DISABLE_SEARCH_TIME_IMAGE_ANALYSIS", "false").lower() == "true"
)
IMAGE_ANALYSIS_MAX_SIZE_MB = int(os.environ.get("IMAGE_ANALYSIS_MAX_SIZE_MB", "20"))

View File

@@ -1,38 +0,0 @@
from onyx.configs.app_configs import DEFAULT_IMAGE_ANALYSIS_MAX_SIZE_MB
from onyx.server.settings.store import load_settings
def get_image_extraction_and_analysis_enabled() -> bool:
"""Get image extraction and analysis enabled setting from workspace settings or fallback to False"""
try:
settings = load_settings()
if settings.image_extraction_and_analysis_enabled is not None:
return settings.image_extraction_and_analysis_enabled
except Exception:
pass
return False
def get_search_time_image_analysis_enabled() -> bool:
"""Get search time image analysis enabled setting from workspace settings or fallback to False"""
try:
settings = load_settings()
if settings.search_time_image_analysis_enabled is not None:
return settings.search_time_image_analysis_enabled
except Exception:
pass
return False
def get_image_analysis_max_size_mb() -> int:
"""Get image analysis max size MB setting from workspace settings or fallback to environment variable"""
try:
settings = load_settings()
if settings.image_analysis_max_size_mb is not None:
return settings.image_analysis_max_size_mb
except Exception:
pass
return DEFAULT_IMAGE_ANALYSIS_MAX_SIZE_MB

View File

@@ -66,6 +66,9 @@ _RESTRICTIONS_EXPANSION_FIELDS = [
_SLIM_DOC_BATCH_SIZE = 5000
_ATTACHMENT_EXTENSIONS_TO_FILTER_OUT = [
"png",
"jpg",
"jpeg",
"gif",
"mp4",
"mov",
@@ -237,7 +240,7 @@ class ConfluenceConnector(
# Extract basic page information
page_id = page["id"]
page_title = page["title"]
page_url = f"{self.wiki_base}{page['_links']['webui']}"
page_url = f"{self.wiki_base}/wiki{page['_links']['webui']}"
# Get the page content
page_content = extract_text_from_confluence_html(
@@ -302,9 +305,7 @@ class ConfluenceConnector(
# Create the document
return Document(
id=build_confluence_document_id(
self.wiki_base, page["_links"]["webui"], self.is_cloud
),
id=build_confluence_document_id(self.wiki_base, page_id, self.is_cloud),
sections=sections,
source=DocumentSource.CONFLUENCE,
semantic_identifier=page_title,
@@ -375,7 +376,7 @@ class ConfluenceConnector(
content_text, file_storage_name = response
object_url = build_confluence_document_id(
self.wiki_base, attachment["_links"]["webui"], self.is_cloud
self.wiki_base, page["_links"]["webui"], self.is_cloud
)
if content_text:

View File

@@ -144,12 +144,6 @@ class OnyxConfluence:
self.static_credentials = credential_json
return credential_json, False
if not OAUTH_CONFLUENCE_CLOUD_CLIENT_ID:
raise RuntimeError("OAUTH_CONFLUENCE_CLOUD_CLIENT_ID must be set!")
if not OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET:
raise RuntimeError("OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET must be set!")
# check if we should refresh tokens. we're deciding to refresh halfway
# to expiration
now = datetime.now(timezone.utc)

View File

@@ -124,14 +124,14 @@ class GithubConnector(LoadConnector, PollConnector):
def __init__(
self,
repo_owner: str,
repositories: str | None = None,
repo_name: str | None = None,
batch_size: int = INDEX_BATCH_SIZE,
state_filter: str = "all",
include_prs: bool = True,
include_issues: bool = False,
) -> None:
self.repo_owner = repo_owner
self.repositories = repositories
self.repo_name = repo_name
self.batch_size = batch_size
self.state_filter = state_filter
self.include_prs = include_prs
@@ -157,42 +157,11 @@ class GithubConnector(LoadConnector, PollConnector):
)
try:
return github_client.get_repo(f"{self.repo_owner}/{self.repositories}")
return github_client.get_repo(f"{self.repo_owner}/{self.repo_name}")
except RateLimitExceededException:
_sleep_after_rate_limit_exception(github_client)
return self._get_github_repo(github_client, attempt_num + 1)
def _get_github_repos(
self, github_client: Github, attempt_num: int = 0
) -> list[Repository.Repository]:
"""Get specific repositories based on comma-separated repo_name string."""
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
raise RuntimeError(
"Re-tried fetching repos too many times. Something is going wrong with fetching objects from Github"
)
try:
repos = []
# Split repo_name by comma and strip whitespace
repo_names = [
name.strip() for name in (cast(str, self.repositories)).split(",")
]
for repo_name in repo_names:
if repo_name: # Skip empty strings
try:
repo = github_client.get_repo(f"{self.repo_owner}/{repo_name}")
repos.append(repo)
except GithubException as e:
logger.warning(
f"Could not fetch repo {self.repo_owner}/{repo_name}: {e}"
)
return repos
except RateLimitExceededException:
_sleep_after_rate_limit_exception(github_client)
return self._get_github_repos(github_client, attempt_num + 1)
def _get_all_repos(
self, github_client: Github, attempt_num: int = 0
) -> list[Repository.Repository]:
@@ -220,17 +189,11 @@ class GithubConnector(LoadConnector, PollConnector):
if self.github_client is None:
raise ConnectorMissingCredentialError("GitHub")
repos = []
if self.repositories:
if "," in self.repositories:
# Multiple repositories specified
repos = self._get_github_repos(self.github_client)
else:
# Single repository (backward compatibility)
repos = [self._get_github_repo(self.github_client)]
else:
# All repositories
repos = self._get_all_repos(self.github_client)
repos = (
[self._get_github_repo(self.github_client)]
if self.repo_name
else self._get_all_repos(self.github_client)
)
for repo in repos:
if self.include_prs:
@@ -305,48 +268,11 @@ class GithubConnector(LoadConnector, PollConnector):
)
try:
if self.repositories:
if "," in self.repositories:
# Multiple repositories specified
repo_names = [name.strip() for name in self.repositories.split(",")]
if not repo_names:
raise ConnectorValidationError(
"Invalid connector settings: No valid repository names provided."
)
# Validate at least one repository exists and is accessible
valid_repos = False
validation_errors = []
for repo_name in repo_names:
if not repo_name:
continue
try:
test_repo = self.github_client.get_repo(
f"{self.repo_owner}/{repo_name}"
)
test_repo.get_contents("")
valid_repos = True
# If at least one repo is valid, we can proceed
break
except GithubException as e:
validation_errors.append(
f"Repository '{repo_name}': {e.data.get('message', str(e))}"
)
if not valid_repos:
error_msg = (
"None of the specified repositories could be accessed: "
)
error_msg += ", ".join(validation_errors)
raise ConnectorValidationError(error_msg)
else:
# Single repository (backward compatibility)
test_repo = self.github_client.get_repo(
f"{self.repo_owner}/{self.repositories}"
)
test_repo.get_contents("")
if self.repo_name:
test_repo = self.github_client.get_repo(
f"{self.repo_owner}/{self.repo_name}"
)
test_repo.get_contents("")
else:
# Try to get organization first
try:
@@ -372,15 +298,10 @@ class GithubConnector(LoadConnector, PollConnector):
"Your GitHub token does not have sufficient permissions for this repository (HTTP 403)."
)
elif e.status == 404:
if self.repositories:
if "," in self.repositories:
raise ConnectorValidationError(
f"None of the specified GitHub repositories could be found for owner: {self.repo_owner}"
)
else:
raise ConnectorValidationError(
f"GitHub repository not found with name: {self.repo_owner}/{self.repositories}"
)
if self.repo_name:
raise ConnectorValidationError(
f"GitHub repository not found with name: {self.repo_owner}/{self.repo_name}"
)
else:
raise ConnectorValidationError(
f"GitHub user or organization not found: {self.repo_owner}"
@@ -389,7 +310,6 @@ class GithubConnector(LoadConnector, PollConnector):
raise ConnectorValidationError(
f"Unexpected GitHub error (status={e.status}): {e.data}"
)
except Exception as exc:
raise Exception(
f"Unexpected error during GitHub settings validation: {exc}"
@@ -401,7 +321,7 @@ if __name__ == "__main__":
connector = GithubConnector(
repo_owner=os.environ["REPO_OWNER"],
repositories=os.environ["REPOSITORIES"],
repo_name=os.environ["REPO_NAME"],
)
connector.load_credentials(
{"github_access_token": os.environ["GITHUB_ACCESS_TOKEN"]}

View File

@@ -316,9 +316,7 @@ class GoogleDriveConnector(
# validate that the user has access to the drive APIs by performing a simple
# request and checking for a 401
try:
# default is ~17mins of retries, don't do that here for cases so we don't
# waste 17mins everytime we run into a user without access to drive APIs
retry_builder(tries=3, delay=1)(get_root_folder_id)(drive_service)
retry_builder()(get_root_folder_id)(drive_service)
except HttpError as e:
if e.status_code == 401:
# fail gracefully, let the other impersonations continue

View File

@@ -1,7 +1,7 @@
"""
Mixin for connectors that need vision capabilities.
"""
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
from onyx.configs.app_configs import ENABLE_INDEXING_TIME_IMAGE_ANALYSIS
from onyx.llm.factory import get_default_llm_with_vision
from onyx.llm.interfaces import LLM
from onyx.utils.logger import setup_logger
@@ -30,7 +30,7 @@ class VisionEnabledConnector:
Sets self.image_analysis_llm to the LLM instance or None if disabled.
"""
self.image_analysis_llm: LLM | None = None
if get_image_extraction_and_analysis_enabled():
if ENABLE_INDEXING_TIME_IMAGE_ANALYSIS:
try:
self.image_analysis_llm = get_default_llm_with_vision()
if self.image_analysis_llm is None:

View File

@@ -16,7 +16,7 @@ from onyx.db.models import SearchSettings
from onyx.indexing.models import BaseChunk
from onyx.indexing.models import IndexingSetting
from shared_configs.enums import RerankerProvider
from shared_configs.model_server_models import Embedding
MAX_METRICS_CONTENT = (
200 # Just need enough characters to identify where in the doc the chunk is
@@ -151,10 +151,6 @@ class SearchRequest(ChunkContext):
evaluation_type: LLMEvaluationType = LLMEvaluationType.UNSPECIFIED
model_config = ConfigDict(arbitrary_types_allowed=True)
precomputed_query_embedding: Embedding | None = None
precomputed_is_keyword: bool | None = None
precomputed_keywords: list[str] | None = None
class SearchQuery(ChunkContext):
"Processed Request that is directly passed to the SearchPipeline"
@@ -179,8 +175,6 @@ class SearchQuery(ChunkContext):
offset: int = 0
model_config = ConfigDict(frozen=True)
precomputed_query_embedding: Embedding | None = None
class RetrievalDetails(ChunkContext):
# Use LLM to determine whether to do a retrieval or only rely on existing history

View File

@@ -331,14 +331,6 @@ class SearchPipeline:
self._retrieved_sections = expanded_inference_sections
return expanded_inference_sections
@property
def retrieved_sections(self) -> list[InferenceSection]:
if self._retrieved_sections is not None:
return self._retrieved_sections
self._retrieved_sections = self._get_sections()
return self._retrieved_sections
@property
def reranked_sections(self) -> list[InferenceSection]:
"""Reranking is always done at the chunk level since section merging could create arbitrarily
@@ -351,7 +343,7 @@ class SearchPipeline:
if self._reranked_sections is not None:
return self._reranked_sections
retrieved_sections = self.retrieved_sections
retrieved_sections = self._get_sections()
if self.retrieved_sections_callback is not None:
self.retrieved_sections_callback(retrieved_sections)

View File

@@ -10,8 +10,8 @@ from langchain_core.messages import SystemMessage
from onyx.chat.models import SectionRelevancePiece
from onyx.configs.app_configs import BLURB_SIZE
from onyx.configs.app_configs import ENABLE_SEARCH_TIME_IMAGE_ANALYSIS
from onyx.configs.constants import RETURN_SEPARATOR
from onyx.configs.llm_configs import get_search_time_image_analysis_enabled
from onyx.configs.model_configs import CROSS_ENCODER_RANGE_MAX
from onyx.configs.model_configs import CROSS_ENCODER_RANGE_MIN
from onyx.context.search.enums import LLMEvaluationType
@@ -413,7 +413,7 @@ def search_postprocessing(
# NOTE: if we don't rerank, we can return the chunks immediately
# since we know this is the final order.
# This way the user experience isn't delayed by the LLM step
if get_search_time_image_analysis_enabled():
if ENABLE_SEARCH_TIME_IMAGE_ANALYSIS:
update_image_sections_with_query(
retrieved_sections, search_query.query, llm
)
@@ -456,7 +456,7 @@ def search_postprocessing(
_log_top_section_links(search_query.search_type.value, reranked_sections)
# Add the image processing step here
if get_search_time_image_analysis_enabled():
if ENABLE_SEARCH_TIME_IMAGE_ANALYSIS:
update_image_sections_with_query(
reranked_sections, search_query.query, llm
)

View File

@@ -117,12 +117,8 @@ def retrieval_preprocessing(
else None
)
# Sometimes this is pre-computed in parallel with other heavy tasks to improve
# latency, and in that case we don't need to run the model again
run_query_analysis = (
None
if (skip_query_analysis or search_request.precomputed_is_keyword is not None)
else FunctionCall(query_analysis, (query,), {})
None if skip_query_analysis else FunctionCall(query_analysis, (query,), {})
)
functions_to_run = [
@@ -147,12 +143,11 @@ def retrieval_preprocessing(
# The extracted keywords right now are not very reliable, not using for now
# Can maybe use for highlighting
is_keyword, _extracted_keywords = False, None
if search_request.precomputed_is_keyword is not None:
is_keyword = search_request.precomputed_is_keyword
_extracted_keywords = search_request.precomputed_keywords
elif run_query_analysis:
is_keyword, _extracted_keywords = parallel_results[run_query_analysis.result_id]
is_keyword, extracted_keywords = (
parallel_results[run_query_analysis.result_id]
if run_query_analysis
else (False, None)
)
all_query_terms = query.split()
processed_keywords = (
@@ -252,5 +247,4 @@ def retrieval_preprocessing(
chunks_above=chunks_above,
chunks_below=chunks_below,
full_doc=search_request.full_doc,
precomputed_query_embedding=search_request.precomputed_query_embedding,
)

View File

@@ -31,7 +31,7 @@ from onyx.utils.timing import log_function_time
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.enums import EmbedTextType
from shared_configs.model_server_models import Embedding
logger = setup_logger()
@@ -109,20 +109,6 @@ def combine_retrieval_results(
return sorted_chunks
def get_query_embedding(query: str, db_session: Session) -> Embedding:
search_settings = get_current_search_settings(db_session)
model = EmbeddingModel.from_db_model(
search_settings=search_settings,
# The below are globally set, this flow always uses the indexing one
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
query_embedding = model.encode([query], text_type=EmbedTextType.QUERY)[0]
return query_embedding
@log_function_time(print_only=True)
def doc_index_retrieval(
query: SearchQuery,
@@ -135,10 +121,17 @@ def doc_index_retrieval(
from the large chunks to the referenced chunks,
dedupes the chunks, and cleans the chunks.
"""
query_embedding = query.precomputed_query_embedding or get_query_embedding(
query.query, db_session
search_settings = get_current_search_settings(db_session)
model = EmbeddingModel.from_db_model(
search_settings=search_settings,
# The below are globally set, this flow always uses the indexing one
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
query_embedding = model.encode([query.query], text_type=EmbedTextType.QUERY)[0]
top_chunks = document_index.hybrid_retrieval(
query=query.query,
query_embedding=query_embedding,
@@ -256,16 +249,7 @@ def retrieve_chunks(
continue
simplified_queries.add(simplified_rephrase)
q_copy = query.model_copy(
update={
"query": rephrase,
# need to recompute for each rephrase
# note that `SearchQuery` is a frozen model, so we can't update
# it below
"precomputed_query_embedding": None,
},
deep=True,
)
q_copy = query.copy(update={"query": rephrase}, deep=True)
run_queries.append(
(
doc_index_retrieval,

View File

@@ -1,79 +0,0 @@
import random
from datetime import datetime
from datetime import timedelta
from logging import getLogger
from onyx.configs.constants import MessageType
from onyx.db.chat import create_chat_session
from onyx.db.chat import create_new_chat_message
from onyx.db.chat import get_or_create_root_message
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.models import ChatSession
logger = getLogger(__name__)
def seed_chat_history(num_sessions: int, num_messages: int, days: int) -> None:
"""Utility function to seed chat history for testing.
num_sessions: the number of sessions to seed
num_messages: the number of messages to seed per sessions
days: the number of days looking backwards from the current time over which to randomize
the times.
"""
with get_session_with_current_tenant() as db_session:
logger.info(f"Seeding {num_sessions} sessions.")
for y in range(0, num_sessions):
create_chat_session(db_session, f"pytest_session_{y}", None, None)
# randomize all session times
logger.info(f"Seeding {num_messages} messages per session.")
rows = db_session.query(ChatSession).all()
for x in range(0, len(rows)):
if x % 1024 == 0:
logger.info(f"Seeded messages for {x} sessions so far.")
row = rows[x]
row.time_created = datetime.utcnow() - timedelta(
days=random.randint(0, days)
)
row.time_updated = row.time_created + timedelta(
minutes=random.randint(0, 10)
)
root_message = get_or_create_root_message(row.id, db_session)
current_message_type = MessageType.USER
parent_message = root_message
for x in range(0, num_messages):
if current_message_type == MessageType.USER:
msg = f"pytest_message_user_{x}"
else:
msg = f"pytest_message_assistant_{x}"
chat_message = create_new_chat_message(
row.id,
parent_message,
msg,
None,
0,
current_message_type,
db_session,
)
chat_message.time_sent = row.time_created + timedelta(
minutes=random.randint(0, 10)
)
db_session.commit()
current_message_type = (
MessageType.ASSISTANT
if current_message_type == MessageType.USER
else MessageType.USER
)
parent_message = chat_message
db_session.commit()
logger.info(f"Seeded messages for {len(rows)} sessions. Finished.")

View File

@@ -464,29 +464,12 @@ def index_doc_batch(
),
)
all_returned_doc_ids = (
{record.document_id for record in insertion_records}
.union(
{
record.failed_document.document_id
for record in vector_db_write_failures
if record.failed_document
}
)
.union(
{
record.failed_document.document_id
for record in embedding_failures
if record.failed_document
}
)
)
if all_returned_doc_ids != set(updatable_ids):
successful_doc_ids = {record.document_id for record in insertion_records}
if successful_doc_ids != set(updatable_ids):
raise RuntimeError(
f"Some documents were not successfully indexed. "
f"Updatable IDs: {updatable_ids}, "
f"Returned IDs: {all_returned_doc_ids}. "
"This should never happen."
f"Successful IDs: {successful_doc_ids}"
)
last_modified_ids = []

View File

@@ -167,7 +167,7 @@ def _convert_delta_to_message_chunk(
stop_reason: str | None = None,
) -> BaseMessageChunk:
"""Adapted from langchain_community.chat_models.litellm._convert_delta_to_message_chunk"""
role = _dict.get("role") or (_base_msg_to_role(curr_msg) if curr_msg else "unknown")
role = _dict.get("role") or (_base_msg_to_role(curr_msg) if curr_msg else None)
content = _dict.get("content") or ""
additional_kwargs = {}
if _dict.get("function_call"):
@@ -402,7 +402,6 @@ class DefaultMultiLLM(LLM):
stream: bool,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
max_tokens: int | None = None,
) -> litellm.ModelResponse | litellm.CustomStreamWrapper:
# litellm doesn't accept LangChain BaseMessage objects, so we need to convert them
# to a dict representation
@@ -430,7 +429,6 @@ class DefaultMultiLLM(LLM):
# model params
temperature=0,
timeout=timeout_override or self._timeout,
max_tokens=max_tokens,
# For now, we don't support parallel tool calls
# NOTE: we can't pass this in if tools are not specified
# or else OpenAI throws an error
@@ -486,7 +484,6 @@ class DefaultMultiLLM(LLM):
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
max_tokens: int | None = None,
) -> BaseMessage:
if LOG_DANSWER_MODEL_INTERACTIONS:
self.log_model_configs()
@@ -500,7 +497,6 @@ class DefaultMultiLLM(LLM):
stream=False,
structured_response_format=structured_response_format,
timeout_override=timeout_override,
max_tokens=max_tokens,
),
)
choice = response.choices[0]
@@ -519,7 +515,6 @@ class DefaultMultiLLM(LLM):
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
max_tokens: int | None = None,
) -> Iterator[BaseMessage]:
if LOG_DANSWER_MODEL_INTERACTIONS:
self.log_model_configs()
@@ -544,7 +539,6 @@ class DefaultMultiLLM(LLM):
stream=True,
structured_response_format=structured_response_format,
timeout_override=timeout_override,
max_tokens=max_tokens,
),
)
try:

View File

@@ -82,7 +82,6 @@ class CustomModelServer(LLM):
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
max_tokens: int | None = None,
) -> BaseMessage:
return self._execute(prompt)
@@ -93,6 +92,5 @@ class CustomModelServer(LLM):
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
max_tokens: int | None = None,
) -> Iterator[BaseMessage]:
yield self._execute(prompt)

View File

@@ -91,18 +91,12 @@ class LLM(abc.ABC):
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
max_tokens: int | None = None,
) -> BaseMessage:
self._precall(prompt)
# TODO add a postcall to log model outputs independent of concrete class
# implementation
return self._invoke_implementation(
prompt,
tools,
tool_choice,
structured_response_format,
timeout_override,
max_tokens,
prompt, tools, tool_choice, structured_response_format, timeout_override
)
@abc.abstractmethod
@@ -113,7 +107,6 @@ class LLM(abc.ABC):
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
max_tokens: int | None = None,
) -> BaseMessage:
raise NotImplementedError
@@ -124,18 +117,12 @@ class LLM(abc.ABC):
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
max_tokens: int | None = None,
) -> Iterator[BaseMessage]:
self._precall(prompt)
# TODO add a postcall to log model outputs independent of concrete class
# implementation
messages = self._stream_implementation(
prompt,
tools,
tool_choice,
structured_response_format,
timeout_override,
max_tokens,
prompt, tools, tool_choice, structured_response_format, timeout_override
)
tokens = []
@@ -155,6 +142,5 @@ class LLM(abc.ABC):
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
max_tokens: int | None = None,
) -> Iterator[BaseMessage]:
raise NotImplementedError

View File

@@ -51,7 +51,6 @@ from onyx.server.documents.cc_pair import router as cc_pair_router
from onyx.server.documents.connector import router as connector_router
from onyx.server.documents.credential import router as credential_router
from onyx.server.documents.document import router as document_router
from onyx.server.documents.standard_oauth import router as standard_oauth_router
from onyx.server.features.document_set.api import router as document_set_router
from onyx.server.features.folder.api import router as folder_router
from onyx.server.features.input_prompt.api import (
@@ -323,7 +322,6 @@ def get_application() -> FastAPI:
)
include_router_with_global_prefix_prepended(application, long_term_logs_router)
include_router_with_global_prefix_prepended(application, api_key_router)
include_router_with_global_prefix_prepended(application, standard_oauth_router)
if AUTH_TYPE == AuthType.DISABLED:
# Server logs this during auth setup verification step

View File

@@ -53,11 +53,6 @@ class Settings(BaseModel):
auto_scroll: bool | None = False
query_history_type: QueryHistoryType | None = None
# Image processing settings
image_extraction_and_analysis_enabled: bool | None = False
search_time_image_analysis_enabled: bool | None = False
image_analysis_max_size_mb: int | None = 20
class UserSettings(Settings):
notifications: list[Notification]

View File

@@ -47,7 +47,6 @@ def load_settings() -> Settings:
settings.anonymous_user_enabled = anonymous_user_enabled
settings.query_history_type = ONYX_QUERY_HISTORY_TYPE
return settings

View File

@@ -9,7 +9,6 @@ from sqlalchemy.orm import Session
from onyx.context.search.enums import SearchType
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import InferenceSection
from shared_configs.model_server_models import Embedding
class ToolResponse(BaseModel):
@@ -61,15 +60,11 @@ class SearchQueryInfo(BaseModel):
recency_bias_multiplier: float
# None indicates that the default value should be used
class SearchToolOverrideKwargs(BaseModel):
force_no_rerank: bool | None = None
alternate_db_session: Session | None = None
retrieved_sections_callback: Callable[[list[InferenceSection]], None] | None = None
skip_query_analysis: bool | None = None
precomputed_query_embedding: Embedding | None = None
precomputed_is_keyword: bool | None = None
precomputed_keywords: list[str] | None = None
force_no_rerank: bool
alternate_db_session: Session | None
retrieved_sections_callback: Callable[[list[InferenceSection]], None] | None
skip_query_analysis: bool
class Config:
arbitrary_types_allowed = True

View File

@@ -3,7 +3,6 @@ from collections.abc import Callable
from collections.abc import Generator
from typing import Any
from typing import cast
from typing import TypeVar
from sqlalchemy.orm import Session
@@ -12,6 +11,7 @@ from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import ContextualPruningConfig
from onyx.chat.models import DocumentPruningConfig
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxContext
from onyx.chat.models import OnyxContexts
from onyx.chat.models import PromptConfig
from onyx.chat.models import SectionRelevancePiece
@@ -42,9 +42,6 @@ from onyx.tools.models import SearchQueryInfo
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.models import ToolResponse
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_utils import (
context_from_inference_section,
)
from onyx.tools.tool_implementations.search.search_utils import llm_doc_to_dict
from onyx.tools.tool_implementations.search_like_tool_utils import (
build_next_prompt_for_search_like_tool,
@@ -284,23 +281,16 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
self, override_kwargs: SearchToolOverrideKwargs | None = None, **llm_kwargs: Any
) -> Generator[ToolResponse, None, None]:
query = cast(str, llm_kwargs[QUERY_FIELD])
precomputed_query_embedding = None
precomputed_is_keyword = None
precomputed_keywords = None
force_no_rerank = False
alternate_db_session = None
retrieved_sections_callback = None
skip_query_analysis = False
if override_kwargs:
force_no_rerank = use_alt_not_None(override_kwargs.force_no_rerank, False)
force_no_rerank = override_kwargs.force_no_rerank
alternate_db_session = override_kwargs.alternate_db_session
retrieved_sections_callback = override_kwargs.retrieved_sections_callback
skip_query_analysis = use_alt_not_None(
override_kwargs.skip_query_analysis, False
)
precomputed_query_embedding = override_kwargs.precomputed_query_embedding
precomputed_is_keyword = override_kwargs.precomputed_is_keyword
precomputed_keywords = override_kwargs.precomputed_keywords
skip_query_analysis = override_kwargs.skip_query_analysis
if self.selected_sections:
yield from self._build_response_for_specified_sections(query)
return
@@ -337,9 +327,6 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
if self.retrieval_options
else None
),
precomputed_query_embedding=precomputed_query_embedding,
precomputed_is_keyword=precomputed_is_keyword,
precomputed_keywords=precomputed_keywords,
),
user=self.user,
llm=self.llm,
@@ -358,9 +345,8 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
)
yield from yield_search_responses(
query,
lambda: search_pipeline.retrieved_sections,
lambda: search_pipeline.reranked_sections,
lambda: search_pipeline.final_context_sections,
search_pipeline.reranked_sections,
search_pipeline.final_context_sections,
search_query_info,
lambda: search_pipeline.section_relevance,
self,
@@ -397,16 +383,10 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
# SearchTool passed in to allow for access to SearchTool properties.
# We can't just call SearchTool methods in the graph because we're operating on
# the retrieved docs (reranking, deduping, etc.) after the SearchTool has run.
#
# The various inference sections are passed in as functions to allow for lazy
# evaluation. The SearchPipeline object properties that they correspond to are
# actually functions defined with @property decorators, and passing them into
# this function causes them to get evaluated immediately which is undesirable.
def yield_search_responses(
query: str,
get_retrieved_sections: Callable[[], list[InferenceSection]],
get_reranked_sections: Callable[[], list[InferenceSection]],
get_final_context_sections: Callable[[], list[InferenceSection]],
reranked_sections: list[InferenceSection],
final_context_sections: list[InferenceSection],
search_query_info: SearchQueryInfo,
get_section_relevance: Callable[[], list[SectionRelevancePiece] | None],
search_tool: SearchTool,
@@ -415,7 +395,7 @@ def yield_search_responses(
id=SEARCH_RESPONSE_SUMMARY_ID,
response=SearchResponseSummary(
rephrased_query=query,
top_sections=get_retrieved_sections(),
top_sections=final_context_sections,
predicted_flow=QueryFlow.QUESTION_ANSWER,
predicted_search=search_query_info.predicted_search,
final_filters=search_query_info.final_filters,
@@ -427,8 +407,13 @@ def yield_search_responses(
id=SEARCH_DOC_CONTENT_ID,
response=OnyxContexts(
contexts=[
context_from_inference_section(section)
for section in get_reranked_sections()
OnyxContext(
content=section.combined_content,
document_id=section.center_chunk.document_id,
semantic_identifier=section.center_chunk.semantic_identifier,
blurb=section.center_chunk.blurb,
)
for section in reranked_sections
]
),
)
@@ -439,7 +424,6 @@ def yield_search_responses(
response=section_relevance,
)
final_context_sections = get_final_context_sections()
pruned_sections = prune_sections(
sections=final_context_sections,
section_relevance_list=section_relevance_list_impl(
@@ -454,10 +438,3 @@ def yield_search_responses(
llm_docs = [llm_doc_from_inference_section(section) for section in pruned_sections]
yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs)
T = TypeVar("T")
def use_alt_not_None(value: T | None, alt: T) -> T:
return value if value is not None else alt

View File

@@ -1,5 +1,4 @@
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxContext
from onyx.context.search.models import InferenceSection
from onyx.prompts.prompt_utils import clean_up_source
@@ -30,12 +29,3 @@ def section_to_dict(section: InferenceSection, section_num: int) -> dict:
"%B %d, %Y %H:%M"
)
return doc_dict
def context_from_inference_section(section: InferenceSection) -> OnyxContext:
return OnyxContext(
content=section.combined_content,
document_id=section.center_chunk.document_id,
semantic_identifier=section.center_chunk.semantic_identifier,
blurb=section.center_chunk.blurb,
)

View File

@@ -1,8 +1,6 @@
from collections.abc import Callable
from collections.abc import Generator
from typing import Any
from typing import Generic
from typing import TypeVar
from onyx.llm.interfaces import LLM
from onyx.llm.models import PreviousMessage
@@ -13,16 +11,10 @@ from onyx.tools.tool import Tool
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
R = TypeVar("R")
class ToolRunner(Generic[R]):
def __init__(
self, tool: Tool[R], args: dict[str, Any], override_kwargs: R | None = None
):
class ToolRunner:
def __init__(self, tool: Tool, args: dict[str, Any]):
self.tool = tool
self.args = args
self.override_kwargs = override_kwargs
self._tool_responses: list[ToolResponse] | None = None
@@ -35,9 +27,7 @@ class ToolRunner(Generic[R]):
return
tool_responses: list[ToolResponse] = []
for tool_response in self.tool.run(
override_kwargs=self.override_kwargs, **self.args
):
for tool_response in self.tool.run(**self.args):
yield tool_response
tool_responses.append(tool_response)

View File

@@ -118,7 +118,7 @@ def run_functions_in_parallel(
return results
class TimeoutThread(threading.Thread, Generic[R]):
class TimeoutThread(threading.Thread):
def __init__(
self, timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any
):
@@ -159,34 +159,3 @@ def run_with_timeout(
task.end()
return task.result
# NOTE: this function should really only be used when run_functions_tuples_in_parallel is
# difficult to use. It's up to the programmer to call wait_on_background on the thread after
# the code you want to run in parallel is finished. As with all python thread parallelism,
# this is only useful for I/O bound tasks.
def run_in_background(
func: Callable[..., R], *args: Any, **kwargs: Any
) -> TimeoutThread[R]:
"""
Runs a function in a background thread. Returns a TimeoutThread object that can be used
to wait for the function to finish with wait_on_background.
"""
context = contextvars.copy_context()
# Timeout not used in the non-blocking case
task = TimeoutThread(-1, context.run, func, *args, **kwargs)
task.start()
return task
def wait_on_background(task: TimeoutThread[R]) -> R:
"""
Used in conjunction with run_in_background. blocks until the task is finished,
then returns the result of the task.
"""
task.join()
if task.exception is not None:
raise task.exception
return task.result

View File

@@ -1,45 +0,0 @@
import argparse
import logging
from logging import getLogger
from onyx.db.seeding.chat_history_seeding import seed_chat_history
# Configure the logger
logging.basicConfig(
level=logging.INFO, # Set the log level (DEBUG, INFO, WARNING, ERROR, CRITICAL)
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", # Log format
handlers=[logging.StreamHandler()], # Output logs to console
)
logger = getLogger(__name__)
def go_main(num_sessions: int, num_messages: int, num_days: int) -> None:
seed_chat_history(num_sessions, num_messages, num_days)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Seed chat history")
parser.add_argument(
"--sessions",
type=int,
default=2048,
help="Number of chat sessions to seed",
)
parser.add_argument(
"--messages",
type=int,
default=4,
help="Number of chat messages to seed per session",
)
parser.add_argument(
"--days",
type=int,
default=90,
help="Number of days looking backwards over which to seed the timestamps with",
)
args = parser.parse_args()
go_main(args.sessions, args.messages, args.days)

View File

@@ -108,7 +108,6 @@ command=tail -qF
/var/log/celery_worker_light.log
/var/log/celery_worker_heavy.log
/var/log/celery_worker_indexing.log
/var/log/celery_worker_monitoring.log
/var/log/slack_bot.log
stdout_logfile=/dev/stdout
stdout_logfile_maxbytes = 0 # must be set to 0 when stdout_logfile=/dev/stdout

View File

@@ -45,7 +45,7 @@ def test_confluence_connector_basic(
with pytest.raises(StopIteration):
next(doc_batch_generator)
assert len(doc_batch) == 2
assert len(doc_batch) == 3
page_within_a_page_doc: Document | None = None
page_doc: Document | None = None

View File

@@ -41,10 +41,5 @@ def test_confluence_connector_permissions(
for slim_doc_batch in confluence_connector.retrieve_all_slim_documents():
all_slim_doc_ids.update([doc.id for doc in slim_doc_batch])
# Find IDs that are in full but not in slim
difference = all_full_doc_ids - all_slim_doc_ids
# The set of full doc IDs should be always be a subset of the slim doc IDs
assert all_full_doc_ids.issubset(
all_slim_doc_ids
), f"Full doc IDs are not a subset of slim doc IDs. Found {len(difference)} IDs in full docs but not in slim docs."
assert all_full_doc_ids.issubset(all_slim_doc_ids)

View File

@@ -25,7 +25,7 @@ from onyx.indexing.models import IndexingSetting
from onyx.setup import setup_postgres
from onyx.setup import setup_vespa
from onyx.utils.logger import setup_logger
from tests.integration.common_utils.timeout import run_with_timeout_multiproc
from tests.integration.common_utils.timeout import run_with_timeout
logger = setup_logger()
@@ -161,7 +161,7 @@ def reset_postgres(
for _ in range(NUM_TRIES):
logger.info(f"Downgrading Postgres... ({_ + 1}/{NUM_TRIES})")
try:
run_with_timeout_multiproc(
run_with_timeout(
downgrade_postgres,
TIMEOUT,
kwargs={

View File

@@ -6,9 +6,7 @@ from typing import TypeVar
T = TypeVar("T")
def run_with_timeout_multiproc(
task: Callable[..., T], timeout: int, kwargs: dict[str, Any]
) -> T:
def run_with_timeout(task: Callable[..., T], timeout: int, kwargs: dict[str, Any]) -> T:
# Use multiprocessing to prevent a thread from blocking the main thread
with multiprocessing.Pool(processes=1) as pool:
async_result = pool.apply_async(task, kwds=kwargs)

View File

@@ -1,48 +0,0 @@
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from ee.onyx.db.usage_export import get_all_empty_chat_message_entries
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.seeding.chat_history_seeding import seed_chat_history
def test_usage_reports(reset: None) -> None:
EXPECTED_SESSIONS = 2048
MESSAGES_PER_SESSION = 4
# divide by 2 because only messages of type USER are returned
EXPECTED_MESSAGES = EXPECTED_SESSIONS * MESSAGES_PER_SESSION / 2
seed_chat_history(EXPECTED_SESSIONS, MESSAGES_PER_SESSION, 90)
with get_session_with_current_tenant() as db_session:
# count of all entries should be exact
period = (
datetime.fromtimestamp(0, tz=timezone.utc),
datetime.now(tz=timezone.utc),
)
count = 0
for entry_batch in get_all_empty_chat_message_entries(db_session, period):
for entry in entry_batch:
count += 1
assert count == EXPECTED_MESSAGES
# count in a one month time range should be within a certain range statistically
# this can be improved if we seed the chat history data deterministically
period = (
datetime.now(tz=timezone.utc) - timedelta(days=30),
datetime.now(tz=timezone.utc),
)
count = 0
for entry_batch in get_all_empty_chat_message_entries(db_session, period):
for entry in entry_batch:
count += 1
lower = EXPECTED_MESSAGES // 3 - (EXPECTED_MESSAGES // (3 * 3))
upper = EXPECTED_MESSAGES // 3 + (EXPECTED_MESSAGES // (3 * 3))
assert count > lower
assert count < upper

View File

@@ -145,7 +145,6 @@ def test_multiple_tool_calls(default_multi_llm: DefaultMultiLLM) -> None:
timeout=30,
parallel_tool_calls=False,
mock_response=MOCK_LLM_RESPONSE,
max_tokens=None,
)
@@ -291,5 +290,4 @@ def test_multiple_tool_calls_streaming(default_multi_llm: DefaultMultiLLM) -> No
timeout=30,
parallel_tool_calls=False,
mock_response=MOCK_LLM_RESPONSE,
max_tokens=None,
)

View File

@@ -1,14 +1,8 @@
import contextvars
import time
import pytest
from onyx.utils.threadpool_concurrency import run_in_background
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.threadpool_concurrency import wait_on_background
# Create a context variable for testing
test_context_var = contextvars.ContextVar("test_var", default="default")
def test_run_with_timeout_completes() -> None:
@@ -65,86 +59,3 @@ def test_run_with_timeout_with_args_and_kwargs() -> None:
# Test with positional and keyword args
result2 = run_with_timeout(1.0, complex_function, x=5, y=3, multiply=True)
assert result2 == 15
def test_run_in_background_and_wait_success() -> None:
"""Test that run_in_background and wait_on_background work correctly for successful execution"""
def background_function(x: int) -> int:
time.sleep(0.1) # Small delay to ensure it's actually running in background
return x * 2
# Start the background task
task = run_in_background(background_function, 21)
# Verify we can do other work while task is running
start_time = time.time()
result = wait_on_background(task)
elapsed = time.time() - start_time
assert result == 42
assert elapsed >= 0.1 # Verify we actually waited for the sleep
@pytest.mark.filterwarnings("ignore::pytest.PytestUnhandledThreadExceptionWarning")
def test_run_in_background_propagates_exceptions() -> None:
"""Test that exceptions in background tasks are properly propagated"""
def error_function() -> None:
time.sleep(0.1) # Small delay to ensure it's actually running in background
raise ValueError("Test background error")
task = run_in_background(error_function)
with pytest.raises(ValueError) as exc_info:
wait_on_background(task)
assert "Test background error" in str(exc_info.value)
def test_run_in_background_with_args_and_kwargs() -> None:
"""Test that args and kwargs are properly passed to the background function"""
def complex_function(x: int, y: int, multiply: bool = False) -> int:
time.sleep(0.1) # Small delay to ensure it's actually running in background
if multiply:
return x * y
return x + y
# Test with args
task1 = run_in_background(complex_function, 5, 3)
result1 = wait_on_background(task1)
assert result1 == 8
# Test with args and kwargs
task2 = run_in_background(complex_function, 5, 3, multiply=True)
result2 = wait_on_background(task2)
assert result2 == 15
def test_multiple_background_tasks() -> None:
"""Test running multiple background tasks concurrently"""
def slow_add(x: int, y: int) -> int:
time.sleep(0.2) # Make each task take some time
return x + y
# Start multiple tasks
start_time = time.time()
task1 = run_in_background(slow_add, 1, 2)
task2 = run_in_background(slow_add, 3, 4)
task3 = run_in_background(slow_add, 5, 6)
# Wait for all results
result1 = wait_on_background(task1)
result2 = wait_on_background(task2)
result3 = wait_on_background(task3)
elapsed = time.time() - start_time
# Verify results
assert result1 == 3
assert result2 == 7
assert result3 == 11
# Verify tasks ran in parallel (total time should be ~0.2s, not ~0.6s)
assert 0.2 <= elapsed < 0.4 # Allow some buffer for test environment variations

View File

@@ -4,9 +4,7 @@ import time
from onyx.utils.threadpool_concurrency import FunctionCall
from onyx.utils.threadpool_concurrency import run_functions_in_parallel
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
from onyx.utils.threadpool_concurrency import run_in_background
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.threadpool_concurrency import wait_on_background
# Create a test contextvar
test_var = contextvars.ContextVar("test_var", default="default")
@@ -131,39 +129,3 @@ def test_contextvar_isolation_between_runs() -> None:
# Verify second run results
assert all(result in ["thread3", "thread4"] for result in second_results)
def test_run_in_background_preserves_contextvar() -> None:
"""Test that run_in_background preserves contextvar values and modifications are isolated"""
def modify_and_sleep() -> tuple[str, str]:
"""Modifies contextvar, sleeps, and returns original, modified, and final values"""
original = test_var.get()
test_var.set("modified_in_background")
time.sleep(0.1) # Ensure we can check main thread during execution
final = test_var.get()
return original, final
# Set initial value in main thread
token = test_var.set("initial_value")
try:
# Start background task
task = run_in_background(modify_and_sleep)
# Verify main thread value remains unchanged while task runs
assert test_var.get() == "initial_value"
# Get results from background thread
original, modified = wait_on_background(task)
# Verify the background thread:
# 1. Saw the initial value
assert original == "initial_value"
# 2. Successfully modified its own copy
assert modified == "modified_in_background"
# Verify main thread value is still unchanged after task completion
assert test_var.get() == "initial_value"
finally:
# Clean up
test_var.reset(token)

View File

@@ -80,13 +80,3 @@ prod cluster**
- `kubectl delete -f .`
- To not delete the persistent volumes (Document indexes and Users), specify the specific `.yaml` files instead of
`.` without specifying delete on persistent-volumes.yaml.
### Using Helm to deploy to an existing cluster
Onyx has a helm chart that is convenient to install all services to an existing Kubernetes cluster. To install:
* Currently the helm chart is not published so to install, clone the repo.
* Configure access to the cluster via kubectl. Ensure the kubectl context is set to the cluster that you want to use
* The default secrets, environment variables and other service level configuration are stored in `deployment/helm/charts/onyx/values.yml`. You may create another `override.yml`
* `cd deployment/helm/charts/onyx` and run `helm install onyx -n onyx -f override.yaml .`. This will install onyx on the cluster under the `onyx` namespace.
* Check the status of the deploy using `kubectl get pods -n onyx`

View File

@@ -254,9 +254,6 @@ services:
# - ./bundle.pem:/app/bundle.pem:ro
extra_hosts:
- "host.docker.internal:host-gateway"
# optional, only for debugging purposes
volumes:
- log_store:/var/log/persisted-logs
logging:
driver: json-file
options:
@@ -434,4 +431,3 @@ volumes:
model_cache_huggingface:
indexing_huggingface_model_cache:
log_store: # for logs that we don't want to lose on container restarts

View File

@@ -209,9 +209,6 @@ services:
# - ./bundle.pem:/app/bundle.pem:ro
extra_hosts:
- "host.docker.internal:host-gateway"
# optional, only for debugging purposes
volumes:
- log_store:/var/log/persisted-logs
logging:
driver: json-file
options:
@@ -387,4 +384,3 @@ volumes:
# Created by the container itself
model_cache_huggingface:
indexing_huggingface_model_cache:
log_store: # for logs that we don't want to lose on container restarts

View File

@@ -244,8 +244,6 @@ services:
# - ./bundle.pem:/app/bundle.pem:ro
extra_hosts:
- "host.docker.internal:host-gateway"
volumes:
- log_store:/var/log/persisted-logs
logging:
driver: json-file
options:
@@ -423,4 +421,3 @@ volumes:
model_cache_huggingface:
indexing_huggingface_model_cache:
log_store: # for logs that we don't want to lose on container restarts

View File

@@ -54,9 +54,6 @@ services:
- INDEXING_MODEL_SERVER_HOST=${INDEXING_MODEL_SERVER_HOST:-indexing_model_server}
extra_hosts:
- "host.docker.internal:host-gateway"
# optional, only for debugging purposes
volumes:
- log_store:/var/log/persisted-logs
logging:
driver: json-file
options:
@@ -236,4 +233,3 @@ volumes:
# Created by the container itself
model_cache_huggingface:
indexing_huggingface_model_cache:
log_store: # for logs that we don't want to lose on container restarts

View File

@@ -68,8 +68,6 @@ services:
# - ./bundle.pem:/app/bundle.pem:ro
extra_hosts:
- "host.docker.internal:host-gateway"
volumes:
- log_store:/var/log/persisted-logs
logging:
driver: json-file
options:
@@ -231,4 +229,3 @@ volumes:
# Created by the container itself
model_cache_huggingface:
indexing_huggingface_model_cache:
log_store: # for logs that we don't want to lose on container restarts

View File

@@ -32,8 +32,6 @@ services:
# - ./bundle.pem:/app/bundle.pem:ro
extra_hosts:
- "host.docker.internal:host-gateway"
volumes:
- log_store:/var/log/persisted-logs
logging:
driver: json-file
options:
@@ -75,8 +73,6 @@ services:
# - ./bundle.pem:/app/bundle.pem:ro
extra_hosts:
- "host.docker.internal:host-gateway"
volumes:
- log_store:/var/log/persisted-logs
logging:
driver: json-file
options:
@@ -264,4 +260,3 @@ volumes:
# Created by the container itself
model_cache_huggingface:
indexing_huggingface_model_cache:
log_store: # for logs that we don't want to lose on container restarts

View File

@@ -62,8 +62,6 @@ services:
- ENABLE_PAID_ENTERPRISE_EDITION_FEATURES=True
extra_hosts:
- "host.docker.internal:host-gateway"
volumes:
- log_store:/var/log/persisted-logs
logging:
driver: json-file
options:
@@ -223,4 +221,3 @@ volumes:
type: none
o: bind
device: ${DANSWER_VESPA_DATA_DIR:-./vespa_data}
log_store: # for logs that we don't want to lose on container restarts

View File

@@ -1,27 +0,0 @@
{{- if .Values.ingress.enabled -}}
apiVersion: networking.k8s.io/v1
kind: Ingress
metadata:
name: {{ include "onyx-stack.fullname" . }}-ingress-api
annotations:
kubernetes.io/ingress.class: nginx
nginx.ingress.kubernetes.io/rewrite-target: /$2
nginx.ingress.kubernetes.io/use-regex: "true"
cert-manager.io/cluster-issuer: {{ include "onyx-stack.fullname" . }}-letsencrypt
spec:
rules:
- host: {{ .Values.ingress.api.host }}
http:
paths:
- path: /api(/|$)(.*)
pathType: Prefix
backend:
service:
name: {{ include "onyx-stack.fullname" . }}-api-service
port:
number: {{ .Values.api.service.servicePort }}
tls:
- hosts:
- {{ .Values.ingress.api.host }}
secretName: {{ include "onyx-stack.fullname" . }}-ingress-api-tls
{{- end }}

View File

@@ -1,26 +0,0 @@
{{- if .Values.ingress.enabled -}}
apiVersion: networking.k8s.io/v1
kind: Ingress
metadata:
name: {{ include "onyx-stack.fullname" . }}-ingress-webserver
annotations:
kubernetes.io/ingress.class: nginx
cert-manager.io/cluster-issuer: {{ include "onyx-stack.fullname" . }}-letsencrypt
kubernetes.io/tls-acme: "true"
spec:
rules:
- host: {{ .Values.ingress.webserver.host }}
http:
paths:
- path: /
pathType: Prefix
backend:
service:
name: {{ include "onyx-stack.fullname" . }}-webserver
port:
number: {{ .Values.webserver.service.servicePort }}
tls:
- hosts:
- {{ .Values.ingress.webserver.host }}
secretName: {{ include "onyx-stack.fullname" . }}-ingress-webserver-tls
{{- end }}

View File

@@ -1,20 +0,0 @@
{{- if .Values.letsencrypt.enabled -}}
apiVersion: cert-manager.io/v1
kind: ClusterIssuer
metadata:
name: {{ include "onyx-stack.fullname" . }}-letsencrypt
spec:
acme:
# The ACME server URL
server: https://acme-v02.api.letsencrypt.org/directory
# Email address used for ACME registration
email: {{ .Values.letsencrypt.email }}
# Name of a secret used to store the ACME account private key
privateKeySecretRef:
name: {{ include "onyx-stack.fullname" . }}-letsencrypt
# Enable the HTTP-01 challenge provider
solvers:
- http01:
ingress:
class: nginx
{{- end }}

View File

@@ -376,17 +376,22 @@ redis:
existingSecret: onyx-secrets
existingSecretPasswordKey: redis_password
ingress:
enabled: false
className: ""
api:
host: onyx.local
webserver:
host: onyx.local
# ingress:
# enabled: false
# className: ""
# annotations: {}
# # kubernetes.io/ingress.class: nginx
# # kubernetes.io/tls-acme: "true"
# hosts:
# - host: chart-example.local
# paths:
# - path: /
# pathType: ImplementationSpecific
# tls: []
# # - secretName: chart-example-tls
# # hosts:
# # - chart-example.local
letsencrypt:
enabled: false
email: "abc@abc.com"
auth:
# existingSecret onyx-secret for storing smtp, oauth, slack, and other secrets

View File

@@ -1,14 +1,17 @@
import {
AnthropicIcon,
AmazonIcon,
AWSIcon,
AzureIcon,
CPUIcon,
MicrosoftIconSVG,
MistralIcon,
MetaIcon,
GeminiIcon,
AnthropicSVG,
IconProps,
OpenAIISVG,
DeepseekIcon,
OpenAISVG,
} from "@/components/icons/icons";
export interface CustomConfigKey {
@@ -71,7 +74,7 @@ export interface LLMProviderDescriptor {
}
export const getProviderIcon = (providerName: string, modelName?: string) => {
const iconMap: Record<
const modelIconMap: Record<
string,
({ size, className }: IconProps) => JSX.Element
> = {
@@ -83,30 +86,34 @@ export const getProviderIcon = (providerName: string, modelName?: string) => {
gemini: GeminiIcon,
deepseek: DeepseekIcon,
claude: AnthropicIcon,
anthropic: AnthropicIcon,
openai: OpenAISVG,
microsoft: MicrosoftIconSVG,
meta: MetaIcon,
google: GeminiIcon,
};
// First check if provider name directly matches an icon
if (providerName.toLowerCase() in iconMap) {
return iconMap[providerName.toLowerCase()];
}
// Then check if model name contains any of the keys
if (modelName) {
const lowerModelName = modelName.toLowerCase();
for (const [key, icon] of Object.entries(iconMap)) {
if (lowerModelName.includes(key)) {
const modelNameToIcon = (
modelName: string,
fallbackIcon: ({ size, className }: IconProps) => JSX.Element
): (({ size, className }: IconProps) => JSX.Element) => {
const lowerModelName = modelName?.toLowerCase();
for (const [key, icon] of Object.entries(modelIconMap)) {
if (lowerModelName?.includes(key)) {
return icon;
}
}
}
return fallbackIcon;
};
// Fallback to CPU icon if no matches
return CPUIcon;
switch (providerName) {
case "openai":
// Special cases for openai based on modelName
return modelNameToIcon(modelName || "", OpenAIISVG);
case "anthropic":
return AnthropicSVG;
case "bedrock":
return AWSIcon;
case "azure":
return AzureIcon;
default:
return modelNameToIcon(modelName || "", CPUIcon);
}
};
export const isAnthropic = (provider: string, modelName: string) =>

View File

@@ -185,10 +185,7 @@ export const FilterComponent = forwardRef<
hasActiveFilters ? "border-primary bg-primary/5" : ""
}`}
>
<SortIcon
size={20}
className="text-neutral-800 dark:text-neutral-200"
/>
<SortIcon size={20} className="text-neutral-800" />
</Button>
</DropdownMenuTrigger>
<DropdownMenuContent
@@ -368,7 +365,7 @@ export const FilterComponent = forwardRef<
{hasActiveFilters && (
<div className="absolute -top-1 -right-1">
<Badge className="h-2 !bg-red-400 !border-red-400 w-2 p-0 border-2 flex items-center justify-center" />
<Badge className="h-2 bg-red-400 border-red-400 w-2 p-0 border-2 flex items-center justify-center" />
</div>
)}
</div>

View File

@@ -26,7 +26,7 @@ export function Checkbox({
onChange: (e: React.ChangeEvent<HTMLInputElement>) => void;
}) {
return (
<label className="flex text-xs cursor-pointer">
<label className="flex text-sm cursor-pointer">
<input
checked={checked}
onChange={onChange}
@@ -34,7 +34,7 @@ export function Checkbox({
className="mr-2 w-3.5 h-3.5 my-auto"
/>
<div>
<Label small>{label}</Label>
<Label>{label}</Label>
{sublabel && <SubLabel>{sublabel}</SubLabel>}
</div>
</label>
@@ -208,7 +208,7 @@ export function SettingsForm() {
}
return (
<div className="flex flex-col pb-8">
<div>
{popup}
<Title className="mb-4">Workspace Settings</Title>
<Checkbox
@@ -290,71 +290,23 @@ export function SettingsForm() {
id="chatRetentionInput"
placeholder="Infinite Retention"
/>
<div className="mr-auto flex gap-2">
<Button
onClick={handleSetChatRetention}
variant="submit"
size="sm"
className="mr-auto"
>
Set Retention Limit
</Button>
<Button
onClick={handleClearChatRetention}
variant="default"
size="sm"
className="mr-auto"
>
Retain All
</Button>
</div>
<Button
onClick={handleSetChatRetention}
variant="submit"
size="sm"
className="mr-3"
>
Set Retention Limit
</Button>
<Button
onClick={handleClearChatRetention}
variant="default"
size="sm"
>
Retain All
</Button>
</>
)}
{/* Image Processing Settings */}
<Title className="mt-8 mb-4">Image Processing</Title>
<div className="flex flex-col gap-2">
<Checkbox
label="Enable Image Extraction and Analysis"
sublabel="Extract and analyze images from documents during indexing. This allows the system to process images and create searchable descriptions of them."
checked={settings.image_extraction_and_analysis_enabled ?? false}
onChange={(e) =>
handleToggleSettingsField(
"image_extraction_and_analysis_enabled",
e.target.checked
)
}
/>
<Checkbox
label="Enable Search-time Image Analysis"
sublabel="Analyze images at search time when a user asks about images. This provides more detailed and query-specific image analysis but may increase search-time latency."
checked={settings.search_time_image_analysis_enabled ?? false}
onChange={(e) =>
handleToggleSettingsField(
"search_time_image_analysis_enabled",
e.target.checked
)
}
/>
<IntegerInput
label="Maximum Image Size for Analysis (MB)"
sublabel="Images larger than this size will not be analyzed to prevent excessive resource usage."
value={settings.image_analysis_max_size_mb ?? null}
onChange={(e) => {
const value = e.target.value ? parseInt(e.target.value) : null;
if (value !== null && !isNaN(value) && value > 0) {
updateSettingField([
{ fieldName: "image_analysis_max_size_mb", newValue: value },
]);
}
}}
id="image-analysis-max-size"
placeholder="Enter maximum size in MB"
/>
</div>
</div>
);
}

View File

@@ -21,11 +21,6 @@ export interface Settings {
auto_scroll: boolean;
temperature_override_enabled: boolean;
query_history_type: QueryHistoryType;
// Image processing settings
image_extraction_and_analysis_enabled?: boolean;
search_time_image_analysis_enabled?: boolean;
image_analysis_max_size_mb?: number;
}
export enum NotificationType {

Some files were not shown because too many files have changed in this diff Show More