Compare commits

..

5 Commits

Author SHA1 Message Date
pablonyx
9087320a06 fix 2025-03-06 14:46:20 -08:00
pablonyx
b0af1458c0 ensure checks pass 2025-03-06 14:46:20 -08:00
pablonyx
bb67a7a122 remove unnecessary logs 2025-03-06 14:46:20 -08:00
pablonyx
e239dc31c1 rename 2025-03-06 14:46:19 -08:00
pablonyx
027128502c add csl 2025-03-06 14:46:19 -08:00
245 changed files with 2019 additions and 5737 deletions

View File

@@ -12,40 +12,29 @@ env:
BUILDKIT_PROGRESS: plain
jobs:
# Bypassing this for now as the idea of not building is glitching
# releases and builds that depends on everything being tagged in docker
# 1) Preliminary job to check if the changed files are relevant
# check_model_server_changes:
# runs-on: ubuntu-latest
# outputs:
# changed: ${{ steps.check.outputs.changed }}
# steps:
# - name: Checkout code
# uses: actions/checkout@v4
#
# - name: Check if relevant files changed
# id: check
# run: |
# # Default to "false"
# echo "changed=false" >> $GITHUB_OUTPUT
#
# # Compare the previous commit (github.event.before) to the current one (github.sha)
# # If any file in backend/model_server/** or backend/Dockerfile.model_server is changed,
# # set changed=true
# if git diff --name-only ${{ github.event.before }} ${{ github.sha }} \
# | grep -E '^backend/model_server/|^backend/Dockerfile.model_server'; then
# echo "changed=true" >> $GITHUB_OUTPUT
# fi
# 1) Preliminary job to check if the changed files are relevant
check_model_server_changes:
runs-on: ubuntu-latest
outputs:
changed: "true"
changed: ${{ steps.check.outputs.changed }}
steps:
- name: Bypass check and set output
run: echo "changed=true" >> $GITHUB_OUTPUT
- name: Checkout code
uses: actions/checkout@v4
- name: Check if relevant files changed
id: check
run: |
# Default to "false"
echo "changed=false" >> $GITHUB_OUTPUT
# Compare the previous commit (github.event.before) to the current one (github.sha)
# If any file in backend/model_server/** or backend/Dockerfile.model_server is changed,
# set changed=true
if git diff --name-only ${{ github.event.before }} ${{ github.sha }} \
| grep -E '^backend/model_server/|^backend/Dockerfile.model_server'; then
echo "changed=true" >> $GITHUB_OUTPUT
fi
build-amd64:
needs: [check_model_server_changes]
if: needs.check_model_server_changes.outputs.changed == 'true'

View File

@@ -1,7 +1,6 @@
name: Connector Tests
on:
merge_group:
pull_request:
branches: [main]
schedule:
@@ -48,13 +47,11 @@ env:
# Gitbook
GITBOOK_SPACE_ID: ${{ secrets.GITBOOK_SPACE_ID }}
GITBOOK_API_KEY: ${{ secrets.GITBOOK_API_KEY }}
# Notion
NOTION_INTEGRATION_TOKEN: ${{ secrets.NOTION_INTEGRATION_TOKEN }}
jobs:
connectors-check:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}"]
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
env:
PYTHONPATH: ./backend
@@ -79,7 +76,7 @@ jobs:
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
playwright install chromium
playwright install-deps chromium
- name: Run Tests
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
run: py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/connectors

View File

@@ -114,4 +114,3 @@ To try the Onyx Enterprise Edition:
## 💡 Contributing
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.

View File

@@ -5,10 +5,7 @@ Revises: f1ca58b2f2ec
Create Date: 2025-01-29 07:48:46.784041
"""
import logging
from typing import cast
from alembic import op
from sqlalchemy.exc import IntegrityError
from sqlalchemy.sql import text
@@ -18,45 +15,21 @@ down_revision = "f1ca58b2f2ec"
branch_labels = None
depends_on = None
logger = logging.getLogger("alembic.runtime.migration")
def upgrade() -> None:
"""Conflicts on lowercasing will result in the uppercased email getting a
unique integer suffix when converted to lowercase."""
# Get database connection
connection = op.get_bind()
# Fetch all user emails that are not already lowercase
user_emails = connection.execute(
text('SELECT id, email FROM "user" WHERE email != LOWER(email)')
).fetchall()
for user_id, email in user_emails:
email = cast(str, email)
username, domain = email.rsplit("@", 1)
new_email = f"{username.lower()}@{domain.lower()}"
attempt = 1
while True:
try:
# Try updating the email
connection.execute(
text('UPDATE "user" SET email = :new_email WHERE id = :user_id'),
{"new_email": new_email, "user_id": user_id},
)
break # Success, exit loop
except IntegrityError:
next_email = f"{username.lower()}_{attempt}@{domain.lower()}"
# Email conflict occurred, append `_1`, `_2`, etc., to the username
logger.warning(
f"Conflict while lowercasing email: "
f"old_email={email} "
f"conflicting_email={new_email} "
f"next_email={next_email}"
)
new_email = next_email
attempt += 1
# Update all user emails to lowercase
connection.execute(
text(
"""
UPDATE "user"
SET email = LOWER(email)
WHERE email != LOWER(email)
"""
)
)
def downgrade() -> None:

View File

@@ -1,45 +0,0 @@
"""add_default_vision_provider_to_llm_provider
Revision ID: df46c75b714e
Revises: 3934b1bc7b62
Create Date: 2025-03-11 16:20:19.038945
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "df46c75b714e"
down_revision = "3934b1bc7b62"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"llm_provider",
sa.Column(
"is_default_vision_provider",
sa.Boolean(),
nullable=True,
server_default=sa.false(),
),
)
op.add_column(
"llm_provider", sa.Column("default_vision_model", sa.String(), nullable=True)
)
# Add unique constraint for is_default_vision_provider
op.create_unique_constraint(
"uq_llm_provider_is_default_vision_provider",
"llm_provider",
["is_default_vision_provider"],
)
def downgrade() -> None:
op.drop_constraint(
"uq_llm_provider_is_default_vision_provider", "llm_provider", type_="unique"
)
op.drop_column("llm_provider", "default_vision_model")
op.drop_column("llm_provider", "is_default_vision_provider")

View File

@@ -1,33 +0,0 @@
"""add new available tenant table
Revision ID: 3b45e0018bf1
Revises: ac842f85f932
Create Date: 2025-03-06 09:55:18.229910
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "3b45e0018bf1"
down_revision = "ac842f85f932"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create new_available_tenant table
op.create_table(
"available_tenant",
sa.Column("tenant_id", sa.String(), nullable=False),
sa.Column("alembic_version", sa.String(), nullable=False),
sa.Column("date_created", sa.DateTime(), nullable=False),
sa.PrimaryKeyConstraint("tenant_id"),
)
def downgrade() -> None:
# Drop new_available_tenant table
op.drop_table("available_tenant")

View File

@@ -1,51 +0,0 @@
"""new column user tenant mapping
Revision ID: ac842f85f932
Revises: 34e3630c7f32
Create Date: 2025-03-03 13:30:14.802874
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "ac842f85f932"
down_revision = "34e3630c7f32"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add active column with default value of True
op.add_column(
"user_tenant_mapping",
sa.Column(
"active",
sa.Boolean(),
nullable=False,
server_default="true",
),
schema="public",
)
op.drop_constraint("uq_email", "user_tenant_mapping", schema="public")
# Create a unique index for active=true records
# This ensures a user can only be active in one tenant at a time
op.execute(
"CREATE UNIQUE INDEX uq_user_active_email_idx ON public.user_tenant_mapping (email) WHERE active = true"
)
def downgrade() -> None:
# Drop the unique index for active=true records
op.execute("DROP INDEX IF EXISTS uq_user_active_email_idx")
op.create_unique_constraint(
"uq_email", "user_tenant_mapping", ["email"], schema="public"
)
# Remove the active column
op.drop_column("user_tenant_mapping", "active", schema="public")

View File

@@ -27,8 +27,6 @@ def get_empty_chat_messages_entries__paginated(
first element is the most recent timestamp out of the sessions iterated
- this timestamp can be used to paginate forward in time
second element is a list of messages belonging to all the sessions iterated
Only messages of type USER are returned
"""
chat_sessions = fetch_chat_sessions_eagerly_by_time(
start=period[0],

View File

@@ -1,14 +1,10 @@
import re
from typing import cast
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from ee.onyx.server.query_and_chat.models import AgentAnswer
from ee.onyx.server.query_and_chat.models import AgentSubQuery
from ee.onyx.server.query_and_chat.models import AgentSubQuestion
from ee.onyx.server.query_and_chat.models import BasicCreateChatMessageRequest
from ee.onyx.server.query_and_chat.models import (
BasicCreateChatMessageWithHistoryRequest,
@@ -18,19 +14,13 @@ from ee.onyx.server.query_and_chat.models import SimpleDoc
from onyx.auth.users import current_user
from onyx.chat.chat_utils import combine_message_thread
from onyx.chat.chat_utils import create_chat_chain
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import AllCitations
from onyx.chat.models import ExtendedToolResponse
from onyx.chat.models import FinalUsedContextDocsResponse
from onyx.chat.models import LlmDoc
from onyx.chat.models import LLMRelevanceFilterResponse
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import QADocsResponse
from onyx.chat.models import RefinedAnswerImprovement
from onyx.chat.models import StreamingError
from onyx.chat.models import SubQueryPiece
from onyx.chat.models import SubQuestionIdentifier
from onyx.chat.models import SubQuestionPiece
from onyx.chat.process_message import ChatPacketStream
from onyx.chat.process_message import stream_chat_message_objects
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
@@ -99,12 +89,6 @@ def _convert_packet_stream_to_response(
final_context_docs: list[LlmDoc] = []
answer = ""
# accumulate stream data with these dicts
agent_sub_questions: dict[tuple[int, int], AgentSubQuestion] = {}
agent_answers: dict[tuple[int, int], AgentAnswer] = {}
agent_sub_queries: dict[tuple[int, int, int], AgentSubQuery] = {}
for packet in packets:
if isinstance(packet, OnyxAnswerPiece) and packet.answer_piece:
answer += packet.answer_piece
@@ -113,15 +97,6 @@ def _convert_packet_stream_to_response(
# TODO: deprecate `simple_search_docs`
response.simple_search_docs = _translate_doc_response_to_simple_doc(packet)
# This is a no-op if agent_sub_questions hasn't already been filled
if packet.level is not None and packet.level_question_num is not None:
id = (packet.level, packet.level_question_num)
if id in agent_sub_questions:
agent_sub_questions[id].document_ids = [
saved_search_doc.document_id
for saved_search_doc in packet.top_documents
]
elif isinstance(packet, StreamingError):
response.error_msg = packet.error
elif isinstance(packet, ChatMessageDetail):
@@ -138,104 +113,11 @@ def _convert_packet_stream_to_response(
citation.citation_num: citation.document_id
for citation in packet.citations
}
# agentic packets
elif isinstance(packet, SubQuestionPiece):
if packet.level is not None and packet.level_question_num is not None:
id = (packet.level, packet.level_question_num)
if agent_sub_questions.get(id) is None:
agent_sub_questions[id] = AgentSubQuestion(
level=packet.level,
level_question_num=packet.level_question_num,
sub_question=packet.sub_question,
document_ids=[],
)
else:
agent_sub_questions[id].sub_question += packet.sub_question
elif isinstance(packet, AgentAnswerPiece):
if packet.level is not None and packet.level_question_num is not None:
id = (packet.level, packet.level_question_num)
if agent_answers.get(id) is None:
agent_answers[id] = AgentAnswer(
level=packet.level,
level_question_num=packet.level_question_num,
answer=packet.answer_piece,
answer_type=packet.answer_type,
)
else:
agent_answers[id].answer += packet.answer_piece
elif isinstance(packet, SubQueryPiece):
if packet.level is not None and packet.level_question_num is not None:
sub_query_id = (
packet.level,
packet.level_question_num,
packet.query_id,
)
if agent_sub_queries.get(sub_query_id) is None:
agent_sub_queries[sub_query_id] = AgentSubQuery(
level=packet.level,
level_question_num=packet.level_question_num,
sub_query=packet.sub_query,
query_id=packet.query_id,
)
else:
agent_sub_queries[sub_query_id].sub_query += packet.sub_query
elif isinstance(packet, ExtendedToolResponse):
# we shouldn't get this ... it gets intercepted and translated to QADocsResponse
logger.warning(
"_convert_packet_stream_to_response: Unexpected chat packet type ExtendedToolResponse!"
)
elif isinstance(packet, RefinedAnswerImprovement):
response.agent_refined_answer_improvement = (
packet.refined_answer_improvement
)
else:
logger.warning(
f"_convert_packet_stream_to_response - Unrecognized chat packet: type={type(packet)}"
)
response.final_context_doc_indices = _get_final_context_doc_indices(
final_context_docs, response.top_documents
)
# organize / sort agent metadata for output
if len(agent_sub_questions) > 0:
response.agent_sub_questions = cast(
dict[int, list[AgentSubQuestion]],
SubQuestionIdentifier.make_dict_by_level(agent_sub_questions),
)
if len(agent_answers) > 0:
# return the agent_level_answer from the first level or the last one depending
# on agent_refined_answer_improvement
response.agent_answers = cast(
dict[int, list[AgentAnswer]],
SubQuestionIdentifier.make_dict_by_level(agent_answers),
)
if response.agent_answers:
selected_answer_level = (
0
if not response.agent_refined_answer_improvement
else len(response.agent_answers) - 1
)
level_answers = response.agent_answers[selected_answer_level]
for level_answer in level_answers:
if level_answer.answer_type != "agent_level_answer":
continue
answer = level_answer.answer
break
if len(agent_sub_queries) > 0:
# subqueries are often emitted with trailing whitespace ... clean it up here
# perhaps fix at the source?
for v in agent_sub_queries.values():
v.sub_query = v.sub_query.strip()
response.agent_sub_queries = (
AgentSubQuery.make_dict_by_level_and_question_index(agent_sub_queries)
)
response.answer = answer
if answer:
response.answer_citationless = remove_answer_citations(answer)

View File

@@ -1,5 +1,3 @@
from collections import OrderedDict
from typing import Literal
from uuid import UUID
from pydantic import BaseModel
@@ -11,7 +9,6 @@ from onyx.chat.models import CitationInfo
from onyx.chat.models import OnyxContexts
from onyx.chat.models import PersonaOverrideConfig
from onyx.chat.models import QADocsResponse
from onyx.chat.models import SubQuestionIdentifier
from onyx.chat.models import ThreadMessage
from onyx.configs.constants import DocumentSource
from onyx.context.search.enums import LLMEvaluationType
@@ -91,64 +88,6 @@ class SimpleDoc(BaseModel):
metadata: dict | None
class AgentSubQuestion(SubQuestionIdentifier):
sub_question: str
document_ids: list[str]
class AgentAnswer(SubQuestionIdentifier):
answer: str
answer_type: Literal["agent_sub_answer", "agent_level_answer"]
class AgentSubQuery(SubQuestionIdentifier):
sub_query: str
query_id: int
@staticmethod
def make_dict_by_level_and_question_index(
original_dict: dict[tuple[int, int, int], "AgentSubQuery"]
) -> dict[int, dict[int, list["AgentSubQuery"]]]:
"""Takes a dict of tuple(level, question num, query_id) to sub queries.
returns a dict of level to dict[question num to list of query_id's]
Ordering is asc for readability.
"""
# In this function, when we sort int | None, we deliberately push None to the end
# map entries to the level_question_dict
level_question_dict: dict[int, dict[int, list["AgentSubQuery"]]] = {}
for k1, obj in original_dict.items():
level = k1[0]
question = k1[1]
if level not in level_question_dict:
level_question_dict[level] = {}
if question not in level_question_dict[level]:
level_question_dict[level][question] = []
level_question_dict[level][question].append(obj)
# sort each query_id list and question_index
for key1, obj1 in level_question_dict.items():
for key2, value2 in obj1.items():
# sort the query_id list of each question_index
level_question_dict[key1][key2] = sorted(
value2, key=lambda o: o.query_id
)
# sort the question_index dict of level
level_question_dict[key1] = OrderedDict(
sorted(level_question_dict[key1].items(), key=lambda x: (x is None, x))
)
# sort the top dict of levels
sorted_dict = OrderedDict(
sorted(level_question_dict.items(), key=lambda x: (x is None, x))
)
return sorted_dict
class ChatBasicResponse(BaseModel):
# This is built piece by piece, any of these can be None as the flow could break
answer: str | None = None
@@ -168,12 +107,6 @@ class ChatBasicResponse(BaseModel):
simple_search_docs: list[SimpleDoc] | None = None
llm_chunks_indices: list[int] | None = None
# agentic fields
agent_sub_questions: dict[int, list[AgentSubQuestion]] | None = None
agent_answers: dict[int, list[AgentAnswer]] | None = None
agent_sub_queries: dict[int, dict[int, list[AgentSubQuery]]] | None = None
agent_refined_answer_improvement: bool | None = None
class OneShotQARequest(ChunkContext):
# Supports simplier APIs that don't deal with chat histories or message edits

View File

@@ -48,15 +48,10 @@ def fetch_and_process_chat_session_history(
feedback_type: QAFeedbackType | None,
limit: int | None = 500,
) -> list[ChatSessionSnapshot]:
# observed to be slow a scale of 8192 sessions and 4 messages per session
# this is a little slow (5 seconds)
chat_sessions = fetch_chat_sessions_eagerly_by_time(
start=start, end=end, db_session=db_session, limit=limit
)
# this is VERY slow (80 seconds) due to create_chat_chain being called
# for each session. Needs optimizing.
chat_session_snapshots = [
snapshot_from_chat_session(chat_session=chat_session, db_session=db_session)
for chat_session in chat_sessions
@@ -251,8 +246,6 @@ def get_query_history_as_csv(
detail="Query history has been disabled by the administrator.",
)
# this call is very expensive and is timing out via endpoint
# TODO: optimize call and/or generate via background task
complete_chat_session_history = fetch_and_process_chat_session_history(
db_session=db_session,
start=start or datetime.fromtimestamp(0, tz=timezone.utc),

View File

@@ -1,45 +0,0 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Response
from ee.onyx.auth.users import current_cloud_superuser
from ee.onyx.server.tenants.models import ImpersonateRequest
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
from onyx.auth.users import auth_backend
from onyx.auth.users import get_redis_strategy
from onyx.auth.users import User
from onyx.db.engine import get_session_with_tenant
from onyx.db.users import get_user_by_email
from onyx.utils.logger import setup_logger
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.post("/impersonate")
async def impersonate_user(
impersonate_request: ImpersonateRequest,
_: User = Depends(current_cloud_superuser),
) -> Response:
"""Allows a cloud superuser to impersonate another user by generating an impersonation JWT token"""
tenant_id = get_tenant_id_for_email(impersonate_request.email)
with get_session_with_tenant(tenant_id=tenant_id) as tenant_session:
user_to_impersonate = get_user_by_email(
impersonate_request.email, tenant_session
)
if user_to_impersonate is None:
raise HTTPException(status_code=404, detail="User not found")
token = await get_redis_strategy().write_token(user_to_impersonate)
response = await auth_backend.transport.get_login_response(token)
response.set_cookie(
key="fastapiusersauth",
value=token,
httponly=True,
secure=True,
samesite="lax",
)
return response

View File

@@ -1,98 +0,0 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Response
from sqlalchemy.exc import IntegrityError
from ee.onyx.auth.users import generate_anonymous_user_jwt_token
from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME
from ee.onyx.server.tenants.anonymous_user_path import get_anonymous_user_path
from ee.onyx.server.tenants.anonymous_user_path import (
get_tenant_id_for_anonymous_user_path,
)
from ee.onyx.server.tenants.anonymous_user_path import modify_anonymous_user_path
from ee.onyx.server.tenants.anonymous_user_path import validate_anonymous_user_path
from ee.onyx.server.tenants.models import AnonymousUserPath
from onyx.auth.users import anonymous_user_enabled
from onyx.auth.users import current_admin_user
from onyx.auth.users import optional_user
from onyx.auth.users import User
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
from onyx.db.engine import get_session_with_shared_schema
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.get("/anonymous-user-path")
async def get_anonymous_user_path_api(
_: User | None = Depends(current_admin_user),
) -> AnonymousUserPath:
tenant_id = get_current_tenant_id()
if tenant_id is None:
raise HTTPException(status_code=404, detail="Tenant not found")
with get_session_with_shared_schema() as db_session:
current_path = get_anonymous_user_path(tenant_id, db_session)
return AnonymousUserPath(anonymous_user_path=current_path)
@router.post("/anonymous-user-path")
async def set_anonymous_user_path_api(
anonymous_user_path: str,
_: User | None = Depends(current_admin_user),
) -> None:
tenant_id = get_current_tenant_id()
try:
validate_anonymous_user_path(anonymous_user_path)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
with get_session_with_shared_schema() as db_session:
try:
modify_anonymous_user_path(tenant_id, anonymous_user_path, db_session)
except IntegrityError:
raise HTTPException(
status_code=409,
detail="The anonymous user path is already in use. Please choose a different path.",
)
except Exception as e:
logger.exception(f"Failed to modify anonymous user path: {str(e)}")
raise HTTPException(
status_code=500,
detail="An unexpected error occurred while modifying the anonymous user path",
)
@router.post("/anonymous-user")
async def login_as_anonymous_user(
anonymous_user_path: str,
_: User | None = Depends(optional_user),
) -> Response:
with get_session_with_shared_schema() as db_session:
tenant_id = get_tenant_id_for_anonymous_user_path(
anonymous_user_path, db_session
)
if not tenant_id:
raise HTTPException(status_code=404, detail="Tenant not found")
if not anonymous_user_enabled(tenant_id=tenant_id):
raise HTTPException(status_code=403, detail="Anonymous user is not enabled")
token = generate_anonymous_user_jwt_token(tenant_id)
response = Response()
response.delete_cookie(FASTAPI_USERS_AUTH_COOKIE_NAME)
response.set_cookie(
key=ANONYMOUS_USER_COOKIE_NAME,
value=token,
httponly=True,
secure=True,
samesite="strict",
)
return response

View File

@@ -1,24 +1,269 @@
import stripe
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Response
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from ee.onyx.server.tenants.admin_api import router as admin_router
from ee.onyx.server.tenants.anonymous_users_api import router as anonymous_users_router
from ee.onyx.server.tenants.billing_api import router as billing_router
from ee.onyx.server.tenants.team_membership_api import router as team_membership_router
from ee.onyx.server.tenants.tenant_management_api import (
router as tenant_management_router,
)
from ee.onyx.server.tenants.user_invitations_api import (
router as user_invitations_router,
from ee.onyx.auth.users import current_cloud_superuser
from ee.onyx.auth.users import generate_anonymous_user_jwt_token
from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
from ee.onyx.server.tenants.access import control_plane_dep
from ee.onyx.server.tenants.anonymous_user_path import get_anonymous_user_path
from ee.onyx.server.tenants.anonymous_user_path import (
get_tenant_id_for_anonymous_user_path,
)
from ee.onyx.server.tenants.anonymous_user_path import modify_anonymous_user_path
from ee.onyx.server.tenants.anonymous_user_path import validate_anonymous_user_path
from ee.onyx.server.tenants.billing import fetch_billing_information
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
from ee.onyx.server.tenants.models import AnonymousUserPath
from ee.onyx.server.tenants.models import BillingInformation
from ee.onyx.server.tenants.models import ImpersonateRequest
from ee.onyx.server.tenants.models import ProductGatingRequest
from ee.onyx.server.tenants.models import ProductGatingResponse
from ee.onyx.server.tenants.models import SubscriptionSessionResponse
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
from ee.onyx.server.tenants.product_gating import store_product_gating
from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant
from onyx.auth.users import anonymous_user_enabled
from onyx.auth.users import auth_backend
from onyx.auth.users import current_admin_user
from onyx.auth.users import get_redis_strategy
from onyx.auth.users import optional_user
from onyx.auth.users import User
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
from onyx.db.auth import get_user_count
from onyx.db.engine import get_session
from onyx.db.engine import get_session_with_shared_schema
from onyx.db.engine import get_session_with_tenant
from onyx.db.users import delete_user_from_db
from onyx.db.users import get_user_by_email
from onyx.server.manage.models import UserByEmail
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
# Create a main router to include all sub-routers
# Note: We don't add a prefix here as each router already has the /tenants prefix
router = APIRouter()
stripe.api_key = STRIPE_SECRET_KEY
logger = setup_logger()
router = APIRouter(prefix="/tenants")
# Include all the individual routers
router.include_router(admin_router)
router.include_router(anonymous_users_router)
router.include_router(billing_router)
router.include_router(team_membership_router)
router.include_router(tenant_management_router)
router.include_router(user_invitations_router)
@router.get("/anonymous-user-path")
async def get_anonymous_user_path_api(
_: User | None = Depends(current_admin_user),
) -> AnonymousUserPath:
tenant_id = get_current_tenant_id()
if tenant_id is None:
raise HTTPException(status_code=404, detail="Tenant not found")
with get_session_with_shared_schema() as db_session:
current_path = get_anonymous_user_path(tenant_id, db_session)
return AnonymousUserPath(anonymous_user_path=current_path)
@router.post("/anonymous-user-path")
async def set_anonymous_user_path_api(
anonymous_user_path: str,
_: User | None = Depends(current_admin_user),
) -> None:
tenant_id = get_current_tenant_id()
try:
validate_anonymous_user_path(anonymous_user_path)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
with get_session_with_shared_schema() as db_session:
try:
modify_anonymous_user_path(tenant_id, anonymous_user_path, db_session)
except IntegrityError:
raise HTTPException(
status_code=409,
detail="The anonymous user path is already in use. Please choose a different path.",
)
except Exception as e:
logger.exception(f"Failed to modify anonymous user path: {str(e)}")
raise HTTPException(
status_code=500,
detail="An unexpected error occurred while modifying the anonymous user path",
)
@router.post("/anonymous-user")
async def login_as_anonymous_user(
anonymous_user_path: str,
_: User | None = Depends(optional_user),
) -> Response:
with get_session_with_shared_schema() as db_session:
tenant_id = get_tenant_id_for_anonymous_user_path(
anonymous_user_path, db_session
)
if not tenant_id:
raise HTTPException(status_code=404, detail="Tenant not found")
if not anonymous_user_enabled(tenant_id=tenant_id):
raise HTTPException(status_code=403, detail="Anonymous user is not enabled")
token = generate_anonymous_user_jwt_token(tenant_id)
response = Response()
response.delete_cookie(FASTAPI_USERS_AUTH_COOKIE_NAME)
response.set_cookie(
key=ANONYMOUS_USER_COOKIE_NAME,
value=token,
httponly=True,
secure=True,
samesite="strict",
)
return response
@router.post("/product-gating")
def gate_product(
product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep)
) -> ProductGatingResponse:
"""
Gating the product means that the product is not available to the tenant.
They will be directed to the billing page.
We gate the product when their subscription has ended.
"""
try:
store_product_gating(
product_gating_request.tenant_id, product_gating_request.application_status
)
return ProductGatingResponse(updated=True, error=None)
except Exception as e:
logger.exception("Failed to gate product")
return ProductGatingResponse(updated=False, error=str(e))
@router.get("/billing-information")
async def billing_information(
_: User = Depends(current_admin_user),
) -> BillingInformation | SubscriptionStatusResponse:
logger.info("Fetching billing information")
tenant_id = get_current_tenant_id()
return fetch_billing_information(tenant_id)
@router.post("/create-customer-portal-session")
async def create_customer_portal_session(
_: User = Depends(current_admin_user),
) -> dict:
tenant_id = get_current_tenant_id()
try:
stripe_info = fetch_tenant_stripe_information(tenant_id)
stripe_customer_id = stripe_info.get("stripe_customer_id")
if not stripe_customer_id:
raise HTTPException(status_code=400, detail="Stripe customer ID not found")
logger.info(stripe_customer_id)
portal_session = stripe.billing_portal.Session.create(
customer=stripe_customer_id,
return_url=f"{WEB_DOMAIN}/admin/billing",
)
logger.info(portal_session)
return {"url": portal_session.url}
except Exception as e:
logger.exception("Failed to create customer portal session")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/create-subscription-session")
async def create_subscription_session(
_: User = Depends(current_admin_user),
) -> SubscriptionSessionResponse:
try:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if not tenant_id:
raise HTTPException(status_code=400, detail="Tenant ID not found")
session_id = fetch_stripe_checkout_session(tenant_id)
return SubscriptionSessionResponse(sessionId=session_id)
except Exception as e:
logger.exception("Failed to create resubscription session")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/impersonate")
async def impersonate_user(
impersonate_request: ImpersonateRequest,
_: User = Depends(current_cloud_superuser),
) -> Response:
"""Allows a cloud superuser to impersonate another user by generating an impersonation JWT token"""
tenant_id = get_tenant_id_for_email(impersonate_request.email)
with get_session_with_tenant(tenant_id=tenant_id) as tenant_session:
user_to_impersonate = get_user_by_email(
impersonate_request.email, tenant_session
)
if user_to_impersonate is None:
raise HTTPException(status_code=404, detail="User not found")
token = await get_redis_strategy().write_token(user_to_impersonate)
response = await auth_backend.transport.get_login_response(token)
response.set_cookie(
key="fastapiusersauth",
value=token,
httponly=True,
secure=True,
samesite="lax",
)
return response
@router.post("/leave-organization")
async def leave_organization(
user_email: UserByEmail,
current_user: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
tenant_id = get_current_tenant_id()
if current_user is None or current_user.email != user_email.user_email:
raise HTTPException(
status_code=403, detail="You can only leave the organization as yourself"
)
user_to_delete = get_user_by_email(user_email.user_email, db_session)
if user_to_delete is None:
raise HTTPException(status_code=404, detail="User not found")
num_admin_users = await get_user_count(only_admin_users=True)
should_delete_tenant = num_admin_users == 1
if should_delete_tenant:
logger.info(
"Last admin user is leaving the organization. Deleting tenant from control plane."
)
try:
await delete_user_from_control_plane(tenant_id, user_to_delete.email)
logger.debug("User deleted from control plane")
except Exception as e:
logger.exception(
f"Failed to delete user from control plane for tenant {tenant_id}: {e}"
)
raise HTTPException(
status_code=500,
detail=f"Failed to remove user from control plane: {str(e)}",
)
db_session.expunge(user_to_delete)
delete_user_from_db(user_to_delete, db_session)
if should_delete_tenant:
remove_all_users_from_tenant(tenant_id)
else:
remove_users_from_tenant([user_to_delete.email], tenant_id)

View File

@@ -1,96 +0,0 @@
import stripe
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from ee.onyx.auth.users import current_admin_user
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
from ee.onyx.server.tenants.access import control_plane_dep
from ee.onyx.server.tenants.billing import fetch_billing_information
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
from ee.onyx.server.tenants.models import BillingInformation
from ee.onyx.server.tenants.models import ProductGatingRequest
from ee.onyx.server.tenants.models import ProductGatingResponse
from ee.onyx.server.tenants.models import SubscriptionSessionResponse
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
from ee.onyx.server.tenants.product_gating import store_product_gating
from onyx.auth.users import User
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
stripe.api_key = STRIPE_SECRET_KEY
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.post("/product-gating")
def gate_product(
product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep)
) -> ProductGatingResponse:
"""
Gating the product means that the product is not available to the tenant.
They will be directed to the billing page.
We gate the product when their subscription has ended.
"""
try:
store_product_gating(
product_gating_request.tenant_id, product_gating_request.application_status
)
return ProductGatingResponse(updated=True, error=None)
except Exception as e:
logger.exception("Failed to gate product")
return ProductGatingResponse(updated=False, error=str(e))
@router.get("/billing-information")
async def billing_information(
_: User = Depends(current_admin_user),
) -> BillingInformation | SubscriptionStatusResponse:
logger.info("Fetching billing information")
tenant_id = get_current_tenant_id()
return fetch_billing_information(tenant_id)
@router.post("/create-customer-portal-session")
async def create_customer_portal_session(
_: User = Depends(current_admin_user),
) -> dict:
tenant_id = get_current_tenant_id()
try:
stripe_info = fetch_tenant_stripe_information(tenant_id)
stripe_customer_id = stripe_info.get("stripe_customer_id")
if not stripe_customer_id:
raise HTTPException(status_code=400, detail="Stripe customer ID not found")
logger.info(stripe_customer_id)
portal_session = stripe.billing_portal.Session.create(
customer=stripe_customer_id,
return_url=f"{WEB_DOMAIN}/admin/billing",
)
logger.info(portal_session)
return {"url": portal_session.url}
except Exception as e:
logger.exception("Failed to create customer portal session")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/create-subscription-session")
async def create_subscription_session(
_: User = Depends(current_admin_user),
) -> SubscriptionSessionResponse:
try:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if not tenant_id:
raise HTTPException(status_code=400, detail="Tenant ID not found")
session_id = fetch_stripe_checkout_session(tenant_id)
return SubscriptionSessionResponse(sessionId=session_id)
except Exception as e:
logger.exception("Failed to create resubscription session")
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -67,30 +67,3 @@ class ProductGatingResponse(BaseModel):
class SubscriptionSessionResponse(BaseModel):
sessionId: str
class TenantByDomainResponse(BaseModel):
tenant_id: str
number_of_users: int
creator_email: str
class TenantByDomainRequest(BaseModel):
email: str
class RequestInviteRequest(BaseModel):
tenant_id: str
class RequestInviteResponse(BaseModel):
success: bool
message: str
class PendingUserSnapshot(BaseModel):
email: str
class ApproveUserRequest(BaseModel):
email: str

View File

@@ -4,7 +4,6 @@ import uuid
import aiohttp # Async HTTP client
import httpx
import requests
from fastapi import HTTPException
from fastapi import Request
from sqlalchemy import select
@@ -15,7 +14,6 @@ from ee.onyx.configs.app_configs import COHERE_DEFAULT_API_KEY
from ee.onyx.configs.app_configs import HUBSPOT_TRACKING_URL
from ee.onyx.configs.app_configs import OPENAI_DEFAULT_API_KEY
from ee.onyx.server.tenants.access import generate_data_plane_token
from ee.onyx.server.tenants.models import TenantByDomainResponse
from ee.onyx.server.tenants.models import TenantCreationPayload
from ee.onyx.server.tenants.models import TenantDeletionPayload
from ee.onyx.server.tenants.schema_management import create_schema_if_not_exists
@@ -28,12 +26,11 @@ from onyx.auth.users import exceptions
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
from onyx.configs.app_configs import DEV_MODE
from onyx.configs.constants import MilestoneRecordType
from onyx.db.engine import get_session_with_shared_schema
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_sqlalchemy_engine
from onyx.db.llm import update_default_provider
from onyx.db.llm import upsert_cloud_embedding_provider
from onyx.db.llm import upsert_llm_provider
from onyx.db.models import AvailableTenant
from onyx.db.models import IndexModelStatus
from onyx.db.models import SearchSettings
from onyx.db.models import UserTenantMapping
@@ -63,72 +60,42 @@ async def get_or_provision_tenant(
This function should only be called after we have verified we want this user's tenant to exist.
It returns the tenant ID associated with the email, creating a new tenant if necessary.
"""
# Early return for non-multi-tenant mode
if not MULTI_TENANT:
return POSTGRES_DEFAULT_SCHEMA
if referral_source and request:
await submit_to_hubspot(email, referral_source, request)
# First, check if the user already has a tenant
tenant_id: str | None = None
try:
tenant_id = get_tenant_id_for_email(email)
return tenant_id
except exceptions.UserNotExists:
# User doesn't exist, so we need to create a new tenant or assign an existing one
pass
try:
# Try to get a pre-provisioned tenant
tenant_id = await get_available_tenant()
if tenant_id:
# If we have a pre-provisioned tenant, assign it to the user
await assign_tenant_to_user(tenant_id, email, referral_source)
logger.info(f"Assigned pre-provisioned tenant {tenant_id} to user {email}")
return tenant_id
else:
# If no pre-provisioned tenant is available, create a new one on-demand
# If tenant does not exist and in Multi tenant mode, provision a new tenant
try:
tenant_id = await create_tenant(email, referral_source)
return tenant_id
except Exception as e:
logger.error(f"Tenant provisioning failed: {e}")
raise HTTPException(status_code=500, detail="Failed to provision tenant.")
except Exception as e:
# If we've encountered an error, log and raise an exception
error_msg = "Failed to provision tenant"
logger.error(error_msg, exc_info=e)
if not tenant_id:
raise HTTPException(
status_code=500,
detail="Failed to provision tenant. Please try again later.",
status_code=401, detail="User does not belong to an organization"
)
return tenant_id
async def create_tenant(email: str, referral_source: str | None = None) -> str:
"""
Create a new tenant on-demand when no pre-provisioned tenants are available.
This is the fallback method when we can't use a pre-provisioned tenant.
"""
tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4())
logger.info(f"Creating new tenant {tenant_id} for user {email}")
try:
# Provision tenant on data plane
await provision_tenant(tenant_id, email)
# Notify control plane if not already done in provision_tenant
if not DEV_MODE and referral_source:
# Notify control plane
if not DEV_MODE:
await notify_control_plane(tenant_id, email, referral_source)
except Exception as e:
logger.exception(f"Tenant provisioning failed: {str(e)}")
# Attempt to rollback the tenant provisioning
try:
await rollback_tenant_provisioning(tenant_id)
except Exception:
logger.exception(f"Failed to rollback tenant provisioning for {tenant_id}")
logger.error(f"Tenant provisioning failed: {e}")
await rollback_tenant_provisioning(tenant_id)
raise HTTPException(status_code=500, detail="Failed to provision tenant.")
return tenant_id
@@ -142,25 +109,54 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
)
logger.debug(f"Provisioning tenant {tenant_id} for user {email}")
token = None
try:
# Create the schema for the tenant
if not create_schema_if_not_exists(tenant_id):
logger.debug(f"Created schema for tenant {tenant_id}")
else:
logger.debug(f"Schema already exists for tenant {tenant_id}")
# Set up the tenant with all necessary configurations
await setup_tenant(tenant_id)
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
# Assign the tenant to the user
await assign_tenant_to_user(tenant_id, email)
# Await the Alembic migrations
await asyncio.to_thread(run_alembic_migrations, tenant_id)
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
configure_default_api_keys(db_session)
current_search_settings = (
db_session.query(SearchSettings)
.filter_by(status=IndexModelStatus.FUTURE)
.first()
)
cohere_enabled = (
current_search_settings is not None
and current_search_settings.provider_type == EmbeddingProvider.COHERE
)
setup_onyx(db_session, tenant_id, cohere_enabled=cohere_enabled)
add_users_to_tenant([email], tenant_id)
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
create_milestone_and_report(
user=None,
distinct_id=tenant_id,
event_type=MilestoneRecordType.TENANT_CREATED,
properties={
"email": email,
},
db_session=db_session,
)
except Exception as e:
logger.exception(f"Failed to create tenant {tenant_id}")
raise HTTPException(
status_code=500, detail=f"Failed to create tenant: {str(e)}"
)
finally:
if token is not None:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
async def notify_control_plane(
@@ -191,74 +187,20 @@ async def notify_control_plane(
async def rollback_tenant_provisioning(tenant_id: str) -> None:
"""
Logic to rollback tenant provisioning on data plane.
Handles each step independently to ensure maximum cleanup even if some steps fail.
"""
# Logic to rollback tenant provisioning on data plane
logger.info(f"Rolling back tenant provisioning for tenant_id: {tenant_id}")
# Track if any part of the rollback fails
rollback_errors = []
# 1. Try to drop the tenant's schema
try:
# Drop the tenant's schema to rollback provisioning
drop_schema(tenant_id)
logger.info(f"Successfully dropped schema for tenant {tenant_id}")
# Remove tenant mapping
with Session(get_sqlalchemy_engine()) as db_session:
db_session.query(UserTenantMapping).filter(
UserTenantMapping.tenant_id == tenant_id
).delete()
db_session.commit()
except Exception as e:
error_msg = f"Failed to drop schema for tenant {tenant_id}: {str(e)}"
logger.error(error_msg)
rollback_errors.append(error_msg)
# 2. Try to remove tenant mapping
try:
with get_session_with_shared_schema() as db_session:
db_session.begin()
try:
db_session.query(UserTenantMapping).filter(
UserTenantMapping.tenant_id == tenant_id
).delete()
db_session.commit()
logger.info(
f"Successfully removed user mappings for tenant {tenant_id}"
)
except Exception as e:
db_session.rollback()
raise e
except Exception as e:
error_msg = f"Failed to remove user mappings for tenant {tenant_id}: {str(e)}"
logger.error(error_msg)
rollback_errors.append(error_msg)
# 3. If this tenant was in the available tenants table, remove it
try:
with get_session_with_shared_schema() as db_session:
db_session.begin()
try:
available_tenant = (
db_session.query(AvailableTenant)
.filter(AvailableTenant.tenant_id == tenant_id)
.first()
)
if available_tenant:
db_session.delete(available_tenant)
db_session.commit()
logger.info(
f"Removed tenant {tenant_id} from available tenants table"
)
except Exception as e:
db_session.rollback()
raise e
except Exception as e:
error_msg = f"Failed to remove tenant {tenant_id} from available tenants table: {str(e)}"
logger.error(error_msg)
rollback_errors.append(error_msg)
# Log summary of rollback operation
if rollback_errors:
logger.error(f"Tenant rollback completed with {len(rollback_errors)} errors")
else:
logger.info(f"Tenant rollback completed successfully for tenant {tenant_id}")
logger.error(f"Failed to rollback tenant provisioning: {e}")
def configure_default_api_keys(db_session: Session) -> None:
@@ -411,155 +353,3 @@ async def delete_user_from_control_plane(tenant_id: str, email: str) -> None:
raise Exception(
f"Failed to delete tenant on control plane: {error_text}"
)
def get_tenant_by_domain_from_control_plane(
domain: str,
tenant_id: str,
) -> TenantByDomainResponse | None:
"""
Fetches tenant information from the control plane based on the email domain.
Args:
domain: The email domain to search for (e.g., "example.com")
Returns:
A dictionary containing tenant information if found, None otherwise
"""
token = generate_data_plane_token()
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
try:
response = requests.get(
f"{CONTROL_PLANE_API_BASE_URL}/tenant-by-domain",
headers=headers,
json={"domain": domain, "tenant_id": tenant_id},
)
if response.status_code != 200:
logger.error(f"Control plane tenant lookup failed: {response.text}")
return None
response_data = response.json()
if not response_data:
return None
return TenantByDomainResponse(
tenant_id=response_data.get("tenant_id"),
number_of_users=response_data.get("number_of_users"),
creator_email=response_data.get("creator_email"),
)
except Exception as e:
logger.error(f"Error fetching tenant by domain: {str(e)}")
return None
async def get_available_tenant() -> str | None:
"""
Get an available pre-provisioned tenant from the NewAvailableTenant table.
Returns the tenant_id if one is available, None otherwise.
Uses row-level locking to prevent race conditions when multiple processes
try to get an available tenant simultaneously.
"""
if not MULTI_TENANT:
return None
with get_session_with_shared_schema() as db_session:
try:
db_session.begin()
# Get the oldest available tenant with FOR UPDATE lock to prevent race conditions
available_tenant = (
db_session.query(AvailableTenant)
.order_by(AvailableTenant.date_created)
.with_for_update(skip_locked=True) # Skip locked rows to avoid blocking
.first()
)
if available_tenant:
tenant_id = available_tenant.tenant_id
# Remove the tenant from the available tenants table
db_session.delete(available_tenant)
db_session.commit()
logger.info(f"Using pre-provisioned tenant {tenant_id}")
return tenant_id
else:
db_session.rollback()
return None
except Exception:
logger.exception("Error getting available tenant")
db_session.rollback()
return None
async def setup_tenant(tenant_id: str) -> None:
"""
Set up a tenant with all necessary configurations.
This is a centralized function that handles all tenant setup logic.
"""
token = None
try:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
# Run Alembic migrations
await asyncio.to_thread(run_alembic_migrations, tenant_id)
# Configure the tenant with default settings
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
# Configure default API keys
configure_default_api_keys(db_session)
# Set up Onyx with appropriate settings
current_search_settings = (
db_session.query(SearchSettings)
.filter_by(status=IndexModelStatus.FUTURE)
.first()
)
cohere_enabled = (
current_search_settings is not None
and current_search_settings.provider_type == EmbeddingProvider.COHERE
)
setup_onyx(db_session, tenant_id, cohere_enabled=cohere_enabled)
except Exception as e:
logger.exception(f"Failed to set up tenant {tenant_id}")
raise e
finally:
if token is not None:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
async def assign_tenant_to_user(
tenant_id: str, email: str, referral_source: str | None = None
) -> None:
"""
Assign a tenant to a user and perform necessary operations.
Uses transaction handling to ensure atomicity and includes retry logic
for control plane notifications.
"""
# First, add the user to the tenant in a transaction
try:
add_users_to_tenant([email], tenant_id)
# Create milestone record in the same transaction context as the tenant assignment
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
create_milestone_and_report(
user=None,
distinct_id=tenant_id,
event_type=MilestoneRecordType.TENANT_CREATED,
properties={
"email": email,
},
db_session=db_session,
)
except Exception:
logger.exception(f"Failed to assign tenant {tenant_id} to user {email}")
raise Exception("Failed to assign tenant to user")
# Notify control plane with retry logic
if not DEV_MODE:
await notify_control_plane(tenant_id, email, referral_source)

View File

@@ -74,21 +74,3 @@ def drop_schema(tenant_id: str) -> None:
text("DROP SCHEMA IF EXISTS %(schema_name)s CASCADE"),
{"schema_name": tenant_id},
)
def get_current_alembic_version(tenant_id: str) -> str:
"""Get the current Alembic version for a tenant."""
from alembic.runtime.migration import MigrationContext
from sqlalchemy import text
engine = get_sqlalchemy_engine()
# Set the search path to the tenant's schema
with engine.connect() as connection:
connection.execute(text(f'SET search_path TO "{tenant_id}"'))
# Get the current version from the alembic_version table
context = MigrationContext.configure(connection)
current_rev = context.get_current_revision()
return current_rev or "head"

View File

@@ -1,67 +0,0 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane
from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant
from onyx.auth.users import current_admin_user
from onyx.auth.users import User
from onyx.db.auth import get_user_count
from onyx.db.engine import get_session
from onyx.db.users import delete_user_from_db
from onyx.db.users import get_user_by_email
from onyx.server.manage.models import UserByEmail
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.post("/leave-team")
async def leave_organization(
user_email: UserByEmail,
current_user: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
tenant_id = get_current_tenant_id()
if current_user is None or current_user.email != user_email.user_email:
raise HTTPException(
status_code=403, detail="You can only leave the organization as yourself"
)
user_to_delete = get_user_by_email(user_email.user_email, db_session)
if user_to_delete is None:
raise HTTPException(status_code=404, detail="User not found")
num_admin_users = await get_user_count(only_admin_users=True)
should_delete_tenant = num_admin_users == 1
if should_delete_tenant:
logger.info(
"Last admin user is leaving the organization. Deleting tenant from control plane."
)
try:
await delete_user_from_control_plane(tenant_id, user_to_delete.email)
logger.debug("User deleted from control plane")
except Exception as e:
logger.exception(
f"Failed to delete user from control plane for tenant {tenant_id}: {e}"
)
raise HTTPException(
status_code=500,
detail=f"Failed to remove user from control plane: {str(e)}",
)
db_session.expunge(user_to_delete)
delete_user_from_db(user_to_delete, db_session)
if should_delete_tenant:
remove_all_users_from_tenant(tenant_id)
else:
remove_users_from_tenant([user_to_delete.email], tenant_id)

View File

@@ -1,39 +0,0 @@
from fastapi import APIRouter
from fastapi import Depends
from ee.onyx.server.tenants.models import TenantByDomainResponse
from ee.onyx.server.tenants.provisioning import get_tenant_by_domain_from_control_plane
from onyx.auth.users import current_user
from onyx.auth.users import User
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/tenants")
FORBIDDEN_COMMON_EMAIL_SUBSTRINGS = [
"gmail",
"outlook",
"yahoo",
"hotmail",
"icloud",
"msn",
"hotmail",
"hotmail.co.uk",
]
@router.get("/existing-team-by-domain")
def get_existing_tenant_by_domain(
user: User | None = Depends(current_user),
) -> TenantByDomainResponse | None:
if not user:
return None
domain = user.email.split("@")[1]
if any(substring in domain for substring in FORBIDDEN_COMMON_EMAIL_SUBSTRINGS):
return None
tenant_id = get_current_tenant_id()
return get_tenant_by_domain_from_control_plane(domain, tenant_id)

View File

@@ -1,90 +0,0 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from ee.onyx.server.tenants.models import ApproveUserRequest
from ee.onyx.server.tenants.models import PendingUserSnapshot
from ee.onyx.server.tenants.models import RequestInviteRequest
from ee.onyx.server.tenants.user_mapping import accept_user_invite
from ee.onyx.server.tenants.user_mapping import approve_user_invite
from ee.onyx.server.tenants.user_mapping import deny_user_invite
from ee.onyx.server.tenants.user_mapping import invite_self_to_tenant
from onyx.auth.invited_users import get_pending_users
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_user
from onyx.auth.users import User
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.post("/users/invite/request")
async def request_invite(
invite_request: RequestInviteRequest,
user: User | None = Depends(current_admin_user),
) -> None:
if user is None:
raise HTTPException(status_code=401, detail="User not authenticated")
try:
invite_self_to_tenant(user.email, invite_request.tenant_id)
except Exception as e:
logger.exception(
f"Failed to invite self to tenant {invite_request.tenant_id}: {e}"
)
raise HTTPException(status_code=500, detail=str(e))
@router.get("/users/pending")
def list_pending_users(
_: User | None = Depends(current_admin_user),
) -> list[PendingUserSnapshot]:
pending_emails = get_pending_users()
return [PendingUserSnapshot(email=email) for email in pending_emails]
@router.post("/users/invite/approve")
async def approve_user(
approve_user_request: ApproveUserRequest,
_: User | None = Depends(current_admin_user),
) -> None:
tenant_id = get_current_tenant_id()
approve_user_invite(approve_user_request.email, tenant_id)
@router.post("/users/invite/accept")
async def accept_invite(
invite_request: RequestInviteRequest,
user: User | None = Depends(current_user),
) -> None:
"""
Accept an invitation to join a tenant.
"""
if not user:
raise HTTPException(status_code=401, detail="Not authenticated")
try:
accept_user_invite(user.email, invite_request.tenant_id)
except Exception as e:
logger.exception(f"Failed to accept invite: {str(e)}")
raise HTTPException(status_code=500, detail="Failed to accept invitation")
@router.post("/users/invite/deny")
async def deny_invite(
invite_request: RequestInviteRequest,
user: User | None = Depends(current_user),
) -> None:
"""
Deny an invitation to join a tenant.
"""
if not user:
raise HTTPException(status_code=401, detail="Not authenticated")
try:
deny_user_invite(user.email, invite_request.tenant_id)
except Exception as e:
logger.exception(f"Failed to deny invite: {str(e)}")
raise HTTPException(status_code=500, detail="Failed to deny invitation")

View File

@@ -1,56 +1,27 @@
import logging
from fastapi_users import exceptions
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.auth.invited_users import get_invited_users
from onyx.auth.invited_users import get_pending_users
from onyx.auth.invited_users import write_invited_users
from onyx.auth.invited_users import write_pending_users
from onyx.db.engine import get_session_with_shared_schema
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_sqlalchemy_engine
from onyx.db.models import UserTenantMapping
from onyx.server.manage.models import TenantSnapshot
from onyx.setup import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
logger = logging.getLogger(__name__)
def get_tenant_id_for_email(email: str) -> str:
if not MULTI_TENANT:
return POSTGRES_DEFAULT_SCHEMA
# Implement logic to get tenant_id from the mapping table
try:
with get_session_with_shared_schema() as db_session:
# First try to get an active tenant
result = db_session.execute(
select(UserTenantMapping).where(
UserTenantMapping.email == email,
UserTenantMapping.active == True, # noqa: E712
)
)
mapping = result.scalar_one_or_none()
tenant_id = mapping.tenant_id if mapping else None
# If no active tenant found, try to get the first inactive one
if tenant_id is None:
result = db_session.execute(
select(UserTenantMapping).where(
UserTenantMapping.email == email,
UserTenantMapping.active == False, # noqa: E712
)
)
mapping = result.scalar_one_or_none()
if mapping:
# Mark this mapping as active
mapping.active = True
db_session.commit()
tenant_id = mapping.tenant_id
except Exception as e:
logger.exception(f"Error getting tenant id for email {email}: {e}")
raise exceptions.UserNotExists()
with Session(get_sqlalchemy_engine()) as db_session:
result = db_session.execute(
select(UserTenantMapping.tenant_id).where(UserTenantMapping.email == email)
)
tenant_id = result.scalar_one_or_none()
if tenant_id is None:
raise exceptions.UserNotExists()
return tenant_id
@@ -67,39 +38,13 @@ def user_owns_a_tenant(email: str) -> bool:
def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
"""
Add users to a tenant with proper transaction handling.
Checks if users already have a tenant mapping to avoid duplicates.
"""
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
try:
# Start a transaction
db_session.begin()
for email in emails:
# Check if the user already has a mapping to this tenant
existing_mapping = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email == email,
UserTenantMapping.tenant_id == tenant_id,
)
.with_for_update()
.first()
)
if not existing_mapping:
# Only add if mapping doesn't exist
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
# Commit the transaction
db_session.commit()
logger.info(f"Successfully added users {emails} to tenant {tenant_id}")
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
except Exception:
logger.exception(f"Failed to add users to tenant {tenant_id}")
db_session.rollback()
raise
db_session.commit()
def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
@@ -131,187 +76,3 @@ def remove_all_users_from_tenant(tenant_id: str) -> None:
UserTenantMapping.tenant_id == tenant_id
).delete()
db_session.commit()
def invite_self_to_tenant(email: str, tenant_id: str) -> None:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
try:
pending_users = get_pending_users()
if email in pending_users:
return
write_pending_users(pending_users + [email])
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def approve_user_invite(email: str, tenant_id: str) -> None:
"""
Approve a user invite to a tenant.
This will delete all existing records for this email and create a new mapping entry for the user in this tenant.
"""
with get_session_with_shared_schema() as db_session:
# Delete all existing records for this email
db_session.query(UserTenantMapping).filter(
UserTenantMapping.email == email
).delete()
# Create a new mapping entry for the user in this tenant
new_mapping = UserTenantMapping(email=email, tenant_id=tenant_id, active=True)
db_session.add(new_mapping)
db_session.commit()
# Also remove the user from pending users list
# Remove from pending users
pending_users = get_pending_users()
if email in pending_users:
pending_users.remove(email)
write_pending_users(pending_users)
# Add to invited users
invited_users = get_invited_users()
if email not in invited_users:
invited_users.append(email)
write_invited_users(invited_users)
def accept_user_invite(email: str, tenant_id: str) -> None:
"""
Accept an invitation to join a tenant.
This activates the user's mapping to the tenant.
"""
with get_session_with_shared_schema() as db_session:
try:
# First check if there's an active mapping for this user and tenant
active_mapping = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email == email,
UserTenantMapping.active == True, # noqa: E712
)
.first()
)
# If an active mapping exists, delete it
if active_mapping:
db_session.delete(active_mapping)
logger.info(
f"Deleted existing active mapping for user {email} in tenant {tenant_id}"
)
# Find the inactive mapping for this user and tenant
mapping = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email == email,
UserTenantMapping.tenant_id == tenant_id,
UserTenantMapping.active == False, # noqa: E712
)
.first()
)
if mapping:
# Set all other mappings for this user to inactive
db_session.query(UserTenantMapping).filter(
UserTenantMapping.email == email,
UserTenantMapping.active == True, # noqa: E712
).update({"active": False})
# Activate this mapping
mapping.active = True
db_session.commit()
logger.info(f"User {email} accepted invitation to tenant {tenant_id}")
else:
logger.warning(
f"No invitation found for user {email} in tenant {tenant_id}"
)
except Exception as e:
db_session.rollback()
logger.exception(
f"Failed to accept invitation for user {email} to tenant {tenant_id}: {str(e)}"
)
raise
def deny_user_invite(email: str, tenant_id: str) -> None:
"""
Deny an invitation to join a tenant.
This removes the user's mapping to the tenant.
"""
with get_session_with_shared_schema() as db_session:
# Delete the mapping for this user and tenant
result = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email == email,
UserTenantMapping.tenant_id == tenant_id,
UserTenantMapping.active == False, # noqa: E712
)
.delete()
)
db_session.commit()
if result:
logger.info(f"User {email} denied invitation to tenant {tenant_id}")
else:
logger.warning(
f"No invitation found for user {email} in tenant {tenant_id}"
)
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
try:
pending_users = get_invited_users()
if email in pending_users:
pending_users.remove(email)
write_invited_users(pending_users)
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def get_tenant_count(tenant_id: str) -> int:
"""
Get the number of active users for this tenant
"""
with get_session_with_shared_schema() as db_session:
# Count the number of active users for this tenant
user_count = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.tenant_id == tenant_id,
UserTenantMapping.active == True, # noqa: E712
)
.count()
)
return user_count
def get_tenant_invitation(email: str) -> TenantSnapshot | None:
"""
Get the first tenant invitation for this user
"""
with get_session_with_shared_schema() as db_session:
# Get the first tenant invitation for this user
invitation = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email == email,
UserTenantMapping.active == False, # noqa: E712
)
.first()
)
if invitation:
# Get the user count for this tenant
user_count = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.tenant_id == invitation.tenant_id,
UserTenantMapping.active == True, # noqa: E712
)
.count()
)
return TenantSnapshot(
tenant_id=invitation.tenant_id, number_of_users=user_count
)
return None

View File

@@ -62,60 +62,6 @@ _OPENAI_MAX_INPUT_LEN = 2048
# Cohere allows up to 96 embeddings in a single embedding calling
_COHERE_MAX_INPUT_LEN = 96
# Authentication error string constants
_AUTH_ERROR_401 = "401"
_AUTH_ERROR_UNAUTHORIZED = "unauthorized"
_AUTH_ERROR_INVALID_API_KEY = "invalid api key"
_AUTH_ERROR_PERMISSION = "permission"
def is_authentication_error(error: Exception) -> bool:
"""Check if an exception is related to authentication issues.
Args:
error: The exception to check
Returns:
bool: True if the error appears to be authentication-related
"""
error_str = str(error).lower()
return (
_AUTH_ERROR_401 in error_str
or _AUTH_ERROR_UNAUTHORIZED in error_str
or _AUTH_ERROR_INVALID_API_KEY in error_str
or _AUTH_ERROR_PERMISSION in error_str
)
def format_embedding_error(
error: Exception,
service_name: str,
model: str | None,
provider: EmbeddingProvider,
status_code: int | None = None,
) -> str:
"""
Format a standardized error string for embedding errors.
"""
detail = f"Status {status_code}" if status_code else f"{type(error)}"
return (
f"{'HTTP error' if status_code else 'Exception'} embedding text with {service_name} - {detail}: "
f"Model: {model} "
f"Provider: {provider} "
f"Exception: {error}"
)
# Custom exception for authentication errors
class AuthenticationError(Exception):
"""Raised when authentication fails with a provider."""
def __init__(self, provider: str, message: str = "API key is invalid or expired"):
self.provider = provider
self.message = message
super().__init__(f"{provider} authentication failed: {message}")
class CloudEmbedding:
def __init__(
@@ -146,17 +92,31 @@ class CloudEmbedding:
)
final_embeddings: list[Embedding] = []
try:
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
response = await client.embeddings.create(
input=text_batch,
model=model,
dimensions=reduced_dimension or openai.NOT_GIVEN,
)
final_embeddings.extend(
[embedding.embedding for embedding in response.data]
)
return final_embeddings
except Exception as e:
error_string = (
f"Exception embedding text with OpenAI - {type(e)}: "
f"Model: {model} "
f"Provider: {self.provider} "
f"Exception: {e}"
)
logger.error(error_string)
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
response = await client.embeddings.create(
input=text_batch,
model=model,
dimensions=reduced_dimension or openai.NOT_GIVEN,
)
final_embeddings.extend(
[embedding.embedding for embedding in response.data]
)
return final_embeddings
# only log text when it's not an authentication error.
if not isinstance(e, openai.AuthenticationError):
logger.debug(f"Exception texts: {texts}")
raise RuntimeError(error_string)
async def _embed_cohere(
self, texts: list[str], model: str | None, embedding_type: str
@@ -195,6 +155,7 @@ class CloudEmbedding:
input_type=embedding_type,
truncation=True,
)
return response.embeddings
async def _embed_azure(
@@ -278,51 +239,22 @@ class CloudEmbedding:
deployment_name: str | None = None,
reduced_dimension: int | None = None,
) -> list[Embedding]:
try:
if self.provider == EmbeddingProvider.OPENAI:
return await self._embed_openai(texts, model_name, reduced_dimension)
elif self.provider == EmbeddingProvider.AZURE:
return await self._embed_azure(texts, f"azure/{deployment_name}")
elif self.provider == EmbeddingProvider.LITELLM:
return await self._embed_litellm_proxy(texts, model_name)
if self.provider == EmbeddingProvider.OPENAI:
return await self._embed_openai(texts, model_name, reduced_dimension)
elif self.provider == EmbeddingProvider.AZURE:
return await self._embed_azure(texts, f"azure/{deployment_name}")
elif self.provider == EmbeddingProvider.LITELLM:
return await self._embed_litellm_proxy(texts, model_name)
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
if self.provider == EmbeddingProvider.COHERE:
return await self._embed_cohere(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.VOYAGE:
return await self._embed_voyage(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.GOOGLE:
return await self._embed_vertex(texts, model_name, embedding_type)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
except openai.AuthenticationError:
raise AuthenticationError(provider="OpenAI")
except httpx.HTTPStatusError as e:
if e.response.status_code == 401:
raise AuthenticationError(provider=str(self.provider))
error_string = format_embedding_error(
e,
str(self.provider),
model_name or deployment_name,
self.provider,
status_code=e.response.status_code,
)
logger.error(error_string)
logger.debug(f"Exception texts: {texts}")
raise RuntimeError(error_string)
except Exception as e:
if is_authentication_error(e):
raise AuthenticationError(provider=str(self.provider))
error_string = format_embedding_error(
e, str(self.provider), model_name or deployment_name, self.provider
)
logger.error(error_string)
logger.debug(f"Exception texts: {texts}")
raise RuntimeError(error_string)
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
if self.provider == EmbeddingProvider.COHERE:
return await self._embed_cohere(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.VOYAGE:
return await self._embed_voyage(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.GOOGLE:
return await self._embed_vertex(texts, model_name, embedding_type)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
@staticmethod
def create(
@@ -637,13 +569,6 @@ async def process_embed_request(
gpu_type=gpu_type,
)
return EmbedResponse(embeddings=embeddings)
except AuthenticationError as e:
# Handle authentication errors consistently
logger.error(f"Authentication error: {e.provider}")
raise HTTPException(
status_code=401,
detail=f"Authentication failed: {e.message}",
)
except RateLimitError as e:
raise HTTPException(
status_code=429,

View File

@@ -31,7 +31,6 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_VALIDATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_SUBANSWER_CHECK
from onyx.llm.chat_llm import LLMRateLimitError
@@ -93,7 +92,6 @@ def check_sub_answer(
fast_llm.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK,
max_tokens=AGENT_MAX_TOKENS_VALIDATION,
)
quality_str: str = cast(str, response.content)

View File

@@ -46,7 +46,6 @@ from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_SUBANSWER_GENERATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION
from onyx.llm.chat_llm import LLMRateLimitError
@@ -120,7 +119,6 @@ def generate_sub_answer(
for message in fast_llm.stream(
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION,
max_tokens=AGENT_MAX_TOKENS_SUBANSWER_GENERATION,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content

View File

@@ -43,7 +43,6 @@ from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrin
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_section_list,
)
from onyx.agents.agent_search.shared_graph_utils.utils import _should_restrict_tokens
from onyx.agents.agent_search.shared_graph_utils.utils import (
dispatch_main_answer_stop_info,
)
@@ -63,7 +62,6 @@ from onyx.chat.models import StreamingError
from onyx.configs.agent_configs import AGENT_ANSWER_GENERATION_BY_FAST_LLM
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_MAX_STREAMED_DOCS_FOR_INITIAL_ANSWER
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_ANSWER_GENERATION
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
@@ -155,9 +153,8 @@ def generate_initial_answer(
)
for tool_response in yield_search_responses(
query=question,
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
get_reranked_sections=lambda: answer_generation_documents.streaming_documents,
get_final_context_sections=lambda: answer_generation_documents.context_documents,
reranked_sections=answer_generation_documents.streaming_documents,
final_context_sections=answer_generation_documents.context_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,
search_tool=graph_config.tooling.search_tool,
@@ -281,9 +278,6 @@ def generate_initial_answer(
for message in model.stream(
msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
max_tokens=AGENT_MAX_TOKENS_ANSWER_GENERATION
if _should_restrict_tokens(model.config)
else None,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content

View File

@@ -34,7 +34,6 @@ from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.chat.models import SubQuestionPiece
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_SUBQUESTION_GENERATION
from onyx.configs.agent_configs import AGENT_NUM_DOCS_FOR_DECOMPOSITION
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION,
@@ -142,7 +141,6 @@ def decompose_orig_question(
model.stream(
msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION,
max_tokens=AGENT_MAX_TOKENS_SUBQUESTION_GENERATION,
),
dispatch_subquestion(0, writer),
sep_callback=dispatch_subquestion_sep(0, writer),

View File

@@ -33,7 +33,6 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import RefinedAnswerImprovement
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_VALIDATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_COMPARE_ANSWERS
from onyx.llm.chat_llm import LLMRateLimitError
@@ -113,7 +112,6 @@ def compare_answers(
model.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS,
max_tokens=AGENT_MAX_TOKENS_VALIDATION,
)
except (LLMTimeoutError, TimeoutError):

View File

@@ -43,7 +43,6 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import StreamingError
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_SUBQUESTION_GENERATION
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION,
)
@@ -145,7 +144,6 @@ def create_refined_sub_questions(
model.stream(
msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION,
max_tokens=AGENT_MAX_TOKENS_SUBQUESTION_GENERATION,
),
dispatch_subquestion(1, writer),
sep_callback=dispatch_subquestion_sep(1, writer),

View File

@@ -50,7 +50,13 @@ def decide_refinement_need(
)
]
return RequireRefinemenEvalUpdate(
require_refined_answer_eval=graph_config.behavior.allow_refinement and decision,
log_messages=log_messages,
)
if graph_config.behavior.allow_refinement:
return RequireRefinemenEvalUpdate(
require_refined_answer_eval=decision,
log_messages=log_messages,
)
else:
return RequireRefinemenEvalUpdate(
require_refined_answer_eval=False,
log_messages=log_messages,
)

View File

@@ -21,7 +21,6 @@ from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_ENTITY_TERM_EXTRACTION
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION,
)
@@ -97,7 +96,6 @@ def extract_entities_terms(
fast_llm.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION,
max_tokens=AGENT_MAX_TOKENS_ENTITY_TERM_EXTRACTION,
)
cleaned_response = (

View File

@@ -46,7 +46,6 @@ from onyx.agents.agent_search.shared_graph_utils.models import RefinedAgentStats
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_section_list,
)
from onyx.agents.agent_search.shared_graph_utils.utils import _should_restrict_tokens
from onyx.agents.agent_search.shared_graph_utils.utils import (
dispatch_main_answer_stop_info,
)
@@ -69,8 +68,6 @@ from onyx.chat.models import StreamingError
from onyx.configs.agent_configs import AGENT_ANSWER_GENERATION_BY_FAST_LLM
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_MAX_STREAMED_DOCS_FOR_REFINED_ANSWER
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_ANSWER_GENERATION
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_VALIDATION
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION,
@@ -182,9 +179,8 @@ def generate_validate_refined_answer(
)
for tool_response in yield_search_responses(
query=question,
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
get_reranked_sections=lambda: answer_generation_documents.streaming_documents,
get_final_context_sections=lambda: answer_generation_documents.context_documents,
reranked_sections=answer_generation_documents.streaming_documents,
final_context_sections=answer_generation_documents.context_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,
search_tool=graph_config.tooling.search_tool,
@@ -306,11 +302,7 @@ def generate_validate_refined_answer(
def stream_refined_answer() -> list[str]:
for message in model.stream(
msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION,
max_tokens=AGENT_MAX_TOKENS_ANSWER_GENERATION
if _should_restrict_tokens(model.config)
else None,
msg, timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
@@ -417,7 +409,6 @@ def generate_validate_refined_answer(
validation_model.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION,
max_tokens=AGENT_MAX_TOKENS_VALIDATION,
)
refined_answer_quality = binary_string_test_after_answer_separator(
text=cast(str, validation_response.content),

View File

@@ -13,6 +13,7 @@ from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.chat.models import SubQuestionPiece
from onyx.context.search.models import IndexFilters
from onyx.tools.models import SearchQueryInfo
from onyx.utils.logger import setup_logger
@@ -143,6 +144,8 @@ def get_query_info(results: list[QueryRetrievalResult]) -> SearchQueryInfo:
if result.query_info is not None:
query_info = result.query_info
break
assert query_info is not None, "must have query info"
return query_info
return query_info or SearchQueryInfo(
predicted_search=None,
final_filters=IndexFilters(access_control_list=None),
recency_bias_multiplier=1.0,
)

View File

@@ -33,7 +33,6 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_SUBQUERY_GENERATION
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION,
)
@@ -97,7 +96,6 @@ def expand_queries(
model.stream(
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION,
max_tokens=AGENT_MAX_TOKENS_SUBQUERY_GENERATION,
),
dispatch_subquery(level, question_num, writer),
)

View File

@@ -56,9 +56,8 @@ def format_results(
relevance_list = relevance_from_docs(reranked_documents)
for tool_response in yield_search_responses(
query=state.question,
get_retrieved_sections=lambda: reranked_documents,
get_reranked_sections=lambda: state.retrieved_documents,
get_final_context_sections=lambda: reranked_documents,
reranked_sections=state.retrieved_documents,
final_context_sections=reranked_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,
search_tool=graph_config.tooling.search_tool,

View File

@@ -91,7 +91,7 @@ def retrieve_documents(
retrieved_docs = retrieved_docs[:AGENT_MAX_QUERY_RETRIEVAL_RESULTS]
if AGENT_RETRIEVAL_STATS:
pre_rerank_docs = callback_container[0] if callback_container else []
pre_rerank_docs = callback_container[0]
fit_scores = get_fit_scores(
pre_rerank_docs,
retrieved_docs,

View File

@@ -25,7 +25,6 @@ from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrin
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_VALIDATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION
from onyx.llm.chat_llm import LLMRateLimitError
@@ -94,7 +93,6 @@ def verify_documents(
fast_llm.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION,
max_tokens=AGENT_MAX_TOKENS_VALIDATION,
)
assert isinstance(response.content, str)

View File

@@ -44,9 +44,7 @@ def call_tool(
tool = tool_choice.tool
tool_args = tool_choice.tool_args
tool_id = tool_choice.id
tool_runner = ToolRunner(
tool, tool_args, override_kwargs=tool_choice.search_tool_override_kwargs
)
tool_runner = ToolRunner(tool, tool_args)
tool_kickoff = tool_runner.kickoff()
emit_packet(tool_kickoff, writer)

View File

@@ -15,17 +15,8 @@ from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
from onyx.chat.tool_handling.tool_response_handler import (
get_tool_call_for_non_tool_calling_llm_impl,
)
from onyx.context.search.preprocessing.preprocessing import query_analysis
from onyx.context.search.retrieval.search_runner import get_query_embedding
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_in_background
from onyx.utils.threadpool_concurrency import TimeoutThread
from onyx.utils.threadpool_concurrency import wait_on_background
from onyx.utils.timing import log_function_time
from shared_configs.model_server_models import Embedding
logger = setup_logger()
@@ -34,7 +25,6 @@ logger = setup_logger()
# and a function that handles extracting the necessary fields
# from the state and config
# TODO: fan-out to multiple tool call nodes? Make this configurable?
@log_function_time(print_only=True)
def choose_tool(
state: ToolChoiceState,
config: RunnableConfig,
@@ -47,31 +37,6 @@ def choose_tool(
should_stream_answer = state.should_stream_answer
agent_config = cast(GraphConfig, config["metadata"]["config"])
force_use_tool = agent_config.tooling.force_use_tool
embedding_thread: TimeoutThread[Embedding] | None = None
keyword_thread: TimeoutThread[tuple[bool, list[str]]] | None = None
override_kwargs: SearchToolOverrideKwargs | None = None
if (
not agent_config.behavior.use_agentic_search
and agent_config.tooling.search_tool is not None
and (
not force_use_tool.force_use or force_use_tool.tool_name == SearchTool.name
)
):
override_kwargs = SearchToolOverrideKwargs()
# Run in a background thread to avoid blocking the main thread
embedding_thread = run_in_background(
get_query_embedding,
agent_config.inputs.search_request.query,
agent_config.persistence.db_session,
)
keyword_thread = run_in_background(
query_analysis,
agent_config.inputs.search_request.query,
)
using_tool_calling_llm = agent_config.tooling.using_tool_calling_llm
prompt_builder = state.prompt_snapshot or agent_config.inputs.prompt_builder
@@ -82,6 +47,7 @@ def choose_tool(
tools = [
tool for tool in (agent_config.tooling.tools or []) if tool.name in state.tools
]
force_use_tool = agent_config.tooling.force_use_tool
tool, tool_args = None, None
if force_use_tool.force_use and force_use_tool.args is not None:
@@ -105,22 +71,11 @@ def choose_tool(
# If we have a tool and tool args, we are ready to request a tool call.
# This only happens if the tool call was forced or we are using a non-tool calling LLM.
if tool and tool_args:
if embedding_thread and tool.name == SearchTool._NAME:
# Wait for the embedding thread to finish
embedding = wait_on_background(embedding_thread)
assert override_kwargs is not None, "must have override kwargs"
override_kwargs.precomputed_query_embedding = embedding
if keyword_thread and tool.name == SearchTool._NAME:
is_keyword, keywords = wait_on_background(keyword_thread)
assert override_kwargs is not None, "must have override kwargs"
override_kwargs.precomputed_is_keyword = is_keyword
override_kwargs.precomputed_keywords = keywords
return ToolChoiceUpdate(
tool_choice=ToolChoice(
tool=tool,
tool_args=tool_args,
id=str(uuid4()),
search_tool_override_kwargs=override_kwargs,
),
)
@@ -198,22 +153,10 @@ def choose_tool(
logger.debug(f"Selected tool: {selected_tool.name}")
logger.debug(f"Selected tool call request: {selected_tool_call_request}")
if embedding_thread and selected_tool.name == SearchTool._NAME:
# Wait for the embedding thread to finish
embedding = wait_on_background(embedding_thread)
assert override_kwargs is not None, "must have override kwargs"
override_kwargs.precomputed_query_embedding = embedding
if keyword_thread and selected_tool.name == SearchTool._NAME:
is_keyword, keywords = wait_on_background(keyword_thread)
assert override_kwargs is not None, "must have override kwargs"
override_kwargs.precomputed_is_keyword = is_keyword
override_kwargs.precomputed_keywords = keywords
return ToolChoiceUpdate(
tool_choice=ToolChoice(
tool=selected_tool,
tool_args=selected_tool_call_request["args"],
id=selected_tool_call_request["id"],
search_tool_override_kwargs=override_kwargs,
),
)

View File

@@ -9,23 +9,18 @@ from onyx.agents.agent_search.basic.states import BasicState
from onyx.agents.agent_search.basic.utils import process_llm_stream
from onyx.agents.agent_search.models import GraphConfig
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxContexts
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
from onyx.tools.tool_implementations.search.search_utils import (
context_from_inference_section,
SEARCH_DOC_CONTENT_ID,
)
from onyx.tools.tool_implementations.search_like_tool_utils import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
from onyx.utils.logger import setup_logger
from onyx.utils.timing import log_function_time
logger = setup_logger()
@log_function_time(print_only=True)
def basic_use_tool_response(
state: BasicState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> BasicOutput:
@@ -55,13 +50,11 @@ def basic_use_tool_response(
for yield_item in tool_call_responses:
if yield_item.id == FINAL_CONTEXT_DOCUMENTS_ID:
final_search_results = cast(list[LlmDoc], yield_item.response)
elif yield_item.id == SEARCH_RESPONSE_SUMMARY_ID:
search_response_summary = cast(SearchResponseSummary, yield_item.response)
for section in search_response_summary.top_sections:
if section.center_chunk.document_id not in initial_search_results:
initial_search_results.append(
context_from_inference_section(section)
)
elif yield_item.id == SEARCH_DOC_CONTENT_ID:
search_contexts = cast(OnyxContexts, yield_item.response).contexts
for doc in search_contexts:
if doc.document_id not in initial_search_results:
initial_search_results.append(doc)
new_tool_call_chunk = AIMessageChunk(content="")
if not agent_config.behavior.skip_gen_ai_answer_generation:

View File

@@ -2,7 +2,6 @@ from pydantic import BaseModel
from onyx.chat.prompt_builder.answer_prompt_builder import PromptSnapshot
from onyx.tools.message import ToolCallSummary
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.models import ToolCallFinalResult
from onyx.tools.models import ToolCallKickoff
from onyx.tools.models import ToolResponse
@@ -36,7 +35,6 @@ class ToolChoice(BaseModel):
tool: Tool
tool_args: dict
id: str | None
search_tool_override_kwargs: SearchToolOverrideKwargs | None = None
class Config:
arbitrary_types_allowed = True

View File

@@ -13,11 +13,6 @@ AGENT_NEGATIVE_VALUE_STR = "no"
AGENT_ANSWER_SEPARATOR = "Answer:"
EMBEDDING_KEY = "embedding"
IS_KEYWORD_KEY = "is_keyword"
KEYWORDS_KEY = "keywords"
class AgentLLMErrorType(str, Enum):
TIMEOUT = "timeout"
RATE_LIMIT = "rate_limit"

View File

@@ -42,7 +42,6 @@ from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_HISTORY_SUMMARY
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION,
)
@@ -62,7 +61,6 @@ from onyx.db.persona import Persona
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import LLMConfig
from onyx.prompts.agent_search import (
ASSISTANT_SYSTEM_PROMPT_DEFAULT,
)
@@ -404,7 +402,6 @@ def summarize_history(
llm.invoke,
history_context_prompt,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION,
max_tokens=AGENT_MAX_TOKENS_HISTORY_SUMMARY,
)
except (LLMTimeoutError, TimeoutError):
logger.error("LLM Timeout Error - summarize history")
@@ -508,9 +505,3 @@ def get_deduplicated_structured_subquestion_documents(
cited_documents=dedup_inference_section_list(cited_docs),
context_documents=dedup_inference_section_list(context_docs),
)
def _should_restrict_tokens(llm_config: LLMConfig) -> bool:
return not (
llm_config.model_provider == "openai" and llm_config.model_name.startswith("o")
)

View File

@@ -153,8 +153,7 @@ def send_email(
msg = MIMEMultipart("alternative")
msg["Subject"] = subject
msg["To"] = user_email
if mail_from:
msg["From"] = mail_from
msg["From"] = mail_from
msg["Date"] = formatdate(localtime=True)
msg["Message-ID"] = make_msgid(domain="onyx.app")

View File

@@ -1,6 +1,5 @@
from typing import cast
from onyx.configs.constants import KV_PENDING_USERS_KEY
from onyx.configs.constants import KV_USER_STORE_KEY
from onyx.key_value_store.factory import get_kv_store
from onyx.key_value_store.interface import KvKeyNotFoundError
@@ -19,17 +18,3 @@ def write_invited_users(emails: list[str]) -> int:
store = get_kv_store()
store.store(KV_USER_STORE_KEY, cast(JSON_ro, emails))
return len(emails)
def get_pending_users() -> list[str]:
try:
store = get_kv_store()
return cast(list, store.load(KV_PENDING_USERS_KEY))
except KvKeyNotFoundError:
return list()
def write_pending_users(emails: list[str]) -> int:
store = get_kv_store()
store.store(KV_PENDING_USERS_KEY, cast(JSON_ro, emails))
return len(emails)

View File

@@ -100,7 +100,6 @@ from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
from onyx.utils.url import add_url_params
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
from onyx.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import async_return_default_schema
@@ -895,7 +894,7 @@ async def current_limited_user(
return await double_check_user(user)
async def current_chat_accessible_user(
async def current_chat_accesssible_user(
user: User | None = Depends(optional_user),
) -> User | None:
tenant_id = get_current_tenant_id()
@@ -1096,12 +1095,6 @@ def get_oauth_router(
next_url = state_data.get("next_url", "/")
referral_source = state_data.get("referral_source", None)
try:
tenant_id = fetch_ee_implementation_or_noop(
"onyx.server.tenants.user_mapping", "get_tenant_id_for_email", None
)(account_email)
except exceptions.UserNotExists:
tenant_id = None
request.state.referral_source = referral_source
@@ -1133,14 +1126,9 @@ def get_oauth_router(
# Login user
response = await backend.login(strategy, user)
await user_manager.on_after_login(user, request, response)
# Prepare redirect response
if tenant_id is None:
# Use URL utility to add parameters
redirect_url = add_url_params(next_url, {"new_team": "true"})
redirect_response = RedirectResponse(redirect_url, status_code=302)
else:
# No parameters to add
redirect_response = RedirectResponse(next_url, status_code=302)
redirect_response = RedirectResponse(next_url, status_code=302)
# Copy headers and other attributes from 'response' to 'redirect_response'
for header_name, header_value in response.headers.items():
@@ -1152,7 +1140,6 @@ def get_oauth_router(
redirect_response.status_code = response.status_code
if hasattr(response, "media_type"):
redirect_response.media_type = response.media_type
return redirect_response
return router

View File

@@ -111,7 +111,5 @@ celery_app.autodiscover_tasks(
"onyx.background.celery.tasks.vespa",
"onyx.background.celery.tasks.connector_deletion",
"onyx.background.celery.tasks.doc_permission_syncing",
"onyx.background.celery.tasks.indexing",
"onyx.background.celery.tasks.tenant_provisioning",
]
)

View File

@@ -92,6 +92,5 @@ def on_setup_logging(
celery_app.autodiscover_tasks(
[
"onyx.background.celery.tasks.monitoring",
"onyx.background.celery.tasks.tenant_provisioning",
]
)

View File

@@ -1,73 +0,0 @@
# backend/onyx/background/celery/memory_monitoring.py
import logging
import os
from logging.handlers import RotatingFileHandler
import psutil
from onyx.utils.logger import is_running_in_container
from onyx.utils.logger import setup_logger
# Regular application logger
logger = setup_logger()
# Only set up memory monitoring in container environment
if is_running_in_container():
# Set up a dedicated memory monitoring logger
MEMORY_LOG_DIR = "/var/log/persisted-logs/memory"
MEMORY_LOG_FILE = os.path.join(MEMORY_LOG_DIR, "memory_usage.log")
MEMORY_LOG_MAX_BYTES = 10 * 1024 * 1024 # 10MB
MEMORY_LOG_BACKUP_COUNT = 5 # Keep 5 backup files
# Ensure log directory exists
os.makedirs(MEMORY_LOG_DIR, exist_ok=True)
# Create a dedicated logger for memory monitoring
memory_logger = logging.getLogger("memory_monitoring")
memory_logger.setLevel(logging.INFO)
# Create a rotating file handler
memory_handler = RotatingFileHandler(
MEMORY_LOG_FILE,
maxBytes=MEMORY_LOG_MAX_BYTES,
backupCount=MEMORY_LOG_BACKUP_COUNT,
)
# Create a formatter that includes all relevant information
memory_formatter = logging.Formatter(
"%(asctime)s [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
)
memory_handler.setFormatter(memory_formatter)
memory_logger.addHandler(memory_handler)
else:
# Create a null logger when not in container
memory_logger = logging.getLogger("memory_monitoring")
memory_logger.addHandler(logging.NullHandler())
def emit_process_memory(
pid: int, process_name: str, additional_metadata: dict[str, str | int]
) -> None:
# Skip memory monitoring if not in container
if not is_running_in_container():
return
try:
process = psutil.Process(pid)
memory_info = process.memory_info()
cpu_percent = process.cpu_percent(interval=0.1)
# Build metadata string from additional_metadata dictionary
metadata_str = " ".join(
[f"{key}={value}" for key, value in additional_metadata.items()]
)
metadata_str = f" {metadata_str}" if metadata_str else ""
memory_logger.info(
f"PROCESS_MEMORY process_name={process_name} pid={pid} "
f"rss_mb={memory_info.rss / (1024 * 1024):.2f} "
f"vms_mb={memory_info.vms / (1024 * 1024):.2f} "
f"cpu={cpu_percent:.2f}{metadata_str}"
)
except Exception:
logger.exception("Error monitoring process memory.")

View File

@@ -167,16 +167,6 @@ beat_cloud_tasks: list[dict] = [
"expires": BEAT_EXPIRES_DEFAULT,
},
},
{
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-available-tenants",
"task": OnyxCeleryTask.CHECK_AVAILABLE_TENANTS,
"schedule": timedelta(minutes=10),
"options": {
"queue": OnyxCeleryQueues.MONITORING,
"priority": OnyxCeleryPriority.HIGH,
"expires": BEAT_EXPIRES_DEFAULT,
},
},
]
# tasks that only run self hosted

View File

@@ -23,7 +23,6 @@ from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.background.celery.memory_monitoring import emit_process_memory
from onyx.background.celery.tasks.indexing.utils import get_unfenced_index_attempt_ids
from onyx.background.celery.tasks.indexing.utils import IndexingCallback
from onyx.background.celery.tasks.indexing.utils import should_index
@@ -985,9 +984,6 @@ def connector_indexing_proxy_task(
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
# Track the last time memory info was emitted
last_memory_emit_time = 0.0
try:
with get_session_with_current_tenant() as db_session:
index_attempt = get_index_attempt(
@@ -1028,23 +1024,6 @@ def connector_indexing_proxy_task(
job.release()
break
# log the memory usage for tracking down memory leaks / connector-specific memory issues
pid = job.process.pid
if pid is not None:
# Only emit memory info once per minute (60 seconds)
current_time = time.monotonic()
if current_time - last_memory_emit_time >= 60.0:
emit_process_memory(
pid,
"indexing_worker",
{
"cc_pair_id": cc_pair_id,
"search_settings_id": search_settings_id,
"index_attempt_id": index_attempt_id,
},
)
last_memory_emit_time = current_time
# if a termination signal is detected, break (exit point will clean up)
if self.request.id and redis_connector_index.terminating(self.request.id):
task_logger.warning(
@@ -1191,7 +1170,6 @@ def connector_indexing_proxy_task(
return
# primary
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_CHECKPOINT_CLEANUP,
soft_time_limit=300,
@@ -1239,7 +1217,6 @@ def check_for_checkpoint_cleanup(*, tenant_id: str) -> None:
)
# light worker
@shared_task(
name=OnyxCeleryTask.CLEANUP_CHECKPOINT,
bind=True,

View File

@@ -1,199 +0,0 @@
"""
Periodic tasks for tenant pre-provisioning.
"""
import asyncio
import datetime
import uuid
from celery import shared_task
from celery import Task
from redis.lock import Lock as RedisLock
from ee.onyx.server.tenants.provisioning import setup_tenant
from ee.onyx.server.tenants.schema_management import create_schema_if_not_exists
from ee.onyx.server.tenants.schema_management import get_current_alembic_version
from onyx.background.celery.apps.app_base import task_logger
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.configs.app_configs import TARGET_AVAILABLE_TENANTS
from onyx.configs.constants import OnyxCeleryPriority
from onyx.configs.constants import OnyxCeleryQueues
from onyx.configs.constants import OnyxCeleryTask
from onyx.configs.constants import OnyxRedisLocks
from onyx.db.engine import get_session_with_shared_schema
from onyx.db.models import AvailableTenant
from onyx.redis.redis_pool import get_redis_client
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import TENANT_ID_PREFIX
# Default number of pre-provisioned tenants to maintain
DEFAULT_TARGET_AVAILABLE_TENANTS = 5
# Soft time limit for tenant pre-provisioning tasks (in seconds)
_TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 5 # 5 minutes
# Hard time limit for tenant pre-provisioning tasks (in seconds)
_TENANT_PROVISIONING_TIME_LIMIT = 60 * 10 # 10 minutes
@shared_task(
name=OnyxCeleryTask.CHECK_AVAILABLE_TENANTS,
queue=OnyxCeleryQueues.MONITORING,
ignore_result=True,
soft_time_limit=JOB_TIMEOUT,
trail=False,
bind=True,
)
def check_available_tenants(self: Task) -> None:
"""
Check if we have enough pre-provisioned tenants available.
If not, trigger the pre-provisioning of new tenants.
"""
task_logger.info("STARTING CHECK_AVAILABLE_TENANTS")
if not MULTI_TENANT:
task_logger.info(
"Multi-tenancy is not enabled, skipping tenant pre-provisioning"
)
return
r = get_redis_client()
lock_check: RedisLock = r.lock(
OnyxRedisLocks.CHECK_AVAILABLE_TENANTS_LOCK,
timeout=_TENANT_PROVISIONING_SOFT_TIME_LIMIT,
)
# These tasks should never overlap
if not lock_check.acquire(blocking=False):
task_logger.info(
"Skipping check_available_tenants task because it is already running"
)
return
try:
# Get the current count of available tenants
with get_session_with_shared_schema() as db_session:
available_tenants_count = db_session.query(AvailableTenant).count()
# Get the target number of available tenants
target_available_tenants = getattr(
TARGET_AVAILABLE_TENANTS, "value", DEFAULT_TARGET_AVAILABLE_TENANTS
)
# Calculate how many new tenants we need to provision
tenants_to_provision = max(
0, target_available_tenants - available_tenants_count
)
task_logger.info(
f"Available tenants: {available_tenants_count}, "
f"Target: {target_available_tenants}, "
f"To provision: {tenants_to_provision}"
)
# Trigger pre-provisioning tasks for each tenant needed
for _ in range(tenants_to_provision):
from celery import current_app
current_app.send_task(
OnyxCeleryTask.PRE_PROVISION_TENANT,
priority=OnyxCeleryPriority.LOW,
)
except Exception:
task_logger.exception("Error in check_available_tenants task")
finally:
lock_check.release()
@shared_task(
name=OnyxCeleryTask.PRE_PROVISION_TENANT,
ignore_result=True,
soft_time_limit=_TENANT_PROVISIONING_SOFT_TIME_LIMIT,
time_limit=_TENANT_PROVISIONING_TIME_LIMIT,
queue=OnyxCeleryQueues.MONITORING,
bind=True,
)
def pre_provision_tenant(self: Task) -> None:
"""
Pre-provision a new tenant and store it in the NewAvailableTenant table.
This function fully sets up the tenant with all necessary configurations,
so it's ready to be assigned to a user immediately.
"""
# The MULTI_TENANT check is now done at the caller level (check_available_tenants)
# rather than inside this function
r = get_redis_client()
lock_provision: RedisLock = r.lock(
OnyxRedisLocks.PRE_PROVISION_TENANT_LOCK,
timeout=_TENANT_PROVISIONING_SOFT_TIME_LIMIT,
)
# Allow multiple pre-provisioning tasks to run, but ensure they don't overlap
if not lock_provision.acquire(blocking=False):
task_logger.debug(
"Skipping pre_provision_tenant task because it is already running"
)
return
tenant_id: str | None = None
try:
# Generate a new tenant ID
tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4())
task_logger.info(f"Pre-provisioning tenant: {tenant_id}")
# Create the schema for the new tenant
schema_created = create_schema_if_not_exists(tenant_id)
if schema_created:
task_logger.debug(f"Created schema for tenant: {tenant_id}")
else:
task_logger.debug(f"Schema already exists for tenant: {tenant_id}")
# Set up the tenant with all necessary configurations
task_logger.debug(f"Setting up tenant configuration: {tenant_id}")
asyncio.run(setup_tenant(tenant_id))
task_logger.debug(f"Tenant configuration completed: {tenant_id}")
# Get the current Alembic version
alembic_version = get_current_alembic_version(tenant_id)
task_logger.debug(
f"Tenant {tenant_id} using Alembic version: {alembic_version}"
)
# Store the pre-provisioned tenant in the database
task_logger.debug(f"Storing pre-provisioned tenant in database: {tenant_id}")
with get_session_with_shared_schema() as db_session:
# Use a transaction to ensure atomicity
db_session.begin()
try:
new_tenant = AvailableTenant(
tenant_id=tenant_id,
alembic_version=alembic_version,
date_created=datetime.datetime.now(),
)
db_session.add(new_tenant)
db_session.commit()
task_logger.info(f"Successfully pre-provisioned tenant: {tenant_id}")
except Exception:
db_session.rollback()
task_logger.error(
f"Failed to store pre-provisioned tenant: {tenant_id}",
exc_info=True,
)
raise
except Exception:
task_logger.error("Error in pre_provision_tenant task", exc_info=True)
# If we have a tenant_id, attempt to rollback any partially completed provisioning
if tenant_id:
task_logger.info(
f"Rolling back failed tenant provisioning for: {tenant_id}"
)
try:
from ee.onyx.server.tenants.provisioning import (
rollback_tenant_provisioning,
)
asyncio.run(rollback_tenant_provisioning(tenant_id))
except Exception:
task_logger.exception(f"Error during rollback for tenant: {tenant_id}")
finally:
lock_provision.release()

View File

@@ -28,7 +28,6 @@ from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
from onyx.connectors.models import Document
from onyx.connectors.models import IndexAttemptMetadata
from onyx.connectors.models import TextSection
from onyx.db.connector_credential_pair import get_connector_credential_pair_from_id
from onyx.db.connector_credential_pair import get_last_successful_attempt_time
from onyx.db.connector_credential_pair import update_connector_credential_pair
@@ -155,12 +154,14 @@ def strip_null_characters(doc_batch: list[Document]) -> list[Document]:
)
for section in cleaned_doc.sections:
if section.link is not None:
if section.link and "\x00" in section.link:
logger.warning(
f"NUL characters found in document link for document: {cleaned_doc.id}"
)
section.link = section.link.replace("\x00", "")
# since text can be longer, just replace to avoid double scan
if isinstance(section, TextSection) and section.text is not None:
section.text = section.text.replace("\x00", "")
section.text = section.text.replace("\x00", "")
cleaned_batch.append(cleaned_doc)
@@ -478,11 +479,7 @@ def _run_indexing(
doc_size = 0
for section in doc.sections:
if (
isinstance(section, TextSection)
and section.text is not None
):
doc_size += len(section.text)
doc_size += len(section.text)
if doc_size > INDEXING_SIZE_WARNING_THRESHOLD:
logger.warning(

View File

@@ -15,8 +15,6 @@ from onyx.chat.stream_processing.answer_response_handler import (
from onyx.chat.tool_handling.tool_response_handler import ToolResponseHandler
# This is Legacy code that is not used anymore.
# It is kept here for reference.
class LLMResponseHandlerManager:
"""
This class is responsible for postprocessing the LLM response stream.

View File

@@ -1,13 +1,10 @@
from collections import OrderedDict
from collections.abc import Callable
from collections.abc import Iterator
from collections.abc import Mapping
from datetime import datetime
from enum import Enum
from typing import Any
from typing import Literal
from typing import TYPE_CHECKING
from typing import Union
from pydantic import BaseModel
from pydantic import ConfigDict
@@ -47,44 +44,9 @@ class LlmDoc(BaseModel):
class SubQuestionIdentifier(BaseModel):
"""None represents references to objects in the original flow. To our understanding,
these will not be None in the packets returned from agent search.
"""
level: int | None = None
level_question_num: int | None = None
@staticmethod
def make_dict_by_level(
original_dict: Mapping[tuple[int, int], "SubQuestionIdentifier"]
) -> dict[int, list["SubQuestionIdentifier"]]:
"""returns a dict of level to object list (sorted by level_question_num)
Ordering is asc for readability.
"""
# organize by level, then sort ascending by question_index
level_dict: dict[int, list[SubQuestionIdentifier]] = {}
# group by level
for k, obj in original_dict.items():
level = k[0]
if level not in level_dict:
level_dict[level] = []
level_dict[level].append(obj)
# for each level, sort the group
for k2, value2 in level_dict.items():
# we need to handle the none case due to SubQuestionIdentifier typing
# level_question_num as int | None, even though it should never be None here.
level_dict[k2] = sorted(
value2,
key=lambda x: (x.level_question_num is None, x.level_question_num),
)
# sort by level
sorted_dict = OrderedDict(sorted(level_dict.items()))
return sorted_dict
# First chunk of info for streaming QA
class QADocsResponse(RetrievalDocs, SubQuestionIdentifier):
@@ -374,8 +336,6 @@ class AgentAnswerPiece(SubQuestionIdentifier):
class SubQuestionPiece(SubQuestionIdentifier):
"""Refined sub questions generated from the initial user question."""
sub_question: str
@@ -387,13 +347,13 @@ class RefinedAnswerImprovement(BaseModel):
refined_answer_improvement: bool
AgentSearchPacket = Union[
AgentSearchPacket = (
SubQuestionPiece
| AgentAnswerPiece
| SubQueryPiece
| ExtendedToolResponse
| RefinedAnswerImprovement
]
)
AnswerPacket = (
AnswerQuestionPossibleReturn | AgentSearchPacket | ToolCallKickoff | ToolResponse

View File

@@ -90,97 +90,97 @@ class CitationProcessor:
next(group for group in citation.groups() if group is not None)
)
if not (1 <= numerical_value <= self.max_citation_num):
continue
context_llm_doc = self.context_docs[numerical_value - 1]
final_citation_num = self.final_order_mapping[
context_llm_doc.document_id
]
if final_citation_num not in self.citation_order:
self.citation_order.append(final_citation_num)
citation_order_idx = self.citation_order.index(final_citation_num) + 1
# get the value that was displayed to user, should always
# be in the display_doc_order_dict. But check anyways
if context_llm_doc.document_id in self.display_order_mapping:
displayed_citation_num = self.display_order_mapping[
if 1 <= numerical_value <= self.max_citation_num:
context_llm_doc = self.context_docs[numerical_value - 1]
final_citation_num = self.final_order_mapping[
context_llm_doc.document_id
]
else:
displayed_citation_num = final_citation_num
logger.warning(
f"Doc {context_llm_doc.document_id} not in display_doc_order_dict. Used LLM citation number instead."
if final_citation_num not in self.citation_order:
self.citation_order.append(final_citation_num)
citation_order_idx = (
self.citation_order.index(final_citation_num) + 1
)
# Skip consecutive citations of the same work
if final_citation_num in self.current_citations:
start, end = citation.span()
real_start = length_to_add + start
diff = end - start
self.curr_segment = (
self.curr_segment[: length_to_add + start]
+ self.curr_segment[real_start + diff :]
)
length_to_add -= diff
continue
# Handle edge case where LLM outputs citation itself
if self.curr_segment.startswith("[["):
match = re.match(r"\[\[(\d+)\]\]", self.curr_segment)
if match:
try:
doc_id = int(match.group(1))
context_llm_doc = self.context_docs[doc_id - 1]
yield CitationInfo(
# citation_num is now the number post initial ranking, i.e. as displayed to user
citation_num=displayed_citation_num,
document_id=context_llm_doc.document_id,
)
except Exception as e:
logger.warning(
f"Manual LLM citation didn't properly cite documents {e}"
)
# get the value that was displayed to user, should always
# be in the display_doc_order_dict. But check anyways
if context_llm_doc.document_id in self.display_order_mapping:
displayed_citation_num = self.display_order_mapping[
context_llm_doc.document_id
]
else:
displayed_citation_num = final_citation_num
logger.warning(
"Manual LLM citation wasn't able to close brackets"
f"Doc {context_llm_doc.document_id} not in display_doc_order_dict. Used LLM citation number instead."
)
continue
link = context_llm_doc.link
# Skip consecutive citations of the same work
if final_citation_num in self.current_citations:
start, end = citation.span()
real_start = length_to_add + start
diff = end - start
self.curr_segment = (
self.curr_segment[: length_to_add + start]
+ self.curr_segment[real_start + diff :]
)
length_to_add -= diff
continue
self.past_cite_count = len(self.llm_out)
self.current_citations.append(final_citation_num)
# Handle edge case where LLM outputs citation itself
if self.curr_segment.startswith("[["):
match = re.match(r"\[\[(\d+)\]\]", self.curr_segment)
if match:
try:
doc_id = int(match.group(1))
context_llm_doc = self.context_docs[doc_id - 1]
yield CitationInfo(
# citation_num is now the number post initial ranking, i.e. as displayed to user
citation_num=displayed_citation_num,
document_id=context_llm_doc.document_id,
)
except Exception as e:
logger.warning(
f"Manual LLM citation didn't properly cite documents {e}"
)
else:
logger.warning(
"Manual LLM citation wasn't able to close brackets"
)
continue
if citation_order_idx not in self.cited_inds:
self.cited_inds.add(citation_order_idx)
yield CitationInfo(
# citation number is now the one that was displayed to user
citation_num=displayed_citation_num,
document_id=context_llm_doc.document_id,
)
link = context_llm_doc.link
start, end = citation.span()
if link:
prev_length = len(self.curr_segment)
self.curr_segment = (
self.curr_segment[: start + length_to_add]
+ f"[[{displayed_citation_num}]]({link})" # use the value that was displayed to user
+ self.curr_segment[end + length_to_add :]
)
length_to_add += len(self.curr_segment) - prev_length
else:
prev_length = len(self.curr_segment)
self.curr_segment = (
self.curr_segment[: start + length_to_add]
+ f"[[{displayed_citation_num}]]()" # use the value that was displayed to user
+ self.curr_segment[end + length_to_add :]
)
length_to_add += len(self.curr_segment) - prev_length
self.past_cite_count = len(self.llm_out)
self.current_citations.append(final_citation_num)
last_citation_end = end + length_to_add
if citation_order_idx not in self.cited_inds:
self.cited_inds.add(citation_order_idx)
yield CitationInfo(
# citation number is now the one that was displayed to user
citation_num=displayed_citation_num,
document_id=context_llm_doc.document_id,
)
start, end = citation.span()
if link:
prev_length = len(self.curr_segment)
self.curr_segment = (
self.curr_segment[: start + length_to_add]
+ f"[[{displayed_citation_num}]]({link})" # use the value that was displayed to user
+ self.curr_segment[end + length_to_add :]
)
length_to_add += len(self.curr_segment) - prev_length
else:
prev_length = len(self.curr_segment)
self.curr_segment = (
self.curr_segment[: start + length_to_add]
+ f"[[{displayed_citation_num}]]()" # use the value that was displayed to user
+ self.curr_segment[end + length_to_add :]
)
length_to_add += len(self.curr_segment) - prev_length
last_citation_end = end + length_to_add
if last_citation_end > 0:
result += self.curr_segment[:last_citation_end]

View File

@@ -217,20 +217,20 @@ AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION = int(
)
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION = 6 # in seconds
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION = 4 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION
)
AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION = 40 # in seconds
AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION = 30 # in seconds
AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION
)
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION = 10 # in seconds
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION = 5 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION
@@ -243,13 +243,13 @@ AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION = int(
)
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION = 15 # in seconds
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION = 5 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION
)
AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION = 45 # in seconds
AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION = 30 # in seconds
AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION
@@ -333,45 +333,4 @@ AGENT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION = int(
or AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION
)
AGENT_DEFAULT_MAX_TOKENS_VALIDATION = 4
AGENT_MAX_TOKENS_VALIDATION = int(
os.environ.get("AGENT_MAX_TOKENS_VALIDATION") or AGENT_DEFAULT_MAX_TOKENS_VALIDATION
)
AGENT_DEFAULT_MAX_TOKENS_SUBANSWER_GENERATION = 256
AGENT_MAX_TOKENS_SUBANSWER_GENERATION = int(
os.environ.get("AGENT_MAX_TOKENS_SUBANSWER_GENERATION")
or AGENT_DEFAULT_MAX_TOKENS_SUBANSWER_GENERATION
)
AGENT_DEFAULT_MAX_TOKENS_ANSWER_GENERATION = 1024
AGENT_MAX_TOKENS_ANSWER_GENERATION = int(
os.environ.get("AGENT_MAX_TOKENS_ANSWER_GENERATION")
or AGENT_DEFAULT_MAX_TOKENS_ANSWER_GENERATION
)
AGENT_DEFAULT_MAX_TOKENS_SUBQUESTION_GENERATION = 256
AGENT_MAX_TOKENS_SUBQUESTION_GENERATION = int(
os.environ.get("AGENT_MAX_TOKENS_SUBQUESTION_GENERATION")
or AGENT_DEFAULT_MAX_TOKENS_SUBQUESTION_GENERATION
)
AGENT_DEFAULT_MAX_TOKENS_ENTITY_TERM_EXTRACTION = 1024
AGENT_MAX_TOKENS_ENTITY_TERM_EXTRACTION = int(
os.environ.get("AGENT_MAX_TOKENS_ENTITY_TERM_EXTRACTION")
or AGENT_DEFAULT_MAX_TOKENS_ENTITY_TERM_EXTRACTION
)
AGENT_DEFAULT_MAX_TOKENS_SUBQUERY_GENERATION = 64
AGENT_MAX_TOKENS_SUBQUERY_GENERATION = int(
os.environ.get("AGENT_MAX_TOKENS_SUBQUERY_GENERATION")
or AGENT_DEFAULT_MAX_TOKENS_SUBQUERY_GENERATION
)
AGENT_DEFAULT_MAX_TOKENS_HISTORY_SUMMARY = 128
AGENT_MAX_TOKENS_HISTORY_SUMMARY = int(
os.environ.get("AGENT_MAX_TOKENS_HISTORY_SUMMARY")
or AGENT_DEFAULT_MAX_TOKENS_HISTORY_SUMMARY
)
GRAPH_VERSION_NAME: str = "a"

View File

@@ -8,9 +8,6 @@ from onyx.configs.constants import AuthType
from onyx.configs.constants import DocumentIndexType
from onyx.configs.constants import QueryHistoryType
from onyx.file_processing.enums import HtmlBasedConnectorTransformLinksStrategy
from onyx.prompts.image_analysis import DEFAULT_IMAGE_ANALYSIS_SYSTEM_PROMPT
from onyx.prompts.image_analysis import DEFAULT_IMAGE_SUMMARIZATION_SYSTEM_PROMPT
from onyx.prompts.image_analysis import DEFAULT_IMAGE_SUMMARIZATION_USER_PROMPT
#####
# App Configs
@@ -646,24 +643,3 @@ MOCK_LLM_RESPONSE = (
DEFAULT_IMAGE_ANALYSIS_MAX_SIZE_MB = 20
# Number of pre-provisioned tenants to maintain
TARGET_AVAILABLE_TENANTS = int(os.environ.get("TARGET_AVAILABLE_TENANTS", "5"))
# Image summarization configuration
IMAGE_SUMMARIZATION_SYSTEM_PROMPT = os.environ.get(
"IMAGE_SUMMARIZATION_SYSTEM_PROMPT",
DEFAULT_IMAGE_SUMMARIZATION_SYSTEM_PROMPT,
)
# The user prompt for image summarization - the image filename will be automatically prepended
IMAGE_SUMMARIZATION_USER_PROMPT = os.environ.get(
"IMAGE_SUMMARIZATION_USER_PROMPT",
DEFAULT_IMAGE_SUMMARIZATION_USER_PROMPT,
)
IMAGE_ANALYSIS_SYSTEM_PROMPT = os.environ.get(
"IMAGE_ANALYSIS_SYSTEM_PROMPT",
DEFAULT_IMAGE_ANALYSIS_SYSTEM_PROMPT,
)

View File

@@ -76,7 +76,6 @@ KV_REINDEX_KEY = "needs_reindexing"
KV_SEARCH_SETTINGS = "search_settings"
KV_UNSTRUCTURED_API_KEY = "unstructured_api_key"
KV_USER_STORE_KEY = "INVITED_USERS"
KV_PENDING_USERS_KEY = "PENDING_USERS"
KV_NO_AUTH_USER_PREFERENCES_KEY = "no_auth_user_preferences"
KV_CRED_KEY = "credential_id_{}"
KV_GMAIL_CRED_KEY = "gmail_app_credential"
@@ -322,8 +321,6 @@ class OnyxRedisLocks:
"da_lock:check_connector_external_group_sync_beat"
)
MONITOR_BACKGROUND_PROCESSES_LOCK = "da_lock:monitor_background_processes"
CHECK_AVAILABLE_TENANTS_LOCK = "da_lock:check_available_tenants"
PRE_PROVISION_TENANT_LOCK = "da_lock:pre_provision_tenant"
CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX = (
"da_lock:connector_doc_permissions_sync"
@@ -386,7 +383,6 @@ class OnyxCeleryTask:
CLOUD_MONITOR_CELERY_QUEUES = (
f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor_celery_queues"
)
CHECK_AVAILABLE_TENANTS = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check_available_tenants"
CHECK_FOR_CONNECTOR_DELETION = "check_for_connector_deletion_task"
CHECK_FOR_VESPA_SYNC_TASK = "check_for_vespa_sync_task"
@@ -403,9 +399,6 @@ class OnyxCeleryTask:
MONITOR_BACKGROUND_PROCESSES = "monitor_background_processes"
MONITOR_CELERY_QUEUES = "monitor_celery_queues"
# Tenant pre-provisioning
PRE_PROVISION_TENANT = "pre_provision_tenant"
KOMBU_MESSAGE_CLEANUP_TASK = "kombu_message_cleanup_task"
CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK = (
"connector_permission_sync_generator_task"

View File

@@ -4,7 +4,6 @@ from concurrent.futures import Future
from concurrent.futures import ThreadPoolExecutor
from io import BytesIO
from typing import Any
from typing import cast
import requests
from pyairtable import Api as AirtableApi
@@ -17,8 +16,7 @@ from onyx.configs.constants import DocumentSource
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.utils.logger import setup_logger
@@ -269,7 +267,7 @@ class AirtableConnector(LoadConnector):
table_id: str,
view_id: str | None,
record_id: str,
) -> tuple[list[TextSection], dict[str, str | list[str]]]:
) -> tuple[list[Section], dict[str, str | list[str]]]:
"""
Process a single Airtable field and return sections or metadata.
@@ -307,7 +305,7 @@ class AirtableConnector(LoadConnector):
# Otherwise, create relevant sections
sections = [
TextSection(
Section(
link=link,
text=(
f"{field_name}:\n"
@@ -342,7 +340,7 @@ class AirtableConnector(LoadConnector):
table_name = table_schema.name
record_id = record["id"]
fields = record["fields"]
sections: list[TextSection] = []
sections: list[Section] = []
metadata: dict[str, str | list[str]] = {}
# Get primary field value if it exists
@@ -386,7 +384,7 @@ class AirtableConnector(LoadConnector):
return Document(
id=f"airtable__{record_id}",
sections=(cast(list[TextSection | ImageSection], sections)),
sections=sections,
source=DocumentSource.AIRTABLE,
semantic_identifier=semantic_id,
metadata=metadata,

View File

@@ -10,7 +10,7 @@ from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -82,7 +82,7 @@ class AsanaConnector(LoadConnector, PollConnector):
logger.debug(f"Converting Asana task {task.id} to Document")
return Document(
id=task.id,
sections=[TextSection(link=task.link, text=task.text)],
sections=[Section(link=task.link, text=task.text)],
doc_updated_at=task.last_modified,
source=DocumentSource.ASANA,
semantic_identifier=task.title,

View File

@@ -20,7 +20,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.file_processing.html_utils import parse_html_page_basic
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
@@ -221,7 +221,7 @@ def _get_forums(
def _translate_forum_to_doc(af: AxeroForum) -> Document:
doc = Document(
id=af.doc_id,
sections=[TextSection(link=af.link, text=reply) for reply in af.responses],
sections=[Section(link=af.link, text=reply) for reply in af.responses],
source=DocumentSource.AXERO,
semantic_identifier=af.title,
doc_updated_at=af.last_update,
@@ -244,7 +244,7 @@ def _translate_content_to_doc(content: dict) -> Document:
doc = Document(
id="AXERO_" + str(content["ContentID"]),
sections=[TextSection(link=content["ContentURL"], text=page_text)],
sections=[Section(link=content["ContentURL"], text=page_text)],
source=DocumentSource.AXERO,
semantic_identifier=content["ContentTitle"],
doc_updated_at=time_str_to_utc(content["DateUpdated"]),

View File

@@ -25,7 +25,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.utils.logger import setup_logger
@@ -208,7 +208,7 @@ class BlobStorageConnector(LoadConnector, PollConnector):
batch.append(
Document(
id=f"{self.bucket_type}:{self.bucket_name}:{obj['Key']}",
sections=[TextSection(link=link, text=text)],
sections=[Section(link=link, text=text)],
source=DocumentSource(self.bucket_type.value),
semantic_identifier=name,
doc_updated_at=last_modified,
@@ -341,14 +341,7 @@ if __name__ == "__main__":
print("Sections:")
for section in doc.sections:
print(f" - Link: {section.link}")
if isinstance(section, TextSection) and section.text is not None:
print(f" - Text: {section.text[:100]}...")
elif (
hasattr(section, "image_file_name") and section.image_file_name
):
print(f" - Image: {section.image_file_name}")
else:
print("Error: Unknown section type")
print(f" - Text: {section.text[:100]}...")
print("---")
break

View File

@@ -18,7 +18,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.file_processing.html_utils import parse_html_page_basic
@@ -81,7 +81,7 @@ class BookstackConnector(LoadConnector, PollConnector):
)
return Document(
id="book__" + str(book.get("id")),
sections=[TextSection(link=url, text=text)],
sections=[Section(link=url, text=text)],
source=DocumentSource.BOOKSTACK,
semantic_identifier="Book: " + title,
title=title,
@@ -110,7 +110,7 @@ class BookstackConnector(LoadConnector, PollConnector):
)
return Document(
id="chapter__" + str(chapter.get("id")),
sections=[TextSection(link=url, text=text)],
sections=[Section(link=url, text=text)],
source=DocumentSource.BOOKSTACK,
semantic_identifier="Chapter: " + title,
title=title,
@@ -134,7 +134,7 @@ class BookstackConnector(LoadConnector, PollConnector):
)
return Document(
id="shelf:" + str(shelf.get("id")),
sections=[TextSection(link=url, text=text)],
sections=[Section(link=url, text=text)],
source=DocumentSource.BOOKSTACK,
semantic_identifier="Shelf: " + title,
title=title,
@@ -167,7 +167,7 @@ class BookstackConnector(LoadConnector, PollConnector):
time.sleep(0.1)
return Document(
id="page:" + page_id,
sections=[TextSection(link=url, text=text)],
sections=[Section(link=url, text=text)],
source=DocumentSource.BOOKSTACK,
semantic_identifier="Page: " + str(title),
title=str(title),

View File

@@ -17,7 +17,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.utils.retry_wrapper import retry_builder
@@ -62,11 +62,11 @@ class ClickupConnector(LoadConnector, PollConnector):
return response.json()
def _get_task_comments(self, task_id: str) -> list[TextSection]:
def _get_task_comments(self, task_id: str) -> list[Section]:
url_endpoint = f"/task/{task_id}/comment"
response = self._make_request(url_endpoint)
comments = [
TextSection(
Section(
link=f'https://app.clickup.com/t/{task_id}?comment={comment_dict["id"]}',
text=comment_dict["comment_text"],
)
@@ -133,7 +133,7 @@ class ClickupConnector(LoadConnector, PollConnector):
],
title=task["name"],
sections=[
TextSection(
Section(
link=task["url"],
text=(
task["markdown_description"]

View File

@@ -33,9 +33,9 @@ from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import Section
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.connectors.vision_enabled_connector import VisionEnabledConnector
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
@@ -66,6 +66,9 @@ _RESTRICTIONS_EXPANSION_FIELDS = [
_SLIM_DOC_BATCH_SIZE = 5000
_ATTACHMENT_EXTENSIONS_TO_FILTER_OUT = [
"png",
"jpg",
"jpeg",
"gif",
"mp4",
"mov",
@@ -85,6 +88,7 @@ class ConfluenceConnector(
PollConnector,
SlimConnector,
CredentialsConnector,
VisionEnabledConnector,
):
def __init__(
self,
@@ -115,6 +119,9 @@ class ConfluenceConnector(
self._confluence_client: OnyxConfluence | None = None
self._fetched_titles: set[str] = set()
# Initialize vision LLM using the mixin
self.initialize_vision_llm()
# Remove trailing slash from wiki_base if present
self.wiki_base = wiki_base.rstrip("/")
"""
@@ -241,16 +248,12 @@ class ConfluenceConnector(
)
# Create the main section for the page content
sections: list[TextSection | ImageSection] = [
TextSection(text=page_content, link=page_url)
]
sections = [Section(text=page_content, link=page_url)]
# Process comments if available
comment_text = self._get_comment_string_for_page_id(page_id)
if comment_text:
sections.append(
TextSection(text=comment_text, link=f"{page_url}#comments")
)
sections.append(Section(text=comment_text, link=f"{page_url}#comments"))
# Process attachments
if "children" in page and "attachment" in page["children"]:
@@ -263,27 +266,21 @@ class ConfluenceConnector(
result = process_attachment(
self.confluence_client,
attachment,
page_id,
page_title,
self.image_analysis_llm,
)
if result and result.text:
if result.text:
# Create a section for the attachment text
attachment_section = TextSection(
attachment_section = Section(
text=result.text,
link=f"{page_url}#attachment-{attachment['id']}",
)
sections.append(attachment_section)
elif result and result.file_name:
# Create an ImageSection for image attachments
image_section = ImageSection(
link=f"{page_url}#attachment-{attachment['id']}",
image_file_name=result.file_name,
)
sections.append(image_section)
else:
sections.append(attachment_section)
elif result.error:
logger.warning(
f"Error processing attachment '{attachment.get('title')}':",
f"{result.error if result else 'Unknown error'}",
f"Error processing attachment '{attachment.get('title')}': {result.error}"
)
# Extract metadata
@@ -308,9 +305,7 @@ class ConfluenceConnector(
# Create the document
return Document(
id=build_confluence_document_id(
self.wiki_base, page["_links"]["webui"], self.is_cloud
),
id=build_confluence_document_id(self.wiki_base, page_id, self.is_cloud),
sections=sections,
source=DocumentSource.CONFLUENCE,
semantic_identifier=page_title,
@@ -354,7 +349,7 @@ class ConfluenceConnector(
# Now get attachments for that page:
attachment_query = self._construct_attachment_query(page["id"])
# We'll use the page's XML to provide context if we summarize an image
page.get("body", {}).get("storage", {}).get("value", "")
confluence_xml = page.get("body", {}).get("storage", {}).get("value", "")
for attachment in self.confluence_client.paginated_cql_retrieval(
cql=attachment_query,
@@ -362,7 +357,7 @@ class ConfluenceConnector(
):
attachment["metadata"].get("mediaType", "")
if not validate_attachment_filetype(
attachment,
attachment, self.image_analysis_llm
):
continue
@@ -372,26 +367,23 @@ class ConfluenceConnector(
response = convert_attachment_to_content(
confluence_client=self.confluence_client,
attachment=attachment,
page_id=page["id"],
page_context=confluence_xml,
llm=self.image_analysis_llm,
)
if response is None:
continue
content_text, file_storage_name = response
object_url = build_confluence_document_id(
self.wiki_base, attachment["_links"]["webui"], self.is_cloud
self.wiki_base, page["_links"]["webui"], self.is_cloud
)
if content_text:
doc.sections.append(
TextSection(
Section(
text=content_text,
link=object_url,
)
)
elif file_storage_name:
doc.sections.append(
ImageSection(
link=object_url,
image_file_name=file_storage_name,
)
)
@@ -471,7 +463,7 @@ class ConfluenceConnector(
# If you skip images, you'll skip them in the permission sync
attachment["metadata"].get("mediaType", "")
if not validate_attachment_filetype(
attachment,
attachment, self.image_analysis_llm
):
continue

View File

@@ -1,3 +1,4 @@
import io
import json
import time
from collections.abc import Callable
@@ -18,11 +19,17 @@ from requests import HTTPError
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_ID
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET
from onyx.configs.app_configs import (
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
)
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
from onyx.connectors.confluence.utils import _handle_http_error
from onyx.connectors.confluence.utils import confluence_refresh_tokens
from onyx.connectors.confluence.utils import get_start_param_from_url
from onyx.connectors.confluence.utils import update_param_in_path
from onyx.connectors.confluence.utils import validate_attachment_filetype
from onyx.connectors.interfaces import CredentialsProviderInterface
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.html_utils import format_document_soup
from onyx.redis.redis_pool import get_redis_client
from onyx.utils.logger import setup_logger
@@ -801,6 +808,65 @@ def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str:
return _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) or _USER_NOT_FOUND
def attachment_to_content(
confluence_client: OnyxConfluence,
attachment: dict[str, Any],
parent_content_id: str | None = None,
) -> str | None:
"""If it returns None, assume that we should skip this attachment."""
if not validate_attachment_filetype(attachment):
return None
if "api.atlassian.com" in confluence_client.url:
# https://developer.atlassian.com/cloud/confluence/rest/v1/api-group-content---attachments/#api-wiki-rest-api-content-id-child-attachment-attachmentid-download-get
if not parent_content_id:
logger.warning(
"parent_content_id is required to download attachments from Confluence Cloud!"
)
return None
download_link = (
confluence_client.url
+ f"/rest/api/content/{parent_content_id}/child/attachment/{attachment['id']}/download"
)
else:
download_link = confluence_client.url + attachment["_links"]["download"]
attachment_size = attachment["extensions"]["fileSize"]
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
logger.warning(
f"Skipping {download_link} due to size. "
f"size={attachment_size} "
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD}"
)
return None
logger.info(f"_attachment_to_content - _session.get: link={download_link}")
# why are we using session.get here? we probably won't retry these ... is that ok?
response = confluence_client._session.get(download_link)
if response.status_code != 200:
logger.warning(
f"Failed to fetch {download_link} with invalid status code {response.status_code}"
)
return None
extracted_text = extract_file_text(
io.BytesIO(response.content),
file_name=attachment["title"],
break_on_unprocessable=False,
)
if len(extracted_text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD:
logger.warning(
f"Skipping {download_link} due to char count. "
f"char count={len(extracted_text)} "
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD}"
)
return None
return extracted_text
def extract_text_from_confluence_html(
confluence_client: OnyxConfluence,
confluence_object: dict[str, Any],

View File

@@ -22,7 +22,6 @@ from sqlalchemy.orm import Session
from onyx.configs.app_configs import (
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
)
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
from onyx.configs.constants import FileOrigin
if TYPE_CHECKING:
@@ -36,6 +35,7 @@ from onyx.db.pg_file_store import upsert_pgfilestore
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.file_validation import is_valid_image_type
from onyx.file_processing.image_utils import store_image_and_create_section
from onyx.llm.interfaces import LLM
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -53,16 +53,17 @@ class TokenResponse(BaseModel):
def validate_attachment_filetype(
attachment: dict[str, Any],
attachment: dict[str, Any], llm: LLM | None = None
) -> bool:
"""
Validates if the attachment is a supported file type.
If LLM is provided, also checks if it's an image that can be processed.
"""
attachment.get("metadata", {})
media_type = attachment.get("metadata", {}).get("mediaType", "")
if media_type.startswith("image/"):
return is_valid_image_type(media_type)
return llm is not None and is_valid_image_type(media_type)
# For non-image files, check if we support the extension
title = attachment.get("title", "")
@@ -83,103 +84,55 @@ class AttachmentProcessingResult(BaseModel):
error: str | None = None
def _make_attachment_link(
confluence_client: "OnyxConfluence",
attachment: dict[str, Any],
parent_content_id: str | None = None,
) -> str | None:
download_link = ""
if "api.atlassian.com" in confluence_client.url:
# https://developer.atlassian.com/cloud/confluence/rest/v1/api-group-content---attachments/#api-wiki-rest-api-content-id-child-attachment-attachmentid-download-get
if not parent_content_id:
logger.warning(
"parent_content_id is required to download attachments from Confluence Cloud!"
)
return None
download_link = (
confluence_client.url
+ f"/rest/api/content/{parent_content_id}/child/attachment/{attachment['id']}/download"
def _download_attachment(
confluence_client: "OnyxConfluence", attachment: dict[str, Any]
) -> bytes | None:
"""
Retrieves the raw bytes of an attachment from Confluence. Returns None on error.
"""
download_link = confluence_client.url + attachment["_links"]["download"]
resp = confluence_client._session.get(download_link)
if resp.status_code != 200:
logger.warning(
f"Failed to fetch {download_link} with status code {resp.status_code}"
)
else:
download_link = confluence_client.url + attachment["_links"]["download"]
return download_link
return None
return resp.content
def process_attachment(
confluence_client: "OnyxConfluence",
attachment: dict[str, Any],
parent_content_id: str | None,
page_context: str,
llm: LLM | None,
) -> AttachmentProcessingResult:
"""
Processes a Confluence attachment. If it's a document, extracts text,
or if it's an image, stores it for later analysis. Returns a structured result.
or if it's an image and an LLM is available, summarizes it. Returns a structured result.
"""
try:
# Get the media type from the attachment metadata
media_type = attachment.get("metadata", {}).get("mediaType", "")
# Validate the attachment type
if not validate_attachment_filetype(attachment):
if not validate_attachment_filetype(attachment, llm):
return AttachmentProcessingResult(
text=None,
file_name=None,
error=f"Unsupported file type: {media_type}",
)
attachment_link = _make_attachment_link(
confluence_client, attachment, parent_content_id
)
if not attachment_link:
return AttachmentProcessingResult(
text=None, file_name=None, error="Failed to make attachment link"
)
attachment_size = attachment["extensions"]["fileSize"]
if not media_type.startswith("image/"):
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
logger.warning(
f"Skipping {attachment_link} due to size. "
f"size={attachment_size} "
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD}"
)
return AttachmentProcessingResult(
text=None,
file_name=None,
error=f"Attachment text too long: {attachment_size} chars",
)
logger.info(
f"Downloading attachment: "
f"title={attachment['title']} "
f"length={attachment_size} "
f"link={attachment_link}"
)
# Download the attachment
resp: requests.Response = confluence_client._session.get(attachment_link)
if resp.status_code != 200:
logger.warning(
f"Failed to fetch {attachment_link} with status code {resp.status_code}"
)
raw_bytes = _download_attachment(confluence_client, attachment)
if raw_bytes is None:
return AttachmentProcessingResult(
text=None,
file_name=None,
error=f"Attachment download status code is {resp.status_code}",
text=None, file_name=None, error="Failed to download attachment"
)
raw_bytes = resp.content
if not raw_bytes:
return AttachmentProcessingResult(
text=None, file_name=None, error="attachment.content is None"
)
# Process image attachments
if media_type.startswith("image/"):
# Process image attachments with LLM if available
if media_type.startswith("image/") and llm:
return _process_image_attachment(
confluence_client, attachment, raw_bytes, media_type
confluence_client, attachment, page_context, llm, raw_bytes, media_type
)
# Process document attachments
@@ -212,10 +165,12 @@ def process_attachment(
def _process_image_attachment(
confluence_client: "OnyxConfluence",
attachment: dict[str, Any],
page_context: str,
llm: LLM,
raw_bytes: bytes,
media_type: str,
) -> AttachmentProcessingResult:
"""Process an image attachment by saving it without generating a summary."""
"""Process an image attachment by saving it and generating a summary."""
try:
# Use the standardized image storage and section creation
with get_session_with_current_tenant() as db_session:
@@ -225,14 +180,15 @@ def _process_image_attachment(
file_name=Path(attachment["id"]).name,
display_name=attachment["title"],
media_type=media_type,
llm=llm,
file_origin=FileOrigin.CONNECTOR,
)
logger.info(f"Stored image attachment with file name: {file_name}")
# Return empty text but include the file_name for later processing
return AttachmentProcessingResult(text="", file_name=file_name, error=None)
return AttachmentProcessingResult(
text=section.text, file_name=file_name, error=None
)
except Exception as e:
msg = f"Image storage failed for {attachment['title']}: {e}"
msg = f"Image summarization failed for {attachment['title']}: {e}"
logger.error(msg, exc_info=e)
return AttachmentProcessingResult(text=None, file_name=None, error=msg)
@@ -293,12 +249,13 @@ def _process_text_attachment(
def convert_attachment_to_content(
confluence_client: "OnyxConfluence",
attachment: dict[str, Any],
page_id: str,
page_context: str,
llm: LLM | None,
) -> tuple[str | None, str | None] | None:
"""
Facade function which:
1. Validates attachment type
2. Extracts content or stores image for later processing
2. Extracts or summarizes content
3. Returns (content_text, stored_file_name) or None if we should skip it
"""
media_type = attachment["metadata"]["mediaType"]
@@ -309,7 +266,7 @@ def convert_attachment_to_content(
)
return None
result = process_attachment(confluence_client, attachment, page_id)
result = process_attachment(confluence_client, attachment, page_context, llm)
if result.error is not None:
logger.warning(
f"Attachment {attachment['title']} encountered error: {result.error}"

View File

@@ -4,7 +4,6 @@ from collections.abc import Iterable
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
from discord import Client
from discord.channel import TextChannel
@@ -21,8 +20,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -34,7 +32,7 @@ _SNIPPET_LENGTH = 30
def _convert_message_to_document(
message: DiscordMessage,
sections: list[TextSection],
sections: list[Section],
) -> Document:
"""
Convert a discord message to a document
@@ -80,7 +78,7 @@ def _convert_message_to_document(
semantic_identifier=semantic_identifier,
doc_updated_at=message.edited_at,
title=title,
sections=(cast(list[TextSection | ImageSection], sections)),
sections=sections,
metadata=metadata,
)
@@ -125,8 +123,8 @@ async def _fetch_documents_from_channel(
if channel_message.type != MessageType.default:
continue
sections: list[TextSection] = [
TextSection(
sections: list[Section] = [
Section(
text=channel_message.content,
link=channel_message.jump_url,
)
@@ -144,7 +142,7 @@ async def _fetch_documents_from_channel(
continue
sections = [
TextSection(
Section(
text=thread_message.content,
link=thread_message.jump_url,
)
@@ -162,7 +160,7 @@ async def _fetch_documents_from_channel(
continue
sections = [
TextSection(
Section(
text=thread_message.content,
link=thread_message.jump_url,
)

View File

@@ -3,7 +3,6 @@ import urllib.parse
from datetime import datetime
from datetime import timezone
from typing import Any
from typing import cast
import requests
from pydantic import BaseModel
@@ -21,8 +20,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.file_processing.html_utils import parse_html_page_basic
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
@@ -114,7 +112,7 @@ class DiscourseConnector(PollConnector):
responders.append(BasicExpertInfo(display_name=responder_name))
sections.append(
TextSection(link=topic_url, text=parse_html_page_basic(post["cooked"]))
Section(link=topic_url, text=parse_html_page_basic(post["cooked"]))
)
category_name = self.category_id_map.get(topic["category_id"])
@@ -131,7 +129,7 @@ class DiscourseConnector(PollConnector):
doc = Document(
id="_".join([DocumentSource.DISCOURSE.value, str(topic["id"])]),
sections=cast(list[TextSection | ImageSection], sections),
sections=sections,
source=DocumentSource.DISCOURSE,
semantic_identifier=topic["title"],
doc_updated_at=time_str_to_utc(topic["last_posted_at"]),

View File

@@ -19,7 +19,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.file_processing.html_utils import parse_html_page_basic
from onyx.utils.retry_wrapper import retry_builder
@@ -158,7 +158,7 @@ class Document360Connector(LoadConnector, PollConnector):
document = Document(
id=article_details["id"],
sections=[TextSection(link=doc_link, text=doc_text)],
sections=[Section(link=doc_link, text=doc_text)],
source=DocumentSource.DOCUMENT360,
semantic_identifier=article_details["title"],
doc_updated_at=updated_at,

View File

@@ -19,7 +19,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.utils.logger import setup_logger
@@ -108,7 +108,7 @@ class DropboxConnector(LoadConnector, PollConnector):
batch.append(
Document(
id=f"doc:{entry.id}",
sections=[TextSection(link=link, text=text)],
sections=[Section(link=link, text=text)],
source=DocumentSource.DROPBOX,
semantic_identifier=entry.name,
doc_updated_at=modified_time,

View File

@@ -24,7 +24,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.file_processing.extract_file_text import detect_encoding
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import get_file_ext
@@ -111,7 +111,7 @@ def _process_egnyte_file(
# Create the document
return Document(
id=f"egnyte-{file_metadata['entry_id']}",
sections=[TextSection(text=file_content_raw.strip(), link=web_url)],
sections=[Section(text=file_content_raw.strip(), link=web_url)],
source=DocumentSource.EGNYTE,
semantic_identifier=file_name,
metadata=metadata,

View File

@@ -16,8 +16,8 @@ from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.connectors.vision_enabled_connector import VisionEnabledConnector
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.pg_file_store import get_pgfilestore_by_file_name
from onyx.file_processing.extract_file_text import extract_text_and_images
@@ -26,6 +26,7 @@ from onyx.file_processing.extract_file_text import is_valid_file_ext
from onyx.file_processing.extract_file_text import load_files_from_zip
from onyx.file_processing.image_utils import store_image_and_create_section
from onyx.file_store.file_store import get_default_file_store
from onyx.llm.interfaces import LLM
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -58,44 +59,32 @@ def _read_files_and_metadata(
def _create_image_section(
llm: LLM | None,
image_data: bytes,
db_session: Session,
parent_file_name: str,
display_name: str,
link: str | None = None,
idx: int = 0,
) -> tuple[ImageSection, str | None]:
) -> tuple[Section, str | None]:
"""
Creates an ImageSection for an image file or embedded image.
Stores the image in PGFileStore but does not generate a summary.
Args:
image_data: Raw image bytes
db_session: Database session
parent_file_name: Name of the parent file (for embedded images)
display_name: Display name for the image
idx: Index for embedded images
Create a Section object for a single image and store the image in PGFileStore.
If summarization is enabled and we have an LLM, summarize the image.
Returns:
Tuple of (ImageSection, stored_file_name or None)
tuple: (Section object, file_name in PGFileStore or None if storage failed)
"""
# Create a unique identifier for the image
file_name = f"{parent_file_name}_embedded_{idx}" if idx > 0 else parent_file_name
# Create a unique file name for the embedded image
file_name = f"{parent_file_name}_embedded_{idx}"
# Store the image and create a section
try:
section, stored_file_name = store_image_and_create_section(
db_session=db_session,
image_data=image_data,
file_name=file_name,
display_name=display_name,
link=link,
file_origin=FileOrigin.CONNECTOR,
)
return section, stored_file_name
except Exception as e:
logger.error(f"Failed to store image {display_name}: {e}")
raise e
# Use the standardized utility to store the image and create a section
return store_image_and_create_section(
db_session=db_session,
image_data=image_data,
file_name=file_name,
display_name=display_name,
llm=llm,
file_origin=FileOrigin.OTHER,
)
def _process_file(
@@ -104,16 +93,12 @@ def _process_file(
metadata: dict[str, Any] | None,
pdf_pass: str | None,
db_session: Session,
llm: LLM | None,
) -> list[Document]:
"""
Process a file and return a list of Documents.
For images, creates ImageSection objects without summarization.
For documents with embedded images, extracts and stores the images.
Processes a single file, returning a list of Documents (typically one).
Also handles embedded images if 'EMBEDDED_IMAGE_EXTRACTION_ENABLED' is true.
"""
if metadata is None:
metadata = {}
# Get file extension and determine file type
extension = get_file_ext(file_name)
# Fetch the DB record so we know the ID for internal URL
@@ -129,6 +114,8 @@ def _process_file(
return []
# Prepare doc metadata
if metadata is None:
metadata = {}
file_display_name = metadata.get("file_display_name") or os.path.basename(file_name)
# Timestamps
@@ -171,7 +158,6 @@ def _process_file(
"title",
"connector_type",
"pdf_password",
"mime_type",
]
}
@@ -184,45 +170,33 @@ def _process_file(
title = metadata.get("title") or file_display_name
# 1) If the file itself is an image, handle that scenario quickly
if extension in LoadConnector.IMAGE_EXTENSIONS:
# Read the image data
IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".webp"}
if extension in IMAGE_EXTENSIONS:
# Summarize or produce empty doc
image_data = file.read()
if not image_data:
logger.warning(f"Empty image file: {file_name}")
return []
# Create an ImageSection for the image
try:
section, _ = _create_image_section(
image_data=image_data,
db_session=db_session,
parent_file_name=pg_record.file_name,
display_name=title,
image_section, _ = _create_image_section(
llm, image_data, db_session, pg_record.file_name, title
)
return [
Document(
id=doc_id,
sections=[image_section],
source=source_type,
semantic_identifier=file_display_name,
title=title,
doc_updated_at=final_time_updated,
primary_owners=p_owners,
secondary_owners=s_owners,
metadata=metadata_tags,
)
]
return [
Document(
id=doc_id,
sections=[section],
source=source_type,
semantic_identifier=file_display_name,
title=title,
doc_updated_at=final_time_updated,
primary_owners=p_owners,
secondary_owners=s_owners,
metadata=metadata_tags,
)
]
except Exception as e:
logger.error(f"Failed to process image file {file_name}: {e}")
return []
# 2) Otherwise: text-based approach. Possibly with embedded images.
# 2) Otherwise: text-based approach. Possibly with embedded images if enabled.
# (For example .docx with inline images).
file.seek(0)
text_content = ""
embedded_images: list[tuple[bytes, str]] = []
# Extract text and images from the file
text_content, embedded_images = extract_text_and_images(
file=file,
file_name=file_name,
@@ -230,29 +204,24 @@ def _process_file(
)
# Build sections: first the text as a single Section
sections: list[TextSection | ImageSection] = []
sections = []
link_in_meta = metadata.get("link")
if text_content.strip():
sections.append(TextSection(link=link_in_meta, text=text_content.strip()))
sections.append(Section(link=link_in_meta, text=text_content.strip()))
# Then any extracted images from docx, etc.
for idx, (img_data, img_name) in enumerate(embedded_images, start=1):
# Store each embedded image as a separate file in PGFileStore
# and create a section with the image reference
try:
image_section, _ = _create_image_section(
image_data=img_data,
db_session=db_session,
parent_file_name=pg_record.file_name,
display_name=f"{title} - image {idx}",
idx=idx,
)
sections.append(image_section)
except Exception as e:
logger.warning(
f"Failed to process embedded image {idx} in {file_name}: {e}"
)
# and create a section with the image summary
image_section, _ = _create_image_section(
llm,
img_data,
db_session,
pg_record.file_name,
f"{title} - image {idx}",
idx,
)
sections.append(image_section)
return [
Document(
id=doc_id,
@@ -268,10 +237,10 @@ def _process_file(
]
class LocalFileConnector(LoadConnector):
class LocalFileConnector(LoadConnector, VisionEnabledConnector):
"""
Connector that reads files from Postgres and yields Documents, including
embedded image extraction without summarization.
optional embedded image extraction.
"""
def __init__(
@@ -283,6 +252,9 @@ class LocalFileConnector(LoadConnector):
self.batch_size = batch_size
self.pdf_pass: str | None = None
# Initialize vision LLM using the mixin
self.initialize_vision_llm()
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self.pdf_pass = credentials.get("pdf_password")
@@ -314,6 +286,7 @@ class LocalFileConnector(LoadConnector):
metadata=metadata,
pdf_pass=self.pdf_pass,
db_session=db_session,
llm=self.image_analysis_llm,
)
documents.extend(new_docs)

View File

@@ -1,7 +1,6 @@
from collections.abc import Iterator
from datetime import datetime
from datetime import timezone
from typing import cast
from typing import List
import requests
@@ -15,8 +14,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -47,7 +45,7 @@ _FIREFLIES_API_QUERY = """
def _create_doc_from_transcript(transcript: dict) -> Document | None:
sections: List[TextSection] = []
sections: List[Section] = []
current_speaker_name = None
current_link = ""
current_text = ""
@@ -59,7 +57,7 @@ def _create_doc_from_transcript(transcript: dict) -> Document | None:
if sentence["speaker_name"] != current_speaker_name:
if current_speaker_name is not None:
sections.append(
TextSection(
Section(
link=current_link,
text=current_text.strip(),
)
@@ -73,7 +71,7 @@ def _create_doc_from_transcript(transcript: dict) -> Document | None:
# Sometimes these links (links with a timestamp) do not work, it is a bug with Fireflies.
sections.append(
TextSection(
Section(
link=current_link,
text=current_text.strip(),
)
@@ -96,7 +94,7 @@ def _create_doc_from_transcript(transcript: dict) -> Document | None:
return Document(
id=fireflies_id,
sections=cast(list[TextSection | ImageSection], sections),
sections=sections,
source=DocumentSource.FIREFLIES,
semantic_identifier=meeting_title,
metadata={},

View File

@@ -14,7 +14,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.file_processing.html_utils import parse_html_page_basic
from onyx.utils.logger import setup_logger
@@ -133,7 +133,7 @@ def _create_doc_from_ticket(ticket: dict, domain: str) -> Document:
return Document(
id=_FRESHDESK_ID_PREFIX + link,
sections=[
TextSection(
Section(
link=link,
text=text,
)

View File

@@ -13,7 +13,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.utils.logger import setup_logger
@@ -183,7 +183,7 @@ def _convert_page_to_document(
return Document(
id=f"gitbook-{space_id}-{page_id}",
sections=[
TextSection(
Section(
link=page.get("urls", {}).get("app", ""),
text=_extract_text_from_document(page_content),
)
@@ -228,15 +228,10 @@ class GitbookConnector(LoadConnector, PollConnector):
raise ConnectorMissingCredentialError("GitBook")
try:
content = self.client.get(f"/spaces/{self.space_id}/content/pages")
content = self.client.get(f"/spaces/{self.space_id}/content")
pages: list[dict[str, Any]] = content.get("pages", [])
current_batch: list[Document] = []
logger.info(f"Found {len(pages)} root pages.")
logger.info(
f"First 20 Page Ids: {[page.get('id', 'Unknown') for page in pages[:20]]}"
)
while pages:
page = pages.pop(0)

View File

@@ -27,7 +27,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.utils.batching import batch_generator
from onyx.utils.logger import setup_logger
@@ -87,9 +87,7 @@ def _batch_github_objects(
def _convert_pr_to_document(pull_request: PullRequest) -> Document:
return Document(
id=pull_request.html_url,
sections=[
TextSection(link=pull_request.html_url, text=pull_request.body or "")
],
sections=[Section(link=pull_request.html_url, text=pull_request.body or "")],
source=DocumentSource.GITHUB,
semantic_identifier=pull_request.title,
# updated_at is UTC time but is timezone unaware, explicitly add UTC
@@ -111,7 +109,7 @@ def _fetch_issue_comments(issue: Issue) -> str:
def _convert_issue_to_document(issue: Issue) -> Document:
return Document(
id=issue.html_url,
sections=[TextSection(link=issue.html_url, text=issue.body or "")],
sections=[Section(link=issue.html_url, text=issue.body or "")],
source=DocumentSource.GITHUB,
semantic_identifier=issue.title,
# updated_at is UTC time but is timezone unaware

View File

@@ -21,7 +21,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.utils.logger import setup_logger
@@ -56,7 +56,7 @@ def get_author(author: Any) -> BasicExpertInfo:
def _convert_merge_request_to_document(mr: Any) -> Document:
doc = Document(
id=mr.web_url,
sections=[TextSection(link=mr.web_url, text=mr.description or "")],
sections=[Section(link=mr.web_url, text=mr.description or "")],
source=DocumentSource.GITLAB,
semantic_identifier=mr.title,
# updated_at is UTC time but is timezone unaware, explicitly add UTC
@@ -72,7 +72,7 @@ def _convert_merge_request_to_document(mr: Any) -> Document:
def _convert_issue_to_document(issue: Any) -> Document:
doc = Document(
id=issue.web_url,
sections=[TextSection(link=issue.web_url, text=issue.description or "")],
sections=[Section(link=issue.web_url, text=issue.description or "")],
source=DocumentSource.GITLAB,
semantic_identifier=issue.title,
# updated_at is UTC time but is timezone unaware, explicitly add UTC
@@ -99,7 +99,7 @@ def _convert_code_to_document(
file_url = f"{url}/{projectOwner}/{projectName}/-/blob/master/{file['path']}" # Construct the file URL
doc = Document(
id=file["id"],
sections=[TextSection(link=file_url, text=file_content)],
sections=[Section(link=file_url, text=file_content)],
source=DocumentSource.GITLAB,
semantic_identifier=file["name"],
doc_updated_at=datetime.now().replace(

View File

@@ -1,6 +1,5 @@
from base64 import urlsafe_b64decode
from typing import Any
from typing import cast
from typing import Dict
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
@@ -29,9 +28,8 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import Section
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
@@ -117,7 +115,7 @@ def _get_message_body(payload: dict[str, Any]) -> str:
return message_body
def message_to_section(message: Dict[str, Any]) -> tuple[TextSection, dict[str, str]]:
def message_to_section(message: Dict[str, Any]) -> tuple[Section, dict[str, str]]:
link = f"https://mail.google.com/mail/u/0/#inbox/{message['id']}"
payload = message.get("payload", {})
@@ -144,7 +142,7 @@ def message_to_section(message: Dict[str, Any]) -> tuple[TextSection, dict[str,
message_body_text: str = _get_message_body(payload)
return TextSection(link=link, text=message_body_text + message_data), metadata
return Section(link=link, text=message_body_text + message_data), metadata
def thread_to_document(full_thread: Dict[str, Any]) -> Document | None:
@@ -194,7 +192,7 @@ def thread_to_document(full_thread: Dict[str, Any]) -> Document | None:
return Document(
id=id,
semantic_identifier=semantic_identifier,
sections=cast(list[TextSection | ImageSection], sections),
sections=sections,
source=DocumentSource.GMAIL,
# This is used to perform permission sync
primary_owners=primary_owners,

View File

@@ -18,7 +18,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.utils.logger import setup_logger
@@ -243,7 +243,7 @@ class GongConnector(LoadConnector, PollConnector):
Document(
id=call_id,
sections=[
TextSection(link=call_metadata["url"], text=transcript_text)
Section(link=call_metadata["url"], text=transcript_text)
],
source=DocumentSource.GONG,
# Should not ever be Untitled as a call cannot be made without a Title

View File

@@ -43,7 +43,9 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.vision_enabled_connector import VisionEnabledConnector
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.llm.interfaces import LLM
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
@@ -66,6 +68,7 @@ def _convert_single_file(
creds: Any,
primary_admin_email: str,
file: dict[str, Any],
image_analysis_llm: LLM | None,
) -> Any:
user_email = file.get("owners", [{}])[0].get("emailAddress") or primary_admin_email
user_drive_service = get_drive_service(creds, user_email=user_email)
@@ -74,6 +77,7 @@ def _convert_single_file(
file=file,
drive_service=user_drive_service,
docs_service=docs_service,
image_analysis_llm=image_analysis_llm, # pass the LLM so doc_conversion can summarize images
)
@@ -112,7 +116,9 @@ def _clean_requested_drive_ids(
return valid_requested_drive_ids, filtered_folder_ids
class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
class GoogleDriveConnector(
LoadConnector, PollConnector, SlimConnector, VisionEnabledConnector
):
def __init__(
self,
include_shared_drives: bool = False,
@@ -145,6 +151,9 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
if continue_on_failure is not None:
logger.warning("The 'continue_on_failure' parameter is deprecated.")
# Initialize vision LLM using the mixin
self.initialize_vision_llm()
if (
not include_shared_drives
and not include_my_drives
@@ -307,9 +316,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
# validate that the user has access to the drive APIs by performing a simple
# request and checking for a 401
try:
# default is ~17mins of retries, don't do that here for cases so we don't
# waste 17mins everytime we run into a user without access to drive APIs
retry_builder(tries=3, delay=1)(get_root_folder_id)(drive_service)
retry_builder()(get_root_folder_id)(drive_service)
except HttpError as e:
if e.status_code == 401:
# fail gracefully, let the other impersonations continue
@@ -530,6 +537,7 @@ class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
_convert_single_file,
self.creds,
self.primary_admin_email,
image_analysis_llm=self.image_analysis_llm, # Use the mixin's LLM
)
# Fetch files in batches

View File

@@ -1,50 +1,40 @@
import io
from datetime import datetime
from typing import cast
from datetime import timezone
from tempfile import NamedTemporaryFile
from googleapiclient.http import MediaIoBaseDownload # type: ignore
import openpyxl # type: ignore
from googleapiclient.discovery import build # type: ignore
from googleapiclient.errors import HttpError # type: ignore
from onyx.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import FileOrigin
from onyx.connectors.google_drive.constants import DRIVE_FOLDER_TYPE
from onyx.connectors.google_drive.constants import DRIVE_SHORTCUT_TYPE
from onyx.connectors.google_drive.constants import UNSUPPORTED_FILE_TYPE_CONTENT
from onyx.connectors.google_drive.models import GDriveMimeType
from onyx.connectors.google_drive.models import GoogleDriveFileType
from onyx.connectors.google_drive.section_extraction import get_document_sections
from onyx.connectors.google_utils.resources import GoogleDocsService
from onyx.connectors.google_utils.resources import GoogleDriveService
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import Section
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.db.engine import get_session_with_current_tenant
from onyx.file_processing.extract_file_text import docx_to_text_and_images
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import pptx_to_text
from onyx.file_processing.extract_file_text import read_pdf_file
from onyx.file_processing.extract_file_text import xlsx_to_text
from onyx.file_processing.file_validation import is_valid_image_type
from onyx.file_processing.image_summarization import summarize_image_with_error_handling
from onyx.file_processing.image_utils import store_image_and_create_section
from onyx.file_processing.unstructured import get_unstructured_api_key
from onyx.file_processing.unstructured import unstructured_to_text
from onyx.llm.interfaces import LLM
from onyx.utils.logger import setup_logger
logger = setup_logger()
# Mapping of Google Drive mime types to export formats
GOOGLE_MIME_TYPES_TO_EXPORT = {
GDriveMimeType.DOC.value: "text/plain",
GDriveMimeType.SPREADSHEET.value: "text/csv",
GDriveMimeType.PPT.value: "text/plain",
}
# Define Google MIME types mapping
GOOGLE_MIME_TYPES = {
GDriveMimeType.DOC.value: "text/plain",
GDriveMimeType.SPREADSHEET.value: "text/csv",
GDriveMimeType.PPT.value: "text/plain",
}
def _summarize_drive_image(
image_data: bytes, image_name: str, image_analysis_llm: LLM | None
@@ -76,137 +66,259 @@ def is_gdrive_image_mime_type(mime_type: str) -> bool:
def _extract_sections_basic(
file: dict[str, str],
service: GoogleDriveService,
) -> list[TextSection | ImageSection]:
"""Extract text and images from a Google Drive file."""
file_id = file["id"]
file_name = file["name"]
image_analysis_llm: LLM | None = None,
) -> list[Section]:
"""
Extends the existing logic to handle either a docx with embedded images
or standalone images (PNG, JPG, etc).
"""
mime_type = file["mimeType"]
link = file.get("webViewLink", "")
link = file["webViewLink"]
file_name = file.get("name", file["id"])
supported_file_types = set(item.value for item in GDriveMimeType)
try:
# For Google Docs, Sheets, and Slides, export as plain text
if mime_type in GOOGLE_MIME_TYPES_TO_EXPORT:
export_mime_type = GOOGLE_MIME_TYPES_TO_EXPORT[mime_type]
# Use the correct API call for exporting files
request = service.files().export_media(
fileId=file_id, mimeType=export_mime_type
)
response_bytes = io.BytesIO()
downloader = MediaIoBaseDownload(response_bytes, request)
done = False
while not done:
_, done = downloader.next_chunk()
# 1) If the file is an image, retrieve the raw bytes, optionally summarize
if is_gdrive_image_mime_type(mime_type):
try:
response = service.files().get_media(fileId=file["id"]).execute()
response = response_bytes.getvalue()
if not response:
logger.warning(f"Failed to export {file_name} as {export_mime_type}")
return []
text = response.decode("utf-8")
return [TextSection(link=link, text=text)]
# For other file types, download the file
# Use the correct API call for downloading files
request = service.files().get_media(fileId=file_id)
response_bytes = io.BytesIO()
downloader = MediaIoBaseDownload(response_bytes, request)
done = False
while not done:
_, done = downloader.next_chunk()
response = response_bytes.getvalue()
if not response:
logger.warning(f"Failed to download {file_name}")
return []
# Process based on mime type
if mime_type == "text/plain":
text = response.decode("utf-8")
return [TextSection(link=link, text=text)]
elif (
mime_type
== "application/vnd.openxmlformats-officedocument.wordprocessingml.document"
):
text, _ = docx_to_text_and_images(io.BytesIO(response))
return [TextSection(link=link, text=text)]
elif (
mime_type
== "application/vnd.openxmlformats-officedocument.spreadsheetml.sheet"
):
text = xlsx_to_text(io.BytesIO(response))
return [TextSection(link=link, text=text)]
elif (
mime_type
== "application/vnd.openxmlformats-officedocument.presentationml.presentation"
):
text = pptx_to_text(io.BytesIO(response))
return [TextSection(link=link, text=text)]
elif is_gdrive_image_mime_type(mime_type):
# For images, store them for later processing
sections: list[TextSection | ImageSection] = []
try:
with get_session_with_current_tenant() as db_session:
section, embedded_id = store_image_and_create_section(
db_session=db_session,
image_data=response,
file_name=file_id,
display_name=file_name,
media_type=mime_type,
file_origin=FileOrigin.CONNECTOR,
link=link,
)
sections.append(section)
except Exception as e:
logger.error(f"Failed to process image {file_name}: {e}")
return sections
elif mime_type == "application/pdf":
text, _pdf_meta, images = read_pdf_file(io.BytesIO(response))
pdf_sections: list[TextSection | ImageSection] = [
TextSection(link=link, text=text)
with get_session_with_current_tenant() as db_session:
section, _ = store_image_and_create_section(
db_session=db_session,
image_data=response,
file_name=file["id"],
display_name=file_name,
media_type=mime_type,
llm=image_analysis_llm,
file_origin=FileOrigin.CONNECTOR,
)
return [section]
except Exception as e:
logger.warning(f"Failed to fetch or summarize image: {e}")
return [
Section(
link=link,
text="",
image_file_name=link,
)
]
# Process embedded images in the PDF
if mime_type not in supported_file_types:
# Unsupported file types can still have a title, finding this way is still useful
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
try:
# ---------------------------
# Google Sheets extraction
if mime_type == GDriveMimeType.SPREADSHEET.value:
try:
sheets_service = build(
"sheets", "v4", credentials=service._http.credentials
)
spreadsheet = (
sheets_service.spreadsheets()
.get(spreadsheetId=file["id"])
.execute()
)
sections = []
for sheet in spreadsheet["sheets"]:
sheet_name = sheet["properties"]["title"]
sheet_id = sheet["properties"]["sheetId"]
# Get sheet dimensions
grid_properties = sheet["properties"].get("gridProperties", {})
row_count = grid_properties.get("rowCount", 1000)
column_count = grid_properties.get("columnCount", 26)
# Convert column count to letter (e.g., 26 -> Z, 27 -> AA)
end_column = ""
while column_count:
column_count, remainder = divmod(column_count - 1, 26)
end_column = chr(65 + remainder) + end_column
range_name = f"'{sheet_name}'!A1:{end_column}{row_count}"
try:
result = (
sheets_service.spreadsheets()
.values()
.get(spreadsheetId=file["id"], range=range_name)
.execute()
)
values = result.get("values", [])
if values:
text = f"Sheet: {sheet_name}\n"
for row in values:
text += "\t".join(str(cell) for cell in row) + "\n"
sections.append(
Section(
link=f"{link}#gid={sheet_id}",
text=text,
)
)
except HttpError as e:
logger.warning(
f"Error fetching data for sheet '{sheet_name}': {e}"
)
continue
return sections
except Exception as e:
logger.warning(
f"Ran into exception '{e}' when pulling data from Google Sheet '{file['name']}'."
" Falling back to basic extraction."
)
# ---------------------------
# Microsoft Excel (.xlsx or .xls) extraction branch
elif mime_type in [
GDriveMimeType.SPREADSHEET_OPEN_FORMAT.value,
GDriveMimeType.SPREADSHEET_MS_EXCEL.value,
]:
try:
response = service.files().get_media(fileId=file["id"]).execute()
with NamedTemporaryFile(suffix=".xlsx", delete=True) as tmp:
tmp.write(response)
tmp_path = tmp.name
section_separator = "\n\n"
workbook = openpyxl.load_workbook(tmp_path, read_only=True)
# Work similarly to the xlsx_to_text function used for file connector
# but returns Sections instead of a string
sections = [
Section(
link=link,
text=(
f"Sheet: {sheet.title}\n\n"
+ section_separator.join(
",".join(map(str, row))
for row in sheet.iter_rows(
min_row=1, values_only=True
)
if row
)
),
)
for sheet in workbook.worksheets
]
return sections
except Exception as e:
logger.warning(
f"Error extracting data from Excel file '{file['name']}': {e}"
)
return [
Section(link=link, text="Error extracting data from Excel file")
]
# ---------------------------
# Export for Google Docs, PPT, and fallback for spreadsheets
if mime_type in [
GDriveMimeType.DOC.value,
GDriveMimeType.PPT.value,
GDriveMimeType.SPREADSHEET.value,
]:
export_mime_type = (
"text/plain"
if mime_type != GDriveMimeType.SPREADSHEET.value
else "text/csv"
)
text = (
service.files()
.export(fileId=file["id"], mimeType=export_mime_type)
.execute()
.decode("utf-8")
)
return [Section(link=link, text=text)]
# ---------------------------
# Plain text and Markdown files
elif mime_type in [
GDriveMimeType.PLAIN_TEXT.value,
GDriveMimeType.MARKDOWN.value,
]:
text_data = (
service.files().get_media(fileId=file["id"]).execute().decode("utf-8")
)
return [Section(link=link, text=text_data)]
# ---------------------------
# Word, PowerPoint, PDF files
elif mime_type in [
GDriveMimeType.WORD_DOC.value,
GDriveMimeType.POWERPOINT.value,
GDriveMimeType.PDF.value,
]:
response_bytes = service.files().get_media(fileId=file["id"]).execute()
# Optionally use Unstructured
if get_unstructured_api_key():
text = unstructured_to_text(
file=io.BytesIO(response_bytes),
file_name=file_name,
)
return [Section(link=link, text=text)]
if mime_type == GDriveMimeType.WORD_DOC.value:
# Use docx_to_text_and_images to get text plus embedded images
text, embedded_images = docx_to_text_and_images(
file=io.BytesIO(response_bytes),
)
sections = []
if text.strip():
sections.append(Section(link=link, text=text.strip()))
# Process each embedded image using the standardized function
with get_session_with_current_tenant() as db_session:
for idx, (img_data, img_name) in enumerate(images):
section, embedded_id = store_image_and_create_section(
for idx, (img_data, img_name) in enumerate(
embedded_images, start=1
):
# Create a unique identifier for the embedded image
embedded_id = f"{file['id']}_embedded_{idx}"
section, _ = store_image_and_create_section(
db_session=db_session,
image_data=img_data,
file_name=f"{file_id}_img_{idx}",
file_name=embedded_id,
display_name=img_name or f"{file_name} - image {idx}",
llm=image_analysis_llm,
file_origin=FileOrigin.CONNECTOR,
)
pdf_sections.append(section)
except Exception as e:
logger.error(f"Failed to process PDF images in {file_name}: {e}")
return pdf_sections
sections.append(section)
return sections
else:
# For unsupported file types, try to extract text
try:
text = extract_file_text(io.BytesIO(response), file_name)
return [TextSection(link=link, text=text)]
except Exception as e:
logger.warning(f"Failed to extract text from {file_name}: {e}")
return []
elif mime_type == GDriveMimeType.PDF.value:
text, _pdf_meta, images = read_pdf_file(io.BytesIO(response_bytes))
return [Section(link=link, text=text)]
elif mime_type == GDriveMimeType.POWERPOINT.value:
text_data = pptx_to_text(io.BytesIO(response_bytes))
return [Section(link=link, text=text_data)]
# Catch-all case, should not happen since there should be specific handling
# for each of the supported file types
error_message = f"Unsupported file type: {mime_type}"
logger.error(error_message)
raise ValueError(error_message)
except Exception as e:
logger.error(f"Error processing file {file_name}: {e}")
return []
logger.exception(f"Error extracting sections from file: {e}")
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
def convert_drive_item_to_document(
file: GoogleDriveFileType,
drive_service: GoogleDriveService,
docs_service: GoogleDocsService,
image_analysis_llm: LLM | None,
) -> Document | None:
"""
Main entry point for converting a Google Drive file => Document object.
Now we accept an optional `llm` to pass to `_extract_sections_basic`.
"""
try:
# skip shortcuts or folders
@@ -215,50 +327,44 @@ def convert_drive_item_to_document(
return None
# If it's a Google Doc, we might do advanced parsing
sections: list[TextSection | ImageSection] = []
# Try to get sections using the advanced method first
sections: list[Section] = []
if file.get("mimeType") == GDriveMimeType.DOC.value:
try:
doc_sections = get_document_sections(
docs_service=docs_service, doc_id=file.get("id", "")
)
if doc_sections:
sections = cast(list[TextSection | ImageSection], doc_sections)
# get_document_sections is the advanced approach for Google Docs
sections = get_document_sections(docs_service, file["id"])
except Exception as e:
logger.warning(
f"Error in advanced parsing: {e}. Falling back to basic extraction."
f"Failed to pull google doc sections from '{file['name']}': {e}. "
"Falling back to basic extraction."
)
# If we don't have sections yet, use the basic extraction method
# If not a doc, or if we failed above, do our 'basic' approach
if not sections:
sections = _extract_sections_basic(file, drive_service)
sections = _extract_sections_basic(file, drive_service, image_analysis_llm)
# If we still don't have any sections, skip this file
if not sections:
logger.warning(f"No content extracted from {file.get('name')}. Skipping.")
return None
doc_id = file["webViewLink"]
updated_time = datetime.fromisoformat(file["modifiedTime"]).astimezone(
timezone.utc
)
# Create the document
return Document(
id=doc_id,
sections=sections,
source=DocumentSource.GOOGLE_DRIVE,
semantic_identifier=file.get("name", ""),
metadata={
"owner_names": ", ".join(
owner.get("displayName", "") for owner in file.get("owners", [])
),
},
doc_updated_at=datetime.fromisoformat(
file.get("modifiedTime", "").replace("Z", "+00:00")
),
semantic_identifier=file["name"],
doc_updated_at=updated_time,
metadata={}, # or any metadata from 'file'
additional_info=file.get("id"),
)
except Exception as e:
logger.error(f"Error converting file {file.get('name')}: {e}")
return None
logger.exception(f"Error converting file '{file.get('name')}' to Document: {e}")
if not CONTINUE_ON_CONNECTOR_FAILURE:
raise
return None
def build_slim_document(file: GoogleDriveFileType) -> SlimDocument | None:

View File

@@ -3,7 +3,7 @@ from typing import Any
from pydantic import BaseModel
from onyx.connectors.google_utils.resources import GoogleDocsService
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
class CurrentHeading(BaseModel):
@@ -37,7 +37,7 @@ def _extract_text_from_paragraph(paragraph: dict[str, Any]) -> str:
def get_document_sections(
docs_service: GoogleDocsService,
doc_id: str,
) -> list[TextSection]:
) -> list[Section]:
"""Extracts sections from a Google Doc, including their headings and content"""
# Fetch the document structure
doc = docs_service.documents().get(documentId=doc_id).execute()
@@ -45,7 +45,7 @@ def get_document_sections(
# Get the content
content = doc.get("body", {}).get("content", [])
sections: list[TextSection] = []
sections: list[Section] = []
current_section: list[str] = []
current_heading: CurrentHeading | None = None
@@ -70,7 +70,7 @@ def get_document_sections(
heading_text = current_heading.text
section_text = f"{heading_text}\n" + "\n".join(current_section)
sections.append(
TextSection(
Section(
text=section_text.strip(),
link=_build_gdoc_section_link(doc_id, current_heading.id),
)
@@ -96,7 +96,7 @@ def get_document_sections(
if current_heading is not None and current_section:
section_text = f"{current_heading.text}\n" + "\n".join(current_section)
sections.append(
TextSection(
Section(
text=section_text.strip(),
link=_build_gdoc_section_link(doc_id, current_heading.id),
)

View File

@@ -12,7 +12,7 @@ from onyx.configs.constants import DocumentSource
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.db.engine import get_sqlalchemy_engine
from onyx.file_processing.extract_file_text import load_files_from_zip
from onyx.file_processing.extract_file_text import read_text_file
@@ -118,7 +118,7 @@ class GoogleSitesConnector(LoadConnector):
source=DocumentSource.GOOGLE_SITES,
semantic_identifier=title,
sections=[
TextSection(
Section(
link=(self.base_url.rstrip("/") + "/" + path.lstrip("/"))
if path
else "",

View File

@@ -15,7 +15,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.file_processing.html_utils import parse_html_page_basic
from onyx.utils.logger import setup_logger
@@ -120,7 +120,7 @@ class GuruConnector(LoadConnector, PollConnector):
doc_batch.append(
Document(
id=card["id"],
sections=[TextSection(link=link, text=content_text)],
sections=[Section(link=link, text=content_text)],
source=DocumentSource.GURU,
semantic_identifier=title,
doc_updated_at=latest_time,

View File

@@ -13,7 +13,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.utils.logger import setup_logger
HUBSPOT_BASE_URL = "https://app.hubspot.com/contacts/"
@@ -108,7 +108,7 @@ class HubSpotConnector(LoadConnector, PollConnector):
doc_batch.append(
Document(
id=ticket.id,
sections=[TextSection(link=link, text=content_text)],
sections=[Section(link=link, text=content_text)],
source=DocumentSource.HUBSPOT,
semantic_identifier=title,
# Is already in tzutc, just replacing the timezone format

View File

@@ -24,8 +24,6 @@ CheckpointOutput = Generator[Document | ConnectorFailure, None, ConnectorCheckpo
class BaseConnector(abc.ABC):
REDIS_KEY_PREFIX = "da_connector_data:"
# Common image file extensions supported across connectors
IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".webp", ".gif"}
@abc.abstractmethod
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:

View File

@@ -21,8 +21,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import request_with_retries
@@ -238,30 +237,22 @@ class LinearConnector(LoadConnector, PollConnector, OAuthConnector):
documents: list[Document] = []
for edge in edges:
node = edge["node"]
# Create sections for description and comments
sections = [
TextSection(
link=node["url"],
text=node["description"] or "",
)
]
# Add comment sections
for comment in node["comments"]["nodes"]:
sections.append(
TextSection(
link=node["url"],
text=comment["body"] or "",
)
)
# Cast the sections list to the expected type
typed_sections = cast(list[TextSection | ImageSection], sections)
documents.append(
Document(
id=node["id"],
sections=typed_sections,
sections=[
Section(
link=node["url"],
text=node["description"] or "",
)
]
+ [
Section(
link=node["url"],
text=comment["body"] or "",
)
for comment in node["comments"]["nodes"]
],
source=DocumentSource.LINEAR,
semantic_identifier=f"[{node['identifier']}] {node['title']}",
title=node["title"],

View File

@@ -17,7 +17,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.file_processing.html_utils import parse_html_page_basic
from onyx.file_processing.html_utils import strip_excessive_newlines_and_spaces
from onyx.utils.logger import setup_logger
@@ -162,7 +162,7 @@ class LoopioConnector(LoadConnector, PollConnector):
doc_batch.append(
Document(
id=str(entry["id"]),
sections=[TextSection(link=link, text=content_text)],
sections=[Section(link=link, text=content_text)],
source=DocumentSource.LOOPIO,
semantic_identifier=questions[0],
doc_updated_at=latest_time,

View File

@@ -6,7 +6,6 @@ import tempfile
from collections.abc import Generator
from collections.abc import Iterator
from typing import Any
from typing import cast
from typing import ClassVar
import pywikibot.time # type: ignore[import-untyped]
@@ -21,8 +20,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.mediawiki.family import family_class_dispatch
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.utils.logger import setup_logger
@@ -62,14 +60,14 @@ def get_doc_from_page(
sections_extracted: textlib.Content = textlib.extract_sections(page_text, site)
sections = [
TextSection(
Section(
link=f"{page.full_url()}#" + section.heading.replace(" ", "_"),
text=section.title + section.content,
)
for section in sections_extracted.sections
]
sections.append(
TextSection(
Section(
link=page.full_url(),
text=sections_extracted.header,
)
@@ -81,7 +79,7 @@ def get_doc_from_page(
doc_updated_at=pywikibot_timestamp_to_utc_datetime(
page.latest_revision.timestamp
),
sections=cast(list[TextSection | ImageSection], sections),
sections=sections,
semantic_identifier=page.title(),
metadata={"categories": [category.title() for category in page.categories()]},
id=f"MEDIAWIKI_{page.pageid}_{page.full_url()}",

View File

@@ -1,4 +1,3 @@
import json
from datetime import datetime
from enum import Enum
from typing import Any
@@ -28,25 +27,9 @@ class ConnectorMissingCredentialError(PermissionError):
class Section(BaseModel):
"""Base section class with common attributes"""
link: str | None = None
text: str | None = None
image_file_name: str | None = None
class TextSection(Section):
"""Section containing text content"""
text: str
link: str | None = None
class ImageSection(Section):
"""Section containing an image reference"""
image_file_name: str
link: str | None = None
image_file_name: str | None = None
class BasicExpertInfo(BaseModel):
@@ -116,7 +99,7 @@ class DocumentBase(BaseModel):
"""Used for Onyx ingestion api, the ID is inferred before use if not provided"""
id: str | None = None
sections: list[TextSection | ImageSection]
sections: list[Section]
source: DocumentSource | None = None
semantic_identifier: str # displayed in the UI as the main identifier for the doc
metadata: dict[str, str | list[str]]
@@ -166,11 +149,19 @@ class DocumentBase(BaseModel):
class Document(DocumentBase):
"""Used for Onyx ingestion api, the ID is required"""
id: str
id: str # This must be unique or during indexing/reindexing, chunks will be overwritten
source: DocumentSource
def get_total_char_length(self) -> int:
"""Calculate the total character length of the document including sections, metadata, and identifiers."""
section_length = sum(len(section.text) for section in self.sections)
identifier_length = len(self.semantic_identifier) + len(self.title or "")
metadata_length = sum(
len(k) + len(v) if isinstance(v, str) else len(k) + sum(len(x) for x in v)
for k, v in self.metadata.items()
)
return section_length + identifier_length + metadata_length
def to_short_descriptor(self) -> str:
"""Used when logging the identity of a document"""
return f"ID: '{self.id}'; Semantic ID: '{self.semantic_identifier}'"
@@ -193,32 +184,6 @@ class Document(DocumentBase):
)
class IndexingDocument(Document):
"""Document with processed sections for indexing"""
processed_sections: list[Section] = []
def get_total_char_length(self) -> int:
"""Get the total character length of the document including processed sections"""
title_len = len(self.title or self.semantic_identifier)
# Use processed_sections if available, otherwise fall back to original sections
if self.processed_sections:
section_len = sum(
len(section.text) if section.text is not None else 0
for section in self.processed_sections
)
else:
section_len = sum(
len(section.text)
if isinstance(section, TextSection) and section.text is not None
else 0
for section in self.sections
)
return title_len + section_len
class SlimDocument(BaseModel):
id: str
perm_sync_data: Any | None = None
@@ -239,15 +204,6 @@ class ConnectorCheckpoint(BaseModel):
def build_dummy_checkpoint(cls) -> "ConnectorCheckpoint":
return ConnectorCheckpoint(checkpoint_content={}, has_more=True)
def __str__(self) -> str:
"""String representation of the checkpoint, with truncation for large checkpoint content."""
MAX_CHECKPOINT_CONTENT_CHARS = 1000
content_str = json.dumps(self.checkpoint_content)
if len(content_str) > MAX_CHECKPOINT_CONTENT_CHARS:
content_str = content_str[: MAX_CHECKPOINT_CONTENT_CHARS - 3] + "..."
return f"ConnectorCheckpoint(checkpoint_content={content_str}, has_more={self.has_more})"
class DocumentFailure(BaseModel):
document_id: str

View File

@@ -1,3 +1,4 @@
import time
from collections.abc import Generator
from dataclasses import dataclass
from dataclasses import fields
@@ -25,13 +26,12 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.utils.batching import batch_generator
from onyx.utils.logger import setup_logger
logger = setup_logger()
_NOTION_PAGE_SIZE = 100
_NOTION_CALL_TIMEOUT = 30 # 30 seconds
@@ -475,7 +475,7 @@ class NotionConnector(LoadConnector, PollConnector):
Document(
id=page.id,
sections=[
TextSection(
Section(
link=f"{page.url}#{block.id.replace('-', '')}",
text=block.prefix + block.text,
)
@@ -537,9 +537,9 @@ class NotionConnector(LoadConnector, PollConnector):
"""
filtered_pages: list[NotionPage] = []
for page in pages:
# Parse ISO 8601 timestamp and convert to UTC epoch time
timestamp = page[filter_field].replace(".000Z", "+00:00")
compare_time = datetime.fromisoformat(timestamp).timestamp()
compare_time = time.mktime(
time.strptime(page[filter_field], "%Y-%m-%dT%H:%M:%S.000Z")
)
if compare_time > start and compare_time <= end:
filtered_pages += [NotionPage(**page)]
return filtered_pages
@@ -578,7 +578,7 @@ class NotionConnector(LoadConnector, PollConnector):
query_dict = {
"filter": {"property": "object", "value": "page"},
"page_size": _NOTION_PAGE_SIZE,
"page_size": self.batch_size,
}
while True:
db_res = self._search_notion(query_dict)
@@ -604,7 +604,7 @@ class NotionConnector(LoadConnector, PollConnector):
return
query_dict = {
"page_size": _NOTION_PAGE_SIZE,
"page_size": self.batch_size,
"sort": {"timestamp": "last_edited_time", "direction": "descending"},
"filter": {"property": "object", "value": "page"},
}

View File

@@ -23,8 +23,8 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.connectors.onyx_jira.utils import best_effort_basic_expert_info
from onyx.connectors.onyx_jira.utils import best_effort_get_field_from_issue
from onyx.connectors.onyx_jira.utils import build_jira_client
@@ -145,7 +145,7 @@ def fetch_jira_issues_batch(
yield Document(
id=page_url,
sections=[TextSection(link=page_url, text=ticket_content)],
sections=[Section(link=page_url, text=ticket_content)],
source=DocumentSource.JIRA,
semantic_identifier=f"{issue.key}: {issue.fields.summary}",
title=f"{issue.key} {issue.fields.summary}",

View File

@@ -16,7 +16,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.utils.logger import setup_logger
@@ -110,7 +110,7 @@ class ProductboardConnector(PollConnector):
yield Document(
id=feature["id"],
sections=[
TextSection(
Section(
link=feature["links"]["html"],
text=self._parse_description_html(feature["description"]),
)
@@ -133,7 +133,7 @@ class ProductboardConnector(PollConnector):
yield Document(
id=component["id"],
sections=[
TextSection(
Section(
link=component["links"]["html"],
text=self._parse_description_html(component["description"]),
)
@@ -159,7 +159,7 @@ class ProductboardConnector(PollConnector):
yield Document(
id=product["id"],
sections=[
TextSection(
Section(
link=product["links"]["html"],
text=self._parse_description_html(product["description"]),
)
@@ -189,7 +189,7 @@ class ProductboardConnector(PollConnector):
yield Document(
id=objective["id"],
sections=[
TextSection(
Section(
link=objective["links"]["html"],
text=self._parse_description_html(objective["description"]),
)

View File

@@ -13,7 +13,6 @@ from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.connectors.salesforce.doc_conversion import convert_sf_object_to_doc
from onyx.connectors.salesforce.doc_conversion import ID_PREFIX
from onyx.connectors.salesforce.salesforce_calls import fetch_all_csvs_in_parallel
@@ -49,12 +48,10 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
self,
credentials: dict[str, Any],
) -> dict[str, Any] | None:
domain = "test" if credentials.get("is_sandbox") else None
self._sf_client = Salesforce(
username=credentials["sf_username"],
password=credentials["sf_password"],
security_token=credentials["sf_security_token"],
domain=domain,
)
return None
@@ -219,8 +216,7 @@ if __name__ == "__main__":
for doc in doc_batch:
section_count += len(doc.sections)
for section in doc.sections:
if isinstance(section, TextSection) and section.text is not None:
text_count += len(section.text)
text_count += len(section.text)
end_time = time.time()
print(f"Doc count: {doc_count}")

View File

@@ -1,12 +1,10 @@
import re
from typing import cast
from onyx.configs.constants import DocumentSource
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import Document
from onyx.connectors.models import ImageSection
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.connectors.salesforce.sqlite_functions import get_child_ids
from onyx.connectors.salesforce.sqlite_functions import get_record
from onyx.connectors.salesforce.utils import SalesforceObject
@@ -116,8 +114,8 @@ def _extract_dict_text(raw_dict: dict) -> str:
return natural_language_for_dict
def _extract_section(salesforce_object: SalesforceObject, base_url: str) -> TextSection:
return TextSection(
def _extract_section(salesforce_object: SalesforceObject, base_url: str) -> Section:
return Section(
text=_extract_dict_text(salesforce_object.data),
link=f"{base_url}/{salesforce_object.id}",
)
@@ -177,7 +175,7 @@ def convert_sf_object_to_doc(
doc = Document(
id=onyx_salesforce_id,
sections=cast(list[TextSection | ImageSection], sections),
sections=sections,
source=DocumentSource.SALESFORCE,
semantic_identifier=extracted_semantic_identifier,
doc_updated_at=extracted_doc_updated_at,

View File

@@ -19,7 +19,7 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import TextSection
from onyx.connectors.models import Section
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.utils.logger import setup_logger
@@ -55,7 +55,7 @@ def _convert_driveitem_to_document(
doc = Document(
id=driveitem.id,
sections=[TextSection(link=driveitem.web_url, text=file_text)],
sections=[Section(link=driveitem.web_url, text=file_text)],
source=DocumentSource.SHAREPOINT,
semantic_identifier=driveitem.name,
doc_updated_at=driveitem.last_modified_datetime.replace(tzinfo=timezone.utc),

View File

@@ -19,8 +19,8 @@ from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import SlimDocument
from onyx.connectors.models import TextSection
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
@@ -212,7 +212,7 @@ class SlabConnector(LoadConnector, PollConnector, SlimConnector):
doc_batch.append(
Document(
id=post_id, # can't be url as this changes with the post title
sections=[TextSection(link=page_url, text=content_text)],
sections=[Section(link=page_url, text=content_text)],
source=DocumentSource.SLAB,
semantic_identifier=post["title"],
metadata={},

Some files were not shown because too many files have changed in this diff Show More