mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-18 08:15:48 +00:00
Compare commits
42 Commits
github_lis
...
auto_refre
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
69638b4c4e | ||
|
|
8821f399f0 | ||
|
|
cebc341991 | ||
|
|
81cb98aaa7 | ||
|
|
38afc8fa3a | ||
|
|
185aa07526 | ||
|
|
3ba554843c | ||
|
|
5883336d5e | ||
|
|
0153ff6b51 | ||
|
|
2f8f0f01be | ||
|
|
a9e5ae2f11 | ||
|
|
997f40500d | ||
|
|
a918a84e7b | ||
|
|
090f3fe817 | ||
|
|
4e70f99214 | ||
|
|
ecbd4eb1ad | ||
|
|
f94d335d12 | ||
|
|
59a388ce0a | ||
|
|
9cd3cbb978 | ||
|
|
ab1b6b487e | ||
|
|
6ead9510a4 | ||
|
|
965f9e98bf | ||
|
|
426883bbf5 | ||
|
|
6ca400ced9 | ||
|
|
104c4b9f4d | ||
|
|
8b5e8bd5b9 | ||
|
|
7f7621d7c0 | ||
|
|
06dcc28d05 | ||
|
|
18df63dfd9 | ||
|
|
0d3c72acbf | ||
|
|
9217243e3e | ||
|
|
61ccba82a9 | ||
|
|
9e8eba23c3 | ||
|
|
0c29743538 | ||
|
|
08b2421947 | ||
|
|
ed518563db | ||
|
|
a32f7dc936 | ||
|
|
798e10c52f | ||
|
|
bf4983e35a | ||
|
|
b7da91e3ae | ||
|
|
29382656fc | ||
|
|
7d6db8d500 |
@@ -12,29 +12,40 @@ env:
|
||||
BUILDKIT_PROGRESS: plain
|
||||
|
||||
jobs:
|
||||
# 1) Preliminary job to check if the changed files are relevant
|
||||
|
||||
# Bypassing this for now as the idea of not building is glitching
|
||||
# releases and builds that depends on everything being tagged in docker
|
||||
# 1) Preliminary job to check if the changed files are relevant
|
||||
# check_model_server_changes:
|
||||
# runs-on: ubuntu-latest
|
||||
# outputs:
|
||||
# changed: ${{ steps.check.outputs.changed }}
|
||||
# steps:
|
||||
# - name: Checkout code
|
||||
# uses: actions/checkout@v4
|
||||
#
|
||||
# - name: Check if relevant files changed
|
||||
# id: check
|
||||
# run: |
|
||||
# # Default to "false"
|
||||
# echo "changed=false" >> $GITHUB_OUTPUT
|
||||
#
|
||||
# # Compare the previous commit (github.event.before) to the current one (github.sha)
|
||||
# # If any file in backend/model_server/** or backend/Dockerfile.model_server is changed,
|
||||
# # set changed=true
|
||||
# if git diff --name-only ${{ github.event.before }} ${{ github.sha }} \
|
||||
# | grep -E '^backend/model_server/|^backend/Dockerfile.model_server'; then
|
||||
# echo "changed=true" >> $GITHUB_OUTPUT
|
||||
# fi
|
||||
|
||||
check_model_server_changes:
|
||||
runs-on: ubuntu-latest
|
||||
outputs:
|
||||
changed: ${{ steps.check.outputs.changed }}
|
||||
changed: "true"
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Check if relevant files changed
|
||||
id: check
|
||||
run: |
|
||||
# Default to "false"
|
||||
echo "changed=false" >> $GITHUB_OUTPUT
|
||||
|
||||
# Compare the previous commit (github.event.before) to the current one (github.sha)
|
||||
# If any file in backend/model_server/** or backend/Dockerfile.model_server is changed,
|
||||
# set changed=true
|
||||
if git diff --name-only ${{ github.event.before }} ${{ github.sha }} \
|
||||
| grep -E '^backend/model_server/|^backend/Dockerfile.model_server'; then
|
||||
echo "changed=true" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Bypass check and set output
|
||||
run: echo "changed=true" >> $GITHUB_OUTPUT
|
||||
|
||||
build-amd64:
|
||||
needs: [check_model_server_changes]
|
||||
if: needs.check_model_server_changes.outputs.changed == 'true'
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
name: Connector Tests
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [main]
|
||||
schedule:
|
||||
@@ -47,11 +48,13 @@ env:
|
||||
# Gitbook
|
||||
GITBOOK_SPACE_ID: ${{ secrets.GITBOOK_SPACE_ID }}
|
||||
GITBOOK_API_KEY: ${{ secrets.GITBOOK_API_KEY }}
|
||||
# Notion
|
||||
NOTION_INTEGRATION_TOKEN: ${{ secrets.NOTION_INTEGRATION_TOKEN }}
|
||||
|
||||
jobs:
|
||||
connectors-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}"]
|
||||
|
||||
env:
|
||||
PYTHONPATH: ./backend
|
||||
@@ -76,7 +79,7 @@ jobs:
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
playwright install chromium
|
||||
playwright install-deps chromium
|
||||
|
||||
|
||||
- name: Run Tests
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
run: py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/connectors
|
||||
|
||||
@@ -114,3 +114,4 @@ To try the Onyx Enterprise Edition:
|
||||
|
||||
## 💡 Contributing
|
||||
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.
|
||||
|
||||
|
||||
@@ -5,7 +5,10 @@ Revises: f1ca58b2f2ec
|
||||
Create Date: 2025-01-29 07:48:46.784041
|
||||
|
||||
"""
|
||||
import logging
|
||||
from typing import cast
|
||||
from alembic import op
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.sql import text
|
||||
|
||||
|
||||
@@ -15,21 +18,45 @@ down_revision = "f1ca58b2f2ec"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
logger = logging.getLogger("alembic.runtime.migration")
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Get database connection
|
||||
"""Conflicts on lowercasing will result in the uppercased email getting a
|
||||
unique integer suffix when converted to lowercase."""
|
||||
|
||||
connection = op.get_bind()
|
||||
|
||||
# Update all user emails to lowercase
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE "user"
|
||||
SET email = LOWER(email)
|
||||
WHERE email != LOWER(email)
|
||||
"""
|
||||
)
|
||||
)
|
||||
# Fetch all user emails that are not already lowercase
|
||||
user_emails = connection.execute(
|
||||
text('SELECT id, email FROM "user" WHERE email != LOWER(email)')
|
||||
).fetchall()
|
||||
|
||||
for user_id, email in user_emails:
|
||||
email = cast(str, email)
|
||||
username, domain = email.rsplit("@", 1)
|
||||
new_email = f"{username.lower()}@{domain.lower()}"
|
||||
attempt = 1
|
||||
|
||||
while True:
|
||||
try:
|
||||
# Try updating the email
|
||||
connection.execute(
|
||||
text('UPDATE "user" SET email = :new_email WHERE id = :user_id'),
|
||||
{"new_email": new_email, "user_id": user_id},
|
||||
)
|
||||
break # Success, exit loop
|
||||
except IntegrityError:
|
||||
next_email = f"{username.lower()}_{attempt}@{domain.lower()}"
|
||||
# Email conflict occurred, append `_1`, `_2`, etc., to the username
|
||||
logger.warning(
|
||||
f"Conflict while lowercasing email: "
|
||||
f"old_email={email} "
|
||||
f"conflicting_email={new_email} "
|
||||
f"next_email={next_email}"
|
||||
)
|
||||
new_email = next_email
|
||||
attempt += 1
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
|
||||
@@ -0,0 +1,45 @@
|
||||
"""add_default_vision_provider_to_llm_provider
|
||||
|
||||
Revision ID: df46c75b714e
|
||||
Revises: 3934b1bc7b62
|
||||
Create Date: 2025-03-11 16:20:19.038945
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "df46c75b714e"
|
||||
down_revision = "3934b1bc7b62"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column(
|
||||
"llm_provider",
|
||||
sa.Column(
|
||||
"is_default_vision_provider",
|
||||
sa.Boolean(),
|
||||
nullable=True,
|
||||
server_default=sa.false(),
|
||||
),
|
||||
)
|
||||
op.add_column(
|
||||
"llm_provider", sa.Column("default_vision_model", sa.String(), nullable=True)
|
||||
)
|
||||
# Add unique constraint for is_default_vision_provider
|
||||
op.create_unique_constraint(
|
||||
"uq_llm_provider_is_default_vision_provider",
|
||||
"llm_provider",
|
||||
["is_default_vision_provider"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_constraint(
|
||||
"uq_llm_provider_is_default_vision_provider", "llm_provider", type_="unique"
|
||||
)
|
||||
op.drop_column("llm_provider", "default_vision_model")
|
||||
op.drop_column("llm_provider", "is_default_vision_provider")
|
||||
@@ -0,0 +1,33 @@
|
||||
"""add new available tenant table
|
||||
|
||||
Revision ID: 3b45e0018bf1
|
||||
Revises: ac842f85f932
|
||||
Create Date: 2025-03-06 09:55:18.229910
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "3b45e0018bf1"
|
||||
down_revision = "ac842f85f932"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create new_available_tenant table
|
||||
op.create_table(
|
||||
"available_tenant",
|
||||
sa.Column("tenant_id", sa.String(), nullable=False),
|
||||
sa.Column("alembic_version", sa.String(), nullable=False),
|
||||
sa.Column("date_created", sa.DateTime(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("tenant_id"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop new_available_tenant table
|
||||
op.drop_table("available_tenant")
|
||||
@@ -0,0 +1,51 @@
|
||||
"""new column user tenant mapping
|
||||
|
||||
Revision ID: ac842f85f932
|
||||
Revises: 34e3630c7f32
|
||||
Create Date: 2025-03-03 13:30:14.802874
|
||||
|
||||
"""
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "ac842f85f932"
|
||||
down_revision = "34e3630c7f32"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add active column with default value of True
|
||||
op.add_column(
|
||||
"user_tenant_mapping",
|
||||
sa.Column(
|
||||
"active",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default="true",
|
||||
),
|
||||
schema="public",
|
||||
)
|
||||
|
||||
op.drop_constraint("uq_email", "user_tenant_mapping", schema="public")
|
||||
|
||||
# Create a unique index for active=true records
|
||||
# This ensures a user can only be active in one tenant at a time
|
||||
op.execute(
|
||||
"CREATE UNIQUE INDEX uq_user_active_email_idx ON public.user_tenant_mapping (email) WHERE active = true"
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Drop the unique index for active=true records
|
||||
op.execute("DROP INDEX IF EXISTS uq_user_active_email_idx")
|
||||
|
||||
op.create_unique_constraint(
|
||||
"uq_email", "user_tenant_mapping", ["email"], schema="public"
|
||||
)
|
||||
|
||||
# Remove the active column
|
||||
op.drop_column("user_tenant_mapping", "active", schema="public")
|
||||
@@ -27,6 +27,8 @@ def get_empty_chat_messages_entries__paginated(
|
||||
first element is the most recent timestamp out of the sessions iterated
|
||||
- this timestamp can be used to paginate forward in time
|
||||
second element is a list of messages belonging to all the sessions iterated
|
||||
|
||||
Only messages of type USER are returned
|
||||
"""
|
||||
chat_sessions = fetch_chat_sessions_eagerly_by_time(
|
||||
start=period[0],
|
||||
|
||||
@@ -64,7 +64,15 @@ def get_application() -> FastAPI:
|
||||
add_tenant_id_middleware(application, logger)
|
||||
|
||||
if AUTH_TYPE == AuthType.CLOUD:
|
||||
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
|
||||
# For Google OAuth, refresh tokens are requested by:
|
||||
# 1. Adding the right scopes
|
||||
# 2. Properly configuring OAuth in Google Cloud Console to allow offline access
|
||||
oauth_client = GoogleOAuth2(
|
||||
OAUTH_CLIENT_ID,
|
||||
OAUTH_CLIENT_SECRET,
|
||||
# Use standard scopes that include profile and email
|
||||
scopes=["openid", "email", "profile"],
|
||||
)
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
create_onyx_oauth_router(
|
||||
@@ -87,6 +95,16 @@ def get_application() -> FastAPI:
|
||||
)
|
||||
|
||||
if AUTH_TYPE == AuthType.OIDC:
|
||||
# Ensure we request offline_access for refresh tokens
|
||||
try:
|
||||
oidc_scopes = list(OIDC_SCOPE_OVERRIDE or BASE_SCOPES)
|
||||
if "offline_access" not in oidc_scopes:
|
||||
oidc_scopes.append("offline_access")
|
||||
except Exception as e:
|
||||
logger.warning(f"Error configuring OIDC scopes: {e}")
|
||||
# Fall back to default scopes if there's an error
|
||||
oidc_scopes = BASE_SCOPES
|
||||
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
create_onyx_oauth_router(
|
||||
@@ -94,8 +112,8 @@ def get_application() -> FastAPI:
|
||||
OAUTH_CLIENT_ID,
|
||||
OAUTH_CLIENT_SECRET,
|
||||
OPENID_CONFIG_URL,
|
||||
# BASE_SCOPES is the same as not setting this
|
||||
base_scopes=OIDC_SCOPE_OVERRIDE or BASE_SCOPES,
|
||||
# Use the configured scopes
|
||||
base_scopes=oidc_scopes,
|
||||
),
|
||||
auth_backend,
|
||||
USER_AUTH_SECRET,
|
||||
|
||||
@@ -1,10 +1,14 @@
|
||||
import re
|
||||
from typing import cast
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.server.query_and_chat.models import AgentAnswer
|
||||
from ee.onyx.server.query_and_chat.models import AgentSubQuery
|
||||
from ee.onyx.server.query_and_chat.models import AgentSubQuestion
|
||||
from ee.onyx.server.query_and_chat.models import BasicCreateChatMessageRequest
|
||||
from ee.onyx.server.query_and_chat.models import (
|
||||
BasicCreateChatMessageWithHistoryRequest,
|
||||
@@ -14,13 +18,19 @@ from ee.onyx.server.query_and_chat.models import SimpleDoc
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.chat.chat_utils import combine_message_thread
|
||||
from onyx.chat.chat_utils import create_chat_chain
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import AllCitations
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.chat.models import FinalUsedContextDocsResponse
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import LLMRelevanceFilterResponse
|
||||
from onyx.chat.models import OnyxAnswerPiece
|
||||
from onyx.chat.models import QADocsResponse
|
||||
from onyx.chat.models import RefinedAnswerImprovement
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.chat.models import SubQueryPiece
|
||||
from onyx.chat.models import SubQuestionIdentifier
|
||||
from onyx.chat.models import SubQuestionPiece
|
||||
from onyx.chat.process_message import ChatPacketStream
|
||||
from onyx.chat.process_message import stream_chat_message_objects
|
||||
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
@@ -89,6 +99,12 @@ def _convert_packet_stream_to_response(
|
||||
final_context_docs: list[LlmDoc] = []
|
||||
|
||||
answer = ""
|
||||
|
||||
# accumulate stream data with these dicts
|
||||
agent_sub_questions: dict[tuple[int, int], AgentSubQuestion] = {}
|
||||
agent_answers: dict[tuple[int, int], AgentAnswer] = {}
|
||||
agent_sub_queries: dict[tuple[int, int, int], AgentSubQuery] = {}
|
||||
|
||||
for packet in packets:
|
||||
if isinstance(packet, OnyxAnswerPiece) and packet.answer_piece:
|
||||
answer += packet.answer_piece
|
||||
@@ -97,6 +113,15 @@ def _convert_packet_stream_to_response(
|
||||
|
||||
# TODO: deprecate `simple_search_docs`
|
||||
response.simple_search_docs = _translate_doc_response_to_simple_doc(packet)
|
||||
|
||||
# This is a no-op if agent_sub_questions hasn't already been filled
|
||||
if packet.level is not None and packet.level_question_num is not None:
|
||||
id = (packet.level, packet.level_question_num)
|
||||
if id in agent_sub_questions:
|
||||
agent_sub_questions[id].document_ids = [
|
||||
saved_search_doc.document_id
|
||||
for saved_search_doc in packet.top_documents
|
||||
]
|
||||
elif isinstance(packet, StreamingError):
|
||||
response.error_msg = packet.error
|
||||
elif isinstance(packet, ChatMessageDetail):
|
||||
@@ -113,11 +138,104 @@ def _convert_packet_stream_to_response(
|
||||
citation.citation_num: citation.document_id
|
||||
for citation in packet.citations
|
||||
}
|
||||
# agentic packets
|
||||
elif isinstance(packet, SubQuestionPiece):
|
||||
if packet.level is not None and packet.level_question_num is not None:
|
||||
id = (packet.level, packet.level_question_num)
|
||||
if agent_sub_questions.get(id) is None:
|
||||
agent_sub_questions[id] = AgentSubQuestion(
|
||||
level=packet.level,
|
||||
level_question_num=packet.level_question_num,
|
||||
sub_question=packet.sub_question,
|
||||
document_ids=[],
|
||||
)
|
||||
else:
|
||||
agent_sub_questions[id].sub_question += packet.sub_question
|
||||
|
||||
elif isinstance(packet, AgentAnswerPiece):
|
||||
if packet.level is not None and packet.level_question_num is not None:
|
||||
id = (packet.level, packet.level_question_num)
|
||||
if agent_answers.get(id) is None:
|
||||
agent_answers[id] = AgentAnswer(
|
||||
level=packet.level,
|
||||
level_question_num=packet.level_question_num,
|
||||
answer=packet.answer_piece,
|
||||
answer_type=packet.answer_type,
|
||||
)
|
||||
else:
|
||||
agent_answers[id].answer += packet.answer_piece
|
||||
elif isinstance(packet, SubQueryPiece):
|
||||
if packet.level is not None and packet.level_question_num is not None:
|
||||
sub_query_id = (
|
||||
packet.level,
|
||||
packet.level_question_num,
|
||||
packet.query_id,
|
||||
)
|
||||
if agent_sub_queries.get(sub_query_id) is None:
|
||||
agent_sub_queries[sub_query_id] = AgentSubQuery(
|
||||
level=packet.level,
|
||||
level_question_num=packet.level_question_num,
|
||||
sub_query=packet.sub_query,
|
||||
query_id=packet.query_id,
|
||||
)
|
||||
else:
|
||||
agent_sub_queries[sub_query_id].sub_query += packet.sub_query
|
||||
elif isinstance(packet, ExtendedToolResponse):
|
||||
# we shouldn't get this ... it gets intercepted and translated to QADocsResponse
|
||||
logger.warning(
|
||||
"_convert_packet_stream_to_response: Unexpected chat packet type ExtendedToolResponse!"
|
||||
)
|
||||
elif isinstance(packet, RefinedAnswerImprovement):
|
||||
response.agent_refined_answer_improvement = (
|
||||
packet.refined_answer_improvement
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"_convert_packet_stream_to_response - Unrecognized chat packet: type={type(packet)}"
|
||||
)
|
||||
|
||||
response.final_context_doc_indices = _get_final_context_doc_indices(
|
||||
final_context_docs, response.top_documents
|
||||
)
|
||||
|
||||
# organize / sort agent metadata for output
|
||||
if len(agent_sub_questions) > 0:
|
||||
response.agent_sub_questions = cast(
|
||||
dict[int, list[AgentSubQuestion]],
|
||||
SubQuestionIdentifier.make_dict_by_level(agent_sub_questions),
|
||||
)
|
||||
|
||||
if len(agent_answers) > 0:
|
||||
# return the agent_level_answer from the first level or the last one depending
|
||||
# on agent_refined_answer_improvement
|
||||
response.agent_answers = cast(
|
||||
dict[int, list[AgentAnswer]],
|
||||
SubQuestionIdentifier.make_dict_by_level(agent_answers),
|
||||
)
|
||||
if response.agent_answers:
|
||||
selected_answer_level = (
|
||||
0
|
||||
if not response.agent_refined_answer_improvement
|
||||
else len(response.agent_answers) - 1
|
||||
)
|
||||
level_answers = response.agent_answers[selected_answer_level]
|
||||
for level_answer in level_answers:
|
||||
if level_answer.answer_type != "agent_level_answer":
|
||||
continue
|
||||
|
||||
answer = level_answer.answer
|
||||
break
|
||||
|
||||
if len(agent_sub_queries) > 0:
|
||||
# subqueries are often emitted with trailing whitespace ... clean it up here
|
||||
# perhaps fix at the source?
|
||||
for v in agent_sub_queries.values():
|
||||
v.sub_query = v.sub_query.strip()
|
||||
|
||||
response.agent_sub_queries = (
|
||||
AgentSubQuery.make_dict_by_level_and_question_index(agent_sub_queries)
|
||||
)
|
||||
|
||||
response.answer = answer
|
||||
if answer:
|
||||
response.answer_citationless = remove_answer_citations(answer)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from collections import OrderedDict
|
||||
from typing import Literal
|
||||
from uuid import UUID
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -9,6 +11,7 @@ from onyx.chat.models import CitationInfo
|
||||
from onyx.chat.models import OnyxContexts
|
||||
from onyx.chat.models import PersonaOverrideConfig
|
||||
from onyx.chat.models import QADocsResponse
|
||||
from onyx.chat.models import SubQuestionIdentifier
|
||||
from onyx.chat.models import ThreadMessage
|
||||
from onyx.configs.constants import DocumentSource
|
||||
from onyx.context.search.enums import LLMEvaluationType
|
||||
@@ -88,6 +91,64 @@ class SimpleDoc(BaseModel):
|
||||
metadata: dict | None
|
||||
|
||||
|
||||
class AgentSubQuestion(SubQuestionIdentifier):
|
||||
sub_question: str
|
||||
document_ids: list[str]
|
||||
|
||||
|
||||
class AgentAnswer(SubQuestionIdentifier):
|
||||
answer: str
|
||||
answer_type: Literal["agent_sub_answer", "agent_level_answer"]
|
||||
|
||||
|
||||
class AgentSubQuery(SubQuestionIdentifier):
|
||||
sub_query: str
|
||||
query_id: int
|
||||
|
||||
@staticmethod
|
||||
def make_dict_by_level_and_question_index(
|
||||
original_dict: dict[tuple[int, int, int], "AgentSubQuery"]
|
||||
) -> dict[int, dict[int, list["AgentSubQuery"]]]:
|
||||
"""Takes a dict of tuple(level, question num, query_id) to sub queries.
|
||||
|
||||
returns a dict of level to dict[question num to list of query_id's]
|
||||
Ordering is asc for readability.
|
||||
"""
|
||||
# In this function, when we sort int | None, we deliberately push None to the end
|
||||
|
||||
# map entries to the level_question_dict
|
||||
level_question_dict: dict[int, dict[int, list["AgentSubQuery"]]] = {}
|
||||
for k1, obj in original_dict.items():
|
||||
level = k1[0]
|
||||
question = k1[1]
|
||||
|
||||
if level not in level_question_dict:
|
||||
level_question_dict[level] = {}
|
||||
|
||||
if question not in level_question_dict[level]:
|
||||
level_question_dict[level][question] = []
|
||||
|
||||
level_question_dict[level][question].append(obj)
|
||||
|
||||
# sort each query_id list and question_index
|
||||
for key1, obj1 in level_question_dict.items():
|
||||
for key2, value2 in obj1.items():
|
||||
# sort the query_id list of each question_index
|
||||
level_question_dict[key1][key2] = sorted(
|
||||
value2, key=lambda o: o.query_id
|
||||
)
|
||||
# sort the question_index dict of level
|
||||
level_question_dict[key1] = OrderedDict(
|
||||
sorted(level_question_dict[key1].items(), key=lambda x: (x is None, x))
|
||||
)
|
||||
|
||||
# sort the top dict of levels
|
||||
sorted_dict = OrderedDict(
|
||||
sorted(level_question_dict.items(), key=lambda x: (x is None, x))
|
||||
)
|
||||
return sorted_dict
|
||||
|
||||
|
||||
class ChatBasicResponse(BaseModel):
|
||||
# This is built piece by piece, any of these can be None as the flow could break
|
||||
answer: str | None = None
|
||||
@@ -107,6 +168,12 @@ class ChatBasicResponse(BaseModel):
|
||||
simple_search_docs: list[SimpleDoc] | None = None
|
||||
llm_chunks_indices: list[int] | None = None
|
||||
|
||||
# agentic fields
|
||||
agent_sub_questions: dict[int, list[AgentSubQuestion]] | None = None
|
||||
agent_answers: dict[int, list[AgentAnswer]] | None = None
|
||||
agent_sub_queries: dict[int, dict[int, list[AgentSubQuery]]] | None = None
|
||||
agent_refined_answer_improvement: bool | None = None
|
||||
|
||||
|
||||
class OneShotQARequest(ChunkContext):
|
||||
# Supports simplier APIs that don't deal with chat histories or message edits
|
||||
|
||||
@@ -48,10 +48,15 @@ def fetch_and_process_chat_session_history(
|
||||
feedback_type: QAFeedbackType | None,
|
||||
limit: int | None = 500,
|
||||
) -> list[ChatSessionSnapshot]:
|
||||
# observed to be slow a scale of 8192 sessions and 4 messages per session
|
||||
|
||||
# this is a little slow (5 seconds)
|
||||
chat_sessions = fetch_chat_sessions_eagerly_by_time(
|
||||
start=start, end=end, db_session=db_session, limit=limit
|
||||
)
|
||||
|
||||
# this is VERY slow (80 seconds) due to create_chat_chain being called
|
||||
# for each session. Needs optimizing.
|
||||
chat_session_snapshots = [
|
||||
snapshot_from_chat_session(chat_session=chat_session, db_session=db_session)
|
||||
for chat_session in chat_sessions
|
||||
@@ -246,6 +251,8 @@ def get_query_history_as_csv(
|
||||
detail="Query history has been disabled by the administrator.",
|
||||
)
|
||||
|
||||
# this call is very expensive and is timing out via endpoint
|
||||
# TODO: optimize call and/or generate via background task
|
||||
complete_chat_session_history = fetch_and_process_chat_session_history(
|
||||
db_session=db_session,
|
||||
start=start or datetime.fromtimestamp(0, tz=timezone.utc),
|
||||
|
||||
45
backend/ee/onyx/server/tenants/admin_api.py
Normal file
45
backend/ee/onyx/server/tenants/admin_api.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Response
|
||||
|
||||
from ee.onyx.auth.users import current_cloud_superuser
|
||||
from ee.onyx.server.tenants.models import ImpersonateRequest
|
||||
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
|
||||
from onyx.auth.users import auth_backend
|
||||
from onyx.auth.users import get_redis_strategy
|
||||
from onyx.auth.users import User
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/tenants")
|
||||
|
||||
|
||||
@router.post("/impersonate")
|
||||
async def impersonate_user(
|
||||
impersonate_request: ImpersonateRequest,
|
||||
_: User = Depends(current_cloud_superuser),
|
||||
) -> Response:
|
||||
"""Allows a cloud superuser to impersonate another user by generating an impersonation JWT token"""
|
||||
tenant_id = get_tenant_id_for_email(impersonate_request.email)
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as tenant_session:
|
||||
user_to_impersonate = get_user_by_email(
|
||||
impersonate_request.email, tenant_session
|
||||
)
|
||||
if user_to_impersonate is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
token = await get_redis_strategy().write_token(user_to_impersonate)
|
||||
|
||||
response = await auth_backend.transport.get_login_response(token)
|
||||
response.set_cookie(
|
||||
key="fastapiusersauth",
|
||||
value=token,
|
||||
httponly=True,
|
||||
secure=True,
|
||||
samesite="lax",
|
||||
)
|
||||
return response
|
||||
98
backend/ee/onyx/server/tenants/anonymous_users_api.py
Normal file
98
backend/ee/onyx/server/tenants/anonymous_users_api.py
Normal file
@@ -0,0 +1,98 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Response
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from ee.onyx.auth.users import generate_anonymous_user_jwt_token
|
||||
from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME
|
||||
from ee.onyx.server.tenants.anonymous_user_path import get_anonymous_user_path
|
||||
from ee.onyx.server.tenants.anonymous_user_path import (
|
||||
get_tenant_id_for_anonymous_user_path,
|
||||
)
|
||||
from ee.onyx.server.tenants.anonymous_user_path import modify_anonymous_user_path
|
||||
from ee.onyx.server.tenants.anonymous_user_path import validate_anonymous_user_path
|
||||
from ee.onyx.server.tenants.models import AnonymousUserPath
|
||||
from onyx.auth.users import anonymous_user_enabled
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import optional_user
|
||||
from onyx.auth.users import User
|
||||
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
|
||||
from onyx.db.engine import get_session_with_shared_schema
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/tenants")
|
||||
|
||||
|
||||
@router.get("/anonymous-user-path")
|
||||
async def get_anonymous_user_path_api(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> AnonymousUserPath:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
if tenant_id is None:
|
||||
raise HTTPException(status_code=404, detail="Tenant not found")
|
||||
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
current_path = get_anonymous_user_path(tenant_id, db_session)
|
||||
|
||||
return AnonymousUserPath(anonymous_user_path=current_path)
|
||||
|
||||
|
||||
@router.post("/anonymous-user-path")
|
||||
async def set_anonymous_user_path_api(
|
||||
anonymous_user_path: str,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
try:
|
||||
validate_anonymous_user_path(anonymous_user_path)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
try:
|
||||
modify_anonymous_user_path(tenant_id, anonymous_user_path, db_session)
|
||||
except IntegrityError:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="The anonymous user path is already in use. Please choose a different path.",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to modify anonymous user path: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="An unexpected error occurred while modifying the anonymous user path",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/anonymous-user")
|
||||
async def login_as_anonymous_user(
|
||||
anonymous_user_path: str,
|
||||
_: User | None = Depends(optional_user),
|
||||
) -> Response:
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
tenant_id = get_tenant_id_for_anonymous_user_path(
|
||||
anonymous_user_path, db_session
|
||||
)
|
||||
if not tenant_id:
|
||||
raise HTTPException(status_code=404, detail="Tenant not found")
|
||||
|
||||
if not anonymous_user_enabled(tenant_id=tenant_id):
|
||||
raise HTTPException(status_code=403, detail="Anonymous user is not enabled")
|
||||
|
||||
token = generate_anonymous_user_jwt_token(tenant_id)
|
||||
|
||||
response = Response()
|
||||
response.delete_cookie(FASTAPI_USERS_AUTH_COOKIE_NAME)
|
||||
response.set_cookie(
|
||||
key=ANONYMOUS_USER_COOKIE_NAME,
|
||||
value=token,
|
||||
httponly=True,
|
||||
secure=True,
|
||||
samesite="strict",
|
||||
)
|
||||
return response
|
||||
@@ -1,269 +1,24 @@
|
||||
import stripe
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Response
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.auth.users import current_cloud_superuser
|
||||
from ee.onyx.auth.users import generate_anonymous_user_jwt_token
|
||||
from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME
|
||||
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
|
||||
from ee.onyx.server.tenants.access import control_plane_dep
|
||||
from ee.onyx.server.tenants.anonymous_user_path import get_anonymous_user_path
|
||||
from ee.onyx.server.tenants.anonymous_user_path import (
|
||||
get_tenant_id_for_anonymous_user_path,
|
||||
from ee.onyx.server.tenants.admin_api import router as admin_router
|
||||
from ee.onyx.server.tenants.anonymous_users_api import router as anonymous_users_router
|
||||
from ee.onyx.server.tenants.billing_api import router as billing_router
|
||||
from ee.onyx.server.tenants.team_membership_api import router as team_membership_router
|
||||
from ee.onyx.server.tenants.tenant_management_api import (
|
||||
router as tenant_management_router,
|
||||
)
|
||||
from ee.onyx.server.tenants.user_invitations_api import (
|
||||
router as user_invitations_router,
|
||||
)
|
||||
from ee.onyx.server.tenants.anonymous_user_path import modify_anonymous_user_path
|
||||
from ee.onyx.server.tenants.anonymous_user_path import validate_anonymous_user_path
|
||||
from ee.onyx.server.tenants.billing import fetch_billing_information
|
||||
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
|
||||
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
|
||||
from ee.onyx.server.tenants.models import AnonymousUserPath
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
from ee.onyx.server.tenants.models import ImpersonateRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingResponse
|
||||
from ee.onyx.server.tenants.models import SubscriptionSessionResponse
|
||||
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
|
||||
from ee.onyx.server.tenants.product_gating import store_product_gating
|
||||
from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane
|
||||
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
|
||||
from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
|
||||
from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant
|
||||
from onyx.auth.users import anonymous_user_enabled
|
||||
from onyx.auth.users import auth_backend
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import get_redis_strategy
|
||||
from onyx.auth.users import optional_user
|
||||
from onyx.auth.users import User
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
|
||||
from onyx.db.auth import get_user_count
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.engine import get_session_with_shared_schema
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.users import delete_user_from_db
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.server.manage.models import UserByEmail
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
stripe.api_key = STRIPE_SECRET_KEY
|
||||
logger = setup_logger()
|
||||
router = APIRouter(prefix="/tenants")
|
||||
# Create a main router to include all sub-routers
|
||||
# Note: We don't add a prefix here as each router already has the /tenants prefix
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/anonymous-user-path")
|
||||
async def get_anonymous_user_path_api(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> AnonymousUserPath:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
if tenant_id is None:
|
||||
raise HTTPException(status_code=404, detail="Tenant not found")
|
||||
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
current_path = get_anonymous_user_path(tenant_id, db_session)
|
||||
|
||||
return AnonymousUserPath(anonymous_user_path=current_path)
|
||||
|
||||
|
||||
@router.post("/anonymous-user-path")
|
||||
async def set_anonymous_user_path_api(
|
||||
anonymous_user_path: str,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
try:
|
||||
validate_anonymous_user_path(anonymous_user_path)
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
try:
|
||||
modify_anonymous_user_path(tenant_id, anonymous_user_path, db_session)
|
||||
except IntegrityError:
|
||||
raise HTTPException(
|
||||
status_code=409,
|
||||
detail="The anonymous user path is already in use. Please choose a different path.",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to modify anonymous user path: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="An unexpected error occurred while modifying the anonymous user path",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/anonymous-user")
|
||||
async def login_as_anonymous_user(
|
||||
anonymous_user_path: str,
|
||||
_: User | None = Depends(optional_user),
|
||||
) -> Response:
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
tenant_id = get_tenant_id_for_anonymous_user_path(
|
||||
anonymous_user_path, db_session
|
||||
)
|
||||
if not tenant_id:
|
||||
raise HTTPException(status_code=404, detail="Tenant not found")
|
||||
|
||||
if not anonymous_user_enabled(tenant_id=tenant_id):
|
||||
raise HTTPException(status_code=403, detail="Anonymous user is not enabled")
|
||||
|
||||
token = generate_anonymous_user_jwt_token(tenant_id)
|
||||
|
||||
response = Response()
|
||||
response.delete_cookie(FASTAPI_USERS_AUTH_COOKIE_NAME)
|
||||
response.set_cookie(
|
||||
key=ANONYMOUS_USER_COOKIE_NAME,
|
||||
value=token,
|
||||
httponly=True,
|
||||
secure=True,
|
||||
samesite="strict",
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/product-gating")
|
||||
def gate_product(
|
||||
product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep)
|
||||
) -> ProductGatingResponse:
|
||||
"""
|
||||
Gating the product means that the product is not available to the tenant.
|
||||
They will be directed to the billing page.
|
||||
We gate the product when their subscription has ended.
|
||||
"""
|
||||
try:
|
||||
store_product_gating(
|
||||
product_gating_request.tenant_id, product_gating_request.application_status
|
||||
)
|
||||
return ProductGatingResponse(updated=True, error=None)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to gate product")
|
||||
return ProductGatingResponse(updated=False, error=str(e))
|
||||
|
||||
|
||||
@router.get("/billing-information")
|
||||
async def billing_information(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> BillingInformation | SubscriptionStatusResponse:
|
||||
logger.info("Fetching billing information")
|
||||
tenant_id = get_current_tenant_id()
|
||||
return fetch_billing_information(tenant_id)
|
||||
|
||||
|
||||
@router.post("/create-customer-portal-session")
|
||||
async def create_customer_portal_session(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> dict:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
try:
|
||||
stripe_info = fetch_tenant_stripe_information(tenant_id)
|
||||
stripe_customer_id = stripe_info.get("stripe_customer_id")
|
||||
if not stripe_customer_id:
|
||||
raise HTTPException(status_code=400, detail="Stripe customer ID not found")
|
||||
logger.info(stripe_customer_id)
|
||||
|
||||
portal_session = stripe.billing_portal.Session.create(
|
||||
customer=stripe_customer_id,
|
||||
return_url=f"{WEB_DOMAIN}/admin/billing",
|
||||
)
|
||||
logger.info(portal_session)
|
||||
return {"url": portal_session.url}
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create customer portal session")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/create-subscription-session")
|
||||
async def create_subscription_session(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> SubscriptionSessionResponse:
|
||||
try:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if not tenant_id:
|
||||
raise HTTPException(status_code=400, detail="Tenant ID not found")
|
||||
session_id = fetch_stripe_checkout_session(tenant_id)
|
||||
return SubscriptionSessionResponse(sessionId=session_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create resubscription session")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/impersonate")
|
||||
async def impersonate_user(
|
||||
impersonate_request: ImpersonateRequest,
|
||||
_: User = Depends(current_cloud_superuser),
|
||||
) -> Response:
|
||||
"""Allows a cloud superuser to impersonate another user by generating an impersonation JWT token"""
|
||||
tenant_id = get_tenant_id_for_email(impersonate_request.email)
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as tenant_session:
|
||||
user_to_impersonate = get_user_by_email(
|
||||
impersonate_request.email, tenant_session
|
||||
)
|
||||
if user_to_impersonate is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
token = await get_redis_strategy().write_token(user_to_impersonate)
|
||||
|
||||
response = await auth_backend.transport.get_login_response(token)
|
||||
response.set_cookie(
|
||||
key="fastapiusersauth",
|
||||
value=token,
|
||||
httponly=True,
|
||||
secure=True,
|
||||
samesite="lax",
|
||||
)
|
||||
return response
|
||||
|
||||
|
||||
@router.post("/leave-organization")
|
||||
async def leave_organization(
|
||||
user_email: UserByEmail,
|
||||
current_user: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
if current_user is None or current_user.email != user_email.user_email:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You can only leave the organization as yourself"
|
||||
)
|
||||
|
||||
user_to_delete = get_user_by_email(user_email.user_email, db_session)
|
||||
if user_to_delete is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
num_admin_users = await get_user_count(only_admin_users=True)
|
||||
|
||||
should_delete_tenant = num_admin_users == 1
|
||||
|
||||
if should_delete_tenant:
|
||||
logger.info(
|
||||
"Last admin user is leaving the organization. Deleting tenant from control plane."
|
||||
)
|
||||
try:
|
||||
await delete_user_from_control_plane(tenant_id, user_to_delete.email)
|
||||
logger.debug("User deleted from control plane")
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to delete user from control plane for tenant {tenant_id}: {e}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to remove user from control plane: {str(e)}",
|
||||
)
|
||||
|
||||
db_session.expunge(user_to_delete)
|
||||
delete_user_from_db(user_to_delete, db_session)
|
||||
|
||||
if should_delete_tenant:
|
||||
remove_all_users_from_tenant(tenant_id)
|
||||
else:
|
||||
remove_users_from_tenant([user_to_delete.email], tenant_id)
|
||||
# Include all the individual routers
|
||||
router.include_router(admin_router)
|
||||
router.include_router(anonymous_users_router)
|
||||
router.include_router(billing_router)
|
||||
router.include_router(team_membership_router)
|
||||
router.include_router(tenant_management_router)
|
||||
router.include_router(user_invitations_router)
|
||||
|
||||
96
backend/ee/onyx/server/tenants/billing_api.py
Normal file
96
backend/ee/onyx/server/tenants/billing_api.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import stripe
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.auth.users import current_admin_user
|
||||
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
|
||||
from ee.onyx.server.tenants.access import control_plane_dep
|
||||
from ee.onyx.server.tenants.billing import fetch_billing_information
|
||||
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
|
||||
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
|
||||
from ee.onyx.server.tenants.models import BillingInformation
|
||||
from ee.onyx.server.tenants.models import ProductGatingRequest
|
||||
from ee.onyx.server.tenants.models import ProductGatingResponse
|
||||
from ee.onyx.server.tenants.models import SubscriptionSessionResponse
|
||||
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
|
||||
from ee.onyx.server.tenants.product_gating import store_product_gating
|
||||
from onyx.auth.users import User
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
stripe.api_key = STRIPE_SECRET_KEY
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/tenants")
|
||||
|
||||
|
||||
@router.post("/product-gating")
|
||||
def gate_product(
|
||||
product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep)
|
||||
) -> ProductGatingResponse:
|
||||
"""
|
||||
Gating the product means that the product is not available to the tenant.
|
||||
They will be directed to the billing page.
|
||||
We gate the product when their subscription has ended.
|
||||
"""
|
||||
try:
|
||||
store_product_gating(
|
||||
product_gating_request.tenant_id, product_gating_request.application_status
|
||||
)
|
||||
return ProductGatingResponse(updated=True, error=None)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to gate product")
|
||||
return ProductGatingResponse(updated=False, error=str(e))
|
||||
|
||||
|
||||
@router.get("/billing-information")
|
||||
async def billing_information(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> BillingInformation | SubscriptionStatusResponse:
|
||||
logger.info("Fetching billing information")
|
||||
tenant_id = get_current_tenant_id()
|
||||
return fetch_billing_information(tenant_id)
|
||||
|
||||
|
||||
@router.post("/create-customer-portal-session")
|
||||
async def create_customer_portal_session(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> dict:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
try:
|
||||
stripe_info = fetch_tenant_stripe_information(tenant_id)
|
||||
stripe_customer_id = stripe_info.get("stripe_customer_id")
|
||||
if not stripe_customer_id:
|
||||
raise HTTPException(status_code=400, detail="Stripe customer ID not found")
|
||||
logger.info(stripe_customer_id)
|
||||
|
||||
portal_session = stripe.billing_portal.Session.create(
|
||||
customer=stripe_customer_id,
|
||||
return_url=f"{WEB_DOMAIN}/admin/billing",
|
||||
)
|
||||
logger.info(portal_session)
|
||||
return {"url": portal_session.url}
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create customer portal session")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.post("/create-subscription-session")
|
||||
async def create_subscription_session(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> SubscriptionSessionResponse:
|
||||
try:
|
||||
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
|
||||
if not tenant_id:
|
||||
raise HTTPException(status_code=400, detail="Tenant ID not found")
|
||||
session_id = fetch_stripe_checkout_session(tenant_id)
|
||||
return SubscriptionSessionResponse(sessionId=session_id)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to create resubscription session")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
@@ -67,3 +67,30 @@ class ProductGatingResponse(BaseModel):
|
||||
|
||||
class SubscriptionSessionResponse(BaseModel):
|
||||
sessionId: str
|
||||
|
||||
|
||||
class TenantByDomainResponse(BaseModel):
|
||||
tenant_id: str
|
||||
number_of_users: int
|
||||
creator_email: str
|
||||
|
||||
|
||||
class TenantByDomainRequest(BaseModel):
|
||||
email: str
|
||||
|
||||
|
||||
class RequestInviteRequest(BaseModel):
|
||||
tenant_id: str
|
||||
|
||||
|
||||
class RequestInviteResponse(BaseModel):
|
||||
success: bool
|
||||
message: str
|
||||
|
||||
|
||||
class PendingUserSnapshot(BaseModel):
|
||||
email: str
|
||||
|
||||
|
||||
class ApproveUserRequest(BaseModel):
|
||||
email: str
|
||||
|
||||
@@ -4,6 +4,7 @@ import uuid
|
||||
|
||||
import aiohttp # Async HTTP client
|
||||
import httpx
|
||||
import requests
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Request
|
||||
from sqlalchemy import select
|
||||
@@ -14,6 +15,7 @@ from ee.onyx.configs.app_configs import COHERE_DEFAULT_API_KEY
|
||||
from ee.onyx.configs.app_configs import HUBSPOT_TRACKING_URL
|
||||
from ee.onyx.configs.app_configs import OPENAI_DEFAULT_API_KEY
|
||||
from ee.onyx.server.tenants.access import generate_data_plane_token
|
||||
from ee.onyx.server.tenants.models import TenantByDomainResponse
|
||||
from ee.onyx.server.tenants.models import TenantCreationPayload
|
||||
from ee.onyx.server.tenants.models import TenantDeletionPayload
|
||||
from ee.onyx.server.tenants.schema_management import create_schema_if_not_exists
|
||||
@@ -26,11 +28,12 @@ from onyx.auth.users import exceptions
|
||||
from onyx.configs.app_configs import CONTROL_PLANE_API_BASE_URL
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.constants import MilestoneRecordType
|
||||
from onyx.db.engine import get_session_with_shared_schema
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import get_sqlalchemy_engine
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import upsert_cloud_embedding_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.models import AvailableTenant
|
||||
from onyx.db.models import IndexModelStatus
|
||||
from onyx.db.models import SearchSettings
|
||||
from onyx.db.models import UserTenantMapping
|
||||
@@ -60,42 +63,72 @@ async def get_or_provision_tenant(
|
||||
This function should only be called after we have verified we want this user's tenant to exist.
|
||||
It returns the tenant ID associated with the email, creating a new tenant if necessary.
|
||||
"""
|
||||
# Early return for non-multi-tenant mode
|
||||
if not MULTI_TENANT:
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
if referral_source and request:
|
||||
await submit_to_hubspot(email, referral_source, request)
|
||||
|
||||
# First, check if the user already has a tenant
|
||||
tenant_id: str | None = None
|
||||
try:
|
||||
tenant_id = get_tenant_id_for_email(email)
|
||||
return tenant_id
|
||||
except exceptions.UserNotExists:
|
||||
# If tenant does not exist and in Multi tenant mode, provision a new tenant
|
||||
try:
|
||||
# User doesn't exist, so we need to create a new tenant or assign an existing one
|
||||
pass
|
||||
|
||||
try:
|
||||
# Try to get a pre-provisioned tenant
|
||||
tenant_id = await get_available_tenant()
|
||||
|
||||
if tenant_id:
|
||||
# If we have a pre-provisioned tenant, assign it to the user
|
||||
await assign_tenant_to_user(tenant_id, email, referral_source)
|
||||
logger.info(f"Assigned pre-provisioned tenant {tenant_id} to user {email}")
|
||||
return tenant_id
|
||||
else:
|
||||
# If no pre-provisioned tenant is available, create a new one on-demand
|
||||
tenant_id = await create_tenant(email, referral_source)
|
||||
except Exception as e:
|
||||
logger.error(f"Tenant provisioning failed: {e}")
|
||||
raise HTTPException(status_code=500, detail="Failed to provision tenant.")
|
||||
return tenant_id
|
||||
|
||||
if not tenant_id:
|
||||
except Exception as e:
|
||||
# If we've encountered an error, log and raise an exception
|
||||
error_msg = "Failed to provision tenant"
|
||||
logger.error(error_msg, exc_info=e)
|
||||
raise HTTPException(
|
||||
status_code=401, detail="User does not belong to an organization"
|
||||
status_code=500,
|
||||
detail="Failed to provision tenant. Please try again later.",
|
||||
)
|
||||
|
||||
return tenant_id
|
||||
|
||||
|
||||
async def create_tenant(email: str, referral_source: str | None = None) -> str:
|
||||
"""
|
||||
Create a new tenant on-demand when no pre-provisioned tenants are available.
|
||||
This is the fallback method when we can't use a pre-provisioned tenant.
|
||||
|
||||
"""
|
||||
tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4())
|
||||
logger.info(f"Creating new tenant {tenant_id} for user {email}")
|
||||
|
||||
try:
|
||||
# Provision tenant on data plane
|
||||
await provision_tenant(tenant_id, email)
|
||||
# Notify control plane
|
||||
if not DEV_MODE:
|
||||
|
||||
# Notify control plane if not already done in provision_tenant
|
||||
if not DEV_MODE and referral_source:
|
||||
await notify_control_plane(tenant_id, email, referral_source)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Tenant provisioning failed: {e}")
|
||||
await rollback_tenant_provisioning(tenant_id)
|
||||
logger.exception(f"Tenant provisioning failed: {str(e)}")
|
||||
# Attempt to rollback the tenant provisioning
|
||||
try:
|
||||
await rollback_tenant_provisioning(tenant_id)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to rollback tenant provisioning for {tenant_id}")
|
||||
raise HTTPException(status_code=500, detail="Failed to provision tenant.")
|
||||
|
||||
return tenant_id
|
||||
|
||||
|
||||
@@ -109,54 +142,25 @@ async def provision_tenant(tenant_id: str, email: str) -> None:
|
||||
)
|
||||
|
||||
logger.debug(f"Provisioning tenant {tenant_id} for user {email}")
|
||||
token = None
|
||||
|
||||
try:
|
||||
# Create the schema for the tenant
|
||||
if not create_schema_if_not_exists(tenant_id):
|
||||
logger.debug(f"Created schema for tenant {tenant_id}")
|
||||
else:
|
||||
logger.debug(f"Schema already exists for tenant {tenant_id}")
|
||||
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
# Set up the tenant with all necessary configurations
|
||||
await setup_tenant(tenant_id)
|
||||
|
||||
# Await the Alembic migrations
|
||||
await asyncio.to_thread(run_alembic_migrations, tenant_id)
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
configure_default_api_keys(db_session)
|
||||
|
||||
current_search_settings = (
|
||||
db_session.query(SearchSettings)
|
||||
.filter_by(status=IndexModelStatus.FUTURE)
|
||||
.first()
|
||||
)
|
||||
cohere_enabled = (
|
||||
current_search_settings is not None
|
||||
and current_search_settings.provider_type == EmbeddingProvider.COHERE
|
||||
)
|
||||
setup_onyx(db_session, tenant_id, cohere_enabled=cohere_enabled)
|
||||
|
||||
add_users_to_tenant([email], tenant_id)
|
||||
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
create_milestone_and_report(
|
||||
user=None,
|
||||
distinct_id=tenant_id,
|
||||
event_type=MilestoneRecordType.TENANT_CREATED,
|
||||
properties={
|
||||
"email": email,
|
||||
},
|
||||
db_session=db_session,
|
||||
)
|
||||
# Assign the tenant to the user
|
||||
await assign_tenant_to_user(tenant_id, email)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to create tenant {tenant_id}")
|
||||
raise HTTPException(
|
||||
status_code=500, detail=f"Failed to create tenant: {str(e)}"
|
||||
)
|
||||
finally:
|
||||
if token is not None:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
async def notify_control_plane(
|
||||
@@ -187,20 +191,74 @@ async def notify_control_plane(
|
||||
|
||||
|
||||
async def rollback_tenant_provisioning(tenant_id: str) -> None:
|
||||
# Logic to rollback tenant provisioning on data plane
|
||||
"""
|
||||
Logic to rollback tenant provisioning on data plane.
|
||||
Handles each step independently to ensure maximum cleanup even if some steps fail.
|
||||
"""
|
||||
logger.info(f"Rolling back tenant provisioning for tenant_id: {tenant_id}")
|
||||
try:
|
||||
# Drop the tenant's schema to rollback provisioning
|
||||
drop_schema(tenant_id)
|
||||
|
||||
# Remove tenant mapping
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
db_session.query(UserTenantMapping).filter(
|
||||
UserTenantMapping.tenant_id == tenant_id
|
||||
).delete()
|
||||
db_session.commit()
|
||||
# Track if any part of the rollback fails
|
||||
rollback_errors = []
|
||||
|
||||
# 1. Try to drop the tenant's schema
|
||||
try:
|
||||
drop_schema(tenant_id)
|
||||
logger.info(f"Successfully dropped schema for tenant {tenant_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to rollback tenant provisioning: {e}")
|
||||
error_msg = f"Failed to drop schema for tenant {tenant_id}: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
rollback_errors.append(error_msg)
|
||||
|
||||
# 2. Try to remove tenant mapping
|
||||
try:
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
db_session.begin()
|
||||
try:
|
||||
db_session.query(UserTenantMapping).filter(
|
||||
UserTenantMapping.tenant_id == tenant_id
|
||||
).delete()
|
||||
db_session.commit()
|
||||
logger.info(
|
||||
f"Successfully removed user mappings for tenant {tenant_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
db_session.rollback()
|
||||
raise e
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to remove user mappings for tenant {tenant_id}: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
rollback_errors.append(error_msg)
|
||||
|
||||
# 3. If this tenant was in the available tenants table, remove it
|
||||
try:
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
db_session.begin()
|
||||
try:
|
||||
available_tenant = (
|
||||
db_session.query(AvailableTenant)
|
||||
.filter(AvailableTenant.tenant_id == tenant_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if available_tenant:
|
||||
db_session.delete(available_tenant)
|
||||
db_session.commit()
|
||||
logger.info(
|
||||
f"Removed tenant {tenant_id} from available tenants table"
|
||||
)
|
||||
except Exception as e:
|
||||
db_session.rollback()
|
||||
raise e
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to remove tenant {tenant_id} from available tenants table: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
rollback_errors.append(error_msg)
|
||||
|
||||
# Log summary of rollback operation
|
||||
if rollback_errors:
|
||||
logger.error(f"Tenant rollback completed with {len(rollback_errors)} errors")
|
||||
else:
|
||||
logger.info(f"Tenant rollback completed successfully for tenant {tenant_id}")
|
||||
|
||||
|
||||
def configure_default_api_keys(db_session: Session) -> None:
|
||||
@@ -353,3 +411,155 @@ async def delete_user_from_control_plane(tenant_id: str, email: str) -> None:
|
||||
raise Exception(
|
||||
f"Failed to delete tenant on control plane: {error_text}"
|
||||
)
|
||||
|
||||
|
||||
def get_tenant_by_domain_from_control_plane(
|
||||
domain: str,
|
||||
tenant_id: str,
|
||||
) -> TenantByDomainResponse | None:
|
||||
"""
|
||||
Fetches tenant information from the control plane based on the email domain.
|
||||
|
||||
Args:
|
||||
domain: The email domain to search for (e.g., "example.com")
|
||||
|
||||
Returns:
|
||||
A dictionary containing tenant information if found, None otherwise
|
||||
"""
|
||||
token = generate_data_plane_token()
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.get(
|
||||
f"{CONTROL_PLANE_API_BASE_URL}/tenant-by-domain",
|
||||
headers=headers,
|
||||
json={"domain": domain, "tenant_id": tenant_id},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(f"Control plane tenant lookup failed: {response.text}")
|
||||
return None
|
||||
|
||||
response_data = response.json()
|
||||
if not response_data:
|
||||
return None
|
||||
|
||||
return TenantByDomainResponse(
|
||||
tenant_id=response_data.get("tenant_id"),
|
||||
number_of_users=response_data.get("number_of_users"),
|
||||
creator_email=response_data.get("creator_email"),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error fetching tenant by domain: {str(e)}")
|
||||
return None
|
||||
|
||||
|
||||
async def get_available_tenant() -> str | None:
|
||||
"""
|
||||
Get an available pre-provisioned tenant from the NewAvailableTenant table.
|
||||
Returns the tenant_id if one is available, None otherwise.
|
||||
Uses row-level locking to prevent race conditions when multiple processes
|
||||
try to get an available tenant simultaneously.
|
||||
"""
|
||||
if not MULTI_TENANT:
|
||||
return None
|
||||
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
try:
|
||||
db_session.begin()
|
||||
|
||||
# Get the oldest available tenant with FOR UPDATE lock to prevent race conditions
|
||||
available_tenant = (
|
||||
db_session.query(AvailableTenant)
|
||||
.order_by(AvailableTenant.date_created)
|
||||
.with_for_update(skip_locked=True) # Skip locked rows to avoid blocking
|
||||
.first()
|
||||
)
|
||||
|
||||
if available_tenant:
|
||||
tenant_id = available_tenant.tenant_id
|
||||
# Remove the tenant from the available tenants table
|
||||
db_session.delete(available_tenant)
|
||||
db_session.commit()
|
||||
logger.info(f"Using pre-provisioned tenant {tenant_id}")
|
||||
return tenant_id
|
||||
else:
|
||||
db_session.rollback()
|
||||
return None
|
||||
except Exception:
|
||||
logger.exception("Error getting available tenant")
|
||||
db_session.rollback()
|
||||
return None
|
||||
|
||||
|
||||
async def setup_tenant(tenant_id: str) -> None:
|
||||
"""
|
||||
Set up a tenant with all necessary configurations.
|
||||
This is a centralized function that handles all tenant setup logic.
|
||||
"""
|
||||
token = None
|
||||
try:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
|
||||
# Run Alembic migrations
|
||||
await asyncio.to_thread(run_alembic_migrations, tenant_id)
|
||||
|
||||
# Configure the tenant with default settings
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
# Configure default API keys
|
||||
configure_default_api_keys(db_session)
|
||||
|
||||
# Set up Onyx with appropriate settings
|
||||
current_search_settings = (
|
||||
db_session.query(SearchSettings)
|
||||
.filter_by(status=IndexModelStatus.FUTURE)
|
||||
.first()
|
||||
)
|
||||
cohere_enabled = (
|
||||
current_search_settings is not None
|
||||
and current_search_settings.provider_type == EmbeddingProvider.COHERE
|
||||
)
|
||||
setup_onyx(db_session, tenant_id, cohere_enabled=cohere_enabled)
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to set up tenant {tenant_id}")
|
||||
raise e
|
||||
finally:
|
||||
if token is not None:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
async def assign_tenant_to_user(
|
||||
tenant_id: str, email: str, referral_source: str | None = None
|
||||
) -> None:
|
||||
"""
|
||||
Assign a tenant to a user and perform necessary operations.
|
||||
Uses transaction handling to ensure atomicity and includes retry logic
|
||||
for control plane notifications.
|
||||
"""
|
||||
# First, add the user to the tenant in a transaction
|
||||
|
||||
try:
|
||||
add_users_to_tenant([email], tenant_id)
|
||||
|
||||
# Create milestone record in the same transaction context as the tenant assignment
|
||||
with get_session_with_tenant(tenant_id=tenant_id) as db_session:
|
||||
create_milestone_and_report(
|
||||
user=None,
|
||||
distinct_id=tenant_id,
|
||||
event_type=MilestoneRecordType.TENANT_CREATED,
|
||||
properties={
|
||||
"email": email,
|
||||
},
|
||||
db_session=db_session,
|
||||
)
|
||||
except Exception:
|
||||
logger.exception(f"Failed to assign tenant {tenant_id} to user {email}")
|
||||
raise Exception("Failed to assign tenant to user")
|
||||
|
||||
# Notify control plane with retry logic
|
||||
if not DEV_MODE:
|
||||
await notify_control_plane(tenant_id, email, referral_source)
|
||||
|
||||
@@ -74,3 +74,21 @@ def drop_schema(tenant_id: str) -> None:
|
||||
text("DROP SCHEMA IF EXISTS %(schema_name)s CASCADE"),
|
||||
{"schema_name": tenant_id},
|
||||
)
|
||||
|
||||
|
||||
def get_current_alembic_version(tenant_id: str) -> str:
|
||||
"""Get the current Alembic version for a tenant."""
|
||||
from alembic.runtime.migration import MigrationContext
|
||||
from sqlalchemy import text
|
||||
|
||||
engine = get_sqlalchemy_engine()
|
||||
|
||||
# Set the search path to the tenant's schema
|
||||
with engine.connect() as connection:
|
||||
connection.execute(text(f'SET search_path TO "{tenant_id}"'))
|
||||
|
||||
# Get the current version from the alembic_version table
|
||||
context = MigrationContext.configure(connection)
|
||||
current_rev = context.get_current_revision()
|
||||
|
||||
return current_rev or "head"
|
||||
|
||||
67
backend/ee/onyx/server/tenants/team_membership_api.py
Normal file
67
backend/ee/onyx/server/tenants/team_membership_api.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane
|
||||
from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
|
||||
from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import User
|
||||
from onyx.db.auth import get_user_count
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.users import delete_user_from_db
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.server.manage.models import UserByEmail
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/tenants")
|
||||
|
||||
|
||||
@router.post("/leave-team")
|
||||
async def leave_organization(
|
||||
user_email: UserByEmail,
|
||||
current_user: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
if current_user is None or current_user.email != user_email.user_email:
|
||||
raise HTTPException(
|
||||
status_code=403, detail="You can only leave the organization as yourself"
|
||||
)
|
||||
|
||||
user_to_delete = get_user_by_email(user_email.user_email, db_session)
|
||||
if user_to_delete is None:
|
||||
raise HTTPException(status_code=404, detail="User not found")
|
||||
|
||||
num_admin_users = await get_user_count(only_admin_users=True)
|
||||
|
||||
should_delete_tenant = num_admin_users == 1
|
||||
|
||||
if should_delete_tenant:
|
||||
logger.info(
|
||||
"Last admin user is leaving the organization. Deleting tenant from control plane."
|
||||
)
|
||||
try:
|
||||
await delete_user_from_control_plane(tenant_id, user_to_delete.email)
|
||||
logger.debug("User deleted from control plane")
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to delete user from control plane for tenant {tenant_id}: {e}"
|
||||
)
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Failed to remove user from control plane: {str(e)}",
|
||||
)
|
||||
|
||||
db_session.expunge(user_to_delete)
|
||||
delete_user_from_db(user_to_delete, db_session)
|
||||
|
||||
if should_delete_tenant:
|
||||
remove_all_users_from_tenant(tenant_id)
|
||||
else:
|
||||
remove_users_from_tenant([user_to_delete.email], tenant_id)
|
||||
39
backend/ee/onyx/server/tenants/tenant_management_api.py
Normal file
39
backend/ee/onyx/server/tenants/tenant_management_api.py
Normal file
@@ -0,0 +1,39 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
|
||||
from ee.onyx.server.tenants.models import TenantByDomainResponse
|
||||
from ee.onyx.server.tenants.provisioning import get_tenant_by_domain_from_control_plane
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.auth.users import User
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/tenants")
|
||||
|
||||
FORBIDDEN_COMMON_EMAIL_SUBSTRINGS = [
|
||||
"gmail",
|
||||
"outlook",
|
||||
"yahoo",
|
||||
"hotmail",
|
||||
"icloud",
|
||||
"msn",
|
||||
"hotmail",
|
||||
"hotmail.co.uk",
|
||||
]
|
||||
|
||||
|
||||
@router.get("/existing-team-by-domain")
|
||||
def get_existing_tenant_by_domain(
|
||||
user: User | None = Depends(current_user),
|
||||
) -> TenantByDomainResponse | None:
|
||||
if not user:
|
||||
return None
|
||||
domain = user.email.split("@")[1]
|
||||
if any(substring in domain for substring in FORBIDDEN_COMMON_EMAIL_SUBSTRINGS):
|
||||
return None
|
||||
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
return get_tenant_by_domain_from_control_plane(domain, tenant_id)
|
||||
90
backend/ee/onyx/server/tenants/user_invitations_api.py
Normal file
90
backend/ee/onyx/server/tenants/user_invitations_api.py
Normal file
@@ -0,0 +1,90 @@
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
|
||||
from ee.onyx.server.tenants.models import ApproveUserRequest
|
||||
from ee.onyx.server.tenants.models import PendingUserSnapshot
|
||||
from ee.onyx.server.tenants.models import RequestInviteRequest
|
||||
from ee.onyx.server.tenants.user_mapping import accept_user_invite
|
||||
from ee.onyx.server.tenants.user_mapping import approve_user_invite
|
||||
from ee.onyx.server.tenants.user_mapping import deny_user_invite
|
||||
from ee.onyx.server.tenants.user_mapping import invite_self_to_tenant
|
||||
from onyx.auth.invited_users import get_pending_users
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.auth.users import User
|
||||
from onyx.utils.logger import setup_logger
|
||||
from shared_configs.contextvars import get_current_tenant_id
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
router = APIRouter(prefix="/tenants")
|
||||
|
||||
|
||||
@router.post("/users/invite/request")
|
||||
async def request_invite(
|
||||
invite_request: RequestInviteRequest,
|
||||
user: User | None = Depends(current_admin_user),
|
||||
) -> None:
|
||||
if user is None:
|
||||
raise HTTPException(status_code=401, detail="User not authenticated")
|
||||
try:
|
||||
invite_self_to_tenant(user.email, invite_request.tenant_id)
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
f"Failed to invite self to tenant {invite_request.tenant_id}: {e}"
|
||||
)
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@router.get("/users/pending")
|
||||
def list_pending_users(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> list[PendingUserSnapshot]:
|
||||
pending_emails = get_pending_users()
|
||||
return [PendingUserSnapshot(email=email) for email in pending_emails]
|
||||
|
||||
|
||||
@router.post("/users/invite/approve")
|
||||
async def approve_user(
|
||||
approve_user_request: ApproveUserRequest,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
approve_user_invite(approve_user_request.email, tenant_id)
|
||||
|
||||
|
||||
@router.post("/users/invite/accept")
|
||||
async def accept_invite(
|
||||
invite_request: RequestInviteRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
) -> None:
|
||||
"""
|
||||
Accept an invitation to join a tenant.
|
||||
"""
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
try:
|
||||
accept_user_invite(user.email, invite_request.tenant_id)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to accept invite: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Failed to accept invitation")
|
||||
|
||||
|
||||
@router.post("/users/invite/deny")
|
||||
async def deny_invite(
|
||||
invite_request: RequestInviteRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
) -> None:
|
||||
"""
|
||||
Deny an invitation to join a tenant.
|
||||
"""
|
||||
if not user:
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
|
||||
try:
|
||||
deny_user_invite(user.email, invite_request.tenant_id)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to deny invite: {str(e)}")
|
||||
raise HTTPException(status_code=500, detail="Failed to deny invitation")
|
||||
@@ -1,27 +1,56 @@
|
||||
import logging
|
||||
|
||||
from fastapi_users import exceptions
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.invited_users import get_invited_users
|
||||
from onyx.auth.invited_users import get_pending_users
|
||||
from onyx.auth.invited_users import write_invited_users
|
||||
from onyx.auth.invited_users import write_pending_users
|
||||
from onyx.db.engine import get_session_with_shared_schema
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import get_sqlalchemy_engine
|
||||
from onyx.db.models import UserTenantMapping
|
||||
from onyx.server.manage.models import TenantSnapshot
|
||||
from onyx.setup import setup_logger
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
|
||||
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_tenant_id_for_email(email: str) -> str:
|
||||
if not MULTI_TENANT:
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
# Implement logic to get tenant_id from the mapping table
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
result = db_session.execute(
|
||||
select(UserTenantMapping.tenant_id).where(UserTenantMapping.email == email)
|
||||
)
|
||||
tenant_id = result.scalar_one_or_none()
|
||||
try:
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
# First try to get an active tenant
|
||||
result = db_session.execute(
|
||||
select(UserTenantMapping).where(
|
||||
UserTenantMapping.email == email,
|
||||
UserTenantMapping.active == True, # noqa: E712
|
||||
)
|
||||
)
|
||||
mapping = result.scalar_one_or_none()
|
||||
tenant_id = mapping.tenant_id if mapping else None
|
||||
|
||||
# If no active tenant found, try to get the first inactive one
|
||||
if tenant_id is None:
|
||||
result = db_session.execute(
|
||||
select(UserTenantMapping).where(
|
||||
UserTenantMapping.email == email,
|
||||
UserTenantMapping.active == False, # noqa: E712
|
||||
)
|
||||
)
|
||||
mapping = result.scalar_one_or_none()
|
||||
if mapping:
|
||||
# Mark this mapping as active
|
||||
mapping.active = True
|
||||
db_session.commit()
|
||||
tenant_id = mapping.tenant_id
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error getting tenant id for email {email}: {e}")
|
||||
raise exceptions.UserNotExists()
|
||||
if tenant_id is None:
|
||||
raise exceptions.UserNotExists()
|
||||
return tenant_id
|
||||
@@ -38,13 +67,39 @@ def user_owns_a_tenant(email: str) -> bool:
|
||||
|
||||
|
||||
def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
"""
|
||||
Add users to a tenant with proper transaction handling.
|
||||
Checks if users already have a tenant mapping to avoid duplicates.
|
||||
"""
|
||||
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
|
||||
try:
|
||||
# Start a transaction
|
||||
db_session.begin()
|
||||
|
||||
for email in emails:
|
||||
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
|
||||
# Check if the user already has a mapping to this tenant
|
||||
existing_mapping = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email == email,
|
||||
UserTenantMapping.tenant_id == tenant_id,
|
||||
)
|
||||
.with_for_update()
|
||||
.first()
|
||||
)
|
||||
|
||||
if not existing_mapping:
|
||||
# Only add if mapping doesn't exist
|
||||
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
|
||||
|
||||
# Commit the transaction
|
||||
db_session.commit()
|
||||
logger.info(f"Successfully added users {emails} to tenant {tenant_id}")
|
||||
|
||||
except Exception:
|
||||
logger.exception(f"Failed to add users to tenant {tenant_id}")
|
||||
db_session.commit()
|
||||
db_session.rollback()
|
||||
raise
|
||||
|
||||
|
||||
def remove_users_from_tenant(emails: list[str], tenant_id: str) -> None:
|
||||
@@ -76,3 +131,187 @@ def remove_all_users_from_tenant(tenant_id: str) -> None:
|
||||
UserTenantMapping.tenant_id == tenant_id
|
||||
).delete()
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def invite_self_to_tenant(email: str, tenant_id: str) -> None:
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
try:
|
||||
pending_users = get_pending_users()
|
||||
if email in pending_users:
|
||||
return
|
||||
write_pending_users(pending_users + [email])
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
def approve_user_invite(email: str, tenant_id: str) -> None:
|
||||
"""
|
||||
Approve a user invite to a tenant.
|
||||
This will delete all existing records for this email and create a new mapping entry for the user in this tenant.
|
||||
"""
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
# Delete all existing records for this email
|
||||
db_session.query(UserTenantMapping).filter(
|
||||
UserTenantMapping.email == email
|
||||
).delete()
|
||||
|
||||
# Create a new mapping entry for the user in this tenant
|
||||
new_mapping = UserTenantMapping(email=email, tenant_id=tenant_id, active=True)
|
||||
db_session.add(new_mapping)
|
||||
db_session.commit()
|
||||
|
||||
# Also remove the user from pending users list
|
||||
# Remove from pending users
|
||||
pending_users = get_pending_users()
|
||||
if email in pending_users:
|
||||
pending_users.remove(email)
|
||||
write_pending_users(pending_users)
|
||||
|
||||
# Add to invited users
|
||||
invited_users = get_invited_users()
|
||||
if email not in invited_users:
|
||||
invited_users.append(email)
|
||||
write_invited_users(invited_users)
|
||||
|
||||
|
||||
def accept_user_invite(email: str, tenant_id: str) -> None:
|
||||
"""
|
||||
Accept an invitation to join a tenant.
|
||||
This activates the user's mapping to the tenant.
|
||||
"""
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
try:
|
||||
# First check if there's an active mapping for this user and tenant
|
||||
active_mapping = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email == email,
|
||||
UserTenantMapping.active == True, # noqa: E712
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
# If an active mapping exists, delete it
|
||||
if active_mapping:
|
||||
db_session.delete(active_mapping)
|
||||
logger.info(
|
||||
f"Deleted existing active mapping for user {email} in tenant {tenant_id}"
|
||||
)
|
||||
|
||||
# Find the inactive mapping for this user and tenant
|
||||
mapping = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email == email,
|
||||
UserTenantMapping.tenant_id == tenant_id,
|
||||
UserTenantMapping.active == False, # noqa: E712
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if mapping:
|
||||
# Set all other mappings for this user to inactive
|
||||
db_session.query(UserTenantMapping).filter(
|
||||
UserTenantMapping.email == email,
|
||||
UserTenantMapping.active == True, # noqa: E712
|
||||
).update({"active": False})
|
||||
|
||||
# Activate this mapping
|
||||
mapping.active = True
|
||||
db_session.commit()
|
||||
logger.info(f"User {email} accepted invitation to tenant {tenant_id}")
|
||||
else:
|
||||
logger.warning(
|
||||
f"No invitation found for user {email} in tenant {tenant_id}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
db_session.rollback()
|
||||
logger.exception(
|
||||
f"Failed to accept invitation for user {email} to tenant {tenant_id}: {str(e)}"
|
||||
)
|
||||
raise
|
||||
|
||||
|
||||
def deny_user_invite(email: str, tenant_id: str) -> None:
|
||||
"""
|
||||
Deny an invitation to join a tenant.
|
||||
This removes the user's mapping to the tenant.
|
||||
"""
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
# Delete the mapping for this user and tenant
|
||||
result = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email == email,
|
||||
UserTenantMapping.tenant_id == tenant_id,
|
||||
UserTenantMapping.active == False, # noqa: E712
|
||||
)
|
||||
.delete()
|
||||
)
|
||||
|
||||
db_session.commit()
|
||||
if result:
|
||||
logger.info(f"User {email} denied invitation to tenant {tenant_id}")
|
||||
else:
|
||||
logger.warning(
|
||||
f"No invitation found for user {email} in tenant {tenant_id}"
|
||||
)
|
||||
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
|
||||
try:
|
||||
pending_users = get_invited_users()
|
||||
if email in pending_users:
|
||||
pending_users.remove(email)
|
||||
write_invited_users(pending_users)
|
||||
finally:
|
||||
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
|
||||
|
||||
|
||||
def get_tenant_count(tenant_id: str) -> int:
|
||||
"""
|
||||
Get the number of active users for this tenant
|
||||
"""
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
# Count the number of active users for this tenant
|
||||
user_count = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.tenant_id == tenant_id,
|
||||
UserTenantMapping.active == True, # noqa: E712
|
||||
)
|
||||
.count()
|
||||
)
|
||||
|
||||
return user_count
|
||||
|
||||
|
||||
def get_tenant_invitation(email: str) -> TenantSnapshot | None:
|
||||
"""
|
||||
Get the first tenant invitation for this user
|
||||
"""
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
# Get the first tenant invitation for this user
|
||||
invitation = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.email == email,
|
||||
UserTenantMapping.active == False, # noqa: E712
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
if invitation:
|
||||
# Get the user count for this tenant
|
||||
user_count = (
|
||||
db_session.query(UserTenantMapping)
|
||||
.filter(
|
||||
UserTenantMapping.tenant_id == invitation.tenant_id,
|
||||
UserTenantMapping.active == True, # noqa: E712
|
||||
)
|
||||
.count()
|
||||
)
|
||||
return TenantSnapshot(
|
||||
tenant_id=invitation.tenant_id, number_of_users=user_count
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
@@ -62,6 +62,60 @@ _OPENAI_MAX_INPUT_LEN = 2048
|
||||
# Cohere allows up to 96 embeddings in a single embedding calling
|
||||
_COHERE_MAX_INPUT_LEN = 96
|
||||
|
||||
# Authentication error string constants
|
||||
_AUTH_ERROR_401 = "401"
|
||||
_AUTH_ERROR_UNAUTHORIZED = "unauthorized"
|
||||
_AUTH_ERROR_INVALID_API_KEY = "invalid api key"
|
||||
_AUTH_ERROR_PERMISSION = "permission"
|
||||
|
||||
|
||||
def is_authentication_error(error: Exception) -> bool:
|
||||
"""Check if an exception is related to authentication issues.
|
||||
|
||||
Args:
|
||||
error: The exception to check
|
||||
|
||||
Returns:
|
||||
bool: True if the error appears to be authentication-related
|
||||
"""
|
||||
error_str = str(error).lower()
|
||||
return (
|
||||
_AUTH_ERROR_401 in error_str
|
||||
or _AUTH_ERROR_UNAUTHORIZED in error_str
|
||||
or _AUTH_ERROR_INVALID_API_KEY in error_str
|
||||
or _AUTH_ERROR_PERMISSION in error_str
|
||||
)
|
||||
|
||||
|
||||
def format_embedding_error(
|
||||
error: Exception,
|
||||
service_name: str,
|
||||
model: str | None,
|
||||
provider: EmbeddingProvider,
|
||||
status_code: int | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Format a standardized error string for embedding errors.
|
||||
"""
|
||||
detail = f"Status {status_code}" if status_code else f"{type(error)}"
|
||||
|
||||
return (
|
||||
f"{'HTTP error' if status_code else 'Exception'} embedding text with {service_name} - {detail}: "
|
||||
f"Model: {model} "
|
||||
f"Provider: {provider} "
|
||||
f"Exception: {error}"
|
||||
)
|
||||
|
||||
|
||||
# Custom exception for authentication errors
|
||||
class AuthenticationError(Exception):
|
||||
"""Raised when authentication fails with a provider."""
|
||||
|
||||
def __init__(self, provider: str, message: str = "API key is invalid or expired"):
|
||||
self.provider = provider
|
||||
self.message = message
|
||||
super().__init__(f"{provider} authentication failed: {message}")
|
||||
|
||||
|
||||
class CloudEmbedding:
|
||||
def __init__(
|
||||
@@ -92,31 +146,17 @@ class CloudEmbedding:
|
||||
)
|
||||
|
||||
final_embeddings: list[Embedding] = []
|
||||
try:
|
||||
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
|
||||
response = await client.embeddings.create(
|
||||
input=text_batch,
|
||||
model=model,
|
||||
dimensions=reduced_dimension or openai.NOT_GIVEN,
|
||||
)
|
||||
final_embeddings.extend(
|
||||
[embedding.embedding for embedding in response.data]
|
||||
)
|
||||
return final_embeddings
|
||||
except Exception as e:
|
||||
error_string = (
|
||||
f"Exception embedding text with OpenAI - {type(e)}: "
|
||||
f"Model: {model} "
|
||||
f"Provider: {self.provider} "
|
||||
f"Exception: {e}"
|
||||
|
||||
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
|
||||
response = await client.embeddings.create(
|
||||
input=text_batch,
|
||||
model=model,
|
||||
dimensions=reduced_dimension or openai.NOT_GIVEN,
|
||||
)
|
||||
logger.error(error_string)
|
||||
|
||||
# only log text when it's not an authentication error.
|
||||
if not isinstance(e, openai.AuthenticationError):
|
||||
logger.debug(f"Exception texts: {texts}")
|
||||
|
||||
raise RuntimeError(error_string)
|
||||
final_embeddings.extend(
|
||||
[embedding.embedding for embedding in response.data]
|
||||
)
|
||||
return final_embeddings
|
||||
|
||||
async def _embed_cohere(
|
||||
self, texts: list[str], model: str | None, embedding_type: str
|
||||
@@ -155,7 +195,6 @@ class CloudEmbedding:
|
||||
input_type=embedding_type,
|
||||
truncation=True,
|
||||
)
|
||||
|
||||
return response.embeddings
|
||||
|
||||
async def _embed_azure(
|
||||
@@ -239,22 +278,51 @@ class CloudEmbedding:
|
||||
deployment_name: str | None = None,
|
||||
reduced_dimension: int | None = None,
|
||||
) -> list[Embedding]:
|
||||
if self.provider == EmbeddingProvider.OPENAI:
|
||||
return await self._embed_openai(texts, model_name, reduced_dimension)
|
||||
elif self.provider == EmbeddingProvider.AZURE:
|
||||
return await self._embed_azure(texts, f"azure/{deployment_name}")
|
||||
elif self.provider == EmbeddingProvider.LITELLM:
|
||||
return await self._embed_litellm_proxy(texts, model_name)
|
||||
try:
|
||||
if self.provider == EmbeddingProvider.OPENAI:
|
||||
return await self._embed_openai(texts, model_name, reduced_dimension)
|
||||
elif self.provider == EmbeddingProvider.AZURE:
|
||||
return await self._embed_azure(texts, f"azure/{deployment_name}")
|
||||
elif self.provider == EmbeddingProvider.LITELLM:
|
||||
return await self._embed_litellm_proxy(texts, model_name)
|
||||
|
||||
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
|
||||
if self.provider == EmbeddingProvider.COHERE:
|
||||
return await self._embed_cohere(texts, model_name, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.VOYAGE:
|
||||
return await self._embed_voyage(texts, model_name, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.GOOGLE:
|
||||
return await self._embed_vertex(texts, model_name, embedding_type)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
|
||||
if self.provider == EmbeddingProvider.COHERE:
|
||||
return await self._embed_cohere(texts, model_name, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.VOYAGE:
|
||||
return await self._embed_voyage(texts, model_name, embedding_type)
|
||||
elif self.provider == EmbeddingProvider.GOOGLE:
|
||||
return await self._embed_vertex(texts, model_name, embedding_type)
|
||||
else:
|
||||
raise ValueError(f"Unsupported provider: {self.provider}")
|
||||
except openai.AuthenticationError:
|
||||
raise AuthenticationError(provider="OpenAI")
|
||||
except httpx.HTTPStatusError as e:
|
||||
if e.response.status_code == 401:
|
||||
raise AuthenticationError(provider=str(self.provider))
|
||||
|
||||
error_string = format_embedding_error(
|
||||
e,
|
||||
str(self.provider),
|
||||
model_name or deployment_name,
|
||||
self.provider,
|
||||
status_code=e.response.status_code,
|
||||
)
|
||||
logger.error(error_string)
|
||||
logger.debug(f"Exception texts: {texts}")
|
||||
|
||||
raise RuntimeError(error_string)
|
||||
except Exception as e:
|
||||
if is_authentication_error(e):
|
||||
raise AuthenticationError(provider=str(self.provider))
|
||||
|
||||
error_string = format_embedding_error(
|
||||
e, str(self.provider), model_name or deployment_name, self.provider
|
||||
)
|
||||
logger.error(error_string)
|
||||
logger.debug(f"Exception texts: {texts}")
|
||||
|
||||
raise RuntimeError(error_string)
|
||||
|
||||
@staticmethod
|
||||
def create(
|
||||
@@ -569,6 +637,13 @@ async def process_embed_request(
|
||||
gpu_type=gpu_type,
|
||||
)
|
||||
return EmbedResponse(embeddings=embeddings)
|
||||
except AuthenticationError as e:
|
||||
# Handle authentication errors consistently
|
||||
logger.error(f"Authentication error: {e.provider}")
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=f"Authentication failed: {e.message}",
|
||||
)
|
||||
except RateLimitError as e:
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
|
||||
@@ -31,6 +31,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_VALIDATION
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_SUBANSWER_CHECK
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
@@ -92,6 +93,7 @@ def check_sub_answer(
|
||||
fast_llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK,
|
||||
max_tokens=AGENT_MAX_TOKENS_VALIDATION,
|
||||
)
|
||||
|
||||
quality_str: str = cast(str, response.content)
|
||||
|
||||
@@ -46,6 +46,7 @@ from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import StreamType
|
||||
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
|
||||
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_SUBANSWER_GENERATION
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
@@ -119,6 +120,7 @@ def generate_sub_answer(
|
||||
for message in fast_llm.stream(
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION,
|
||||
max_tokens=AGENT_MAX_TOKENS_SUBANSWER_GENERATION,
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
|
||||
@@ -43,6 +43,7 @@ from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrin
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_section_list,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import _should_restrict_tokens
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
dispatch_main_answer_stop_info,
|
||||
)
|
||||
@@ -62,6 +63,7 @@ from onyx.chat.models import StreamingError
|
||||
from onyx.configs.agent_configs import AGENT_ANSWER_GENERATION_BY_FAST_LLM
|
||||
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
|
||||
from onyx.configs.agent_configs import AGENT_MAX_STREAMED_DOCS_FOR_INITIAL_ANSWER
|
||||
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_ANSWER_GENERATION
|
||||
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
|
||||
@@ -153,8 +155,9 @@ def generate_initial_answer(
|
||||
)
|
||||
for tool_response in yield_search_responses(
|
||||
query=question,
|
||||
reranked_sections=answer_generation_documents.streaming_documents,
|
||||
final_context_sections=answer_generation_documents.context_documents,
|
||||
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
|
||||
get_reranked_sections=lambda: answer_generation_documents.streaming_documents,
|
||||
get_final_context_sections=lambda: answer_generation_documents.context_documents,
|
||||
search_query_info=query_info,
|
||||
get_section_relevance=lambda: relevance_list,
|
||||
search_tool=graph_config.tooling.search_tool,
|
||||
@@ -278,6 +281,9 @@ def generate_initial_answer(
|
||||
for message in model.stream(
|
||||
msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
|
||||
max_tokens=AGENT_MAX_TOKENS_ANSWER_GENERATION
|
||||
if _should_restrict_tokens(model.config)
|
||||
else None,
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
|
||||
@@ -34,6 +34,7 @@ from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import StreamType
|
||||
from onyx.chat.models import SubQuestionPiece
|
||||
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_SUBQUESTION_GENERATION
|
||||
from onyx.configs.agent_configs import AGENT_NUM_DOCS_FOR_DECOMPOSITION
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION,
|
||||
@@ -141,6 +142,7 @@ def decompose_orig_question(
|
||||
model.stream(
|
||||
msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION,
|
||||
max_tokens=AGENT_MAX_TOKENS_SUBQUESTION_GENERATION,
|
||||
),
|
||||
dispatch_subquestion(0, writer),
|
||||
sep_callback=dispatch_subquestion_sep(0, writer),
|
||||
|
||||
@@ -33,6 +33,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import RefinedAnswerImprovement
|
||||
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_VALIDATION
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_COMPARE_ANSWERS
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
@@ -112,6 +113,7 @@ def compare_answers(
|
||||
model.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS,
|
||||
max_tokens=AGENT_MAX_TOKENS_VALIDATION,
|
||||
)
|
||||
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
|
||||
@@ -43,6 +43,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_SUBQUESTION_GENERATION
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION,
|
||||
)
|
||||
@@ -144,6 +145,7 @@ def create_refined_sub_questions(
|
||||
model.stream(
|
||||
msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION,
|
||||
max_tokens=AGENT_MAX_TOKENS_SUBQUESTION_GENERATION,
|
||||
),
|
||||
dispatch_subquestion(1, writer),
|
||||
sep_callback=dispatch_subquestion_sep(1, writer),
|
||||
|
||||
@@ -50,13 +50,7 @@ def decide_refinement_need(
|
||||
)
|
||||
]
|
||||
|
||||
if graph_config.behavior.allow_refinement:
|
||||
return RequireRefinemenEvalUpdate(
|
||||
require_refined_answer_eval=decision,
|
||||
log_messages=log_messages,
|
||||
)
|
||||
else:
|
||||
return RequireRefinemenEvalUpdate(
|
||||
require_refined_answer_eval=False,
|
||||
log_messages=log_messages,
|
||||
)
|
||||
return RequireRefinemenEvalUpdate(
|
||||
require_refined_answer_eval=graph_config.behavior.allow_refinement and decision,
|
||||
log_messages=log_messages,
|
||||
)
|
||||
|
||||
@@ -21,6 +21,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_ENTITY_TERM_EXTRACTION
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION,
|
||||
)
|
||||
@@ -96,6 +97,7 @@ def extract_entities_terms(
|
||||
fast_llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION,
|
||||
max_tokens=AGENT_MAX_TOKENS_ENTITY_TERM_EXTRACTION,
|
||||
)
|
||||
|
||||
cleaned_response = (
|
||||
|
||||
@@ -46,6 +46,7 @@ from onyx.agents.agent_search.shared_graph_utils.models import RefinedAgentStats
|
||||
from onyx.agents.agent_search.shared_graph_utils.operators import (
|
||||
dedup_inference_section_list,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import _should_restrict_tokens
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
dispatch_main_answer_stop_info,
|
||||
)
|
||||
@@ -68,6 +69,8 @@ from onyx.chat.models import StreamingError
|
||||
from onyx.configs.agent_configs import AGENT_ANSWER_GENERATION_BY_FAST_LLM
|
||||
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
|
||||
from onyx.configs.agent_configs import AGENT_MAX_STREAMED_DOCS_FOR_REFINED_ANSWER
|
||||
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_ANSWER_GENERATION
|
||||
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_VALIDATION
|
||||
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION,
|
||||
@@ -179,8 +182,9 @@ def generate_validate_refined_answer(
|
||||
)
|
||||
for tool_response in yield_search_responses(
|
||||
query=question,
|
||||
reranked_sections=answer_generation_documents.streaming_documents,
|
||||
final_context_sections=answer_generation_documents.context_documents,
|
||||
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
|
||||
get_reranked_sections=lambda: answer_generation_documents.streaming_documents,
|
||||
get_final_context_sections=lambda: answer_generation_documents.context_documents,
|
||||
search_query_info=query_info,
|
||||
get_section_relevance=lambda: relevance_list,
|
||||
search_tool=graph_config.tooling.search_tool,
|
||||
@@ -302,7 +306,11 @@ def generate_validate_refined_answer(
|
||||
|
||||
def stream_refined_answer() -> list[str]:
|
||||
for message in model.stream(
|
||||
msg, timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION
|
||||
msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION,
|
||||
max_tokens=AGENT_MAX_TOKENS_ANSWER_GENERATION
|
||||
if _should_restrict_tokens(model.config)
|
||||
else None,
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
@@ -409,6 +417,7 @@ def generate_validate_refined_answer(
|
||||
validation_model.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION,
|
||||
max_tokens=AGENT_MAX_TOKENS_VALIDATION,
|
||||
)
|
||||
refined_answer_quality = binary_string_test_after_answer_separator(
|
||||
text=cast(str, validation_response.content),
|
||||
|
||||
@@ -13,7 +13,6 @@ from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import StreamType
|
||||
from onyx.chat.models import SubQuestionPiece
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.tools.models import SearchQueryInfo
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -144,8 +143,6 @@ def get_query_info(results: list[QueryRetrievalResult]) -> SearchQueryInfo:
|
||||
if result.query_info is not None:
|
||||
query_info = result.query_info
|
||||
break
|
||||
return query_info or SearchQueryInfo(
|
||||
predicted_search=None,
|
||||
final_filters=IndexFilters(access_control_list=None),
|
||||
recency_bias_multiplier=1.0,
|
||||
)
|
||||
|
||||
assert query_info is not None, "must have query info"
|
||||
return query_info
|
||||
|
||||
@@ -33,6 +33,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_SUBQUERY_GENERATION
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION,
|
||||
)
|
||||
@@ -96,6 +97,7 @@ def expand_queries(
|
||||
model.stream(
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION,
|
||||
max_tokens=AGENT_MAX_TOKENS_SUBQUERY_GENERATION,
|
||||
),
|
||||
dispatch_subquery(level, question_num, writer),
|
||||
)
|
||||
|
||||
@@ -56,8 +56,9 @@ def format_results(
|
||||
relevance_list = relevance_from_docs(reranked_documents)
|
||||
for tool_response in yield_search_responses(
|
||||
query=state.question,
|
||||
reranked_sections=state.retrieved_documents,
|
||||
final_context_sections=reranked_documents,
|
||||
get_retrieved_sections=lambda: reranked_documents,
|
||||
get_reranked_sections=lambda: state.retrieved_documents,
|
||||
get_final_context_sections=lambda: reranked_documents,
|
||||
search_query_info=query_info,
|
||||
get_section_relevance=lambda: relevance_list,
|
||||
search_tool=graph_config.tooling.search_tool,
|
||||
|
||||
@@ -91,7 +91,7 @@ def retrieve_documents(
|
||||
retrieved_docs = retrieved_docs[:AGENT_MAX_QUERY_RETRIEVAL_RESULTS]
|
||||
|
||||
if AGENT_RETRIEVAL_STATS:
|
||||
pre_rerank_docs = callback_container[0]
|
||||
pre_rerank_docs = callback_container[0] if callback_container else []
|
||||
fit_scores = get_fit_scores(
|
||||
pre_rerank_docs,
|
||||
retrieved_docs,
|
||||
|
||||
@@ -25,6 +25,7 @@ from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrin
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_VALIDATION
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
@@ -93,6 +94,7 @@ def verify_documents(
|
||||
fast_llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION,
|
||||
max_tokens=AGENT_MAX_TOKENS_VALIDATION,
|
||||
)
|
||||
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
@@ -44,7 +44,9 @@ def call_tool(
|
||||
tool = tool_choice.tool
|
||||
tool_args = tool_choice.tool_args
|
||||
tool_id = tool_choice.id
|
||||
tool_runner = ToolRunner(tool, tool_args)
|
||||
tool_runner = ToolRunner(
|
||||
tool, tool_args, override_kwargs=tool_choice.search_tool_override_kwargs
|
||||
)
|
||||
tool_kickoff = tool_runner.kickoff()
|
||||
|
||||
emit_packet(tool_kickoff, writer)
|
||||
|
||||
@@ -15,8 +15,17 @@ from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
|
||||
from onyx.chat.tool_handling.tool_response_handler import (
|
||||
get_tool_call_for_non_tool_calling_llm_impl,
|
||||
)
|
||||
from onyx.context.search.preprocessing.preprocessing import query_analysis
|
||||
from onyx.context.search.retrieval.search_runner import get_query_embedding
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_in_background
|
||||
from onyx.utils.threadpool_concurrency import TimeoutThread
|
||||
from onyx.utils.threadpool_concurrency import wait_on_background
|
||||
from onyx.utils.timing import log_function_time
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -25,6 +34,7 @@ logger = setup_logger()
|
||||
# and a function that handles extracting the necessary fields
|
||||
# from the state and config
|
||||
# TODO: fan-out to multiple tool call nodes? Make this configurable?
|
||||
@log_function_time(print_only=True)
|
||||
def choose_tool(
|
||||
state: ToolChoiceState,
|
||||
config: RunnableConfig,
|
||||
@@ -37,6 +47,31 @@ def choose_tool(
|
||||
should_stream_answer = state.should_stream_answer
|
||||
|
||||
agent_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
|
||||
force_use_tool = agent_config.tooling.force_use_tool
|
||||
|
||||
embedding_thread: TimeoutThread[Embedding] | None = None
|
||||
keyword_thread: TimeoutThread[tuple[bool, list[str]]] | None = None
|
||||
override_kwargs: SearchToolOverrideKwargs | None = None
|
||||
if (
|
||||
not agent_config.behavior.use_agentic_search
|
||||
and agent_config.tooling.search_tool is not None
|
||||
and (
|
||||
not force_use_tool.force_use or force_use_tool.tool_name == SearchTool.name
|
||||
)
|
||||
):
|
||||
override_kwargs = SearchToolOverrideKwargs()
|
||||
# Run in a background thread to avoid blocking the main thread
|
||||
embedding_thread = run_in_background(
|
||||
get_query_embedding,
|
||||
agent_config.inputs.search_request.query,
|
||||
agent_config.persistence.db_session,
|
||||
)
|
||||
keyword_thread = run_in_background(
|
||||
query_analysis,
|
||||
agent_config.inputs.search_request.query,
|
||||
)
|
||||
|
||||
using_tool_calling_llm = agent_config.tooling.using_tool_calling_llm
|
||||
prompt_builder = state.prompt_snapshot or agent_config.inputs.prompt_builder
|
||||
|
||||
@@ -47,7 +82,6 @@ def choose_tool(
|
||||
tools = [
|
||||
tool for tool in (agent_config.tooling.tools or []) if tool.name in state.tools
|
||||
]
|
||||
force_use_tool = agent_config.tooling.force_use_tool
|
||||
|
||||
tool, tool_args = None, None
|
||||
if force_use_tool.force_use and force_use_tool.args is not None:
|
||||
@@ -71,11 +105,22 @@ def choose_tool(
|
||||
# If we have a tool and tool args, we are ready to request a tool call.
|
||||
# This only happens if the tool call was forced or we are using a non-tool calling LLM.
|
||||
if tool and tool_args:
|
||||
if embedding_thread and tool.name == SearchTool._NAME:
|
||||
# Wait for the embedding thread to finish
|
||||
embedding = wait_on_background(embedding_thread)
|
||||
assert override_kwargs is not None, "must have override kwargs"
|
||||
override_kwargs.precomputed_query_embedding = embedding
|
||||
if keyword_thread and tool.name == SearchTool._NAME:
|
||||
is_keyword, keywords = wait_on_background(keyword_thread)
|
||||
assert override_kwargs is not None, "must have override kwargs"
|
||||
override_kwargs.precomputed_is_keyword = is_keyword
|
||||
override_kwargs.precomputed_keywords = keywords
|
||||
return ToolChoiceUpdate(
|
||||
tool_choice=ToolChoice(
|
||||
tool=tool,
|
||||
tool_args=tool_args,
|
||||
id=str(uuid4()),
|
||||
search_tool_override_kwargs=override_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -153,10 +198,22 @@ def choose_tool(
|
||||
logger.debug(f"Selected tool: {selected_tool.name}")
|
||||
logger.debug(f"Selected tool call request: {selected_tool_call_request}")
|
||||
|
||||
if embedding_thread and selected_tool.name == SearchTool._NAME:
|
||||
# Wait for the embedding thread to finish
|
||||
embedding = wait_on_background(embedding_thread)
|
||||
assert override_kwargs is not None, "must have override kwargs"
|
||||
override_kwargs.precomputed_query_embedding = embedding
|
||||
if keyword_thread and selected_tool.name == SearchTool._NAME:
|
||||
is_keyword, keywords = wait_on_background(keyword_thread)
|
||||
assert override_kwargs is not None, "must have override kwargs"
|
||||
override_kwargs.precomputed_is_keyword = is_keyword
|
||||
override_kwargs.precomputed_keywords = keywords
|
||||
|
||||
return ToolChoiceUpdate(
|
||||
tool_choice=ToolChoice(
|
||||
tool=selected_tool,
|
||||
tool_args=selected_tool_call_request["args"],
|
||||
id=selected_tool_call_request["id"],
|
||||
search_tool_override_kwargs=override_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -9,18 +9,23 @@ from onyx.agents.agent_search.basic.states import BasicState
|
||||
from onyx.agents.agent_search.basic.utils import process_llm_stream
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import OnyxContexts
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_DOC_CONTENT_ID,
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
from onyx.tools.tool_implementations.search.search_utils import (
|
||||
context_from_inference_section,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def basic_use_tool_response(
|
||||
state: BasicState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> BasicOutput:
|
||||
@@ -50,11 +55,13 @@ def basic_use_tool_response(
|
||||
for yield_item in tool_call_responses:
|
||||
if yield_item.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
final_search_results = cast(list[LlmDoc], yield_item.response)
|
||||
elif yield_item.id == SEARCH_DOC_CONTENT_ID:
|
||||
search_contexts = cast(OnyxContexts, yield_item.response).contexts
|
||||
for doc in search_contexts:
|
||||
if doc.document_id not in initial_search_results:
|
||||
initial_search_results.append(doc)
|
||||
elif yield_item.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
search_response_summary = cast(SearchResponseSummary, yield_item.response)
|
||||
for section in search_response_summary.top_sections:
|
||||
if section.center_chunk.document_id not in initial_search_results:
|
||||
initial_search_results.append(
|
||||
context_from_inference_section(section)
|
||||
)
|
||||
|
||||
new_tool_call_chunk = AIMessageChunk(content="")
|
||||
if not agent_config.behavior.skip_gen_ai_answer_generation:
|
||||
|
||||
@@ -2,6 +2,7 @@ from pydantic import BaseModel
|
||||
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import PromptSnapshot
|
||||
from onyx.tools.message import ToolCallSummary
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.models import ToolCallFinalResult
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.models import ToolResponse
|
||||
@@ -35,6 +36,7 @@ class ToolChoice(BaseModel):
|
||||
tool: Tool
|
||||
tool_args: dict
|
||||
id: str | None
|
||||
search_tool_override_kwargs: SearchToolOverrideKwargs | None = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@@ -13,6 +13,11 @@ AGENT_NEGATIVE_VALUE_STR = "no"
|
||||
AGENT_ANSWER_SEPARATOR = "Answer:"
|
||||
|
||||
|
||||
EMBEDDING_KEY = "embedding"
|
||||
IS_KEYWORD_KEY = "is_keyword"
|
||||
KEYWORDS_KEY = "keywords"
|
||||
|
||||
|
||||
class AgentLLMErrorType(str, Enum):
|
||||
TIMEOUT = "timeout"
|
||||
RATE_LIMIT = "rate_limit"
|
||||
|
||||
@@ -42,6 +42,7 @@ from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import StreamType
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_HISTORY_SUMMARY
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION,
|
||||
)
|
||||
@@ -61,6 +62,7 @@ from onyx.db.persona import Persona
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.interfaces import LLMConfig
|
||||
from onyx.prompts.agent_search import (
|
||||
ASSISTANT_SYSTEM_PROMPT_DEFAULT,
|
||||
)
|
||||
@@ -402,6 +404,7 @@ def summarize_history(
|
||||
llm.invoke,
|
||||
history_context_prompt,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION,
|
||||
max_tokens=AGENT_MAX_TOKENS_HISTORY_SUMMARY,
|
||||
)
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
logger.error("LLM Timeout Error - summarize history")
|
||||
@@ -505,3 +508,9 @@ def get_deduplicated_structured_subquestion_documents(
|
||||
cited_documents=dedup_inference_section_list(cited_docs),
|
||||
context_documents=dedup_inference_section_list(context_docs),
|
||||
)
|
||||
|
||||
|
||||
def _should_restrict_tokens(llm_config: LLMConfig) -> bool:
|
||||
return not (
|
||||
llm_config.model_provider == "openai" and llm_config.model_name.startswith("o")
|
||||
)
|
||||
|
||||
@@ -153,7 +153,8 @@ def send_email(
|
||||
msg = MIMEMultipart("alternative")
|
||||
msg["Subject"] = subject
|
||||
msg["To"] = user_email
|
||||
msg["From"] = mail_from
|
||||
if mail_from:
|
||||
msg["From"] = mail_from
|
||||
msg["Date"] = formatdate(localtime=True)
|
||||
msg["Message-ID"] = make_msgid(domain="onyx.app")
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import cast
|
||||
|
||||
from onyx.configs.constants import KV_PENDING_USERS_KEY
|
||||
from onyx.configs.constants import KV_USER_STORE_KEY
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.key_value_store.interface import KvKeyNotFoundError
|
||||
@@ -18,3 +19,17 @@ def write_invited_users(emails: list[str]) -> int:
|
||||
store = get_kv_store()
|
||||
store.store(KV_USER_STORE_KEY, cast(JSON_ro, emails))
|
||||
return len(emails)
|
||||
|
||||
|
||||
def get_pending_users() -> list[str]:
|
||||
try:
|
||||
store = get_kv_store()
|
||||
return cast(list, store.load(KV_PENDING_USERS_KEY))
|
||||
except KvKeyNotFoundError:
|
||||
return list()
|
||||
|
||||
|
||||
def write_pending_users(emails: list[str]) -> int:
|
||||
store = get_kv_store()
|
||||
store.store(KV_PENDING_USERS_KEY, cast(JSON_ro, emails))
|
||||
return len(emails)
|
||||
|
||||
211
backend/onyx/auth/oauth_refresher.py
Normal file
211
backend/onyx/auth/oauth_refresher.py
Normal file
@@ -0,0 +1,211 @@
|
||||
from datetime import datetime
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
|
||||
import httpx
|
||||
from fastapi_users.manager import BaseUserManager
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
|
||||
from onyx.configs.app_configs import OAUTH_CLIENT_ID
|
||||
from onyx.configs.app_configs import OAUTH_CLIENT_SECRET
|
||||
from onyx.configs.app_configs import TRACK_EXTERNAL_IDP_EXPIRY
|
||||
from onyx.db.models import OAuthAccount
|
||||
from onyx.db.models import User
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# Standard OAuth refresh token endpoints
|
||||
REFRESH_ENDPOINTS = {
|
||||
"google": "https://oauth2.googleapis.com/token",
|
||||
}
|
||||
|
||||
|
||||
# NOTE: Keeping this as a utility function for potential future debugging,
|
||||
# but not using it in production code
|
||||
async def _test_expire_oauth_token(
|
||||
user: User,
|
||||
oauth_account: OAuthAccount,
|
||||
db_session: AsyncSession,
|
||||
user_manager: BaseUserManager[User, Any],
|
||||
expire_in_seconds: int = 10,
|
||||
) -> bool:
|
||||
"""
|
||||
Utility function for testing - Sets an OAuth token to expire in a short time
|
||||
to facilitate testing of the refresh flow.
|
||||
Not used in production code.
|
||||
"""
|
||||
try:
|
||||
new_expires_at = int(
|
||||
(datetime.now(timezone.utc).timestamp() + expire_in_seconds)
|
||||
)
|
||||
|
||||
updated_data: Dict[str, Any] = {"expires_at": new_expires_at}
|
||||
|
||||
await user_manager.user_db.update_oauth_account(
|
||||
user, cast(Any, oauth_account), updated_data
|
||||
)
|
||||
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.exception(f"Error setting artificial expiration: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
async def refresh_oauth_token(
|
||||
user: User,
|
||||
oauth_account: OAuthAccount,
|
||||
db_session: AsyncSession,
|
||||
user_manager: BaseUserManager[User, Any],
|
||||
) -> bool:
|
||||
"""
|
||||
Attempt to refresh an OAuth token that's about to expire or has expired.
|
||||
Returns True if successful, False otherwise.
|
||||
"""
|
||||
if not oauth_account.refresh_token:
|
||||
logger.warning(
|
||||
f"No refresh token available for {user.email}'s {oauth_account.oauth_name} account"
|
||||
)
|
||||
return False
|
||||
|
||||
provider = oauth_account.oauth_name
|
||||
if provider not in REFRESH_ENDPOINTS:
|
||||
logger.warning(f"Refresh endpoint not configured for provider: {provider}")
|
||||
return False
|
||||
|
||||
try:
|
||||
logger.info(f"Refreshing OAuth token for {user.email}'s {provider} account")
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
REFRESH_ENDPOINTS[provider],
|
||||
data={
|
||||
"client_id": OAUTH_CLIENT_ID,
|
||||
"client_secret": OAUTH_CLIENT_SECRET,
|
||||
"refresh_token": oauth_account.refresh_token,
|
||||
"grant_type": "refresh_token",
|
||||
},
|
||||
headers={"Content-Type": "application/x-www-form-urlencoded"},
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
logger.error(
|
||||
f"Failed to refresh OAuth token: Status {response.status_code}"
|
||||
)
|
||||
return False
|
||||
|
||||
token_data = response.json()
|
||||
|
||||
new_access_token = token_data.get("access_token")
|
||||
new_refresh_token = token_data.get(
|
||||
"refresh_token", oauth_account.refresh_token
|
||||
)
|
||||
expires_in = token_data.get("expires_in")
|
||||
|
||||
# Calculate new expiry time if provided
|
||||
new_expires_at: Optional[int] = None
|
||||
if expires_in:
|
||||
new_expires_at = int(
|
||||
(datetime.now(timezone.utc).timestamp() + expires_in)
|
||||
)
|
||||
|
||||
# Update the OAuth account
|
||||
updated_data: Dict[str, Any] = {
|
||||
"access_token": new_access_token,
|
||||
"refresh_token": new_refresh_token,
|
||||
}
|
||||
|
||||
if new_expires_at:
|
||||
updated_data["expires_at"] = new_expires_at
|
||||
|
||||
# Update oidc_expiry in user model if we're tracking it
|
||||
if TRACK_EXTERNAL_IDP_EXPIRY:
|
||||
oidc_expiry = datetime.fromtimestamp(
|
||||
new_expires_at, tz=timezone.utc
|
||||
)
|
||||
await user_manager.user_db.update(
|
||||
user, {"oidc_expiry": oidc_expiry}
|
||||
)
|
||||
|
||||
# Update the OAuth account
|
||||
await user_manager.user_db.update_oauth_account(
|
||||
user, cast(Any, oauth_account), updated_data
|
||||
)
|
||||
|
||||
logger.info(f"Successfully refreshed OAuth token for {user.email}")
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
logger.exception(f"Error refreshing OAuth token: {str(e)}")
|
||||
return False
|
||||
|
||||
|
||||
async def check_and_refresh_oauth_tokens(
|
||||
user: User,
|
||||
db_session: AsyncSession,
|
||||
user_manager: BaseUserManager[User, Any],
|
||||
) -> None:
|
||||
"""
|
||||
Check if any OAuth tokens are expired or about to expire and refresh them.
|
||||
"""
|
||||
if not hasattr(user, "oauth_accounts") or not user.oauth_accounts:
|
||||
return
|
||||
|
||||
now_timestamp = datetime.now(timezone.utc).timestamp()
|
||||
|
||||
# Buffer time to refresh tokens before they expire (in seconds)
|
||||
buffer_seconds = 300 # 5 minutes
|
||||
|
||||
for oauth_account in user.oauth_accounts:
|
||||
# Skip accounts without refresh tokens
|
||||
if not oauth_account.refresh_token:
|
||||
continue
|
||||
|
||||
# If token is about to expire, refresh it
|
||||
if (
|
||||
oauth_account.expires_at
|
||||
and oauth_account.expires_at - now_timestamp < buffer_seconds
|
||||
):
|
||||
logger.info(f"OAuth token for {user.email} is about to expire - refreshing")
|
||||
success = await refresh_oauth_token(
|
||||
user, oauth_account, db_session, user_manager
|
||||
)
|
||||
|
||||
if not success:
|
||||
logger.warning(
|
||||
"Failed to refresh OAuth token. User may need to re-authenticate."
|
||||
)
|
||||
|
||||
|
||||
async def check_oauth_account_has_refresh_token(
|
||||
user: User,
|
||||
oauth_account: OAuthAccount,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if an OAuth account has a refresh token.
|
||||
Returns True if a refresh token exists, False otherwise.
|
||||
"""
|
||||
return bool(oauth_account.refresh_token)
|
||||
|
||||
|
||||
async def get_oauth_accounts_requiring_refresh_token(user: User) -> List[OAuthAccount]:
|
||||
"""
|
||||
Returns a list of OAuth accounts for a user that are missing refresh tokens.
|
||||
These accounts will need re-authentication to get refresh tokens.
|
||||
"""
|
||||
if not hasattr(user, "oauth_accounts") or not user.oauth_accounts:
|
||||
return []
|
||||
|
||||
accounts_needing_refresh = []
|
||||
for oauth_account in user.oauth_accounts:
|
||||
has_refresh_token = await check_oauth_account_has_refresh_token(
|
||||
user, oauth_account
|
||||
)
|
||||
if not has_refresh_token:
|
||||
accounts_needing_refresh.append(oauth_account)
|
||||
|
||||
return accounts_needing_refresh
|
||||
@@ -5,12 +5,16 @@ import string
|
||||
import uuid
|
||||
from collections.abc import AsyncGenerator
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from typing import List
|
||||
from typing import Optional
|
||||
from typing import Protocol
|
||||
from typing import Tuple
|
||||
from typing import TypeVar
|
||||
|
||||
import jwt
|
||||
from email_validator import EmailNotValidError
|
||||
@@ -100,6 +104,7 @@ from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
from onyx.utils.telemetry import RecordType
|
||||
from onyx.utils.url import add_url_params
|
||||
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
|
||||
from onyx.utils.variable_functionality import fetch_versioned_implementation
|
||||
from shared_configs.configs import async_return_default_schema
|
||||
@@ -686,16 +691,20 @@ cookie_transport = CookieTransport(
|
||||
)
|
||||
|
||||
|
||||
def get_redis_strategy() -> RedisStrategy:
|
||||
return TenantAwareRedisStrategy()
|
||||
T = TypeVar("T", covariant=True)
|
||||
ID = TypeVar("ID", contravariant=True)
|
||||
|
||||
|
||||
def get_database_strategy(
|
||||
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
|
||||
) -> DatabaseStrategy:
|
||||
return DatabaseStrategy(
|
||||
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS
|
||||
)
|
||||
# Protocol for strategies that support token refreshing without inheritance.
|
||||
class RefreshableStrategy(Protocol):
|
||||
"""Protocol for authentication strategies that support token refreshing."""
|
||||
|
||||
async def refresh_token(self, token: Optional[str], user: Any) -> str:
|
||||
"""
|
||||
Refresh an existing token by extending its lifetime.
|
||||
Returns either the same token with extended expiration or a new token.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
|
||||
@@ -754,6 +763,75 @@ class TenantAwareRedisStrategy(RedisStrategy[User, uuid.UUID]):
|
||||
redis = await get_async_redis_connection()
|
||||
await redis.delete(f"{self.key_prefix}{token}")
|
||||
|
||||
async def refresh_token(self, token: Optional[str], user: User) -> str:
|
||||
"""Refresh a token by extending its expiration time in Redis."""
|
||||
if token is None:
|
||||
# If no token provided, create a new one
|
||||
return await self.write_token(user)
|
||||
|
||||
redis = await get_async_redis_connection()
|
||||
token_key = f"{self.key_prefix}{token}"
|
||||
|
||||
# Check if token exists
|
||||
token_data_str = await redis.get(token_key)
|
||||
if not token_data_str:
|
||||
# Token not found, create new one
|
||||
return await self.write_token(user)
|
||||
|
||||
# Token exists, extend its lifetime
|
||||
token_data = json.loads(token_data_str)
|
||||
await redis.set(
|
||||
token_key,
|
||||
json.dumps(token_data),
|
||||
ex=self.lifetime_seconds,
|
||||
)
|
||||
|
||||
return token
|
||||
|
||||
|
||||
class RefreshableDatabaseStrategy(DatabaseStrategy[User, uuid.UUID, AccessToken]):
|
||||
"""Database strategy with token refreshing capabilities."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
access_token_db: AccessTokenDatabase[AccessToken],
|
||||
lifetime_seconds: Optional[int] = None,
|
||||
):
|
||||
super().__init__(access_token_db, lifetime_seconds)
|
||||
self._access_token_db = access_token_db
|
||||
|
||||
async def refresh_token(self, token: Optional[str], user: User) -> str:
|
||||
"""Refresh a token by updating its expiration time in the database."""
|
||||
if token is None:
|
||||
return await self.write_token(user)
|
||||
|
||||
# Find the token in database
|
||||
access_token = await self._access_token_db.get_by_token(token)
|
||||
|
||||
if access_token is None:
|
||||
# Token not found, create new one
|
||||
return await self.write_token(user)
|
||||
|
||||
# Update expiration time
|
||||
new_expires = datetime.now(timezone.utc) + timedelta(
|
||||
seconds=float(self.lifetime_seconds or SESSION_EXPIRE_TIME_SECONDS)
|
||||
)
|
||||
await self._access_token_db.update(access_token, {"expires": new_expires})
|
||||
|
||||
return token
|
||||
|
||||
|
||||
def get_redis_strategy() -> TenantAwareRedisStrategy:
|
||||
return TenantAwareRedisStrategy()
|
||||
|
||||
|
||||
def get_database_strategy(
|
||||
access_token_db: AccessTokenDatabase[AccessToken] = Depends(get_access_token_db),
|
||||
) -> RefreshableDatabaseStrategy:
|
||||
return RefreshableDatabaseStrategy(
|
||||
access_token_db, lifetime_seconds=SESSION_EXPIRE_TIME_SECONDS
|
||||
)
|
||||
|
||||
|
||||
if AUTH_BACKEND == AuthBackend.REDIS:
|
||||
auth_backend = AuthenticationBackend(
|
||||
@@ -804,6 +882,88 @@ class FastAPIUserWithLogoutRouter(FastAPIUsers[models.UP, models.ID]):
|
||||
|
||||
return router
|
||||
|
||||
def get_refresh_router(
|
||||
self,
|
||||
backend: AuthenticationBackend,
|
||||
requires_verification: bool = REQUIRE_EMAIL_VERIFICATION,
|
||||
) -> APIRouter:
|
||||
"""
|
||||
Provide a router for session token refreshing.
|
||||
"""
|
||||
# Import the oauth_refresher here to avoid circular imports
|
||||
from onyx.auth.oauth_refresher import check_and_refresh_oauth_tokens
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
get_current_user_token = self.authenticator.current_user_token(
|
||||
active=True, verified=requires_verification
|
||||
)
|
||||
|
||||
refresh_responses: OpenAPIResponseType = {
|
||||
**{
|
||||
status.HTTP_401_UNAUTHORIZED: {
|
||||
"description": "Missing token or inactive user."
|
||||
}
|
||||
},
|
||||
**backend.transport.get_openapi_login_responses_success(),
|
||||
}
|
||||
|
||||
@router.post(
|
||||
"/refresh", name=f"auth:{backend.name}.refresh", responses=refresh_responses
|
||||
)
|
||||
async def refresh(
|
||||
user_token: Tuple[models.UP, str] = Depends(get_current_user_token),
|
||||
strategy: Strategy[models.UP, models.ID] = Depends(backend.get_strategy),
|
||||
user_manager: BaseUserManager[models.UP, models.ID] = Depends(
|
||||
get_user_manager
|
||||
),
|
||||
db_session: AsyncSession = Depends(get_async_session),
|
||||
) -> Response:
|
||||
try:
|
||||
user, token = user_token
|
||||
logger.info(f"Processing token refresh request for user {user.email}")
|
||||
|
||||
# Check if user has OAuth accounts that need refreshing
|
||||
await check_and_refresh_oauth_tokens(
|
||||
user=cast(User, user),
|
||||
db_session=db_session,
|
||||
user_manager=cast(Any, user_manager),
|
||||
)
|
||||
|
||||
# Check if strategy supports refreshing
|
||||
supports_refresh = hasattr(strategy, "refresh_token") and callable(
|
||||
getattr(strategy, "refresh_token")
|
||||
)
|
||||
|
||||
if supports_refresh:
|
||||
try:
|
||||
refresh_method = getattr(strategy, "refresh_token")
|
||||
new_token = await refresh_method(token, user)
|
||||
logger.info(
|
||||
f"Successfully refreshed session token for user {user.email}"
|
||||
)
|
||||
return await backend.transport.get_login_response(new_token)
|
||||
except Exception as e:
|
||||
logger.error(f"Error refreshing session token: {str(e)}")
|
||||
# Fallback to logout and login if refresh fails
|
||||
await backend.logout(strategy, user, token)
|
||||
return await backend.login(strategy, user)
|
||||
|
||||
# Fallback: logout and login again
|
||||
logger.info(
|
||||
"Strategy doesn't support refresh - using logout/login flow"
|
||||
)
|
||||
await backend.logout(strategy, user, token)
|
||||
return await backend.login(strategy, user)
|
||||
except Exception as e:
|
||||
logger.error(f"Unexpected error in refresh endpoint: {str(e)}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Token refresh failed: {str(e)}",
|
||||
)
|
||||
|
||||
return router
|
||||
|
||||
|
||||
fastapi_users = FastAPIUserWithLogoutRouter[User, uuid.UUID](
|
||||
get_user_manager, [auth_backend]
|
||||
@@ -894,7 +1054,7 @@ async def current_limited_user(
|
||||
return await double_check_user(user)
|
||||
|
||||
|
||||
async def current_chat_accesssible_user(
|
||||
async def current_chat_accessible_user(
|
||||
user: User | None = Depends(optional_user),
|
||||
) -> User | None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
@@ -1037,12 +1197,20 @@ def get_oauth_router(
|
||||
"referral_source": referral_source or "default_referral",
|
||||
}
|
||||
state = generate_state_token(state_data, state_secret)
|
||||
|
||||
# Get the basic authorization URL
|
||||
authorization_url = await oauth_client.get_authorization_url(
|
||||
authorize_redirect_url,
|
||||
state,
|
||||
scopes,
|
||||
)
|
||||
|
||||
# For Google OAuth, add parameters to request refresh tokens
|
||||
if oauth_client.name == "google":
|
||||
authorization_url = add_url_params(
|
||||
authorization_url, {"access_type": "offline", "prompt": "consent"}
|
||||
)
|
||||
|
||||
return OAuth2AuthorizeResponse(authorization_url=authorization_url)
|
||||
|
||||
@router.get(
|
||||
@@ -1095,6 +1263,12 @@ def get_oauth_router(
|
||||
|
||||
next_url = state_data.get("next_url", "/")
|
||||
referral_source = state_data.get("referral_source", None)
|
||||
try:
|
||||
tenant_id = fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.user_mapping", "get_tenant_id_for_email", None
|
||||
)(account_email)
|
||||
except exceptions.UserNotExists:
|
||||
tenant_id = None
|
||||
|
||||
request.state.referral_source = referral_source
|
||||
|
||||
@@ -1126,9 +1300,14 @@ def get_oauth_router(
|
||||
# Login user
|
||||
response = await backend.login(strategy, user)
|
||||
await user_manager.on_after_login(user, request, response)
|
||||
|
||||
# Prepare redirect response
|
||||
redirect_response = RedirectResponse(next_url, status_code=302)
|
||||
if tenant_id is None:
|
||||
# Use URL utility to add parameters
|
||||
redirect_url = add_url_params(next_url, {"new_team": "true"})
|
||||
redirect_response = RedirectResponse(redirect_url, status_code=302)
|
||||
else:
|
||||
# No parameters to add
|
||||
redirect_response = RedirectResponse(next_url, status_code=302)
|
||||
|
||||
# Copy headers and other attributes from 'response' to 'redirect_response'
|
||||
for header_name, header_value in response.headers.items():
|
||||
@@ -1140,6 +1319,7 @@ def get_oauth_router(
|
||||
redirect_response.status_code = response.status_code
|
||||
if hasattr(response, "media_type"):
|
||||
redirect_response.media_type = response.media_type
|
||||
|
||||
return redirect_response
|
||||
|
||||
return router
|
||||
|
||||
@@ -111,5 +111,7 @@ celery_app.autodiscover_tasks(
|
||||
"onyx.background.celery.tasks.vespa",
|
||||
"onyx.background.celery.tasks.connector_deletion",
|
||||
"onyx.background.celery.tasks.doc_permission_syncing",
|
||||
"onyx.background.celery.tasks.indexing",
|
||||
"onyx.background.celery.tasks.tenant_provisioning",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -92,5 +92,6 @@ def on_setup_logging(
|
||||
celery_app.autodiscover_tasks(
|
||||
[
|
||||
"onyx.background.celery.tasks.monitoring",
|
||||
"onyx.background.celery.tasks.tenant_provisioning",
|
||||
]
|
||||
)
|
||||
|
||||
73
backend/onyx/background/celery/memory_monitoring.py
Normal file
73
backend/onyx/background/celery/memory_monitoring.py
Normal file
@@ -0,0 +1,73 @@
|
||||
# backend/onyx/background/celery/memory_monitoring.py
|
||||
import logging
|
||||
import os
|
||||
from logging.handlers import RotatingFileHandler
|
||||
|
||||
import psutil
|
||||
|
||||
from onyx.utils.logger import is_running_in_container
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
# Regular application logger
|
||||
logger = setup_logger()
|
||||
|
||||
# Only set up memory monitoring in container environment
|
||||
if is_running_in_container():
|
||||
# Set up a dedicated memory monitoring logger
|
||||
MEMORY_LOG_DIR = "/var/log/persisted-logs/memory"
|
||||
MEMORY_LOG_FILE = os.path.join(MEMORY_LOG_DIR, "memory_usage.log")
|
||||
MEMORY_LOG_MAX_BYTES = 10 * 1024 * 1024 # 10MB
|
||||
MEMORY_LOG_BACKUP_COUNT = 5 # Keep 5 backup files
|
||||
|
||||
# Ensure log directory exists
|
||||
os.makedirs(MEMORY_LOG_DIR, exist_ok=True)
|
||||
|
||||
# Create a dedicated logger for memory monitoring
|
||||
memory_logger = logging.getLogger("memory_monitoring")
|
||||
memory_logger.setLevel(logging.INFO)
|
||||
|
||||
# Create a rotating file handler
|
||||
memory_handler = RotatingFileHandler(
|
||||
MEMORY_LOG_FILE,
|
||||
maxBytes=MEMORY_LOG_MAX_BYTES,
|
||||
backupCount=MEMORY_LOG_BACKUP_COUNT,
|
||||
)
|
||||
|
||||
# Create a formatter that includes all relevant information
|
||||
memory_formatter = logging.Formatter(
|
||||
"%(asctime)s [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
|
||||
)
|
||||
memory_handler.setFormatter(memory_formatter)
|
||||
memory_logger.addHandler(memory_handler)
|
||||
else:
|
||||
# Create a null logger when not in container
|
||||
memory_logger = logging.getLogger("memory_monitoring")
|
||||
memory_logger.addHandler(logging.NullHandler())
|
||||
|
||||
|
||||
def emit_process_memory(
|
||||
pid: int, process_name: str, additional_metadata: dict[str, str | int]
|
||||
) -> None:
|
||||
# Skip memory monitoring if not in container
|
||||
if not is_running_in_container():
|
||||
return
|
||||
|
||||
try:
|
||||
process = psutil.Process(pid)
|
||||
memory_info = process.memory_info()
|
||||
cpu_percent = process.cpu_percent(interval=0.1)
|
||||
|
||||
# Build metadata string from additional_metadata dictionary
|
||||
metadata_str = " ".join(
|
||||
[f"{key}={value}" for key, value in additional_metadata.items()]
|
||||
)
|
||||
metadata_str = f" {metadata_str}" if metadata_str else ""
|
||||
|
||||
memory_logger.info(
|
||||
f"PROCESS_MEMORY process_name={process_name} pid={pid} "
|
||||
f"rss_mb={memory_info.rss / (1024 * 1024):.2f} "
|
||||
f"vms_mb={memory_info.vms / (1024 * 1024):.2f} "
|
||||
f"cpu={cpu_percent:.2f}{metadata_str}"
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Error monitoring process memory.")
|
||||
@@ -167,6 +167,16 @@ beat_cloud_tasks: list[dict] = [
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
{
|
||||
"name": f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check-available-tenants",
|
||||
"task": OnyxCeleryTask.CHECK_AVAILABLE_TENANTS,
|
||||
"schedule": timedelta(minutes=10),
|
||||
"options": {
|
||||
"queue": OnyxCeleryQueues.MONITORING,
|
||||
"priority": OnyxCeleryPriority.HIGH,
|
||||
"expires": BEAT_EXPIRES_DEFAULT,
|
||||
},
|
||||
},
|
||||
]
|
||||
|
||||
# tasks that only run self hosted
|
||||
|
||||
@@ -23,6 +23,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
|
||||
from onyx.background.celery.memory_monitoring import emit_process_memory
|
||||
from onyx.background.celery.tasks.indexing.utils import get_unfenced_index_attempt_ids
|
||||
from onyx.background.celery.tasks.indexing.utils import IndexingCallback
|
||||
from onyx.background.celery.tasks.indexing.utils import should_index
|
||||
@@ -984,6 +985,9 @@ def connector_indexing_proxy_task(
|
||||
redis_connector = RedisConnector(tenant_id, cc_pair_id)
|
||||
redis_connector_index = redis_connector.new_index(search_settings_id)
|
||||
|
||||
# Track the last time memory info was emitted
|
||||
last_memory_emit_time = 0.0
|
||||
|
||||
try:
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
index_attempt = get_index_attempt(
|
||||
@@ -1024,6 +1028,23 @@ def connector_indexing_proxy_task(
|
||||
job.release()
|
||||
break
|
||||
|
||||
# log the memory usage for tracking down memory leaks / connector-specific memory issues
|
||||
pid = job.process.pid
|
||||
if pid is not None:
|
||||
# Only emit memory info once per minute (60 seconds)
|
||||
current_time = time.monotonic()
|
||||
if current_time - last_memory_emit_time >= 60.0:
|
||||
emit_process_memory(
|
||||
pid,
|
||||
"indexing_worker",
|
||||
{
|
||||
"cc_pair_id": cc_pair_id,
|
||||
"search_settings_id": search_settings_id,
|
||||
"index_attempt_id": index_attempt_id,
|
||||
},
|
||||
)
|
||||
last_memory_emit_time = current_time
|
||||
|
||||
# if a termination signal is detected, break (exit point will clean up)
|
||||
if self.request.id and redis_connector_index.terminating(self.request.id):
|
||||
task_logger.warning(
|
||||
@@ -1170,6 +1191,7 @@ def connector_indexing_proxy_task(
|
||||
return
|
||||
|
||||
|
||||
# primary
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_FOR_CHECKPOINT_CLEANUP,
|
||||
soft_time_limit=300,
|
||||
@@ -1217,6 +1239,7 @@ def check_for_checkpoint_cleanup(*, tenant_id: str) -> None:
|
||||
)
|
||||
|
||||
|
||||
# light worker
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CLEANUP_CHECKPOINT,
|
||||
bind=True,
|
||||
|
||||
@@ -0,0 +1,199 @@
|
||||
"""
|
||||
Periodic tasks for tenant pre-provisioning.
|
||||
"""
|
||||
import asyncio
|
||||
import datetime
|
||||
import uuid
|
||||
|
||||
from celery import shared_task
|
||||
from celery import Task
|
||||
from redis.lock import Lock as RedisLock
|
||||
|
||||
from ee.onyx.server.tenants.provisioning import setup_tenant
|
||||
from ee.onyx.server.tenants.schema_management import create_schema_if_not_exists
|
||||
from ee.onyx.server.tenants.schema_management import get_current_alembic_version
|
||||
from onyx.background.celery.apps.app_base import task_logger
|
||||
from onyx.configs.app_configs import JOB_TIMEOUT
|
||||
from onyx.configs.app_configs import TARGET_AVAILABLE_TENANTS
|
||||
from onyx.configs.constants import OnyxCeleryPriority
|
||||
from onyx.configs.constants import OnyxCeleryQueues
|
||||
from onyx.configs.constants import OnyxCeleryTask
|
||||
from onyx.configs.constants import OnyxRedisLocks
|
||||
from onyx.db.engine import get_session_with_shared_schema
|
||||
from onyx.db.models import AvailableTenant
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import TENANT_ID_PREFIX
|
||||
|
||||
# Default number of pre-provisioned tenants to maintain
|
||||
DEFAULT_TARGET_AVAILABLE_TENANTS = 5
|
||||
|
||||
# Soft time limit for tenant pre-provisioning tasks (in seconds)
|
||||
_TENANT_PROVISIONING_SOFT_TIME_LIMIT = 60 * 5 # 5 minutes
|
||||
# Hard time limit for tenant pre-provisioning tasks (in seconds)
|
||||
_TENANT_PROVISIONING_TIME_LIMIT = 60 * 10 # 10 minutes
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.CHECK_AVAILABLE_TENANTS,
|
||||
queue=OnyxCeleryQueues.MONITORING,
|
||||
ignore_result=True,
|
||||
soft_time_limit=JOB_TIMEOUT,
|
||||
trail=False,
|
||||
bind=True,
|
||||
)
|
||||
def check_available_tenants(self: Task) -> None:
|
||||
"""
|
||||
Check if we have enough pre-provisioned tenants available.
|
||||
If not, trigger the pre-provisioning of new tenants.
|
||||
"""
|
||||
task_logger.info("STARTING CHECK_AVAILABLE_TENANTS")
|
||||
if not MULTI_TENANT:
|
||||
task_logger.info(
|
||||
"Multi-tenancy is not enabled, skipping tenant pre-provisioning"
|
||||
)
|
||||
return
|
||||
|
||||
r = get_redis_client()
|
||||
lock_check: RedisLock = r.lock(
|
||||
OnyxRedisLocks.CHECK_AVAILABLE_TENANTS_LOCK,
|
||||
timeout=_TENANT_PROVISIONING_SOFT_TIME_LIMIT,
|
||||
)
|
||||
|
||||
# These tasks should never overlap
|
||||
if not lock_check.acquire(blocking=False):
|
||||
task_logger.info(
|
||||
"Skipping check_available_tenants task because it is already running"
|
||||
)
|
||||
return
|
||||
|
||||
try:
|
||||
# Get the current count of available tenants
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
available_tenants_count = db_session.query(AvailableTenant).count()
|
||||
|
||||
# Get the target number of available tenants
|
||||
target_available_tenants = getattr(
|
||||
TARGET_AVAILABLE_TENANTS, "value", DEFAULT_TARGET_AVAILABLE_TENANTS
|
||||
)
|
||||
|
||||
# Calculate how many new tenants we need to provision
|
||||
tenants_to_provision = max(
|
||||
0, target_available_tenants - available_tenants_count
|
||||
)
|
||||
|
||||
task_logger.info(
|
||||
f"Available tenants: {available_tenants_count}, "
|
||||
f"Target: {target_available_tenants}, "
|
||||
f"To provision: {tenants_to_provision}"
|
||||
)
|
||||
|
||||
# Trigger pre-provisioning tasks for each tenant needed
|
||||
for _ in range(tenants_to_provision):
|
||||
from celery import current_app
|
||||
|
||||
current_app.send_task(
|
||||
OnyxCeleryTask.PRE_PROVISION_TENANT,
|
||||
priority=OnyxCeleryPriority.LOW,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
task_logger.exception("Error in check_available_tenants task")
|
||||
|
||||
finally:
|
||||
lock_check.release()
|
||||
|
||||
|
||||
@shared_task(
|
||||
name=OnyxCeleryTask.PRE_PROVISION_TENANT,
|
||||
ignore_result=True,
|
||||
soft_time_limit=_TENANT_PROVISIONING_SOFT_TIME_LIMIT,
|
||||
time_limit=_TENANT_PROVISIONING_TIME_LIMIT,
|
||||
queue=OnyxCeleryQueues.MONITORING,
|
||||
bind=True,
|
||||
)
|
||||
def pre_provision_tenant(self: Task) -> None:
|
||||
"""
|
||||
Pre-provision a new tenant and store it in the NewAvailableTenant table.
|
||||
This function fully sets up the tenant with all necessary configurations,
|
||||
so it's ready to be assigned to a user immediately.
|
||||
"""
|
||||
# The MULTI_TENANT check is now done at the caller level (check_available_tenants)
|
||||
# rather than inside this function
|
||||
|
||||
r = get_redis_client()
|
||||
lock_provision: RedisLock = r.lock(
|
||||
OnyxRedisLocks.PRE_PROVISION_TENANT_LOCK,
|
||||
timeout=_TENANT_PROVISIONING_SOFT_TIME_LIMIT,
|
||||
)
|
||||
|
||||
# Allow multiple pre-provisioning tasks to run, but ensure they don't overlap
|
||||
if not lock_provision.acquire(blocking=False):
|
||||
task_logger.debug(
|
||||
"Skipping pre_provision_tenant task because it is already running"
|
||||
)
|
||||
return
|
||||
|
||||
tenant_id: str | None = None
|
||||
try:
|
||||
# Generate a new tenant ID
|
||||
tenant_id = TENANT_ID_PREFIX + str(uuid.uuid4())
|
||||
task_logger.info(f"Pre-provisioning tenant: {tenant_id}")
|
||||
|
||||
# Create the schema for the new tenant
|
||||
schema_created = create_schema_if_not_exists(tenant_id)
|
||||
if schema_created:
|
||||
task_logger.debug(f"Created schema for tenant: {tenant_id}")
|
||||
else:
|
||||
task_logger.debug(f"Schema already exists for tenant: {tenant_id}")
|
||||
|
||||
# Set up the tenant with all necessary configurations
|
||||
task_logger.debug(f"Setting up tenant configuration: {tenant_id}")
|
||||
asyncio.run(setup_tenant(tenant_id))
|
||||
task_logger.debug(f"Tenant configuration completed: {tenant_id}")
|
||||
|
||||
# Get the current Alembic version
|
||||
alembic_version = get_current_alembic_version(tenant_id)
|
||||
task_logger.debug(
|
||||
f"Tenant {tenant_id} using Alembic version: {alembic_version}"
|
||||
)
|
||||
|
||||
# Store the pre-provisioned tenant in the database
|
||||
task_logger.debug(f"Storing pre-provisioned tenant in database: {tenant_id}")
|
||||
with get_session_with_shared_schema() as db_session:
|
||||
# Use a transaction to ensure atomicity
|
||||
db_session.begin()
|
||||
try:
|
||||
new_tenant = AvailableTenant(
|
||||
tenant_id=tenant_id,
|
||||
alembic_version=alembic_version,
|
||||
date_created=datetime.datetime.now(),
|
||||
)
|
||||
db_session.add(new_tenant)
|
||||
db_session.commit()
|
||||
task_logger.info(f"Successfully pre-provisioned tenant: {tenant_id}")
|
||||
except Exception:
|
||||
db_session.rollback()
|
||||
task_logger.error(
|
||||
f"Failed to store pre-provisioned tenant: {tenant_id}",
|
||||
exc_info=True,
|
||||
)
|
||||
raise
|
||||
|
||||
except Exception:
|
||||
task_logger.error("Error in pre_provision_tenant task", exc_info=True)
|
||||
# If we have a tenant_id, attempt to rollback any partially completed provisioning
|
||||
if tenant_id:
|
||||
task_logger.info(
|
||||
f"Rolling back failed tenant provisioning for: {tenant_id}"
|
||||
)
|
||||
try:
|
||||
from ee.onyx.server.tenants.provisioning import (
|
||||
rollback_tenant_provisioning,
|
||||
)
|
||||
|
||||
asyncio.run(rollback_tenant_provisioning(tenant_id))
|
||||
except Exception:
|
||||
task_logger.exception(f"Error during rollback for tenant: {tenant_id}")
|
||||
finally:
|
||||
lock_provision.release()
|
||||
@@ -15,6 +15,8 @@ from onyx.chat.stream_processing.answer_response_handler import (
|
||||
from onyx.chat.tool_handling.tool_response_handler import ToolResponseHandler
|
||||
|
||||
|
||||
# This is Legacy code that is not used anymore.
|
||||
# It is kept here for reference.
|
||||
class LLMResponseHandlerManager:
|
||||
"""
|
||||
This class is responsible for postprocessing the LLM response stream.
|
||||
|
||||
@@ -1,10 +1,13 @@
|
||||
from collections import OrderedDict
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from collections.abc import Mapping
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
from typing import Literal
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import ConfigDict
|
||||
@@ -44,9 +47,44 @@ class LlmDoc(BaseModel):
|
||||
|
||||
|
||||
class SubQuestionIdentifier(BaseModel):
|
||||
"""None represents references to objects in the original flow. To our understanding,
|
||||
these will not be None in the packets returned from agent search.
|
||||
"""
|
||||
|
||||
level: int | None = None
|
||||
level_question_num: int | None = None
|
||||
|
||||
@staticmethod
|
||||
def make_dict_by_level(
|
||||
original_dict: Mapping[tuple[int, int], "SubQuestionIdentifier"]
|
||||
) -> dict[int, list["SubQuestionIdentifier"]]:
|
||||
"""returns a dict of level to object list (sorted by level_question_num)
|
||||
Ordering is asc for readability.
|
||||
"""
|
||||
|
||||
# organize by level, then sort ascending by question_index
|
||||
level_dict: dict[int, list[SubQuestionIdentifier]] = {}
|
||||
|
||||
# group by level
|
||||
for k, obj in original_dict.items():
|
||||
level = k[0]
|
||||
if level not in level_dict:
|
||||
level_dict[level] = []
|
||||
level_dict[level].append(obj)
|
||||
|
||||
# for each level, sort the group
|
||||
for k2, value2 in level_dict.items():
|
||||
# we need to handle the none case due to SubQuestionIdentifier typing
|
||||
# level_question_num as int | None, even though it should never be None here.
|
||||
level_dict[k2] = sorted(
|
||||
value2,
|
||||
key=lambda x: (x.level_question_num is None, x.level_question_num),
|
||||
)
|
||||
|
||||
# sort by level
|
||||
sorted_dict = OrderedDict(sorted(level_dict.items()))
|
||||
return sorted_dict
|
||||
|
||||
|
||||
# First chunk of info for streaming QA
|
||||
class QADocsResponse(RetrievalDocs, SubQuestionIdentifier):
|
||||
@@ -336,6 +374,8 @@ class AgentAnswerPiece(SubQuestionIdentifier):
|
||||
|
||||
|
||||
class SubQuestionPiece(SubQuestionIdentifier):
|
||||
"""Refined sub questions generated from the initial user question."""
|
||||
|
||||
sub_question: str
|
||||
|
||||
|
||||
@@ -347,13 +387,13 @@ class RefinedAnswerImprovement(BaseModel):
|
||||
refined_answer_improvement: bool
|
||||
|
||||
|
||||
AgentSearchPacket = (
|
||||
AgentSearchPacket = Union[
|
||||
SubQuestionPiece
|
||||
| AgentAnswerPiece
|
||||
| SubQueryPiece
|
||||
| ExtendedToolResponse
|
||||
| RefinedAnswerImprovement
|
||||
)
|
||||
]
|
||||
|
||||
AnswerPacket = (
|
||||
AnswerQuestionPossibleReturn | AgentSearchPacket | ToolCallKickoff | ToolResponse
|
||||
|
||||
@@ -90,97 +90,97 @@ class CitationProcessor:
|
||||
next(group for group in citation.groups() if group is not None)
|
||||
)
|
||||
|
||||
if 1 <= numerical_value <= self.max_citation_num:
|
||||
context_llm_doc = self.context_docs[numerical_value - 1]
|
||||
final_citation_num = self.final_order_mapping[
|
||||
if not (1 <= numerical_value <= self.max_citation_num):
|
||||
continue
|
||||
|
||||
context_llm_doc = self.context_docs[numerical_value - 1]
|
||||
final_citation_num = self.final_order_mapping[
|
||||
context_llm_doc.document_id
|
||||
]
|
||||
|
||||
if final_citation_num not in self.citation_order:
|
||||
self.citation_order.append(final_citation_num)
|
||||
|
||||
citation_order_idx = self.citation_order.index(final_citation_num) + 1
|
||||
|
||||
# get the value that was displayed to user, should always
|
||||
# be in the display_doc_order_dict. But check anyways
|
||||
if context_llm_doc.document_id in self.display_order_mapping:
|
||||
displayed_citation_num = self.display_order_mapping[
|
||||
context_llm_doc.document_id
|
||||
]
|
||||
|
||||
if final_citation_num not in self.citation_order:
|
||||
self.citation_order.append(final_citation_num)
|
||||
|
||||
citation_order_idx = (
|
||||
self.citation_order.index(final_citation_num) + 1
|
||||
else:
|
||||
displayed_citation_num = final_citation_num
|
||||
logger.warning(
|
||||
f"Doc {context_llm_doc.document_id} not in display_doc_order_dict. Used LLM citation number instead."
|
||||
)
|
||||
|
||||
# get the value that was displayed to user, should always
|
||||
# be in the display_doc_order_dict. But check anyways
|
||||
if context_llm_doc.document_id in self.display_order_mapping:
|
||||
displayed_citation_num = self.display_order_mapping[
|
||||
context_llm_doc.document_id
|
||||
]
|
||||
else:
|
||||
displayed_citation_num = final_citation_num
|
||||
logger.warning(
|
||||
f"Doc {context_llm_doc.document_id} not in display_doc_order_dict. Used LLM citation number instead."
|
||||
)
|
||||
|
||||
# Skip consecutive citations of the same work
|
||||
if final_citation_num in self.current_citations:
|
||||
start, end = citation.span()
|
||||
real_start = length_to_add + start
|
||||
diff = end - start
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: length_to_add + start]
|
||||
+ self.curr_segment[real_start + diff :]
|
||||
)
|
||||
length_to_add -= diff
|
||||
continue
|
||||
|
||||
# Handle edge case where LLM outputs citation itself
|
||||
if self.curr_segment.startswith("[["):
|
||||
match = re.match(r"\[\[(\d+)\]\]", self.curr_segment)
|
||||
if match:
|
||||
try:
|
||||
doc_id = int(match.group(1))
|
||||
context_llm_doc = self.context_docs[doc_id - 1]
|
||||
yield CitationInfo(
|
||||
# citation_num is now the number post initial ranking, i.e. as displayed to user
|
||||
citation_num=displayed_citation_num,
|
||||
document_id=context_llm_doc.document_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Manual LLM citation didn't properly cite documents {e}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Manual LLM citation wasn't able to close brackets"
|
||||
)
|
||||
continue
|
||||
|
||||
link = context_llm_doc.link
|
||||
|
||||
self.past_cite_count = len(self.llm_out)
|
||||
self.current_citations.append(final_citation_num)
|
||||
|
||||
if citation_order_idx not in self.cited_inds:
|
||||
self.cited_inds.add(citation_order_idx)
|
||||
yield CitationInfo(
|
||||
# citation number is now the one that was displayed to user
|
||||
citation_num=displayed_citation_num,
|
||||
document_id=context_llm_doc.document_id,
|
||||
)
|
||||
|
||||
# Skip consecutive citations of the same work
|
||||
if final_citation_num in self.current_citations:
|
||||
start, end = citation.span()
|
||||
if link:
|
||||
prev_length = len(self.curr_segment)
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: start + length_to_add]
|
||||
+ f"[[{displayed_citation_num}]]({link})" # use the value that was displayed to user
|
||||
+ self.curr_segment[end + length_to_add :]
|
||||
)
|
||||
length_to_add += len(self.curr_segment) - prev_length
|
||||
else:
|
||||
prev_length = len(self.curr_segment)
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: start + length_to_add]
|
||||
+ f"[[{displayed_citation_num}]]()" # use the value that was displayed to user
|
||||
+ self.curr_segment[end + length_to_add :]
|
||||
)
|
||||
length_to_add += len(self.curr_segment) - prev_length
|
||||
real_start = length_to_add + start
|
||||
diff = end - start
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: length_to_add + start]
|
||||
+ self.curr_segment[real_start + diff :]
|
||||
)
|
||||
length_to_add -= diff
|
||||
continue
|
||||
|
||||
last_citation_end = end + length_to_add
|
||||
# Handle edge case where LLM outputs citation itself
|
||||
if self.curr_segment.startswith("[["):
|
||||
match = re.match(r"\[\[(\d+)\]\]", self.curr_segment)
|
||||
if match:
|
||||
try:
|
||||
doc_id = int(match.group(1))
|
||||
context_llm_doc = self.context_docs[doc_id - 1]
|
||||
yield CitationInfo(
|
||||
# citation_num is now the number post initial ranking, i.e. as displayed to user
|
||||
citation_num=displayed_citation_num,
|
||||
document_id=context_llm_doc.document_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Manual LLM citation didn't properly cite documents {e}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Manual LLM citation wasn't able to close brackets"
|
||||
)
|
||||
continue
|
||||
|
||||
link = context_llm_doc.link
|
||||
|
||||
self.past_cite_count = len(self.llm_out)
|
||||
self.current_citations.append(final_citation_num)
|
||||
|
||||
if citation_order_idx not in self.cited_inds:
|
||||
self.cited_inds.add(citation_order_idx)
|
||||
yield CitationInfo(
|
||||
# citation number is now the one that was displayed to user
|
||||
citation_num=displayed_citation_num,
|
||||
document_id=context_llm_doc.document_id,
|
||||
)
|
||||
|
||||
start, end = citation.span()
|
||||
if link:
|
||||
prev_length = len(self.curr_segment)
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: start + length_to_add]
|
||||
+ f"[[{displayed_citation_num}]]({link})" # use the value that was displayed to user
|
||||
+ self.curr_segment[end + length_to_add :]
|
||||
)
|
||||
length_to_add += len(self.curr_segment) - prev_length
|
||||
else:
|
||||
prev_length = len(self.curr_segment)
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: start + length_to_add]
|
||||
+ f"[[{displayed_citation_num}]]()" # use the value that was displayed to user
|
||||
+ self.curr_segment[end + length_to_add :]
|
||||
)
|
||||
length_to_add += len(self.curr_segment) - prev_length
|
||||
|
||||
last_citation_end = end + length_to_add
|
||||
|
||||
if last_citation_end > 0:
|
||||
result += self.curr_segment[:last_citation_end]
|
||||
|
||||
@@ -217,20 +217,20 @@ AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION = int(
|
||||
)
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION = 4 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION = 6 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION = 30 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION = 40 # in seconds
|
||||
AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION
|
||||
)
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION = 5 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION = 10 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION
|
||||
@@ -243,13 +243,13 @@ AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION = int(
|
||||
)
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION = 5 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION = 15 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION = 30 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION = 45 # in seconds
|
||||
AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION
|
||||
@@ -333,4 +333,45 @@ AGENT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION = int(
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_MAX_TOKENS_VALIDATION = 4
|
||||
AGENT_MAX_TOKENS_VALIDATION = int(
|
||||
os.environ.get("AGENT_MAX_TOKENS_VALIDATION") or AGENT_DEFAULT_MAX_TOKENS_VALIDATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_MAX_TOKENS_SUBANSWER_GENERATION = 256
|
||||
AGENT_MAX_TOKENS_SUBANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_MAX_TOKENS_SUBANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_MAX_TOKENS_SUBANSWER_GENERATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_MAX_TOKENS_ANSWER_GENERATION = 1024
|
||||
AGENT_MAX_TOKENS_ANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_MAX_TOKENS_ANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_MAX_TOKENS_ANSWER_GENERATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_MAX_TOKENS_SUBQUESTION_GENERATION = 256
|
||||
AGENT_MAX_TOKENS_SUBQUESTION_GENERATION = int(
|
||||
os.environ.get("AGENT_MAX_TOKENS_SUBQUESTION_GENERATION")
|
||||
or AGENT_DEFAULT_MAX_TOKENS_SUBQUESTION_GENERATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_MAX_TOKENS_ENTITY_TERM_EXTRACTION = 1024
|
||||
AGENT_MAX_TOKENS_ENTITY_TERM_EXTRACTION = int(
|
||||
os.environ.get("AGENT_MAX_TOKENS_ENTITY_TERM_EXTRACTION")
|
||||
or AGENT_DEFAULT_MAX_TOKENS_ENTITY_TERM_EXTRACTION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_MAX_TOKENS_SUBQUERY_GENERATION = 64
|
||||
AGENT_MAX_TOKENS_SUBQUERY_GENERATION = int(
|
||||
os.environ.get("AGENT_MAX_TOKENS_SUBQUERY_GENERATION")
|
||||
or AGENT_DEFAULT_MAX_TOKENS_SUBQUERY_GENERATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_MAX_TOKENS_HISTORY_SUMMARY = 128
|
||||
AGENT_MAX_TOKENS_HISTORY_SUMMARY = int(
|
||||
os.environ.get("AGENT_MAX_TOKENS_HISTORY_SUMMARY")
|
||||
or AGENT_DEFAULT_MAX_TOKENS_HISTORY_SUMMARY
|
||||
)
|
||||
|
||||
GRAPH_VERSION_NAME: str = "a"
|
||||
|
||||
@@ -8,6 +8,9 @@ from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import DocumentIndexType
|
||||
from onyx.configs.constants import QueryHistoryType
|
||||
from onyx.file_processing.enums import HtmlBasedConnectorTransformLinksStrategy
|
||||
from onyx.prompts.image_analysis import DEFAULT_IMAGE_ANALYSIS_SYSTEM_PROMPT
|
||||
from onyx.prompts.image_analysis import DEFAULT_IMAGE_SUMMARIZATION_SYSTEM_PROMPT
|
||||
from onyx.prompts.image_analysis import DEFAULT_IMAGE_SUMMARIZATION_USER_PROMPT
|
||||
|
||||
#####
|
||||
# App Configs
|
||||
@@ -643,3 +646,28 @@ MOCK_LLM_RESPONSE = (
|
||||
|
||||
|
||||
DEFAULT_IMAGE_ANALYSIS_MAX_SIZE_MB = 20
|
||||
|
||||
# Number of pre-provisioned tenants to maintain
|
||||
TARGET_AVAILABLE_TENANTS = int(os.environ.get("TARGET_AVAILABLE_TENANTS", "5"))
|
||||
|
||||
|
||||
# Image summarization configuration
|
||||
IMAGE_SUMMARIZATION_SYSTEM_PROMPT = os.environ.get(
|
||||
"IMAGE_SUMMARIZATION_SYSTEM_PROMPT",
|
||||
DEFAULT_IMAGE_SUMMARIZATION_SYSTEM_PROMPT,
|
||||
)
|
||||
|
||||
# The user prompt for image summarization - the image filename will be automatically prepended
|
||||
IMAGE_SUMMARIZATION_USER_PROMPT = os.environ.get(
|
||||
"IMAGE_SUMMARIZATION_USER_PROMPT",
|
||||
DEFAULT_IMAGE_SUMMARIZATION_USER_PROMPT,
|
||||
)
|
||||
|
||||
IMAGE_ANALYSIS_SYSTEM_PROMPT = os.environ.get(
|
||||
"IMAGE_ANALYSIS_SYSTEM_PROMPT",
|
||||
DEFAULT_IMAGE_ANALYSIS_SYSTEM_PROMPT,
|
||||
)
|
||||
|
||||
DISABLE_AUTO_AUTH_REFRESH = (
|
||||
os.environ.get("DISABLE_AUTO_AUTH_REFRESH", "").lower() == "true"
|
||||
)
|
||||
|
||||
@@ -76,6 +76,7 @@ KV_REINDEX_KEY = "needs_reindexing"
|
||||
KV_SEARCH_SETTINGS = "search_settings"
|
||||
KV_UNSTRUCTURED_API_KEY = "unstructured_api_key"
|
||||
KV_USER_STORE_KEY = "INVITED_USERS"
|
||||
KV_PENDING_USERS_KEY = "PENDING_USERS"
|
||||
KV_NO_AUTH_USER_PREFERENCES_KEY = "no_auth_user_preferences"
|
||||
KV_CRED_KEY = "credential_id_{}"
|
||||
KV_GMAIL_CRED_KEY = "gmail_app_credential"
|
||||
@@ -321,6 +322,8 @@ class OnyxRedisLocks:
|
||||
"da_lock:check_connector_external_group_sync_beat"
|
||||
)
|
||||
MONITOR_BACKGROUND_PROCESSES_LOCK = "da_lock:monitor_background_processes"
|
||||
CHECK_AVAILABLE_TENANTS_LOCK = "da_lock:check_available_tenants"
|
||||
PRE_PROVISION_TENANT_LOCK = "da_lock:pre_provision_tenant"
|
||||
|
||||
CONNECTOR_DOC_PERMISSIONS_SYNC_LOCK_PREFIX = (
|
||||
"da_lock:connector_doc_permissions_sync"
|
||||
@@ -383,6 +386,7 @@ class OnyxCeleryTask:
|
||||
CLOUD_MONITOR_CELERY_QUEUES = (
|
||||
f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_monitor_celery_queues"
|
||||
)
|
||||
CHECK_AVAILABLE_TENANTS = f"{ONYX_CLOUD_CELERY_TASK_PREFIX}_check_available_tenants"
|
||||
|
||||
CHECK_FOR_CONNECTOR_DELETION = "check_for_connector_deletion_task"
|
||||
CHECK_FOR_VESPA_SYNC_TASK = "check_for_vespa_sync_task"
|
||||
@@ -399,6 +403,9 @@ class OnyxCeleryTask:
|
||||
MONITOR_BACKGROUND_PROCESSES = "monitor_background_processes"
|
||||
MONITOR_CELERY_QUEUES = "monitor_celery_queues"
|
||||
|
||||
# Tenant pre-provisioning
|
||||
PRE_PROVISION_TENANT = "pre_provision_tenant"
|
||||
|
||||
KOMBU_MESSAGE_CLEANUP_TASK = "kombu_message_cleanup_task"
|
||||
CONNECTOR_PERMISSION_SYNC_GENERATOR_TASK = (
|
||||
"connector_permission_sync_generator_task"
|
||||
|
||||
@@ -66,9 +66,6 @@ _RESTRICTIONS_EXPANSION_FIELDS = [
|
||||
_SLIM_DOC_BATCH_SIZE = 5000
|
||||
|
||||
_ATTACHMENT_EXTENSIONS_TO_FILTER_OUT = [
|
||||
"png",
|
||||
"jpg",
|
||||
"jpeg",
|
||||
"gif",
|
||||
"mp4",
|
||||
"mov",
|
||||
@@ -266,6 +263,7 @@ class ConfluenceConnector(
|
||||
result = process_attachment(
|
||||
self.confluence_client,
|
||||
attachment,
|
||||
page_id,
|
||||
page_title,
|
||||
self.image_analysis_llm,
|
||||
)
|
||||
@@ -305,7 +303,9 @@ class ConfluenceConnector(
|
||||
|
||||
# Create the document
|
||||
return Document(
|
||||
id=build_confluence_document_id(self.wiki_base, page_id, self.is_cloud),
|
||||
id=build_confluence_document_id(
|
||||
self.wiki_base, page["_links"]["webui"], self.is_cloud
|
||||
),
|
||||
sections=sections,
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
semantic_identifier=page_title,
|
||||
@@ -367,6 +367,7 @@ class ConfluenceConnector(
|
||||
response = convert_attachment_to_content(
|
||||
confluence_client=self.confluence_client,
|
||||
attachment=attachment,
|
||||
page_id=page["id"],
|
||||
page_context=confluence_xml,
|
||||
llm=self.image_analysis_llm,
|
||||
)
|
||||
@@ -376,7 +377,7 @@ class ConfluenceConnector(
|
||||
content_text, file_storage_name = response
|
||||
|
||||
object_url = build_confluence_document_id(
|
||||
self.wiki_base, page["_links"]["webui"], self.is_cloud
|
||||
self.wiki_base, attachment["_links"]["webui"], self.is_cloud
|
||||
)
|
||||
|
||||
if content_text:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import io
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
@@ -19,17 +18,11 @@ from requests import HTTPError
|
||||
|
||||
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_ID
|
||||
from ee.onyx.configs.app_configs import OAUTH_CONFLUENCE_CLOUD_CLIENT_SECRET
|
||||
from onyx.configs.app_configs import (
|
||||
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
|
||||
)
|
||||
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
|
||||
from onyx.connectors.confluence.utils import _handle_http_error
|
||||
from onyx.connectors.confluence.utils import confluence_refresh_tokens
|
||||
from onyx.connectors.confluence.utils import get_start_param_from_url
|
||||
from onyx.connectors.confluence.utils import update_param_in_path
|
||||
from onyx.connectors.confluence.utils import validate_attachment_filetype
|
||||
from onyx.connectors.interfaces import CredentialsProviderInterface
|
||||
from onyx.file_processing.extract_file_text import extract_file_text
|
||||
from onyx.file_processing.html_utils import format_document_soup
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.utils.logger import setup_logger
|
||||
@@ -808,65 +801,6 @@ def _get_user(confluence_client: OnyxConfluence, user_id: str) -> str:
|
||||
return _USER_ID_TO_DISPLAY_NAME_CACHE.get(user_id) or _USER_NOT_FOUND
|
||||
|
||||
|
||||
def attachment_to_content(
|
||||
confluence_client: OnyxConfluence,
|
||||
attachment: dict[str, Any],
|
||||
parent_content_id: str | None = None,
|
||||
) -> str | None:
|
||||
"""If it returns None, assume that we should skip this attachment."""
|
||||
if not validate_attachment_filetype(attachment):
|
||||
return None
|
||||
|
||||
if "api.atlassian.com" in confluence_client.url:
|
||||
# https://developer.atlassian.com/cloud/confluence/rest/v1/api-group-content---attachments/#api-wiki-rest-api-content-id-child-attachment-attachmentid-download-get
|
||||
if not parent_content_id:
|
||||
logger.warning(
|
||||
"parent_content_id is required to download attachments from Confluence Cloud!"
|
||||
)
|
||||
return None
|
||||
|
||||
download_link = (
|
||||
confluence_client.url
|
||||
+ f"/rest/api/content/{parent_content_id}/child/attachment/{attachment['id']}/download"
|
||||
)
|
||||
else:
|
||||
download_link = confluence_client.url + attachment["_links"]["download"]
|
||||
|
||||
attachment_size = attachment["extensions"]["fileSize"]
|
||||
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Skipping {download_link} due to size. "
|
||||
f"size={attachment_size} "
|
||||
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD}"
|
||||
)
|
||||
return None
|
||||
|
||||
logger.info(f"_attachment_to_content - _session.get: link={download_link}")
|
||||
|
||||
# why are we using session.get here? we probably won't retry these ... is that ok?
|
||||
response = confluence_client._session.get(download_link)
|
||||
if response.status_code != 200:
|
||||
logger.warning(
|
||||
f"Failed to fetch {download_link} with invalid status code {response.status_code}"
|
||||
)
|
||||
return None
|
||||
|
||||
extracted_text = extract_file_text(
|
||||
io.BytesIO(response.content),
|
||||
file_name=attachment["title"],
|
||||
break_on_unprocessable=False,
|
||||
)
|
||||
if len(extracted_text) > CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Skipping {download_link} due to char count. "
|
||||
f"char count={len(extracted_text)} "
|
||||
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD}"
|
||||
)
|
||||
return None
|
||||
|
||||
return extracted_text
|
||||
|
||||
|
||||
def extract_text_from_confluence_html(
|
||||
confluence_client: OnyxConfluence,
|
||||
confluence_object: dict[str, Any],
|
||||
|
||||
@@ -22,6 +22,7 @@ from sqlalchemy.orm import Session
|
||||
from onyx.configs.app_configs import (
|
||||
CONFLUENCE_CONNECTOR_ATTACHMENT_CHAR_COUNT_THRESHOLD,
|
||||
)
|
||||
from onyx.configs.app_configs import CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD
|
||||
from onyx.configs.constants import FileOrigin
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -84,25 +85,35 @@ class AttachmentProcessingResult(BaseModel):
|
||||
error: str | None = None
|
||||
|
||||
|
||||
def _download_attachment(
|
||||
confluence_client: "OnyxConfluence", attachment: dict[str, Any]
|
||||
) -> bytes | None:
|
||||
"""
|
||||
Retrieves the raw bytes of an attachment from Confluence. Returns None on error.
|
||||
"""
|
||||
download_link = confluence_client.url + attachment["_links"]["download"]
|
||||
resp = confluence_client._session.get(download_link)
|
||||
if resp.status_code != 200:
|
||||
logger.warning(
|
||||
f"Failed to fetch {download_link} with status code {resp.status_code}"
|
||||
def _make_attachment_link(
|
||||
confluence_client: "OnyxConfluence",
|
||||
attachment: dict[str, Any],
|
||||
parent_content_id: str | None = None,
|
||||
) -> str | None:
|
||||
download_link = ""
|
||||
|
||||
if "api.atlassian.com" in confluence_client.url:
|
||||
# https://developer.atlassian.com/cloud/confluence/rest/v1/api-group-content---attachments/#api-wiki-rest-api-content-id-child-attachment-attachmentid-download-get
|
||||
if not parent_content_id:
|
||||
logger.warning(
|
||||
"parent_content_id is required to download attachments from Confluence Cloud!"
|
||||
)
|
||||
return None
|
||||
|
||||
download_link = (
|
||||
confluence_client.url
|
||||
+ f"/rest/api/content/{parent_content_id}/child/attachment/{attachment['id']}/download"
|
||||
)
|
||||
return None
|
||||
return resp.content
|
||||
else:
|
||||
download_link = confluence_client.url + attachment["_links"]["download"]
|
||||
|
||||
return download_link
|
||||
|
||||
|
||||
def process_attachment(
|
||||
confluence_client: "OnyxConfluence",
|
||||
attachment: dict[str, Any],
|
||||
parent_content_id: str | None,
|
||||
page_context: str,
|
||||
llm: LLM | None,
|
||||
) -> AttachmentProcessingResult:
|
||||
@@ -122,11 +133,52 @@ def process_attachment(
|
||||
error=f"Unsupported file type: {media_type}",
|
||||
)
|
||||
|
||||
# Download the attachment
|
||||
raw_bytes = _download_attachment(confluence_client, attachment)
|
||||
if raw_bytes is None:
|
||||
attachment_link = _make_attachment_link(
|
||||
confluence_client, attachment, parent_content_id
|
||||
)
|
||||
if not attachment_link:
|
||||
return AttachmentProcessingResult(
|
||||
text=None, file_name=None, error="Failed to download attachment"
|
||||
text=None, file_name=None, error="Failed to make attachment link"
|
||||
)
|
||||
|
||||
attachment_size = attachment["extensions"]["fileSize"]
|
||||
|
||||
if not media_type.startswith("image/") or not llm:
|
||||
if attachment_size > CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD:
|
||||
logger.warning(
|
||||
f"Skipping {attachment_link} due to size. "
|
||||
f"size={attachment_size} "
|
||||
f"threshold={CONFLUENCE_CONNECTOR_ATTACHMENT_SIZE_THRESHOLD}"
|
||||
)
|
||||
return AttachmentProcessingResult(
|
||||
text=None,
|
||||
file_name=None,
|
||||
error=f"Attachment text too long: {attachment_size} chars",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"Downloading attachment: "
|
||||
f"title={attachment['title']} "
|
||||
f"length={attachment_size} "
|
||||
f"link={attachment_link}"
|
||||
)
|
||||
|
||||
# Download the attachment
|
||||
resp: requests.Response = confluence_client._session.get(attachment_link)
|
||||
if resp.status_code != 200:
|
||||
logger.warning(
|
||||
f"Failed to fetch {attachment_link} with status code {resp.status_code}"
|
||||
)
|
||||
return AttachmentProcessingResult(
|
||||
text=None,
|
||||
file_name=None,
|
||||
error=f"Attachment download status code is {resp.status_code}",
|
||||
)
|
||||
|
||||
raw_bytes = resp.content
|
||||
if not raw_bytes:
|
||||
return AttachmentProcessingResult(
|
||||
text=None, file_name=None, error="attachment.content is None"
|
||||
)
|
||||
|
||||
# Process image attachments with LLM if available
|
||||
@@ -249,6 +301,7 @@ def _process_text_attachment(
|
||||
def convert_attachment_to_content(
|
||||
confluence_client: "OnyxConfluence",
|
||||
attachment: dict[str, Any],
|
||||
page_id: str,
|
||||
page_context: str,
|
||||
llm: LLM | None,
|
||||
) -> tuple[str | None, str | None] | None:
|
||||
@@ -266,7 +319,9 @@ def convert_attachment_to_content(
|
||||
)
|
||||
return None
|
||||
|
||||
result = process_attachment(confluence_client, attachment, page_context, llm)
|
||||
result = process_attachment(
|
||||
confluence_client, attachment, page_id, page_context, llm
|
||||
)
|
||||
if result.error is not None:
|
||||
logger.warning(
|
||||
f"Attachment {attachment['title']} encountered error: {result.error}"
|
||||
|
||||
@@ -228,10 +228,15 @@ class GitbookConnector(LoadConnector, PollConnector):
|
||||
raise ConnectorMissingCredentialError("GitBook")
|
||||
|
||||
try:
|
||||
content = self.client.get(f"/spaces/{self.space_id}/content")
|
||||
content = self.client.get(f"/spaces/{self.space_id}/content/pages")
|
||||
pages: list[dict[str, Any]] = content.get("pages", [])
|
||||
current_batch: list[Document] = []
|
||||
|
||||
logger.info(f"Found {len(pages)} root pages.")
|
||||
logger.info(
|
||||
f"First 20 Page Ids: {[page.get('id', 'Unknown') for page in pages[:20]]}"
|
||||
)
|
||||
|
||||
while pages:
|
||||
page = pages.pop(0)
|
||||
|
||||
|
||||
@@ -316,7 +316,9 @@ class GoogleDriveConnector(
|
||||
# validate that the user has access to the drive APIs by performing a simple
|
||||
# request and checking for a 401
|
||||
try:
|
||||
retry_builder()(get_root_folder_id)(drive_service)
|
||||
# default is ~17mins of retries, don't do that here for cases so we don't
|
||||
# waste 17mins everytime we run into a user without access to drive APIs
|
||||
retry_builder(tries=3, delay=1)(get_root_folder_id)(drive_service)
|
||||
except HttpError as e:
|
||||
if e.status_code == 401:
|
||||
# fail gracefully, let the other impersonations continue
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
@@ -204,6 +205,15 @@ class ConnectorCheckpoint(BaseModel):
|
||||
def build_dummy_checkpoint(cls) -> "ConnectorCheckpoint":
|
||||
return ConnectorCheckpoint(checkpoint_content={}, has_more=True)
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""String representation of the checkpoint, with truncation for large checkpoint content."""
|
||||
MAX_CHECKPOINT_CONTENT_CHARS = 1000
|
||||
|
||||
content_str = json.dumps(self.checkpoint_content)
|
||||
if len(content_str) > MAX_CHECKPOINT_CONTENT_CHARS:
|
||||
content_str = content_str[: MAX_CHECKPOINT_CONTENT_CHARS - 3] + "..."
|
||||
return f"ConnectorCheckpoint(checkpoint_content={content_str}, has_more={self.has_more})"
|
||||
|
||||
|
||||
class DocumentFailure(BaseModel):
|
||||
document_id: str
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import time
|
||||
from collections.abc import Generator
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import fields
|
||||
@@ -32,6 +31,7 @@ from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_NOTION_PAGE_SIZE = 100
|
||||
_NOTION_CALL_TIMEOUT = 30 # 30 seconds
|
||||
|
||||
|
||||
@@ -537,9 +537,9 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
"""
|
||||
filtered_pages: list[NotionPage] = []
|
||||
for page in pages:
|
||||
compare_time = time.mktime(
|
||||
time.strptime(page[filter_field], "%Y-%m-%dT%H:%M:%S.000Z")
|
||||
)
|
||||
# Parse ISO 8601 timestamp and convert to UTC epoch time
|
||||
timestamp = page[filter_field].replace(".000Z", "+00:00")
|
||||
compare_time = datetime.fromisoformat(timestamp).timestamp()
|
||||
if compare_time > start and compare_time <= end:
|
||||
filtered_pages += [NotionPage(**page)]
|
||||
return filtered_pages
|
||||
@@ -578,7 +578,7 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
|
||||
query_dict = {
|
||||
"filter": {"property": "object", "value": "page"},
|
||||
"page_size": self.batch_size,
|
||||
"page_size": _NOTION_PAGE_SIZE,
|
||||
}
|
||||
while True:
|
||||
db_res = self._search_notion(query_dict)
|
||||
@@ -604,7 +604,7 @@ class NotionConnector(LoadConnector, PollConnector):
|
||||
return
|
||||
|
||||
query_dict = {
|
||||
"page_size": self.batch_size,
|
||||
"page_size": _NOTION_PAGE_SIZE,
|
||||
"sort": {"timestamp": "last_edited_time", "direction": "descending"},
|
||||
"filter": {"property": "object", "value": "page"},
|
||||
}
|
||||
|
||||
@@ -48,10 +48,12 @@ class SalesforceConnector(LoadConnector, PollConnector, SlimConnector):
|
||||
self,
|
||||
credentials: dict[str, Any],
|
||||
) -> dict[str, Any] | None:
|
||||
domain = "test" if credentials.get("is_sandbox") else None
|
||||
self._sf_client = Salesforce(
|
||||
username=credentials["sf_username"],
|
||||
password=credentials["sf_password"],
|
||||
security_token=credentials["sf_security_token"],
|
||||
domain=domain,
|
||||
)
|
||||
return None
|
||||
|
||||
|
||||
@@ -674,7 +674,7 @@ class SlackConnector(SlimConnector, CheckpointConnector):
|
||||
"""
|
||||
1. Verify the bot token is valid for the workspace (via auth_test).
|
||||
2. Ensure the bot has enough scope to list channels.
|
||||
3. Check that every channel specified in self.channels exists.
|
||||
3. Check that every channel specified in self.channels exists (only when regex is not enabled).
|
||||
"""
|
||||
if self.client is None:
|
||||
raise ConnectorMissingCredentialError("Slack credentials not loaded.")
|
||||
@@ -706,8 +706,8 @@ class SlackConnector(SlimConnector, CheckpointConnector):
|
||||
f"Slack API returned a failure: {error_msg}"
|
||||
)
|
||||
|
||||
# 3) If channels are specified, verify each is accessible
|
||||
if self.channels:
|
||||
# 3) If channels are specified and regex is not enabled, verify each is accessible
|
||||
if self.channels and not self.channel_regex_enabled:
|
||||
accessible_channels = get_channels(
|
||||
client=self.client,
|
||||
exclude_archived=True,
|
||||
|
||||
@@ -30,6 +30,7 @@ class VisionEnabledConnector:
|
||||
Sets self.image_analysis_llm to the LLM instance or None if disabled.
|
||||
"""
|
||||
self.image_analysis_llm: LLM | None = None
|
||||
|
||||
if get_image_extraction_and_analysis_enabled():
|
||||
try:
|
||||
self.image_analysis_llm = get_default_llm_with_vision()
|
||||
|
||||
@@ -16,7 +16,7 @@ from onyx.db.models import SearchSettings
|
||||
from onyx.indexing.models import BaseChunk
|
||||
from onyx.indexing.models import IndexingSetting
|
||||
from shared_configs.enums import RerankerProvider
|
||||
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
MAX_METRICS_CONTENT = (
|
||||
200 # Just need enough characters to identify where in the doc the chunk is
|
||||
@@ -151,6 +151,10 @@ class SearchRequest(ChunkContext):
|
||||
evaluation_type: LLMEvaluationType = LLMEvaluationType.UNSPECIFIED
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
precomputed_query_embedding: Embedding | None = None
|
||||
precomputed_is_keyword: bool | None = None
|
||||
precomputed_keywords: list[str] | None = None
|
||||
|
||||
|
||||
class SearchQuery(ChunkContext):
|
||||
"Processed Request that is directly passed to the SearchPipeline"
|
||||
@@ -175,6 +179,8 @@ class SearchQuery(ChunkContext):
|
||||
offset: int = 0
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
precomputed_query_embedding: Embedding | None = None
|
||||
|
||||
|
||||
class RetrievalDetails(ChunkContext):
|
||||
# Use LLM to determine whether to do a retrieval or only rely on existing history
|
||||
|
||||
@@ -331,6 +331,14 @@ class SearchPipeline:
|
||||
self._retrieved_sections = expanded_inference_sections
|
||||
return expanded_inference_sections
|
||||
|
||||
@property
|
||||
def retrieved_sections(self) -> list[InferenceSection]:
|
||||
if self._retrieved_sections is not None:
|
||||
return self._retrieved_sections
|
||||
|
||||
self._retrieved_sections = self._get_sections()
|
||||
return self._retrieved_sections
|
||||
|
||||
@property
|
||||
def reranked_sections(self) -> list[InferenceSection]:
|
||||
"""Reranking is always done at the chunk level since section merging could create arbitrarily
|
||||
@@ -343,7 +351,7 @@ class SearchPipeline:
|
||||
if self._reranked_sections is not None:
|
||||
return self._reranked_sections
|
||||
|
||||
retrieved_sections = self._get_sections()
|
||||
retrieved_sections = self.retrieved_sections
|
||||
if self.retrieved_sections_callback is not None:
|
||||
self.retrieved_sections_callback(retrieved_sections)
|
||||
|
||||
|
||||
@@ -10,6 +10,7 @@ from langchain_core.messages import SystemMessage
|
||||
|
||||
from onyx.chat.models import SectionRelevancePiece
|
||||
from onyx.configs.app_configs import BLURB_SIZE
|
||||
from onyx.configs.app_configs import IMAGE_ANALYSIS_SYSTEM_PROMPT
|
||||
from onyx.configs.constants import RETURN_SEPARATOR
|
||||
from onyx.configs.llm_configs import get_search_time_image_analysis_enabled
|
||||
from onyx.configs.model_configs import CROSS_ENCODER_RANGE_MAX
|
||||
@@ -31,7 +32,6 @@ from onyx.file_store.file_store import get_default_file_store
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.utils import message_to_string
|
||||
from onyx.natural_language_processing.search_nlp_models import RerankingModel
|
||||
from onyx.prompts.image_analysis import IMAGE_ANALYSIS_SYSTEM_PROMPT
|
||||
from onyx.secondary_llm_flows.chunk_usefulness import llm_batch_eval_sections
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import FunctionCall
|
||||
|
||||
@@ -117,8 +117,12 @@ def retrieval_preprocessing(
|
||||
else None
|
||||
)
|
||||
|
||||
# Sometimes this is pre-computed in parallel with other heavy tasks to improve
|
||||
# latency, and in that case we don't need to run the model again
|
||||
run_query_analysis = (
|
||||
None if skip_query_analysis else FunctionCall(query_analysis, (query,), {})
|
||||
None
|
||||
if (skip_query_analysis or search_request.precomputed_is_keyword is not None)
|
||||
else FunctionCall(query_analysis, (query,), {})
|
||||
)
|
||||
|
||||
functions_to_run = [
|
||||
@@ -143,11 +147,12 @@ def retrieval_preprocessing(
|
||||
|
||||
# The extracted keywords right now are not very reliable, not using for now
|
||||
# Can maybe use for highlighting
|
||||
is_keyword, extracted_keywords = (
|
||||
parallel_results[run_query_analysis.result_id]
|
||||
if run_query_analysis
|
||||
else (False, None)
|
||||
)
|
||||
is_keyword, _extracted_keywords = False, None
|
||||
if search_request.precomputed_is_keyword is not None:
|
||||
is_keyword = search_request.precomputed_is_keyword
|
||||
_extracted_keywords = search_request.precomputed_keywords
|
||||
elif run_query_analysis:
|
||||
is_keyword, _extracted_keywords = parallel_results[run_query_analysis.result_id]
|
||||
|
||||
all_query_terms = query.split()
|
||||
processed_keywords = (
|
||||
@@ -247,4 +252,5 @@ def retrieval_preprocessing(
|
||||
chunks_above=chunks_above,
|
||||
chunks_below=chunks_below,
|
||||
full_doc=search_request.full_doc,
|
||||
precomputed_query_embedding=search_request.precomputed_query_embedding,
|
||||
)
|
||||
|
||||
@@ -31,7 +31,7 @@ from onyx.utils.timing import log_function_time
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
from shared_configs.enums import EmbedTextType
|
||||
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -109,6 +109,20 @@ def combine_retrieval_results(
|
||||
return sorted_chunks
|
||||
|
||||
|
||||
def get_query_embedding(query: str, db_session: Session) -> Embedding:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
|
||||
model = EmbeddingModel.from_db_model(
|
||||
search_settings=search_settings,
|
||||
# The below are globally set, this flow always uses the indexing one
|
||||
server_host=MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
query_embedding = model.encode([query], text_type=EmbedTextType.QUERY)[0]
|
||||
return query_embedding
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def doc_index_retrieval(
|
||||
query: SearchQuery,
|
||||
@@ -121,17 +135,10 @@ def doc_index_retrieval(
|
||||
from the large chunks to the referenced chunks,
|
||||
dedupes the chunks, and cleans the chunks.
|
||||
"""
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
|
||||
model = EmbeddingModel.from_db_model(
|
||||
search_settings=search_settings,
|
||||
# The below are globally set, this flow always uses the indexing one
|
||||
server_host=MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
query_embedding = query.precomputed_query_embedding or get_query_embedding(
|
||||
query.query, db_session
|
||||
)
|
||||
|
||||
query_embedding = model.encode([query.query], text_type=EmbedTextType.QUERY)[0]
|
||||
|
||||
top_chunks = document_index.hybrid_retrieval(
|
||||
query=query.query,
|
||||
query_embedding=query_embedding,
|
||||
@@ -249,7 +256,16 @@ def retrieve_chunks(
|
||||
continue
|
||||
simplified_queries.add(simplified_rephrase)
|
||||
|
||||
q_copy = query.copy(update={"query": rephrase}, deep=True)
|
||||
q_copy = query.model_copy(
|
||||
update={
|
||||
"query": rephrase,
|
||||
# need to recompute for each rephrase
|
||||
# note that `SearchQuery` is a frozen model, so we can't update
|
||||
# it below
|
||||
"precomputed_query_embedding": None,
|
||||
},
|
||||
deep=True,
|
||||
)
|
||||
run_queries.append(
|
||||
(
|
||||
doc_index_retrieval,
|
||||
|
||||
@@ -13,6 +13,7 @@ from onyx.db.models import SearchSettings
|
||||
from onyx.db.models import Tool as ToolModel
|
||||
from onyx.db.models import User
|
||||
from onyx.db.models import User__UserGroup
|
||||
from onyx.llm.utils import model_supports_image_input
|
||||
from onyx.server.manage.embedding.models import CloudEmbeddingProvider
|
||||
from onyx.server.manage.embedding.models import CloudEmbeddingProviderCreationRequest
|
||||
from onyx.server.manage.llm.models import FullLLMProvider
|
||||
@@ -187,6 +188,17 @@ def fetch_default_provider(db_session: Session) -> FullLLMProvider | None:
|
||||
return FullLLMProvider.from_model(provider_model)
|
||||
|
||||
|
||||
def fetch_default_vision_provider(db_session: Session) -> FullLLMProvider | None:
|
||||
provider_model = db_session.scalar(
|
||||
select(LLMProviderModel).where(
|
||||
LLMProviderModel.is_default_vision_provider == True # noqa: E712
|
||||
)
|
||||
)
|
||||
if not provider_model:
|
||||
return None
|
||||
return FullLLMProvider.from_model(provider_model)
|
||||
|
||||
|
||||
def fetch_provider(db_session: Session, provider_name: str) -> FullLLMProvider | None:
|
||||
provider_model = db_session.scalar(
|
||||
select(LLMProviderModel).where(LLMProviderModel.name == provider_name)
|
||||
@@ -246,3 +258,39 @@ def update_default_provider(provider_id: int, db_session: Session) -> None:
|
||||
|
||||
new_default.is_default_provider = True
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def update_default_vision_provider(
|
||||
provider_id: int, vision_model: str | None, db_session: Session
|
||||
) -> None:
|
||||
new_default = db_session.scalar(
|
||||
select(LLMProviderModel).where(LLMProviderModel.id == provider_id)
|
||||
)
|
||||
if not new_default:
|
||||
raise ValueError(f"LLM Provider with id {provider_id} does not exist")
|
||||
|
||||
# Validate that the specified vision model supports image input
|
||||
model_to_validate = vision_model or new_default.default_model_name
|
||||
if model_to_validate:
|
||||
if not model_supports_image_input(model_to_validate, new_default.provider):
|
||||
raise ValueError(
|
||||
f"Model '{model_to_validate}' for provider '{new_default.provider}' does not support image input"
|
||||
)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Model '{vision_model}' is not a valid model for provider '{new_default.provider}'"
|
||||
)
|
||||
|
||||
existing_default = db_session.scalar(
|
||||
select(LLMProviderModel).where(
|
||||
LLMProviderModel.is_default_vision_provider == True # noqa: E712
|
||||
)
|
||||
)
|
||||
if existing_default:
|
||||
existing_default.is_default_vision_provider = None
|
||||
# required to ensure that the below does not cause a unique constraint violation
|
||||
db_session.flush()
|
||||
|
||||
new_default.is_default_vision_provider = True
|
||||
new_default.default_vision_model = vision_model
|
||||
db_session.commit()
|
||||
|
||||
@@ -1489,6 +1489,10 @@ class LLMProvider(Base):
|
||||
|
||||
# should only be set for a single provider
|
||||
is_default_provider: Mapped[bool | None] = mapped_column(Boolean, unique=True)
|
||||
is_default_vision_provider: Mapped[bool | None] = mapped_column(
|
||||
Boolean, unique=True
|
||||
)
|
||||
default_vision_model: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
# EE only
|
||||
is_public: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
groups: Mapped[list["UserGroup"]] = relationship(
|
||||
@@ -2295,21 +2299,31 @@ class PublicBase(DeclarativeBase):
|
||||
__abstract__ = True
|
||||
|
||||
|
||||
# Strictly keeps track of the tenant that a given user will authenticate to.
|
||||
class UserTenantMapping(Base):
|
||||
__tablename__ = "user_tenant_mapping"
|
||||
__table_args__ = (
|
||||
UniqueConstraint("email", "tenant_id", name="uq_user_tenant"),
|
||||
{"schema": "public"},
|
||||
)
|
||||
__table_args__ = ({"schema": "public"},)
|
||||
|
||||
email: Mapped[str] = mapped_column(String, nullable=False, primary_key=True)
|
||||
tenant_id: Mapped[str] = mapped_column(String, nullable=False)
|
||||
tenant_id: Mapped[str] = mapped_column(String, nullable=False, primary_key=True)
|
||||
active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
|
||||
|
||||
@validates("email")
|
||||
def validate_email(self, key: str, value: str) -> str:
|
||||
return value.lower() if value else value
|
||||
|
||||
|
||||
class AvailableTenant(Base):
|
||||
__tablename__ = "available_tenant"
|
||||
"""
|
||||
These entries will only exist ephemerally and are meant to be picked up by new users on registration.
|
||||
"""
|
||||
|
||||
tenant_id: Mapped[str] = mapped_column(String, primary_key=True, nullable=False)
|
||||
alembic_version: Mapped[str] = mapped_column(String, nullable=False)
|
||||
date_created: Mapped[datetime.datetime] = mapped_column(DateTime, nullable=False)
|
||||
|
||||
|
||||
# This is a mapping from tenant IDs to anonymous user paths
|
||||
class TenantAnonymousUserPath(Base):
|
||||
__tablename__ = "tenant_anonymous_user_path"
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import random
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from logging import getLogger
|
||||
|
||||
from onyx.configs.constants import MessageType
|
||||
from onyx.db.chat import create_chat_session
|
||||
@@ -9,6 +10,8 @@ from onyx.db.chat import get_or_create_root_message
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.models import ChatSession
|
||||
|
||||
logger = getLogger(__name__)
|
||||
|
||||
|
||||
def seed_chat_history(num_sessions: int, num_messages: int, days: int) -> None:
|
||||
"""Utility function to seed chat history for testing.
|
||||
@@ -19,12 +22,18 @@ def seed_chat_history(num_sessions: int, num_messages: int, days: int) -> None:
|
||||
the times.
|
||||
"""
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
logger.info(f"Seeding {num_sessions} sessions.")
|
||||
for y in range(0, num_sessions):
|
||||
create_chat_session(db_session, f"pytest_session_{y}", None, None)
|
||||
|
||||
# randomize all session times
|
||||
logger.info(f"Seeding {num_messages} messages per session.")
|
||||
rows = db_session.query(ChatSession).all()
|
||||
for row in rows:
|
||||
for x in range(0, len(rows)):
|
||||
if x % 1024 == 0:
|
||||
logger.info(f"Seeded messages for {x} sessions so far.")
|
||||
|
||||
row = rows[x]
|
||||
row.time_created = datetime.utcnow() - timedelta(
|
||||
days=random.randint(0, days)
|
||||
)
|
||||
@@ -34,20 +43,37 @@ def seed_chat_history(num_sessions: int, num_messages: int, days: int) -> None:
|
||||
|
||||
root_message = get_or_create_root_message(row.id, db_session)
|
||||
|
||||
current_message_type = MessageType.USER
|
||||
parent_message = root_message
|
||||
for x in range(0, num_messages):
|
||||
if current_message_type == MessageType.USER:
|
||||
msg = f"pytest_message_user_{x}"
|
||||
else:
|
||||
msg = f"pytest_message_assistant_{x}"
|
||||
|
||||
chat_message = create_new_chat_message(
|
||||
row.id,
|
||||
root_message,
|
||||
f"pytest_message_{x}",
|
||||
parent_message,
|
||||
msg,
|
||||
None,
|
||||
0,
|
||||
MessageType.USER,
|
||||
current_message_type,
|
||||
db_session,
|
||||
)
|
||||
|
||||
chat_message.time_sent = row.time_created + timedelta(
|
||||
minutes=random.randint(0, 10)
|
||||
)
|
||||
db_session.commit()
|
||||
|
||||
db_session.commit()
|
||||
|
||||
current_message_type = (
|
||||
MessageType.ASSISTANT
|
||||
if current_message_type == MessageType.USER
|
||||
else MessageType.USER
|
||||
)
|
||||
parent_message = chat_message
|
||||
|
||||
db_session.commit()
|
||||
|
||||
logger.info(f"Seeded messages for {len(rows)} sessions. Finished.")
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
from sqlalchemy import and_
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy import or_
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
@@ -149,11 +148,10 @@ def delete_document_tags_for_documents__no_commit(
|
||||
stmt = delete(Document__Tag).where(Document__Tag.document_id.in_(document_ids))
|
||||
db_session.execute(stmt)
|
||||
|
||||
orphan_tags_query = (
|
||||
select(Tag.id)
|
||||
.outerjoin(Document__Tag, Tag.id == Document__Tag.tag_id)
|
||||
.group_by(Tag.id)
|
||||
.having(func.count(Document__Tag.document_id) == 0)
|
||||
orphan_tags_query = select(Tag.id).where(
|
||||
~db_session.query(Document__Tag.tag_id)
|
||||
.filter(Document__Tag.tag_id == Tag.id)
|
||||
.exists()
|
||||
)
|
||||
|
||||
orphan_tags = db_session.execute(orphan_tags_query).scalars().all()
|
||||
|
||||
@@ -6,10 +6,10 @@ from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import SystemMessage
|
||||
from PIL import Image
|
||||
|
||||
from onyx.configs.app_configs import IMAGE_SUMMARIZATION_SYSTEM_PROMPT
|
||||
from onyx.configs.app_configs import IMAGE_SUMMARIZATION_USER_PROMPT
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.utils import message_to_string
|
||||
from onyx.prompts.image_analysis import IMAGE_SUMMARIZATION_SYSTEM_PROMPT
|
||||
from onyx.prompts.image_analysis import IMAGE_SUMMARIZATION_USER_PROMPT
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -62,7 +62,7 @@ def summarize_image_with_error_handling(
|
||||
image_data: The raw image bytes
|
||||
context_name: Name or title of the image for context
|
||||
system_prompt: System prompt to use for the LLM
|
||||
user_prompt_template: Template for the user prompt, should contain {title} placeholder
|
||||
user_prompt_template: User prompt to use (without title)
|
||||
|
||||
Returns:
|
||||
The image summary text, or None if summarization failed or is disabled
|
||||
@@ -70,7 +70,10 @@ def summarize_image_with_error_handling(
|
||||
if llm is None:
|
||||
return None
|
||||
|
||||
user_prompt = user_prompt_template.format(title=context_name)
|
||||
# Prepend the image filename to the user prompt
|
||||
user_prompt = (
|
||||
f"The image has the file name '{context_name}'.\n{user_prompt_template}"
|
||||
)
|
||||
return summarize_image_pipeline(llm, image_data, user_prompt, system_prompt)
|
||||
|
||||
|
||||
|
||||
@@ -167,7 +167,7 @@ def _convert_delta_to_message_chunk(
|
||||
stop_reason: str | None = None,
|
||||
) -> BaseMessageChunk:
|
||||
"""Adapted from langchain_community.chat_models.litellm._convert_delta_to_message_chunk"""
|
||||
role = _dict.get("role") or (_base_msg_to_role(curr_msg) if curr_msg else None)
|
||||
role = _dict.get("role") or (_base_msg_to_role(curr_msg) if curr_msg else "unknown")
|
||||
content = _dict.get("content") or ""
|
||||
additional_kwargs = {}
|
||||
if _dict.get("function_call"):
|
||||
@@ -402,6 +402,7 @@ class DefaultMultiLLM(LLM):
|
||||
stream: bool,
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> litellm.ModelResponse | litellm.CustomStreamWrapper:
|
||||
# litellm doesn't accept LangChain BaseMessage objects, so we need to convert them
|
||||
# to a dict representation
|
||||
@@ -429,6 +430,7 @@ class DefaultMultiLLM(LLM):
|
||||
# model params
|
||||
temperature=0,
|
||||
timeout=timeout_override or self._timeout,
|
||||
max_tokens=max_tokens,
|
||||
# For now, we don't support parallel tool calls
|
||||
# NOTE: we can't pass this in if tools are not specified
|
||||
# or else OpenAI throws an error
|
||||
@@ -484,6 +486,7 @@ class DefaultMultiLLM(LLM):
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> BaseMessage:
|
||||
if LOG_DANSWER_MODEL_INTERACTIONS:
|
||||
self.log_model_configs()
|
||||
@@ -497,6 +500,7 @@ class DefaultMultiLLM(LLM):
|
||||
stream=False,
|
||||
structured_response_format=structured_response_format,
|
||||
timeout_override=timeout_override,
|
||||
max_tokens=max_tokens,
|
||||
),
|
||||
)
|
||||
choice = response.choices[0]
|
||||
@@ -515,6 +519,7 @@ class DefaultMultiLLM(LLM):
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> Iterator[BaseMessage]:
|
||||
if LOG_DANSWER_MODEL_INTERACTIONS:
|
||||
self.log_model_configs()
|
||||
@@ -539,6 +544,7 @@ class DefaultMultiLLM(LLM):
|
||||
stream=True,
|
||||
structured_response_format=structured_response_format,
|
||||
timeout_override=timeout_override,
|
||||
max_tokens=max_tokens,
|
||||
),
|
||||
)
|
||||
try:
|
||||
|
||||
@@ -82,6 +82,7 @@ class CustomModelServer(LLM):
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> BaseMessage:
|
||||
return self._execute(prompt)
|
||||
|
||||
@@ -92,5 +93,6 @@ class CustomModelServer(LLM):
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> Iterator[BaseMessage]:
|
||||
yield self._execute(prompt)
|
||||
|
||||
@@ -5,7 +5,9 @@ from onyx.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from onyx.configs.model_configs import GEN_AI_MODEL_FALLBACK_MAX_TOKENS
|
||||
from onyx.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.engine import get_session_with_current_tenant
|
||||
from onyx.db.llm import fetch_default_provider
|
||||
from onyx.db.llm import fetch_default_vision_provider
|
||||
from onyx.db.llm import fetch_existing_llm_providers
|
||||
from onyx.db.llm import fetch_provider
|
||||
from onyx.db.models import Persona
|
||||
@@ -14,6 +16,7 @@ from onyx.llm.exceptions import GenAIDisabledException
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.override_models import LLMOverride
|
||||
from onyx.llm.utils import model_supports_image_input
|
||||
from onyx.server.manage.llm.models import FullLLMProvider
|
||||
from onyx.utils.headers import build_llm_extra_headers
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.long_term_log import LongTermLogger
|
||||
@@ -94,40 +97,61 @@ def get_default_llm_with_vision(
|
||||
additional_headers: dict[str, str] | None = None,
|
||||
long_term_logger: LongTermLogger | None = None,
|
||||
) -> LLM | None:
|
||||
"""Get an LLM that supports image input, with the following priority:
|
||||
1. Use the designated default vision provider if it exists and supports image input
|
||||
2. Fall back to the first LLM provider that supports image input
|
||||
|
||||
Returns None if no providers exist or if no provider supports images.
|
||||
"""
|
||||
if DISABLE_GENERATIVE_AI:
|
||||
raise GenAIDisabledException()
|
||||
|
||||
with get_session_context_manager() as db_session:
|
||||
llm_providers = fetch_existing_llm_providers(db_session)
|
||||
|
||||
if not llm_providers:
|
||||
return None
|
||||
|
||||
for provider in llm_providers:
|
||||
model_name = provider.default_model_name
|
||||
fast_model_name = (
|
||||
provider.fast_default_model_name or provider.default_model_name
|
||||
def create_vision_llm(provider: FullLLMProvider, model: str) -> LLM:
|
||||
"""Helper to create an LLM if the provider supports image input."""
|
||||
return get_llm(
|
||||
provider=provider.provider,
|
||||
model=model,
|
||||
deployment_name=provider.deployment_name,
|
||||
api_key=provider.api_key,
|
||||
api_base=provider.api_base,
|
||||
api_version=provider.api_version,
|
||||
custom_config=provider.custom_config,
|
||||
timeout=timeout,
|
||||
temperature=temperature,
|
||||
additional_headers=additional_headers,
|
||||
long_term_logger=long_term_logger,
|
||||
)
|
||||
|
||||
if not model_name or not fast_model_name:
|
||||
continue
|
||||
|
||||
if model_supports_image_input(model_name, provider.provider):
|
||||
return get_llm(
|
||||
provider=provider.provider,
|
||||
model=model_name,
|
||||
deployment_name=provider.deployment_name,
|
||||
api_key=provider.api_key,
|
||||
api_base=provider.api_base,
|
||||
api_version=provider.api_version,
|
||||
custom_config=provider.custom_config,
|
||||
timeout=timeout,
|
||||
temperature=temperature,
|
||||
additional_headers=additional_headers,
|
||||
long_term_logger=long_term_logger,
|
||||
with get_session_with_current_tenant() as db_session:
|
||||
# Try the default vision provider first
|
||||
default_provider = fetch_default_vision_provider(db_session)
|
||||
if (
|
||||
default_provider
|
||||
and default_provider.default_vision_model
|
||||
and model_supports_image_input(
|
||||
default_provider.default_vision_model, default_provider.provider
|
||||
)
|
||||
):
|
||||
return create_vision_llm(
|
||||
default_provider, default_provider.default_vision_model
|
||||
)
|
||||
|
||||
raise ValueError("No LLM provider found that supports image input")
|
||||
# Fall back to searching all providers
|
||||
providers = fetch_existing_llm_providers(db_session)
|
||||
|
||||
if not providers:
|
||||
return None
|
||||
|
||||
# Find the first provider that supports image input
|
||||
for provider in providers:
|
||||
if provider.default_vision_model and model_supports_image_input(
|
||||
provider.default_vision_model, provider.provider
|
||||
):
|
||||
return create_vision_llm(
|
||||
FullLLMProvider.from_model(provider), provider.default_vision_model
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_default_llms(
|
||||
|
||||
@@ -91,12 +91,18 @@ class LLM(abc.ABC):
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> BaseMessage:
|
||||
self._precall(prompt)
|
||||
# TODO add a postcall to log model outputs independent of concrete class
|
||||
# implementation
|
||||
return self._invoke_implementation(
|
||||
prompt, tools, tool_choice, structured_response_format, timeout_override
|
||||
prompt,
|
||||
tools,
|
||||
tool_choice,
|
||||
structured_response_format,
|
||||
timeout_override,
|
||||
max_tokens,
|
||||
)
|
||||
|
||||
@abc.abstractmethod
|
||||
@@ -107,6 +113,7 @@ class LLM(abc.ABC):
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> BaseMessage:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -117,12 +124,18 @@ class LLM(abc.ABC):
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> Iterator[BaseMessage]:
|
||||
self._precall(prompt)
|
||||
# TODO add a postcall to log model outputs independent of concrete class
|
||||
# implementation
|
||||
messages = self._stream_implementation(
|
||||
prompt, tools, tool_choice, structured_response_format, timeout_override
|
||||
prompt,
|
||||
tools,
|
||||
tool_choice,
|
||||
structured_response_format,
|
||||
timeout_override,
|
||||
max_tokens,
|
||||
)
|
||||
|
||||
tokens = []
|
||||
@@ -142,5 +155,6 @@ class LLM(abc.ABC):
|
||||
tool_choice: ToolChoiceOptions | None = None,
|
||||
structured_response_format: dict | None = None,
|
||||
timeout_override: int | None = None,
|
||||
max_tokens: int | None = None,
|
||||
) -> Iterator[BaseMessage]:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -234,6 +234,8 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
|
||||
yield
|
||||
|
||||
SqlEngine.reset_engine()
|
||||
|
||||
if AUTH_RATE_LIMITING_ENABLED:
|
||||
await close_auth_limiter()
|
||||
|
||||
@@ -359,7 +361,15 @@ def get_application() -> FastAPI:
|
||||
)
|
||||
|
||||
if AUTH_TYPE == AuthType.GOOGLE_OAUTH:
|
||||
oauth_client = GoogleOAuth2(OAUTH_CLIENT_ID, OAUTH_CLIENT_SECRET)
|
||||
# For Google OAuth, refresh tokens are requested by:
|
||||
# 1. Adding the right scopes
|
||||
# 2. Properly configuring OAuth in Google Cloud Console to allow offline access
|
||||
oauth_client = GoogleOAuth2(
|
||||
OAUTH_CLIENT_ID,
|
||||
OAUTH_CLIENT_SECRET,
|
||||
# Use standard scopes that include profile and email
|
||||
scopes=["openid", "email", "profile"],
|
||||
)
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
create_onyx_oauth_router(
|
||||
@@ -381,6 +391,13 @@ def get_application() -> FastAPI:
|
||||
prefix="/auth",
|
||||
)
|
||||
|
||||
# Add refresh token endpoint for OAuth as well
|
||||
include_auth_router_with_prefix(
|
||||
application,
|
||||
fastapi_users.get_refresh_router(auth_backend),
|
||||
prefix="/auth",
|
||||
)
|
||||
|
||||
application.add_exception_handler(
|
||||
RequestValidationError, validation_exception_handler
|
||||
)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# Used for creating embeddings of images for vector search
|
||||
IMAGE_SUMMARIZATION_SYSTEM_PROMPT = """
|
||||
DEFAULT_IMAGE_SUMMARIZATION_SYSTEM_PROMPT = """
|
||||
You are an assistant for summarizing images for retrieval.
|
||||
Summarize the content of the following image and be as precise as possible.
|
||||
The summary will be embedded and used to retrieve the original image.
|
||||
@@ -7,14 +7,13 @@ Therefore, write a concise summary of the image that is optimized for retrieval.
|
||||
"""
|
||||
|
||||
# Prompt for generating image descriptions with filename context
|
||||
IMAGE_SUMMARIZATION_USER_PROMPT = """
|
||||
The image has the file name '{title}'.
|
||||
DEFAULT_IMAGE_SUMMARIZATION_USER_PROMPT = """
|
||||
Describe precisely and concisely what the image shows.
|
||||
"""
|
||||
|
||||
|
||||
# Used for analyzing images in response to user queries at search time
|
||||
IMAGE_ANALYSIS_SYSTEM_PROMPT = (
|
||||
DEFAULT_IMAGE_ANALYSIS_SYSTEM_PROMPT = (
|
||||
"You are an AI assistant specialized in describing images.\n"
|
||||
"You will receive a user question plus an image URL. Provide a concise textual answer.\n"
|
||||
"Focus on aspects of the image that are relevant to the user's question.\n"
|
||||
|
||||
@@ -160,6 +160,20 @@ class RedisPool:
|
||||
def get_replica_client(self, tenant_id: str) -> Redis:
|
||||
return TenantRedis(tenant_id, connection_pool=self._replica_pool)
|
||||
|
||||
def get_raw_client(self) -> Redis:
|
||||
"""
|
||||
Returns a Redis client with direct access to the primary connection pool,
|
||||
without tenant prefixing.
|
||||
"""
|
||||
return redis.Redis(connection_pool=self._pool)
|
||||
|
||||
def get_raw_replica_client(self) -> Redis:
|
||||
"""
|
||||
Returns a Redis client with direct access to the replica connection pool,
|
||||
without tenant prefixing.
|
||||
"""
|
||||
return redis.Redis(connection_pool=self._replica_pool)
|
||||
|
||||
@staticmethod
|
||||
def create_pool(
|
||||
host: str = REDIS_HOST,
|
||||
@@ -224,6 +238,15 @@ def get_redis_client(
|
||||
# This argument will be deprecated in the future
|
||||
tenant_id: str | None = None,
|
||||
) -> Redis:
|
||||
"""
|
||||
Returns a Redis client with tenant-specific key prefixing.
|
||||
|
||||
This ensures proper data isolation between tenants by automatically
|
||||
prefixing all Redis keys with the tenant ID.
|
||||
|
||||
Use this when working with tenant-specific data that should be
|
||||
isolated from other tenants.
|
||||
"""
|
||||
if tenant_id is None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
@@ -235,6 +258,15 @@ def get_redis_replica_client(
|
||||
# this argument will be deprecated in the future
|
||||
tenant_id: str | None = None,
|
||||
) -> Redis:
|
||||
"""
|
||||
Returns a Redis replica client with tenant-specific key prefixing.
|
||||
|
||||
Similar to get_redis_client(), but connects to a read replica when available.
|
||||
This ensures proper data isolation between tenants by automatically
|
||||
prefixing all Redis keys with the tenant ID.
|
||||
|
||||
Use this for read-heavy operations on tenant-specific data.
|
||||
"""
|
||||
if tenant_id is None:
|
||||
tenant_id = get_current_tenant_id()
|
||||
|
||||
@@ -242,13 +274,57 @@ def get_redis_replica_client(
|
||||
|
||||
|
||||
def get_shared_redis_client() -> Redis:
|
||||
"""
|
||||
Returns a Redis client with a shared namespace prefix.
|
||||
|
||||
Unlike tenant-specific clients, this uses a common prefix for all keys,
|
||||
creating a shared namespace accessible across all tenants.
|
||||
|
||||
Use this for data that should be shared across the application and
|
||||
isn't specific to any individual tenant.
|
||||
"""
|
||||
return redis_pool.get_client(DEFAULT_REDIS_PREFIX)
|
||||
|
||||
|
||||
def get_shared_redis_replica_client() -> Redis:
|
||||
"""
|
||||
Returns a Redis replica client with a shared namespace prefix.
|
||||
|
||||
Similar to get_shared_redis_client(), but connects to a read replica when available.
|
||||
Uses a common prefix for all keys, creating a shared namespace.
|
||||
|
||||
Use this for read-heavy operations on data that should be shared
|
||||
across the application.
|
||||
"""
|
||||
return redis_pool.get_replica_client(DEFAULT_REDIS_PREFIX)
|
||||
|
||||
|
||||
def get_raw_redis_client() -> Redis:
|
||||
"""
|
||||
Returns a Redis client that doesn't apply tenant prefixing to keys.
|
||||
|
||||
Use this only when you need to access Redis directly without tenant isolation
|
||||
or any key prefixing. Typically needed for integrating with external systems
|
||||
or libraries that have inflexible key requirements.
|
||||
|
||||
Warning: Be careful with this client as it bypasses tenant isolation.
|
||||
"""
|
||||
return redis_pool.get_raw_client()
|
||||
|
||||
|
||||
def get_raw_redis_replica_client() -> Redis:
|
||||
"""
|
||||
Returns a Redis replica client that doesn't apply tenant prefixing to keys.
|
||||
|
||||
Similar to get_raw_redis_client(), but connects to a read replica when available.
|
||||
Use this for read-heavy operations that need direct Redis access without
|
||||
tenant isolation or key prefixing.
|
||||
|
||||
Warning: Be careful with this client as it bypasses tenant isolation.
|
||||
"""
|
||||
return redis_pool.get_raw_replica_client()
|
||||
|
||||
|
||||
SSL_CERT_REQS_MAP = {
|
||||
"none": ssl.CERT_NONE,
|
||||
"optional": ssl.CERT_OPTIONAL,
|
||||
|
||||
@@ -5,7 +5,7 @@ from fastapi.dependencies.models import Dependant
|
||||
from starlette.routing import BaseRoute
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_chat_accesssible_user
|
||||
from onyx.auth.users import current_chat_accessible_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_limited_user
|
||||
from onyx.auth.users import current_user
|
||||
@@ -31,6 +31,7 @@ PUBLIC_ENDPOINT_SPECS = [
|
||||
# just gets the version of Onyx (e.g. 0.3.11)
|
||||
("/version", {"GET"}),
|
||||
# stuff related to basic auth
|
||||
("/auth/refresh", {"POST"}),
|
||||
("/auth/register", {"POST"}),
|
||||
("/auth/login", {"POST"}),
|
||||
("/auth/logout", {"POST"}),
|
||||
@@ -112,7 +113,7 @@ def check_router_auth(
|
||||
or depends_fn == current_curator_or_admin_user
|
||||
or depends_fn == api_key_dep
|
||||
or depends_fn == current_user_with_expired_token
|
||||
or depends_fn == current_chat_accesssible_user
|
||||
or depends_fn == current_chat_accessible_user
|
||||
or depends_fn == control_plane_dep
|
||||
or depends_fn == current_cloud_superuser
|
||||
):
|
||||
|
||||
@@ -17,7 +17,7 @@ from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_chat_accesssible_user
|
||||
from onyx.auth.users import current_chat_accessible_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.background.celery.versioned_apps.primary import app as primary_app
|
||||
@@ -1247,7 +1247,7 @@ class BasicCCPairInfo(BaseModel):
|
||||
|
||||
@router.get("/connector-status")
|
||||
def get_basic_connector_indexing_status(
|
||||
user: User = Depends(current_chat_accesssible_user),
|
||||
user: User = Depends(current_chat_accessible_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[BasicCCPairInfo]:
|
||||
cc_pairs = get_connector_credential_pairs_for_user(
|
||||
|
||||
@@ -11,7 +11,7 @@ from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_chat_accesssible_user
|
||||
from onyx.auth.users import current_chat_accessible_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_limited_user
|
||||
from onyx.auth.users import current_user
|
||||
@@ -390,7 +390,7 @@ def get_image_generation_tool(
|
||||
|
||||
@basic_router.get("")
|
||||
def list_personas(
|
||||
user: User | None = Depends(current_chat_accesssible_user),
|
||||
user: User | None = Depends(current_chat_accessible_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
include_deleted: bool = False,
|
||||
persona_ids: list[int] = Query(None),
|
||||
|
||||
@@ -7,13 +7,14 @@ from fastapi import Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_chat_accesssible_user
|
||||
from onyx.auth.users import current_chat_accessible_user
|
||||
from onyx.db.engine import get_session
|
||||
from onyx.db.llm import fetch_existing_llm_providers
|
||||
from onyx.db.llm import fetch_existing_llm_providers_for_user
|
||||
from onyx.db.llm import fetch_provider
|
||||
from onyx.db.llm import remove_llm_provider
|
||||
from onyx.db.llm import update_default_provider
|
||||
from onyx.db.llm import update_default_vision_provider
|
||||
from onyx.db.llm import upsert_llm_provider
|
||||
from onyx.db.models import User
|
||||
from onyx.llm.factory import get_default_llms
|
||||
@@ -21,11 +22,13 @@ from onyx.llm.factory import get_llm
|
||||
from onyx.llm.llm_provider_options import fetch_available_well_known_llms
|
||||
from onyx.llm.llm_provider_options import WellKnownLLMProviderDescriptor
|
||||
from onyx.llm.utils import litellm_exception_to_error_msg
|
||||
from onyx.llm.utils import model_supports_image_input
|
||||
from onyx.llm.utils import test_llm
|
||||
from onyx.server.manage.llm.models import FullLLMProvider
|
||||
from onyx.server.manage.llm.models import LLMProviderDescriptor
|
||||
from onyx.server.manage.llm.models import LLMProviderUpsertRequest
|
||||
from onyx.server.manage.llm.models import TestLLMRequest
|
||||
from onyx.server.manage.llm.models import VisionProviderResponse
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
|
||||
@@ -186,12 +189,68 @@ def set_provider_as_default(
|
||||
update_default_provider(provider_id=provider_id, db_session=db_session)
|
||||
|
||||
|
||||
@admin_router.post("/provider/{provider_id}/default-vision")
|
||||
def set_provider_as_default_vision(
|
||||
provider_id: int,
|
||||
vision_model: str
|
||||
| None = Query(None, description="The default vision model to use"),
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
update_default_vision_provider(
|
||||
provider_id=provider_id, vision_model=vision_model, db_session=db_session
|
||||
)
|
||||
|
||||
|
||||
@admin_router.get("/vision-providers")
|
||||
def get_vision_capable_providers(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[VisionProviderResponse]:
|
||||
"""Return a list of LLM providers and their models that support image input"""
|
||||
|
||||
providers = fetch_existing_llm_providers(db_session)
|
||||
vision_providers = []
|
||||
|
||||
logger.info("Fetching vision-capable providers")
|
||||
|
||||
for provider in providers:
|
||||
vision_models = []
|
||||
|
||||
# Check model names in priority order
|
||||
model_names_to_check = []
|
||||
if provider.model_names:
|
||||
model_names_to_check = provider.model_names
|
||||
elif provider.display_model_names:
|
||||
model_names_to_check = provider.display_model_names
|
||||
elif provider.default_model_name:
|
||||
model_names_to_check = [provider.default_model_name]
|
||||
|
||||
# Check each model for vision capability
|
||||
for model_name in model_names_to_check:
|
||||
if model_supports_image_input(model_name, provider.provider):
|
||||
vision_models.append(model_name)
|
||||
logger.debug(f"Vision model found: {provider.provider}/{model_name}")
|
||||
|
||||
# Only include providers with at least one vision-capable model
|
||||
if vision_models:
|
||||
provider_dict = FullLLMProvider.from_model(provider).model_dump()
|
||||
provider_dict["vision_models"] = vision_models
|
||||
logger.info(
|
||||
f"Vision provider: {provider.provider} with models: {vision_models}"
|
||||
)
|
||||
vision_providers.append(VisionProviderResponse(**provider_dict))
|
||||
|
||||
logger.info(f"Found {len(vision_providers)} vision-capable providers")
|
||||
return vision_providers
|
||||
|
||||
|
||||
"""Endpoints for all"""
|
||||
|
||||
|
||||
@basic_router.get("/provider")
|
||||
def list_llm_provider_basics(
|
||||
user: User | None = Depends(current_chat_accesssible_user),
|
||||
user: User | None = Depends(current_chat_accessible_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[LLMProviderDescriptor]:
|
||||
return [
|
||||
|
||||
@@ -34,6 +34,8 @@ class LLMProviderDescriptor(BaseModel):
|
||||
default_model_name: str
|
||||
fast_default_model_name: str | None
|
||||
is_default_provider: bool | None
|
||||
is_default_vision_provider: bool | None
|
||||
default_vision_model: str | None
|
||||
display_model_names: list[str] | None
|
||||
|
||||
@classmethod
|
||||
@@ -46,11 +48,10 @@ class LLMProviderDescriptor(BaseModel):
|
||||
default_model_name=llm_provider_model.default_model_name,
|
||||
fast_default_model_name=llm_provider_model.fast_default_model_name,
|
||||
is_default_provider=llm_provider_model.is_default_provider,
|
||||
model_names=(
|
||||
llm_provider_model.model_names
|
||||
or fetch_models_for_provider(llm_provider_model.provider)
|
||||
or [llm_provider_model.default_model_name]
|
||||
),
|
||||
is_default_vision_provider=llm_provider_model.is_default_vision_provider,
|
||||
default_vision_model=llm_provider_model.default_vision_model,
|
||||
model_names=llm_provider_model.model_names
|
||||
or fetch_models_for_provider(llm_provider_model.provider),
|
||||
display_model_names=llm_provider_model.display_model_names,
|
||||
)
|
||||
|
||||
@@ -68,6 +69,7 @@ class LLMProvider(BaseModel):
|
||||
groups: list[int] = Field(default_factory=list)
|
||||
display_model_names: list[str] | None = None
|
||||
deployment_name: str | None = None
|
||||
default_vision_model: str | None = None
|
||||
|
||||
|
||||
class LLMProviderUpsertRequest(LLMProvider):
|
||||
@@ -79,6 +81,7 @@ class LLMProviderUpsertRequest(LLMProvider):
|
||||
class FullLLMProvider(LLMProvider):
|
||||
id: int
|
||||
is_default_provider: bool | None = None
|
||||
is_default_vision_provider: bool | None = None
|
||||
model_names: list[str]
|
||||
|
||||
@classmethod
|
||||
@@ -94,6 +97,8 @@ class FullLLMProvider(LLMProvider):
|
||||
default_model_name=llm_provider_model.default_model_name,
|
||||
fast_default_model_name=llm_provider_model.fast_default_model_name,
|
||||
is_default_provider=llm_provider_model.is_default_provider,
|
||||
is_default_vision_provider=llm_provider_model.is_default_vision_provider,
|
||||
default_vision_model=llm_provider_model.default_vision_model,
|
||||
display_model_names=llm_provider_model.display_model_names,
|
||||
model_names=(
|
||||
llm_provider_model.model_names
|
||||
@@ -104,3 +109,9 @@ class FullLLMProvider(LLMProvider):
|
||||
groups=[group.id for group in llm_provider_model.groups],
|
||||
deployment_name=llm_provider_model.deployment_name,
|
||||
)
|
||||
|
||||
|
||||
class VisionProviderResponse(FullLLMProvider):
|
||||
"""Response model for vision providers endpoint, including vision-specific fields."""
|
||||
|
||||
vision_models: list[str]
|
||||
|
||||
@@ -53,6 +53,16 @@ class UserPreferences(BaseModel):
|
||||
temperature_override_enabled: bool | None = None
|
||||
|
||||
|
||||
class TenantSnapshot(BaseModel):
|
||||
tenant_id: str
|
||||
number_of_users: int
|
||||
|
||||
|
||||
class TenantInfo(BaseModel):
|
||||
invitation: TenantSnapshot | None = None
|
||||
new_tenant: TenantSnapshot | None = None
|
||||
|
||||
|
||||
class UserInfo(BaseModel):
|
||||
id: str
|
||||
email: str
|
||||
@@ -65,9 +75,10 @@ class UserInfo(BaseModel):
|
||||
current_token_created_at: datetime | None = None
|
||||
current_token_expiry_length: int | None = None
|
||||
is_cloud_superuser: bool = False
|
||||
organization_name: str | None = None
|
||||
team_name: str | None = None
|
||||
is_anonymous_user: bool | None = None
|
||||
password_configured: bool | None = None
|
||||
tenant_info: TenantInfo | None = None
|
||||
|
||||
@classmethod
|
||||
def from_model(
|
||||
@@ -76,8 +87,9 @@ class UserInfo(BaseModel):
|
||||
current_token_created_at: datetime | None = None,
|
||||
expiry_length: int | None = None,
|
||||
is_cloud_superuser: bool = False,
|
||||
organization_name: str | None = None,
|
||||
team_name: str | None = None,
|
||||
is_anonymous_user: bool | None = None,
|
||||
tenant_info: TenantInfo | None = None,
|
||||
) -> "UserInfo":
|
||||
return cls(
|
||||
id=str(user.id),
|
||||
@@ -99,7 +111,7 @@ class UserInfo(BaseModel):
|
||||
temperature_override_enabled=user.temperature_override_enabled,
|
||||
)
|
||||
),
|
||||
organization_name=organization_name,
|
||||
team_name=team_name,
|
||||
# set to None if TRACK_EXTERNAL_IDP_EXPIRY is False so that we avoid cases
|
||||
# where they previously had this set + used OIDC, and now they switched to
|
||||
# basic auth are now constantly getting redirected back to the login page
|
||||
@@ -109,6 +121,7 @@ class UserInfo(BaseModel):
|
||||
current_token_expiry_length=expiry_length,
|
||||
is_cloud_superuser=is_cloud_superuser,
|
||||
is_anonymous_user=is_anonymous_user,
|
||||
tenant_info=tenant_info,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import re
|
||||
from datetime import datetime
|
||||
from datetime import timedelta
|
||||
from datetime import timezone
|
||||
from typing import cast
|
||||
|
||||
import jwt
|
||||
from email_validator import EmailNotValidError
|
||||
@@ -12,13 +14,11 @@ from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
from fastapi import Query
|
||||
from fastapi import Request
|
||||
from psycopg2.errors import UniqueViolation
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import Column
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy import update
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from ee.onyx.configs.app_configs import SUPER_USERS
|
||||
@@ -33,9 +33,12 @@ from onyx.auth.users import current_admin_user
|
||||
from onyx.auth.users import current_curator_or_admin_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.auth.users import optional_user
|
||||
from onyx.configs.app_configs import AUTH_BACKEND
|
||||
from onyx.configs.app_configs import AUTH_TYPE
|
||||
from onyx.configs.app_configs import AuthBackend
|
||||
from onyx.configs.app_configs import DEV_MODE
|
||||
from onyx.configs.app_configs import ENABLE_EMAIL_INVITES
|
||||
from onyx.configs.app_configs import REDIS_AUTH_KEY_PREFIX
|
||||
from onyx.configs.app_configs import SESSION_EXPIRE_TIME_SECONDS
|
||||
from onyx.configs.app_configs import VALID_EMAIL_DOMAINS
|
||||
from onyx.configs.constants import AuthType
|
||||
@@ -52,9 +55,12 @@ from onyx.db.users import get_total_filtered_users_count
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.db.users import validate_user_role_update
|
||||
from onyx.key_value_store.factory import get_kv_store
|
||||
from onyx.redis.redis_pool import get_raw_redis_client
|
||||
from onyx.server.documents.models import PaginatedReturn
|
||||
from onyx.server.manage.models import AllUsersResponse
|
||||
from onyx.server.manage.models import AutoScrollRequest
|
||||
from onyx.server.manage.models import TenantInfo
|
||||
from onyx.server.manage.models import TenantSnapshot
|
||||
from onyx.server.manage.models import UserByEmail
|
||||
from onyx.server.manage.models import UserInfo
|
||||
from onyx.server.manage.models import UserPreferences
|
||||
@@ -296,13 +302,6 @@ def bulk_invite_users(
|
||||
"onyx.server.tenants.provisioning", "add_users_to_tenant", None
|
||||
)(new_invited_emails, tenant_id)
|
||||
|
||||
except IntegrityError as e:
|
||||
if isinstance(e.orig, UniqueViolation):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail="User has already been invited to a Onyx organization",
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to add users to tenant {tenant_id}: {str(e)}")
|
||||
|
||||
@@ -425,6 +424,10 @@ async def delete_user(
|
||||
db_session.expunge(user_to_delete)
|
||||
|
||||
try:
|
||||
tenant_id = get_current_tenant_id()
|
||||
fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.user_mapping", "remove_users_from_tenant", None
|
||||
)([user_email.user_email], tenant_id)
|
||||
delete_user_from_db(user_to_delete, db_session)
|
||||
logger.info(f"Deleted user {user_to_delete.email}")
|
||||
|
||||
@@ -480,7 +483,7 @@ async def get_user_role(user: User = Depends(current_user)) -> UserRoleResponse:
|
||||
return UserRoleResponse(role=user.role)
|
||||
|
||||
|
||||
def get_current_token_expiration_jwt(
|
||||
def get_current_auth_token_expiration_jwt(
|
||||
user: User | None, request: Request
|
||||
) -> datetime | None:
|
||||
if user is None:
|
||||
@@ -509,6 +512,48 @@ def get_current_token_expiration_jwt(
|
||||
return None
|
||||
|
||||
|
||||
def get_current_auth_token_creation_redis(
|
||||
user: User | None, request: Request
|
||||
) -> datetime | None:
|
||||
"""Calculate the token creation time from Redis TTL information.
|
||||
|
||||
This function retrieves the authentication token from cookies,
|
||||
checks its TTL in Redis, and calculates when the token was created.
|
||||
Despite the function name, it returns the token creation time, not the expiration time.
|
||||
"""
|
||||
if user is None:
|
||||
return None
|
||||
try:
|
||||
# Get the token from the request
|
||||
token = request.cookies.get(FASTAPI_USERS_AUTH_COOKIE_NAME)
|
||||
if not token:
|
||||
logger.debug("No auth token cookie found")
|
||||
return None
|
||||
|
||||
# Get the Redis client
|
||||
redis = get_raw_redis_client()
|
||||
redis_key = REDIS_AUTH_KEY_PREFIX + token
|
||||
|
||||
# Get the TTL of the token
|
||||
ttl = cast(int, redis.ttl(redis_key))
|
||||
if ttl <= 0:
|
||||
logger.error("Token has expired or doesn't exist in Redis")
|
||||
return None
|
||||
|
||||
# Calculate the creation time based on TTL and session expiry
|
||||
# Current time minus (total session length minus remaining TTL)
|
||||
current_time = datetime.now(timezone.utc)
|
||||
token_creation_time = current_time - timedelta(
|
||||
seconds=(SESSION_EXPIRE_TIME_SECONDS - ttl)
|
||||
)
|
||||
|
||||
return token_creation_time
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error retrieving token expiration from Redis: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def get_current_token_creation(
|
||||
user: User | None, db_session: Session
|
||||
) -> datetime | None:
|
||||
@@ -536,6 +581,7 @@ def get_current_token_creation(
|
||||
|
||||
@router.get("/me")
|
||||
def verify_user_logged_in(
|
||||
request: Request,
|
||||
user: User | None = Depends(optional_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> UserInfo:
|
||||
@@ -553,26 +599,47 @@ def verify_user_logged_in(
|
||||
if anonymous_user_enabled(tenant_id=tenant_id):
|
||||
store = get_kv_store()
|
||||
return fetch_no_auth_user(store, anonymous_user_enabled=True)
|
||||
|
||||
raise BasicAuthenticationError(detail="User Not Authenticated")
|
||||
|
||||
if user.oidc_expiry and user.oidc_expiry < datetime.now(timezone.utc):
|
||||
raise BasicAuthenticationError(
|
||||
detail="Access denied. User's OIDC token has expired.",
|
||||
)
|
||||
|
||||
token_created_at = (
|
||||
None if MULTI_TENANT else get_current_token_creation(user, db_session)
|
||||
get_current_auth_token_creation_redis(user, request)
|
||||
if AUTH_BACKEND == AuthBackend.REDIS
|
||||
else get_current_token_creation(user, db_session)
|
||||
)
|
||||
organization_name = fetch_ee_implementation_or_noop(
|
||||
|
||||
team_name = fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.user_mapping", "get_tenant_id_for_email", None
|
||||
)(user.email)
|
||||
|
||||
new_tenant: TenantSnapshot | None = None
|
||||
tenant_invitation: TenantSnapshot | None = None
|
||||
|
||||
if MULTI_TENANT:
|
||||
if team_name != get_current_tenant_id():
|
||||
user_count = fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.user_mapping", "get_tenant_count", None
|
||||
)(team_name)
|
||||
new_tenant = TenantSnapshot(tenant_id=team_name, number_of_users=user_count)
|
||||
|
||||
tenant_invitation = fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.user_mapping", "get_tenant_invitation", None
|
||||
)(user.email)
|
||||
|
||||
user_info = UserInfo.from_model(
|
||||
user,
|
||||
current_token_created_at=token_created_at,
|
||||
expiry_length=SESSION_EXPIRE_TIME_SECONDS,
|
||||
is_cloud_superuser=user.email in SUPER_USERS,
|
||||
organization_name=organization_name,
|
||||
team_name=team_name,
|
||||
tenant_info=TenantInfo(
|
||||
new_tenant=new_tenant,
|
||||
invitation=tenant_invitation,
|
||||
),
|
||||
)
|
||||
|
||||
return user_info
|
||||
|
||||
@@ -49,9 +49,9 @@ class FullUserSnapshot(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
class InvitedUserSnapshot(BaseModel):
|
||||
email: str
|
||||
|
||||
|
||||
class DisplayPriorityRequest(BaseModel):
|
||||
display_priority_map: dict[int, int]
|
||||
|
||||
|
||||
class InvitedUserSnapshot(BaseModel):
|
||||
email: str
|
||||
|
||||
@@ -20,7 +20,7 @@ from fastapi.responses import StreamingResponse
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_chat_accesssible_user
|
||||
from onyx.auth.users import current_chat_accessible_user
|
||||
from onyx.auth.users import current_user
|
||||
from onyx.chat.chat_utils import create_chat_chain
|
||||
from onyx.chat.chat_utils import extract_headers
|
||||
@@ -190,7 +190,7 @@ def update_chat_session_model(
|
||||
def get_chat_session(
|
||||
session_id: UUID,
|
||||
is_shared: bool = False,
|
||||
user: User | None = Depends(current_chat_accesssible_user),
|
||||
user: User | None = Depends(current_chat_accessible_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> ChatSessionDetailResponse:
|
||||
user_id = user.id if user is not None else None
|
||||
@@ -246,7 +246,7 @@ def get_chat_session(
|
||||
@router.post("/create-chat-session")
|
||||
def create_new_chat_session(
|
||||
chat_session_creation_request: ChatSessionCreationRequest,
|
||||
user: User | None = Depends(current_chat_accesssible_user),
|
||||
user: User | None = Depends(current_chat_accessible_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> CreateChatSessionID:
|
||||
user_id = user.id if user is not None else None
|
||||
@@ -381,7 +381,7 @@ async def is_connected(request: Request) -> Callable[[], bool]:
|
||||
def handle_new_chat_message(
|
||||
chat_message_req: CreateChatMessageRequest,
|
||||
request: Request,
|
||||
user: User | None = Depends(current_chat_accesssible_user),
|
||||
user: User | None = Depends(current_chat_accessible_user),
|
||||
_rate_limit_check: None = Depends(check_token_rate_limits),
|
||||
is_connected_func: Callable[[], bool] = Depends(is_connected),
|
||||
) -> StreamingResponse:
|
||||
@@ -473,7 +473,7 @@ def set_message_as_latest(
|
||||
@router.post("/create-chat-message-feedback")
|
||||
def create_chat_feedback(
|
||||
feedback: ChatFeedbackRequest,
|
||||
user: User | None = Depends(current_chat_accesssible_user),
|
||||
user: User | None = Depends(current_chat_accessible_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
user_id = user.id if user else None
|
||||
|
||||
@@ -11,7 +11,7 @@ from sqlalchemy import func
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.auth.users import current_chat_accesssible_user
|
||||
from onyx.auth.users import current_chat_accessible_user
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.models import ChatMessage
|
||||
from onyx.db.models import ChatSession
|
||||
@@ -29,7 +29,7 @@ TOKEN_BUDGET_UNIT = 1_000
|
||||
|
||||
|
||||
def check_token_rate_limits(
|
||||
user: User | None = Depends(current_chat_accesssible_user),
|
||||
user: User | None = Depends(current_chat_accessible_user),
|
||||
) -> None:
|
||||
# short circuit if no rate limits are set up
|
||||
# NOTE: result of `any_rate_limit_exists` is cached, so this call is fast 99% of the time
|
||||
|
||||
@@ -32,15 +32,15 @@ class InCodeToolInfo(TypedDict):
|
||||
BUILT_IN_TOOLS: list[InCodeToolInfo] = [
|
||||
InCodeToolInfo(
|
||||
cls=SearchTool,
|
||||
description="The Search Tool allows the Assistant to search through connected knowledge to help build an answer.",
|
||||
description="The Search Action allows the Assistant to search through connected knowledge to help build an answer.",
|
||||
in_code_tool_id=SearchTool.__name__,
|
||||
display_name=SearchTool._DISPLAY_NAME,
|
||||
),
|
||||
InCodeToolInfo(
|
||||
cls=ImageGenerationTool,
|
||||
description=(
|
||||
"The Image Generation Tool allows the assistant to use DALL-E 3 to generate images. "
|
||||
"The tool will be used when the user asks the assistant to generate an image."
|
||||
"The Image Generation Action allows the assistant to use DALL-E 3 to generate images. "
|
||||
"The action will be used when the user asks the assistant to generate an image."
|
||||
),
|
||||
in_code_tool_id=ImageGenerationTool.__name__,
|
||||
display_name=ImageGenerationTool._DISPLAY_NAME,
|
||||
@@ -51,7 +51,7 @@ BUILT_IN_TOOLS: list[InCodeToolInfo] = [
|
||||
InCodeToolInfo(
|
||||
cls=InternetSearchTool,
|
||||
description=(
|
||||
"The Internet Search Tool allows the assistant "
|
||||
"The Internet Search Action allows the assistant "
|
||||
"to perform internet searches for up-to-date information."
|
||||
),
|
||||
in_code_tool_id=InternetSearchTool.__name__,
|
||||
@@ -98,7 +98,7 @@ def load_builtin_tools(db_session: Session) -> None:
|
||||
for tool_id, tool in list(in_code_tool_id_to_tool.items()):
|
||||
if tool_id not in built_in_ids:
|
||||
db_session.delete(tool)
|
||||
logger.notice(f"Removed tool no longer in built-in list: {tool.name}")
|
||||
logger.notice(f"Removed action no longer in built-in list: {tool.name}")
|
||||
|
||||
db_session.commit()
|
||||
logger.notice("All built-in tools are loaded/verified.")
|
||||
|
||||
@@ -9,6 +9,7 @@ from sqlalchemy.orm import Session
|
||||
from onyx.context.search.enums import SearchType
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
|
||||
class ToolResponse(BaseModel):
|
||||
@@ -60,11 +61,15 @@ class SearchQueryInfo(BaseModel):
|
||||
recency_bias_multiplier: float
|
||||
|
||||
|
||||
# None indicates that the default value should be used
|
||||
class SearchToolOverrideKwargs(BaseModel):
|
||||
force_no_rerank: bool
|
||||
alternate_db_session: Session | None
|
||||
retrieved_sections_callback: Callable[[list[InferenceSection]], None] | None
|
||||
skip_query_analysis: bool
|
||||
force_no_rerank: bool | None = None
|
||||
alternate_db_session: Session | None = None
|
||||
retrieved_sections_callback: Callable[[list[InferenceSection]], None] | None = None
|
||||
skip_query_analysis: bool | None = None
|
||||
precomputed_query_embedding: Embedding | None = None
|
||||
precomputed_is_keyword: bool | None = None
|
||||
precomputed_keywords: list[str] | None = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user