Compare commits

..

3 Commits

Author SHA1 Message Date
pablonyx
75103d2f8b quick nit 2025-02-28 15:05:39 -08:00
pablonyx
e8f7c34a72 k 2025-02-28 14:18:52 -08:00
pablonyx
1a378448f4 k 2025-02-28 14:10:33 -08:00
290 changed files with 2667 additions and 9838 deletions

View File

@@ -12,40 +12,29 @@ env:
BUILDKIT_PROGRESS: plain
jobs:
# Bypassing this for now as the idea of not building is glitching
# releases and builds that depends on everything being tagged in docker
# 1) Preliminary job to check if the changed files are relevant
# check_model_server_changes:
# runs-on: ubuntu-latest
# outputs:
# changed: ${{ steps.check.outputs.changed }}
# steps:
# - name: Checkout code
# uses: actions/checkout@v4
#
# - name: Check if relevant files changed
# id: check
# run: |
# # Default to "false"
# echo "changed=false" >> $GITHUB_OUTPUT
#
# # Compare the previous commit (github.event.before) to the current one (github.sha)
# # If any file in backend/model_server/** or backend/Dockerfile.model_server is changed,
# # set changed=true
# if git diff --name-only ${{ github.event.before }} ${{ github.sha }} \
# | grep -E '^backend/model_server/|^backend/Dockerfile.model_server'; then
# echo "changed=true" >> $GITHUB_OUTPUT
# fi
# 1) Preliminary job to check if the changed files are relevant
check_model_server_changes:
runs-on: ubuntu-latest
outputs:
changed: "true"
changed: ${{ steps.check.outputs.changed }}
steps:
- name: Bypass check and set output
run: echo "changed=true" >> $GITHUB_OUTPUT
- name: Checkout code
uses: actions/checkout@v4
- name: Check if relevant files changed
id: check
run: |
# Default to "false"
echo "changed=false" >> $GITHUB_OUTPUT
# Compare the previous commit (github.event.before) to the current one (github.sha)
# If any file in backend/model_server/** or backend/Dockerfile.model_server is changed,
# set changed=true
if git diff --name-only ${{ github.event.before }} ${{ github.sha }} \
| grep -E '^backend/model_server/|^backend/Dockerfile.model_server'; then
echo "changed=true" >> $GITHUB_OUTPUT
fi
build-amd64:
needs: [check_model_server_changes]
if: needs.check_model_server_changes.outputs.changed == 'true'

View File

@@ -1,7 +1,6 @@
name: Connector Tests
on:
merge_group:
pull_request:
branches: [main]
schedule:
@@ -48,13 +47,11 @@ env:
# Gitbook
GITBOOK_SPACE_ID: ${{ secrets.GITBOOK_SPACE_ID }}
GITBOOK_API_KEY: ${{ secrets.GITBOOK_API_KEY }}
# Notion
NOTION_INTEGRATION_TOKEN: ${{ secrets.NOTION_INTEGRATION_TOKEN }}
jobs:
connectors-check:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}"]
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
env:
PYTHONPATH: ./backend
@@ -79,7 +76,7 @@ jobs:
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
playwright install chromium
playwright install-deps chromium
- name: Run Tests
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
run: py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/connectors

View File

@@ -114,4 +114,3 @@ To try the Onyx Enterprise Edition:
## 💡 Contributing
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.

View File

@@ -1,125 +0,0 @@
"""Update GitHub connector repo_name to repositories
Revision ID: 3934b1bc7b62
Revises: b7c2b63c4a03
Create Date: 2025-03-05 10:50:30.516962
"""
from alembic import op
import sqlalchemy as sa
import json
import logging
# revision identifiers, used by Alembic.
revision = "3934b1bc7b62"
down_revision = "b7c2b63c4a03"
branch_labels = None
depends_on = None
logger = logging.getLogger("alembic.runtime.migration")
def upgrade() -> None:
# Get all GitHub connectors
conn = op.get_bind()
# First get all GitHub connectors
github_connectors = conn.execute(
sa.text(
"""
SELECT id, connector_specific_config
FROM connector
WHERE source = 'GITHUB'
"""
)
).fetchall()
# Update each connector's config
updated_count = 0
for connector_id, config in github_connectors:
try:
if not config:
logger.warning(f"Connector {connector_id} has no config, skipping")
continue
# Parse the config if it's a string
if isinstance(config, str):
config = json.loads(config)
if "repo_name" not in config:
continue
# Create new config with repositories instead of repo_name
new_config = dict(config)
repo_name_value = new_config.pop("repo_name")
new_config["repositories"] = repo_name_value
# Update the connector with the new config
conn.execute(
sa.text(
"""
UPDATE connector
SET connector_specific_config = :new_config
WHERE id = :connector_id
"""
),
{"connector_id": connector_id, "new_config": json.dumps(new_config)},
)
updated_count += 1
except Exception as e:
logger.error(f"Error updating connector {connector_id}: {str(e)}")
def downgrade() -> None:
# Get all GitHub connectors
conn = op.get_bind()
logger.debug(
"Starting rollback of GitHub connectors from repositories to repo_name"
)
github_connectors = conn.execute(
sa.text(
"""
SELECT id, connector_specific_config
FROM connector
WHERE source = 'GITHUB'
"""
)
).fetchall()
logger.debug(f"Found {len(github_connectors)} GitHub connectors to rollback")
# Revert each GitHub connector to use repo_name instead of repositories
reverted_count = 0
for connector_id, config in github_connectors:
try:
if not config:
continue
# Parse the config if it's a string
if isinstance(config, str):
config = json.loads(config)
if "repositories" not in config:
continue
# Create new config with repo_name instead of repositories
new_config = dict(config)
repositories_value = new_config.pop("repositories")
new_config["repo_name"] = repositories_value
# Update the connector with the new config
conn.execute(
sa.text(
"""
UPDATE connector
SET connector_specific_config = :new_config
WHERE id = :connector_id
"""
),
{"new_config": json.dumps(new_config), "connector_id": connector_id},
)
reverted_count += 1
except Exception as e:
logger.error(f"Error reverting connector {connector_id}: {str(e)}")

View File

@@ -1,55 +0,0 @@
"""add background_reindex_enabled field
Revision ID: b7c2b63c4a03
Revises: f11b408e39d3
Create Date: 2024-03-26 12:34:56.789012
"""
from alembic import op
import sqlalchemy as sa
from onyx.db.enums import EmbeddingPrecision
# revision identifiers, used by Alembic.
revision = "b7c2b63c4a03"
down_revision = "f11b408e39d3"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add background_reindex_enabled column with default value of True
op.add_column(
"search_settings",
sa.Column(
"background_reindex_enabled",
sa.Boolean(),
nullable=False,
server_default="true",
),
)
# Add embedding_precision column with default value of FLOAT
op.add_column(
"search_settings",
sa.Column(
"embedding_precision",
sa.Enum(EmbeddingPrecision, native_enum=False),
nullable=False,
server_default=EmbeddingPrecision.FLOAT.name,
),
)
# Add reduced_dimension column with default value of None
op.add_column(
"search_settings",
sa.Column("reduced_dimension", sa.Integer(), nullable=True),
)
def downgrade() -> None:
# Remove the background_reindex_enabled column
op.drop_column("search_settings", "background_reindex_enabled")
op.drop_column("search_settings", "embedding_precision")
op.drop_column("search_settings", "reduced_dimension")

View File

@@ -1,51 +0,0 @@
"""new column user tenant mapping
Revision ID: ac842f85f932
Revises: 34e3630c7f32
Create Date: 2025-03-03 13:30:14.802874
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "ac842f85f932"
down_revision = "34e3630c7f32"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add active column with default value of True
op.add_column(
"user_tenant_mapping",
sa.Column(
"active",
sa.Boolean(),
nullable=False,
server_default="true",
),
schema="public",
)
op.drop_constraint("uq_email", "user_tenant_mapping", schema="public")
# Create a unique index for active=true records
# This ensures a user can only be active in one tenant at a time
op.execute(
"CREATE UNIQUE INDEX uq_user_active_email_idx ON public.user_tenant_mapping (email) WHERE active = true"
)
def downgrade() -> None:
# Drop the unique index for active=true records
op.execute("DROP INDEX IF EXISTS uq_user_active_email_idx")
op.create_unique_constraint(
"uq_email", "user_tenant_mapping", ["email"], schema="public"
)
# Remove the active column
op.drop_column("user_tenant_mapping", "active", schema="public")

View File

@@ -4,8 +4,7 @@ from ee.onyx.server.reporting.usage_export_generation import create_new_usage_re
from onyx.background.celery.apps.primary import celery_app
from onyx.background.task_utils import build_celery_task_wrapper
from onyx.configs.app_configs import JOB_TIMEOUT
from onyx.db.chat import delete_chat_session
from onyx.db.chat import get_chat_sessions_older_than
from onyx.db.chat import delete_chat_sessions_older_than
from onyx.db.engine import get_session_with_current_tenant
from onyx.server.settings.store import load_settings
from onyx.utils.logger import setup_logger
@@ -19,26 +18,7 @@ logger = setup_logger()
@celery_app.task(soft_time_limit=JOB_TIMEOUT)
def perform_ttl_management_task(retention_limit_days: int, *, tenant_id: str) -> None:
with get_session_with_current_tenant() as db_session:
old_chat_sessions = get_chat_sessions_older_than(
retention_limit_days, db_session
)
for user_id, session_id in old_chat_sessions:
# one session per delete so that we don't blow up if a deletion fails.
with get_session_with_current_tenant() as db_session:
try:
delete_chat_session(
user_id,
session_id,
db_session,
include_deleted=True,
hard_delete=True,
)
except Exception:
logger.exception(
"delete_chat_session exceptioned. "
f"user_id={user_id} session_id={session_id}"
)
delete_chat_sessions_older_than(retention_limit_days, db_session)
#####

View File

@@ -134,9 +134,7 @@ def fetch_chat_sessions_eagerly_by_time(
limit: int | None = 500,
initial_time: datetime | None = None,
) -> list[ChatSession]:
"""Sorted by oldest to newest, then by message id"""
asc_time_order: UnaryExpression = asc(ChatSession.time_created)
time_order: UnaryExpression = desc(ChatSession.time_created)
message_order: UnaryExpression = asc(ChatMessage.id)
filters: list[ColumnElement | BinaryExpression] = [
@@ -149,7 +147,8 @@ def fetch_chat_sessions_eagerly_by_time(
subquery = (
db_session.query(ChatSession.id, ChatSession.time_created)
.filter(*filters)
.order_by(asc_time_order)
.order_by(ChatSession.id, time_order)
.distinct(ChatSession.id)
.limit(limit)
.subquery()
)
@@ -165,7 +164,7 @@ def fetch_chat_sessions_eagerly_by_time(
ChatMessage.chat_message_feedbacks
),
)
.order_by(asc_time_order, message_order)
.order_by(time_order, message_order)
)
chat_sessions = query.all()

View File

@@ -16,20 +16,13 @@ from onyx.db.models import UsageReport
from onyx.file_store.file_store import get_default_file_store
# Gets skeletons of all messages in the given range
# Gets skeletons of all message
def get_empty_chat_messages_entries__paginated(
db_session: Session,
period: tuple[datetime, datetime],
limit: int | None = 500,
initial_time: datetime | None = None,
) -> tuple[Optional[datetime], list[ChatMessageSkeleton]]:
"""Returns a tuple where:
first element is the most recent timestamp out of the sessions iterated
- this timestamp can be used to paginate forward in time
second element is a list of messages belonging to all the sessions iterated
Only messages of type USER are returned
"""
chat_sessions = fetch_chat_sessions_eagerly_by_time(
start=period[0],
end=period[1],
@@ -59,17 +52,18 @@ def get_empty_chat_messages_entries__paginated(
if len(chat_sessions) == 0:
return None, []
return chat_sessions[-1].time_created, message_skeletons
return chat_sessions[0].time_created, message_skeletons
def get_all_empty_chat_message_entries(
db_session: Session,
period: tuple[datetime, datetime],
) -> Generator[list[ChatMessageSkeleton], None, None]:
"""period is the range of time over which to fetch messages."""
initial_time: Optional[datetime] = period[0]
ind = 0
while True:
# iterate from oldest to newest
ind += 1
time_created, message_skeletons = get_empty_chat_messages_entries__paginated(
db_session,
period,

View File

@@ -424,7 +424,7 @@ def _validate_curator_status__no_commit(
)
# if the user is a curator in any of their groups, set their role to CURATOR
# otherwise, set their role to BASIC only if they were previously a CURATOR
# otherwise, set their role to BASIC
if curator_relationships:
user.role = UserRole.CURATOR
elif user.role == UserRole.CURATOR:
@@ -631,16 +631,7 @@ def update_user_group(
removed_users = db_session.scalars(
select(User).where(User.id.in_(removed_user_ids)) # type: ignore
).unique()
# Filter out admin and global curator users before validating curator status
users_to_validate = [
user
for user in removed_users
if user.role not in [UserRole.ADMIN, UserRole.GLOBAL_CURATOR]
]
if users_to_validate:
_validate_curator_status__no_commit(db_session, users_to_validate)
_validate_curator_status__no_commit(db_session, list(removed_users))
# update "time_updated" to now
db_user_group.time_last_modified_by_user = func.now()

View File

@@ -15,7 +15,7 @@ from ee.onyx.server.enterprise_settings.api import (
)
from ee.onyx.server.manage.standard_answer import router as standard_answer_router
from ee.onyx.server.middleware.tenant_tracking import add_tenant_id_middleware
from ee.onyx.server.oauth.api import router as ee_oauth_router
from ee.onyx.server.oauth.api import router as oauth_router
from ee.onyx.server.query_and_chat.chat_backend import (
router as chat_router,
)
@@ -128,7 +128,7 @@ def get_application() -> FastAPI:
include_router_with_global_prefix_prepended(application, query_router)
include_router_with_global_prefix_prepended(application, chat_router)
include_router_with_global_prefix_prepended(application, standard_answer_router)
include_router_with_global_prefix_prepended(application, ee_oauth_router)
include_router_with_global_prefix_prepended(application, oauth_router)
# Enterprise-only global settings
include_router_with_global_prefix_prepended(

View File

@@ -22,7 +22,7 @@ from onyx.onyxbot.slack.blocks import get_restate_blocks
from onyx.onyxbot.slack.constants import GENERATE_ANSWER_BUTTON_ACTION_ID
from onyx.onyxbot.slack.handlers.utils import send_team_member_message
from onyx.onyxbot.slack.models import SlackMessageInfo
from onyx.onyxbot.slack.utils import respond_in_thread_or_channel
from onyx.onyxbot.slack.utils import respond_in_thread
from onyx.onyxbot.slack.utils import update_emote_react
from onyx.utils.logger import OnyxLoggingAdapter
from onyx.utils.logger import setup_logger
@@ -216,7 +216,7 @@ def _handle_standard_answers(
all_blocks = restate_question_blocks + answer_blocks
try:
respond_in_thread_or_channel(
respond_in_thread(
client=client,
channel=message_info.channel_to_respond,
receiver_ids=receiver_ids,
@@ -231,7 +231,6 @@ def _handle_standard_answers(
client=client,
channel=message_info.channel_to_respond,
thread_ts=slack_thread_id,
receiver_ids=receiver_ids,
)
return True

View File

@@ -80,7 +80,6 @@ class ConfluenceCloudOAuth:
"search:confluence%20"
# granular scope
"read:attachment:confluence%20" # possibly unneeded unless calling v2 attachments api
"read:content-details:confluence%20" # for permission sync
"offline_access"
)

View File

@@ -1,14 +1,10 @@
import re
from typing import cast
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from ee.onyx.server.query_and_chat.models import AgentAnswer
from ee.onyx.server.query_and_chat.models import AgentSubQuery
from ee.onyx.server.query_and_chat.models import AgentSubQuestion
from ee.onyx.server.query_and_chat.models import BasicCreateChatMessageRequest
from ee.onyx.server.query_and_chat.models import (
BasicCreateChatMessageWithHistoryRequest,
@@ -18,19 +14,13 @@ from ee.onyx.server.query_and_chat.models import SimpleDoc
from onyx.auth.users import current_user
from onyx.chat.chat_utils import combine_message_thread
from onyx.chat.chat_utils import create_chat_chain
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import AllCitations
from onyx.chat.models import ExtendedToolResponse
from onyx.chat.models import FinalUsedContextDocsResponse
from onyx.chat.models import LlmDoc
from onyx.chat.models import LLMRelevanceFilterResponse
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import QADocsResponse
from onyx.chat.models import RefinedAnswerImprovement
from onyx.chat.models import StreamingError
from onyx.chat.models import SubQueryPiece
from onyx.chat.models import SubQuestionIdentifier
from onyx.chat.models import SubQuestionPiece
from onyx.chat.process_message import ChatPacketStream
from onyx.chat.process_message import stream_chat_message_objects
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
@@ -99,12 +89,6 @@ def _convert_packet_stream_to_response(
final_context_docs: list[LlmDoc] = []
answer = ""
# accumulate stream data with these dicts
agent_sub_questions: dict[tuple[int, int], AgentSubQuestion] = {}
agent_answers: dict[tuple[int, int], AgentAnswer] = {}
agent_sub_queries: dict[tuple[int, int, int], AgentSubQuery] = {}
for packet in packets:
if isinstance(packet, OnyxAnswerPiece) and packet.answer_piece:
answer += packet.answer_piece
@@ -113,15 +97,6 @@ def _convert_packet_stream_to_response(
# TODO: deprecate `simple_search_docs`
response.simple_search_docs = _translate_doc_response_to_simple_doc(packet)
# This is a no-op if agent_sub_questions hasn't already been filled
if packet.level is not None and packet.level_question_num is not None:
id = (packet.level, packet.level_question_num)
if id in agent_sub_questions:
agent_sub_questions[id].document_ids = [
saved_search_doc.document_id
for saved_search_doc in packet.top_documents
]
elif isinstance(packet, StreamingError):
response.error_msg = packet.error
elif isinstance(packet, ChatMessageDetail):
@@ -138,104 +113,11 @@ def _convert_packet_stream_to_response(
citation.citation_num: citation.document_id
for citation in packet.citations
}
# agentic packets
elif isinstance(packet, SubQuestionPiece):
if packet.level is not None and packet.level_question_num is not None:
id = (packet.level, packet.level_question_num)
if agent_sub_questions.get(id) is None:
agent_sub_questions[id] = AgentSubQuestion(
level=packet.level,
level_question_num=packet.level_question_num,
sub_question=packet.sub_question,
document_ids=[],
)
else:
agent_sub_questions[id].sub_question += packet.sub_question
elif isinstance(packet, AgentAnswerPiece):
if packet.level is not None and packet.level_question_num is not None:
id = (packet.level, packet.level_question_num)
if agent_answers.get(id) is None:
agent_answers[id] = AgentAnswer(
level=packet.level,
level_question_num=packet.level_question_num,
answer=packet.answer_piece,
answer_type=packet.answer_type,
)
else:
agent_answers[id].answer += packet.answer_piece
elif isinstance(packet, SubQueryPiece):
if packet.level is not None and packet.level_question_num is not None:
sub_query_id = (
packet.level,
packet.level_question_num,
packet.query_id,
)
if agent_sub_queries.get(sub_query_id) is None:
agent_sub_queries[sub_query_id] = AgentSubQuery(
level=packet.level,
level_question_num=packet.level_question_num,
sub_query=packet.sub_query,
query_id=packet.query_id,
)
else:
agent_sub_queries[sub_query_id].sub_query += packet.sub_query
elif isinstance(packet, ExtendedToolResponse):
# we shouldn't get this ... it gets intercepted and translated to QADocsResponse
logger.warning(
"_convert_packet_stream_to_response: Unexpected chat packet type ExtendedToolResponse!"
)
elif isinstance(packet, RefinedAnswerImprovement):
response.agent_refined_answer_improvement = (
packet.refined_answer_improvement
)
else:
logger.warning(
f"_convert_packet_stream_to_response - Unrecognized chat packet: type={type(packet)}"
)
response.final_context_doc_indices = _get_final_context_doc_indices(
final_context_docs, response.top_documents
)
# organize / sort agent metadata for output
if len(agent_sub_questions) > 0:
response.agent_sub_questions = cast(
dict[int, list[AgentSubQuestion]],
SubQuestionIdentifier.make_dict_by_level(agent_sub_questions),
)
if len(agent_answers) > 0:
# return the agent_level_answer from the first level or the last one depending
# on agent_refined_answer_improvement
response.agent_answers = cast(
dict[int, list[AgentAnswer]],
SubQuestionIdentifier.make_dict_by_level(agent_answers),
)
if response.agent_answers:
selected_answer_level = (
0
if not response.agent_refined_answer_improvement
else len(response.agent_answers) - 1
)
level_answers = response.agent_answers[selected_answer_level]
for level_answer in level_answers:
if level_answer.answer_type != "agent_level_answer":
continue
answer = level_answer.answer
break
if len(agent_sub_queries) > 0:
# subqueries are often emitted with trailing whitespace ... clean it up here
# perhaps fix at the source?
for v in agent_sub_queries.values():
v.sub_query = v.sub_query.strip()
response.agent_sub_queries = (
AgentSubQuery.make_dict_by_level_and_question_index(agent_sub_queries)
)
response.answer = answer
if answer:
response.answer_citationless = remove_answer_citations(answer)

View File

@@ -1,5 +1,3 @@
from collections import OrderedDict
from typing import Literal
from uuid import UUID
from pydantic import BaseModel
@@ -11,7 +9,6 @@ from onyx.chat.models import CitationInfo
from onyx.chat.models import OnyxContexts
from onyx.chat.models import PersonaOverrideConfig
from onyx.chat.models import QADocsResponse
from onyx.chat.models import SubQuestionIdentifier
from onyx.chat.models import ThreadMessage
from onyx.configs.constants import DocumentSource
from onyx.context.search.enums import LLMEvaluationType
@@ -91,64 +88,6 @@ class SimpleDoc(BaseModel):
metadata: dict | None
class AgentSubQuestion(SubQuestionIdentifier):
sub_question: str
document_ids: list[str]
class AgentAnswer(SubQuestionIdentifier):
answer: str
answer_type: Literal["agent_sub_answer", "agent_level_answer"]
class AgentSubQuery(SubQuestionIdentifier):
sub_query: str
query_id: int
@staticmethod
def make_dict_by_level_and_question_index(
original_dict: dict[tuple[int, int, int], "AgentSubQuery"]
) -> dict[int, dict[int, list["AgentSubQuery"]]]:
"""Takes a dict of tuple(level, question num, query_id) to sub queries.
returns a dict of level to dict[question num to list of query_id's]
Ordering is asc for readability.
"""
# In this function, when we sort int | None, we deliberately push None to the end
# map entries to the level_question_dict
level_question_dict: dict[int, dict[int, list["AgentSubQuery"]]] = {}
for k1, obj in original_dict.items():
level = k1[0]
question = k1[1]
if level not in level_question_dict:
level_question_dict[level] = {}
if question not in level_question_dict[level]:
level_question_dict[level][question] = []
level_question_dict[level][question].append(obj)
# sort each query_id list and question_index
for key1, obj1 in level_question_dict.items():
for key2, value2 in obj1.items():
# sort the query_id list of each question_index
level_question_dict[key1][key2] = sorted(
value2, key=lambda o: o.query_id
)
# sort the question_index dict of level
level_question_dict[key1] = OrderedDict(
sorted(level_question_dict[key1].items(), key=lambda x: (x is None, x))
)
# sort the top dict of levels
sorted_dict = OrderedDict(
sorted(level_question_dict.items(), key=lambda x: (x is None, x))
)
return sorted_dict
class ChatBasicResponse(BaseModel):
# This is built piece by piece, any of these can be None as the flow could break
answer: str | None = None
@@ -168,12 +107,6 @@ class ChatBasicResponse(BaseModel):
simple_search_docs: list[SimpleDoc] | None = None
llm_chunks_indices: list[int] | None = None
# agentic fields
agent_sub_questions: dict[int, list[AgentSubQuestion]] | None = None
agent_answers: dict[int, list[AgentAnswer]] | None = None
agent_sub_queries: dict[int, dict[int, list[AgentSubQuery]]] | None = None
agent_refined_answer_improvement: bool | None = None
class OneShotQARequest(ChunkContext):
# Supports simplier APIs that don't deal with chat histories or message edits

View File

@@ -48,15 +48,10 @@ def fetch_and_process_chat_session_history(
feedback_type: QAFeedbackType | None,
limit: int | None = 500,
) -> list[ChatSessionSnapshot]:
# observed to be slow a scale of 8192 sessions and 4 messages per session
# this is a little slow (5 seconds)
chat_sessions = fetch_chat_sessions_eagerly_by_time(
start=start, end=end, db_session=db_session, limit=limit
)
# this is VERY slow (80 seconds) due to create_chat_chain being called
# for each session. Needs optimizing.
chat_session_snapshots = [
snapshot_from_chat_session(chat_session=chat_session, db_session=db_session)
for chat_session in chat_sessions
@@ -251,8 +246,6 @@ def get_query_history_as_csv(
detail="Query history has been disabled by the administrator.",
)
# this call is very expensive and is timing out via endpoint
# TODO: optimize call and/or generate via background task
complete_chat_session_history = fetch_and_process_chat_session_history(
db_session=db_session,
start=start or datetime.fromtimestamp(0, tz=timezone.utc),

View File

@@ -1,45 +0,0 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Response
from ee.onyx.auth.users import current_cloud_superuser
from ee.onyx.server.tenants.models import ImpersonateRequest
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
from onyx.auth.users import auth_backend
from onyx.auth.users import get_redis_strategy
from onyx.auth.users import User
from onyx.db.engine import get_session_with_tenant
from onyx.db.users import get_user_by_email
from onyx.utils.logger import setup_logger
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.post("/impersonate")
async def impersonate_user(
impersonate_request: ImpersonateRequest,
_: User = Depends(current_cloud_superuser),
) -> Response:
"""Allows a cloud superuser to impersonate another user by generating an impersonation JWT token"""
tenant_id = get_tenant_id_for_email(impersonate_request.email)
with get_session_with_tenant(tenant_id=tenant_id) as tenant_session:
user_to_impersonate = get_user_by_email(
impersonate_request.email, tenant_session
)
if user_to_impersonate is None:
raise HTTPException(status_code=404, detail="User not found")
token = await get_redis_strategy().write_token(user_to_impersonate)
response = await auth_backend.transport.get_login_response(token)
response.set_cookie(
key="fastapiusersauth",
value=token,
httponly=True,
secure=True,
samesite="lax",
)
return response

View File

@@ -1,98 +0,0 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Response
from sqlalchemy.exc import IntegrityError
from ee.onyx.auth.users import generate_anonymous_user_jwt_token
from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME
from ee.onyx.server.tenants.anonymous_user_path import get_anonymous_user_path
from ee.onyx.server.tenants.anonymous_user_path import (
get_tenant_id_for_anonymous_user_path,
)
from ee.onyx.server.tenants.anonymous_user_path import modify_anonymous_user_path
from ee.onyx.server.tenants.anonymous_user_path import validate_anonymous_user_path
from ee.onyx.server.tenants.models import AnonymousUserPath
from onyx.auth.users import anonymous_user_enabled
from onyx.auth.users import current_admin_user
from onyx.auth.users import optional_user
from onyx.auth.users import User
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
from onyx.db.engine import get_session_with_shared_schema
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.get("/anonymous-user-path")
async def get_anonymous_user_path_api(
_: User | None = Depends(current_admin_user),
) -> AnonymousUserPath:
tenant_id = get_current_tenant_id()
if tenant_id is None:
raise HTTPException(status_code=404, detail="Tenant not found")
with get_session_with_shared_schema() as db_session:
current_path = get_anonymous_user_path(tenant_id, db_session)
return AnonymousUserPath(anonymous_user_path=current_path)
@router.post("/anonymous-user-path")
async def set_anonymous_user_path_api(
anonymous_user_path: str,
_: User | None = Depends(current_admin_user),
) -> None:
tenant_id = get_current_tenant_id()
try:
validate_anonymous_user_path(anonymous_user_path)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
with get_session_with_shared_schema() as db_session:
try:
modify_anonymous_user_path(tenant_id, anonymous_user_path, db_session)
except IntegrityError:
raise HTTPException(
status_code=409,
detail="The anonymous user path is already in use. Please choose a different path.",
)
except Exception as e:
logger.exception(f"Failed to modify anonymous user path: {str(e)}")
raise HTTPException(
status_code=500,
detail="An unexpected error occurred while modifying the anonymous user path",
)
@router.post("/anonymous-user")
async def login_as_anonymous_user(
anonymous_user_path: str,
_: User | None = Depends(optional_user),
) -> Response:
with get_session_with_shared_schema() as db_session:
tenant_id = get_tenant_id_for_anonymous_user_path(
anonymous_user_path, db_session
)
if not tenant_id:
raise HTTPException(status_code=404, detail="Tenant not found")
if not anonymous_user_enabled(tenant_id=tenant_id):
raise HTTPException(status_code=403, detail="Anonymous user is not enabled")
token = generate_anonymous_user_jwt_token(tenant_id)
response = Response()
response.delete_cookie(FASTAPI_USERS_AUTH_COOKIE_NAME)
response.set_cookie(
key=ANONYMOUS_USER_COOKIE_NAME,
value=token,
httponly=True,
secure=True,
samesite="strict",
)
return response

View File

@@ -1,24 +1,269 @@
import stripe
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Response
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from ee.onyx.server.tenants.admin_api import router as admin_router
from ee.onyx.server.tenants.anonymous_users_api import router as anonymous_users_router
from ee.onyx.server.tenants.billing_api import router as billing_router
from ee.onyx.server.tenants.team_membership_api import router as team_membership_router
from ee.onyx.server.tenants.tenant_management_api import (
router as tenant_management_router,
)
from ee.onyx.server.tenants.user_invitations_api import (
router as user_invitations_router,
from ee.onyx.auth.users import current_cloud_superuser
from ee.onyx.auth.users import generate_anonymous_user_jwt_token
from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
from ee.onyx.server.tenants.access import control_plane_dep
from ee.onyx.server.tenants.anonymous_user_path import get_anonymous_user_path
from ee.onyx.server.tenants.anonymous_user_path import (
get_tenant_id_for_anonymous_user_path,
)
from ee.onyx.server.tenants.anonymous_user_path import modify_anonymous_user_path
from ee.onyx.server.tenants.anonymous_user_path import validate_anonymous_user_path
from ee.onyx.server.tenants.billing import fetch_billing_information
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
from ee.onyx.server.tenants.models import AnonymousUserPath
from ee.onyx.server.tenants.models import BillingInformation
from ee.onyx.server.tenants.models import ImpersonateRequest
from ee.onyx.server.tenants.models import ProductGatingRequest
from ee.onyx.server.tenants.models import ProductGatingResponse
from ee.onyx.server.tenants.models import SubscriptionSessionResponse
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
from ee.onyx.server.tenants.product_gating import store_product_gating
from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant
from onyx.auth.users import anonymous_user_enabled
from onyx.auth.users import auth_backend
from onyx.auth.users import current_admin_user
from onyx.auth.users import get_redis_strategy
from onyx.auth.users import optional_user
from onyx.auth.users import User
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
from onyx.db.auth import get_user_count
from onyx.db.engine import get_session
from onyx.db.engine import get_session_with_shared_schema
from onyx.db.engine import get_session_with_tenant
from onyx.db.users import delete_user_from_db
from onyx.db.users import get_user_by_email
from onyx.server.manage.models import UserByEmail
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
# Create a main router to include all sub-routers
# Note: We don't add a prefix here as each router already has the /tenants prefix
router = APIRouter()
stripe.api_key = STRIPE_SECRET_KEY
logger = setup_logger()
router = APIRouter(prefix="/tenants")
# Include all the individual routers
router.include_router(admin_router)
router.include_router(anonymous_users_router)
router.include_router(billing_router)
router.include_router(team_membership_router)
router.include_router(tenant_management_router)
router.include_router(user_invitations_router)
@router.get("/anonymous-user-path")
async def get_anonymous_user_path_api(
_: User | None = Depends(current_admin_user),
) -> AnonymousUserPath:
tenant_id = get_current_tenant_id()
if tenant_id is None:
raise HTTPException(status_code=404, detail="Tenant not found")
with get_session_with_shared_schema() as db_session:
current_path = get_anonymous_user_path(tenant_id, db_session)
return AnonymousUserPath(anonymous_user_path=current_path)
@router.post("/anonymous-user-path")
async def set_anonymous_user_path_api(
anonymous_user_path: str,
_: User | None = Depends(current_admin_user),
) -> None:
tenant_id = get_current_tenant_id()
try:
validate_anonymous_user_path(anonymous_user_path)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
with get_session_with_shared_schema() as db_session:
try:
modify_anonymous_user_path(tenant_id, anonymous_user_path, db_session)
except IntegrityError:
raise HTTPException(
status_code=409,
detail="The anonymous user path is already in use. Please choose a different path.",
)
except Exception as e:
logger.exception(f"Failed to modify anonymous user path: {str(e)}")
raise HTTPException(
status_code=500,
detail="An unexpected error occurred while modifying the anonymous user path",
)
@router.post("/anonymous-user")
async def login_as_anonymous_user(
anonymous_user_path: str,
_: User | None = Depends(optional_user),
) -> Response:
with get_session_with_shared_schema() as db_session:
tenant_id = get_tenant_id_for_anonymous_user_path(
anonymous_user_path, db_session
)
if not tenant_id:
raise HTTPException(status_code=404, detail="Tenant not found")
if not anonymous_user_enabled(tenant_id=tenant_id):
raise HTTPException(status_code=403, detail="Anonymous user is not enabled")
token = generate_anonymous_user_jwt_token(tenant_id)
response = Response()
response.delete_cookie(FASTAPI_USERS_AUTH_COOKIE_NAME)
response.set_cookie(
key=ANONYMOUS_USER_COOKIE_NAME,
value=token,
httponly=True,
secure=True,
samesite="strict",
)
return response
@router.post("/product-gating")
def gate_product(
product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep)
) -> ProductGatingResponse:
"""
Gating the product means that the product is not available to the tenant.
They will be directed to the billing page.
We gate the product when their subscription has ended.
"""
try:
store_product_gating(
product_gating_request.tenant_id, product_gating_request.application_status
)
return ProductGatingResponse(updated=True, error=None)
except Exception as e:
logger.exception("Failed to gate product")
return ProductGatingResponse(updated=False, error=str(e))
@router.get("/billing-information")
async def billing_information(
_: User = Depends(current_admin_user),
) -> BillingInformation | SubscriptionStatusResponse:
logger.info("Fetching billing information")
tenant_id = get_current_tenant_id()
return fetch_billing_information(tenant_id)
@router.post("/create-customer-portal-session")
async def create_customer_portal_session(
_: User = Depends(current_admin_user),
) -> dict:
tenant_id = get_current_tenant_id()
try:
stripe_info = fetch_tenant_stripe_information(tenant_id)
stripe_customer_id = stripe_info.get("stripe_customer_id")
if not stripe_customer_id:
raise HTTPException(status_code=400, detail="Stripe customer ID not found")
logger.info(stripe_customer_id)
portal_session = stripe.billing_portal.Session.create(
customer=stripe_customer_id,
return_url=f"{WEB_DOMAIN}/admin/billing",
)
logger.info(portal_session)
return {"url": portal_session.url}
except Exception as e:
logger.exception("Failed to create customer portal session")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/create-subscription-session")
async def create_subscription_session(
_: User = Depends(current_admin_user),
) -> SubscriptionSessionResponse:
try:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if not tenant_id:
raise HTTPException(status_code=400, detail="Tenant ID not found")
session_id = fetch_stripe_checkout_session(tenant_id)
return SubscriptionSessionResponse(sessionId=session_id)
except Exception as e:
logger.exception("Failed to create resubscription session")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/impersonate")
async def impersonate_user(
impersonate_request: ImpersonateRequest,
_: User = Depends(current_cloud_superuser),
) -> Response:
"""Allows a cloud superuser to impersonate another user by generating an impersonation JWT token"""
tenant_id = get_tenant_id_for_email(impersonate_request.email)
with get_session_with_tenant(tenant_id=tenant_id) as tenant_session:
user_to_impersonate = get_user_by_email(
impersonate_request.email, tenant_session
)
if user_to_impersonate is None:
raise HTTPException(status_code=404, detail="User not found")
token = await get_redis_strategy().write_token(user_to_impersonate)
response = await auth_backend.transport.get_login_response(token)
response.set_cookie(
key="fastapiusersauth",
value=token,
httponly=True,
secure=True,
samesite="lax",
)
return response
@router.post("/leave-organization")
async def leave_organization(
user_email: UserByEmail,
current_user: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
tenant_id = get_current_tenant_id()
if current_user is None or current_user.email != user_email.user_email:
raise HTTPException(
status_code=403, detail="You can only leave the organization as yourself"
)
user_to_delete = get_user_by_email(user_email.user_email, db_session)
if user_to_delete is None:
raise HTTPException(status_code=404, detail="User not found")
num_admin_users = await get_user_count(only_admin_users=True)
should_delete_tenant = num_admin_users == 1
if should_delete_tenant:
logger.info(
"Last admin user is leaving the organization. Deleting tenant from control plane."
)
try:
await delete_user_from_control_plane(tenant_id, user_to_delete.email)
logger.debug("User deleted from control plane")
except Exception as e:
logger.exception(
f"Failed to delete user from control plane for tenant {tenant_id}: {e}"
)
raise HTTPException(
status_code=500,
detail=f"Failed to remove user from control plane: {str(e)}",
)
db_session.expunge(user_to_delete)
delete_user_from_db(user_to_delete, db_session)
if should_delete_tenant:
remove_all_users_from_tenant(tenant_id)
else:
remove_users_from_tenant([user_to_delete.email], tenant_id)

View File

@@ -1,96 +0,0 @@
import stripe
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from ee.onyx.auth.users import current_admin_user
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
from ee.onyx.server.tenants.access import control_plane_dep
from ee.onyx.server.tenants.billing import fetch_billing_information
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
from ee.onyx.server.tenants.models import BillingInformation
from ee.onyx.server.tenants.models import ProductGatingRequest
from ee.onyx.server.tenants.models import ProductGatingResponse
from ee.onyx.server.tenants.models import SubscriptionSessionResponse
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
from ee.onyx.server.tenants.product_gating import store_product_gating
from onyx.auth.users import User
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
stripe.api_key = STRIPE_SECRET_KEY
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.post("/product-gating")
def gate_product(
product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep)
) -> ProductGatingResponse:
"""
Gating the product means that the product is not available to the tenant.
They will be directed to the billing page.
We gate the product when their subscription has ended.
"""
try:
store_product_gating(
product_gating_request.tenant_id, product_gating_request.application_status
)
return ProductGatingResponse(updated=True, error=None)
except Exception as e:
logger.exception("Failed to gate product")
return ProductGatingResponse(updated=False, error=str(e))
@router.get("/billing-information")
async def billing_information(
_: User = Depends(current_admin_user),
) -> BillingInformation | SubscriptionStatusResponse:
logger.info("Fetching billing information")
tenant_id = get_current_tenant_id()
return fetch_billing_information(tenant_id)
@router.post("/create-customer-portal-session")
async def create_customer_portal_session(
_: User = Depends(current_admin_user),
) -> dict:
tenant_id = get_current_tenant_id()
try:
stripe_info = fetch_tenant_stripe_information(tenant_id)
stripe_customer_id = stripe_info.get("stripe_customer_id")
if not stripe_customer_id:
raise HTTPException(status_code=400, detail="Stripe customer ID not found")
logger.info(stripe_customer_id)
portal_session = stripe.billing_portal.Session.create(
customer=stripe_customer_id,
return_url=f"{WEB_DOMAIN}/admin/billing",
)
logger.info(portal_session)
return {"url": portal_session.url}
except Exception as e:
logger.exception("Failed to create customer portal session")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/create-subscription-session")
async def create_subscription_session(
_: User = Depends(current_admin_user),
) -> SubscriptionSessionResponse:
try:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if not tenant_id:
raise HTTPException(status_code=400, detail="Tenant ID not found")
session_id = fetch_stripe_checkout_session(tenant_id)
return SubscriptionSessionResponse(sessionId=session_id)
except Exception as e:
logger.exception("Failed to create resubscription session")
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -67,30 +67,3 @@ class ProductGatingResponse(BaseModel):
class SubscriptionSessionResponse(BaseModel):
sessionId: str
class TenantByDomainResponse(BaseModel):
tenant_id: str
number_of_users: int
creator_email: str
class TenantByDomainRequest(BaseModel):
email: str
class RequestInviteRequest(BaseModel):
tenant_id: str
class RequestInviteResponse(BaseModel):
success: bool
message: str
class PendingUserSnapshot(BaseModel):
email: str
class ApproveUserRequest(BaseModel):
email: str

View File

@@ -48,5 +48,4 @@ def store_product_gating(tenant_id: str, application_status: ApplicationStatus)
def get_gated_tenants() -> set[str]:
redis_client = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
gated_tenants_bytes = cast(set[bytes], redis_client.smembers(GATED_TENANTS_KEY))
return {tenant_id.decode("utf-8") for tenant_id in gated_tenants_bytes}
return cast(set[str], redis_client.smembers(GATED_TENANTS_KEY))

View File

@@ -4,7 +4,6 @@ import uuid
import aiohttp # Async HTTP client
import httpx
import requests
from fastapi import HTTPException
from fastapi import Request
from sqlalchemy import select
@@ -15,7 +14,6 @@ from ee.onyx.configs.app_configs import COHERE_DEFAULT_API_KEY
from ee.onyx.configs.app_configs import HUBSPOT_TRACKING_URL
from ee.onyx.configs.app_configs import OPENAI_DEFAULT_API_KEY
from ee.onyx.server.tenants.access import generate_data_plane_token
from ee.onyx.server.tenants.models import TenantByDomainResponse
from ee.onyx.server.tenants.models import TenantCreationPayload
from ee.onyx.server.tenants.models import TenantDeletionPayload
from ee.onyx.server.tenants.schema_management import create_schema_if_not_exists
@@ -57,11 +55,7 @@ logger = logging.getLogger(__name__)
async def get_or_provision_tenant(
email: str, referral_source: str | None = None, request: Request | None = None
) -> str:
"""
Get existing tenant ID for an email or create a new tenant if none exists.
This function should only be called after we have verified we want this user's tenant to exist.
It returns the tenant ID associated with the email, creating a new tenant if necessary.
"""
"""Get existing tenant ID for an email or create a new tenant if none exists."""
if not MULTI_TENANT:
return POSTGRES_DEFAULT_SCHEMA
@@ -110,14 +104,14 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
status_code=409, detail="User already belongs to an organization"
)
logger.debug(f"Provisioning tenant {tenant_id} for user {email}")
logger.info(f"Provisioning tenant: {tenant_id}")
token = None
try:
if not create_schema_if_not_exists(tenant_id):
logger.debug(f"Created schema for tenant {tenant_id}")
logger.info(f"Created schema for tenant {tenant_id}")
else:
logger.debug(f"Schema already exists for tenant {tenant_id}")
logger.info(f"Schema already exists for tenant {tenant_id}")
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
@@ -355,47 +349,3 @@ async def delete_user_from_control_plane(tenant_id: str, email: str) -> None:
raise Exception(
f"Failed to delete tenant on control plane: {error_text}"
)
def get_tenant_by_domain_from_control_plane(
domain: str,
tenant_id: str,
) -> TenantByDomainResponse | None:
"""
Fetches tenant information from the control plane based on the email domain.
Args:
domain: The email domain to search for (e.g., "example.com")
Returns:
A dictionary containing tenant information if found, None otherwise
"""
token = generate_data_plane_token()
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
try:
response = requests.get(
f"{CONTROL_PLANE_API_BASE_URL}/tenant-by-domain",
headers=headers,
json={"domain": domain, "tenant_id": tenant_id},
)
if response.status_code != 200:
logger.error(f"Control plane tenant lookup failed: {response.text}")
return None
response_data = response.json()
if not response_data:
return None
return TenantByDomainResponse(
tenant_id=response_data.get("tenant_id"),
number_of_users=response_data.get("number_of_users"),
creator_email=response_data.get("creator_email"),
)
except Exception as e:
logger.error(f"Error fetching tenant by domain: {str(e)}")
return None

View File

@@ -1,67 +0,0 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane
from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant
from onyx.auth.users import current_admin_user
from onyx.auth.users import User
from onyx.db.auth import get_user_count
from onyx.db.engine import get_session
from onyx.db.users import delete_user_from_db
from onyx.db.users import get_user_by_email
from onyx.server.manage.models import UserByEmail
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.post("/leave-team")
async def leave_organization(
user_email: UserByEmail,
current_user: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
tenant_id = get_current_tenant_id()
if current_user is None or current_user.email != user_email.user_email:
raise HTTPException(
status_code=403, detail="You can only leave the organization as yourself"
)
user_to_delete = get_user_by_email(user_email.user_email, db_session)
if user_to_delete is None:
raise HTTPException(status_code=404, detail="User not found")
num_admin_users = await get_user_count(only_admin_users=True)
should_delete_tenant = num_admin_users == 1
if should_delete_tenant:
logger.info(
"Last admin user is leaving the organization. Deleting tenant from control plane."
)
try:
await delete_user_from_control_plane(tenant_id, user_to_delete.email)
logger.debug("User deleted from control plane")
except Exception as e:
logger.exception(
f"Failed to delete user from control plane for tenant {tenant_id}: {e}"
)
raise HTTPException(
status_code=500,
detail=f"Failed to remove user from control plane: {str(e)}",
)
db_session.expunge(user_to_delete)
delete_user_from_db(user_to_delete, db_session)
if should_delete_tenant:
remove_all_users_from_tenant(tenant_id)
else:
remove_users_from_tenant([user_to_delete.email], tenant_id)

View File

@@ -1,39 +0,0 @@
from fastapi import APIRouter
from fastapi import Depends
from ee.onyx.server.tenants.models import TenantByDomainResponse
from ee.onyx.server.tenants.provisioning import get_tenant_by_domain_from_control_plane
from onyx.auth.users import current_user
from onyx.auth.users import User
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/tenants")
FORBIDDEN_COMMON_EMAIL_SUBSTRINGS = [
"gmail",
"outlook",
"yahoo",
"hotmail",
"icloud",
"msn",
"hotmail",
"hotmail.co.uk",
]
@router.get("/existing-team-by-domain")
def get_existing_tenant_by_domain(
user: User | None = Depends(current_user),
) -> TenantByDomainResponse | None:
if not user:
return None
domain = user.email.split("@")[1]
if any(substring in domain for substring in FORBIDDEN_COMMON_EMAIL_SUBSTRINGS):
return None
tenant_id = get_current_tenant_id()
return get_tenant_by_domain_from_control_plane(domain, tenant_id)

View File

@@ -1,90 +0,0 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from ee.onyx.server.tenants.models import ApproveUserRequest
from ee.onyx.server.tenants.models import PendingUserSnapshot
from ee.onyx.server.tenants.models import RequestInviteRequest
from ee.onyx.server.tenants.user_mapping import accept_user_invite
from ee.onyx.server.tenants.user_mapping import approve_user_invite
from ee.onyx.server.tenants.user_mapping import deny_user_invite
from ee.onyx.server.tenants.user_mapping import invite_self_to_tenant
from onyx.auth.invited_users import get_pending_users
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_user
from onyx.auth.users import User
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.post("/users/invite/request")
async def request_invite(
invite_request: RequestInviteRequest,
user: User | None = Depends(current_admin_user),
) -> None:
if user is None:
raise HTTPException(status_code=401, detail="User not authenticated")
try:
invite_self_to_tenant(user.email, invite_request.tenant_id)
except Exception as e:
logger.exception(
f"Failed to invite self to tenant {invite_request.tenant_id}: {e}"
)
raise HTTPException(status_code=500, detail=str(e))
@router.get("/users/pending")
def list_pending_users(
_: User | None = Depends(current_admin_user),
) -> list[PendingUserSnapshot]:
pending_emails = get_pending_users()
return [PendingUserSnapshot(email=email) for email in pending_emails]
@router.post("/users/invite/approve")
async def approve_user(
approve_user_request: ApproveUserRequest,
_: User | None = Depends(current_admin_user),
) -> None:
tenant_id = get_current_tenant_id()
approve_user_invite(approve_user_request.email, tenant_id)
@router.post("/users/invite/accept")
async def accept_invite(
invite_request: RequestInviteRequest,
user: User | None = Depends(current_user),
) -> None:
"""
Accept an invitation to join a tenant.
"""
if not user:
raise HTTPException(status_code=401, detail="Not authenticated")
try:
accept_user_invite(user.email, invite_request.tenant_id)
except Exception as e:
logger.exception(f"Failed to accept invite: {str(e)}")
raise HTTPException(status_code=500, detail="Failed to accept invitation")
@router.post("/users/invite/deny")
async def deny_invite(
invite_request: RequestInviteRequest,
user: User | None = Depends(current_user),
) -> None:
"""
Deny an invitation to join a tenant.
"""
if not user:
raise HTTPException(status_code=401, detail="Not authenticated")
try:
deny_user_invite(user.email, invite_request.tenant_id)
except Exception as e:
logger.exception(f"Failed to deny invite: {str(e)}")
raise HTTPException(status_code=500, detail="Failed to deny invitation")

View File

@@ -1,56 +1,27 @@
import logging
from fastapi_users import exceptions
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.auth.invited_users import get_invited_users
from onyx.auth.invited_users import get_pending_users
from onyx.auth.invited_users import write_invited_users
from onyx.auth.invited_users import write_pending_users
from onyx.db.engine import get_session_with_shared_schema
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_sqlalchemy_engine
from onyx.db.models import UserTenantMapping
from onyx.server.manage.models import TenantSnapshot
from onyx.setup import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
logger = logging.getLogger(__name__)
def get_tenant_id_for_email(email: str) -> str:
if not MULTI_TENANT:
return POSTGRES_DEFAULT_SCHEMA
# Implement logic to get tenant_id from the mapping table
try:
with get_session_with_shared_schema() as db_session:
# First try to get an active tenant
result = db_session.execute(
select(UserTenantMapping).where(
UserTenantMapping.email == email,
UserTenantMapping.active == True, # noqa: E712
)
)
mapping = result.scalar_one_or_none()
tenant_id = mapping.tenant_id if mapping else None
# If no active tenant found, try to get the first inactive one
if tenant_id is None:
result = db_session.execute(
select(UserTenantMapping).where(
UserTenantMapping.email == email,
UserTenantMapping.active == False, # noqa: E712
)
)
mapping = result.scalar_one_or_none()
if mapping:
# Mark this mapping as active
mapping.active = True
db_session.commit()
tenant_id = mapping.tenant_id
except Exception as e:
logger.exception(f"Error getting tenant id for email {email}: {e}")
raise exceptions.UserNotExists()
with Session(get_sqlalchemy_engine()) as db_session:
result = db_session.execute(
select(UserTenantMapping.tenant_id).where(UserTenantMapping.email == email)
)
tenant_id = result.scalar_one_or_none()
if tenant_id is None:
raise exceptions.UserNotExists()
return tenant_id
@@ -70,9 +41,7 @@ def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
try:
for email in emails:
db_session.add(
UserTenantMapping(email=email, tenant_id=tenant_id, active=False)
)
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
except Exception:
logger.exception(f"Failed to add users to tenant {tenant_id}")
db_session.commit()
@@ -107,187 +76,3 @@ def remove_all_users_from_tenant(tenant_id: str) -> None:
UserTenantMapping.tenant_id == tenant_id
).delete()
db_session.commit()
def invite_self_to_tenant(email: str, tenant_id: str) -> None:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
try:
pending_users = get_pending_users()
if email in pending_users:
return
write_pending_users(pending_users + [email])
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def approve_user_invite(email: str, tenant_id: str) -> None:
"""
Approve a user invite to a tenant.
This will delete all existing records for this email and create a new mapping entry for the user in this tenant.
"""
with get_session_with_shared_schema() as db_session:
# Delete all existing records for this email
db_session.query(UserTenantMapping).filter(
UserTenantMapping.email == email
).delete()
# Create a new mapping entry for the user in this tenant
new_mapping = UserTenantMapping(email=email, tenant_id=tenant_id, active=True)
db_session.add(new_mapping)
db_session.commit()
# Also remove the user from pending users list
# Remove from pending users
pending_users = get_pending_users()
if email in pending_users:
pending_users.remove(email)
write_pending_users(pending_users)
# Add to invited users
invited_users = get_invited_users()
if email not in invited_users:
invited_users.append(email)
write_invited_users(invited_users)
def accept_user_invite(email: str, tenant_id: str) -> None:
"""
Accept an invitation to join a tenant.
This activates the user's mapping to the tenant.
"""
with get_session_with_shared_schema() as db_session:
try:
# First check if there's an active mapping for this user and tenant
active_mapping = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email == email,
UserTenantMapping.active == True, # noqa: E712
)
.first()
)
# If an active mapping exists, delete it
if active_mapping:
db_session.delete(active_mapping)
logger.info(
f"Deleted existing active mapping for user {email} in tenant {tenant_id}"
)
# Find the inactive mapping for this user and tenant
mapping = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email == email,
UserTenantMapping.tenant_id == tenant_id,
UserTenantMapping.active == False, # noqa: E712
)
.first()
)
if mapping:
# Set all other mappings for this user to inactive
db_session.query(UserTenantMapping).filter(
UserTenantMapping.email == email,
UserTenantMapping.active == True, # noqa: E712
).update({"active": False})
# Activate this mapping
mapping.active = True
db_session.commit()
logger.info(f"User {email} accepted invitation to tenant {tenant_id}")
else:
logger.warning(
f"No invitation found for user {email} in tenant {tenant_id}"
)
except Exception as e:
db_session.rollback()
logger.exception(
f"Failed to accept invitation for user {email} to tenant {tenant_id}: {str(e)}"
)
raise
def deny_user_invite(email: str, tenant_id: str) -> None:
"""
Deny an invitation to join a tenant.
This removes the user's mapping to the tenant.
"""
with get_session_with_shared_schema() as db_session:
# Delete the mapping for this user and tenant
result = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email == email,
UserTenantMapping.tenant_id == tenant_id,
UserTenantMapping.active == False, # noqa: E712
)
.delete()
)
db_session.commit()
if result:
logger.info(f"User {email} denied invitation to tenant {tenant_id}")
else:
logger.warning(
f"No invitation found for user {email} in tenant {tenant_id}"
)
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
try:
pending_users = get_invited_users()
if email in pending_users:
pending_users.remove(email)
write_invited_users(pending_users)
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def get_tenant_count(tenant_id: str) -> int:
"""
Get the number of active users for this tenant
"""
with get_session_with_shared_schema() as db_session:
# Count the number of active users for this tenant
user_count = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.tenant_id == tenant_id,
UserTenantMapping.active == True, # noqa: E712
)
.count()
)
return user_count
def get_tenant_invitation(email: str) -> TenantSnapshot | None:
"""
Get the first tenant invitation for this user
"""
with get_session_with_shared_schema() as db_session:
# Get the first tenant invitation for this user
invitation = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email == email,
UserTenantMapping.active == False, # noqa: E712
)
.first()
)
if invitation:
# Get the user count for this tenant
user_count = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.tenant_id == invitation.tenant_id,
UserTenantMapping.active == True, # noqa: E712
)
.count()
)
return TenantSnapshot(
tenant_id=invitation.tenant_id, number_of_users=user_count
)
return None

View File

@@ -6,7 +6,7 @@ MODEL_WARM_UP_STRING = "hi " * 512
DEFAULT_OPENAI_MODEL = "text-embedding-3-small"
DEFAULT_COHERE_MODEL = "embed-english-light-v3.0"
DEFAULT_VOYAGE_MODEL = "voyage-large-2-instruct"
DEFAULT_VERTEX_MODEL = "text-embedding-005"
DEFAULT_VERTEX_MODEL = "text-embedding-004"
class EmbeddingModelTextType:

View File

@@ -5,7 +5,6 @@ from types import TracebackType
from typing import cast
from typing import Optional
import aioboto3 # type: ignore
import httpx
import openai
import vertexai # type: ignore
@@ -29,13 +28,11 @@ from model_server.constants import DEFAULT_VERTEX_MODEL
from model_server.constants import DEFAULT_VOYAGE_MODEL
from model_server.constants import EmbeddingModelTextType
from model_server.constants import EmbeddingProvider
from model_server.utils import pass_aws_key
from model_server.utils import simple_log_function_time
from onyx.utils.logger import setup_logger
from shared_configs.configs import API_BASED_EMBEDDING_TIMEOUT
from shared_configs.configs import INDEXING_ONLY
from shared_configs.configs import OPENAI_EMBEDDING_TIMEOUT
from shared_configs.configs import VERTEXAI_EMBEDDING_LOCAL_BATCH_SIZE
from shared_configs.enums import EmbedTextType
from shared_configs.enums import RerankerProvider
from shared_configs.model_server_models import Embedding
@@ -62,60 +59,6 @@ _OPENAI_MAX_INPUT_LEN = 2048
# Cohere allows up to 96 embeddings in a single embedding calling
_COHERE_MAX_INPUT_LEN = 96
# Authentication error string constants
_AUTH_ERROR_401 = "401"
_AUTH_ERROR_UNAUTHORIZED = "unauthorized"
_AUTH_ERROR_INVALID_API_KEY = "invalid api key"
_AUTH_ERROR_PERMISSION = "permission"
def is_authentication_error(error: Exception) -> bool:
"""Check if an exception is related to authentication issues.
Args:
error: The exception to check
Returns:
bool: True if the error appears to be authentication-related
"""
error_str = str(error).lower()
return (
_AUTH_ERROR_401 in error_str
or _AUTH_ERROR_UNAUTHORIZED in error_str
or _AUTH_ERROR_INVALID_API_KEY in error_str
or _AUTH_ERROR_PERMISSION in error_str
)
def format_embedding_error(
error: Exception,
service_name: str,
model: str | None,
provider: EmbeddingProvider,
status_code: int | None = None,
) -> str:
"""
Format a standardized error string for embedding errors.
"""
detail = f"Status {status_code}" if status_code else f"{type(error)}"
return (
f"{'HTTP error' if status_code else 'Exception'} embedding text with {service_name} - {detail}: "
f"Model: {model} "
f"Provider: {provider} "
f"Exception: {error}"
)
# Custom exception for authentication errors
class AuthenticationError(Exception):
"""Raised when authentication fails with a provider."""
def __init__(self, provider: str, message: str = "API key is invalid or expired"):
self.provider = provider
self.message = message
super().__init__(f"{provider} authentication failed: {message}")
class CloudEmbedding:
def __init__(
@@ -135,7 +78,7 @@ class CloudEmbedding:
self._closed = False
async def _embed_openai(
self, texts: list[str], model: str | None, reduced_dimension: int | None
self, texts: list[str], model: str | None
) -> list[Embedding]:
if not model:
model = DEFAULT_OPENAI_MODEL
@@ -146,17 +89,27 @@ class CloudEmbedding:
)
final_embeddings: list[Embedding] = []
try:
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
response = await client.embeddings.create(input=text_batch, model=model)
final_embeddings.extend(
[embedding.embedding for embedding in response.data]
)
return final_embeddings
except Exception as e:
error_string = (
f"Exception embedding text with OpenAI - {type(e)}: "
f"Model: {model} "
f"Provider: {self.provider} "
f"Exception: {e}"
)
logger.error(error_string)
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
response = await client.embeddings.create(
input=text_batch,
model=model,
dimensions=reduced_dimension or openai.NOT_GIVEN,
)
final_embeddings.extend(
[embedding.embedding for embedding in response.data]
)
return final_embeddings
# only log text when it's not an authentication error.
if not isinstance(e, openai.AuthenticationError):
logger.debug(f"Exception texts: {texts}")
raise RuntimeError(error_string)
async def _embed_cohere(
self, texts: list[str], model: str | None, embedding_type: str
@@ -195,6 +148,7 @@ class CloudEmbedding:
input_type=embedding_type,
truncation=True,
)
return response.embeddings
async def _embed_azure(
@@ -224,24 +178,17 @@ class CloudEmbedding:
vertexai.init(project=project_id, credentials=credentials)
client = TextEmbeddingModel.from_pretrained(model)
inputs = [TextEmbeddingInput(text, embedding_type) for text in texts]
# Split into batches of 25 texts
max_texts_per_batch = VERTEXAI_EMBEDDING_LOCAL_BATCH_SIZE
batches = [
inputs[i : i + max_texts_per_batch]
for i in range(0, len(inputs), max_texts_per_batch)
]
# Dispatch all embedding calls asynchronously at once
tasks = [
client.get_embeddings_async(batch, auto_truncate=True) for batch in batches
]
# Wait for all tasks to complete in parallel
results = await asyncio.gather(*tasks)
return [embedding.values for batch in results for embedding in batch]
embeddings = await client.get_embeddings_async(
[
TextEmbeddingInput(
text,
embedding_type,
)
for text in texts
],
auto_truncate=True, # This is the default
)
return [embedding.values for embedding in embeddings]
async def _embed_litellm_proxy(
self, texts: list[str], model_name: str | None
@@ -276,53 +223,23 @@ class CloudEmbedding:
text_type: EmbedTextType,
model_name: str | None = None,
deployment_name: str | None = None,
reduced_dimension: int | None = None,
) -> list[Embedding]:
try:
if self.provider == EmbeddingProvider.OPENAI:
return await self._embed_openai(texts, model_name, reduced_dimension)
elif self.provider == EmbeddingProvider.AZURE:
return await self._embed_azure(texts, f"azure/{deployment_name}")
elif self.provider == EmbeddingProvider.LITELLM:
return await self._embed_litellm_proxy(texts, model_name)
if self.provider == EmbeddingProvider.OPENAI:
return await self._embed_openai(texts, model_name)
elif self.provider == EmbeddingProvider.AZURE:
return await self._embed_azure(texts, f"azure/{deployment_name}")
elif self.provider == EmbeddingProvider.LITELLM:
return await self._embed_litellm_proxy(texts, model_name)
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
if self.provider == EmbeddingProvider.COHERE:
return await self._embed_cohere(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.VOYAGE:
return await self._embed_voyage(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.GOOGLE:
return await self._embed_vertex(texts, model_name, embedding_type)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
except openai.AuthenticationError:
raise AuthenticationError(provider="OpenAI")
except httpx.HTTPStatusError as e:
if e.response.status_code == 401:
raise AuthenticationError(provider=str(self.provider))
error_string = format_embedding_error(
e,
str(self.provider),
model_name or deployment_name,
self.provider,
status_code=e.response.status_code,
)
logger.error(error_string)
logger.debug(f"Exception texts: {texts}")
raise RuntimeError(error_string)
except Exception as e:
if is_authentication_error(e):
raise AuthenticationError(provider=str(self.provider))
error_string = format_embedding_error(
e, str(self.provider), model_name or deployment_name, self.provider
)
logger.error(error_string)
logger.debug(f"Exception texts: {texts}")
raise RuntimeError(error_string)
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
if self.provider == EmbeddingProvider.COHERE:
return await self._embed_cohere(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.VOYAGE:
return await self._embed_voyage(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.GOOGLE:
return await self._embed_vertex(texts, model_name, embedding_type)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
@staticmethod
def create(
@@ -409,7 +326,6 @@ async def embed_text(
prefix: str | None,
api_url: str | None,
api_version: str | None,
reduced_dimension: int | None,
gpu_type: str = "UNKNOWN",
) -> list[Embedding]:
if not all(texts):
@@ -453,7 +369,6 @@ async def embed_text(
model_name=model_name,
deployment_name=deployment_name,
text_type=text_type,
reduced_dimension=reduced_dimension,
)
if any(embedding is None for embedding in embeddings):
@@ -525,7 +440,7 @@ async def local_rerank(query: str, docs: list[str], model_name: str) -> list[flo
)
async def cohere_rerank_api(
async def cohere_rerank(
query: str, docs: list[str], model_name: str, api_key: str
) -> list[float]:
cohere_client = CohereAsyncClient(api_key=api_key)
@@ -535,45 +450,6 @@ async def cohere_rerank_api(
return [result.relevance_score for result in sorted_results]
async def cohere_rerank_aws(
query: str,
docs: list[str],
model_name: str,
region_name: str,
aws_access_key_id: str,
aws_secret_access_key: str,
) -> list[float]:
session = aioboto3.Session(
aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key
)
async with session.client(
"bedrock-runtime", region_name=region_name
) as bedrock_client:
body = json.dumps(
{
"query": query,
"documents": docs,
"api_version": 2,
}
)
# Invoke the Bedrock model asynchronously
response = await bedrock_client.invoke_model(
modelId=model_name,
accept="application/json",
contentType="application/json",
body=body,
)
# Read the response asynchronously
response_body = json.loads(await response["body"].read())
# Extract and sort the results
results = response_body.get("results", [])
sorted_results = sorted(results, key=lambda item: item["index"])
return [result["relevance_score"] for result in sorted_results]
async def litellm_rerank(
query: str, docs: list[str], api_url: str, model_name: str, api_key: str | None
) -> list[float]:
@@ -632,18 +508,10 @@ async def process_embed_request(
text_type=embed_request.text_type,
api_url=embed_request.api_url,
api_version=embed_request.api_version,
reduced_dimension=embed_request.reduced_dimension,
prefix=prefix,
gpu_type=gpu_type,
)
return EmbedResponse(embeddings=embeddings)
except AuthenticationError as e:
# Handle authentication errors consistently
logger.error(f"Authentication error: {e.provider}")
raise HTTPException(
status_code=401,
detail=f"Authentication failed: {e.message}",
)
except RateLimitError as e:
raise HTTPException(
status_code=429,
@@ -696,32 +564,15 @@ async def process_rerank_request(rerank_request: RerankRequest) -> RerankRespons
elif rerank_request.provider_type == RerankerProvider.COHERE:
if rerank_request.api_key is None:
raise RuntimeError("Cohere Rerank Requires an API Key")
sim_scores = await cohere_rerank_api(
sim_scores = await cohere_rerank(
query=rerank_request.query,
docs=rerank_request.documents,
model_name=rerank_request.model_name,
api_key=rerank_request.api_key,
)
return RerankResponse(scores=sim_scores)
elif rerank_request.provider_type == RerankerProvider.BEDROCK:
if rerank_request.api_key is None:
raise RuntimeError("Bedrock Rerank Requires an API Key")
aws_access_key_id, aws_secret_access_key, aws_region = pass_aws_key(
rerank_request.api_key
)
sim_scores = await cohere_rerank_aws(
query=rerank_request.query,
docs=rerank_request.documents,
model_name=rerank_request.model_name,
region_name=aws_region,
aws_access_key_id=aws_access_key_id,
aws_secret_access_key=aws_secret_access_key,
)
return RerankResponse(scores=sim_scores)
else:
raise ValueError(f"Unsupported provider: {rerank_request.provider_type}")
except Exception as e:
logger.exception(f"Error during reranking process:\n{str(e)}")
raise HTTPException(

View File

@@ -70,32 +70,3 @@ def get_gpu_type() -> str:
return GPUStatus.MAC_MPS
return GPUStatus.NONE
def pass_aws_key(api_key: str) -> tuple[str, str, str]:
"""Parse AWS API key string into components.
Args:
api_key: String in format 'aws_ACCESSKEY_SECRETKEY_REGION'
Returns:
Tuple of (access_key, secret_key, region)
Raises:
ValueError: If key format is invalid
"""
if not api_key.startswith("aws"):
raise ValueError("API key must start with 'aws' prefix")
parts = api_key.split("_")
if len(parts) != 4:
raise ValueError(
f"API key must be in format 'aws_ACCESSKEY_SECRETKEY_REGION', got {len(parts) - 1} parts"
"this is an onyx specific format for formatting the aws secrets for bedrock"
)
try:
_, aws_access_key_id, aws_secret_access_key, aws_region = parts
return aws_access_key_id, aws_secret_access_key, aws_region
except Exception as e:
raise ValueError(f"Failed to parse AWS key components: {str(e)}")

View File

@@ -31,7 +31,6 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_VALIDATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_SUBANSWER_CHECK
from onyx.llm.chat_llm import LLMRateLimitError
@@ -93,7 +92,6 @@ def check_sub_answer(
fast_llm.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK,
max_tokens=AGENT_MAX_TOKENS_VALIDATION,
)
quality_str: str = cast(str, response.content)

View File

@@ -46,7 +46,6 @@ from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_SUBANSWER_GENERATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION
from onyx.llm.chat_llm import LLMRateLimitError
@@ -120,7 +119,6 @@ def generate_sub_answer(
for message in fast_llm.stream(
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION,
max_tokens=AGENT_MAX_TOKENS_SUBANSWER_GENERATION,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content

View File

@@ -43,7 +43,6 @@ from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrin
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_section_list,
)
from onyx.agents.agent_search.shared_graph_utils.utils import _should_restrict_tokens
from onyx.agents.agent_search.shared_graph_utils.utils import (
dispatch_main_answer_stop_info,
)
@@ -63,7 +62,6 @@ from onyx.chat.models import StreamingError
from onyx.configs.agent_configs import AGENT_ANSWER_GENERATION_BY_FAST_LLM
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_MAX_STREAMED_DOCS_FOR_INITIAL_ANSWER
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_ANSWER_GENERATION
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
@@ -155,9 +153,8 @@ def generate_initial_answer(
)
for tool_response in yield_search_responses(
query=question,
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
get_reranked_sections=lambda: answer_generation_documents.streaming_documents,
get_final_context_sections=lambda: answer_generation_documents.context_documents,
reranked_sections=answer_generation_documents.streaming_documents,
final_context_sections=answer_generation_documents.context_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,
search_tool=graph_config.tooling.search_tool,
@@ -281,9 +278,6 @@ def generate_initial_answer(
for message in model.stream(
msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
max_tokens=AGENT_MAX_TOKENS_ANSWER_GENERATION
if _should_restrict_tokens(model.config)
else None,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content

View File

@@ -34,7 +34,6 @@ from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.chat.models import SubQuestionPiece
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_SUBQUESTION_GENERATION
from onyx.configs.agent_configs import AGENT_NUM_DOCS_FOR_DECOMPOSITION
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION,
@@ -142,7 +141,6 @@ def decompose_orig_question(
model.stream(
msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION,
max_tokens=AGENT_MAX_TOKENS_SUBQUESTION_GENERATION,
),
dispatch_subquestion(0, writer),
sep_callback=dispatch_subquestion_sep(0, writer),

View File

@@ -33,7 +33,6 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import RefinedAnswerImprovement
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_VALIDATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_COMPARE_ANSWERS
from onyx.llm.chat_llm import LLMRateLimitError
@@ -113,7 +112,6 @@ def compare_answers(
model.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS,
max_tokens=AGENT_MAX_TOKENS_VALIDATION,
)
except (LLMTimeoutError, TimeoutError):

View File

@@ -43,7 +43,6 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import StreamingError
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_SUBQUESTION_GENERATION
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION,
)
@@ -145,7 +144,6 @@ def create_refined_sub_questions(
model.stream(
msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION,
max_tokens=AGENT_MAX_TOKENS_SUBQUESTION_GENERATION,
),
dispatch_subquestion(1, writer),
sep_callback=dispatch_subquestion_sep(1, writer),

View File

@@ -50,7 +50,13 @@ def decide_refinement_need(
)
]
return RequireRefinemenEvalUpdate(
require_refined_answer_eval=graph_config.behavior.allow_refinement and decision,
log_messages=log_messages,
)
if graph_config.behavior.allow_refinement:
return RequireRefinemenEvalUpdate(
require_refined_answer_eval=decision,
log_messages=log_messages,
)
else:
return RequireRefinemenEvalUpdate(
require_refined_answer_eval=False,
log_messages=log_messages,
)

View File

@@ -21,7 +21,6 @@ from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_ENTITY_TERM_EXTRACTION
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION,
)
@@ -97,7 +96,6 @@ def extract_entities_terms(
fast_llm.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION,
max_tokens=AGENT_MAX_TOKENS_ENTITY_TERM_EXTRACTION,
)
cleaned_response = (

View File

@@ -46,7 +46,6 @@ from onyx.agents.agent_search.shared_graph_utils.models import RefinedAgentStats
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_section_list,
)
from onyx.agents.agent_search.shared_graph_utils.utils import _should_restrict_tokens
from onyx.agents.agent_search.shared_graph_utils.utils import (
dispatch_main_answer_stop_info,
)
@@ -69,8 +68,6 @@ from onyx.chat.models import StreamingError
from onyx.configs.agent_configs import AGENT_ANSWER_GENERATION_BY_FAST_LLM
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_MAX_STREAMED_DOCS_FOR_REFINED_ANSWER
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_ANSWER_GENERATION
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_VALIDATION
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION,
@@ -182,9 +179,8 @@ def generate_validate_refined_answer(
)
for tool_response in yield_search_responses(
query=question,
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
get_reranked_sections=lambda: answer_generation_documents.streaming_documents,
get_final_context_sections=lambda: answer_generation_documents.context_documents,
reranked_sections=answer_generation_documents.streaming_documents,
final_context_sections=answer_generation_documents.context_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,
search_tool=graph_config.tooling.search_tool,
@@ -306,11 +302,7 @@ def generate_validate_refined_answer(
def stream_refined_answer() -> list[str]:
for message in model.stream(
msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION,
max_tokens=AGENT_MAX_TOKENS_ANSWER_GENERATION
if _should_restrict_tokens(model.config)
else None,
msg, timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
@@ -417,7 +409,6 @@ def generate_validate_refined_answer(
validation_model.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION,
max_tokens=AGENT_MAX_TOKENS_VALIDATION,
)
refined_answer_quality = binary_string_test_after_answer_separator(
text=cast(str, validation_response.content),

View File

@@ -13,6 +13,7 @@ from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.chat.models import SubQuestionPiece
from onyx.context.search.models import IndexFilters
from onyx.tools.models import SearchQueryInfo
from onyx.utils.logger import setup_logger
@@ -143,6 +144,8 @@ def get_query_info(results: list[QueryRetrievalResult]) -> SearchQueryInfo:
if result.query_info is not None:
query_info = result.query_info
break
assert query_info is not None, "must have query info"
return query_info
return query_info or SearchQueryInfo(
predicted_search=None,
final_filters=IndexFilters(access_control_list=None),
recency_bias_multiplier=1.0,
)

View File

@@ -33,7 +33,6 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_SUBQUERY_GENERATION
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION,
)
@@ -97,7 +96,6 @@ def expand_queries(
model.stream(
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION,
max_tokens=AGENT_MAX_TOKENS_SUBQUERY_GENERATION,
),
dispatch_subquery(level, question_num, writer),
)

View File

@@ -56,9 +56,8 @@ def format_results(
relevance_list = relevance_from_docs(reranked_documents)
for tool_response in yield_search_responses(
query=state.question,
get_retrieved_sections=lambda: reranked_documents,
get_reranked_sections=lambda: state.retrieved_documents,
get_final_context_sections=lambda: reranked_documents,
reranked_sections=state.retrieved_documents,
final_context_sections=reranked_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,
search_tool=graph_config.tooling.search_tool,

View File

@@ -91,7 +91,7 @@ def retrieve_documents(
retrieved_docs = retrieved_docs[:AGENT_MAX_QUERY_RETRIEVAL_RESULTS]
if AGENT_RETRIEVAL_STATS:
pre_rerank_docs = callback_container[0] if callback_container else []
pre_rerank_docs = callback_container[0]
fit_scores = get_fit_scores(
pre_rerank_docs,
retrieved_docs,

View File

@@ -25,7 +25,6 @@ from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrin
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_VALIDATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION
from onyx.llm.chat_llm import LLMRateLimitError
@@ -94,7 +93,6 @@ def verify_documents(
fast_llm.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION,
max_tokens=AGENT_MAX_TOKENS_VALIDATION,
)
assert isinstance(response.content, str)

View File

@@ -44,9 +44,7 @@ def call_tool(
tool = tool_choice.tool
tool_args = tool_choice.tool_args
tool_id = tool_choice.id
tool_runner = ToolRunner(
tool, tool_args, override_kwargs=tool_choice.search_tool_override_kwargs
)
tool_runner = ToolRunner(tool, tool_args)
tool_kickoff = tool_runner.kickoff()
emit_packet(tool_kickoff, writer)

View File

@@ -15,17 +15,8 @@ from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
from onyx.chat.tool_handling.tool_response_handler import (
get_tool_call_for_non_tool_calling_llm_impl,
)
from onyx.context.search.preprocessing.preprocessing import query_analysis
from onyx.context.search.retrieval.search_runner import get_query_embedding
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_in_background
from onyx.utils.threadpool_concurrency import TimeoutThread
from onyx.utils.threadpool_concurrency import wait_on_background
from onyx.utils.timing import log_function_time
from shared_configs.model_server_models import Embedding
logger = setup_logger()
@@ -34,7 +25,6 @@ logger = setup_logger()
# and a function that handles extracting the necessary fields
# from the state and config
# TODO: fan-out to multiple tool call nodes? Make this configurable?
@log_function_time(print_only=True)
def choose_tool(
state: ToolChoiceState,
config: RunnableConfig,
@@ -47,31 +37,6 @@ def choose_tool(
should_stream_answer = state.should_stream_answer
agent_config = cast(GraphConfig, config["metadata"]["config"])
force_use_tool = agent_config.tooling.force_use_tool
embedding_thread: TimeoutThread[Embedding] | None = None
keyword_thread: TimeoutThread[tuple[bool, list[str]]] | None = None
override_kwargs: SearchToolOverrideKwargs | None = None
if (
not agent_config.behavior.use_agentic_search
and agent_config.tooling.search_tool is not None
and (
not force_use_tool.force_use or force_use_tool.tool_name == SearchTool.name
)
):
override_kwargs = SearchToolOverrideKwargs()
# Run in a background thread to avoid blocking the main thread
embedding_thread = run_in_background(
get_query_embedding,
agent_config.inputs.search_request.query,
agent_config.persistence.db_session,
)
keyword_thread = run_in_background(
query_analysis,
agent_config.inputs.search_request.query,
)
using_tool_calling_llm = agent_config.tooling.using_tool_calling_llm
prompt_builder = state.prompt_snapshot or agent_config.inputs.prompt_builder
@@ -82,6 +47,7 @@ def choose_tool(
tools = [
tool for tool in (agent_config.tooling.tools or []) if tool.name in state.tools
]
force_use_tool = agent_config.tooling.force_use_tool
tool, tool_args = None, None
if force_use_tool.force_use and force_use_tool.args is not None:
@@ -105,22 +71,11 @@ def choose_tool(
# If we have a tool and tool args, we are ready to request a tool call.
# This only happens if the tool call was forced or we are using a non-tool calling LLM.
if tool and tool_args:
if embedding_thread and tool.name == SearchTool._NAME:
# Wait for the embedding thread to finish
embedding = wait_on_background(embedding_thread)
assert override_kwargs is not None, "must have override kwargs"
override_kwargs.precomputed_query_embedding = embedding
if keyword_thread and tool.name == SearchTool._NAME:
is_keyword, keywords = wait_on_background(keyword_thread)
assert override_kwargs is not None, "must have override kwargs"
override_kwargs.precomputed_is_keyword = is_keyword
override_kwargs.precomputed_keywords = keywords
return ToolChoiceUpdate(
tool_choice=ToolChoice(
tool=tool,
tool_args=tool_args,
id=str(uuid4()),
search_tool_override_kwargs=override_kwargs,
),
)
@@ -143,16 +98,8 @@ def choose_tool(
# For tool calling LLMs, we want to insert the task prompt as part of this flow, this is because the LLM
# may choose to not call any tools and just generate the answer, in which case the task prompt is needed.
prompt=built_prompt,
tools=(
[tool.tool_definition() for tool in tools] or None
if using_tool_calling_llm
else None
),
tool_choice=(
"required"
if tools and force_use_tool.force_use and using_tool_calling_llm
else None
),
tools=[tool.tool_definition() for tool in tools] or None,
tool_choice=("required" if tools and force_use_tool.force_use else None),
structured_response_format=structured_response_format,
)
@@ -198,22 +145,10 @@ def choose_tool(
logger.debug(f"Selected tool: {selected_tool.name}")
logger.debug(f"Selected tool call request: {selected_tool_call_request}")
if embedding_thread and selected_tool.name == SearchTool._NAME:
# Wait for the embedding thread to finish
embedding = wait_on_background(embedding_thread)
assert override_kwargs is not None, "must have override kwargs"
override_kwargs.precomputed_query_embedding = embedding
if keyword_thread and selected_tool.name == SearchTool._NAME:
is_keyword, keywords = wait_on_background(keyword_thread)
assert override_kwargs is not None, "must have override kwargs"
override_kwargs.precomputed_is_keyword = is_keyword
override_kwargs.precomputed_keywords = keywords
return ToolChoiceUpdate(
tool_choice=ToolChoice(
tool=selected_tool,
tool_args=selected_tool_call_request["args"],
id=selected_tool_call_request["id"],
search_tool_override_kwargs=override_kwargs,
),
)

View File

@@ -9,23 +9,18 @@ from onyx.agents.agent_search.basic.states import BasicState
from onyx.agents.agent_search.basic.utils import process_llm_stream
from onyx.agents.agent_search.models import GraphConfig
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxContexts
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_RESPONSE_SUMMARY_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
from onyx.tools.tool_implementations.search.search_utils import (
context_from_inference_section,
SEARCH_DOC_CONTENT_ID,
)
from onyx.tools.tool_implementations.search_like_tool_utils import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
from onyx.utils.logger import setup_logger
from onyx.utils.timing import log_function_time
logger = setup_logger()
@log_function_time(print_only=True)
def basic_use_tool_response(
state: BasicState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> BasicOutput:
@@ -55,13 +50,11 @@ def basic_use_tool_response(
for yield_item in tool_call_responses:
if yield_item.id == FINAL_CONTEXT_DOCUMENTS_ID:
final_search_results = cast(list[LlmDoc], yield_item.response)
elif yield_item.id == SEARCH_RESPONSE_SUMMARY_ID:
search_response_summary = cast(SearchResponseSummary, yield_item.response)
for section in search_response_summary.top_sections:
if section.center_chunk.document_id not in initial_search_results:
initial_search_results.append(
context_from_inference_section(section)
)
elif yield_item.id == SEARCH_DOC_CONTENT_ID:
search_contexts = cast(OnyxContexts, yield_item.response).contexts
for doc in search_contexts:
if doc.document_id not in initial_search_results:
initial_search_results.append(doc)
new_tool_call_chunk = AIMessageChunk(content="")
if not agent_config.behavior.skip_gen_ai_answer_generation:

View File

@@ -2,7 +2,6 @@ from pydantic import BaseModel
from onyx.chat.prompt_builder.answer_prompt_builder import PromptSnapshot
from onyx.tools.message import ToolCallSummary
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.models import ToolCallFinalResult
from onyx.tools.models import ToolCallKickoff
from onyx.tools.models import ToolResponse
@@ -36,7 +35,6 @@ class ToolChoice(BaseModel):
tool: Tool
tool_args: dict
id: str | None
search_tool_override_kwargs: SearchToolOverrideKwargs | None = None
class Config:
arbitrary_types_allowed = True

View File

@@ -13,11 +13,6 @@ AGENT_NEGATIVE_VALUE_STR = "no"
AGENT_ANSWER_SEPARATOR = "Answer:"
EMBEDDING_KEY = "embedding"
IS_KEYWORD_KEY = "is_keyword"
KEYWORDS_KEY = "keywords"
class AgentLLMErrorType(str, Enum):
TIMEOUT = "timeout"
RATE_LIMIT = "rate_limit"

View File

@@ -42,7 +42,6 @@ from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_HISTORY_SUMMARY
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION,
)
@@ -62,7 +61,6 @@ from onyx.db.persona import Persona
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import LLMConfig
from onyx.prompts.agent_search import (
ASSISTANT_SYSTEM_PROMPT_DEFAULT,
)
@@ -404,7 +402,6 @@ def summarize_history(
llm.invoke,
history_context_prompt,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION,
max_tokens=AGENT_MAX_TOKENS_HISTORY_SUMMARY,
)
except (LLMTimeoutError, TimeoutError):
logger.error("LLM Timeout Error - summarize history")
@@ -508,9 +505,3 @@ def get_deduplicated_structured_subquestion_documents(
cited_documents=dedup_inference_section_list(cited_docs),
context_documents=dedup_inference_section_list(context_docs),
)
def _should_restrict_tokens(llm_config: LLMConfig) -> bool:
return not (
llm_config.model_provider == "openai" and llm_config.model_name.startswith("o")
)

View File

@@ -153,8 +153,7 @@ def send_email(
msg = MIMEMultipart("alternative")
msg["Subject"] = subject
msg["To"] = user_email
if mail_from:
msg["From"] = mail_from
msg["From"] = mail_from
msg["Date"] = formatdate(localtime=True)
msg["Message-ID"] = make_msgid(domain="onyx.app")

View File

@@ -1,6 +1,5 @@
from typing import cast
from onyx.configs.constants import KV_PENDING_USERS_KEY
from onyx.configs.constants import KV_USER_STORE_KEY
from onyx.key_value_store.factory import get_kv_store
from onyx.key_value_store.interface import KvKeyNotFoundError
@@ -19,17 +18,3 @@ def write_invited_users(emails: list[str]) -> int:
store = get_kv_store()
store.store(KV_USER_STORE_KEY, cast(JSON_ro, emails))
return len(emails)
def get_pending_users() -> list[str]:
try:
store = get_kv_store()
return cast(list, store.load(KV_PENDING_USERS_KEY))
except KvKeyNotFoundError:
return list()
def write_pending_users(emails: list[str]) -> int:
store = get_kv_store()
store.store(KV_PENDING_USERS_KEY, cast(JSON_ro, emails))
return len(emails)

View File

@@ -100,7 +100,6 @@ from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
from onyx.utils.url import add_url_params
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
from onyx.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import async_return_default_schema
@@ -524,7 +523,6 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
try:
user_count = await get_user_count()
logger.debug(f"Current tenant user count: {user_count}")
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
if user_count == 1:
@@ -546,7 +544,7 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
logger.debug(f"User {user.id} has registered.")
logger.notice(f"User {user.id} has registered.")
optional_telemetry(
record_type=RecordType.SIGN_UP,
data={"action": "create"},
@@ -588,20 +586,14 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
) -> Optional[User]:
email = credentials.username
tenant_id: str | None = None
try:
tenant_id = fetch_ee_implementation_or_noop(
"onyx.server.tenants.provisioning",
"get_tenant_id_for_email",
None,
)(
email=email,
)
except Exception as e:
logger.warning(
f"User attempted to login with invalid credentials: {str(e)}"
)
# Get tenant_id from mapping table
tenant_id = await fetch_ee_implementation_or_noop(
"onyx.server.tenants.provisioning",
"get_or_provision_tenant",
async_return_default_schema,
)(
email=email,
)
if not tenant_id:
# User not found in mapping
self.password_helper.hash(credentials.password)
@@ -895,7 +887,7 @@ async def current_limited_user(
return await double_check_user(user)
async def current_chat_accessible_user(
async def current_chat_accesssible_user(
user: User | None = Depends(optional_user),
) -> User | None:
tenant_id = get_current_tenant_id()
@@ -1096,12 +1088,6 @@ def get_oauth_router(
next_url = state_data.get("next_url", "/")
referral_source = state_data.get("referral_source", None)
try:
tenant_id = fetch_ee_implementation_or_noop(
"onyx.server.tenants.user_mapping", "get_tenant_id_for_email", None
)(account_email)
except exceptions.UserNotExists:
tenant_id = None
request.state.referral_source = referral_source
@@ -1133,14 +1119,9 @@ def get_oauth_router(
# Login user
response = await backend.login(strategy, user)
await user_manager.on_after_login(user, request, response)
# Prepare redirect response
if tenant_id is None:
# Use URL utility to add parameters
redirect_url = add_url_params(next_url, {"new_team": "true"})
redirect_response = RedirectResponse(redirect_url, status_code=302)
else:
# No parameters to add
redirect_response = RedirectResponse(next_url, status_code=302)
redirect_response = RedirectResponse(next_url, status_code=302)
# Copy headers and other attributes from 'response' to 'redirect_response'
for header_name, header_value in response.headers.items():
@@ -1152,7 +1133,6 @@ def get_oauth_router(
redirect_response.status_code = response.status_code
if hasattr(response, "media_type"):
redirect_response.media_type = response.media_type
return redirect_response
return router

View File

@@ -111,6 +111,5 @@ celery_app.autodiscover_tasks(
"onyx.background.celery.tasks.vespa",
"onyx.background.celery.tasks.connector_deletion",
"onyx.background.celery.tasks.doc_permission_syncing",
"onyx.background.celery.tasks.indexing",
]
)

View File

@@ -1,73 +0,0 @@
# backend/onyx/background/celery/memory_monitoring.py
import logging
import os
from logging.handlers import RotatingFileHandler
import psutil
from onyx.utils.logger import is_running_in_container
from onyx.utils.logger import setup_logger
# Regular application logger
logger = setup_logger()
# Only set up memory monitoring in container environment
if is_running_in_container():
# Set up a dedicated memory monitoring logger
MEMORY_LOG_DIR = "/var/log/persisted-logs/memory"
MEMORY_LOG_FILE = os.path.join(MEMORY_LOG_DIR, "memory_usage.log")
MEMORY_LOG_MAX_BYTES = 10 * 1024 * 1024 # 10MB
MEMORY_LOG_BACKUP_COUNT = 5 # Keep 5 backup files
# Ensure log directory exists
os.makedirs(MEMORY_LOG_DIR, exist_ok=True)
# Create a dedicated logger for memory monitoring
memory_logger = logging.getLogger("memory_monitoring")
memory_logger.setLevel(logging.INFO)
# Create a rotating file handler
memory_handler = RotatingFileHandler(
MEMORY_LOG_FILE,
maxBytes=MEMORY_LOG_MAX_BYTES,
backupCount=MEMORY_LOG_BACKUP_COUNT,
)
# Create a formatter that includes all relevant information
memory_formatter = logging.Formatter(
"%(asctime)s [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
)
memory_handler.setFormatter(memory_formatter)
memory_logger.addHandler(memory_handler)
else:
# Create a null logger when not in container
memory_logger = logging.getLogger("memory_monitoring")
memory_logger.addHandler(logging.NullHandler())
def emit_process_memory(
pid: int, process_name: str, additional_metadata: dict[str, str | int]
) -> None:
# Skip memory monitoring if not in container
if not is_running_in_container():
return
try:
process = psutil.Process(pid)
memory_info = process.memory_info()
cpu_percent = process.cpu_percent(interval=0.1)
# Build metadata string from additional_metadata dictionary
metadata_str = " ".join(
[f"{key}={value}" for key, value in additional_metadata.items()]
)
metadata_str = f" {metadata_str}" if metadata_str else ""
memory_logger.info(
f"PROCESS_MEMORY process_name={process_name} pid={pid} "
f"rss_mb={memory_info.rss / (1024 * 1024):.2f} "
f"vms_mb={memory_info.vms / (1024 * 1024):.2f} "
f"cpu={cpu_percent:.2f}{metadata_str}"
)
except Exception:
logger.exception("Error monitoring process memory.")

View File

@@ -23,10 +23,9 @@ from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.background.celery.memory_monitoring import emit_process_memory
from onyx.background.celery.tasks.indexing.utils import _should_index
from onyx.background.celery.tasks.indexing.utils import get_unfenced_index_attempt_ids
from onyx.background.celery.tasks.indexing.utils import IndexingCallback
from onyx.background.celery.tasks.indexing.utils import should_index
from onyx.background.celery.tasks.indexing.utils import try_creating_indexing_task
from onyx.background.celery.tasks.indexing.utils import validate_indexing_fences
from onyx.background.indexing.checkpointing_utils import cleanup_checkpoint
@@ -62,7 +61,7 @@ from onyx.db.index_attempt import mark_attempt_canceled
from onyx.db.index_attempt import mark_attempt_failed
from onyx.db.search_settings import get_active_search_settings_list
from onyx.db.search_settings import get_current_search_settings
from onyx.db.swap_index import check_and_perform_index_swap
from onyx.db.swap_index import check_index_swap
from onyx.natural_language_processing.search_nlp_models import EmbeddingModel
from onyx.natural_language_processing.search_nlp_models import warm_up_bi_encoder
from onyx.redis.redis_connector import RedisConnector
@@ -407,7 +406,7 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
# check for search settings swap
with get_session_with_current_tenant() as db_session:
old_search_settings = check_and_perform_index_swap(db_session=db_session)
old_search_settings = check_index_swap(db_session=db_session)
current_search_settings = get_current_search_settings(db_session)
# So that the first time users aren't surprised by really slow speed of first
# batch of documents indexed
@@ -440,15 +439,6 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
with get_session_with_current_tenant() as db_session:
search_settings_list = get_active_search_settings_list(db_session)
for search_settings_instance in search_settings_list:
# skip non-live search settings that don't have background reindex enabled
# those should just auto-change to live shortly after creation without
# requiring any indexing till that point
if (
not search_settings_instance.status.is_current()
and not search_settings_instance.background_reindex_enabled
):
continue
redis_connector_index = redis_connector.new_index(
search_settings_instance.id
)
@@ -466,18 +456,23 @@ def check_for_indexing(self: Task, *, tenant_id: str) -> int | None:
cc_pair.id, search_settings_instance.id, db_session
)
if not should_index(
search_settings_primary = False
if search_settings_instance.id == search_settings_list[0].id:
search_settings_primary = True
if not _should_index(
cc_pair=cc_pair,
last_index=last_attempt,
search_settings_instance=search_settings_instance,
search_settings_primary=search_settings_primary,
secondary_index_building=len(search_settings_list) > 1,
db_session=db_session,
):
continue
reindex = False
if search_settings_instance.status.is_current():
# the indexing trigger is only checked and cleared with the current search settings
if search_settings_instance.id == search_settings_list[0].id:
# the indexing trigger is only checked and cleared with the primary search settings
if cc_pair.indexing_trigger is not None:
if cc_pair.indexing_trigger == IndexingMode.REINDEX:
reindex = True
@@ -985,9 +980,6 @@ def connector_indexing_proxy_task(
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
# Track the last time memory info was emitted
last_memory_emit_time = 0.0
try:
with get_session_with_current_tenant() as db_session:
index_attempt = get_index_attempt(
@@ -1028,23 +1020,6 @@ def connector_indexing_proxy_task(
job.release()
break
# log the memory usage for tracking down memory leaks / connector-specific memory issues
pid = job.process.pid
if pid is not None:
# Only emit memory info once per minute (60 seconds)
current_time = time.monotonic()
if current_time - last_memory_emit_time >= 60.0:
emit_process_memory(
pid,
"indexing_worker",
{
"cc_pair_id": cc_pair_id,
"search_settings_id": search_settings_id,
"index_attempt_id": index_attempt_id,
},
)
last_memory_emit_time = current_time
# if a termination signal is detected, break (exit point will clean up)
if self.request.id and redis_connector_index.terminating(self.request.id):
task_logger.warning(
@@ -1191,7 +1166,6 @@ def connector_indexing_proxy_task(
return
# primary
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_CHECKPOINT_CLEANUP,
soft_time_limit=300,
@@ -1239,7 +1213,6 @@ def check_for_checkpoint_cleanup(*, tenant_id: str) -> None:
)
# light worker
@shared_task(
name=OnyxCeleryTask.CLEANUP_CHECKPOINT,
bind=True,

View File

@@ -346,10 +346,11 @@ def validate_indexing_fences(
return
def should_index(
def _should_index(
cc_pair: ConnectorCredentialPair,
last_index: IndexAttempt | None,
search_settings_instance: SearchSettings,
search_settings_primary: bool,
secondary_index_building: bool,
db_session: Session,
) -> bool:
@@ -414,9 +415,9 @@ def should_index(
):
return False
if search_settings_instance.status.is_current():
if search_settings_primary:
if cc_pair.indexing_trigger is not None:
# if a manual indexing trigger is on the cc pair, honor it for live search settings
# if a manual indexing trigger is on the cc pair, honor it for primary search settings
return True
# if no attempt has ever occurred, we should index regardless of refresh_freq

View File

@@ -22,7 +22,6 @@ from onyx.configs.constants import DocumentSource
from onyx.configs.constants import MilestoneRecordType
from onyx.connectors.connector_runner import ConnectorRunner
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.factory import instantiate_connector
from onyx.connectors.models import ConnectorCheckpoint
from onyx.connectors.models import ConnectorFailure
@@ -93,13 +92,8 @@ def _get_connector_runner(
if not INTEGRATION_TESTS_MODE:
runnable_connector.validate_connector_settings()
except UnexpectedValidationError as e:
logger.exception(
"Unable to instantiate connector due to an unexpected temporary issue."
)
raise e
except Exception as e:
logger.exception("Unable to instantiate connector. Pausing until fixed.")
logger.exception("Unable to instantiate connector.")
# since we failed to even instantiate the connector, we pause the CCPair since
# it will never succeed

View File

@@ -15,8 +15,6 @@ from onyx.chat.stream_processing.answer_response_handler import (
from onyx.chat.tool_handling.tool_response_handler import ToolResponseHandler
# This is Legacy code that is not used anymore.
# It is kept here for reference.
class LLMResponseHandlerManager:
"""
This class is responsible for postprocessing the LLM response stream.

View File

@@ -1,13 +1,10 @@
from collections import OrderedDict
from collections.abc import Callable
from collections.abc import Iterator
from collections.abc import Mapping
from datetime import datetime
from enum import Enum
from typing import Any
from typing import Literal
from typing import TYPE_CHECKING
from typing import Union
from pydantic import BaseModel
from pydantic import ConfigDict
@@ -47,44 +44,9 @@ class LlmDoc(BaseModel):
class SubQuestionIdentifier(BaseModel):
"""None represents references to objects in the original flow. To our understanding,
these will not be None in the packets returned from agent search.
"""
level: int | None = None
level_question_num: int | None = None
@staticmethod
def make_dict_by_level(
original_dict: Mapping[tuple[int, int], "SubQuestionIdentifier"]
) -> dict[int, list["SubQuestionIdentifier"]]:
"""returns a dict of level to object list (sorted by level_question_num)
Ordering is asc for readability.
"""
# organize by level, then sort ascending by question_index
level_dict: dict[int, list[SubQuestionIdentifier]] = {}
# group by level
for k, obj in original_dict.items():
level = k[0]
if level not in level_dict:
level_dict[level] = []
level_dict[level].append(obj)
# for each level, sort the group
for k2, value2 in level_dict.items():
# we need to handle the none case due to SubQuestionIdentifier typing
# level_question_num as int | None, even though it should never be None here.
level_dict[k2] = sorted(
value2,
key=lambda x: (x.level_question_num is None, x.level_question_num),
)
# sort by level
sorted_dict = OrderedDict(sorted(level_dict.items()))
return sorted_dict
# First chunk of info for streaming QA
class QADocsResponse(RetrievalDocs, SubQuestionIdentifier):
@@ -374,8 +336,6 @@ class AgentAnswerPiece(SubQuestionIdentifier):
class SubQuestionPiece(SubQuestionIdentifier):
"""Refined sub questions generated from the initial user question."""
sub_question: str
@@ -387,13 +347,13 @@ class RefinedAnswerImprovement(BaseModel):
refined_answer_improvement: bool
AgentSearchPacket = Union[
AgentSearchPacket = (
SubQuestionPiece
| AgentAnswerPiece
| SubQueryPiece
| ExtendedToolResponse
| RefinedAnswerImprovement
]
)
AnswerPacket = (
AnswerQuestionPossibleReturn | AgentSearchPacket | ToolCallKickoff | ToolResponse

View File

@@ -756,7 +756,6 @@ def stream_chat_message_objects(
)
# LLM prompt building, response capturing, etc.
answer = Answer(
prompt_builder=prompt_builder,
is_connected=is_connected,

View File

@@ -90,97 +90,97 @@ class CitationProcessor:
next(group for group in citation.groups() if group is not None)
)
if not (1 <= numerical_value <= self.max_citation_num):
continue
context_llm_doc = self.context_docs[numerical_value - 1]
final_citation_num = self.final_order_mapping[
context_llm_doc.document_id
]
if final_citation_num not in self.citation_order:
self.citation_order.append(final_citation_num)
citation_order_idx = self.citation_order.index(final_citation_num) + 1
# get the value that was displayed to user, should always
# be in the display_doc_order_dict. But check anyways
if context_llm_doc.document_id in self.display_order_mapping:
displayed_citation_num = self.display_order_mapping[
if 1 <= numerical_value <= self.max_citation_num:
context_llm_doc = self.context_docs[numerical_value - 1]
final_citation_num = self.final_order_mapping[
context_llm_doc.document_id
]
else:
displayed_citation_num = final_citation_num
logger.warning(
f"Doc {context_llm_doc.document_id} not in display_doc_order_dict. Used LLM citation number instead."
if final_citation_num not in self.citation_order:
self.citation_order.append(final_citation_num)
citation_order_idx = (
self.citation_order.index(final_citation_num) + 1
)
# Skip consecutive citations of the same work
if final_citation_num in self.current_citations:
start, end = citation.span()
real_start = length_to_add + start
diff = end - start
self.curr_segment = (
self.curr_segment[: length_to_add + start]
+ self.curr_segment[real_start + diff :]
)
length_to_add -= diff
continue
# Handle edge case where LLM outputs citation itself
if self.curr_segment.startswith("[["):
match = re.match(r"\[\[(\d+)\]\]", self.curr_segment)
if match:
try:
doc_id = int(match.group(1))
context_llm_doc = self.context_docs[doc_id - 1]
yield CitationInfo(
# citation_num is now the number post initial ranking, i.e. as displayed to user
citation_num=displayed_citation_num,
document_id=context_llm_doc.document_id,
)
except Exception as e:
logger.warning(
f"Manual LLM citation didn't properly cite documents {e}"
)
# get the value that was displayed to user, should always
# be in the display_doc_order_dict. But check anyways
if context_llm_doc.document_id in self.display_order_mapping:
displayed_citation_num = self.display_order_mapping[
context_llm_doc.document_id
]
else:
displayed_citation_num = final_citation_num
logger.warning(
"Manual LLM citation wasn't able to close brackets"
f"Doc {context_llm_doc.document_id} not in display_doc_order_dict. Used LLM citation number instead."
)
continue
link = context_llm_doc.link
# Skip consecutive citations of the same work
if final_citation_num in self.current_citations:
start, end = citation.span()
real_start = length_to_add + start
diff = end - start
self.curr_segment = (
self.curr_segment[: length_to_add + start]
+ self.curr_segment[real_start + diff :]
)
length_to_add -= diff
continue
self.past_cite_count = len(self.llm_out)
self.current_citations.append(final_citation_num)
# Handle edge case where LLM outputs citation itself
if self.curr_segment.startswith("[["):
match = re.match(r"\[\[(\d+)\]\]", self.curr_segment)
if match:
try:
doc_id = int(match.group(1))
context_llm_doc = self.context_docs[doc_id - 1]
yield CitationInfo(
# citation_num is now the number post initial ranking, i.e. as displayed to user
citation_num=displayed_citation_num,
document_id=context_llm_doc.document_id,
)
except Exception as e:
logger.warning(
f"Manual LLM citation didn't properly cite documents {e}"
)
else:
logger.warning(
"Manual LLM citation wasn't able to close brackets"
)
continue
if citation_order_idx not in self.cited_inds:
self.cited_inds.add(citation_order_idx)
yield CitationInfo(
# citation number is now the one that was displayed to user
citation_num=displayed_citation_num,
document_id=context_llm_doc.document_id,
)
link = context_llm_doc.link
start, end = citation.span()
if link:
prev_length = len(self.curr_segment)
self.curr_segment = (
self.curr_segment[: start + length_to_add]
+ f"[[{displayed_citation_num}]]({link})" # use the value that was displayed to user
+ self.curr_segment[end + length_to_add :]
)
length_to_add += len(self.curr_segment) - prev_length
else:
prev_length = len(self.curr_segment)
self.curr_segment = (
self.curr_segment[: start + length_to_add]
+ f"[[{displayed_citation_num}]]()" # use the value that was displayed to user
+ self.curr_segment[end + length_to_add :]
)
length_to_add += len(self.curr_segment) - prev_length
self.past_cite_count = len(self.llm_out)
self.current_citations.append(final_citation_num)
last_citation_end = end + length_to_add
if citation_order_idx not in self.cited_inds:
self.cited_inds.add(citation_order_idx)
yield CitationInfo(
# citation number is now the one that was displayed to user
citation_num=displayed_citation_num,
document_id=context_llm_doc.document_id,
)
start, end = citation.span()
if link:
prev_length = len(self.curr_segment)
self.curr_segment = (
self.curr_segment[: start + length_to_add]
+ f"[[{displayed_citation_num}]]({link})" # use the value that was displayed to user
+ self.curr_segment[end + length_to_add :]
)
length_to_add += len(self.curr_segment) - prev_length
else:
prev_length = len(self.curr_segment)
self.curr_segment = (
self.curr_segment[: start + length_to_add]
+ f"[[{displayed_citation_num}]]()" # use the value that was displayed to user
+ self.curr_segment[end + length_to_add :]
)
length_to_add += len(self.curr_segment) - prev_length
last_citation_end = end + length_to_add
if last_citation_end > 0:
result += self.curr_segment[:last_citation_end]

View File

@@ -217,20 +217,20 @@ AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION = int(
)
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION = 6 # in seconds
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION = 4 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION
)
AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION = 40 # in seconds
AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION = 30 # in seconds
AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION
)
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION = 10 # in seconds
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION = 5 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION
@@ -243,13 +243,13 @@ AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION = int(
)
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION = 15 # in seconds
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION = 5 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION
)
AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION = 45 # in seconds
AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION = 30 # in seconds
AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION
@@ -333,45 +333,4 @@ AGENT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION = int(
or AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION
)
AGENT_DEFAULT_MAX_TOKENS_VALIDATION = 4
AGENT_MAX_TOKENS_VALIDATION = int(
os.environ.get("AGENT_MAX_TOKENS_VALIDATION") or AGENT_DEFAULT_MAX_TOKENS_VALIDATION
)
AGENT_DEFAULT_MAX_TOKENS_SUBANSWER_GENERATION = 256
AGENT_MAX_TOKENS_SUBANSWER_GENERATION = int(
os.environ.get("AGENT_MAX_TOKENS_SUBANSWER_GENERATION")
or AGENT_DEFAULT_MAX_TOKENS_SUBANSWER_GENERATION
)
AGENT_DEFAULT_MAX_TOKENS_ANSWER_GENERATION = 1024
AGENT_MAX_TOKENS_ANSWER_GENERATION = int(
os.environ.get("AGENT_MAX_TOKENS_ANSWER_GENERATION")
or AGENT_DEFAULT_MAX_TOKENS_ANSWER_GENERATION
)
AGENT_DEFAULT_MAX_TOKENS_SUBQUESTION_GENERATION = 256
AGENT_MAX_TOKENS_SUBQUESTION_GENERATION = int(
os.environ.get("AGENT_MAX_TOKENS_SUBQUESTION_GENERATION")
or AGENT_DEFAULT_MAX_TOKENS_SUBQUESTION_GENERATION
)
AGENT_DEFAULT_MAX_TOKENS_ENTITY_TERM_EXTRACTION = 1024
AGENT_MAX_TOKENS_ENTITY_TERM_EXTRACTION = int(
os.environ.get("AGENT_MAX_TOKENS_ENTITY_TERM_EXTRACTION")
or AGENT_DEFAULT_MAX_TOKENS_ENTITY_TERM_EXTRACTION
)
AGENT_DEFAULT_MAX_TOKENS_SUBQUERY_GENERATION = 64
AGENT_MAX_TOKENS_SUBQUERY_GENERATION = int(
os.environ.get("AGENT_MAX_TOKENS_SUBQUERY_GENERATION")
or AGENT_DEFAULT_MAX_TOKENS_SUBQUERY_GENERATION
)
AGENT_DEFAULT_MAX_TOKENS_HISTORY_SUMMARY = 128
AGENT_MAX_TOKENS_HISTORY_SUMMARY = int(
os.environ.get("AGENT_MAX_TOKENS_HISTORY_SUMMARY")
or AGENT_DEFAULT_MAX_TOKENS_HISTORY_SUMMARY
)
GRAPH_VERSION_NAME: str = "a"

View File

@@ -640,6 +640,3 @@ TEST_ENV = os.environ.get("TEST_ENV", "").lower() == "true"
MOCK_LLM_RESPONSE = (
os.environ.get("MOCK_LLM_RESPONSE") if os.environ.get("MOCK_LLM_RESPONSE") else None
)
DEFAULT_IMAGE_ANALYSIS_MAX_SIZE_MB = 20

View File

@@ -76,7 +76,6 @@ KV_REINDEX_KEY = "needs_reindexing"
KV_SEARCH_SETTINGS = "search_settings"
KV_UNSTRUCTURED_API_KEY = "unstructured_api_key"
KV_USER_STORE_KEY = "INVITED_USERS"
KV_PENDING_USERS_KEY = "PENDING_USERS"
KV_NO_AUTH_USER_PREFERENCES_KEY = "no_auth_user_preferences"
KV_CRED_KEY = "credential_id_{}"
KV_GMAIL_CRED_KEY = "gmail_app_credential"

View File

@@ -1,38 +0,0 @@
from onyx.configs.app_configs import DEFAULT_IMAGE_ANALYSIS_MAX_SIZE_MB
from onyx.server.settings.store import load_settings
def get_image_extraction_and_analysis_enabled() -> bool:
"""Get image extraction and analysis enabled setting from workspace settings or fallback to False"""
try:
settings = load_settings()
if settings.image_extraction_and_analysis_enabled is not None:
return settings.image_extraction_and_analysis_enabled
except Exception:
pass
return False
def get_search_time_image_analysis_enabled() -> bool:
"""Get search time image analysis enabled setting from workspace settings or fallback to False"""
try:
settings = load_settings()
if settings.search_time_image_analysis_enabled is not None:
return settings.search_time_image_analysis_enabled
except Exception:
pass
return False
def get_image_analysis_max_size_mb() -> int:
"""Get image analysis max size MB setting from workspace settings or fallback to environment variable"""
try:
settings = load_settings()
if settings.image_analysis_max_size_mb is not None:
return settings.image_analysis_max_size_mb
except Exception:
pass
return DEFAULT_IMAGE_ANALYSIS_MAX_SIZE_MB

View File

@@ -200,6 +200,7 @@ class AirtableConnector(LoadConnector):
return attachment_response.content
logger.error(f"Failed to refresh attachment for {filename}")
raise
attachment_content = get_attachment_with_retry(url, record_id)

View File

@@ -18,7 +18,7 @@ from onyx.configs.constants import DocumentSource
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.exceptions import UnexpectedError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
@@ -310,7 +310,7 @@ class BlobStorageConnector(LoadConnector, PollConnector):
# Catch-all for anything not captured by the above
# Since we are unsure of the error and it may not disable the connector,
# raise an unexpected error (does not disable connector)
raise UnexpectedValidationError(
raise UnexpectedError(
f"Unexpected error during blob storage settings validation: {e}"
)

View File

@@ -11,17 +11,18 @@ from onyx.configs.app_configs import CONFLUENCE_TIMEZONE_OFFSET
from onyx.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.connectors.confluence.onyx_confluence import extract_text_from_confluence_html
from onyx.connectors.confluence.onyx_confluence import attachment_to_content
from onyx.connectors.confluence.onyx_confluence import (
extract_text_from_confluence_html,
)
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
from onyx.connectors.confluence.utils import build_confluence_document_id
from onyx.connectors.confluence.utils import convert_attachment_to_content
from onyx.connectors.confluence.utils import datetime_from_string
from onyx.connectors.confluence.utils import process_attachment
from onyx.connectors.confluence.utils import validate_attachment_filetype
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.exceptions import UnexpectedError
from onyx.connectors.interfaces import CredentialsConnector
from onyx.connectors.interfaces import CredentialsProviderInterface
from onyx.connectors.interfaces import GenerateDocumentsOutput
@@ -35,26 +36,28 @@ from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import SlimDocument
from onyx.connectors.vision_enabled_connector import VisionEnabledConnector
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.utils.logger import setup_logger
logger = setup_logger()
# Potential Improvements
# 1. Segment into Sections for more accurate linking, can split by headers but make sure no text/ordering is lost
# 1. Include attachments, etc
# 2. Segment into Sections for more accurate linking, can split by headers but make sure no text/ordering is lost
_COMMENT_EXPANSION_FIELDS = ["body.storage.value"]
_PAGE_EXPANSION_FIELDS = [
"body.storage.value",
"version",
"space",
"metadata.labels",
"history.lastUpdated",
]
_ATTACHMENT_EXPANSION_FIELDS = [
"version",
"space",
"metadata.labels",
]
_RESTRICTIONS_EXPANSION_FIELDS = [
"space",
"restrictions.read.restrictions.user",
@@ -66,6 +69,9 @@ _RESTRICTIONS_EXPANSION_FIELDS = [
_SLIM_DOC_BATCH_SIZE = 5000
_ATTACHMENT_EXTENSIONS_TO_FILTER_OUT = [
"png",
"jpg",
"jpeg",
"gif",
"mp4",
"mov",
@@ -81,11 +87,7 @@ _FULL_EXTENSION_FILTER_STRING = "".join(
class ConfluenceConnector(
LoadConnector,
PollConnector,
SlimConnector,
CredentialsConnector,
VisionEnabledConnector,
LoadConnector, PollConnector, SlimConnector, CredentialsConnector
):
def __init__(
self,
@@ -103,24 +105,13 @@ class ConfluenceConnector(
labels_to_skip: list[str] = CONFLUENCE_CONNECTOR_LABELS_TO_SKIP,
timezone_offset: float = CONFLUENCE_TIMEZONE_OFFSET,
) -> None:
self.wiki_base = wiki_base
self.is_cloud = is_cloud
self.space = space
self.page_id = page_id
self.index_recursively = index_recursively
self.cql_query = cql_query
self.batch_size = batch_size
self.continue_on_failure = continue_on_failure
self.labels_to_skip = labels_to_skip
self.timezone_offset = timezone_offset
self._confluence_client: OnyxConfluence | None = None
self._fetched_titles: set[str] = set()
# Initialize vision LLM using the mixin
self.initialize_vision_llm()
self.is_cloud = is_cloud
# Remove trailing slash from wiki_base if present
self.wiki_base = wiki_base.rstrip("/")
"""
If nothing is provided, we default to fetching all pages
Only one or none of the following options should be specified so
@@ -162,6 +153,8 @@ class ConfluenceConnector(
"max_backoff_seconds": 60,
}
self._confluence_client: OnyxConfluence | None = None
@property
def confluence_client(self) -> OnyxConfluence:
if self._confluence_client is None:
@@ -191,6 +184,7 @@ class ConfluenceConnector(
end: SecondsSinceUnixEpoch | None = None,
) -> str:
page_query = self.base_cql_page_query + self.cql_label_filter
# Add time filters
if start:
formatted_start_time = datetime.fromtimestamp(
@@ -202,6 +196,7 @@ class ConfluenceConnector(
"%Y-%m-%d %H:%M"
)
page_query += f" and lastmodified <= '{formatted_end_time}'"
return page_query
def _construct_attachment_query(self, confluence_page_id: str) -> str:
@@ -212,10 +207,11 @@ class ConfluenceConnector(
def _get_comment_string_for_page_id(self, page_id: str) -> str:
comment_string = ""
comment_cql = f"type=comment and container='{page_id}'"
comment_cql += self.cql_label_filter
expand = ",".join(_COMMENT_EXPANSION_FIELDS)
expand = ",".join(_COMMENT_EXPANSION_FIELDS)
for comment in self.confluence_client.paginated_cql_retrieval(
cql=comment_cql,
expand=expand,
@@ -226,179 +222,123 @@ class ConfluenceConnector(
confluence_object=comment,
fetched_titles=set(),
)
return comment_string
def _convert_page_to_document(self, page: dict[str, Any]) -> Document | None:
def _convert_object_to_document(
self,
confluence_object: dict[str, Any],
parent_content_id: str | None = None,
) -> Document | None:
"""
Converts a Confluence page to a Document object.
Includes the page content, comments, and attachments.
"""
try:
# Extract basic page information
page_id = page["id"]
page_title = page["title"]
page_url = f"{self.wiki_base}{page['_links']['webui']}"
Takes in a confluence object, extracts all metadata, and converts it into a document.
If its a page, it extracts the text, adds the comments for the document text.
If its an attachment, it just downloads the attachment and converts that into a document.
# Get the page content
page_content = extract_text_from_confluence_html(
self.confluence_client, page, self._fetched_titles
parent_content_id: if the object is an attachment, specifies the content id that
the attachment is attached to
"""
# The url and the id are the same
object_url = build_confluence_document_id(
self.wiki_base, confluence_object["_links"]["webui"], self.is_cloud
)
object_text = None
# Extract text from page
if confluence_object["type"] == "page":
object_text = extract_text_from_confluence_html(
confluence_client=self.confluence_client,
confluence_object=confluence_object,
fetched_titles={confluence_object.get("title", "")},
)
# Add comments to text
object_text += self._get_comment_string_for_page_id(confluence_object["id"])
elif confluence_object["type"] == "attachment":
object_text = attachment_to_content(
confluence_client=self.confluence_client,
attachment=confluence_object,
parent_content_id=parent_content_id,
)
# Create the main section for the page content
sections = [Section(text=page_content, link=page_url)]
# Process comments if available
comment_text = self._get_comment_string_for_page_id(page_id)
if comment_text:
sections.append(Section(text=comment_text, link=f"{page_url}#comments"))
# Process attachments
if "children" in page and "attachment" in page["children"]:
attachments = self.confluence_client.get_attachments_for_page(
page_id, expand="metadata"
)
for attachment in attachments.get("results", []):
# Process each attachment
result = process_attachment(
self.confluence_client,
attachment,
page_title,
self.image_analysis_llm,
)
if result.text:
# Create a section for the attachment text
attachment_section = Section(
text=result.text,
link=f"{page_url}#attachment-{attachment['id']}",
image_file_name=result.file_name,
)
sections.append(attachment_section)
elif result.error:
logger.warning(
f"Error processing attachment '{attachment.get('title')}': {result.error}"
)
# Extract metadata
metadata = {}
if "space" in page:
metadata["space"] = page["space"].get("name", "")
# Extract labels
labels = []
if "metadata" in page and "labels" in page["metadata"]:
for label in page["metadata"]["labels"].get("results", []):
labels.append(label.get("name", ""))
if labels:
metadata["labels"] = labels
# Extract owners
primary_owners = []
if "version" in page and "by" in page["version"]:
author = page["version"]["by"]
display_name = author.get("displayName", "Unknown")
primary_owners.append(BasicExpertInfo(display_name=display_name))
# Create the document
return Document(
id=build_confluence_document_id(
self.wiki_base, page["_links"]["webui"], self.is_cloud
),
sections=sections,
source=DocumentSource.CONFLUENCE,
semantic_identifier=page_title,
metadata=metadata,
doc_updated_at=datetime_from_string(page["version"]["when"]),
primary_owners=primary_owners if primary_owners else None,
)
except Exception as e:
logger.error(f"Error converting page {page.get('id', 'unknown')}: {e}")
if not self.continue_on_failure:
raise
if object_text is None:
# This only happens for attachments that are not parseable
return None
# Get space name
doc_metadata: dict[str, str | list[str]] = {
"Wiki Space Name": confluence_object["space"]["name"]
}
# Get labels
label_dicts = (
confluence_object.get("metadata", {}).get("labels", {}).get("results", [])
)
page_labels = [label.get("name") for label in label_dicts if label.get("name")]
if page_labels:
doc_metadata["labels"] = page_labels
# Get last modified and author email
version_dict = confluence_object.get("version", {})
last_modified = (
datetime_from_string(version_dict.get("when"))
if version_dict.get("when")
else None
)
author_email = version_dict.get("by", {}).get("email")
title = confluence_object.get("title", "Untitled Document")
return Document(
id=object_url,
sections=[Section(link=object_url, text=object_text)],
source=DocumentSource.CONFLUENCE,
semantic_identifier=title,
doc_updated_at=last_modified,
primary_owners=(
[BasicExpertInfo(email=author_email)] if author_email else None
),
metadata=doc_metadata,
)
def _fetch_document_batches(
self,
start: SecondsSinceUnixEpoch | None = None,
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateDocumentsOutput:
"""
Yields batches of Documents. For each page:
- Create a Document with 1 Section for the page text/comments
- Then fetch attachments. For each attachment:
- Attempt to convert it with convert_attachment_to_content(...)
- If successful, create a new Section with the extracted text or summary.
"""
doc_batch: list[Document] = []
confluence_page_ids: list[str] = []
page_query = self._construct_page_query(start, end)
logger.debug(f"page_query: {page_query}")
# Fetch pages as Documents
for page in self.confluence_client.paginated_cql_retrieval(
cql=page_query,
expand=",".join(_PAGE_EXPANSION_FIELDS),
limit=self.batch_size,
):
# Build doc from page
doc = self._convert_page_to_document(page)
if not doc:
continue
# Now get attachments for that page:
attachment_query = self._construct_attachment_query(page["id"])
# We'll use the page's XML to provide context if we summarize an image
confluence_xml = page.get("body", {}).get("storage", {}).get("value", "")
logger.debug(f"_fetch_document_batches: {page['id']}")
confluence_page_ids.append(page["id"])
doc = self._convert_object_to_document(page)
if doc is not None:
doc_batch.append(doc)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
# Fetch attachments as Documents
for confluence_page_id in confluence_page_ids:
attachment_query = self._construct_attachment_query(confluence_page_id)
# TODO: maybe should add time filter as well?
for attachment in self.confluence_client.paginated_cql_retrieval(
cql=attachment_query,
expand=",".join(_ATTACHMENT_EXPANSION_FIELDS),
):
attachment["metadata"].get("mediaType", "")
if not validate_attachment_filetype(
attachment, self.image_analysis_llm
):
continue
# Attempt to get textual content or image summarization:
try:
logger.info(f"Processing attachment: {attachment['title']}")
response = convert_attachment_to_content(
confluence_client=self.confluence_client,
attachment=attachment,
page_context=confluence_xml,
llm=self.image_analysis_llm,
)
if response is None:
continue
content_text, file_storage_name = response
object_url = build_confluence_document_id(
self.wiki_base, attachment["_links"]["webui"], self.is_cloud
)
if content_text:
doc.sections.append(
Section(
text=content_text,
link=object_url,
image_file_name=file_storage_name,
)
)
except Exception as e:
logger.error(
f"Failed to extract/summarize attachment {attachment['title']}",
exc_info=e,
)
if not self.continue_on_failure:
raise
doc_batch.append(doc)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
doc = self._convert_object_to_document(attachment, confluence_page_id)
if doc is not None:
doc_batch.append(doc)
if len(doc_batch) >= self.batch_size:
yield doc_batch
doc_batch = []
if doc_batch:
yield doc_batch
@@ -419,63 +359,55 @@ class ConfluenceConnector(
end: SecondsSinceUnixEpoch | None = None,
callback: IndexingHeartbeatInterface | None = None,
) -> GenerateSlimDocumentOutput:
"""
Return 'slim' docs (IDs + minimal permission data).
Does not fetch actual text. Used primarily for incremental permission sync.
"""
doc_metadata_list: list[SlimDocument] = []
restrictions_expand = ",".join(_RESTRICTIONS_EXPANSION_FIELDS)
# Query pages
page_query = self.base_cql_page_query + self.cql_label_filter
for page in self.confluence_client.cql_paginate_all_expansions(
cql=page_query,
expand=restrictions_expand,
limit=_SLIM_DOC_BATCH_SIZE,
):
# If the page has restrictions, add them to the perm_sync_data
# These will be used by doc_sync.py to sync permissions
page_restrictions = page.get("restrictions")
page_space_key = page.get("space", {}).get("key")
page_ancestors = page.get("ancestors", [])
page_perm_sync_data = {
"restrictions": page_restrictions or {},
"space_key": page_space_key,
"ancestors": page_ancestors,
"ancestors": page_ancestors or [],
}
doc_metadata_list.append(
SlimDocument(
id=build_confluence_document_id(
self.wiki_base, page["_links"]["webui"], self.is_cloud
self.wiki_base,
page["_links"]["webui"],
self.is_cloud,
),
perm_sync_data=page_perm_sync_data,
)
)
# Query attachments for each page
attachment_query = self._construct_attachment_query(page["id"])
for attachment in self.confluence_client.cql_paginate_all_expansions(
cql=attachment_query,
expand=restrictions_expand,
limit=_SLIM_DOC_BATCH_SIZE,
):
# If you skip images, you'll skip them in the permission sync
attachment["metadata"].get("mediaType", "")
if not validate_attachment_filetype(
attachment, self.image_analysis_llm
):
if not validate_attachment_filetype(attachment):
continue
attachment_restrictions = attachment.get("restrictions", {})
attachment_restrictions = attachment.get("restrictions")
if not attachment_restrictions:
attachment_restrictions = page_restrictions or {}
attachment_restrictions = page_restrictions
attachment_space_key = attachment.get("space", {}).get("key")
if not attachment_space_key:
attachment_space_key = page_space_key
attachment_perm_sync_data = {
"restrictions": attachment_restrictions,
"restrictions": attachment_restrictions or {},
"space_key": attachment_space_key,
}
@@ -489,16 +421,16 @@ class ConfluenceConnector(
perm_sync_data=attachment_perm_sync_data,
)
)
if len(doc_metadata_list) > _SLIM_DOC_BATCH_SIZE:
yield doc_metadata_list[:_SLIM_DOC_BATCH_SIZE]
doc_metadata_list = doc_metadata_list[_SLIM_DOC_BATCH_SIZE:]
if callback and callback.should_stop():
raise RuntimeError(
"retrieve_all_slim_documents: Stop signal detected"
)
if callback:
if callback.should_stop():
raise RuntimeError(
"retrieve_all_slim_documents: Stop signal detected"
)
callback.progress("retrieve_all_slim_documents", 1)
yield doc_metadata_list
@@ -519,11 +451,11 @@ class ConfluenceConnector(
raise InsufficientPermissionsError(
"Insufficient permissions to access Confluence resources (HTTP 403)."
)
raise UnexpectedValidationError(
raise UnexpectedError(
f"Unexpected Confluence error (status={status_code}): {e}"
)
except Exception as e:
raise UnexpectedValidationError(
raise UnexpectedError(
f"Unexpected error while validating Confluence settings: {e}"
)

View File

@@ -144,12 +144,6 @@ class OnyxConfluence:
self.static_credentials = credential_json
return credential_json, False
if not OAUTH_CONFLUENCE_CLOUD_CLIENT_ID:
raise RuntimeError("OAUTH_CONFLUENCE_CLOUD_CLIENT_ID must be set!")
if not OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET:
raise RuntimeError("OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET must be set!")
# check if we should refresh tokens. we're deciding to refresh halfway
# to expiration
now = datetime.now(timezone.utc)

View File

@@ -1,12 +1,9 @@
import io
import math
import time
from collections.abc import Callable
from datetime import datetime
from datetime import timedelta
from datetime import timezone
from io import BytesIO
from pathlib import Path
from typing import Any
from typing import cast
from typing import TYPE_CHECKING
@@ -15,28 +12,14 @@ from urllib.parse import parse_qs
from urllib.parse import quote
from urllib.parse import urlparse
import bs4
import requests
from pydantic import BaseModel
from sqlalchemy.orm import Session
from onyx.configs.app_configs import (
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
)
from onyx.configs.constants import FileOrigin
from onyx.utils.logger import setup_logger
if TYPE_CHECKING:
from onyx.connectors.confluence.onyx_confluence import OnyxConfluence
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.models import PGFileStore
from onyx.db.pg_file_store import create_populate_lobj
from onyx.db.pg_file_store import save_bytes_to_pgfilestore
from onyx.db.pg_file_store import upsert_pgfilestore
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.file_validation import is_valid_image_type
from onyx.file_processing.image_utils import store_image_and_create_section
from onyx.llm.interfaces import LLM
from onyx.utils.logger import setup_logger
pass
logger = setup_logger()
@@ -52,229 +35,15 @@ class TokenResponse(BaseModel):
scope: str
def validate_attachment_filetype(
attachment: dict[str, Any], llm: LLM | None = None
) -> bool:
"""
Validates if the attachment is a supported file type.
If LLM is provided, also checks if it's an image that can be processed.
"""
attachment.get("metadata", {})
media_type = attachment.get("metadata", {}).get("mediaType", "")
if media_type.startswith("image/"):
return llm is not None and is_valid_image_type(media_type)
# For non-image files, check if we support the extension
title = attachment.get("title", "")
extension = Path(title).suffix.lstrip(".").lower() if "." in title else ""
return extension in ["pdf", "doc", "docx", "txt", "md", "rtf"]
class AttachmentProcessingResult(BaseModel):
"""
A container for results after processing a Confluence attachment.
'text' is the textual content of the attachment.
'file_name' is the final file name used in PGFileStore to store the content.
'error' holds an exception or string if something failed.
"""
text: str | None
file_name: str | None
error: str | None = None
def _download_attachment(
confluence_client: "OnyxConfluence", attachment: dict[str, Any]
) -> bytes | None:
"""
Retrieves the raw bytes of an attachment from Confluence. Returns None on error.
"""
download_link = confluence_client.url + attachment["_links"]["download"]
resp = confluence_client._session.get(download_link)
if resp.status_code != 200:
logger.warning(
f"Failed to fetch {download_link} with status code {resp.status_code}"
)
return None
return resp.content
def process_attachment(
confluence_client: "OnyxConfluence",
attachment: dict[str, Any],
page_context: str,
llm: LLM | None,
) -> AttachmentProcessingResult:
"""
Processes a Confluence attachment. If it's a document, extracts text,
or if it's an image and an LLM is available, summarizes it. Returns a structured result.
"""
try:
# Get the media type from the attachment metadata
media_type = attachment.get("metadata", {}).get("mediaType", "")
# Validate the attachment type
if not validate_attachment_filetype(attachment, llm):
return AttachmentProcessingResult(
text=None,
file_name=None,
error=f"Unsupported file type: {media_type}",
)
# Download the attachment
raw_bytes = _download_attachment(confluence_client, attachment)
if raw_bytes is None:
return AttachmentProcessingResult(
text=None, file_name=None, error="Failed to download attachment"
)
# Process image attachments with LLM if available
if media_type.startswith("image/") and llm:
return _process_image_attachment(
confluence_client, attachment, page_context, llm, raw_bytes, media_type
)
# Process document attachments
try:
text = extract_file_text(
file=BytesIO(raw_bytes),
file_name=attachment["title"],
)
# Skip if the text is too long
if len(text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD:
return AttachmentProcessingResult(
text=None,
file_name=None,
error=f"Attachment text too long: {len(text)} chars",
)
return AttachmentProcessingResult(text=text, file_name=None, error=None)
except Exception as e:
return AttachmentProcessingResult(
text=None, file_name=None, error=f"Failed to extract text: {e}"
)
except Exception as e:
return AttachmentProcessingResult(
text=None, file_name=None, error=f"Failed to process attachment: {e}"
)
def _process_image_attachment(
confluence_client: "OnyxConfluence",
attachment: dict[str, Any],
page_context: str,
llm: LLM,
raw_bytes: bytes,
media_type: str,
) -> AttachmentProcessingResult:
"""Process an image attachment by saving it and generating a summary."""
try:
# Use the standardized image storage and section creation
with get_session_with_current_tenant() as db_session:
section, file_name = store_image_and_create_section(
db_session=db_session,
image_data=raw_bytes,
file_name=Path(attachment["id"]).name,
display_name=attachment["title"],
media_type=media_type,
llm=llm,
file_origin=FileOrigin.CONNECTOR,
)
return AttachmentProcessingResult(
text=section.text, file_name=file_name, error=None
)
except Exception as e:
msg = f"Image summarization failed for {attachment['title']}: {e}"
logger.error(msg, exc_info=e)
return AttachmentProcessingResult(text=None, file_name=None, error=msg)
def _process_text_attachment(
attachment: dict[str, Any],
raw_bytes: bytes,
media_type: str,
) -> AttachmentProcessingResult:
"""Process a text-based attachment by extracting its content."""
try:
extracted_text = extract_file_text(
io.BytesIO(raw_bytes),
file_name=attachment["title"],
break_on_unprocessable=False,
)
except Exception as e:
msg = f"Failed to extract text for '{attachment['title']}': {e}"
logger.error(msg, exc_info=e)
return AttachmentProcessingResult(text=None, file_name=None, error=msg)
# Check length constraints
if extracted_text is None or len(extracted_text) == 0:
msg = f"No text extracted for {attachment['title']}"
logger.warning(msg)
return AttachmentProcessingResult(text=None, file_name=None, error=msg)
if len(extracted_text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD:
msg = (
f"Skipping attachment {attachment['title']} due to char count "
f"({len(extracted_text)} > {CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD})"
)
logger.warning(msg)
return AttachmentProcessingResult(text=None, file_name=None, error=msg)
# Save the attachment
try:
with get_session_with_current_tenant() as db_session:
saved_record = save_bytes_to_pgfilestore(
db_session=db_session,
raw_bytes=raw_bytes,
media_type=media_type,
identifier=attachment["id"],
display_name=attachment["title"],
)
except Exception as e:
msg = f"Failed to save attachment '{attachment['title']}' to PG: {e}"
logger.error(msg, exc_info=e)
return AttachmentProcessingResult(
text=extracted_text, file_name=None, error=msg
)
return AttachmentProcessingResult(
text=extracted_text, file_name=saved_record.file_name, error=None
)
def convert_attachment_to_content(
confluence_client: "OnyxConfluence",
attachment: dict[str, Any],
page_context: str,
llm: LLM | None,
) -> tuple[str | None, str | None] | None:
"""
Facade function which:
1. Validates attachment type
2. Extracts or summarizes content
3. Returns (content_text, stored_file_name) or None if we should skip it
"""
media_type = attachment["metadata"]["mediaType"]
# Quick check for unsupported types:
if media_type.startswith("video/") or media_type == "application/gliffy+json":
logger.warning(
f"Skipping unsupported attachment type: '{media_type}' for {attachment['title']}"
)
return None
result = process_attachment(confluence_client, attachment, page_context, llm)
if result.error is not None:
logger.warning(
f"Attachment {attachment['title']} encountered error: {result.error}"
)
return None
# Return the text and the file name
return result.text, result.file_name
def validate_attachment_filetype(attachment: dict[str, Any]) -> bool:
return attachment["metadata"]["mediaType"] not in [
"image/jpeg",
"image/png",
"image/gif",
"image/svg+xml",
"video/mp4",
"video/quicktime",
]
def build_confluence_document_id(
@@ -295,6 +64,23 @@ def build_confluence_document_id(
return f"{base_url}{content_url}"
def _extract_referenced_attachment_names(page_text: str) -> list[str]:
"""Parse a Confluence html page to generate a list of current
attachments in use
Args:
text (str): The page content
Returns:
list[str]: List of filenames currently in use by the page text
"""
referenced_attachment_filenames = []
soup = bs4.BeautifulSoup(page_text, "html.parser")
for attachment in soup.findAll("ri:attachment"):
referenced_attachment_filenames.append(attachment.attrs["ri:filename"])
return referenced_attachment_filenames
def datetime_from_string(datetime_string: str) -> datetime:
datetime_object = datetime.fromisoformat(datetime_string)
@@ -466,37 +252,3 @@ def update_param_in_path(path: str, param: str, value: str) -> str:
+ "?"
+ "&".join(f"{k}={quote(v[0])}" for k, v in query_params.items())
)
def attachment_to_file_record(
confluence_client: "OnyxConfluence",
attachment: dict[str, Any],
db_session: Session,
) -> tuple[PGFileStore, bytes]:
"""Save an attachment to the file store and return the file record."""
download_link = _attachment_to_download_link(confluence_client, attachment)
image_data = confluence_client.get(
download_link, absolute=True, not_json_response=True
)
# Save image to file store
file_name = f"confluence_attachment_{attachment['id']}"
lobj_oid = create_populate_lobj(BytesIO(image_data), db_session)
pgfilestore = upsert_pgfilestore(
file_name=file_name,
display_name=attachment["title"],
file_origin=FileOrigin.OTHER,
file_type=attachment["metadata"]["mediaType"],
lobj_oid=lobj_oid,
db_session=db_session,
commit=True,
)
return pgfilestore, image_data
def _attachment_to_download_link(
confluence_client: "OnyxConfluence", attachment: dict[str, Any]
) -> str:
"""Extracts the download link to images."""
return confluence_client.url + attachment["_links"]["download"]

View File

@@ -14,15 +14,12 @@ class ConnectorValidationError(ValidationError):
super().__init__(self.message)
class UnexpectedValidationError(ValidationError):
class UnexpectedError(ValidationError):
"""Raised when an unexpected error occurs during connector validation.
Unexpected errors don't necessarily mean the credential is invalid,
but rather that there was an error during the validation process
or we encountered a currently unhandled error case.
Currently, unexpected validation errors are defined as transient and should not be
used to disable the connector.
"""
def __init__(self, message: str = "Unexpected error during connector validation"):

View File

@@ -10,23 +10,22 @@ from sqlalchemy.orm import Session
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import FileOrigin
from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_to_utc
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.models import BasicExpertInfo
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.vision_enabled_connector import VisionEnabledConnector
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.pg_file_store import get_pgfilestore_by_file_name
from onyx.file_processing.extract_file_text import extract_text_and_images
from onyx.file_processing.extract_file_text import detect_encoding
from onyx.file_processing.extract_file_text import extract_file_text
from onyx.file_processing.extract_file_text import get_file_ext
from onyx.file_processing.extract_file_text import is_text_file_extension
from onyx.file_processing.extract_file_text import is_valid_file_ext
from onyx.file_processing.extract_file_text import load_files_from_zip
from onyx.file_processing.image_utils import store_image_and_create_section
from onyx.file_processing.extract_file_text import read_pdf_file
from onyx.file_processing.extract_file_text import read_text_file
from onyx.file_store.file_store import get_default_file_store
from onyx.llm.interfaces import LLM
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -36,115 +35,81 @@ def _read_files_and_metadata(
file_name: str,
db_session: Session,
) -> Iterator[tuple[str, IO, dict[str, Any]]]:
"""
Reads the file from Postgres. If the file is a .zip, yields subfiles.
"""
"""Reads the file into IO, in the case of a zip file, yields each individual
file contained within, also includes the metadata dict if packaged in the zip"""
extension = get_file_ext(file_name)
metadata: dict[str, Any] = {}
directory_path = os.path.dirname(file_name)
# Read file from Postgres store
file_content = get_default_file_store(db_session).read_file(file_name, mode="b")
# If it's a zip, expand it
if extension == ".zip":
for file_info, subfile, metadata in load_files_from_zip(
for file_info, file, metadata in load_files_from_zip(
file_content, ignore_dirs=True
):
yield os.path.join(directory_path, file_info.filename), subfile, metadata
yield os.path.join(directory_path, file_info.filename), file, metadata
elif is_valid_file_ext(extension):
yield file_name, file_content, metadata
else:
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
def _create_image_section(
llm: LLM | None,
image_data: bytes,
db_session: Session,
parent_file_name: str,
display_name: str,
idx: int = 0,
) -> tuple[Section, str | None]:
"""
Create a Section object for a single image and store the image in PGFileStore.
If summarization is enabled and we have an LLM, summarize the image.
Returns:
tuple: (Section object, file_name in PGFileStore or None if storage failed)
"""
# Create a unique file name for the embedded image
file_name = f"{parent_file_name}_embedded_{idx}"
# Use the standardized utility to store the image and create a section
return store_image_and_create_section(
db_session=db_session,
image_data=image_data,
file_name=file_name,
display_name=display_name,
llm=llm,
file_origin=FileOrigin.OTHER,
)
def _process_file(
file_name: str,
file: IO[Any],
metadata: dict[str, Any] | None,
pdf_pass: str | None,
db_session: Session,
llm: LLM | None,
metadata: dict[str, Any] | None = None,
pdf_pass: str | None = None,
) -> list[Document]:
"""
Processes a single file, returning a list of Documents (typically one).
Also handles embedded images if 'EMBEDDED_IMAGE_EXTRACTION_ENABLED' is true.
"""
extension = get_file_ext(file_name)
# Fetch the DB record so we know the ID for internal URL
pg_record = get_pgfilestore_by_file_name(file_name=file_name, db_session=db_session)
if not pg_record:
logger.warning(f"No file record found for '{file_name}' in PG; skipping.")
return []
if not is_valid_file_ext(extension):
logger.warning(
f"Skipping file '{file_name}' with unrecognized extension '{extension}'"
)
logger.warning(f"Skipping file '{file_name}' with extension '{extension}'")
return []
# Prepare doc metadata
if metadata is None:
metadata = {}
file_display_name = metadata.get("file_display_name") or os.path.basename(file_name)
file_metadata: dict[str, Any] = {}
# Timestamps
current_datetime = datetime.now(timezone.utc)
time_updated = metadata.get("time_updated", current_datetime)
if is_text_file_extension(file_name):
encoding = detect_encoding(file)
file_content_raw, file_metadata = read_text_file(
file, encoding=encoding, ignore_onyx_metadata=False
)
# Using the PDF reader function directly to pass in password cleanly
elif extension == ".pdf" and pdf_pass is not None:
file_content_raw, file_metadata = read_pdf_file(file=file, pdf_pass=pdf_pass)
else:
file_content_raw = extract_file_text(
file=file,
file_name=file_name,
break_on_unprocessable=True,
)
all_metadata = {**metadata, **file_metadata} if metadata else file_metadata
# add a prefix to avoid conflicts with other connectors
doc_id = f"FILE_CONNECTOR__{file_name}"
if metadata:
doc_id = metadata.get("document_id") or doc_id
# If this is set, we will show this in the UI as the "name" of the file
file_display_name = all_metadata.get("file_display_name") or os.path.basename(
file_name
)
title = (
all_metadata["title"] or "" if "title" in all_metadata else file_display_name
)
time_updated = all_metadata.get("time_updated", datetime.now(timezone.utc))
if isinstance(time_updated, str):
time_updated = time_str_to_utc(time_updated)
dt_str = metadata.get("doc_updated_at")
dt_str = all_metadata.get("doc_updated_at")
final_time_updated = time_str_to_utc(dt_str) if dt_str else time_updated
# Collect owners
p_owner_names = metadata.get("primary_owners")
s_owner_names = metadata.get("secondary_owners")
p_owners = (
[BasicExpertInfo(display_name=name) for name in p_owner_names]
if p_owner_names
else None
)
s_owners = (
[BasicExpertInfo(display_name=name) for name in s_owner_names]
if s_owner_names
else None
)
# Additional tags we store as doc metadata
# Metadata tags separate from the Onyx specific fields
metadata_tags = {
k: v
for k, v in metadata.items()
for k, v in all_metadata.items()
if k
not in [
"document_id",
@@ -157,142 +122,77 @@ def _process_file(
"file_display_name",
"title",
"connector_type",
"pdf_password",
]
}
source_type_str = metadata.get("connector_type")
source_type = (
DocumentSource(source_type_str) if source_type_str else DocumentSource.FILE
source_type_str = all_metadata.get("connector_type")
source_type = DocumentSource(source_type_str) if source_type_str else None
p_owner_names = all_metadata.get("primary_owners")
s_owner_names = all_metadata.get("secondary_owners")
p_owners = (
[BasicExpertInfo(display_name=name) for name in p_owner_names]
if p_owner_names
else None
)
s_owners = (
[BasicExpertInfo(display_name=name) for name in s_owner_names]
if s_owner_names
else None
)
doc_id = metadata.get("document_id") or f"FILE_CONNECTOR__{file_name}"
title = metadata.get("title") or file_display_name
# 1) If the file itself is an image, handle that scenario quickly
IMAGE_EXTENSIONS = {".jpg", ".jpeg", ".png", ".webp"}
if extension in IMAGE_EXTENSIONS:
# Summarize or produce empty doc
image_data = file.read()
image_section, _ = _create_image_section(
llm, image_data, db_session, pg_record.file_name, title
)
return [
Document(
id=doc_id,
sections=[image_section],
source=source_type,
semantic_identifier=file_display_name,
title=title,
doc_updated_at=final_time_updated,
primary_owners=p_owners,
secondary_owners=s_owners,
metadata=metadata_tags,
)
]
# 2) Otherwise: text-based approach. Possibly with embedded images if enabled.
# (For example .docx with inline images).
file.seek(0)
text_content = ""
embedded_images: list[tuple[bytes, str]] = []
text_content, embedded_images = extract_text_and_images(
file=file,
file_name=file_name,
pdf_pass=pdf_pass,
)
# Build sections: first the text as a single Section
sections = []
link_in_meta = metadata.get("link")
if text_content.strip():
sections.append(Section(link=link_in_meta, text=text_content.strip()))
# Then any extracted images from docx, etc.
for idx, (img_data, img_name) in enumerate(embedded_images, start=1):
# Store each embedded image as a separate file in PGFileStore
# and create a section with the image summary
image_section, _ = _create_image_section(
llm,
img_data,
db_session,
pg_record.file_name,
f"{title} - image {idx}",
idx,
)
sections.append(image_section)
return [
Document(
id=doc_id,
sections=sections,
source=source_type,
sections=[
Section(link=all_metadata.get("link"), text=file_content_raw.strip())
],
source=source_type or DocumentSource.FILE,
semantic_identifier=file_display_name,
title=title,
doc_updated_at=final_time_updated,
primary_owners=p_owners,
secondary_owners=s_owners,
# currently metadata just houses tags, other stuff like owners / updated at have dedicated fields
metadata=metadata_tags,
)
]
class LocalFileConnector(LoadConnector, VisionEnabledConnector):
"""
Connector that reads files from Postgres and yields Documents, including
optional embedded image extraction.
"""
class LocalFileConnector(LoadConnector):
def __init__(
self,
file_locations: list[Path | str],
batch_size: int = INDEX_BATCH_SIZE,
) -> None:
self.file_locations = [str(loc) for loc in file_locations]
self.file_locations = [Path(file_location) for file_location in file_locations]
self.batch_size = batch_size
self.pdf_pass: str | None = None
# Initialize vision LLM using the mixin
self.initialize_vision_llm()
def load_credentials(self, credentials: dict[str, Any]) -> dict[str, Any] | None:
self.pdf_pass = credentials.get("pdf_password")
return None
def load_from_state(self) -> GenerateDocumentsOutput:
"""
Iterates over each file path, fetches from Postgres, tries to parse text
or images, and yields Document batches.
"""
documents: list[Document] = []
with get_session_with_current_tenant() as db_session:
for file_path in self.file_locations:
current_datetime = datetime.now(timezone.utc)
files_iter = _read_files_and_metadata(
file_name=file_path,
db_session=db_session,
files = _read_files_and_metadata(
file_name=str(file_path), db_session=db_session
)
for actual_file_name, file, metadata in files_iter:
for file_name, file, metadata in files:
metadata["time_updated"] = metadata.get(
"time_updated", current_datetime
)
new_docs = _process_file(
file_name=actual_file_name,
file=file,
metadata=metadata,
pdf_pass=self.pdf_pass,
db_session=db_session,
llm=self.image_analysis_llm,
documents.extend(
_process_file(file_name, file, metadata, self.pdf_pass)
)
documents.extend(new_docs)
if len(documents) >= self.batch_size:
yield documents
documents = []
if documents:
@@ -301,7 +201,7 @@ class LocalFileConnector(LoadConnector, VisionEnabledConnector):
if __name__ == "__main__":
connector = LocalFileConnector(file_locations=[os.environ["TEST_FILE"]])
connector.load_credentials({"pdf_password": os.environ.get("PDF_PASSWORD")})
doc_batches = connector.load_from_state()
for batch in doc_batches:
print("BATCH:", batch)
connector.load_credentials({"pdf_password": os.environ["PDF_PASSWORD"]})
document_batches = connector.load_from_state()
print(next(document_batches))

View File

@@ -228,15 +228,10 @@ class GitbookConnector(LoadConnector, PollConnector):
raise ConnectorMissingCredentialError("GitBook")
try:
content = self.client.get(f"/spaces/{self.space_id}/content/pages")
content = self.client.get(f"/spaces/{self.space_id}/content")
pages: list[dict[str, Any]] = content.get("pages", [])
current_batch: list[Document] = []
logger.info(f"Found {len(pages)} root pages.")
logger.info(
f"First 20 Page Ids: {[page.get('id', 'Unknown') for page in pages[:20]]}"
)
while pages:
page = pages.pop(0)

View File

@@ -20,7 +20,7 @@ from onyx.configs.constants import DocumentSource
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.exceptions import UnexpectedError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
@@ -124,14 +124,14 @@ class GithubConnector(LoadConnector, PollConnector):
def __init__(
self,
repo_owner: str,
repositories: str | None = None,
repo_name: str | None = None,
batch_size: int = INDEX_BATCH_SIZE,
state_filter: str = "all",
include_prs: bool = True,
include_issues: bool = False,
) -> None:
self.repo_owner = repo_owner
self.repositories = repositories
self.repo_name = repo_name
self.batch_size = batch_size
self.state_filter = state_filter
self.include_prs = include_prs
@@ -157,42 +157,11 @@ class GithubConnector(LoadConnector, PollConnector):
)
try:
return github_client.get_repo(f"{self.repo_owner}/{self.repositories}")
return github_client.get_repo(f"{self.repo_owner}/{self.repo_name}")
except RateLimitExceededException:
_sleep_after_rate_limit_exception(github_client)
return self._get_github_repo(github_client, attempt_num + 1)
def _get_github_repos(
self, github_client: Github, attempt_num: int = 0
) -> list[Repository.Repository]:
"""Get specific repositories based on comma-separated repo_name string."""
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
raise RuntimeError(
"Re-tried fetching repos too many times. Something is going wrong with fetching objects from Github"
)
try:
repos = []
# Split repo_name by comma and strip whitespace
repo_names = [
name.strip() for name in (cast(str, self.repositories)).split(",")
]
for repo_name in repo_names:
if repo_name: # Skip empty strings
try:
repo = github_client.get_repo(f"{self.repo_owner}/{repo_name}")
repos.append(repo)
except GithubException as e:
logger.warning(
f"Could not fetch repo {self.repo_owner}/{repo_name}: {e}"
)
return repos
except RateLimitExceededException:
_sleep_after_rate_limit_exception(github_client)
return self._get_github_repos(github_client, attempt_num + 1)
def _get_all_repos(
self, github_client: Github, attempt_num: int = 0
) -> list[Repository.Repository]:
@@ -220,17 +189,11 @@ class GithubConnector(LoadConnector, PollConnector):
if self.github_client is None:
raise ConnectorMissingCredentialError("GitHub")
repos = []
if self.repositories:
if "," in self.repositories:
# Multiple repositories specified
repos = self._get_github_repos(self.github_client)
else:
# Single repository (backward compatibility)
repos = [self._get_github_repo(self.github_client)]
else:
# All repositories
repos = self._get_all_repos(self.github_client)
repos = (
[self._get_github_repo(self.github_client)]
if self.repo_name
else self._get_all_repos(self.github_client)
)
for repo in repos:
if self.include_prs:
@@ -305,48 +268,11 @@ class GithubConnector(LoadConnector, PollConnector):
)
try:
if self.repositories:
if "," in self.repositories:
# Multiple repositories specified
repo_names = [name.strip() for name in self.repositories.split(",")]
if not repo_names:
raise ConnectorValidationError(
"Invalid connector settings: No valid repository names provided."
)
# Validate at least one repository exists and is accessible
valid_repos = False
validation_errors = []
for repo_name in repo_names:
if not repo_name:
continue
try:
test_repo = self.github_client.get_repo(
f"{self.repo_owner}/{repo_name}"
)
test_repo.get_contents("")
valid_repos = True
# If at least one repo is valid, we can proceed
break
except GithubException as e:
validation_errors.append(
f"Repository '{repo_name}': {e.data.get('message', str(e))}"
)
if not valid_repos:
error_msg = (
"None of the specified repositories could be accessed: "
)
error_msg += ", ".join(validation_errors)
raise ConnectorValidationError(error_msg)
else:
# Single repository (backward compatibility)
test_repo = self.github_client.get_repo(
f"{self.repo_owner}/{self.repositories}"
)
test_repo.get_contents("")
if self.repo_name:
test_repo = self.github_client.get_repo(
f"{self.repo_owner}/{self.repo_name}"
)
test_repo.get_contents("")
else:
# Try to get organization first
try:
@@ -358,7 +284,7 @@ class GithubConnector(LoadConnector, PollConnector):
user.get_repos().totalCount # Just check if we can access repos
except RateLimitExceededException:
raise UnexpectedValidationError(
raise UnexpectedError(
"Validation failed due to GitHub rate-limits being exceeded. Please try again later."
)
@@ -372,15 +298,10 @@ class GithubConnector(LoadConnector, PollConnector):
"Your GitHub token does not have sufficient permissions for this repository (HTTP 403)."
)
elif e.status == 404:
if self.repositories:
if "," in self.repositories:
raise ConnectorValidationError(
f"None of the specified GitHub repositories could be found for owner: {self.repo_owner}"
)
else:
raise ConnectorValidationError(
f"GitHub repository not found with name: {self.repo_owner}/{self.repositories}"
)
if self.repo_name:
raise ConnectorValidationError(
f"GitHub repository not found with name: {self.repo_owner}/{self.repo_name}"
)
else:
raise ConnectorValidationError(
f"GitHub user or organization not found: {self.repo_owner}"
@@ -389,7 +310,6 @@ class GithubConnector(LoadConnector, PollConnector):
raise ConnectorValidationError(
f"Unexpected GitHub error (status={e.status}): {e.data}"
)
except Exception as exc:
raise Exception(
f"Unexpected error during GitHub settings validation: {exc}"
@@ -401,7 +321,7 @@ if __name__ == "__main__":
connector = GithubConnector(
repo_owner=os.environ["REPO_OWNER"],
repositories=os.environ["REPOSITORIES"],
repo_name=os.environ["REPO_NAME"],
)
connector.load_credentials(
{"github_access_token": os.environ["GITHUB_ACCESS_TOKEN"]}

View File

@@ -4,12 +4,14 @@ from concurrent.futures import as_completed
from concurrent.futures import ThreadPoolExecutor
from functools import partial
from typing import Any
from typing import cast
from google.oauth2.credentials import Credentials as OAuthCredentials # type: ignore
from google.oauth2.service_account import Credentials as ServiceAccountCredentials # type: ignore
from googleapiclient.errors import HttpError # type: ignore
from onyx.configs.app_configs import INDEX_BATCH_SIZE
from onyx.configs.app_configs import MAX_FILE_SIZE_BYTES
from onyx.configs.constants import DocumentSource
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
@@ -34,6 +36,7 @@ from onyx.connectors.google_utils.shared_constants import (
)
from onyx.connectors.google_utils.shared_constants import MISSING_SCOPES_ERROR_STR
from onyx.connectors.google_utils.shared_constants import ONYX_SCOPE_INSTRUCTIONS
from onyx.connectors.google_utils.shared_constants import SCOPE_DOC_URL
from onyx.connectors.google_utils.shared_constants import SLIM_BATCH_SIZE
from onyx.connectors.google_utils.shared_constants import USER_FIELDS
from onyx.connectors.interfaces import GenerateDocumentsOutput
@@ -43,9 +46,7 @@ from onyx.connectors.interfaces import PollConnector
from onyx.connectors.interfaces import SecondsSinceUnixEpoch
from onyx.connectors.interfaces import SlimConnector
from onyx.connectors.models import ConnectorMissingCredentialError
from onyx.connectors.vision_enabled_connector import VisionEnabledConnector
from onyx.indexing.indexing_heartbeat import IndexingHeartbeatInterface
from onyx.llm.interfaces import LLM
from onyx.utils.logger import setup_logger
from onyx.utils.retry_wrapper import retry_builder
@@ -65,10 +66,7 @@ def _extract_ids_from_urls(urls: list[str]) -> list[str]:
def _convert_single_file(
creds: Any,
primary_admin_email: str,
file: dict[str, Any],
image_analysis_llm: LLM | None,
creds: Any, primary_admin_email: str, file: dict[str, Any]
) -> Any:
user_email = file.get("owners", [{}])[0].get("emailAddress") or primary_admin_email
user_drive_service = get_drive_service(creds, user_email=user_email)
@@ -77,14 +75,11 @@ def _convert_single_file(
file=file,
drive_service=user_drive_service,
docs_service=docs_service,
image_analysis_llm=image_analysis_llm, # pass the LLM so doc_conversion can summarize images
)
def _process_files_batch(
files: list[GoogleDriveFileType],
convert_func: Callable[[GoogleDriveFileType], Any],
batch_size: int,
files: list[GoogleDriveFileType], convert_func: Callable, batch_size: int
) -> GenerateDocumentsOutput:
doc_batch = []
with ThreadPoolExecutor(max_workers=min(16, len(files))) as executor:
@@ -116,9 +111,7 @@ def _clean_requested_drive_ids(
return valid_requested_drive_ids, filtered_folder_ids
class GoogleDriveConnector(
LoadConnector, PollConnector, SlimConnector, VisionEnabledConnector
):
class GoogleDriveConnector(LoadConnector, PollConnector, SlimConnector):
def __init__(
self,
include_shared_drives: bool = False,
@@ -136,23 +129,23 @@ class GoogleDriveConnector(
continue_on_failure: bool | None = None,
) -> None:
# Check for old input parameters
if folder_paths is not None:
logger.warning(
"The 'folder_paths' parameter is deprecated. Use 'shared_folder_urls' instead."
if (
folder_paths is not None
or include_shared is not None
or follow_shortcuts is not None
or only_org_public is not None
or continue_on_failure is not None
):
logger.exception(
"Google Drive connector received old input parameters. "
"Please visit the docs for help with the new setup: "
f"{SCOPE_DOC_URL}"
)
if include_shared is not None:
logger.warning(
"The 'include_shared' parameter is deprecated. Use 'include_files_shared_with_me' instead."
raise ConnectorValidationError(
"Google Drive connector received old input parameters. "
"Please visit the docs for help with the new setup: "
f"{SCOPE_DOC_URL}"
)
if follow_shortcuts is not None:
logger.warning("The 'follow_shortcuts' parameter is deprecated.")
if only_org_public is not None:
logger.warning("The 'only_org_public' parameter is deprecated.")
if continue_on_failure is not None:
logger.warning("The 'continue_on_failure' parameter is deprecated.")
# Initialize vision LLM using the mixin
self.initialize_vision_llm()
if (
not include_shared_drives
@@ -244,7 +237,6 @@ class GoogleDriveConnector(
credentials=credentials,
source=DocumentSource.GOOGLE_DRIVE,
)
return new_creds_dict
def _update_traversed_parent_ids(self, folder_id: str) -> None:
@@ -316,9 +308,7 @@ class GoogleDriveConnector(
# validate that the user has access to the drive APIs by performing a simple
# request and checking for a 401
try:
# default is ~17mins of retries, don't do that here for cases so we don't
# waste 17mins everytime we run into a user without access to drive APIs
retry_builder(tries=3, delay=1)(get_root_folder_id)(drive_service)
retry_builder()(get_root_folder_id)(drive_service)
except HttpError as e:
if e.status_code == 401:
# fail gracefully, let the other impersonations continue
@@ -533,53 +523,37 @@ class GoogleDriveConnector(
end: SecondsSinceUnixEpoch | None = None,
) -> GenerateDocumentsOutput:
# Create a larger process pool for file conversion
with ThreadPoolExecutor(max_workers=8) as executor:
# Prepare a partial function with the credentials and admin email
convert_func = partial(
_convert_single_file,
self.creds,
self.primary_admin_email,
image_analysis_llm=self.image_analysis_llm, # Use the mixin's LLM
convert_func = partial(
_convert_single_file, self.creds, self.primary_admin_email
)
# Process files in larger batches
LARGE_BATCH_SIZE = self.batch_size * 4
files_to_process = []
# Gather the files into batches to be processed in parallel
for file in self._fetch_drive_items(is_slim=False, start=start, end=end):
if (
file.get("size")
and int(cast(str, file.get("size"))) > MAX_FILE_SIZE_BYTES
):
logger.warning(
f"Skipping file {file.get('name', 'Unknown')} as it is too large: {file.get('size')} bytes"
)
continue
files_to_process.append(file)
if len(files_to_process) >= LARGE_BATCH_SIZE:
yield from _process_files_batch(
files_to_process, convert_func, self.batch_size
)
files_to_process = []
# Process any remaining files
if files_to_process:
yield from _process_files_batch(
files_to_process, convert_func, self.batch_size
)
# Fetch files in batches
files_batch: list[GoogleDriveFileType] = []
for file in self._fetch_drive_items(is_slim=False, start=start, end=end):
files_batch.append(file)
if len(files_batch) >= self.batch_size:
# Process the batch
futures = [
executor.submit(convert_func, file) for file in files_batch
]
documents = []
for future in as_completed(futures):
try:
doc = future.result()
if doc is not None:
documents.append(doc)
except Exception as e:
logger.error(f"Error converting file: {e}")
if documents:
yield documents
files_batch = []
# Process any remaining files
if files_batch:
futures = [executor.submit(convert_func, file) for file in files_batch]
documents = []
for future in as_completed(futures):
try:
doc = future.result()
if doc is not None:
documents.append(doc)
except Exception as e:
logger.error(f"Error converting file: {e}")
if documents:
yield documents
def load_from_state(self) -> GenerateDocumentsOutput:
try:
yield from self._extract_docs_from_google_drive()

View File

@@ -9,7 +9,7 @@ from googleapiclient.errors import HttpError # type: ignore
from onyx.configs.app_configs import CONTINUE_ON_CONNECTOR_FAILURE
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import FileOrigin
from onyx.configs.constants import IGNORE_FOR_QA
from onyx.connectors.google_drive.constants import DRIVE_FOLDER_TYPE
from onyx.connectors.google_drive.constants import DRIVE_SHORTCUT_TYPE
from onyx.connectors.google_drive.constants import UNSUPPORTED_FILE_TYPE_CONTENT
@@ -21,88 +21,32 @@ from onyx.connectors.google_utils.resources import GoogleDriveService
from onyx.connectors.models import Document
from onyx.connectors.models import Section
from onyx.connectors.models import SlimDocument
from onyx.db.engine import get_session_with_current_tenant
from onyx.file_processing.extract_file_text import docx_to_text_and_images
from onyx.file_processing.extract_file_text import docx_to_text
from onyx.file_processing.extract_file_text import pptx_to_text
from onyx.file_processing.extract_file_text import read_pdf_file
from onyx.file_processing.file_validation import is_valid_image_type
from onyx.file_processing.image_summarization import summarize_image_with_error_handling
from onyx.file_processing.image_utils import store_image_and_create_section
from onyx.file_processing.unstructured import get_unstructured_api_key
from onyx.file_processing.unstructured import unstructured_to_text
from onyx.llm.interfaces import LLM
from onyx.utils.logger import setup_logger
logger = setup_logger()
def _summarize_drive_image(
image_data: bytes, image_name: str, image_analysis_llm: LLM | None
) -> str:
"""
Summarize the given image using the provided LLM.
"""
if not image_analysis_llm:
return ""
return (
summarize_image_with_error_handling(
llm=image_analysis_llm,
image_data=image_data,
context_name=image_name,
)
or ""
)
def is_gdrive_image_mime_type(mime_type: str) -> bool:
"""
Return True if the mime_type is a common image type in GDrive.
(e.g. 'image/png', 'image/jpeg')
"""
return is_valid_image_type(mime_type)
# these errors don't represent a failure in the connector, but simply files
# that can't / shouldn't be indexed
ERRORS_TO_CONTINUE_ON = [
"cannotExportFile",
"exportSizeLimitExceeded",
"cannotDownloadFile",
]
def _extract_sections_basic(
file: dict[str, str],
service: GoogleDriveService,
image_analysis_llm: LLM | None = None,
file: dict[str, str], service: GoogleDriveService
) -> list[Section]:
"""
Extends the existing logic to handle either a docx with embedded images
or standalone images (PNG, JPG, etc).
"""
mime_type = file["mimeType"]
link = file["webViewLink"]
file_name = file.get("name", file["id"])
supported_file_types = set(item.value for item in GDriveMimeType)
# 1) If the file is an image, retrieve the raw bytes, optionally summarize
if is_gdrive_image_mime_type(mime_type):
try:
response = service.files().get_media(fileId=file["id"]).execute()
with get_session_with_current_tenant() as db_session:
section, _ = store_image_and_create_section(
db_session=db_session,
image_data=response,
file_name=file["id"],
display_name=file_name,
media_type=mime_type,
llm=image_analysis_llm,
file_origin=FileOrigin.CONNECTOR,
)
return [section]
except Exception as e:
logger.warning(f"Failed to fetch or summarize image: {e}")
return [
Section(
link=link,
text="",
image_file_name=link,
)
]
if mime_type not in supported_file_types:
# Unsupported file types can still have a title, finding this way is still useful
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
@@ -241,63 +185,45 @@ def _extract_sections_basic(
GDriveMimeType.PLAIN_TEXT.value,
GDriveMimeType.MARKDOWN.value,
]:
text_data = (
service.files().get_media(fileId=file["id"]).execute().decode("utf-8")
)
return [Section(link=link, text=text_data)]
return [
Section(
link=link,
text=service.files()
.get_media(fileId=file["id"])
.execute()
.decode("utf-8"),
)
]
# ---------------------------
# Word, PowerPoint, PDF files
elif mime_type in [
if mime_type in [
GDriveMimeType.WORD_DOC.value,
GDriveMimeType.POWERPOINT.value,
GDriveMimeType.PDF.value,
]:
response_bytes = service.files().get_media(fileId=file["id"]).execute()
# Optionally use Unstructured
response = service.files().get_media(fileId=file["id"]).execute()
if get_unstructured_api_key():
text = unstructured_to_text(
file=io.BytesIO(response_bytes),
file_name=file_name,
)
return [Section(link=link, text=text)]
return [
Section(
link=link,
text=unstructured_to_text(
file=io.BytesIO(response),
file_name=file.get("name", file["id"]),
),
)
]
if mime_type == GDriveMimeType.WORD_DOC.value:
# Use docx_to_text_and_images to get text plus embedded images
text, embedded_images = docx_to_text_and_images(
file=io.BytesIO(response_bytes),
)
sections = []
if text.strip():
sections.append(Section(link=link, text=text.strip()))
# Process each embedded image using the standardized function
with get_session_with_current_tenant() as db_session:
for idx, (img_data, img_name) in enumerate(
embedded_images, start=1
):
# Create a unique identifier for the embedded image
embedded_id = f"{file['id']}_embedded_{idx}"
section, _ = store_image_and_create_section(
db_session=db_session,
image_data=img_data,
file_name=embedded_id,
display_name=img_name or f"{file_name} - image {idx}",
llm=image_analysis_llm,
file_origin=FileOrigin.CONNECTOR,
)
sections.append(section)
return sections
return [
Section(link=link, text=docx_to_text(file=io.BytesIO(response)))
]
elif mime_type == GDriveMimeType.PDF.value:
text, _pdf_meta, images = read_pdf_file(io.BytesIO(response_bytes))
text, _ = read_pdf_file(file=io.BytesIO(response))
return [Section(link=link, text=text)]
elif mime_type == GDriveMimeType.POWERPOINT.value:
text_data = pptx_to_text(io.BytesIO(response_bytes))
return [Section(link=link, text=text_data)]
return [
Section(link=link, text=pptx_to_text(file=io.BytesIO(response)))
]
# Catch-all case, should not happen since there should be specific handling
# for each of the supported file types
@@ -305,8 +231,7 @@ def _extract_sections_basic(
logger.error(error_message)
raise ValueError(error_message)
except Exception as e:
logger.exception(f"Error extracting sections from file: {e}")
except Exception:
return [Section(link=link, text=UNSUPPORTED_FILE_TYPE_CONTENT)]
@@ -314,62 +239,74 @@ def convert_drive_item_to_document(
file: GoogleDriveFileType,
drive_service: GoogleDriveService,
docs_service: GoogleDocsService,
image_analysis_llm: LLM | None,
) -> Document | None:
"""
Main entry point for converting a Google Drive file => Document object.
Now we accept an optional `llm` to pass to `_extract_sections_basic`.
"""
try:
# skip shortcuts or folders
if file.get("mimeType") in [DRIVE_SHORTCUT_TYPE, DRIVE_FOLDER_TYPE]:
logger.info("Skipping shortcut/folder.")
# Skip files that are shortcuts
if file.get("mimeType") == DRIVE_SHORTCUT_TYPE:
logger.info("Ignoring Drive Shortcut Filetype")
return None
# Skip files that are folders
if file.get("mimeType") == DRIVE_FOLDER_TYPE:
logger.info("Ignoring Drive Folder Filetype")
return None
# If it's a Google Doc, we might do advanced parsing
sections: list[Section] = []
# Special handling for Google Docs to preserve structure, link
# to headers
if file.get("mimeType") == GDriveMimeType.DOC.value:
try:
# get_document_sections is the advanced approach for Google Docs
sections = get_document_sections(docs_service, file["id"])
except Exception as e:
logger.warning(
f"Failed to pull google doc sections from '{file['name']}': {e}. "
"Falling back to basic extraction."
f"Ran into exception '{e}' when pulling sections from Google Doc '{file['name']}'."
" Falling back to basic extraction."
)
# If not a doc, or if we failed above, do our 'basic' approach
# NOTE: this will run for either (1) the above failed or (2) the file is not a Google Doc
if not sections:
sections = _extract_sections_basic(file, drive_service, image_analysis_llm)
try:
# For all other file types just extract the text
sections = _extract_sections_basic(file, drive_service)
except HttpError as e:
reason = e.error_details[0]["reason"] if e.error_details else e.reason
message = e.error_details[0]["message"] if e.error_details else e.reason
if e.status_code == 403 and reason in ERRORS_TO_CONTINUE_ON:
logger.warning(
f"Could not export file '{file['name']}' due to '{message}', skipping..."
)
return None
raise
if not sections:
return None
doc_id = file["webViewLink"]
updated_time = datetime.fromisoformat(file["modifiedTime"]).astimezone(
timezone.utc
)
return Document(
id=doc_id,
id=file["webViewLink"],
sections=sections,
source=DocumentSource.GOOGLE_DRIVE,
semantic_identifier=file["name"],
doc_updated_at=updated_time,
metadata={}, # or any metadata from 'file'
doc_updated_at=datetime.fromisoformat(file["modifiedTime"]).astimezone(
timezone.utc
),
metadata={}
if any(section.text for section in sections)
else {IGNORE_FOR_QA: "True"},
additional_info=file.get("id"),
)
except Exception as e:
logger.exception(f"Error converting file '{file.get('name')}' to Document: {e}")
if not CONTINUE_ON_CONNECTOR_FAILURE:
raise
raise e
logger.exception("Ran into exception when pulling a file from Google Drive")
return None
def build_slim_document(file: GoogleDriveFileType) -> SlimDocument | None:
# Skip files that are folders or shortcuts
if file.get("mimeType") in [DRIVE_FOLDER_TYPE, DRIVE_SHORTCUT_TYPE]:
return None
return SlimDocument(
id=file["webViewLink"],
perm_sync_data={

View File

@@ -1,4 +1,3 @@
import json
from datetime import datetime
from enum import Enum
from typing import Any
@@ -29,8 +28,7 @@ class ConnectorMissingCredentialError(PermissionError):
class Section(BaseModel):
text: str
link: str | None = None
image_file_name: str | None = None
link: str | None
class BasicExpertInfo(BaseModel):
@@ -205,15 +203,6 @@ class ConnectorCheckpoint(BaseModel):
def build_dummy_checkpoint(cls) -> "ConnectorCheckpoint":
return ConnectorCheckpoint(checkpoint_content={}, has_more=True)
def __str__(self) -> str:
"""String representation of the checkpoint, with truncation for large checkpoint content."""
MAX_CHECKPOINT_CONTENT_CHARS = 1000
content_str = json.dumps(self.checkpoint_content)
if len(content_str) > MAX_CHECKPOINT_CONTENT_CHARS:
content_str = content_str[: MAX_CHECKPOINT_CONTENT_CHARS - 3] + "..."
return f"ConnectorCheckpoint(checkpoint_content={content_str}, has_more={self.has_more})"
class DocumentFailure(BaseModel):
document_id: str

View File

@@ -1,3 +1,4 @@
import time
from collections.abc import Generator
from dataclasses import dataclass
from dataclasses import fields
@@ -18,7 +19,7 @@ from onyx.connectors.cross_connector_utils.rate_limit_wrapper import (
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.exceptions import UnexpectedError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
@@ -31,7 +32,6 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
_NOTION_PAGE_SIZE = 100
_NOTION_CALL_TIMEOUT = 30 # 30 seconds
@@ -537,9 +537,9 @@ class NotionConnector(LoadConnector, PollConnector):
"""
filtered_pages: list[NotionPage] = []
for page in pages:
# Parse ISO 8601 timestamp and convert to UTC epoch time
timestamp = page[filter_field].replace(".000Z", "+00:00")
compare_time = datetime.fromisoformat(timestamp).timestamp()
compare_time = time.mktime(
time.strptime(page[filter_field], "%Y-%m-%dT%H:%M:%S.000Z")
)
if compare_time > start and compare_time <= end:
filtered_pages += [NotionPage(**page)]
return filtered_pages
@@ -578,7 +578,7 @@ class NotionConnector(LoadConnector, PollConnector):
query_dict = {
"filter": {"property": "object", "value": "page"},
"page_size": _NOTION_PAGE_SIZE,
"page_size": self.batch_size,
}
while True:
db_res = self._search_notion(query_dict)
@@ -604,7 +604,7 @@ class NotionConnector(LoadConnector, PollConnector):
return
query_dict = {
"page_size": _NOTION_PAGE_SIZE,
"page_size": self.batch_size,
"sort": {"timestamp": "last_edited_time", "direction": "descending"},
"filter": {"property": "object", "value": "page"},
}
@@ -671,12 +671,12 @@ class NotionConnector(LoadConnector, PollConnector):
"Please try again later."
)
else:
raise UnexpectedValidationError(
raise UnexpectedError(
f"Unexpected Notion HTTP error (status={status_code}): {http_err}"
) from http_err
except Exception as exc:
raise UnexpectedValidationError(
raise UnexpectedError(
f"Unexpected error during Notion settings validation: {exc}"
)

View File

@@ -21,7 +21,7 @@ from onyx.configs.constants import DocumentSource
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.exceptions import UnexpectedError
from onyx.connectors.interfaces import CheckpointConnector
from onyx.connectors.interfaces import CheckpointOutput
from onyx.connectors.interfaces import GenerateSlimDocumentOutput
@@ -674,7 +674,7 @@ class SlackConnector(SlimConnector, CheckpointConnector):
"""
1. Verify the bot token is valid for the workspace (via auth_test).
2. Ensure the bot has enough scope to list channels.
3. Check that every channel specified in self.channels exists (only when regex is not enabled).
3. Check that every channel specified in self.channels exists.
"""
if self.client is None:
raise ConnectorMissingCredentialError("Slack credentials not loaded.")
@@ -702,12 +702,10 @@ class SlackConnector(SlimConnector, CheckpointConnector):
raise CredentialExpiredError(
f"Invalid or expired Slack bot token ({error_msg})."
)
raise UnexpectedValidationError(
f"Slack API returned a failure: {error_msg}"
)
raise UnexpectedError(f"Slack API returned a failure: {error_msg}")
# 3) If channels are specified and regex is not enabled, verify each is accessible
if self.channels and not self.channel_regex_enabled:
# 3) If channels are specified, verify each is accessible
if self.channels:
accessible_channels = get_channels(
client=self.client,
exclude_archived=True,
@@ -742,13 +740,13 @@ class SlackConnector(SlimConnector, CheckpointConnector):
raise CredentialExpiredError(
f"Invalid or expired Slack bot token ({slack_error})."
)
raise UnexpectedValidationError(
raise UnexpectedError(
f"Unexpected Slack error '{slack_error}' during settings validation."
)
except ConnectorValidationError as e:
raise e
except Exception as e:
raise UnexpectedValidationError(
raise UnexpectedError(
f"Unexpected error during Slack settings validation: {e}"
)

View File

@@ -72,7 +72,6 @@ def make_slack_api_rate_limited(
@wraps(call)
def rate_limited_call(**kwargs: Any) -> SlackResponse:
last_exception = None
for _ in range(max_retries):
try:
# Make the API call

View File

@@ -16,7 +16,7 @@ from onyx.connectors.cross_connector_utils.miscellaneous_utils import time_str_t
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.exceptions import UnexpectedError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.interfaces import PollConnector
@@ -302,7 +302,7 @@ class TeamsConnector(LoadConnector, PollConnector):
raise InsufficientPermissionsError(
"Your app lacks sufficient permissions to read Teams (403 Forbidden)."
)
raise UnexpectedValidationError(f"Unexpected error retrieving teams: {e}")
raise UnexpectedError(f"Unexpected error retrieving teams: {e}")
except Exception as e:
error_str = str(e).lower()

View File

@@ -1,45 +0,0 @@
"""
Mixin for connectors that need vision capabilities.
"""
from onyx.configs.llm_configs import get_image_extraction_and_analysis_enabled
from onyx.llm.factory import get_default_llm_with_vision
from onyx.llm.interfaces import LLM
from onyx.utils.logger import setup_logger
logger = setup_logger()
class VisionEnabledConnector:
"""
Mixin for connectors that need vision capabilities.
This mixin provides a standard way to initialize a vision-capable LLM
for image analysis during indexing.
Usage:
class MyConnector(LoadConnector, VisionEnabledConnector):
def __init__(self, ...):
super().__init__(...)
self.initialize_vision_llm()
"""
def initialize_vision_llm(self) -> None:
"""
Initialize a vision-capable LLM if enabled by configuration.
Sets self.image_analysis_llm to the LLM instance or None if disabled.
"""
self.image_analysis_llm: LLM | None = None
if get_image_extraction_and_analysis_enabled():
try:
self.image_analysis_llm = get_default_llm_with_vision()
if self.image_analysis_llm is None:
logger.warning(
"No LLM with vision found; image summarization will be disabled"
)
except Exception as e:
logger.warning(
f"Failed to initialize vision LLM due to an error: {str(e)}. "
"Image summarization will be disabled."
)
self.image_analysis_llm = None

View File

@@ -28,7 +28,7 @@ from onyx.configs.constants import DocumentSource
from onyx.connectors.exceptions import ConnectorValidationError
from onyx.connectors.exceptions import CredentialExpiredError
from onyx.connectors.exceptions import InsufficientPermissionsError
from onyx.connectors.exceptions import UnexpectedValidationError
from onyx.connectors.exceptions import UnexpectedError
from onyx.connectors.interfaces import GenerateDocumentsOutput
from onyx.connectors.interfaces import LoadConnector
from onyx.connectors.models import Document
@@ -42,10 +42,6 @@ from shared_configs.configs import MULTI_TENANT
logger = setup_logger()
WEB_CONNECTOR_MAX_SCROLL_ATTEMPTS = 20
# Threshold for determining when to replace vs append iframe content
IFRAME_TEXT_LENGTH_THRESHOLD = 700
# Message indicating JavaScript is disabled, which often appears when scraping fails
JAVASCRIPT_DISABLED_MESSAGE = "You have JavaScript disabled in your browser"
class WEB_CONNECTOR_VALID_SETTINGS(str, Enum):
@@ -142,8 +138,7 @@ def get_internal_links(
# Account for malformed backslashes in URLs
href = href.replace("\\", "/")
# "#!" indicates the page is using a hashbang URL, which is a client-side routing technique
if should_ignore_pound and "#" in href and "#!" not in href:
if should_ignore_pound and "#" in href:
href = href.split("#")[0]
if not is_valid_url(href):
@@ -157,7 +152,6 @@ def get_internal_links(
def start_playwright() -> Tuple[Playwright, BrowserContext]:
playwright = sync_playwright().start()
browser = playwright.chromium.launch(headless=True)
context = browser.new_context()
@@ -294,7 +288,6 @@ class WebConnector(LoadConnector):
and converts them into documents"""
visited_links: set[str] = set()
to_visit: list[str] = self.to_visit_list
content_hashes = set()
if not to_visit:
raise ValueError("No URLs to visit")
@@ -321,8 +314,7 @@ class WebConnector(LoadConnector):
logger.warning(last_error)
continue
index = len(visited_links)
logger.info(f"{index}: Visiting {initial_url}")
logger.info(f"{len(visited_links)}: Visiting {initial_url}")
try:
check_internet_connection(initial_url)
@@ -333,7 +325,7 @@ class WebConnector(LoadConnector):
if initial_url.split(".")[-1] == "pdf":
# PDF files are not checked for links
response = requests.get(initial_url)
page_text, metadata, images = read_pdf_file(
page_text, metadata = read_pdf_file(
file=io.BytesIO(response.content)
)
last_modified = response.headers.get("Last-Modified")
@@ -355,13 +347,7 @@ class WebConnector(LoadConnector):
continue
page = context.new_page()
# Can't use wait_until="networkidle" because it interferes with the scrolling behavior
page_response = page.goto(
initial_url,
timeout=30000, # 30 seconds
)
page_response = page.goto(initial_url)
last_modified = (
page_response.header_value("Last-Modified")
if page_response
@@ -373,10 +359,12 @@ class WebConnector(LoadConnector):
initial_url = final_url
if initial_url in visited_links:
logger.info(
f"{index}: {initial_url} redirected to {final_url} - already indexed"
f"{len(visited_links)}: {initial_url} redirected to {final_url} - already indexed"
)
continue
logger.info(f"{index}: {initial_url} redirected to {final_url}")
logger.info(
f"{len(visited_links)}: {initial_url} redirected to {final_url}"
)
visited_links.add(initial_url)
if self.scroll_before_scraping:
@@ -407,38 +395,6 @@ class WebConnector(LoadConnector):
parsed_html = web_html_cleanup(soup, self.mintlify_cleanup)
"""For websites containing iframes that need to be scraped,
the code below can extract text from within these iframes.
"""
logger.debug(
f"{index}: Length of cleaned text {len(parsed_html.cleaned_text)}"
)
if JAVASCRIPT_DISABLED_MESSAGE in parsed_html.cleaned_text:
iframe_count = page.frame_locator("iframe").locator("html").count()
if iframe_count > 0:
iframe_texts = (
page.frame_locator("iframe")
.locator("html")
.all_inner_texts()
)
document_text = "\n".join(iframe_texts)
""" 700 is the threshold value for the length of the text extracted
from the iframe based on the issue faced """
if len(parsed_html.cleaned_text) < IFRAME_TEXT_LENGTH_THRESHOLD:
parsed_html.cleaned_text = document_text
else:
parsed_html.cleaned_text += "\n" + document_text
# Sometimes pages with #! will serve duplicate content
# There are also just other ways this can happen
hashed_text = hash((parsed_html.title, parsed_html.cleaned_text))
if hashed_text in content_hashes:
logger.info(
f"{index}: Skipping duplicate title + content for {initial_url}"
)
continue
content_hashes.add(hashed_text)
doc_batch.append(
Document(
id=initial_url,
@@ -529,9 +485,7 @@ class WebConnector(LoadConnector):
)
else:
# Could be a 5xx or another error, treat as unexpected
raise UnexpectedValidationError(
f"Unexpected error validating '{test_url}': {e}"
)
raise UnexpectedError(f"Unexpected error validating '{test_url}': {e}")
if __name__ == "__main__":

View File

@@ -16,7 +16,7 @@ from onyx.db.models import SearchSettings
from onyx.indexing.models import BaseChunk
from onyx.indexing.models import IndexingSetting
from shared_configs.enums import RerankerProvider
from shared_configs.model_server_models import Embedding
MAX_METRICS_CONTENT = (
200 # Just need enough characters to identify where in the doc the chunk is
@@ -76,10 +76,6 @@ class SavedSearchSettings(InferenceSettings, IndexingSetting):
provider_type=search_settings.provider_type,
index_name=search_settings.index_name,
multipass_indexing=search_settings.multipass_indexing,
embedding_precision=search_settings.embedding_precision,
reduced_dimension=search_settings.reduced_dimension,
# Whether switching to this model requires re-indexing
background_reindex_enabled=search_settings.background_reindex_enabled,
# Reranking Details
rerank_model_name=search_settings.rerank_model_name,
rerank_provider_type=search_settings.rerank_provider_type,
@@ -151,10 +147,6 @@ class SearchRequest(ChunkContext):
evaluation_type: LLMEvaluationType = LLMEvaluationType.UNSPECIFIED
model_config = ConfigDict(arbitrary_types_allowed=True)
precomputed_query_embedding: Embedding | None = None
precomputed_is_keyword: bool | None = None
precomputed_keywords: list[str] | None = None
class SearchQuery(ChunkContext):
"Processed Request that is directly passed to the SearchPipeline"
@@ -179,8 +171,6 @@ class SearchQuery(ChunkContext):
offset: int = 0
model_config = ConfigDict(frozen=True)
precomputed_query_embedding: Embedding | None = None
class RetrievalDetails(ChunkContext):
# Use LLM to determine whether to do a retrieval or only rely on existing history

View File

@@ -331,14 +331,6 @@ class SearchPipeline:
self._retrieved_sections = expanded_inference_sections
return expanded_inference_sections
@property
def retrieved_sections(self) -> list[InferenceSection]:
if self._retrieved_sections is not None:
return self._retrieved_sections
self._retrieved_sections = self._get_sections()
return self._retrieved_sections
@property
def reranked_sections(self) -> list[InferenceSection]:
"""Reranking is always done at the chunk level since section merging could create arbitrarily
@@ -351,7 +343,7 @@ class SearchPipeline:
if self._reranked_sections is not None:
return self._reranked_sections
retrieved_sections = self.retrieved_sections
retrieved_sections = self._get_sections()
if self.retrieved_sections_callback is not None:
self.retrieved_sections_callback(retrieved_sections)

View File

@@ -1,17 +1,12 @@
import base64
from collections.abc import Callable
from collections.abc import Iterator
from typing import cast
import numpy
from langchain_core.messages import BaseMessage
from langchain_core.messages import HumanMessage
from langchain_core.messages import SystemMessage
from onyx.chat.models import SectionRelevancePiece
from onyx.configs.app_configs import BLURB_SIZE
from onyx.configs.constants import RETURN_SEPARATOR
from onyx.configs.llm_configs import get_search_time_image_analysis_enabled
from onyx.configs.model_configs import CROSS_ENCODER_RANGE_MAX
from onyx.configs.model_configs import CROSS_ENCODER_RANGE_MIN
from onyx.context.search.enums import LLMEvaluationType
@@ -23,15 +18,11 @@ from onyx.context.search.models import MAX_METRICS_CONTENT
from onyx.context.search.models import RerankingDetails
from onyx.context.search.models import RerankMetricsContainer
from onyx.context.search.models import SearchQuery
from onyx.db.engine import get_session_with_current_tenant
from onyx.document_index.document_index_utils import (
translate_boost_count_to_multiplier,
)
from onyx.file_store.file_store import get_default_file_store
from onyx.llm.interfaces import LLM
from onyx.llm.utils import message_to_string
from onyx.natural_language_processing.search_nlp_models import RerankingModel
from onyx.prompts.image_analysis import IMAGE_ANALYSIS_SYSTEM_PROMPT
from onyx.secondary_llm_flows.chunk_usefulness import llm_batch_eval_sections
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import FunctionCall
@@ -39,124 +30,6 @@ from onyx.utils.threadpool_concurrency import run_functions_in_parallel
from onyx.utils.timing import log_function_time
def update_image_sections_with_query(
sections: list[InferenceSection],
query: str,
llm: LLM,
) -> None:
"""
For each chunk in each section that has an image URL, call an LLM to produce
a new 'content' string that directly addresses the user's query about that image.
This implementation uses parallel processing for efficiency.
"""
logger = setup_logger()
logger.debug(f"Starting image section update with query: {query}")
chunks_with_images = []
for section in sections:
for chunk in section.chunks:
if chunk.image_file_name:
chunks_with_images.append(chunk)
if not chunks_with_images:
logger.debug("No images to process in the sections")
return # No images to process
logger.info(f"Found {len(chunks_with_images)} chunks with images to process")
def process_image_chunk(chunk: InferenceChunk) -> tuple[str, str]:
try:
logger.debug(
f"Processing image chunk with ID: {chunk.unique_id}, image: {chunk.image_file_name}"
)
with get_session_with_current_tenant() as db_session:
file_record = get_default_file_store(db_session).read_file(
cast(str, chunk.image_file_name), mode="b"
)
if not file_record:
logger.error(f"Image file not found: {chunk.image_file_name}")
raise Exception("File not found")
file_content = file_record.read()
image_base64 = base64.b64encode(file_content).decode()
logger.debug(
f"Successfully loaded image data for {chunk.image_file_name}"
)
messages: list[BaseMessage] = [
SystemMessage(content=IMAGE_ANALYSIS_SYSTEM_PROMPT),
HumanMessage(
content=[
{
"type": "text",
"text": (
f"The user's question is: '{query}'. "
"Please analyze the following image in that context:\n"
),
},
{
"type": "image_url",
"image_url": {
"url": f"data:image/jpeg;base64,{image_base64}",
},
},
]
),
]
raw_response = llm.invoke(messages)
answer_text = message_to_string(raw_response).strip()
return (
chunk.unique_id,
answer_text if answer_text else "No relevant info found.",
)
except Exception:
logger.exception(
f"Error updating image section with query source image url: {chunk.image_file_name}"
)
return chunk.unique_id, "Error analyzing image."
image_processing_tasks = [
FunctionCall(process_image_chunk, (chunk,)) for chunk in chunks_with_images
]
logger.info(
f"Starting parallel processing of {len(image_processing_tasks)} image tasks"
)
image_processing_results = run_functions_in_parallel(image_processing_tasks)
logger.info(
f"Completed parallel processing with {len(image_processing_results)} results"
)
# Create a mapping of chunk IDs to their processed content
chunk_id_to_content = {}
success_count = 0
for task_id, result in image_processing_results.items():
if result:
chunk_id, content = result
chunk_id_to_content[chunk_id] = content
success_count += 1
else:
logger.error(f"Task {task_id} failed to return a valid result")
logger.info(
f"Successfully processed {success_count}/{len(image_processing_results)} images"
)
# Update the chunks with the processed content
updated_count = 0
for section in sections:
for chunk in section.chunks:
if chunk.unique_id in chunk_id_to_content:
chunk.content = chunk_id_to_content[chunk.unique_id]
updated_count += 1
logger.info(
f"Updated content for {updated_count} chunks with image analysis results"
)
logger = setup_logger()
@@ -413,10 +286,6 @@ def search_postprocessing(
# NOTE: if we don't rerank, we can return the chunks immediately
# since we know this is the final order.
# This way the user experience isn't delayed by the LLM step
if get_search_time_image_analysis_enabled():
update_image_sections_with_query(
retrieved_sections, search_query.query, llm
)
_log_top_section_links(search_query.search_type.value, retrieved_sections)
yield retrieved_sections
sections_yielded = True
@@ -454,13 +323,6 @@ def search_postprocessing(
)
else:
_log_top_section_links(search_query.search_type.value, reranked_sections)
# Add the image processing step here
if get_search_time_image_analysis_enabled():
update_image_sections_with_query(
reranked_sections, search_query.query, llm
)
yield reranked_sections
llm_selected_section_ids = (

View File

@@ -117,12 +117,8 @@ def retrieval_preprocessing(
else None
)
# Sometimes this is pre-computed in parallel with other heavy tasks to improve
# latency, and in that case we don't need to run the model again
run_query_analysis = (
None
if (skip_query_analysis or search_request.precomputed_is_keyword is not None)
else FunctionCall(query_analysis, (query,), {})
None if skip_query_analysis else FunctionCall(query_analysis, (query,), {})
)
functions_to_run = [
@@ -147,12 +143,11 @@ def retrieval_preprocessing(
# The extracted keywords right now are not very reliable, not using for now
# Can maybe use for highlighting
is_keyword, _extracted_keywords = False, None
if search_request.precomputed_is_keyword is not None:
is_keyword = search_request.precomputed_is_keyword
_extracted_keywords = search_request.precomputed_keywords
elif run_query_analysis:
is_keyword, _extracted_keywords = parallel_results[run_query_analysis.result_id]
is_keyword, extracted_keywords = (
parallel_results[run_query_analysis.result_id]
if run_query_analysis
else (False, None)
)
all_query_terms = query.split()
processed_keywords = (
@@ -252,5 +247,4 @@ def retrieval_preprocessing(
chunks_above=chunks_above,
chunks_below=chunks_below,
full_doc=search_request.full_doc,
precomputed_query_embedding=search_request.precomputed_query_embedding,
)

View File

@@ -31,7 +31,7 @@ from onyx.utils.timing import log_function_time
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.enums import EmbedTextType
from shared_configs.model_server_models import Embedding
logger = setup_logger()
@@ -109,20 +109,6 @@ def combine_retrieval_results(
return sorted_chunks
def get_query_embedding(query: str, db_session: Session) -> Embedding:
search_settings = get_current_search_settings(db_session)
model = EmbeddingModel.from_db_model(
search_settings=search_settings,
# The below are globally set, this flow always uses the indexing one
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
query_embedding = model.encode([query], text_type=EmbedTextType.QUERY)[0]
return query_embedding
@log_function_time(print_only=True)
def doc_index_retrieval(
query: SearchQuery,
@@ -135,10 +121,17 @@ def doc_index_retrieval(
from the large chunks to the referenced chunks,
dedupes the chunks, and cleans the chunks.
"""
query_embedding = query.precomputed_query_embedding or get_query_embedding(
query.query, db_session
search_settings = get_current_search_settings(db_session)
model = EmbeddingModel.from_db_model(
search_settings=search_settings,
# The below are globally set, this flow always uses the indexing one
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
query_embedding = model.encode([query.query], text_type=EmbedTextType.QUERY)[0]
top_chunks = document_index.hybrid_retrieval(
query=query.query,
query_embedding=query_embedding,
@@ -256,16 +249,7 @@ def retrieve_chunks(
continue
simplified_queries.add(simplified_rephrase)
q_copy = query.model_copy(
update={
"query": rephrase,
# need to recompute for each rephrase
# note that `SearchQuery` is a frozen model, so we can't update
# it below
"precomputed_query_embedding": None,
},
deep=True,
)
q_copy = query.copy(update={"query": rephrase}, deep=True)
run_queries.append(
(
doc_index_retrieval,

View File

@@ -3,7 +3,6 @@ from datetime import datetime
from datetime import timedelta
from typing import Any
from typing import cast
from typing import Tuple
from uuid import UUID
from fastapi import HTTPException
@@ -12,7 +11,6 @@ from sqlalchemy import desc
from sqlalchemy import func
from sqlalchemy import nullsfirst
from sqlalchemy import or_
from sqlalchemy import Row
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.exc import MultipleResultsFound
@@ -377,33 +375,24 @@ def delete_chat_session(
db_session.commit()
def get_chat_sessions_older_than(
days_old: int, db_session: Session
) -> list[tuple[UUID | None, UUID]]:
"""
Retrieves chat sessions older than a specified number of days.
Args:
days_old: The number of days to consider as "old".
db_session: The database session.
Returns:
A list of tuples, where each tuple contains the user_id (can be None) and the chat_session_id of an old chat session.
"""
def delete_chat_sessions_older_than(days_old: int, db_session: Session) -> None:
cutoff_time = datetime.utcnow() - timedelta(days=days_old)
old_sessions: Sequence[Row[Tuple[UUID | None, UUID]]] = db_session.execute(
old_sessions = db_session.execute(
select(ChatSession.user_id, ChatSession.id).where(
ChatSession.time_created < cutoff_time
)
).fetchall()
# convert old_sessions to a conventional list of tuples
returned_sessions: list[tuple[UUID | None, UUID]] = [
(user_id, session_id) for user_id, session_id in old_sessions
]
return returned_sessions
for user_id, session_id in old_sessions:
try:
delete_chat_session(
user_id, session_id, db_session, include_deleted=True, hard_delete=True
)
except Exception:
logger.exception(
"delete_chat_session exceptioned. "
f"user_id={user_id} session_id={session_id}"
)
def get_chat_message(

View File

@@ -63,9 +63,6 @@ class IndexModelStatus(str, PyEnum):
PRESENT = "PRESENT"
FUTURE = "FUTURE"
def is_current(self) -> bool:
return self == IndexModelStatus.PRESENT
class ChatSessionSharedStatus(str, PyEnum):
PUBLIC = "public"
@@ -86,11 +83,3 @@ class AccessType(str, PyEnum):
PUBLIC = "public"
PRIVATE = "private"
SYNC = "sync"
class EmbeddingPrecision(str, PyEnum):
# matches vespa tensor type
# only support float / bfloat16 for now, since there's not a
# good reason to specify anything else
BFLOAT16 = "bfloat16"
FLOAT = "float"

View File

@@ -46,13 +46,7 @@ from onyx.configs.constants import DEFAULT_BOOST, MilestoneRecordType
from onyx.configs.constants import DocumentSource
from onyx.configs.constants import FileOrigin
from onyx.configs.constants import MessageType
from onyx.db.enums import (
AccessType,
EmbeddingPrecision,
IndexingMode,
SyncType,
SyncStatus,
)
from onyx.db.enums import AccessType, IndexingMode, SyncType, SyncStatus
from onyx.configs.constants import NotificationType
from onyx.configs.constants import SearchFeedbackType
from onyx.configs.constants import TokenRateLimitScope
@@ -722,23 +716,6 @@ class SearchSettings(Base):
ForeignKey("embedding_provider.provider_type"), nullable=True
)
# Whether switching to this model should re-index all connectors in the background
# if no re-index is needed, will be ignored. Only used during the switch-over process.
background_reindex_enabled: Mapped[bool] = mapped_column(Boolean, default=True)
# allows for quantization -> less memory usage for a small performance hit
embedding_precision: Mapped[EmbeddingPrecision] = mapped_column(
Enum(EmbeddingPrecision, native_enum=False)
)
# can be used to reduce dimensionality of vectors and save memory with
# a small performance hit. More details in the `Reducing embedding dimensions`
# section here:
# https://platform.openai.com/docs/guides/embeddings#embedding-models
# If not specified, will just use the model_dim without any reduction.
# NOTE: this is only currently available for OpenAI models
reduced_dimension: Mapped[int | None] = mapped_column(Integer, nullable=True)
# Mini and Large Chunks (large chunk also checks for model max context)
multipass_indexing: Mapped[bool] = mapped_column(Boolean, default=True)
@@ -820,12 +797,6 @@ class SearchSettings(Base):
self.multipass_indexing, self.model_name, self.provider_type
)
@property
def final_embedding_dim(self) -> int:
if self.reduced_dimension:
return self.reduced_dimension
return self.model_dim
@staticmethod
def can_use_large_chunks(
multipass: bool, model_name: str, provider_type: EmbeddingProvider | None
@@ -1790,7 +1761,6 @@ class ChannelConfig(TypedDict):
channel_name: str | None # None for default channel config
respond_tag_only: NotRequired[bool] # defaults to False
respond_to_bots: NotRequired[bool] # defaults to False
is_ephemeral: NotRequired[bool] # defaults to False
respond_member_group_list: NotRequired[list[str]]
answer_filters: NotRequired[list[AllowedAnswerFilters]]
# If None then no follow up
@@ -2295,14 +2265,15 @@ class PublicBase(DeclarativeBase):
__abstract__ = True
# Strictly keeps track of the tenant that a given user will authenticate to.
class UserTenantMapping(Base):
__tablename__ = "user_tenant_mapping"
__table_args__ = ({"schema": "public"},)
__table_args__ = (
UniqueConstraint("email", "tenant_id", name="uq_user_tenant"),
{"schema": "public"},
)
email: Mapped[str] = mapped_column(String, nullable=False, primary_key=True)
tenant_id: Mapped[str] = mapped_column(String, nullable=False, primary_key=True)
active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
tenant_id: Mapped[str] = mapped_column(String, nullable=False)
@validates("email")
def validate_email(self, key: str, value: str) -> str:

View File

@@ -209,21 +209,13 @@ def create_update_persona(
if not all_prompt_ids:
raise ValueError("No prompt IDs provided")
is_default_persona: bool | None = create_persona_request.is_default_persona
# Default persona validation
if create_persona_request.is_default_persona:
if not create_persona_request.is_public:
raise ValueError("Cannot make a default persona non public")
if user:
# Curators can edit default personas, but not make them
if (
user.role == UserRole.CURATOR
or user.role == UserRole.GLOBAL_CURATOR
):
is_default_persona = None
elif user.role != UserRole.ADMIN:
raise ValueError("Only admins can make a default persona")
if user and user.role != UserRole.ADMIN:
raise ValueError("Only admins can make a default persona")
persona = upsert_persona(
persona_id=persona_id,
@@ -249,7 +241,7 @@ def create_update_persona(
num_chunks=create_persona_request.num_chunks,
llm_relevance_filter=create_persona_request.llm_relevance_filter,
llm_filter_extraction=create_persona_request.llm_filter_extraction,
is_default_persona=is_default_persona,
is_default_persona=create_persona_request.is_default_persona,
)
versioned_make_persona_private = fetch_versioned_implementation(
@@ -436,7 +428,7 @@ def upsert_persona(
remove_image: bool | None = None,
search_start_date: datetime | None = None,
builtin_persona: bool = False,
is_default_persona: bool | None = None,
is_default_persona: bool = False,
label_ids: list[int] | None = None,
chunks_above: int = CONTEXT_CHUNKS_ABOVE,
chunks_below: int = CONTEXT_CHUNKS_BELOW,
@@ -531,11 +523,7 @@ def upsert_persona(
existing_persona.is_visible = is_visible
existing_persona.search_start_date = search_start_date
existing_persona.labels = labels or []
existing_persona.is_default_persona = (
is_default_persona
if is_default_persona is not None
else existing_persona.is_default_persona
)
existing_persona.is_default_persona = is_default_persona
# Do not delete any associations manually added unless
# a new updated list is provided
if document_sets is not None:
@@ -587,9 +575,7 @@ def upsert_persona(
display_priority=display_priority,
is_visible=is_visible,
search_start_date=search_start_date,
is_default_persona=is_default_persona
if is_default_persona is not None
else False,
is_default_persona=is_default_persona,
labels=labels or [],
)
db_session.add(new_persona)

View File

@@ -148,28 +148,3 @@ def upsert_pgfilestore(
db_session.commit()
return pgfilestore
def save_bytes_to_pgfilestore(
db_session: Session,
raw_bytes: bytes,
media_type: str,
identifier: str,
display_name: str,
file_origin: FileOrigin = FileOrigin.OTHER,
) -> PGFileStore:
"""
Saves raw bytes to PGFileStore and returns the resulting record.
"""
file_name = f"{file_origin.name.lower()}_{identifier}"
lobj_oid = create_populate_lobj(BytesIO(raw_bytes), db_session)
pgfilestore = upsert_pgfilestore(
file_name=file_name,
display_name=display_name,
file_origin=file_origin,
file_type=media_type,
lobj_oid=lobj_oid,
db_session=db_session,
commit=True,
)
return pgfilestore

View File

@@ -14,7 +14,6 @@ from onyx.configs.model_configs import OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
from onyx.configs.model_configs import OLD_DEFAULT_MODEL_NORMALIZE_EMBEDDINGS
from onyx.context.search.models import SavedSearchSettings
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.enums import EmbeddingPrecision
from onyx.db.llm import fetch_embedding_provider
from onyx.db.models import CloudEmbeddingProvider
from onyx.db.models import IndexAttempt
@@ -60,15 +59,12 @@ def create_search_settings(
index_name=search_settings.index_name,
provider_type=search_settings.provider_type,
multipass_indexing=search_settings.multipass_indexing,
embedding_precision=search_settings.embedding_precision,
reduced_dimension=search_settings.reduced_dimension,
multilingual_expansion=search_settings.multilingual_expansion,
disable_rerank_for_streaming=search_settings.disable_rerank_for_streaming,
rerank_model_name=search_settings.rerank_model_name,
rerank_provider_type=search_settings.rerank_provider_type,
rerank_api_key=search_settings.rerank_api_key,
num_rerank=search_settings.num_rerank,
background_reindex_enabled=search_settings.background_reindex_enabled,
)
db_session.add(embedding_model)
@@ -309,7 +305,6 @@ def get_old_default_embedding_model() -> IndexingSetting:
model_dim=(
DOC_EMBEDDING_DIM if is_overridden else OLD_DEFAULT_MODEL_DOC_EMBEDDING_DIM
),
embedding_precision=(EmbeddingPrecision.FLOAT),
normalize=(
NORMALIZE_EMBEDDINGS
if is_overridden
@@ -327,7 +322,6 @@ def get_new_default_embedding_model() -> IndexingSetting:
return IndexingSetting(
model_name=DOCUMENT_ENCODER_MODEL,
model_dim=DOC_EMBEDDING_DIM,
embedding_precision=(EmbeddingPrecision.FLOAT),
normalize=NORMALIZE_EMBEDDINGS,
query_prefix=ASYM_QUERY_PREFIX,
passage_prefix=ASYM_PASSAGE_PREFIX,

View File

@@ -1,79 +0,0 @@
import random
from datetime import datetime
from datetime import timedelta
from logging import getLogger
from onyx.configs.constants import MessageType
from onyx.db.chat import create_chat_session
from onyx.db.chat import create_new_chat_message
from onyx.db.chat import get_or_create_root_message
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.models import ChatSession
logger = getLogger(__name__)
def seed_chat_history(num_sessions: int, num_messages: int, days: int) -> None:
"""Utility function to seed chat history for testing.
num_sessions: the number of sessions to seed
num_messages: the number of messages to seed per sessions
days: the number of days looking backwards from the current time over which to randomize
the times.
"""
with get_session_with_current_tenant() as db_session:
logger.info(f"Seeding {num_sessions} sessions.")
for y in range(0, num_sessions):
create_chat_session(db_session, f"pytest_session_{y}", None, None)
# randomize all session times
logger.info(f"Seeding {num_messages} messages per session.")
rows = db_session.query(ChatSession).all()
for x in range(0, len(rows)):
if x % 1024 == 0:
logger.info(f"Seeded messages for {x} sessions so far.")
row = rows[x]
row.time_created = datetime.utcnow() - timedelta(
days=random.randint(0, days)
)
row.time_updated = row.time_created + timedelta(
minutes=random.randint(0, 10)
)
root_message = get_or_create_root_message(row.id, db_session)
current_message_type = MessageType.USER
parent_message = root_message
for x in range(0, num_messages):
if current_message_type == MessageType.USER:
msg = f"pytest_message_user_{x}"
else:
msg = f"pytest_message_assistant_{x}"
chat_message = create_new_chat_message(
row.id,
parent_message,
msg,
None,
0,
current_message_type,
db_session,
)
chat_message.time_sent = row.time_created + timedelta(
minutes=random.randint(0, 10)
)
db_session.commit()
current_message_type = (
MessageType.ASSISTANT
if current_message_type == MessageType.USER
else MessageType.USER
)
parent_message = chat_message
db_session.commit()
logger.info(f"Seeded messages for {len(rows)} sessions. Finished.")

View File

@@ -8,12 +8,10 @@ from onyx.db.index_attempt import cancel_indexing_attempts_past_model
from onyx.db.index_attempt import (
count_unique_cc_pairs_with_successful_index_attempts,
)
from onyx.db.models import ConnectorCredentialPair
from onyx.db.models import SearchSettings
from onyx.db.search_settings import get_current_search_settings
from onyx.db.search_settings import get_secondary_search_settings
from onyx.db.search_settings import update_search_settings_status
from onyx.document_index.factory import get_default_document_index
from onyx.key_value_store.factory import get_kv_store
from onyx.utils.logger import setup_logger
@@ -21,49 +19,7 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
def _perform_index_swap(
db_session: Session,
current_search_settings: SearchSettings,
secondary_search_settings: SearchSettings,
all_cc_pairs: list[ConnectorCredentialPair],
) -> None:
"""Swap the indices and expire the old one."""
current_search_settings = get_current_search_settings(db_session)
update_search_settings_status(
search_settings=current_search_settings,
new_status=IndexModelStatus.PAST,
db_session=db_session,
)
update_search_settings_status(
search_settings=secondary_search_settings,
new_status=IndexModelStatus.PRESENT,
db_session=db_session,
)
if len(all_cc_pairs) > 0:
kv_store = get_kv_store()
kv_store.store(KV_REINDEX_KEY, False)
# Expire jobs for the now past index/embedding model
cancel_indexing_attempts_past_model(db_session)
# Recount aggregates
for cc_pair in all_cc_pairs:
resync_cc_pair(cc_pair, db_session=db_session)
# remove the old index from the vector db
document_index = get_default_document_index(secondary_search_settings, None)
document_index.ensure_indices_exist(
primary_embedding_dim=secondary_search_settings.final_embedding_dim,
primary_embedding_precision=secondary_search_settings.embedding_precision,
# just finished swap, no more secondary index
secondary_index_embedding_dim=None,
secondary_index_embedding_precision=None,
)
def check_and_perform_index_swap(db_session: Session) -> SearchSettings | None:
def check_index_swap(db_session: Session) -> SearchSettings | None:
"""Get count of cc-pairs and count of successful index_attempts for the
new model grouped by connector + credential, if it's the same, then assume
new index is done building. If so, swap the indices and expire the old one.
@@ -71,45 +27,52 @@ def check_and_perform_index_swap(db_session: Session) -> SearchSettings | None:
Returns None if search settings did not change, or the old search settings if they
did change.
"""
old_search_settings = None
# Default CC-pair created for Ingestion API unused here
all_cc_pairs = get_connector_credential_pairs(db_session)
cc_pair_count = max(len(all_cc_pairs) - 1, 0)
secondary_search_settings = get_secondary_search_settings(db_session)
search_settings = get_secondary_search_settings(db_session)
if not secondary_search_settings:
if not search_settings:
return None
# If the secondary search settings are not configured to reindex in the background,
# we can just swap over instantly
if not secondary_search_settings.background_reindex_enabled:
current_search_settings = get_current_search_settings(db_session)
_perform_index_swap(
db_session=db_session,
current_search_settings=current_search_settings,
secondary_search_settings=secondary_search_settings,
all_cc_pairs=all_cc_pairs,
)
return current_search_settings
unique_cc_indexings = count_unique_cc_pairs_with_successful_index_attempts(
search_settings_id=secondary_search_settings.id, db_session=db_session
search_settings_id=search_settings.id, db_session=db_session
)
# Index Attempts are cleaned up as well when the cc-pair is deleted so the logic in this
# function is correct. The unique_cc_indexings are specifically for the existing cc-pairs
old_search_settings = None
if unique_cc_indexings > cc_pair_count:
logger.error("More unique indexings than cc pairs, should not occur")
if cc_pair_count == 0 or cc_pair_count == unique_cc_indexings:
# Swap indices
current_search_settings = get_current_search_settings(db_session)
_perform_index_swap(
update_search_settings_status(
search_settings=current_search_settings,
new_status=IndexModelStatus.PAST,
db_session=db_session,
current_search_settings=current_search_settings,
secondary_search_settings=secondary_search_settings,
all_cc_pairs=all_cc_pairs,
)
old_search_settings = current_search_settings
update_search_settings_status(
search_settings=search_settings,
new_status=IndexModelStatus.PRESENT,
db_session=db_session,
)
if cc_pair_count > 0:
kv_store = get_kv_store()
kv_store.store(KV_REINDEX_KEY, False)
# Expire jobs for the now past index/embedding model
cancel_indexing_attempts_past_model(db_session)
# Recount aggregates
for cc_pair in all_cc_pairs:
resync_cc_pair(cc_pair, db_session=db_session)
old_search_settings = current_search_settings
return old_search_settings

View File

@@ -1,5 +1,6 @@
from sqlalchemy import and_
from sqlalchemy import delete
from sqlalchemy import func
from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -148,10 +149,11 @@ def delete_document_tags_for_documents__no_commit(
stmt = delete(Document__Tag).where(Document__Tag.document_id.in_(document_ids))
db_session.execute(stmt)
orphan_tags_query = select(Tag.id).where(
~db_session.query(Document__Tag.tag_id)
.filter(Document__Tag.tag_id == Tag.id)
.exists()
orphan_tags_query = (
select(Tag.id)
.outerjoin(Document__Tag, Tag.id == Document__Tag.tag_id)
.group_by(Tag.id)
.having(func.count(Document__Tag.document_id) == 0)
)
orphan_tags = db_session.execute(orphan_tags_query).scalars().all()

View File

@@ -6,7 +6,6 @@ from typing import Any
from onyx.access.models import DocumentAccess
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import InferenceChunkUncleaned
from onyx.db.enums import EmbeddingPrecision
from onyx.indexing.models import DocMetadataAwareIndexChunk
from shared_configs.model_server_models import Embedding
@@ -146,21 +145,17 @@ class Verifiable(abc.ABC):
@abc.abstractmethod
def ensure_indices_exist(
self,
primary_embedding_dim: int,
primary_embedding_precision: EmbeddingPrecision,
index_embedding_dim: int,
secondary_index_embedding_dim: int | None,
secondary_index_embedding_precision: EmbeddingPrecision | None,
) -> None:
"""
Verify that the document index exists and is consistent with the expectations in the code.
Parameters:
- primary_embedding_dim: Vector dimensionality for the vector similarity part of the search
- primary_embedding_precision: Precision of the vector similarity part of the search
- index_embedding_dim: Vector dimensionality for the vector similarity part of the search
- secondary_index_embedding_dim: Vector dimensionality of the secondary index being built
behind the scenes. The secondary index should only be built when switching
embedding models therefore this dim should be different from the primary index.
- secondary_index_embedding_precision: Precision of the vector similarity part of the secondary index
"""
raise NotImplementedError
@@ -169,7 +164,6 @@ class Verifiable(abc.ABC):
def register_multitenant_indices(
indices: list[str],
embedding_dims: list[int],
embedding_precisions: list[EmbeddingPrecision],
) -> None:
"""
Register multitenant indices with the document index.

View File

@@ -37,7 +37,7 @@ schema DANSWER_CHUNK_NAME {
summary: dynamic
}
# Title embedding (x1)
field title_embedding type tensor<EMBEDDING_PRECISION>(x[VARIABLE_DIM]) {
field title_embedding type tensor<float>(x[VARIABLE_DIM]) {
indexing: attribute | index
attribute {
distance-metric: angular
@@ -45,7 +45,7 @@ schema DANSWER_CHUNK_NAME {
}
# Content embeddings (chunk + optional mini chunks embeddings)
# "t" and "x" are arbitrary names, not special keywords
field embeddings type tensor<EMBEDDING_PRECISION>(t{},x[VARIABLE_DIM]) {
field embeddings type tensor<float>(t{},x[VARIABLE_DIM]) {
indexing: attribute | index
attribute {
distance-metric: angular
@@ -55,9 +55,6 @@ schema DANSWER_CHUNK_NAME {
field blurb type string {
indexing: summary | attribute
}
field image_file_name type string {
indexing: summary | attribute
}
# https://docs.vespa.ai/en/attributes.html potential enum store for speed, but probably not worth it
field source_type type string {
indexing: summary | attribute

Some files were not shown because too many files have changed in this diff Show More