Compare commits

..

3 Commits

Author SHA1 Message Date
Weves
9b169350a9 Switch to monotonic 2025-03-07 19:04:47 -08:00
Weves
c1dbb073d0 Small tweaks 2025-03-07 15:53:22 -08:00
Weves
39bfc6ae16 Add basic memory logging 2025-03-07 15:46:14 -08:00
213 changed files with 1825 additions and 5843 deletions

View File

@@ -48,8 +48,6 @@ env:
# Gitbook
GITBOOK_SPACE_ID: ${{ secrets.GITBOOK_SPACE_ID }}
GITBOOK_API_KEY: ${{ secrets.GITBOOK_API_KEY }}
# Notion
NOTION_INTEGRATION_TOKEN: ${{ secrets.NOTION_INTEGRATION_TOKEN }}
jobs:
connectors-check:

View File

@@ -114,4 +114,3 @@ To try the Onyx Enterprise Edition:
## 💡 Contributing
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.

View File

@@ -31,8 +31,7 @@ RUN python -c "from transformers import AutoTokenizer; \
AutoTokenizer.from_pretrained('distilbert-base-uncased'); \
AutoTokenizer.from_pretrained('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
from huggingface_hub import snapshot_download; \
snapshot_download(repo_id='onyx-dot-app/hybrid-intent-token-classifier'); \
snapshot_download(repo_id='onyx-dot-app/information-content-model'); \
snapshot_download(repo_id='danswer/hybrid-intent-token-classifier', revision='v1.0.3'); \
snapshot_download('nomic-ai/nomic-embed-text-v1'); \
snapshot_download('mixedbread-ai/mxbai-rerank-xsmall-v1'); \
from sentence_transformers import SentenceTransformer; \

View File

@@ -1,51 +0,0 @@
"""add chunk stats table
Revision ID: 3781a5eb12cb
Revises: df46c75b714e
Create Date: 2025-03-10 10:02:30.586666
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "3781a5eb12cb"
down_revision = "df46c75b714e"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"chunk_stats",
sa.Column("id", sa.String(), primary_key=True, index=True),
sa.Column(
"document_id",
sa.String(),
sa.ForeignKey("document.id"),
nullable=False,
index=True,
),
sa.Column("chunk_in_doc_id", sa.Integer(), nullable=False),
sa.Column("information_content_boost", sa.Float(), nullable=True),
sa.Column(
"last_modified",
sa.DateTime(timezone=True),
nullable=False,
index=True,
server_default=sa.func.now(),
),
sa.Column("last_synced", sa.DateTime(timezone=True), nullable=True, index=True),
sa.UniqueConstraint(
"document_id", "chunk_in_doc_id", name="uq_chunk_stats_doc_chunk"
),
)
op.create_index(
"ix_chunk_sync_status", "chunk_stats", ["last_modified", "last_synced"]
)
def downgrade() -> None:
op.drop_index("ix_chunk_sync_status", table_name="chunk_stats")
op.drop_table("chunk_stats")

View File

@@ -5,10 +5,7 @@ Revises: f1ca58b2f2ec
Create Date: 2025-01-29 07:48:46.784041
"""
import logging
from typing import cast
from alembic import op
from sqlalchemy.exc import IntegrityError
from sqlalchemy.sql import text
@@ -18,45 +15,21 @@ down_revision = "f1ca58b2f2ec"
branch_labels = None
depends_on = None
logger = logging.getLogger("alembic.runtime.migration")
def upgrade() -> None:
"""Conflicts on lowercasing will result in the uppercased email getting a
unique integer suffix when converted to lowercase."""
# Get database connection
connection = op.get_bind()
# Fetch all user emails that are not already lowercase
user_emails = connection.execute(
text('SELECT id, email FROM "user" WHERE email != LOWER(email)')
).fetchall()
for user_id, email in user_emails:
email = cast(str, email)
username, domain = email.rsplit("@", 1)
new_email = f"{username.lower()}@{domain.lower()}"
attempt = 1
while True:
try:
# Try updating the email
connection.execute(
text('UPDATE "user" SET email = :new_email WHERE id = :user_id'),
{"new_email": new_email, "user_id": user_id},
)
break # Success, exit loop
except IntegrityError:
next_email = f"{username.lower()}_{attempt}@{domain.lower()}"
# Email conflict occurred, append `_1`, `_2`, etc., to the username
logger.warning(
f"Conflict while lowercasing email: "
f"old_email={email} "
f"conflicting_email={new_email} "
f"next_email={next_email}"
)
new_email = next_email
attempt += 1
# Update all user emails to lowercase
connection.execute(
text(
"""
UPDATE "user"
SET email = LOWER(email)
WHERE email != LOWER(email)
"""
)
)
def downgrade() -> None:

View File

@@ -1,36 +0,0 @@
"""add_default_vision_provider_to_llm_provider
Revision ID: df46c75b714e
Revises: 3934b1bc7b62
Create Date: 2025-03-11 16:20:19.038945
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "df46c75b714e"
down_revision = "3934b1bc7b62"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"llm_provider",
sa.Column(
"is_default_vision_provider",
sa.Boolean(),
nullable=True,
server_default=sa.false(),
),
)
op.add_column(
"llm_provider", sa.Column("default_vision_model", sa.String(), nullable=True)
)
def downgrade() -> None:
op.drop_column("llm_provider", "default_vision_model")
op.drop_column("llm_provider", "is_default_vision_provider")

View File

@@ -1,33 +0,0 @@
"""add new available tenant table
Revision ID: 3b45e0018bf1
Revises: ac842f85f932
Create Date: 2025-03-06 09:55:18.229910
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "3b45e0018bf1"
down_revision = "ac842f85f932"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create new_available_tenant table
op.create_table(
"available_tenant",
sa.Column("tenant_id", sa.String(), nullable=False),
sa.Column("alembic_version", sa.String(), nullable=False),
sa.Column("date_created", sa.DateTime(), nullable=False),
sa.PrimaryKeyConstraint("tenant_id"),
)
def downgrade() -> None:
# Drop new_available_tenant table
op.drop_table("available_tenant")

View File

@@ -1,51 +0,0 @@
"""new column user tenant mapping
Revision ID: ac842f85f932
Revises: 34e3630c7f32
Create Date: 2025-03-03 13:30:14.802874
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "ac842f85f932"
down_revision = "34e3630c7f32"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add active column with default value of True
op.add_column(
"user_tenant_mapping",
sa.Column(
"active",
sa.Boolean(),
nullable=False,
server_default="true",
),
schema="public",
)
op.drop_constraint("uq_email", "user_tenant_mapping", schema="public")
# Create a unique index for active=true records
# This ensures a user can only be active in one tenant at a time
op.execute(
"CREATE UNIQUE INDEX uq_user_active_email_idx ON public.user_tenant_mapping (email) WHERE active = true"
)
def downgrade() -> None:
# Drop the unique index for active=true records
op.execute("DROP INDEX IF EXISTS uq_user_active_email_idx")
op.create_unique_constraint(
"uq_email", "user_tenant_mapping", ["email"], schema="public"
)
# Remove the active column
op.drop_column("user_tenant_mapping", "active", schema="public")

View File

@@ -1,14 +1,10 @@
import re
from typing import cast
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from ee.onyx.server.query_and_chat.models import AgentAnswer
from ee.onyx.server.query_and_chat.models import AgentSubQuery
from ee.onyx.server.query_and_chat.models import AgentSubQuestion
from ee.onyx.server.query_and_chat.models import BasicCreateChatMessageRequest
from ee.onyx.server.query_and_chat.models import (
BasicCreateChatMessageWithHistoryRequest,
@@ -18,19 +14,13 @@ from ee.onyx.server.query_and_chat.models import SimpleDoc
from onyx.auth.users import current_user
from onyx.chat.chat_utils import combine_message_thread
from onyx.chat.chat_utils import create_chat_chain
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import AllCitations
from onyx.chat.models import ExtendedToolResponse
from onyx.chat.models import FinalUsedContextDocsResponse
from onyx.chat.models import LlmDoc
from onyx.chat.models import LLMRelevanceFilterResponse
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import QADocsResponse
from onyx.chat.models import RefinedAnswerImprovement
from onyx.chat.models import StreamingError
from onyx.chat.models import SubQueryPiece
from onyx.chat.models import SubQuestionIdentifier
from onyx.chat.models import SubQuestionPiece
from onyx.chat.process_message import ChatPacketStream
from onyx.chat.process_message import stream_chat_message_objects
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
@@ -99,12 +89,6 @@ def _convert_packet_stream_to_response(
final_context_docs: list[LlmDoc] = []
answer = ""
# accumulate stream data with these dicts
agent_sub_questions: dict[tuple[int, int], AgentSubQuestion] = {}
agent_answers: dict[tuple[int, int], AgentAnswer] = {}
agent_sub_queries: dict[tuple[int, int, int], AgentSubQuery] = {}
for packet in packets:
if isinstance(packet, OnyxAnswerPiece) and packet.answer_piece:
answer += packet.answer_piece
@@ -113,15 +97,6 @@ def _convert_packet_stream_to_response(
# TODO: deprecate `simple_search_docs`
response.simple_search_docs = _translate_doc_response_to_simple_doc(packet)
# This is a no-op if agent_sub_questions hasn't already been filled
if packet.level is not None and packet.level_question_num is not None:
id = (packet.level, packet.level_question_num)
if id in agent_sub_questions:
agent_sub_questions[id].document_ids = [
saved_search_doc.document_id
for saved_search_doc in packet.top_documents
]
elif isinstance(packet, StreamingError):
response.error_msg = packet.error
elif isinstance(packet, ChatMessageDetail):
@@ -138,104 +113,11 @@ def _convert_packet_stream_to_response(
citation.citation_num: citation.document_id
for citation in packet.citations
}
# agentic packets
elif isinstance(packet, SubQuestionPiece):
if packet.level is not None and packet.level_question_num is not None:
id = (packet.level, packet.level_question_num)
if agent_sub_questions.get(id) is None:
agent_sub_questions[id] = AgentSubQuestion(
level=packet.level,
level_question_num=packet.level_question_num,
sub_question=packet.sub_question,
document_ids=[],
)
else:
agent_sub_questions[id].sub_question += packet.sub_question
elif isinstance(packet, AgentAnswerPiece):
if packet.level is not None and packet.level_question_num is not None:
id = (packet.level, packet.level_question_num)
if agent_answers.get(id) is None:
agent_answers[id] = AgentAnswer(
level=packet.level,
level_question_num=packet.level_question_num,
answer=packet.answer_piece,
answer_type=packet.answer_type,
)
else:
agent_answers[id].answer += packet.answer_piece
elif isinstance(packet, SubQueryPiece):
if packet.level is not None and packet.level_question_num is not None:
sub_query_id = (
packet.level,
packet.level_question_num,
packet.query_id,
)
if agent_sub_queries.get(sub_query_id) is None:
agent_sub_queries[sub_query_id] = AgentSubQuery(
level=packet.level,
level_question_num=packet.level_question_num,
sub_query=packet.sub_query,
query_id=packet.query_id,
)
else:
agent_sub_queries[sub_query_id].sub_query += packet.sub_query
elif isinstance(packet, ExtendedToolResponse):
# we shouldn't get this ... it gets intercepted and translated to QADocsResponse
logger.warning(
"_convert_packet_stream_to_response: Unexpected chat packet type ExtendedToolResponse!"
)
elif isinstance(packet, RefinedAnswerImprovement):
response.agent_refined_answer_improvement = (
packet.refined_answer_improvement
)
else:
logger.warning(
f"_convert_packet_stream_to_response - Unrecognized chat packet: type={type(packet)}"
)
response.final_context_doc_indices = _get_final_context_doc_indices(
final_context_docs, response.top_documents
)
# organize / sort agent metadata for output
if len(agent_sub_questions) > 0:
response.agent_sub_questions = cast(
dict[int, list[AgentSubQuestion]],
SubQuestionIdentifier.make_dict_by_level(agent_sub_questions),
)
if len(agent_answers) > 0:
# return the agent_level_answer from the first level or the last one depending
# on agent_refined_answer_improvement
response.agent_answers = cast(
dict[int, list[AgentAnswer]],
SubQuestionIdentifier.make_dict_by_level(agent_answers),
)
if response.agent_answers:
selected_answer_level = (
0
if not response.agent_refined_answer_improvement
else len(response.agent_answers) - 1
)
level_answers = response.agent_answers[selected_answer_level]
for level_answer in level_answers:
if level_answer.answer_type != "agent_level_answer":
continue
answer = level_answer.answer
break
if len(agent_sub_queries) > 0:
# subqueries are often emitted with trailing whitespace ... clean it up here
# perhaps fix at the source?
for v in agent_sub_queries.values():
v.sub_query = v.sub_query.strip()
response.agent_sub_queries = (
AgentSubQuery.make_dict_by_level_and_question_index(agent_sub_queries)
)
response.answer = answer
if answer:
response.answer_citationless = remove_answer_citations(answer)

View File

@@ -1,5 +1,3 @@
from collections import OrderedDict
from typing import Literal
from uuid import UUID
from pydantic import BaseModel
@@ -11,7 +9,6 @@ from onyx.chat.models import CitationInfo
from onyx.chat.models import OnyxContexts
from onyx.chat.models import PersonaOverrideConfig
from onyx.chat.models import QADocsResponse
from onyx.chat.models import SubQuestionIdentifier
from onyx.chat.models import ThreadMessage
from onyx.configs.constants import DocumentSource
from onyx.context.search.enums import LLMEvaluationType
@@ -91,64 +88,6 @@ class SimpleDoc(BaseModel):
metadata: dict | None
class AgentSubQuestion(SubQuestionIdentifier):
sub_question: str
document_ids: list[str]
class AgentAnswer(SubQuestionIdentifier):
answer: str
answer_type: Literal["agent_sub_answer", "agent_level_answer"]
class AgentSubQuery(SubQuestionIdentifier):
sub_query: str
query_id: int
@staticmethod
def make_dict_by_level_and_question_index(
original_dict: dict[tuple[int, int, int], "AgentSubQuery"]
) -> dict[int, dict[int, list["AgentSubQuery"]]]:
"""Takes a dict of tuple(level, question num, query_id) to sub queries.
returns a dict of level to dict[question num to list of query_id's]
Ordering is asc for readability.
"""
# In this function, when we sort int | None, we deliberately push None to the end
# map entries to the level_question_dict
level_question_dict: dict[int, dict[int, list["AgentSubQuery"]]] = {}
for k1, obj in original_dict.items():
level = k1[0]
question = k1[1]
if level not in level_question_dict:
level_question_dict[level] = {}
if question not in level_question_dict[level]:
level_question_dict[level][question] = []
level_question_dict[level][question].append(obj)
# sort each query_id list and question_index
for key1, obj1 in level_question_dict.items():
for key2, value2 in obj1.items():
# sort the query_id list of each question_index
level_question_dict[key1][key2] = sorted(
value2, key=lambda o: o.query_id
)
# sort the question_index dict of level
level_question_dict[key1] = OrderedDict(
sorted(level_question_dict[key1].items(), key=lambda x: (x is None, x))
)
# sort the top dict of levels
sorted_dict = OrderedDict(
sorted(level_question_dict.items(), key=lambda x: (x is None, x))
)
return sorted_dict
class ChatBasicResponse(BaseModel):
# This is built piece by piece, any of these can be None as the flow could break
answer: str | None = None
@@ -168,12 +107,6 @@ class ChatBasicResponse(BaseModel):
simple_search_docs: list[SimpleDoc] | None = None
llm_chunks_indices: list[int] | None = None
# agentic fields
agent_sub_questions: dict[int, list[AgentSubQuestion]] | None = None
agent_answers: dict[int, list[AgentAnswer]] | None = None
agent_sub_queries: dict[int, dict[int, list[AgentSubQuery]]] | None = None
agent_refined_answer_improvement: bool | None = None
class OneShotQARequest(ChunkContext):
# Supports simplier APIs that don't deal with chat histories or message edits

View File

@@ -1,45 +0,0 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Response
from ee.onyx.auth.users import current_cloud_superuser
from ee.onyx.server.tenants.models import ImpersonateRequest
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
from onyx.auth.users import auth_backend
from onyx.auth.users import get_redis_strategy
from onyx.auth.users import User
from onyx.db.engine import get_session_with_tenant
from onyx.db.users import get_user_by_email
from onyx.utils.logger import setup_logger
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.post("/impersonate")
async def impersonate_user(
impersonate_request: ImpersonateRequest,
_: User = Depends(current_cloud_superuser),
) -> Response:
"""Allows a cloud superuser to impersonate another user by generating an impersonation JWT token"""
tenant_id = get_tenant_id_for_email(impersonate_request.email)
with get_session_with_tenant(tenant_id=tenant_id) as tenant_session:
user_to_impersonate = get_user_by_email(
impersonate_request.email, tenant_session
)
if user_to_impersonate is None:
raise HTTPException(status_code=404, detail="User not found")
token = await get_redis_strategy().write_token(user_to_impersonate)
response = await auth_backend.transport.get_login_response(token)
response.set_cookie(
key="fastapiusersauth",
value=token,
httponly=True,
secure=True,
samesite="lax",
)
return response

View File

@@ -1,98 +0,0 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Response
from sqlalchemy.exc import IntegrityError
from ee.onyx.auth.users import generate_anonymous_user_jwt_token
from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME
from ee.onyx.server.tenants.anonymous_user_path import get_anonymous_user_path
from ee.onyx.server.tenants.anonymous_user_path import (
get_tenant_id_for_anonymous_user_path,
)
from ee.onyx.server.tenants.anonymous_user_path import modify_anonymous_user_path
from ee.onyx.server.tenants.anonymous_user_path import validate_anonymous_user_path
from ee.onyx.server.tenants.models import AnonymousUserPath
from onyx.auth.users import anonymous_user_enabled
from onyx.auth.users import current_admin_user
from onyx.auth.users import optional_user
from onyx.auth.users import User
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
from onyx.db.engine import get_session_with_shared_schema
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.get("/anonymous-user-path")
async def get_anonymous_user_path_api(
_: User | None = Depends(current_admin_user),
) -> AnonymousUserPath:
tenant_id = get_current_tenant_id()
if tenant_id is None:
raise HTTPException(status_code=404, detail="Tenant not found")
with get_session_with_shared_schema() as db_session:
current_path = get_anonymous_user_path(tenant_id, db_session)
return AnonymousUserPath(anonymous_user_path=current_path)
@router.post("/anonymous-user-path")
async def set_anonymous_user_path_api(
anonymous_user_path: str,
_: User | None = Depends(current_admin_user),
) -> None:
tenant_id = get_current_tenant_id()
try:
validate_anonymous_user_path(anonymous_user_path)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
with get_session_with_shared_schema() as db_session:
try:
modify_anonymous_user_path(tenant_id, anonymous_user_path, db_session)
except IntegrityError:
raise HTTPException(
status_code=409,
detail="The anonymous user path is already in use. Please choose a different path.",
)
except Exception as e:
logger.exception(f"Failed to modify anonymous user path: {str(e)}")
raise HTTPException(
status_code=500,
detail="An unexpected error occurred while modifying the anonymous user path",
)
@router.post("/anonymous-user")
async def login_as_anonymous_user(
anonymous_user_path: str,
_: User | None = Depends(optional_user),
) -> Response:
with get_session_with_shared_schema() as db_session:
tenant_id = get_tenant_id_for_anonymous_user_path(
anonymous_user_path, db_session
)
if not tenant_id:
raise HTTPException(status_code=404, detail="Tenant not found")
if not anonymous_user_enabled(tenant_id=tenant_id):
raise HTTPException(status_code=403, detail="Anonymous user is not enabled")
token = generate_anonymous_user_jwt_token(tenant_id)
response = Response()
response.delete_cookie(FASTAPI_USERS_AUTH_COOKIE_NAME)
response.set_cookie(
key=ANONYMOUS_USER_COOKIE_NAME,
value=token,
httponly=True,
secure=True,
samesite="strict",
)
return response

View File

@@ -1,24 +1,269 @@
import stripe
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Response
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from ee.onyx.server.tenants.admin_api import router as admin_router
from ee.onyx.server.tenants.anonymous_users_api import router as anonymous_users_router
from ee.onyx.server.tenants.billing_api import router as billing_router
from ee.onyx.server.tenants.team_membership_api import router as team_membership_router
from ee.onyx.server.tenants.tenant_management_api import (
router as tenant_management_router,
)
from ee.onyx.server.tenants.user_invitations_api import (
router as user_invitations_router,
from ee.onyx.auth.users import current_cloud_superuser
from ee.onyx.auth.users import generate_anonymous_user_jwt_token
from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
from ee.onyx.server.tenants.access import control_plane_dep
from ee.onyx.server.tenants.anonymous_user_path import get_anonymous_user_path
from ee.onyx.server.tenants.anonymous_user_path import (
get_tenant_id_for_anonymous_user_path,
)
from ee.onyx.server.tenants.anonymous_user_path import modify_anonymous_user_path
from ee.onyx.server.tenants.anonymous_user_path import validate_anonymous_user_path
from ee.onyx.server.tenants.billing import fetch_billing_information
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
from ee.onyx.server.tenants.models import AnonymousUserPath
from ee.onyx.server.tenants.models import BillingInformation
from ee.onyx.server.tenants.models import ImpersonateRequest
from ee.onyx.server.tenants.models import ProductGatingRequest
from ee.onyx.server.tenants.models import ProductGatingResponse
from ee.onyx.server.tenants.models import SubscriptionSessionResponse
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
from ee.onyx.server.tenants.product_gating import store_product_gating
from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant
from onyx.auth.users import anonymous_user_enabled
from onyx.auth.users import auth_backend
from onyx.auth.users import current_admin_user
from onyx.auth.users import get_redis_strategy
from onyx.auth.users import optional_user
from onyx.auth.users import User
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
from onyx.db.auth import get_user_count
from onyx.db.engine import get_session
from onyx.db.engine import get_session_with_shared_schema
from onyx.db.engine import get_session_with_tenant
from onyx.db.users import delete_user_from_db
from onyx.db.users import get_user_by_email
from onyx.server.manage.models import UserByEmail
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
# Create a main router to include all sub-routers
# Note: We don't add a prefix here as each router already has the /tenants prefix
router = APIRouter()
stripe.api_key = STRIPE_SECRET_KEY
logger = setup_logger()
router = APIRouter(prefix="/tenants")
# Include all the individual routers
router.include_router(admin_router)
router.include_router(anonymous_users_router)
router.include_router(billing_router)
router.include_router(team_membership_router)
router.include_router(tenant_management_router)
router.include_router(user_invitations_router)
@router.get("/anonymous-user-path")
async def get_anonymous_user_path_api(
_: User | None = Depends(current_admin_user),
) -> AnonymousUserPath:
tenant_id = get_current_tenant_id()
if tenant_id is None:
raise HTTPException(status_code=404, detail="Tenant not found")
with get_session_with_shared_schema() as db_session:
current_path = get_anonymous_user_path(tenant_id, db_session)
return AnonymousUserPath(anonymous_user_path=current_path)
@router.post("/anonymous-user-path")
async def set_anonymous_user_path_api(
anonymous_user_path: str,
_: User | None = Depends(current_admin_user),
) -> None:
tenant_id = get_current_tenant_id()
try:
validate_anonymous_user_path(anonymous_user_path)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
with get_session_with_shared_schema() as db_session:
try:
modify_anonymous_user_path(tenant_id, anonymous_user_path, db_session)
except IntegrityError:
raise HTTPException(
status_code=409,
detail="The anonymous user path is already in use. Please choose a different path.",
)
except Exception as e:
logger.exception(f"Failed to modify anonymous user path: {str(e)}")
raise HTTPException(
status_code=500,
detail="An unexpected error occurred while modifying the anonymous user path",
)
@router.post("/anonymous-user")
async def login_as_anonymous_user(
anonymous_user_path: str,
_: User | None = Depends(optional_user),
) -> Response:
with get_session_with_shared_schema() as db_session:
tenant_id = get_tenant_id_for_anonymous_user_path(
anonymous_user_path, db_session
)
if not tenant_id:
raise HTTPException(status_code=404, detail="Tenant not found")
if not anonymous_user_enabled(tenant_id=tenant_id):
raise HTTPException(status_code=403, detail="Anonymous user is not enabled")
token = generate_anonymous_user_jwt_token(tenant_id)
response = Response()
response.delete_cookie(FASTAPI_USERS_AUTH_COOKIE_NAME)
response.set_cookie(
key=ANONYMOUS_USER_COOKIE_NAME,
value=token,
httponly=True,
secure=True,
samesite="strict",
)
return response
@router.post("/product-gating")
def gate_product(
product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep)
) -> ProductGatingResponse:
"""
Gating the product means that the product is not available to the tenant.
They will be directed to the billing page.
We gate the product when their subscription has ended.
"""
try:
store_product_gating(
product_gating_request.tenant_id, product_gating_request.application_status
)
return ProductGatingResponse(updated=True, error=None)
except Exception as e:
logger.exception("Failed to gate product")
return ProductGatingResponse(updated=False, error=str(e))
@router.get("/billing-information")
async def billing_information(
_: User = Depends(current_admin_user),
) -> BillingInformation | SubscriptionStatusResponse:
logger.info("Fetching billing information")
tenant_id = get_current_tenant_id()
return fetch_billing_information(tenant_id)
@router.post("/create-customer-portal-session")
async def create_customer_portal_session(
_: User = Depends(current_admin_user),
) -> dict:
tenant_id = get_current_tenant_id()
try:
stripe_info = fetch_tenant_stripe_information(tenant_id)
stripe_customer_id = stripe_info.get("stripe_customer_id")
if not stripe_customer_id:
raise HTTPException(status_code=400, detail="Stripe customer ID not found")
logger.info(stripe_customer_id)
portal_session = stripe.billing_portal.Session.create(
customer=stripe_customer_id,
return_url=f"{WEB_DOMAIN}/admin/billing",
)
logger.info(portal_session)
return {"url": portal_session.url}
except Exception as e:
logger.exception("Failed to create customer portal session")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/create-subscription-session")
async def create_subscription_session(
_: User = Depends(current_admin_user),
) -> SubscriptionSessionResponse:
try:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if not tenant_id:
raise HTTPException(status_code=400, detail="Tenant ID not found")
session_id = fetch_stripe_checkout_session(tenant_id)
return SubscriptionSessionResponse(sessionId=session_id)
except Exception as e:
logger.exception("Failed to create resubscription session")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/impersonate")
async def impersonate_user(
impersonate_request: ImpersonateRequest,
_: User = Depends(current_cloud_superuser),
) -> Response:
"""Allows a cloud superuser to impersonate another user by generating an impersonation JWT token"""
tenant_id = get_tenant_id_for_email(impersonate_request.email)
with get_session_with_tenant(tenant_id=tenant_id) as tenant_session:
user_to_impersonate = get_user_by_email(
impersonate_request.email, tenant_session
)
if user_to_impersonate is None:
raise HTTPException(status_code=404, detail="User not found")
token = await get_redis_strategy().write_token(user_to_impersonate)
response = await auth_backend.transport.get_login_response(token)
response.set_cookie(
key="fastapiusersauth",
value=token,
httponly=True,
secure=True,
samesite="lax",
)
return response
@router.post("/leave-organization")
async def leave_organization(
user_email: UserByEmail,
current_user: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
tenant_id = get_current_tenant_id()
if current_user is None or current_user.email != user_email.user_email:
raise HTTPException(
status_code=403, detail="You can only leave the organization as yourself"
)
user_to_delete = get_user_by_email(user_email.user_email, db_session)
if user_to_delete is None:
raise HTTPException(status_code=404, detail="User not found")
num_admin_users = await get_user_count(only_admin_users=True)
should_delete_tenant = num_admin_users == 1
if should_delete_tenant:
logger.info(
"Last admin user is leaving the organization. Deleting tenant from control plane."
)
try:
await delete_user_from_control_plane(tenant_id, user_to_delete.email)
logger.debug("User deleted from control plane")
except Exception as e:
logger.exception(
f"Failed to delete user from control plane for tenant {tenant_id}: {e}"
)
raise HTTPException(
status_code=500,
detail=f"Failed to remove user from control plane: {str(e)}",
)
db_session.expunge(user_to_delete)
delete_user_from_db(user_to_delete, db_session)
if should_delete_tenant:
remove_all_users_from_tenant(tenant_id)
else:
remove_users_from_tenant([user_to_delete.email], tenant_id)

View File

@@ -1,96 +0,0 @@
import stripe
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from ee.onyx.auth.users import current_admin_user
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
from ee.onyx.server.tenants.access import control_plane_dep
from ee.onyx.server.tenants.billing import fetch_billing_information
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
from ee.onyx.server.tenants.models import BillingInformation
from ee.onyx.server.tenants.models import ProductGatingRequest
from ee.onyx.server.tenants.models import ProductGatingResponse
from ee.onyx.server.tenants.models import SubscriptionSessionResponse
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
from ee.onyx.server.tenants.product_gating import store_product_gating
from onyx.auth.users import User
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
stripe.api_key = STRIPE_SECRET_KEY
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.post("/product-gating")
def gate_product(
product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep)
) -> ProductGatingResponse:
"""
Gating the product means that the product is not available to the tenant.
They will be directed to the billing page.
We gate the product when their subscription has ended.
"""
try:
store_product_gating(
product_gating_request.tenant_id, product_gating_request.application_status
)
return ProductGatingResponse(updated=True, error=None)
except Exception as e:
logger.exception("Failed to gate product")
return ProductGatingResponse(updated=False, error=str(e))
@router.get("/billing-information")
async def billing_information(
_: User = Depends(current_admin_user),
) -> BillingInformation | SubscriptionStatusResponse:
logger.info("Fetching billing information")
tenant_id = get_current_tenant_id()
return fetch_billing_information(tenant_id)
@router.post("/create-customer-portal-session")
async def create_customer_portal_session(
_: User = Depends(current_admin_user),
) -> dict:
tenant_id = get_current_tenant_id()
try:
stripe_info = fetch_tenant_stripe_information(tenant_id)
stripe_customer_id = stripe_info.get("stripe_customer_id")
if not stripe_customer_id:
raise HTTPException(status_code=400, detail="Stripe customer ID not found")
logger.info(stripe_customer_id)
portal_session = stripe.billing_portal.Session.create(
customer=stripe_customer_id,
return_url=f"{WEB_DOMAIN}/admin/billing",
)
logger.info(portal_session)
return {"url": portal_session.url}
except Exception as e:
logger.exception("Failed to create customer portal session")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/create-subscription-session")
async def create_subscription_session(
_: User = Depends(current_admin_user),
) -> SubscriptionSessionResponse:
try:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if not tenant_id:
raise HTTPException(status_code=400, detail="Tenant ID not found")
session_id = fetch_stripe_checkout_session(tenant_id)
return SubscriptionSessionResponse(sessionId=session_id)
except Exception as e:
logger.exception("Failed to create resubscription session")
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -67,30 +67,3 @@ class ProductGatingResponse(BaseModel):
class SubscriptionSessionResponse(BaseModel):
sessionId: str
class TenantByDomainResponse(BaseModel):
tenant_id: str
number_of_users: int
creator_email: str
class TenantByDomainRequest(BaseModel):
email: str
class RequestInviteRequest(BaseModel):
tenant_id: str
class RequestInviteResponse(BaseModel):
success: bool
message: str
class PendingUserSnapshot(BaseModel):
email: str
class ApproveUserRequest(BaseModel):
email: str

View File

@@ -4,7 +4,6 @@ import uuid
import aiohttp # Async HTTP client
import httpx
import requests
from fastapi import HTTPException
from fastapi import Request
from sqlalchemy import select
@@ -15,7 +14,6 @@ from ee.onyx.configs.app_configs import COHERE_DEFAULT_API_KEY
from ee.onyx.configs.app_configs import HUBSPOT_TRACKING_URL
from ee.onyx.configs.app_configs import OPENAI_DEFAULT_API_KEY
from ee.onyx.server.tenants.access import generate_data_plane_token
from ee.onyx.server.tenants.models import TenantByDomainResponse
from ee.onyx.server.tenants.models import TenantCreationPayload
from ee.onyx.server.tenants.models import TenantDeletionPayload
from ee.onyx.server.tenants.schema_management import create_schema_if_not_exists
@@ -28,12 +26,11 @@ from onyx.auth.users import exceptions
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.constants import MilestoneRecordType
from onyx.db.engine import get_session_with_shared_schema
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_sqlalchemy_engine
from onyx.db.llm import update_default_provider
from onyx.db.llm import upsert_cloud_embedding_provider
from onyx.db.llm import upsert_llm_provider
from onyx.db.models import AvailableTenant
from onyx.db.models import IndexModelStatus
from onyx.db.models import SearchSettings
from onyx.db.models import UserTenantMapping
@@ -63,72 +60,42 @@ async def get_or_provision_tenant(
This function should only be called after we have verified we want this user's tenant to exist.
It returns the tenant ID associated with the email, creating a new tenant if necessary.
"""
# Early return for non-multi-tenant mode
if not MULTI_TENANT:
return POSTGRES_DEFAULT_SCHEMA
if referral_source and request:
await submit_to_hubspot(email, referral_source, request)
# First, check if the user already has a tenant
tenant_id: str | None = None
try:
tenant_id = get_tenant_id_for_email(email)
return tenant_id
except exceptions.UserNotExists:
# User doesn't exist, so we need to create a new tenant or assign an existing one
pass
try:
# Try to get a pre-provisioned tenant
tenant_id = await get_available_tenant()
if tenant_id:
# If we have a pre-provisioned tenant, assign it to the user
await assign_tenant_to_user(tenant_id, email, referral_source)
logger.info(f"Assigned pre-provisioned tenant {tenant_id} to user {email}")
return tenant_id
else:
# If no pre-provisioned tenant is available, create a new one on-demand
# If tenant does not exist and in Multi tenant mode, provision a new tenant
try:
tenant_id = await create_tenant(email, referral_source)
return tenant_id
except Exception as e:
logger.error(f"Tenant provisioning failed: {e}")
raise HTTPException(status_code=500, detail="Failed to provision tenant.")
except Exception as e:
# If we've encountered an error, log and raise an exception
error_msg = "Failed to provision tenant"
logger.error(error_msg, exc_info=e)
if not tenant_id:
raise HTTPException(
status_code=500,
detail="Failed to provision tenant. Please try again later.",
status_code=401, detail="User does not belong to an organization"
)
return tenant_id
async def create_tenant(email: str, referral_source: str | None = None) -> str:
"""
Create a new tenant on-demand when no pre-provisioned tenants are available.
This is the fallback method when we can't use a pre-provisioned tenant.
"""
tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4())
logger.info(f"Creating new tenant {tenant_id} for user {email}")
try:
# Provision tenant on data plane
await provision_tenant(tenant_id, email)
# Notify control plane if not already done in provision_tenant
if not DEV_MODE and referral_source:
# Notify control plane
if not DEV_MODE:
await notify_control_plane(tenant_id, email, referral_source)
except Exception as e:
logger.exception(f"Tenant provisioning failed: {str(e)}")
# Attempt to rollback the tenant provisioning
try:
await rollback_tenant_provisioning(tenant_id)
except Exception:
logger.exception(f"Failed to rollback tenant provisioning for {tenant_id}")
logger.error(f"Tenant provisioning failed: {e}")
await rollback_tenant_provisioning(tenant_id)
raise HTTPException(status_code=500, detail="Failed to provision tenant.")
return tenant_id
@@ -142,25 +109,54 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
)
logger.debug(f"Provisioning tenant {tenant_id} for user {email}")
token = None
try:
# Create the schema for the tenant
if not create_schema_if_not_exists(tenant_id):
logger.debug(f"Created schema for tenant {tenant_id}")
else:
logger.debug(f"Schema already exists for tenant {tenant_id}")
# Set up the tenant with all necessary configurations
await setup_tenant(tenant_id)
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
# Assign the tenant to the user
await assign_tenant_to_user(tenant_id, email)
# Await the Alembic migrations
await asyncio.to_thread(run_alembic_migrations, tenant_id)
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
configure_default_api_keys(db_session)
current_search_settings = (
db_session.query(SearchSettings)
.filter_by(status=IndexModelStatus.FUTURE)
.first()
)
cohere_enabled = (
current_search_settings is not None
and current_search_settings.provider_type == EmbeddingProvider.COHERE
)
setup_onyx(db_session, tenant_id, cohere_enabled=cohere_enabled)
add_users_to_tenant([email], tenant_id)
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
create_milestone_and_report(
user=None,
distinct_id=tenant_id,
event_type=MilestoneRecordType.TENANT_CREATED,
properties={
"email": email,
},
db_session=db_session,
)
except Exception as e:
logger.exception(f"Failed to create tenant {tenant_id}")
raise HTTPException(
status_code=500, detail=f"Failed to create tenant: {str(e)}"
)
finally:
if token is not None:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
async def notify_control_plane(
@@ -191,74 +187,20 @@ async def notify_control_plane(
async def rollback_tenant_provisioning(tenant_id: str) -> None:
"""
Logic to rollback tenant provisioning on data plane.
Handles each step independently to ensure maximum cleanup even if some steps fail.
"""
# Logic to rollback tenant provisioning on data plane
logger.info(f"Rolling back tenant provisioning for tenant_id: {tenant_id}")
# Track if any part of the rollback fails
rollback_errors = []
# 1. Try to drop the tenant's schema
try:
# Drop the tenant's schema to rollback provisioning
drop_schema(tenant_id)
logger.info(f"Successfully dropped schema for tenant {tenant_id}")
# Remove tenant mapping
with Session(get_sqlalchemy_engine()) as db_session:
db_session.query(UserTenantMapping).filter(
UserTenantMapping.tenant_id == tenant_id
).delete()
db_session.commit()
except Exception as e:
error_msg = f"Failed to drop schema for tenant {tenant_id}: {str(e)}"
logger.error(error_msg)
rollback_errors.append(error_msg)
# 2. Try to remove tenant mapping
try:
with get_session_with_shared_schema() as db_session:
db_session.begin()
try:
db_session.query(UserTenantMapping).filter(
UserTenantMapping.tenant_id == tenant_id
).delete()
db_session.commit()
logger.info(
f"Successfully removed user mappings for tenant {tenant_id}"
)
except Exception as e:
db_session.rollback()
raise e
except Exception as e:
error_msg = f"Failed to remove user mappings for tenant {tenant_id}: {str(e)}"
logger.error(error_msg)
rollback_errors.append(error_msg)
# 3. If this tenant was in the available tenants table, remove it
try:
with get_session_with_shared_schema() as db_session:
db_session.begin()
try:
available_tenant = (
db_session.query(AvailableTenant)
.filter(AvailableTenant.tenant_id == tenant_id)
.first()
)
if available_tenant:
db_session.delete(available_tenant)
db_session.commit()
logger.info(
f"Removed tenant {tenant_id} from available tenants table"
)
except Exception as e:
db_session.rollback()
raise e
except Exception as e:
error_msg = f"Failed to remove tenant {tenant_id} from available tenants table: {str(e)}"
logger.error(error_msg)
rollback_errors.append(error_msg)
# Log summary of rollback operation
if rollback_errors:
logger.error(f"Tenant rollback completed with {len(rollback_errors)} errors")
else:
logger.info(f"Tenant rollback completed successfully for tenant {tenant_id}")
logger.error(f"Failed to rollback tenant provisioning: {e}")
def configure_default_api_keys(db_session: Session) -> None:
@@ -411,155 +353,3 @@ async def delete_user_from_control_plane(tenant_id: str, email: str) -> None:
raise Exception(
f"Failed to delete tenant on control plane: {error_text}"
)
def get_tenant_by_domain_from_control_plane(
domain: str,
tenant_id: str,
) -> TenantByDomainResponse | None:
"""
Fetches tenant information from the control plane based on the email domain.
Args:
domain: The email domain to search for (e.g., "example.com")
Returns:
A dictionary containing tenant information if found, None otherwise
"""
token = generate_data_plane_token()
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
try:
response = requests.get(
f"{CONTROL_PLANE_API_BASE_URL}/tenant-by-domain",
headers=headers,
json={"domain": domain, "tenant_id": tenant_id},
)
if response.status_code != 200:
logger.error(f"Control plane tenant lookup failed: {response.text}")
return None
response_data = response.json()
if not response_data:
return None
return TenantByDomainResponse(
tenant_id=response_data.get("tenant_id"),
number_of_users=response_data.get("number_of_users"),
creator_email=response_data.get("creator_email"),
)
except Exception as e:
logger.error(f"Error fetching tenant by domain: {str(e)}")
return None
async def get_available_tenant() -> str | None:
"""
Get an available pre-provisioned tenant from the NewAvailableTenant table.
Returns the tenant_id if one is available, None otherwise.
Uses row-level locking to prevent race conditions when multiple processes
try to get an available tenant simultaneously.
"""
if not MULTI_TENANT:
return None
with get_session_with_shared_schema() as db_session:
try:
db_session.begin()
# Get the oldest available tenant with FOR UPDATE lock to prevent race conditions
available_tenant = (
db_session.query(AvailableTenant)
.order_by(AvailableTenant.date_created)
.with_for_update(skip_locked=True) # Skip locked rows to avoid blocking
.first()
)
if available_tenant:
tenant_id = available_tenant.tenant_id
# Remove the tenant from the available tenants table
db_session.delete(available_tenant)
db_session.commit()
logger.info(f"Using pre-provisioned tenant {tenant_id}")
return tenant_id
else:
db_session.rollback()
return None
except Exception:
logger.exception("Error getting available tenant")
db_session.rollback()
return None
async def setup_tenant(tenant_id: str) -> None:
"""
Set up a tenant with all necessary configurations.
This is a centralized function that handles all tenant setup logic.
"""
token = None
try:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
# Run Alembic migrations
await asyncio.to_thread(run_alembic_migrations, tenant_id)
# Configure the tenant with default settings
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
# Configure default API keys
configure_default_api_keys(db_session)
# Set up Onyx with appropriate settings
current_search_settings = (
db_session.query(SearchSettings)
.filter_by(status=IndexModelStatus.FUTURE)
.first()
)
cohere_enabled = (
current_search_settings is not None
and current_search_settings.provider_type == EmbeddingProvider.COHERE
)
setup_onyx(db_session, tenant_id, cohere_enabled=cohere_enabled)
except Exception as e:
logger.exception(f"Failed to set up tenant {tenant_id}")
raise e
finally:
if token is not None:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
async def assign_tenant_to_user(
tenant_id: str, email: str, referral_source: str | None = None
) -> None:
"""
Assign a tenant to a user and perform necessary operations.
Uses transaction handling to ensure atomicity and includes retry logic
for control plane notifications.
"""
# First, add the user to the tenant in a transaction
try:
add_users_to_tenant([email], tenant_id)
# Create milestone record in the same transaction context as the tenant assignment
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
create_milestone_and_report(
user=None,
distinct_id=tenant_id,
event_type=MilestoneRecordType.TENANT_CREATED,
properties={
"email": email,
},
db_session=db_session,
)
except Exception:
logger.exception(f"Failed to assign tenant {tenant_id} to user {email}")
raise Exception("Failed to assign tenant to user")
# Notify control plane with retry logic
if not DEV_MODE:
await notify_control_plane(tenant_id, email, referral_source)

View File

@@ -74,21 +74,3 @@ def drop_schema(tenant_id: str) -> None:
text("DROP SCHEMA IF EXISTS %(schema_name)s CASCADE"),
{"schema_name": tenant_id},
)
def get_current_alembic_version(tenant_id: str) -> str:
"""Get the current Alembic version for a tenant."""
from alembic.runtime.migration import MigrationContext
from sqlalchemy import text
engine = get_sqlalchemy_engine()
# Set the search path to the tenant's schema
with engine.connect() as connection:
connection.execute(text(f'SET search_path TO "{tenant_id}"'))
# Get the current version from the alembic_version table
context = MigrationContext.configure(connection)
current_rev = context.get_current_revision()
return current_rev or "head"

View File

@@ -1,67 +0,0 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane
from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant
from onyx.auth.users import current_admin_user
from onyx.auth.users import User
from onyx.db.auth import get_user_count
from onyx.db.engine import get_session
from onyx.db.users import delete_user_from_db
from onyx.db.users import get_user_by_email
from onyx.server.manage.models import UserByEmail
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.post("/leave-team")
async def leave_organization(
user_email: UserByEmail,
current_user: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
tenant_id = get_current_tenant_id()
if current_user is None or current_user.email != user_email.user_email:
raise HTTPException(
status_code=403, detail="You can only leave the organization as yourself"
)
user_to_delete = get_user_by_email(user_email.user_email, db_session)
if user_to_delete is None:
raise HTTPException(status_code=404, detail="User not found")
num_admin_users = await get_user_count(only_admin_users=True)
should_delete_tenant = num_admin_users == 1
if should_delete_tenant:
logger.info(
"Last admin user is leaving the organization. Deleting tenant from control plane."
)
try:
await delete_user_from_control_plane(tenant_id, user_to_delete.email)
logger.debug("User deleted from control plane")
except Exception as e:
logger.exception(
f"Failed to delete user from control plane for tenant {tenant_id}: {e}"
)
raise HTTPException(
status_code=500,
detail=f"Failed to remove user from control plane: {str(e)}",
)
db_session.expunge(user_to_delete)
delete_user_from_db(user_to_delete, db_session)
if should_delete_tenant:
remove_all_users_from_tenant(tenant_id)
else:
remove_users_from_tenant([user_to_delete.email], tenant_id)

View File

@@ -1,39 +0,0 @@
from fastapi import APIRouter
from fastapi import Depends
from ee.onyx.server.tenants.models import TenantByDomainResponse
from ee.onyx.server.tenants.provisioning import get_tenant_by_domain_from_control_plane
from onyx.auth.users import current_user
from onyx.auth.users import User
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/tenants")
FORBIDDEN_COMMON_EMAIL_SUBSTRINGS = [
"gmail",
"outlook",
"yahoo",
"hotmail",
"icloud",
"msn",
"hotmail",
"hotmail.co.uk",
]
@router.get("/existing-team-by-domain")
def get_existing_tenant_by_domain(
user: User | None = Depends(current_user),
) -> TenantByDomainResponse | None:
if not user:
return None
domain = user.email.split("@")[1]
if any(substring in domain for substring in FORBIDDEN_COMMON_EMAIL_SUBSTRINGS):
return None
tenant_id = get_current_tenant_id()
return get_tenant_by_domain_from_control_plane(domain, tenant_id)

View File

@@ -1,90 +0,0 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from ee.onyx.server.tenants.models import ApproveUserRequest
from ee.onyx.server.tenants.models import PendingUserSnapshot
from ee.onyx.server.tenants.models import RequestInviteRequest
from ee.onyx.server.tenants.user_mapping import accept_user_invite
from ee.onyx.server.tenants.user_mapping import approve_user_invite
from ee.onyx.server.tenants.user_mapping import deny_user_invite
from ee.onyx.server.tenants.user_mapping import invite_self_to_tenant
from onyx.auth.invited_users import get_pending_users
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_user
from onyx.auth.users import User
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.post("/users/invite/request")
async def request_invite(
invite_request: RequestInviteRequest,
user: User | None = Depends(current_admin_user),
) -> None:
if user is None:
raise HTTPException(status_code=401, detail="User not authenticated")
try:
invite_self_to_tenant(user.email, invite_request.tenant_id)
except Exception as e:
logger.exception(
f"Failed to invite self to tenant {invite_request.tenant_id}: {e}"
)
raise HTTPException(status_code=500, detail=str(e))
@router.get("/users/pending")
def list_pending_users(
_: User | None = Depends(current_admin_user),
) -> list[PendingUserSnapshot]:
pending_emails = get_pending_users()
return [PendingUserSnapshot(email=email) for email in pending_emails]
@router.post("/users/invite/approve")
async def approve_user(
approve_user_request: ApproveUserRequest,
_: User | None = Depends(current_admin_user),
) -> None:
tenant_id = get_current_tenant_id()
approve_user_invite(approve_user_request.email, tenant_id)
@router.post("/users/invite/accept")
async def accept_invite(
invite_request: RequestInviteRequest,
user: User | None = Depends(current_user),
) -> None:
"""
Accept an invitation to join a tenant.
"""
if not user:
raise HTTPException(status_code=401, detail="Not authenticated")
try:
accept_user_invite(user.email, invite_request.tenant_id)
except Exception as e:
logger.exception(f"Failed to accept invite: {str(e)}")
raise HTTPException(status_code=500, detail="Failed to accept invitation")
@router.post("/users/invite/deny")
async def deny_invite(
invite_request: RequestInviteRequest,
user: User | None = Depends(current_user),
) -> None:
"""
Deny an invitation to join a tenant.
"""
if not user:
raise HTTPException(status_code=401, detail="Not authenticated")
try:
deny_user_invite(user.email, invite_request.tenant_id)
except Exception as e:
logger.exception(f"Failed to deny invite: {str(e)}")
raise HTTPException(status_code=500, detail="Failed to deny invitation")

View File

@@ -1,56 +1,27 @@
import logging
from fastapi_users import exceptions
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.auth.invited_users import get_invited_users
from onyx.auth.invited_users import get_pending_users
from onyx.auth.invited_users import write_invited_users
from onyx.auth.invited_users import write_pending_users
from onyx.db.engine import get_session_with_shared_schema
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_sqlalchemy_engine
from onyx.db.models import UserTenantMapping
from onyx.server.manage.models import TenantSnapshot
from onyx.setup import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
logger = logging.getLogger(__name__)
def get_tenant_id_for_email(email: str) -> str:
if not MULTI_TENANT:
return POSTGRES_DEFAULT_SCHEMA
# Implement logic to get tenant_id from the mapping table
try:
with get_session_with_shared_schema() as db_session:
# First try to get an active tenant
result = db_session.execute(
select(UserTenantMapping).where(
UserTenantMapping.email == email,
UserTenantMapping.active == True, # noqa: E712
)
)
mapping = result.scalar_one_or_none()
tenant_id = mapping.tenant_id if mapping else None
# If no active tenant found, try to get the first inactive one
if tenant_id is None:
result = db_session.execute(
select(UserTenantMapping).where(
UserTenantMapping.email == email,
UserTenantMapping.active == False, # noqa: E712
)
)
mapping = result.scalar_one_or_none()
if mapping:
# Mark this mapping as active
mapping.active = True
db_session.commit()
tenant_id = mapping.tenant_id
except Exception as e:
logger.exception(f"Error getting tenant id for email {email}: {e}")
raise exceptions.UserNotExists()
with Session(get_sqlalchemy_engine()) as db_session:
result = db_session.execute(
select(UserTenantMapping.tenant_id).where(UserTenantMapping.email == email)
)
tenant_id = result.scalar_one_or_none()
if tenant_id is None:
raise exceptions.UserNotExists()
return tenant_id
@@ -67,39 +38,13 @@ def user_owns_a_tenant(email: str) -> bool:
def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
"""
Add users to a tenant with proper transaction handling.
Checks if users already have a tenant mapping to avoid duplicates.
"""
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
try:
# Start a transaction
db_session.begin()
for email in emails:
# Check if the user already has a mapping to this tenant
existing_mapping = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email == email,
UserTenantMapping.tenant_id == tenant_id,
)
.with_for_update()
.first()
)
if not existing_mapping:
# Only add if mapping doesn't exist
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
# Commit the transaction
db_session.commit()
logger.info(f"Successfully added users {emails} to tenant {tenant_id}")
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
except Exception:
logger.exception(f"Failed to add users to tenant {tenant_id}")
db_session.rollback()
raise
db_session.commit()
def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
@@ -131,187 +76,3 @@ def remove_all_users_from_tenant(tenant_id: str) -> None:
UserTenantMapping.tenant_id == tenant_id
).delete()
db_session.commit()
def invite_self_to_tenant(email: str, tenant_id: str) -> None:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
try:
pending_users = get_pending_users()
if email in pending_users:
return
write_pending_users(pending_users + [email])
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def approve_user_invite(email: str, tenant_id: str) -> None:
"""
Approve a user invite to a tenant.
This will delete all existing records for this email and create a new mapping entry for the user in this tenant.
"""
with get_session_with_shared_schema() as db_session:
# Delete all existing records for this email
db_session.query(UserTenantMapping).filter(
UserTenantMapping.email == email
).delete()
# Create a new mapping entry for the user in this tenant
new_mapping = UserTenantMapping(email=email, tenant_id=tenant_id, active=True)
db_session.add(new_mapping)
db_session.commit()
# Also remove the user from pending users list
# Remove from pending users
pending_users = get_pending_users()
if email in pending_users:
pending_users.remove(email)
write_pending_users(pending_users)
# Add to invited users
invited_users = get_invited_users()
if email not in invited_users:
invited_users.append(email)
write_invited_users(invited_users)
def accept_user_invite(email: str, tenant_id: str) -> None:
"""
Accept an invitation to join a tenant.
This activates the user's mapping to the tenant.
"""
with get_session_with_shared_schema() as db_session:
try:
# First check if there's an active mapping for this user and tenant
active_mapping = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email == email,
UserTenantMapping.active == True, # noqa: E712
)
.first()
)
# If an active mapping exists, delete it
if active_mapping:
db_session.delete(active_mapping)
logger.info(
f"Deleted existing active mapping for user {email} in tenant {tenant_id}"
)
# Find the inactive mapping for this user and tenant
mapping = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email == email,
UserTenantMapping.tenant_id == tenant_id,
UserTenantMapping.active == False, # noqa: E712
)
.first()
)
if mapping:
# Set all other mappings for this user to inactive
db_session.query(UserTenantMapping).filter(
UserTenantMapping.email == email,
UserTenantMapping.active == True, # noqa: E712
).update({"active": False})
# Activate this mapping
mapping.active = True
db_session.commit()
logger.info(f"User {email} accepted invitation to tenant {tenant_id}")
else:
logger.warning(
f"No invitation found for user {email} in tenant {tenant_id}"
)
except Exception as e:
db_session.rollback()
logger.exception(
f"Failed to accept invitation for user {email} to tenant {tenant_id}: {str(e)}"
)
raise
def deny_user_invite(email: str, tenant_id: str) -> None:
"""
Deny an invitation to join a tenant.
This removes the user's mapping to the tenant.
"""
with get_session_with_shared_schema() as db_session:
# Delete the mapping for this user and tenant
result = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email == email,
UserTenantMapping.tenant_id == tenant_id,
UserTenantMapping.active == False, # noqa: E712
)
.delete()
)
db_session.commit()
if result:
logger.info(f"User {email} denied invitation to tenant {tenant_id}")
else:
logger.warning(
f"No invitation found for user {email} in tenant {tenant_id}"
)
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
try:
pending_users = get_invited_users()
if email in pending_users:
pending_users.remove(email)
write_invited_users(pending_users)
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def get_tenant_count(tenant_id: str) -> int:
"""
Get the number of active users for this tenant
"""
with get_session_with_shared_schema() as db_session:
# Count the number of active users for this tenant
user_count = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.tenant_id == tenant_id,
UserTenantMapping.active == True, # noqa: E712
)
.count()
)
return user_count
def get_tenant_invitation(email: str) -> TenantSnapshot | None:
"""
Get the first tenant invitation for this user
"""
with get_session_with_shared_schema() as db_session:
# Get the first tenant invitation for this user
invitation = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email == email,
UserTenantMapping.active == False, # noqa: E712
)
.first()
)
if invitation:
# Get the user count for this tenant
user_count = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.tenant_id == invitation.tenant_id,
UserTenantMapping.active == True, # noqa: E712
)
.count()
)
return TenantSnapshot(
tenant_id=invitation.tenant_id, number_of_users=user_count
)
return None

View File

@@ -3,7 +3,6 @@ from shared_configs.enums import EmbedTextType
MODEL_WARM_UP_STRING = "hi " * 512
INFORMATION_CONTENT_MODEL_WARM_UP_STRING = "hi " * 16
DEFAULT_OPENAI_MODEL = "text-embedding-3-small"
DEFAULT_COHERE_MODEL = "embed-english-light-v3.0"
DEFAULT_VOYAGE_MODEL = "voyage-large-2-instruct"

View File

@@ -1,14 +1,11 @@
import numpy as np
import torch
import torch.nn.functional as F
from fastapi import APIRouter
from huggingface_hub import snapshot_download # type: ignore
from setfit import SetFitModel # type: ignore[import]
from transformers import AutoTokenizer # type: ignore
from transformers import BatchEncoding # type: ignore
from transformers import PreTrainedTokenizer # type: ignore
from model_server.constants import INFORMATION_CONTENT_MODEL_WARM_UP_STRING
from model_server.constants import MODEL_WARM_UP_STRING
from model_server.onyx_torch_model import ConnectorClassifier
from model_server.onyx_torch_model import HybridClassifier
@@ -16,22 +13,11 @@ from model_server.utils import simple_log_function_time
from onyx.utils.logger import setup_logger
from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_REPO
from shared_configs.configs import CONNECTOR_CLASSIFIER_MODEL_TAG
from shared_configs.configs import (
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH,
)
from shared_configs.configs import INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX
from shared_configs.configs import INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
from shared_configs.configs import (
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE,
)
from shared_configs.configs import INDEXING_ONLY
from shared_configs.configs import INFORMATION_CONTENT_MODEL_TAG
from shared_configs.configs import INFORMATION_CONTENT_MODEL_VERSION
from shared_configs.configs import INTENT_MODEL_TAG
from shared_configs.configs import INTENT_MODEL_VERSION
from shared_configs.model_server_models import ConnectorClassificationRequest
from shared_configs.model_server_models import ConnectorClassificationResponse
from shared_configs.model_server_models import ContentClassificationPrediction
from shared_configs.model_server_models import IntentRequest
from shared_configs.model_server_models import IntentResponse
@@ -45,10 +31,6 @@ _CONNECTOR_CLASSIFIER_MODEL: ConnectorClassifier | None = None
_INTENT_TOKENIZER: AutoTokenizer | None = None
_INTENT_MODEL: HybridClassifier | None = None
_INFORMATION_CONTENT_MODEL: SetFitModel | None = None
_INFORMATION_CONTENT_MODEL_PROMPT_PREFIX: str = "" # spec to model version!
def get_connector_classifier_tokenizer() -> AutoTokenizer:
global _CONNECTOR_CLASSIFIER_TOKENIZER
@@ -103,7 +85,7 @@ def get_intent_model_tokenizer() -> AutoTokenizer:
def get_local_intent_model(
model_name_or_path: str = INTENT_MODEL_VERSION,
tag: str | None = INTENT_MODEL_TAG,
tag: str = INTENT_MODEL_TAG,
) -> HybridClassifier:
global _INTENT_MODEL
if _INTENT_MODEL is None:
@@ -120,9 +102,7 @@ def get_local_intent_model(
try:
# Attempt to download the model snapshot
logger.notice(f"Downloading model snapshot for {model_name_or_path}")
local_path = snapshot_download(
repo_id=model_name_or_path, revision=tag, local_files_only=False
)
local_path = snapshot_download(repo_id=model_name_or_path, revision=tag)
_INTENT_MODEL = HybridClassifier.from_pretrained(local_path)
except Exception as e:
logger.error(
@@ -132,44 +112,6 @@ def get_local_intent_model(
return _INTENT_MODEL
def get_local_information_content_model(
model_name_or_path: str = INFORMATION_CONTENT_MODEL_VERSION,
tag: str | None = INFORMATION_CONTENT_MODEL_TAG,
) -> SetFitModel:
global _INFORMATION_CONTENT_MODEL
if _INFORMATION_CONTENT_MODEL is None:
try:
# Calculate where the cache should be, then load from local if available
logger.notice(
f"Loading content information model from local cache: {model_name_or_path}"
)
local_path = snapshot_download(
repo_id=model_name_or_path, revision=tag, local_files_only=True
)
_INFORMATION_CONTENT_MODEL = SetFitModel.from_pretrained(local_path)
logger.notice(
f"Loaded content information model from local cache: {local_path}"
)
except Exception as e:
logger.warning(f"Failed to load content information model directly: {e}")
try:
# Attempt to download the model snapshot
logger.notice(
f"Downloading content information model snapshot for {model_name_or_path}"
)
local_path = snapshot_download(
repo_id=model_name_or_path, revision=tag, local_files_only=False
)
_INFORMATION_CONTENT_MODEL = SetFitModel.from_pretrained(local_path)
except Exception as e:
logger.error(
f"Failed to load content information model even after attempted snapshot download: {e}"
)
raise
return _INFORMATION_CONTENT_MODEL
def tokenize_connector_classification_query(
connectors: list[str],
query: str,
@@ -253,13 +195,6 @@ def warm_up_intent_model() -> None:
)
def warm_up_information_content_model() -> None:
logger.notice("Warming up Content Model") # TODO: add version if needed
information_content_model = get_local_information_content_model()
information_content_model(INFORMATION_CONTENT_MODEL_WARM_UP_STRING)
@simple_log_function_time()
def run_inference(tokens: BatchEncoding) -> tuple[list[float], list[float]]:
intent_model = get_local_intent_model()
@@ -283,117 +218,6 @@ def run_inference(tokens: BatchEncoding) -> tuple[list[float], list[float]]:
return intent_probabilities.tolist(), token_positive_probs
@simple_log_function_time()
def run_content_classification_inference(
text_inputs: list[str],
) -> list[ContentClassificationPrediction]:
"""
Assign a score to the segments in question. The model stored in get_local_information_content_model()
creates the 'model score' based on its training, and the scores are then converted to a 0.0-1.0 scale.
In the code outside of the model/inference model servers that score will be converted into the actual
boost factor.
"""
def _prob_to_score(prob: float) -> float:
"""
Conversion of base score to 0.0 - 1.0 score. Note that the min/max values depend on the model!
"""
_MIN_BASE_SCORE = 0.25
_MAX_BASE_SCORE = 0.75
if prob < _MIN_BASE_SCORE:
raw_score = 0.0
elif prob < _MAX_BASE_SCORE:
raw_score = (prob - _MIN_BASE_SCORE) / (_MAX_BASE_SCORE - _MIN_BASE_SCORE)
else:
raw_score = 1.0
return (
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
+ (
INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MAX
- INDEXING_INFORMATION_CONTENT_CLASSIFICATION_MIN
)
* raw_score
)
_BATCH_SIZE = 32
content_model = get_local_information_content_model()
# Process inputs in batches
all_output_classes: list[int] = []
all_base_output_probabilities: list[float] = []
for i in range(0, len(text_inputs), _BATCH_SIZE):
batch = text_inputs[i : i + _BATCH_SIZE]
batch_with_prefix = []
batch_indices = []
# Pre-allocate results for this batch
batch_output_classes: list[np.ndarray] = [np.array(1)] * len(batch)
batch_probabilities: list[np.ndarray] = [np.array(1.0)] * len(batch)
# Pre-process batch to handle long input exceptions
for j, text in enumerate(batch):
if len(text) == 0:
# if no input, treat as non-informative from the model's perspective
batch_output_classes[j] = np.array(0)
batch_probabilities[j] = np.array(0.0)
logger.warning("Input for Content Information Model is empty")
elif (
len(text.split())
<= INDEXING_INFORMATION_CONTENT_CLASSIFICATION_CUTOFF_LENGTH
):
# if input is short, use the model
batch_with_prefix.append(
_INFORMATION_CONTENT_MODEL_PROMPT_PREFIX + text
)
batch_indices.append(j)
else:
# if longer than cutoff, treat as informative (stay with default), but issue warning
logger.warning("Input for Content Information Model too long")
if batch_with_prefix: # Only run model if we have valid inputs
# Get predictions for the batch
model_output_classes = content_model(batch_with_prefix)
model_output_probabilities = content_model.predict_proba(batch_with_prefix)
# Place results in the correct positions
for idx, batch_idx in enumerate(batch_indices):
batch_output_classes[batch_idx] = model_output_classes[idx].numpy()
batch_probabilities[batch_idx] = model_output_probabilities[idx][
1
].numpy() # x[1] is prob of the positive class
all_output_classes.extend([int(x) for x in batch_output_classes])
all_base_output_probabilities.extend([float(x) for x in batch_probabilities])
logits = [
np.log(p / (1 - p)) if p != 0.0 and p != 1.0 else (100 if p == 1.0 else -100)
for p in all_base_output_probabilities
]
scaled_logits = [
logit / INDEXING_INFORMATION_CONTENT_CLASSIFICATION_TEMPERATURE
for logit in logits
]
output_probabilities_with_temp = [
np.exp(scaled_logit) / (1 + np.exp(scaled_logit))
for scaled_logit in scaled_logits
]
prediction_scores = [
_prob_to_score(p_temp) for p_temp in output_probabilities_with_temp
]
content_classification_predictions = [
ContentClassificationPrediction(
predicted_label=predicted_label, content_boost_factor=output_score
)
for predicted_label, output_score in zip(all_output_classes, prediction_scores)
]
return content_classification_predictions
def map_keywords(
input_ids: torch.Tensor, tokenizer: AutoTokenizer, is_keyword: list[bool]
) -> list[str]:
@@ -538,10 +362,3 @@ async def process_analysis_request(
is_keyword, keywords = run_analysis(intent_request)
return IntentResponse(is_keyword=is_keyword, keywords=keywords)
@router.post("/content-classification")
async def process_content_classification_request(
content_classification_requests: list[str],
) -> list[ContentClassificationPrediction]:
return run_content_classification_inference(content_classification_requests)

View File

@@ -62,60 +62,6 @@ _OPENAI_MAX_INPUT_LEN = 2048
# Cohere allows up to 96 embeddings in a single embedding calling
_COHERE_MAX_INPUT_LEN = 96
# Authentication error string constants
_AUTH_ERROR_401 = "401"
_AUTH_ERROR_UNAUTHORIZED = "unauthorized"
_AUTH_ERROR_INVALID_API_KEY = "invalid api key"
_AUTH_ERROR_PERMISSION = "permission"
def is_authentication_error(error: Exception) -> bool:
"""Check if an exception is related to authentication issues.
Args:
error: The exception to check
Returns:
bool: True if the error appears to be authentication-related
"""
error_str = str(error).lower()
return (
_AUTH_ERROR_401 in error_str
or _AUTH_ERROR_UNAUTHORIZED in error_str
or _AUTH_ERROR_INVALID_API_KEY in error_str
or _AUTH_ERROR_PERMISSION in error_str
)
def format_embedding_error(
error: Exception,
service_name: str,
model: str | None,
provider: EmbeddingProvider,
status_code: int | None = None,
) -> str:
"""
Format a standardized error string for embedding errors.
"""
detail = f"Status {status_code}" if status_code else f"{type(error)}"
return (
f"{'HTTP error' if status_code else 'Exception'} embedding text with {service_name} - {detail}: "
f"Model: {model} "
f"Provider: {provider} "
f"Exception: {error}"
)
# Custom exception for authentication errors
class AuthenticationError(Exception):
"""Raised when authentication fails with a provider."""
def __init__(self, provider: str, message: str = "API key is invalid or expired"):
self.provider = provider
self.message = message
super().__init__(f"{provider} authentication failed: {message}")
class CloudEmbedding:
def __init__(
@@ -146,17 +92,31 @@ class CloudEmbedding:
)
final_embeddings: list[Embedding] = []
try:
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
response = await client.embeddings.create(
input=text_batch,
model=model,
dimensions=reduced_dimension or openai.NOT_GIVEN,
)
final_embeddings.extend(
[embedding.embedding for embedding in response.data]
)
return final_embeddings
except Exception as e:
error_string = (
f"Exception embedding text with OpenAI - {type(e)}: "
f"Model: {model} "
f"Provider: {self.provider} "
f"Exception: {e}"
)
logger.error(error_string)
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
response = await client.embeddings.create(
input=text_batch,
model=model,
dimensions=reduced_dimension or openai.NOT_GIVEN,
)
final_embeddings.extend(
[embedding.embedding for embedding in response.data]
)
return final_embeddings
# only log text when it's not an authentication error.
if not isinstance(e, openai.AuthenticationError):
logger.debug(f"Exception texts: {texts}")
raise RuntimeError(error_string)
async def _embed_cohere(
self, texts: list[str], model: str | None, embedding_type: str
@@ -195,6 +155,7 @@ class CloudEmbedding:
input_type=embedding_type,
truncation=True,
)
return response.embeddings
async def _embed_azure(
@@ -278,51 +239,22 @@ class CloudEmbedding:
deployment_name: str | None = None,
reduced_dimension: int | None = None,
) -> list[Embedding]:
try:
if self.provider == EmbeddingProvider.OPENAI:
return await self._embed_openai(texts, model_name, reduced_dimension)
elif self.provider == EmbeddingProvider.AZURE:
return await self._embed_azure(texts, f"azure/{deployment_name}")
elif self.provider == EmbeddingProvider.LITELLM:
return await self._embed_litellm_proxy(texts, model_name)
if self.provider == EmbeddingProvider.OPENAI:
return await self._embed_openai(texts, model_name, reduced_dimension)
elif self.provider == EmbeddingProvider.AZURE:
return await self._embed_azure(texts, f"azure/{deployment_name}")
elif self.provider == EmbeddingProvider.LITELLM:
return await self._embed_litellm_proxy(texts, model_name)
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
if self.provider == EmbeddingProvider.COHERE:
return await self._embed_cohere(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.VOYAGE:
return await self._embed_voyage(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.GOOGLE:
return await self._embed_vertex(texts, model_name, embedding_type)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
except openai.AuthenticationError:
raise AuthenticationError(provider="OpenAI")
except httpx.HTTPStatusError as e:
if e.response.status_code == 401:
raise AuthenticationError(provider=str(self.provider))
error_string = format_embedding_error(
e,
str(self.provider),
model_name or deployment_name,
self.provider,
status_code=e.response.status_code,
)
logger.error(error_string)
logger.debug(f"Exception texts: {texts}")
raise RuntimeError(error_string)
except Exception as e:
if is_authentication_error(e):
raise AuthenticationError(provider=str(self.provider))
error_string = format_embedding_error(
e, str(self.provider), model_name or deployment_name, self.provider
)
logger.error(error_string)
logger.debug(f"Exception texts: {texts}")
raise RuntimeError(error_string)
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
if self.provider == EmbeddingProvider.COHERE:
return await self._embed_cohere(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.VOYAGE:
return await self._embed_voyage(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.GOOGLE:
return await self._embed_vertex(texts, model_name, embedding_type)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
@staticmethod
def create(
@@ -637,13 +569,6 @@ async def process_embed_request(
gpu_type=gpu_type,
)
return EmbedResponse(embeddings=embeddings)
except AuthenticationError as e:
# Handle authentication errors consistently
logger.error(f"Authentication error: {e.provider}")
raise HTTPException(
status_code=401,
detail=f"Authentication failed: {e.message}",
)
except RateLimitError as e:
raise HTTPException(
status_code=429,

View File

@@ -13,7 +13,6 @@ from sentry_sdk.integrations.starlette import StarletteIntegration
from transformers import logging as transformer_logging # type:ignore
from model_server.custom_models import router as custom_models_router
from model_server.custom_models import warm_up_information_content_model
from model_server.custom_models import warm_up_intent_model
from model_server.encoders import router as encoders_router
from model_server.management_endpoints import router as management_router
@@ -75,15 +74,9 @@ async def lifespan(app: FastAPI) -> AsyncGenerator:
logger.notice(f"Torch Threads: {torch.get_num_threads()}")
if not INDEXING_ONLY:
logger.notice(
"The intent model should run on the model server. The information content model should not run here."
)
warm_up_intent_model()
else:
logger.notice(
"The content information model should run on the indexing model server. The intent model should not run here."
)
warm_up_information_content_model()
logger.notice("This model server should only run document indexing.")
yield

View File

@@ -153,8 +153,7 @@ def send_email(
msg = MIMEMultipart("alternative")
msg["Subject"] = subject
msg["To"] = user_email
if mail_from:
msg["From"] = mail_from
msg["From"] = mail_from
msg["Date"] = formatdate(localtime=True)
msg["Message-ID"] = make_msgid(domain="onyx.app")

View File

@@ -1,6 +1,5 @@
from typing import cast
from onyx.configs.constants import KV_PENDING_USERS_KEY
from onyx.configs.constants import KV_USER_STORE_KEY
from onyx.key_value_store.factory import get_kv_store
from onyx.key_value_store.interface import KvKeyNotFoundError
@@ -19,17 +18,3 @@ def write_invited_users(emails: list[str]) -> int:
store = get_kv_store()
store.store(KV_USER_STORE_KEY, cast(JSON_ro, emails))
return len(emails)
def get_pending_users() -> list[str]:
try:
store = get_kv_store()
return cast(list, store.load(KV_PENDING_USERS_KEY))
except KvKeyNotFoundError:
return list()
def write_pending_users(emails: list[str]) -> int:
store = get_kv_store()
store.store(KV_PENDING_USERS_KEY, cast(JSON_ro, emails))
return len(emails)

View File

@@ -100,7 +100,6 @@ from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
from onyx.utils.url import add_url_params
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
from onyx.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import async_return_default_schema
@@ -895,7 +894,7 @@ async def current_limited_user(
return await double_check_user(user)
async def current_chat_accessible_user(
async def current_chat_accesssible_user(
user: User | None = Depends(optional_user),
) -> User | None:
tenant_id = get_current_tenant_id()
@@ -1096,12 +1095,6 @@ def get_oauth_router(
next_url = state_data.get("next_url", "/")
referral_source = state_data.get("referral_source", None)
try:
tenant_id = fetch_ee_implementation_or_noop(
"onyx.server.tenants.user_mapping", "get_tenant_id_for_email", None
)(account_email)
except exceptions.UserNotExists:
tenant_id = None
request.state.referral_source = referral_source
@@ -1133,14 +1126,9 @@ def get_oauth_router(
# Login user
response = await backend.login(strategy, user)
await user_manager.on_after_login(user, request, response)
# Prepare redirect response
if tenant_id is None:
# Use URL utility to add parameters
redirect_url = add_url_params(next_url, {"new_team": "true"})
redirect_response = RedirectResponse(redirect_url, status_code=302)
else:
# No parameters to add
redirect_response = RedirectResponse(next_url, status_code=302)
redirect_response = RedirectResponse(next_url, status_code=302)
# Copy headers and other attributes from 'response' to 'redirect_response'
for header_name, header_value in response.headers.items():
@@ -1152,7 +1140,6 @@ def get_oauth_router(
redirect_response.status_code = response.status_code
if hasattr(response, "media_type"):
redirect_response.media_type = response.media_type
return redirect_response
return router

View File

@@ -112,6 +112,5 @@ celery_app.autodiscover_tasks(
"onyx.background.celery.tasks.connector_deletion",
"onyx.background.celery.tasks.doc_permission_syncing",
"onyx.background.celery.tasks.indexing",
"onyx.background.celery.tasks.tenant_provisioning",
]
)

View File

@@ -92,6 +92,5 @@ def on_setup_logging(
celery_app.autodiscover_tasks(
[
"onyx.background.celery.tasks.monitoring",
"onyx.background.celery.tasks.tenant_provisioning",
]
)

View File

@@ -5,53 +5,40 @@ from logging.handlers import RotatingFileHandler
import psutil
from onyx.utils.logger import is_running_in_container
from onyx.utils.logger import setup_logger
# Regular application logger
logger = setup_logger()
# Only set up memory monitoring in container environment
if is_running_in_container():
# Set up a dedicated memory monitoring logger
MEMORY_LOG_DIR = "/var/log/persisted-logs/memory"
MEMORY_LOG_FILE = os.path.join(MEMORY_LOG_DIR, "memory_usage.log")
MEMORY_LOG_MAX_BYTES = 10 * 1024 * 1024 # 10MB
MEMORY_LOG_BACKUP_COUNT = 5 # Keep 5 backup files
# Set up a dedicated memory monitoring logger
MEMORY_LOG_DIR = "/var/log/persisted-logs/memory"
MEMORY_LOG_FILE = os.path.join(MEMORY_LOG_DIR, "memory_usage.log")
MEMORY_LOG_MAX_BYTES = 10 * 1024 * 1024 # 10MB
MEMORY_LOG_BACKUP_COUNT = 5 # Keep 5 backup files
# Ensure log directory exists
os.makedirs(MEMORY_LOG_DIR, exist_ok=True)
# Ensure log directory exists
os.makedirs(MEMORY_LOG_DIR, exist_ok=True)
# Create a dedicated logger for memory monitoring
memory_logger = logging.getLogger("memory_monitoring")
memory_logger.setLevel(logging.INFO)
# Create a dedicated logger for memory monitoring
memory_logger = logging.getLogger("memory_monitoring")
memory_logger.setLevel(logging.INFO)
# Create a rotating file handler
memory_handler = RotatingFileHandler(
MEMORY_LOG_FILE,
maxBytes=MEMORY_LOG_MAX_BYTES,
backupCount=MEMORY_LOG_BACKUP_COUNT,
)
# Create a rotating file handler
memory_handler = RotatingFileHandler(
MEMORY_LOG_FILE, maxBytes=MEMORY_LOG_MAX_BYTES, backupCount=MEMORY_LOG_BACKUP_COUNT
)
# Create a formatter that includes all relevant information
memory_formatter = logging.Formatter(
"%(asctime)s [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
)
memory_handler.setFormatter(memory_formatter)
memory_logger.addHandler(memory_handler)
else:
# Create a null logger when not in container
memory_logger = logging.getLogger("memory_monitoring")
memory_logger.addHandler(logging.NullHandler())
# Create a formatter that includes all relevant information
memory_formatter = logging.Formatter(
"%(asctime)s [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
)
memory_handler.setFormatter(memory_formatter)
memory_logger.addHandler(memory_handler)
def emit_process_memory(
pid: int, process_name: str, additional_metadata: dict[str, str | int]
) -> None:
# Skip memory monitoring if not in container
if not is_running_in_container():
return
try:
process = psutil.Process(pid)
memory_info = process.memory_info()

View File

@@ -167,16 +167,6 @@ beat_cloud_tasks: list[dict] = [
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-available-tenants",
"task": OnyxCeleryTask.CHECK_AVAILABLE_TENANTS,
"schedule": timedelta(minutes=10),
"options": {
"queue": OnyxCeleryQueues.MONITORING,
"priority": OnyxCeleryPriority.HIGH,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
]
# tasks that only run self hosted

View File

@@ -1,199 +0,0 @@
"""
Periodic tasks for tenant pre-provisioning.
"""
import asyncio
import datetime
import uuid
from celery import shared_task
from celery import Task
from redis.lock import Lock as RedisLock
from ee.onyx.server.tenants.provisioning import setup_tenant
from ee.onyx.server.tenants.schema_management import create_schema_if_not_exists
from ee.onyx.server.tenants.schema_management import get_current_alembic_version
from onyx.background.celery.apps.app_base import task_logger
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.app_configs import TARGET_AVAILABLE_TENANTS
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.engine import get_session_with_shared_schema
from onyx.db.models import AvailableTenant
from onyx.redis.redis_pool import get_redis_client
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import TENANT_ID_PREFIX
# Default number of pre-provisioned tenants to maintain
DEFAULT_TARGET_AVAILABLE_TENANTS = 5
# Soft time limit for tenant pre-provisioning tasks (in seconds)
_TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 5 # 5 minutes
# Hard time limit for tenant pre-provisioning tasks (in seconds)
_TENANT_PROVISIONING_TIME_LIMIT = 60 * 10 # 10 minutes
@shared_task(
name=OnyxCeleryTask.CHECK_AVAILABLE_TENANTS,
queue=OnyxCeleryQueues.MONITORING,
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
trail=False,
bind=True,
)
def check_available_tenants(self: Task) -> None:
"""
Check if we have enough pre-provisioned tenants available.
If not, trigger the pre-provisioning of new tenants.
"""
task_logger.info("STARTING CHECK_AVAILABLE_TENANTS")
if not MULTI_TENANT:
task_logger.info(
"Multi-tenancy is not enabled, skipping tenant pre-provisioning"
)
return
r = get_redis_client()
lock_check: RedisLock = r.lock(
OnyxRedisLocks.CHECK_AVAILABLE_TENANTS_LOCK,
timeout=_TENANT_PROVISIONING_SOFT_TIME_LIMIT,
)
# These tasks should never overlap
if not lock_check.acquire(blocking=False):
task_logger.info(
"Skipping check_available_tenants task because it is already running"
)
return
try:
# Get the current count of available tenants
with get_session_with_shared_schema() as db_session:
available_tenants_count = db_session.query(AvailableTenant).count()
# Get the target number of available tenants
target_available_tenants = getattr(
TARGET_AVAILABLE_TENANTS, "value", DEFAULT_TARGET_AVAILABLE_TENANTS
)
# Calculate how many new tenants we need to provision
tenants_to_provision = max(
0, target_available_tenants - available_tenants_count
)
task_logger.info(
f"Available tenants: {available_tenants_count}, "
f"Target: {target_available_tenants}, "
f"To provision: {tenants_to_provision}"
)
# Trigger pre-provisioning tasks for each tenant needed
for _ in range(tenants_to_provision):
from celery import current_app
current_app.send_task(
OnyxCeleryTask.PRE_PROVISION_TENANT,
priority=OnyxCeleryPriority.LOW,
)
except Exception:
task_logger.exception("Error in check_available_tenants task")
finally:
lock_check.release()
@shared_task(
name=OnyxCeleryTask.PRE_PROVISION_TENANT,
ignore_result=True,
soft_time_limit=_TENANT_PROVISIONING_SOFT_TIME_LIMIT,
time_limit=_TENANT_PROVISIONING_TIME_LIMIT,
queue=OnyxCeleryQueues.MONITORING,
bind=True,
)
def pre_provision_tenant(self: Task) -> None:
"""
Pre-provision a new tenant and store it in the NewAvailableTenant table.
This function fully sets up the tenant with all necessary configurations,
so it's ready to be assigned to a user immediately.
"""
# The MULTI_TENANT check is now done at the caller level (check_available_tenants)
# rather than inside this function
r = get_redis_client()
lock_provision: RedisLock = r.lock(
OnyxRedisLocks.PRE_PROVISION_TENANT_LOCK,
timeout=_TENANT_PROVISIONING_SOFT_TIME_LIMIT,
)
# Allow multiple pre-provisioning tasks to run, but ensure they don't overlap
if not lock_provision.acquire(blocking=False):
task_logger.debug(
"Skipping pre_provision_tenant task because it is already running"
)
return
tenant_id: str | None = None
try:
# Generate a new tenant ID
tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4())
task_logger.info(f"Pre-provisioning tenant: {tenant_id}")
# Create the schema for the new tenant
schema_created = create_schema_if_not_exists(tenant_id)
if schema_created:
task_logger.debug(f"Created schema for tenant: {tenant_id}")
else:
task_logger.debug(f"Schema already exists for tenant: {tenant_id}")
# Set up the tenant with all necessary configurations
task_logger.debug(f"Setting up tenant configuration: {tenant_id}")
asyncio.run(setup_tenant(tenant_id))
task_logger.debug(f"Tenant configuration completed: {tenant_id}")
# Get the current Alembic version
alembic_version = get_current_alembic_version(tenant_id)
task_logger.debug(
f"Tenant {tenant_id} using Alembic version: {alembic_version}"
)
# Store the pre-provisioned tenant in the database
task_logger.debug(f"Storing pre-provisioned tenant in database: {tenant_id}")
with get_session_with_shared_schema() as db_session:
# Use a transaction to ensure atomicity
db_session.begin()
try:
new_tenant = AvailableTenant(
tenant_id=tenant_id,
alembic_version=alembic_version,
date_created=datetime.datetime.now(),
)
db_session.add(new_tenant)
db_session.commit()
task_logger.info(f"Successfully pre-provisioned tenant: {tenant_id}")
except Exception:
db_session.rollback()
task_logger.error(
f"Failed to store pre-provisioned tenant: {tenant_id}",
exc_info=True,
)
raise
except Exception:
task_logger.error("Error in pre_provision_tenant task", exc_info=True)
# If we have a tenant_id, attempt to rollback any partially completed provisioning
if tenant_id:
task_logger.info(
f"Rolling back failed tenant provisioning for: {tenant_id}"
)
try:
from ee.onyx.server.tenants.provisioning import (
rollback_tenant_provisioning,
)
asyncio.run(rollback_tenant_provisioning(tenant_id))
except Exception:
task_logger.exception(f"Error during rollback for tenant: {tenant_id}")
finally:
lock_provision.release()

View File

@@ -563,7 +563,6 @@ def vespa_metadata_sync_task(self: Task, document_id: str, *, tenant_id: str) ->
access=doc_access,
boost=doc.boost,
hidden=doc.hidden,
# aggregated_boost_factor=doc.aggregated_boost_factor,
)
# update Vespa. OK if doc doesn't exist. Raises exception otherwise.

View File

@@ -28,7 +28,6 @@ from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import Document
from onyx.connectors.models import IndexAttemptMetadata
from onyx.connectors.models import TextSection
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import get_last_successful_attempt_time
from onyx.db.connector_credential_pair import update_connector_credential_pair
@@ -53,9 +52,6 @@ from onyx.httpx.httpx_pool import HttpxPool
from onyx.indexing.embedder import DefaultIndexingEmbedder
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.indexing.indexing_pipeline import build_indexing_pipeline
from onyx.natural_language_processing.search_nlp_models import (
InformationContentClassificationModel,
)
from onyx.utils.logger import setup_logger
from onyx.utils.logger import TaskAttemptSingleton
from onyx.utils.telemetry import create_milestone_and_report
@@ -158,12 +154,14 @@ def strip_null_characters(doc_batch: list[Document]) -> list[Document]:
)
for section in cleaned_doc.sections:
if section.link is not None:
if section.link and "\x00" in section.link:
logger.warning(
f"NUL characters found in document link for document: {cleaned_doc.id}"
)
section.link = section.link.replace("\x00", "")
# since text can be longer, just replace to avoid double scan
if isinstance(section, TextSection) and section.text is not None:
section.text = section.text.replace("\x00", "")
section.text = section.text.replace("\x00", "")
cleaned_batch.append(cleaned_doc)
@@ -351,8 +349,6 @@ def _run_indexing(
callback=callback,
)
information_content_classification_model = InformationContentClassificationModel()
document_index = get_default_document_index(
index_attempt_start.search_settings,
None,
@@ -361,7 +357,6 @@ def _run_indexing(
indexing_pipeline = build_indexing_pipeline(
embedder=embedding_model,
information_content_classification_model=information_content_classification_model,
document_index=document_index,
ignore_time_skip=(
ctx.from_beginning
@@ -484,11 +479,7 @@ def _run_indexing(
doc_size = 0
for section in doc.sections:
if (
isinstance(section, TextSection)
and section.text is not None
):
doc_size += len(section.text)
doc_size += len(section.text)
if doc_size > INDEXING_SIZE_WARNING_THRESHOLD:
logger.warning(

View File

@@ -1,13 +1,10 @@
from collections import OrderedDict
from collections.abc import Callable
from collections.abc import Iterator
from collections.abc import Mapping
from datetime import datetime
from enum import Enum
from typing import Any
from typing import Literal
from typing import TYPE_CHECKING
from typing import Union
from pydantic import BaseModel
from pydantic import ConfigDict
@@ -47,44 +44,9 @@ class LlmDoc(BaseModel):
class SubQuestionIdentifier(BaseModel):
"""None represents references to objects in the original flow. To our understanding,
these will not be None in the packets returned from agent search.
"""
level: int | None = None
level_question_num: int | None = None
@staticmethod
def make_dict_by_level(
original_dict: Mapping[tuple[int, int], "SubQuestionIdentifier"]
) -> dict[int, list["SubQuestionIdentifier"]]:
"""returns a dict of level to object list (sorted by level_question_num)
Ordering is asc for readability.
"""
# organize by level, then sort ascending by question_index
level_dict: dict[int, list[SubQuestionIdentifier]] = {}
# group by level
for k, obj in original_dict.items():
level = k[0]
if level not in level_dict:
level_dict[level] = []
level_dict[level].append(obj)
# for each level, sort the group
for k2, value2 in level_dict.items():
# we need to handle the none case due to SubQuestionIdentifier typing
# level_question_num as int | None, even though it should never be None here.
level_dict[k2] = sorted(
value2,
key=lambda x: (x.level_question_num is None, x.level_question_num),
)
# sort by level
sorted_dict = OrderedDict(sorted(level_dict.items()))
return sorted_dict
# First chunk of info for streaming QA
class QADocsResponse(RetrievalDocs, SubQuestionIdentifier):
@@ -374,8 +336,6 @@ class AgentAnswerPiece(SubQuestionIdentifier):
class SubQuestionPiece(SubQuestionIdentifier):
"""Refined sub questions generated from the initial user question."""
sub_question: str
@@ -387,13 +347,13 @@ class RefinedAnswerImprovement(BaseModel):
refined_answer_improvement: bool
AgentSearchPacket = Union[
AgentSearchPacket = (
SubQuestionPiece
| AgentAnswerPiece
| SubQueryPiece
| ExtendedToolResponse
| RefinedAnswerImprovement
]
)
AnswerPacket = (
AnswerQuestionPossibleReturn | AgentSearchPacket | ToolCallKickoff | ToolResponse

View File

@@ -8,9 +8,6 @@ from onyx.configs.constants import AuthType
from onyx.configs.constants import DocumentIndexType
from onyx.configs.constants import QueryHistoryType
from onyx.file_processing.enums import HtmlBasedConnectorTransformLinksStrategy
from onyx.prompts.image_analysis import DEFAULT_IMAGE_ANALYSIS_SYSTEM_PROMPT
from onyx.prompts.image_analysis import DEFAULT_IMAGE_SUMMARIZATION_SYSTEM_PROMPT
from onyx.prompts.image_analysis import DEFAULT_IMAGE_SUMMARIZATION_USER_PROMPT
#####
# App Configs
@@ -646,24 +643,3 @@ MOCK_LLM_RESPONSE = (
DEFAULT_IMAGE_ANALYSIS_MAX_SIZE_MB = 20
# Number of pre-provisioned tenants to maintain
TARGET_AVAILABLE_TENANTS = int(os.environ.get("TARGET_AVAILABLE_TENANTS", "5"))
# Image summarization configuration
IMAGE_SUMMARIZATION_SYSTEM_PROMPT = os.environ.get(
"IMAGE_SUMMARIZATION_SYSTEM_PROMPT",
DEFAULT_IMAGE_SUMMARIZATION_SYSTEM_PROMPT,
)
# The user prompt for image summarization - the image filename will be automatically prepended
IMAGE_SUMMARIZATION_USER_PROMPT = os.environ.get(
"IMAGE_SUMMARIZATION_USER_PROMPT",
DEFAULT_IMAGE_SUMMARIZATION_USER_PROMPT,
)
IMAGE_ANALYSIS_SYSTEM_PROMPT = os.environ.get(
"IMAGE_ANALYSIS_SYSTEM_PROMPT",
DEFAULT_IMAGE_ANALYSIS_SYSTEM_PROMPT,
)

View File

@@ -76,7 +76,6 @@ KV_REINDEX_KEY = "needs_reindexing"
KV_SEARCH_SETTINGS = "search_settings"
KV_UNSTRUCTURED_API_KEY = "unstructured_api_key"
KV_USER_STORE_KEY = "INVITED_USERS"
KV_PENDING_USERS_KEY = "PENDING_USERS"
KV_NO_AUTH_USER_PREFERENCES_KEY = "no_auth_user_preferences"
KV_CRED_KEY = "credential_id_{}"
KV_GMAIL_CRED_KEY = "gmail_app_credential"
@@ -322,8 +321,6 @@ class OnyxRedisLocks:
"da_lock:check_connector_external_group_sync_beat"
)
MONITOR_BACKGROUND_PROCESSES_LOCK = "da_lock:monitor_background_processes"
CHECK_AVAILABLE_TENANTS_LOCK = "da_lock:check_available_tenants"
PRE_PROVISION_TENANT_LOCK = "da_lock:pre_provision_tenant"
CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX = (
"da_lock:connector_doc_permissions_sync"
@@ -386,7 +383,6 @@ class OnyxCeleryTask:
CLOUD_MONITOR_CELERY_QUEUES = (
f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor_celery_queues"
)
CHECK_AVAILABLE_TENANTS = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check_available_tenants"
CHECK_FOR_CONNECTOR_DELETION = "check_for_connector_deletion_task"
CHECK_FOR_VESPA_SYNC_TASK = "check_for_vespa_sync_task"
@@ -403,9 +399,6 @@ class OnyxCeleryTask:
MONITOR_BACKGROUND_PROCESSES = "monitor_background_processes"
MONITOR_CELERY_QUEUES = "monitor_celery_queues"
# Tenant pre-provisioning
PRE_PROVISION_TENANT = "pre_provision_tenant"
KOMBU_MESSAGE_CLEANUP_TASK = "kombu_message_cleanup_task"
CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK = (
"connector_permission_sync_generator_task"

View File

@@ -132,10 +132,3 @@ if _LITELLM_EXTRA_BODY_RAW:
LITELLM_EXTRA_BODY = json.loads(_LITELLM_EXTRA_BODY_RAW)
except Exception:
pass
# Whether and how to lower scores for short chunks w/o relevant context
# Evaluated via custom ML model
USE_INFORMATION_CONTENT_CLASSIFICATION = (
os.environ.get("USE_INFORMATION_CONTENT_CLASSIFICATION", "false").lower() == "true"
)

View File

@@ -4,7 +4,6 @@ from concurrent.futures import Future
from concurrent.futures import ThreadPoolExecutor
from io import BytesIO
from typing import Any
from typing import cast
import requests
from pyairtable import Api as AirtableApi
@@ -17,8 +16,7 @@ from onyx.configs.constants import DocumentSource
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.utils.logger import setup_logger
@@ -269,7 +267,7 @@ class AirtableConnector(LoadConnector):
table_id: str,
view_id: str | None,
record_id: str,
) -> tuple[list[TextSection], dict[str, str | list[str]]]:
) -> tuple[list[Section], dict[str, str | list[str]]]:
"""
Process a single Airtable field and return sections or metadata.
@@ -307,7 +305,7 @@ class AirtableConnector(LoadConnector):
# Otherwise, create relevant sections
sections = [
TextSection(
Section(
link=link,
text=(
f"{field_name}:\n"
@@ -342,7 +340,7 @@ class AirtableConnector(LoadConnector):
table_name = table_schema.name
record_id = record["id"]
fields = record["fields"]
sections: list[TextSection] = []
sections: list[Section] = []
metadata: dict[str, str | list[str]] = {}
# Get primary field value if it exists
@@ -386,7 +384,7 @@ class AirtableConnector(LoadConnector):
return Document(
id=f"airtable__{record_id}",
sections=(cast(list[TextSection | ImageSection], sections)),
sections=sections,
source=DocumentSource.AIRTABLE,
semantic_identifier=semantic_id,
metadata=metadata,

View File

@@ -10,7 +10,7 @@ from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -82,7 +82,7 @@ class AsanaConnector(LoadConnector, PollConnector):
logger.debug(f"Converting Asana task {task.id} to Document")
return Document(
id=task.id,
sections=[TextSection(link=task.link, text=task.text)],
sections=[Section(link=task.link, text=task.text)],
doc_updated_at=task.last_modified,
source=DocumentSource.ASANA,
semantic_identifier=task.title,

View File

@@ -20,7 +20,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.file_processing.html_utils import parse_html_page_basic
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
@@ -221,7 +221,7 @@ def _get_forums(
def _translate_forum_to_doc(af: AxeroForum) -> Document:
doc = Document(
id=af.doc_id,
sections=[TextSection(link=af.link, text=reply) for reply in af.responses],
sections=[Section(link=af.link, text=reply) for reply in af.responses],
source=DocumentSource.AXERO,
semantic_identifier=af.title,
doc_updated_at=af.last_update,
@@ -244,7 +244,7 @@ def _translate_content_to_doc(content: dict) -> Document:
doc = Document(
id="AXERO_" + str(content["ContentID"]),
sections=[TextSection(link=content["ContentURL"], text=page_text)],
sections=[Section(link=content["ContentURL"], text=page_text)],
source=DocumentSource.AXERO,
semantic_identifier=content["ContentTitle"],
doc_updated_at=time_str_to_utc(content["DateUpdated"]),

View File

@@ -25,7 +25,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.utils.logger import setup_logger
@@ -208,7 +208,7 @@ class BlobStorageConnector(LoadConnector, PollConnector):
batch.append(
Document(
id=f"{self.bucket_type}:{self.bucket_name}:{obj['Key']}",
sections=[TextSection(link=link, text=text)],
sections=[Section(link=link, text=text)],
source=DocumentSource(self.bucket_type.value),
semantic_identifier=name,
doc_updated_at=last_modified,
@@ -341,14 +341,7 @@ if __name__ == "__main__":
print("Sections:")
for section in doc.sections:
print(f" - Link: {section.link}")
if isinstance(section, TextSection) and section.text is not None:
print(f" - Text: {section.text[:100]}...")
elif (
hasattr(section, "image_file_name") and section.image_file_name
):
print(f" - Image: {section.image_file_name}")
else:
print("Error: Unknown section type")
print(f" - Text: {section.text[:100]}...")
print("---")
break

View File

@@ -18,7 +18,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.file_processing.html_utils import parse_html_page_basic
@@ -81,7 +81,7 @@ class BookstackConnector(LoadConnector, PollConnector):
)
return Document(
id="book__" + str(book.get("id")),
sections=[TextSection(link=url, text=text)],
sections=[Section(link=url, text=text)],
source=DocumentSource.BOOKSTACK,
semantic_identifier="Book: " + title,
title=title,
@@ -110,7 +110,7 @@ class BookstackConnector(LoadConnector, PollConnector):
)
return Document(
id="chapter__" + str(chapter.get("id")),
sections=[TextSection(link=url, text=text)],
sections=[Section(link=url, text=text)],
source=DocumentSource.BOOKSTACK,
semantic_identifier="Chapter: " + title,
title=title,
@@ -134,7 +134,7 @@ class BookstackConnector(LoadConnector, PollConnector):
)
return Document(
id="shelf:" + str(shelf.get("id")),
sections=[TextSection(link=url, text=text)],
sections=[Section(link=url, text=text)],
source=DocumentSource.BOOKSTACK,
semantic_identifier="Shelf: " + title,
title=title,
@@ -167,7 +167,7 @@ class BookstackConnector(LoadConnector, PollConnector):
time.sleep(0.1)
return Document(
id="page:" + page_id,
sections=[TextSection(link=url, text=text)],
sections=[Section(link=url, text=text)],
source=DocumentSource.BOOKSTACK,
semantic_identifier="Page: " + str(title),
title=str(title),

View File

@@ -17,7 +17,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.utils.retry_wrapper import retry_builder
@@ -62,11 +62,11 @@ class ClickupConnector(LoadConnector, PollConnector):
return response.json()
def _get_task_comments(self, task_id: str) -> list[TextSection]:
def _get_task_comments(self, task_id: str) -> list[Section]:
url_endpoint = f"/task/{task_id}/comment"
response = self._make_request(url_endpoint)
comments = [
TextSection(
Section(
link=f'https://app.clickup.com/t/{task_id}?comment={comment_dict["id"]}',
text=comment_dict["comment_text"],
)
@@ -133,7 +133,7 @@ class ClickupConnector(LoadConnector, PollConnector):
],
title=task["name"],
sections=[
TextSection(
Section(
link=task["url"],
text=(
task["markdown_description"]

View File

@@ -33,9 +33,9 @@ from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import Section
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.connectors.vision_enabled_connector import VisionEnabledConnector
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
@@ -85,6 +85,7 @@ class ConfluenceConnector(
PollConnector,
SlimConnector,
CredentialsConnector,
VisionEnabledConnector,
):
def __init__(
self,
@@ -115,6 +116,9 @@ class ConfluenceConnector(
self._confluence_client: OnyxConfluence | None = None
self._fetched_titles: set[str] = set()
# Initialize vision LLM using the mixin
self.initialize_vision_llm()
# Remove trailing slash from wiki_base if present
self.wiki_base = wiki_base.rstrip("/")
"""
@@ -241,16 +245,12 @@ class ConfluenceConnector(
)
# Create the main section for the page content
sections: list[TextSection | ImageSection] = [
TextSection(text=page_content, link=page_url)
]
sections = [Section(text=page_content, link=page_url)]
# Process comments if available
comment_text = self._get_comment_string_for_page_id(page_id)
if comment_text:
sections.append(
TextSection(text=comment_text, link=f"{page_url}#comments")
)
sections.append(Section(text=comment_text, link=f"{page_url}#comments"))
# Process attachments
if "children" in page and "attachment" in page["children"]:
@@ -263,27 +263,21 @@ class ConfluenceConnector(
result = process_attachment(
self.confluence_client,
attachment,
page_id,
page_title,
self.image_analysis_llm,
)
if result and result.text:
if result.text:
# Create a section for the attachment text
attachment_section = TextSection(
attachment_section = Section(
text=result.text,
link=f"{page_url}#attachment-{attachment['id']}",
)
sections.append(attachment_section)
elif result and result.file_name:
# Create an ImageSection for image attachments
image_section = ImageSection(
link=f"{page_url}#attachment-{attachment['id']}",
image_file_name=result.file_name,
)
sections.append(image_section)
else:
sections.append(attachment_section)
elif result.error:
logger.warning(
f"Error processing attachment '{attachment.get('title')}':",
f"{result.error if result else 'Unknown error'}",
f"Error processing attachment '{attachment.get('title')}': {result.error}"
)
# Extract metadata
@@ -354,7 +348,7 @@ class ConfluenceConnector(
# Now get attachments for that page:
attachment_query = self._construct_attachment_query(page["id"])
# We'll use the page's XML to provide context if we summarize an image
page.get("body", {}).get("storage", {}).get("value", "")
confluence_xml = page.get("body", {}).get("storage", {}).get("value", "")
for attachment in self.confluence_client.paginated_cql_retrieval(
cql=attachment_query,
@@ -362,7 +356,7 @@ class ConfluenceConnector(
):
attachment["metadata"].get("mediaType", "")
if not validate_attachment_filetype(
attachment,
attachment, self.image_analysis_llm
):
continue
@@ -372,26 +366,23 @@ class ConfluenceConnector(
response = convert_attachment_to_content(
confluence_client=self.confluence_client,
attachment=attachment,
page_id=page["id"],
page_context=confluence_xml,
llm=self.image_analysis_llm,
)
if response is None:
continue
content_text, file_storage_name = response
object_url = build_confluence_document_id(
self.wiki_base, attachment["_links"]["webui"], self.is_cloud
)
if content_text:
doc.sections.append(
TextSection(
Section(
text=content_text,
link=object_url,
)
)
elif file_storage_name:
doc.sections.append(
ImageSection(
link=object_url,
image_file_name=file_storage_name,
)
)
@@ -471,7 +462,7 @@ class ConfluenceConnector(
# If you skip images, you'll skip them in the permission sync
attachment["metadata"].get("mediaType", "")
if not validate_attachment_filetype(
attachment,
attachment, self.image_analysis_llm
):
continue

View File

@@ -1,3 +1,4 @@
import io
import json
import time
from collections.abc import Callable
@@ -18,11 +19,17 @@ from requests import HTTPError
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET
from onyx.configs.app_configs import (
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
)
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
from onyx.connectors.confluence.utils import _handle_http_error
from onyx.connectors.confluence.utils import confluence_refresh_tokens
from onyx.connectors.confluence.utils import get_start_param_from_url
from onyx.connectors.confluence.utils import update_param_in_path
from onyx.connectors.confluence.utils import validate_attachment_filetype
from onyx.connectors.interfaces import CredentialsProviderInterface
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.html_utils import format_document_soup
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.logger import setup_logger
@@ -801,6 +808,65 @@ def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str:
return _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) or _USER_NOT_FOUND
def attachment_to_content(
confluence_client: OnyxConfluence,
attachment: dict[str, Any],
parent_content_id: str | None = None,
) -> str | None:
"""If it returns None, assume that we should skip this attachment."""
if not validate_attachment_filetype(attachment):
return None
if "api.atlassian.com" in confluence_client.url:
# https://developer.atlassian.com/cloud/confluence/rest/v1/api-group-content---attachments/#api-wiki-rest-api-content-id-child-attachment-attachmentid-download-get
if not parent_content_id:
logger.warning(
"parent_content_id is required to download attachments from Confluence Cloud!"
)
return None
download_link = (
confluence_client.url
+ f"/rest/api/content/{parent_content_id}/child/attachment/{attachment['id']}/download"
)
else:
download_link = confluence_client.url + attachment["_links"]["download"]
attachment_size = attachment["extensions"]["fileSize"]
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
logger.warning(
f"Skipping {download_link} due to size. "
f"size={attachment_size} "
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD}"
)
return None
logger.info(f"_attachment_to_content - _session.get: link={download_link}")
# why are we using session.get here? we probably won't retry these ... is that ok?
response = confluence_client._session.get(download_link)
if response.status_code != 200:
logger.warning(
f"Failed to fetch {download_link} with invalid status code {response.status_code}"
)
return None
extracted_text = extract_file_text(
io.BytesIO(response.content),
file_name=attachment["title"],
break_on_unprocessable=False,
)
if len(extracted_text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD:
logger.warning(
f"Skipping {download_link} due to char count. "
f"char count={len(extracted_text)} "
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD}"
)
return None
return extracted_text
def extract_text_from_confluence_html(
confluence_client: OnyxConfluence,
confluence_object: dict[str, Any],

View File

@@ -22,7 +22,6 @@ from sqlalchemy.orm import Session
from onyx.configs.app_configs import (
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
)
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
from onyx.configs.constants import FileOrigin
if TYPE_CHECKING:
@@ -36,6 +35,7 @@ from onyx.db.pg_file_store import upsert_pgfilestore
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.file_validation import is_valid_image_type
from onyx.file_processing.image_utils import store_image_and_create_section
from onyx.llm.interfaces import LLM
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -53,16 +53,17 @@ class TokenResponse(BaseModel):
def validate_attachment_filetype(
attachment: dict[str, Any],
attachment: dict[str, Any], llm: LLM | None = None
) -> bool:
"""
Validates if the attachment is a supported file type.
If LLM is provided, also checks if it's an image that can be processed.
"""
attachment.get("metadata", {})
media_type = attachment.get("metadata", {}).get("mediaType", "")
if media_type.startswith("image/"):
return is_valid_image_type(media_type)
return llm is not None and is_valid_image_type(media_type)
# For non-image files, check if we support the extension
title = attachment.get("title", "")
@@ -83,103 +84,55 @@ class AttachmentProcessingResult(BaseModel):
error: str | None = None
def _make_attachment_link(
confluence_client: "OnyxConfluence",
attachment: dict[str, Any],
parent_content_id: str | None = None,
) -> str | None:
download_link = ""
if "api.atlassian.com" in confluence_client.url:
# https://developer.atlassian.com/cloud/confluence/rest/v1/api-group-content---attachments/#api-wiki-rest-api-content-id-child-attachment-attachmentid-download-get
if not parent_content_id:
logger.warning(
"parent_content_id is required to download attachments from Confluence Cloud!"
)
return None
download_link = (
confluence_client.url
+ f"/rest/api/content/{parent_content_id}/child/attachment/{attachment['id']}/download"
def _download_attachment(
confluence_client: "OnyxConfluence", attachment: dict[str, Any]
) -> bytes | None:
"""
Retrieves the raw bytes of an attachment from Confluence. Returns None on error.
"""
download_link = confluence_client.url + attachment["_links"]["download"]
resp = confluence_client._session.get(download_link)
if resp.status_code != 200:
logger.warning(
f"Failed to fetch {download_link} with status code {resp.status_code}"
)
else:
download_link = confluence_client.url + attachment["_links"]["download"]
return download_link
return None
return resp.content
def process_attachment(
confluence_client: "OnyxConfluence",
attachment: dict[str, Any],
parent_content_id: str | None,
page_context: str,
llm: LLM | None,
) -> AttachmentProcessingResult:
"""
Processes a Confluence attachment. If it's a document, extracts text,
or if it's an image, stores it for later analysis. Returns a structured result.
or if it's an image and an LLM is available, summarizes it. Returns a structured result.
"""
try:
# Get the media type from the attachment metadata
media_type = attachment.get("metadata", {}).get("mediaType", "")
# Validate the attachment type
if not validate_attachment_filetype(attachment):
if not validate_attachment_filetype(attachment, llm):
return AttachmentProcessingResult(
text=None,
file_name=None,
error=f"Unsupported file type: {media_type}",
)
attachment_link = _make_attachment_link(
confluence_client, attachment, parent_content_id
)
if not attachment_link:
return AttachmentProcessingResult(
text=None, file_name=None, error="Failed to make attachment link"
)
attachment_size = attachment["extensions"]["fileSize"]
if not media_type.startswith("image/"):
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
logger.warning(
f"Skipping {attachment_link} due to size. "
f"size={attachment_size} "
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD}"
)
return AttachmentProcessingResult(
text=None,
file_name=None,
error=f"Attachment text too long: {attachment_size} chars",
)
logger.info(
f"Downloading attachment: "
f"title={attachment['title']} "
f"length={attachment_size} "
f"link={attachment_link}"
)
# Download the attachment
resp: requests.Response = confluence_client._session.get(attachment_link)
if resp.status_code != 200:
logger.warning(
f"Failed to fetch {attachment_link} with status code {resp.status_code}"
)
raw_bytes = _download_attachment(confluence_client, attachment)
if raw_bytes is None:
return AttachmentProcessingResult(
text=None,
file_name=None,
error=f"Attachment download status code is {resp.status_code}",
text=None, file_name=None, error="Failed to download attachment"
)
raw_bytes = resp.content
if not raw_bytes:
return AttachmentProcessingResult(
text=None, file_name=None, error="attachment.content is None"
)
# Process image attachments
if media_type.startswith("image/"):
# Process image attachments with LLM if available
if media_type.startswith("image/") and llm:
return _process_image_attachment(
confluence_client, attachment, raw_bytes, media_type
confluence_client, attachment, page_context, llm, raw_bytes, media_type
)
# Process document attachments
@@ -212,10 +165,12 @@ def process_attachment(
def _process_image_attachment(
confluence_client: "OnyxConfluence",
attachment: dict[str, Any],
page_context: str,
llm: LLM,
raw_bytes: bytes,
media_type: str,
) -> AttachmentProcessingResult:
"""Process an image attachment by saving it without generating a summary."""
"""Process an image attachment by saving it and generating a summary."""
try:
# Use the standardized image storage and section creation
with get_session_with_current_tenant() as db_session:
@@ -225,14 +180,15 @@ def _process_image_attachment(
file_name=Path(attachment["id"]).name,
display_name=attachment["title"],
media_type=media_type,
llm=llm,
file_origin=FileOrigin.CONNECTOR,
)
logger.info(f"Stored image attachment with file name: {file_name}")
# Return empty text but include the file_name for later processing
return AttachmentProcessingResult(text="", file_name=file_name, error=None)
return AttachmentProcessingResult(
text=section.text, file_name=file_name, error=None
)
except Exception as e:
msg = f"Image storage failed for {attachment['title']}: {e}"
msg = f"Image summarization failed for {attachment['title']}: {e}"
logger.error(msg, exc_info=e)
return AttachmentProcessingResult(text=None, file_name=None, error=msg)
@@ -293,15 +249,16 @@ def _process_text_attachment(
def convert_attachment_to_content(
confluence_client: "OnyxConfluence",
attachment: dict[str, Any],
page_id: str,
page_context: str,
llm: LLM | None,
) -> tuple[str | None, str | None] | None:
"""
Facade function which:
1. Validates attachment type
2. Extracts content or stores image for later processing
2. Extracts or summarizes content
3. Returns (content_text, stored_file_name) or None if we should skip it
"""
media_type = attachment.get("metadata", {}).get("mediaType", "")
media_type = attachment["metadata"]["mediaType"]
# Quick check for unsupported types:
if media_type.startswith("video/") or media_type == "application/gliffy+json":
logger.warning(
@@ -309,7 +266,7 @@ def convert_attachment_to_content(
)
return None
result = process_attachment(confluence_client, attachment, page_id)
result = process_attachment(confluence_client, attachment, page_context, llm)
if result.error is not None:
logger.warning(
f"Attachment {attachment['title']} encountered error: {result.error}"
@@ -522,10 +479,6 @@ def attachment_to_file_record(
download_link, absolute=True, not_json_response=True
)
file_type = attachment.get("metadata", {}).get(
"mediaType", "application/octet-stream"
)
# Save image to file store
file_name = f"confluence_attachment_{attachment['id']}"
lobj_oid = create_populate_lobj(BytesIO(image_data), db_session)
@@ -533,7 +486,7 @@ def attachment_to_file_record(
file_name=file_name,
display_name=attachment["title"],
file_origin=FileOrigin.OTHER,
file_type=file_type,
file_type=attachment["metadata"]["mediaType"],
lobj_oid=lobj_oid,
db_session=db_session,
commit=True,

View File

@@ -4,7 +4,6 @@ from collections.abc import Iterable
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
from discord import Client
from discord.channel import TextChannel
@@ -21,8 +20,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -34,7 +32,7 @@ _SNIPPET_LENGTH = 30
def _convert_message_to_document(
message: DiscordMessage,
sections: list[TextSection],
sections: list[Section],
) -> Document:
"""
Convert a discord message to a document
@@ -80,7 +78,7 @@ def _convert_message_to_document(
semantic_identifier=semantic_identifier,
doc_updated_at=message.edited_at,
title=title,
sections=(cast(list[TextSection | ImageSection], sections)),
sections=sections,
metadata=metadata,
)
@@ -125,8 +123,8 @@ async def _fetch_documents_from_channel(
if channel_message.type != MessageType.default:
continue
sections: list[TextSection] = [
TextSection(
sections: list[Section] = [
Section(
text=channel_message.content,
link=channel_message.jump_url,
)
@@ -144,7 +142,7 @@ async def _fetch_documents_from_channel(
continue
sections = [
TextSection(
Section(
text=thread_message.content,
link=thread_message.jump_url,
)
@@ -162,7 +160,7 @@ async def _fetch_documents_from_channel(
continue
sections = [
TextSection(
Section(
text=thread_message.content,
link=thread_message.jump_url,
)

View File

@@ -3,7 +3,6 @@ import urllib.parse
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
import requests
from pydantic import BaseModel
@@ -21,8 +20,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.file_processing.html_utils import parse_html_page_basic
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
@@ -114,7 +112,7 @@ class DiscourseConnector(PollConnector):
responders.append(BasicExpertInfo(display_name=responder_name))
sections.append(
TextSection(link=topic_url, text=parse_html_page_basic(post["cooked"]))
Section(link=topic_url, text=parse_html_page_basic(post["cooked"]))
)
category_name = self.category_id_map.get(topic["category_id"])
@@ -131,7 +129,7 @@ class DiscourseConnector(PollConnector):
doc = Document(
id="_".join([DocumentSource.DISCOURSE.value, str(topic["id"])]),
sections=cast(list[TextSection | ImageSection], sections),
sections=sections,
source=DocumentSource.DISCOURSE,
semantic_identifier=topic["title"],
doc_updated_at=time_str_to_utc(topic["last_posted_at"]),

View File

@@ -19,7 +19,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.file_processing.html_utils import parse_html_page_basic
from onyx.utils.retry_wrapper import retry_builder
@@ -158,7 +158,7 @@ class Document360Connector(LoadConnector, PollConnector):
document = Document(
id=article_details["id"],
sections=[TextSection(link=doc_link, text=doc_text)],
sections=[Section(link=doc_link, text=doc_text)],
source=DocumentSource.DOCUMENT360,
semantic_identifier=article_details["title"],
doc_updated_at=updated_at,

View File

@@ -19,7 +19,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.utils.logger import setup_logger
@@ -108,7 +108,7 @@ class DropboxConnector(LoadConnector, PollConnector):
batch.append(
Document(
id=f"doc:{entry.id}",
sections=[TextSection(link=link, text=text)],
sections=[Section(link=link, text=text)],
source=DocumentSource.DROPBOX,
semantic_identifier=entry.name,
doc_updated_at=modified_time,

View File

@@ -24,7 +24,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.file_processing.extract_file_text import detect_encoding
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import get_file_ext
@@ -111,7 +111,7 @@ def _process_egnyte_file(
# Create the document
return Document(
id=f"egnyte-{file_metadata['entry_id']}",
sections=[TextSection(text=file_content_raw.strip(), link=web_url)],
sections=[Section(text=file_content_raw.strip(), link=web_url)],
source=DocumentSource.EGNYTE,
semantic_identifier=file_name,
metadata=metadata,

View File

@@ -16,8 +16,8 @@ from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.connectors.vision_enabled_connector import VisionEnabledConnector
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.pg_file_store import get_pgfilestore_by_file_name
from onyx.file_processing.extract_file_text import extract_text_and_images
@@ -26,6 +26,7 @@ from onyx.file_processing.extract_file_text import is_valid_file_ext
from onyx.file_processing.extract_file_text import load_files_from_zip
from onyx.file_processing.image_utils import store_image_and_create_section
from onyx.file_store.file_store import get_default_file_store
from onyx.llm.interfaces import LLM
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -58,44 +59,32 @@ def _read_files_and_metadata(
def _create_image_section(
llm: LLM | None,
image_data: bytes,
db_session: Session,
parent_file_name: str,
display_name: str,
link: str | None = None,
idx: int = 0,
) -> tuple[ImageSection, str | None]:
) -> tuple[Section, str | None]:
"""
Creates an ImageSection for an image file or embedded image.
Stores the image in PGFileStore but does not generate a summary.
Args:
image_data: Raw image bytes
db_session: Database session
parent_file_name: Name of the parent file (for embedded images)
display_name: Display name for the image
idx: Index for embedded images
Create a Section object for a single image and store the image in PGFileStore.
If summarization is enabled and we have an LLM, summarize the image.
Returns:
Tuple of (ImageSection, stored_file_name or None)
tuple: (Section object, file_name in PGFileStore or None if storage failed)
"""
# Create a unique identifier for the image
file_name = f"{parent_file_name}_embedded_{idx}" if idx > 0 else parent_file_name
# Create a unique file name for the embedded image
file_name = f"{parent_file_name}_embedded_{idx}"
# Store the image and create a section
try:
section, stored_file_name = store_image_and_create_section(
db_session=db_session,
image_data=image_data,
file_name=file_name,
display_name=display_name,
link=link,
file_origin=FileOrigin.CONNECTOR,
)
return section, stored_file_name
except Exception as e:
logger.error(f"Failed to store image {display_name}: {e}")
raise e
# Use the standardized utility to store the image and create a section
return store_image_and_create_section(
db_session=db_session,
image_data=image_data,
file_name=file_name,
display_name=display_name,
llm=llm,
file_origin=FileOrigin.OTHER,
)
def _process_file(
@@ -104,16 +93,12 @@ def _process_file(
metadata: dict[str, Any] | None,
pdf_pass: str | None,
db_session: Session,
llm: LLM | None,
) -> list[Document]:
"""
Process a file and return a list of Documents.
For images, creates ImageSection objects without summarization.
For documents with embedded images, extracts and stores the images.
Processes a single file, returning a list of Documents (typically one).
Also handles embedded images if 'EMBEDDED_IMAGE_EXTRACTION_ENABLED' is true.
"""
if metadata is None:
metadata = {}
# Get file extension and determine file type
extension = get_file_ext(file_name)
# Fetch the DB record so we know the ID for internal URL
@@ -129,6 +114,8 @@ def _process_file(
return []
# Prepare doc metadata
if metadata is None:
metadata = {}
file_display_name = metadata.get("file_display_name") or os.path.basename(file_name)
# Timestamps
@@ -171,7 +158,6 @@ def _process_file(
"title",
"connector_type",
"pdf_password",
"mime_type",
]
}
@@ -184,45 +170,33 @@ def _process_file(
title = metadata.get("title") or file_display_name
# 1) If the file itself is an image, handle that scenario quickly
if extension in LoadConnector.IMAGE_EXTENSIONS:
# Read the image data
IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".webp"}
if extension in IMAGE_EXTENSIONS:
# Summarize or produce empty doc
image_data = file.read()
if not image_data:
logger.warning(f"Empty image file: {file_name}")
return []
# Create an ImageSection for the image
try:
section, _ = _create_image_section(
image_data=image_data,
db_session=db_session,
parent_file_name=pg_record.file_name,
display_name=title,
image_section, _ = _create_image_section(
llm, image_data, db_session, pg_record.file_name, title
)
return [
Document(
id=doc_id,
sections=[image_section],
source=source_type,
semantic_identifier=file_display_name,
title=title,
doc_updated_at=final_time_updated,
primary_owners=p_owners,
secondary_owners=s_owners,
metadata=metadata_tags,
)
]
return [
Document(
id=doc_id,
sections=[section],
source=source_type,
semantic_identifier=file_display_name,
title=title,
doc_updated_at=final_time_updated,
primary_owners=p_owners,
secondary_owners=s_owners,
metadata=metadata_tags,
)
]
except Exception as e:
logger.error(f"Failed to process image file {file_name}: {e}")
return []
# 2) Otherwise: text-based approach. Possibly with embedded images.
# 2) Otherwise: text-based approach. Possibly with embedded images if enabled.
# (For example .docx with inline images).
file.seek(0)
text_content = ""
embedded_images: list[tuple[bytes, str]] = []
# Extract text and images from the file
text_content, embedded_images = extract_text_and_images(
file=file,
file_name=file_name,
@@ -230,29 +204,24 @@ def _process_file(
)
# Build sections: first the text as a single Section
sections: list[TextSection | ImageSection] = []
sections = []
link_in_meta = metadata.get("link")
if text_content.strip():
sections.append(TextSection(link=link_in_meta, text=text_content.strip()))
sections.append(Section(link=link_in_meta, text=text_content.strip()))
# Then any extracted images from docx, etc.
for idx, (img_data, img_name) in enumerate(embedded_images, start=1):
# Store each embedded image as a separate file in PGFileStore
# and create a section with the image reference
try:
image_section, _ = _create_image_section(
image_data=img_data,
db_session=db_session,
parent_file_name=pg_record.file_name,
display_name=f"{title} - image {idx}",
idx=idx,
)
sections.append(image_section)
except Exception as e:
logger.warning(
f"Failed to process embedded image {idx} in {file_name}: {e}"
)
# and create a section with the image summary
image_section, _ = _create_image_section(
llm,
img_data,
db_session,
pg_record.file_name,
f"{title} - image {idx}",
idx,
)
sections.append(image_section)
return [
Document(
id=doc_id,
@@ -268,10 +237,10 @@ def _process_file(
]
class LocalFileConnector(LoadConnector):
class LocalFileConnector(LoadConnector, VisionEnabledConnector):
"""
Connector that reads files from Postgres and yields Documents, including
embedded image extraction without summarization.
optional embedded image extraction.
"""
def __init__(
@@ -283,6 +252,9 @@ class LocalFileConnector(LoadConnector):
self.batch_size = batch_size
self.pdf_pass: str | None = None
# Initialize vision LLM using the mixin
self.initialize_vision_llm()
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self.pdf_pass = credentials.get("pdf_password")
@@ -314,6 +286,7 @@ class LocalFileConnector(LoadConnector):
metadata=metadata,
pdf_pass=self.pdf_pass,
db_session=db_session,
llm=self.image_analysis_llm,
)
documents.extend(new_docs)

View File

@@ -1,7 +1,6 @@
from collections.abc import Iterator
from datetime import datetime
from datetime import timezone
from typing import cast
from typing import List
import requests
@@ -15,8 +14,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -47,7 +45,7 @@ _FIREFLIES_API_QUERY = """
def _create_doc_from_transcript(transcript: dict) -> Document | None:
sections: List[TextSection] = []
sections: List[Section] = []
current_speaker_name = None
current_link = ""
current_text = ""
@@ -59,7 +57,7 @@ def _create_doc_from_transcript(transcript: dict) -> Document | None:
if sentence["speaker_name"] != current_speaker_name:
if current_speaker_name is not None:
sections.append(
TextSection(
Section(
link=current_link,
text=current_text.strip(),
)
@@ -73,7 +71,7 @@ def _create_doc_from_transcript(transcript: dict) -> Document | None:
# Sometimes these links (links with a timestamp) do not work, it is a bug with Fireflies.
sections.append(
TextSection(
Section(
link=current_link,
text=current_text.strip(),
)
@@ -96,7 +94,7 @@ def _create_doc_from_transcript(transcript: dict) -> Document | None:
return Document(
id=fireflies_id,
sections=cast(list[TextSection | ImageSection], sections),
sections=sections,
source=DocumentSource.FIREFLIES,
semantic_identifier=meeting_title,
metadata={},

View File

@@ -14,7 +14,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.file_processing.html_utils import parse_html_page_basic
from onyx.utils.logger import setup_logger
@@ -133,7 +133,7 @@ def _create_doc_from_ticket(ticket: dict, domain: str) -> Document:
return Document(
id=_FRESHDESK_ID_PREFIX + link,
sections=[
TextSection(
Section(
link=link,
text=text,
)

View File

@@ -13,7 +13,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.utils.logger import setup_logger
@@ -183,7 +183,7 @@ def _convert_page_to_document(
return Document(
id=f"gitbook-{space_id}-{page_id}",
sections=[
TextSection(
Section(
link=page.get("urls", {}).get("app", ""),
text=_extract_text_from_document(page_content),
)
@@ -228,15 +228,10 @@ class GitbookConnector(LoadConnector, PollConnector):
raise ConnectorMissingCredentialError("GitBook")
try:
content = self.client.get(f"/spaces/{self.space_id}/content/pages")
content = self.client.get(f"/spaces/{self.space_id}/content")
pages: list[dict[str, Any]] = content.get("pages", [])
current_batch: list[Document] = []
logger.info(f"Found {len(pages)} root pages.")
logger.info(
f"First 20 Page Ids: {[page.get('id', 'Unknown') for page in pages[:20]]}"
)
while pages:
page = pages.pop(0)

View File

@@ -27,7 +27,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.utils.batching import batch_generator
from onyx.utils.logger import setup_logger
@@ -87,9 +87,7 @@ def _batch_github_objects(
def _convert_pr_to_document(pull_request: PullRequest) -> Document:
return Document(
id=pull_request.html_url,
sections=[
TextSection(link=pull_request.html_url, text=pull_request.body or "")
],
sections=[Section(link=pull_request.html_url, text=pull_request.body or "")],
source=DocumentSource.GITHUB,
semantic_identifier=pull_request.title,
# updated_at is UTC time but is timezone unaware, explicitly add UTC
@@ -111,7 +109,7 @@ def _fetch_issue_comments(issue: Issue) -> str:
def _convert_issue_to_document(issue: Issue) -> Document:
return Document(
id=issue.html_url,
sections=[TextSection(link=issue.html_url, text=issue.body or "")],
sections=[Section(link=issue.html_url, text=issue.body or "")],
source=DocumentSource.GITHUB,
semantic_identifier=issue.title,
# updated_at is UTC time but is timezone unaware

View File

@@ -21,7 +21,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.utils.logger import setup_logger
@@ -56,7 +56,7 @@ def get_author(author: Any) -> BasicExpertInfo:
def _convert_merge_request_to_document(mr: Any) -> Document:
doc = Document(
id=mr.web_url,
sections=[TextSection(link=mr.web_url, text=mr.description or "")],
sections=[Section(link=mr.web_url, text=mr.description or "")],
source=DocumentSource.GITLAB,
semantic_identifier=mr.title,
# updated_at is UTC time but is timezone unaware, explicitly add UTC
@@ -72,7 +72,7 @@ def _convert_merge_request_to_document(mr: Any) -> Document:
def _convert_issue_to_document(issue: Any) -> Document:
doc = Document(
id=issue.web_url,
sections=[TextSection(link=issue.web_url, text=issue.description or "")],
sections=[Section(link=issue.web_url, text=issue.description or "")],
source=DocumentSource.GITLAB,
semantic_identifier=issue.title,
# updated_at is UTC time but is timezone unaware, explicitly add UTC
@@ -99,7 +99,7 @@ def _convert_code_to_document(
file_url = f"{url}/{projectOwner}/{projectName}/-/blob/master/{file['path']}" # Construct the file URL
doc = Document(
id=file["id"],
sections=[TextSection(link=file_url, text=file_content)],
sections=[Section(link=file_url, text=file_content)],
source=DocumentSource.GITLAB,
semantic_identifier=file["name"],
doc_updated_at=datetime.now().replace(

View File

@@ -1,6 +1,5 @@
from base64 import urlsafe_b64decode
from typing import Any
from typing import cast
from typing import Dict
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
@@ -29,9 +28,8 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import Section
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
@@ -117,7 +115,7 @@ def _get_message_body(payload: dict[str, Any]) -> str:
return message_body
def message_to_section(message: Dict[str, Any]) -> tuple[TextSection, dict[str, str]]:
def message_to_section(message: Dict[str, Any]) -> tuple[Section, dict[str, str]]:
link = f"https://mail.google.com/mail/u/0/#inbox/{message['id']}"
payload = message.get("payload", {})
@@ -144,7 +142,7 @@ def message_to_section(message: Dict[str, Any]) -> tuple[TextSection, dict[str,
message_body_text: str = _get_message_body(payload)
return TextSection(link=link, text=message_body_text + message_data), metadata
return Section(link=link, text=message_body_text + message_data), metadata
def thread_to_document(full_thread: Dict[str, Any]) -> Document | None:
@@ -194,7 +192,7 @@ def thread_to_document(full_thread: Dict[str, Any]) -> Document | None:
return Document(
id=id,
semantic_identifier=semantic_identifier,
sections=cast(list[TextSection | ImageSection], sections),
sections=sections,
source=DocumentSource.GMAIL,
# This is used to perform permission sync
primary_owners=primary_owners,

View File

@@ -18,7 +18,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.utils.logger import setup_logger
@@ -243,7 +243,7 @@ class GongConnector(LoadConnector, PollConnector):
Document(
id=call_id,
sections=[
TextSection(link=call_metadata["url"], text=transcript_text)
Section(link=call_metadata["url"], text=transcript_text)
],
source=DocumentSource.GONG,
# Should not ever be Untitled as a call cannot be made without a Title

View File

@@ -43,7 +43,9 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.vision_enabled_connector import VisionEnabledConnector
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.llm.interfaces import LLM
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
@@ -66,6 +68,7 @@ def _convert_single_file(
creds: Any,
primary_admin_email: str,
file: dict[str, Any],
image_analysis_llm: LLM | None,
) -> Any:
user_email = file.get("owners", [{}])[0].get("emailAddress") or primary_admin_email
user_drive_service = get_drive_service(creds, user_email=user_email)
@@ -74,6 +77,7 @@ def _convert_single_file(
file=file,
drive_service=user_drive_service,
docs_service=docs_service,
image_analysis_llm=image_analysis_llm, # pass the LLM so doc_conversion can summarize images
)
@@ -112,7 +116,9 @@ def _clean_requested_drive_ids(
return valid_requested_drive_ids, filtered_folder_ids
class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
class GoogleDriveConnector(
LoadConnector, PollConnector, SlimConnector, VisionEnabledConnector
):
def __init__(
self,
include_shared_drives: bool = False,
@@ -145,6 +151,9 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
if continue_on_failure is not None:
logger.warning("The 'continue_on_failure' parameter is deprecated.")
# Initialize vision LLM using the mixin
self.initialize_vision_llm()
if (
not include_shared_drives
and not include_my_drives
@@ -530,6 +539,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
_convert_single_file,
self.creds,
self.primary_admin_email,
image_analysis_llm=self.image_analysis_llm, # Use the mixin's LLM
)
# Fetch files in batches

View File

@@ -1,50 +1,40 @@
import io
from datetime import datetime
from typing import cast
from datetime import timezone
from tempfile import NamedTemporaryFile
from googleapiclient.http import MediaIoBaseDownload # type: ignore
import openpyxl # type: ignore
from googleapiclient.discovery import build # type: ignore
from googleapiclient.errors import HttpError # type: ignore
from onyx.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import FileOrigin
from onyx.connectors.google_drive.constants import DRIVE_FOLDER_TYPE
from onyx.connectors.google_drive.constants import DRIVE_SHORTCUT_TYPE
from onyx.connectors.google_drive.constants import UNSUPPORTED_FILE_TYPE_CONTENT
from onyx.connectors.google_drive.models import GDriveMimeType
from onyx.connectors.google_drive.models import GoogleDriveFileType
from onyx.connectors.google_drive.section_extraction import get_document_sections
from onyx.connectors.google_utils.resources import GoogleDocsService
from onyx.connectors.google_utils.resources import GoogleDriveService
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import Section
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.db.engine import get_session_with_current_tenant
from onyx.file_processing.extract_file_text import docx_to_text_and_images
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import pptx_to_text
from onyx.file_processing.extract_file_text import read_pdf_file
from onyx.file_processing.extract_file_text import xlsx_to_text
from onyx.file_processing.file_validation import is_valid_image_type
from onyx.file_processing.image_summarization import summarize_image_with_error_handling
from onyx.file_processing.image_utils import store_image_and_create_section
from onyx.file_processing.unstructured import get_unstructured_api_key
from onyx.file_processing.unstructured import unstructured_to_text
from onyx.llm.interfaces import LLM
from onyx.utils.logger import setup_logger
logger = setup_logger()
# Mapping of Google Drive mime types to export formats
GOOGLE_MIME_TYPES_TO_EXPORT = {
GDriveMimeType.DOC.value: "text/plain",
GDriveMimeType.SPREADSHEET.value: "text/csv",
GDriveMimeType.PPT.value: "text/plain",
}
# Define Google MIME types mapping
GOOGLE_MIME_TYPES = {
GDriveMimeType.DOC.value: "text/plain",
GDriveMimeType.SPREADSHEET.value: "text/csv",
GDriveMimeType.PPT.value: "text/plain",
}
def _summarize_drive_image(
image_data: bytes, image_name: str, image_analysis_llm: LLM | None
@@ -76,137 +66,259 @@ def is_gdrive_image_mime_type(mime_type: str) -> bool:
def _extract_sections_basic(
file: dict[str, str],
service: GoogleDriveService,
) -> list[TextSection | ImageSection]:
"""Extract text and images from a Google Drive file."""
file_id = file["id"]
file_name = file["name"]
image_analysis_llm: LLM | None = None,
) -> list[Section]:
"""
Extends the existing logic to handle either a docx with embedded images
or standalone images (PNG, JPG, etc).
"""
mime_type = file["mimeType"]
link = file.get("webViewLink", "")
link = file["webViewLink"]
file_name = file.get("name", file["id"])
supported_file_types = set(item.value for item in GDriveMimeType)
try:
# For Google Docs, Sheets, and Slides, export as plain text
if mime_type in GOOGLE_MIME_TYPES_TO_EXPORT:
export_mime_type = GOOGLE_MIME_TYPES_TO_EXPORT[mime_type]
# Use the correct API call for exporting files
request = service.files().export_media(
fileId=file_id, mimeType=export_mime_type
)
response_bytes = io.BytesIO()
downloader = MediaIoBaseDownload(response_bytes, request)
done = False
while not done:
_, done = downloader.next_chunk()
# 1) If the file is an image, retrieve the raw bytes, optionally summarize
if is_gdrive_image_mime_type(mime_type):
try:
response = service.files().get_media(fileId=file["id"]).execute()
response = response_bytes.getvalue()
if not response:
logger.warning(f"Failed to export {file_name} as {export_mime_type}")
return []
text = response.decode("utf-8")
return [TextSection(link=link, text=text)]
# For other file types, download the file
# Use the correct API call for downloading files
request = service.files().get_media(fileId=file_id)
response_bytes = io.BytesIO()
downloader = MediaIoBaseDownload(response_bytes, request)
done = False
while not done:
_, done = downloader.next_chunk()
response = response_bytes.getvalue()
if not response:
logger.warning(f"Failed to download {file_name}")
return []
# Process based on mime type
if mime_type == "text/plain":
text = response.decode("utf-8")
return [TextSection(link=link, text=text)]
elif (
mime_type
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
):
text, _ = docx_to_text_and_images(io.BytesIO(response))
return [TextSection(link=link, text=text)]
elif (
mime_type
== "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
):
text = xlsx_to_text(io.BytesIO(response))
return [TextSection(link=link, text=text)]
elif (
mime_type
== "application/vnd.openxmlformats-officedocument.presentationml.presentation"
):
text = pptx_to_text(io.BytesIO(response))
return [TextSection(link=link, text=text)]
elif is_gdrive_image_mime_type(mime_type):
# For images, store them for later processing
sections: list[TextSection | ImageSection] = []
try:
with get_session_with_current_tenant() as db_session:
section, embedded_id = store_image_and_create_section(
db_session=db_session,
image_data=response,
file_name=file_id,
display_name=file_name,
media_type=mime_type,
file_origin=FileOrigin.CONNECTOR,
link=link,
)
sections.append(section)
except Exception as e:
logger.error(f"Failed to process image {file_name}: {e}")
return sections
elif mime_type == "application/pdf":
text, _pdf_meta, images = read_pdf_file(io.BytesIO(response))
pdf_sections: list[TextSection | ImageSection] = [
TextSection(link=link, text=text)
with get_session_with_current_tenant() as db_session:
section, _ = store_image_and_create_section(
db_session=db_session,
image_data=response,
file_name=file["id"],
display_name=file_name,
media_type=mime_type,
llm=image_analysis_llm,
file_origin=FileOrigin.CONNECTOR,
)
return [section]
except Exception as e:
logger.warning(f"Failed to fetch or summarize image: {e}")
return [
Section(
link=link,
text="",
image_file_name=link,
)
]
# Process embedded images in the PDF
if mime_type not in supported_file_types:
# Unsupported file types can still have a title, finding this way is still useful
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
try:
# ---------------------------
# Google Sheets extraction
if mime_type == GDriveMimeType.SPREADSHEET.value:
try:
sheets_service = build(
"sheets", "v4", credentials=service._http.credentials
)
spreadsheet = (
sheets_service.spreadsheets()
.get(spreadsheetId=file["id"])
.execute()
)
sections = []
for sheet in spreadsheet["sheets"]:
sheet_name = sheet["properties"]["title"]
sheet_id = sheet["properties"]["sheetId"]
# Get sheet dimensions
grid_properties = sheet["properties"].get("gridProperties", {})
row_count = grid_properties.get("rowCount", 1000)
column_count = grid_properties.get("columnCount", 26)
# Convert column count to letter (e.g., 26 -> Z, 27 -> AA)
end_column = ""
while column_count:
column_count, remainder = divmod(column_count - 1, 26)
end_column = chr(65 + remainder) + end_column
range_name = f"'{sheet_name}'!A1:{end_column}{row_count}"
try:
result = (
sheets_service.spreadsheets()
.values()
.get(spreadsheetId=file["id"], range=range_name)
.execute()
)
values = result.get("values", [])
if values:
text = f"Sheet: {sheet_name}\n"
for row in values:
text += "\t".join(str(cell) for cell in row) + "\n"
sections.append(
Section(
link=f"{link}#gid={sheet_id}",
text=text,
)
)
except HttpError as e:
logger.warning(
f"Error fetching data for sheet '{sheet_name}': {e}"
)
continue
return sections
except Exception as e:
logger.warning(
f"Ran into exception '{e}' when pulling data from Google Sheet '{file['name']}'."
" Falling back to basic extraction."
)
# ---------------------------
# Microsoft Excel (.xlsx or .xls) extraction branch
elif mime_type in [
GDriveMimeType.SPREADSHEET_OPEN_FORMAT.value,
GDriveMimeType.SPREADSHEET_MS_EXCEL.value,
]:
try:
response = service.files().get_media(fileId=file["id"]).execute()
with NamedTemporaryFile(suffix=".xlsx", delete=True) as tmp:
tmp.write(response)
tmp_path = tmp.name
section_separator = "\n\n"
workbook = openpyxl.load_workbook(tmp_path, read_only=True)
# Work similarly to the xlsx_to_text function used for file connector
# but returns Sections instead of a string
sections = [
Section(
link=link,
text=(
f"Sheet: {sheet.title}\n\n"
+ section_separator.join(
",".join(map(str, row))
for row in sheet.iter_rows(
min_row=1, values_only=True
)
if row
)
),
)
for sheet in workbook.worksheets
]
return sections
except Exception as e:
logger.warning(
f"Error extracting data from Excel file '{file['name']}': {e}"
)
return [
Section(link=link, text="Error extracting data from Excel file")
]
# ---------------------------
# Export for Google Docs, PPT, and fallback for spreadsheets
if mime_type in [
GDriveMimeType.DOC.value,
GDriveMimeType.PPT.value,
GDriveMimeType.SPREADSHEET.value,
]:
export_mime_type = (
"text/plain"
if mime_type != GDriveMimeType.SPREADSHEET.value
else "text/csv"
)
text = (
service.files()
.export(fileId=file["id"], mimeType=export_mime_type)
.execute()
.decode("utf-8")
)
return [Section(link=link, text=text)]
# ---------------------------
# Plain text and Markdown files
elif mime_type in [
GDriveMimeType.PLAIN_TEXT.value,
GDriveMimeType.MARKDOWN.value,
]:
text_data = (
service.files().get_media(fileId=file["id"]).execute().decode("utf-8")
)
return [Section(link=link, text=text_data)]
# ---------------------------
# Word, PowerPoint, PDF files
elif mime_type in [
GDriveMimeType.WORD_DOC.value,
GDriveMimeType.POWERPOINT.value,
GDriveMimeType.PDF.value,
]:
response_bytes = service.files().get_media(fileId=file["id"]).execute()
# Optionally use Unstructured
if get_unstructured_api_key():
text = unstructured_to_text(
file=io.BytesIO(response_bytes),
file_name=file_name,
)
return [Section(link=link, text=text)]
if mime_type == GDriveMimeType.WORD_DOC.value:
# Use docx_to_text_and_images to get text plus embedded images
text, embedded_images = docx_to_text_and_images(
file=io.BytesIO(response_bytes),
)
sections = []
if text.strip():
sections.append(Section(link=link, text=text.strip()))
# Process each embedded image using the standardized function
with get_session_with_current_tenant() as db_session:
for idx, (img_data, img_name) in enumerate(images):
section, embedded_id = store_image_and_create_section(
for idx, (img_data, img_name) in enumerate(
embedded_images, start=1
):
# Create a unique identifier for the embedded image
embedded_id = f"{file['id']}_embedded_{idx}"
section, _ = store_image_and_create_section(
db_session=db_session,
image_data=img_data,
file_name=f"{file_id}_img_{idx}",
file_name=embedded_id,
display_name=img_name or f"{file_name} - image {idx}",
llm=image_analysis_llm,
file_origin=FileOrigin.CONNECTOR,
)
pdf_sections.append(section)
except Exception as e:
logger.error(f"Failed to process PDF images in {file_name}: {e}")
return pdf_sections
sections.append(section)
return sections
else:
# For unsupported file types, try to extract text
try:
text = extract_file_text(io.BytesIO(response), file_name)
return [TextSection(link=link, text=text)]
except Exception as e:
logger.warning(f"Failed to extract text from {file_name}: {e}")
return []
elif mime_type == GDriveMimeType.PDF.value:
text, _pdf_meta, images = read_pdf_file(io.BytesIO(response_bytes))
return [Section(link=link, text=text)]
elif mime_type == GDriveMimeType.POWERPOINT.value:
text_data = pptx_to_text(io.BytesIO(response_bytes))
return [Section(link=link, text=text_data)]
# Catch-all case, should not happen since there should be specific handling
# for each of the supported file types
error_message = f"Unsupported file type: {mime_type}"
logger.error(error_message)
raise ValueError(error_message)
except Exception as e:
logger.error(f"Error processing file {file_name}: {e}")
return []
logger.exception(f"Error extracting sections from file: {e}")
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
def convert_drive_item_to_document(
file: GoogleDriveFileType,
drive_service: GoogleDriveService,
docs_service: GoogleDocsService,
image_analysis_llm: LLM | None,
) -> Document | None:
"""
Main entry point for converting a Google Drive file => Document object.
Now we accept an optional `llm` to pass to `_extract_sections_basic`.
"""
try:
# skip shortcuts or folders
@@ -215,50 +327,44 @@ def convert_drive_item_to_document(
return None
# If it's a Google Doc, we might do advanced parsing
sections: list[TextSection | ImageSection] = []
# Try to get sections using the advanced method first
sections: list[Section] = []
if file.get("mimeType") == GDriveMimeType.DOC.value:
try:
doc_sections = get_document_sections(
docs_service=docs_service, doc_id=file.get("id", "")
)
if doc_sections:
sections = cast(list[TextSection | ImageSection], doc_sections)
# get_document_sections is the advanced approach for Google Docs
sections = get_document_sections(docs_service, file["id"])
except Exception as e:
logger.warning(
f"Error in advanced parsing: {e}. Falling back to basic extraction."
f"Failed to pull google doc sections from '{file['name']}': {e}. "
"Falling back to basic extraction."
)
# If we don't have sections yet, use the basic extraction method
# If not a doc, or if we failed above, do our 'basic' approach
if not sections:
sections = _extract_sections_basic(file, drive_service)
sections = _extract_sections_basic(file, drive_service, image_analysis_llm)
# If we still don't have any sections, skip this file
if not sections:
logger.warning(f"No content extracted from {file.get('name')}. Skipping.")
return None
doc_id = file["webViewLink"]
updated_time = datetime.fromisoformat(file["modifiedTime"]).astimezone(
timezone.utc
)
# Create the document
return Document(
id=doc_id,
sections=sections,
source=DocumentSource.GOOGLE_DRIVE,
semantic_identifier=file.get("name", ""),
metadata={
"owner_names": ", ".join(
owner.get("displayName", "") for owner in file.get("owners", [])
),
},
doc_updated_at=datetime.fromisoformat(
file.get("modifiedTime", "").replace("Z", "+00:00")
),
semantic_identifier=file["name"],
doc_updated_at=updated_time,
metadata={}, # or any metadata from 'file'
additional_info=file.get("id"),
)
except Exception as e:
logger.error(f"Error converting file {file.get('name')}: {e}")
return None
logger.exception(f"Error converting file '{file.get('name')}' to Document: {e}")
if not CONTINUE_ON_CONNECTOR_FAILURE:
raise
return None
def build_slim_document(file: GoogleDriveFileType) -> SlimDocument | None:

View File

@@ -3,7 +3,7 @@ from typing import Any
from pydantic import BaseModel
from onyx.connectors.google_utils.resources import GoogleDocsService
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
class CurrentHeading(BaseModel):
@@ -37,7 +37,7 @@ def _extract_text_from_paragraph(paragraph: dict[str, Any]) -> str:
def get_document_sections(
docs_service: GoogleDocsService,
doc_id: str,
) -> list[TextSection]:
) -> list[Section]:
"""Extracts sections from a Google Doc, including their headings and content"""
# Fetch the document structure
doc = docs_service.documents().get(documentId=doc_id).execute()
@@ -45,7 +45,7 @@ def get_document_sections(
# Get the content
content = doc.get("body", {}).get("content", [])
sections: list[TextSection] = []
sections: list[Section] = []
current_section: list[str] = []
current_heading: CurrentHeading | None = None
@@ -70,7 +70,7 @@ def get_document_sections(
heading_text = current_heading.text
section_text = f"{heading_text}\n" + "\n".join(current_section)
sections.append(
TextSection(
Section(
text=section_text.strip(),
link=_build_gdoc_section_link(doc_id, current_heading.id),
)
@@ -96,7 +96,7 @@ def get_document_sections(
if current_heading is not None and current_section:
section_text = f"{current_heading.text}\n" + "\n".join(current_section)
sections.append(
TextSection(
Section(
text=section_text.strip(),
link=_build_gdoc_section_link(doc_id, current_heading.id),
)

View File

@@ -12,7 +12,7 @@ from onyx.configs.constants import DocumentSource
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.db.engine import get_sqlalchemy_engine
from onyx.file_processing.extract_file_text import load_files_from_zip
from onyx.file_processing.extract_file_text import read_text_file
@@ -118,7 +118,7 @@ class GoogleSitesConnector(LoadConnector):
source=DocumentSource.GOOGLE_SITES,
semantic_identifier=title,
sections=[
TextSection(
Section(
link=(self.base_url.rstrip("/") + "/" + path.lstrip("/"))
if path
else "",

View File

@@ -15,7 +15,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.file_processing.html_utils import parse_html_page_basic
from onyx.utils.logger import setup_logger
@@ -120,7 +120,7 @@ class GuruConnector(LoadConnector, PollConnector):
doc_batch.append(
Document(
id=card["id"],
sections=[TextSection(link=link, text=content_text)],
sections=[Section(link=link, text=content_text)],
source=DocumentSource.GURU,
semantic_identifier=title,
doc_updated_at=latest_time,

View File

@@ -13,7 +13,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.utils.logger import setup_logger
HUBSPOT_BASE_URL = "https://app.hubspot.com/contacts/"
@@ -108,7 +108,7 @@ class HubSpotConnector(LoadConnector, PollConnector):
doc_batch.append(
Document(
id=ticket.id,
sections=[TextSection(link=link, text=content_text)],
sections=[Section(link=link, text=content_text)],
source=DocumentSource.HUBSPOT,
semantic_identifier=title,
# Is already in tzutc, just replacing the timezone format

View File

@@ -24,8 +24,6 @@ CheckpointOutput = Generator[Document | ConnectorFailure, None, ConnectorCheckpo
class BaseConnector(abc.ABC):
REDIS_KEY_PREFIX = "da_connector_data:"
# Common image file extensions supported across connectors
IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".webp", ".gif"}
@abc.abstractmethod
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:

View File

@@ -21,8 +21,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import request_with_retries
@@ -238,30 +237,22 @@ class LinearConnector(LoadConnector, PollConnector, OAuthConnector):
documents: list[Document] = []
for edge in edges:
node = edge["node"]
# Create sections for description and comments
sections = [
TextSection(
link=node["url"],
text=node["description"] or "",
)
]
# Add comment sections
for comment in node["comments"]["nodes"]:
sections.append(
TextSection(
link=node["url"],
text=comment["body"] or "",
)
)
# Cast the sections list to the expected type
typed_sections = cast(list[TextSection | ImageSection], sections)
documents.append(
Document(
id=node["id"],
sections=typed_sections,
sections=[
Section(
link=node["url"],
text=node["description"] or "",
)
]
+ [
Section(
link=node["url"],
text=comment["body"] or "",
)
for comment in node["comments"]["nodes"]
],
source=DocumentSource.LINEAR,
semantic_identifier=f"[{node['identifier']}] {node['title']}",
title=node["title"],

View File

@@ -17,7 +17,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.file_processing.html_utils import parse_html_page_basic
from onyx.file_processing.html_utils import strip_excessive_newlines_and_spaces
from onyx.utils.logger import setup_logger
@@ -162,7 +162,7 @@ class LoopioConnector(LoadConnector, PollConnector):
doc_batch.append(
Document(
id=str(entry["id"]),
sections=[TextSection(link=link, text=content_text)],
sections=[Section(link=link, text=content_text)],
source=DocumentSource.LOOPIO,
semantic_identifier=questions[0],
doc_updated_at=latest_time,

View File

@@ -6,7 +6,6 @@ import tempfile
from collections.abc import Generator
from collections.abc import Iterator
from typing import Any
from typing import cast
from typing import ClassVar
import pywikibot.time # type: ignore[import-untyped]
@@ -21,8 +20,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.mediawiki.family import family_class_dispatch
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.utils.logger import setup_logger
@@ -62,14 +60,14 @@ def get_doc_from_page(
sections_extracted: textlib.Content = textlib.extract_sections(page_text, site)
sections = [
TextSection(
Section(
link=f"{page.full_url()}#" + section.heading.replace(" ", "_"),
text=section.title + section.content,
)
for section in sections_extracted.sections
]
sections.append(
TextSection(
Section(
link=page.full_url(),
text=sections_extracted.header,
)
@@ -81,7 +79,7 @@ def get_doc_from_page(
doc_updated_at=pywikibot_timestamp_to_utc_datetime(
page.latest_revision.timestamp
),
sections=cast(list[TextSection | ImageSection], sections),
sections=sections,
semantic_identifier=page.title(),
metadata={"categories": [category.title() for category in page.categories()]},
id=f"MEDIAWIKI_{page.pageid}_{page.full_url()}",

View File

@@ -1,4 +1,3 @@
import json
from datetime import datetime
from enum import Enum
from typing import Any
@@ -28,25 +27,9 @@ class ConnectorMissingCredentialError(PermissionError):
class Section(BaseModel):
"""Base section class with common attributes"""
link: str | None = None
text: str | None = None
image_file_name: str | None = None
class TextSection(Section):
"""Section containing text content"""
text: str
link: str | None = None
class ImageSection(Section):
"""Section containing an image reference"""
image_file_name: str
link: str | None = None
image_file_name: str | None = None
class BasicExpertInfo(BaseModel):
@@ -116,7 +99,7 @@ class DocumentBase(BaseModel):
"""Used for Onyx ingestion api, the ID is inferred before use if not provided"""
id: str | None = None
sections: list[TextSection | ImageSection]
sections: list[Section]
source: DocumentSource | None = None
semantic_identifier: str # displayed in the UI as the main identifier for the doc
metadata: dict[str, str | list[str]]
@@ -166,11 +149,19 @@ class DocumentBase(BaseModel):
class Document(DocumentBase):
"""Used for Onyx ingestion api, the ID is required"""
id: str
id: str # This must be unique or during indexing/reindexing, chunks will be overwritten
source: DocumentSource
def get_total_char_length(self) -> int:
"""Calculate the total character length of the document including sections, metadata, and identifiers."""
section_length = sum(len(section.text) for section in self.sections)
identifier_length = len(self.semantic_identifier) + len(self.title or "")
metadata_length = sum(
len(k) + len(v) if isinstance(v, str) else len(k) + sum(len(x) for x in v)
for k, v in self.metadata.items()
)
return section_length + identifier_length + metadata_length
def to_short_descriptor(self) -> str:
"""Used when logging the identity of a document"""
return f"ID: '{self.id}'; Semantic ID: '{self.semantic_identifier}'"
@@ -193,32 +184,6 @@ class Document(DocumentBase):
)
class IndexingDocument(Document):
"""Document with processed sections for indexing"""
processed_sections: list[Section] = []
def get_total_char_length(self) -> int:
"""Get the total character length of the document including processed sections"""
title_len = len(self.title or self.semantic_identifier)
# Use processed_sections if available, otherwise fall back to original sections
if self.processed_sections:
section_len = sum(
len(section.text) if section.text is not None else 0
for section in self.processed_sections
)
else:
section_len = sum(
len(section.text)
if isinstance(section, TextSection) and section.text is not None
else 0
for section in self.sections
)
return title_len + section_len
class SlimDocument(BaseModel):
id: str
perm_sync_data: Any | None = None
@@ -239,15 +204,6 @@ class ConnectorCheckpoint(BaseModel):
def build_dummy_checkpoint(cls) -> "ConnectorCheckpoint":
return ConnectorCheckpoint(checkpoint_content={}, has_more=True)
def __str__(self) -> str:
"""String representation of the checkpoint, with truncation for large checkpoint content."""
MAX_CHECKPOINT_CONTENT_CHARS = 1000
content_str = json.dumps(self.checkpoint_content)
if len(content_str) > MAX_CHECKPOINT_CONTENT_CHARS:
content_str = content_str[: MAX_CHECKPOINT_CONTENT_CHARS - 3] + "..."
return f"ConnectorCheckpoint(checkpoint_content={content_str}, has_more={self.has_more})"
class DocumentFailure(BaseModel):
document_id: str

View File

@@ -1,3 +1,4 @@
import time
from collections.abc import Generator
from dataclasses import dataclass
from dataclasses import fields
@@ -25,13 +26,12 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.utils.batching import batch_generator
from onyx.utils.logger import setup_logger
logger = setup_logger()
_NOTION_PAGE_SIZE = 100
_NOTION_CALL_TIMEOUT = 30 # 30 seconds
@@ -475,7 +475,7 @@ class NotionConnector(LoadConnector, PollConnector):
Document(
id=page.id,
sections=[
TextSection(
Section(
link=f"{page.url}#{block.id.replace('-', '')}",
text=block.prefix + block.text,
)
@@ -537,9 +537,9 @@ class NotionConnector(LoadConnector, PollConnector):
"""
filtered_pages: list[NotionPage] = []
for page in pages:
# Parse ISO 8601 timestamp and convert to UTC epoch time
timestamp = page[filter_field].replace(".000Z", "+00:00")
compare_time = datetime.fromisoformat(timestamp).timestamp()
compare_time = time.mktime(
time.strptime(page[filter_field], "%Y-%m-%dT%H:%M:%S.000Z")
)
if compare_time > start and compare_time <= end:
filtered_pages += [NotionPage(**page)]
return filtered_pages
@@ -578,7 +578,7 @@ class NotionConnector(LoadConnector, PollConnector):
query_dict = {
"filter": {"property": "object", "value": "page"},
"page_size": _NOTION_PAGE_SIZE,
"page_size": self.batch_size,
}
while True:
db_res = self._search_notion(query_dict)
@@ -604,7 +604,7 @@ class NotionConnector(LoadConnector, PollConnector):
return
query_dict = {
"page_size": _NOTION_PAGE_SIZE,
"page_size": self.batch_size,
"sort": {"timestamp": "last_edited_time", "direction": "descending"},
"filter": {"property": "object", "value": "page"},
}

View File

@@ -23,8 +23,8 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.connectors.onyx_jira.utils import best_effort_basic_expert_info
from onyx.connectors.onyx_jira.utils import best_effort_get_field_from_issue
from onyx.connectors.onyx_jira.utils import build_jira_client
@@ -145,7 +145,7 @@ def fetch_jira_issues_batch(
yield Document(
id=page_url,
sections=[TextSection(link=page_url, text=ticket_content)],
sections=[Section(link=page_url, text=ticket_content)],
source=DocumentSource.JIRA,
semantic_identifier=f"{issue.key}: {issue.fields.summary}",
title=f"{issue.key} {issue.fields.summary}",

View File

@@ -16,7 +16,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.utils.logger import setup_logger
@@ -110,7 +110,7 @@ class ProductboardConnector(PollConnector):
yield Document(
id=feature["id"],
sections=[
TextSection(
Section(
link=feature["links"]["html"],
text=self._parse_description_html(feature["description"]),
)
@@ -133,7 +133,7 @@ class ProductboardConnector(PollConnector):
yield Document(
id=component["id"],
sections=[
TextSection(
Section(
link=component["links"]["html"],
text=self._parse_description_html(component["description"]),
)
@@ -159,7 +159,7 @@ class ProductboardConnector(PollConnector):
yield Document(
id=product["id"],
sections=[
TextSection(
Section(
link=product["links"]["html"],
text=self._parse_description_html(product["description"]),
)
@@ -189,7 +189,7 @@ class ProductboardConnector(PollConnector):
yield Document(
id=objective["id"],
sections=[
TextSection(
Section(
link=objective["links"]["html"],
text=self._parse_description_html(objective["description"]),
)

View File

@@ -13,7 +13,6 @@ from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.connectors.salesforce.doc_conversion import convert_sf_object_to_doc
from onyx.connectors.salesforce.doc_conversion import ID_PREFIX
from onyx.connectors.salesforce.salesforce_calls import fetch_all_csvs_in_parallel
@@ -49,12 +48,10 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
self,
credentials: dict[str, Any],
) -> dict[str, Any] | None:
domain = "test" if credentials.get("is_sandbox") else None
self._sf_client = Salesforce(
username=credentials["sf_username"],
password=credentials["sf_password"],
security_token=credentials["sf_security_token"],
domain=domain,
)
return None
@@ -219,8 +216,7 @@ if __name__ == "__main__":
for doc in doc_batch:
section_count += len(doc.sections)
for section in doc.sections:
if isinstance(section, TextSection) and section.text is not None:
text_count += len(section.text)
text_count += len(section.text)
end_time = time.time()
print(f"Doc count: {doc_count}")

View File

@@ -1,12 +1,10 @@
import re
from typing import cast
from onyx.configs.constants import DocumentSource
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.connectors.salesforce.sqlite_functions import get_child_ids
from onyx.connectors.salesforce.sqlite_functions import get_record
from onyx.connectors.salesforce.utils import SalesforceObject
@@ -116,8 +114,8 @@ def _extract_dict_text(raw_dict: dict) -> str:
return natural_language_for_dict
def _extract_section(salesforce_object: SalesforceObject, base_url: str) -> TextSection:
return TextSection(
def _extract_section(salesforce_object: SalesforceObject, base_url: str) -> Section:
return Section(
text=_extract_dict_text(salesforce_object.data),
link=f"{base_url}/{salesforce_object.id}",
)
@@ -177,7 +175,7 @@ def convert_sf_object_to_doc(
doc = Document(
id=onyx_salesforce_id,
sections=cast(list[TextSection | ImageSection], sections),
sections=sections,
source=DocumentSource.SALESFORCE,
semantic_identifier=extracted_semantic_identifier,
doc_updated_at=extracted_doc_updated_at,

View File

@@ -19,7 +19,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.utils.logger import setup_logger
@@ -55,7 +55,7 @@ def _convert_driveitem_to_document(
doc = Document(
id=driveitem.id,
sections=[TextSection(link=driveitem.web_url, text=file_text)],
sections=[Section(link=driveitem.web_url, text=file_text)],
source=DocumentSource.SHAREPOINT,
semantic_identifier=driveitem.name,
doc_updated_at=driveitem.last_modified_datetime.replace(tzinfo=timezone.utc),

View File

@@ -19,8 +19,8 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
@@ -212,7 +212,7 @@ class SlabConnector(LoadConnector, PollConnector, SlimConnector):
doc_batch.append(
Document(
id=post_id, # can't be url as this changes with the post title
sections=[TextSection(link=page_url, text=content_text)],
sections=[Section(link=page_url, text=content_text)],
source=DocumentSource.SLAB,
semantic_identifier=post["title"],
metadata={},

View File

@@ -34,8 +34,8 @@ from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import EntityFailure
from onyx.connectors.models import Section
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.connectors.slack.utils import expert_info_from_slack_id
from onyx.connectors.slack.utils import get_message_link
from onyx.connectors.slack.utils import make_paginated_slack_api_call_w_retries
@@ -211,7 +211,7 @@ def thread_to_doc(
return Document(
id=_build_doc_id(channel_id=channel_id, thread_ts=thread[0]["ts"]),
sections=[
TextSection(
Section(
link=get_message_link(event=m, client=client, channel_id=channel_id),
text=slack_cleaner.index_clean(cast(str, m["text"])),
)
@@ -220,6 +220,7 @@ def thread_to_doc(
source=DocumentSource.SLACK,
semantic_identifier=doc_sem_id,
doc_updated_at=get_latest_message_time(thread),
title="", # slack docs don't really have a "title"
primary_owners=valid_experts,
metadata={"Channel": channel["name"]},
)
@@ -673,7 +674,7 @@ class SlackConnector(SlimConnector, CheckpointConnector):
"""
1. Verify the bot token is valid for the workspace (via auth_test).
2. Ensure the bot has enough scope to list channels.
3. Check that every channel specified in self.channels exists (only when regex is not enabled).
3. Check that every channel specified in self.channels exists.
"""
if self.client is None:
raise ConnectorMissingCredentialError("Slack credentials not loaded.")
@@ -705,8 +706,8 @@ class SlackConnector(SlimConnector, CheckpointConnector):
f"Slack API returned a failure: {error_msg}"
)
# 3) If channels are specified and regex is not enabled, verify each is accessible
if self.channels and not self.channel_regex_enabled:
# 3) If channels are specified, verify each is accessible
if self.channels:
accessible_channels = get_channels(
client=self.client,
exclude_archived=True,

View File

@@ -24,7 +24,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.file_processing.html_utils import parse_html_page_basic
from onyx.utils.logger import setup_logger
@@ -165,7 +165,7 @@ def _convert_thread_to_document(
doc = Document(
id=post_id,
sections=[TextSection(link=web_url, text=thread_text)],
sections=[Section(link=web_url, text=thread_text)],
source=DocumentSource.TEAMS,
semantic_identifier=semantic_string,
title="", # teams threads don't really have a "title"

View File

@@ -0,0 +1,45 @@
"""
Mixin for connectors that need vision capabilities.
"""
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
from onyx.llm.factory import get_default_llm_with_vision
from onyx.llm.interfaces import LLM
from onyx.utils.logger import setup_logger
logger = setup_logger()
class VisionEnabledConnector:
"""
Mixin for connectors that need vision capabilities.
This mixin provides a standard way to initialize a vision-capable LLM
for image analysis during indexing.
Usage:
class MyConnector(LoadConnector, VisionEnabledConnector):
def __init__(self, ...):
super().__init__(...)
self.initialize_vision_llm()
"""
def initialize_vision_llm(self) -> None:
"""
Initialize a vision-capable LLM if enabled by configuration.
Sets self.image_analysis_llm to the LLM instance or None if disabled.
"""
self.image_analysis_llm: LLM | None = None
if get_image_extraction_and_analysis_enabled():
try:
self.image_analysis_llm = get_default_llm_with_vision()
if self.image_analysis_llm is None:
logger.warning(
"No LLM with vision found; image summarization will be disabled"
)
except Exception as e:
logger.warning(
f"Failed to initialize vision LLM due to an error: {str(e)}. "
"Image summarization will be disabled."
)
self.image_analysis_llm = None

View File

@@ -32,7 +32,7 @@ from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.file_processing.extract_file_text import read_pdf_file
from onyx.file_processing.html_utils import web_html_cleanup
from onyx.utils.logger import setup_logger
@@ -341,7 +341,7 @@ class WebConnector(LoadConnector):
doc_batch.append(
Document(
id=initial_url,
sections=[TextSection(link=initial_url, text=page_text)],
sections=[Section(link=initial_url, text=page_text)],
source=DocumentSource.WEB,
semantic_identifier=initial_url.split("/")[-1],
metadata=metadata,
@@ -443,7 +443,7 @@ class WebConnector(LoadConnector):
Document(
id=initial_url,
sections=[
TextSection(link=initial_url, text=parsed_html.cleaned_text)
Section(link=initial_url, text=parsed_html.cleaned_text)
],
source=DocumentSource.WEB,
semantic_identifier=parsed_html.title or initial_url,

View File

@@ -28,7 +28,7 @@ from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -104,7 +104,7 @@ def scrape_page_posts(
# id. We may want to de-dupe this stuff inside the indexing service.
document = Document(
id=f"{DocumentSource.XENFORO.value}_{title}_{page_index}_{formatted_time}",
sections=[TextSection(link=url, text=post_text)],
sections=[Section(link=url, text=post_text)],
title=title,
source=DocumentSource.XENFORO,
semantic_identifier=title,

View File

@@ -1,6 +1,5 @@
from collections.abc import Iterator
from typing import Any
from typing import cast
import requests
@@ -18,8 +17,8 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.file_processing.html_utils import parse_html_page_basic
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.retry_wrapper import retry_builder
@@ -169,8 +168,8 @@ def _article_to_document(
return new_author_mapping, Document(
id=f"article:{article['id']}",
sections=[
TextSection(
link=cast(str, article.get("html_url")),
Section(
link=article.get("html_url"),
text=parse_html_page_basic(article["body"]),
)
],
@@ -269,7 +268,7 @@ def _ticket_to_document(
return new_author_mapping, Document(
id=f"zendesk_ticket_{ticket['id']}",
sections=[TextSection(link=ticket_display_url, text=full_text)],
sections=[Section(link=ticket_display_url, text=full_text)],
source=DocumentSource.ZENDESK,
semantic_identifier=f"Ticket #{ticket['id']}: {subject or 'No Subject'}",
doc_updated_at=update_time,

View File

@@ -20,7 +20,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.connectors.zulip.schemas import GetMessagesResponse
from onyx.connectors.zulip.schemas import Message
from onyx.connectors.zulip.utils import build_search_narrow
@@ -161,7 +161,7 @@ class ZulipConnector(LoadConnector, PollConnector):
return Document(
id=f"{message.stream_id}__{message.id}",
sections=[
TextSection(
Section(
link=self._message_to_narrow_link(message),
text=text,
)

View File

@@ -10,7 +10,6 @@ from langchain_core.messages import SystemMessage
from onyx.chat.models import SectionRelevancePiece
from onyx.configs.app_configs import BLURB_SIZE
from onyx.configs.app_configs import IMAGE_ANALYSIS_SYSTEM_PROMPT
from onyx.configs.constants import RETURN_SEPARATOR
from onyx.configs.llm_configs import get_search_time_image_analysis_enabled
from onyx.configs.model_configs import CROSS_ENCODER_RANGE_MAX
@@ -32,6 +31,7 @@ from onyx.file_store.file_store import get_default_file_store
from onyx.llm.interfaces import LLM
from onyx.llm.utils import message_to_string
from onyx.natural_language_processing.search_nlp_models import RerankingModel
from onyx.prompts.image_analysis import IMAGE_ANALYSIS_SYSTEM_PROMPT
from onyx.secondary_llm_flows.chunk_usefulness import llm_batch_eval_sections
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import FunctionCall

View File

@@ -1,63 +0,0 @@
from datetime import datetime
from datetime import timezone
from sqlalchemy import delete
from sqlalchemy.orm import Session
from onyx.db.models import ChunkStats
from onyx.indexing.models import UpdatableChunkData
def update_chunk_boost_components__no_commit(
chunk_data: list[UpdatableChunkData],
db_session: Session,
) -> None:
"""Updates the chunk_boost_components for chunks in the database.
Args:
chunk_data: List of dicts containing chunk_id, document_id, and boost_score
db_session: SQLAlchemy database session
"""
if not chunk_data:
return
for data in chunk_data:
chunk_in_doc_id = int(data.chunk_id)
if chunk_in_doc_id < 0:
raise ValueError(f"Chunk ID is empty for chunk {data}")
chunk_document_id = f"{data.document_id}" f"__{chunk_in_doc_id}"
chunk_stats = (
db_session.query(ChunkStats)
.filter(
ChunkStats.id == chunk_document_id,
)
.first()
)
score = data.boost_score
if chunk_stats:
chunk_stats.information_content_boost = score
chunk_stats.last_modified = datetime.now(timezone.utc)
db_session.add(chunk_stats)
else:
# do not save new chunks with a neutral boost score
if score == 1.0:
continue
# Create new record
chunk_stats = ChunkStats(
document_id=data.document_id,
chunk_in_doc_id=chunk_in_doc_id,
information_content_boost=score,
)
db_session.add(chunk_stats)
def delete_chunk_stats_by_connector_credential_pair__no_commit(
db_session: Session, document_ids: list[str]
) -> None:
"""This deletes just chunk stats in postgres."""
stmt = delete(ChunkStats).where(ChunkStats.document_id.in_(document_ids))
db_session.execute(stmt)

View File

@@ -23,7 +23,6 @@ from sqlalchemy.sql.expression import null
from onyx.configs.constants import DEFAULT_BOOST
from onyx.configs.constants import DocumentSource
from onyx.db.chunk import delete_chunk_stats_by_connector_credential_pair__no_commit
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.engine import get_session_context_manager
from onyx.db.enums import AccessType
@@ -563,18 +562,6 @@ def delete_documents_complete__no_commit(
db_session: Session, document_ids: list[str]
) -> None:
"""This completely deletes the documents from the db, including all foreign key relationships"""
# Start by deleting the chunk stats for the documents
delete_chunk_stats_by_connector_credential_pair__no_commit(
db_session=db_session,
document_ids=document_ids,
)
delete_chunk_stats_by_connector_credential_pair__no_commit(
db_session=db_session,
document_ids=document_ids,
)
delete_documents_by_connector_credential_pair__no_commit(db_session, document_ids)
delete_document_feedback_for_documents__no_commit(
document_ids=document_ids, db_session=db_session

View File

@@ -13,7 +13,6 @@ from onyx.db.models import SearchSettings
from onyx.db.models import Tool as ToolModel
from onyx.db.models import User
from onyx.db.models import User__UserGroup
from onyx.llm.utils import model_supports_image_input
from onyx.server.manage.embedding.models import CloudEmbeddingProvider
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
from onyx.server.manage.llm.models import FullLLMProvider
@@ -188,17 +187,6 @@ def fetch_default_provider(db_session: Session) -> FullLLMProvider | None:
return FullLLMProvider.from_model(provider_model)
def fetch_default_vision_provider(db_session: Session) -> FullLLMProvider | None:
provider_model = db_session.scalar(
select(LLMProviderModel).where(
LLMProviderModel.is_default_vision_provider == True # noqa: E712
)
)
if not provider_model:
return None
return FullLLMProvider.from_model(provider_model)
def fetch_provider(db_session: Session, provider_name: str) -> FullLLMProvider | None:
provider_model = db_session.scalar(
select(LLMProviderModel).where(LLMProviderModel.name == provider_name)
@@ -258,39 +246,3 @@ def update_default_provider(provider_id: int, db_session: Session) -> None:
new_default.is_default_provider = True
db_session.commit()
def update_default_vision_provider(
provider_id: int, vision_model: str | None, db_session: Session
) -> None:
new_default = db_session.scalar(
select(LLMProviderModel).where(LLMProviderModel.id == provider_id)
)
if not new_default:
raise ValueError(f"LLM Provider with id {provider_id} does not exist")
# Validate that the specified vision model supports image input
model_to_validate = vision_model or new_default.default_model_name
if model_to_validate:
if not model_supports_image_input(model_to_validate, new_default.provider):
raise ValueError(
f"Model '{model_to_validate}' for provider '{new_default.provider}' does not support image input"
)
else:
raise ValueError(
f"Model '{vision_model}' is not a valid model for provider '{new_default.provider}'"
)
existing_default = db_session.scalar(
select(LLMProviderModel).where(
LLMProviderModel.is_default_vision_provider == True # noqa: E712
)
)
if existing_default:
existing_default.is_default_vision_provider = None
# required to ensure that the below does not cause a unique constraint violation
db_session.flush()
new_default.is_default_vision_provider = True
new_default.default_vision_model = vision_model
db_session.commit()

View File

@@ -591,55 +591,6 @@ class Document(Base):
)
class ChunkStats(Base):
__tablename__ = "chunk_stats"
# NOTE: if more sensitive data is added here for display, make sure to add user/group permission
# this should correspond to the ID of the document
# (as is passed around in Onyx)
id: Mapped[str] = mapped_column(
NullFilteredString,
primary_key=True,
default=lambda context: (
f"{context.get_current_parameters()['document_id']}"
f"__{context.get_current_parameters()['chunk_in_doc_id']}"
),
index=True,
)
# Reference to parent document
document_id: Mapped[str] = mapped_column(
NullFilteredString, ForeignKey("document.id"), nullable=False, index=True
)
chunk_in_doc_id: Mapped[int] = mapped_column(
Integer,
nullable=False,
)
information_content_boost: Mapped[float | None] = mapped_column(
Float, nullable=True
)
last_modified: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=False, index=True, default=func.now()
)
last_synced: Mapped[datetime.datetime | None] = mapped_column(
DateTime(timezone=True), nullable=True, index=True
)
__table_args__ = (
Index(
"ix_chunk_sync_status",
last_modified,
last_synced,
),
UniqueConstraint(
"document_id", "chunk_in_doc_id", name="uq_chunk_stats_doc_chunk"
),
)
class Tag(Base):
__tablename__ = "tag"
@@ -1538,8 +1489,6 @@ class LLMProvider(Base):
# should only be set for a single provider
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, unique=True)
is_default_vision_provider: Mapped[bool | None] = mapped_column(Boolean)
default_vision_model: Mapped[str | None] = mapped_column(String, nullable=True)
# EE only
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
groups: Mapped[list["UserGroup"]] = relationship(
@@ -2346,31 +2295,21 @@ class PublicBase(DeclarativeBase):
__abstract__ = True
# Strictly keeps track of the tenant that a given user will authenticate to.
class UserTenantMapping(Base):
__tablename__ = "user_tenant_mapping"
__table_args__ = ({"schema": "public"},)
__table_args__ = (
UniqueConstraint("email", "tenant_id", name="uq_user_tenant"),
{"schema": "public"},
)
email: Mapped[str] = mapped_column(String, nullable=False, primary_key=True)
tenant_id: Mapped[str] = mapped_column(String, nullable=False, primary_key=True)
active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
tenant_id: Mapped[str] = mapped_column(String, nullable=False)
@validates("email")
def validate_email(self, key: str, value: str) -> str:
return value.lower() if value else value
class AvailableTenant(Base):
__tablename__ = "available_tenant"
"""
These entries will only exist ephemerally and are meant to be picked up by new users on registration.
"""
tenant_id: Mapped[str] = mapped_column(String, primary_key=True, nullable=False)
alembic_version: Mapped[str] = mapped_column(String, nullable=False)
date_created: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False)
# This is a mapping from tenant IDs to anonymous user paths
class TenantAnonymousUserPath(Base):
__tablename__ = "tenant_anonymous_user_path"

View File

@@ -67,9 +67,6 @@ def read_lobj(
use_tempfile: bool = False,
) -> IO:
pg_conn = get_pg_conn_from_session(db_session)
# Ensure we're using binary mode by default for large objects
if mode is None:
mode = "rb"
large_object = (
pg_conn.lobject(lobj_oid, mode=mode) if mode else pg_conn.lobject(lobj_oid)
)
@@ -84,7 +81,6 @@ def read_lobj(
temp_file.seek(0)
return temp_file
else:
# Ensure we're getting raw bytes without text decoding
return BytesIO(large_object.read())

View File

@@ -1,5 +1,6 @@
from sqlalchemy import and_
from sqlalchemy import delete
from sqlalchemy import func
from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -148,10 +149,11 @@ def delete_document_tags_for_documents__no_commit(
stmt = delete(Document__Tag).where(Document__Tag.document_id.in_(document_ids))
db_session.execute(stmt)
orphan_tags_query = select(Tag.id).where(
~db_session.query(Document__Tag.tag_id)
.filter(Document__Tag.tag_id == Tag.id)
.exists()
orphan_tags_query = (
select(Tag.id)
.outerjoin(Document__Tag, Tag.id == Document__Tag.tag_id)
.group_by(Tag.id)
.having(func.count(Document__Tag.document_id) == 0)
)
orphan_tags = db_session.execute(orphan_tags_query).scalars().all()

View File

@@ -101,7 +101,6 @@ class VespaDocumentFields:
document_sets: set[str] | None = None
boost: float | None = None
hidden: bool | None = None
aggregated_chunk_boost_factor: float | None = None
@dataclass

View File

@@ -80,11 +80,6 @@ schema DANSWER_CHUNK_NAME {
indexing: summary | attribute
rank: filter
}
# Field to indicate whether a short chunk is a low content chunk
field aggregated_chunk_boost_factor type float {
indexing: attribute
}
# Needs to have a separate Attribute list for efficient filtering
field metadata_list type array<string> {
indexing: summary | attribute
@@ -147,11 +142,6 @@ schema DANSWER_CHUNK_NAME {
expression: max(if(isNan(attribute(doc_updated_at)) == 1, 7890000, now() - attribute(doc_updated_at)) / 31536000, 0)
}
function inline aggregated_chunk_boost() {
# Aggregated boost factor, currently only used for information content classification
expression: if(isNan(attribute(aggregated_chunk_boost_factor)) == 1, 1.0, attribute(aggregated_chunk_boost_factor))
}
# Document score decays from 1 to 0.75 as age of last updated time increases
function inline recency_bias() {
expression: max(1 / (1 + query(decay_factor) * document_age), 0.75)
@@ -209,8 +199,6 @@ schema DANSWER_CHUNK_NAME {
* document_boost
# Decay factor based on time document was last updated
* recency_bias
# Boost based on aggregated boost calculation
* aggregated_chunk_boost
}
rerank-count: 1000
}
@@ -222,7 +210,6 @@ schema DANSWER_CHUNK_NAME {
closeness(field, embeddings)
document_boost
recency_bias
aggregated_chunk_boost
closest(embeddings)
}
}

View File

@@ -22,7 +22,6 @@ from onyx.document_index.vespa.shared_utils.utils import (
replace_invalid_doc_id_characters,
)
from onyx.document_index.vespa_constants import ACCESS_CONTROL_LIST
from onyx.document_index.vespa_constants import AGGREGATED_CHUNK_BOOST_FACTOR
from onyx.document_index.vespa_constants import BLURB
from onyx.document_index.vespa_constants import BOOST
from onyx.document_index.vespa_constants import CHUNK_ID
@@ -202,7 +201,6 @@ def _index_vespa_chunk(
DOCUMENT_SETS: {document_set: 1 for document_set in chunk.document_sets},
IMAGE_FILE_NAME: chunk.image_file_name,
BOOST: chunk.boost,
AGGREGATED_CHUNK_BOOST_FACTOR: chunk.aggregated_chunk_boost_factor,
}
if multitenant:

View File

@@ -72,7 +72,6 @@ METADATA = "metadata"
METADATA_LIST = "metadata_list"
METADATA_SUFFIX = "metadata_suffix"
BOOST = "boost"
AGGREGATED_CHUNK_BOOST_FACTOR = "aggregated_chunk_boost_factor"
DOC_UPDATED_AT = "doc_updated_at" # Indexed as seconds since epoch
PRIMARY_OWNERS = "primary_owners"
SECONDARY_OWNERS = "secondary_owners"
@@ -98,7 +97,6 @@ YQL_BASE = (
f"{SECTION_CONTINUATION}, "
f"{IMAGE_FILE_NAME}, "
f"{BOOST}, "
f"{AGGREGATED_CHUNK_BOOST_FACTOR}, "
f"{HIDDEN}, "
f"{DOC_UPDATED_AT}, "
f"{PRIMARY_OWNERS}, "

View File

@@ -6,10 +6,10 @@ from langchain_core.messages import HumanMessage
from langchain_core.messages import SystemMessage
from PIL import Image
from onyx.configs.app_configs import IMAGE_SUMMARIZATION_SYSTEM_PROMPT
from onyx.configs.app_configs import IMAGE_SUMMARIZATION_USER_PROMPT
from onyx.llm.interfaces import LLM
from onyx.llm.utils import message_to_string
from onyx.prompts.image_analysis import IMAGE_SUMMARIZATION_SYSTEM_PROMPT
from onyx.prompts.image_analysis import IMAGE_SUMMARIZATION_USER_PROMPT
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -62,7 +62,7 @@ def summarize_image_with_error_handling(
image_data: The raw image bytes
context_name: Name or title of the image for context
system_prompt: System prompt to use for the LLM
user_prompt_template: User prompt to use (without title)
user_prompt_template: Template for the user prompt, should contain {title} placeholder
Returns:
The image summary text, or None if summarization failed or is disabled
@@ -70,10 +70,7 @@ def summarize_image_with_error_handling(
if llm is None:
return None
# Prepend the image filename to the user prompt
user_prompt = (
f"The image has the file name '{context_name}'.\n{user_prompt_template}"
)
user_prompt = user_prompt_template.format(title=context_name)
return summarize_image_pipeline(llm, image_data, user_prompt, system_prompt)

View File

@@ -2,9 +2,12 @@ from typing import Tuple
from sqlalchemy.orm import Session
from onyx.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
from onyx.configs.constants import FileOrigin
from onyx.connectors.models import ImageSection
from onyx.connectors.models import Section
from onyx.db.pg_file_store import save_bytes_to_pgfilestore
from onyx.file_processing.image_summarization import summarize_image_with_error_handling
from onyx.llm.interfaces import LLM
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -15,12 +18,12 @@ def store_image_and_create_section(
image_data: bytes,
file_name: str,
display_name: str,
link: str | None = None,
media_type: str = "application/octet-stream",
media_type: str = "image/unknown",
llm: LLM | None = None,
file_origin: FileOrigin = FileOrigin.OTHER,
) -> Tuple[ImageSection, str | None]:
) -> Tuple[Section, str | None]:
"""
Stores an image in PGFileStore and creates an ImageSection object without summarization.
Stores an image in PGFileStore and creates a Section object with optional summarization.
Args:
db_session: Database session
@@ -28,11 +31,12 @@ def store_image_and_create_section(
file_name: Base identifier for the file
display_name: Human-readable name for the image
media_type: MIME type of the image
llm: Optional LLM with vision capabilities for summarization
file_origin: Origin of the file (e.g., CONFLUENCE, GOOGLE_DRIVE, etc.)
Returns:
Tuple containing:
- ImageSection object with image reference
- Section object with image reference and optional summary text
- The file_name in PGFileStore or None if storage failed
"""
# Storage logic
@@ -49,10 +53,18 @@ def store_image_and_create_section(
stored_file_name = pgfilestore.file_name
except Exception as e:
logger.error(f"Failed to store image: {e}")
raise e
if not CONTINUE_ON_CONNECTOR_FAILURE:
raise
return Section(text=""), None
# Summarization logic
summary_text = ""
if llm:
summary_text = (
summarize_image_with_error_handling(llm, image_data, display_name) or ""
)
# Create an ImageSection with empty text (will be filled by LLM later in the pipeline)
return (
ImageSection(image_file_name=stored_file_name, link=link),
Section(text=summary_text, image_file_name=stored_file_name),
stored_file_name,
)

Some files were not shown because too many files have changed in this diff Show More