mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-21 01:35:46 +00:00
Compare commits
1 Commits
content-mo
...
invite_use
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
aacdf775da |
@@ -31,8 +31,7 @@ RUN python -c "from transformers import AutoTokenizer; \
|
||||
AutoTokenizer.from_pretrained('distilbert-base-uncased'); \
|
||||
AutoTokenizer.from_pretrained('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
|
||||
from huggingface_hub import snapshot_download; \
|
||||
snapshot_download(repo_id='onyx-dot-app/hybrid-intent-token-classifier'); \
|
||||
snapshot_download(repo_id='onyx-dot-app/information-content-model'); \
|
||||
snapshot_download(repo_id='danswer/hybrid-intent-token-classifier', revision='v1.0.3'); \
|
||||
snapshot_download('nomic-ai/nomic-embed-text-v1'); \
|
||||
snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
|
||||
from sentence_transformers import SentenceTransformer; \
|
||||
|
||||
@@ -1,51 +0,0 @@
|
||||
"""add chunk stats table
|
||||
|
||||
Revision ID: 3781a5eb12cb
|
||||
Revises: df46c75b714e
|
||||
Create Date: 2025-03-10 10:02:30.586666
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "3781a5eb12cb"
|
||||
down_revision = "df46c75b714e"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"chunk_stats",
|
||||
sa.Column("id", sa.String(), primary_key=True, index=True),
|
||||
sa.Column(
|
||||
"document_id",
|
||||
sa.String(),
|
||||
sa.ForeignKey("document.id"),
|
||||
nullable=False,
|
||||
index=True,
|
||||
),
|
||||
sa.Column("chunk_in_doc_id", sa.Integer(), nullable=False),
|
||||
sa.Column("information_content_boost", sa.Float(), nullable=True),
|
||||
sa.Column(
|
||||
"last_modified",
|
||||
sa.DateTime(timezone=True),
|
||||
nullable=False,
|
||||
index=True,
|
||||
server_default=sa.func.now(),
|
||||
),
|
||||
sa.Column("last_synced", sa.DateTime(timezone=True), nullable=True, index=True),
|
||||
sa.UniqueConstraint(
|
||||
"document_id", "chunk_in_doc_id", name="uq_chunk_stats_doc_chunk"
|
||||
),
|
||||
)
|
||||
|
||||
op.create_index(
|
||||
"ix_chunk_sync_status", "chunk_stats", ["last_modified", "last_synced"]
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index("ix_chunk_sync_status", table_name="chunk_stats")
|
||||
op.drop_table("chunk_stats")
|
||||
@@ -5,10 +5,7 @@ Revises: f1ca58b2f2ec
|
||||
Create Date: 2025-01-29 07:48:46.784041
|
||||
|
||||
"""
|
||||
import logging
|
||||
from typing import cast
|
||||
from alembic import op
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.sql import text
|
||||
|
||||
|
||||
@@ -18,45 +15,21 @@ down_revision = "f1ca58b2f2ec"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
logger = logging.getLogger("alembic.runtime.migration")
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
"""Conflicts on lowercasing will result in the uppercased email getting a
|
||||
unique integer suffix when converted to lowercase."""
|
||||
|
||||
# Get database connection
|
||||
connection = op.get_bind()
|
||||
|
||||
# Fetch all user emails that are not already lowercase
|
||||
user_emails = connection.execute(
|
||||
text('SELECT id, email FROM "user" WHERE email != LOWER(email)')
|
||||
).fetchall()
|
||||
|
||||
for user_id, email in user_emails:
|
||||
email = cast(str, email)
|
||||
username, domain = email.rsplit("@", 1)
|
||||
new_email = f"{username.lower()}@{domain.lower()}"
|
||||
attempt = 1
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Try updating the email
|
||||
connection.execute(
|
||||
text('UPDATE "user" SET email = :new_email WHERE id = :user_id'),
|
||||
{"new_email": new_email, "user_id": user_id},
|
||||
)
|
||||
break # Success, exit loop
|
||||
except IntegrityError:
|
||||
next_email = f"{username.lower()}_{attempt}@{domain.lower()}"
|
||||
# Email conflict occurred, append `_1`, `_2`, etc., to the username
|
||||
logger.warning(
|
||||
f"Conflict while lowercasing email: "
|
||||
f"old_email={email} "
|
||||
f"conflicting_email={new_email} "
|
||||
f"next_email={next_email}"
|
||||
)
|
||||
new_email = next_email
|
||||
attempt += 1
|
||||
# Update all user emails to lowercase
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE "user"
|
||||
SET email = LOWER(email)
|
||||
WHERE email != LOWER(email)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
"""add_default_vision_provider_to_llm_provider
|
||||
|
||||
Revision ID: df46c75b714e
|
||||
Revises: 3934b1bc7b62
|
||||
Create Date: 2025-03-11 16:20:19.038945
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "df46c75b714e"
|
||||
down_revision = "3934b1bc7b62"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"llm_provider",
|
||||
sa.Column(
|
||||
"is_default_vision_provider",
|
||||
sa.Boolean(),
|
||||
nullable=True,
|
||||
server_default=sa.false(),
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"llm_provider", sa.Column("default_vision_model", sa.String(), nullable=True)
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("llm_provider", "default_vision_model")
|
||||
op.drop_column("llm_provider", "is_default_vision_provider")
|
||||
@@ -1,33 +0,0 @@
|
||||
"""add new available tenant table
|
||||
|
||||
Revision ID: 3b45e0018bf1
|
||||
Revises: ac842f85f932
|
||||
Create Date: 2025-03-06 09:55:18.229910
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "3b45e0018bf1"
|
||||
down_revision = "ac842f85f932"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create new_available_tenant table
|
||||
op.create_table(
|
||||
"available_tenant",
|
||||
sa.Column("tenant_id", sa.String(), nullable=False),
|
||||
sa.Column("alembic_version", sa.String(), nullable=False),
|
||||
sa.Column("date_created", sa.DateTime(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("tenant_id"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop new_available_tenant table
|
||||
op.drop_table("available_tenant")
|
||||
@@ -28,12 +28,11 @@ from onyx.auth.users import exceptions
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.db.engine import get_session_with_shared_schema
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import get_sqlalchemy_engine
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import upsert_cloud_embedding_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.models import AvailableTenant
|
||||
from onyx.db.models import IndexModelStatus
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.models import UserTenantMapping
|
||||
@@ -63,72 +62,42 @@ async def get_or_provision_tenant(
|
||||
This function should only be called after we have verified we want this user's tenant to exist.
|
||||
It returns the tenant ID associated with the email, creating a new tenant if necessary.
|
||||
"""
|
||||
# Early return for non-multi-tenant mode
|
||||
if not MULTI_TENANT:
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
if referral_source and request:
|
||||
await submit_to_hubspot(email, referral_source, request)
|
||||
|
||||
# First, check if the user already has a tenant
|
||||
tenant_id: str | None = None
|
||||
try:
|
||||
tenant_id = get_tenant_id_for_email(email)
|
||||
return tenant_id
|
||||
except exceptions.UserNotExists:
|
||||
# User doesn't exist, so we need to create a new tenant or assign an existing one
|
||||
pass
|
||||
|
||||
try:
|
||||
# Try to get a pre-provisioned tenant
|
||||
tenant_id = await get_available_tenant()
|
||||
|
||||
if tenant_id:
|
||||
# If we have a pre-provisioned tenant, assign it to the user
|
||||
await assign_tenant_to_user(tenant_id, email, referral_source)
|
||||
logger.info(f"Assigned pre-provisioned tenant {tenant_id} to user {email}")
|
||||
return tenant_id
|
||||
else:
|
||||
# If no pre-provisioned tenant is available, create a new one on-demand
|
||||
# If tenant does not exist and in Multi tenant mode, provision a new tenant
|
||||
try:
|
||||
tenant_id = await create_tenant(email, referral_source)
|
||||
return tenant_id
|
||||
except Exception as e:
|
||||
logger.error(f"Tenant provisioning failed: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to provision tenant.")
|
||||
|
||||
except Exception as e:
|
||||
# If we've encountered an error, log and raise an exception
|
||||
error_msg = "Failed to provision tenant"
|
||||
logger.error(error_msg, exc_info=e)
|
||||
if not tenant_id:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Failed to provision tenant. Please try again later.",
|
||||
status_code=401, detail="User does not belong to an organization"
|
||||
)
|
||||
|
||||
return tenant_id
|
||||
|
||||
|
||||
async def create_tenant(email: str, referral_source: str | None = None) -> str:
|
||||
"""
|
||||
Create a new tenant on-demand when no pre-provisioned tenants are available.
|
||||
This is the fallback method when we can't use a pre-provisioned tenant.
|
||||
|
||||
"""
|
||||
tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4())
|
||||
logger.info(f"Creating new tenant {tenant_id} for user {email}")
|
||||
|
||||
try:
|
||||
# Provision tenant on data plane
|
||||
await provision_tenant(tenant_id, email)
|
||||
|
||||
# Notify control plane if not already done in provision_tenant
|
||||
if not DEV_MODE and referral_source:
|
||||
# Notify control plane
|
||||
if not DEV_MODE:
|
||||
await notify_control_plane(tenant_id, email, referral_source)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Tenant provisioning failed: {str(e)}")
|
||||
# Attempt to rollback the tenant provisioning
|
||||
try:
|
||||
await rollback_tenant_provisioning(tenant_id)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to rollback tenant provisioning for {tenant_id}")
|
||||
logger.error(f"Tenant provisioning failed: {e}")
|
||||
await rollback_tenant_provisioning(tenant_id)
|
||||
raise HTTPException(status_code=500, detail="Failed to provision tenant.")
|
||||
|
||||
return tenant_id
|
||||
|
||||
|
||||
@@ -142,25 +111,54 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
|
||||
)
|
||||
|
||||
logger.debug(f"Provisioning tenant {tenant_id} for user {email}")
|
||||
token = None
|
||||
|
||||
try:
|
||||
# Create the schema for the tenant
|
||||
if not create_schema_if_not_exists(tenant_id):
|
||||
logger.debug(f"Created schema for tenant {tenant_id}")
|
||||
else:
|
||||
logger.debug(f"Schema already exists for tenant {tenant_id}")
|
||||
|
||||
# Set up the tenant with all necessary configurations
|
||||
await setup_tenant(tenant_id)
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
# Assign the tenant to the user
|
||||
await assign_tenant_to_user(tenant_id, email)
|
||||
# Await the Alembic migrations
|
||||
await asyncio.to_thread(run_alembic_migrations, tenant_id)
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
configure_default_api_keys(db_session)
|
||||
|
||||
current_search_settings = (
|
||||
db_session.query(SearchSettings)
|
||||
.filter_by(status=IndexModelStatus.FUTURE)
|
||||
.first()
|
||||
)
|
||||
cohere_enabled = (
|
||||
current_search_settings is not None
|
||||
and current_search_settings.provider_type == EmbeddingProvider.COHERE
|
||||
)
|
||||
setup_onyx(db_session, tenant_id, cohere_enabled=cohere_enabled)
|
||||
|
||||
add_users_to_tenant([email], tenant_id)
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
create_milestone_and_report(
|
||||
user=None,
|
||||
distinct_id=tenant_id,
|
||||
event_type=MilestoneRecordType.TENANT_CREATED,
|
||||
properties={
|
||||
"email": email,
|
||||
},
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to create tenant {tenant_id}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to create tenant: {str(e)}"
|
||||
)
|
||||
finally:
|
||||
if token is not None:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
async def notify_control_plane(
|
||||
@@ -191,74 +189,20 @@ async def notify_control_plane(
|
||||
|
||||
|
||||
async def rollback_tenant_provisioning(tenant_id: str) -> None:
|
||||
"""
|
||||
Logic to rollback tenant provisioning on data plane.
|
||||
Handles each step independently to ensure maximum cleanup even if some steps fail.
|
||||
"""
|
||||
# Logic to rollback tenant provisioning on data plane
|
||||
logger.info(f"Rolling back tenant provisioning for tenant_id: {tenant_id}")
|
||||
|
||||
# Track if any part of the rollback fails
|
||||
rollback_errors = []
|
||||
|
||||
# 1. Try to drop the tenant's schema
|
||||
try:
|
||||
# Drop the tenant's schema to rollback provisioning
|
||||
drop_schema(tenant_id)
|
||||
logger.info(f"Successfully dropped schema for tenant {tenant_id}")
|
||||
|
||||
# Remove tenant mapping
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
db_session.query(UserTenantMapping).filter(
|
||||
UserTenantMapping.tenant_id == tenant_id
|
||||
).delete()
|
||||
db_session.commit()
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to drop schema for tenant {tenant_id}: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
rollback_errors.append(error_msg)
|
||||
|
||||
# 2. Try to remove tenant mapping
|
||||
try:
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
db_session.begin()
|
||||
try:
|
||||
db_session.query(UserTenantMapping).filter(
|
||||
UserTenantMapping.tenant_id == tenant_id
|
||||
).delete()
|
||||
db_session.commit()
|
||||
logger.info(
|
||||
f"Successfully removed user mappings for tenant {tenant_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
db_session.rollback()
|
||||
raise e
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to remove user mappings for tenant {tenant_id}: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
rollback_errors.append(error_msg)
|
||||
|
||||
# 3. If this tenant was in the available tenants table, remove it
|
||||
try:
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
db_session.begin()
|
||||
try:
|
||||
available_tenant = (
|
||||
db_session.query(AvailableTenant)
|
||||
.filter(AvailableTenant.tenant_id == tenant_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if available_tenant:
|
||||
db_session.delete(available_tenant)
|
||||
db_session.commit()
|
||||
logger.info(
|
||||
f"Removed tenant {tenant_id} from available tenants table"
|
||||
)
|
||||
except Exception as e:
|
||||
db_session.rollback()
|
||||
raise e
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to remove tenant {tenant_id} from available tenants table: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
rollback_errors.append(error_msg)
|
||||
|
||||
# Log summary of rollback operation
|
||||
if rollback_errors:
|
||||
logger.error(f"Tenant rollback completed with {len(rollback_errors)} errors")
|
||||
else:
|
||||
logger.info(f"Tenant rollback completed successfully for tenant {tenant_id}")
|
||||
logger.error(f"Failed to rollback tenant provisioning: {e}")
|
||||
|
||||
|
||||
def configure_default_api_keys(db_session: Session) -> None:
|
||||
@@ -455,111 +399,3 @@ def get_tenant_by_domain_from_control_plane(
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching tenant by domain: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
async def get_available_tenant() -> str | None:
|
||||
"""
|
||||
Get an available pre-provisioned tenant from the NewAvailableTenant table.
|
||||
Returns the tenant_id if one is available, None otherwise.
|
||||
Uses row-level locking to prevent race conditions when multiple processes
|
||||
try to get an available tenant simultaneously.
|
||||
"""
|
||||
if not MULTI_TENANT:
|
||||
return None
|
||||
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
try:
|
||||
db_session.begin()
|
||||
|
||||
# Get the oldest available tenant with FOR UPDATE lock to prevent race conditions
|
||||
available_tenant = (
|
||||
db_session.query(AvailableTenant)
|
||||
.order_by(AvailableTenant.date_created)
|
||||
.with_for_update(skip_locked=True) # Skip locked rows to avoid blocking
|
||||
.first()
|
||||
)
|
||||
|
||||
if available_tenant:
|
||||
tenant_id = available_tenant.tenant_id
|
||||
# Remove the tenant from the available tenants table
|
||||
db_session.delete(available_tenant)
|
||||
db_session.commit()
|
||||
logger.info(f"Using pre-provisioned tenant {tenant_id}")
|
||||
return tenant_id
|
||||
else:
|
||||
db_session.rollback()
|
||||
return None
|
||||
except Exception:
|
||||
logger.exception("Error getting available tenant")
|
||||
db_session.rollback()
|
||||
return None
|
||||
|
||||
|
||||
async def setup_tenant(tenant_id: str) -> None:
|
||||
"""
|
||||
Set up a tenant with all necessary configurations.
|
||||
This is a centralized function that handles all tenant setup logic.
|
||||
"""
|
||||
token = None
|
||||
try:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
# Run Alembic migrations
|
||||
await asyncio.to_thread(run_alembic_migrations, tenant_id)
|
||||
|
||||
# Configure the tenant with default settings
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
# Configure default API keys
|
||||
configure_default_api_keys(db_session)
|
||||
|
||||
# Set up Onyx with appropriate settings
|
||||
current_search_settings = (
|
||||
db_session.query(SearchSettings)
|
||||
.filter_by(status=IndexModelStatus.FUTURE)
|
||||
.first()
|
||||
)
|
||||
cohere_enabled = (
|
||||
current_search_settings is not None
|
||||
and current_search_settings.provider_type == EmbeddingProvider.COHERE
|
||||
)
|
||||
setup_onyx(db_session, tenant_id, cohere_enabled=cohere_enabled)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to set up tenant {tenant_id}")
|
||||
raise e
|
||||
finally:
|
||||
if token is not None:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
async def assign_tenant_to_user(
|
||||
tenant_id: str, email: str, referral_source: str | None = None
|
||||
) -> None:
|
||||
"""
|
||||
Assign a tenant to a user and perform necessary operations.
|
||||
Uses transaction handling to ensure atomicity and includes retry logic
|
||||
for control plane notifications.
|
||||
"""
|
||||
# First, add the user to the tenant in a transaction
|
||||
|
||||
try:
|
||||
add_users_to_tenant([email], tenant_id)
|
||||
|
||||
# Create milestone record in the same transaction context as the tenant assignment
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
create_milestone_and_report(
|
||||
user=None,
|
||||
distinct_id=tenant_id,
|
||||
event_type=MilestoneRecordType.TENANT_CREATED,
|
||||
properties={
|
||||
"email": email,
|
||||
},
|
||||
db_session=db_session,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to assign tenant {tenant_id} to user {email}")
|
||||
raise Exception("Failed to assign tenant to user")
|
||||
|
||||
# Notify control plane with retry logic
|
||||
if not DEV_MODE:
|
||||
await notify_control_plane(tenant_id, email, referral_source)
|
||||
|
||||
@@ -74,21 +74,3 @@ def drop_schema(tenant_id: str) -> None:
|
||||
text("DROP SCHEMA IF EXISTS %(schema_name)s CASCADE"),
|
||||
{"schema_name": tenant_id},
|
||||
)
|
||||
|
||||
|
||||
def get_current_alembic_version(tenant_id: str) -> str:
|
||||
"""Get the current Alembic version for a tenant."""
|
||||
from alembic.runtime.migration import MigrationContext
|
||||
from sqlalchemy import text
|
||||
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
||||
# Set the search path to the tenant's schema
|
||||
with engine.connect() as connection:
|
||||
connection.execute(text(f'SET search_path TO "{tenant_id}"'))
|
||||
|
||||
# Get the current version from the alembic_version table
|
||||
context = MigrationContext.configure(connection)
|
||||
current_rev = context.get_current_revision()
|
||||
|
||||
return current_rev or "head"
|
||||
|
||||
@@ -67,39 +67,15 @@ def user_owns_a_tenant(email: str) -> bool:
|
||||
|
||||
|
||||
def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
"""
|
||||
Add users to a tenant with proper transaction handling.
|
||||
Checks if users already have a tenant mapping to avoid duplicates.
|
||||
"""
|
||||
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
try:
|
||||
# Start a transaction
|
||||
db_session.begin()
|
||||
|
||||
for email in emails:
|
||||
# Check if the user already has a mapping to this tenant
|
||||
existing_mapping = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email == email,
|
||||
UserTenantMapping.tenant_id == tenant_id,
|
||||
)
|
||||
.with_for_update()
|
||||
.first()
|
||||
db_session.add(
|
||||
UserTenantMapping(email=email, tenant_id=tenant_id, active=False)
|
||||
)
|
||||
|
||||
if not existing_mapping:
|
||||
# Only add if mapping doesn't exist
|
||||
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
|
||||
|
||||
# Commit the transaction
|
||||
db_session.commit()
|
||||
logger.info(f"Successfully added users {emails} to tenant {tenant_id}")
|
||||
|
||||
except Exception:
|
||||
logger.exception(f"Failed to add users to tenant {tenant_id}")
|
||||
db_session.rollback()
|
||||
raise
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
|
||||
@@ -3,7 +3,6 @@ from shared_configs.enums import EmbedTextType
|
||||
|
||||
|
||||
MODEL_WARM_UP_STRING = "hi " * 512
|
||||
INFORMATION_CONTENT_MODEL_WARM_UP_STRING = "hi " * 16
|
||||
DEFAULT_OPENAI_MODEL = "text-embedding-3-small"
|
||||
DEFAULT_COHERE_MODEL = "embed-english-light-v3.0"
|
||||
DEFAULT_VOYAGE_MODEL = "voyage-large-2-instruct"
|
||||
|
||||
@@ -1,14 +1,11 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from fastapi import APIRouter
|
||||
from huggingface_hub import snapshot_download # type: ignore
|
||||
from setfit import SetFitModel # type: ignore[import]
|
||||
from transformers import AutoTokenizer # type: ignore
|
||||
from transformers import BatchEncoding # type: ignore
|
||||
from transformers import PreTrainedTokenizer # type: ignore
|
||||
|
||||
from model_server.constants import INFORMATION_CONTENT_MODEL_WARM_UP_STRING
|
||||
from model_server.constants import MODEL_WARM_UP_STRING
|
||||
from model_server.onyx_torch_model import ConnectorClassifier
|
||||
from model_server.onyx_torch_model import HybridClassifier
|
||||
@@ -16,22 +13,11 @@ from model_server.utils import simple_log_function_time
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_REPO
|
||||
from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_TAG
|
||||
from shared_configs.configs import (
|
||||
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH,
|
||||
)
|
||||
from shared_configs.configs import INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX
|
||||
from shared_configs.configs import INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
|
||||
from shared_configs.configs import (
|
||||
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE,
|
||||
)
|
||||
from shared_configs.configs import INDEXING_ONLY
|
||||
from shared_configs.configs import INFORMATION_CONTENT_MODEL_TAG
|
||||
from shared_configs.configs import INFORMATION_CONTENT_MODEL_VERSION
|
||||
from shared_configs.configs import INTENT_MODEL_TAG
|
||||
from shared_configs.configs import INTENT_MODEL_VERSION
|
||||
from shared_configs.model_server_models import ConnectorClassificationRequest
|
||||
from shared_configs.model_server_models import ConnectorClassificationResponse
|
||||
from shared_configs.model_server_models import ContentClassificationPrediction
|
||||
from shared_configs.model_server_models import IntentRequest
|
||||
from shared_configs.model_server_models import IntentResponse
|
||||
|
||||
@@ -45,10 +31,6 @@ _CONNECTOR_CLASSIFIER_MODEL: ConnectorClassifier | None = None
|
||||
_INTENT_TOKENIZER: AutoTokenizer | None = None
|
||||
_INTENT_MODEL: HybridClassifier | None = None
|
||||
|
||||
_INFORMATION_CONTENT_MODEL: SetFitModel | None = None
|
||||
|
||||
_INFORMATION_CONTENT_MODEL_PROMPT_PREFIX: str = "" # spec to model version!
|
||||
|
||||
|
||||
def get_connector_classifier_tokenizer() -> AutoTokenizer:
|
||||
global _CONNECTOR_CLASSIFIER_TOKENIZER
|
||||
@@ -103,7 +85,7 @@ def get_intent_model_tokenizer() -> AutoTokenizer:
|
||||
|
||||
def get_local_intent_model(
|
||||
model_name_or_path: str = INTENT_MODEL_VERSION,
|
||||
tag: str | None = INTENT_MODEL_TAG,
|
||||
tag: str = INTENT_MODEL_TAG,
|
||||
) -> HybridClassifier:
|
||||
global _INTENT_MODEL
|
||||
if _INTENT_MODEL is None:
|
||||
@@ -120,9 +102,7 @@ def get_local_intent_model(
|
||||
try:
|
||||
# Attempt to download the model snapshot
|
||||
logger.notice(f"Downloading model snapshot for {model_name_or_path}")
|
||||
local_path = snapshot_download(
|
||||
repo_id=model_name_or_path, revision=tag, local_files_only=False
|
||||
)
|
||||
local_path = snapshot_download(repo_id=model_name_or_path, revision=tag)
|
||||
_INTENT_MODEL = HybridClassifier.from_pretrained(local_path)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
@@ -132,44 +112,6 @@ def get_local_intent_model(
|
||||
return _INTENT_MODEL
|
||||
|
||||
|
||||
def get_local_information_content_model(
|
||||
model_name_or_path: str = INFORMATION_CONTENT_MODEL_VERSION,
|
||||
tag: str | None = INFORMATION_CONTENT_MODEL_TAG,
|
||||
) -> SetFitModel:
|
||||
global _INFORMATION_CONTENT_MODEL
|
||||
if _INFORMATION_CONTENT_MODEL is None:
|
||||
try:
|
||||
# Calculate where the cache should be, then load from local if available
|
||||
logger.notice(
|
||||
f"Loading content information model from local cache: {model_name_or_path}"
|
||||
)
|
||||
local_path = snapshot_download(
|
||||
repo_id=model_name_or_path, revision=tag, local_files_only=True
|
||||
)
|
||||
_INFORMATION_CONTENT_MODEL = SetFitModel.from_pretrained(local_path)
|
||||
logger.notice(
|
||||
f"Loaded content information model from local cache: {local_path}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load content information model directly: {e}")
|
||||
try:
|
||||
# Attempt to download the model snapshot
|
||||
logger.notice(
|
||||
f"Downloading content information model snapshot for {model_name_or_path}"
|
||||
)
|
||||
local_path = snapshot_download(
|
||||
repo_id=model_name_or_path, revision=tag, local_files_only=False
|
||||
)
|
||||
_INFORMATION_CONTENT_MODEL = SetFitModel.from_pretrained(local_path)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to load content information model even after attempted snapshot download: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
return _INFORMATION_CONTENT_MODEL
|
||||
|
||||
|
||||
def tokenize_connector_classification_query(
|
||||
connectors: list[str],
|
||||
query: str,
|
||||
@@ -253,13 +195,6 @@ def warm_up_intent_model() -> None:
|
||||
)
|
||||
|
||||
|
||||
def warm_up_information_content_model() -> None:
|
||||
logger.notice("Warming up Content Model") # TODO: add version if needed
|
||||
|
||||
information_content_model = get_local_information_content_model()
|
||||
information_content_model(INFORMATION_CONTENT_MODEL_WARM_UP_STRING)
|
||||
|
||||
|
||||
@simple_log_function_time()
|
||||
def run_inference(tokens: BatchEncoding) -> tuple[list[float], list[float]]:
|
||||
intent_model = get_local_intent_model()
|
||||
@@ -283,117 +218,6 @@ def run_inference(tokens: BatchEncoding) -> tuple[list[float], list[float]]:
|
||||
return intent_probabilities.tolist(), token_positive_probs
|
||||
|
||||
|
||||
@simple_log_function_time()
|
||||
def run_content_classification_inference(
|
||||
text_inputs: list[str],
|
||||
) -> list[ContentClassificationPrediction]:
|
||||
"""
|
||||
Assign a score to the segments in question. The model stored in get_local_information_content_model()
|
||||
creates the 'model score' based on its training, and the scores are then converted to a 0.0-1.0 scale.
|
||||
In the code outside of the model/inference model servers that score will be converted into the actual
|
||||
boost factor.
|
||||
"""
|
||||
|
||||
def _prob_to_score(prob: float) -> float:
|
||||
"""
|
||||
Conversion of base score to 0.0 - 1.0 score. Note that the min/max values depend on the model!
|
||||
"""
|
||||
_MIN_BASE_SCORE = 0.25
|
||||
_MAX_BASE_SCORE = 0.75
|
||||
if prob < _MIN_BASE_SCORE:
|
||||
raw_score = 0.0
|
||||
elif prob < _MAX_BASE_SCORE:
|
||||
raw_score = (prob - _MIN_BASE_SCORE) / (_MAX_BASE_SCORE - _MIN_BASE_SCORE)
|
||||
else:
|
||||
raw_score = 1.0
|
||||
return (
|
||||
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
|
||||
+ (
|
||||
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX
|
||||
- INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
|
||||
)
|
||||
* raw_score
|
||||
)
|
||||
|
||||
_BATCH_SIZE = 32
|
||||
content_model = get_local_information_content_model()
|
||||
|
||||
# Process inputs in batches
|
||||
all_output_classes: list[int] = []
|
||||
all_base_output_probabilities: list[float] = []
|
||||
|
||||
for i in range(0, len(text_inputs), _BATCH_SIZE):
|
||||
batch = text_inputs[i : i + _BATCH_SIZE]
|
||||
batch_with_prefix = []
|
||||
batch_indices = []
|
||||
|
||||
# Pre-allocate results for this batch
|
||||
batch_output_classes: list[np.ndarray] = [np.array(1)] * len(batch)
|
||||
batch_probabilities: list[np.ndarray] = [np.array(1.0)] * len(batch)
|
||||
|
||||
# Pre-process batch to handle long input exceptions
|
||||
for j, text in enumerate(batch):
|
||||
if len(text) == 0:
|
||||
# if no input, treat as non-informative from the model's perspective
|
||||
batch_output_classes[j] = np.array(0)
|
||||
batch_probabilities[j] = np.array(0.0)
|
||||
logger.warning("Input for Content Information Model is empty")
|
||||
|
||||
elif (
|
||||
len(text.split())
|
||||
<= INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH
|
||||
):
|
||||
# if input is short, use the model
|
||||
batch_with_prefix.append(
|
||||
_INFORMATION_CONTENT_MODEL_PROMPT_PREFIX + text
|
||||
)
|
||||
batch_indices.append(j)
|
||||
else:
|
||||
# if longer than cutoff, treat as informative (stay with default), but issue warning
|
||||
logger.warning("Input for Content Information Model too long")
|
||||
|
||||
if batch_with_prefix: # Only run model if we have valid inputs
|
||||
# Get predictions for the batch
|
||||
model_output_classes = content_model(batch_with_prefix)
|
||||
model_output_probabilities = content_model.predict_proba(batch_with_prefix)
|
||||
|
||||
# Place results in the correct positions
|
||||
for idx, batch_idx in enumerate(batch_indices):
|
||||
batch_output_classes[batch_idx] = model_output_classes[idx].numpy()
|
||||
batch_probabilities[batch_idx] = model_output_probabilities[idx][
|
||||
1
|
||||
].numpy() # x[1] is prob of the positive class
|
||||
|
||||
all_output_classes.extend([int(x) for x in batch_output_classes])
|
||||
all_base_output_probabilities.extend([float(x) for x in batch_probabilities])
|
||||
|
||||
logits = [
|
||||
np.log(p / (1 - p)) if p != 0.0 and p != 1.0 else (100 if p == 1.0 else -100)
|
||||
for p in all_base_output_probabilities
|
||||
]
|
||||
scaled_logits = [
|
||||
logit / INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE
|
||||
for logit in logits
|
||||
]
|
||||
output_probabilities_with_temp = [
|
||||
np.exp(scaled_logit) / (1 + np.exp(scaled_logit))
|
||||
for scaled_logit in scaled_logits
|
||||
]
|
||||
|
||||
prediction_scores = [
|
||||
_prob_to_score(p_temp) for p_temp in output_probabilities_with_temp
|
||||
]
|
||||
|
||||
content_classification_predictions = [
|
||||
ContentClassificationPrediction(
|
||||
predicted_label=predicted_label, content_boost_factor=output_score
|
||||
)
|
||||
for predicted_label, output_score in zip(all_output_classes, prediction_scores)
|
||||
]
|
||||
|
||||
return content_classification_predictions
|
||||
|
||||
|
||||
def map_keywords(
|
||||
input_ids: torch.Tensor, tokenizer: AutoTokenizer, is_keyword: list[bool]
|
||||
) -> list[str]:
|
||||
@@ -538,10 +362,3 @@ async def process_analysis_request(
|
||||
|
||||
is_keyword, keywords = run_analysis(intent_request)
|
||||
return IntentResponse(is_keyword=is_keyword, keywords=keywords)
|
||||
|
||||
|
||||
@router.post("/content-classification")
|
||||
async def process_content_classification_request(
|
||||
content_classification_requests: list[str],
|
||||
) -> list[ContentClassificationPrediction]:
|
||||
return run_content_classification_inference(content_classification_requests)
|
||||
|
||||
@@ -13,7 +13,6 @@ from sentry_sdk.integrations.starlette import StarletteIntegration
|
||||
from transformers import logging as transformer_logging # type:ignore
|
||||
|
||||
from model_server.custom_models import router as custom_models_router
|
||||
from model_server.custom_models import warm_up_information_content_model
|
||||
from model_server.custom_models import warm_up_intent_model
|
||||
from model_server.encoders import router as encoders_router
|
||||
from model_server.management_endpoints import router as management_router
|
||||
@@ -75,15 +74,9 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
|
||||
logger.notice(f"Torch Threads: {torch.get_num_threads()}")
|
||||
|
||||
if not INDEXING_ONLY:
|
||||
logger.notice(
|
||||
"The intent model should run on the model server. The information content model should not run here."
|
||||
)
|
||||
warm_up_intent_model()
|
||||
else:
|
||||
logger.notice(
|
||||
"The content information model should run on the indexing model server. The intent model should not run here."
|
||||
)
|
||||
warm_up_information_content_model()
|
||||
logger.notice("This model server should only run document indexing.")
|
||||
|
||||
yield
|
||||
|
||||
|
||||
@@ -112,6 +112,5 @@ celery_app.autodiscover_tasks(
|
||||
"onyx.background.celery.tasks.connector_deletion",
|
||||
"onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"onyx.background.celery.tasks.indexing",
|
||||
"onyx.background.celery.tasks.tenant_provisioning",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -92,6 +92,5 @@ def on_setup_logging(
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"onyx.background.celery.tasks.monitoring",
|
||||
"onyx.background.celery.tasks.tenant_provisioning",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -167,16 +167,6 @@ beat_cloud_tasks: list[dict] = [
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-available-tenants",
|
||||
"task": OnyxCeleryTask.CHECK_AVAILABLE_TENANTS,
|
||||
"schedule": timedelta(minutes=10),
|
||||
"options": {
|
||||
"queue": OnyxCeleryQueues.MONITORING,
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
# tasks that only run self hosted
|
||||
|
||||
@@ -1,199 +0,0 @@
|
||||
"""
|
||||
Periodic tasks for tenant pre-provisioning.
|
||||
"""
|
||||
import asyncio
|
||||
import datetime
|
||||
import uuid
|
||||
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from ee.onyx.server.tenants.provisioning import setup_tenant
|
||||
from ee.onyx.server.tenants.schema_management import create_schema_if_not_exists
|
||||
from ee.onyx.server.tenants.schema_management import get_current_alembic_version
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.app_configs import TARGET_AVAILABLE_TENANTS
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.engine import get_session_with_shared_schema
|
||||
from onyx.db.models import AvailableTenant
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import TENANT_ID_PREFIX
|
||||
|
||||
# Default number of pre-provisioned tenants to maintain
|
||||
DEFAULT_TARGET_AVAILABLE_TENANTS = 5
|
||||
|
||||
# Soft time limit for tenant pre-provisioning tasks (in seconds)
|
||||
_TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 5 # 5 minutes
|
||||
# Hard time limit for tenant pre-provisioning tasks (in seconds)
|
||||
_TENANT_PROVISIONING_TIME_LIMIT = 60 * 10 # 10 minutes
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_AVAILABLE_TENANTS,
|
||||
queue=OnyxCeleryQueues.MONITORING,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def check_available_tenants(self: Task) -> None:
|
||||
"""
|
||||
Check if we have enough pre-provisioned tenants available.
|
||||
If not, trigger the pre-provisioning of new tenants.
|
||||
"""
|
||||
task_logger.info("STARTING CHECK_AVAILABLE_TENANTS")
|
||||
if not MULTI_TENANT:
|
||||
task_logger.info(
|
||||
"Multi-tenancy is not enabled, skipping tenant pre-provisioning"
|
||||
)
|
||||
return
|
||||
|
||||
r = get_redis_client()
|
||||
lock_check: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_AVAILABLE_TENANTS_LOCK,
|
||||
timeout=_TENANT_PROVISIONING_SOFT_TIME_LIMIT,
|
||||
)
|
||||
|
||||
# These tasks should never overlap
|
||||
if not lock_check.acquire(blocking=False):
|
||||
task_logger.info(
|
||||
"Skipping check_available_tenants task because it is already running"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
# Get the current count of available tenants
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
available_tenants_count = db_session.query(AvailableTenant).count()
|
||||
|
||||
# Get the target number of available tenants
|
||||
target_available_tenants = getattr(
|
||||
TARGET_AVAILABLE_TENANTS, "value", DEFAULT_TARGET_AVAILABLE_TENANTS
|
||||
)
|
||||
|
||||
# Calculate how many new tenants we need to provision
|
||||
tenants_to_provision = max(
|
||||
0, target_available_tenants - available_tenants_count
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Available tenants: {available_tenants_count}, "
|
||||
f"Target: {target_available_tenants}, "
|
||||
f"To provision: {tenants_to_provision}"
|
||||
)
|
||||
|
||||
# Trigger pre-provisioning tasks for each tenant needed
|
||||
for _ in range(tenants_to_provision):
|
||||
from celery import current_app
|
||||
|
||||
current_app.send_task(
|
||||
OnyxCeleryTask.PRE_PROVISION_TENANT,
|
||||
priority=OnyxCeleryPriority.LOW,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
task_logger.exception("Error in check_available_tenants task")
|
||||
|
||||
finally:
|
||||
lock_check.release()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.PRE_PROVISION_TENANT,
|
||||
ignore_result=True,
|
||||
soft_time_limit=_TENANT_PROVISIONING_SOFT_TIME_LIMIT,
|
||||
time_limit=_TENANT_PROVISIONING_TIME_LIMIT,
|
||||
queue=OnyxCeleryQueues.MONITORING,
|
||||
bind=True,
|
||||
)
|
||||
def pre_provision_tenant(self: Task) -> None:
|
||||
"""
|
||||
Pre-provision a new tenant and store it in the NewAvailableTenant table.
|
||||
This function fully sets up the tenant with all necessary configurations,
|
||||
so it's ready to be assigned to a user immediately.
|
||||
"""
|
||||
# The MULTI_TENANT check is now done at the caller level (check_available_tenants)
|
||||
# rather than inside this function
|
||||
|
||||
r = get_redis_client()
|
||||
lock_provision: RedisLock = r.lock(
|
||||
OnyxRedisLocks.PRE_PROVISION_TENANT_LOCK,
|
||||
timeout=_TENANT_PROVISIONING_SOFT_TIME_LIMIT,
|
||||
)
|
||||
|
||||
# Allow multiple pre-provisioning tasks to run, but ensure they don't overlap
|
||||
if not lock_provision.acquire(blocking=False):
|
||||
task_logger.debug(
|
||||
"Skipping pre_provision_tenant task because it is already running"
|
||||
)
|
||||
return
|
||||
|
||||
tenant_id: str | None = None
|
||||
try:
|
||||
# Generate a new tenant ID
|
||||
tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4())
|
||||
task_logger.info(f"Pre-provisioning tenant: {tenant_id}")
|
||||
|
||||
# Create the schema for the new tenant
|
||||
schema_created = create_schema_if_not_exists(tenant_id)
|
||||
if schema_created:
|
||||
task_logger.debug(f"Created schema for tenant: {tenant_id}")
|
||||
else:
|
||||
task_logger.debug(f"Schema already exists for tenant: {tenant_id}")
|
||||
|
||||
# Set up the tenant with all necessary configurations
|
||||
task_logger.debug(f"Setting up tenant configuration: {tenant_id}")
|
||||
asyncio.run(setup_tenant(tenant_id))
|
||||
task_logger.debug(f"Tenant configuration completed: {tenant_id}")
|
||||
|
||||
# Get the current Alembic version
|
||||
alembic_version = get_current_alembic_version(tenant_id)
|
||||
task_logger.debug(
|
||||
f"Tenant {tenant_id} using Alembic version: {alembic_version}"
|
||||
)
|
||||
|
||||
# Store the pre-provisioned tenant in the database
|
||||
task_logger.debug(f"Storing pre-provisioned tenant in database: {tenant_id}")
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
# Use a transaction to ensure atomicity
|
||||
db_session.begin()
|
||||
try:
|
||||
new_tenant = AvailableTenant(
|
||||
tenant_id=tenant_id,
|
||||
alembic_version=alembic_version,
|
||||
date_created=datetime.datetime.now(),
|
||||
)
|
||||
db_session.add(new_tenant)
|
||||
db_session.commit()
|
||||
task_logger.info(f"Successfully pre-provisioned tenant: {tenant_id}")
|
||||
except Exception:
|
||||
db_session.rollback()
|
||||
task_logger.error(
|
||||
f"Failed to store pre-provisioned tenant: {tenant_id}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
except Exception:
|
||||
task_logger.error("Error in pre_provision_tenant task", exc_info=True)
|
||||
# If we have a tenant_id, attempt to rollback any partially completed provisioning
|
||||
if tenant_id:
|
||||
task_logger.info(
|
||||
f"Rolling back failed tenant provisioning for: {tenant_id}"
|
||||
)
|
||||
try:
|
||||
from ee.onyx.server.tenants.provisioning import (
|
||||
rollback_tenant_provisioning,
|
||||
)
|
||||
|
||||
asyncio.run(rollback_tenant_provisioning(tenant_id))
|
||||
except Exception:
|
||||
task_logger.exception(f"Error during rollback for tenant: {tenant_id}")
|
||||
finally:
|
||||
lock_provision.release()
|
||||
@@ -563,7 +563,6 @@ def vespa_metadata_sync_task(self: Task, document_id: str, *, tenant_id: str) ->
|
||||
access=doc_access,
|
||||
boost=doc.boost,
|
||||
hidden=doc.hidden,
|
||||
# aggregated_boost_factor=doc.aggregated_boost_factor,
|
||||
)
|
||||
|
||||
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.
|
||||
|
||||
@@ -28,7 +28,6 @@ from onyx.connectors.models import ConnectorCheckpoint
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import IndexAttemptMetadata
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.connector_credential_pair import get_last_successful_attempt_time
|
||||
from onyx.db.connector_credential_pair import update_connector_credential_pair
|
||||
@@ -53,9 +52,6 @@ from onyx.httpx.httpx_pool import HttpxPool
|
||||
from onyx.indexing.embedder import DefaultIndexingEmbedder
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.indexing.indexing_pipeline import build_indexing_pipeline
|
||||
from onyx.natural_language_processing.search_nlp_models import (
|
||||
InformationContentClassificationModel,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.logger import TaskAttemptSingleton
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
@@ -158,12 +154,14 @@ def strip_null_characters(doc_batch: list[Document]) -> list[Document]:
|
||||
)
|
||||
|
||||
for section in cleaned_doc.sections:
|
||||
if section.link is not None:
|
||||
if section.link and "\x00" in section.link:
|
||||
logger.warning(
|
||||
f"NUL characters found in document link for document: {cleaned_doc.id}"
|
||||
)
|
||||
section.link = section.link.replace("\x00", "")
|
||||
|
||||
# since text can be longer, just replace to avoid double scan
|
||||
if isinstance(section, TextSection) and section.text is not None:
|
||||
section.text = section.text.replace("\x00", "")
|
||||
section.text = section.text.replace("\x00", "")
|
||||
|
||||
cleaned_batch.append(cleaned_doc)
|
||||
|
||||
@@ -351,8 +349,6 @@ def _run_indexing(
|
||||
callback=callback,
|
||||
)
|
||||
|
||||
information_content_classification_model = InformationContentClassificationModel()
|
||||
|
||||
document_index = get_default_document_index(
|
||||
index_attempt_start.search_settings,
|
||||
None,
|
||||
@@ -361,7 +357,6 @@ def _run_indexing(
|
||||
|
||||
indexing_pipeline = build_indexing_pipeline(
|
||||
embedder=embedding_model,
|
||||
information_content_classification_model=information_content_classification_model,
|
||||
document_index=document_index,
|
||||
ignore_time_skip=(
|
||||
ctx.from_beginning
|
||||
@@ -484,11 +479,7 @@ def _run_indexing(
|
||||
|
||||
doc_size = 0
|
||||
for section in doc.sections:
|
||||
if (
|
||||
isinstance(section, TextSection)
|
||||
and section.text is not None
|
||||
):
|
||||
doc_size += len(section.text)
|
||||
doc_size += len(section.text)
|
||||
|
||||
if doc_size > INDEXING_SIZE_WARNING_THRESHOLD:
|
||||
logger.warning(
|
||||
|
||||
@@ -8,9 +8,6 @@ from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import DocumentIndexType
|
||||
from onyx.configs.constants import QueryHistoryType
|
||||
from onyx.file_processing.enums import HtmlBasedConnectorTransformLinksStrategy
|
||||
from onyx.prompts.image_analysis import DEFAULT_IMAGE_ANALYSIS_SYSTEM_PROMPT
|
||||
from onyx.prompts.image_analysis import DEFAULT_IMAGE_SUMMARIZATION_SYSTEM_PROMPT
|
||||
from onyx.prompts.image_analysis import DEFAULT_IMAGE_SUMMARIZATION_USER_PROMPT
|
||||
|
||||
#####
|
||||
# App Configs
|
||||
@@ -646,24 +643,3 @@ MOCK_LLM_RESPONSE = (
|
||||
|
||||
|
||||
DEFAULT_IMAGE_ANALYSIS_MAX_SIZE_MB = 20
|
||||
|
||||
# Number of pre-provisioned tenants to maintain
|
||||
TARGET_AVAILABLE_TENANTS = int(os.environ.get("TARGET_AVAILABLE_TENANTS", "5"))
|
||||
|
||||
|
||||
# Image summarization configuration
|
||||
IMAGE_SUMMARIZATION_SYSTEM_PROMPT = os.environ.get(
|
||||
"IMAGE_SUMMARIZATION_SYSTEM_PROMPT",
|
||||
DEFAULT_IMAGE_SUMMARIZATION_SYSTEM_PROMPT,
|
||||
)
|
||||
|
||||
# The user prompt for image summarization - the image filename will be automatically prepended
|
||||
IMAGE_SUMMARIZATION_USER_PROMPT = os.environ.get(
|
||||
"IMAGE_SUMMARIZATION_USER_PROMPT",
|
||||
DEFAULT_IMAGE_SUMMARIZATION_USER_PROMPT,
|
||||
)
|
||||
|
||||
IMAGE_ANALYSIS_SYSTEM_PROMPT = os.environ.get(
|
||||
"IMAGE_ANALYSIS_SYSTEM_PROMPT",
|
||||
DEFAULT_IMAGE_ANALYSIS_SYSTEM_PROMPT,
|
||||
)
|
||||
|
||||
@@ -322,8 +322,6 @@ class OnyxRedisLocks:
|
||||
"da_lock:check_connector_external_group_sync_beat"
|
||||
)
|
||||
MONITOR_BACKGROUND_PROCESSES_LOCK = "da_lock:monitor_background_processes"
|
||||
CHECK_AVAILABLE_TENANTS_LOCK = "da_lock:check_available_tenants"
|
||||
PRE_PROVISION_TENANT_LOCK = "da_lock:pre_provision_tenant"
|
||||
|
||||
CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX = (
|
||||
"da_lock:connector_doc_permissions_sync"
|
||||
@@ -386,7 +384,6 @@ class OnyxCeleryTask:
|
||||
CLOUD_MONITOR_CELERY_QUEUES = (
|
||||
f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor_celery_queues"
|
||||
)
|
||||
CHECK_AVAILABLE_TENANTS = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check_available_tenants"
|
||||
|
||||
CHECK_FOR_CONNECTOR_DELETION = "check_for_connector_deletion_task"
|
||||
CHECK_FOR_VESPA_SYNC_TASK = "check_for_vespa_sync_task"
|
||||
@@ -403,9 +400,6 @@ class OnyxCeleryTask:
|
||||
MONITOR_BACKGROUND_PROCESSES = "monitor_background_processes"
|
||||
MONITOR_CELERY_QUEUES = "monitor_celery_queues"
|
||||
|
||||
# Tenant pre-provisioning
|
||||
PRE_PROVISION_TENANT = "pre_provision_tenant"
|
||||
|
||||
KOMBU_MESSAGE_CLEANUP_TASK = "kombu_message_cleanup_task"
|
||||
CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK = (
|
||||
"connector_permission_sync_generator_task"
|
||||
|
||||
@@ -132,10 +132,3 @@ if _LITELLM_EXTRA_BODY_RAW:
|
||||
LITELLM_EXTRA_BODY = json.loads(_LITELLM_EXTRA_BODY_RAW)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Whether and how to lower scores for short chunks w/o relevant context
|
||||
# Evaluated via custom ML model
|
||||
|
||||
USE_INFORMATION_CONTENT_CLASSIFICATION = (
|
||||
os.environ.get("USE_INFORMATION_CONTENT_CLASSIFICATION", "false").lower() == "true"
|
||||
)
|
||||
|
||||
@@ -4,7 +4,6 @@ from concurrent.futures import Future
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from io import BytesIO
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import requests
|
||||
from pyairtable import Api as AirtableApi
|
||||
@@ -17,8 +16,7 @@ from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -269,7 +267,7 @@ class AirtableConnector(LoadConnector):
|
||||
table_id: str,
|
||||
view_id: str | None,
|
||||
record_id: str,
|
||||
) -> tuple[list[TextSection], dict[str, str | list[str]]]:
|
||||
) -> tuple[list[Section], dict[str, str | list[str]]]:
|
||||
"""
|
||||
Process a single Airtable field and return sections or metadata.
|
||||
|
||||
@@ -307,7 +305,7 @@ class AirtableConnector(LoadConnector):
|
||||
|
||||
# Otherwise, create relevant sections
|
||||
sections = [
|
||||
TextSection(
|
||||
Section(
|
||||
link=link,
|
||||
text=(
|
||||
f"{field_name}:\n"
|
||||
@@ -342,7 +340,7 @@ class AirtableConnector(LoadConnector):
|
||||
table_name = table_schema.name
|
||||
record_id = record["id"]
|
||||
fields = record["fields"]
|
||||
sections: list[TextSection] = []
|
||||
sections: list[Section] = []
|
||||
metadata: dict[str, str | list[str]] = {}
|
||||
|
||||
# Get primary field value if it exists
|
||||
@@ -386,7 +384,7 @@ class AirtableConnector(LoadConnector):
|
||||
|
||||
return Document(
|
||||
id=f"airtable__{record_id}",
|
||||
sections=(cast(list[TextSection | ImageSection], sections)),
|
||||
sections=sections,
|
||||
source=DocumentSource.AIRTABLE,
|
||||
semantic_identifier=semantic_id,
|
||||
metadata=metadata,
|
||||
|
||||
@@ -10,7 +10,7 @@ from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -82,7 +82,7 @@ class AsanaConnector(LoadConnector, PollConnector):
|
||||
logger.debug(f"Converting Asana task {task.id} to Document")
|
||||
return Document(
|
||||
id=task.id,
|
||||
sections=[TextSection(link=task.link, text=task.text)],
|
||||
sections=[Section(link=task.link, text=task.text)],
|
||||
doc_updated_at=task.last_modified,
|
||||
source=DocumentSource.ASANA,
|
||||
semantic_identifier=task.title,
|
||||
|
||||
@@ -20,7 +20,7 @@ from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.file_processing.html_utils import parse_html_page_basic
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
@@ -221,7 +221,7 @@ def _get_forums(
|
||||
def _translate_forum_to_doc(af: AxeroForum) -> Document:
|
||||
doc = Document(
|
||||
id=af.doc_id,
|
||||
sections=[TextSection(link=af.link, text=reply) for reply in af.responses],
|
||||
sections=[Section(link=af.link, text=reply) for reply in af.responses],
|
||||
source=DocumentSource.AXERO,
|
||||
semantic_identifier=af.title,
|
||||
doc_updated_at=af.last_update,
|
||||
@@ -244,7 +244,7 @@ def _translate_content_to_doc(content: dict) -> Document:
|
||||
|
||||
doc = Document(
|
||||
id="AXERO_" + str(content["ContentID"]),
|
||||
sections=[TextSection(link=content["ContentURL"], text=page_text)],
|
||||
sections=[Section(link=content["ContentURL"], text=page_text)],
|
||||
source=DocumentSource.AXERO,
|
||||
semantic_identifier=content["ContentTitle"],
|
||||
doc_updated_at=time_str_to_utc(content["DateUpdated"]),
|
||||
|
||||
@@ -25,7 +25,7 @@ from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -208,7 +208,7 @@ class BlobStorageConnector(LoadConnector, PollConnector):
|
||||
batch.append(
|
||||
Document(
|
||||
id=f"{self.bucket_type}:{self.bucket_name}:{obj['Key']}",
|
||||
sections=[TextSection(link=link, text=text)],
|
||||
sections=[Section(link=link, text=text)],
|
||||
source=DocumentSource(self.bucket_type.value),
|
||||
semantic_identifier=name,
|
||||
doc_updated_at=last_modified,
|
||||
@@ -341,14 +341,7 @@ if __name__ == "__main__":
|
||||
print("Sections:")
|
||||
for section in doc.sections:
|
||||
print(f" - Link: {section.link}")
|
||||
if isinstance(section, TextSection) and section.text is not None:
|
||||
print(f" - Text: {section.text[:100]}...")
|
||||
elif (
|
||||
hasattr(section, "image_file_name") and section.image_file_name
|
||||
):
|
||||
print(f" - Image: {section.image_file_name}")
|
||||
else:
|
||||
print("Error: Unknown section type")
|
||||
print(f" - Text: {section.text[:100]}...")
|
||||
print("---")
|
||||
break
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.file_processing.html_utils import parse_html_page_basic
|
||||
|
||||
|
||||
@@ -81,7 +81,7 @@ class BookstackConnector(LoadConnector, PollConnector):
|
||||
)
|
||||
return Document(
|
||||
id="book__" + str(book.get("id")),
|
||||
sections=[TextSection(link=url, text=text)],
|
||||
sections=[Section(link=url, text=text)],
|
||||
source=DocumentSource.BOOKSTACK,
|
||||
semantic_identifier="Book: " + title,
|
||||
title=title,
|
||||
@@ -110,7 +110,7 @@ class BookstackConnector(LoadConnector, PollConnector):
|
||||
)
|
||||
return Document(
|
||||
id="chapter__" + str(chapter.get("id")),
|
||||
sections=[TextSection(link=url, text=text)],
|
||||
sections=[Section(link=url, text=text)],
|
||||
source=DocumentSource.BOOKSTACK,
|
||||
semantic_identifier="Chapter: " + title,
|
||||
title=title,
|
||||
@@ -134,7 +134,7 @@ class BookstackConnector(LoadConnector, PollConnector):
|
||||
)
|
||||
return Document(
|
||||
id="shelf:" + str(shelf.get("id")),
|
||||
sections=[TextSection(link=url, text=text)],
|
||||
sections=[Section(link=url, text=text)],
|
||||
source=DocumentSource.BOOKSTACK,
|
||||
semantic_identifier="Shelf: " + title,
|
||||
title=title,
|
||||
@@ -167,7 +167,7 @@ class BookstackConnector(LoadConnector, PollConnector):
|
||||
time.sleep(0.1)
|
||||
return Document(
|
||||
id="page:" + page_id,
|
||||
sections=[TextSection(link=url, text=text)],
|
||||
sections=[Section(link=url, text=text)],
|
||||
source=DocumentSource.BOOKSTACK,
|
||||
semantic_identifier="Page: " + str(title),
|
||||
title=str(title),
|
||||
|
||||
@@ -17,7 +17,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
|
||||
@@ -62,11 +62,11 @@ class ClickupConnector(LoadConnector, PollConnector):
|
||||
|
||||
return response.json()
|
||||
|
||||
def _get_task_comments(self, task_id: str) -> list[TextSection]:
|
||||
def _get_task_comments(self, task_id: str) -> list[Section]:
|
||||
url_endpoint = f"/task/{task_id}/comment"
|
||||
response = self._make_request(url_endpoint)
|
||||
comments = [
|
||||
TextSection(
|
||||
Section(
|
||||
link=f'https://app.clickup.com/t/{task_id}?comment={comment_dict["id"]}',
|
||||
text=comment_dict["comment_text"],
|
||||
)
|
||||
@@ -133,7 +133,7 @@ class ClickupConnector(LoadConnector, PollConnector):
|
||||
],
|
||||
title=task["name"],
|
||||
sections=[
|
||||
TextSection(
|
||||
Section(
|
||||
link=task["url"],
|
||||
text=(
|
||||
task["markdown_description"]
|
||||
|
||||
@@ -33,9 +33,9 @@ from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.vision_enabled_connector import VisionEnabledConnector
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -85,6 +85,7 @@ class ConfluenceConnector(
|
||||
PollConnector,
|
||||
SlimConnector,
|
||||
CredentialsConnector,
|
||||
VisionEnabledConnector,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -115,6 +116,9 @@ class ConfluenceConnector(
|
||||
self._confluence_client: OnyxConfluence | None = None
|
||||
self._fetched_titles: set[str] = set()
|
||||
|
||||
# Initialize vision LLM using the mixin
|
||||
self.initialize_vision_llm()
|
||||
|
||||
# Remove trailing slash from wiki_base if present
|
||||
self.wiki_base = wiki_base.rstrip("/")
|
||||
"""
|
||||
@@ -241,16 +245,12 @@ class ConfluenceConnector(
|
||||
)
|
||||
|
||||
# Create the main section for the page content
|
||||
sections: list[TextSection | ImageSection] = [
|
||||
TextSection(text=page_content, link=page_url)
|
||||
]
|
||||
sections = [Section(text=page_content, link=page_url)]
|
||||
|
||||
# Process comments if available
|
||||
comment_text = self._get_comment_string_for_page_id(page_id)
|
||||
if comment_text:
|
||||
sections.append(
|
||||
TextSection(text=comment_text, link=f"{page_url}#comments")
|
||||
)
|
||||
sections.append(Section(text=comment_text, link=f"{page_url}#comments"))
|
||||
|
||||
# Process attachments
|
||||
if "children" in page and "attachment" in page["children"]:
|
||||
@@ -263,27 +263,21 @@ class ConfluenceConnector(
|
||||
result = process_attachment(
|
||||
self.confluence_client,
|
||||
attachment,
|
||||
page_id,
|
||||
page_title,
|
||||
self.image_analysis_llm,
|
||||
)
|
||||
|
||||
if result and result.text:
|
||||
if result.text:
|
||||
# Create a section for the attachment text
|
||||
attachment_section = TextSection(
|
||||
attachment_section = Section(
|
||||
text=result.text,
|
||||
link=f"{page_url}#attachment-{attachment['id']}",
|
||||
)
|
||||
sections.append(attachment_section)
|
||||
elif result and result.file_name:
|
||||
# Create an ImageSection for image attachments
|
||||
image_section = ImageSection(
|
||||
link=f"{page_url}#attachment-{attachment['id']}",
|
||||
image_file_name=result.file_name,
|
||||
)
|
||||
sections.append(image_section)
|
||||
else:
|
||||
sections.append(attachment_section)
|
||||
elif result.error:
|
||||
logger.warning(
|
||||
f"Error processing attachment '{attachment.get('title')}':",
|
||||
f"{result.error if result else 'Unknown error'}",
|
||||
f"Error processing attachment '{attachment.get('title')}': {result.error}"
|
||||
)
|
||||
|
||||
# Extract metadata
|
||||
@@ -354,7 +348,7 @@ class ConfluenceConnector(
|
||||
# Now get attachments for that page:
|
||||
attachment_query = self._construct_attachment_query(page["id"])
|
||||
# We'll use the page's XML to provide context if we summarize an image
|
||||
page.get("body", {}).get("storage", {}).get("value", "")
|
||||
confluence_xml = page.get("body", {}).get("storage", {}).get("value", "")
|
||||
|
||||
for attachment in self.confluence_client.paginated_cql_retrieval(
|
||||
cql=attachment_query,
|
||||
@@ -362,7 +356,7 @@ class ConfluenceConnector(
|
||||
):
|
||||
attachment["metadata"].get("mediaType", "")
|
||||
if not validate_attachment_filetype(
|
||||
attachment,
|
||||
attachment, self.image_analysis_llm
|
||||
):
|
||||
continue
|
||||
|
||||
@@ -372,26 +366,23 @@ class ConfluenceConnector(
|
||||
response = convert_attachment_to_content(
|
||||
confluence_client=self.confluence_client,
|
||||
attachment=attachment,
|
||||
page_id=page["id"],
|
||||
page_context=confluence_xml,
|
||||
llm=self.image_analysis_llm,
|
||||
)
|
||||
if response is None:
|
||||
continue
|
||||
|
||||
content_text, file_storage_name = response
|
||||
|
||||
object_url = build_confluence_document_id(
|
||||
self.wiki_base, attachment["_links"]["webui"], self.is_cloud
|
||||
)
|
||||
|
||||
if content_text:
|
||||
doc.sections.append(
|
||||
TextSection(
|
||||
Section(
|
||||
text=content_text,
|
||||
link=object_url,
|
||||
)
|
||||
)
|
||||
elif file_storage_name:
|
||||
doc.sections.append(
|
||||
ImageSection(
|
||||
link=object_url,
|
||||
image_file_name=file_storage_name,
|
||||
)
|
||||
)
|
||||
@@ -471,7 +462,7 @@ class ConfluenceConnector(
|
||||
# If you skip images, you'll skip them in the permission sync
|
||||
attachment["metadata"].get("mediaType", "")
|
||||
if not validate_attachment_filetype(
|
||||
attachment,
|
||||
attachment, self.image_analysis_llm
|
||||
):
|
||||
continue
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import io
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
@@ -18,11 +19,17 @@ from requests import HTTPError
|
||||
|
||||
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET
|
||||
from onyx.configs.app_configs import (
|
||||
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
|
||||
)
|
||||
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
|
||||
from onyx.connectors.confluence.utils import _handle_http_error
|
||||
from onyx.connectors.confluence.utils import confluence_refresh_tokens
|
||||
from onyx.connectors.confluence.utils import get_start_param_from_url
|
||||
from onyx.connectors.confluence.utils import update_param_in_path
|
||||
from onyx.connectors.confluence.utils import validate_attachment_filetype
|
||||
from onyx.connectors.interfaces import CredentialsProviderInterface
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.html_utils import format_document_soup
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -801,6 +808,65 @@ def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str:
|
||||
return _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) or _USER_NOT_FOUND
|
||||
|
||||
|
||||
def attachment_to_content(
|
||||
confluence_client: OnyxConfluence,
|
||||
attachment: dict[str, Any],
|
||||
parent_content_id: str | None = None,
|
||||
) -> str | None:
|
||||
"""If it returns None, assume that we should skip this attachment."""
|
||||
if not validate_attachment_filetype(attachment):
|
||||
return None
|
||||
|
||||
if "api.atlassian.com" in confluence_client.url:
|
||||
# https://developer.atlassian.com/cloud/confluence/rest/v1/api-group-content---attachments/#api-wiki-rest-api-content-id-child-attachment-attachmentid-download-get
|
||||
if not parent_content_id:
|
||||
logger.warning(
|
||||
"parent_content_id is required to download attachments from Confluence Cloud!"
|
||||
)
|
||||
return None
|
||||
|
||||
download_link = (
|
||||
confluence_client.url
|
||||
+ f"/rest/api/content/{parent_content_id}/child/attachment/{attachment['id']}/download"
|
||||
)
|
||||
else:
|
||||
download_link = confluence_client.url + attachment["_links"]["download"]
|
||||
|
||||
attachment_size = attachment["extensions"]["fileSize"]
|
||||
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Skipping {download_link} due to size. "
|
||||
f"size={attachment_size} "
|
||||
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD}"
|
||||
)
|
||||
return None
|
||||
|
||||
logger.info(f"_attachment_to_content - _session.get: link={download_link}")
|
||||
|
||||
# why are we using session.get here? we probably won't retry these ... is that ok?
|
||||
response = confluence_client._session.get(download_link)
|
||||
if response.status_code != 200:
|
||||
logger.warning(
|
||||
f"Failed to fetch {download_link} with invalid status code {response.status_code}"
|
||||
)
|
||||
return None
|
||||
|
||||
extracted_text = extract_file_text(
|
||||
io.BytesIO(response.content),
|
||||
file_name=attachment["title"],
|
||||
break_on_unprocessable=False,
|
||||
)
|
||||
if len(extracted_text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Skipping {download_link} due to char count. "
|
||||
f"char count={len(extracted_text)} "
|
||||
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD}"
|
||||
)
|
||||
return None
|
||||
|
||||
return extracted_text
|
||||
|
||||
|
||||
def extract_text_from_confluence_html(
|
||||
confluence_client: OnyxConfluence,
|
||||
confluence_object: dict[str, Any],
|
||||
|
||||
@@ -22,7 +22,6 @@ from sqlalchemy.orm import Session
|
||||
from onyx.configs.app_configs import (
|
||||
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
|
||||
)
|
||||
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
|
||||
from onyx.configs.constants import FileOrigin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -36,6 +35,7 @@ from onyx.db.pg_file_store import upsert_pgfilestore
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.file_validation import is_valid_image_type
|
||||
from onyx.file_processing.image_utils import store_image_and_create_section
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -53,16 +53,17 @@ class TokenResponse(BaseModel):
|
||||
|
||||
|
||||
def validate_attachment_filetype(
|
||||
attachment: dict[str, Any],
|
||||
attachment: dict[str, Any], llm: LLM | None = None
|
||||
) -> bool:
|
||||
"""
|
||||
Validates if the attachment is a supported file type.
|
||||
If LLM is provided, also checks if it's an image that can be processed.
|
||||
"""
|
||||
attachment.get("metadata", {})
|
||||
media_type = attachment.get("metadata", {}).get("mediaType", "")
|
||||
|
||||
if media_type.startswith("image/"):
|
||||
return is_valid_image_type(media_type)
|
||||
return llm is not None and is_valid_image_type(media_type)
|
||||
|
||||
# For non-image files, check if we support the extension
|
||||
title = attachment.get("title", "")
|
||||
@@ -83,103 +84,55 @@ class AttachmentProcessingResult(BaseModel):
|
||||
error: str | None = None
|
||||
|
||||
|
||||
def _make_attachment_link(
|
||||
confluence_client: "OnyxConfluence",
|
||||
attachment: dict[str, Any],
|
||||
parent_content_id: str | None = None,
|
||||
) -> str | None:
|
||||
download_link = ""
|
||||
|
||||
if "api.atlassian.com" in confluence_client.url:
|
||||
# https://developer.atlassian.com/cloud/confluence/rest/v1/api-group-content---attachments/#api-wiki-rest-api-content-id-child-attachment-attachmentid-download-get
|
||||
if not parent_content_id:
|
||||
logger.warning(
|
||||
"parent_content_id is required to download attachments from Confluence Cloud!"
|
||||
)
|
||||
return None
|
||||
|
||||
download_link = (
|
||||
confluence_client.url
|
||||
+ f"/rest/api/content/{parent_content_id}/child/attachment/{attachment['id']}/download"
|
||||
def _download_attachment(
|
||||
confluence_client: "OnyxConfluence", attachment: dict[str, Any]
|
||||
) -> bytes | None:
|
||||
"""
|
||||
Retrieves the raw bytes of an attachment from Confluence. Returns None on error.
|
||||
"""
|
||||
download_link = confluence_client.url + attachment["_links"]["download"]
|
||||
resp = confluence_client._session.get(download_link)
|
||||
if resp.status_code != 200:
|
||||
logger.warning(
|
||||
f"Failed to fetch {download_link} with status code {resp.status_code}"
|
||||
)
|
||||
else:
|
||||
download_link = confluence_client.url + attachment["_links"]["download"]
|
||||
|
||||
return download_link
|
||||
return None
|
||||
return resp.content
|
||||
|
||||
|
||||
def process_attachment(
|
||||
confluence_client: "OnyxConfluence",
|
||||
attachment: dict[str, Any],
|
||||
parent_content_id: str | None,
|
||||
page_context: str,
|
||||
llm: LLM | None,
|
||||
) -> AttachmentProcessingResult:
|
||||
"""
|
||||
Processes a Confluence attachment. If it's a document, extracts text,
|
||||
or if it's an image, stores it for later analysis. Returns a structured result.
|
||||
or if it's an image and an LLM is available, summarizes it. Returns a structured result.
|
||||
"""
|
||||
try:
|
||||
# Get the media type from the attachment metadata
|
||||
media_type = attachment.get("metadata", {}).get("mediaType", "")
|
||||
|
||||
# Validate the attachment type
|
||||
if not validate_attachment_filetype(attachment):
|
||||
if not validate_attachment_filetype(attachment, llm):
|
||||
return AttachmentProcessingResult(
|
||||
text=None,
|
||||
file_name=None,
|
||||
error=f"Unsupported file type: {media_type}",
|
||||
)
|
||||
|
||||
attachment_link = _make_attachment_link(
|
||||
confluence_client, attachment, parent_content_id
|
||||
)
|
||||
if not attachment_link:
|
||||
return AttachmentProcessingResult(
|
||||
text=None, file_name=None, error="Failed to make attachment link"
|
||||
)
|
||||
|
||||
attachment_size = attachment["extensions"]["fileSize"]
|
||||
|
||||
if not media_type.startswith("image/"):
|
||||
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Skipping {attachment_link} due to size. "
|
||||
f"size={attachment_size} "
|
||||
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD}"
|
||||
)
|
||||
return AttachmentProcessingResult(
|
||||
text=None,
|
||||
file_name=None,
|
||||
error=f"Attachment text too long: {attachment_size} chars",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Downloading attachment: "
|
||||
f"title={attachment['title']} "
|
||||
f"length={attachment_size} "
|
||||
f"link={attachment_link}"
|
||||
)
|
||||
|
||||
# Download the attachment
|
||||
resp: requests.Response = confluence_client._session.get(attachment_link)
|
||||
if resp.status_code != 200:
|
||||
logger.warning(
|
||||
f"Failed to fetch {attachment_link} with status code {resp.status_code}"
|
||||
)
|
||||
raw_bytes = _download_attachment(confluence_client, attachment)
|
||||
if raw_bytes is None:
|
||||
return AttachmentProcessingResult(
|
||||
text=None,
|
||||
file_name=None,
|
||||
error=f"Attachment download status code is {resp.status_code}",
|
||||
text=None, file_name=None, error="Failed to download attachment"
|
||||
)
|
||||
|
||||
raw_bytes = resp.content
|
||||
if not raw_bytes:
|
||||
return AttachmentProcessingResult(
|
||||
text=None, file_name=None, error="attachment.content is None"
|
||||
)
|
||||
|
||||
# Process image attachments
|
||||
if media_type.startswith("image/"):
|
||||
# Process image attachments with LLM if available
|
||||
if media_type.startswith("image/") and llm:
|
||||
return _process_image_attachment(
|
||||
confluence_client, attachment, raw_bytes, media_type
|
||||
confluence_client, attachment, page_context, llm, raw_bytes, media_type
|
||||
)
|
||||
|
||||
# Process document attachments
|
||||
@@ -212,10 +165,12 @@ def process_attachment(
|
||||
def _process_image_attachment(
|
||||
confluence_client: "OnyxConfluence",
|
||||
attachment: dict[str, Any],
|
||||
page_context: str,
|
||||
llm: LLM,
|
||||
raw_bytes: bytes,
|
||||
media_type: str,
|
||||
) -> AttachmentProcessingResult:
|
||||
"""Process an image attachment by saving it without generating a summary."""
|
||||
"""Process an image attachment by saving it and generating a summary."""
|
||||
try:
|
||||
# Use the standardized image storage and section creation
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
@@ -225,14 +180,15 @@ def _process_image_attachment(
|
||||
file_name=Path(attachment["id"]).name,
|
||||
display_name=attachment["title"],
|
||||
media_type=media_type,
|
||||
llm=llm,
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
)
|
||||
logger.info(f"Stored image attachment with file name: {file_name}")
|
||||
|
||||
# Return empty text but include the file_name for later processing
|
||||
return AttachmentProcessingResult(text="", file_name=file_name, error=None)
|
||||
return AttachmentProcessingResult(
|
||||
text=section.text, file_name=file_name, error=None
|
||||
)
|
||||
except Exception as e:
|
||||
msg = f"Image storage failed for {attachment['title']}: {e}"
|
||||
msg = f"Image summarization failed for {attachment['title']}: {e}"
|
||||
logger.error(msg, exc_info=e)
|
||||
return AttachmentProcessingResult(text=None, file_name=None, error=msg)
|
||||
|
||||
@@ -293,15 +249,16 @@ def _process_text_attachment(
|
||||
def convert_attachment_to_content(
|
||||
confluence_client: "OnyxConfluence",
|
||||
attachment: dict[str, Any],
|
||||
page_id: str,
|
||||
page_context: str,
|
||||
llm: LLM | None,
|
||||
) -> tuple[str | None, str | None] | None:
|
||||
"""
|
||||
Facade function which:
|
||||
1. Validates attachment type
|
||||
2. Extracts content or stores image for later processing
|
||||
2. Extracts or summarizes content
|
||||
3. Returns (content_text, stored_file_name) or None if we should skip it
|
||||
"""
|
||||
media_type = attachment.get("metadata", {}).get("mediaType", "")
|
||||
media_type = attachment["metadata"]["mediaType"]
|
||||
# Quick check for unsupported types:
|
||||
if media_type.startswith("video/") or media_type == "application/gliffy+json":
|
||||
logger.warning(
|
||||
@@ -309,7 +266,7 @@ def convert_attachment_to_content(
|
||||
)
|
||||
return None
|
||||
|
||||
result = process_attachment(confluence_client, attachment, page_id)
|
||||
result = process_attachment(confluence_client, attachment, page_context, llm)
|
||||
if result.error is not None:
|
||||
logger.warning(
|
||||
f"Attachment {attachment['title']} encountered error: {result.error}"
|
||||
@@ -522,10 +479,6 @@ def attachment_to_file_record(
|
||||
download_link, absolute=True, not_json_response=True
|
||||
)
|
||||
|
||||
file_type = attachment.get("metadata", {}).get(
|
||||
"mediaType", "application/octet-stream"
|
||||
)
|
||||
|
||||
# Save image to file store
|
||||
file_name = f"confluence_attachment_{attachment['id']}"
|
||||
lobj_oid = create_populate_lobj(BytesIO(image_data), db_session)
|
||||
@@ -533,7 +486,7 @@ def attachment_to_file_record(
|
||||
file_name=file_name,
|
||||
display_name=attachment["title"],
|
||||
file_origin=FileOrigin.OTHER,
|
||||
file_type=file_type,
|
||||
file_type=attachment["metadata"]["mediaType"],
|
||||
lobj_oid=lobj_oid,
|
||||
db_session=db_session,
|
||||
commit=True,
|
||||
|
||||
@@ -4,7 +4,6 @@ from collections.abc import Iterable
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from discord import Client
|
||||
from discord.channel import TextChannel
|
||||
@@ -21,8 +20,7 @@ from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -34,7 +32,7 @@ _SNIPPET_LENGTH = 30
|
||||
|
||||
def _convert_message_to_document(
|
||||
message: DiscordMessage,
|
||||
sections: list[TextSection],
|
||||
sections: list[Section],
|
||||
) -> Document:
|
||||
"""
|
||||
Convert a discord message to a document
|
||||
@@ -80,7 +78,7 @@ def _convert_message_to_document(
|
||||
semantic_identifier=semantic_identifier,
|
||||
doc_updated_at=message.edited_at,
|
||||
title=title,
|
||||
sections=(cast(list[TextSection | ImageSection], sections)),
|
||||
sections=sections,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
@@ -125,8 +123,8 @@ async def _fetch_documents_from_channel(
|
||||
if channel_message.type != MessageType.default:
|
||||
continue
|
||||
|
||||
sections: list[TextSection] = [
|
||||
TextSection(
|
||||
sections: list[Section] = [
|
||||
Section(
|
||||
text=channel_message.content,
|
||||
link=channel_message.jump_url,
|
||||
)
|
||||
@@ -144,7 +142,7 @@ async def _fetch_documents_from_channel(
|
||||
continue
|
||||
|
||||
sections = [
|
||||
TextSection(
|
||||
Section(
|
||||
text=thread_message.content,
|
||||
link=thread_message.jump_url,
|
||||
)
|
||||
@@ -162,7 +160,7 @@ async def _fetch_documents_from_channel(
|
||||
continue
|
||||
|
||||
sections = [
|
||||
TextSection(
|
||||
Section(
|
||||
text=thread_message.content,
|
||||
link=thread_message.jump_url,
|
||||
)
|
||||
|
||||
@@ -3,7 +3,6 @@ import urllib.parse
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel
|
||||
@@ -21,8 +20,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.file_processing.html_utils import parse_html_page_basic
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
@@ -114,7 +112,7 @@ class DiscourseConnector(PollConnector):
|
||||
responders.append(BasicExpertInfo(display_name=responder_name))
|
||||
|
||||
sections.append(
|
||||
TextSection(link=topic_url, text=parse_html_page_basic(post["cooked"]))
|
||||
Section(link=topic_url, text=parse_html_page_basic(post["cooked"]))
|
||||
)
|
||||
category_name = self.category_id_map.get(topic["category_id"])
|
||||
|
||||
@@ -131,7 +129,7 @@ class DiscourseConnector(PollConnector):
|
||||
|
||||
doc = Document(
|
||||
id="_".join([DocumentSource.DISCOURSE.value, str(topic["id"])]),
|
||||
sections=cast(list[TextSection | ImageSection], sections),
|
||||
sections=sections,
|
||||
source=DocumentSource.DISCOURSE,
|
||||
semantic_identifier=topic["title"],
|
||||
doc_updated_at=time_str_to_utc(topic["last_posted_at"]),
|
||||
|
||||
@@ -19,7 +19,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.file_processing.html_utils import parse_html_page_basic
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
@@ -158,7 +158,7 @@ class Document360Connector(LoadConnector, PollConnector):
|
||||
|
||||
document = Document(
|
||||
id=article_details["id"],
|
||||
sections=[TextSection(link=doc_link, text=doc_text)],
|
||||
sections=[Section(link=doc_link, text=doc_text)],
|
||||
source=DocumentSource.DOCUMENT360,
|
||||
semantic_identifier=article_details["title"],
|
||||
doc_updated_at=updated_at,
|
||||
|
||||
@@ -19,7 +19,7 @@ from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -108,7 +108,7 @@ class DropboxConnector(LoadConnector, PollConnector):
|
||||
batch.append(
|
||||
Document(
|
||||
id=f"doc:{entry.id}",
|
||||
sections=[TextSection(link=link, text=text)],
|
||||
sections=[Section(link=link, text=text)],
|
||||
source=DocumentSource.DROPBOX,
|
||||
semantic_identifier=entry.name,
|
||||
doc_updated_at=modified_time,
|
||||
|
||||
@@ -24,7 +24,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.file_processing.extract_file_text import detect_encoding
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import get_file_ext
|
||||
@@ -111,7 +111,7 @@ def _process_egnyte_file(
|
||||
# Create the document
|
||||
return Document(
|
||||
id=f"egnyte-{file_metadata['entry_id']}",
|
||||
sections=[TextSection(text=file_content_raw.strip(), link=web_url)],
|
||||
sections=[Section(text=file_content_raw.strip(), link=web_url)],
|
||||
source=DocumentSource.EGNYTE,
|
||||
semantic_identifier=file_name,
|
||||
metadata=metadata,
|
||||
|
||||
@@ -16,8 +16,8 @@ from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.connectors.vision_enabled_connector import VisionEnabledConnector
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.pg_file_store import get_pgfilestore_by_file_name
|
||||
from onyx.file_processing.extract_file_text import extract_text_and_images
|
||||
@@ -26,6 +26,7 @@ from onyx.file_processing.extract_file_text import is_valid_file_ext
|
||||
from onyx.file_processing.extract_file_text import load_files_from_zip
|
||||
from onyx.file_processing.image_utils import store_image_and_create_section
|
||||
from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -58,44 +59,32 @@ def _read_files_and_metadata(
|
||||
|
||||
|
||||
def _create_image_section(
|
||||
llm: LLM | None,
|
||||
image_data: bytes,
|
||||
db_session: Session,
|
||||
parent_file_name: str,
|
||||
display_name: str,
|
||||
link: str | None = None,
|
||||
idx: int = 0,
|
||||
) -> tuple[ImageSection, str | None]:
|
||||
) -> tuple[Section, str | None]:
|
||||
"""
|
||||
Creates an ImageSection for an image file or embedded image.
|
||||
Stores the image in PGFileStore but does not generate a summary.
|
||||
|
||||
Args:
|
||||
image_data: Raw image bytes
|
||||
db_session: Database session
|
||||
parent_file_name: Name of the parent file (for embedded images)
|
||||
display_name: Display name for the image
|
||||
idx: Index for embedded images
|
||||
Create a Section object for a single image and store the image in PGFileStore.
|
||||
If summarization is enabled and we have an LLM, summarize the image.
|
||||
|
||||
Returns:
|
||||
Tuple of (ImageSection, stored_file_name or None)
|
||||
tuple: (Section object, file_name in PGFileStore or None if storage failed)
|
||||
"""
|
||||
# Create a unique identifier for the image
|
||||
file_name = f"{parent_file_name}_embedded_{idx}" if idx > 0 else parent_file_name
|
||||
# Create a unique file name for the embedded image
|
||||
file_name = f"{parent_file_name}_embedded_{idx}"
|
||||
|
||||
# Store the image and create a section
|
||||
try:
|
||||
section, stored_file_name = store_image_and_create_section(
|
||||
db_session=db_session,
|
||||
image_data=image_data,
|
||||
file_name=file_name,
|
||||
display_name=display_name,
|
||||
link=link,
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
)
|
||||
return section, stored_file_name
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store image {display_name}: {e}")
|
||||
raise e
|
||||
# Use the standardized utility to store the image and create a section
|
||||
return store_image_and_create_section(
|
||||
db_session=db_session,
|
||||
image_data=image_data,
|
||||
file_name=file_name,
|
||||
display_name=display_name,
|
||||
llm=llm,
|
||||
file_origin=FileOrigin.OTHER,
|
||||
)
|
||||
|
||||
|
||||
def _process_file(
|
||||
@@ -104,16 +93,12 @@ def _process_file(
|
||||
metadata: dict[str, Any] | None,
|
||||
pdf_pass: str | None,
|
||||
db_session: Session,
|
||||
llm: LLM | None,
|
||||
) -> list[Document]:
|
||||
"""
|
||||
Process a file and return a list of Documents.
|
||||
For images, creates ImageSection objects without summarization.
|
||||
For documents with embedded images, extracts and stores the images.
|
||||
Processes a single file, returning a list of Documents (typically one).
|
||||
Also handles embedded images if 'EMBEDDED_IMAGE_EXTRACTION_ENABLED' is true.
|
||||
"""
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
|
||||
# Get file extension and determine file type
|
||||
extension = get_file_ext(file_name)
|
||||
|
||||
# Fetch the DB record so we know the ID for internal URL
|
||||
@@ -129,6 +114,8 @@ def _process_file(
|
||||
return []
|
||||
|
||||
# Prepare doc metadata
|
||||
if metadata is None:
|
||||
metadata = {}
|
||||
file_display_name = metadata.get("file_display_name") or os.path.basename(file_name)
|
||||
|
||||
# Timestamps
|
||||
@@ -171,7 +158,6 @@ def _process_file(
|
||||
"title",
|
||||
"connector_type",
|
||||
"pdf_password",
|
||||
"mime_type",
|
||||
]
|
||||
}
|
||||
|
||||
@@ -184,45 +170,33 @@ def _process_file(
|
||||
title = metadata.get("title") or file_display_name
|
||||
|
||||
# 1) If the file itself is an image, handle that scenario quickly
|
||||
if extension in LoadConnector.IMAGE_EXTENSIONS:
|
||||
# Read the image data
|
||||
IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".webp"}
|
||||
if extension in IMAGE_EXTENSIONS:
|
||||
# Summarize or produce empty doc
|
||||
image_data = file.read()
|
||||
if not image_data:
|
||||
logger.warning(f"Empty image file: {file_name}")
|
||||
return []
|
||||
|
||||
# Create an ImageSection for the image
|
||||
try:
|
||||
section, _ = _create_image_section(
|
||||
image_data=image_data,
|
||||
db_session=db_session,
|
||||
parent_file_name=pg_record.file_name,
|
||||
display_name=title,
|
||||
image_section, _ = _create_image_section(
|
||||
llm, image_data, db_session, pg_record.file_name, title
|
||||
)
|
||||
return [
|
||||
Document(
|
||||
id=doc_id,
|
||||
sections=[image_section],
|
||||
source=source_type,
|
||||
semantic_identifier=file_display_name,
|
||||
title=title,
|
||||
doc_updated_at=final_time_updated,
|
||||
primary_owners=p_owners,
|
||||
secondary_owners=s_owners,
|
||||
metadata=metadata_tags,
|
||||
)
|
||||
]
|
||||
|
||||
return [
|
||||
Document(
|
||||
id=doc_id,
|
||||
sections=[section],
|
||||
source=source_type,
|
||||
semantic_identifier=file_display_name,
|
||||
title=title,
|
||||
doc_updated_at=final_time_updated,
|
||||
primary_owners=p_owners,
|
||||
secondary_owners=s_owners,
|
||||
metadata=metadata_tags,
|
||||
)
|
||||
]
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process image file {file_name}: {e}")
|
||||
return []
|
||||
|
||||
# 2) Otherwise: text-based approach. Possibly with embedded images.
|
||||
# 2) Otherwise: text-based approach. Possibly with embedded images if enabled.
|
||||
# (For example .docx with inline images).
|
||||
file.seek(0)
|
||||
text_content = ""
|
||||
embedded_images: list[tuple[bytes, str]] = []
|
||||
|
||||
# Extract text and images from the file
|
||||
text_content, embedded_images = extract_text_and_images(
|
||||
file=file,
|
||||
file_name=file_name,
|
||||
@@ -230,29 +204,24 @@ def _process_file(
|
||||
)
|
||||
|
||||
# Build sections: first the text as a single Section
|
||||
sections: list[TextSection | ImageSection] = []
|
||||
sections = []
|
||||
link_in_meta = metadata.get("link")
|
||||
if text_content.strip():
|
||||
sections.append(TextSection(link=link_in_meta, text=text_content.strip()))
|
||||
sections.append(Section(link=link_in_meta, text=text_content.strip()))
|
||||
|
||||
# Then any extracted images from docx, etc.
|
||||
for idx, (img_data, img_name) in enumerate(embedded_images, start=1):
|
||||
# Store each embedded image as a separate file in PGFileStore
|
||||
# and create a section with the image reference
|
||||
try:
|
||||
image_section, _ = _create_image_section(
|
||||
image_data=img_data,
|
||||
db_session=db_session,
|
||||
parent_file_name=pg_record.file_name,
|
||||
display_name=f"{title} - image {idx}",
|
||||
idx=idx,
|
||||
)
|
||||
sections.append(image_section)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to process embedded image {idx} in {file_name}: {e}"
|
||||
)
|
||||
|
||||
# and create a section with the image summary
|
||||
image_section, _ = _create_image_section(
|
||||
llm,
|
||||
img_data,
|
||||
db_session,
|
||||
pg_record.file_name,
|
||||
f"{title} - image {idx}",
|
||||
idx,
|
||||
)
|
||||
sections.append(image_section)
|
||||
return [
|
||||
Document(
|
||||
id=doc_id,
|
||||
@@ -268,10 +237,10 @@ def _process_file(
|
||||
]
|
||||
|
||||
|
||||
class LocalFileConnector(LoadConnector):
|
||||
class LocalFileConnector(LoadConnector, VisionEnabledConnector):
|
||||
"""
|
||||
Connector that reads files from Postgres and yields Documents, including
|
||||
embedded image extraction without summarization.
|
||||
optional embedded image extraction.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -283,6 +252,9 @@ class LocalFileConnector(LoadConnector):
|
||||
self.batch_size = batch_size
|
||||
self.pdf_pass: str | None = None
|
||||
|
||||
# Initialize vision LLM using the mixin
|
||||
self.initialize_vision_llm()
|
||||
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
self.pdf_pass = credentials.get("pdf_password")
|
||||
|
||||
@@ -314,6 +286,7 @@ class LocalFileConnector(LoadConnector):
|
||||
metadata=metadata,
|
||||
pdf_pass=self.pdf_pass,
|
||||
db_session=db_session,
|
||||
llm=self.image_analysis_llm,
|
||||
)
|
||||
documents.extend(new_docs)
|
||||
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from collections.abc import Iterator
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import cast
|
||||
from typing import List
|
||||
|
||||
import requests
|
||||
@@ -15,8 +14,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -47,7 +45,7 @@ _FIREFLIES_API_QUERY = """
|
||||
|
||||
|
||||
def _create_doc_from_transcript(transcript: dict) -> Document | None:
|
||||
sections: List[TextSection] = []
|
||||
sections: List[Section] = []
|
||||
current_speaker_name = None
|
||||
current_link = ""
|
||||
current_text = ""
|
||||
@@ -59,7 +57,7 @@ def _create_doc_from_transcript(transcript: dict) -> Document | None:
|
||||
if sentence["speaker_name"] != current_speaker_name:
|
||||
if current_speaker_name is not None:
|
||||
sections.append(
|
||||
TextSection(
|
||||
Section(
|
||||
link=current_link,
|
||||
text=current_text.strip(),
|
||||
)
|
||||
@@ -73,7 +71,7 @@ def _create_doc_from_transcript(transcript: dict) -> Document | None:
|
||||
|
||||
# Sometimes these links (links with a timestamp) do not work, it is a bug with Fireflies.
|
||||
sections.append(
|
||||
TextSection(
|
||||
Section(
|
||||
link=current_link,
|
||||
text=current_text.strip(),
|
||||
)
|
||||
@@ -96,7 +94,7 @@ def _create_doc_from_transcript(transcript: dict) -> Document | None:
|
||||
|
||||
return Document(
|
||||
id=fireflies_id,
|
||||
sections=cast(list[TextSection | ImageSection], sections),
|
||||
sections=sections,
|
||||
source=DocumentSource.FIREFLIES,
|
||||
semantic_identifier=meeting_title,
|
||||
metadata={},
|
||||
|
||||
@@ -14,7 +14,7 @@ from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.file_processing.html_utils import parse_html_page_basic
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -133,7 +133,7 @@ def _create_doc_from_ticket(ticket: dict, domain: str) -> Document:
|
||||
return Document(
|
||||
id=_FRESHDESK_ID_PREFIX + link,
|
||||
sections=[
|
||||
TextSection(
|
||||
Section(
|
||||
link=link,
|
||||
text=text,
|
||||
)
|
||||
|
||||
@@ -13,7 +13,7 @@ from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -183,7 +183,7 @@ def _convert_page_to_document(
|
||||
return Document(
|
||||
id=f"gitbook-{space_id}-{page_id}",
|
||||
sections=[
|
||||
TextSection(
|
||||
Section(
|
||||
link=page.get("urls", {}).get("app", ""),
|
||||
text=_extract_text_from_document(page_content),
|
||||
)
|
||||
|
||||
@@ -27,7 +27,7 @@ from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.utils.batching import batch_generator
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -87,9 +87,7 @@ def _batch_github_objects(
|
||||
def _convert_pr_to_document(pull_request: PullRequest) -> Document:
|
||||
return Document(
|
||||
id=pull_request.html_url,
|
||||
sections=[
|
||||
TextSection(link=pull_request.html_url, text=pull_request.body or "")
|
||||
],
|
||||
sections=[Section(link=pull_request.html_url, text=pull_request.body or "")],
|
||||
source=DocumentSource.GITHUB,
|
||||
semantic_identifier=pull_request.title,
|
||||
# updated_at is UTC time but is timezone unaware, explicitly add UTC
|
||||
@@ -111,7 +109,7 @@ def _fetch_issue_comments(issue: Issue) -> str:
|
||||
def _convert_issue_to_document(issue: Issue) -> Document:
|
||||
return Document(
|
||||
id=issue.html_url,
|
||||
sections=[TextSection(link=issue.html_url, text=issue.body or "")],
|
||||
sections=[Section(link=issue.html_url, text=issue.body or "")],
|
||||
source=DocumentSource.GITHUB,
|
||||
semantic_identifier=issue.title,
|
||||
# updated_at is UTC time but is timezone unaware
|
||||
|
||||
@@ -21,7 +21,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -56,7 +56,7 @@ def get_author(author: Any) -> BasicExpertInfo:
|
||||
def _convert_merge_request_to_document(mr: Any) -> Document:
|
||||
doc = Document(
|
||||
id=mr.web_url,
|
||||
sections=[TextSection(link=mr.web_url, text=mr.description or "")],
|
||||
sections=[Section(link=mr.web_url, text=mr.description or "")],
|
||||
source=DocumentSource.GITLAB,
|
||||
semantic_identifier=mr.title,
|
||||
# updated_at is UTC time but is timezone unaware, explicitly add UTC
|
||||
@@ -72,7 +72,7 @@ def _convert_merge_request_to_document(mr: Any) -> Document:
|
||||
def _convert_issue_to_document(issue: Any) -> Document:
|
||||
doc = Document(
|
||||
id=issue.web_url,
|
||||
sections=[TextSection(link=issue.web_url, text=issue.description or "")],
|
||||
sections=[Section(link=issue.web_url, text=issue.description or "")],
|
||||
source=DocumentSource.GITLAB,
|
||||
semantic_identifier=issue.title,
|
||||
# updated_at is UTC time but is timezone unaware, explicitly add UTC
|
||||
@@ -99,7 +99,7 @@ def _convert_code_to_document(
|
||||
file_url = f"{url}/{projectOwner}/{projectName}/-/blob/master/{file['path']}" # Construct the file URL
|
||||
doc = Document(
|
||||
id=file["id"],
|
||||
sections=[TextSection(link=file_url, text=file_content)],
|
||||
sections=[Section(link=file_url, text=file_content)],
|
||||
source=DocumentSource.GITLAB,
|
||||
semantic_identifier=file["name"],
|
||||
doc_updated_at=datetime.now().replace(
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from base64 import urlsafe_b64decode
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
|
||||
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
|
||||
@@ -29,9 +28,8 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
@@ -117,7 +115,7 @@ def _get_message_body(payload: dict[str, Any]) -> str:
|
||||
return message_body
|
||||
|
||||
|
||||
def message_to_section(message: Dict[str, Any]) -> tuple[TextSection, dict[str, str]]:
|
||||
def message_to_section(message: Dict[str, Any]) -> tuple[Section, dict[str, str]]:
|
||||
link = f"https://mail.google.com/mail/u/0/#inbox/{message['id']}"
|
||||
|
||||
payload = message.get("payload", {})
|
||||
@@ -144,7 +142,7 @@ def message_to_section(message: Dict[str, Any]) -> tuple[TextSection, dict[str,
|
||||
|
||||
message_body_text: str = _get_message_body(payload)
|
||||
|
||||
return TextSection(link=link, text=message_body_text + message_data), metadata
|
||||
return Section(link=link, text=message_body_text + message_data), metadata
|
||||
|
||||
|
||||
def thread_to_document(full_thread: Dict[str, Any]) -> Document | None:
|
||||
@@ -194,7 +192,7 @@ def thread_to_document(full_thread: Dict[str, Any]) -> Document | None:
|
||||
return Document(
|
||||
id=id,
|
||||
semantic_identifier=semantic_identifier,
|
||||
sections=cast(list[TextSection | ImageSection], sections),
|
||||
sections=sections,
|
||||
source=DocumentSource.GMAIL,
|
||||
# This is used to perform permission sync
|
||||
primary_owners=primary_owners,
|
||||
|
||||
@@ -18,7 +18,7 @@ from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -243,7 +243,7 @@ class GongConnector(LoadConnector, PollConnector):
|
||||
Document(
|
||||
id=call_id,
|
||||
sections=[
|
||||
TextSection(link=call_metadata["url"], text=transcript_text)
|
||||
Section(link=call_metadata["url"], text=transcript_text)
|
||||
],
|
||||
source=DocumentSource.GONG,
|
||||
# Should not ever be Untitled as a call cannot be made without a Title
|
||||
|
||||
@@ -43,7 +43,9 @@ from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.vision_enabled_connector import VisionEnabledConnector
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
|
||||
@@ -66,6 +68,7 @@ def _convert_single_file(
|
||||
creds: Any,
|
||||
primary_admin_email: str,
|
||||
file: dict[str, Any],
|
||||
image_analysis_llm: LLM | None,
|
||||
) -> Any:
|
||||
user_email = file.get("owners", [{}])[0].get("emailAddress") or primary_admin_email
|
||||
user_drive_service = get_drive_service(creds, user_email=user_email)
|
||||
@@ -74,6 +77,7 @@ def _convert_single_file(
|
||||
file=file,
|
||||
drive_service=user_drive_service,
|
||||
docs_service=docs_service,
|
||||
image_analysis_llm=image_analysis_llm, # pass the LLM so doc_conversion can summarize images
|
||||
)
|
||||
|
||||
|
||||
@@ -112,7 +116,9 @@ def _clean_requested_drive_ids(
|
||||
return valid_requested_drive_ids, filtered_folder_ids
|
||||
|
||||
|
||||
class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
class GoogleDriveConnector(
|
||||
LoadConnector, PollConnector, SlimConnector, VisionEnabledConnector
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
include_shared_drives: bool = False,
|
||||
@@ -145,6 +151,9 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
if continue_on_failure is not None:
|
||||
logger.warning("The 'continue_on_failure' parameter is deprecated.")
|
||||
|
||||
# Initialize vision LLM using the mixin
|
||||
self.initialize_vision_llm()
|
||||
|
||||
if (
|
||||
not include_shared_drives
|
||||
and not include_my_drives
|
||||
@@ -530,6 +539,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
_convert_single_file,
|
||||
self.creds,
|
||||
self.primary_admin_email,
|
||||
image_analysis_llm=self.image_analysis_llm, # Use the mixin's LLM
|
||||
)
|
||||
|
||||
# Fetch files in batches
|
||||
|
||||
@@ -1,50 +1,40 @@
|
||||
import io
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
from datetime import timezone
|
||||
from tempfile import NamedTemporaryFile
|
||||
|
||||
from googleapiclient.http import MediaIoBaseDownload # type: ignore
|
||||
import openpyxl # type: ignore
|
||||
from googleapiclient.discovery import build # type: ignore
|
||||
from googleapiclient.errors import HttpError # type: ignore
|
||||
|
||||
from onyx.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.connectors.google_drive.constants import DRIVE_FOLDER_TYPE
|
||||
from onyx.connectors.google_drive.constants import DRIVE_SHORTCUT_TYPE
|
||||
from onyx.connectors.google_drive.constants import UNSUPPORTED_FILE_TYPE_CONTENT
|
||||
from onyx.connectors.google_drive.models import GDriveMimeType
|
||||
from onyx.connectors.google_drive.models import GoogleDriveFileType
|
||||
from onyx.connectors.google_drive.section_extraction import get_document_sections
|
||||
from onyx.connectors.google_utils.resources import GoogleDocsService
|
||||
from onyx.connectors.google_utils.resources import GoogleDriveService
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.file_processing.extract_file_text import docx_to_text_and_images
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.extract_file_text import pptx_to_text
|
||||
from onyx.file_processing.extract_file_text import read_pdf_file
|
||||
from onyx.file_processing.extract_file_text import xlsx_to_text
|
||||
from onyx.file_processing.file_validation import is_valid_image_type
|
||||
from onyx.file_processing.image_summarization import summarize_image_with_error_handling
|
||||
from onyx.file_processing.image_utils import store_image_and_create_section
|
||||
from onyx.file_processing.unstructured import get_unstructured_api_key
|
||||
from onyx.file_processing.unstructured import unstructured_to_text
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Mapping of Google Drive mime types to export formats
|
||||
GOOGLE_MIME_TYPES_TO_EXPORT = {
|
||||
GDriveMimeType.DOC.value: "text/plain",
|
||||
GDriveMimeType.SPREADSHEET.value: "text/csv",
|
||||
GDriveMimeType.PPT.value: "text/plain",
|
||||
}
|
||||
|
||||
# Define Google MIME types mapping
|
||||
GOOGLE_MIME_TYPES = {
|
||||
GDriveMimeType.DOC.value: "text/plain",
|
||||
GDriveMimeType.SPREADSHEET.value: "text/csv",
|
||||
GDriveMimeType.PPT.value: "text/plain",
|
||||
}
|
||||
|
||||
|
||||
def _summarize_drive_image(
|
||||
image_data: bytes, image_name: str, image_analysis_llm: LLM | None
|
||||
@@ -76,137 +66,259 @@ def is_gdrive_image_mime_type(mime_type: str) -> bool:
|
||||
def _extract_sections_basic(
|
||||
file: dict[str, str],
|
||||
service: GoogleDriveService,
|
||||
) -> list[TextSection | ImageSection]:
|
||||
"""Extract text and images from a Google Drive file."""
|
||||
file_id = file["id"]
|
||||
file_name = file["name"]
|
||||
image_analysis_llm: LLM | None = None,
|
||||
) -> list[Section]:
|
||||
"""
|
||||
Extends the existing logic to handle either a docx with embedded images
|
||||
or standalone images (PNG, JPG, etc).
|
||||
"""
|
||||
mime_type = file["mimeType"]
|
||||
link = file.get("webViewLink", "")
|
||||
link = file["webViewLink"]
|
||||
file_name = file.get("name", file["id"])
|
||||
supported_file_types = set(item.value for item in GDriveMimeType)
|
||||
|
||||
try:
|
||||
# For Google Docs, Sheets, and Slides, export as plain text
|
||||
if mime_type in GOOGLE_MIME_TYPES_TO_EXPORT:
|
||||
export_mime_type = GOOGLE_MIME_TYPES_TO_EXPORT[mime_type]
|
||||
# Use the correct API call for exporting files
|
||||
request = service.files().export_media(
|
||||
fileId=file_id, mimeType=export_mime_type
|
||||
)
|
||||
response_bytes = io.BytesIO()
|
||||
downloader = MediaIoBaseDownload(response_bytes, request)
|
||||
done = False
|
||||
while not done:
|
||||
_, done = downloader.next_chunk()
|
||||
# 1) If the file is an image, retrieve the raw bytes, optionally summarize
|
||||
if is_gdrive_image_mime_type(mime_type):
|
||||
try:
|
||||
response = service.files().get_media(fileId=file["id"]).execute()
|
||||
|
||||
response = response_bytes.getvalue()
|
||||
if not response:
|
||||
logger.warning(f"Failed to export {file_name} as {export_mime_type}")
|
||||
return []
|
||||
|
||||
text = response.decode("utf-8")
|
||||
return [TextSection(link=link, text=text)]
|
||||
|
||||
# For other file types, download the file
|
||||
# Use the correct API call for downloading files
|
||||
request = service.files().get_media(fileId=file_id)
|
||||
response_bytes = io.BytesIO()
|
||||
downloader = MediaIoBaseDownload(response_bytes, request)
|
||||
done = False
|
||||
while not done:
|
||||
_, done = downloader.next_chunk()
|
||||
|
||||
response = response_bytes.getvalue()
|
||||
if not response:
|
||||
logger.warning(f"Failed to download {file_name}")
|
||||
return []
|
||||
|
||||
# Process based on mime type
|
||||
if mime_type == "text/plain":
|
||||
text = response.decode("utf-8")
|
||||
return [TextSection(link=link, text=text)]
|
||||
|
||||
elif (
|
||||
mime_type
|
||||
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
||||
):
|
||||
text, _ = docx_to_text_and_images(io.BytesIO(response))
|
||||
return [TextSection(link=link, text=text)]
|
||||
|
||||
elif (
|
||||
mime_type
|
||||
== "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
|
||||
):
|
||||
text = xlsx_to_text(io.BytesIO(response))
|
||||
return [TextSection(link=link, text=text)]
|
||||
|
||||
elif (
|
||||
mime_type
|
||||
== "application/vnd.openxmlformats-officedocument.presentationml.presentation"
|
||||
):
|
||||
text = pptx_to_text(io.BytesIO(response))
|
||||
return [TextSection(link=link, text=text)]
|
||||
|
||||
elif is_gdrive_image_mime_type(mime_type):
|
||||
# For images, store them for later processing
|
||||
sections: list[TextSection | ImageSection] = []
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
section, embedded_id = store_image_and_create_section(
|
||||
db_session=db_session,
|
||||
image_data=response,
|
||||
file_name=file_id,
|
||||
display_name=file_name,
|
||||
media_type=mime_type,
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
link=link,
|
||||
)
|
||||
sections.append(section)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process image {file_name}: {e}")
|
||||
return sections
|
||||
|
||||
elif mime_type == "application/pdf":
|
||||
text, _pdf_meta, images = read_pdf_file(io.BytesIO(response))
|
||||
pdf_sections: list[TextSection | ImageSection] = [
|
||||
TextSection(link=link, text=text)
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
section, _ = store_image_and_create_section(
|
||||
db_session=db_session,
|
||||
image_data=response,
|
||||
file_name=file["id"],
|
||||
display_name=file_name,
|
||||
media_type=mime_type,
|
||||
llm=image_analysis_llm,
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
)
|
||||
return [section]
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch or summarize image: {e}")
|
||||
return [
|
||||
Section(
|
||||
link=link,
|
||||
text="",
|
||||
image_file_name=link,
|
||||
)
|
||||
]
|
||||
|
||||
# Process embedded images in the PDF
|
||||
if mime_type not in supported_file_types:
|
||||
# Unsupported file types can still have a title, finding this way is still useful
|
||||
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
|
||||
|
||||
try:
|
||||
# ---------------------------
|
||||
# Google Sheets extraction
|
||||
if mime_type == GDriveMimeType.SPREADSHEET.value:
|
||||
try:
|
||||
sheets_service = build(
|
||||
"sheets", "v4", credentials=service._http.credentials
|
||||
)
|
||||
spreadsheet = (
|
||||
sheets_service.spreadsheets()
|
||||
.get(spreadsheetId=file["id"])
|
||||
.execute()
|
||||
)
|
||||
|
||||
sections = []
|
||||
for sheet in spreadsheet["sheets"]:
|
||||
sheet_name = sheet["properties"]["title"]
|
||||
sheet_id = sheet["properties"]["sheetId"]
|
||||
|
||||
# Get sheet dimensions
|
||||
grid_properties = sheet["properties"].get("gridProperties", {})
|
||||
row_count = grid_properties.get("rowCount", 1000)
|
||||
column_count = grid_properties.get("columnCount", 26)
|
||||
|
||||
# Convert column count to letter (e.g., 26 -> Z, 27 -> AA)
|
||||
end_column = ""
|
||||
while column_count:
|
||||
column_count, remainder = divmod(column_count - 1, 26)
|
||||
end_column = chr(65 + remainder) + end_column
|
||||
|
||||
range_name = f"'{sheet_name}'!A1:{end_column}{row_count}"
|
||||
|
||||
try:
|
||||
result = (
|
||||
sheets_service.spreadsheets()
|
||||
.values()
|
||||
.get(spreadsheetId=file["id"], range=range_name)
|
||||
.execute()
|
||||
)
|
||||
values = result.get("values", [])
|
||||
|
||||
if values:
|
||||
text = f"Sheet: {sheet_name}\n"
|
||||
for row in values:
|
||||
text += "\t".join(str(cell) for cell in row) + "\n"
|
||||
sections.append(
|
||||
Section(
|
||||
link=f"{link}#gid={sheet_id}",
|
||||
text=text,
|
||||
)
|
||||
)
|
||||
except HttpError as e:
|
||||
logger.warning(
|
||||
f"Error fetching data for sheet '{sheet_name}': {e}"
|
||||
)
|
||||
continue
|
||||
return sections
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Ran into exception '{e}' when pulling data from Google Sheet '{file['name']}'."
|
||||
" Falling back to basic extraction."
|
||||
)
|
||||
# ---------------------------
|
||||
# Microsoft Excel (.xlsx or .xls) extraction branch
|
||||
elif mime_type in [
|
||||
GDriveMimeType.SPREADSHEET_OPEN_FORMAT.value,
|
||||
GDriveMimeType.SPREADSHEET_MS_EXCEL.value,
|
||||
]:
|
||||
try:
|
||||
response = service.files().get_media(fileId=file["id"]).execute()
|
||||
|
||||
with NamedTemporaryFile(suffix=".xlsx", delete=True) as tmp:
|
||||
tmp.write(response)
|
||||
tmp_path = tmp.name
|
||||
|
||||
section_separator = "\n\n"
|
||||
workbook = openpyxl.load_workbook(tmp_path, read_only=True)
|
||||
|
||||
# Work similarly to the xlsx_to_text function used for file connector
|
||||
# but returns Sections instead of a string
|
||||
sections = [
|
||||
Section(
|
||||
link=link,
|
||||
text=(
|
||||
f"Sheet: {sheet.title}\n\n"
|
||||
+ section_separator.join(
|
||||
",".join(map(str, row))
|
||||
for row in sheet.iter_rows(
|
||||
min_row=1, values_only=True
|
||||
)
|
||||
if row
|
||||
)
|
||||
),
|
||||
)
|
||||
for sheet in workbook.worksheets
|
||||
]
|
||||
|
||||
return sections
|
||||
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error extracting data from Excel file '{file['name']}': {e}"
|
||||
)
|
||||
return [
|
||||
Section(link=link, text="Error extracting data from Excel file")
|
||||
]
|
||||
|
||||
# ---------------------------
|
||||
# Export for Google Docs, PPT, and fallback for spreadsheets
|
||||
if mime_type in [
|
||||
GDriveMimeType.DOC.value,
|
||||
GDriveMimeType.PPT.value,
|
||||
GDriveMimeType.SPREADSHEET.value,
|
||||
]:
|
||||
export_mime_type = (
|
||||
"text/plain"
|
||||
if mime_type != GDriveMimeType.SPREADSHEET.value
|
||||
else "text/csv"
|
||||
)
|
||||
text = (
|
||||
service.files()
|
||||
.export(fileId=file["id"], mimeType=export_mime_type)
|
||||
.execute()
|
||||
.decode("utf-8")
|
||||
)
|
||||
return [Section(link=link, text=text)]
|
||||
|
||||
# ---------------------------
|
||||
# Plain text and Markdown files
|
||||
elif mime_type in [
|
||||
GDriveMimeType.PLAIN_TEXT.value,
|
||||
GDriveMimeType.MARKDOWN.value,
|
||||
]:
|
||||
text_data = (
|
||||
service.files().get_media(fileId=file["id"]).execute().decode("utf-8")
|
||||
)
|
||||
return [Section(link=link, text=text_data)]
|
||||
|
||||
# ---------------------------
|
||||
# Word, PowerPoint, PDF files
|
||||
elif mime_type in [
|
||||
GDriveMimeType.WORD_DOC.value,
|
||||
GDriveMimeType.POWERPOINT.value,
|
||||
GDriveMimeType.PDF.value,
|
||||
]:
|
||||
response_bytes = service.files().get_media(fileId=file["id"]).execute()
|
||||
|
||||
# Optionally use Unstructured
|
||||
if get_unstructured_api_key():
|
||||
text = unstructured_to_text(
|
||||
file=io.BytesIO(response_bytes),
|
||||
file_name=file_name,
|
||||
)
|
||||
return [Section(link=link, text=text)]
|
||||
|
||||
if mime_type == GDriveMimeType.WORD_DOC.value:
|
||||
# Use docx_to_text_and_images to get text plus embedded images
|
||||
text, embedded_images = docx_to_text_and_images(
|
||||
file=io.BytesIO(response_bytes),
|
||||
)
|
||||
sections = []
|
||||
if text.strip():
|
||||
sections.append(Section(link=link, text=text.strip()))
|
||||
|
||||
# Process each embedded image using the standardized function
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
for idx, (img_data, img_name) in enumerate(images):
|
||||
section, embedded_id = store_image_and_create_section(
|
||||
for idx, (img_data, img_name) in enumerate(
|
||||
embedded_images, start=1
|
||||
):
|
||||
# Create a unique identifier for the embedded image
|
||||
embedded_id = f"{file['id']}_embedded_{idx}"
|
||||
|
||||
section, _ = store_image_and_create_section(
|
||||
db_session=db_session,
|
||||
image_data=img_data,
|
||||
file_name=f"{file_id}_img_{idx}",
|
||||
file_name=embedded_id,
|
||||
display_name=img_name or f"{file_name} - image {idx}",
|
||||
llm=image_analysis_llm,
|
||||
file_origin=FileOrigin.CONNECTOR,
|
||||
)
|
||||
pdf_sections.append(section)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to process PDF images in {file_name}: {e}")
|
||||
return pdf_sections
|
||||
sections.append(section)
|
||||
return sections
|
||||
|
||||
else:
|
||||
# For unsupported file types, try to extract text
|
||||
try:
|
||||
text = extract_file_text(io.BytesIO(response), file_name)
|
||||
return [TextSection(link=link, text=text)]
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to extract text from {file_name}: {e}")
|
||||
return []
|
||||
elif mime_type == GDriveMimeType.PDF.value:
|
||||
text, _pdf_meta, images = read_pdf_file(io.BytesIO(response_bytes))
|
||||
return [Section(link=link, text=text)]
|
||||
|
||||
elif mime_type == GDriveMimeType.POWERPOINT.value:
|
||||
text_data = pptx_to_text(io.BytesIO(response_bytes))
|
||||
return [Section(link=link, text=text_data)]
|
||||
|
||||
# Catch-all case, should not happen since there should be specific handling
|
||||
# for each of the supported file types
|
||||
error_message = f"Unsupported file type: {mime_type}"
|
||||
logger.error(error_message)
|
||||
raise ValueError(error_message)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing file {file_name}: {e}")
|
||||
return []
|
||||
logger.exception(f"Error extracting sections from file: {e}")
|
||||
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
|
||||
|
||||
|
||||
def convert_drive_item_to_document(
|
||||
file: GoogleDriveFileType,
|
||||
drive_service: GoogleDriveService,
|
||||
docs_service: GoogleDocsService,
|
||||
image_analysis_llm: LLM | None,
|
||||
) -> Document | None:
|
||||
"""
|
||||
Main entry point for converting a Google Drive file => Document object.
|
||||
Now we accept an optional `llm` to pass to `_extract_sections_basic`.
|
||||
"""
|
||||
try:
|
||||
# skip shortcuts or folders
|
||||
@@ -215,50 +327,44 @@ def convert_drive_item_to_document(
|
||||
return None
|
||||
|
||||
# If it's a Google Doc, we might do advanced parsing
|
||||
sections: list[TextSection | ImageSection] = []
|
||||
|
||||
# Try to get sections using the advanced method first
|
||||
sections: list[Section] = []
|
||||
if file.get("mimeType") == GDriveMimeType.DOC.value:
|
||||
try:
|
||||
doc_sections = get_document_sections(
|
||||
docs_service=docs_service, doc_id=file.get("id", "")
|
||||
)
|
||||
if doc_sections:
|
||||
sections = cast(list[TextSection | ImageSection], doc_sections)
|
||||
# get_document_sections is the advanced approach for Google Docs
|
||||
sections = get_document_sections(docs_service, file["id"])
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Error in advanced parsing: {e}. Falling back to basic extraction."
|
||||
f"Failed to pull google doc sections from '{file['name']}': {e}. "
|
||||
"Falling back to basic extraction."
|
||||
)
|
||||
|
||||
# If we don't have sections yet, use the basic extraction method
|
||||
# If not a doc, or if we failed above, do our 'basic' approach
|
||||
if not sections:
|
||||
sections = _extract_sections_basic(file, drive_service)
|
||||
sections = _extract_sections_basic(file, drive_service, image_analysis_llm)
|
||||
|
||||
# If we still don't have any sections, skip this file
|
||||
if not sections:
|
||||
logger.warning(f"No content extracted from {file.get('name')}. Skipping.")
|
||||
return None
|
||||
|
||||
doc_id = file["webViewLink"]
|
||||
updated_time = datetime.fromisoformat(file["modifiedTime"]).astimezone(
|
||||
timezone.utc
|
||||
)
|
||||
|
||||
# Create the document
|
||||
return Document(
|
||||
id=doc_id,
|
||||
sections=sections,
|
||||
source=DocumentSource.GOOGLE_DRIVE,
|
||||
semantic_identifier=file.get("name", ""),
|
||||
metadata={
|
||||
"owner_names": ", ".join(
|
||||
owner.get("displayName", "") for owner in file.get("owners", [])
|
||||
),
|
||||
},
|
||||
doc_updated_at=datetime.fromisoformat(
|
||||
file.get("modifiedTime", "").replace("Z", "+00:00")
|
||||
),
|
||||
semantic_identifier=file["name"],
|
||||
doc_updated_at=updated_time,
|
||||
metadata={}, # or any metadata from 'file'
|
||||
additional_info=file.get("id"),
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error converting file {file.get('name')}: {e}")
|
||||
return None
|
||||
logger.exception(f"Error converting file '{file.get('name')}' to Document: {e}")
|
||||
if not CONTINUE_ON_CONNECTOR_FAILURE:
|
||||
raise
|
||||
return None
|
||||
|
||||
|
||||
def build_slim_document(file: GoogleDriveFileType) -> SlimDocument | None:
|
||||
|
||||
@@ -3,7 +3,7 @@ from typing import Any
|
||||
from pydantic import BaseModel
|
||||
|
||||
from onyx.connectors.google_utils.resources import GoogleDocsService
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.models import Section
|
||||
|
||||
|
||||
class CurrentHeading(BaseModel):
|
||||
@@ -37,7 +37,7 @@ def _extract_text_from_paragraph(paragraph: dict[str, Any]) -> str:
|
||||
def get_document_sections(
|
||||
docs_service: GoogleDocsService,
|
||||
doc_id: str,
|
||||
) -> list[TextSection]:
|
||||
) -> list[Section]:
|
||||
"""Extracts sections from a Google Doc, including their headings and content"""
|
||||
# Fetch the document structure
|
||||
doc = docs_service.documents().get(documentId=doc_id).execute()
|
||||
@@ -45,7 +45,7 @@ def get_document_sections(
|
||||
# Get the content
|
||||
content = doc.get("body", {}).get("content", [])
|
||||
|
||||
sections: list[TextSection] = []
|
||||
sections: list[Section] = []
|
||||
current_section: list[str] = []
|
||||
current_heading: CurrentHeading | None = None
|
||||
|
||||
@@ -70,7 +70,7 @@ def get_document_sections(
|
||||
heading_text = current_heading.text
|
||||
section_text = f"{heading_text}\n" + "\n".join(current_section)
|
||||
sections.append(
|
||||
TextSection(
|
||||
Section(
|
||||
text=section_text.strip(),
|
||||
link=_build_gdoc_section_link(doc_id, current_heading.id),
|
||||
)
|
||||
@@ -96,7 +96,7 @@ def get_document_sections(
|
||||
if current_heading is not None and current_section:
|
||||
section_text = f"{current_heading.text}\n" + "\n".join(current_section)
|
||||
sections.append(
|
||||
TextSection(
|
||||
Section(
|
||||
text=section_text.strip(),
|
||||
link=_build_gdoc_section_link(doc_id, current_heading.id),
|
||||
)
|
||||
|
||||
@@ -12,7 +12,7 @@ from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.db.engine import get_sqlalchemy_engine
|
||||
from onyx.file_processing.extract_file_text import load_files_from_zip
|
||||
from onyx.file_processing.extract_file_text import read_text_file
|
||||
@@ -118,7 +118,7 @@ class GoogleSitesConnector(LoadConnector):
|
||||
source=DocumentSource.GOOGLE_SITES,
|
||||
semantic_identifier=title,
|
||||
sections=[
|
||||
TextSection(
|
||||
Section(
|
||||
link=(self.base_url.rstrip("/") + "/" + path.lstrip("/"))
|
||||
if path
|
||||
else "",
|
||||
|
||||
@@ -15,7 +15,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.file_processing.html_utils import parse_html_page_basic
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -120,7 +120,7 @@ class GuruConnector(LoadConnector, PollConnector):
|
||||
doc_batch.append(
|
||||
Document(
|
||||
id=card["id"],
|
||||
sections=[TextSection(link=link, text=content_text)],
|
||||
sections=[Section(link=link, text=content_text)],
|
||||
source=DocumentSource.GURU,
|
||||
semantic_identifier=title,
|
||||
doc_updated_at=latest_time,
|
||||
|
||||
@@ -13,7 +13,7 @@ from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
HUBSPOT_BASE_URL = "https://app.hubspot.com/contacts/"
|
||||
@@ -108,7 +108,7 @@ class HubSpotConnector(LoadConnector, PollConnector):
|
||||
doc_batch.append(
|
||||
Document(
|
||||
id=ticket.id,
|
||||
sections=[TextSection(link=link, text=content_text)],
|
||||
sections=[Section(link=link, text=content_text)],
|
||||
source=DocumentSource.HUBSPOT,
|
||||
semantic_identifier=title,
|
||||
# Is already in tzutc, just replacing the timezone format
|
||||
|
||||
@@ -24,8 +24,6 @@ CheckpointOutput = Generator[Document | ConnectorFailure, None, ConnectorCheckpo
|
||||
|
||||
class BaseConnector(abc.ABC):
|
||||
REDIS_KEY_PREFIX = "da_connector_data:"
|
||||
# Common image file extensions supported across connectors
|
||||
IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".webp", ".gif"}
|
||||
|
||||
@abc.abstractmethod
|
||||
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
|
||||
|
||||
@@ -21,8 +21,7 @@ from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.retry_wrapper import request_with_retries
|
||||
|
||||
@@ -238,30 +237,22 @@ class LinearConnector(LoadConnector, PollConnector, OAuthConnector):
|
||||
documents: list[Document] = []
|
||||
for edge in edges:
|
||||
node = edge["node"]
|
||||
# Create sections for description and comments
|
||||
sections = [
|
||||
TextSection(
|
||||
link=node["url"],
|
||||
text=node["description"] or "",
|
||||
)
|
||||
]
|
||||
|
||||
# Add comment sections
|
||||
for comment in node["comments"]["nodes"]:
|
||||
sections.append(
|
||||
TextSection(
|
||||
link=node["url"],
|
||||
text=comment["body"] or "",
|
||||
)
|
||||
)
|
||||
|
||||
# Cast the sections list to the expected type
|
||||
typed_sections = cast(list[TextSection | ImageSection], sections)
|
||||
|
||||
documents.append(
|
||||
Document(
|
||||
id=node["id"],
|
||||
sections=typed_sections,
|
||||
sections=[
|
||||
Section(
|
||||
link=node["url"],
|
||||
text=node["description"] or "",
|
||||
)
|
||||
]
|
||||
+ [
|
||||
Section(
|
||||
link=node["url"],
|
||||
text=comment["body"] or "",
|
||||
)
|
||||
for comment in node["comments"]["nodes"]
|
||||
],
|
||||
source=DocumentSource.LINEAR,
|
||||
semantic_identifier=f"[{node['identifier']}] {node['title']}",
|
||||
title=node["title"],
|
||||
|
||||
@@ -17,7 +17,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.file_processing.html_utils import parse_html_page_basic
|
||||
from onyx.file_processing.html_utils import strip_excessive_newlines_and_spaces
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -162,7 +162,7 @@ class LoopioConnector(LoadConnector, PollConnector):
|
||||
doc_batch.append(
|
||||
Document(
|
||||
id=str(entry["id"]),
|
||||
sections=[TextSection(link=link, text=content_text)],
|
||||
sections=[Section(link=link, text=content_text)],
|
||||
source=DocumentSource.LOOPIO,
|
||||
semantic_identifier=questions[0],
|
||||
doc_updated_at=latest_time,
|
||||
|
||||
@@ -6,7 +6,6 @@ import tempfile
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import ClassVar
|
||||
|
||||
import pywikibot.time # type: ignore[import-untyped]
|
||||
@@ -21,8 +20,7 @@ from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.mediawiki.family import family_class_dispatch
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -62,14 +60,14 @@ def get_doc_from_page(
|
||||
sections_extracted: textlib.Content = textlib.extract_sections(page_text, site)
|
||||
|
||||
sections = [
|
||||
TextSection(
|
||||
Section(
|
||||
link=f"{page.full_url()}#" + section.heading.replace(" ", "_"),
|
||||
text=section.title + section.content,
|
||||
)
|
||||
for section in sections_extracted.sections
|
||||
]
|
||||
sections.append(
|
||||
TextSection(
|
||||
Section(
|
||||
link=page.full_url(),
|
||||
text=sections_extracted.header,
|
||||
)
|
||||
@@ -81,7 +79,7 @@ def get_doc_from_page(
|
||||
doc_updated_at=pywikibot_timestamp_to_utc_datetime(
|
||||
page.latest_revision.timestamp
|
||||
),
|
||||
sections=cast(list[TextSection | ImageSection], sections),
|
||||
sections=sections,
|
||||
semantic_identifier=page.title(),
|
||||
metadata={"categories": [category.title() for category in page.categories()]},
|
||||
id=f"MEDIAWIKI_{page.pageid}_{page.full_url()}",
|
||||
|
||||
@@ -28,25 +28,9 @@ class ConnectorMissingCredentialError(PermissionError):
|
||||
|
||||
|
||||
class Section(BaseModel):
|
||||
"""Base section class with common attributes"""
|
||||
|
||||
link: str | None = None
|
||||
text: str | None = None
|
||||
image_file_name: str | None = None
|
||||
|
||||
|
||||
class TextSection(Section):
|
||||
"""Section containing text content"""
|
||||
|
||||
text: str
|
||||
link: str | None = None
|
||||
|
||||
|
||||
class ImageSection(Section):
|
||||
"""Section containing an image reference"""
|
||||
|
||||
image_file_name: str
|
||||
link: str | None = None
|
||||
image_file_name: str | None = None
|
||||
|
||||
|
||||
class BasicExpertInfo(BaseModel):
|
||||
@@ -116,7 +100,7 @@ class DocumentBase(BaseModel):
|
||||
"""Used for Onyx ingestion api, the ID is inferred before use if not provided"""
|
||||
|
||||
id: str | None = None
|
||||
sections: list[TextSection | ImageSection]
|
||||
sections: list[Section]
|
||||
source: DocumentSource | None = None
|
||||
semantic_identifier: str # displayed in the UI as the main identifier for the doc
|
||||
metadata: dict[str, str | list[str]]
|
||||
@@ -166,11 +150,19 @@ class DocumentBase(BaseModel):
|
||||
|
||||
|
||||
class Document(DocumentBase):
|
||||
"""Used for Onyx ingestion api, the ID is required"""
|
||||
|
||||
id: str
|
||||
id: str # This must be unique or during indexing/reindexing, chunks will be overwritten
|
||||
source: DocumentSource
|
||||
|
||||
def get_total_char_length(self) -> int:
|
||||
"""Calculate the total character length of the document including sections, metadata, and identifiers."""
|
||||
section_length = sum(len(section.text) for section in self.sections)
|
||||
identifier_length = len(self.semantic_identifier) + len(self.title or "")
|
||||
metadata_length = sum(
|
||||
len(k) + len(v) if isinstance(v, str) else len(k) + sum(len(x) for x in v)
|
||||
for k, v in self.metadata.items()
|
||||
)
|
||||
return section_length + identifier_length + metadata_length
|
||||
|
||||
def to_short_descriptor(self) -> str:
|
||||
"""Used when logging the identity of a document"""
|
||||
return f"ID: '{self.id}'; Semantic ID: '{self.semantic_identifier}'"
|
||||
@@ -193,32 +185,6 @@ class Document(DocumentBase):
|
||||
)
|
||||
|
||||
|
||||
class IndexingDocument(Document):
|
||||
"""Document with processed sections for indexing"""
|
||||
|
||||
processed_sections: list[Section] = []
|
||||
|
||||
def get_total_char_length(self) -> int:
|
||||
"""Get the total character length of the document including processed sections"""
|
||||
title_len = len(self.title or self.semantic_identifier)
|
||||
|
||||
# Use processed_sections if available, otherwise fall back to original sections
|
||||
if self.processed_sections:
|
||||
section_len = sum(
|
||||
len(section.text) if section.text is not None else 0
|
||||
for section in self.processed_sections
|
||||
)
|
||||
else:
|
||||
section_len = sum(
|
||||
len(section.text)
|
||||
if isinstance(section, TextSection) and section.text is not None
|
||||
else 0
|
||||
for section in self.sections
|
||||
)
|
||||
|
||||
return title_len + section_len
|
||||
|
||||
|
||||
class SlimDocument(BaseModel):
|
||||
id: str
|
||||
perm_sync_data: Any | None = None
|
||||
|
||||
@@ -25,7 +25,7 @@ from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.utils.batching import batch_generator
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -475,7 +475,7 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
Document(
|
||||
id=page.id,
|
||||
sections=[
|
||||
TextSection(
|
||||
Section(
|
||||
link=f"{page.url}#{block.id.replace('-', '')}",
|
||||
text=block.prefix + block.text,
|
||||
)
|
||||
|
||||
@@ -23,8 +23,8 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.onyx_jira.utils import best_effort_basic_expert_info
|
||||
from onyx.connectors.onyx_jira.utils import best_effort_get_field_from_issue
|
||||
from onyx.connectors.onyx_jira.utils import build_jira_client
|
||||
@@ -145,7 +145,7 @@ def fetch_jira_issues_batch(
|
||||
|
||||
yield Document(
|
||||
id=page_url,
|
||||
sections=[TextSection(link=page_url, text=ticket_content)],
|
||||
sections=[Section(link=page_url, text=ticket_content)],
|
||||
source=DocumentSource.JIRA,
|
||||
semantic_identifier=f"{issue.key}: {issue.fields.summary}",
|
||||
title=f"{issue.key} {issue.fields.summary}",
|
||||
|
||||
@@ -16,7 +16,7 @@ from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
|
||||
@@ -110,7 +110,7 @@ class ProductboardConnector(PollConnector):
|
||||
yield Document(
|
||||
id=feature["id"],
|
||||
sections=[
|
||||
TextSection(
|
||||
Section(
|
||||
link=feature["links"]["html"],
|
||||
text=self._parse_description_html(feature["description"]),
|
||||
)
|
||||
@@ -133,7 +133,7 @@ class ProductboardConnector(PollConnector):
|
||||
yield Document(
|
||||
id=component["id"],
|
||||
sections=[
|
||||
TextSection(
|
||||
Section(
|
||||
link=component["links"]["html"],
|
||||
text=self._parse_description_html(component["description"]),
|
||||
)
|
||||
@@ -159,7 +159,7 @@ class ProductboardConnector(PollConnector):
|
||||
yield Document(
|
||||
id=product["id"],
|
||||
sections=[
|
||||
TextSection(
|
||||
Section(
|
||||
link=product["links"]["html"],
|
||||
text=self._parse_description_html(product["description"]),
|
||||
)
|
||||
@@ -189,7 +189,7 @@ class ProductboardConnector(PollConnector):
|
||||
yield Document(
|
||||
id=objective["id"],
|
||||
sections=[
|
||||
TextSection(
|
||||
Section(
|
||||
link=objective["links"]["html"],
|
||||
text=self._parse_description_html(objective["description"]),
|
||||
)
|
||||
|
||||
@@ -13,7 +13,6 @@ from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.salesforce.doc_conversion import convert_sf_object_to_doc
|
||||
from onyx.connectors.salesforce.doc_conversion import ID_PREFIX
|
||||
from onyx.connectors.salesforce.salesforce_calls import fetch_all_csvs_in_parallel
|
||||
@@ -49,12 +48,10 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
self,
|
||||
credentials: dict[str, Any],
|
||||
) -> dict[str, Any] | None:
|
||||
domain = "test" if credentials.get("is_sandbox") else None
|
||||
self._sf_client = Salesforce(
|
||||
username=credentials["sf_username"],
|
||||
password=credentials["sf_password"],
|
||||
security_token=credentials["sf_security_token"],
|
||||
domain=domain,
|
||||
)
|
||||
return None
|
||||
|
||||
@@ -219,8 +216,7 @@ if __name__ == "__main__":
|
||||
for doc in doc_batch:
|
||||
section_count += len(doc.sections)
|
||||
for section in doc.sections:
|
||||
if isinstance(section, TextSection) and section.text is not None:
|
||||
text_count += len(section.text)
|
||||
text_count += len(section.text)
|
||||
end_time = time.time()
|
||||
|
||||
print(f"Doc count: {doc_count}")
|
||||
|
||||
@@ -1,12 +1,10 @@
|
||||
import re
|
||||
from typing import cast
|
||||
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.connectors.salesforce.sqlite_functions import get_child_ids
|
||||
from onyx.connectors.salesforce.sqlite_functions import get_record
|
||||
from onyx.connectors.salesforce.utils import SalesforceObject
|
||||
@@ -116,8 +114,8 @@ def _extract_dict_text(raw_dict: dict) -> str:
|
||||
return natural_language_for_dict
|
||||
|
||||
|
||||
def _extract_section(salesforce_object: SalesforceObject, base_url: str) -> TextSection:
|
||||
return TextSection(
|
||||
def _extract_section(salesforce_object: SalesforceObject, base_url: str) -> Section:
|
||||
return Section(
|
||||
text=_extract_dict_text(salesforce_object.data),
|
||||
link=f"{base_url}/{salesforce_object.id}",
|
||||
)
|
||||
@@ -177,7 +175,7 @@ def convert_sf_object_to_doc(
|
||||
|
||||
doc = Document(
|
||||
id=onyx_salesforce_id,
|
||||
sections=cast(list[TextSection | ImageSection], sections),
|
||||
sections=sections,
|
||||
source=DocumentSource.SALESFORCE,
|
||||
semantic_identifier=extracted_semantic_identifier,
|
||||
doc_updated_at=extracted_doc_updated_at,
|
||||
|
||||
@@ -19,7 +19,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -55,7 +55,7 @@ def _convert_driveitem_to_document(
|
||||
|
||||
doc = Document(
|
||||
id=driveitem.id,
|
||||
sections=[TextSection(link=driveitem.web_url, text=file_text)],
|
||||
sections=[Section(link=driveitem.web_url, text=file_text)],
|
||||
source=DocumentSource.SHAREPOINT,
|
||||
semantic_identifier=driveitem.name,
|
||||
doc_updated_at=driveitem.last_modified_datetime.replace(tzinfo=timezone.utc),
|
||||
|
||||
@@ -19,8 +19,8 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -212,7 +212,7 @@ class SlabConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
doc_batch.append(
|
||||
Document(
|
||||
id=post_id, # can't be url as this changes with the post title
|
||||
sections=[TextSection(link=page_url, text=content_text)],
|
||||
sections=[Section(link=page_url, text=content_text)],
|
||||
source=DocumentSource.SLAB,
|
||||
semantic_identifier=post["title"],
|
||||
metadata={},
|
||||
|
||||
@@ -34,8 +34,8 @@ from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import EntityFailure
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.slack.utils import expert_info_from_slack_id
|
||||
from onyx.connectors.slack.utils import get_message_link
|
||||
from onyx.connectors.slack.utils import make_paginated_slack_api_call_w_retries
|
||||
@@ -211,7 +211,7 @@ def thread_to_doc(
|
||||
return Document(
|
||||
id=_build_doc_id(channel_id=channel_id, thread_ts=thread[0]["ts"]),
|
||||
sections=[
|
||||
TextSection(
|
||||
Section(
|
||||
link=get_message_link(event=m, client=client, channel_id=channel_id),
|
||||
text=slack_cleaner.index_clean(cast(str, m["text"])),
|
||||
)
|
||||
@@ -220,6 +220,7 @@ def thread_to_doc(
|
||||
source=DocumentSource.SLACK,
|
||||
semantic_identifier=doc_sem_id,
|
||||
doc_updated_at=get_latest_message_time(thread),
|
||||
title="", # slack docs don't really have a "title"
|
||||
primary_owners=valid_experts,
|
||||
metadata={"Channel": channel["name"]},
|
||||
)
|
||||
|
||||
@@ -24,7 +24,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.file_processing.html_utils import parse_html_page_basic
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -165,7 +165,7 @@ def _convert_thread_to_document(
|
||||
|
||||
doc = Document(
|
||||
id=post_id,
|
||||
sections=[TextSection(link=web_url, text=thread_text)],
|
||||
sections=[Section(link=web_url, text=thread_text)],
|
||||
source=DocumentSource.TEAMS,
|
||||
semantic_identifier=semantic_string,
|
||||
title="", # teams threads don't really have a "title"
|
||||
|
||||
45
backend/onyx/connectors/vision_enabled_connector.py
Normal file
45
backend/onyx/connectors/vision_enabled_connector.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""
|
||||
Mixin for connectors that need vision capabilities.
|
||||
"""
|
||||
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
|
||||
from onyx.llm.factory import get_default_llm_with_vision
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class VisionEnabledConnector:
|
||||
"""
|
||||
Mixin for connectors that need vision capabilities.
|
||||
|
||||
This mixin provides a standard way to initialize a vision-capable LLM
|
||||
for image analysis during indexing.
|
||||
|
||||
Usage:
|
||||
class MyConnector(LoadConnector, VisionEnabledConnector):
|
||||
def __init__(self, ...):
|
||||
super().__init__(...)
|
||||
self.initialize_vision_llm()
|
||||
"""
|
||||
|
||||
def initialize_vision_llm(self) -> None:
|
||||
"""
|
||||
Initialize a vision-capable LLM if enabled by configuration.
|
||||
|
||||
Sets self.image_analysis_llm to the LLM instance or None if disabled.
|
||||
"""
|
||||
self.image_analysis_llm: LLM | None = None
|
||||
if get_image_extraction_and_analysis_enabled():
|
||||
try:
|
||||
self.image_analysis_llm = get_default_llm_with_vision()
|
||||
if self.image_analysis_llm is None:
|
||||
logger.warning(
|
||||
"No LLM with vision found; image summarization will be disabled"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Failed to initialize vision LLM due to an error: {str(e)}. "
|
||||
"Image summarization will be disabled."
|
||||
)
|
||||
self.image_analysis_llm = None
|
||||
@@ -32,7 +32,7 @@ from onyx.connectors.exceptions import UnexpectedValidationError
|
||||
from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.file_processing.extract_file_text import read_pdf_file
|
||||
from onyx.file_processing.html_utils import web_html_cleanup
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -341,7 +341,7 @@ class WebConnector(LoadConnector):
|
||||
doc_batch.append(
|
||||
Document(
|
||||
id=initial_url,
|
||||
sections=[TextSection(link=initial_url, text=page_text)],
|
||||
sections=[Section(link=initial_url, text=page_text)],
|
||||
source=DocumentSource.WEB,
|
||||
semantic_identifier=initial_url.split("/")[-1],
|
||||
metadata=metadata,
|
||||
@@ -443,7 +443,7 @@ class WebConnector(LoadConnector):
|
||||
Document(
|
||||
id=initial_url,
|
||||
sections=[
|
||||
TextSection(link=initial_url, text=parsed_html.cleaned_text)
|
||||
Section(link=initial_url, text=parsed_html.cleaned_text)
|
||||
],
|
||||
source=DocumentSource.WEB,
|
||||
semantic_identifier=parsed_html.title or initial_url,
|
||||
|
||||
@@ -28,7 +28,7 @@ from onyx.connectors.interfaces import GenerateDocumentsOutput
|
||||
from onyx.connectors.interfaces import LoadConnector
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -104,7 +104,7 @@ def scrape_page_posts(
|
||||
# id. We may want to de-dupe this stuff inside the indexing service.
|
||||
document = Document(
|
||||
id=f"{DocumentSource.XENFORO.value}_{title}_{page_index}_{formatted_time}",
|
||||
sections=[TextSection(link=url, text=post_text)],
|
||||
sections=[Section(link=url, text=post_text)],
|
||||
title=title,
|
||||
source=DocumentSource.XENFORO,
|
||||
semantic_identifier=title,
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import requests
|
||||
|
||||
@@ -18,8 +17,8 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.interfaces import SlimConnector
|
||||
from onyx.connectors.models import BasicExpertInfo
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.connectors.models import SlimDocument
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.file_processing.html_utils import parse_html_page_basic
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.utils.retry_wrapper import retry_builder
|
||||
@@ -169,8 +168,8 @@ def _article_to_document(
|
||||
return new_author_mapping, Document(
|
||||
id=f"article:{article['id']}",
|
||||
sections=[
|
||||
TextSection(
|
||||
link=cast(str, article.get("html_url")),
|
||||
Section(
|
||||
link=article.get("html_url"),
|
||||
text=parse_html_page_basic(article["body"]),
|
||||
)
|
||||
],
|
||||
@@ -269,7 +268,7 @@ def _ticket_to_document(
|
||||
|
||||
return new_author_mapping, Document(
|
||||
id=f"zendesk_ticket_{ticket['id']}",
|
||||
sections=[TextSection(link=ticket_display_url, text=full_text)],
|
||||
sections=[Section(link=ticket_display_url, text=full_text)],
|
||||
source=DocumentSource.ZENDESK,
|
||||
semantic_identifier=f"Ticket #{ticket['id']}: {subject or 'No Subject'}",
|
||||
doc_updated_at=update_time,
|
||||
|
||||
@@ -20,7 +20,7 @@ from onyx.connectors.interfaces import PollConnector
|
||||
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
|
||||
from onyx.connectors.models import ConnectorMissingCredentialError
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.connectors.zulip.schemas import GetMessagesResponse
|
||||
from onyx.connectors.zulip.schemas import Message
|
||||
from onyx.connectors.zulip.utils import build_search_narrow
|
||||
@@ -161,7 +161,7 @@ class ZulipConnector(LoadConnector, PollConnector):
|
||||
return Document(
|
||||
id=f"{message.stream_id}__{message.id}",
|
||||
sections=[
|
||||
TextSection(
|
||||
Section(
|
||||
link=self._message_to_narrow_link(message),
|
||||
text=text,
|
||||
)
|
||||
|
||||
@@ -10,7 +10,6 @@ from langchain_core.messages import SystemMessage
|
||||
|
||||
from onyx.chat.models import SectionRelevancePiece
|
||||
from onyx.configs.app_configs import BLURB_SIZE
|
||||
from onyx.configs.app_configs import IMAGE_ANALYSIS_SYSTEM_PROMPT
|
||||
from onyx.configs.constants import RETURN_SEPARATOR
|
||||
from onyx.configs.llm_configs import get_search_time_image_analysis_enabled
|
||||
from onyx.configs.model_configs import CROSS_ENCODER_RANGE_MAX
|
||||
@@ -32,6 +31,7 @@ from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.utils import message_to_string
|
||||
from onyx.natural_language_processing.search_nlp_models import RerankingModel
|
||||
from onyx.prompts.image_analysis import IMAGE_ANALYSIS_SYSTEM_PROMPT
|
||||
from onyx.secondary_llm_flows.chunk_usefulness import llm_batch_eval_sections
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import FunctionCall
|
||||
|
||||
@@ -1,63 +0,0 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.db.models import ChunkStats
|
||||
from onyx.indexing.models import UpdatableChunkData
|
||||
|
||||
|
||||
def update_chunk_boost_components__no_commit(
|
||||
chunk_data: list[UpdatableChunkData],
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
"""Updates the chunk_boost_components for chunks in the database.
|
||||
|
||||
Args:
|
||||
chunk_data: List of dicts containing chunk_id, document_id, and boost_score
|
||||
db_session: SQLAlchemy database session
|
||||
"""
|
||||
if not chunk_data:
|
||||
return
|
||||
|
||||
for data in chunk_data:
|
||||
chunk_in_doc_id = int(data.chunk_id)
|
||||
if chunk_in_doc_id < 0:
|
||||
raise ValueError(f"Chunk ID is empty for chunk {data}")
|
||||
|
||||
chunk_document_id = f"{data.document_id}" f"__{chunk_in_doc_id}"
|
||||
chunk_stats = (
|
||||
db_session.query(ChunkStats)
|
||||
.filter(
|
||||
ChunkStats.id == chunk_document_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
score = data.boost_score
|
||||
|
||||
if chunk_stats:
|
||||
chunk_stats.information_content_boost = score
|
||||
chunk_stats.last_modified = datetime.now(timezone.utc)
|
||||
db_session.add(chunk_stats)
|
||||
else:
|
||||
# do not save new chunks with a neutral boost score
|
||||
if score == 1.0:
|
||||
continue
|
||||
# Create new record
|
||||
chunk_stats = ChunkStats(
|
||||
document_id=data.document_id,
|
||||
chunk_in_doc_id=chunk_in_doc_id,
|
||||
information_content_boost=score,
|
||||
)
|
||||
db_session.add(chunk_stats)
|
||||
|
||||
|
||||
def delete_chunk_stats_by_connector_credential_pair__no_commit(
|
||||
db_session: Session, document_ids: list[str]
|
||||
) -> None:
|
||||
"""This deletes just chunk stats in postgres."""
|
||||
stmt = delete(ChunkStats).where(ChunkStats.document_id.in_(document_ids))
|
||||
|
||||
db_session.execute(stmt)
|
||||
@@ -23,7 +23,6 @@ from sqlalchemy.sql.expression import null
|
||||
|
||||
from onyx.configs.constants import DEFAULT_BOOST
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.db.chunk import delete_chunk_stats_by_connector_credential_pair__no_commit
|
||||
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.enums import AccessType
|
||||
@@ -563,18 +562,6 @@ def delete_documents_complete__no_commit(
|
||||
db_session: Session, document_ids: list[str]
|
||||
) -> None:
|
||||
"""This completely deletes the documents from the db, including all foreign key relationships"""
|
||||
|
||||
# Start by deleting the chunk stats for the documents
|
||||
delete_chunk_stats_by_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids,
|
||||
)
|
||||
|
||||
delete_chunk_stats_by_connector_credential_pair__no_commit(
|
||||
db_session=db_session,
|
||||
document_ids=document_ids,
|
||||
)
|
||||
|
||||
delete_documents_by_connector_credential_pair__no_commit(db_session, document_ids)
|
||||
delete_document_feedback_for_documents__no_commit(
|
||||
document_ids=document_ids, db_session=db_session
|
||||
|
||||
@@ -13,7 +13,6 @@ from onyx.db.models import SearchSettings
|
||||
from onyx.db.models import Tool as ToolModel
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.llm.utils import model_supports_image_input
|
||||
from onyx.server.manage.embedding.models import CloudEmbeddingProvider
|
||||
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||
from onyx.server.manage.llm.models import FullLLMProvider
|
||||
@@ -188,17 +187,6 @@ def fetch_default_provider(db_session: Session) -> FullLLMProvider | None:
|
||||
return FullLLMProvider.from_model(provider_model)
|
||||
|
||||
|
||||
def fetch_default_vision_provider(db_session: Session) -> FullLLMProvider | None:
|
||||
provider_model = db_session.scalar(
|
||||
select(LLMProviderModel).where(
|
||||
LLMProviderModel.is_default_vision_provider == True # noqa: E712
|
||||
)
|
||||
)
|
||||
if not provider_model:
|
||||
return None
|
||||
return FullLLMProvider.from_model(provider_model)
|
||||
|
||||
|
||||
def fetch_provider(db_session: Session, provider_name: str) -> FullLLMProvider | None:
|
||||
provider_model = db_session.scalar(
|
||||
select(LLMProviderModel).where(LLMProviderModel.name == provider_name)
|
||||
@@ -258,39 +246,3 @@ def update_default_provider(provider_id: int, db_session: Session) -> None:
|
||||
|
||||
new_default.is_default_provider = True
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_default_vision_provider(
|
||||
provider_id: int, vision_model: str | None, db_session: Session
|
||||
) -> None:
|
||||
new_default = db_session.scalar(
|
||||
select(LLMProviderModel).where(LLMProviderModel.id == provider_id)
|
||||
)
|
||||
if not new_default:
|
||||
raise ValueError(f"LLM Provider with id {provider_id} does not exist")
|
||||
|
||||
# Validate that the specified vision model supports image input
|
||||
model_to_validate = vision_model or new_default.default_model_name
|
||||
if model_to_validate:
|
||||
if not model_supports_image_input(model_to_validate, new_default.provider):
|
||||
raise ValueError(
|
||||
f"Model '{model_to_validate}' for provider '{new_default.provider}' does not support image input"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Model '{vision_model}' is not a valid model for provider '{new_default.provider}'"
|
||||
)
|
||||
|
||||
existing_default = db_session.scalar(
|
||||
select(LLMProviderModel).where(
|
||||
LLMProviderModel.is_default_vision_provider == True # noqa: E712
|
||||
)
|
||||
)
|
||||
if existing_default:
|
||||
existing_default.is_default_vision_provider = None
|
||||
# required to ensure that the below does not cause a unique constraint violation
|
||||
db_session.flush()
|
||||
|
||||
new_default.is_default_vision_provider = True
|
||||
new_default.default_vision_model = vision_model
|
||||
db_session.commit()
|
||||
|
||||
@@ -591,55 +591,6 @@ class Document(Base):
|
||||
)
|
||||
|
||||
|
||||
class ChunkStats(Base):
|
||||
__tablename__ = "chunk_stats"
|
||||
# NOTE: if more sensitive data is added here for display, make sure to add user/group permission
|
||||
|
||||
# this should correspond to the ID of the document
|
||||
# (as is passed around in Onyx)
|
||||
id: Mapped[str] = mapped_column(
|
||||
NullFilteredString,
|
||||
primary_key=True,
|
||||
default=lambda context: (
|
||||
f"{context.get_current_parameters()['document_id']}"
|
||||
f"__{context.get_current_parameters()['chunk_in_doc_id']}"
|
||||
),
|
||||
index=True,
|
||||
)
|
||||
|
||||
# Reference to parent document
|
||||
document_id: Mapped[str] = mapped_column(
|
||||
NullFilteredString, ForeignKey("document.id"), nullable=False, index=True
|
||||
)
|
||||
|
||||
chunk_in_doc_id: Mapped[int] = mapped_column(
|
||||
Integer,
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
information_content_boost: Mapped[float | None] = mapped_column(
|
||||
Float, nullable=True
|
||||
)
|
||||
|
||||
last_modified: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, index=True, default=func.now()
|
||||
)
|
||||
last_synced: Mapped[datetime.datetime | None] = mapped_column(
|
||||
DateTime(timezone=True), nullable=True, index=True
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"ix_chunk_sync_status",
|
||||
last_modified,
|
||||
last_synced,
|
||||
),
|
||||
UniqueConstraint(
|
||||
"document_id", "chunk_in_doc_id", name="uq_chunk_stats_doc_chunk"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class Tag(Base):
|
||||
__tablename__ = "tag"
|
||||
|
||||
@@ -1538,8 +1489,6 @@ class LLMProvider(Base):
|
||||
|
||||
# should only be set for a single provider
|
||||
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, unique=True)
|
||||
is_default_vision_provider: Mapped[bool | None] = mapped_column(Boolean)
|
||||
default_vision_model: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
# EE only
|
||||
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
groups: Mapped[list["UserGroup"]] = relationship(
|
||||
@@ -2360,17 +2309,6 @@ class UserTenantMapping(Base):
|
||||
return value.lower() if value else value
|
||||
|
||||
|
||||
class AvailableTenant(Base):
|
||||
__tablename__ = "available_tenant"
|
||||
"""
|
||||
These entries will only exist ephemerally and are meant to be picked up by new users on registration.
|
||||
"""
|
||||
|
||||
tenant_id: Mapped[str] = mapped_column(String, primary_key=True, nullable=False)
|
||||
alembic_version: Mapped[str] = mapped_column(String, nullable=False)
|
||||
date_created: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False)
|
||||
|
||||
|
||||
# This is a mapping from tenant IDs to anonymous user paths
|
||||
class TenantAnonymousUserPath(Base):
|
||||
__tablename__ = "tenant_anonymous_user_path"
|
||||
|
||||
@@ -67,9 +67,6 @@ def read_lobj(
|
||||
use_tempfile: bool = False,
|
||||
) -> IO:
|
||||
pg_conn = get_pg_conn_from_session(db_session)
|
||||
# Ensure we're using binary mode by default for large objects
|
||||
if mode is None:
|
||||
mode = "rb"
|
||||
large_object = (
|
||||
pg_conn.lobject(lobj_oid, mode=mode) if mode else pg_conn.lobject(lobj_oid)
|
||||
)
|
||||
@@ -84,7 +81,6 @@ def read_lobj(
|
||||
temp_file.seek(0)
|
||||
return temp_file
|
||||
else:
|
||||
# Ensure we're getting raw bytes without text decoding
|
||||
return BytesIO(large_object.read())
|
||||
|
||||
|
||||
|
||||
@@ -101,7 +101,6 @@ class VespaDocumentFields:
|
||||
document_sets: set[str] | None = None
|
||||
boost: float | None = None
|
||||
hidden: bool | None = None
|
||||
aggregated_chunk_boost_factor: float | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -80,11 +80,6 @@ schema DANSWER_CHUNK_NAME {
|
||||
indexing: summary | attribute
|
||||
rank: filter
|
||||
}
|
||||
# Field to indicate whether a short chunk is a low content chunk
|
||||
field aggregated_chunk_boost_factor type float {
|
||||
indexing: attribute
|
||||
}
|
||||
|
||||
# Needs to have a separate Attribute list for efficient filtering
|
||||
field metadata_list type array<string> {
|
||||
indexing: summary | attribute
|
||||
@@ -147,11 +142,6 @@ schema DANSWER_CHUNK_NAME {
|
||||
expression: max(if(isNan(attribute(doc_updated_at)) == 1, 7890000, now() - attribute(doc_updated_at)) / 31536000, 0)
|
||||
}
|
||||
|
||||
function inline aggregated_chunk_boost() {
|
||||
# Aggregated boost factor, currently only used for information content classification
|
||||
expression: if(isNan(attribute(aggregated_chunk_boost_factor)) == 1, 1.0, attribute(aggregated_chunk_boost_factor))
|
||||
}
|
||||
|
||||
# Document score decays from 1 to 0.75 as age of last updated time increases
|
||||
function inline recency_bias() {
|
||||
expression: max(1 / (1 + query(decay_factor) * document_age), 0.75)
|
||||
@@ -209,8 +199,6 @@ schema DANSWER_CHUNK_NAME {
|
||||
* document_boost
|
||||
# Decay factor based on time document was last updated
|
||||
* recency_bias
|
||||
# Boost based on aggregated boost calculation
|
||||
* aggregated_chunk_boost
|
||||
}
|
||||
rerank-count: 1000
|
||||
}
|
||||
@@ -222,7 +210,6 @@ schema DANSWER_CHUNK_NAME {
|
||||
closeness(field, embeddings)
|
||||
document_boost
|
||||
recency_bias
|
||||
aggregated_chunk_boost
|
||||
closest(embeddings)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,7 +22,6 @@ from onyx.document_index.vespa.shared_utils.utils import (
|
||||
replace_invalid_doc_id_characters,
|
||||
)
|
||||
from onyx.document_index.vespa_constants import ACCESS_CONTROL_LIST
|
||||
from onyx.document_index.vespa_constants import AGGREGATED_CHUNK_BOOST_FACTOR
|
||||
from onyx.document_index.vespa_constants import BLURB
|
||||
from onyx.document_index.vespa_constants import BOOST
|
||||
from onyx.document_index.vespa_constants import CHUNK_ID
|
||||
@@ -202,7 +201,6 @@ def _index_vespa_chunk(
|
||||
DOCUMENT_SETS: {document_set: 1 for document_set in chunk.document_sets},
|
||||
IMAGE_FILE_NAME: chunk.image_file_name,
|
||||
BOOST: chunk.boost,
|
||||
AGGREGATED_CHUNK_BOOST_FACTOR: chunk.aggregated_chunk_boost_factor,
|
||||
}
|
||||
|
||||
if multitenant:
|
||||
|
||||
@@ -72,7 +72,6 @@ METADATA = "metadata"
|
||||
METADATA_LIST = "metadata_list"
|
||||
METADATA_SUFFIX = "metadata_suffix"
|
||||
BOOST = "boost"
|
||||
AGGREGATED_CHUNK_BOOST_FACTOR = "aggregated_chunk_boost_factor"
|
||||
DOC_UPDATED_AT = "doc_updated_at" # Indexed as seconds since epoch
|
||||
PRIMARY_OWNERS = "primary_owners"
|
||||
SECONDARY_OWNERS = "secondary_owners"
|
||||
@@ -98,7 +97,6 @@ YQL_BASE = (
|
||||
f"{SECTION_CONTINUATION}, "
|
||||
f"{IMAGE_FILE_NAME}, "
|
||||
f"{BOOST}, "
|
||||
f"{AGGREGATED_CHUNK_BOOST_FACTOR}, "
|
||||
f"{HIDDEN}, "
|
||||
f"{DOC_UPDATED_AT}, "
|
||||
f"{PRIMARY_OWNERS}, "
|
||||
|
||||
@@ -6,10 +6,10 @@ from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import SystemMessage
|
||||
from PIL import Image
|
||||
|
||||
from onyx.configs.app_configs import IMAGE_SUMMARIZATION_SYSTEM_PROMPT
|
||||
from onyx.configs.app_configs import IMAGE_SUMMARIZATION_USER_PROMPT
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.utils import message_to_string
|
||||
from onyx.prompts.image_analysis import IMAGE_SUMMARIZATION_SYSTEM_PROMPT
|
||||
from onyx.prompts.image_analysis import IMAGE_SUMMARIZATION_USER_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -62,7 +62,7 @@ def summarize_image_with_error_handling(
|
||||
image_data: The raw image bytes
|
||||
context_name: Name or title of the image for context
|
||||
system_prompt: System prompt to use for the LLM
|
||||
user_prompt_template: User prompt to use (without title)
|
||||
user_prompt_template: Template for the user prompt, should contain {title} placeholder
|
||||
|
||||
Returns:
|
||||
The image summary text, or None if summarization failed or is disabled
|
||||
@@ -70,10 +70,7 @@ def summarize_image_with_error_handling(
|
||||
if llm is None:
|
||||
return None
|
||||
|
||||
# Prepend the image filename to the user prompt
|
||||
user_prompt = (
|
||||
f"The image has the file name '{context_name}'.\n{user_prompt_template}"
|
||||
)
|
||||
user_prompt = user_prompt_template.format(title=context_name)
|
||||
return summarize_image_pipeline(llm, image_data, user_prompt, system_prompt)
|
||||
|
||||
|
||||
|
||||
@@ -2,9 +2,12 @@ from typing import Tuple
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
|
||||
from onyx.configs.constants import FileOrigin
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.db.pg_file_store import save_bytes_to_pgfilestore
|
||||
from onyx.file_processing.image_summarization import summarize_image_with_error_handling
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -15,12 +18,12 @@ def store_image_and_create_section(
|
||||
image_data: bytes,
|
||||
file_name: str,
|
||||
display_name: str,
|
||||
link: str | None = None,
|
||||
media_type: str = "application/octet-stream",
|
||||
media_type: str = "image/unknown",
|
||||
llm: LLM | None = None,
|
||||
file_origin: FileOrigin = FileOrigin.OTHER,
|
||||
) -> Tuple[ImageSection, str | None]:
|
||||
) -> Tuple[Section, str | None]:
|
||||
"""
|
||||
Stores an image in PGFileStore and creates an ImageSection object without summarization.
|
||||
Stores an image in PGFileStore and creates a Section object with optional summarization.
|
||||
|
||||
Args:
|
||||
db_session: Database session
|
||||
@@ -28,11 +31,12 @@ def store_image_and_create_section(
|
||||
file_name: Base identifier for the file
|
||||
display_name: Human-readable name for the image
|
||||
media_type: MIME type of the image
|
||||
llm: Optional LLM with vision capabilities for summarization
|
||||
file_origin: Origin of the file (e.g., CONFLUENCE, GOOGLE_DRIVE, etc.)
|
||||
|
||||
Returns:
|
||||
Tuple containing:
|
||||
- ImageSection object with image reference
|
||||
- Section object with image reference and optional summary text
|
||||
- The file_name in PGFileStore or None if storage failed
|
||||
"""
|
||||
# Storage logic
|
||||
@@ -49,10 +53,18 @@ def store_image_and_create_section(
|
||||
stored_file_name = pgfilestore.file_name
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to store image: {e}")
|
||||
raise e
|
||||
if not CONTINUE_ON_CONNECTOR_FAILURE:
|
||||
raise
|
||||
return Section(text=""), None
|
||||
|
||||
# Summarization logic
|
||||
summary_text = ""
|
||||
if llm:
|
||||
summary_text = (
|
||||
summarize_image_with_error_handling(llm, image_data, display_name) or ""
|
||||
)
|
||||
|
||||
# Create an ImageSection with empty text (will be filled by LLM later in the pipeline)
|
||||
return (
|
||||
ImageSection(image_file_name=stored_file_name, link=link),
|
||||
Section(text=summary_text, image_file_name=stored_file_name),
|
||||
stored_file_name,
|
||||
)
|
||||
|
||||
@@ -9,8 +9,7 @@ from onyx.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
get_metadata_keys_to_ignore,
|
||||
)
|
||||
from onyx.connectors.models import IndexingDocument
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.indexing.models import DocAwareChunk
|
||||
from onyx.natural_language_processing.utils import BaseTokenizer
|
||||
@@ -65,7 +64,7 @@ def _get_metadata_suffix_for_document_index(
|
||||
|
||||
def _combine_chunks(chunks: list[DocAwareChunk], large_chunk_id: int) -> DocAwareChunk:
|
||||
"""
|
||||
Combines multiple DocAwareChunks into one large chunk (for "multipass" mode),
|
||||
Combines multiple DocAwareChunks into one large chunk (for “multipass” mode),
|
||||
appending the content and adjusting source_links accordingly.
|
||||
"""
|
||||
merged_chunk = DocAwareChunk(
|
||||
@@ -99,7 +98,7 @@ def _combine_chunks(chunks: list[DocAwareChunk], large_chunk_id: int) -> DocAwar
|
||||
|
||||
def generate_large_chunks(chunks: list[DocAwareChunk]) -> list[DocAwareChunk]:
|
||||
"""
|
||||
Generates larger "grouped" chunks by combining sets of smaller chunks.
|
||||
Generates larger “grouped” chunks by combining sets of smaller chunks.
|
||||
"""
|
||||
large_chunks = []
|
||||
for idx, i in enumerate(range(0, len(chunks), LARGE_CHUNK_RATIO)):
|
||||
@@ -186,7 +185,7 @@ class Chunker:
|
||||
|
||||
def _get_mini_chunk_texts(self, chunk_text: str) -> list[str] | None:
|
||||
"""
|
||||
For "multipass" mode: additional sub-chunks (mini-chunks) for use in certain embeddings.
|
||||
For “multipass” mode: additional sub-chunks (mini-chunks) for use in certain embeddings.
|
||||
"""
|
||||
if self.mini_chunk_splitter and chunk_text.strip():
|
||||
return self.mini_chunk_splitter.split_text(chunk_text)
|
||||
@@ -195,7 +194,7 @@ class Chunker:
|
||||
# ADDED: extra param image_url to store in the chunk
|
||||
def _create_chunk(
|
||||
self,
|
||||
document: IndexingDocument,
|
||||
document: Document,
|
||||
chunks_list: list[DocAwareChunk],
|
||||
text: str,
|
||||
links: dict[int, str],
|
||||
@@ -226,29 +225,7 @@ class Chunker:
|
||||
|
||||
def _chunk_document(
|
||||
self,
|
||||
document: IndexingDocument,
|
||||
title_prefix: str,
|
||||
metadata_suffix_semantic: str,
|
||||
metadata_suffix_keyword: str,
|
||||
content_token_limit: int,
|
||||
) -> list[DocAwareChunk]:
|
||||
"""
|
||||
Legacy method for backward compatibility.
|
||||
Calls _chunk_document_with_sections with document.sections.
|
||||
"""
|
||||
return self._chunk_document_with_sections(
|
||||
document,
|
||||
document.processed_sections,
|
||||
title_prefix,
|
||||
metadata_suffix_semantic,
|
||||
metadata_suffix_keyword,
|
||||
content_token_limit,
|
||||
)
|
||||
|
||||
def _chunk_document_with_sections(
|
||||
self,
|
||||
document: IndexingDocument,
|
||||
sections: list[Section],
|
||||
document: Document,
|
||||
title_prefix: str,
|
||||
metadata_suffix_semantic: str,
|
||||
metadata_suffix_keyword: str,
|
||||
@@ -256,16 +233,17 @@ class Chunker:
|
||||
) -> list[DocAwareChunk]:
|
||||
"""
|
||||
Loops through sections of the document, converting them into one or more chunks.
|
||||
Works with processed sections that are base Section objects.
|
||||
If a section has an image_link, we treat it as a dedicated chunk.
|
||||
"""
|
||||
|
||||
chunks: list[DocAwareChunk] = []
|
||||
link_offsets: dict[int, str] = {}
|
||||
chunk_text = ""
|
||||
|
||||
for section_idx, section in enumerate(sections):
|
||||
# Get section text and other attributes
|
||||
section_text = clean_text(section.text or "")
|
||||
for section_idx, section in enumerate(document.sections):
|
||||
section_text = clean_text(section.text)
|
||||
section_link_text = section.link or ""
|
||||
# ADDED: if the Section has an image link
|
||||
image_url = section.image_file_name
|
||||
|
||||
# If there is no useful content, skip
|
||||
@@ -276,7 +254,7 @@ class Chunker:
|
||||
)
|
||||
continue
|
||||
|
||||
# CASE 1: If this section has an image, force a separate chunk
|
||||
# CASE 1: If this is an image section, force a separate chunk
|
||||
if image_url:
|
||||
# First, if we have any partially built text chunk, finalize it
|
||||
if chunk_text.strip():
|
||||
@@ -293,13 +271,15 @@ class Chunker:
|
||||
chunk_text = ""
|
||||
link_offsets = {}
|
||||
|
||||
# Create a chunk specifically for this image section
|
||||
# (Using the text summary that was generated during processing)
|
||||
# Create a chunk specifically for this image
|
||||
# (If the section has text describing the image, use that as content)
|
||||
self._create_chunk(
|
||||
document,
|
||||
chunks,
|
||||
section_text,
|
||||
links={0: section_link_text} if section_link_text else {},
|
||||
links={0: section_link_text}
|
||||
if section_link_text
|
||||
else {}, # No text offsets needed for images
|
||||
image_file_name=image_url,
|
||||
title_prefix=title_prefix,
|
||||
metadata_suffix_semantic=metadata_suffix_semantic,
|
||||
@@ -404,9 +384,7 @@ class Chunker:
|
||||
)
|
||||
return chunks
|
||||
|
||||
def _handle_single_document(
|
||||
self, document: IndexingDocument
|
||||
) -> list[DocAwareChunk]:
|
||||
def _handle_single_document(self, document: Document) -> list[DocAwareChunk]:
|
||||
# Specifically for reproducing an issue with gmail
|
||||
if document.source == DocumentSource.GMAIL:
|
||||
logger.debug(f"Chunking {document.semantic_identifier}")
|
||||
@@ -442,31 +420,26 @@ class Chunker:
|
||||
title_prefix = ""
|
||||
metadata_suffix_semantic = ""
|
||||
|
||||
# Use processed_sections if available (IndexingDocument), otherwise use original sections
|
||||
sections_to_chunk = document.processed_sections
|
||||
|
||||
normal_chunks = self._chunk_document_with_sections(
|
||||
# Chunk the document
|
||||
normal_chunks = self._chunk_document(
|
||||
document,
|
||||
sections_to_chunk,
|
||||
title_prefix,
|
||||
metadata_suffix_semantic,
|
||||
metadata_suffix_keyword,
|
||||
content_token_limit,
|
||||
)
|
||||
|
||||
# Optional "multipass" large chunk creation
|
||||
# Optional “multipass” large chunk creation
|
||||
if self.enable_multipass and self.enable_large_chunks:
|
||||
large_chunks = generate_large_chunks(normal_chunks)
|
||||
normal_chunks.extend(large_chunks)
|
||||
|
||||
return normal_chunks
|
||||
|
||||
def chunk(self, documents: list[IndexingDocument]) -> list[DocAwareChunk]:
|
||||
def chunk(self, documents: list[Document]) -> list[DocAwareChunk]:
|
||||
"""
|
||||
Takes in a list of documents and chunks them into smaller chunks for indexing
|
||||
while persisting the document metadata.
|
||||
|
||||
Works with both standard Document objects and IndexingDocument objects with processed_sections.
|
||||
"""
|
||||
final_chunks: list[DocAwareChunk] = []
|
||||
for document in documents:
|
||||
|
||||
@@ -10,20 +10,13 @@ from onyx.access.access import get_access_for_documents
|
||||
from onyx.access.models import DocumentAccess
|
||||
from onyx.configs.app_configs import MAX_DOCUMENT_CHARS
|
||||
from onyx.configs.constants import DEFAULT_BOOST
|
||||
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
|
||||
from onyx.configs.model_configs import USE_INFORMATION_CONTENT_CLASSIFICATION
|
||||
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
|
||||
get_experts_stores_representations,
|
||||
)
|
||||
from onyx.connectors.models import ConnectorFailure
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import DocumentFailure
|
||||
from onyx.connectors.models import ImageSection
|
||||
from onyx.connectors.models import IndexAttemptMetadata
|
||||
from onyx.connectors.models import IndexingDocument
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.db.chunk import update_chunk_boost_components__no_commit
|
||||
from onyx.db.document import fetch_chunk_counts_for_documents
|
||||
from onyx.db.document import get_documents_by_ids
|
||||
from onyx.db.document import mark_document_as_indexed_for_cc_pair__no_commit
|
||||
@@ -34,10 +27,7 @@ from onyx.db.document import update_docs_updated_at__no_commit
|
||||
from onyx.db.document import upsert_document_by_connector_credential_pair
|
||||
from onyx.db.document import upsert_documents
|
||||
from onyx.db.document_set import fetch_document_sets_for_documents
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.models import Document as DBDocument
|
||||
from onyx.db.pg_file_store import get_pgfilestore_by_file_name
|
||||
from onyx.db.pg_file_store import read_lobj
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.tag import create_or_add_document_tag
|
||||
from onyx.db.tag import create_or_add_document_tag_list
|
||||
@@ -47,26 +37,15 @@ from onyx.document_index.document_index_utils import (
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.interfaces import DocumentMetadata
|
||||
from onyx.document_index.interfaces import IndexBatchParams
|
||||
from onyx.file_processing.image_summarization import summarize_image_with_error_handling
|
||||
from onyx.indexing.chunker import Chunker
|
||||
from onyx.indexing.embedder import embed_chunks_with_failure_handling
|
||||
from onyx.indexing.embedder import IndexingEmbedder
|
||||
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
|
||||
from onyx.indexing.models import DocAwareChunk
|
||||
from onyx.indexing.models import DocMetadataAwareIndexChunk
|
||||
from onyx.indexing.models import IndexChunk
|
||||
from onyx.indexing.models import UpdatableChunkData
|
||||
from onyx.indexing.vector_db_insertion import write_chunks_to_vector_db_with_backoff
|
||||
from onyx.llm.factory import get_default_llm_with_vision
|
||||
from onyx.natural_language_processing.search_nlp_models import (
|
||||
InformationContentClassificationModel,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.timing import log_function_time
|
||||
from shared_configs.configs import (
|
||||
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH,
|
||||
)
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -74,7 +53,6 @@ logger = setup_logger()
|
||||
class DocumentBatchPrepareContext(BaseModel):
|
||||
updatable_docs: list[Document]
|
||||
id_to_db_doc_map: dict[str, DBDocument]
|
||||
indexable_docs: list[IndexingDocument] = []
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
||||
@@ -147,72 +125,6 @@ def _upsert_documents_in_db(
|
||||
)
|
||||
|
||||
|
||||
def _get_aggregated_chunk_boost_factor(
|
||||
chunks: list[IndexChunk],
|
||||
information_content_classification_model: InformationContentClassificationModel,
|
||||
) -> list[float]:
|
||||
"""Calculates the aggregated boost factor for a chunk based on its content."""
|
||||
|
||||
short_chunk_content_dict = {
|
||||
chunk_num: chunk.content
|
||||
for chunk_num, chunk in enumerate(chunks)
|
||||
if len(chunk.content.split())
|
||||
<= INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH
|
||||
}
|
||||
short_chunk_contents = list(short_chunk_content_dict.values())
|
||||
short_chunk_keys = list(short_chunk_content_dict.keys())
|
||||
|
||||
try:
|
||||
predictions = information_content_classification_model.predict(
|
||||
short_chunk_contents
|
||||
)
|
||||
# Create a mapping of chunk positions to their scores
|
||||
score_map = {
|
||||
short_chunk_keys[i]: prediction.content_boost_factor
|
||||
for i, prediction in enumerate(predictions)
|
||||
}
|
||||
# Default to 1.0 for longer chunks, use predicted score for short chunks
|
||||
chunk_content_scores = [score_map.get(i, 1.0) for i in range(len(chunks))]
|
||||
|
||||
return chunk_content_scores
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Error predicting content classification for chunks: {e}. Falling back to individual examples."
|
||||
)
|
||||
|
||||
chunks_with_scores: list[IndexChunk] = []
|
||||
chunk_content_scores = []
|
||||
|
||||
for chunk in chunks:
|
||||
if (
|
||||
len(chunk.content.split())
|
||||
> INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH
|
||||
):
|
||||
chunk_content_scores.append(1.0)
|
||||
chunks_with_scores.append(chunk)
|
||||
continue
|
||||
|
||||
try:
|
||||
chunk_content_scores.append(
|
||||
information_content_classification_model.predict([chunk.content])[
|
||||
0
|
||||
].content_boost_factor
|
||||
)
|
||||
chunks_with_scores.append(chunk)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Error predicting content classification for chunk: {e}."
|
||||
)
|
||||
|
||||
raise Exception(
|
||||
f"Failed to predict content classification for chunk {chunk.chunk_id} "
|
||||
f"from document {chunk.source_document.id}"
|
||||
) from e
|
||||
|
||||
return chunk_content_scores
|
||||
|
||||
|
||||
def get_doc_ids_to_update(
|
||||
documents: list[Document], db_docs: list[DBDocument]
|
||||
) -> list[Document]:
|
||||
@@ -242,7 +154,6 @@ def index_doc_batch_with_handler(
|
||||
*,
|
||||
chunker: Chunker,
|
||||
embedder: IndexingEmbedder,
|
||||
information_content_classification_model: InformationContentClassificationModel,
|
||||
document_index: DocumentIndex,
|
||||
document_batch: list[Document],
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
@@ -254,7 +165,6 @@ def index_doc_batch_with_handler(
|
||||
index_pipeline_result = index_doc_batch(
|
||||
chunker=chunker,
|
||||
embedder=embedder,
|
||||
information_content_classification_model=information_content_classification_model,
|
||||
document_index=document_index,
|
||||
document_batch=document_batch,
|
||||
index_attempt_metadata=index_attempt_metadata,
|
||||
@@ -355,12 +265,7 @@ def index_doc_batch_prepare(
|
||||
def filter_documents(document_batch: list[Document]) -> list[Document]:
|
||||
documents: list[Document] = []
|
||||
for document in document_batch:
|
||||
empty_contents = not any(
|
||||
isinstance(section, TextSection)
|
||||
and section.text is not None
|
||||
and section.text.strip()
|
||||
for section in document.sections
|
||||
)
|
||||
empty_contents = not any(section.text.strip() for section in document.sections)
|
||||
if (
|
||||
(not document.title or not document.title.strip())
|
||||
and not document.semantic_identifier.strip()
|
||||
@@ -383,12 +288,7 @@ def filter_documents(document_batch: list[Document]) -> list[Document]:
|
||||
)
|
||||
continue
|
||||
|
||||
section_chars = sum(
|
||||
len(section.text)
|
||||
if isinstance(section, TextSection) and section.text is not None
|
||||
else 0
|
||||
for section in document.sections
|
||||
)
|
||||
section_chars = sum(len(section.text) for section in document.sections)
|
||||
if (
|
||||
MAX_DOCUMENT_CHARS
|
||||
and len(document.title or document.semantic_identifier) + section_chars
|
||||
@@ -408,128 +308,12 @@ def filter_documents(document_batch: list[Document]) -> list[Document]:
|
||||
return documents
|
||||
|
||||
|
||||
def process_image_sections(documents: list[Document]) -> list[IndexingDocument]:
|
||||
"""
|
||||
Process all sections in documents by:
|
||||
1. Converting both TextSection and ImageSection objects to base Section objects
|
||||
2. Processing ImageSections to generate text summaries using a vision-capable LLM
|
||||
3. Returning IndexingDocument objects with both original and processed sections
|
||||
|
||||
Args:
|
||||
documents: List of documents with TextSection | ImageSection objects
|
||||
|
||||
Returns:
|
||||
List of IndexingDocument objects with processed_sections as list[Section]
|
||||
"""
|
||||
# Check if image extraction and analysis is enabled before trying to get a vision LLM
|
||||
if not get_image_extraction_and_analysis_enabled():
|
||||
llm = None
|
||||
else:
|
||||
# Only get the vision LLM if image processing is enabled
|
||||
llm = get_default_llm_with_vision()
|
||||
|
||||
if not llm:
|
||||
logger.warning(
|
||||
"No vision-capable LLM available. Image sections will not be processed."
|
||||
)
|
||||
|
||||
# Even without LLM, we still convert to IndexingDocument with base Sections
|
||||
return [
|
||||
IndexingDocument(
|
||||
**document.dict(),
|
||||
processed_sections=[
|
||||
Section(
|
||||
text=section.text if isinstance(section, TextSection) else None,
|
||||
link=section.link,
|
||||
image_file_name=section.image_file_name
|
||||
if isinstance(section, ImageSection)
|
||||
else None,
|
||||
)
|
||||
for section in document.sections
|
||||
],
|
||||
)
|
||||
for document in documents
|
||||
]
|
||||
|
||||
indexed_documents: list[IndexingDocument] = []
|
||||
|
||||
for document in documents:
|
||||
processed_sections: list[Section] = []
|
||||
|
||||
for section in document.sections:
|
||||
# For ImageSection, process and create base Section with both text and image_file_name
|
||||
if isinstance(section, ImageSection):
|
||||
# Default section with image path preserved
|
||||
processed_section = Section(
|
||||
link=section.link,
|
||||
image_file_name=section.image_file_name,
|
||||
text=None, # Will be populated if summarization succeeds
|
||||
)
|
||||
|
||||
# Try to get image summary
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
pgfilestore = get_pgfilestore_by_file_name(
|
||||
file_name=section.image_file_name, db_session=db_session
|
||||
)
|
||||
|
||||
if not pgfilestore:
|
||||
logger.warning(
|
||||
f"Image file {section.image_file_name} not found in PGFileStore"
|
||||
)
|
||||
|
||||
processed_section.text = "[Image could not be processed]"
|
||||
else:
|
||||
# Get the image data
|
||||
image_data_io = read_lobj(
|
||||
pgfilestore.lobj_oid, db_session, mode="rb"
|
||||
)
|
||||
pgfilestore_data = image_data_io.read()
|
||||
summary = summarize_image_with_error_handling(
|
||||
llm=llm,
|
||||
image_data=pgfilestore_data,
|
||||
context_name=pgfilestore.display_name or "Image",
|
||||
)
|
||||
|
||||
if summary:
|
||||
processed_section.text = summary
|
||||
else:
|
||||
processed_section.text = (
|
||||
"[Image could not be summarized]"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing image section: {e}")
|
||||
processed_section.text = "[Error processing image]"
|
||||
|
||||
processed_sections.append(processed_section)
|
||||
|
||||
# For TextSection, create a base Section with text and link
|
||||
elif isinstance(section, TextSection):
|
||||
processed_section = Section(
|
||||
text=section.text, link=section.link, image_file_name=None
|
||||
)
|
||||
processed_sections.append(processed_section)
|
||||
|
||||
# If it's already a base Section (unlikely), just append it
|
||||
else:
|
||||
processed_sections.append(section)
|
||||
|
||||
# Create IndexingDocument with original sections and processed_sections
|
||||
indexed_document = IndexingDocument(
|
||||
**document.dict(), processed_sections=processed_sections
|
||||
)
|
||||
indexed_documents.append(indexed_document)
|
||||
|
||||
return indexed_documents
|
||||
|
||||
|
||||
@log_function_time(debug_only=True)
|
||||
def index_doc_batch(
|
||||
*,
|
||||
document_batch: list[Document],
|
||||
chunker: Chunker,
|
||||
embedder: IndexingEmbedder,
|
||||
information_content_classification_model: InformationContentClassificationModel,
|
||||
document_index: DocumentIndex,
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
db_session: Session,
|
||||
@@ -578,23 +362,19 @@ def index_doc_batch(
|
||||
failures=[],
|
||||
)
|
||||
|
||||
# Convert documents to IndexingDocument objects with processed section
|
||||
# logger.debug("Processing image sections")
|
||||
ctx.indexable_docs = process_image_sections(ctx.updatable_docs)
|
||||
|
||||
doc_descriptors = [
|
||||
{
|
||||
"doc_id": doc.id,
|
||||
"doc_length": doc.get_total_char_length(),
|
||||
}
|
||||
for doc in ctx.indexable_docs
|
||||
for doc in ctx.updatable_docs
|
||||
]
|
||||
logger.debug(f"Starting indexing process for documents: {doc_descriptors}")
|
||||
|
||||
logger.debug("Starting chunking")
|
||||
# NOTE: no special handling for failures here, since the chunker is not
|
||||
# a common source of failure for the indexing pipeline
|
||||
chunks: list[DocAwareChunk] = chunker.chunk(ctx.indexable_docs)
|
||||
chunks: list[DocAwareChunk] = chunker.chunk(ctx.updatable_docs)
|
||||
|
||||
logger.debug("Starting embedding")
|
||||
chunks_with_embeddings, embedding_failures = (
|
||||
@@ -606,23 +386,7 @@ def index_doc_batch(
|
||||
else ([], [])
|
||||
)
|
||||
|
||||
chunk_content_scores = (
|
||||
_get_aggregated_chunk_boost_factor(
|
||||
chunks_with_embeddings, information_content_classification_model
|
||||
)
|
||||
if USE_INFORMATION_CONTENT_CLASSIFICATION
|
||||
else [1.0] * len(chunks_with_embeddings)
|
||||
)
|
||||
|
||||
updatable_ids = [doc.id for doc in ctx.updatable_docs]
|
||||
updatable_chunk_data = [
|
||||
UpdatableChunkData(
|
||||
chunk_id=chunk.chunk_id,
|
||||
document_id=chunk.source_document.id,
|
||||
boost_score=score,
|
||||
)
|
||||
for chunk, score in zip(chunks_with_embeddings, chunk_content_scores)
|
||||
]
|
||||
|
||||
# Acquires a lock on the documents so that no other process can modify them
|
||||
# NOTE: don't need to acquire till here, since this is when the actual race condition
|
||||
@@ -675,9 +439,8 @@ def index_doc_batch(
|
||||
else DEFAULT_BOOST
|
||||
),
|
||||
tenant_id=tenant_id,
|
||||
aggregated_chunk_boost_factor=chunk_content_scores[chunk_num],
|
||||
)
|
||||
for chunk_num, chunk in enumerate(chunks_with_embeddings)
|
||||
for chunk in chunks_with_embeddings
|
||||
]
|
||||
|
||||
logger.debug(
|
||||
@@ -762,11 +525,6 @@ def index_doc_batch(
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
# save the chunk boost components to postgres
|
||||
update_chunk_boost_components__no_commit(
|
||||
chunk_data=updatable_chunk_data, db_session=db_session
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
|
||||
result = IndexingPipelineResult(
|
||||
@@ -782,7 +540,6 @@ def index_doc_batch(
|
||||
def build_indexing_pipeline(
|
||||
*,
|
||||
embedder: IndexingEmbedder,
|
||||
information_content_classification_model: InformationContentClassificationModel,
|
||||
document_index: DocumentIndex,
|
||||
db_session: Session,
|
||||
tenant_id: str,
|
||||
@@ -806,7 +563,6 @@ def build_indexing_pipeline(
|
||||
index_doc_batch_with_handler,
|
||||
chunker=chunker,
|
||||
embedder=embedder,
|
||||
information_content_classification_model=information_content_classification_model,
|
||||
document_index=document_index,
|
||||
ignore_time_skip=ignore_time_skip,
|
||||
db_session=db_session,
|
||||
|
||||
@@ -83,16 +83,13 @@ class DocMetadataAwareIndexChunk(IndexChunk):
|
||||
document_sets: all document sets the source document for this chunk is a part
|
||||
of. This is used for filtering / personas.
|
||||
boost: influences the ranking of this chunk at query time. Positive -> ranked higher,
|
||||
negative -> ranked lower. Not included in aggregated boost calculation
|
||||
for legacy reasons.
|
||||
aggregated_chunk_boost_factor: represents the aggregated chunk-level boost (currently: information content)
|
||||
negative -> ranked lower.
|
||||
"""
|
||||
|
||||
tenant_id: str
|
||||
access: "DocumentAccess"
|
||||
document_sets: set[str]
|
||||
boost: int
|
||||
aggregated_chunk_boost_factor: float
|
||||
|
||||
@classmethod
|
||||
def from_index_chunk(
|
||||
@@ -101,7 +98,6 @@ class DocMetadataAwareIndexChunk(IndexChunk):
|
||||
access: "DocumentAccess",
|
||||
document_sets: set[str],
|
||||
boost: int,
|
||||
aggregated_chunk_boost_factor: float,
|
||||
tenant_id: str,
|
||||
) -> "DocMetadataAwareIndexChunk":
|
||||
index_chunk_data = index_chunk.model_dump()
|
||||
@@ -110,7 +106,6 @@ class DocMetadataAwareIndexChunk(IndexChunk):
|
||||
access=access,
|
||||
document_sets=document_sets,
|
||||
boost=boost,
|
||||
aggregated_chunk_boost_factor=aggregated_chunk_boost_factor,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
@@ -184,9 +179,3 @@ class IndexingSetting(EmbeddingModelDetail):
|
||||
class MultipassConfig(BaseModel):
|
||||
multipass_indexing: bool
|
||||
enable_large_chunks: bool
|
||||
|
||||
|
||||
class UpdatableChunkData(BaseModel):
|
||||
chunk_id: int
|
||||
document_id: str
|
||||
boost_score: float
|
||||
|
||||
@@ -5,9 +5,7 @@ from onyx.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.llm import fetch_default_provider
|
||||
from onyx.db.llm import fetch_default_vision_provider
|
||||
from onyx.db.llm import fetch_existing_llm_providers
|
||||
from onyx.db.llm import fetch_provider
|
||||
from onyx.db.models import Persona
|
||||
@@ -16,7 +14,6 @@ from onyx.llm.exceptions import GenAIDisabledException
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.llm.utils import model_supports_image_input
|
||||
from onyx.server.manage.llm.models import FullLLMProvider
|
||||
from onyx.utils.headers import build_llm_extra_headers
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.long_term_log import LongTermLogger
|
||||
@@ -97,61 +94,40 @@ def get_default_llm_with_vision(
|
||||
additional_headers: dict[str, str] | None = None,
|
||||
long_term_logger: LongTermLogger | None = None,
|
||||
) -> LLM | None:
|
||||
"""Get an LLM that supports image input, with the following priority:
|
||||
1. Use the designated default vision provider if it exists and supports image input
|
||||
2. Fall back to the first LLM provider that supports image input
|
||||
|
||||
Returns None if no providers exist or if no provider supports images.
|
||||
"""
|
||||
if DISABLE_GENERATIVE_AI:
|
||||
raise GenAIDisabledException()
|
||||
|
||||
def create_vision_llm(provider: FullLLMProvider, model: str) -> LLM:
|
||||
"""Helper to create an LLM if the provider supports image input."""
|
||||
return get_llm(
|
||||
provider=provider.provider,
|
||||
model=model,
|
||||
deployment_name=provider.deployment_name,
|
||||
api_key=provider.api_key,
|
||||
api_base=provider.api_base,
|
||||
api_version=provider.api_version,
|
||||
custom_config=provider.custom_config,
|
||||
timeout=timeout,
|
||||
temperature=temperature,
|
||||
additional_headers=additional_headers,
|
||||
long_term_logger=long_term_logger,
|
||||
)
|
||||
with get_session_context_manager() as db_session:
|
||||
llm_providers = fetch_existing_llm_providers(db_session)
|
||||
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# Try the default vision provider first
|
||||
default_provider = fetch_default_vision_provider(db_session)
|
||||
if (
|
||||
default_provider
|
||||
and default_provider.default_vision_model
|
||||
and model_supports_image_input(
|
||||
default_provider.default_vision_model, default_provider.provider
|
||||
)
|
||||
):
|
||||
return create_vision_llm(
|
||||
default_provider, default_provider.default_vision_model
|
||||
)
|
||||
|
||||
# Fall back to searching all providers
|
||||
providers = fetch_existing_llm_providers(db_session)
|
||||
|
||||
if not providers:
|
||||
if not llm_providers:
|
||||
return None
|
||||
|
||||
# Find the first provider that supports image input
|
||||
for provider in providers:
|
||||
if provider.default_vision_model and model_supports_image_input(
|
||||
provider.default_vision_model, provider.provider
|
||||
):
|
||||
return create_vision_llm(
|
||||
FullLLMProvider.from_model(provider), provider.default_vision_model
|
||||
for provider in llm_providers:
|
||||
model_name = provider.default_model_name
|
||||
fast_model_name = (
|
||||
provider.fast_default_model_name or provider.default_model_name
|
||||
)
|
||||
|
||||
if not model_name or not fast_model_name:
|
||||
continue
|
||||
|
||||
if model_supports_image_input(model_name, provider.provider):
|
||||
return get_llm(
|
||||
provider=provider.provider,
|
||||
model=model_name,
|
||||
deployment_name=provider.deployment_name,
|
||||
api_key=provider.api_key,
|
||||
api_base=provider.api_base,
|
||||
api_version=provider.api_version,
|
||||
custom_config=provider.custom_config,
|
||||
timeout=timeout,
|
||||
temperature=temperature,
|
||||
additional_headers=additional_headers,
|
||||
long_term_logger=long_term_logger,
|
||||
)
|
||||
|
||||
return None
|
||||
raise ValueError("No LLM provider found that supports image input")
|
||||
|
||||
|
||||
def get_default_llms(
|
||||
|
||||
@@ -29,8 +29,6 @@ from onyx.natural_language_processing.exceptions import (
|
||||
from onyx.natural_language_processing.utils import get_tokenizer
|
||||
from onyx.natural_language_processing.utils import tokenizer_trim_content
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_HOST
|
||||
from shared_configs.configs import INDEXING_MODEL_SERVER_PORT
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
from shared_configs.enums import EmbeddingProvider
|
||||
@@ -38,11 +36,9 @@ from shared_configs.enums import EmbedTextType
|
||||
from shared_configs.enums import RerankerProvider
|
||||
from shared_configs.model_server_models import ConnectorClassificationRequest
|
||||
from shared_configs.model_server_models import ConnectorClassificationResponse
|
||||
from shared_configs.model_server_models import ContentClassificationPrediction
|
||||
from shared_configs.model_server_models import Embedding
|
||||
from shared_configs.model_server_models import EmbedRequest
|
||||
from shared_configs.model_server_models import EmbedResponse
|
||||
from shared_configs.model_server_models import InformationContentClassificationResponses
|
||||
from shared_configs.model_server_models import IntentRequest
|
||||
from shared_configs.model_server_models import IntentResponse
|
||||
from shared_configs.model_server_models import RerankRequest
|
||||
@@ -381,31 +377,6 @@ class QueryAnalysisModel:
|
||||
return response_model.is_keyword, response_model.keywords
|
||||
|
||||
|
||||
class InformationContentClassificationModel:
|
||||
def __init__(
|
||||
self,
|
||||
model_server_host: str = INDEXING_MODEL_SERVER_HOST,
|
||||
model_server_port: int = INDEXING_MODEL_SERVER_PORT,
|
||||
) -> None:
|
||||
model_server_url = build_model_server_url(model_server_host, model_server_port)
|
||||
self.content_server_endpoint = (
|
||||
model_server_url + "/custom/content-classification"
|
||||
)
|
||||
|
||||
def predict(
|
||||
self,
|
||||
queries: list[str],
|
||||
) -> list[ContentClassificationPrediction]:
|
||||
response = requests.post(self.content_server_endpoint, json=queries)
|
||||
response.raise_for_status()
|
||||
|
||||
model_responses = InformationContentClassificationResponses(
|
||||
information_content_classifications=response.json()
|
||||
)
|
||||
|
||||
return model_responses.information_content_classifications
|
||||
|
||||
|
||||
class ConnectorClassificationModel:
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# Used for creating embeddings of images for vector search
|
||||
DEFAULT_IMAGE_SUMMARIZATION_SYSTEM_PROMPT = """
|
||||
IMAGE_SUMMARIZATION_SYSTEM_PROMPT = """
|
||||
You are an assistant for summarizing images for retrieval.
|
||||
Summarize the content of the following image and be as precise as possible.
|
||||
The summary will be embedded and used to retrieve the original image.
|
||||
@@ -7,13 +7,14 @@ Therefore, write a concise summary of the image that is optimized for retrieval.
|
||||
"""
|
||||
|
||||
# Prompt for generating image descriptions with filename context
|
||||
DEFAULT_IMAGE_SUMMARIZATION_USER_PROMPT = """
|
||||
IMAGE_SUMMARIZATION_USER_PROMPT = """
|
||||
The image has the file name '{title}'.
|
||||
Describe precisely and concisely what the image shows.
|
||||
"""
|
||||
|
||||
|
||||
# Used for analyzing images in response to user queries at search time
|
||||
DEFAULT_IMAGE_ANALYSIS_SYSTEM_PROMPT = (
|
||||
IMAGE_ANALYSIS_SYSTEM_PROMPT = (
|
||||
"You are an AI assistant specialized in describing images.\n"
|
||||
"You will receive a user question plus an image URL. Provide a concise textual answer.\n"
|
||||
"Focus on aspects of the image that are relevant to the user's question.\n"
|
||||
|
||||
@@ -160,20 +160,6 @@ class RedisPool:
|
||||
def get_replica_client(self, tenant_id: str) -> Redis:
|
||||
return TenantRedis(tenant_id, connection_pool=self._replica_pool)
|
||||
|
||||
def get_raw_client(self) -> Redis:
|
||||
"""
|
||||
Returns a Redis client with direct access to the primary connection pool,
|
||||
without tenant prefixing.
|
||||
"""
|
||||
return redis.Redis(connection_pool=self._pool)
|
||||
|
||||
def get_raw_replica_client(self) -> Redis:
|
||||
"""
|
||||
Returns a Redis client with direct access to the replica connection pool,
|
||||
without tenant prefixing.
|
||||
"""
|
||||
return redis.Redis(connection_pool=self._replica_pool)
|
||||
|
||||
@staticmethod
|
||||
def create_pool(
|
||||
host: str = REDIS_HOST,
|
||||
@@ -238,15 +224,6 @@ def get_redis_client(
|
||||
# This argument will be deprecated in the future
|
||||
tenant_id: str | None = None,
|
||||
) -> Redis:
|
||||
"""
|
||||
Returns a Redis client with tenant-specific key prefixing.
|
||||
|
||||
This ensures proper data isolation between tenants by automatically
|
||||
prefixing all Redis keys with the tenant ID.
|
||||
|
||||
Use this when working with tenant-specific data that should be
|
||||
isolated from other tenants.
|
||||
"""
|
||||
if tenant_id is None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
@@ -258,15 +235,6 @@ def get_redis_replica_client(
|
||||
# this argument will be deprecated in the future
|
||||
tenant_id: str | None = None,
|
||||
) -> Redis:
|
||||
"""
|
||||
Returns a Redis replica client with tenant-specific key prefixing.
|
||||
|
||||
Similar to get_redis_client(), but connects to a read replica when available.
|
||||
This ensures proper data isolation between tenants by automatically
|
||||
prefixing all Redis keys with the tenant ID.
|
||||
|
||||
Use this for read-heavy operations on tenant-specific data.
|
||||
"""
|
||||
if tenant_id is None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
@@ -274,57 +242,13 @@ def get_redis_replica_client(
|
||||
|
||||
|
||||
def get_shared_redis_client() -> Redis:
|
||||
"""
|
||||
Returns a Redis client with a shared namespace prefix.
|
||||
|
||||
Unlike tenant-specific clients, this uses a common prefix for all keys,
|
||||
creating a shared namespace accessible across all tenants.
|
||||
|
||||
Use this for data that should be shared across the application and
|
||||
isn't specific to any individual tenant.
|
||||
"""
|
||||
return redis_pool.get_client(DEFAULT_REDIS_PREFIX)
|
||||
|
||||
|
||||
def get_shared_redis_replica_client() -> Redis:
|
||||
"""
|
||||
Returns a Redis replica client with a shared namespace prefix.
|
||||
|
||||
Similar to get_shared_redis_client(), but connects to a read replica when available.
|
||||
Uses a common prefix for all keys, creating a shared namespace.
|
||||
|
||||
Use this for read-heavy operations on data that should be shared
|
||||
across the application.
|
||||
"""
|
||||
return redis_pool.get_replica_client(DEFAULT_REDIS_PREFIX)
|
||||
|
||||
|
||||
def get_raw_redis_client() -> Redis:
|
||||
"""
|
||||
Returns a Redis client that doesn't apply tenant prefixing to keys.
|
||||
|
||||
Use this only when you need to access Redis directly without tenant isolation
|
||||
or any key prefixing. Typically needed for integrating with external systems
|
||||
or libraries that have inflexible key requirements.
|
||||
|
||||
Warning: Be careful with this client as it bypasses tenant isolation.
|
||||
"""
|
||||
return redis_pool.get_raw_client()
|
||||
|
||||
|
||||
def get_raw_redis_replica_client() -> Redis:
|
||||
"""
|
||||
Returns a Redis replica client that doesn't apply tenant prefixing to keys.
|
||||
|
||||
Similar to get_raw_redis_client(), but connects to a read replica when available.
|
||||
Use this for read-heavy operations that need direct Redis access without
|
||||
tenant isolation or key prefixing.
|
||||
|
||||
Warning: Be careful with this client as it bypasses tenant isolation.
|
||||
"""
|
||||
return redis_pool.get_raw_replica_client()
|
||||
|
||||
|
||||
SSL_CERT_REQS_MAP = {
|
||||
"none": ssl.CERT_NONE,
|
||||
"optional": ssl.CERT_OPTIONAL,
|
||||
|
||||
@@ -15,7 +15,7 @@ from onyx.configs.model_configs import DEFAULT_DOCUMENT_ENCODER_MODEL
|
||||
from onyx.connectors.models import Document
|
||||
from onyx.connectors.models import IndexAttemptMetadata
|
||||
from onyx.connectors.models import InputType
|
||||
from onyx.connectors.models import TextSection
|
||||
from onyx.connectors.models import Section
|
||||
from onyx.db.connector import check_connectors_exist
|
||||
from onyx.db.connector import create_connector
|
||||
from onyx.db.connector_credential_pair import add_credential_to_connector
|
||||
@@ -55,7 +55,7 @@ def _create_indexable_chunks(
|
||||
# The section is not really used past this point since we have already done the other processing
|
||||
# for the chunking and embedding.
|
||||
sections=[
|
||||
TextSection(
|
||||
Section(
|
||||
text=preprocessed_doc["content"],
|
||||
link=preprocessed_doc["url"],
|
||||
image_file_name=None,
|
||||
@@ -98,7 +98,6 @@ def _create_indexable_chunks(
|
||||
boost=DEFAULT_BOOST,
|
||||
large_chunk_id=None,
|
||||
image_file_name=None,
|
||||
aggregated_chunk_boost_factor=1.0,
|
||||
)
|
||||
|
||||
chunks.append(chunk)
|
||||
|
||||
@@ -14,7 +14,6 @@ from onyx.db.llm import fetch_existing_llm_providers_for_user
|
||||
from onyx.db.llm import fetch_provider
|
||||
from onyx.db.llm import remove_llm_provider
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import update_default_vision_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.models import User
|
||||
from onyx.llm.factory import get_default_llms
|
||||
@@ -22,13 +21,11 @@ from onyx.llm.factory import get_llm
|
||||
from onyx.llm.llm_provider_options import fetch_available_well_known_llms
|
||||
from onyx.llm.llm_provider_options import WellKnownLLMProviderDescriptor
|
||||
from onyx.llm.utils import litellm_exception_to_error_msg
|
||||
from onyx.llm.utils import model_supports_image_input
|
||||
from onyx.llm.utils import test_llm
|
||||
from onyx.server.manage.llm.models import FullLLMProvider
|
||||
from onyx.server.manage.llm.models import LLMProviderDescriptor
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import TestLLMRequest
|
||||
from onyx.server.manage.llm.models import VisionProviderResponse
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
|
||||
@@ -189,62 +186,6 @@ def set_provider_as_default(
|
||||
update_default_provider(provider_id=provider_id, db_session=db_session)
|
||||
|
||||
|
||||
@admin_router.post("/provider/{provider_id}/default-vision")
|
||||
def set_provider_as_default_vision(
|
||||
provider_id: int,
|
||||
vision_model: str
|
||||
| None = Query(None, description="The default vision model to use"),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
update_default_vision_provider(
|
||||
provider_id=provider_id, vision_model=vision_model, db_session=db_session
|
||||
)
|
||||
|
||||
|
||||
@admin_router.get("/vision-providers")
|
||||
def get_vision_capable_providers(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[VisionProviderResponse]:
|
||||
"""Return a list of LLM providers and their models that support image input"""
|
||||
|
||||
providers = fetch_existing_llm_providers(db_session)
|
||||
vision_providers = []
|
||||
|
||||
logger.info("Fetching vision-capable providers")
|
||||
|
||||
for provider in providers:
|
||||
vision_models = []
|
||||
|
||||
# Check model names in priority order
|
||||
model_names_to_check = []
|
||||
if provider.model_names:
|
||||
model_names_to_check = provider.model_names
|
||||
elif provider.display_model_names:
|
||||
model_names_to_check = provider.display_model_names
|
||||
elif provider.default_model_name:
|
||||
model_names_to_check = [provider.default_model_name]
|
||||
|
||||
# Check each model for vision capability
|
||||
for model_name in model_names_to_check:
|
||||
if model_supports_image_input(model_name, provider.provider):
|
||||
vision_models.append(model_name)
|
||||
logger.debug(f"Vision model found: {provider.provider}/{model_name}")
|
||||
|
||||
# Only include providers with at least one vision-capable model
|
||||
if vision_models:
|
||||
provider_dict = FullLLMProvider.from_model(provider).model_dump()
|
||||
provider_dict["vision_models"] = vision_models
|
||||
logger.info(
|
||||
f"Vision provider: {provider.provider} with models: {vision_models}"
|
||||
)
|
||||
vision_providers.append(VisionProviderResponse(**provider_dict))
|
||||
|
||||
logger.info(f"Found {len(vision_providers)} vision-capable providers")
|
||||
return vision_providers
|
||||
|
||||
|
||||
"""Endpoints for all"""
|
||||
|
||||
|
||||
|
||||
@@ -34,8 +34,6 @@ class LLMProviderDescriptor(BaseModel):
|
||||
default_model_name: str
|
||||
fast_default_model_name: str | None
|
||||
is_default_provider: bool | None
|
||||
is_default_vision_provider: bool | None
|
||||
default_vision_model: str | None
|
||||
display_model_names: list[str] | None
|
||||
|
||||
@classmethod
|
||||
@@ -48,10 +46,11 @@ class LLMProviderDescriptor(BaseModel):
|
||||
default_model_name=llm_provider_model.default_model_name,
|
||||
fast_default_model_name=llm_provider_model.fast_default_model_name,
|
||||
is_default_provider=llm_provider_model.is_default_provider,
|
||||
is_default_vision_provider=llm_provider_model.is_default_vision_provider,
|
||||
default_vision_model=llm_provider_model.default_vision_model,
|
||||
model_names=llm_provider_model.model_names
|
||||
or fetch_models_for_provider(llm_provider_model.provider),
|
||||
model_names=(
|
||||
llm_provider_model.model_names
|
||||
or fetch_models_for_provider(llm_provider_model.provider)
|
||||
or [llm_provider_model.default_model_name]
|
||||
),
|
||||
display_model_names=llm_provider_model.display_model_names,
|
||||
)
|
||||
|
||||
@@ -69,7 +68,6 @@ class LLMProvider(BaseModel):
|
||||
groups: list[int] = Field(default_factory=list)
|
||||
display_model_names: list[str] | None = None
|
||||
deployment_name: str | None = None
|
||||
default_vision_model: str | None = None
|
||||
|
||||
|
||||
class LLMProviderUpsertRequest(LLMProvider):
|
||||
@@ -81,7 +79,6 @@ class LLMProviderUpsertRequest(LLMProvider):
|
||||
class FullLLMProvider(LLMProvider):
|
||||
id: int
|
||||
is_default_provider: bool | None = None
|
||||
is_default_vision_provider: bool | None = None
|
||||
model_names: list[str]
|
||||
|
||||
@classmethod
|
||||
@@ -97,8 +94,6 @@ class FullLLMProvider(LLMProvider):
|
||||
default_model_name=llm_provider_model.default_model_name,
|
||||
fast_default_model_name=llm_provider_model.fast_default_model_name,
|
||||
is_default_provider=llm_provider_model.is_default_provider,
|
||||
is_default_vision_provider=llm_provider_model.is_default_vision_provider,
|
||||
default_vision_model=llm_provider_model.default_vision_model,
|
||||
display_model_names=llm_provider_model.display_model_names,
|
||||
model_names=(
|
||||
llm_provider_model.model_names
|
||||
@@ -109,9 +104,3 @@ class FullLLMProvider(LLMProvider):
|
||||
groups=[group.id for group in llm_provider_model.groups],
|
||||
deployment_name=llm_provider_model.deployment_name,
|
||||
)
|
||||
|
||||
|
||||
class VisionProviderResponse(FullLLMProvider):
|
||||
"""Response model for vision providers endpoint, including vision-specific fields."""
|
||||
|
||||
vision_models: list[str]
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import re
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import cast
|
||||
|
||||
import jwt
|
||||
from email_validator import EmailNotValidError
|
||||
@@ -33,12 +31,9 @@ from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.auth.users import optional_user
|
||||
from onyx.configs.app_configs import AUTH_BACKEND
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import AuthBackend
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.app_configs import ENABLE_EMAIL_INVITES
|
||||
from onyx.configs.app_configs import REDIS_AUTH_KEY_PREFIX
|
||||
from onyx.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
from onyx.configs.app_configs import VALID_EMAIL_DOMAINS
|
||||
from onyx.configs.constants import AuthType
|
||||
@@ -55,7 +50,6 @@ from onyx.db.users import get_total_filtered_users_count
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.db.users import validate_user_role_update
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.redis.redis_pool import get_raw_redis_client
|
||||
from onyx.server.documents.models import PaginatedReturn
|
||||
from onyx.server.manage.models import AllUsersResponse
|
||||
from onyx.server.manage.models import AutoScrollRequest
|
||||
@@ -483,7 +477,7 @@ async def get_user_role(user: User = Depends(current_user)) -> UserRoleResponse:
|
||||
return UserRoleResponse(role=user.role)
|
||||
|
||||
|
||||
def get_current_auth_token_expiration_jwt(
|
||||
def get_current_token_expiration_jwt(
|
||||
user: User | None, request: Request
|
||||
) -> datetime | None:
|
||||
if user is None:
|
||||
@@ -512,48 +506,6 @@ def get_current_auth_token_expiration_jwt(
|
||||
return None
|
||||
|
||||
|
||||
def get_current_auth_token_creation_redis(
|
||||
user: User | None, request: Request
|
||||
) -> datetime | None:
|
||||
"""Calculate the token creation time from Redis TTL information.
|
||||
|
||||
This function retrieves the authentication token from cookies,
|
||||
checks its TTL in Redis, and calculates when the token was created.
|
||||
Despite the function name, it returns the token creation time, not the expiration time.
|
||||
"""
|
||||
if user is None:
|
||||
return None
|
||||
try:
|
||||
# Get the token from the request
|
||||
token = request.cookies.get(FASTAPI_USERS_AUTH_COOKIE_NAME)
|
||||
if not token:
|
||||
logger.debug("No auth token cookie found")
|
||||
return None
|
||||
|
||||
# Get the Redis client
|
||||
redis = get_raw_redis_client()
|
||||
redis_key = REDIS_AUTH_KEY_PREFIX + token
|
||||
|
||||
# Get the TTL of the token
|
||||
ttl = cast(int, redis.ttl(redis_key))
|
||||
if ttl <= 0:
|
||||
logger.error("Token has expired or doesn't exist in Redis")
|
||||
return None
|
||||
|
||||
# Calculate the creation time based on TTL and session expiry
|
||||
# Current time minus (total session length minus remaining TTL)
|
||||
current_time = datetime.now(timezone.utc)
|
||||
token_creation_time = current_time - timedelta(
|
||||
seconds=(SESSION_EXPIRE_TIME_SECONDS - ttl)
|
||||
)
|
||||
|
||||
return token_creation_time
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving token expiration from Redis: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_current_token_creation(
|
||||
user: User | None, db_session: Session
|
||||
) -> datetime | None:
|
||||
@@ -581,7 +533,6 @@ def get_current_token_creation(
|
||||
|
||||
@router.get("/me")
|
||||
def verify_user_logged_in(
|
||||
request: Request,
|
||||
user: User | None = Depends(optional_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UserInfo:
|
||||
@@ -607,9 +558,7 @@ def verify_user_logged_in(
|
||||
)
|
||||
|
||||
token_created_at = (
|
||||
get_current_auth_token_creation_redis(user, request)
|
||||
if AUTH_BACKEND == AuthBackend.REDIS
|
||||
else get_current_token_creation(user, db_session)
|
||||
None if MULTI_TENANT else get_current_token_creation(user, db_session)
|
||||
)
|
||||
|
||||
team_name = fetch_ee_implementation_or_noop(
|
||||
|
||||
@@ -19,9 +19,6 @@ from onyx.db.search_settings import get_secondary_search_settings
|
||||
from onyx.document_index.factory import get_default_document_index
|
||||
from onyx.indexing.embedder import DefaultIndexingEmbedder
|
||||
from onyx.indexing.indexing_pipeline import build_indexing_pipeline
|
||||
from onyx.natural_language_processing.search_nlp_models import (
|
||||
InformationContentClassificationModel,
|
||||
)
|
||||
from onyx.server.onyx_api.models import DocMinimalInfo
|
||||
from onyx.server.onyx_api.models import IngestionDocument
|
||||
from onyx.server.onyx_api.models import IngestionResult
|
||||
@@ -105,11 +102,8 @@ def upsert_ingestion_doc(
|
||||
search_settings=search_settings
|
||||
)
|
||||
|
||||
information_content_classification_model = InformationContentClassificationModel()
|
||||
|
||||
indexing_pipeline = build_indexing_pipeline(
|
||||
embedder=index_embedding_model,
|
||||
information_content_classification_model=information_content_classification_model,
|
||||
document_index=curr_doc_index,
|
||||
ignore_time_skip=True,
|
||||
db_session=db_session,
|
||||
@@ -144,7 +138,6 @@ def upsert_ingestion_doc(
|
||||
|
||||
sec_ind_pipeline = build_indexing_pipeline(
|
||||
embedder=new_index_embedding_model,
|
||||
information_content_classification_model=information_content_classification_model,
|
||||
document_index=sec_doc_index,
|
||||
ignore_time_skip=True,
|
||||
db_session=db_session,
|
||||
|
||||
@@ -25,7 +25,7 @@ google-auth-oauthlib==1.0.0
|
||||
httpcore==1.0.5
|
||||
httpx[http2]==0.27.0
|
||||
httpx-oauth==0.15.1
|
||||
huggingface-hub==0.29.0
|
||||
huggingface-hub==0.20.1
|
||||
inflection==0.5.1
|
||||
jira==3.5.1
|
||||
jsonref==1.1.0
|
||||
@@ -71,7 +71,6 @@ requests==2.32.2
|
||||
requests-oauthlib==1.3.1
|
||||
retry==0.9.2 # This pulls in py which is in CVE-2022-42969, must remove py from image
|
||||
rfc3986==1.5.0
|
||||
setfit==1.1.1
|
||||
simple-salesforce==1.12.6
|
||||
slack-sdk==3.20.2
|
||||
SQLAlchemy[mypy]==2.0.15
|
||||
@@ -79,7 +78,7 @@ starlette==0.36.3
|
||||
supervisor==4.2.5
|
||||
tiktoken==0.7.0
|
||||
timeago==1.0.16
|
||||
transformers==4.49.0
|
||||
transformers==4.39.2
|
||||
unstructured==0.15.1
|
||||
unstructured-client==0.25.4
|
||||
uvicorn==0.21.1
|
||||
|
||||
@@ -14,7 +14,7 @@ pytest-asyncio==0.22.0
|
||||
pytest==7.4.4
|
||||
reorder-python-imports==3.9.0
|
||||
ruff==0.0.286
|
||||
sentence-transformers==3.4.1
|
||||
sentence-transformers==2.6.1
|
||||
trafilatura==1.12.2
|
||||
types-beautifulsoup4==4.12.0.3
|
||||
types-html5lib==1.1.11.13
|
||||
|
||||
@@ -7,10 +7,9 @@ openai==1.61.0
|
||||
pydantic==2.8.2
|
||||
retry==0.9.2
|
||||
safetensors==0.4.2
|
||||
sentence-transformers==3.4.1
|
||||
setfit==1.1.1
|
||||
sentence-transformers==2.6.1
|
||||
torch==2.2.0
|
||||
transformers==4.49.0
|
||||
transformers==4.39.2
|
||||
uvicorn==0.21.1
|
||||
voyageai==0.2.3
|
||||
litellm==1.61.16
|
||||
|
||||
@@ -161,21 +161,17 @@ overview_doc = SeedPresaveDocument(
|
||||
url="https://docs.onyx.app/more/use_cases/overview",
|
||||
title=overview_title,
|
||||
content=overview,
|
||||
title_embedding=list(model.encode(f"search_document: {overview_title}")),
|
||||
content_embedding=list(
|
||||
model.encode(f"search_document: {overview_title}\n{overview}")
|
||||
),
|
||||
title_embedding=model.encode(f"search_document: {overview_title}"),
|
||||
content_embedding=model.encode(f"search_document: {overview_title}\n{overview}"),
|
||||
)
|
||||
|
||||
enterprise_search_doc = SeedPresaveDocument(
|
||||
url="https://docs.onyx.app/more/use_cases/enterprise_search",
|
||||
title=enterprise_search_title,
|
||||
content=enterprise_search_1,
|
||||
title_embedding=list(model.encode(f"search_document: {enterprise_search_title}")),
|
||||
content_embedding=list(
|
||||
model.encode(
|
||||
f"search_document: {enterprise_search_title}\n{enterprise_search_1}"
|
||||
)
|
||||
title_embedding=model.encode(f"search_document: {enterprise_search_title}"),
|
||||
content_embedding=model.encode(
|
||||
f"search_document: {enterprise_search_title}\n{enterprise_search_1}"
|
||||
),
|
||||
)
|
||||
|
||||
@@ -183,11 +179,9 @@ enterprise_search_doc_2 = SeedPresaveDocument(
|
||||
url="https://docs.onyx.app/more/use_cases/enterprise_search",
|
||||
title=enterprise_search_title,
|
||||
content=enterprise_search_2,
|
||||
title_embedding=list(model.encode(f"search_document: {enterprise_search_title}")),
|
||||
content_embedding=list(
|
||||
model.encode(
|
||||
f"search_document: {enterprise_search_title}\n{enterprise_search_2}"
|
||||
)
|
||||
title_embedding=model.encode(f"search_document: {enterprise_search_title}"),
|
||||
content_embedding=model.encode(
|
||||
f"search_document: {enterprise_search_title}\n{enterprise_search_2}"
|
||||
),
|
||||
chunk_ind=1,
|
||||
)
|
||||
@@ -196,9 +190,9 @@ ai_platform_doc = SeedPresaveDocument(
|
||||
url="https://docs.onyx.app/more/use_cases/ai_platform",
|
||||
title=ai_platform_title,
|
||||
content=ai_platform,
|
||||
title_embedding=list(model.encode(f"search_document: {ai_platform_title}")),
|
||||
content_embedding=list(
|
||||
model.encode(f"search_document: {ai_platform_title}\n{ai_platform}")
|
||||
title_embedding=model.encode(f"search_document: {ai_platform_title}"),
|
||||
content_embedding=model.encode(
|
||||
f"search_document: {ai_platform_title}\n{ai_platform}"
|
||||
),
|
||||
)
|
||||
|
||||
@@ -206,9 +200,9 @@ customer_support_doc = SeedPresaveDocument(
|
||||
url="https://docs.onyx.app/more/use_cases/support",
|
||||
title=customer_support_title,
|
||||
content=customer_support,
|
||||
title_embedding=list(model.encode(f"search_document: {customer_support_title}")),
|
||||
content_embedding=list(
|
||||
model.encode(f"search_document: {customer_support_title}\n{customer_support}")
|
||||
title_embedding=model.encode(f"search_document: {customer_support_title}"),
|
||||
content_embedding=model.encode(
|
||||
f"search_document: {customer_support_title}\n{customer_support}"
|
||||
),
|
||||
)
|
||||
|
||||
@@ -216,17 +210,17 @@ sales_doc = SeedPresaveDocument(
|
||||
url="https://docs.onyx.app/more/use_cases/sales",
|
||||
title=sales_title,
|
||||
content=sales,
|
||||
title_embedding=list(model.encode(f"search_document: {sales_title}")),
|
||||
content_embedding=list(model.encode(f"search_document: {sales_title}\n{sales}")),
|
||||
title_embedding=model.encode(f"search_document: {sales_title}"),
|
||||
content_embedding=model.encode(f"search_document: {sales_title}\n{sales}"),
|
||||
)
|
||||
|
||||
operations_doc = SeedPresaveDocument(
|
||||
url="https://docs.onyx.app/more/use_cases/operations",
|
||||
title=operations_title,
|
||||
content=operations,
|
||||
title_embedding=list(model.encode(f"search_document: {operations_title}")),
|
||||
content_embedding=list(
|
||||
model.encode(f"search_document: {operations_title}\n{operations}")
|
||||
title_embedding=model.encode(f"search_document: {operations_title}"),
|
||||
content_embedding=model.encode(
|
||||
f"search_document: {operations_title}\n{operations}"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -99,7 +99,6 @@ def generate_dummy_chunk(
|
||||
),
|
||||
document_sets={document_set for document_set in document_set_names},
|
||||
boost=random.randint(-1, 1),
|
||||
aggregated_chunk_boost_factor=random.random(),
|
||||
tenant_id=POSTGRES_DEFAULT_SCHEMA,
|
||||
)
|
||||
|
||||
|
||||
@@ -23,11 +23,9 @@ INDEXING_MODEL_SERVER_PORT = int(
|
||||
# Onyx custom Deep Learning Models
|
||||
CONNECTOR_CLASSIFIER_MODEL_REPO = "Danswer/filter-extraction-model"
|
||||
CONNECTOR_CLASSIFIER_MODEL_TAG = "1.0.0"
|
||||
INTENT_MODEL_VERSION = "onyx-dot-app/hybrid-intent-token-classifier"
|
||||
# INTENT_MODEL_TAG = "v1.0.3"
|
||||
INTENT_MODEL_TAG: str | None = None
|
||||
INFORMATION_CONTENT_MODEL_VERSION = "onyx-dot-app/information-content-model"
|
||||
INFORMATION_CONTENT_MODEL_TAG: str | None = None
|
||||
INTENT_MODEL_VERSION = "danswer/hybrid-intent-token-classifier"
|
||||
INTENT_MODEL_TAG = "v1.0.3"
|
||||
|
||||
|
||||
# Bi-Encoder, other details
|
||||
DOC_EMBEDDING_CONTEXT_SIZE = 512
|
||||
@@ -279,20 +277,3 @@ SUPPORTED_EMBEDDING_MODELS = [
|
||||
index_name="danswer_chunk_intfloat_multilingual_e5_small",
|
||||
),
|
||||
]
|
||||
# Maximum (least severe) downgrade factor for chunks above the cutoff
|
||||
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX = float(
|
||||
os.environ.get("INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX") or 1.0
|
||||
)
|
||||
# Minimum (most severe) downgrade factor for short chunks below the cutoff if no content
|
||||
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN = float(
|
||||
os.environ.get("INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN") or 0.7
|
||||
)
|
||||
# Temperature for the information content classification model
|
||||
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE = float(
|
||||
os.environ.get("INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE") or 4.0
|
||||
)
|
||||
# Cutoff below which we start using the information content classification model
|
||||
# (cutoff length number itself is still considered 'short'))
|
||||
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH = int(
|
||||
os.environ.get("INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH") or 10
|
||||
)
|
||||
|
||||
@@ -4,7 +4,6 @@ from shared_configs.enums import EmbeddingProvider
|
||||
from shared_configs.enums import EmbedTextType
|
||||
from shared_configs.enums import RerankerProvider
|
||||
|
||||
|
||||
Embedding = list[float]
|
||||
|
||||
|
||||
@@ -74,20 +73,7 @@ class IntentResponse(BaseModel):
|
||||
keywords: list[str]
|
||||
|
||||
|
||||
class InformationContentClassificationRequests(BaseModel):
|
||||
queries: list[str]
|
||||
|
||||
|
||||
class SupportedEmbeddingModel(BaseModel):
|
||||
name: str
|
||||
dim: int
|
||||
index_name: str
|
||||
|
||||
|
||||
class ContentClassificationPrediction(BaseModel):
|
||||
predicted_label: int
|
||||
content_boost_factor: float
|
||||
|
||||
|
||||
class InformationContentClassificationResponses(BaseModel):
|
||||
information_content_classifications: list[ContentClassificationPrediction]
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user