Compare commits

..

36 Commits

Author SHA1 Message Date
joachim-danswer
68aa00e330 manual import reordering 2025-03-13 09:01:01 -07:00
joachim-danswer
3a78421e38 Update migration 2025-03-13 09:01:01 -07:00
joachim-danswer
0aa9e8968a RK comments 2025-03-13 09:01:01 -07:00
joachim-danswer
7d9e133e35 create columns for individual boost factors 2025-03-13 09:01:01 -07:00
joachim-danswer
838160e660 improved dict score update 2025-03-13 09:01:01 -07:00
joachim-danswer
6b84332f1b saving and updating chunk stats 2025-03-13 09:01:01 -07:00
joachim-danswer
4fe5561f44 saving and deleting chunk stats in new table 2025-03-13 09:01:01 -07:00
joachim-danswer
2d81d6082a renaming to chunk_boost & chunk table def 2025-03-13 09:01:01 -07:00
joachim-danswer
ef291fcf0c YS comments - pt 1 2025-03-13 09:01:01 -07:00
joachim-danswer
b8f64d10a2 small updates 2025-03-13 09:01:01 -07:00
joachim-danswer
9f37ca23e8 default off 2025-03-13 09:01:01 -07:00
joachim-danswer
34b2e5d9d3 new home for model_server configs 2025-03-13 09:01:01 -07:00
joachim-danswer
ff6e4cd231 fixed docker file 2025-03-13 09:01:01 -07:00
joachim-danswer
9d7137b3bb requirements version update 2025-03-13 09:01:01 -07:00
joachim-danswer
91cdcd820f nit 2025-03-13 09:01:01 -07:00
joachim-danswer
4ff0c18822 EL comments and improvements
Improvements:
  - proper import of information content model from cache or HF
  - warm up for information content model

Other:
  - EL PR review comments
2025-03-13 09:01:01 -07:00
joachim-danswer
66976529d2 nit 2025-03-13 09:01:01 -07:00
joachim-danswer
1954105cc6 avoid boost_score > 1.0 2025-03-13 09:01:01 -07:00
joachim-danswer
625ea6f24b name change to information_content_model 2025-03-13 09:01:01 -07:00
joachim-danswer
60d2c8c86c improvements 2025-03-13 09:01:01 -07:00
joachim-danswer
324b8e42a5 simplification 2025-03-13 09:01:01 -07:00
joachim-danswer
125877ec65 initial working code 2025-03-13 09:01:01 -07:00
joachim-danswer
ca803859cc remove title for slack 2025-03-13 09:01:01 -07:00
rkuo-danswer
b4ecc870b9 safe handling for mediaType in confluence connector in all places (#4269)
Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-03-13 06:09:19 +00:00
rkuo-danswer
a2ac9f02fb unique constraint here doesn't work (#4271)
Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-03-12 16:25:27 -07:00
pablonyx
f87e559cc4 Separate out indexing-time image analysis into new phase (#4228)
* Separate out indexing-time image analysis into new phase

* looking good

* k

* k
2025-03-12 22:26:05 +00:00
pablonyx
5883336d5e Support image indexing customization (#4261)
* working well

* k

* ready to go

* k

* minor nits

* k

* quick fix

* k

* k
2025-03-12 20:03:45 +00:00
pablonyx
0153ff6b51 Improved logout flow (#4258)
* improved app provider modals

* improved logout flow

* k

* updates

* add docstring
2025-03-12 19:19:39 +00:00
pablonyx
2f8f0f01be Tenants on standby (#4218)
* add tenants on standby feature

* k

* fix alembic

* k

* k
2025-03-12 18:25:30 +00:00
pablonyx
a9e5ae2f11 Fix slash mystery (#4263) 2025-03-12 10:03:21 -07:00
Chris Weaver
997f40500d Add support for sandboxed salesforce (#4252) 2025-03-12 00:21:24 +00:00
rkuo-danswer
a918a84e7b fix oauth downloading and size limits in confluence (#4249)
* fix oauth downloading and size limits in confluence

* bump black to get past corrupt hash

* try working around another corrupt package

* fix raw_bytes

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2025-03-11 23:57:47 +00:00
rkuo-danswer
090f3fe817 handle conflicts on lowercasing emails (#4255)
Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-03-11 21:25:50 +00:00
pablonyx
4e70f99214 Fix slack links (#4254)
* fix slack links

* updates

* k

* nit improvements
2025-03-11 19:58:15 +00:00
pablonyx
ecbd4eb1ad add basic user invite flow (#4253) 2025-03-11 19:02:51 +00:00
pablonyx
f94d335d12 Do not show modals to non-multitenant users (#4256) 2025-03-11 11:53:13 -07:00
142 changed files with 3214 additions and 1248 deletions

View File

@@ -31,7 +31,8 @@ RUN python -c "from transformers import AutoTokenizer; \
AutoTokenizer.from_pretrained('distilbert-base-uncased'); \
AutoTokenizer.from_pretrained('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
from huggingface_hub import snapshot_download; \
snapshot_download(repo_id='danswer/hybrid-intent-token-classifier', revision='v1.0.3'); \
snapshot_download(repo_id='onyx-dot-app/hybrid-intent-token-classifier'); \
snapshot_download(repo_id='onyx-dot-app/information-content-model'); \
snapshot_download('nomic-ai/nomic-embed-text-v1'); \
snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
from sentence_transformers import SentenceTransformer; \

View File

@@ -0,0 +1,51 @@
"""add chunk stats table
Revision ID: 3781a5eb12cb
Revises: df46c75b714e
Create Date: 2025-03-10 10:02:30.586666
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "3781a5eb12cb"
down_revision = "df46c75b714e"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"chunk_stats",
sa.Column("id", sa.String(), primary_key=True, index=True),
sa.Column(
"document_id",
sa.String(),
sa.ForeignKey("document.id"),
nullable=False,
index=True,
),
sa.Column("chunk_in_doc_id", sa.Integer(), nullable=False),
sa.Column("information_content_boost", sa.Float(), nullable=True),
sa.Column(
"last_modified",
sa.DateTime(timezone=True),
nullable=False,
index=True,
server_default=sa.func.now(),
),
sa.Column("last_synced", sa.DateTime(timezone=True), nullable=True, index=True),
sa.UniqueConstraint(
"document_id", "chunk_in_doc_id", name="uq_chunk_stats_doc_chunk"
),
)
op.create_index(
"ix_chunk_sync_status", "chunk_stats", ["last_modified", "last_synced"]
)
def downgrade() -> None:
op.drop_index("ix_chunk_sync_status", table_name="chunk_stats")
op.drop_table("chunk_stats")

View File

@@ -5,7 +5,10 @@ Revises: f1ca58b2f2ec
Create Date: 2025-01-29 07:48:46.784041
"""
import logging
from typing import cast
from alembic import op
from sqlalchemy.exc import IntegrityError
from sqlalchemy.sql import text
@@ -15,21 +18,45 @@ down_revision = "f1ca58b2f2ec"
branch_labels = None
depends_on = None
logger = logging.getLogger("alembic.runtime.migration")
def upgrade() -> None:
# Get database connection
"""Conflicts on lowercasing will result in the uppercased email getting a
unique integer suffix when converted to lowercase."""
connection = op.get_bind()
# Update all user emails to lowercase
connection.execute(
text(
"""
UPDATE "user"
SET email = LOWER(email)
WHERE email != LOWER(email)
"""
)
)
# Fetch all user emails that are not already lowercase
user_emails = connection.execute(
text('SELECT id, email FROM "user" WHERE email != LOWER(email)')
).fetchall()
for user_id, email in user_emails:
email = cast(str, email)
username, domain = email.rsplit("@", 1)
new_email = f"{username.lower()}@{domain.lower()}"
attempt = 1
while True:
try:
# Try updating the email
connection.execute(
text('UPDATE "user" SET email = :new_email WHERE id = :user_id'),
{"new_email": new_email, "user_id": user_id},
)
break # Success, exit loop
except IntegrityError:
next_email = f"{username.lower()}_{attempt}@{domain.lower()}"
# Email conflict occurred, append `_1`, `_2`, etc., to the username
logger.warning(
f"Conflict while lowercasing email: "
f"old_email={email} "
f"conflicting_email={new_email} "
f"next_email={next_email}"
)
new_email = next_email
attempt += 1
def downgrade() -> None:

View File

@@ -0,0 +1,36 @@
"""add_default_vision_provider_to_llm_provider
Revision ID: df46c75b714e
Revises: 3934b1bc7b62
Create Date: 2025-03-11 16:20:19.038945
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "df46c75b714e"
down_revision = "3934b1bc7b62"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"llm_provider",
sa.Column(
"is_default_vision_provider",
sa.Boolean(),
nullable=True,
server_default=sa.false(),
),
)
op.add_column(
"llm_provider", sa.Column("default_vision_model", sa.String(), nullable=True)
)
def downgrade() -> None:
op.drop_column("llm_provider", "default_vision_model")
op.drop_column("llm_provider", "is_default_vision_provider")

View File

@@ -0,0 +1,33 @@
"""add new available tenant table
Revision ID: 3b45e0018bf1
Revises: ac842f85f932
Create Date: 2025-03-06 09:55:18.229910
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "3b45e0018bf1"
down_revision = "ac842f85f932"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create new_available_tenant table
op.create_table(
"available_tenant",
sa.Column("tenant_id", sa.String(), nullable=False),
sa.Column("alembic_version", sa.String(), nullable=False),
sa.Column("date_created", sa.DateTime(), nullable=False),
sa.PrimaryKeyConstraint("tenant_id"),
)
def downgrade() -> None:
# Drop new_available_tenant table
op.drop_table("available_tenant")

View File

@@ -28,11 +28,12 @@ from onyx.auth.users import exceptions
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.constants import MilestoneRecordType
from onyx.db.engine import get_session_with_shared_schema
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_sqlalchemy_engine
from onyx.db.llm import update_default_provider
from onyx.db.llm import upsert_cloud_embedding_provider
from onyx.db.llm import upsert_llm_provider
from onyx.db.models import AvailableTenant
from onyx.db.models import IndexModelStatus
from onyx.db.models import SearchSettings
from onyx.db.models import UserTenantMapping
@@ -62,42 +63,72 @@ async def get_or_provision_tenant(
This function should only be called after we have verified we want this user's tenant to exist.
It returns the tenant ID associated with the email, creating a new tenant if necessary.
"""
# Early return for non-multi-tenant mode
if not MULTI_TENANT:
return POSTGRES_DEFAULT_SCHEMA
if referral_source and request:
await submit_to_hubspot(email, referral_source, request)
# First, check if the user already has a tenant
tenant_id: str | None = None
try:
tenant_id = get_tenant_id_for_email(email)
return tenant_id
except exceptions.UserNotExists:
# If tenant does not exist and in Multi tenant mode, provision a new tenant
try:
# User doesn't exist, so we need to create a new tenant or assign an existing one
pass
try:
# Try to get a pre-provisioned tenant
tenant_id = await get_available_tenant()
if tenant_id:
# If we have a pre-provisioned tenant, assign it to the user
await assign_tenant_to_user(tenant_id, email, referral_source)
logger.info(f"Assigned pre-provisioned tenant {tenant_id} to user {email}")
return tenant_id
else:
# If no pre-provisioned tenant is available, create a new one on-demand
tenant_id = await create_tenant(email, referral_source)
except Exception as e:
logger.error(f"Tenant provisioning failed: {e}")
raise HTTPException(status_code=500, detail="Failed to provision tenant.")
return tenant_id
if not tenant_id:
except Exception as e:
# If we've encountered an error, log and raise an exception
error_msg = "Failed to provision tenant"
logger.error(error_msg, exc_info=e)
raise HTTPException(
status_code=401, detail="User does not belong to an organization"
status_code=500,
detail="Failed to provision tenant. Please try again later.",
)
return tenant_id
async def create_tenant(email: str, referral_source: str | None = None) -> str:
"""
Create a new tenant on-demand when no pre-provisioned tenants are available.
This is the fallback method when we can't use a pre-provisioned tenant.
"""
tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4())
logger.info(f"Creating new tenant {tenant_id} for user {email}")
try:
# Provision tenant on data plane
await provision_tenant(tenant_id, email)
# Notify control plane
if not DEV_MODE:
# Notify control plane if not already done in provision_tenant
if not DEV_MODE and referral_source:
await notify_control_plane(tenant_id, email, referral_source)
except Exception as e:
logger.error(f"Tenant provisioning failed: {e}")
await rollback_tenant_provisioning(tenant_id)
logger.exception(f"Tenant provisioning failed: {str(e)}")
# Attempt to rollback the tenant provisioning
try:
await rollback_tenant_provisioning(tenant_id)
except Exception:
logger.exception(f"Failed to rollback tenant provisioning for {tenant_id}")
raise HTTPException(status_code=500, detail="Failed to provision tenant.")
return tenant_id
@@ -111,54 +142,25 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
)
logger.debug(f"Provisioning tenant {tenant_id} for user {email}")
token = None
try:
# Create the schema for the tenant
if not create_schema_if_not_exists(tenant_id):
logger.debug(f"Created schema for tenant {tenant_id}")
else:
logger.debug(f"Schema already exists for tenant {tenant_id}")
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
# Set up the tenant with all necessary configurations
await setup_tenant(tenant_id)
# Await the Alembic migrations
await asyncio.to_thread(run_alembic_migrations, tenant_id)
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
configure_default_api_keys(db_session)
current_search_settings = (
db_session.query(SearchSettings)
.filter_by(status=IndexModelStatus.FUTURE)
.first()
)
cohere_enabled = (
current_search_settings is not None
and current_search_settings.provider_type == EmbeddingProvider.COHERE
)
setup_onyx(db_session, tenant_id, cohere_enabled=cohere_enabled)
add_users_to_tenant([email], tenant_id)
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
create_milestone_and_report(
user=None,
distinct_id=tenant_id,
event_type=MilestoneRecordType.TENANT_CREATED,
properties={
"email": email,
},
db_session=db_session,
)
# Assign the tenant to the user
await assign_tenant_to_user(tenant_id, email)
except Exception as e:
logger.exception(f"Failed to create tenant {tenant_id}")
raise HTTPException(
status_code=500, detail=f"Failed to create tenant: {str(e)}"
)
finally:
if token is not None:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
async def notify_control_plane(
@@ -189,20 +191,74 @@ async def notify_control_plane(
async def rollback_tenant_provisioning(tenant_id: str) -> None:
# Logic to rollback tenant provisioning on data plane
"""
Logic to rollback tenant provisioning on data plane.
Handles each step independently to ensure maximum cleanup even if some steps fail.
"""
logger.info(f"Rolling back tenant provisioning for tenant_id: {tenant_id}")
try:
# Drop the tenant's schema to rollback provisioning
drop_schema(tenant_id)
# Remove tenant mapping
with Session(get_sqlalchemy_engine()) as db_session:
db_session.query(UserTenantMapping).filter(
UserTenantMapping.tenant_id == tenant_id
).delete()
db_session.commit()
# Track if any part of the rollback fails
rollback_errors = []
# 1. Try to drop the tenant's schema
try:
drop_schema(tenant_id)
logger.info(f"Successfully dropped schema for tenant {tenant_id}")
except Exception as e:
logger.error(f"Failed to rollback tenant provisioning: {e}")
error_msg = f"Failed to drop schema for tenant {tenant_id}: {str(e)}"
logger.error(error_msg)
rollback_errors.append(error_msg)
# 2. Try to remove tenant mapping
try:
with get_session_with_shared_schema() as db_session:
db_session.begin()
try:
db_session.query(UserTenantMapping).filter(
UserTenantMapping.tenant_id == tenant_id
).delete()
db_session.commit()
logger.info(
f"Successfully removed user mappings for tenant {tenant_id}"
)
except Exception as e:
db_session.rollback()
raise e
except Exception as e:
error_msg = f"Failed to remove user mappings for tenant {tenant_id}: {str(e)}"
logger.error(error_msg)
rollback_errors.append(error_msg)
# 3. If this tenant was in the available tenants table, remove it
try:
with get_session_with_shared_schema() as db_session:
db_session.begin()
try:
available_tenant = (
db_session.query(AvailableTenant)
.filter(AvailableTenant.tenant_id == tenant_id)
.first()
)
if available_tenant:
db_session.delete(available_tenant)
db_session.commit()
logger.info(
f"Removed tenant {tenant_id} from available tenants table"
)
except Exception as e:
db_session.rollback()
raise e
except Exception as e:
error_msg = f"Failed to remove tenant {tenant_id} from available tenants table: {str(e)}"
logger.error(error_msg)
rollback_errors.append(error_msg)
# Log summary of rollback operation
if rollback_errors:
logger.error(f"Tenant rollback completed with {len(rollback_errors)} errors")
else:
logger.info(f"Tenant rollback completed successfully for tenant {tenant_id}")
def configure_default_api_keys(db_session: Session) -> None:
@@ -399,3 +455,111 @@ def get_tenant_by_domain_from_control_plane(
except Exception as e:
logger.error(f"Error fetching tenant by domain: {str(e)}")
return None
async def get_available_tenant() -> str | None:
"""
Get an available pre-provisioned tenant from the NewAvailableTenant table.
Returns the tenant_id if one is available, None otherwise.
Uses row-level locking to prevent race conditions when multiple processes
try to get an available tenant simultaneously.
"""
if not MULTI_TENANT:
return None
with get_session_with_shared_schema() as db_session:
try:
db_session.begin()
# Get the oldest available tenant with FOR UPDATE lock to prevent race conditions
available_tenant = (
db_session.query(AvailableTenant)
.order_by(AvailableTenant.date_created)
.with_for_update(skip_locked=True) # Skip locked rows to avoid blocking
.first()
)
if available_tenant:
tenant_id = available_tenant.tenant_id
# Remove the tenant from the available tenants table
db_session.delete(available_tenant)
db_session.commit()
logger.info(f"Using pre-provisioned tenant {tenant_id}")
return tenant_id
else:
db_session.rollback()
return None
except Exception:
logger.exception("Error getting available tenant")
db_session.rollback()
return None
async def setup_tenant(tenant_id: str) -> None:
"""
Set up a tenant with all necessary configurations.
This is a centralized function that handles all tenant setup logic.
"""
token = None
try:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
# Run Alembic migrations
await asyncio.to_thread(run_alembic_migrations, tenant_id)
# Configure the tenant with default settings
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
# Configure default API keys
configure_default_api_keys(db_session)
# Set up Onyx with appropriate settings
current_search_settings = (
db_session.query(SearchSettings)
.filter_by(status=IndexModelStatus.FUTURE)
.first()
)
cohere_enabled = (
current_search_settings is not None
and current_search_settings.provider_type == EmbeddingProvider.COHERE
)
setup_onyx(db_session, tenant_id, cohere_enabled=cohere_enabled)
except Exception as e:
logger.exception(f"Failed to set up tenant {tenant_id}")
raise e
finally:
if token is not None:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
async def assign_tenant_to_user(
tenant_id: str, email: str, referral_source: str | None = None
) -> None:
"""
Assign a tenant to a user and perform necessary operations.
Uses transaction handling to ensure atomicity and includes retry logic
for control plane notifications.
"""
# First, add the user to the tenant in a transaction
try:
add_users_to_tenant([email], tenant_id)
# Create milestone record in the same transaction context as the tenant assignment
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
create_milestone_and_report(
user=None,
distinct_id=tenant_id,
event_type=MilestoneRecordType.TENANT_CREATED,
properties={
"email": email,
},
db_session=db_session,
)
except Exception:
logger.exception(f"Failed to assign tenant {tenant_id} to user {email}")
raise Exception("Failed to assign tenant to user")
# Notify control plane with retry logic
if not DEV_MODE:
await notify_control_plane(tenant_id, email, referral_source)

View File

@@ -74,3 +74,21 @@ def drop_schema(tenant_id: str) -> None:
text("DROP SCHEMA IF EXISTS %(schema_name)s CASCADE"),
{"schema_name": tenant_id},
)
def get_current_alembic_version(tenant_id: str) -> str:
"""Get the current Alembic version for a tenant."""
from alembic.runtime.migration import MigrationContext
from sqlalchemy import text
engine = get_sqlalchemy_engine()
# Set the search path to the tenant's schema
with engine.connect() as connection:
connection.execute(text(f'SET search_path TO "{tenant_id}"'))
# Get the current version from the alembic_version table
context = MigrationContext.configure(connection)
current_rev = context.get_current_revision()
return current_rev or "head"

View File

@@ -67,15 +67,39 @@ def user_owns_a_tenant(email: str) -> bool:
def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
"""
Add users to a tenant with proper transaction handling.
Checks if users already have a tenant mapping to avoid duplicates.
"""
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
try:
# Start a transaction
db_session.begin()
for email in emails:
db_session.add(
UserTenantMapping(email=email, tenant_id=tenant_id, active=False)
# Check if the user already has a mapping to this tenant
existing_mapping = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email == email,
UserTenantMapping.tenant_id == tenant_id,
)
.with_for_update()
.first()
)
if not existing_mapping:
# Only add if mapping doesn't exist
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
# Commit the transaction
db_session.commit()
logger.info(f"Successfully added users {emails} to tenant {tenant_id}")
except Exception:
logger.exception(f"Failed to add users to tenant {tenant_id}")
db_session.commit()
db_session.rollback()
raise
def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:

View File

@@ -3,6 +3,7 @@ from shared_configs.enums import EmbedTextType
MODEL_WARM_UP_STRING = "hi " * 512
INFORMATION_CONTENT_MODEL_WARM_UP_STRING = "hi " * 16
DEFAULT_OPENAI_MODEL = "text-embedding-3-small"
DEFAULT_COHERE_MODEL = "embed-english-light-v3.0"
DEFAULT_VOYAGE_MODEL = "voyage-large-2-instruct"

View File

@@ -1,11 +1,14 @@
import numpy as np
import torch
import torch.nn.functional as F
from fastapi import APIRouter
from huggingface_hub import snapshot_download # type: ignore
from setfit import SetFitModel # type: ignore[import]
from transformers import AutoTokenizer # type: ignore
from transformers import BatchEncoding # type: ignore
from transformers import PreTrainedTokenizer # type: ignore
from model_server.constants import INFORMATION_CONTENT_MODEL_WARM_UP_STRING
from model_server.constants import MODEL_WARM_UP_STRING
from model_server.onyx_torch_model import ConnectorClassifier
from model_server.onyx_torch_model import HybridClassifier
@@ -13,11 +16,22 @@ from model_server.utils import simple_log_function_time
from onyx.utils.logger import setup_logger
from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_REPO
from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_TAG
from shared_configs.configs import (
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH,
)
from shared_configs.configs import INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX
from shared_configs.configs import INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
from shared_configs.configs import (
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE,
)
from shared_configs.configs import INDEXING_ONLY
from shared_configs.configs import INFORMATION_CONTENT_MODEL_TAG
from shared_configs.configs import INFORMATION_CONTENT_MODEL_VERSION
from shared_configs.configs import INTENT_MODEL_TAG
from shared_configs.configs import INTENT_MODEL_VERSION
from shared_configs.model_server_models import ConnectorClassificationRequest
from shared_configs.model_server_models import ConnectorClassificationResponse
from shared_configs.model_server_models import ContentClassificationPrediction
from shared_configs.model_server_models import IntentRequest
from shared_configs.model_server_models import IntentResponse
@@ -31,6 +45,10 @@ _CONNECTOR_CLASSIFIER_MODEL: ConnectorClassifier | None = None
_INTENT_TOKENIZER: AutoTokenizer | None = None
_INTENT_MODEL: HybridClassifier | None = None
_INFORMATION_CONTENT_MODEL: SetFitModel | None = None
_INFORMATION_CONTENT_MODEL_PROMPT_PREFIX: str = "" # spec to model version!
def get_connector_classifier_tokenizer() -> AutoTokenizer:
global _CONNECTOR_CLASSIFIER_TOKENIZER
@@ -85,7 +103,7 @@ def get_intent_model_tokenizer() -> AutoTokenizer:
def get_local_intent_model(
model_name_or_path: str = INTENT_MODEL_VERSION,
tag: str = INTENT_MODEL_TAG,
tag: str | None = INTENT_MODEL_TAG,
) -> HybridClassifier:
global _INTENT_MODEL
if _INTENT_MODEL is None:
@@ -102,7 +120,9 @@ def get_local_intent_model(
try:
# Attempt to download the model snapshot
logger.notice(f"Downloading model snapshot for {model_name_or_path}")
local_path = snapshot_download(repo_id=model_name_or_path, revision=tag)
local_path = snapshot_download(
repo_id=model_name_or_path, revision=tag, local_files_only=False
)
_INTENT_MODEL = HybridClassifier.from_pretrained(local_path)
except Exception as e:
logger.error(
@@ -112,6 +132,44 @@ def get_local_intent_model(
return _INTENT_MODEL
def get_local_information_content_model(
model_name_or_path: str = INFORMATION_CONTENT_MODEL_VERSION,
tag: str | None = INFORMATION_CONTENT_MODEL_TAG,
) -> SetFitModel:
global _INFORMATION_CONTENT_MODEL
if _INFORMATION_CONTENT_MODEL is None:
try:
# Calculate where the cache should be, then load from local if available
logger.notice(
f"Loading content information model from local cache: {model_name_or_path}"
)
local_path = snapshot_download(
repo_id=model_name_or_path, revision=tag, local_files_only=True
)
_INFORMATION_CONTENT_MODEL = SetFitModel.from_pretrained(local_path)
logger.notice(
f"Loaded content information model from local cache: {local_path}"
)
except Exception as e:
logger.warning(f"Failed to load content information model directly: {e}")
try:
# Attempt to download the model snapshot
logger.notice(
f"Downloading content information model snapshot for {model_name_or_path}"
)
local_path = snapshot_download(
repo_id=model_name_or_path, revision=tag, local_files_only=False
)
_INFORMATION_CONTENT_MODEL = SetFitModel.from_pretrained(local_path)
except Exception as e:
logger.error(
f"Failed to load content information model even after attempted snapshot download: {e}"
)
raise
return _INFORMATION_CONTENT_MODEL
def tokenize_connector_classification_query(
connectors: list[str],
query: str,
@@ -195,6 +253,13 @@ def warm_up_intent_model() -> None:
)
def warm_up_information_content_model() -> None:
logger.notice("Warming up Content Model") # TODO: add version if needed
information_content_model = get_local_information_content_model()
information_content_model(INFORMATION_CONTENT_MODEL_WARM_UP_STRING)
@simple_log_function_time()
def run_inference(tokens: BatchEncoding) -> tuple[list[float], list[float]]:
intent_model = get_local_intent_model()
@@ -218,6 +283,117 @@ def run_inference(tokens: BatchEncoding) -> tuple[list[float], list[float]]:
return intent_probabilities.tolist(), token_positive_probs
@simple_log_function_time()
def run_content_classification_inference(
text_inputs: list[str],
) -> list[ContentClassificationPrediction]:
"""
Assign a score to the segments in question. The model stored in get_local_information_content_model()
creates the 'model score' based on its training, and the scores are then converted to a 0.0-1.0 scale.
In the code outside of the model/inference model servers that score will be converted into the actual
boost factor.
"""
def _prob_to_score(prob: float) -> float:
"""
Conversion of base score to 0.0 - 1.0 score. Note that the min/max values depend on the model!
"""
_MIN_BASE_SCORE = 0.25
_MAX_BASE_SCORE = 0.75
if prob < _MIN_BASE_SCORE:
raw_score = 0.0
elif prob < _MAX_BASE_SCORE:
raw_score = (prob - _MIN_BASE_SCORE) / (_MAX_BASE_SCORE - _MIN_BASE_SCORE)
else:
raw_score = 1.0
return (
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
+ (
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX
- INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
)
* raw_score
)
_BATCH_SIZE = 32
content_model = get_local_information_content_model()
# Process inputs in batches
all_output_classes: list[int] = []
all_base_output_probabilities: list[float] = []
for i in range(0, len(text_inputs), _BATCH_SIZE):
batch = text_inputs[i : i + _BATCH_SIZE]
batch_with_prefix = []
batch_indices = []
# Pre-allocate results for this batch
batch_output_classes: list[np.ndarray] = [np.array(1)] * len(batch)
batch_probabilities: list[np.ndarray] = [np.array(1.0)] * len(batch)
# Pre-process batch to handle long input exceptions
for j, text in enumerate(batch):
if len(text) == 0:
# if no input, treat as non-informative from the model's perspective
batch_output_classes[j] = np.array(0)
batch_probabilities[j] = np.array(0.0)
logger.warning("Input for Content Information Model is empty")
elif (
len(text.split())
<= INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH
):
# if input is short, use the model
batch_with_prefix.append(
_INFORMATION_CONTENT_MODEL_PROMPT_PREFIX + text
)
batch_indices.append(j)
else:
# if longer than cutoff, treat as informative (stay with default), but issue warning
logger.warning("Input for Content Information Model too long")
if batch_with_prefix: # Only run model if we have valid inputs
# Get predictions for the batch
model_output_classes = content_model(batch_with_prefix)
model_output_probabilities = content_model.predict_proba(batch_with_prefix)
# Place results in the correct positions
for idx, batch_idx in enumerate(batch_indices):
batch_output_classes[batch_idx] = model_output_classes[idx].numpy()
batch_probabilities[batch_idx] = model_output_probabilities[idx][
1
].numpy() # x[1] is prob of the positive class
all_output_classes.extend([int(x) for x in batch_output_classes])
all_base_output_probabilities.extend([float(x) for x in batch_probabilities])
logits = [
np.log(p / (1 - p)) if p != 0.0 and p != 1.0 else (100 if p == 1.0 else -100)
for p in all_base_output_probabilities
]
scaled_logits = [
logit / INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE
for logit in logits
]
output_probabilities_with_temp = [
np.exp(scaled_logit) / (1 + np.exp(scaled_logit))
for scaled_logit in scaled_logits
]
prediction_scores = [
_prob_to_score(p_temp) for p_temp in output_probabilities_with_temp
]
content_classification_predictions = [
ContentClassificationPrediction(
predicted_label=predicted_label, content_boost_factor=output_score
)
for predicted_label, output_score in zip(all_output_classes, prediction_scores)
]
return content_classification_predictions
def map_keywords(
input_ids: torch.Tensor, tokenizer: AutoTokenizer, is_keyword: list[bool]
) -> list[str]:
@@ -362,3 +538,10 @@ async def process_analysis_request(
is_keyword, keywords = run_analysis(intent_request)
return IntentResponse(is_keyword=is_keyword, keywords=keywords)
@router.post("/content-classification")
async def process_content_classification_request(
content_classification_requests: list[str],
) -> list[ContentClassificationPrediction]:
return run_content_classification_inference(content_classification_requests)

View File

@@ -13,6 +13,7 @@ from sentry_sdk.integrations.starlette import StarletteIntegration
from transformers import logging as transformer_logging # type:ignore
from model_server.custom_models import router as custom_models_router
from model_server.custom_models import warm_up_information_content_model
from model_server.custom_models import warm_up_intent_model
from model_server.encoders import router as encoders_router
from model_server.management_endpoints import router as management_router
@@ -74,9 +75,15 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
logger.notice(f"Torch Threads: {torch.get_num_threads()}")
if not INDEXING_ONLY:
logger.notice(
"The intent model should run on the model server. The information content model should not run here."
)
warm_up_intent_model()
else:
logger.notice("This model server should only run document indexing.")
logger.notice(
"The content information model should run on the indexing model server. The intent model should not run here."
)
warm_up_information_content_model()
yield

View File

@@ -112,5 +112,6 @@ celery_app.autodiscover_tasks(
"onyx.background.celery.tasks.connector_deletion",
"onyx.background.celery.tasks.doc_permission_syncing",
"onyx.background.celery.tasks.indexing",
"onyx.background.celery.tasks.tenant_provisioning",
]
)

View File

@@ -92,5 +92,6 @@ def on_setup_logging(
celery_app.autodiscover_tasks(
[
"onyx.background.celery.tasks.monitoring",
"onyx.background.celery.tasks.tenant_provisioning",
]
)

View File

@@ -167,6 +167,16 @@ beat_cloud_tasks: list[dict] = [
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-available-tenants",
"task": OnyxCeleryTask.CHECK_AVAILABLE_TENANTS,
"schedule": timedelta(minutes=10),
"options": {
"queue": OnyxCeleryQueues.MONITORING,
"priority": OnyxCeleryPriority.HIGH,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
]
# tasks that only run self hosted

View File

@@ -0,0 +1,199 @@
"""
Periodic tasks for tenant pre-provisioning.
"""
import asyncio
import datetime
import uuid
from celery import shared_task
from celery import Task
from redis.lock import Lock as RedisLock
from ee.onyx.server.tenants.provisioning import setup_tenant
from ee.onyx.server.tenants.schema_management import create_schema_if_not_exists
from ee.onyx.server.tenants.schema_management import get_current_alembic_version
from onyx.background.celery.apps.app_base import task_logger
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.app_configs import TARGET_AVAILABLE_TENANTS
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.engine import get_session_with_shared_schema
from onyx.db.models import AvailableTenant
from onyx.redis.redis_pool import get_redis_client
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import TENANT_ID_PREFIX
# Default number of pre-provisioned tenants to maintain
DEFAULT_TARGET_AVAILABLE_TENANTS = 5
# Soft time limit for tenant pre-provisioning tasks (in seconds)
_TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 5 # 5 minutes
# Hard time limit for tenant pre-provisioning tasks (in seconds)
_TENANT_PROVISIONING_TIME_LIMIT = 60 * 10 # 10 minutes
@shared_task(
name=OnyxCeleryTask.CHECK_AVAILABLE_TENANTS,
queue=OnyxCeleryQueues.MONITORING,
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
trail=False,
bind=True,
)
def check_available_tenants(self: Task) -> None:
"""
Check if we have enough pre-provisioned tenants available.
If not, trigger the pre-provisioning of new tenants.
"""
task_logger.info("STARTING CHECK_AVAILABLE_TENANTS")
if not MULTI_TENANT:
task_logger.info(
"Multi-tenancy is not enabled, skipping tenant pre-provisioning"
)
return
r = get_redis_client()
lock_check: RedisLock = r.lock(
OnyxRedisLocks.CHECK_AVAILABLE_TENANTS_LOCK,
timeout=_TENANT_PROVISIONING_SOFT_TIME_LIMIT,
)
# These tasks should never overlap
if not lock_check.acquire(blocking=False):
task_logger.info(
"Skipping check_available_tenants task because it is already running"
)
return
try:
# Get the current count of available tenants
with get_session_with_shared_schema() as db_session:
available_tenants_count = db_session.query(AvailableTenant).count()
# Get the target number of available tenants
target_available_tenants = getattr(
TARGET_AVAILABLE_TENANTS, "value", DEFAULT_TARGET_AVAILABLE_TENANTS
)
# Calculate how many new tenants we need to provision
tenants_to_provision = max(
0, target_available_tenants - available_tenants_count
)
task_logger.info(
f"Available tenants: {available_tenants_count}, "
f"Target: {target_available_tenants}, "
f"To provision: {tenants_to_provision}"
)
# Trigger pre-provisioning tasks for each tenant needed
for _ in range(tenants_to_provision):
from celery import current_app
current_app.send_task(
OnyxCeleryTask.PRE_PROVISION_TENANT,
priority=OnyxCeleryPriority.LOW,
)
except Exception:
task_logger.exception("Error in check_available_tenants task")
finally:
lock_check.release()
@shared_task(
name=OnyxCeleryTask.PRE_PROVISION_TENANT,
ignore_result=True,
soft_time_limit=_TENANT_PROVISIONING_SOFT_TIME_LIMIT,
time_limit=_TENANT_PROVISIONING_TIME_LIMIT,
queue=OnyxCeleryQueues.MONITORING,
bind=True,
)
def pre_provision_tenant(self: Task) -> None:
"""
Pre-provision a new tenant and store it in the NewAvailableTenant table.
This function fully sets up the tenant with all necessary configurations,
so it's ready to be assigned to a user immediately.
"""
# The MULTI_TENANT check is now done at the caller level (check_available_tenants)
# rather than inside this function
r = get_redis_client()
lock_provision: RedisLock = r.lock(
OnyxRedisLocks.PRE_PROVISION_TENANT_LOCK,
timeout=_TENANT_PROVISIONING_SOFT_TIME_LIMIT,
)
# Allow multiple pre-provisioning tasks to run, but ensure they don't overlap
if not lock_provision.acquire(blocking=False):
task_logger.debug(
"Skipping pre_provision_tenant task because it is already running"
)
return
tenant_id: str | None = None
try:
# Generate a new tenant ID
tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4())
task_logger.info(f"Pre-provisioning tenant: {tenant_id}")
# Create the schema for the new tenant
schema_created = create_schema_if_not_exists(tenant_id)
if schema_created:
task_logger.debug(f"Created schema for tenant: {tenant_id}")
else:
task_logger.debug(f"Schema already exists for tenant: {tenant_id}")
# Set up the tenant with all necessary configurations
task_logger.debug(f"Setting up tenant configuration: {tenant_id}")
asyncio.run(setup_tenant(tenant_id))
task_logger.debug(f"Tenant configuration completed: {tenant_id}")
# Get the current Alembic version
alembic_version = get_current_alembic_version(tenant_id)
task_logger.debug(
f"Tenant {tenant_id} using Alembic version: {alembic_version}"
)
# Store the pre-provisioned tenant in the database
task_logger.debug(f"Storing pre-provisioned tenant in database: {tenant_id}")
with get_session_with_shared_schema() as db_session:
# Use a transaction to ensure atomicity
db_session.begin()
try:
new_tenant = AvailableTenant(
tenant_id=tenant_id,
alembic_version=alembic_version,
date_created=datetime.datetime.now(),
)
db_session.add(new_tenant)
db_session.commit()
task_logger.info(f"Successfully pre-provisioned tenant: {tenant_id}")
except Exception:
db_session.rollback()
task_logger.error(
f"Failed to store pre-provisioned tenant: {tenant_id}",
exc_info=True,
)
raise
except Exception:
task_logger.error("Error in pre_provision_tenant task", exc_info=True)
# If we have a tenant_id, attempt to rollback any partially completed provisioning
if tenant_id:
task_logger.info(
f"Rolling back failed tenant provisioning for: {tenant_id}"
)
try:
from ee.onyx.server.tenants.provisioning import (
rollback_tenant_provisioning,
)
asyncio.run(rollback_tenant_provisioning(tenant_id))
except Exception:
task_logger.exception(f"Error during rollback for tenant: {tenant_id}")
finally:
lock_provision.release()

View File

@@ -563,6 +563,7 @@ def vespa_metadata_sync_task(self: Task, document_id: str, *, tenant_id: str) ->
access=doc_access,
boost=doc.boost,
hidden=doc.hidden,
# aggregated_boost_factor=doc.aggregated_boost_factor,
)
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.

View File

@@ -28,6 +28,7 @@ from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import Document
from onyx.connectors.models import IndexAttemptMetadata
from onyx.connectors.models import TextSection
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import get_last_successful_attempt_time
from onyx.db.connector_credential_pair import update_connector_credential_pair
@@ -52,6 +53,9 @@ from onyx.httpx.httpx_pool import HttpxPool
from onyx.indexing.embedder import DefaultIndexingEmbedder
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.indexing.indexing_pipeline import build_indexing_pipeline
from onyx.natural_language_processing.search_nlp_models import (
InformationContentClassificationModel,
)
from onyx.utils.logger import setup_logger
from onyx.utils.logger import TaskAttemptSingleton
from onyx.utils.telemetry import create_milestone_and_report
@@ -154,14 +158,12 @@ def strip_null_characters(doc_batch: list[Document]) -> list[Document]:
)
for section in cleaned_doc.sections:
if section.link and "\x00" in section.link:
logger.warning(
f"NUL characters found in document link for document: {cleaned_doc.id}"
)
if section.link is not None:
section.link = section.link.replace("\x00", "")
# since text can be longer, just replace to avoid double scan
section.text = section.text.replace("\x00", "")
if isinstance(section, TextSection) and section.text is not None:
section.text = section.text.replace("\x00", "")
cleaned_batch.append(cleaned_doc)
@@ -349,6 +351,8 @@ def _run_indexing(
callback=callback,
)
information_content_classification_model = InformationContentClassificationModel()
document_index = get_default_document_index(
index_attempt_start.search_settings,
None,
@@ -357,6 +361,7 @@ def _run_indexing(
indexing_pipeline = build_indexing_pipeline(
embedder=embedding_model,
information_content_classification_model=information_content_classification_model,
document_index=document_index,
ignore_time_skip=(
ctx.from_beginning
@@ -479,7 +484,11 @@ def _run_indexing(
doc_size = 0
for section in doc.sections:
doc_size += len(section.text)
if (
isinstance(section, TextSection)
and section.text is not None
):
doc_size += len(section.text)
if doc_size > INDEXING_SIZE_WARNING_THRESHOLD:
logger.warning(

View File

@@ -8,6 +8,9 @@ from onyx.configs.constants import AuthType
from onyx.configs.constants import DocumentIndexType
from onyx.configs.constants import QueryHistoryType
from onyx.file_processing.enums import HtmlBasedConnectorTransformLinksStrategy
from onyx.prompts.image_analysis import DEFAULT_IMAGE_ANALYSIS_SYSTEM_PROMPT
from onyx.prompts.image_analysis import DEFAULT_IMAGE_SUMMARIZATION_SYSTEM_PROMPT
from onyx.prompts.image_analysis import DEFAULT_IMAGE_SUMMARIZATION_USER_PROMPT
#####
# App Configs
@@ -643,3 +646,24 @@ MOCK_LLM_RESPONSE = (
DEFAULT_IMAGE_ANALYSIS_MAX_SIZE_MB = 20
# Number of pre-provisioned tenants to maintain
TARGET_AVAILABLE_TENANTS = int(os.environ.get("TARGET_AVAILABLE_TENANTS", "5"))
# Image summarization configuration
IMAGE_SUMMARIZATION_SYSTEM_PROMPT = os.environ.get(
"IMAGE_SUMMARIZATION_SYSTEM_PROMPT",
DEFAULT_IMAGE_SUMMARIZATION_SYSTEM_PROMPT,
)
# The user prompt for image summarization - the image filename will be automatically prepended
IMAGE_SUMMARIZATION_USER_PROMPT = os.environ.get(
"IMAGE_SUMMARIZATION_USER_PROMPT",
DEFAULT_IMAGE_SUMMARIZATION_USER_PROMPT,
)
IMAGE_ANALYSIS_SYSTEM_PROMPT = os.environ.get(
"IMAGE_ANALYSIS_SYSTEM_PROMPT",
DEFAULT_IMAGE_ANALYSIS_SYSTEM_PROMPT,
)

View File

@@ -322,6 +322,8 @@ class OnyxRedisLocks:
"da_lock:check_connector_external_group_sync_beat"
)
MONITOR_BACKGROUND_PROCESSES_LOCK = "da_lock:monitor_background_processes"
CHECK_AVAILABLE_TENANTS_LOCK = "da_lock:check_available_tenants"
PRE_PROVISION_TENANT_LOCK = "da_lock:pre_provision_tenant"
CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX = (
"da_lock:connector_doc_permissions_sync"
@@ -384,6 +386,7 @@ class OnyxCeleryTask:
CLOUD_MONITOR_CELERY_QUEUES = (
f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor_celery_queues"
)
CHECK_AVAILABLE_TENANTS = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check_available_tenants"
CHECK_FOR_CONNECTOR_DELETION = "check_for_connector_deletion_task"
CHECK_FOR_VESPA_SYNC_TASK = "check_for_vespa_sync_task"
@@ -400,6 +403,9 @@ class OnyxCeleryTask:
MONITOR_BACKGROUND_PROCESSES = "monitor_background_processes"
MONITOR_CELERY_QUEUES = "monitor_celery_queues"
# Tenant pre-provisioning
PRE_PROVISION_TENANT = "pre_provision_tenant"
KOMBU_MESSAGE_CLEANUP_TASK = "kombu_message_cleanup_task"
CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK = (
"connector_permission_sync_generator_task"

View File

@@ -132,3 +132,10 @@ if _LITELLM_EXTRA_BODY_RAW:
LITELLM_EXTRA_BODY = json.loads(_LITELLM_EXTRA_BODY_RAW)
except Exception:
pass
# Whether and how to lower scores for short chunks w/o relevant context
# Evaluated via custom ML model
USE_INFORMATION_CONTENT_CLASSIFICATION = (
os.environ.get("USE_INFORMATION_CONTENT_CLASSIFICATION", "false").lower() == "true"
)

View File

@@ -4,6 +4,7 @@ from concurrent.futures import Future
from concurrent.futures import ThreadPoolExecutor
from io import BytesIO
from typing import Any
from typing import cast
import requests
from pyairtable import Api as AirtableApi
@@ -16,7 +17,8 @@ from onyx.configs.constants import DocumentSource
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.utils.logger import setup_logger
@@ -267,7 +269,7 @@ class AirtableConnector(LoadConnector):
table_id: str,
view_id: str | None,
record_id: str,
) -> tuple[list[Section], dict[str, str | list[str]]]:
) -> tuple[list[TextSection], dict[str, str | list[str]]]:
"""
Process a single Airtable field and return sections or metadata.
@@ -305,7 +307,7 @@ class AirtableConnector(LoadConnector):
# Otherwise, create relevant sections
sections = [
Section(
TextSection(
link=link,
text=(
f"{field_name}:\n"
@@ -340,7 +342,7 @@ class AirtableConnector(LoadConnector):
table_name = table_schema.name
record_id = record["id"]
fields = record["fields"]
sections: list[Section] = []
sections: list[TextSection] = []
metadata: dict[str, str | list[str]] = {}
# Get primary field value if it exists
@@ -384,7 +386,7 @@ class AirtableConnector(LoadConnector):
return Document(
id=f"airtable__{record_id}",
sections=sections,
sections=(cast(list[TextSection | ImageSection], sections)),
source=DocumentSource.AIRTABLE,
semantic_identifier=semantic_id,
metadata=metadata,

View File

@@ -10,7 +10,7 @@ from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import TextSection
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -82,7 +82,7 @@ class AsanaConnector(LoadConnector, PollConnector):
logger.debug(f"Converting Asana task {task.id} to Document")
return Document(
id=task.id,
sections=[Section(link=task.link, text=task.text)],
sections=[TextSection(link=task.link, text=task.text)],
doc_updated_at=task.last_modified,
source=DocumentSource.ASANA,
semantic_identifier=task.title,

View File

@@ -20,7 +20,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import TextSection
from onyx.file_processing.html_utils import parse_html_page_basic
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
@@ -221,7 +221,7 @@ def _get_forums(
def _translate_forum_to_doc(af: AxeroForum) -> Document:
doc = Document(
id=af.doc_id,
sections=[Section(link=af.link, text=reply) for reply in af.responses],
sections=[TextSection(link=af.link, text=reply) for reply in af.responses],
source=DocumentSource.AXERO,
semantic_identifier=af.title,
doc_updated_at=af.last_update,
@@ -244,7 +244,7 @@ def _translate_content_to_doc(content: dict) -> Document:
doc = Document(
id="AXERO_" + str(content["ContentID"]),
sections=[Section(link=content["ContentURL"], text=page_text)],
sections=[TextSection(link=content["ContentURL"], text=page_text)],
source=DocumentSource.AXERO,
semantic_identifier=content["ContentTitle"],
doc_updated_at=time_str_to_utc(content["DateUpdated"]),

View File

@@ -25,7 +25,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import TextSection
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.utils.logger import setup_logger
@@ -208,7 +208,7 @@ class BlobStorageConnector(LoadConnector, PollConnector):
batch.append(
Document(
id=f"{self.bucket_type}:{self.bucket_name}:{obj['Key']}",
sections=[Section(link=link, text=text)],
sections=[TextSection(link=link, text=text)],
source=DocumentSource(self.bucket_type.value),
semantic_identifier=name,
doc_updated_at=last_modified,
@@ -341,7 +341,14 @@ if __name__ == "__main__":
print("Sections:")
for section in doc.sections:
print(f" - Link: {section.link}")
print(f" - Text: {section.text[:100]}...")
if isinstance(section, TextSection) and section.text is not None:
print(f" - Text: {section.text[:100]}...")
elif (
hasattr(section, "image_file_name") and section.image_file_name
):
print(f" - Image: {section.image_file_name}")
else:
print("Error: Unknown section type")
print("---")
break

View File

@@ -18,7 +18,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import TextSection
from onyx.file_processing.html_utils import parse_html_page_basic
@@ -81,7 +81,7 @@ class BookstackConnector(LoadConnector, PollConnector):
)
return Document(
id="book__" + str(book.get("id")),
sections=[Section(link=url, text=text)],
sections=[TextSection(link=url, text=text)],
source=DocumentSource.BOOKSTACK,
semantic_identifier="Book: " + title,
title=title,
@@ -110,7 +110,7 @@ class BookstackConnector(LoadConnector, PollConnector):
)
return Document(
id="chapter__" + str(chapter.get("id")),
sections=[Section(link=url, text=text)],
sections=[TextSection(link=url, text=text)],
source=DocumentSource.BOOKSTACK,
semantic_identifier="Chapter: " + title,
title=title,
@@ -134,7 +134,7 @@ class BookstackConnector(LoadConnector, PollConnector):
)
return Document(
id="shelf:" + str(shelf.get("id")),
sections=[Section(link=url, text=text)],
sections=[TextSection(link=url, text=text)],
source=DocumentSource.BOOKSTACK,
semantic_identifier="Shelf: " + title,
title=title,
@@ -167,7 +167,7 @@ class BookstackConnector(LoadConnector, PollConnector):
time.sleep(0.1)
return Document(
id="page:" + page_id,
sections=[Section(link=url, text=text)],
sections=[TextSection(link=url, text=text)],
source=DocumentSource.BOOKSTACK,
semantic_identifier="Page: " + str(title),
title=str(title),

View File

@@ -17,7 +17,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import TextSection
from onyx.utils.retry_wrapper import retry_builder
@@ -62,11 +62,11 @@ class ClickupConnector(LoadConnector, PollConnector):
return response.json()
def _get_task_comments(self, task_id: str) -> list[Section]:
def _get_task_comments(self, task_id: str) -> list[TextSection]:
url_endpoint = f"/task/{task_id}/comment"
response = self._make_request(url_endpoint)
comments = [
Section(
TextSection(
link=f'https://app.clickup.com/t/{task_id}?comment={comment_dict["id"]}',
text=comment_dict["comment_text"],
)
@@ -133,7 +133,7 @@ class ClickupConnector(LoadConnector, PollConnector):
],
title=task["name"],
sections=[
Section(
TextSection(
link=task["url"],
text=(
task["markdown_description"]

View File

@@ -33,9 +33,9 @@ from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import ImageSection
from onyx.connectors.models import SlimDocument
from onyx.connectors.vision_enabled_connector import VisionEnabledConnector
from onyx.connectors.models import TextSection
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
@@ -85,7 +85,6 @@ class ConfluenceConnector(
PollConnector,
SlimConnector,
CredentialsConnector,
VisionEnabledConnector,
):
def __init__(
self,
@@ -116,9 +115,6 @@ class ConfluenceConnector(
self._confluence_client: OnyxConfluence | None = None
self._fetched_titles: set[str] = set()
# Initialize vision LLM using the mixin
self.initialize_vision_llm()
# Remove trailing slash from wiki_base if present
self.wiki_base = wiki_base.rstrip("/")
"""
@@ -245,12 +241,16 @@ class ConfluenceConnector(
)
# Create the main section for the page content
sections = [Section(text=page_content, link=page_url)]
sections: list[TextSection | ImageSection] = [
TextSection(text=page_content, link=page_url)
]
# Process comments if available
comment_text = self._get_comment_string_for_page_id(page_id)
if comment_text:
sections.append(Section(text=comment_text, link=f"{page_url}#comments"))
sections.append(
TextSection(text=comment_text, link=f"{page_url}#comments")
)
# Process attachments
if "children" in page and "attachment" in page["children"]:
@@ -263,21 +263,27 @@ class ConfluenceConnector(
result = process_attachment(
self.confluence_client,
attachment,
page_title,
self.image_analysis_llm,
page_id,
)
if result.text:
if result and result.text:
# Create a section for the attachment text
attachment_section = Section(
attachment_section = TextSection(
text=result.text,
link=f"{page_url}#attachment-{attachment['id']}",
)
sections.append(attachment_section)
elif result and result.file_name:
# Create an ImageSection for image attachments
image_section = ImageSection(
link=f"{page_url}#attachment-{attachment['id']}",
image_file_name=result.file_name,
)
sections.append(attachment_section)
elif result.error:
sections.append(image_section)
else:
logger.warning(
f"Error processing attachment '{attachment.get('title')}': {result.error}"
f"Error processing attachment '{attachment.get('title')}':",
f"{result.error if result else 'Unknown error'}",
)
# Extract metadata
@@ -348,7 +354,7 @@ class ConfluenceConnector(
# Now get attachments for that page:
attachment_query = self._construct_attachment_query(page["id"])
# We'll use the page's XML to provide context if we summarize an image
confluence_xml = page.get("body", {}).get("storage", {}).get("value", "")
page.get("body", {}).get("storage", {}).get("value", "")
for attachment in self.confluence_client.paginated_cql_retrieval(
cql=attachment_query,
@@ -356,7 +362,7 @@ class ConfluenceConnector(
):
attachment["metadata"].get("mediaType", "")
if not validate_attachment_filetype(
attachment, self.image_analysis_llm
attachment,
):
continue
@@ -366,23 +372,26 @@ class ConfluenceConnector(
response = convert_attachment_to_content(
confluence_client=self.confluence_client,
attachment=attachment,
page_context=confluence_xml,
llm=self.image_analysis_llm,
page_id=page["id"],
)
if response is None:
continue
content_text, file_storage_name = response
object_url = build_confluence_document_id(
self.wiki_base, attachment["_links"]["webui"], self.is_cloud
)
if content_text:
doc.sections.append(
Section(
TextSection(
text=content_text,
link=object_url,
)
)
elif file_storage_name:
doc.sections.append(
ImageSection(
link=object_url,
image_file_name=file_storage_name,
)
)
@@ -462,7 +471,7 @@ class ConfluenceConnector(
# If you skip images, you'll skip them in the permission sync
attachment["metadata"].get("mediaType", "")
if not validate_attachment_filetype(
attachment, self.image_analysis_llm
attachment,
):
continue

View File

@@ -1,4 +1,3 @@
import io
import json
import time
from collections.abc import Callable
@@ -19,17 +18,11 @@ from requests import HTTPError
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET
from onyx.configs.app_configs import (
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
)
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
from onyx.connectors.confluence.utils import _handle_http_error
from onyx.connectors.confluence.utils import confluence_refresh_tokens
from onyx.connectors.confluence.utils import get_start_param_from_url
from onyx.connectors.confluence.utils import update_param_in_path
from onyx.connectors.confluence.utils import validate_attachment_filetype
from onyx.connectors.interfaces import CredentialsProviderInterface
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.html_utils import format_document_soup
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.logger import setup_logger
@@ -808,65 +801,6 @@ def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str:
return _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) or _USER_NOT_FOUND
def attachment_to_content(
confluence_client: OnyxConfluence,
attachment: dict[str, Any],
parent_content_id: str | None = None,
) -> str | None:
"""If it returns None, assume that we should skip this attachment."""
if not validate_attachment_filetype(attachment):
return None
if "api.atlassian.com" in confluence_client.url:
# https://developer.atlassian.com/cloud/confluence/rest/v1/api-group-content---attachments/#api-wiki-rest-api-content-id-child-attachment-attachmentid-download-get
if not parent_content_id:
logger.warning(
"parent_content_id is required to download attachments from Confluence Cloud!"
)
return None
download_link = (
confluence_client.url
+ f"/rest/api/content/{parent_content_id}/child/attachment/{attachment['id']}/download"
)
else:
download_link = confluence_client.url + attachment["_links"]["download"]
attachment_size = attachment["extensions"]["fileSize"]
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
logger.warning(
f"Skipping {download_link} due to size. "
f"size={attachment_size} "
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD}"
)
return None
logger.info(f"_attachment_to_content - _session.get: link={download_link}")
# why are we using session.get here? we probably won't retry these ... is that ok?
response = confluence_client._session.get(download_link)
if response.status_code != 200:
logger.warning(
f"Failed to fetch {download_link} with invalid status code {response.status_code}"
)
return None
extracted_text = extract_file_text(
io.BytesIO(response.content),
file_name=attachment["title"],
break_on_unprocessable=False,
)
if len(extracted_text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD:
logger.warning(
f"Skipping {download_link} due to char count. "
f"char count={len(extracted_text)} "
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD}"
)
return None
return extracted_text
def extract_text_from_confluence_html(
confluence_client: OnyxConfluence,
confluence_object: dict[str, Any],

View File

@@ -22,6 +22,7 @@ from sqlalchemy.orm import Session
from onyx.configs.app_configs import (
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
)
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
from onyx.configs.constants import FileOrigin
if TYPE_CHECKING:
@@ -35,7 +36,6 @@ from onyx.db.pg_file_store import upsert_pgfilestore
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.file_validation import is_valid_image_type
from onyx.file_processing.image_utils import store_image_and_create_section
from onyx.llm.interfaces import LLM
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -53,17 +53,16 @@ class TokenResponse(BaseModel):
def validate_attachment_filetype(
attachment: dict[str, Any], llm: LLM | None = None
attachment: dict[str, Any],
) -> bool:
"""
Validates if the attachment is a supported file type.
If LLM is provided, also checks if it's an image that can be processed.
"""
attachment.get("metadata", {})
media_type = attachment.get("metadata", {}).get("mediaType", "")
if media_type.startswith("image/"):
return llm is not None and is_valid_image_type(media_type)
return is_valid_image_type(media_type)
# For non-image files, check if we support the extension
title = attachment.get("title", "")
@@ -84,55 +83,103 @@ class AttachmentProcessingResult(BaseModel):
error: str | None = None
def _download_attachment(
confluence_client: "OnyxConfluence", attachment: dict[str, Any]
) -> bytes | None:
"""
Retrieves the raw bytes of an attachment from Confluence. Returns None on error.
"""
download_link = confluence_client.url + attachment["_links"]["download"]
resp = confluence_client._session.get(download_link)
if resp.status_code != 200:
logger.warning(
f"Failed to fetch {download_link} with status code {resp.status_code}"
def _make_attachment_link(
confluence_client: "OnyxConfluence",
attachment: dict[str, Any],
parent_content_id: str | None = None,
) -> str | None:
download_link = ""
if "api.atlassian.com" in confluence_client.url:
# https://developer.atlassian.com/cloud/confluence/rest/v1/api-group-content---attachments/#api-wiki-rest-api-content-id-child-attachment-attachmentid-download-get
if not parent_content_id:
logger.warning(
"parent_content_id is required to download attachments from Confluence Cloud!"
)
return None
download_link = (
confluence_client.url
+ f"/rest/api/content/{parent_content_id}/child/attachment/{attachment['id']}/download"
)
return None
return resp.content
else:
download_link = confluence_client.url + attachment["_links"]["download"]
return download_link
def process_attachment(
confluence_client: "OnyxConfluence",
attachment: dict[str, Any],
page_context: str,
llm: LLM | None,
parent_content_id: str | None,
) -> AttachmentProcessingResult:
"""
Processes a Confluence attachment. If it's a document, extracts text,
or if it's an image and an LLM is available, summarizes it. Returns a structured result.
or if it's an image, stores it for later analysis. Returns a structured result.
"""
try:
# Get the media type from the attachment metadata
media_type = attachment.get("metadata", {}).get("mediaType", "")
# Validate the attachment type
if not validate_attachment_filetype(attachment, llm):
if not validate_attachment_filetype(attachment):
return AttachmentProcessingResult(
text=None,
file_name=None,
error=f"Unsupported file type: {media_type}",
)
# Download the attachment
raw_bytes = _download_attachment(confluence_client, attachment)
if raw_bytes is None:
attachment_link = _make_attachment_link(
confluence_client, attachment, parent_content_id
)
if not attachment_link:
return AttachmentProcessingResult(
text=None, file_name=None, error="Failed to download attachment"
text=None, file_name=None, error="Failed to make attachment link"
)
# Process image attachments with LLM if available
if media_type.startswith("image/") and llm:
attachment_size = attachment["extensions"]["fileSize"]
if not media_type.startswith("image/"):
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
logger.warning(
f"Skipping {attachment_link} due to size. "
f"size={attachment_size} "
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD}"
)
return AttachmentProcessingResult(
text=None,
file_name=None,
error=f"Attachment text too long: {attachment_size} chars",
)
logger.info(
f"Downloading attachment: "
f"title={attachment['title']} "
f"length={attachment_size} "
f"link={attachment_link}"
)
# Download the attachment
resp: requests.Response = confluence_client._session.get(attachment_link)
if resp.status_code != 200:
logger.warning(
f"Failed to fetch {attachment_link} with status code {resp.status_code}"
)
return AttachmentProcessingResult(
text=None,
file_name=None,
error=f"Attachment download status code is {resp.status_code}",
)
raw_bytes = resp.content
if not raw_bytes:
return AttachmentProcessingResult(
text=None, file_name=None, error="attachment.content is None"
)
# Process image attachments
if media_type.startswith("image/"):
return _process_image_attachment(
confluence_client, attachment, page_context, llm, raw_bytes, media_type
confluence_client, attachment, raw_bytes, media_type
)
# Process document attachments
@@ -165,12 +212,10 @@ def process_attachment(
def _process_image_attachment(
confluence_client: "OnyxConfluence",
attachment: dict[str, Any],
page_context: str,
llm: LLM,
raw_bytes: bytes,
media_type: str,
) -> AttachmentProcessingResult:
"""Process an image attachment by saving it and generating a summary."""
"""Process an image attachment by saving it without generating a summary."""
try:
# Use the standardized image storage and section creation
with get_session_with_current_tenant() as db_session:
@@ -180,15 +225,14 @@ def _process_image_attachment(
file_name=Path(attachment["id"]).name,
display_name=attachment["title"],
media_type=media_type,
llm=llm,
file_origin=FileOrigin.CONNECTOR,
)
logger.info(f"Stored image attachment with file name: {file_name}")
return AttachmentProcessingResult(
text=section.text, file_name=file_name, error=None
)
# Return empty text but include the file_name for later processing
return AttachmentProcessingResult(text="", file_name=file_name, error=None)
except Exception as e:
msg = f"Image summarization failed for {attachment['title']}: {e}"
msg = f"Image storage failed for {attachment['title']}: {e}"
logger.error(msg, exc_info=e)
return AttachmentProcessingResult(text=None, file_name=None, error=msg)
@@ -249,16 +293,15 @@ def _process_text_attachment(
def convert_attachment_to_content(
confluence_client: "OnyxConfluence",
attachment: dict[str, Any],
page_context: str,
llm: LLM | None,
page_id: str,
) -> tuple[str | None, str | None] | None:
"""
Facade function which:
1. Validates attachment type
2. Extracts or summarizes content
2. Extracts content or stores image for later processing
3. Returns (content_text, stored_file_name) or None if we should skip it
"""
media_type = attachment["metadata"]["mediaType"]
media_type = attachment.get("metadata", {}).get("mediaType", "")
# Quick check for unsupported types:
if media_type.startswith("video/") or media_type == "application/gliffy+json":
logger.warning(
@@ -266,7 +309,7 @@ def convert_attachment_to_content(
)
return None
result = process_attachment(confluence_client, attachment, page_context, llm)
result = process_attachment(confluence_client, attachment, page_id)
if result.error is not None:
logger.warning(
f"Attachment {attachment['title']} encountered error: {result.error}"
@@ -479,6 +522,10 @@ def attachment_to_file_record(
download_link, absolute=True, not_json_response=True
)
file_type = attachment.get("metadata", {}).get(
"mediaType", "application/octet-stream"
)
# Save image to file store
file_name = f"confluence_attachment_{attachment['id']}"
lobj_oid = create_populate_lobj(BytesIO(image_data), db_session)
@@ -486,7 +533,7 @@ def attachment_to_file_record(
file_name=file_name,
display_name=attachment["title"],
file_origin=FileOrigin.OTHER,
file_type=attachment["metadata"]["mediaType"],
file_type=file_type,
lobj_oid=lobj_oid,
db_session=db_session,
commit=True,

View File

@@ -4,6 +4,7 @@ from collections.abc import Iterable
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
from discord import Client
from discord.channel import TextChannel
@@ -20,7 +21,8 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -32,7 +34,7 @@ _SNIPPET_LENGTH = 30
def _convert_message_to_document(
message: DiscordMessage,
sections: list[Section],
sections: list[TextSection],
) -> Document:
"""
Convert a discord message to a document
@@ -78,7 +80,7 @@ def _convert_message_to_document(
semantic_identifier=semantic_identifier,
doc_updated_at=message.edited_at,
title=title,
sections=sections,
sections=(cast(list[TextSection | ImageSection], sections)),
metadata=metadata,
)
@@ -123,8 +125,8 @@ async def _fetch_documents_from_channel(
if channel_message.type != MessageType.default:
continue
sections: list[Section] = [
Section(
sections: list[TextSection] = [
TextSection(
text=channel_message.content,
link=channel_message.jump_url,
)
@@ -142,7 +144,7 @@ async def _fetch_documents_from_channel(
continue
sections = [
Section(
TextSection(
text=thread_message.content,
link=thread_message.jump_url,
)
@@ -160,7 +162,7 @@ async def _fetch_documents_from_channel(
continue
sections = [
Section(
TextSection(
text=thread_message.content,
link=thread_message.jump_url,
)

View File

@@ -3,6 +3,7 @@ import urllib.parse
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
import requests
from pydantic import BaseModel
@@ -20,7 +21,8 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.file_processing.html_utils import parse_html_page_basic
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
@@ -112,7 +114,7 @@ class DiscourseConnector(PollConnector):
responders.append(BasicExpertInfo(display_name=responder_name))
sections.append(
Section(link=topic_url, text=parse_html_page_basic(post["cooked"]))
TextSection(link=topic_url, text=parse_html_page_basic(post["cooked"]))
)
category_name = self.category_id_map.get(topic["category_id"])
@@ -129,7 +131,7 @@ class DiscourseConnector(PollConnector):
doc = Document(
id="_".join([DocumentSource.DISCOURSE.value, str(topic["id"])]),
sections=sections,
sections=cast(list[TextSection | ImageSection], sections),
source=DocumentSource.DISCOURSE,
semantic_identifier=topic["title"],
doc_updated_at=time_str_to_utc(topic["last_posted_at"]),

View File

@@ -19,7 +19,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import TextSection
from onyx.file_processing.html_utils import parse_html_page_basic
from onyx.utils.retry_wrapper import retry_builder
@@ -158,7 +158,7 @@ class Document360Connector(LoadConnector, PollConnector):
document = Document(
id=article_details["id"],
sections=[Section(link=doc_link, text=doc_text)],
sections=[TextSection(link=doc_link, text=doc_text)],
source=DocumentSource.DOCUMENT360,
semantic_identifier=article_details["title"],
doc_updated_at=updated_at,

View File

@@ -19,7 +19,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import TextSection
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.utils.logger import setup_logger
@@ -108,7 +108,7 @@ class DropboxConnector(LoadConnector, PollConnector):
batch.append(
Document(
id=f"doc:{entry.id}",
sections=[Section(link=link, text=text)],
sections=[TextSection(link=link, text=text)],
source=DocumentSource.DROPBOX,
semantic_identifier=entry.name,
doc_updated_at=modified_time,

View File

@@ -24,7 +24,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import TextSection
from onyx.file_processing.extract_file_text import detect_encoding
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import get_file_ext
@@ -111,7 +111,7 @@ def _process_egnyte_file(
# Create the document
return Document(
id=f"egnyte-{file_metadata['entry_id']}",
sections=[Section(text=file_content_raw.strip(), link=web_url)],
sections=[TextSection(text=file_content_raw.strip(), link=web_url)],
source=DocumentSource.EGNYTE,
semantic_identifier=file_name,
metadata=metadata,

View File

@@ -16,8 +16,8 @@ from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.vision_enabled_connector import VisionEnabledConnector
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.pg_file_store import get_pgfilestore_by_file_name
from onyx.file_processing.extract_file_text import extract_text_and_images
@@ -26,7 +26,6 @@ from onyx.file_processing.extract_file_text import is_valid_file_ext
from onyx.file_processing.extract_file_text import load_files_from_zip
from onyx.file_processing.image_utils import store_image_and_create_section
from onyx.file_store.file_store import get_default_file_store
from onyx.llm.interfaces import LLM
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -59,32 +58,44 @@ def _read_files_and_metadata(
def _create_image_section(
llm: LLM | None,
image_data: bytes,
db_session: Session,
parent_file_name: str,
display_name: str,
link: str | None = None,
idx: int = 0,
) -> tuple[Section, str | None]:
) -> tuple[ImageSection, str | None]:
"""
Create a Section object for a single image and store the image in PGFileStore.
If summarization is enabled and we have an LLM, summarize the image.
Creates an ImageSection for an image file or embedded image.
Stores the image in PGFileStore but does not generate a summary.
Args:
image_data: Raw image bytes
db_session: Database session
parent_file_name: Name of the parent file (for embedded images)
display_name: Display name for the image
idx: Index for embedded images
Returns:
tuple: (Section object, file_name in PGFileStore or None if storage failed)
Tuple of (ImageSection, stored_file_name or None)
"""
# Create a unique file name for the embedded image
file_name = f"{parent_file_name}_embedded_{idx}"
# Create a unique identifier for the image
file_name = f"{parent_file_name}_embedded_{idx}" if idx > 0 else parent_file_name
# Use the standardized utility to store the image and create a section
return store_image_and_create_section(
db_session=db_session,
image_data=image_data,
file_name=file_name,
display_name=display_name,
llm=llm,
file_origin=FileOrigin.OTHER,
)
# Store the image and create a section
try:
section, stored_file_name = store_image_and_create_section(
db_session=db_session,
image_data=image_data,
file_name=file_name,
display_name=display_name,
link=link,
file_origin=FileOrigin.CONNECTOR,
)
return section, stored_file_name
except Exception as e:
logger.error(f"Failed to store image {display_name}: {e}")
raise e
def _process_file(
@@ -93,12 +104,16 @@ def _process_file(
metadata: dict[str, Any] | None,
pdf_pass: str | None,
db_session: Session,
llm: LLM | None,
) -> list[Document]:
"""
Processes a single file, returning a list of Documents (typically one).
Also handles embedded images if 'EMBEDDED_IMAGE_EXTRACTION_ENABLED' is true.
Process a file and return a list of Documents.
For images, creates ImageSection objects without summarization.
For documents with embedded images, extracts and stores the images.
"""
if metadata is None:
metadata = {}
# Get file extension and determine file type
extension = get_file_ext(file_name)
# Fetch the DB record so we know the ID for internal URL
@@ -114,8 +129,6 @@ def _process_file(
return []
# Prepare doc metadata
if metadata is None:
metadata = {}
file_display_name = metadata.get("file_display_name") or os.path.basename(file_name)
# Timestamps
@@ -158,6 +171,7 @@ def _process_file(
"title",
"connector_type",
"pdf_password",
"mime_type",
]
}
@@ -170,33 +184,45 @@ def _process_file(
title = metadata.get("title") or file_display_name
# 1) If the file itself is an image, handle that scenario quickly
IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".webp"}
if extension in IMAGE_EXTENSIONS:
# Summarize or produce empty doc
if extension in LoadConnector.IMAGE_EXTENSIONS:
# Read the image data
image_data = file.read()
image_section, _ = _create_image_section(
llm, image_data, db_session, pg_record.file_name, title
)
return [
Document(
id=doc_id,
sections=[image_section],
source=source_type,
semantic_identifier=file_display_name,
title=title,
doc_updated_at=final_time_updated,
primary_owners=p_owners,
secondary_owners=s_owners,
metadata=metadata_tags,
)
]
if not image_data:
logger.warning(f"Empty image file: {file_name}")
return []
# 2) Otherwise: text-based approach. Possibly with embedded images if enabled.
# (For example .docx with inline images).
# Create an ImageSection for the image
try:
section, _ = _create_image_section(
image_data=image_data,
db_session=db_session,
parent_file_name=pg_record.file_name,
display_name=title,
)
return [
Document(
id=doc_id,
sections=[section],
source=source_type,
semantic_identifier=file_display_name,
title=title,
doc_updated_at=final_time_updated,
primary_owners=p_owners,
secondary_owners=s_owners,
metadata=metadata_tags,
)
]
except Exception as e:
logger.error(f"Failed to process image file {file_name}: {e}")
return []
# 2) Otherwise: text-based approach. Possibly with embedded images.
file.seek(0)
text_content = ""
embedded_images: list[tuple[bytes, str]] = []
# Extract text and images from the file
text_content, embedded_images = extract_text_and_images(
file=file,
file_name=file_name,
@@ -204,24 +230,29 @@ def _process_file(
)
# Build sections: first the text as a single Section
sections = []
sections: list[TextSection | ImageSection] = []
link_in_meta = metadata.get("link")
if text_content.strip():
sections.append(Section(link=link_in_meta, text=text_content.strip()))
sections.append(TextSection(link=link_in_meta, text=text_content.strip()))
# Then any extracted images from docx, etc.
for idx, (img_data, img_name) in enumerate(embedded_images, start=1):
# Store each embedded image as a separate file in PGFileStore
# and create a section with the image summary
image_section, _ = _create_image_section(
llm,
img_data,
db_session,
pg_record.file_name,
f"{title} - image {idx}",
idx,
)
sections.append(image_section)
# and create a section with the image reference
try:
image_section, _ = _create_image_section(
image_data=img_data,
db_session=db_session,
parent_file_name=pg_record.file_name,
display_name=f"{title} - image {idx}",
idx=idx,
)
sections.append(image_section)
except Exception as e:
logger.warning(
f"Failed to process embedded image {idx} in {file_name}: {e}"
)
return [
Document(
id=doc_id,
@@ -237,10 +268,10 @@ def _process_file(
]
class LocalFileConnector(LoadConnector, VisionEnabledConnector):
class LocalFileConnector(LoadConnector):
"""
Connector that reads files from Postgres and yields Documents, including
optional embedded image extraction.
embedded image extraction without summarization.
"""
def __init__(
@@ -252,9 +283,6 @@ class LocalFileConnector(LoadConnector, VisionEnabledConnector):
self.batch_size = batch_size
self.pdf_pass: str | None = None
# Initialize vision LLM using the mixin
self.initialize_vision_llm()
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self.pdf_pass = credentials.get("pdf_password")
@@ -286,7 +314,6 @@ class LocalFileConnector(LoadConnector, VisionEnabledConnector):
metadata=metadata,
pdf_pass=self.pdf_pass,
db_session=db_session,
llm=self.image_analysis_llm,
)
documents.extend(new_docs)

View File

@@ -1,6 +1,7 @@
from collections.abc import Iterator
from datetime import datetime
from datetime import timezone
from typing import cast
from typing import List
import requests
@@ -14,7 +15,8 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -45,7 +47,7 @@ _FIREFLIES_API_QUERY = """
def _create_doc_from_transcript(transcript: dict) -> Document | None:
sections: List[Section] = []
sections: List[TextSection] = []
current_speaker_name = None
current_link = ""
current_text = ""
@@ -57,7 +59,7 @@ def _create_doc_from_transcript(transcript: dict) -> Document | None:
if sentence["speaker_name"] != current_speaker_name:
if current_speaker_name is not None:
sections.append(
Section(
TextSection(
link=current_link,
text=current_text.strip(),
)
@@ -71,7 +73,7 @@ def _create_doc_from_transcript(transcript: dict) -> Document | None:
# Sometimes these links (links with a timestamp) do not work, it is a bug with Fireflies.
sections.append(
Section(
TextSection(
link=current_link,
text=current_text.strip(),
)
@@ -94,7 +96,7 @@ def _create_doc_from_transcript(transcript: dict) -> Document | None:
return Document(
id=fireflies_id,
sections=sections,
sections=cast(list[TextSection | ImageSection], sections),
source=DocumentSource.FIREFLIES,
semantic_identifier=meeting_title,
metadata={},

View File

@@ -14,7 +14,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import TextSection
from onyx.file_processing.html_utils import parse_html_page_basic
from onyx.utils.logger import setup_logger
@@ -133,7 +133,7 @@ def _create_doc_from_ticket(ticket: dict, domain: str) -> Document:
return Document(
id=_FRESHDESK_ID_PREFIX + link,
sections=[
Section(
TextSection(
link=link,
text=text,
)

View File

@@ -13,7 +13,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import TextSection
from onyx.utils.logger import setup_logger
@@ -183,7 +183,7 @@ def _convert_page_to_document(
return Document(
id=f"gitbook-{space_id}-{page_id}",
sections=[
Section(
TextSection(
link=page.get("urls", {}).get("app", ""),
text=_extract_text_from_document(page_content),
)

View File

@@ -27,7 +27,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import TextSection
from onyx.utils.batching import batch_generator
from onyx.utils.logger import setup_logger
@@ -87,7 +87,9 @@ def _batch_github_objects(
def _convert_pr_to_document(pull_request: PullRequest) -> Document:
return Document(
id=pull_request.html_url,
sections=[Section(link=pull_request.html_url, text=pull_request.body or "")],
sections=[
TextSection(link=pull_request.html_url, text=pull_request.body or "")
],
source=DocumentSource.GITHUB,
semantic_identifier=pull_request.title,
# updated_at is UTC time but is timezone unaware, explicitly add UTC
@@ -109,7 +111,7 @@ def _fetch_issue_comments(issue: Issue) -> str:
def _convert_issue_to_document(issue: Issue) -> Document:
return Document(
id=issue.html_url,
sections=[Section(link=issue.html_url, text=issue.body or "")],
sections=[TextSection(link=issue.html_url, text=issue.body or "")],
source=DocumentSource.GITHUB,
semantic_identifier=issue.title,
# updated_at is UTC time but is timezone unaware

View File

@@ -21,7 +21,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import TextSection
from onyx.utils.logger import setup_logger
@@ -56,7 +56,7 @@ def get_author(author: Any) -> BasicExpertInfo:
def _convert_merge_request_to_document(mr: Any) -> Document:
doc = Document(
id=mr.web_url,
sections=[Section(link=mr.web_url, text=mr.description or "")],
sections=[TextSection(link=mr.web_url, text=mr.description or "")],
source=DocumentSource.GITLAB,
semantic_identifier=mr.title,
# updated_at is UTC time but is timezone unaware, explicitly add UTC
@@ -72,7 +72,7 @@ def _convert_merge_request_to_document(mr: Any) -> Document:
def _convert_issue_to_document(issue: Any) -> Document:
doc = Document(
id=issue.web_url,
sections=[Section(link=issue.web_url, text=issue.description or "")],
sections=[TextSection(link=issue.web_url, text=issue.description or "")],
source=DocumentSource.GITLAB,
semantic_identifier=issue.title,
# updated_at is UTC time but is timezone unaware, explicitly add UTC
@@ -99,7 +99,7 @@ def _convert_code_to_document(
file_url = f"{url}/{projectOwner}/{projectName}/-/blob/master/{file['path']}" # Construct the file URL
doc = Document(
id=file["id"],
sections=[Section(link=file_url, text=file_content)],
sections=[TextSection(link=file_url, text=file_content)],
source=DocumentSource.GITLAB,
semantic_identifier=file["name"],
doc_updated_at=datetime.now().replace(

View File

@@ -1,5 +1,6 @@
from base64 import urlsafe_b64decode
from typing import Any
from typing import cast
from typing import Dict
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
@@ -28,8 +29,9 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import ImageSection
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
@@ -115,7 +117,7 @@ def _get_message_body(payload: dict[str, Any]) -> str:
return message_body
def message_to_section(message: Dict[str, Any]) -> tuple[Section, dict[str, str]]:
def message_to_section(message: Dict[str, Any]) -> tuple[TextSection, dict[str, str]]:
link = f"https://mail.google.com/mail/u/0/#inbox/{message['id']}"
payload = message.get("payload", {})
@@ -142,7 +144,7 @@ def message_to_section(message: Dict[str, Any]) -> tuple[Section, dict[str, str]
message_body_text: str = _get_message_body(payload)
return Section(link=link, text=message_body_text + message_data), metadata
return TextSection(link=link, text=message_body_text + message_data), metadata
def thread_to_document(full_thread: Dict[str, Any]) -> Document | None:
@@ -192,7 +194,7 @@ def thread_to_document(full_thread: Dict[str, Any]) -> Document | None:
return Document(
id=id,
semantic_identifier=semantic_identifier,
sections=sections,
sections=cast(list[TextSection | ImageSection], sections),
source=DocumentSource.GMAIL,
# This is used to perform permission sync
primary_owners=primary_owners,

View File

@@ -18,7 +18,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import TextSection
from onyx.utils.logger import setup_logger
@@ -243,7 +243,7 @@ class GongConnector(LoadConnector, PollConnector):
Document(
id=call_id,
sections=[
Section(link=call_metadata["url"], text=transcript_text)
TextSection(link=call_metadata["url"], text=transcript_text)
],
source=DocumentSource.GONG,
# Should not ever be Untitled as a call cannot be made without a Title

View File

@@ -43,9 +43,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.vision_enabled_connector import VisionEnabledConnector
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.llm.interfaces import LLM
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
@@ -68,7 +66,6 @@ def _convert_single_file(
creds: Any,
primary_admin_email: str,
file: dict[str, Any],
image_analysis_llm: LLM | None,
) -> Any:
user_email = file.get("owners", [{}])[0].get("emailAddress") or primary_admin_email
user_drive_service = get_drive_service(creds, user_email=user_email)
@@ -77,7 +74,6 @@ def _convert_single_file(
file=file,
drive_service=user_drive_service,
docs_service=docs_service,
image_analysis_llm=image_analysis_llm, # pass the LLM so doc_conversion can summarize images
)
@@ -116,9 +112,7 @@ def _clean_requested_drive_ids(
return valid_requested_drive_ids, filtered_folder_ids
class GoogleDriveConnector(
LoadConnector, PollConnector, SlimConnector, VisionEnabledConnector
):
class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
def __init__(
self,
include_shared_drives: bool = False,
@@ -151,9 +145,6 @@ class GoogleDriveConnector(
if continue_on_failure is not None:
logger.warning("The 'continue_on_failure' parameter is deprecated.")
# Initialize vision LLM using the mixin
self.initialize_vision_llm()
if (
not include_shared_drives
and not include_my_drives
@@ -539,7 +530,6 @@ class GoogleDriveConnector(
_convert_single_file,
self.creds,
self.primary_admin_email,
image_analysis_llm=self.image_analysis_llm, # Use the mixin's LLM
)
# Fetch files in batches

View File

@@ -1,40 +1,50 @@
import io
from datetime import datetime
from datetime import timezone
from tempfile import NamedTemporaryFile
from typing import cast
import openpyxl # type: ignore
from googleapiclient.discovery import build # type: ignore
from googleapiclient.errors import HttpError # type: ignore
from googleapiclient.http import MediaIoBaseDownload # type: ignore
from onyx.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import FileOrigin
from onyx.connectors.google_drive.constants import DRIVE_FOLDER_TYPE
from onyx.connectors.google_drive.constants import DRIVE_SHORTCUT_TYPE
from onyx.connectors.google_drive.constants import UNSUPPORTED_FILE_TYPE_CONTENT
from onyx.connectors.google_drive.models import GDriveMimeType
from onyx.connectors.google_drive.models import GoogleDriveFileType
from onyx.connectors.google_drive.section_extraction import get_document_sections
from onyx.connectors.google_utils.resources import GoogleDocsService
from onyx.connectors.google_utils.resources import GoogleDriveService
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import ImageSection
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.db.engine import get_session_with_current_tenant
from onyx.file_processing.extract_file_text import docx_to_text_and_images
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import pptx_to_text
from onyx.file_processing.extract_file_text import read_pdf_file
from onyx.file_processing.extract_file_text import xlsx_to_text
from onyx.file_processing.file_validation import is_valid_image_type
from onyx.file_processing.image_summarization import summarize_image_with_error_handling
from onyx.file_processing.image_utils import store_image_and_create_section
from onyx.file_processing.unstructured import get_unstructured_api_key
from onyx.file_processing.unstructured import unstructured_to_text
from onyx.llm.interfaces import LLM
from onyx.utils.logger import setup_logger
logger = setup_logger()
# Mapping of Google Drive mime types to export formats
GOOGLE_MIME_TYPES_TO_EXPORT = {
GDriveMimeType.DOC.value: "text/plain",
GDriveMimeType.SPREADSHEET.value: "text/csv",
GDriveMimeType.PPT.value: "text/plain",
}
# Define Google MIME types mapping
GOOGLE_MIME_TYPES = {
GDriveMimeType.DOC.value: "text/plain",
GDriveMimeType.SPREADSHEET.value: "text/csv",
GDriveMimeType.PPT.value: "text/plain",
}
def _summarize_drive_image(
image_data: bytes, image_name: str, image_analysis_llm: LLM | None
@@ -66,259 +76,137 @@ def is_gdrive_image_mime_type(mime_type: str) -> bool:
def _extract_sections_basic(
file: dict[str, str],
service: GoogleDriveService,
image_analysis_llm: LLM | None = None,
) -> list[Section]:
"""
Extends the existing logic to handle either a docx with embedded images
or standalone images (PNG, JPG, etc).
"""
) -> list[TextSection | ImageSection]:
"""Extract text and images from a Google Drive file."""
file_id = file["id"]
file_name = file["name"]
mime_type = file["mimeType"]
link = file["webViewLink"]
file_name = file.get("name", file["id"])
supported_file_types = set(item.value for item in GDriveMimeType)
# 1) If the file is an image, retrieve the raw bytes, optionally summarize
if is_gdrive_image_mime_type(mime_type):
try:
response = service.files().get_media(fileId=file["id"]).execute()
with get_session_with_current_tenant() as db_session:
section, _ = store_image_and_create_section(
db_session=db_session,
image_data=response,
file_name=file["id"],
display_name=file_name,
media_type=mime_type,
llm=image_analysis_llm,
file_origin=FileOrigin.CONNECTOR,
)
return [section]
except Exception as e:
logger.warning(f"Failed to fetch or summarize image: {e}")
return [
Section(
link=link,
text="",
image_file_name=link,
)
]
if mime_type not in supported_file_types:
# Unsupported file types can still have a title, finding this way is still useful
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
link = file.get("webViewLink", "")
try:
# ---------------------------
# Google Sheets extraction
if mime_type == GDriveMimeType.SPREADSHEET.value:
# For Google Docs, Sheets, and Slides, export as plain text
if mime_type in GOOGLE_MIME_TYPES_TO_EXPORT:
export_mime_type = GOOGLE_MIME_TYPES_TO_EXPORT[mime_type]
# Use the correct API call for exporting files
request = service.files().export_media(
fileId=file_id, mimeType=export_mime_type
)
response_bytes = io.BytesIO()
downloader = MediaIoBaseDownload(response_bytes, request)
done = False
while not done:
_, done = downloader.next_chunk()
response = response_bytes.getvalue()
if not response:
logger.warning(f"Failed to export {file_name} as {export_mime_type}")
return []
text = response.decode("utf-8")
return [TextSection(link=link, text=text)]
# For other file types, download the file
# Use the correct API call for downloading files
request = service.files().get_media(fileId=file_id)
response_bytes = io.BytesIO()
downloader = MediaIoBaseDownload(response_bytes, request)
done = False
while not done:
_, done = downloader.next_chunk()
response = response_bytes.getvalue()
if not response:
logger.warning(f"Failed to download {file_name}")
return []
# Process based on mime type
if mime_type == "text/plain":
text = response.decode("utf-8")
return [TextSection(link=link, text=text)]
elif (
mime_type
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
):
text, _ = docx_to_text_and_images(io.BytesIO(response))
return [TextSection(link=link, text=text)]
elif (
mime_type
== "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
):
text = xlsx_to_text(io.BytesIO(response))
return [TextSection(link=link, text=text)]
elif (
mime_type
== "application/vnd.openxmlformats-officedocument.presentationml.presentation"
):
text = pptx_to_text(io.BytesIO(response))
return [TextSection(link=link, text=text)]
elif is_gdrive_image_mime_type(mime_type):
# For images, store them for later processing
sections: list[TextSection | ImageSection] = []
try:
sheets_service = build(
"sheets", "v4", credentials=service._http.credentials
)
spreadsheet = (
sheets_service.spreadsheets()
.get(spreadsheetId=file["id"])
.execute()
)
sections = []
for sheet in spreadsheet["sheets"]:
sheet_name = sheet["properties"]["title"]
sheet_id = sheet["properties"]["sheetId"]
# Get sheet dimensions
grid_properties = sheet["properties"].get("gridProperties", {})
row_count = grid_properties.get("rowCount", 1000)
column_count = grid_properties.get("columnCount", 26)
# Convert column count to letter (e.g., 26 -> Z, 27 -> AA)
end_column = ""
while column_count:
column_count, remainder = divmod(column_count - 1, 26)
end_column = chr(65 + remainder) + end_column
range_name = f"'{sheet_name}'!A1:{end_column}{row_count}"
try:
result = (
sheets_service.spreadsheets()
.values()
.get(spreadsheetId=file["id"], range=range_name)
.execute()
)
values = result.get("values", [])
if values:
text = f"Sheet: {sheet_name}\n"
for row in values:
text += "\t".join(str(cell) for cell in row) + "\n"
sections.append(
Section(
link=f"{link}#gid={sheet_id}",
text=text,
)
)
except HttpError as e:
logger.warning(
f"Error fetching data for sheet '{sheet_name}': {e}"
)
continue
return sections
except Exception as e:
logger.warning(
f"Ran into exception '{e}' when pulling data from Google Sheet '{file['name']}'."
" Falling back to basic extraction."
)
# ---------------------------
# Microsoft Excel (.xlsx or .xls) extraction branch
elif mime_type in [
GDriveMimeType.SPREADSHEET_OPEN_FORMAT.value,
GDriveMimeType.SPREADSHEET_MS_EXCEL.value,
]:
try:
response = service.files().get_media(fileId=file["id"]).execute()
with NamedTemporaryFile(suffix=".xlsx", delete=True) as tmp:
tmp.write(response)
tmp_path = tmp.name
section_separator = "\n\n"
workbook = openpyxl.load_workbook(tmp_path, read_only=True)
# Work similarly to the xlsx_to_text function used for file connector
# but returns Sections instead of a string
sections = [
Section(
link=link,
text=(
f"Sheet: {sheet.title}\n\n"
+ section_separator.join(
",".join(map(str, row))
for row in sheet.iter_rows(
min_row=1, values_only=True
)
if row
)
),
)
for sheet in workbook.worksheets
]
return sections
except Exception as e:
logger.warning(
f"Error extracting data from Excel file '{file['name']}': {e}"
)
return [
Section(link=link, text="Error extracting data from Excel file")
]
# ---------------------------
# Export for Google Docs, PPT, and fallback for spreadsheets
if mime_type in [
GDriveMimeType.DOC.value,
GDriveMimeType.PPT.value,
GDriveMimeType.SPREADSHEET.value,
]:
export_mime_type = (
"text/plain"
if mime_type != GDriveMimeType.SPREADSHEET.value
else "text/csv"
)
text = (
service.files()
.export(fileId=file["id"], mimeType=export_mime_type)
.execute()
.decode("utf-8")
)
return [Section(link=link, text=text)]
# ---------------------------
# Plain text and Markdown files
elif mime_type in [
GDriveMimeType.PLAIN_TEXT.value,
GDriveMimeType.MARKDOWN.value,
]:
text_data = (
service.files().get_media(fileId=file["id"]).execute().decode("utf-8")
)
return [Section(link=link, text=text_data)]
# ---------------------------
# Word, PowerPoint, PDF files
elif mime_type in [
GDriveMimeType.WORD_DOC.value,
GDriveMimeType.POWERPOINT.value,
GDriveMimeType.PDF.value,
]:
response_bytes = service.files().get_media(fileId=file["id"]).execute()
# Optionally use Unstructured
if get_unstructured_api_key():
text = unstructured_to_text(
file=io.BytesIO(response_bytes),
file_name=file_name,
)
return [Section(link=link, text=text)]
if mime_type == GDriveMimeType.WORD_DOC.value:
# Use docx_to_text_and_images to get text plus embedded images
text, embedded_images = docx_to_text_and_images(
file=io.BytesIO(response_bytes),
)
sections = []
if text.strip():
sections.append(Section(link=link, text=text.strip()))
# Process each embedded image using the standardized function
with get_session_with_current_tenant() as db_session:
for idx, (img_data, img_name) in enumerate(
embedded_images, start=1
):
# Create a unique identifier for the embedded image
embedded_id = f"{file['id']}_embedded_{idx}"
section, embedded_id = store_image_and_create_section(
db_session=db_session,
image_data=response,
file_name=file_id,
display_name=file_name,
media_type=mime_type,
file_origin=FileOrigin.CONNECTOR,
link=link,
)
sections.append(section)
except Exception as e:
logger.error(f"Failed to process image {file_name}: {e}")
return sections
section, _ = store_image_and_create_section(
elif mime_type == "application/pdf":
text, _pdf_meta, images = read_pdf_file(io.BytesIO(response))
pdf_sections: list[TextSection | ImageSection] = [
TextSection(link=link, text=text)
]
# Process embedded images in the PDF
try:
with get_session_with_current_tenant() as db_session:
for idx, (img_data, img_name) in enumerate(images):
section, embedded_id = store_image_and_create_section(
db_session=db_session,
image_data=img_data,
file_name=embedded_id,
file_name=f"{file_id}_img_{idx}",
display_name=img_name or f"{file_name} - image {idx}",
llm=image_analysis_llm,
file_origin=FileOrigin.CONNECTOR,
)
sections.append(section)
return sections
pdf_sections.append(section)
except Exception as e:
logger.error(f"Failed to process PDF images in {file_name}: {e}")
return pdf_sections
elif mime_type == GDriveMimeType.PDF.value:
text, _pdf_meta, images = read_pdf_file(io.BytesIO(response_bytes))
return [Section(link=link, text=text)]
elif mime_type == GDriveMimeType.POWERPOINT.value:
text_data = pptx_to_text(io.BytesIO(response_bytes))
return [Section(link=link, text=text_data)]
# Catch-all case, should not happen since there should be specific handling
# for each of the supported file types
error_message = f"Unsupported file type: {mime_type}"
logger.error(error_message)
raise ValueError(error_message)
else:
# For unsupported file types, try to extract text
try:
text = extract_file_text(io.BytesIO(response), file_name)
return [TextSection(link=link, text=text)]
except Exception as e:
logger.warning(f"Failed to extract text from {file_name}: {e}")
return []
except Exception as e:
logger.exception(f"Error extracting sections from file: {e}")
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
logger.error(f"Error processing file {file_name}: {e}")
return []
def convert_drive_item_to_document(
file: GoogleDriveFileType,
drive_service: GoogleDriveService,
docs_service: GoogleDocsService,
image_analysis_llm: LLM | None,
) -> Document | None:
"""
Main entry point for converting a Google Drive file => Document object.
Now we accept an optional `llm` to pass to `_extract_sections_basic`.
"""
try:
# skip shortcuts or folders
@@ -327,44 +215,50 @@ def convert_drive_item_to_document(
return None
# If it's a Google Doc, we might do advanced parsing
sections: list[Section] = []
sections: list[TextSection | ImageSection] = []
# Try to get sections using the advanced method first
if file.get("mimeType") == GDriveMimeType.DOC.value:
try:
# get_document_sections is the advanced approach for Google Docs
sections = get_document_sections(docs_service, file["id"])
doc_sections = get_document_sections(
docs_service=docs_service, doc_id=file.get("id", "")
)
if doc_sections:
sections = cast(list[TextSection | ImageSection], doc_sections)
except Exception as e:
logger.warning(
f"Failed to pull google doc sections from '{file['name']}': {e}. "
"Falling back to basic extraction."
f"Error in advanced parsing: {e}. Falling back to basic extraction."
)
# If not a doc, or if we failed above, do our 'basic' approach
# If we don't have sections yet, use the basic extraction method
if not sections:
sections = _extract_sections_basic(file, drive_service, image_analysis_llm)
sections = _extract_sections_basic(file, drive_service)
# If we still don't have any sections, skip this file
if not sections:
logger.warning(f"No content extracted from {file.get('name')}. Skipping.")
return None
doc_id = file["webViewLink"]
updated_time = datetime.fromisoformat(file["modifiedTime"]).astimezone(
timezone.utc
)
# Create the document
return Document(
id=doc_id,
sections=sections,
source=DocumentSource.GOOGLE_DRIVE,
semantic_identifier=file["name"],
doc_updated_at=updated_time,
metadata={}, # or any metadata from 'file'
additional_info=file.get("id"),
semantic_identifier=file.get("name", ""),
metadata={
"owner_names": ", ".join(
owner.get("displayName", "") for owner in file.get("owners", [])
),
},
doc_updated_at=datetime.fromisoformat(
file.get("modifiedTime", "").replace("Z", "+00:00")
),
)
except Exception as e:
logger.exception(f"Error converting file '{file.get('name')}' to Document: {e}")
if not CONTINUE_ON_CONNECTOR_FAILURE:
raise
return None
logger.error(f"Error converting file {file.get('name')}: {e}")
return None
def build_slim_document(file: GoogleDriveFileType) -> SlimDocument | None:

View File

@@ -3,7 +3,7 @@ from typing import Any
from pydantic import BaseModel
from onyx.connectors.google_utils.resources import GoogleDocsService
from onyx.connectors.models import Section
from onyx.connectors.models import TextSection
class CurrentHeading(BaseModel):
@@ -37,7 +37,7 @@ def _extract_text_from_paragraph(paragraph: dict[str, Any]) -> str:
def get_document_sections(
docs_service: GoogleDocsService,
doc_id: str,
) -> list[Section]:
) -> list[TextSection]:
"""Extracts sections from a Google Doc, including their headings and content"""
# Fetch the document structure
doc = docs_service.documents().get(documentId=doc_id).execute()
@@ -45,7 +45,7 @@ def get_document_sections(
# Get the content
content = doc.get("body", {}).get("content", [])
sections: list[Section] = []
sections: list[TextSection] = []
current_section: list[str] = []
current_heading: CurrentHeading | None = None
@@ -70,7 +70,7 @@ def get_document_sections(
heading_text = current_heading.text
section_text = f"{heading_text}\n" + "\n".join(current_section)
sections.append(
Section(
TextSection(
text=section_text.strip(),
link=_build_gdoc_section_link(doc_id, current_heading.id),
)
@@ -96,7 +96,7 @@ def get_document_sections(
if current_heading is not None and current_section:
section_text = f"{current_heading.text}\n" + "\n".join(current_section)
sections.append(
Section(
TextSection(
text=section_text.strip(),
link=_build_gdoc_section_link(doc_id, current_heading.id),
)

View File

@@ -12,7 +12,7 @@ from onyx.configs.constants import DocumentSource
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import TextSection
from onyx.db.engine import get_sqlalchemy_engine
from onyx.file_processing.extract_file_text import load_files_from_zip
from onyx.file_processing.extract_file_text import read_text_file
@@ -118,7 +118,7 @@ class GoogleSitesConnector(LoadConnector):
source=DocumentSource.GOOGLE_SITES,
semantic_identifier=title,
sections=[
Section(
TextSection(
link=(self.base_url.rstrip("/") + "/" + path.lstrip("/"))
if path
else "",

View File

@@ -15,7 +15,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import TextSection
from onyx.file_processing.html_utils import parse_html_page_basic
from onyx.utils.logger import setup_logger
@@ -120,7 +120,7 @@ class GuruConnector(LoadConnector, PollConnector):
doc_batch.append(
Document(
id=card["id"],
sections=[Section(link=link, text=content_text)],
sections=[TextSection(link=link, text=content_text)],
source=DocumentSource.GURU,
semantic_identifier=title,
doc_updated_at=latest_time,

View File

@@ -13,7 +13,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import TextSection
from onyx.utils.logger import setup_logger
HUBSPOT_BASE_URL = "https://app.hubspot.com/contacts/"
@@ -108,7 +108,7 @@ class HubSpotConnector(LoadConnector, PollConnector):
doc_batch.append(
Document(
id=ticket.id,
sections=[Section(link=link, text=content_text)],
sections=[TextSection(link=link, text=content_text)],
source=DocumentSource.HUBSPOT,
semantic_identifier=title,
# Is already in tzutc, just replacing the timezone format

View File

@@ -24,6 +24,8 @@ CheckpointOutput = Generator[Document | ConnectorFailure, None, ConnectorCheckpo
class BaseConnector(abc.ABC):
REDIS_KEY_PREFIX = "da_connector_data:"
# Common image file extensions supported across connectors
IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".webp", ".gif"}
@abc.abstractmethod
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:

View File

@@ -21,7 +21,8 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import request_with_retries
@@ -237,22 +238,30 @@ class LinearConnector(LoadConnector, PollConnector, OAuthConnector):
documents: list[Document] = []
for edge in edges:
node = edge["node"]
# Create sections for description and comments
sections = [
TextSection(
link=node["url"],
text=node["description"] or "",
)
]
# Add comment sections
for comment in node["comments"]["nodes"]:
sections.append(
TextSection(
link=node["url"],
text=comment["body"] or "",
)
)
# Cast the sections list to the expected type
typed_sections = cast(list[TextSection | ImageSection], sections)
documents.append(
Document(
id=node["id"],
sections=[
Section(
link=node["url"],
text=node["description"] or "",
)
]
+ [
Section(
link=node["url"],
text=comment["body"] or "",
)
for comment in node["comments"]["nodes"]
],
sections=typed_sections,
source=DocumentSource.LINEAR,
semantic_identifier=f"[{node['identifier']}] {node['title']}",
title=node["title"],

View File

@@ -17,7 +17,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import TextSection
from onyx.file_processing.html_utils import parse_html_page_basic
from onyx.file_processing.html_utils import strip_excessive_newlines_and_spaces
from onyx.utils.logger import setup_logger
@@ -162,7 +162,7 @@ class LoopioConnector(LoadConnector, PollConnector):
doc_batch.append(
Document(
id=str(entry["id"]),
sections=[Section(link=link, text=content_text)],
sections=[TextSection(link=link, text=content_text)],
source=DocumentSource.LOOPIO,
semantic_identifier=questions[0],
doc_updated_at=latest_time,

View File

@@ -6,6 +6,7 @@ import tempfile
from collections.abc import Generator
from collections.abc import Iterator
from typing import Any
from typing import cast
from typing import ClassVar
import pywikibot.time # type: ignore[import-untyped]
@@ -20,7 +21,8 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.mediawiki.family import family_class_dispatch
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.utils.logger import setup_logger
@@ -60,14 +62,14 @@ def get_doc_from_page(
sections_extracted: textlib.Content = textlib.extract_sections(page_text, site)
sections = [
Section(
TextSection(
link=f"{page.full_url()}#" + section.heading.replace(" ", "_"),
text=section.title + section.content,
)
for section in sections_extracted.sections
]
sections.append(
Section(
TextSection(
link=page.full_url(),
text=sections_extracted.header,
)
@@ -79,7 +81,7 @@ def get_doc_from_page(
doc_updated_at=pywikibot_timestamp_to_utc_datetime(
page.latest_revision.timestamp
),
sections=sections,
sections=cast(list[TextSection | ImageSection], sections),
semantic_identifier=page.title(),
metadata={"categories": [category.title() for category in page.categories()]},
id=f"MEDIAWIKI_{page.pageid}_{page.full_url()}",

View File

@@ -28,9 +28,25 @@ class ConnectorMissingCredentialError(PermissionError):
class Section(BaseModel):
"""Base section class with common attributes"""
link: str | None = None
text: str | None = None
image_file_name: str | None = None
class TextSection(Section):
"""Section containing text content"""
text: str
link: str | None = None
image_file_name: str | None = None
class ImageSection(Section):
"""Section containing an image reference"""
image_file_name: str
link: str | None = None
class BasicExpertInfo(BaseModel):
@@ -100,7 +116,7 @@ class DocumentBase(BaseModel):
"""Used for Onyx ingestion api, the ID is inferred before use if not provided"""
id: str | None = None
sections: list[Section]
sections: list[TextSection | ImageSection]
source: DocumentSource | None = None
semantic_identifier: str # displayed in the UI as the main identifier for the doc
metadata: dict[str, str | list[str]]
@@ -150,18 +166,10 @@ class DocumentBase(BaseModel):
class Document(DocumentBase):
id: str # This must be unique or during indexing/reindexing, chunks will be overwritten
source: DocumentSource
"""Used for Onyx ingestion api, the ID is required"""
def get_total_char_length(self) -> int:
"""Calculate the total character length of the document including sections, metadata, and identifiers."""
section_length = sum(len(section.text) for section in self.sections)
identifier_length = len(self.semantic_identifier) + len(self.title or "")
metadata_length = sum(
len(k) + len(v) if isinstance(v, str) else len(k) + sum(len(x) for x in v)
for k, v in self.metadata.items()
)
return section_length + identifier_length + metadata_length
id: str
source: DocumentSource
def to_short_descriptor(self) -> str:
"""Used when logging the identity of a document"""
@@ -185,6 +193,32 @@ class Document(DocumentBase):
)
class IndexingDocument(Document):
"""Document with processed sections for indexing"""
processed_sections: list[Section] = []
def get_total_char_length(self) -> int:
"""Get the total character length of the document including processed sections"""
title_len = len(self.title or self.semantic_identifier)
# Use processed_sections if available, otherwise fall back to original sections
if self.processed_sections:
section_len = sum(
len(section.text) if section.text is not None else 0
for section in self.processed_sections
)
else:
section_len = sum(
len(section.text)
if isinstance(section, TextSection) and section.text is not None
else 0
for section in self.sections
)
return title_len + section_len
class SlimDocument(BaseModel):
id: str
perm_sync_data: Any | None = None

View File

@@ -25,7 +25,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import TextSection
from onyx.utils.batching import batch_generator
from onyx.utils.logger import setup_logger
@@ -475,7 +475,7 @@ class NotionConnector(LoadConnector, PollConnector):
Document(
id=page.id,
sections=[
Section(
TextSection(
link=f"{page.url}#{block.id.replace('-', '')}",
text=block.prefix + block.text,
)

View File

@@ -23,8 +23,8 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.connectors.onyx_jira.utils import best_effort_basic_expert_info
from onyx.connectors.onyx_jira.utils import best_effort_get_field_from_issue
from onyx.connectors.onyx_jira.utils import build_jira_client
@@ -145,7 +145,7 @@ def fetch_jira_issues_batch(
yield Document(
id=page_url,
sections=[Section(link=page_url, text=ticket_content)],
sections=[TextSection(link=page_url, text=ticket_content)],
source=DocumentSource.JIRA,
semantic_identifier=f"{issue.key}: {issue.fields.summary}",
title=f"{issue.key} {issue.fields.summary}",

View File

@@ -16,7 +16,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import TextSection
from onyx.utils.logger import setup_logger
@@ -110,7 +110,7 @@ class ProductboardConnector(PollConnector):
yield Document(
id=feature["id"],
sections=[
Section(
TextSection(
link=feature["links"]["html"],
text=self._parse_description_html(feature["description"]),
)
@@ -133,7 +133,7 @@ class ProductboardConnector(PollConnector):
yield Document(
id=component["id"],
sections=[
Section(
TextSection(
link=component["links"]["html"],
text=self._parse_description_html(component["description"]),
)
@@ -159,7 +159,7 @@ class ProductboardConnector(PollConnector):
yield Document(
id=product["id"],
sections=[
Section(
TextSection(
link=product["links"]["html"],
text=self._parse_description_html(product["description"]),
)
@@ -189,7 +189,7 @@ class ProductboardConnector(PollConnector):
yield Document(
id=objective["id"],
sections=[
Section(
TextSection(
link=objective["links"]["html"],
text=self._parse_description_html(objective["description"]),
)

View File

@@ -13,6 +13,7 @@ from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.connectors.salesforce.doc_conversion import convert_sf_object_to_doc
from onyx.connectors.salesforce.doc_conversion import ID_PREFIX
from onyx.connectors.salesforce.salesforce_calls import fetch_all_csvs_in_parallel
@@ -48,10 +49,12 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
self,
credentials: dict[str, Any],
) -> dict[str, Any] | None:
domain = "test" if credentials.get("is_sandbox") else None
self._sf_client = Salesforce(
username=credentials["sf_username"],
password=credentials["sf_password"],
security_token=credentials["sf_security_token"],
domain=domain,
)
return None
@@ -216,7 +219,8 @@ if __name__ == "__main__":
for doc in doc_batch:
section_count += len(doc.sections)
for section in doc.sections:
text_count += len(section.text)
if isinstance(section, TextSection) and section.text is not None:
text_count += len(section.text)
end_time = time.time()
print(f"Doc count: {doc_count}")

View File

@@ -1,10 +1,12 @@
import re
from typing import cast
from onyx.configs.constants import DocumentSource
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.connectors.salesforce.sqlite_functions import get_child_ids
from onyx.connectors.salesforce.sqlite_functions import get_record
from onyx.connectors.salesforce.utils import SalesforceObject
@@ -114,8 +116,8 @@ def _extract_dict_text(raw_dict: dict) -> str:
return natural_language_for_dict
def _extract_section(salesforce_object: SalesforceObject, base_url: str) -> Section:
return Section(
def _extract_section(salesforce_object: SalesforceObject, base_url: str) -> TextSection:
return TextSection(
text=_extract_dict_text(salesforce_object.data),
link=f"{base_url}/{salesforce_object.id}",
)
@@ -175,7 +177,7 @@ def convert_sf_object_to_doc(
doc = Document(
id=onyx_salesforce_id,
sections=sections,
sections=cast(list[TextSection | ImageSection], sections),
source=DocumentSource.SALESFORCE,
semantic_identifier=extracted_semantic_identifier,
doc_updated_at=extracted_doc_updated_at,

View File

@@ -19,7 +19,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import TextSection
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.utils.logger import setup_logger
@@ -55,7 +55,7 @@ def _convert_driveitem_to_document(
doc = Document(
id=driveitem.id,
sections=[Section(link=driveitem.web_url, text=file_text)],
sections=[TextSection(link=driveitem.web_url, text=file_text)],
source=DocumentSource.SHAREPOINT,
semantic_identifier=driveitem.name,
doc_updated_at=driveitem.last_modified_datetime.replace(tzinfo=timezone.utc),

View File

@@ -19,8 +19,8 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
@@ -212,7 +212,7 @@ class SlabConnector(LoadConnector, PollConnector, SlimConnector):
doc_batch.append(
Document(
id=post_id, # can't be url as this changes with the post title
sections=[Section(link=page_url, text=content_text)],
sections=[TextSection(link=page_url, text=content_text)],
source=DocumentSource.SLAB,
semantic_identifier=post["title"],
metadata={},

View File

@@ -34,8 +34,8 @@ from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import EntityFailure
from onyx.connectors.models import Section
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.connectors.slack.utils import expert_info_from_slack_id
from onyx.connectors.slack.utils import get_message_link
from onyx.connectors.slack.utils import make_paginated_slack_api_call_w_retries
@@ -211,7 +211,7 @@ def thread_to_doc(
return Document(
id=_build_doc_id(channel_id=channel_id, thread_ts=thread[0]["ts"]),
sections=[
Section(
TextSection(
link=get_message_link(event=m, client=client, channel_id=channel_id),
text=slack_cleaner.index_clean(cast(str, m["text"])),
)
@@ -220,7 +220,6 @@ def thread_to_doc(
source=DocumentSource.SLACK,
semantic_identifier=doc_sem_id,
doc_updated_at=get_latest_message_time(thread),
title="", # slack docs don't really have a "title"
primary_owners=valid_experts,
metadata={"Channel": channel["name"]},
)

View File

@@ -24,7 +24,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import TextSection
from onyx.file_processing.html_utils import parse_html_page_basic
from onyx.utils.logger import setup_logger
@@ -165,7 +165,7 @@ def _convert_thread_to_document(
doc = Document(
id=post_id,
sections=[Section(link=web_url, text=thread_text)],
sections=[TextSection(link=web_url, text=thread_text)],
source=DocumentSource.TEAMS,
semantic_identifier=semantic_string,
title="", # teams threads don't really have a "title"

View File

@@ -1,45 +0,0 @@
"""
Mixin for connectors that need vision capabilities.
"""
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
from onyx.llm.factory import get_default_llm_with_vision
from onyx.llm.interfaces import LLM
from onyx.utils.logger import setup_logger
logger = setup_logger()
class VisionEnabledConnector:
"""
Mixin for connectors that need vision capabilities.
This mixin provides a standard way to initialize a vision-capable LLM
for image analysis during indexing.
Usage:
class MyConnector(LoadConnector, VisionEnabledConnector):
def __init__(self, ...):
super().__init__(...)
self.initialize_vision_llm()
"""
def initialize_vision_llm(self) -> None:
"""
Initialize a vision-capable LLM if enabled by configuration.
Sets self.image_analysis_llm to the LLM instance or None if disabled.
"""
self.image_analysis_llm: LLM | None = None
if get_image_extraction_and_analysis_enabled():
try:
self.image_analysis_llm = get_default_llm_with_vision()
if self.image_analysis_llm is None:
logger.warning(
"No LLM with vision found; image summarization will be disabled"
)
except Exception as e:
logger.warning(
f"Failed to initialize vision LLM due to an error: {str(e)}. "
"Image summarization will be disabled."
)
self.image_analysis_llm = None

View File

@@ -32,7 +32,7 @@ from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import TextSection
from onyx.file_processing.extract_file_text import read_pdf_file
from onyx.file_processing.html_utils import web_html_cleanup
from onyx.utils.logger import setup_logger
@@ -341,7 +341,7 @@ class WebConnector(LoadConnector):
doc_batch.append(
Document(
id=initial_url,
sections=[Section(link=initial_url, text=page_text)],
sections=[TextSection(link=initial_url, text=page_text)],
source=DocumentSource.WEB,
semantic_identifier=initial_url.split("/")[-1],
metadata=metadata,
@@ -443,7 +443,7 @@ class WebConnector(LoadConnector):
Document(
id=initial_url,
sections=[
Section(link=initial_url, text=parsed_html.cleaned_text)
TextSection(link=initial_url, text=parsed_html.cleaned_text)
],
source=DocumentSource.WEB,
semantic_identifier=parsed_html.title or initial_url,

View File

@@ -28,7 +28,7 @@ from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import TextSection
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -104,7 +104,7 @@ def scrape_page_posts(
# id. We may want to de-dupe this stuff inside the indexing service.
document = Document(
id=f"{DocumentSource.XENFORO.value}_{title}_{page_index}_{formatted_time}",
sections=[Section(link=url, text=post_text)],
sections=[TextSection(link=url, text=post_text)],
title=title,
source=DocumentSource.XENFORO,
semantic_identifier=title,

View File

@@ -1,5 +1,6 @@
from collections.abc import Iterator
from typing import Any
from typing import cast
import requests
@@ -17,8 +18,8 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.file_processing.html_utils import parse_html_page_basic
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.retry_wrapper import retry_builder
@@ -168,8 +169,8 @@ def _article_to_document(
return new_author_mapping, Document(
id=f"article:{article['id']}",
sections=[
Section(
link=article.get("html_url"),
TextSection(
link=cast(str, article.get("html_url")),
text=parse_html_page_basic(article["body"]),
)
],
@@ -268,7 +269,7 @@ def _ticket_to_document(
return new_author_mapping, Document(
id=f"zendesk_ticket_{ticket['id']}",
sections=[Section(link=ticket_display_url, text=full_text)],
sections=[TextSection(link=ticket_display_url, text=full_text)],
source=DocumentSource.ZENDESK,
semantic_identifier=f"Ticket #{ticket['id']}: {subject or 'No Subject'}",
doc_updated_at=update_time,

View File

@@ -20,7 +20,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import TextSection
from onyx.connectors.zulip.schemas import GetMessagesResponse
from onyx.connectors.zulip.schemas import Message
from onyx.connectors.zulip.utils import build_search_narrow
@@ -161,7 +161,7 @@ class ZulipConnector(LoadConnector, PollConnector):
return Document(
id=f"{message.stream_id}__{message.id}",
sections=[
Section(
TextSection(
link=self._message_to_narrow_link(message),
text=text,
)

View File

@@ -10,6 +10,7 @@ from langchain_core.messages import SystemMessage
from onyx.chat.models import SectionRelevancePiece
from onyx.configs.app_configs import BLURB_SIZE
from onyx.configs.app_configs import IMAGE_ANALYSIS_SYSTEM_PROMPT
from onyx.configs.constants import RETURN_SEPARATOR
from onyx.configs.llm_configs import get_search_time_image_analysis_enabled
from onyx.configs.model_configs import CROSS_ENCODER_RANGE_MAX
@@ -31,7 +32,6 @@ from onyx.file_store.file_store import get_default_file_store
from onyx.llm.interfaces import LLM
from onyx.llm.utils import message_to_string
from onyx.natural_language_processing.search_nlp_models import RerankingModel
from onyx.prompts.image_analysis import IMAGE_ANALYSIS_SYSTEM_PROMPT
from onyx.secondary_llm_flows.chunk_usefulness import llm_batch_eval_sections
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import FunctionCall

63
backend/onyx/db/chunk.py Normal file
View File

@@ -0,0 +1,63 @@
from datetime import datetime
from datetime import timezone
from sqlalchemy import delete
from sqlalchemy.orm import Session
from onyx.db.models import ChunkStats
from onyx.indexing.models import UpdatableChunkData
def update_chunk_boost_components__no_commit(
chunk_data: list[UpdatableChunkData],
db_session: Session,
) -> None:
"""Updates the chunk_boost_components for chunks in the database.
Args:
chunk_data: List of dicts containing chunk_id, document_id, and boost_score
db_session: SQLAlchemy database session
"""
if not chunk_data:
return
for data in chunk_data:
chunk_in_doc_id = int(data.chunk_id)
if chunk_in_doc_id < 0:
raise ValueError(f"Chunk ID is empty for chunk {data}")
chunk_document_id = f"{data.document_id}" f"__{chunk_in_doc_id}"
chunk_stats = (
db_session.query(ChunkStats)
.filter(
ChunkStats.id == chunk_document_id,
)
.first()
)
score = data.boost_score
if chunk_stats:
chunk_stats.information_content_boost = score
chunk_stats.last_modified = datetime.now(timezone.utc)
db_session.add(chunk_stats)
else:
# do not save new chunks with a neutral boost score
if score == 1.0:
continue
# Create new record
chunk_stats = ChunkStats(
document_id=data.document_id,
chunk_in_doc_id=chunk_in_doc_id,
information_content_boost=score,
)
db_session.add(chunk_stats)
def delete_chunk_stats_by_connector_credential_pair__no_commit(
db_session: Session, document_ids: list[str]
) -> None:
"""This deletes just chunk stats in postgres."""
stmt = delete(ChunkStats).where(ChunkStats.document_id.in_(document_ids))
db_session.execute(stmt)

View File

@@ -23,6 +23,7 @@ from sqlalchemy.sql.expression import null
from onyx.configs.constants import DEFAULT_BOOST
from onyx.configs.constants import DocumentSource
from onyx.db.chunk import delete_chunk_stats_by_connector_credential_pair__no_commit
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.engine import get_session_context_manager
from onyx.db.enums import AccessType
@@ -562,6 +563,18 @@ def delete_documents_complete__no_commit(
db_session: Session, document_ids: list[str]
) -> None:
"""This completely deletes the documents from the db, including all foreign key relationships"""
# Start by deleting the chunk stats for the documents
delete_chunk_stats_by_connector_credential_pair__no_commit(
db_session=db_session,
document_ids=document_ids,
)
delete_chunk_stats_by_connector_credential_pair__no_commit(
db_session=db_session,
document_ids=document_ids,
)
delete_documents_by_connector_credential_pair__no_commit(db_session, document_ids)
delete_document_feedback_for_documents__no_commit(
document_ids=document_ids, db_session=db_session

View File

@@ -13,6 +13,7 @@ from onyx.db.models import SearchSettings
from onyx.db.models import Tool as ToolModel
from onyx.db.models import User
from onyx.db.models import User__UserGroup
from onyx.llm.utils import model_supports_image_input
from onyx.server.manage.embedding.models import CloudEmbeddingProvider
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
from onyx.server.manage.llm.models import FullLLMProvider
@@ -187,6 +188,17 @@ def fetch_default_provider(db_session: Session) -> FullLLMProvider | None:
return FullLLMProvider.from_model(provider_model)
def fetch_default_vision_provider(db_session: Session) -> FullLLMProvider | None:
provider_model = db_session.scalar(
select(LLMProviderModel).where(
LLMProviderModel.is_default_vision_provider == True # noqa: E712
)
)
if not provider_model:
return None
return FullLLMProvider.from_model(provider_model)
def fetch_provider(db_session: Session, provider_name: str) -> FullLLMProvider | None:
provider_model = db_session.scalar(
select(LLMProviderModel).where(LLMProviderModel.name == provider_name)
@@ -246,3 +258,39 @@ def update_default_provider(provider_id: int, db_session: Session) -> None:
new_default.is_default_provider = True
db_session.commit()
def update_default_vision_provider(
provider_id: int, vision_model: str | None, db_session: Session
) -> None:
new_default = db_session.scalar(
select(LLMProviderModel).where(LLMProviderModel.id == provider_id)
)
if not new_default:
raise ValueError(f"LLM Provider with id {provider_id} does not exist")
# Validate that the specified vision model supports image input
model_to_validate = vision_model or new_default.default_model_name
if model_to_validate:
if not model_supports_image_input(model_to_validate, new_default.provider):
raise ValueError(
f"Model '{model_to_validate}' for provider '{new_default.provider}' does not support image input"
)
else:
raise ValueError(
f"Model '{vision_model}' is not a valid model for provider '{new_default.provider}'"
)
existing_default = db_session.scalar(
select(LLMProviderModel).where(
LLMProviderModel.is_default_vision_provider == True # noqa: E712
)
)
if existing_default:
existing_default.is_default_vision_provider = None
# required to ensure that the below does not cause a unique constraint violation
db_session.flush()
new_default.is_default_vision_provider = True
new_default.default_vision_model = vision_model
db_session.commit()

View File

@@ -591,6 +591,55 @@ class Document(Base):
)
class ChunkStats(Base):
__tablename__ = "chunk_stats"
# NOTE: if more sensitive data is added here for display, make sure to add user/group permission
# this should correspond to the ID of the document
# (as is passed around in Onyx)
id: Mapped[str] = mapped_column(
NullFilteredString,
primary_key=True,
default=lambda context: (
f"{context.get_current_parameters()['document_id']}"
f"__{context.get_current_parameters()['chunk_in_doc_id']}"
),
index=True,
)
# Reference to parent document
document_id: Mapped[str] = mapped_column(
NullFilteredString, ForeignKey("document.id"), nullable=False, index=True
)
chunk_in_doc_id: Mapped[int] = mapped_column(
Integer,
nullable=False,
)
information_content_boost: Mapped[float | None] = mapped_column(
Float, nullable=True
)
last_modified: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=False, index=True, default=func.now()
)
last_synced: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True, index=True
)
__table_args__ = (
Index(
"ix_chunk_sync_status",
last_modified,
last_synced,
),
UniqueConstraint(
"document_id", "chunk_in_doc_id", name="uq_chunk_stats_doc_chunk"
),
)
class Tag(Base):
__tablename__ = "tag"
@@ -1489,6 +1538,8 @@ class LLMProvider(Base):
# should only be set for a single provider
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, unique=True)
is_default_vision_provider: Mapped[bool | None] = mapped_column(Boolean)
default_vision_model: Mapped[str | None] = mapped_column(String, nullable=True)
# EE only
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
groups: Mapped[list["UserGroup"]] = relationship(
@@ -2309,6 +2360,17 @@ class UserTenantMapping(Base):
return value.lower() if value else value
class AvailableTenant(Base):
__tablename__ = "available_tenant"
"""
These entries will only exist ephemerally and are meant to be picked up by new users on registration.
"""
tenant_id: Mapped[str] = mapped_column(String, primary_key=True, nullable=False)
alembic_version: Mapped[str] = mapped_column(String, nullable=False)
date_created: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False)
# This is a mapping from tenant IDs to anonymous user paths
class TenantAnonymousUserPath(Base):
__tablename__ = "tenant_anonymous_user_path"

View File

@@ -67,6 +67,9 @@ def read_lobj(
use_tempfile: bool = False,
) -> IO:
pg_conn = get_pg_conn_from_session(db_session)
# Ensure we're using binary mode by default for large objects
if mode is None:
mode = "rb"
large_object = (
pg_conn.lobject(lobj_oid, mode=mode) if mode else pg_conn.lobject(lobj_oid)
)
@@ -81,6 +84,7 @@ def read_lobj(
temp_file.seek(0)
return temp_file
else:
# Ensure we're getting raw bytes without text decoding
return BytesIO(large_object.read())

View File

@@ -101,6 +101,7 @@ class VespaDocumentFields:
document_sets: set[str] | None = None
boost: float | None = None
hidden: bool | None = None
aggregated_chunk_boost_factor: float | None = None
@dataclass

View File

@@ -80,6 +80,11 @@ schema DANSWER_CHUNK_NAME {
indexing: summary | attribute
rank: filter
}
# Field to indicate whether a short chunk is a low content chunk
field aggregated_chunk_boost_factor type float {
indexing: attribute
}
# Needs to have a separate Attribute list for efficient filtering
field metadata_list type array<string> {
indexing: summary | attribute
@@ -142,6 +147,11 @@ schema DANSWER_CHUNK_NAME {
expression: max(if(isNan(attribute(doc_updated_at)) == 1, 7890000, now() - attribute(doc_updated_at)) / 31536000, 0)
}
function inline aggregated_chunk_boost() {
# Aggregated boost factor, currently only used for information content classification
expression: if(isNan(attribute(aggregated_chunk_boost_factor)) == 1, 1.0, attribute(aggregated_chunk_boost_factor))
}
# Document score decays from 1 to 0.75 as age of last updated time increases
function inline recency_bias() {
expression: max(1 / (1 + query(decay_factor) * document_age), 0.75)
@@ -199,6 +209,8 @@ schema DANSWER_CHUNK_NAME {
* document_boost
# Decay factor based on time document was last updated
* recency_bias
# Boost based on aggregated boost calculation
* aggregated_chunk_boost
}
rerank-count: 1000
}
@@ -210,6 +222,7 @@ schema DANSWER_CHUNK_NAME {
closeness(field, embeddings)
document_boost
recency_bias
aggregated_chunk_boost
closest(embeddings)
}
}

View File

@@ -22,6 +22,7 @@ from onyx.document_index.vespa.shared_utils.utils import (
replace_invalid_doc_id_characters,
)
from onyx.document_index.vespa_constants import ACCESS_CONTROL_LIST
from onyx.document_index.vespa_constants import AGGREGATED_CHUNK_BOOST_FACTOR
from onyx.document_index.vespa_constants import BLURB
from onyx.document_index.vespa_constants import BOOST
from onyx.document_index.vespa_constants import CHUNK_ID
@@ -201,6 +202,7 @@ def _index_vespa_chunk(
DOCUMENT_SETS: {document_set: 1 for document_set in chunk.document_sets},
IMAGE_FILE_NAME: chunk.image_file_name,
BOOST: chunk.boost,
AGGREGATED_CHUNK_BOOST_FACTOR: chunk.aggregated_chunk_boost_factor,
}
if multitenant:

View File

@@ -72,6 +72,7 @@ METADATA = "metadata"
METADATA_LIST = "metadata_list"
METADATA_SUFFIX = "metadata_suffix"
BOOST = "boost"
AGGREGATED_CHUNK_BOOST_FACTOR = "aggregated_chunk_boost_factor"
DOC_UPDATED_AT = "doc_updated_at" # Indexed as seconds since epoch
PRIMARY_OWNERS = "primary_owners"
SECONDARY_OWNERS = "secondary_owners"
@@ -97,6 +98,7 @@ YQL_BASE = (
f"{SECTION_CONTINUATION}, "
f"{IMAGE_FILE_NAME}, "
f"{BOOST}, "
f"{AGGREGATED_CHUNK_BOOST_FACTOR}, "
f"{HIDDEN}, "
f"{DOC_UPDATED_AT}, "
f"{PRIMARY_OWNERS}, "

View File

@@ -6,10 +6,10 @@ from langchain_core.messages import HumanMessage
from langchain_core.messages import SystemMessage
from PIL import Image
from onyx.configs.app_configs import IMAGE_SUMMARIZATION_SYSTEM_PROMPT
from onyx.configs.app_configs import IMAGE_SUMMARIZATION_USER_PROMPT
from onyx.llm.interfaces import LLM
from onyx.llm.utils import message_to_string
from onyx.prompts.image_analysis import IMAGE_SUMMARIZATION_SYSTEM_PROMPT
from onyx.prompts.image_analysis import IMAGE_SUMMARIZATION_USER_PROMPT
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -62,7 +62,7 @@ def summarize_image_with_error_handling(
image_data: The raw image bytes
context_name: Name or title of the image for context
system_prompt: System prompt to use for the LLM
user_prompt_template: Template for the user prompt, should contain {title} placeholder
user_prompt_template: User prompt to use (without title)
Returns:
The image summary text, or None if summarization failed or is disabled
@@ -70,7 +70,10 @@ def summarize_image_with_error_handling(
if llm is None:
return None
user_prompt = user_prompt_template.format(title=context_name)
# Prepend the image filename to the user prompt
user_prompt = (
f"The image has the file name '{context_name}'.\n{user_prompt_template}"
)
return summarize_image_pipeline(llm, image_data, user_prompt, system_prompt)

View File

@@ -2,12 +2,9 @@ from typing import Tuple
from sqlalchemy.orm import Session
from onyx.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
from onyx.configs.constants import FileOrigin
from onyx.connectors.models import Section
from onyx.connectors.models import ImageSection
from onyx.db.pg_file_store import save_bytes_to_pgfilestore
from onyx.file_processing.image_summarization import summarize_image_with_error_handling
from onyx.llm.interfaces import LLM
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -18,12 +15,12 @@ def store_image_and_create_section(
image_data: bytes,
file_name: str,
display_name: str,
media_type: str = "image/unknown",
llm: LLM | None = None,
link: str | None = None,
media_type: str = "application/octet-stream",
file_origin: FileOrigin = FileOrigin.OTHER,
) -> Tuple[Section, str | None]:
) -> Tuple[ImageSection, str | None]:
"""
Stores an image in PGFileStore and creates a Section object with optional summarization.
Stores an image in PGFileStore and creates an ImageSection object without summarization.
Args:
db_session: Database session
@@ -31,12 +28,11 @@ def store_image_and_create_section(
file_name: Base identifier for the file
display_name: Human-readable name for the image
media_type: MIME type of the image
llm: Optional LLM with vision capabilities for summarization
file_origin: Origin of the file (e.g., CONFLUENCE, GOOGLE_DRIVE, etc.)
Returns:
Tuple containing:
- Section object with image reference and optional summary text
- ImageSection object with image reference
- The file_name in PGFileStore or None if storage failed
"""
# Storage logic
@@ -53,18 +49,10 @@ def store_image_and_create_section(
stored_file_name = pgfilestore.file_name
except Exception as e:
logger.error(f"Failed to store image: {e}")
if not CONTINUE_ON_CONNECTOR_FAILURE:
raise
return Section(text=""), None
# Summarization logic
summary_text = ""
if llm:
summary_text = (
summarize_image_with_error_handling(llm, image_data, display_name) or ""
)
raise e
# Create an ImageSection with empty text (will be filled by LLM later in the pipeline)
return (
Section(text=summary_text, image_file_name=stored_file_name),
ImageSection(image_file_name=stored_file_name, link=link),
stored_file_name,
)

View File

@@ -9,7 +9,8 @@ from onyx.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
get_metadata_keys_to_ignore,
)
from onyx.connectors.models import Document
from onyx.connectors.models import IndexingDocument
from onyx.connectors.models import Section
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.indexing.models import DocAwareChunk
from onyx.natural_language_processing.utils import BaseTokenizer
@@ -64,7 +65,7 @@ def _get_metadata_suffix_for_document_index(
def _combine_chunks(chunks: list[DocAwareChunk], large_chunk_id: int) -> DocAwareChunk:
"""
Combines multiple DocAwareChunks into one large chunk (for multipass mode),
Combines multiple DocAwareChunks into one large chunk (for "multipass" mode),
appending the content and adjusting source_links accordingly.
"""
merged_chunk = DocAwareChunk(
@@ -98,7 +99,7 @@ def _combine_chunks(chunks: list[DocAwareChunk], large_chunk_id: int) -> DocAwar
def generate_large_chunks(chunks: list[DocAwareChunk]) -> list[DocAwareChunk]:
"""
Generates larger grouped chunks by combining sets of smaller chunks.
Generates larger "grouped" chunks by combining sets of smaller chunks.
"""
large_chunks = []
for idx, i in enumerate(range(0, len(chunks), LARGE_CHUNK_RATIO)):
@@ -185,7 +186,7 @@ class Chunker:
def _get_mini_chunk_texts(self, chunk_text: str) -> list[str] | None:
"""
For multipass mode: additional sub-chunks (mini-chunks) for use in certain embeddings.
For "multipass" mode: additional sub-chunks (mini-chunks) for use in certain embeddings.
"""
if self.mini_chunk_splitter and chunk_text.strip():
return self.mini_chunk_splitter.split_text(chunk_text)
@@ -194,7 +195,7 @@ class Chunker:
# ADDED: extra param image_url to store in the chunk
def _create_chunk(
self,
document: Document,
document: IndexingDocument,
chunks_list: list[DocAwareChunk],
text: str,
links: dict[int, str],
@@ -225,7 +226,29 @@ class Chunker:
def _chunk_document(
self,
document: Document,
document: IndexingDocument,
title_prefix: str,
metadata_suffix_semantic: str,
metadata_suffix_keyword: str,
content_token_limit: int,
) -> list[DocAwareChunk]:
"""
Legacy method for backward compatibility.
Calls _chunk_document_with_sections with document.sections.
"""
return self._chunk_document_with_sections(
document,
document.processed_sections,
title_prefix,
metadata_suffix_semantic,
metadata_suffix_keyword,
content_token_limit,
)
def _chunk_document_with_sections(
self,
document: IndexingDocument,
sections: list[Section],
title_prefix: str,
metadata_suffix_semantic: str,
metadata_suffix_keyword: str,
@@ -233,17 +256,16 @@ class Chunker:
) -> list[DocAwareChunk]:
"""
Loops through sections of the document, converting them into one or more chunks.
If a section has an image_link, we treat it as a dedicated chunk.
Works with processed sections that are base Section objects.
"""
chunks: list[DocAwareChunk] = []
link_offsets: dict[int, str] = {}
chunk_text = ""
for section_idx, section in enumerate(document.sections):
section_text = clean_text(section.text)
for section_idx, section in enumerate(sections):
# Get section text and other attributes
section_text = clean_text(section.text or "")
section_link_text = section.link or ""
# ADDED: if the Section has an image link
image_url = section.image_file_name
# If there is no useful content, skip
@@ -254,7 +276,7 @@ class Chunker:
)
continue
# CASE 1: If this is an image section, force a separate chunk
# CASE 1: If this section has an image, force a separate chunk
if image_url:
# First, if we have any partially built text chunk, finalize it
if chunk_text.strip():
@@ -271,15 +293,13 @@ class Chunker:
chunk_text = ""
link_offsets = {}
# Create a chunk specifically for this image
# (If the section has text describing the image, use that as content)
# Create a chunk specifically for this image section
# (Using the text summary that was generated during processing)
self._create_chunk(
document,
chunks,
section_text,
links={0: section_link_text}
if section_link_text
else {}, # No text offsets needed for images
links={0: section_link_text} if section_link_text else {},
image_file_name=image_url,
title_prefix=title_prefix,
metadata_suffix_semantic=metadata_suffix_semantic,
@@ -384,7 +404,9 @@ class Chunker:
)
return chunks
def _handle_single_document(self, document: Document) -> list[DocAwareChunk]:
def _handle_single_document(
self, document: IndexingDocument
) -> list[DocAwareChunk]:
# Specifically for reproducing an issue with gmail
if document.source == DocumentSource.GMAIL:
logger.debug(f"Chunking {document.semantic_identifier}")
@@ -420,26 +442,31 @@ class Chunker:
title_prefix = ""
metadata_suffix_semantic = ""
# Chunk the document
normal_chunks = self._chunk_document(
# Use processed_sections if available (IndexingDocument), otherwise use original sections
sections_to_chunk = document.processed_sections
normal_chunks = self._chunk_document_with_sections(
document,
sections_to_chunk,
title_prefix,
metadata_suffix_semantic,
metadata_suffix_keyword,
content_token_limit,
)
# Optional multipass large chunk creation
# Optional "multipass" large chunk creation
if self.enable_multipass and self.enable_large_chunks:
large_chunks = generate_large_chunks(normal_chunks)
normal_chunks.extend(large_chunks)
return normal_chunks
def chunk(self, documents: list[Document]) -> list[DocAwareChunk]:
def chunk(self, documents: list[IndexingDocument]) -> list[DocAwareChunk]:
"""
Takes in a list of documents and chunks them into smaller chunks for indexing
while persisting the document metadata.
Works with both standard Document objects and IndexingDocument objects with processed_sections.
"""
final_chunks: list[DocAwareChunk] = []
for document in documents:

View File

@@ -10,13 +10,20 @@ from onyx.access.access import get_access_for_documents
from onyx.access.models import DocumentAccess
from onyx.configs.app_configs import MAX_DOCUMENT_CHARS
from onyx.configs.constants import DEFAULT_BOOST
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
from onyx.configs.model_configs import USE_INFORMATION_CONTENT_CLASSIFICATION
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
get_experts_stores_representations,
)
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import ImageSection
from onyx.connectors.models import IndexAttemptMetadata
from onyx.connectors.models import IndexingDocument
from onyx.connectors.models import Section
from onyx.connectors.models import TextSection
from onyx.db.chunk import update_chunk_boost_components__no_commit
from onyx.db.document import fetch_chunk_counts_for_documents
from onyx.db.document import get_documents_by_ids
from onyx.db.document import mark_document_as_indexed_for_cc_pair__no_commit
@@ -27,7 +34,10 @@ from onyx.db.document import update_docs_updated_at__no_commit
from onyx.db.document import upsert_document_by_connector_credential_pair
from onyx.db.document import upsert_documents
from onyx.db.document_set import fetch_document_sets_for_documents
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.models import Document as DBDocument
from onyx.db.pg_file_store import get_pgfilestore_by_file_name
from onyx.db.pg_file_store import read_lobj
from onyx.db.search_settings import get_current_search_settings
from onyx.db.tag import create_or_add_document_tag
from onyx.db.tag import create_or_add_document_tag_list
@@ -37,15 +47,26 @@ from onyx.document_index.document_index_utils import (
from onyx.document_index.interfaces import DocumentIndex
from onyx.document_index.interfaces import DocumentMetadata
from onyx.document_index.interfaces import IndexBatchParams
from onyx.file_processing.image_summarization import summarize_image_with_error_handling
from onyx.indexing.chunker import Chunker
from onyx.indexing.embedder import embed_chunks_with_failure_handling
from onyx.indexing.embedder import IndexingEmbedder
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.indexing.models import DocAwareChunk
from onyx.indexing.models import DocMetadataAwareIndexChunk
from onyx.indexing.models import IndexChunk
from onyx.indexing.models import UpdatableChunkData
from onyx.indexing.vector_db_insertion import write_chunks_to_vector_db_with_backoff
from onyx.llm.factory import get_default_llm_with_vision
from onyx.natural_language_processing.search_nlp_models import (
InformationContentClassificationModel,
)
from onyx.utils.logger import setup_logger
from onyx.utils.timing import log_function_time
from shared_configs.configs import (
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH,
)
logger = setup_logger()
@@ -53,6 +74,7 @@ logger = setup_logger()
class DocumentBatchPrepareContext(BaseModel):
updatable_docs: list[Document]
id_to_db_doc_map: dict[str, DBDocument]
indexable_docs: list[IndexingDocument] = []
model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -125,6 +147,72 @@ def _upsert_documents_in_db(
)
def _get_aggregated_chunk_boost_factor(
chunks: list[IndexChunk],
information_content_classification_model: InformationContentClassificationModel,
) -> list[float]:
"""Calculates the aggregated boost factor for a chunk based on its content."""
short_chunk_content_dict = {
chunk_num: chunk.content
for chunk_num, chunk in enumerate(chunks)
if len(chunk.content.split())
<= INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH
}
short_chunk_contents = list(short_chunk_content_dict.values())
short_chunk_keys = list(short_chunk_content_dict.keys())
try:
predictions = information_content_classification_model.predict(
short_chunk_contents
)
# Create a mapping of chunk positions to their scores
score_map = {
short_chunk_keys[i]: prediction.content_boost_factor
for i, prediction in enumerate(predictions)
}
# Default to 1.0 for longer chunks, use predicted score for short chunks
chunk_content_scores = [score_map.get(i, 1.0) for i in range(len(chunks))]
return chunk_content_scores
except Exception as e:
logger.exception(
f"Error predicting content classification for chunks: {e}. Falling back to individual examples."
)
chunks_with_scores: list[IndexChunk] = []
chunk_content_scores = []
for chunk in chunks:
if (
len(chunk.content.split())
> INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH
):
chunk_content_scores.append(1.0)
chunks_with_scores.append(chunk)
continue
try:
chunk_content_scores.append(
information_content_classification_model.predict([chunk.content])[
0
].content_boost_factor
)
chunks_with_scores.append(chunk)
except Exception as e:
logger.exception(
f"Error predicting content classification for chunk: {e}."
)
raise Exception(
f"Failed to predict content classification for chunk {chunk.chunk_id} "
f"from document {chunk.source_document.id}"
) from e
return chunk_content_scores
def get_doc_ids_to_update(
documents: list[Document], db_docs: list[DBDocument]
) -> list[Document]:
@@ -154,6 +242,7 @@ def index_doc_batch_with_handler(
*,
chunker: Chunker,
embedder: IndexingEmbedder,
information_content_classification_model: InformationContentClassificationModel,
document_index: DocumentIndex,
document_batch: list[Document],
index_attempt_metadata: IndexAttemptMetadata,
@@ -165,6 +254,7 @@ def index_doc_batch_with_handler(
index_pipeline_result = index_doc_batch(
chunker=chunker,
embedder=embedder,
information_content_classification_model=information_content_classification_model,
document_index=document_index,
document_batch=document_batch,
index_attempt_metadata=index_attempt_metadata,
@@ -265,7 +355,12 @@ def index_doc_batch_prepare(
def filter_documents(document_batch: list[Document]) -> list[Document]:
documents: list[Document] = []
for document in document_batch:
empty_contents = not any(section.text.strip() for section in document.sections)
empty_contents = not any(
isinstance(section, TextSection)
and section.text is not None
and section.text.strip()
for section in document.sections
)
if (
(not document.title or not document.title.strip())
and not document.semantic_identifier.strip()
@@ -288,7 +383,12 @@ def filter_documents(document_batch: list[Document]) -> list[Document]:
)
continue
section_chars = sum(len(section.text) for section in document.sections)
section_chars = sum(
len(section.text)
if isinstance(section, TextSection) and section.text is not None
else 0
for section in document.sections
)
if (
MAX_DOCUMENT_CHARS
and len(document.title or document.semantic_identifier) + section_chars
@@ -308,12 +408,128 @@ def filter_documents(document_batch: list[Document]) -> list[Document]:
return documents
def process_image_sections(documents: list[Document]) -> list[IndexingDocument]:
"""
Process all sections in documents by:
1. Converting both TextSection and ImageSection objects to base Section objects
2. Processing ImageSections to generate text summaries using a vision-capable LLM
3. Returning IndexingDocument objects with both original and processed sections
Args:
documents: List of documents with TextSection | ImageSection objects
Returns:
List of IndexingDocument objects with processed_sections as list[Section]
"""
# Check if image extraction and analysis is enabled before trying to get a vision LLM
if not get_image_extraction_and_analysis_enabled():
llm = None
else:
# Only get the vision LLM if image processing is enabled
llm = get_default_llm_with_vision()
if not llm:
logger.warning(
"No vision-capable LLM available. Image sections will not be processed."
)
# Even without LLM, we still convert to IndexingDocument with base Sections
return [
IndexingDocument(
**document.dict(),
processed_sections=[
Section(
text=section.text if isinstance(section, TextSection) else None,
link=section.link,
image_file_name=section.image_file_name
if isinstance(section, ImageSection)
else None,
)
for section in document.sections
],
)
for document in documents
]
indexed_documents: list[IndexingDocument] = []
for document in documents:
processed_sections: list[Section] = []
for section in document.sections:
# For ImageSection, process and create base Section with both text and image_file_name
if isinstance(section, ImageSection):
# Default section with image path preserved
processed_section = Section(
link=section.link,
image_file_name=section.image_file_name,
text=None, # Will be populated if summarization succeeds
)
# Try to get image summary
try:
with get_session_with_current_tenant() as db_session:
pgfilestore = get_pgfilestore_by_file_name(
file_name=section.image_file_name, db_session=db_session
)
if not pgfilestore:
logger.warning(
f"Image file {section.image_file_name} not found in PGFileStore"
)
processed_section.text = "[Image could not be processed]"
else:
# Get the image data
image_data_io = read_lobj(
pgfilestore.lobj_oid, db_session, mode="rb"
)
pgfilestore_data = image_data_io.read()
summary = summarize_image_with_error_handling(
llm=llm,
image_data=pgfilestore_data,
context_name=pgfilestore.display_name or "Image",
)
if summary:
processed_section.text = summary
else:
processed_section.text = (
"[Image could not be summarized]"
)
except Exception as e:
logger.error(f"Error processing image section: {e}")
processed_section.text = "[Error processing image]"
processed_sections.append(processed_section)
# For TextSection, create a base Section with text and link
elif isinstance(section, TextSection):
processed_section = Section(
text=section.text, link=section.link, image_file_name=None
)
processed_sections.append(processed_section)
# If it's already a base Section (unlikely), just append it
else:
processed_sections.append(section)
# Create IndexingDocument with original sections and processed_sections
indexed_document = IndexingDocument(
**document.dict(), processed_sections=processed_sections
)
indexed_documents.append(indexed_document)
return indexed_documents
@log_function_time(debug_only=True)
def index_doc_batch(
*,
document_batch: list[Document],
chunker: Chunker,
embedder: IndexingEmbedder,
information_content_classification_model: InformationContentClassificationModel,
document_index: DocumentIndex,
index_attempt_metadata: IndexAttemptMetadata,
db_session: Session,
@@ -362,19 +578,23 @@ def index_doc_batch(
failures=[],
)
# Convert documents to IndexingDocument objects with processed section
# logger.debug("Processing image sections")
ctx.indexable_docs = process_image_sections(ctx.updatable_docs)
doc_descriptors = [
{
"doc_id": doc.id,
"doc_length": doc.get_total_char_length(),
}
for doc in ctx.updatable_docs
for doc in ctx.indexable_docs
]
logger.debug(f"Starting indexing process for documents: {doc_descriptors}")
logger.debug("Starting chunking")
# NOTE: no special handling for failures here, since the chunker is not
# a common source of failure for the indexing pipeline
chunks: list[DocAwareChunk] = chunker.chunk(ctx.updatable_docs)
chunks: list[DocAwareChunk] = chunker.chunk(ctx.indexable_docs)
logger.debug("Starting embedding")
chunks_with_embeddings, embedding_failures = (
@@ -386,7 +606,23 @@ def index_doc_batch(
else ([], [])
)
chunk_content_scores = (
_get_aggregated_chunk_boost_factor(
chunks_with_embeddings, information_content_classification_model
)
if USE_INFORMATION_CONTENT_CLASSIFICATION
else [1.0] * len(chunks_with_embeddings)
)
updatable_ids = [doc.id for doc in ctx.updatable_docs]
updatable_chunk_data = [
UpdatableChunkData(
chunk_id=chunk.chunk_id,
document_id=chunk.source_document.id,
boost_score=score,
)
for chunk, score in zip(chunks_with_embeddings, chunk_content_scores)
]
# Acquires a lock on the documents so that no other process can modify them
# NOTE: don't need to acquire till here, since this is when the actual race condition
@@ -439,8 +675,9 @@ def index_doc_batch(
else DEFAULT_BOOST
),
tenant_id=tenant_id,
aggregated_chunk_boost_factor=chunk_content_scores[chunk_num],
)
for chunk in chunks_with_embeddings
for chunk_num, chunk in enumerate(chunks_with_embeddings)
]
logger.debug(
@@ -525,6 +762,11 @@ def index_doc_batch(
db_session=db_session,
)
# save the chunk boost components to postgres
update_chunk_boost_components__no_commit(
chunk_data=updatable_chunk_data, db_session=db_session
)
db_session.commit()
result = IndexingPipelineResult(
@@ -540,6 +782,7 @@ def index_doc_batch(
def build_indexing_pipeline(
*,
embedder: IndexingEmbedder,
information_content_classification_model: InformationContentClassificationModel,
document_index: DocumentIndex,
db_session: Session,
tenant_id: str,
@@ -563,6 +806,7 @@ def build_indexing_pipeline(
index_doc_batch_with_handler,
chunker=chunker,
embedder=embedder,
information_content_classification_model=information_content_classification_model,
document_index=document_index,
ignore_time_skip=ignore_time_skip,
db_session=db_session,

View File

@@ -83,13 +83,16 @@ class DocMetadataAwareIndexChunk(IndexChunk):
document_sets: all document sets the source document for this chunk is a part
of. This is used for filtering / personas.
boost: influences the ranking of this chunk at query time. Positive -> ranked higher,
negative -> ranked lower.
negative -> ranked lower. Not included in aggregated boost calculation
for legacy reasons.
aggregated_chunk_boost_factor: represents the aggregated chunk-level boost (currently: information content)
"""
tenant_id: str
access: "DocumentAccess"
document_sets: set[str]
boost: int
aggregated_chunk_boost_factor: float
@classmethod
def from_index_chunk(
@@ -98,6 +101,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
access: "DocumentAccess",
document_sets: set[str],
boost: int,
aggregated_chunk_boost_factor: float,
tenant_id: str,
) -> "DocMetadataAwareIndexChunk":
index_chunk_data = index_chunk.model_dump()
@@ -106,6 +110,7 @@ class DocMetadataAwareIndexChunk(IndexChunk):
access=access,
document_sets=document_sets,
boost=boost,
aggregated_chunk_boost_factor=aggregated_chunk_boost_factor,
tenant_id=tenant_id,
)
@@ -179,3 +184,9 @@ class IndexingSetting(EmbeddingModelDetail):
class MultipassConfig(BaseModel):
multipass_indexing: bool
enable_large_chunks: bool
class UpdatableChunkData(BaseModel):
chunk_id: int
document_id: str
boost_score: float

View File

@@ -5,7 +5,9 @@ from onyx.configs.app_configs import DISABLE_GENERATIVE_AI
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
from onyx.db.engine import get_session_context_manager
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.llm import fetch_default_provider
from onyx.db.llm import fetch_default_vision_provider
from onyx.db.llm import fetch_existing_llm_providers
from onyx.db.llm import fetch_provider
from onyx.db.models import Persona
@@ -14,6 +16,7 @@ from onyx.llm.exceptions import GenAIDisabledException
from onyx.llm.interfaces import LLM
from onyx.llm.override_models import LLMOverride
from onyx.llm.utils import model_supports_image_input
from onyx.server.manage.llm.models import FullLLMProvider
from onyx.utils.headers import build_llm_extra_headers
from onyx.utils.logger import setup_logger
from onyx.utils.long_term_log import LongTermLogger
@@ -94,40 +97,61 @@ def get_default_llm_with_vision(
additional_headers: dict[str, str] | None = None,
long_term_logger: LongTermLogger | None = None,
) -> LLM | None:
"""Get an LLM that supports image input, with the following priority:
1. Use the designated default vision provider if it exists and supports image input
2. Fall back to the first LLM provider that supports image input
Returns None if no providers exist or if no provider supports images.
"""
if DISABLE_GENERATIVE_AI:
raise GenAIDisabledException()
with get_session_context_manager() as db_session:
llm_providers = fetch_existing_llm_providers(db_session)
if not llm_providers:
return None
for provider in llm_providers:
model_name = provider.default_model_name
fast_model_name = (
provider.fast_default_model_name or provider.default_model_name
def create_vision_llm(provider: FullLLMProvider, model: str) -> LLM:
"""Helper to create an LLM if the provider supports image input."""
return get_llm(
provider=provider.provider,
model=model,
deployment_name=provider.deployment_name,
api_key=provider.api_key,
api_base=provider.api_base,
api_version=provider.api_version,
custom_config=provider.custom_config,
timeout=timeout,
temperature=temperature,
additional_headers=additional_headers,
long_term_logger=long_term_logger,
)
if not model_name or not fast_model_name:
continue
if model_supports_image_input(model_name, provider.provider):
return get_llm(
provider=provider.provider,
model=model_name,
deployment_name=provider.deployment_name,
api_key=provider.api_key,
api_base=provider.api_base,
api_version=provider.api_version,
custom_config=provider.custom_config,
timeout=timeout,
temperature=temperature,
additional_headers=additional_headers,
long_term_logger=long_term_logger,
with get_session_with_current_tenant() as db_session:
# Try the default vision provider first
default_provider = fetch_default_vision_provider(db_session)
if (
default_provider
and default_provider.default_vision_model
and model_supports_image_input(
default_provider.default_vision_model, default_provider.provider
)
):
return create_vision_llm(
default_provider, default_provider.default_vision_model
)
raise ValueError("No LLM provider found that supports image input")
# Fall back to searching all providers
providers = fetch_existing_llm_providers(db_session)
if not providers:
return None
# Find the first provider that supports image input
for provider in providers:
if provider.default_vision_model and model_supports_image_input(
provider.default_vision_model, provider.provider
):
return create_vision_llm(
FullLLMProvider.from_model(provider), provider.default_vision_model
)
return None
def get_default_llms(

View File

@@ -29,6 +29,8 @@ from onyx.natural_language_processing.exceptions import (
from onyx.natural_language_processing.utils import get_tokenizer
from onyx.natural_language_processing.utils import tokenizer_trim_content
from onyx.utils.logger import setup_logger
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.enums import EmbeddingProvider
@@ -36,9 +38,11 @@ from shared_configs.enums import EmbedTextType
from shared_configs.enums import RerankerProvider
from shared_configs.model_server_models import ConnectorClassificationRequest
from shared_configs.model_server_models import ConnectorClassificationResponse
from shared_configs.model_server_models import ContentClassificationPrediction
from shared_configs.model_server_models import Embedding
from shared_configs.model_server_models import EmbedRequest
from shared_configs.model_server_models import EmbedResponse
from shared_configs.model_server_models import InformationContentClassificationResponses
from shared_configs.model_server_models import IntentRequest
from shared_configs.model_server_models import IntentResponse
from shared_configs.model_server_models import RerankRequest
@@ -377,6 +381,31 @@ class QueryAnalysisModel:
return response_model.is_keyword, response_model.keywords
class InformationContentClassificationModel:
def __init__(
self,
model_server_host: str = INDEXING_MODEL_SERVER_HOST,
model_server_port: int = INDEXING_MODEL_SERVER_PORT,
) -> None:
model_server_url = build_model_server_url(model_server_host, model_server_port)
self.content_server_endpoint = (
model_server_url + "/custom/content-classification"
)
def predict(
self,
queries: list[str],
) -> list[ContentClassificationPrediction]:
response = requests.post(self.content_server_endpoint, json=queries)
response.raise_for_status()
model_responses = InformationContentClassificationResponses(
information_content_classifications=response.json()
)
return model_responses.information_content_classifications
class ConnectorClassificationModel:
def __init__(
self,

View File

@@ -1,5 +1,5 @@
# Used for creating embeddings of images for vector search
IMAGE_SUMMARIZATION_SYSTEM_PROMPT = """
DEFAULT_IMAGE_SUMMARIZATION_SYSTEM_PROMPT = """
You are an assistant for summarizing images for retrieval.
Summarize the content of the following image and be as precise as possible.
The summary will be embedded and used to retrieve the original image.
@@ -7,14 +7,13 @@ Therefore, write a concise summary of the image that is optimized for retrieval.
"""
# Prompt for generating image descriptions with filename context
IMAGE_SUMMARIZATION_USER_PROMPT = """
The image has the file name '{title}'.
DEFAULT_IMAGE_SUMMARIZATION_USER_PROMPT = """
Describe precisely and concisely what the image shows.
"""
# Used for analyzing images in response to user queries at search time
IMAGE_ANALYSIS_SYSTEM_PROMPT = (
DEFAULT_IMAGE_ANALYSIS_SYSTEM_PROMPT = (
"You are an AI assistant specialized in describing images.\n"
"You will receive a user question plus an image URL. Provide a concise textual answer.\n"
"Focus on aspects of the image that are relevant to the user's question.\n"

View File

@@ -160,6 +160,20 @@ class RedisPool:
def get_replica_client(self, tenant_id: str) -> Redis:
return TenantRedis(tenant_id, connection_pool=self._replica_pool)
def get_raw_client(self) -> Redis:
"""
Returns a Redis client with direct access to the primary connection pool,
without tenant prefixing.
"""
return redis.Redis(connection_pool=self._pool)
def get_raw_replica_client(self) -> Redis:
"""
Returns a Redis client with direct access to the replica connection pool,
without tenant prefixing.
"""
return redis.Redis(connection_pool=self._replica_pool)
@staticmethod
def create_pool(
host: str = REDIS_HOST,
@@ -224,6 +238,15 @@ def get_redis_client(
# This argument will be deprecated in the future
tenant_id: str | None = None,
) -> Redis:
"""
Returns a Redis client with tenant-specific key prefixing.
This ensures proper data isolation between tenants by automatically
prefixing all Redis keys with the tenant ID.
Use this when working with tenant-specific data that should be
isolated from other tenants.
"""
if tenant_id is None:
tenant_id = get_current_tenant_id()
@@ -235,6 +258,15 @@ def get_redis_replica_client(
# this argument will be deprecated in the future
tenant_id: str | None = None,
) -> Redis:
"""
Returns a Redis replica client with tenant-specific key prefixing.
Similar to get_redis_client(), but connects to a read replica when available.
This ensures proper data isolation between tenants by automatically
prefixing all Redis keys with the tenant ID.
Use this for read-heavy operations on tenant-specific data.
"""
if tenant_id is None:
tenant_id = get_current_tenant_id()
@@ -242,13 +274,57 @@ def get_redis_replica_client(
def get_shared_redis_client() -> Redis:
"""
Returns a Redis client with a shared namespace prefix.
Unlike tenant-specific clients, this uses a common prefix for all keys,
creating a shared namespace accessible across all tenants.
Use this for data that should be shared across the application and
isn't specific to any individual tenant.
"""
return redis_pool.get_client(DEFAULT_REDIS_PREFIX)
def get_shared_redis_replica_client() -> Redis:
"""
Returns a Redis replica client with a shared namespace prefix.
Similar to get_shared_redis_client(), but connects to a read replica when available.
Uses a common prefix for all keys, creating a shared namespace.
Use this for read-heavy operations on data that should be shared
across the application.
"""
return redis_pool.get_replica_client(DEFAULT_REDIS_PREFIX)
def get_raw_redis_client() -> Redis:
"""
Returns a Redis client that doesn't apply tenant prefixing to keys.
Use this only when you need to access Redis directly without tenant isolation
or any key prefixing. Typically needed for integrating with external systems
or libraries that have inflexible key requirements.
Warning: Be careful with this client as it bypasses tenant isolation.
"""
return redis_pool.get_raw_client()
def get_raw_redis_replica_client() -> Redis:
"""
Returns a Redis replica client that doesn't apply tenant prefixing to keys.
Similar to get_raw_redis_client(), but connects to a read replica when available.
Use this for read-heavy operations that need direct Redis access without
tenant isolation or key prefixing.
Warning: Be careful with this client as it bypasses tenant isolation.
"""
return redis_pool.get_raw_replica_client()
SSL_CERT_REQS_MAP = {
"none": ssl.CERT_NONE,
"optional": ssl.CERT_OPTIONAL,

View File

@@ -15,7 +15,7 @@ from onyx.configs.model_configs import DEFAULT_DOCUMENT_ENCODER_MODEL
from onyx.connectors.models import Document
from onyx.connectors.models import IndexAttemptMetadata
from onyx.connectors.models import InputType
from onyx.connectors.models import Section
from onyx.connectors.models import TextSection
from onyx.db.connector import check_connectors_exist
from onyx.db.connector import create_connector
from onyx.db.connector_credential_pair import add_credential_to_connector
@@ -55,7 +55,7 @@ def _create_indexable_chunks(
# The section is not really used past this point since we have already done the other processing
# for the chunking and embedding.
sections=[
Section(
TextSection(
text=preprocessed_doc["content"],
link=preprocessed_doc["url"],
image_file_name=None,
@@ -98,6 +98,7 @@ def _create_indexable_chunks(
boost=DEFAULT_BOOST,
large_chunk_id=None,
image_file_name=None,
aggregated_chunk_boost_factor=1.0,
)
chunks.append(chunk)

View File

@@ -14,6 +14,7 @@ from onyx.db.llm import fetch_existing_llm_providers_for_user
from onyx.db.llm import fetch_provider
from onyx.db.llm import remove_llm_provider
from onyx.db.llm import update_default_provider
from onyx.db.llm import update_default_vision_provider
from onyx.db.llm import upsert_llm_provider
from onyx.db.models import User
from onyx.llm.factory import get_default_llms
@@ -21,11 +22,13 @@ from onyx.llm.factory import get_llm
from onyx.llm.llm_provider_options import fetch_available_well_known_llms
from onyx.llm.llm_provider_options import WellKnownLLMProviderDescriptor
from onyx.llm.utils import litellm_exception_to_error_msg
from onyx.llm.utils import model_supports_image_input
from onyx.llm.utils import test_llm
from onyx.server.manage.llm.models import FullLLMProvider
from onyx.server.manage.llm.models import LLMProviderDescriptor
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import TestLLMRequest
from onyx.server.manage.llm.models import VisionProviderResponse
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
@@ -186,6 +189,62 @@ def set_provider_as_default(
update_default_provider(provider_id=provider_id, db_session=db_session)
@admin_router.post("/provider/{provider_id}/default-vision")
def set_provider_as_default_vision(
provider_id: int,
vision_model: str
| None = Query(None, description="The default vision model to use"),
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
update_default_vision_provider(
provider_id=provider_id, vision_model=vision_model, db_session=db_session
)
@admin_router.get("/vision-providers")
def get_vision_capable_providers(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[VisionProviderResponse]:
"""Return a list of LLM providers and their models that support image input"""
providers = fetch_existing_llm_providers(db_session)
vision_providers = []
logger.info("Fetching vision-capable providers")
for provider in providers:
vision_models = []
# Check model names in priority order
model_names_to_check = []
if provider.model_names:
model_names_to_check = provider.model_names
elif provider.display_model_names:
model_names_to_check = provider.display_model_names
elif provider.default_model_name:
model_names_to_check = [provider.default_model_name]
# Check each model for vision capability
for model_name in model_names_to_check:
if model_supports_image_input(model_name, provider.provider):
vision_models.append(model_name)
logger.debug(f"Vision model found: {provider.provider}/{model_name}")
# Only include providers with at least one vision-capable model
if vision_models:
provider_dict = FullLLMProvider.from_model(provider).model_dump()
provider_dict["vision_models"] = vision_models
logger.info(
f"Vision provider: {provider.provider} with models: {vision_models}"
)
vision_providers.append(VisionProviderResponse(**provider_dict))
logger.info(f"Found {len(vision_providers)} vision-capable providers")
return vision_providers
"""Endpoints for all"""

View File

@@ -34,6 +34,8 @@ class LLMProviderDescriptor(BaseModel):
default_model_name: str
fast_default_model_name: str | None
is_default_provider: bool | None
is_default_vision_provider: bool | None
default_vision_model: str | None
display_model_names: list[str] | None
@classmethod
@@ -46,11 +48,10 @@ class LLMProviderDescriptor(BaseModel):
default_model_name=llm_provider_model.default_model_name,
fast_default_model_name=llm_provider_model.fast_default_model_name,
is_default_provider=llm_provider_model.is_default_provider,
model_names=(
llm_provider_model.model_names
or fetch_models_for_provider(llm_provider_model.provider)
or [llm_provider_model.default_model_name]
),
is_default_vision_provider=llm_provider_model.is_default_vision_provider,
default_vision_model=llm_provider_model.default_vision_model,
model_names=llm_provider_model.model_names
or fetch_models_for_provider(llm_provider_model.provider),
display_model_names=llm_provider_model.display_model_names,
)
@@ -68,6 +69,7 @@ class LLMProvider(BaseModel):
groups: list[int] = Field(default_factory=list)
display_model_names: list[str] | None = None
deployment_name: str | None = None
default_vision_model: str | None = None
class LLMProviderUpsertRequest(LLMProvider):
@@ -79,6 +81,7 @@ class LLMProviderUpsertRequest(LLMProvider):
class FullLLMProvider(LLMProvider):
id: int
is_default_provider: bool | None = None
is_default_vision_provider: bool | None = None
model_names: list[str]
@classmethod
@@ -94,6 +97,8 @@ class FullLLMProvider(LLMProvider):
default_model_name=llm_provider_model.default_model_name,
fast_default_model_name=llm_provider_model.fast_default_model_name,
is_default_provider=llm_provider_model.is_default_provider,
is_default_vision_provider=llm_provider_model.is_default_vision_provider,
default_vision_model=llm_provider_model.default_vision_model,
display_model_names=llm_provider_model.display_model_names,
model_names=(
llm_provider_model.model_names
@@ -104,3 +109,9 @@ class FullLLMProvider(LLMProvider):
groups=[group.id for group in llm_provider_model.groups],
deployment_name=llm_provider_model.deployment_name,
)
class VisionProviderResponse(FullLLMProvider):
"""Response model for vision providers endpoint, including vision-specific fields."""
vision_models: list[str]

View File

@@ -1,6 +1,8 @@
import re
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from typing import cast
import jwt
from email_validator import EmailNotValidError
@@ -31,9 +33,12 @@ from onyx.auth.users import current_admin_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_user
from onyx.auth.users import optional_user
from onyx.configs.app_configs import AUTH_BACKEND
from onyx.configs.app_configs import AUTH_TYPE
from onyx.configs.app_configs import AuthBackend
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.app_configs import ENABLE_EMAIL_INVITES
from onyx.configs.app_configs import REDIS_AUTH_KEY_PREFIX
from onyx.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
from onyx.configs.app_configs import VALID_EMAIL_DOMAINS
from onyx.configs.constants import AuthType
@@ -50,6 +55,7 @@ from onyx.db.users import get_total_filtered_users_count
from onyx.db.users import get_user_by_email
from onyx.db.users import validate_user_role_update
from onyx.key_value_store.factory import get_kv_store
from onyx.redis.redis_pool import get_raw_redis_client
from onyx.server.documents.models import PaginatedReturn
from onyx.server.manage.models import AllUsersResponse
from onyx.server.manage.models import AutoScrollRequest
@@ -477,7 +483,7 @@ async def get_user_role(user: User = Depends(current_user)) -> UserRoleResponse:
return UserRoleResponse(role=user.role)
def get_current_token_expiration_jwt(
def get_current_auth_token_expiration_jwt(
user: User | None, request: Request
) -> datetime | None:
if user is None:
@@ -506,6 +512,48 @@ def get_current_token_expiration_jwt(
return None
def get_current_auth_token_creation_redis(
user: User | None, request: Request
) -> datetime | None:
"""Calculate the token creation time from Redis TTL information.
This function retrieves the authentication token from cookies,
checks its TTL in Redis, and calculates when the token was created.
Despite the function name, it returns the token creation time, not the expiration time.
"""
if user is None:
return None
try:
# Get the token from the request
token = request.cookies.get(FASTAPI_USERS_AUTH_COOKIE_NAME)
if not token:
logger.debug("No auth token cookie found")
return None
# Get the Redis client
redis = get_raw_redis_client()
redis_key = REDIS_AUTH_KEY_PREFIX + token
# Get the TTL of the token
ttl = cast(int, redis.ttl(redis_key))
if ttl <= 0:
logger.error("Token has expired or doesn't exist in Redis")
return None
# Calculate the creation time based on TTL and session expiry
# Current time minus (total session length minus remaining TTL)
current_time = datetime.now(timezone.utc)
token_creation_time = current_time - timedelta(
seconds=(SESSION_EXPIRE_TIME_SECONDS - ttl)
)
return token_creation_time
except Exception as e:
logger.error(f"Error retrieving token expiration from Redis: {e}")
return None
def get_current_token_creation(
user: User | None, db_session: Session
) -> datetime | None:
@@ -533,6 +581,7 @@ def get_current_token_creation(
@router.get("/me")
def verify_user_logged_in(
request: Request,
user: User | None = Depends(optional_user),
db_session: Session = Depends(get_session),
) -> UserInfo:
@@ -558,7 +607,9 @@ def verify_user_logged_in(
)
token_created_at = (
None if MULTI_TENANT else get_current_token_creation(user, db_session)
get_current_auth_token_creation_redis(user, request)
if AUTH_BACKEND == AuthBackend.REDIS
else get_current_token_creation(user, db_session)
)
team_name = fetch_ee_implementation_or_noop(

View File

@@ -19,6 +19,9 @@ from onyx.db.search_settings import get_secondary_search_settings
from onyx.document_index.factory import get_default_document_index
from onyx.indexing.embedder import DefaultIndexingEmbedder
from onyx.indexing.indexing_pipeline import build_indexing_pipeline
from onyx.natural_language_processing.search_nlp_models import (
InformationContentClassificationModel,
)
from onyx.server.onyx_api.models import DocMinimalInfo
from onyx.server.onyx_api.models import IngestionDocument
from onyx.server.onyx_api.models import IngestionResult
@@ -102,8 +105,11 @@ def upsert_ingestion_doc(
search_settings=search_settings
)
information_content_classification_model = InformationContentClassificationModel()
indexing_pipeline = build_indexing_pipeline(
embedder=index_embedding_model,
information_content_classification_model=information_content_classification_model,
document_index=curr_doc_index,
ignore_time_skip=True,
db_session=db_session,
@@ -138,6 +144,7 @@ def upsert_ingestion_doc(
sec_ind_pipeline = build_indexing_pipeline(
embedder=new_index_embedding_model,
information_content_classification_model=information_content_classification_model,
document_index=sec_doc_index,
ignore_time_skip=True,
db_session=db_session,

View File

@@ -25,7 +25,7 @@ google-auth-oauthlib==1.0.0
httpcore==1.0.5
httpx[http2]==0.27.0
httpx-oauth==0.15.1
huggingface-hub==0.20.1
huggingface-hub==0.29.0
inflection==0.5.1
jira==3.5.1
jsonref==1.1.0
@@ -71,6 +71,7 @@ requests==2.32.2
requests-oauthlib==1.3.1
retry==0.9.2 # This pulls in py which is in CVE-2022-42969, must remove py from image
rfc3986==1.5.0
setfit==1.1.1
simple-salesforce==1.12.6
slack-sdk==3.20.2
SQLAlchemy[mypy]==2.0.15
@@ -78,7 +79,7 @@ starlette==0.36.3
supervisor==4.2.5
tiktoken==0.7.0
timeago==1.0.16
transformers==4.39.2
transformers==4.49.0
unstructured==0.15.1
unstructured-client==0.25.4
uvicorn==0.21.1

View File

@@ -14,7 +14,7 @@ pytest-asyncio==0.22.0
pytest==7.4.4
reorder-python-imports==3.9.0
ruff==0.0.286
sentence-transformers==2.6.1
sentence-transformers==3.4.1
trafilatura==1.12.2
types-beautifulsoup4==4.12.0.3
types-html5lib==1.1.11.13

View File

@@ -7,9 +7,10 @@ openai==1.61.0
pydantic==2.8.2
retry==0.9.2
safetensors==0.4.2
sentence-transformers==2.6.1
sentence-transformers==3.4.1
setfit==1.1.1
torch==2.2.0
transformers==4.39.2
transformers==4.49.0
uvicorn==0.21.1
voyageai==0.2.3
litellm==1.61.16

View File

@@ -161,17 +161,21 @@ overview_doc = SeedPresaveDocument(
url="https://docs.onyx.app/more/use_cases/overview",
title=overview_title,
content=overview,
title_embedding=model.encode(f"search_document: {overview_title}"),
content_embedding=model.encode(f"search_document: {overview_title}\n{overview}"),
title_embedding=list(model.encode(f"search_document: {overview_title}")),
content_embedding=list(
model.encode(f"search_document: {overview_title}\n{overview}")
),
)
enterprise_search_doc = SeedPresaveDocument(
url="https://docs.onyx.app/more/use_cases/enterprise_search",
title=enterprise_search_title,
content=enterprise_search_1,
title_embedding=model.encode(f"search_document: {enterprise_search_title}"),
content_embedding=model.encode(
f"search_document: {enterprise_search_title}\n{enterprise_search_1}"
title_embedding=list(model.encode(f"search_document: {enterprise_search_title}")),
content_embedding=list(
model.encode(
f"search_document: {enterprise_search_title}\n{enterprise_search_1}"
)
),
)
@@ -179,9 +183,11 @@ enterprise_search_doc_2 = SeedPresaveDocument(
url="https://docs.onyx.app/more/use_cases/enterprise_search",
title=enterprise_search_title,
content=enterprise_search_2,
title_embedding=model.encode(f"search_document: {enterprise_search_title}"),
content_embedding=model.encode(
f"search_document: {enterprise_search_title}\n{enterprise_search_2}"
title_embedding=list(model.encode(f"search_document: {enterprise_search_title}")),
content_embedding=list(
model.encode(
f"search_document: {enterprise_search_title}\n{enterprise_search_2}"
)
),
chunk_ind=1,
)
@@ -190,9 +196,9 @@ ai_platform_doc = SeedPresaveDocument(
url="https://docs.onyx.app/more/use_cases/ai_platform",
title=ai_platform_title,
content=ai_platform,
title_embedding=model.encode(f"search_document: {ai_platform_title}"),
content_embedding=model.encode(
f"search_document: {ai_platform_title}\n{ai_platform}"
title_embedding=list(model.encode(f"search_document: {ai_platform_title}")),
content_embedding=list(
model.encode(f"search_document: {ai_platform_title}\n{ai_platform}")
),
)
@@ -200,9 +206,9 @@ customer_support_doc = SeedPresaveDocument(
url="https://docs.onyx.app/more/use_cases/support",
title=customer_support_title,
content=customer_support,
title_embedding=model.encode(f"search_document: {customer_support_title}"),
content_embedding=model.encode(
f"search_document: {customer_support_title}\n{customer_support}"
title_embedding=list(model.encode(f"search_document: {customer_support_title}")),
content_embedding=list(
model.encode(f"search_document: {customer_support_title}\n{customer_support}")
),
)
@@ -210,17 +216,17 @@ sales_doc = SeedPresaveDocument(
url="https://docs.onyx.app/more/use_cases/sales",
title=sales_title,
content=sales,
title_embedding=model.encode(f"search_document: {sales_title}"),
content_embedding=model.encode(f"search_document: {sales_title}\n{sales}"),
title_embedding=list(model.encode(f"search_document: {sales_title}")),
content_embedding=list(model.encode(f"search_document: {sales_title}\n{sales}")),
)
operations_doc = SeedPresaveDocument(
url="https://docs.onyx.app/more/use_cases/operations",
title=operations_title,
content=operations,
title_embedding=model.encode(f"search_document: {operations_title}"),
content_embedding=model.encode(
f"search_document: {operations_title}\n{operations}"
title_embedding=list(model.encode(f"search_document: {operations_title}")),
content_embedding=list(
model.encode(f"search_document: {operations_title}\n{operations}")
),
)

View File

@@ -99,6 +99,7 @@ def generate_dummy_chunk(
),
document_sets={document_set for document_set in document_set_names},
boost=random.randint(-1, 1),
aggregated_chunk_boost_factor=random.random(),
tenant_id=POSTGRES_DEFAULT_SCHEMA,
)

View File

@@ -23,9 +23,11 @@ INDEXING_MODEL_SERVER_PORT = int(
# Onyx custom Deep Learning Models
CONNECTOR_CLASSIFIER_MODEL_REPO = "Danswer/filter-extraction-model"
CONNECTOR_CLASSIFIER_MODEL_TAG = "1.0.0"
INTENT_MODEL_VERSION = "danswer/hybrid-intent-token-classifier"
INTENT_MODEL_TAG = "v1.0.3"
INTENT_MODEL_VERSION = "onyx-dot-app/hybrid-intent-token-classifier"
# INTENT_MODEL_TAG = "v1.0.3"
INTENT_MODEL_TAG: str | None = None
INFORMATION_CONTENT_MODEL_VERSION = "onyx-dot-app/information-content-model"
INFORMATION_CONTENT_MODEL_TAG: str | None = None
# Bi-Encoder, other details
DOC_EMBEDDING_CONTEXT_SIZE = 512
@@ -277,3 +279,20 @@ SUPPORTED_EMBEDDING_MODELS = [
index_name="danswer_chunk_intfloat_multilingual_e5_small",
),
]
# Maximum (least severe) downgrade factor for chunks above the cutoff
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX = float(
os.environ.get("INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX") or 1.0
)
# Minimum (most severe) downgrade factor for short chunks below the cutoff if no content
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN = float(
os.environ.get("INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN") or 0.7
)
# Temperature for the information content classification model
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE = float(
os.environ.get("INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE") or 4.0
)
# Cutoff below which we start using the information content classification model
# (cutoff length number itself is still considered 'short'))
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH = int(
os.environ.get("INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH") or 10
)

View File

@@ -4,6 +4,7 @@ from shared_configs.enums import EmbeddingProvider
from shared_configs.enums import EmbedTextType
from shared_configs.enums import RerankerProvider
Embedding = list[float]
@@ -73,7 +74,20 @@ class IntentResponse(BaseModel):
keywords: list[str]
class InformationContentClassificationRequests(BaseModel):
queries: list[str]
class SupportedEmbeddingModel(BaseModel):
name: str
dim: int
index_name: str
class ContentClassificationPrediction(BaseModel):
predicted_label: int
content_boost_factor: float
class InformationContentClassificationResponses(BaseModel):
information_content_classifications: list[ContentClassificationPrediction]

Some files were not shown because too many files have changed in this diff Show More