Compare commits

..

3 Commits

Author SHA1 Message Date
Weves
9b169350a9 Switch to monotonic 2025-03-07 19:04:47 -08:00
Weves
c1dbb073d0 Small tweaks 2025-03-07 15:53:22 -08:00
Weves
39bfc6ae16 Add basic memory logging 2025-03-07 15:46:14 -08:00
189 changed files with 1789 additions and 4950 deletions

View File

@@ -48,8 +48,6 @@ env:
# Gitbook
GITBOOK_SPACE_ID: ${{ secrets.GITBOOK_SPACE_ID }}
GITBOOK_API_KEY: ${{ secrets.GITBOOK_API_KEY }}
# Notion
NOTION_INTEGRATION_TOKEN: ${{ secrets.NOTION_INTEGRATION_TOKEN }}
jobs:
connectors-check:

View File

@@ -114,4 +114,3 @@ To try the Onyx Enterprise Edition:
## 💡 Contributing
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.

View File

@@ -5,10 +5,7 @@ Revises: f1ca58b2f2ec
Create Date: 2025-01-29 07:48:46.784041
"""
import logging
from typing import cast
from alembic import op
from sqlalchemy.exc import IntegrityError
from sqlalchemy.sql import text
@@ -18,45 +15,21 @@ down_revision = "f1ca58b2f2ec"
branch_labels = None
depends_on = None
logger = logging.getLogger("alembic.runtime.migration")
def upgrade() -> None:
"""Conflicts on lowercasing will result in the uppercased email getting a
unique integer suffix when converted to lowercase."""
# Get database connection
connection = op.get_bind()
# Fetch all user emails that are not already lowercase
user_emails = connection.execute(
text('SELECT id, email FROM "user" WHERE email != LOWER(email)')
).fetchall()
for user_id, email in user_emails:
email = cast(str, email)
username, domain = email.rsplit("@", 1)
new_email = f"{username.lower()}@{domain.lower()}"
attempt = 1
while True:
try:
# Try updating the email
connection.execute(
text('UPDATE "user" SET email = :new_email WHERE id = :user_id'),
{"new_email": new_email, "user_id": user_id},
)
break # Success, exit loop
except IntegrityError:
next_email = f"{username.lower()}_{attempt}@{domain.lower()}"
# Email conflict occurred, append `_1`, `_2`, etc., to the username
logger.warning(
f"Conflict while lowercasing email: "
f"old_email={email} "
f"conflicting_email={new_email} "
f"next_email={next_email}"
)
new_email = next_email
attempt += 1
# Update all user emails to lowercase
connection.execute(
text(
"""
UPDATE "user"
SET email = LOWER(email)
WHERE email != LOWER(email)
"""
)
)
def downgrade() -> None:

View File

@@ -1,45 +0,0 @@
"""add_default_vision_provider_to_llm_provider
Revision ID: df46c75b714e
Revises: 3934b1bc7b62
Create Date: 2025-03-11 16:20:19.038945
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "df46c75b714e"
down_revision = "3934b1bc7b62"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"llm_provider",
sa.Column(
"is_default_vision_provider",
sa.Boolean(),
nullable=True,
server_default=sa.false(),
),
)
op.add_column(
"llm_provider", sa.Column("default_vision_model", sa.String(), nullable=True)
)
# Add unique constraint for is_default_vision_provider
op.create_unique_constraint(
"uq_llm_provider_is_default_vision_provider",
"llm_provider",
["is_default_vision_provider"],
)
def downgrade() -> None:
op.drop_constraint(
"uq_llm_provider_is_default_vision_provider", "llm_provider", type_="unique"
)
op.drop_column("llm_provider", "default_vision_model")
op.drop_column("llm_provider", "is_default_vision_provider")

View File

@@ -1,33 +0,0 @@
"""add new available tenant table
Revision ID: 3b45e0018bf1
Revises: ac842f85f932
Create Date: 2025-03-06 09:55:18.229910
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "3b45e0018bf1"
down_revision = "ac842f85f932"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create new_available_tenant table
op.create_table(
"available_tenant",
sa.Column("tenant_id", sa.String(), nullable=False),
sa.Column("alembic_version", sa.String(), nullable=False),
sa.Column("date_created", sa.DateTime(), nullable=False),
sa.PrimaryKeyConstraint("tenant_id"),
)
def downgrade() -> None:
# Drop new_available_tenant table
op.drop_table("available_tenant")

View File

@@ -1,51 +0,0 @@
"""new column user tenant mapping
Revision ID: ac842f85f932
Revises: 34e3630c7f32
Create Date: 2025-03-03 13:30:14.802874
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "ac842f85f932"
down_revision = "34e3630c7f32"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add active column with default value of True
op.add_column(
"user_tenant_mapping",
sa.Column(
"active",
sa.Boolean(),
nullable=False,
server_default="true",
),
schema="public",
)
op.drop_constraint("uq_email", "user_tenant_mapping", schema="public")
# Create a unique index for active=true records
# This ensures a user can only be active in one tenant at a time
op.execute(
"CREATE UNIQUE INDEX uq_user_active_email_idx ON public.user_tenant_mapping (email) WHERE active = true"
)
def downgrade() -> None:
# Drop the unique index for active=true records
op.execute("DROP INDEX IF EXISTS uq_user_active_email_idx")
op.create_unique_constraint(
"uq_email", "user_tenant_mapping", ["email"], schema="public"
)
# Remove the active column
op.drop_column("user_tenant_mapping", "active", schema="public")

View File

@@ -1,14 +1,10 @@
import re
from typing import cast
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from ee.onyx.server.query_and_chat.models import AgentAnswer
from ee.onyx.server.query_and_chat.models import AgentSubQuery
from ee.onyx.server.query_and_chat.models import AgentSubQuestion
from ee.onyx.server.query_and_chat.models import BasicCreateChatMessageRequest
from ee.onyx.server.query_and_chat.models import (
BasicCreateChatMessageWithHistoryRequest,
@@ -18,19 +14,13 @@ from ee.onyx.server.query_and_chat.models import SimpleDoc
from onyx.auth.users import current_user
from onyx.chat.chat_utils import combine_message_thread
from onyx.chat.chat_utils import create_chat_chain
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import AllCitations
from onyx.chat.models import ExtendedToolResponse
from onyx.chat.models import FinalUsedContextDocsResponse
from onyx.chat.models import LlmDoc
from onyx.chat.models import LLMRelevanceFilterResponse
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import QADocsResponse
from onyx.chat.models import RefinedAnswerImprovement
from onyx.chat.models import StreamingError
from onyx.chat.models import SubQueryPiece
from onyx.chat.models import SubQuestionIdentifier
from onyx.chat.models import SubQuestionPiece
from onyx.chat.process_message import ChatPacketStream
from onyx.chat.process_message import stream_chat_message_objects
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
@@ -99,12 +89,6 @@ def _convert_packet_stream_to_response(
final_context_docs: list[LlmDoc] = []
answer = ""
# accumulate stream data with these dicts
agent_sub_questions: dict[tuple[int, int], AgentSubQuestion] = {}
agent_answers: dict[tuple[int, int], AgentAnswer] = {}
agent_sub_queries: dict[tuple[int, int, int], AgentSubQuery] = {}
for packet in packets:
if isinstance(packet, OnyxAnswerPiece) and packet.answer_piece:
answer += packet.answer_piece
@@ -113,15 +97,6 @@ def _convert_packet_stream_to_response(
# TODO: deprecate `simple_search_docs`
response.simple_search_docs = _translate_doc_response_to_simple_doc(packet)
# This is a no-op if agent_sub_questions hasn't already been filled
if packet.level is not None and packet.level_question_num is not None:
id = (packet.level, packet.level_question_num)
if id in agent_sub_questions:
agent_sub_questions[id].document_ids = [
saved_search_doc.document_id
for saved_search_doc in packet.top_documents
]
elif isinstance(packet, StreamingError):
response.error_msg = packet.error
elif isinstance(packet, ChatMessageDetail):
@@ -138,104 +113,11 @@ def _convert_packet_stream_to_response(
citation.citation_num: citation.document_id
for citation in packet.citations
}
# agentic packets
elif isinstance(packet, SubQuestionPiece):
if packet.level is not None and packet.level_question_num is not None:
id = (packet.level, packet.level_question_num)
if agent_sub_questions.get(id) is None:
agent_sub_questions[id] = AgentSubQuestion(
level=packet.level,
level_question_num=packet.level_question_num,
sub_question=packet.sub_question,
document_ids=[],
)
else:
agent_sub_questions[id].sub_question += packet.sub_question
elif isinstance(packet, AgentAnswerPiece):
if packet.level is not None and packet.level_question_num is not None:
id = (packet.level, packet.level_question_num)
if agent_answers.get(id) is None:
agent_answers[id] = AgentAnswer(
level=packet.level,
level_question_num=packet.level_question_num,
answer=packet.answer_piece,
answer_type=packet.answer_type,
)
else:
agent_answers[id].answer += packet.answer_piece
elif isinstance(packet, SubQueryPiece):
if packet.level is not None and packet.level_question_num is not None:
sub_query_id = (
packet.level,
packet.level_question_num,
packet.query_id,
)
if agent_sub_queries.get(sub_query_id) is None:
agent_sub_queries[sub_query_id] = AgentSubQuery(
level=packet.level,
level_question_num=packet.level_question_num,
sub_query=packet.sub_query,
query_id=packet.query_id,
)
else:
agent_sub_queries[sub_query_id].sub_query += packet.sub_query
elif isinstance(packet, ExtendedToolResponse):
# we shouldn't get this ... it gets intercepted and translated to QADocsResponse
logger.warning(
"_convert_packet_stream_to_response: Unexpected chat packet type ExtendedToolResponse!"
)
elif isinstance(packet, RefinedAnswerImprovement):
response.agent_refined_answer_improvement = (
packet.refined_answer_improvement
)
else:
logger.warning(
f"_convert_packet_stream_to_response - Unrecognized chat packet: type={type(packet)}"
)
response.final_context_doc_indices = _get_final_context_doc_indices(
final_context_docs, response.top_documents
)
# organize / sort agent metadata for output
if len(agent_sub_questions) > 0:
response.agent_sub_questions = cast(
dict[int, list[AgentSubQuestion]],
SubQuestionIdentifier.make_dict_by_level(agent_sub_questions),
)
if len(agent_answers) > 0:
# return the agent_level_answer from the first level or the last one depending
# on agent_refined_answer_improvement
response.agent_answers = cast(
dict[int, list[AgentAnswer]],
SubQuestionIdentifier.make_dict_by_level(agent_answers),
)
if response.agent_answers:
selected_answer_level = (
0
if not response.agent_refined_answer_improvement
else len(response.agent_answers) - 1
)
level_answers = response.agent_answers[selected_answer_level]
for level_answer in level_answers:
if level_answer.answer_type != "agent_level_answer":
continue
answer = level_answer.answer
break
if len(agent_sub_queries) > 0:
# subqueries are often emitted with trailing whitespace ... clean it up here
# perhaps fix at the source?
for v in agent_sub_queries.values():
v.sub_query = v.sub_query.strip()
response.agent_sub_queries = (
AgentSubQuery.make_dict_by_level_and_question_index(agent_sub_queries)
)
response.answer = answer
if answer:
response.answer_citationless = remove_answer_citations(answer)

View File

@@ -1,5 +1,3 @@
from collections import OrderedDict
from typing import Literal
from uuid import UUID
from pydantic import BaseModel
@@ -11,7 +9,6 @@ from onyx.chat.models import CitationInfo
from onyx.chat.models import OnyxContexts
from onyx.chat.models import PersonaOverrideConfig
from onyx.chat.models import QADocsResponse
from onyx.chat.models import SubQuestionIdentifier
from onyx.chat.models import ThreadMessage
from onyx.configs.constants import DocumentSource
from onyx.context.search.enums import LLMEvaluationType
@@ -91,64 +88,6 @@ class SimpleDoc(BaseModel):
metadata: dict | None
class AgentSubQuestion(SubQuestionIdentifier):
sub_question: str
document_ids: list[str]
class AgentAnswer(SubQuestionIdentifier):
answer: str
answer_type: Literal["agent_sub_answer", "agent_level_answer"]
class AgentSubQuery(SubQuestionIdentifier):
sub_query: str
query_id: int
@staticmethod
def make_dict_by_level_and_question_index(
original_dict: dict[tuple[int, int, int], "AgentSubQuery"]
) -> dict[int, dict[int, list["AgentSubQuery"]]]:
"""Takes a dict of tuple(level, question num, query_id) to sub queries.
returns a dict of level to dict[question num to list of query_id's]
Ordering is asc for readability.
"""
# In this function, when we sort int | None, we deliberately push None to the end
# map entries to the level_question_dict
level_question_dict: dict[int, dict[int, list["AgentSubQuery"]]] = {}
for k1, obj in original_dict.items():
level = k1[0]
question = k1[1]
if level not in level_question_dict:
level_question_dict[level] = {}
if question not in level_question_dict[level]:
level_question_dict[level][question] = []
level_question_dict[level][question].append(obj)
# sort each query_id list and question_index
for key1, obj1 in level_question_dict.items():
for key2, value2 in obj1.items():
# sort the query_id list of each question_index
level_question_dict[key1][key2] = sorted(
value2, key=lambda o: o.query_id
)
# sort the question_index dict of level
level_question_dict[key1] = OrderedDict(
sorted(level_question_dict[key1].items(), key=lambda x: (x is None, x))
)
# sort the top dict of levels
sorted_dict = OrderedDict(
sorted(level_question_dict.items(), key=lambda x: (x is None, x))
)
return sorted_dict
class ChatBasicResponse(BaseModel):
# This is built piece by piece, any of these can be None as the flow could break
answer: str | None = None
@@ -168,12 +107,6 @@ class ChatBasicResponse(BaseModel):
simple_search_docs: list[SimpleDoc] | None = None
llm_chunks_indices: list[int] | None = None
# agentic fields
agent_sub_questions: dict[int, list[AgentSubQuestion]] | None = None
agent_answers: dict[int, list[AgentAnswer]] | None = None
agent_sub_queries: dict[int, dict[int, list[AgentSubQuery]]] | None = None
agent_refined_answer_improvement: bool | None = None
class OneShotQARequest(ChunkContext):
# Supports simplier APIs that don't deal with chat histories or message edits

View File

@@ -1,45 +0,0 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Response
from ee.onyx.auth.users import current_cloud_superuser
from ee.onyx.server.tenants.models import ImpersonateRequest
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
from onyx.auth.users import auth_backend
from onyx.auth.users import get_redis_strategy
from onyx.auth.users import User
from onyx.db.engine import get_session_with_tenant
from onyx.db.users import get_user_by_email
from onyx.utils.logger import setup_logger
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.post("/impersonate")
async def impersonate_user(
impersonate_request: ImpersonateRequest,
_: User = Depends(current_cloud_superuser),
) -> Response:
"""Allows a cloud superuser to impersonate another user by generating an impersonation JWT token"""
tenant_id = get_tenant_id_for_email(impersonate_request.email)
with get_session_with_tenant(tenant_id=tenant_id) as tenant_session:
user_to_impersonate = get_user_by_email(
impersonate_request.email, tenant_session
)
if user_to_impersonate is None:
raise HTTPException(status_code=404, detail="User not found")
token = await get_redis_strategy().write_token(user_to_impersonate)
response = await auth_backend.transport.get_login_response(token)
response.set_cookie(
key="fastapiusersauth",
value=token,
httponly=True,
secure=True,
samesite="lax",
)
return response

View File

@@ -1,98 +0,0 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Response
from sqlalchemy.exc import IntegrityError
from ee.onyx.auth.users import generate_anonymous_user_jwt_token
from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME
from ee.onyx.server.tenants.anonymous_user_path import get_anonymous_user_path
from ee.onyx.server.tenants.anonymous_user_path import (
get_tenant_id_for_anonymous_user_path,
)
from ee.onyx.server.tenants.anonymous_user_path import modify_anonymous_user_path
from ee.onyx.server.tenants.anonymous_user_path import validate_anonymous_user_path
from ee.onyx.server.tenants.models import AnonymousUserPath
from onyx.auth.users import anonymous_user_enabled
from onyx.auth.users import current_admin_user
from onyx.auth.users import optional_user
from onyx.auth.users import User
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
from onyx.db.engine import get_session_with_shared_schema
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.get("/anonymous-user-path")
async def get_anonymous_user_path_api(
_: User | None = Depends(current_admin_user),
) -> AnonymousUserPath:
tenant_id = get_current_tenant_id()
if tenant_id is None:
raise HTTPException(status_code=404, detail="Tenant not found")
with get_session_with_shared_schema() as db_session:
current_path = get_anonymous_user_path(tenant_id, db_session)
return AnonymousUserPath(anonymous_user_path=current_path)
@router.post("/anonymous-user-path")
async def set_anonymous_user_path_api(
anonymous_user_path: str,
_: User | None = Depends(current_admin_user),
) -> None:
tenant_id = get_current_tenant_id()
try:
validate_anonymous_user_path(anonymous_user_path)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
with get_session_with_shared_schema() as db_session:
try:
modify_anonymous_user_path(tenant_id, anonymous_user_path, db_session)
except IntegrityError:
raise HTTPException(
status_code=409,
detail="The anonymous user path is already in use. Please choose a different path.",
)
except Exception as e:
logger.exception(f"Failed to modify anonymous user path: {str(e)}")
raise HTTPException(
status_code=500,
detail="An unexpected error occurred while modifying the anonymous user path",
)
@router.post("/anonymous-user")
async def login_as_anonymous_user(
anonymous_user_path: str,
_: User | None = Depends(optional_user),
) -> Response:
with get_session_with_shared_schema() as db_session:
tenant_id = get_tenant_id_for_anonymous_user_path(
anonymous_user_path, db_session
)
if not tenant_id:
raise HTTPException(status_code=404, detail="Tenant not found")
if not anonymous_user_enabled(tenant_id=tenant_id):
raise HTTPException(status_code=403, detail="Anonymous user is not enabled")
token = generate_anonymous_user_jwt_token(tenant_id)
response = Response()
response.delete_cookie(FASTAPI_USERS_AUTH_COOKIE_NAME)
response.set_cookie(
key=ANONYMOUS_USER_COOKIE_NAME,
value=token,
httponly=True,
secure=True,
samesite="strict",
)
return response

View File

@@ -1,24 +1,269 @@
import stripe
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Response
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from ee.onyx.server.tenants.admin_api import router as admin_router
from ee.onyx.server.tenants.anonymous_users_api import router as anonymous_users_router
from ee.onyx.server.tenants.billing_api import router as billing_router
from ee.onyx.server.tenants.team_membership_api import router as team_membership_router
from ee.onyx.server.tenants.tenant_management_api import (
router as tenant_management_router,
)
from ee.onyx.server.tenants.user_invitations_api import (
router as user_invitations_router,
from ee.onyx.auth.users import current_cloud_superuser
from ee.onyx.auth.users import generate_anonymous_user_jwt_token
from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
from ee.onyx.server.tenants.access import control_plane_dep
from ee.onyx.server.tenants.anonymous_user_path import get_anonymous_user_path
from ee.onyx.server.tenants.anonymous_user_path import (
get_tenant_id_for_anonymous_user_path,
)
from ee.onyx.server.tenants.anonymous_user_path import modify_anonymous_user_path
from ee.onyx.server.tenants.anonymous_user_path import validate_anonymous_user_path
from ee.onyx.server.tenants.billing import fetch_billing_information
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
from ee.onyx.server.tenants.models import AnonymousUserPath
from ee.onyx.server.tenants.models import BillingInformation
from ee.onyx.server.tenants.models import ImpersonateRequest
from ee.onyx.server.tenants.models import ProductGatingRequest
from ee.onyx.server.tenants.models import ProductGatingResponse
from ee.onyx.server.tenants.models import SubscriptionSessionResponse
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
from ee.onyx.server.tenants.product_gating import store_product_gating
from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant
from onyx.auth.users import anonymous_user_enabled
from onyx.auth.users import auth_backend
from onyx.auth.users import current_admin_user
from onyx.auth.users import get_redis_strategy
from onyx.auth.users import optional_user
from onyx.auth.users import User
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
from onyx.db.auth import get_user_count
from onyx.db.engine import get_session
from onyx.db.engine import get_session_with_shared_schema
from onyx.db.engine import get_session_with_tenant
from onyx.db.users import delete_user_from_db
from onyx.db.users import get_user_by_email
from onyx.server.manage.models import UserByEmail
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
# Create a main router to include all sub-routers
# Note: We don't add a prefix here as each router already has the /tenants prefix
router = APIRouter()
stripe.api_key = STRIPE_SECRET_KEY
logger = setup_logger()
router = APIRouter(prefix="/tenants")
# Include all the individual routers
router.include_router(admin_router)
router.include_router(anonymous_users_router)
router.include_router(billing_router)
router.include_router(team_membership_router)
router.include_router(tenant_management_router)
router.include_router(user_invitations_router)
@router.get("/anonymous-user-path")
async def get_anonymous_user_path_api(
_: User | None = Depends(current_admin_user),
) -> AnonymousUserPath:
tenant_id = get_current_tenant_id()
if tenant_id is None:
raise HTTPException(status_code=404, detail="Tenant not found")
with get_session_with_shared_schema() as db_session:
current_path = get_anonymous_user_path(tenant_id, db_session)
return AnonymousUserPath(anonymous_user_path=current_path)
@router.post("/anonymous-user-path")
async def set_anonymous_user_path_api(
anonymous_user_path: str,
_: User | None = Depends(current_admin_user),
) -> None:
tenant_id = get_current_tenant_id()
try:
validate_anonymous_user_path(anonymous_user_path)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
with get_session_with_shared_schema() as db_session:
try:
modify_anonymous_user_path(tenant_id, anonymous_user_path, db_session)
except IntegrityError:
raise HTTPException(
status_code=409,
detail="The anonymous user path is already in use. Please choose a different path.",
)
except Exception as e:
logger.exception(f"Failed to modify anonymous user path: {str(e)}")
raise HTTPException(
status_code=500,
detail="An unexpected error occurred while modifying the anonymous user path",
)
@router.post("/anonymous-user")
async def login_as_anonymous_user(
anonymous_user_path: str,
_: User | None = Depends(optional_user),
) -> Response:
with get_session_with_shared_schema() as db_session:
tenant_id = get_tenant_id_for_anonymous_user_path(
anonymous_user_path, db_session
)
if not tenant_id:
raise HTTPException(status_code=404, detail="Tenant not found")
if not anonymous_user_enabled(tenant_id=tenant_id):
raise HTTPException(status_code=403, detail="Anonymous user is not enabled")
token = generate_anonymous_user_jwt_token(tenant_id)
response = Response()
response.delete_cookie(FASTAPI_USERS_AUTH_COOKIE_NAME)
response.set_cookie(
key=ANONYMOUS_USER_COOKIE_NAME,
value=token,
httponly=True,
secure=True,
samesite="strict",
)
return response
@router.post("/product-gating")
def gate_product(
product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep)
) -> ProductGatingResponse:
"""
Gating the product means that the product is not available to the tenant.
They will be directed to the billing page.
We gate the product when their subscription has ended.
"""
try:
store_product_gating(
product_gating_request.tenant_id, product_gating_request.application_status
)
return ProductGatingResponse(updated=True, error=None)
except Exception as e:
logger.exception("Failed to gate product")
return ProductGatingResponse(updated=False, error=str(e))
@router.get("/billing-information")
async def billing_information(
_: User = Depends(current_admin_user),
) -> BillingInformation | SubscriptionStatusResponse:
logger.info("Fetching billing information")
tenant_id = get_current_tenant_id()
return fetch_billing_information(tenant_id)
@router.post("/create-customer-portal-session")
async def create_customer_portal_session(
_: User = Depends(current_admin_user),
) -> dict:
tenant_id = get_current_tenant_id()
try:
stripe_info = fetch_tenant_stripe_information(tenant_id)
stripe_customer_id = stripe_info.get("stripe_customer_id")
if not stripe_customer_id:
raise HTTPException(status_code=400, detail="Stripe customer ID not found")
logger.info(stripe_customer_id)
portal_session = stripe.billing_portal.Session.create(
customer=stripe_customer_id,
return_url=f"{WEB_DOMAIN}/admin/billing",
)
logger.info(portal_session)
return {"url": portal_session.url}
except Exception as e:
logger.exception("Failed to create customer portal session")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/create-subscription-session")
async def create_subscription_session(
_: User = Depends(current_admin_user),
) -> SubscriptionSessionResponse:
try:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if not tenant_id:
raise HTTPException(status_code=400, detail="Tenant ID not found")
session_id = fetch_stripe_checkout_session(tenant_id)
return SubscriptionSessionResponse(sessionId=session_id)
except Exception as e:
logger.exception("Failed to create resubscription session")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/impersonate")
async def impersonate_user(
impersonate_request: ImpersonateRequest,
_: User = Depends(current_cloud_superuser),
) -> Response:
"""Allows a cloud superuser to impersonate another user by generating an impersonation JWT token"""
tenant_id = get_tenant_id_for_email(impersonate_request.email)
with get_session_with_tenant(tenant_id=tenant_id) as tenant_session:
user_to_impersonate = get_user_by_email(
impersonate_request.email, tenant_session
)
if user_to_impersonate is None:
raise HTTPException(status_code=404, detail="User not found")
token = await get_redis_strategy().write_token(user_to_impersonate)
response = await auth_backend.transport.get_login_response(token)
response.set_cookie(
key="fastapiusersauth",
value=token,
httponly=True,
secure=True,
samesite="lax",
)
return response
@router.post("/leave-organization")
async def leave_organization(
user_email: UserByEmail,
current_user: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
tenant_id = get_current_tenant_id()
if current_user is None or current_user.email != user_email.user_email:
raise HTTPException(
status_code=403, detail="You can only leave the organization as yourself"
)
user_to_delete = get_user_by_email(user_email.user_email, db_session)
if user_to_delete is None:
raise HTTPException(status_code=404, detail="User not found")
num_admin_users = await get_user_count(only_admin_users=True)
should_delete_tenant = num_admin_users == 1
if should_delete_tenant:
logger.info(
"Last admin user is leaving the organization. Deleting tenant from control plane."
)
try:
await delete_user_from_control_plane(tenant_id, user_to_delete.email)
logger.debug("User deleted from control plane")
except Exception as e:
logger.exception(
f"Failed to delete user from control plane for tenant {tenant_id}: {e}"
)
raise HTTPException(
status_code=500,
detail=f"Failed to remove user from control plane: {str(e)}",
)
db_session.expunge(user_to_delete)
delete_user_from_db(user_to_delete, db_session)
if should_delete_tenant:
remove_all_users_from_tenant(tenant_id)
else:
remove_users_from_tenant([user_to_delete.email], tenant_id)

View File

@@ -1,96 +0,0 @@
import stripe
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from ee.onyx.auth.users import current_admin_user
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
from ee.onyx.server.tenants.access import control_plane_dep
from ee.onyx.server.tenants.billing import fetch_billing_information
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
from ee.onyx.server.tenants.models import BillingInformation
from ee.onyx.server.tenants.models import ProductGatingRequest
from ee.onyx.server.tenants.models import ProductGatingResponse
from ee.onyx.server.tenants.models import SubscriptionSessionResponse
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
from ee.onyx.server.tenants.product_gating import store_product_gating
from onyx.auth.users import User
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
stripe.api_key = STRIPE_SECRET_KEY
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.post("/product-gating")
def gate_product(
product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep)
) -> ProductGatingResponse:
"""
Gating the product means that the product is not available to the tenant.
They will be directed to the billing page.
We gate the product when their subscription has ended.
"""
try:
store_product_gating(
product_gating_request.tenant_id, product_gating_request.application_status
)
return ProductGatingResponse(updated=True, error=None)
except Exception as e:
logger.exception("Failed to gate product")
return ProductGatingResponse(updated=False, error=str(e))
@router.get("/billing-information")
async def billing_information(
_: User = Depends(current_admin_user),
) -> BillingInformation | SubscriptionStatusResponse:
logger.info("Fetching billing information")
tenant_id = get_current_tenant_id()
return fetch_billing_information(tenant_id)
@router.post("/create-customer-portal-session")
async def create_customer_portal_session(
_: User = Depends(current_admin_user),
) -> dict:
tenant_id = get_current_tenant_id()
try:
stripe_info = fetch_tenant_stripe_information(tenant_id)
stripe_customer_id = stripe_info.get("stripe_customer_id")
if not stripe_customer_id:
raise HTTPException(status_code=400, detail="Stripe customer ID not found")
logger.info(stripe_customer_id)
portal_session = stripe.billing_portal.Session.create(
customer=stripe_customer_id,
return_url=f"{WEB_DOMAIN}/admin/billing",
)
logger.info(portal_session)
return {"url": portal_session.url}
except Exception as e:
logger.exception("Failed to create customer portal session")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/create-subscription-session")
async def create_subscription_session(
_: User = Depends(current_admin_user),
) -> SubscriptionSessionResponse:
try:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if not tenant_id:
raise HTTPException(status_code=400, detail="Tenant ID not found")
session_id = fetch_stripe_checkout_session(tenant_id)
return SubscriptionSessionResponse(sessionId=session_id)
except Exception as e:
logger.exception("Failed to create resubscription session")
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -67,30 +67,3 @@ class ProductGatingResponse(BaseModel):
class SubscriptionSessionResponse(BaseModel):
sessionId: str
class TenantByDomainResponse(BaseModel):
tenant_id: str
number_of_users: int
creator_email: str
class TenantByDomainRequest(BaseModel):
email: str
class RequestInviteRequest(BaseModel):
tenant_id: str
class RequestInviteResponse(BaseModel):
success: bool
message: str
class PendingUserSnapshot(BaseModel):
email: str
class ApproveUserRequest(BaseModel):
email: str

View File

@@ -4,7 +4,6 @@ import uuid
import aiohttp # Async HTTP client
import httpx
import requests
from fastapi import HTTPException
from fastapi import Request
from sqlalchemy import select
@@ -15,7 +14,6 @@ from ee.onyx.configs.app_configs import COHERE_DEFAULT_API_KEY
from ee.onyx.configs.app_configs import HUBSPOT_TRACKING_URL
from ee.onyx.configs.app_configs import OPENAI_DEFAULT_API_KEY
from ee.onyx.server.tenants.access import generate_data_plane_token
from ee.onyx.server.tenants.models import TenantByDomainResponse
from ee.onyx.server.tenants.models import TenantCreationPayload
from ee.onyx.server.tenants.models import TenantDeletionPayload
from ee.onyx.server.tenants.schema_management import create_schema_if_not_exists
@@ -28,12 +26,11 @@ from onyx.auth.users import exceptions
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.constants import MilestoneRecordType
from onyx.db.engine import get_session_with_shared_schema
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_sqlalchemy_engine
from onyx.db.llm import update_default_provider
from onyx.db.llm import upsert_cloud_embedding_provider
from onyx.db.llm import upsert_llm_provider
from onyx.db.models import AvailableTenant
from onyx.db.models import IndexModelStatus
from onyx.db.models import SearchSettings
from onyx.db.models import UserTenantMapping
@@ -63,72 +60,42 @@ async def get_or_provision_tenant(
This function should only be called after we have verified we want this user's tenant to exist.
It returns the tenant ID associated with the email, creating a new tenant if necessary.
"""
# Early return for non-multi-tenant mode
if not MULTI_TENANT:
return POSTGRES_DEFAULT_SCHEMA
if referral_source and request:
await submit_to_hubspot(email, referral_source, request)
# First, check if the user already has a tenant
tenant_id: str | None = None
try:
tenant_id = get_tenant_id_for_email(email)
return tenant_id
except exceptions.UserNotExists:
# User doesn't exist, so we need to create a new tenant or assign an existing one
pass
try:
# Try to get a pre-provisioned tenant
tenant_id = await get_available_tenant()
if tenant_id:
# If we have a pre-provisioned tenant, assign it to the user
await assign_tenant_to_user(tenant_id, email, referral_source)
logger.info(f"Assigned pre-provisioned tenant {tenant_id} to user {email}")
return tenant_id
else:
# If no pre-provisioned tenant is available, create a new one on-demand
# If tenant does not exist and in Multi tenant mode, provision a new tenant
try:
tenant_id = await create_tenant(email, referral_source)
return tenant_id
except Exception as e:
logger.error(f"Tenant provisioning failed: {e}")
raise HTTPException(status_code=500, detail="Failed to provision tenant.")
except Exception as e:
# If we've encountered an error, log and raise an exception
error_msg = "Failed to provision tenant"
logger.error(error_msg, exc_info=e)
if not tenant_id:
raise HTTPException(
status_code=500,
detail="Failed to provision tenant. Please try again later.",
status_code=401, detail="User does not belong to an organization"
)
return tenant_id
async def create_tenant(email: str, referral_source: str | None = None) -> str:
"""
Create a new tenant on-demand when no pre-provisioned tenants are available.
This is the fallback method when we can't use a pre-provisioned tenant.
"""
tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4())
logger.info(f"Creating new tenant {tenant_id} for user {email}")
try:
# Provision tenant on data plane
await provision_tenant(tenant_id, email)
# Notify control plane if not already done in provision_tenant
if not DEV_MODE and referral_source:
# Notify control plane
if not DEV_MODE:
await notify_control_plane(tenant_id, email, referral_source)
except Exception as e:
logger.exception(f"Tenant provisioning failed: {str(e)}")
# Attempt to rollback the tenant provisioning
try:
await rollback_tenant_provisioning(tenant_id)
except Exception:
logger.exception(f"Failed to rollback tenant provisioning for {tenant_id}")
logger.error(f"Tenant provisioning failed: {e}")
await rollback_tenant_provisioning(tenant_id)
raise HTTPException(status_code=500, detail="Failed to provision tenant.")
return tenant_id
@@ -142,25 +109,54 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
)
logger.debug(f"Provisioning tenant {tenant_id} for user {email}")
token = None
try:
# Create the schema for the tenant
if not create_schema_if_not_exists(tenant_id):
logger.debug(f"Created schema for tenant {tenant_id}")
else:
logger.debug(f"Schema already exists for tenant {tenant_id}")
# Set up the tenant with all necessary configurations
await setup_tenant(tenant_id)
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
# Assign the tenant to the user
await assign_tenant_to_user(tenant_id, email)
# Await the Alembic migrations
await asyncio.to_thread(run_alembic_migrations, tenant_id)
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
configure_default_api_keys(db_session)
current_search_settings = (
db_session.query(SearchSettings)
.filter_by(status=IndexModelStatus.FUTURE)
.first()
)
cohere_enabled = (
current_search_settings is not None
and current_search_settings.provider_type == EmbeddingProvider.COHERE
)
setup_onyx(db_session, tenant_id, cohere_enabled=cohere_enabled)
add_users_to_tenant([email], tenant_id)
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
create_milestone_and_report(
user=None,
distinct_id=tenant_id,
event_type=MilestoneRecordType.TENANT_CREATED,
properties={
"email": email,
},
db_session=db_session,
)
except Exception as e:
logger.exception(f"Failed to create tenant {tenant_id}")
raise HTTPException(
status_code=500, detail=f"Failed to create tenant: {str(e)}"
)
finally:
if token is not None:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
async def notify_control_plane(
@@ -191,74 +187,20 @@ async def notify_control_plane(
async def rollback_tenant_provisioning(tenant_id: str) -> None:
"""
Logic to rollback tenant provisioning on data plane.
Handles each step independently to ensure maximum cleanup even if some steps fail.
"""
# Logic to rollback tenant provisioning on data plane
logger.info(f"Rolling back tenant provisioning for tenant_id: {tenant_id}")
# Track if any part of the rollback fails
rollback_errors = []
# 1. Try to drop the tenant's schema
try:
# Drop the tenant's schema to rollback provisioning
drop_schema(tenant_id)
logger.info(f"Successfully dropped schema for tenant {tenant_id}")
# Remove tenant mapping
with Session(get_sqlalchemy_engine()) as db_session:
db_session.query(UserTenantMapping).filter(
UserTenantMapping.tenant_id == tenant_id
).delete()
db_session.commit()
except Exception as e:
error_msg = f"Failed to drop schema for tenant {tenant_id}: {str(e)}"
logger.error(error_msg)
rollback_errors.append(error_msg)
# 2. Try to remove tenant mapping
try:
with get_session_with_shared_schema() as db_session:
db_session.begin()
try:
db_session.query(UserTenantMapping).filter(
UserTenantMapping.tenant_id == tenant_id
).delete()
db_session.commit()
logger.info(
f"Successfully removed user mappings for tenant {tenant_id}"
)
except Exception as e:
db_session.rollback()
raise e
except Exception as e:
error_msg = f"Failed to remove user mappings for tenant {tenant_id}: {str(e)}"
logger.error(error_msg)
rollback_errors.append(error_msg)
# 3. If this tenant was in the available tenants table, remove it
try:
with get_session_with_shared_schema() as db_session:
db_session.begin()
try:
available_tenant = (
db_session.query(AvailableTenant)
.filter(AvailableTenant.tenant_id == tenant_id)
.first()
)
if available_tenant:
db_session.delete(available_tenant)
db_session.commit()
logger.info(
f"Removed tenant {tenant_id} from available tenants table"
)
except Exception as e:
db_session.rollback()
raise e
except Exception as e:
error_msg = f"Failed to remove tenant {tenant_id} from available tenants table: {str(e)}"
logger.error(error_msg)
rollback_errors.append(error_msg)
# Log summary of rollback operation
if rollback_errors:
logger.error(f"Tenant rollback completed with {len(rollback_errors)} errors")
else:
logger.info(f"Tenant rollback completed successfully for tenant {tenant_id}")
logger.error(f"Failed to rollback tenant provisioning: {e}")
def configure_default_api_keys(db_session: Session) -> None:
@@ -411,155 +353,3 @@ async def delete_user_from_control_plane(tenant_id: str, email: str) -> None:
raise Exception(
f"Failed to delete tenant on control plane: {error_text}"
)
def get_tenant_by_domain_from_control_plane(
domain: str,
tenant_id: str,
) -> TenantByDomainResponse | None:
"""
Fetches tenant information from the control plane based on the email domain.
Args:
domain: The email domain to search for (e.g., "example.com")
Returns:
A dictionary containing tenant information if found, None otherwise
"""
token = generate_data_plane_token()
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
try:
response = requests.get(
f"{CONTROL_PLANE_API_BASE_URL}/tenant-by-domain",
headers=headers,
json={"domain": domain, "tenant_id": tenant_id},
)
if response.status_code != 200:
logger.error(f"Control plane tenant lookup failed: {response.text}")
return None
response_data = response.json()
if not response_data:
return None
return TenantByDomainResponse(
tenant_id=response_data.get("tenant_id"),
number_of_users=response_data.get("number_of_users"),
creator_email=response_data.get("creator_email"),
)
except Exception as e:
logger.error(f"Error fetching tenant by domain: {str(e)}")
return None
async def get_available_tenant() -> str | None:
"""
Get an available pre-provisioned tenant from the NewAvailableTenant table.
Returns the tenant_id if one is available, None otherwise.
Uses row-level locking to prevent race conditions when multiple processes
try to get an available tenant simultaneously.
"""
if not MULTI_TENANT:
return None
with get_session_with_shared_schema() as db_session:
try:
db_session.begin()
# Get the oldest available tenant with FOR UPDATE lock to prevent race conditions
available_tenant = (
db_session.query(AvailableTenant)
.order_by(AvailableTenant.date_created)
.with_for_update(skip_locked=True) # Skip locked rows to avoid blocking
.first()
)
if available_tenant:
tenant_id = available_tenant.tenant_id
# Remove the tenant from the available tenants table
db_session.delete(available_tenant)
db_session.commit()
logger.info(f"Using pre-provisioned tenant {tenant_id}")
return tenant_id
else:
db_session.rollback()
return None
except Exception:
logger.exception("Error getting available tenant")
db_session.rollback()
return None
async def setup_tenant(tenant_id: str) -> None:
"""
Set up a tenant with all necessary configurations.
This is a centralized function that handles all tenant setup logic.
"""
token = None
try:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
# Run Alembic migrations
await asyncio.to_thread(run_alembic_migrations, tenant_id)
# Configure the tenant with default settings
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
# Configure default API keys
configure_default_api_keys(db_session)
# Set up Onyx with appropriate settings
current_search_settings = (
db_session.query(SearchSettings)
.filter_by(status=IndexModelStatus.FUTURE)
.first()
)
cohere_enabled = (
current_search_settings is not None
and current_search_settings.provider_type == EmbeddingProvider.COHERE
)
setup_onyx(db_session, tenant_id, cohere_enabled=cohere_enabled)
except Exception as e:
logger.exception(f"Failed to set up tenant {tenant_id}")
raise e
finally:
if token is not None:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
async def assign_tenant_to_user(
tenant_id: str, email: str, referral_source: str | None = None
) -> None:
"""
Assign a tenant to a user and perform necessary operations.
Uses transaction handling to ensure atomicity and includes retry logic
for control plane notifications.
"""
# First, add the user to the tenant in a transaction
try:
add_users_to_tenant([email], tenant_id)
# Create milestone record in the same transaction context as the tenant assignment
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
create_milestone_and_report(
user=None,
distinct_id=tenant_id,
event_type=MilestoneRecordType.TENANT_CREATED,
properties={
"email": email,
},
db_session=db_session,
)
except Exception:
logger.exception(f"Failed to assign tenant {tenant_id} to user {email}")
raise Exception("Failed to assign tenant to user")
# Notify control plane with retry logic
if not DEV_MODE:
await notify_control_plane(tenant_id, email, referral_source)

View File

@@ -74,21 +74,3 @@ def drop_schema(tenant_id: str) -> None:
text("DROP SCHEMA IF EXISTS %(schema_name)s CASCADE"),
{"schema_name": tenant_id},
)
def get_current_alembic_version(tenant_id: str) -> str:
"""Get the current Alembic version for a tenant."""
from alembic.runtime.migration import MigrationContext
from sqlalchemy import text
engine = get_sqlalchemy_engine()
# Set the search path to the tenant's schema
with engine.connect() as connection:
connection.execute(text(f'SET search_path TO "{tenant_id}"'))
# Get the current version from the alembic_version table
context = MigrationContext.configure(connection)
current_rev = context.get_current_revision()
return current_rev or "head"

View File

@@ -1,67 +0,0 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane
from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant
from onyx.auth.users import current_admin_user
from onyx.auth.users import User
from onyx.db.auth import get_user_count
from onyx.db.engine import get_session
from onyx.db.users import delete_user_from_db
from onyx.db.users import get_user_by_email
from onyx.server.manage.models import UserByEmail
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.post("/leave-team")
async def leave_organization(
user_email: UserByEmail,
current_user: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
tenant_id = get_current_tenant_id()
if current_user is None or current_user.email != user_email.user_email:
raise HTTPException(
status_code=403, detail="You can only leave the organization as yourself"
)
user_to_delete = get_user_by_email(user_email.user_email, db_session)
if user_to_delete is None:
raise HTTPException(status_code=404, detail="User not found")
num_admin_users = await get_user_count(only_admin_users=True)
should_delete_tenant = num_admin_users == 1
if should_delete_tenant:
logger.info(
"Last admin user is leaving the organization. Deleting tenant from control plane."
)
try:
await delete_user_from_control_plane(tenant_id, user_to_delete.email)
logger.debug("User deleted from control plane")
except Exception as e:
logger.exception(
f"Failed to delete user from control plane for tenant {tenant_id}: {e}"
)
raise HTTPException(
status_code=500,
detail=f"Failed to remove user from control plane: {str(e)}",
)
db_session.expunge(user_to_delete)
delete_user_from_db(user_to_delete, db_session)
if should_delete_tenant:
remove_all_users_from_tenant(tenant_id)
else:
remove_users_from_tenant([user_to_delete.email], tenant_id)

View File

@@ -1,39 +0,0 @@
from fastapi import APIRouter
from fastapi import Depends
from ee.onyx.server.tenants.models import TenantByDomainResponse
from ee.onyx.server.tenants.provisioning import get_tenant_by_domain_from_control_plane
from onyx.auth.users import current_user
from onyx.auth.users import User
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/tenants")
FORBIDDEN_COMMON_EMAIL_SUBSTRINGS = [
"gmail",
"outlook",
"yahoo",
"hotmail",
"icloud",
"msn",
"hotmail",
"hotmail.co.uk",
]
@router.get("/existing-team-by-domain")
def get_existing_tenant_by_domain(
user: User | None = Depends(current_user),
) -> TenantByDomainResponse | None:
if not user:
return None
domain = user.email.split("@")[1]
if any(substring in domain for substring in FORBIDDEN_COMMON_EMAIL_SUBSTRINGS):
return None
tenant_id = get_current_tenant_id()
return get_tenant_by_domain_from_control_plane(domain, tenant_id)

View File

@@ -1,90 +0,0 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from ee.onyx.server.tenants.models import ApproveUserRequest
from ee.onyx.server.tenants.models import PendingUserSnapshot
from ee.onyx.server.tenants.models import RequestInviteRequest
from ee.onyx.server.tenants.user_mapping import accept_user_invite
from ee.onyx.server.tenants.user_mapping import approve_user_invite
from ee.onyx.server.tenants.user_mapping import deny_user_invite
from ee.onyx.server.tenants.user_mapping import invite_self_to_tenant
from onyx.auth.invited_users import get_pending_users
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_user
from onyx.auth.users import User
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.post("/users/invite/request")
async def request_invite(
invite_request: RequestInviteRequest,
user: User | None = Depends(current_admin_user),
) -> None:
if user is None:
raise HTTPException(status_code=401, detail="User not authenticated")
try:
invite_self_to_tenant(user.email, invite_request.tenant_id)
except Exception as e:
logger.exception(
f"Failed to invite self to tenant {invite_request.tenant_id}: {e}"
)
raise HTTPException(status_code=500, detail=str(e))
@router.get("/users/pending")
def list_pending_users(
_: User | None = Depends(current_admin_user),
) -> list[PendingUserSnapshot]:
pending_emails = get_pending_users()
return [PendingUserSnapshot(email=email) for email in pending_emails]
@router.post("/users/invite/approve")
async def approve_user(
approve_user_request: ApproveUserRequest,
_: User | None = Depends(current_admin_user),
) -> None:
tenant_id = get_current_tenant_id()
approve_user_invite(approve_user_request.email, tenant_id)
@router.post("/users/invite/accept")
async def accept_invite(
invite_request: RequestInviteRequest,
user: User | None = Depends(current_user),
) -> None:
"""
Accept an invitation to join a tenant.
"""
if not user:
raise HTTPException(status_code=401, detail="Not authenticated")
try:
accept_user_invite(user.email, invite_request.tenant_id)
except Exception as e:
logger.exception(f"Failed to accept invite: {str(e)}")
raise HTTPException(status_code=500, detail="Failed to accept invitation")
@router.post("/users/invite/deny")
async def deny_invite(
invite_request: RequestInviteRequest,
user: User | None = Depends(current_user),
) -> None:
"""
Deny an invitation to join a tenant.
"""
if not user:
raise HTTPException(status_code=401, detail="Not authenticated")
try:
deny_user_invite(user.email, invite_request.tenant_id)
except Exception as e:
logger.exception(f"Failed to deny invite: {str(e)}")
raise HTTPException(status_code=500, detail="Failed to deny invitation")

View File

@@ -1,56 +1,27 @@
import logging
from fastapi_users import exceptions
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.auth.invited_users import get_invited_users
from onyx.auth.invited_users import get_pending_users
from onyx.auth.invited_users import write_invited_users
from onyx.auth.invited_users import write_pending_users
from onyx.db.engine import get_session_with_shared_schema
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_sqlalchemy_engine
from onyx.db.models import UserTenantMapping
from onyx.server.manage.models import TenantSnapshot
from onyx.setup import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
logger = logging.getLogger(__name__)
def get_tenant_id_for_email(email: str) -> str:
if not MULTI_TENANT:
return POSTGRES_DEFAULT_SCHEMA
# Implement logic to get tenant_id from the mapping table
try:
with get_session_with_shared_schema() as db_session:
# First try to get an active tenant
result = db_session.execute(
select(UserTenantMapping).where(
UserTenantMapping.email == email,
UserTenantMapping.active == True, # noqa: E712
)
)
mapping = result.scalar_one_or_none()
tenant_id = mapping.tenant_id if mapping else None
# If no active tenant found, try to get the first inactive one
if tenant_id is None:
result = db_session.execute(
select(UserTenantMapping).where(
UserTenantMapping.email == email,
UserTenantMapping.active == False, # noqa: E712
)
)
mapping = result.scalar_one_or_none()
if mapping:
# Mark this mapping as active
mapping.active = True
db_session.commit()
tenant_id = mapping.tenant_id
except Exception as e:
logger.exception(f"Error getting tenant id for email {email}: {e}")
raise exceptions.UserNotExists()
with Session(get_sqlalchemy_engine()) as db_session:
result = db_session.execute(
select(UserTenantMapping.tenant_id).where(UserTenantMapping.email == email)
)
tenant_id = result.scalar_one_or_none()
if tenant_id is None:
raise exceptions.UserNotExists()
return tenant_id
@@ -67,39 +38,13 @@ def user_owns_a_tenant(email: str) -> bool:
def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
"""
Add users to a tenant with proper transaction handling.
Checks if users already have a tenant mapping to avoid duplicates.
"""
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
try:
# Start a transaction
db_session.begin()
for email in emails:
# Check if the user already has a mapping to this tenant
existing_mapping = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email == email,
UserTenantMapping.tenant_id == tenant_id,
)
.with_for_update()
.first()
)
if not existing_mapping:
# Only add if mapping doesn't exist
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
# Commit the transaction
db_session.commit()
logger.info(f"Successfully added users {emails} to tenant {tenant_id}")
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
except Exception:
logger.exception(f"Failed to add users to tenant {tenant_id}")
db_session.rollback()
raise
db_session.commit()
def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
@@ -131,187 +76,3 @@ def remove_all_users_from_tenant(tenant_id: str) -> None:
UserTenantMapping.tenant_id == tenant_id
).delete()
db_session.commit()
def invite_self_to_tenant(email: str, tenant_id: str) -> None:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
try:
pending_users = get_pending_users()
if email in pending_users:
return
write_pending_users(pending_users + [email])
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def approve_user_invite(email: str, tenant_id: str) -> None:
"""
Approve a user invite to a tenant.
This will delete all existing records for this email and create a new mapping entry for the user in this tenant.
"""
with get_session_with_shared_schema() as db_session:
# Delete all existing records for this email
db_session.query(UserTenantMapping).filter(
UserTenantMapping.email == email
).delete()
# Create a new mapping entry for the user in this tenant
new_mapping = UserTenantMapping(email=email, tenant_id=tenant_id, active=True)
db_session.add(new_mapping)
db_session.commit()
# Also remove the user from pending users list
# Remove from pending users
pending_users = get_pending_users()
if email in pending_users:
pending_users.remove(email)
write_pending_users(pending_users)
# Add to invited users
invited_users = get_invited_users()
if email not in invited_users:
invited_users.append(email)
write_invited_users(invited_users)
def accept_user_invite(email: str, tenant_id: str) -> None:
"""
Accept an invitation to join a tenant.
This activates the user's mapping to the tenant.
"""
with get_session_with_shared_schema() as db_session:
try:
# First check if there's an active mapping for this user and tenant
active_mapping = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email == email,
UserTenantMapping.active == True, # noqa: E712
)
.first()
)
# If an active mapping exists, delete it
if active_mapping:
db_session.delete(active_mapping)
logger.info(
f"Deleted existing active mapping for user {email} in tenant {tenant_id}"
)
# Find the inactive mapping for this user and tenant
mapping = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email == email,
UserTenantMapping.tenant_id == tenant_id,
UserTenantMapping.active == False, # noqa: E712
)
.first()
)
if mapping:
# Set all other mappings for this user to inactive
db_session.query(UserTenantMapping).filter(
UserTenantMapping.email == email,
UserTenantMapping.active == True, # noqa: E712
).update({"active": False})
# Activate this mapping
mapping.active = True
db_session.commit()
logger.info(f"User {email} accepted invitation to tenant {tenant_id}")
else:
logger.warning(
f"No invitation found for user {email} in tenant {tenant_id}"
)
except Exception as e:
db_session.rollback()
logger.exception(
f"Failed to accept invitation for user {email} to tenant {tenant_id}: {str(e)}"
)
raise
def deny_user_invite(email: str, tenant_id: str) -> None:
"""
Deny an invitation to join a tenant.
This removes the user's mapping to the tenant.
"""
with get_session_with_shared_schema() as db_session:
# Delete the mapping for this user and tenant
result = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email == email,
UserTenantMapping.tenant_id == tenant_id,
UserTenantMapping.active == False, # noqa: E712
)
.delete()
)
db_session.commit()
if result:
logger.info(f"User {email} denied invitation to tenant {tenant_id}")
else:
logger.warning(
f"No invitation found for user {email} in tenant {tenant_id}"
)
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
try:
pending_users = get_invited_users()
if email in pending_users:
pending_users.remove(email)
write_invited_users(pending_users)
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def get_tenant_count(tenant_id: str) -> int:
"""
Get the number of active users for this tenant
"""
with get_session_with_shared_schema() as db_session:
# Count the number of active users for this tenant
user_count = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.tenant_id == tenant_id,
UserTenantMapping.active == True, # noqa: E712
)
.count()
)
return user_count
def get_tenant_invitation(email: str) -> TenantSnapshot | None:
"""
Get the first tenant invitation for this user
"""
with get_session_with_shared_schema() as db_session:
# Get the first tenant invitation for this user
invitation = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email == email,
UserTenantMapping.active == False, # noqa: E712
)
.first()
)
if invitation:
# Get the user count for this tenant
user_count = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.tenant_id == invitation.tenant_id,
UserTenantMapping.active == True, # noqa: E712
)
.count()
)
return TenantSnapshot(
tenant_id=invitation.tenant_id, number_of_users=user_count
)
return None

View File

@@ -62,60 +62,6 @@ _OPENAI_MAX_INPUT_LEN = 2048
# Cohere allows up to 96 embeddings in a single embedding calling
_COHERE_MAX_INPUT_LEN = 96
# Authentication error string constants
_AUTH_ERROR_401 = "401"
_AUTH_ERROR_UNAUTHORIZED = "unauthorized"
_AUTH_ERROR_INVALID_API_KEY = "invalid api key"
_AUTH_ERROR_PERMISSION = "permission"
def is_authentication_error(error: Exception) -> bool:
"""Check if an exception is related to authentication issues.
Args:
error: The exception to check
Returns:
bool: True if the error appears to be authentication-related
"""
error_str = str(error).lower()
return (
_AUTH_ERROR_401 in error_str
or _AUTH_ERROR_UNAUTHORIZED in error_str
or _AUTH_ERROR_INVALID_API_KEY in error_str
or _AUTH_ERROR_PERMISSION in error_str
)
def format_embedding_error(
error: Exception,
service_name: str,
model: str | None,
provider: EmbeddingProvider,
status_code: int | None = None,
) -> str:
"""
Format a standardized error string for embedding errors.
"""
detail = f"Status {status_code}" if status_code else f"{type(error)}"
return (
f"{'HTTP error' if status_code else 'Exception'} embedding text with {service_name} - {detail}: "
f"Model: {model} "
f"Provider: {provider} "
f"Exception: {error}"
)
# Custom exception for authentication errors
class AuthenticationError(Exception):
"""Raised when authentication fails with a provider."""
def __init__(self, provider: str, message: str = "API key is invalid or expired"):
self.provider = provider
self.message = message
super().__init__(f"{provider} authentication failed: {message}")
class CloudEmbedding:
def __init__(
@@ -146,17 +92,31 @@ class CloudEmbedding:
)
final_embeddings: list[Embedding] = []
try:
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
response = await client.embeddings.create(
input=text_batch,
model=model,
dimensions=reduced_dimension or openai.NOT_GIVEN,
)
final_embeddings.extend(
[embedding.embedding for embedding in response.data]
)
return final_embeddings
except Exception as e:
error_string = (
f"Exception embedding text with OpenAI - {type(e)}: "
f"Model: {model} "
f"Provider: {self.provider} "
f"Exception: {e}"
)
logger.error(error_string)
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
response = await client.embeddings.create(
input=text_batch,
model=model,
dimensions=reduced_dimension or openai.NOT_GIVEN,
)
final_embeddings.extend(
[embedding.embedding for embedding in response.data]
)
return final_embeddings
# only log text when it's not an authentication error.
if not isinstance(e, openai.AuthenticationError):
logger.debug(f"Exception texts: {texts}")
raise RuntimeError(error_string)
async def _embed_cohere(
self, texts: list[str], model: str | None, embedding_type: str
@@ -195,6 +155,7 @@ class CloudEmbedding:
input_type=embedding_type,
truncation=True,
)
return response.embeddings
async def _embed_azure(
@@ -278,51 +239,22 @@ class CloudEmbedding:
deployment_name: str | None = None,
reduced_dimension: int | None = None,
) -> list[Embedding]:
try:
if self.provider == EmbeddingProvider.OPENAI:
return await self._embed_openai(texts, model_name, reduced_dimension)
elif self.provider == EmbeddingProvider.AZURE:
return await self._embed_azure(texts, f"azure/{deployment_name}")
elif self.provider == EmbeddingProvider.LITELLM:
return await self._embed_litellm_proxy(texts, model_name)
if self.provider == EmbeddingProvider.OPENAI:
return await self._embed_openai(texts, model_name, reduced_dimension)
elif self.provider == EmbeddingProvider.AZURE:
return await self._embed_azure(texts, f"azure/{deployment_name}")
elif self.provider == EmbeddingProvider.LITELLM:
return await self._embed_litellm_proxy(texts, model_name)
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
if self.provider == EmbeddingProvider.COHERE:
return await self._embed_cohere(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.VOYAGE:
return await self._embed_voyage(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.GOOGLE:
return await self._embed_vertex(texts, model_name, embedding_type)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
except openai.AuthenticationError:
raise AuthenticationError(provider="OpenAI")
except httpx.HTTPStatusError as e:
if e.response.status_code == 401:
raise AuthenticationError(provider=str(self.provider))
error_string = format_embedding_error(
e,
str(self.provider),
model_name or deployment_name,
self.provider,
status_code=e.response.status_code,
)
logger.error(error_string)
logger.debug(f"Exception texts: {texts}")
raise RuntimeError(error_string)
except Exception as e:
if is_authentication_error(e):
raise AuthenticationError(provider=str(self.provider))
error_string = format_embedding_error(
e, str(self.provider), model_name or deployment_name, self.provider
)
logger.error(error_string)
logger.debug(f"Exception texts: {texts}")
raise RuntimeError(error_string)
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
if self.provider == EmbeddingProvider.COHERE:
return await self._embed_cohere(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.VOYAGE:
return await self._embed_voyage(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.GOOGLE:
return await self._embed_vertex(texts, model_name, embedding_type)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
@staticmethod
def create(
@@ -637,13 +569,6 @@ async def process_embed_request(
gpu_type=gpu_type,
)
return EmbedResponse(embeddings=embeddings)
except AuthenticationError as e:
# Handle authentication errors consistently
logger.error(f"Authentication error: {e.provider}")
raise HTTPException(
status_code=401,
detail=f"Authentication failed: {e.message}",
)
except RateLimitError as e:
raise HTTPException(
status_code=429,

View File

@@ -153,8 +153,7 @@ def send_email(
msg = MIMEMultipart("alternative")
msg["Subject"] = subject
msg["To"] = user_email
if mail_from:
msg["From"] = mail_from
msg["From"] = mail_from
msg["Date"] = formatdate(localtime=True)
msg["Message-ID"] = make_msgid(domain="onyx.app")

View File

@@ -1,6 +1,5 @@
from typing import cast
from onyx.configs.constants import KV_PENDING_USERS_KEY
from onyx.configs.constants import KV_USER_STORE_KEY
from onyx.key_value_store.factory import get_kv_store
from onyx.key_value_store.interface import KvKeyNotFoundError
@@ -19,17 +18,3 @@ def write_invited_users(emails: list[str]) -> int:
store = get_kv_store()
store.store(KV_USER_STORE_KEY, cast(JSON_ro, emails))
return len(emails)
def get_pending_users() -> list[str]:
try:
store = get_kv_store()
return cast(list, store.load(KV_PENDING_USERS_KEY))
except KvKeyNotFoundError:
return list()
def write_pending_users(emails: list[str]) -> int:
store = get_kv_store()
store.store(KV_PENDING_USERS_KEY, cast(JSON_ro, emails))
return len(emails)

View File

@@ -100,7 +100,6 @@ from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
from onyx.utils.url import add_url_params
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
from onyx.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import async_return_default_schema
@@ -895,7 +894,7 @@ async def current_limited_user(
return await double_check_user(user)
async def current_chat_accessible_user(
async def current_chat_accesssible_user(
user: User | None = Depends(optional_user),
) -> User | None:
tenant_id = get_current_tenant_id()
@@ -1096,12 +1095,6 @@ def get_oauth_router(
next_url = state_data.get("next_url", "/")
referral_source = state_data.get("referral_source", None)
try:
tenant_id = fetch_ee_implementation_or_noop(
"onyx.server.tenants.user_mapping", "get_tenant_id_for_email", None
)(account_email)
except exceptions.UserNotExists:
tenant_id = None
request.state.referral_source = referral_source
@@ -1133,14 +1126,9 @@ def get_oauth_router(
# Login user
response = await backend.login(strategy, user)
await user_manager.on_after_login(user, request, response)
# Prepare redirect response
if tenant_id is None:
# Use URL utility to add parameters
redirect_url = add_url_params(next_url, {"new_team": "true"})
redirect_response = RedirectResponse(redirect_url, status_code=302)
else:
# No parameters to add
redirect_response = RedirectResponse(next_url, status_code=302)
redirect_response = RedirectResponse(next_url, status_code=302)
# Copy headers and other attributes from 'response' to 'redirect_response'
for header_name, header_value in response.headers.items():
@@ -1152,7 +1140,6 @@ def get_oauth_router(
redirect_response.status_code = response.status_code
if hasattr(response, "media_type"):
redirect_response.media_type = response.media_type
return redirect_response
return router

View File

@@ -112,6 +112,5 @@ celery_app.autodiscover_tasks(
"onyx.background.celery.tasks.connector_deletion",
"onyx.background.celery.tasks.doc_permission_syncing",
"onyx.background.celery.tasks.indexing",
"onyx.background.celery.tasks.tenant_provisioning",
]
)

View File

@@ -92,6 +92,5 @@ def on_setup_logging(
celery_app.autodiscover_tasks(
[
"onyx.background.celery.tasks.monitoring",
"onyx.background.celery.tasks.tenant_provisioning",
]
)

View File

@@ -5,53 +5,40 @@ from logging.handlers import RotatingFileHandler
import psutil
from onyx.utils.logger import is_running_in_container
from onyx.utils.logger import setup_logger
# Regular application logger
logger = setup_logger()
# Only set up memory monitoring in container environment
if is_running_in_container():
# Set up a dedicated memory monitoring logger
MEMORY_LOG_DIR = "/var/log/persisted-logs/memory"
MEMORY_LOG_FILE = os.path.join(MEMORY_LOG_DIR, "memory_usage.log")
MEMORY_LOG_MAX_BYTES = 10 * 1024 * 1024 # 10MB
MEMORY_LOG_BACKUP_COUNT = 5 # Keep 5 backup files
# Set up a dedicated memory monitoring logger
MEMORY_LOG_DIR = "/var/log/persisted-logs/memory"
MEMORY_LOG_FILE = os.path.join(MEMORY_LOG_DIR, "memory_usage.log")
MEMORY_LOG_MAX_BYTES = 10 * 1024 * 1024 # 10MB
MEMORY_LOG_BACKUP_COUNT = 5 # Keep 5 backup files
# Ensure log directory exists
os.makedirs(MEMORY_LOG_DIR, exist_ok=True)
# Ensure log directory exists
os.makedirs(MEMORY_LOG_DIR, exist_ok=True)
# Create a dedicated logger for memory monitoring
memory_logger = logging.getLogger("memory_monitoring")
memory_logger.setLevel(logging.INFO)
# Create a dedicated logger for memory monitoring
memory_logger = logging.getLogger("memory_monitoring")
memory_logger.setLevel(logging.INFO)
# Create a rotating file handler
memory_handler = RotatingFileHandler(
MEMORY_LOG_FILE,
maxBytes=MEMORY_LOG_MAX_BYTES,
backupCount=MEMORY_LOG_BACKUP_COUNT,
)
# Create a rotating file handler
memory_handler = RotatingFileHandler(
MEMORY_LOG_FILE, maxBytes=MEMORY_LOG_MAX_BYTES, backupCount=MEMORY_LOG_BACKUP_COUNT
)
# Create a formatter that includes all relevant information
memory_formatter = logging.Formatter(
"%(asctime)s [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
)
memory_handler.setFormatter(memory_formatter)
memory_logger.addHandler(memory_handler)
else:
# Create a null logger when not in container
memory_logger = logging.getLogger("memory_monitoring")
memory_logger.addHandler(logging.NullHandler())
# Create a formatter that includes all relevant information
memory_formatter = logging.Formatter(
"%(asctime)s [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
)
memory_handler.setFormatter(memory_formatter)
memory_logger.addHandler(memory_handler)
def emit_process_memory(
pid: int, process_name: str, additional_metadata: dict[str, str | int]
) -> None:
# Skip memory monitoring if not in container
if not is_running_in_container():
return
try:
process = psutil.Process(pid)
memory_info = process.memory_info()

View File

@@ -167,16 +167,6 @@ beat_cloud_tasks: list[dict] = [
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-available-tenants",
"task": OnyxCeleryTask.CHECK_AVAILABLE_TENANTS,
"schedule": timedelta(minutes=10),
"options": {
"queue": OnyxCeleryQueues.MONITORING,
"priority": OnyxCeleryPriority.HIGH,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
]
# tasks that only run self hosted

View File

@@ -1,199 +0,0 @@
"""
Periodic tasks for tenant pre-provisioning.
"""
import asyncio
import datetime
import uuid
from celery import shared_task
from celery import Task
from redis.lock import Lock as RedisLock
from ee.onyx.server.tenants.provisioning import setup_tenant
from ee.onyx.server.tenants.schema_management import create_schema_if_not_exists
from ee.onyx.server.tenants.schema_management import get_current_alembic_version
from onyx.background.celery.apps.app_base import task_logger
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.app_configs import TARGET_AVAILABLE_TENANTS
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.engine import get_session_with_shared_schema
from onyx.db.models import AvailableTenant
from onyx.redis.redis_pool import get_redis_client
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import TENANT_ID_PREFIX
# Default number of pre-provisioned tenants to maintain
DEFAULT_TARGET_AVAILABLE_TENANTS = 5
# Soft time limit for tenant pre-provisioning tasks (in seconds)
_TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 5 # 5 minutes
# Hard time limit for tenant pre-provisioning tasks (in seconds)
_TENANT_PROVISIONING_TIME_LIMIT = 60 * 10 # 10 minutes
@shared_task(
name=OnyxCeleryTask.CHECK_AVAILABLE_TENANTS,
queue=OnyxCeleryQueues.MONITORING,
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
trail=False,
bind=True,
)
def check_available_tenants(self: Task) -> None:
"""
Check if we have enough pre-provisioned tenants available.
If not, trigger the pre-provisioning of new tenants.
"""
task_logger.info("STARTING CHECK_AVAILABLE_TENANTS")
if not MULTI_TENANT:
task_logger.info(
"Multi-tenancy is not enabled, skipping tenant pre-provisioning"
)
return
r = get_redis_client()
lock_check: RedisLock = r.lock(
OnyxRedisLocks.CHECK_AVAILABLE_TENANTS_LOCK,
timeout=_TENANT_PROVISIONING_SOFT_TIME_LIMIT,
)
# These tasks should never overlap
if not lock_check.acquire(blocking=False):
task_logger.info(
"Skipping check_available_tenants task because it is already running"
)
return
try:
# Get the current count of available tenants
with get_session_with_shared_schema() as db_session:
available_tenants_count = db_session.query(AvailableTenant).count()
# Get the target number of available tenants
target_available_tenants = getattr(
TARGET_AVAILABLE_TENANTS, "value", DEFAULT_TARGET_AVAILABLE_TENANTS
)
# Calculate how many new tenants we need to provision
tenants_to_provision = max(
0, target_available_tenants - available_tenants_count
)
task_logger.info(
f"Available tenants: {available_tenants_count}, "
f"Target: {target_available_tenants}, "
f"To provision: {tenants_to_provision}"
)
# Trigger pre-provisioning tasks for each tenant needed
for _ in range(tenants_to_provision):
from celery import current_app
current_app.send_task(
OnyxCeleryTask.PRE_PROVISION_TENANT,
priority=OnyxCeleryPriority.LOW,
)
except Exception:
task_logger.exception("Error in check_available_tenants task")
finally:
lock_check.release()
@shared_task(
name=OnyxCeleryTask.PRE_PROVISION_TENANT,
ignore_result=True,
soft_time_limit=_TENANT_PROVISIONING_SOFT_TIME_LIMIT,
time_limit=_TENANT_PROVISIONING_TIME_LIMIT,
queue=OnyxCeleryQueues.MONITORING,
bind=True,
)
def pre_provision_tenant(self: Task) -> None:
"""
Pre-provision a new tenant and store it in the NewAvailableTenant table.
This function fully sets up the tenant with all necessary configurations,
so it's ready to be assigned to a user immediately.
"""
# The MULTI_TENANT check is now done at the caller level (check_available_tenants)
# rather than inside this function
r = get_redis_client()
lock_provision: RedisLock = r.lock(
OnyxRedisLocks.PRE_PROVISION_TENANT_LOCK,
timeout=_TENANT_PROVISIONING_SOFT_TIME_LIMIT,
)
# Allow multiple pre-provisioning tasks to run, but ensure they don't overlap
if not lock_provision.acquire(blocking=False):
task_logger.debug(
"Skipping pre_provision_tenant task because it is already running"
)
return
tenant_id: str | None = None
try:
# Generate a new tenant ID
tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4())
task_logger.info(f"Pre-provisioning tenant: {tenant_id}")
# Create the schema for the new tenant
schema_created = create_schema_if_not_exists(tenant_id)
if schema_created:
task_logger.debug(f"Created schema for tenant: {tenant_id}")
else:
task_logger.debug(f"Schema already exists for tenant: {tenant_id}")
# Set up the tenant with all necessary configurations
task_logger.debug(f"Setting up tenant configuration: {tenant_id}")
asyncio.run(setup_tenant(tenant_id))
task_logger.debug(f"Tenant configuration completed: {tenant_id}")
# Get the current Alembic version
alembic_version = get_current_alembic_version(tenant_id)
task_logger.debug(
f"Tenant {tenant_id} using Alembic version: {alembic_version}"
)
# Store the pre-provisioned tenant in the database
task_logger.debug(f"Storing pre-provisioned tenant in database: {tenant_id}")
with get_session_with_shared_schema() as db_session:
# Use a transaction to ensure atomicity
db_session.begin()
try:
new_tenant = AvailableTenant(
tenant_id=tenant_id,
alembic_version=alembic_version,
date_created=datetime.datetime.now(),
)
db_session.add(new_tenant)
db_session.commit()
task_logger.info(f"Successfully pre-provisioned tenant: {tenant_id}")
except Exception:
db_session.rollback()
task_logger.error(
f"Failed to store pre-provisioned tenant: {tenant_id}",
exc_info=True,
)
raise
except Exception:
task_logger.error("Error in pre_provision_tenant task", exc_info=True)
# If we have a tenant_id, attempt to rollback any partially completed provisioning
if tenant_id:
task_logger.info(
f"Rolling back failed tenant provisioning for: {tenant_id}"
)
try:
from ee.onyx.server.tenants.provisioning import (
rollback_tenant_provisioning,
)
asyncio.run(rollback_tenant_provisioning(tenant_id))
except Exception:
task_logger.exception(f"Error during rollback for tenant: {tenant_id}")
finally:
lock_provision.release()

View File

@@ -28,7 +28,6 @@ from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import Document
from onyx.connectors.models import IndexAttemptMetadata
from onyx.connectors.models import TextSection
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import get_last_successful_attempt_time
from onyx.db.connector_credential_pair import update_connector_credential_pair
@@ -155,12 +154,14 @@ def strip_null_characters(doc_batch: list[Document]) -> list[Document]:
)
for section in cleaned_doc.sections:
if section.link is not None:
if section.link and "\x00" in section.link:
logger.warning(
f"NUL characters found in document link for document: {cleaned_doc.id}"
)
section.link = section.link.replace("\x00", "")
# since text can be longer, just replace to avoid double scan
if isinstance(section, TextSection) and section.text is not None:
section.text = section.text.replace("\x00", "")
section.text = section.text.replace("\x00", "")
cleaned_batch.append(cleaned_doc)
@@ -478,11 +479,7 @@ def _run_indexing(
doc_size = 0
for section in doc.sections:
if (
isinstance(section, TextSection)
and section.text is not None
):
doc_size += len(section.text)
doc_size += len(section.text)
if doc_size > INDEXING_SIZE_WARNING_THRESHOLD:
logger.warning(

View File

@@ -1,13 +1,10 @@
from collections import OrderedDict
from collections.abc import Callable
from collections.abc import Iterator
from collections.abc import Mapping
from datetime import datetime
from enum import Enum
from typing import Any
from typing import Literal
from typing import TYPE_CHECKING
from typing import Union
from pydantic import BaseModel
from pydantic import ConfigDict
@@ -47,44 +44,9 @@ class LlmDoc(BaseModel):
class SubQuestionIdentifier(BaseModel):
"""None represents references to objects in the original flow. To our understanding,
these will not be None in the packets returned from agent search.
"""
level: int | None = None
level_question_num: int | None = None
@staticmethod
def make_dict_by_level(
original_dict: Mapping[tuple[int, int], "SubQuestionIdentifier"]
) -> dict[int, list["SubQuestionIdentifier"]]:
"""returns a dict of level to object list (sorted by level_question_num)
Ordering is asc for readability.
"""
# organize by level, then sort ascending by question_index
level_dict: dict[int, list[SubQuestionIdentifier]] = {}
# group by level
for k, obj in original_dict.items():
level = k[0]
if level not in level_dict:
level_dict[level] = []
level_dict[level].append(obj)
# for each level, sort the group
for k2, value2 in level_dict.items():
# we need to handle the none case due to SubQuestionIdentifier typing
# level_question_num as int | None, even though it should never be None here.
level_dict[k2] = sorted(
value2,
key=lambda x: (x.level_question_num is None, x.level_question_num),
)
# sort by level
sorted_dict = OrderedDict(sorted(level_dict.items()))
return sorted_dict
# First chunk of info for streaming QA
class QADocsResponse(RetrievalDocs, SubQuestionIdentifier):
@@ -374,8 +336,6 @@ class AgentAnswerPiece(SubQuestionIdentifier):
class SubQuestionPiece(SubQuestionIdentifier):
"""Refined sub questions generated from the initial user question."""
sub_question: str
@@ -387,13 +347,13 @@ class RefinedAnswerImprovement(BaseModel):
refined_answer_improvement: bool
AgentSearchPacket = Union[
AgentSearchPacket = (
SubQuestionPiece
| AgentAnswerPiece
| SubQueryPiece
| ExtendedToolResponse
| RefinedAnswerImprovement
]
)
AnswerPacket = (
AnswerQuestionPossibleReturn | AgentSearchPacket | ToolCallKickoff | ToolResponse

View File

@@ -8,9 +8,6 @@ from onyx.configs.constants import AuthType
from onyx.configs.constants import DocumentIndexType
from onyx.configs.constants import QueryHistoryType
from onyx.file_processing.enums import HtmlBasedConnectorTransformLinksStrategy
from onyx.prompts.image_analysis import DEFAULT_IMAGE_ANALYSIS_SYSTEM_PROMPT
from onyx.prompts.image_analysis import DEFAULT_IMAGE_SUMMARIZATION_SYSTEM_PROMPT
from onyx.prompts.image_analysis import DEFAULT_IMAGE_SUMMARIZATION_USER_PROMPT
#####
# App Configs
@@ -646,24 +643,3 @@ MOCK_LLM_RESPONSE = (
DEFAULT_IMAGE_ANALYSIS_MAX_SIZE_MB = 20
# Number of pre-provisioned tenants to maintain
TARGET_AVAILABLE_TENANTS = int(os.environ.get("TARGET_AVAILABLE_TENANTS", "5"))
# Image summarization configuration
IMAGE_SUMMARIZATION_SYSTEM_PROMPT = os.environ.get(
"IMAGE_SUMMARIZATION_SYSTEM_PROMPT",
DEFAULT_IMAGE_SUMMARIZATION_SYSTEM_PROMPT,
)
# The user prompt for image summarization - the image filename will be automatically prepended
IMAGE_SUMMARIZATION_USER_PROMPT = os.environ.get(
"IMAGE_SUMMARIZATION_USER_PROMPT",
DEFAULT_IMAGE_SUMMARIZATION_USER_PROMPT,
)
IMAGE_ANALYSIS_SYSTEM_PROMPT = os.environ.get(
"IMAGE_ANALYSIS_SYSTEM_PROMPT",
DEFAULT_IMAGE_ANALYSIS_SYSTEM_PROMPT,
)

View File

@@ -76,7 +76,6 @@ KV_REINDEX_KEY = "needs_reindexing"
KV_SEARCH_SETTINGS = "search_settings"
KV_UNSTRUCTURED_API_KEY = "unstructured_api_key"
KV_USER_STORE_KEY = "INVITED_USERS"
KV_PENDING_USERS_KEY = "PENDING_USERS"
KV_NO_AUTH_USER_PREFERENCES_KEY = "no_auth_user_preferences"
KV_CRED_KEY = "credential_id_{}"
KV_GMAIL_CRED_KEY = "gmail_app_credential"
@@ -322,8 +321,6 @@ class OnyxRedisLocks:
"da_lock:check_connector_external_group_sync_beat"
)
MONITOR_BACKGROUND_PROCESSES_LOCK = "da_lock:monitor_background_processes"
CHECK_AVAILABLE_TENANTS_LOCK = "da_lock:check_available_tenants"
PRE_PROVISION_TENANT_LOCK = "da_lock:pre_provision_tenant"
CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX = (
"da_lock:connector_doc_permissions_sync"
@@ -386,7 +383,6 @@ class OnyxCeleryTask:
CLOUD_MONITOR_CELERY_QUEUES = (
f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor_celery_queues"
)
CHECK_AVAILABLE_TENANTS = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check_available_tenants"
CHECK_FOR_CONNECTOR_DELETION = "check_for_connector_deletion_task"
CHECK_FOR_VESPA_SYNC_TASK = "check_for_vespa_sync_task"
@@ -403,9 +399,6 @@ class OnyxCeleryTask:
MONITOR_BACKGROUND_PROCESSES = "monitor_background_processes"
MONITOR_CELERY_QUEUES = "monitor_celery_queues"
# Tenant pre-provisioning
PRE_PROVISION_TENANT = "pre_provision_tenant"
KOMBU_MESSAGE_CLEANUP_TASK = "kombu_message_cleanup_task"
CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK = (
"connector_permission_sync_generator_task"

View File

@@ -4,7 +4,6 @@ from concurrent.futures import Future
from concurrent.futures import ThreadPoolExecutor
from io import BytesIO
from typing import Any
from typing import cast
import requests
from pyairtable import Api as AirtableApi
@@ -17,8 +16,7 @@ from onyx.configs.constants import DocumentSource
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.utils.logger import setup_logger
@@ -269,7 +267,7 @@ class AirtableConnector(LoadConnector):
table_id: str,
view_id: str | None,
record_id: str,
) -> tuple[list[TextSection], dict[str, str | list[str]]]:
) -> tuple[list[Section], dict[str, str | list[str]]]:
"""
Process a single Airtable field and return sections or metadata.
@@ -307,7 +305,7 @@ class AirtableConnector(LoadConnector):
# Otherwise, create relevant sections
sections = [
TextSection(
Section(
link=link,
text=(
f"{field_name}:\n"
@@ -342,7 +340,7 @@ class AirtableConnector(LoadConnector):
table_name = table_schema.name
record_id = record["id"]
fields = record["fields"]
sections: list[TextSection] = []
sections: list[Section] = []
metadata: dict[str, str | list[str]] = {}
# Get primary field value if it exists
@@ -386,7 +384,7 @@ class AirtableConnector(LoadConnector):
return Document(
id=f"airtable__{record_id}",
sections=(cast(list[TextSection | ImageSection], sections)),
sections=sections,
source=DocumentSource.AIRTABLE,
semantic_identifier=semantic_id,
metadata=metadata,

View File

@@ -10,7 +10,7 @@ from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -82,7 +82,7 @@ class AsanaConnector(LoadConnector, PollConnector):
logger.debug(f"Converting Asana task {task.id} to Document")
return Document(
id=task.id,
sections=[TextSection(link=task.link, text=task.text)],
sections=[Section(link=task.link, text=task.text)],
doc_updated_at=task.last_modified,
source=DocumentSource.ASANA,
semantic_identifier=task.title,

View File

@@ -20,7 +20,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.file_processing.html_utils import parse_html_page_basic
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
@@ -221,7 +221,7 @@ def _get_forums(
def _translate_forum_to_doc(af: AxeroForum) -> Document:
doc = Document(
id=af.doc_id,
sections=[TextSection(link=af.link, text=reply) for reply in af.responses],
sections=[Section(link=af.link, text=reply) for reply in af.responses],
source=DocumentSource.AXERO,
semantic_identifier=af.title,
doc_updated_at=af.last_update,
@@ -244,7 +244,7 @@ def _translate_content_to_doc(content: dict) -> Document:
doc = Document(
id="AXERO_" + str(content["ContentID"]),
sections=[TextSection(link=content["ContentURL"], text=page_text)],
sections=[Section(link=content["ContentURL"], text=page_text)],
source=DocumentSource.AXERO,
semantic_identifier=content["ContentTitle"],
doc_updated_at=time_str_to_utc(content["DateUpdated"]),

View File

@@ -25,7 +25,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.utils.logger import setup_logger
@@ -208,7 +208,7 @@ class BlobStorageConnector(LoadConnector, PollConnector):
batch.append(
Document(
id=f"{self.bucket_type}:{self.bucket_name}:{obj['Key']}",
sections=[TextSection(link=link, text=text)],
sections=[Section(link=link, text=text)],
source=DocumentSource(self.bucket_type.value),
semantic_identifier=name,
doc_updated_at=last_modified,
@@ -341,14 +341,7 @@ if __name__ == "__main__":
print("Sections:")
for section in doc.sections:
print(f" - Link: {section.link}")
if isinstance(section, TextSection) and section.text is not None:
print(f" - Text: {section.text[:100]}...")
elif (
hasattr(section, "image_file_name") and section.image_file_name
):
print(f" - Image: {section.image_file_name}")
else:
print("Error: Unknown section type")
print(f" - Text: {section.text[:100]}...")
print("---")
break

View File

@@ -18,7 +18,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.file_processing.html_utils import parse_html_page_basic
@@ -81,7 +81,7 @@ class BookstackConnector(LoadConnector, PollConnector):
)
return Document(
id="book__" + str(book.get("id")),
sections=[TextSection(link=url, text=text)],
sections=[Section(link=url, text=text)],
source=DocumentSource.BOOKSTACK,
semantic_identifier="Book: " + title,
title=title,
@@ -110,7 +110,7 @@ class BookstackConnector(LoadConnector, PollConnector):
)
return Document(
id="chapter__" + str(chapter.get("id")),
sections=[TextSection(link=url, text=text)],
sections=[Section(link=url, text=text)],
source=DocumentSource.BOOKSTACK,
semantic_identifier="Chapter: " + title,
title=title,
@@ -134,7 +134,7 @@ class BookstackConnector(LoadConnector, PollConnector):
)
return Document(
id="shelf:" + str(shelf.get("id")),
sections=[TextSection(link=url, text=text)],
sections=[Section(link=url, text=text)],
source=DocumentSource.BOOKSTACK,
semantic_identifier="Shelf: " + title,
title=title,
@@ -167,7 +167,7 @@ class BookstackConnector(LoadConnector, PollConnector):
time.sleep(0.1)
return Document(
id="page:" + page_id,
sections=[TextSection(link=url, text=text)],
sections=[Section(link=url, text=text)],
source=DocumentSource.BOOKSTACK,
semantic_identifier="Page: " + str(title),
title=str(title),

View File

@@ -17,7 +17,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.utils.retry_wrapper import retry_builder
@@ -62,11 +62,11 @@ class ClickupConnector(LoadConnector, PollConnector):
return response.json()
def _get_task_comments(self, task_id: str) -> list[TextSection]:
def _get_task_comments(self, task_id: str) -> list[Section]:
url_endpoint = f"/task/{task_id}/comment"
response = self._make_request(url_endpoint)
comments = [
TextSection(
Section(
link=f'https://app.clickup.com/t/{task_id}?comment={comment_dict["id"]}',
text=comment_dict["comment_text"],
)
@@ -133,7 +133,7 @@ class ClickupConnector(LoadConnector, PollConnector):
],
title=task["name"],
sections=[
TextSection(
Section(
link=task["url"],
text=(
task["markdown_description"]

View File

@@ -33,9 +33,9 @@ from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import Section
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.connectors.vision_enabled_connector import VisionEnabledConnector
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
@@ -85,6 +85,7 @@ class ConfluenceConnector(
PollConnector,
SlimConnector,
CredentialsConnector,
VisionEnabledConnector,
):
def __init__(
self,
@@ -115,6 +116,9 @@ class ConfluenceConnector(
self._confluence_client: OnyxConfluence | None = None
self._fetched_titles: set[str] = set()
# Initialize vision LLM using the mixin
self.initialize_vision_llm()
# Remove trailing slash from wiki_base if present
self.wiki_base = wiki_base.rstrip("/")
"""
@@ -241,16 +245,12 @@ class ConfluenceConnector(
)
# Create the main section for the page content
sections: list[TextSection | ImageSection] = [
TextSection(text=page_content, link=page_url)
]
sections = [Section(text=page_content, link=page_url)]
# Process comments if available
comment_text = self._get_comment_string_for_page_id(page_id)
if comment_text:
sections.append(
TextSection(text=comment_text, link=f"{page_url}#comments")
)
sections.append(Section(text=comment_text, link=f"{page_url}#comments"))
# Process attachments
if "children" in page and "attachment" in page["children"]:
@@ -263,27 +263,21 @@ class ConfluenceConnector(
result = process_attachment(
self.confluence_client,
attachment,
page_id,
page_title,
self.image_analysis_llm,
)
if result and result.text:
if result.text:
# Create a section for the attachment text
attachment_section = TextSection(
attachment_section = Section(
text=result.text,
link=f"{page_url}#attachment-{attachment['id']}",
)
sections.append(attachment_section)
elif result and result.file_name:
# Create an ImageSection for image attachments
image_section = ImageSection(
link=f"{page_url}#attachment-{attachment['id']}",
image_file_name=result.file_name,
)
sections.append(image_section)
else:
sections.append(attachment_section)
elif result.error:
logger.warning(
f"Error processing attachment '{attachment.get('title')}':",
f"{result.error if result else 'Unknown error'}",
f"Error processing attachment '{attachment.get('title')}': {result.error}"
)
# Extract metadata
@@ -354,7 +348,7 @@ class ConfluenceConnector(
# Now get attachments for that page:
attachment_query = self._construct_attachment_query(page["id"])
# We'll use the page's XML to provide context if we summarize an image
page.get("body", {}).get("storage", {}).get("value", "")
confluence_xml = page.get("body", {}).get("storage", {}).get("value", "")
for attachment in self.confluence_client.paginated_cql_retrieval(
cql=attachment_query,
@@ -362,7 +356,7 @@ class ConfluenceConnector(
):
attachment["metadata"].get("mediaType", "")
if not validate_attachment_filetype(
attachment,
attachment, self.image_analysis_llm
):
continue
@@ -372,26 +366,23 @@ class ConfluenceConnector(
response = convert_attachment_to_content(
confluence_client=self.confluence_client,
attachment=attachment,
page_id=page["id"],
page_context=confluence_xml,
llm=self.image_analysis_llm,
)
if response is None:
continue
content_text, file_storage_name = response
object_url = build_confluence_document_id(
self.wiki_base, attachment["_links"]["webui"], self.is_cloud
)
if content_text:
doc.sections.append(
TextSection(
Section(
text=content_text,
link=object_url,
)
)
elif file_storage_name:
doc.sections.append(
ImageSection(
link=object_url,
image_file_name=file_storage_name,
)
)
@@ -471,7 +462,7 @@ class ConfluenceConnector(
# If you skip images, you'll skip them in the permission sync
attachment["metadata"].get("mediaType", "")
if not validate_attachment_filetype(
attachment,
attachment, self.image_analysis_llm
):
continue

View File

@@ -1,3 +1,4 @@
import io
import json
import time
from collections.abc import Callable
@@ -18,11 +19,17 @@ from requests import HTTPError
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET
from onyx.configs.app_configs import (
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
)
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
from onyx.connectors.confluence.utils import _handle_http_error
from onyx.connectors.confluence.utils import confluence_refresh_tokens
from onyx.connectors.confluence.utils import get_start_param_from_url
from onyx.connectors.confluence.utils import update_param_in_path
from onyx.connectors.confluence.utils import validate_attachment_filetype
from onyx.connectors.interfaces import CredentialsProviderInterface
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.html_utils import format_document_soup
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.logger import setup_logger
@@ -801,6 +808,65 @@ def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str:
return _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) or _USER_NOT_FOUND
def attachment_to_content(
confluence_client: OnyxConfluence,
attachment: dict[str, Any],
parent_content_id: str | None = None,
) -> str | None:
"""If it returns None, assume that we should skip this attachment."""
if not validate_attachment_filetype(attachment):
return None
if "api.atlassian.com" in confluence_client.url:
# https://developer.atlassian.com/cloud/confluence/rest/v1/api-group-content---attachments/#api-wiki-rest-api-content-id-child-attachment-attachmentid-download-get
if not parent_content_id:
logger.warning(
"parent_content_id is required to download attachments from Confluence Cloud!"
)
return None
download_link = (
confluence_client.url
+ f"/rest/api/content/{parent_content_id}/child/attachment/{attachment['id']}/download"
)
else:
download_link = confluence_client.url + attachment["_links"]["download"]
attachment_size = attachment["extensions"]["fileSize"]
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
logger.warning(
f"Skipping {download_link} due to size. "
f"size={attachment_size} "
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD}"
)
return None
logger.info(f"_attachment_to_content - _session.get: link={download_link}")
# why are we using session.get here? we probably won't retry these ... is that ok?
response = confluence_client._session.get(download_link)
if response.status_code != 200:
logger.warning(
f"Failed to fetch {download_link} with invalid status code {response.status_code}"
)
return None
extracted_text = extract_file_text(
io.BytesIO(response.content),
file_name=attachment["title"],
break_on_unprocessable=False,
)
if len(extracted_text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD:
logger.warning(
f"Skipping {download_link} due to char count. "
f"char count={len(extracted_text)} "
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD}"
)
return None
return extracted_text
def extract_text_from_confluence_html(
confluence_client: OnyxConfluence,
confluence_object: dict[str, Any],

View File

@@ -22,7 +22,6 @@ from sqlalchemy.orm import Session
from onyx.configs.app_configs import (
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
)
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
from onyx.configs.constants import FileOrigin
if TYPE_CHECKING:
@@ -36,6 +35,7 @@ from onyx.db.pg_file_store import upsert_pgfilestore
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.file_validation import is_valid_image_type
from onyx.file_processing.image_utils import store_image_and_create_section
from onyx.llm.interfaces import LLM
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -53,16 +53,17 @@ class TokenResponse(BaseModel):
def validate_attachment_filetype(
attachment: dict[str, Any],
attachment: dict[str, Any], llm: LLM | None = None
) -> bool:
"""
Validates if the attachment is a supported file type.
If LLM is provided, also checks if it's an image that can be processed.
"""
attachment.get("metadata", {})
media_type = attachment.get("metadata", {}).get("mediaType", "")
if media_type.startswith("image/"):
return is_valid_image_type(media_type)
return llm is not None and is_valid_image_type(media_type)
# For non-image files, check if we support the extension
title = attachment.get("title", "")
@@ -83,103 +84,55 @@ class AttachmentProcessingResult(BaseModel):
error: str | None = None
def _make_attachment_link(
confluence_client: "OnyxConfluence",
attachment: dict[str, Any],
parent_content_id: str | None = None,
) -> str | None:
download_link = ""
if "api.atlassian.com" in confluence_client.url:
# https://developer.atlassian.com/cloud/confluence/rest/v1/api-group-content---attachments/#api-wiki-rest-api-content-id-child-attachment-attachmentid-download-get
if not parent_content_id:
logger.warning(
"parent_content_id is required to download attachments from Confluence Cloud!"
)
return None
download_link = (
confluence_client.url
+ f"/rest/api/content/{parent_content_id}/child/attachment/{attachment['id']}/download"
def _download_attachment(
confluence_client: "OnyxConfluence", attachment: dict[str, Any]
) -> bytes | None:
"""
Retrieves the raw bytes of an attachment from Confluence. Returns None on error.
"""
download_link = confluence_client.url + attachment["_links"]["download"]
resp = confluence_client._session.get(download_link)
if resp.status_code != 200:
logger.warning(
f"Failed to fetch {download_link} with status code {resp.status_code}"
)
else:
download_link = confluence_client.url + attachment["_links"]["download"]
return download_link
return None
return resp.content
def process_attachment(
confluence_client: "OnyxConfluence",
attachment: dict[str, Any],
parent_content_id: str | None,
page_context: str,
llm: LLM | None,
) -> AttachmentProcessingResult:
"""
Processes a Confluence attachment. If it's a document, extracts text,
or if it's an image, stores it for later analysis. Returns a structured result.
or if it's an image and an LLM is available, summarizes it. Returns a structured result.
"""
try:
# Get the media type from the attachment metadata
media_type = attachment.get("metadata", {}).get("mediaType", "")
# Validate the attachment type
if not validate_attachment_filetype(attachment):
if not validate_attachment_filetype(attachment, llm):
return AttachmentProcessingResult(
text=None,
file_name=None,
error=f"Unsupported file type: {media_type}",
)
attachment_link = _make_attachment_link(
confluence_client, attachment, parent_content_id
)
if not attachment_link:
return AttachmentProcessingResult(
text=None, file_name=None, error="Failed to make attachment link"
)
attachment_size = attachment["extensions"]["fileSize"]
if not media_type.startswith("image/"):
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
logger.warning(
f"Skipping {attachment_link} due to size. "
f"size={attachment_size} "
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD}"
)
return AttachmentProcessingResult(
text=None,
file_name=None,
error=f"Attachment text too long: {attachment_size} chars",
)
logger.info(
f"Downloading attachment: "
f"title={attachment['title']} "
f"length={attachment_size} "
f"link={attachment_link}"
)
# Download the attachment
resp: requests.Response = confluence_client._session.get(attachment_link)
if resp.status_code != 200:
logger.warning(
f"Failed to fetch {attachment_link} with status code {resp.status_code}"
)
raw_bytes = _download_attachment(confluence_client, attachment)
if raw_bytes is None:
return AttachmentProcessingResult(
text=None,
file_name=None,
error=f"Attachment download status code is {resp.status_code}",
text=None, file_name=None, error="Failed to download attachment"
)
raw_bytes = resp.content
if not raw_bytes:
return AttachmentProcessingResult(
text=None, file_name=None, error="attachment.content is None"
)
# Process image attachments
if media_type.startswith("image/"):
# Process image attachments with LLM if available
if media_type.startswith("image/") and llm:
return _process_image_attachment(
confluence_client, attachment, raw_bytes, media_type
confluence_client, attachment, page_context, llm, raw_bytes, media_type
)
# Process document attachments
@@ -212,10 +165,12 @@ def process_attachment(
def _process_image_attachment(
confluence_client: "OnyxConfluence",
attachment: dict[str, Any],
page_context: str,
llm: LLM,
raw_bytes: bytes,
media_type: str,
) -> AttachmentProcessingResult:
"""Process an image attachment by saving it without generating a summary."""
"""Process an image attachment by saving it and generating a summary."""
try:
# Use the standardized image storage and section creation
with get_session_with_current_tenant() as db_session:
@@ -225,14 +180,15 @@ def _process_image_attachment(
file_name=Path(attachment["id"]).name,
display_name=attachment["title"],
media_type=media_type,
llm=llm,
file_origin=FileOrigin.CONNECTOR,
)
logger.info(f"Stored image attachment with file name: {file_name}")
# Return empty text but include the file_name for later processing
return AttachmentProcessingResult(text="", file_name=file_name, error=None)
return AttachmentProcessingResult(
text=section.text, file_name=file_name, error=None
)
except Exception as e:
msg = f"Image storage failed for {attachment['title']}: {e}"
msg = f"Image summarization failed for {attachment['title']}: {e}"
logger.error(msg, exc_info=e)
return AttachmentProcessingResult(text=None, file_name=None, error=msg)
@@ -293,12 +249,13 @@ def _process_text_attachment(
def convert_attachment_to_content(
confluence_client: "OnyxConfluence",
attachment: dict[str, Any],
page_id: str,
page_context: str,
llm: LLM | None,
) -> tuple[str | None, str | None] | None:
"""
Facade function which:
1. Validates attachment type
2. Extracts content or stores image for later processing
2. Extracts or summarizes content
3. Returns (content_text, stored_file_name) or None if we should skip it
"""
media_type = attachment["metadata"]["mediaType"]
@@ -309,7 +266,7 @@ def convert_attachment_to_content(
)
return None
result = process_attachment(confluence_client, attachment, page_id)
result = process_attachment(confluence_client, attachment, page_context, llm)
if result.error is not None:
logger.warning(
f"Attachment {attachment['title']} encountered error: {result.error}"

View File

@@ -4,7 +4,6 @@ from collections.abc import Iterable
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
from discord import Client
from discord.channel import TextChannel
@@ -21,8 +20,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -34,7 +32,7 @@ _SNIPPET_LENGTH = 30
def _convert_message_to_document(
message: DiscordMessage,
sections: list[TextSection],
sections: list[Section],
) -> Document:
"""
Convert a discord message to a document
@@ -80,7 +78,7 @@ def _convert_message_to_document(
semantic_identifier=semantic_identifier,
doc_updated_at=message.edited_at,
title=title,
sections=(cast(list[TextSection | ImageSection], sections)),
sections=sections,
metadata=metadata,
)
@@ -125,8 +123,8 @@ async def _fetch_documents_from_channel(
if channel_message.type != MessageType.default:
continue
sections: list[TextSection] = [
TextSection(
sections: list[Section] = [
Section(
text=channel_message.content,
link=channel_message.jump_url,
)
@@ -144,7 +142,7 @@ async def _fetch_documents_from_channel(
continue
sections = [
TextSection(
Section(
text=thread_message.content,
link=thread_message.jump_url,
)
@@ -162,7 +160,7 @@ async def _fetch_documents_from_channel(
continue
sections = [
TextSection(
Section(
text=thread_message.content,
link=thread_message.jump_url,
)

View File

@@ -3,7 +3,6 @@ import urllib.parse
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
import requests
from pydantic import BaseModel
@@ -21,8 +20,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.file_processing.html_utils import parse_html_page_basic
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
@@ -114,7 +112,7 @@ class DiscourseConnector(PollConnector):
responders.append(BasicExpertInfo(display_name=responder_name))
sections.append(
TextSection(link=topic_url, text=parse_html_page_basic(post["cooked"]))
Section(link=topic_url, text=parse_html_page_basic(post["cooked"]))
)
category_name = self.category_id_map.get(topic["category_id"])
@@ -131,7 +129,7 @@ class DiscourseConnector(PollConnector):
doc = Document(
id="_".join([DocumentSource.DISCOURSE.value, str(topic["id"])]),
sections=cast(list[TextSection | ImageSection], sections),
sections=sections,
source=DocumentSource.DISCOURSE,
semantic_identifier=topic["title"],
doc_updated_at=time_str_to_utc(topic["last_posted_at"]),

View File

@@ -19,7 +19,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.file_processing.html_utils import parse_html_page_basic
from onyx.utils.retry_wrapper import retry_builder
@@ -158,7 +158,7 @@ class Document360Connector(LoadConnector, PollConnector):
document = Document(
id=article_details["id"],
sections=[TextSection(link=doc_link, text=doc_text)],
sections=[Section(link=doc_link, text=doc_text)],
source=DocumentSource.DOCUMENT360,
semantic_identifier=article_details["title"],
doc_updated_at=updated_at,

View File

@@ -19,7 +19,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.utils.logger import setup_logger
@@ -108,7 +108,7 @@ class DropboxConnector(LoadConnector, PollConnector):
batch.append(
Document(
id=f"doc:{entry.id}",
sections=[TextSection(link=link, text=text)],
sections=[Section(link=link, text=text)],
source=DocumentSource.DROPBOX,
semantic_identifier=entry.name,
doc_updated_at=modified_time,

View File

@@ -24,7 +24,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.file_processing.extract_file_text import detect_encoding
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import get_file_ext
@@ -111,7 +111,7 @@ def _process_egnyte_file(
# Create the document
return Document(
id=f"egnyte-{file_metadata['entry_id']}",
sections=[TextSection(text=file_content_raw.strip(), link=web_url)],
sections=[Section(text=file_content_raw.strip(), link=web_url)],
source=DocumentSource.EGNYTE,
semantic_identifier=file_name,
metadata=metadata,

View File

@@ -16,8 +16,8 @@ from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.connectors.vision_enabled_connector import VisionEnabledConnector
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.pg_file_store import get_pgfilestore_by_file_name
from onyx.file_processing.extract_file_text import extract_text_and_images
@@ -26,6 +26,7 @@ from onyx.file_processing.extract_file_text import is_valid_file_ext
from onyx.file_processing.extract_file_text import load_files_from_zip
from onyx.file_processing.image_utils import store_image_and_create_section
from onyx.file_store.file_store import get_default_file_store
from onyx.llm.interfaces import LLM
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -58,44 +59,32 @@ def _read_files_and_metadata(
def _create_image_section(
llm: LLM | None,
image_data: bytes,
db_session: Session,
parent_file_name: str,
display_name: str,
link: str | None = None,
idx: int = 0,
) -> tuple[ImageSection, str | None]:
) -> tuple[Section, str | None]:
"""
Creates an ImageSection for an image file or embedded image.
Stores the image in PGFileStore but does not generate a summary.
Args:
image_data: Raw image bytes
db_session: Database session
parent_file_name: Name of the parent file (for embedded images)
display_name: Display name for the image
idx: Index for embedded images
Create a Section object for a single image and store the image in PGFileStore.
If summarization is enabled and we have an LLM, summarize the image.
Returns:
Tuple of (ImageSection, stored_file_name or None)
tuple: (Section object, file_name in PGFileStore or None if storage failed)
"""
# Create a unique identifier for the image
file_name = f"{parent_file_name}_embedded_{idx}" if idx > 0 else parent_file_name
# Create a unique file name for the embedded image
file_name = f"{parent_file_name}_embedded_{idx}"
# Store the image and create a section
try:
section, stored_file_name = store_image_and_create_section(
db_session=db_session,
image_data=image_data,
file_name=file_name,
display_name=display_name,
link=link,
file_origin=FileOrigin.CONNECTOR,
)
return section, stored_file_name
except Exception as e:
logger.error(f"Failed to store image {display_name}: {e}")
raise e
# Use the standardized utility to store the image and create a section
return store_image_and_create_section(
db_session=db_session,
image_data=image_data,
file_name=file_name,
display_name=display_name,
llm=llm,
file_origin=FileOrigin.OTHER,
)
def _process_file(
@@ -104,16 +93,12 @@ def _process_file(
metadata: dict[str, Any] | None,
pdf_pass: str | None,
db_session: Session,
llm: LLM | None,
) -> list[Document]:
"""
Process a file and return a list of Documents.
For images, creates ImageSection objects without summarization.
For documents with embedded images, extracts and stores the images.
Processes a single file, returning a list of Documents (typically one).
Also handles embedded images if 'EMBEDDED_IMAGE_EXTRACTION_ENABLED' is true.
"""
if metadata is None:
metadata = {}
# Get file extension and determine file type
extension = get_file_ext(file_name)
# Fetch the DB record so we know the ID for internal URL
@@ -129,6 +114,8 @@ def _process_file(
return []
# Prepare doc metadata
if metadata is None:
metadata = {}
file_display_name = metadata.get("file_display_name") or os.path.basename(file_name)
# Timestamps
@@ -171,7 +158,6 @@ def _process_file(
"title",
"connector_type",
"pdf_password",
"mime_type",
]
}
@@ -184,45 +170,33 @@ def _process_file(
title = metadata.get("title") or file_display_name
# 1) If the file itself is an image, handle that scenario quickly
if extension in LoadConnector.IMAGE_EXTENSIONS:
# Read the image data
IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".webp"}
if extension in IMAGE_EXTENSIONS:
# Summarize or produce empty doc
image_data = file.read()
if not image_data:
logger.warning(f"Empty image file: {file_name}")
return []
# Create an ImageSection for the image
try:
section, _ = _create_image_section(
image_data=image_data,
db_session=db_session,
parent_file_name=pg_record.file_name,
display_name=title,
image_section, _ = _create_image_section(
llm, image_data, db_session, pg_record.file_name, title
)
return [
Document(
id=doc_id,
sections=[image_section],
source=source_type,
semantic_identifier=file_display_name,
title=title,
doc_updated_at=final_time_updated,
primary_owners=p_owners,
secondary_owners=s_owners,
metadata=metadata_tags,
)
]
return [
Document(
id=doc_id,
sections=[section],
source=source_type,
semantic_identifier=file_display_name,
title=title,
doc_updated_at=final_time_updated,
primary_owners=p_owners,
secondary_owners=s_owners,
metadata=metadata_tags,
)
]
except Exception as e:
logger.error(f"Failed to process image file {file_name}: {e}")
return []
# 2) Otherwise: text-based approach. Possibly with embedded images.
# 2) Otherwise: text-based approach. Possibly with embedded images if enabled.
# (For example .docx with inline images).
file.seek(0)
text_content = ""
embedded_images: list[tuple[bytes, str]] = []
# Extract text and images from the file
text_content, embedded_images = extract_text_and_images(
file=file,
file_name=file_name,
@@ -230,29 +204,24 @@ def _process_file(
)
# Build sections: first the text as a single Section
sections: list[TextSection | ImageSection] = []
sections = []
link_in_meta = metadata.get("link")
if text_content.strip():
sections.append(TextSection(link=link_in_meta, text=text_content.strip()))
sections.append(Section(link=link_in_meta, text=text_content.strip()))
# Then any extracted images from docx, etc.
for idx, (img_data, img_name) in enumerate(embedded_images, start=1):
# Store each embedded image as a separate file in PGFileStore
# and create a section with the image reference
try:
image_section, _ = _create_image_section(
image_data=img_data,
db_session=db_session,
parent_file_name=pg_record.file_name,
display_name=f"{title} - image {idx}",
idx=idx,
)
sections.append(image_section)
except Exception as e:
logger.warning(
f"Failed to process embedded image {idx} in {file_name}: {e}"
)
# and create a section with the image summary
image_section, _ = _create_image_section(
llm,
img_data,
db_session,
pg_record.file_name,
f"{title} - image {idx}",
idx,
)
sections.append(image_section)
return [
Document(
id=doc_id,
@@ -268,10 +237,10 @@ def _process_file(
]
class LocalFileConnector(LoadConnector):
class LocalFileConnector(LoadConnector, VisionEnabledConnector):
"""
Connector that reads files from Postgres and yields Documents, including
embedded image extraction without summarization.
optional embedded image extraction.
"""
def __init__(
@@ -283,6 +252,9 @@ class LocalFileConnector(LoadConnector):
self.batch_size = batch_size
self.pdf_pass: str | None = None
# Initialize vision LLM using the mixin
self.initialize_vision_llm()
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self.pdf_pass = credentials.get("pdf_password")
@@ -314,6 +286,7 @@ class LocalFileConnector(LoadConnector):
metadata=metadata,
pdf_pass=self.pdf_pass,
db_session=db_session,
llm=self.image_analysis_llm,
)
documents.extend(new_docs)

View File

@@ -1,7 +1,6 @@
from collections.abc import Iterator
from datetime import datetime
from datetime import timezone
from typing import cast
from typing import List
import requests
@@ -15,8 +14,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -47,7 +45,7 @@ _FIREFLIES_API_QUERY = """
def _create_doc_from_transcript(transcript: dict) -> Document | None:
sections: List[TextSection] = []
sections: List[Section] = []
current_speaker_name = None
current_link = ""
current_text = ""
@@ -59,7 +57,7 @@ def _create_doc_from_transcript(transcript: dict) -> Document | None:
if sentence["speaker_name"] != current_speaker_name:
if current_speaker_name is not None:
sections.append(
TextSection(
Section(
link=current_link,
text=current_text.strip(),
)
@@ -73,7 +71,7 @@ def _create_doc_from_transcript(transcript: dict) -> Document | None:
# Sometimes these links (links with a timestamp) do not work, it is a bug with Fireflies.
sections.append(
TextSection(
Section(
link=current_link,
text=current_text.strip(),
)
@@ -96,7 +94,7 @@ def _create_doc_from_transcript(transcript: dict) -> Document | None:
return Document(
id=fireflies_id,
sections=cast(list[TextSection | ImageSection], sections),
sections=sections,
source=DocumentSource.FIREFLIES,
semantic_identifier=meeting_title,
metadata={},

View File

@@ -14,7 +14,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.file_processing.html_utils import parse_html_page_basic
from onyx.utils.logger import setup_logger
@@ -133,7 +133,7 @@ def _create_doc_from_ticket(ticket: dict, domain: str) -> Document:
return Document(
id=_FRESHDESK_ID_PREFIX + link,
sections=[
TextSection(
Section(
link=link,
text=text,
)

View File

@@ -13,7 +13,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.utils.logger import setup_logger
@@ -183,7 +183,7 @@ def _convert_page_to_document(
return Document(
id=f"gitbook-{space_id}-{page_id}",
sections=[
TextSection(
Section(
link=page.get("urls", {}).get("app", ""),
text=_extract_text_from_document(page_content),
)
@@ -228,15 +228,10 @@ class GitbookConnector(LoadConnector, PollConnector):
raise ConnectorMissingCredentialError("GitBook")
try:
content = self.client.get(f"/spaces/{self.space_id}/content/pages")
content = self.client.get(f"/spaces/{self.space_id}/content")
pages: list[dict[str, Any]] = content.get("pages", [])
current_batch: list[Document] = []
logger.info(f"Found {len(pages)} root pages.")
logger.info(
f"First 20 Page Ids: {[page.get('id', 'Unknown') for page in pages[:20]]}"
)
while pages:
page = pages.pop(0)

View File

@@ -27,7 +27,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.utils.batching import batch_generator
from onyx.utils.logger import setup_logger
@@ -87,9 +87,7 @@ def _batch_github_objects(
def _convert_pr_to_document(pull_request: PullRequest) -> Document:
return Document(
id=pull_request.html_url,
sections=[
TextSection(link=pull_request.html_url, text=pull_request.body or "")
],
sections=[Section(link=pull_request.html_url, text=pull_request.body or "")],
source=DocumentSource.GITHUB,
semantic_identifier=pull_request.title,
# updated_at is UTC time but is timezone unaware, explicitly add UTC
@@ -111,7 +109,7 @@ def _fetch_issue_comments(issue: Issue) -> str:
def _convert_issue_to_document(issue: Issue) -> Document:
return Document(
id=issue.html_url,
sections=[TextSection(link=issue.html_url, text=issue.body or "")],
sections=[Section(link=issue.html_url, text=issue.body or "")],
source=DocumentSource.GITHUB,
semantic_identifier=issue.title,
# updated_at is UTC time but is timezone unaware

View File

@@ -21,7 +21,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.utils.logger import setup_logger
@@ -56,7 +56,7 @@ def get_author(author: Any) -> BasicExpertInfo:
def _convert_merge_request_to_document(mr: Any) -> Document:
doc = Document(
id=mr.web_url,
sections=[TextSection(link=mr.web_url, text=mr.description or "")],
sections=[Section(link=mr.web_url, text=mr.description or "")],
source=DocumentSource.GITLAB,
semantic_identifier=mr.title,
# updated_at is UTC time but is timezone unaware, explicitly add UTC
@@ -72,7 +72,7 @@ def _convert_merge_request_to_document(mr: Any) -> Document:
def _convert_issue_to_document(issue: Any) -> Document:
doc = Document(
id=issue.web_url,
sections=[TextSection(link=issue.web_url, text=issue.description or "")],
sections=[Section(link=issue.web_url, text=issue.description or "")],
source=DocumentSource.GITLAB,
semantic_identifier=issue.title,
# updated_at is UTC time but is timezone unaware, explicitly add UTC
@@ -99,7 +99,7 @@ def _convert_code_to_document(
file_url = f"{url}/{projectOwner}/{projectName}/-/blob/master/{file['path']}" # Construct the file URL
doc = Document(
id=file["id"],
sections=[TextSection(link=file_url, text=file_content)],
sections=[Section(link=file_url, text=file_content)],
source=DocumentSource.GITLAB,
semantic_identifier=file["name"],
doc_updated_at=datetime.now().replace(

View File

@@ -1,6 +1,5 @@
from base64 import urlsafe_b64decode
from typing import Any
from typing import cast
from typing import Dict
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
@@ -29,9 +28,8 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import Section
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
@@ -117,7 +115,7 @@ def _get_message_body(payload: dict[str, Any]) -> str:
return message_body
def message_to_section(message: Dict[str, Any]) -> tuple[TextSection, dict[str, str]]:
def message_to_section(message: Dict[str, Any]) -> tuple[Section, dict[str, str]]:
link = f"https://mail.google.com/mail/u/0/#inbox/{message['id']}"
payload = message.get("payload", {})
@@ -144,7 +142,7 @@ def message_to_section(message: Dict[str, Any]) -> tuple[TextSection, dict[str,
message_body_text: str = _get_message_body(payload)
return TextSection(link=link, text=message_body_text + message_data), metadata
return Section(link=link, text=message_body_text + message_data), metadata
def thread_to_document(full_thread: Dict[str, Any]) -> Document | None:
@@ -194,7 +192,7 @@ def thread_to_document(full_thread: Dict[str, Any]) -> Document | None:
return Document(
id=id,
semantic_identifier=semantic_identifier,
sections=cast(list[TextSection | ImageSection], sections),
sections=sections,
source=DocumentSource.GMAIL,
# This is used to perform permission sync
primary_owners=primary_owners,

View File

@@ -18,7 +18,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.utils.logger import setup_logger
@@ -243,7 +243,7 @@ class GongConnector(LoadConnector, PollConnector):
Document(
id=call_id,
sections=[
TextSection(link=call_metadata["url"], text=transcript_text)
Section(link=call_metadata["url"], text=transcript_text)
],
source=DocumentSource.GONG,
# Should not ever be Untitled as a call cannot be made without a Title

View File

@@ -43,7 +43,9 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.vision_enabled_connector import VisionEnabledConnector
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.llm.interfaces import LLM
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
@@ -66,6 +68,7 @@ def _convert_single_file(
creds: Any,
primary_admin_email: str,
file: dict[str, Any],
image_analysis_llm: LLM | None,
) -> Any:
user_email = file.get("owners", [{}])[0].get("emailAddress") or primary_admin_email
user_drive_service = get_drive_service(creds, user_email=user_email)
@@ -74,6 +77,7 @@ def _convert_single_file(
file=file,
drive_service=user_drive_service,
docs_service=docs_service,
image_analysis_llm=image_analysis_llm, # pass the LLM so doc_conversion can summarize images
)
@@ -112,7 +116,9 @@ def _clean_requested_drive_ids(
return valid_requested_drive_ids, filtered_folder_ids
class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
class GoogleDriveConnector(
LoadConnector, PollConnector, SlimConnector, VisionEnabledConnector
):
def __init__(
self,
include_shared_drives: bool = False,
@@ -145,6 +151,9 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
if continue_on_failure is not None:
logger.warning("The 'continue_on_failure' parameter is deprecated.")
# Initialize vision LLM using the mixin
self.initialize_vision_llm()
if (
not include_shared_drives
and not include_my_drives
@@ -530,6 +539,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
_convert_single_file,
self.creds,
self.primary_admin_email,
image_analysis_llm=self.image_analysis_llm, # Use the mixin's LLM
)
# Fetch files in batches

View File

@@ -1,50 +1,40 @@
import io
from datetime import datetime
from typing import cast
from datetime import timezone
from tempfile import NamedTemporaryFile
from googleapiclient.http import MediaIoBaseDownload # type: ignore
import openpyxl # type: ignore
from googleapiclient.discovery import build # type: ignore
from googleapiclient.errors import HttpError # type: ignore
from onyx.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import FileOrigin
from onyx.connectors.google_drive.constants import DRIVE_FOLDER_TYPE
from onyx.connectors.google_drive.constants import DRIVE_SHORTCUT_TYPE
from onyx.connectors.google_drive.constants import UNSUPPORTED_FILE_TYPE_CONTENT
from onyx.connectors.google_drive.models import GDriveMimeType
from onyx.connectors.google_drive.models import GoogleDriveFileType
from onyx.connectors.google_drive.section_extraction import get_document_sections
from onyx.connectors.google_utils.resources import GoogleDocsService
from onyx.connectors.google_utils.resources import GoogleDriveService
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import Section
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.db.engine import get_session_with_current_tenant
from onyx.file_processing.extract_file_text import docx_to_text_and_images
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import pptx_to_text
from onyx.file_processing.extract_file_text import read_pdf_file
from onyx.file_processing.extract_file_text import xlsx_to_text
from onyx.file_processing.file_validation import is_valid_image_type
from onyx.file_processing.image_summarization import summarize_image_with_error_handling
from onyx.file_processing.image_utils import store_image_and_create_section
from onyx.file_processing.unstructured import get_unstructured_api_key
from onyx.file_processing.unstructured import unstructured_to_text
from onyx.llm.interfaces import LLM
from onyx.utils.logger import setup_logger
logger = setup_logger()
# Mapping of Google Drive mime types to export formats
GOOGLE_MIME_TYPES_TO_EXPORT = {
GDriveMimeType.DOC.value: "text/plain",
GDriveMimeType.SPREADSHEET.value: "text/csv",
GDriveMimeType.PPT.value: "text/plain",
}
# Define Google MIME types mapping
GOOGLE_MIME_TYPES = {
GDriveMimeType.DOC.value: "text/plain",
GDriveMimeType.SPREADSHEET.value: "text/csv",
GDriveMimeType.PPT.value: "text/plain",
}
def _summarize_drive_image(
image_data: bytes, image_name: str, image_analysis_llm: LLM | None
@@ -76,137 +66,259 @@ def is_gdrive_image_mime_type(mime_type: str) -> bool:
def _extract_sections_basic(
file: dict[str, str],
service: GoogleDriveService,
) -> list[TextSection | ImageSection]:
"""Extract text and images from a Google Drive file."""
file_id = file["id"]
file_name = file["name"]
image_analysis_llm: LLM | None = None,
) -> list[Section]:
"""
Extends the existing logic to handle either a docx with embedded images
or standalone images (PNG, JPG, etc).
"""
mime_type = file["mimeType"]
link = file.get("webViewLink", "")
link = file["webViewLink"]
file_name = file.get("name", file["id"])
supported_file_types = set(item.value for item in GDriveMimeType)
try:
# For Google Docs, Sheets, and Slides, export as plain text
if mime_type in GOOGLE_MIME_TYPES_TO_EXPORT:
export_mime_type = GOOGLE_MIME_TYPES_TO_EXPORT[mime_type]
# Use the correct API call for exporting files
request = service.files().export_media(
fileId=file_id, mimeType=export_mime_type
)
response_bytes = io.BytesIO()
downloader = MediaIoBaseDownload(response_bytes, request)
done = False
while not done:
_, done = downloader.next_chunk()
# 1) If the file is an image, retrieve the raw bytes, optionally summarize
if is_gdrive_image_mime_type(mime_type):
try:
response = service.files().get_media(fileId=file["id"]).execute()
response = response_bytes.getvalue()
if not response:
logger.warning(f"Failed to export {file_name} as {export_mime_type}")
return []
text = response.decode("utf-8")
return [TextSection(link=link, text=text)]
# For other file types, download the file
# Use the correct API call for downloading files
request = service.files().get_media(fileId=file_id)
response_bytes = io.BytesIO()
downloader = MediaIoBaseDownload(response_bytes, request)
done = False
while not done:
_, done = downloader.next_chunk()
response = response_bytes.getvalue()
if not response:
logger.warning(f"Failed to download {file_name}")
return []
# Process based on mime type
if mime_type == "text/plain":
text = response.decode("utf-8")
return [TextSection(link=link, text=text)]
elif (
mime_type
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
):
text, _ = docx_to_text_and_images(io.BytesIO(response))
return [TextSection(link=link, text=text)]
elif (
mime_type
== "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
):
text = xlsx_to_text(io.BytesIO(response))
return [TextSection(link=link, text=text)]
elif (
mime_type
== "application/vnd.openxmlformats-officedocument.presentationml.presentation"
):
text = pptx_to_text(io.BytesIO(response))
return [TextSection(link=link, text=text)]
elif is_gdrive_image_mime_type(mime_type):
# For images, store them for later processing
sections: list[TextSection | ImageSection] = []
try:
with get_session_with_current_tenant() as db_session:
section, embedded_id = store_image_and_create_section(
db_session=db_session,
image_data=response,
file_name=file_id,
display_name=file_name,
media_type=mime_type,
file_origin=FileOrigin.CONNECTOR,
link=link,
)
sections.append(section)
except Exception as e:
logger.error(f"Failed to process image {file_name}: {e}")
return sections
elif mime_type == "application/pdf":
text, _pdf_meta, images = read_pdf_file(io.BytesIO(response))
pdf_sections: list[TextSection | ImageSection] = [
TextSection(link=link, text=text)
with get_session_with_current_tenant() as db_session:
section, _ = store_image_and_create_section(
db_session=db_session,
image_data=response,
file_name=file["id"],
display_name=file_name,
media_type=mime_type,
llm=image_analysis_llm,
file_origin=FileOrigin.CONNECTOR,
)
return [section]
except Exception as e:
logger.warning(f"Failed to fetch or summarize image: {e}")
return [
Section(
link=link,
text="",
image_file_name=link,
)
]
# Process embedded images in the PDF
if mime_type not in supported_file_types:
# Unsupported file types can still have a title, finding this way is still useful
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
try:
# ---------------------------
# Google Sheets extraction
if mime_type == GDriveMimeType.SPREADSHEET.value:
try:
sheets_service = build(
"sheets", "v4", credentials=service._http.credentials
)
spreadsheet = (
sheets_service.spreadsheets()
.get(spreadsheetId=file["id"])
.execute()
)
sections = []
for sheet in spreadsheet["sheets"]:
sheet_name = sheet["properties"]["title"]
sheet_id = sheet["properties"]["sheetId"]
# Get sheet dimensions
grid_properties = sheet["properties"].get("gridProperties", {})
row_count = grid_properties.get("rowCount", 1000)
column_count = grid_properties.get("columnCount", 26)
# Convert column count to letter (e.g., 26 -> Z, 27 -> AA)
end_column = ""
while column_count:
column_count, remainder = divmod(column_count - 1, 26)
end_column = chr(65 + remainder) + end_column
range_name = f"'{sheet_name}'!A1:{end_column}{row_count}"
try:
result = (
sheets_service.spreadsheets()
.values()
.get(spreadsheetId=file["id"], range=range_name)
.execute()
)
values = result.get("values", [])
if values:
text = f"Sheet: {sheet_name}\n"
for row in values:
text += "\t".join(str(cell) for cell in row) + "\n"
sections.append(
Section(
link=f"{link}#gid={sheet_id}",
text=text,
)
)
except HttpError as e:
logger.warning(
f"Error fetching data for sheet '{sheet_name}': {e}"
)
continue
return sections
except Exception as e:
logger.warning(
f"Ran into exception '{e}' when pulling data from Google Sheet '{file['name']}'."
" Falling back to basic extraction."
)
# ---------------------------
# Microsoft Excel (.xlsx or .xls) extraction branch
elif mime_type in [
GDriveMimeType.SPREADSHEET_OPEN_FORMAT.value,
GDriveMimeType.SPREADSHEET_MS_EXCEL.value,
]:
try:
response = service.files().get_media(fileId=file["id"]).execute()
with NamedTemporaryFile(suffix=".xlsx", delete=True) as tmp:
tmp.write(response)
tmp_path = tmp.name
section_separator = "\n\n"
workbook = openpyxl.load_workbook(tmp_path, read_only=True)
# Work similarly to the xlsx_to_text function used for file connector
# but returns Sections instead of a string
sections = [
Section(
link=link,
text=(
f"Sheet: {sheet.title}\n\n"
+ section_separator.join(
",".join(map(str, row))
for row in sheet.iter_rows(
min_row=1, values_only=True
)
if row
)
),
)
for sheet in workbook.worksheets
]
return sections
except Exception as e:
logger.warning(
f"Error extracting data from Excel file '{file['name']}': {e}"
)
return [
Section(link=link, text="Error extracting data from Excel file")
]
# ---------------------------
# Export for Google Docs, PPT, and fallback for spreadsheets
if mime_type in [
GDriveMimeType.DOC.value,
GDriveMimeType.PPT.value,
GDriveMimeType.SPREADSHEET.value,
]:
export_mime_type = (
"text/plain"
if mime_type != GDriveMimeType.SPREADSHEET.value
else "text/csv"
)
text = (
service.files()
.export(fileId=file["id"], mimeType=export_mime_type)
.execute()
.decode("utf-8")
)
return [Section(link=link, text=text)]
# ---------------------------
# Plain text and Markdown files
elif mime_type in [
GDriveMimeType.PLAIN_TEXT.value,
GDriveMimeType.MARKDOWN.value,
]:
text_data = (
service.files().get_media(fileId=file["id"]).execute().decode("utf-8")
)
return [Section(link=link, text=text_data)]
# ---------------------------
# Word, PowerPoint, PDF files
elif mime_type in [
GDriveMimeType.WORD_DOC.value,
GDriveMimeType.POWERPOINT.value,
GDriveMimeType.PDF.value,
]:
response_bytes = service.files().get_media(fileId=file["id"]).execute()
# Optionally use Unstructured
if get_unstructured_api_key():
text = unstructured_to_text(
file=io.BytesIO(response_bytes),
file_name=file_name,
)
return [Section(link=link, text=text)]
if mime_type == GDriveMimeType.WORD_DOC.value:
# Use docx_to_text_and_images to get text plus embedded images
text, embedded_images = docx_to_text_and_images(
file=io.BytesIO(response_bytes),
)
sections = []
if text.strip():
sections.append(Section(link=link, text=text.strip()))
# Process each embedded image using the standardized function
with get_session_with_current_tenant() as db_session:
for idx, (img_data, img_name) in enumerate(images):
section, embedded_id = store_image_and_create_section(
for idx, (img_data, img_name) in enumerate(
embedded_images, start=1
):
# Create a unique identifier for the embedded image
embedded_id = f"{file['id']}_embedded_{idx}"
section, _ = store_image_and_create_section(
db_session=db_session,
image_data=img_data,
file_name=f"{file_id}_img_{idx}",
file_name=embedded_id,
display_name=img_name or f"{file_name} - image {idx}",
llm=image_analysis_llm,
file_origin=FileOrigin.CONNECTOR,
)
pdf_sections.append(section)
except Exception as e:
logger.error(f"Failed to process PDF images in {file_name}: {e}")
return pdf_sections
sections.append(section)
return sections
else:
# For unsupported file types, try to extract text
try:
text = extract_file_text(io.BytesIO(response), file_name)
return [TextSection(link=link, text=text)]
except Exception as e:
logger.warning(f"Failed to extract text from {file_name}: {e}")
return []
elif mime_type == GDriveMimeType.PDF.value:
text, _pdf_meta, images = read_pdf_file(io.BytesIO(response_bytes))
return [Section(link=link, text=text)]
elif mime_type == GDriveMimeType.POWERPOINT.value:
text_data = pptx_to_text(io.BytesIO(response_bytes))
return [Section(link=link, text=text_data)]
# Catch-all case, should not happen since there should be specific handling
# for each of the supported file types
error_message = f"Unsupported file type: {mime_type}"
logger.error(error_message)
raise ValueError(error_message)
except Exception as e:
logger.error(f"Error processing file {file_name}: {e}")
return []
logger.exception(f"Error extracting sections from file: {e}")
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
def convert_drive_item_to_document(
file: GoogleDriveFileType,
drive_service: GoogleDriveService,
docs_service: GoogleDocsService,
image_analysis_llm: LLM | None,
) -> Document | None:
"""
Main entry point for converting a Google Drive file => Document object.
Now we accept an optional `llm` to pass to `_extract_sections_basic`.
"""
try:
# skip shortcuts or folders
@@ -215,50 +327,44 @@ def convert_drive_item_to_document(
return None
# If it's a Google Doc, we might do advanced parsing
sections: list[TextSection | ImageSection] = []
# Try to get sections using the advanced method first
sections: list[Section] = []
if file.get("mimeType") == GDriveMimeType.DOC.value:
try:
doc_sections = get_document_sections(
docs_service=docs_service, doc_id=file.get("id", "")
)
if doc_sections:
sections = cast(list[TextSection | ImageSection], doc_sections)
# get_document_sections is the advanced approach for Google Docs
sections = get_document_sections(docs_service, file["id"])
except Exception as e:
logger.warning(
f"Error in advanced parsing: {e}. Falling back to basic extraction."
f"Failed to pull google doc sections from '{file['name']}': {e}. "
"Falling back to basic extraction."
)
# If we don't have sections yet, use the basic extraction method
# If not a doc, or if we failed above, do our 'basic' approach
if not sections:
sections = _extract_sections_basic(file, drive_service)
sections = _extract_sections_basic(file, drive_service, image_analysis_llm)
# If we still don't have any sections, skip this file
if not sections:
logger.warning(f"No content extracted from {file.get('name')}. Skipping.")
return None
doc_id = file["webViewLink"]
updated_time = datetime.fromisoformat(file["modifiedTime"]).astimezone(
timezone.utc
)
# Create the document
return Document(
id=doc_id,
sections=sections,
source=DocumentSource.GOOGLE_DRIVE,
semantic_identifier=file.get("name", ""),
metadata={
"owner_names": ", ".join(
owner.get("displayName", "") for owner in file.get("owners", [])
),
},
doc_updated_at=datetime.fromisoformat(
file.get("modifiedTime", "").replace("Z", "+00:00")
),
semantic_identifier=file["name"],
doc_updated_at=updated_time,
metadata={}, # or any metadata from 'file'
additional_info=file.get("id"),
)
except Exception as e:
logger.error(f"Error converting file {file.get('name')}: {e}")
return None
logger.exception(f"Error converting file '{file.get('name')}' to Document: {e}")
if not CONTINUE_ON_CONNECTOR_FAILURE:
raise
return None
def build_slim_document(file: GoogleDriveFileType) -> SlimDocument | None:

View File

@@ -3,7 +3,7 @@ from typing import Any
from pydantic import BaseModel
from onyx.connectors.google_utils.resources import GoogleDocsService
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
class CurrentHeading(BaseModel):
@@ -37,7 +37,7 @@ def _extract_text_from_paragraph(paragraph: dict[str, Any]) -> str:
def get_document_sections(
docs_service: GoogleDocsService,
doc_id: str,
) -> list[TextSection]:
) -> list[Section]:
"""Extracts sections from a Google Doc, including their headings and content"""
# Fetch the document structure
doc = docs_service.documents().get(documentId=doc_id).execute()
@@ -45,7 +45,7 @@ def get_document_sections(
# Get the content
content = doc.get("body", {}).get("content", [])
sections: list[TextSection] = []
sections: list[Section] = []
current_section: list[str] = []
current_heading: CurrentHeading | None = None
@@ -70,7 +70,7 @@ def get_document_sections(
heading_text = current_heading.text
section_text = f"{heading_text}\n" + "\n".join(current_section)
sections.append(
TextSection(
Section(
text=section_text.strip(),
link=_build_gdoc_section_link(doc_id, current_heading.id),
)
@@ -96,7 +96,7 @@ def get_document_sections(
if current_heading is not None and current_section:
section_text = f"{current_heading.text}\n" + "\n".join(current_section)
sections.append(
TextSection(
Section(
text=section_text.strip(),
link=_build_gdoc_section_link(doc_id, current_heading.id),
)

View File

@@ -12,7 +12,7 @@ from onyx.configs.constants import DocumentSource
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.db.engine import get_sqlalchemy_engine
from onyx.file_processing.extract_file_text import load_files_from_zip
from onyx.file_processing.extract_file_text import read_text_file
@@ -118,7 +118,7 @@ class GoogleSitesConnector(LoadConnector):
source=DocumentSource.GOOGLE_SITES,
semantic_identifier=title,
sections=[
TextSection(
Section(
link=(self.base_url.rstrip("/") + "/" + path.lstrip("/"))
if path
else "",

View File

@@ -15,7 +15,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.file_processing.html_utils import parse_html_page_basic
from onyx.utils.logger import setup_logger
@@ -120,7 +120,7 @@ class GuruConnector(LoadConnector, PollConnector):
doc_batch.append(
Document(
id=card["id"],
sections=[TextSection(link=link, text=content_text)],
sections=[Section(link=link, text=content_text)],
source=DocumentSource.GURU,
semantic_identifier=title,
doc_updated_at=latest_time,

View File

@@ -13,7 +13,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.utils.logger import setup_logger
HUBSPOT_BASE_URL = "https://app.hubspot.com/contacts/"
@@ -108,7 +108,7 @@ class HubSpotConnector(LoadConnector, PollConnector):
doc_batch.append(
Document(
id=ticket.id,
sections=[TextSection(link=link, text=content_text)],
sections=[Section(link=link, text=content_text)],
source=DocumentSource.HUBSPOT,
semantic_identifier=title,
# Is already in tzutc, just replacing the timezone format

View File

@@ -24,8 +24,6 @@ CheckpointOutput = Generator[Document | ConnectorFailure, None, ConnectorCheckpo
class BaseConnector(abc.ABC):
REDIS_KEY_PREFIX = "da_connector_data:"
# Common image file extensions supported across connectors
IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".webp", ".gif"}
@abc.abstractmethod
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:

View File

@@ -21,8 +21,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import request_with_retries
@@ -238,30 +237,22 @@ class LinearConnector(LoadConnector, PollConnector, OAuthConnector):
documents: list[Document] = []
for edge in edges:
node = edge["node"]
# Create sections for description and comments
sections = [
TextSection(
link=node["url"],
text=node["description"] or "",
)
]
# Add comment sections
for comment in node["comments"]["nodes"]:
sections.append(
TextSection(
link=node["url"],
text=comment["body"] or "",
)
)
# Cast the sections list to the expected type
typed_sections = cast(list[TextSection | ImageSection], sections)
documents.append(
Document(
id=node["id"],
sections=typed_sections,
sections=[
Section(
link=node["url"],
text=node["description"] or "",
)
]
+ [
Section(
link=node["url"],
text=comment["body"] or "",
)
for comment in node["comments"]["nodes"]
],
source=DocumentSource.LINEAR,
semantic_identifier=f"[{node['identifier']}] {node['title']}",
title=node["title"],

View File

@@ -17,7 +17,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.file_processing.html_utils import parse_html_page_basic
from onyx.file_processing.html_utils import strip_excessive_newlines_and_spaces
from onyx.utils.logger import setup_logger
@@ -162,7 +162,7 @@ class LoopioConnector(LoadConnector, PollConnector):
doc_batch.append(
Document(
id=str(entry["id"]),
sections=[TextSection(link=link, text=content_text)],
sections=[Section(link=link, text=content_text)],
source=DocumentSource.LOOPIO,
semantic_identifier=questions[0],
doc_updated_at=latest_time,

View File

@@ -6,7 +6,6 @@ import tempfile
from collections.abc import Generator
from collections.abc import Iterator
from typing import Any
from typing import cast
from typing import ClassVar
import pywikibot.time # type: ignore[import-untyped]
@@ -21,8 +20,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.mediawiki.family import family_class_dispatch
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.utils.logger import setup_logger
@@ -62,14 +60,14 @@ def get_doc_from_page(
sections_extracted: textlib.Content = textlib.extract_sections(page_text, site)
sections = [
TextSection(
Section(
link=f"{page.full_url()}#" + section.heading.replace(" ", "_"),
text=section.title + section.content,
)
for section in sections_extracted.sections
]
sections.append(
TextSection(
Section(
link=page.full_url(),
text=sections_extracted.header,
)
@@ -81,7 +79,7 @@ def get_doc_from_page(
doc_updated_at=pywikibot_timestamp_to_utc_datetime(
page.latest_revision.timestamp
),
sections=cast(list[TextSection | ImageSection], sections),
sections=sections,
semantic_identifier=page.title(),
metadata={"categories": [category.title() for category in page.categories()]},
id=f"MEDIAWIKI_{page.pageid}_{page.full_url()}",

View File

@@ -1,4 +1,3 @@
import json
from datetime import datetime
from enum import Enum
from typing import Any
@@ -28,25 +27,9 @@ class ConnectorMissingCredentialError(PermissionError):
class Section(BaseModel):
"""Base section class with common attributes"""
link: str | None = None
text: str | None = None
image_file_name: str | None = None
class TextSection(Section):
"""Section containing text content"""
text: str
link: str | None = None
class ImageSection(Section):
"""Section containing an image reference"""
image_file_name: str
link: str | None = None
image_file_name: str | None = None
class BasicExpertInfo(BaseModel):
@@ -116,7 +99,7 @@ class DocumentBase(BaseModel):
"""Used for Onyx ingestion api, the ID is inferred before use if not provided"""
id: str | None = None
sections: list[TextSection | ImageSection]
sections: list[Section]
source: DocumentSource | None = None
semantic_identifier: str # displayed in the UI as the main identifier for the doc
metadata: dict[str, str | list[str]]
@@ -166,11 +149,19 @@ class DocumentBase(BaseModel):
class Document(DocumentBase):
"""Used for Onyx ingestion api, the ID is required"""
id: str
id: str # This must be unique or during indexing/reindexing, chunks will be overwritten
source: DocumentSource
def get_total_char_length(self) -> int:
"""Calculate the total character length of the document including sections, metadata, and identifiers."""
section_length = sum(len(section.text) for section in self.sections)
identifier_length = len(self.semantic_identifier) + len(self.title or "")
metadata_length = sum(
len(k) + len(v) if isinstance(v, str) else len(k) + sum(len(x) for x in v)
for k, v in self.metadata.items()
)
return section_length + identifier_length + metadata_length
def to_short_descriptor(self) -> str:
"""Used when logging the identity of a document"""
return f"ID: '{self.id}'; Semantic ID: '{self.semantic_identifier}'"
@@ -193,32 +184,6 @@ class Document(DocumentBase):
)
class IndexingDocument(Document):
"""Document with processed sections for indexing"""
processed_sections: list[Section] = []
def get_total_char_length(self) -> int:
"""Get the total character length of the document including processed sections"""
title_len = len(self.title or self.semantic_identifier)
# Use processed_sections if available, otherwise fall back to original sections
if self.processed_sections:
section_len = sum(
len(section.text) if section.text is not None else 0
for section in self.processed_sections
)
else:
section_len = sum(
len(section.text)
if isinstance(section, TextSection) and section.text is not None
else 0
for section in self.sections
)
return title_len + section_len
class SlimDocument(BaseModel):
id: str
perm_sync_data: Any | None = None
@@ -239,15 +204,6 @@ class ConnectorCheckpoint(BaseModel):
def build_dummy_checkpoint(cls) -> "ConnectorCheckpoint":
return ConnectorCheckpoint(checkpoint_content={}, has_more=True)
def __str__(self) -> str:
"""String representation of the checkpoint, with truncation for large checkpoint content."""
MAX_CHECKPOINT_CONTENT_CHARS = 1000
content_str = json.dumps(self.checkpoint_content)
if len(content_str) > MAX_CHECKPOINT_CONTENT_CHARS:
content_str = content_str[: MAX_CHECKPOINT_CONTENT_CHARS - 3] + "..."
return f"ConnectorCheckpoint(checkpoint_content={content_str}, has_more={self.has_more})"
class DocumentFailure(BaseModel):
document_id: str

View File

@@ -1,3 +1,4 @@
import time
from collections.abc import Generator
from dataclasses import dataclass
from dataclasses import fields
@@ -25,13 +26,12 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.utils.batching import batch_generator
from onyx.utils.logger import setup_logger
logger = setup_logger()
_NOTION_PAGE_SIZE = 100
_NOTION_CALL_TIMEOUT = 30 # 30 seconds
@@ -475,7 +475,7 @@ class NotionConnector(LoadConnector, PollConnector):
Document(
id=page.id,
sections=[
TextSection(
Section(
link=f"{page.url}#{block.id.replace('-', '')}",
text=block.prefix + block.text,
)
@@ -537,9 +537,9 @@ class NotionConnector(LoadConnector, PollConnector):
"""
filtered_pages: list[NotionPage] = []
for page in pages:
# Parse ISO 8601 timestamp and convert to UTC epoch time
timestamp = page[filter_field].replace(".000Z", "+00:00")
compare_time = datetime.fromisoformat(timestamp).timestamp()
compare_time = time.mktime(
time.strptime(page[filter_field], "%Y-%m-%dT%H:%M:%S.000Z")
)
if compare_time > start and compare_time <= end:
filtered_pages += [NotionPage(**page)]
return filtered_pages
@@ -578,7 +578,7 @@ class NotionConnector(LoadConnector, PollConnector):
query_dict = {
"filter": {"property": "object", "value": "page"},
"page_size": _NOTION_PAGE_SIZE,
"page_size": self.batch_size,
}
while True:
db_res = self._search_notion(query_dict)
@@ -604,7 +604,7 @@ class NotionConnector(LoadConnector, PollConnector):
return
query_dict = {
"page_size": _NOTION_PAGE_SIZE,
"page_size": self.batch_size,
"sort": {"timestamp": "last_edited_time", "direction": "descending"},
"filter": {"property": "object", "value": "page"},
}

View File

@@ -23,8 +23,8 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.connectors.onyx_jira.utils import best_effort_basic_expert_info
from onyx.connectors.onyx_jira.utils import best_effort_get_field_from_issue
from onyx.connectors.onyx_jira.utils import build_jira_client
@@ -145,7 +145,7 @@ def fetch_jira_issues_batch(
yield Document(
id=page_url,
sections=[TextSection(link=page_url, text=ticket_content)],
sections=[Section(link=page_url, text=ticket_content)],
source=DocumentSource.JIRA,
semantic_identifier=f"{issue.key}: {issue.fields.summary}",
title=f"{issue.key} {issue.fields.summary}",

View File

@@ -16,7 +16,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.utils.logger import setup_logger
@@ -110,7 +110,7 @@ class ProductboardConnector(PollConnector):
yield Document(
id=feature["id"],
sections=[
TextSection(
Section(
link=feature["links"]["html"],
text=self._parse_description_html(feature["description"]),
)
@@ -133,7 +133,7 @@ class ProductboardConnector(PollConnector):
yield Document(
id=component["id"],
sections=[
TextSection(
Section(
link=component["links"]["html"],
text=self._parse_description_html(component["description"]),
)
@@ -159,7 +159,7 @@ class ProductboardConnector(PollConnector):
yield Document(
id=product["id"],
sections=[
TextSection(
Section(
link=product["links"]["html"],
text=self._parse_description_html(product["description"]),
)
@@ -189,7 +189,7 @@ class ProductboardConnector(PollConnector):
yield Document(
id=objective["id"],
sections=[
TextSection(
Section(
link=objective["links"]["html"],
text=self._parse_description_html(objective["description"]),
)

View File

@@ -13,7 +13,6 @@ from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.connectors.salesforce.doc_conversion import convert_sf_object_to_doc
from onyx.connectors.salesforce.doc_conversion import ID_PREFIX
from onyx.connectors.salesforce.salesforce_calls import fetch_all_csvs_in_parallel
@@ -49,12 +48,10 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
self,
credentials: dict[str, Any],
) -> dict[str, Any] | None:
domain = "test" if credentials.get("is_sandbox") else None
self._sf_client = Salesforce(
username=credentials["sf_username"],
password=credentials["sf_password"],
security_token=credentials["sf_security_token"],
domain=domain,
)
return None
@@ -219,8 +216,7 @@ if __name__ == "__main__":
for doc in doc_batch:
section_count += len(doc.sections)
for section in doc.sections:
if isinstance(section, TextSection) and section.text is not None:
text_count += len(section.text)
text_count += len(section.text)
end_time = time.time()
print(f"Doc count: {doc_count}")

View File

@@ -1,12 +1,10 @@
import re
from typing import cast
from onyx.configs.constants import DocumentSource
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.connectors.salesforce.sqlite_functions import get_child_ids
from onyx.connectors.salesforce.sqlite_functions import get_record
from onyx.connectors.salesforce.utils import SalesforceObject
@@ -116,8 +114,8 @@ def _extract_dict_text(raw_dict: dict) -> str:
return natural_language_for_dict
def _extract_section(salesforce_object: SalesforceObject, base_url: str) -> TextSection:
return TextSection(
def _extract_section(salesforce_object: SalesforceObject, base_url: str) -> Section:
return Section(
text=_extract_dict_text(salesforce_object.data),
link=f"{base_url}/{salesforce_object.id}",
)
@@ -177,7 +175,7 @@ def convert_sf_object_to_doc(
doc = Document(
id=onyx_salesforce_id,
sections=cast(list[TextSection | ImageSection], sections),
sections=sections,
source=DocumentSource.SALESFORCE,
semantic_identifier=extracted_semantic_identifier,
doc_updated_at=extracted_doc_updated_at,

View File

@@ -19,7 +19,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.utils.logger import setup_logger
@@ -55,7 +55,7 @@ def _convert_driveitem_to_document(
doc = Document(
id=driveitem.id,
sections=[TextSection(link=driveitem.web_url, text=file_text)],
sections=[Section(link=driveitem.web_url, text=file_text)],
source=DocumentSource.SHAREPOINT,
semantic_identifier=driveitem.name,
doc_updated_at=driveitem.last_modified_datetime.replace(tzinfo=timezone.utc),

View File

@@ -19,8 +19,8 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
@@ -212,7 +212,7 @@ class SlabConnector(LoadConnector, PollConnector, SlimConnector):
doc_batch.append(
Document(
id=post_id, # can't be url as this changes with the post title
sections=[TextSection(link=page_url, text=content_text)],
sections=[Section(link=page_url, text=content_text)],
source=DocumentSource.SLAB,
semantic_identifier=post["title"],
metadata={},

View File

@@ -34,8 +34,8 @@ from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import EntityFailure
from onyx.connectors.models import Section
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.connectors.slack.utils import expert_info_from_slack_id
from onyx.connectors.slack.utils import get_message_link
from onyx.connectors.slack.utils import make_paginated_slack_api_call_w_retries
@@ -211,7 +211,7 @@ def thread_to_doc(
return Document(
id=_build_doc_id(channel_id=channel_id, thread_ts=thread[0]["ts"]),
sections=[
TextSection(
Section(
link=get_message_link(event=m, client=client, channel_id=channel_id),
text=slack_cleaner.index_clean(cast(str, m["text"])),
)
@@ -674,7 +674,7 @@ class SlackConnector(SlimConnector, CheckpointConnector):
"""
1. Verify the bot token is valid for the workspace (via auth_test).
2. Ensure the bot has enough scope to list channels.
3. Check that every channel specified in self.channels exists (only when regex is not enabled).
3. Check that every channel specified in self.channels exists.
"""
if self.client is None:
raise ConnectorMissingCredentialError("Slack credentials not loaded.")
@@ -706,8 +706,8 @@ class SlackConnector(SlimConnector, CheckpointConnector):
f"Slack API returned a failure: {error_msg}"
)
# 3) If channels are specified and regex is not enabled, verify each is accessible
if self.channels and not self.channel_regex_enabled:
# 3) If channels are specified, verify each is accessible
if self.channels:
accessible_channels = get_channels(
client=self.client,
exclude_archived=True,

View File

@@ -24,7 +24,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.file_processing.html_utils import parse_html_page_basic
from onyx.utils.logger import setup_logger
@@ -165,7 +165,7 @@ def _convert_thread_to_document(
doc = Document(
id=post_id,
sections=[TextSection(link=web_url, text=thread_text)],
sections=[Section(link=web_url, text=thread_text)],
source=DocumentSource.TEAMS,
semantic_identifier=semantic_string,
title="", # teams threads don't really have a "title"

View File

@@ -0,0 +1,45 @@
"""
Mixin for connectors that need vision capabilities.
"""
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
from onyx.llm.factory import get_default_llm_with_vision
from onyx.llm.interfaces import LLM
from onyx.utils.logger import setup_logger
logger = setup_logger()
class VisionEnabledConnector:
"""
Mixin for connectors that need vision capabilities.
This mixin provides a standard way to initialize a vision-capable LLM
for image analysis during indexing.
Usage:
class MyConnector(LoadConnector, VisionEnabledConnector):
def __init__(self, ...):
super().__init__(...)
self.initialize_vision_llm()
"""
def initialize_vision_llm(self) -> None:
"""
Initialize a vision-capable LLM if enabled by configuration.
Sets self.image_analysis_llm to the LLM instance or None if disabled.
"""
self.image_analysis_llm: LLM | None = None
if get_image_extraction_and_analysis_enabled():
try:
self.image_analysis_llm = get_default_llm_with_vision()
if self.image_analysis_llm is None:
logger.warning(
"No LLM with vision found; image summarization will be disabled"
)
except Exception as e:
logger.warning(
f"Failed to initialize vision LLM due to an error: {str(e)}. "
"Image summarization will be disabled."
)
self.image_analysis_llm = None

View File

@@ -32,7 +32,7 @@ from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.file_processing.extract_file_text import read_pdf_file
from onyx.file_processing.html_utils import web_html_cleanup
from onyx.utils.logger import setup_logger
@@ -341,7 +341,7 @@ class WebConnector(LoadConnector):
doc_batch.append(
Document(
id=initial_url,
sections=[TextSection(link=initial_url, text=page_text)],
sections=[Section(link=initial_url, text=page_text)],
source=DocumentSource.WEB,
semantic_identifier=initial_url.split("/")[-1],
metadata=metadata,
@@ -443,7 +443,7 @@ class WebConnector(LoadConnector):
Document(
id=initial_url,
sections=[
TextSection(link=initial_url, text=parsed_html.cleaned_text)
Section(link=initial_url, text=parsed_html.cleaned_text)
],
source=DocumentSource.WEB,
semantic_identifier=parsed_html.title or initial_url,

View File

@@ -28,7 +28,7 @@ from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -104,7 +104,7 @@ def scrape_page_posts(
# id. We may want to de-dupe this stuff inside the indexing service.
document = Document(
id=f"{DocumentSource.XENFORO.value}_{title}_{page_index}_{formatted_time}",
sections=[TextSection(link=url, text=post_text)],
sections=[Section(link=url, text=post_text)],
title=title,
source=DocumentSource.XENFORO,
semantic_identifier=title,

View File

@@ -1,6 +1,5 @@
from collections.abc import Iterator
from typing import Any
from typing import cast
import requests
@@ -18,8 +17,8 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.file_processing.html_utils import parse_html_page_basic
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.retry_wrapper import retry_builder
@@ -169,8 +168,8 @@ def _article_to_document(
return new_author_mapping, Document(
id=f"article:{article['id']}",
sections=[
TextSection(
link=cast(str, article.get("html_url")),
Section(
link=article.get("html_url"),
text=parse_html_page_basic(article["body"]),
)
],
@@ -269,7 +268,7 @@ def _ticket_to_document(
return new_author_mapping, Document(
id=f"zendesk_ticket_{ticket['id']}",
sections=[TextSection(link=ticket_display_url, text=full_text)],
sections=[Section(link=ticket_display_url, text=full_text)],
source=DocumentSource.ZENDESK,
semantic_identifier=f"Ticket #{ticket['id']}: {subject or 'No Subject'}",
doc_updated_at=update_time,

View File

@@ -20,7 +20,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.connectors.zulip.schemas import GetMessagesResponse
from onyx.connectors.zulip.schemas import Message
from onyx.connectors.zulip.utils import build_search_narrow
@@ -161,7 +161,7 @@ class ZulipConnector(LoadConnector, PollConnector):
return Document(
id=f"{message.stream_id}__{message.id}",
sections=[
TextSection(
Section(
link=self._message_to_narrow_link(message),
text=text,
)

View File

@@ -10,7 +10,6 @@ from langchain_core.messages import SystemMessage
from onyx.chat.models import SectionRelevancePiece
from onyx.configs.app_configs import BLURB_SIZE
from onyx.configs.app_configs import IMAGE_ANALYSIS_SYSTEM_PROMPT
from onyx.configs.constants import RETURN_SEPARATOR
from onyx.configs.llm_configs import get_search_time_image_analysis_enabled
from onyx.configs.model_configs import CROSS_ENCODER_RANGE_MAX
@@ -32,6 +31,7 @@ from onyx.file_store.file_store import get_default_file_store
from onyx.llm.interfaces import LLM
from onyx.llm.utils import message_to_string
from onyx.natural_language_processing.search_nlp_models import RerankingModel
from onyx.prompts.image_analysis import IMAGE_ANALYSIS_SYSTEM_PROMPT
from onyx.secondary_llm_flows.chunk_usefulness import llm_batch_eval_sections
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import FunctionCall

View File

@@ -13,7 +13,6 @@ from onyx.db.models import SearchSettings
from onyx.db.models import Tool as ToolModel
from onyx.db.models import User
from onyx.db.models import User__UserGroup
from onyx.llm.utils import model_supports_image_input
from onyx.server.manage.embedding.models import CloudEmbeddingProvider
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
from onyx.server.manage.llm.models import FullLLMProvider
@@ -188,17 +187,6 @@ def fetch_default_provider(db_session: Session) -> FullLLMProvider | None:
return FullLLMProvider.from_model(provider_model)
def fetch_default_vision_provider(db_session: Session) -> FullLLMProvider | None:
provider_model = db_session.scalar(
select(LLMProviderModel).where(
LLMProviderModel.is_default_vision_provider == True # noqa: E712
)
)
if not provider_model:
return None
return FullLLMProvider.from_model(provider_model)
def fetch_provider(db_session: Session, provider_name: str) -> FullLLMProvider | None:
provider_model = db_session.scalar(
select(LLMProviderModel).where(LLMProviderModel.name == provider_name)
@@ -258,39 +246,3 @@ def update_default_provider(provider_id: int, db_session: Session) -> None:
new_default.is_default_provider = True
db_session.commit()
def update_default_vision_provider(
provider_id: int, vision_model: str | None, db_session: Session
) -> None:
new_default = db_session.scalar(
select(LLMProviderModel).where(LLMProviderModel.id == provider_id)
)
if not new_default:
raise ValueError(f"LLM Provider with id {provider_id} does not exist")
# Validate that the specified vision model supports image input
model_to_validate = vision_model or new_default.default_model_name
if model_to_validate:
if not model_supports_image_input(model_to_validate, new_default.provider):
raise ValueError(
f"Model '{model_to_validate}' for provider '{new_default.provider}' does not support image input"
)
else:
raise ValueError(
f"Model '{vision_model}' is not a valid model for provider '{new_default.provider}'"
)
existing_default = db_session.scalar(
select(LLMProviderModel).where(
LLMProviderModel.is_default_vision_provider == True # noqa: E712
)
)
if existing_default:
existing_default.is_default_vision_provider = None
# required to ensure that the below does not cause a unique constraint violation
db_session.flush()
new_default.is_default_vision_provider = True
new_default.default_vision_model = vision_model
db_session.commit()

View File

@@ -1489,10 +1489,6 @@ class LLMProvider(Base):
# should only be set for a single provider
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, unique=True)
is_default_vision_provider: Mapped[bool | None] = mapped_column(
Boolean, unique=True
)
default_vision_model: Mapped[str | None] = mapped_column(String, nullable=True)
# EE only
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
groups: Mapped[list["UserGroup"]] = relationship(
@@ -2299,31 +2295,21 @@ class PublicBase(DeclarativeBase):
__abstract__ = True
# Strictly keeps track of the tenant that a given user will authenticate to.
class UserTenantMapping(Base):
__tablename__ = "user_tenant_mapping"
__table_args__ = ({"schema": "public"},)
__table_args__ = (
UniqueConstraint("email", "tenant_id", name="uq_user_tenant"),
{"schema": "public"},
)
email: Mapped[str] = mapped_column(String, nullable=False, primary_key=True)
tenant_id: Mapped[str] = mapped_column(String, nullable=False, primary_key=True)
active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
tenant_id: Mapped[str] = mapped_column(String, nullable=False)
@validates("email")
def validate_email(self, key: str, value: str) -> str:
return value.lower() if value else value
class AvailableTenant(Base):
__tablename__ = "available_tenant"
"""
These entries will only exist ephemerally and are meant to be picked up by new users on registration.
"""
tenant_id: Mapped[str] = mapped_column(String, primary_key=True, nullable=False)
alembic_version: Mapped[str] = mapped_column(String, nullable=False)
date_created: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False)
# This is a mapping from tenant IDs to anonymous user paths
class TenantAnonymousUserPath(Base):
__tablename__ = "tenant_anonymous_user_path"

View File

@@ -67,9 +67,6 @@ def read_lobj(
use_tempfile: bool = False,
) -> IO:
pg_conn = get_pg_conn_from_session(db_session)
# Ensure we're using binary mode by default for large objects
if mode is None:
mode = "rb"
large_object = (
pg_conn.lobject(lobj_oid, mode=mode) if mode else pg_conn.lobject(lobj_oid)
)
@@ -84,7 +81,6 @@ def read_lobj(
temp_file.seek(0)
return temp_file
else:
# Ensure we're getting raw bytes without text decoding
return BytesIO(large_object.read())

View File

@@ -1,5 +1,6 @@
from sqlalchemy import and_
from sqlalchemy import delete
from sqlalchemy import func
from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -148,10 +149,11 @@ def delete_document_tags_for_documents__no_commit(
stmt = delete(Document__Tag).where(Document__Tag.document_id.in_(document_ids))
db_session.execute(stmt)
orphan_tags_query = select(Tag.id).where(
~db_session.query(Document__Tag.tag_id)
.filter(Document__Tag.tag_id == Tag.id)
.exists()
orphan_tags_query = (
select(Tag.id)
.outerjoin(Document__Tag, Tag.id == Document__Tag.tag_id)
.group_by(Tag.id)
.having(func.count(Document__Tag.document_id) == 0)
)
orphan_tags = db_session.execute(orphan_tags_query).scalars().all()

View File

@@ -6,10 +6,10 @@ from langchain_core.messages import HumanMessage
from langchain_core.messages import SystemMessage
from PIL import Image
from onyx.configs.app_configs import IMAGE_SUMMARIZATION_SYSTEM_PROMPT
from onyx.configs.app_configs import IMAGE_SUMMARIZATION_USER_PROMPT
from onyx.llm.interfaces import LLM
from onyx.llm.utils import message_to_string
from onyx.prompts.image_analysis import IMAGE_SUMMARIZATION_SYSTEM_PROMPT
from onyx.prompts.image_analysis import IMAGE_SUMMARIZATION_USER_PROMPT
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -62,7 +62,7 @@ def summarize_image_with_error_handling(
image_data: The raw image bytes
context_name: Name or title of the image for context
system_prompt: System prompt to use for the LLM
user_prompt_template: User prompt to use (without title)
user_prompt_template: Template for the user prompt, should contain {title} placeholder
Returns:
The image summary text, or None if summarization failed or is disabled
@@ -70,10 +70,7 @@ def summarize_image_with_error_handling(
if llm is None:
return None
# Prepend the image filename to the user prompt
user_prompt = (
f"The image has the file name '{context_name}'.\n{user_prompt_template}"
)
user_prompt = user_prompt_template.format(title=context_name)
return summarize_image_pipeline(llm, image_data, user_prompt, system_prompt)

View File

@@ -2,9 +2,12 @@ from typing import Tuple
from sqlalchemy.orm import Session
from onyx.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
from onyx.configs.constants import FileOrigin
from onyx.connectors.models import ImageSection
from onyx.connectors.models import Section
from onyx.db.pg_file_store import save_bytes_to_pgfilestore
from onyx.file_processing.image_summarization import summarize_image_with_error_handling
from onyx.llm.interfaces import LLM
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -15,12 +18,12 @@ def store_image_and_create_section(
image_data: bytes,
file_name: str,
display_name: str,
link: str | None = None,
media_type: str = "application/octet-stream",
media_type: str = "image/unknown",
llm: LLM | None = None,
file_origin: FileOrigin = FileOrigin.OTHER,
) -> Tuple[ImageSection, str | None]:
) -> Tuple[Section, str | None]:
"""
Stores an image in PGFileStore and creates an ImageSection object without summarization.
Stores an image in PGFileStore and creates a Section object with optional summarization.
Args:
db_session: Database session
@@ -28,11 +31,12 @@ def store_image_and_create_section(
file_name: Base identifier for the file
display_name: Human-readable name for the image
media_type: MIME type of the image
llm: Optional LLM with vision capabilities for summarization
file_origin: Origin of the file (e.g., CONFLUENCE, GOOGLE_DRIVE, etc.)
Returns:
Tuple containing:
- ImageSection object with image reference
- Section object with image reference and optional summary text
- The file_name in PGFileStore or None if storage failed
"""
# Storage logic
@@ -49,10 +53,18 @@ def store_image_and_create_section(
stored_file_name = pgfilestore.file_name
except Exception as e:
logger.error(f"Failed to store image: {e}")
raise e
if not CONTINUE_ON_CONNECTOR_FAILURE:
raise
return Section(text=""), None
# Summarization logic
summary_text = ""
if llm:
summary_text = (
summarize_image_with_error_handling(llm, image_data, display_name) or ""
)
# Create an ImageSection with empty text (will be filled by LLM later in the pipeline)
return (
ImageSection(image_file_name=stored_file_name, link=link),
Section(text=summary_text, image_file_name=stored_file_name),
stored_file_name,
)

View File

@@ -9,8 +9,7 @@ from onyx.configs.model_configs import DOC_EMBEDDING_CONTEXT_SIZE
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
get_metadata_keys_to_ignore,
)
from onyx.connectors.models import IndexingDocument
from onyx.connectors.models import Section
from onyx.connectors.models import Document
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.indexing.models import DocAwareChunk
from onyx.natural_language_processing.utils import BaseTokenizer
@@ -65,7 +64,7 @@ def _get_metadata_suffix_for_document_index(
def _combine_chunks(chunks: list[DocAwareChunk], large_chunk_id: int) -> DocAwareChunk:
"""
Combines multiple DocAwareChunks into one large chunk (for "multipass" mode),
Combines multiple DocAwareChunks into one large chunk (for multipass mode),
appending the content and adjusting source_links accordingly.
"""
merged_chunk = DocAwareChunk(
@@ -99,7 +98,7 @@ def _combine_chunks(chunks: list[DocAwareChunk], large_chunk_id: int) -> DocAwar
def generate_large_chunks(chunks: list[DocAwareChunk]) -> list[DocAwareChunk]:
"""
Generates larger "grouped" chunks by combining sets of smaller chunks.
Generates larger grouped chunks by combining sets of smaller chunks.
"""
large_chunks = []
for idx, i in enumerate(range(0, len(chunks), LARGE_CHUNK_RATIO)):
@@ -186,7 +185,7 @@ class Chunker:
def _get_mini_chunk_texts(self, chunk_text: str) -> list[str] | None:
"""
For "multipass" mode: additional sub-chunks (mini-chunks) for use in certain embeddings.
For multipass mode: additional sub-chunks (mini-chunks) for use in certain embeddings.
"""
if self.mini_chunk_splitter and chunk_text.strip():
return self.mini_chunk_splitter.split_text(chunk_text)
@@ -195,7 +194,7 @@ class Chunker:
# ADDED: extra param image_url to store in the chunk
def _create_chunk(
self,
document: IndexingDocument,
document: Document,
chunks_list: list[DocAwareChunk],
text: str,
links: dict[int, str],
@@ -226,29 +225,7 @@ class Chunker:
def _chunk_document(
self,
document: IndexingDocument,
title_prefix: str,
metadata_suffix_semantic: str,
metadata_suffix_keyword: str,
content_token_limit: int,
) -> list[DocAwareChunk]:
"""
Legacy method for backward compatibility.
Calls _chunk_document_with_sections with document.sections.
"""
return self._chunk_document_with_sections(
document,
document.processed_sections,
title_prefix,
metadata_suffix_semantic,
metadata_suffix_keyword,
content_token_limit,
)
def _chunk_document_with_sections(
self,
document: IndexingDocument,
sections: list[Section],
document: Document,
title_prefix: str,
metadata_suffix_semantic: str,
metadata_suffix_keyword: str,
@@ -256,16 +233,17 @@ class Chunker:
) -> list[DocAwareChunk]:
"""
Loops through sections of the document, converting them into one or more chunks.
Works with processed sections that are base Section objects.
If a section has an image_link, we treat it as a dedicated chunk.
"""
chunks: list[DocAwareChunk] = []
link_offsets: dict[int, str] = {}
chunk_text = ""
for section_idx, section in enumerate(sections):
# Get section text and other attributes
section_text = clean_text(section.text or "")
for section_idx, section in enumerate(document.sections):
section_text = clean_text(section.text)
section_link_text = section.link or ""
# ADDED: if the Section has an image link
image_url = section.image_file_name
# If there is no useful content, skip
@@ -276,7 +254,7 @@ class Chunker:
)
continue
# CASE 1: If this section has an image, force a separate chunk
# CASE 1: If this is an image section, force a separate chunk
if image_url:
# First, if we have any partially built text chunk, finalize it
if chunk_text.strip():
@@ -293,13 +271,15 @@ class Chunker:
chunk_text = ""
link_offsets = {}
# Create a chunk specifically for this image section
# (Using the text summary that was generated during processing)
# Create a chunk specifically for this image
# (If the section has text describing the image, use that as content)
self._create_chunk(
document,
chunks,
section_text,
links={0: section_link_text} if section_link_text else {},
links={0: section_link_text}
if section_link_text
else {}, # No text offsets needed for images
image_file_name=image_url,
title_prefix=title_prefix,
metadata_suffix_semantic=metadata_suffix_semantic,
@@ -404,9 +384,7 @@ class Chunker:
)
return chunks
def _handle_single_document(
self, document: IndexingDocument
) -> list[DocAwareChunk]:
def _handle_single_document(self, document: Document) -> list[DocAwareChunk]:
# Specifically for reproducing an issue with gmail
if document.source == DocumentSource.GMAIL:
logger.debug(f"Chunking {document.semantic_identifier}")
@@ -442,31 +420,26 @@ class Chunker:
title_prefix = ""
metadata_suffix_semantic = ""
# Use processed_sections if available (IndexingDocument), otherwise use original sections
sections_to_chunk = document.processed_sections
normal_chunks = self._chunk_document_with_sections(
# Chunk the document
normal_chunks = self._chunk_document(
document,
sections_to_chunk,
title_prefix,
metadata_suffix_semantic,
metadata_suffix_keyword,
content_token_limit,
)
# Optional "multipass" large chunk creation
# Optional multipass large chunk creation
if self.enable_multipass and self.enable_large_chunks:
large_chunks = generate_large_chunks(normal_chunks)
normal_chunks.extend(large_chunks)
return normal_chunks
def chunk(self, documents: list[IndexingDocument]) -> list[DocAwareChunk]:
def chunk(self, documents: list[Document]) -> list[DocAwareChunk]:
"""
Takes in a list of documents and chunks them into smaller chunks for indexing
while persisting the document metadata.
Works with both standard Document objects and IndexingDocument objects with processed_sections.
"""
final_chunks: list[DocAwareChunk] = []
for document in documents:

View File

@@ -10,18 +10,13 @@ from onyx.access.access import get_access_for_documents
from onyx.access.models import DocumentAccess
from onyx.configs.app_configs import MAX_DOCUMENT_CHARS
from onyx.configs.constants import DEFAULT_BOOST
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
from onyx.connectors.cross_connector_utils.miscellaneous_utils import (
get_experts_stores_representations,
)
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import Document
from onyx.connectors.models import DocumentFailure
from onyx.connectors.models import ImageSection
from onyx.connectors.models import IndexAttemptMetadata
from onyx.connectors.models import IndexingDocument
from onyx.connectors.models import Section
from onyx.connectors.models import TextSection
from onyx.db.document import fetch_chunk_counts_for_documents
from onyx.db.document import get_documents_by_ids
from onyx.db.document import mark_document_as_indexed_for_cc_pair__no_commit
@@ -32,10 +27,7 @@ from onyx.db.document import update_docs_updated_at__no_commit
from onyx.db.document import upsert_document_by_connector_credential_pair
from onyx.db.document import upsert_documents
from onyx.db.document_set import fetch_document_sets_for_documents
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.models import Document as DBDocument
from onyx.db.pg_file_store import get_pgfilestore_by_file_name
from onyx.db.pg_file_store import read_lobj
from onyx.db.search_settings import get_current_search_settings
from onyx.db.tag import create_or_add_document_tag
from onyx.db.tag import create_or_add_document_tag_list
@@ -45,7 +37,6 @@ from onyx.document_index.document_index_utils import (
from onyx.document_index.interfaces import DocumentIndex
from onyx.document_index.interfaces import DocumentMetadata
from onyx.document_index.interfaces import IndexBatchParams
from onyx.file_processing.image_summarization import summarize_image_with_error_handling
from onyx.indexing.chunker import Chunker
from onyx.indexing.embedder import embed_chunks_with_failure_handling
from onyx.indexing.embedder import IndexingEmbedder
@@ -53,7 +44,6 @@ from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.indexing.models import DocAwareChunk
from onyx.indexing.models import DocMetadataAwareIndexChunk
from onyx.indexing.vector_db_insertion import write_chunks_to_vector_db_with_backoff
from onyx.llm.factory import get_default_llm_with_vision
from onyx.utils.logger import setup_logger
from onyx.utils.timing import log_function_time
@@ -63,7 +53,6 @@ logger = setup_logger()
class DocumentBatchPrepareContext(BaseModel):
updatable_docs: list[Document]
id_to_db_doc_map: dict[str, DBDocument]
indexable_docs: list[IndexingDocument] = []
model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -276,12 +265,7 @@ def index_doc_batch_prepare(
def filter_documents(document_batch: list[Document]) -> list[Document]:
documents: list[Document] = []
for document in document_batch:
empty_contents = not any(
isinstance(section, TextSection)
and section.text is not None
and section.text.strip()
for section in document.sections
)
empty_contents = not any(section.text.strip() for section in document.sections)
if (
(not document.title or not document.title.strip())
and not document.semantic_identifier.strip()
@@ -304,12 +288,7 @@ def filter_documents(document_batch: list[Document]) -> list[Document]:
)
continue
section_chars = sum(
len(section.text)
if isinstance(section, TextSection) and section.text is not None
else 0
for section in document.sections
)
section_chars = sum(len(section.text) for section in document.sections)
if (
MAX_DOCUMENT_CHARS
and len(document.title or document.semantic_identifier) + section_chars
@@ -329,121 +308,6 @@ def filter_documents(document_batch: list[Document]) -> list[Document]:
return documents
def process_image_sections(documents: list[Document]) -> list[IndexingDocument]:
"""
Process all sections in documents by:
1. Converting both TextSection and ImageSection objects to base Section objects
2. Processing ImageSections to generate text summaries using a vision-capable LLM
3. Returning IndexingDocument objects with both original and processed sections
Args:
documents: List of documents with TextSection | ImageSection objects
Returns:
List of IndexingDocument objects with processed_sections as list[Section]
"""
# Check if image extraction and analysis is enabled before trying to get a vision LLM
if not get_image_extraction_and_analysis_enabled():
llm = None
else:
# Only get the vision LLM if image processing is enabled
llm = get_default_llm_with_vision()
if not llm:
logger.warning(
"No vision-capable LLM available. Image sections will not be processed."
)
# Even without LLM, we still convert to IndexingDocument with base Sections
return [
IndexingDocument(
**document.dict(),
processed_sections=[
Section(
text=section.text if isinstance(section, TextSection) else None,
link=section.link,
image_file_name=section.image_file_name
if isinstance(section, ImageSection)
else None,
)
for section in document.sections
],
)
for document in documents
]
indexed_documents: list[IndexingDocument] = []
for document in documents:
processed_sections: list[Section] = []
for section in document.sections:
# For ImageSection, process and create base Section with both text and image_file_name
if isinstance(section, ImageSection):
# Default section with image path preserved
processed_section = Section(
link=section.link,
image_file_name=section.image_file_name,
text=None, # Will be populated if summarization succeeds
)
# Try to get image summary
try:
with get_session_with_current_tenant() as db_session:
pgfilestore = get_pgfilestore_by_file_name(
file_name=section.image_file_name, db_session=db_session
)
if not pgfilestore:
logger.warning(
f"Image file {section.image_file_name} not found in PGFileStore"
)
processed_section.text = "[Image could not be processed]"
else:
# Get the image data
image_data_io = read_lobj(
pgfilestore.lobj_oid, db_session, mode="rb"
)
pgfilestore_data = image_data_io.read()
summary = summarize_image_with_error_handling(
llm=llm,
image_data=pgfilestore_data,
context_name=pgfilestore.display_name or "Image",
)
if summary:
processed_section.text = summary
else:
processed_section.text = (
"[Image could not be summarized]"
)
except Exception as e:
logger.error(f"Error processing image section: {e}")
processed_section.text = "[Error processing image]"
processed_sections.append(processed_section)
# For TextSection, create a base Section with text and link
elif isinstance(section, TextSection):
processed_section = Section(
text=section.text, link=section.link, image_file_name=None
)
processed_sections.append(processed_section)
# If it's already a base Section (unlikely), just append it
else:
processed_sections.append(section)
# Create IndexingDocument with original sections and processed_sections
indexed_document = IndexingDocument(
**document.dict(), processed_sections=processed_sections
)
indexed_documents.append(indexed_document)
return indexed_documents
@log_function_time(debug_only=True)
def index_doc_batch(
*,
@@ -498,23 +362,19 @@ def index_doc_batch(
failures=[],
)
# Convert documents to IndexingDocument objects with processed section
# logger.debug("Processing image sections")
ctx.indexable_docs = process_image_sections(ctx.updatable_docs)
doc_descriptors = [
{
"doc_id": doc.id,
"doc_length": doc.get_total_char_length(),
}
for doc in ctx.indexable_docs
for doc in ctx.updatable_docs
]
logger.debug(f"Starting indexing process for documents: {doc_descriptors}")
logger.debug("Starting chunking")
# NOTE: no special handling for failures here, since the chunker is not
# a common source of failure for the indexing pipeline
chunks: list[DocAwareChunk] = chunker.chunk(ctx.indexable_docs)
chunks: list[DocAwareChunk] = chunker.chunk(ctx.updatable_docs)
logger.debug("Starting embedding")
chunks_with_embeddings, embedding_failures = (

View File

@@ -5,9 +5,7 @@ from onyx.configs.app_configs import DISABLE_GENERATIVE_AI
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
from onyx.db.engine import get_session_context_manager
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.llm import fetch_default_provider
from onyx.db.llm import fetch_default_vision_provider
from onyx.db.llm import fetch_existing_llm_providers
from onyx.db.llm import fetch_provider
from onyx.db.models import Persona
@@ -16,7 +14,6 @@ from onyx.llm.exceptions import GenAIDisabledException
from onyx.llm.interfaces import LLM
from onyx.llm.override_models import LLMOverride
from onyx.llm.utils import model_supports_image_input
from onyx.server.manage.llm.models import FullLLMProvider
from onyx.utils.headers import build_llm_extra_headers
from onyx.utils.logger import setup_logger
from onyx.utils.long_term_log import LongTermLogger
@@ -97,61 +94,40 @@ def get_default_llm_with_vision(
additional_headers: dict[str, str] | None = None,
long_term_logger: LongTermLogger | None = None,
) -> LLM | None:
"""Get an LLM that supports image input, with the following priority:
1. Use the designated default vision provider if it exists and supports image input
2. Fall back to the first LLM provider that supports image input
Returns None if no providers exist or if no provider supports images.
"""
if DISABLE_GENERATIVE_AI:
raise GenAIDisabledException()
def create_vision_llm(provider: FullLLMProvider, model: str) -> LLM:
"""Helper to create an LLM if the provider supports image input."""
return get_llm(
provider=provider.provider,
model=model,
deployment_name=provider.deployment_name,
api_key=provider.api_key,
api_base=provider.api_base,
api_version=provider.api_version,
custom_config=provider.custom_config,
timeout=timeout,
temperature=temperature,
additional_headers=additional_headers,
long_term_logger=long_term_logger,
)
with get_session_context_manager() as db_session:
llm_providers = fetch_existing_llm_providers(db_session)
with get_session_with_current_tenant() as db_session:
# Try the default vision provider first
default_provider = fetch_default_vision_provider(db_session)
if (
default_provider
and default_provider.default_vision_model
and model_supports_image_input(
default_provider.default_vision_model, default_provider.provider
)
):
return create_vision_llm(
default_provider, default_provider.default_vision_model
)
# Fall back to searching all providers
providers = fetch_existing_llm_providers(db_session)
if not providers:
if not llm_providers:
return None
# Find the first provider that supports image input
for provider in providers:
if provider.default_vision_model and model_supports_image_input(
provider.default_vision_model, provider.provider
):
return create_vision_llm(
FullLLMProvider.from_model(provider), provider.default_vision_model
for provider in llm_providers:
model_name = provider.default_model_name
fast_model_name = (
provider.fast_default_model_name or provider.default_model_name
)
if not model_name or not fast_model_name:
continue
if model_supports_image_input(model_name, provider.provider):
return get_llm(
provider=provider.provider,
model=model_name,
deployment_name=provider.deployment_name,
api_key=provider.api_key,
api_base=provider.api_base,
api_version=provider.api_version,
custom_config=provider.custom_config,
timeout=timeout,
temperature=temperature,
additional_headers=additional_headers,
long_term_logger=long_term_logger,
)
return None
raise ValueError("No LLM provider found that supports image input")
def get_default_llms(

View File

@@ -234,8 +234,6 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
yield
SqlEngine.reset_engine()
if AUTH_RATE_LIMITING_ENABLED:
await close_auth_limiter()

View File

@@ -1,5 +1,5 @@
# Used for creating embeddings of images for vector search
DEFAULT_IMAGE_SUMMARIZATION_SYSTEM_PROMPT = """
IMAGE_SUMMARIZATION_SYSTEM_PROMPT = """
You are an assistant for summarizing images for retrieval.
Summarize the content of the following image and be as precise as possible.
The summary will be embedded and used to retrieve the original image.
@@ -7,13 +7,14 @@ Therefore, write a concise summary of the image that is optimized for retrieval.
"""
# Prompt for generating image descriptions with filename context
DEFAULT_IMAGE_SUMMARIZATION_USER_PROMPT = """
IMAGE_SUMMARIZATION_USER_PROMPT = """
The image has the file name '{title}'.
Describe precisely and concisely what the image shows.
"""
# Used for analyzing images in response to user queries at search time
DEFAULT_IMAGE_ANALYSIS_SYSTEM_PROMPT = (
IMAGE_ANALYSIS_SYSTEM_PROMPT = (
"You are an AI assistant specialized in describing images.\n"
"You will receive a user question plus an image URL. Provide a concise textual answer.\n"
"Focus on aspects of the image that are relevant to the user's question.\n"

View File

@@ -160,20 +160,6 @@ class RedisPool:
def get_replica_client(self, tenant_id: str) -> Redis:
return TenantRedis(tenant_id, connection_pool=self._replica_pool)
def get_raw_client(self) -> Redis:
"""
Returns a Redis client with direct access to the primary connection pool,
without tenant prefixing.
"""
return redis.Redis(connection_pool=self._pool)
def get_raw_replica_client(self) -> Redis:
"""
Returns a Redis client with direct access to the replica connection pool,
without tenant prefixing.
"""
return redis.Redis(connection_pool=self._replica_pool)
@staticmethod
def create_pool(
host: str = REDIS_HOST,
@@ -238,15 +224,6 @@ def get_redis_client(
# This argument will be deprecated in the future
tenant_id: str | None = None,
) -> Redis:
"""
Returns a Redis client with tenant-specific key prefixing.
This ensures proper data isolation between tenants by automatically
prefixing all Redis keys with the tenant ID.
Use this when working with tenant-specific data that should be
isolated from other tenants.
"""
if tenant_id is None:
tenant_id = get_current_tenant_id()
@@ -258,15 +235,6 @@ def get_redis_replica_client(
# this argument will be deprecated in the future
tenant_id: str | None = None,
) -> Redis:
"""
Returns a Redis replica client with tenant-specific key prefixing.
Similar to get_redis_client(), but connects to a read replica when available.
This ensures proper data isolation between tenants by automatically
prefixing all Redis keys with the tenant ID.
Use this for read-heavy operations on tenant-specific data.
"""
if tenant_id is None:
tenant_id = get_current_tenant_id()
@@ -274,57 +242,13 @@ def get_redis_replica_client(
def get_shared_redis_client() -> Redis:
"""
Returns a Redis client with a shared namespace prefix.
Unlike tenant-specific clients, this uses a common prefix for all keys,
creating a shared namespace accessible across all tenants.
Use this for data that should be shared across the application and
isn't specific to any individual tenant.
"""
return redis_pool.get_client(DEFAULT_REDIS_PREFIX)
def get_shared_redis_replica_client() -> Redis:
"""
Returns a Redis replica client with a shared namespace prefix.
Similar to get_shared_redis_client(), but connects to a read replica when available.
Uses a common prefix for all keys, creating a shared namespace.
Use this for read-heavy operations on data that should be shared
across the application.
"""
return redis_pool.get_replica_client(DEFAULT_REDIS_PREFIX)
def get_raw_redis_client() -> Redis:
"""
Returns a Redis client that doesn't apply tenant prefixing to keys.
Use this only when you need to access Redis directly without tenant isolation
or any key prefixing. Typically needed for integrating with external systems
or libraries that have inflexible key requirements.
Warning: Be careful with this client as it bypasses tenant isolation.
"""
return redis_pool.get_raw_client()
def get_raw_redis_replica_client() -> Redis:
"""
Returns a Redis replica client that doesn't apply tenant prefixing to keys.
Similar to get_raw_redis_client(), but connects to a read replica when available.
Use this for read-heavy operations that need direct Redis access without
tenant isolation or key prefixing.
Warning: Be careful with this client as it bypasses tenant isolation.
"""
return redis_pool.get_raw_replica_client()
SSL_CERT_REQS_MAP = {
"none": ssl.CERT_NONE,
"optional": ssl.CERT_OPTIONAL,

View File

@@ -15,7 +15,7 @@ from onyx.configs.model_configs import DEFAULT_DOCUMENT_ENCODER_MODEL
from onyx.connectors.models import Document
from onyx.connectors.models import IndexAttemptMetadata
from onyx.connectors.models import InputType
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.db.connector import check_connectors_exist
from onyx.db.connector import create_connector
from onyx.db.connector_credential_pair import add_credential_to_connector
@@ -55,7 +55,7 @@ def _create_indexable_chunks(
# The section is not really used past this point since we have already done the other processing
# for the chunking and embedding.
sections=[
TextSection(
Section(
text=preprocessed_doc["content"],
link=preprocessed_doc["url"],
image_file_name=None,

View File

@@ -5,7 +5,7 @@ from fastapi.dependencies.models import Dependant
from starlette.routing import BaseRoute
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_chat_accessible_user
from onyx.auth.users import current_chat_accesssible_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_limited_user
from onyx.auth.users import current_user
@@ -112,7 +112,7 @@ def check_router_auth(
or depends_fn == current_curator_or_admin_user
or depends_fn == api_key_dep
or depends_fn == current_user_with_expired_token
or depends_fn == current_chat_accessible_user
or depends_fn == current_chat_accesssible_user
or depends_fn == control_plane_dep
or depends_fn == current_cloud_superuser
):

View File

@@ -17,7 +17,7 @@ from pydantic import BaseModel
from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_chat_accessible_user
from onyx.auth.users import current_chat_accesssible_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_user
from onyx.background.celery.versioned_apps.primary import app as primary_app
@@ -1247,7 +1247,7 @@ class BasicCCPairInfo(BaseModel):
@router.get("/connector-status")
def get_basic_connector_indexing_status(
user: User = Depends(current_chat_accessible_user),
user: User = Depends(current_chat_accesssible_user),
db_session: Session = Depends(get_session),
) -> list[BasicCCPairInfo]:
cc_pairs = get_connector_credential_pairs_for_user(

View File

@@ -11,7 +11,7 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_chat_accessible_user
from onyx.auth.users import current_chat_accesssible_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_limited_user
from onyx.auth.users import current_user
@@ -390,7 +390,7 @@ def get_image_generation_tool(
@basic_router.get("")
def list_personas(
user: User | None = Depends(current_chat_accessible_user),
user: User | None = Depends(current_chat_accesssible_user),
db_session: Session = Depends(get_session),
include_deleted: bool = False,
persona_ids: list[int] = Query(None),

View File

@@ -7,14 +7,13 @@ from fastapi import Query
from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_chat_accessible_user
from onyx.auth.users import current_chat_accesssible_user
from onyx.db.engine import get_session
from onyx.db.llm import fetch_existing_llm_providers
from onyx.db.llm import fetch_existing_llm_providers_for_user
from onyx.db.llm import fetch_provider
from onyx.db.llm import remove_llm_provider
from onyx.db.llm import update_default_provider
from onyx.db.llm import update_default_vision_provider
from onyx.db.llm import upsert_llm_provider
from onyx.db.models import User
from onyx.llm.factory import get_default_llms
@@ -22,13 +21,11 @@ from onyx.llm.factory import get_llm
from onyx.llm.llm_provider_options import fetch_available_well_known_llms
from onyx.llm.llm_provider_options import WellKnownLLMProviderDescriptor
from onyx.llm.utils import litellm_exception_to_error_msg
from onyx.llm.utils import model_supports_image_input
from onyx.llm.utils import test_llm
from onyx.server.manage.llm.models import FullLLMProvider
from onyx.server.manage.llm.models import LLMProviderDescriptor
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
from onyx.server.manage.llm.models import TestLLMRequest
from onyx.server.manage.llm.models import VisionProviderResponse
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
@@ -189,68 +186,12 @@ def set_provider_as_default(
update_default_provider(provider_id=provider_id, db_session=db_session)
@admin_router.post("/provider/{provider_id}/default-vision")
def set_provider_as_default_vision(
provider_id: int,
vision_model: str
| None = Query(None, description="The default vision model to use"),
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
update_default_vision_provider(
provider_id=provider_id, vision_model=vision_model, db_session=db_session
)
@admin_router.get("/vision-providers")
def get_vision_capable_providers(
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[VisionProviderResponse]:
"""Return a list of LLM providers and their models that support image input"""
providers = fetch_existing_llm_providers(db_session)
vision_providers = []
logger.info("Fetching vision-capable providers")
for provider in providers:
vision_models = []
# Check model names in priority order
model_names_to_check = []
if provider.model_names:
model_names_to_check = provider.model_names
elif provider.display_model_names:
model_names_to_check = provider.display_model_names
elif provider.default_model_name:
model_names_to_check = [provider.default_model_name]
# Check each model for vision capability
for model_name in model_names_to_check:
if model_supports_image_input(model_name, provider.provider):
vision_models.append(model_name)
logger.debug(f"Vision model found: {provider.provider}/{model_name}")
# Only include providers with at least one vision-capable model
if vision_models:
provider_dict = FullLLMProvider.from_model(provider).model_dump()
provider_dict["vision_models"] = vision_models
logger.info(
f"Vision provider: {provider.provider} with models: {vision_models}"
)
vision_providers.append(VisionProviderResponse(**provider_dict))
logger.info(f"Found {len(vision_providers)} vision-capable providers")
return vision_providers
"""Endpoints for all"""
@basic_router.get("/provider")
def list_llm_provider_basics(
user: User | None = Depends(current_chat_accessible_user),
user: User | None = Depends(current_chat_accesssible_user),
db_session: Session = Depends(get_session),
) -> list[LLMProviderDescriptor]:
return [

View File

@@ -34,8 +34,6 @@ class LLMProviderDescriptor(BaseModel):
default_model_name: str
fast_default_model_name: str | None
is_default_provider: bool | None
is_default_vision_provider: bool | None
default_vision_model: str | None
display_model_names: list[str] | None
@classmethod
@@ -48,10 +46,11 @@ class LLMProviderDescriptor(BaseModel):
default_model_name=llm_provider_model.default_model_name,
fast_default_model_name=llm_provider_model.fast_default_model_name,
is_default_provider=llm_provider_model.is_default_provider,
is_default_vision_provider=llm_provider_model.is_default_vision_provider,
default_vision_model=llm_provider_model.default_vision_model,
model_names=llm_provider_model.model_names
or fetch_models_for_provider(llm_provider_model.provider),
model_names=(
llm_provider_model.model_names
or fetch_models_for_provider(llm_provider_model.provider)
or [llm_provider_model.default_model_name]
),
display_model_names=llm_provider_model.display_model_names,
)
@@ -69,7 +68,6 @@ class LLMProvider(BaseModel):
groups: list[int] = Field(default_factory=list)
display_model_names: list[str] | None = None
deployment_name: str | None = None
default_vision_model: str | None = None
class LLMProviderUpsertRequest(LLMProvider):
@@ -81,7 +79,6 @@ class LLMProviderUpsertRequest(LLMProvider):
class FullLLMProvider(LLMProvider):
id: int
is_default_provider: bool | None = None
is_default_vision_provider: bool | None = None
model_names: list[str]
@classmethod
@@ -97,8 +94,6 @@ class FullLLMProvider(LLMProvider):
default_model_name=llm_provider_model.default_model_name,
fast_default_model_name=llm_provider_model.fast_default_model_name,
is_default_provider=llm_provider_model.is_default_provider,
is_default_vision_provider=llm_provider_model.is_default_vision_provider,
default_vision_model=llm_provider_model.default_vision_model,
display_model_names=llm_provider_model.display_model_names,
model_names=(
llm_provider_model.model_names
@@ -109,9 +104,3 @@ class FullLLMProvider(LLMProvider):
groups=[group.id for group in llm_provider_model.groups],
deployment_name=llm_provider_model.deployment_name,
)
class VisionProviderResponse(FullLLMProvider):
"""Response model for vision providers endpoint, including vision-specific fields."""
vision_models: list[str]

View File

@@ -53,16 +53,6 @@ class UserPreferences(BaseModel):
temperature_override_enabled: bool | None = None
class TenantSnapshot(BaseModel):
tenant_id: str
number_of_users: int
class TenantInfo(BaseModel):
invitation: TenantSnapshot | None = None
new_tenant: TenantSnapshot | None = None
class UserInfo(BaseModel):
id: str
email: str
@@ -75,10 +65,9 @@ class UserInfo(BaseModel):
current_token_created_at: datetime | None = None
current_token_expiry_length: int | None = None
is_cloud_superuser: bool = False
team_name: str | None = None
organization_name: str | None = None
is_anonymous_user: bool | None = None
password_configured: bool | None = None
tenant_info: TenantInfo | None = None
@classmethod
def from_model(
@@ -87,9 +76,8 @@ class UserInfo(BaseModel):
current_token_created_at: datetime | None = None,
expiry_length: int | None = None,
is_cloud_superuser: bool = False,
team_name: str | None = None,
organization_name: str | None = None,
is_anonymous_user: bool | None = None,
tenant_info: TenantInfo | None = None,
) -> "UserInfo":
return cls(
id=str(user.id),
@@ -111,7 +99,7 @@ class UserInfo(BaseModel):
temperature_override_enabled=user.temperature_override_enabled,
)
),
team_name=team_name,
organization_name=organization_name,
# set to None if TRACK_EXTERNAL_IDP_EXPIRY is False so that we avoid cases
# where they previously had this set + used OIDC, and now they switched to
# basic auth are now constantly getting redirected back to the login page
@@ -121,7 +109,6 @@ class UserInfo(BaseModel):
current_token_expiry_length=expiry_length,
is_cloud_superuser=is_cloud_superuser,
is_anonymous_user=is_anonymous_user,
tenant_info=tenant_info,
)

Some files were not shown because too many files have changed in this diff Show More