Compare commits

..

36 Commits

Author SHA1 Message Date
pablonyx
ecbd4eb1ad add basic user invite flow (#4253) 2025-03-11 19:02:51 +00:00
pablonyx
f94d335d12 Do not show modals to non-multitenant users (#4256) 2025-03-11 11:53:13 -07:00
pablonyx
59a388ce0a fix tests 2025-03-11 11:12:35 -07:00
rkuo-danswer
9cd3cbb978 fix versions (#4250)
Co-authored-by: Richard Kuo <rkuo@rkuo.com>
2025-03-10 23:50:07 -07:00
pablonyx
ab1b6b487e descrease model server logspam (#4166) 2025-03-10 18:29:27 +00:00
Chris Weaver
6ead9510a4 Small notion tweaks (#4244)
* Small notion tweaks

* Add comment
2025-03-10 15:51:12 +00:00
Chris Weaver
965f9e98bf Eliminate extremely long log line for large checkpointds (#4236)
* Eliminate extremely long log line for large checkpointds

* address greptile
2025-03-10 15:50:50 +00:00
rkuo-danswer
426883bbf5 Feature/agentic buffered (#4231)
* rename agent test script to prevent pytest autodiscovery

* first cut

* fix log message

* fix up typing

* add a sample test

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-03-10 15:48:42 +00:00
rkuo-danswer
6ca400ced9 Bugfix/delete document tags slow (#4232)
* Add Missing Date and Message-ID Headers to Ensure Email Delivery

* fix issue Performance issue during connector deletion #4191

* fix ruff

* bump to rebuild PR

---------

Co-authored-by: ThomaciousD <2194608+ThomaciousD@users.noreply.github.com>
Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-03-10 03:07:30 +00:00
Weves
104c4b9f4d small modal improvement 2025-03-09 20:54:53 -07:00
pablonyx
8b5e8bd5b9 k (#4240) 2025-03-10 03:06:13 +00:00
Weves
7f7621d7c0 SMall gitbook tweaks 2025-03-09 14:46:44 -07:00
pablonyx
06dcc28d05 Improved login experience (#4178)
* functional initial auth modal

* k

* k

* k

* looking good

* k

* k

* k

* k

* update

* k

* k

* misc bunch

* improvements

* k

* address comments

* k

* nit

* update

* k
2025-03-09 01:06:20 +00:00
pablonyx
18df63dfd9 Fix local background jobs (#4241) 2025-03-08 14:47:56 -08:00
Chris Weaver
0d3c72acbf Add basic memory logging (#4234)
* Add basic memory logging

* Small tweaks

* Switch to monotonic
2025-03-08 03:49:47 +00:00
rkuo-danswer
9217243e3e Bugfix/query history notes (#4204)
* early work in progress

* rename utility script

* move actual data seeding to a shareable function

* add test

* make the test pass with the fix

* fix comment

* slight improvements and notes to query history and seeding

* update test

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-03-07 19:52:30 +00:00
rkuo-danswer
61ccba82a9 light worker needs to discover some indexing tasks (#4209)
* light worker needs to discover some indexing tasks

* fix formatting

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-03-07 11:52:09 -08:00
Weves
9e8eba23c3 Fix frozen model issue 2025-03-07 09:05:43 -08:00
evan-danswer
0c29743538 use max_tokens to do better rate limit handling (#4224)
* use max_tokens to do better rate limit handling

* fix unti tests

* address greptile comment, thanks greptile
2025-03-06 18:12:05 -08:00
pablonyx
08b2421947 fix 2025-03-06 17:30:31 -08:00
pablonyx
ed518563db minor typing update 2025-03-06 17:02:39 -08:00
pablonyx
a32f7dc936 Fix Connector tests (confluence) (#4221) 2025-03-06 17:00:01 -08:00
rkuo-danswer
798e10c52f revert to always building model server (#4213)
* revert to always building model server

* fix just in case

---------

Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-03-06 23:49:45 +00:00
pablonyx
bf4983e35a Ensure consistent UX (#4222)
* ux consistent

* nit

* Update web/src/app/admin/configuration/llm/interfaces.ts

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>

---------

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
2025-03-06 23:13:32 +00:00
evan-danswer
b7da91e3ae improved basic search latency (#4186)
* improved basic search latency

* address PR comments + minor cleanup
2025-03-06 22:22:59 +00:00
Weves
29382656fc Stop trying a million times for the user validity check 2025-03-06 15:35:49 -08:00
pablonyx
7d6db8d500 Comma separated list for Github repos (#4199) 2025-03-06 14:46:57 -08:00
Chris Weaver
a7a374dc81 Confluence fixes (#4220)
* Confluence fixes

* Small tweak

* Address greptile comments
2025-03-06 20:57:07 +00:00
rkuo-danswer
facc8cc2fa add scope needed for permission sync (#4198)
Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-03-06 20:03:38 +00:00
rkuo-danswer
2c0af0a0ca Feature/helm updates (#4201)
* add ingress for api and web

* helm setup docs

* add letsencrypt. close blocks

* use pathType ImplementationSpecific as Prefix is deprecated

* fix backend labels. configure nginx routes. update annotations

* fix linting

---------

Co-authored-by: Sajjad Anwar <sajjadkm@gmail.com>
Co-authored-by: Richard Kuo (Danswer) <rkuo@onyx.app>
2025-03-06 19:48:20 +00:00
pablonyx
bfbc1cd954 k (#4172) 2025-03-06 18:55:12 +00:00
pablonyx
626da583aa Fix gated tenants (#4177)
* fix

* mypy .
2025-03-06 18:07:15 +00:00
pablonyx
92faca139d Fix extra tenant mystery (#4197)
* fix extra tenant mystery

* nit
2025-03-06 18:06:49 +00:00
pablonyx
cec05c5ee9 Revert "k"
This reverts commit 687122911d.
2025-03-06 09:38:31 -08:00
Richard Kuo (Danswer)
eaf054ef06 oauth router went missing? 2025-03-05 15:50:23 -08:00
pablonyx
a7a1a24658 minor nit 2025-03-05 15:35:02 -08:00
166 changed files with 3860 additions and 1239 deletions

View File

@@ -12,29 +12,40 @@ env:
BUILDKIT_PROGRESS: plain
jobs:
# 1) Preliminary job to check if the changed files are relevant
# Bypassing this for now as the idea of not building is glitching
# releases and builds that depends on everything being tagged in docker
# 1) Preliminary job to check if the changed files are relevant
# check_model_server_changes:
# runs-on: ubuntu-latest
# outputs:
# changed: ${{ steps.check.outputs.changed }}
# steps:
# - name: Checkout code
# uses: actions/checkout@v4
#
# - name: Check if relevant files changed
# id: check
# run: |
# # Default to "false"
# echo "changed=false" >> $GITHUB_OUTPUT
#
# # Compare the previous commit (github.event.before) to the current one (github.sha)
# # If any file in backend/model_server/** or backend/Dockerfile.model_server is changed,
# # set changed=true
# if git diff --name-only ${{ github.event.before }} ${{ github.sha }} \
# | grep -E '^backend/model_server/|^backend/Dockerfile.model_server'; then
# echo "changed=true" >> $GITHUB_OUTPUT
# fi
check_model_server_changes:
runs-on: ubuntu-latest
outputs:
changed: ${{ steps.check.outputs.changed }}
changed: "true"
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Check if relevant files changed
id: check
run: |
# Default to "false"
echo "changed=false" >> $GITHUB_OUTPUT
# Compare the previous commit (github.event.before) to the current one (github.sha)
# If any file in backend/model_server/** or backend/Dockerfile.model_server is changed,
# set changed=true
if git diff --name-only ${{ github.event.before }} ${{ github.sha }} \
| grep -E '^backend/model_server/|^backend/Dockerfile.model_server'; then
echo "changed=true" >> $GITHUB_OUTPUT
fi
- name: Bypass check and set output
run: echo "changed=true" >> $GITHUB_OUTPUT
build-amd64:
needs: [check_model_server_changes]
if: needs.check_model_server_changes.outputs.changed == 'true'

View File

@@ -1,6 +1,7 @@
name: Connector Tests
on:
merge_group:
pull_request:
branches: [main]
schedule:
@@ -47,11 +48,13 @@ env:
# Gitbook
GITBOOK_SPACE_ID: ${{ secrets.GITBOOK_SPACE_ID }}
GITBOOK_API_KEY: ${{ secrets.GITBOOK_API_KEY }}
# Notion
NOTION_INTEGRATION_TOKEN: ${{ secrets.NOTION_INTEGRATION_TOKEN }}
jobs:
connectors-check:
# See https://runs-on.com/runners/linux/
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}"]
env:
PYTHONPATH: ./backend
@@ -76,7 +79,7 @@ jobs:
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
playwright install chromium
playwright install-deps chromium
- name: Run Tests
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
run: py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/connectors

View File

@@ -114,3 +114,4 @@ To try the Onyx Enterprise Edition:
## 💡 Contributing
Looking to contribute? Please check out the [Contribution Guide](CONTRIBUTING.md) for more details.

View File

@@ -0,0 +1,125 @@
"""Update GitHub connector repo_name to repositories
Revision ID: 3934b1bc7b62
Revises: b7c2b63c4a03
Create Date: 2025-03-05 10:50:30.516962
"""
from alembic import op
import sqlalchemy as sa
import json
import logging
# revision identifiers, used by Alembic.
revision = "3934b1bc7b62"
down_revision = "b7c2b63c4a03"
branch_labels = None
depends_on = None
logger = logging.getLogger("alembic.runtime.migration")
def upgrade() -> None:
# Get all GitHub connectors
conn = op.get_bind()
# First get all GitHub connectors
github_connectors = conn.execute(
sa.text(
"""
SELECT id, connector_specific_config
FROM connector
WHERE source = 'GITHUB'
"""
)
).fetchall()
# Update each connector's config
updated_count = 0
for connector_id, config in github_connectors:
try:
if not config:
logger.warning(f"Connector {connector_id} has no config, skipping")
continue
# Parse the config if it's a string
if isinstance(config, str):
config = json.loads(config)
if "repo_name" not in config:
continue
# Create new config with repositories instead of repo_name
new_config = dict(config)
repo_name_value = new_config.pop("repo_name")
new_config["repositories"] = repo_name_value
# Update the connector with the new config
conn.execute(
sa.text(
"""
UPDATE connector
SET connector_specific_config = :new_config
WHERE id = :connector_id
"""
),
{"connector_id": connector_id, "new_config": json.dumps(new_config)},
)
updated_count += 1
except Exception as e:
logger.error(f"Error updating connector {connector_id}: {str(e)}")
def downgrade() -> None:
# Get all GitHub connectors
conn = op.get_bind()
logger.debug(
"Starting rollback of GitHub connectors from repositories to repo_name"
)
github_connectors = conn.execute(
sa.text(
"""
SELECT id, connector_specific_config
FROM connector
WHERE source = 'GITHUB'
"""
)
).fetchall()
logger.debug(f"Found {len(github_connectors)} GitHub connectors to rollback")
# Revert each GitHub connector to use repo_name instead of repositories
reverted_count = 0
for connector_id, config in github_connectors:
try:
if not config:
continue
# Parse the config if it's a string
if isinstance(config, str):
config = json.loads(config)
if "repositories" not in config:
continue
# Create new config with repo_name instead of repositories
new_config = dict(config)
repositories_value = new_config.pop("repositories")
new_config["repo_name"] = repositories_value
# Update the connector with the new config
conn.execute(
sa.text(
"""
UPDATE connector
SET connector_specific_config = :new_config
WHERE id = :connector_id
"""
),
{"new_config": json.dumps(new_config), "connector_id": connector_id},
)
reverted_count += 1
except Exception as e:
logger.error(f"Error reverting connector {connector_id}: {str(e)}")

View File

@@ -6,8 +6,7 @@ Create Date: 2025-02-26 13:07:56.217791
"""
from alembic import op
import time
from sqlalchemy import text
# revision identifiers, used by Alembic.
revision = "3bd4c84fe72f"
@@ -28,357 +27,45 @@ depends_on = None
# 4. Adds indexes to both chat_message and chat_session tables for comprehensive search
def upgrade():
# --- PART 1: chat_message table ---
# Step 1: Add nullable column (quick, minimal locking)
# op.execute("ALTER TABLE chat_message DROP COLUMN IF EXISTS message_tsv")
# op.execute("DROP TRIGGER IF EXISTS chat_message_tsv_trigger ON chat_message")
# op.execute("DROP FUNCTION IF EXISTS update_chat_message_tsv()")
# op.execute("ALTER TABLE chat_message DROP COLUMN IF EXISTS message_tsv")
# # Drop chat_session tsv trigger if it exists
# op.execute("DROP TRIGGER IF EXISTS chat_session_tsv_trigger ON chat_session")
# op.execute("DROP FUNCTION IF EXISTS update_chat_session_tsv()")
# op.execute("ALTER TABLE chat_session DROP COLUMN IF EXISTS title_tsv")
# raise Exception("Stop here")
time.time()
op.execute("ALTER TABLE chat_message ADD COLUMN IF NOT EXISTS message_tsv tsvector")
# Step 2: Create function and trigger for new/updated rows
def upgrade() -> None:
# Create a GIN index for full-text search on chat_message.message
op.execute(
"""
CREATE OR REPLACE FUNCTION update_chat_message_tsv()
RETURNS TRIGGER AS $$
BEGIN
NEW.message_tsv = to_tsvector('english', NEW.message);
RETURN NEW;
END;
$$ LANGUAGE plpgsql
"""
ALTER TABLE chat_message
ADD COLUMN message_tsv tsvector
GENERATED ALWAYS AS (to_tsvector('english', message)) STORED;
"""
)
# Create trigger in a separate execute call
# Commit the current transaction before creating concurrent indexes
op.execute("COMMIT")
op.execute(
"""
CREATE TRIGGER chat_message_tsv_trigger
BEFORE INSERT OR UPDATE ON chat_message
FOR EACH ROW EXECUTE FUNCTION update_chat_message_tsv()
"""
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_message_tsv
ON chat_message
USING GIN (message_tsv)
"""
)
# Step 3: Update existing rows in batches using Python
time.time()
# Get connection and count total rows
connection = op.get_bind()
total_count_result = connection.execute(
text("SELECT COUNT(*) FROM chat_message")
).scalar()
total_count = total_count_result if total_count_result is not None else 0
batch_size = 5000
batches = 0
# Calculate total batches needed
total_batches = (
(total_count + batch_size - 1) // batch_size if total_count > 0 else 0
# Also add a stored tsvector column for chat_session.description
op.execute(
"""
ALTER TABLE chat_session
ADD COLUMN description_tsv tsvector
GENERATED ALWAYS AS (to_tsvector('english', coalesce(description, ''))) STORED;
"""
)
# Process in batches - properly handling UUIDs by using OFFSET/LIMIT approach
for batch_num in range(total_batches):
offset = batch_num * batch_size
# Commit again before creating the second concurrent index
op.execute("COMMIT")
# Execute update for this batch using OFFSET/LIMIT which works with UUIDs
connection.execute(
text(
"""
UPDATE chat_message
SET message_tsv = to_tsvector('english', message)
WHERE id IN (
SELECT id FROM chat_message
WHERE message_tsv IS NULL
ORDER BY id
LIMIT :batch_size OFFSET :offset
)
"""
).bindparams(batch_size=batch_size, offset=offset)
)
# Commit each batch
connection.execute(text("COMMIT"))
# Start a new transaction
connection.execute(text("BEGIN"))
batches += 1
# Final check for any remaining NULL values
connection.execute(
text(
"""
UPDATE chat_message SET message_tsv = to_tsvector('english', message)
WHERE message_tsv IS NULL
"""
)
)
# Create GIN index concurrently
connection.execute(text("COMMIT"))
time.time()
connection.execute(
text(
"""
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_message_tsv
ON chat_message USING GIN (message_tsv)
"""
)
)
# First drop the trigger as it won't be needed anymore
connection.execute(
text(
"""
DROP TRIGGER IF EXISTS chat_message_tsv_trigger ON chat_message;
"""
)
)
connection.execute(
text(
"""
DROP FUNCTION IF EXISTS update_chat_message_tsv();
"""
)
)
# Add new generated column
time.time()
connection.execute(
text(
"""
ALTER TABLE chat_message
ADD COLUMN message_tsv_gen tsvector
GENERATED ALWAYS AS (to_tsvector('english', message)) STORED;
"""
)
)
connection.execute(text("COMMIT"))
time.time()
connection.execute(
text(
"""
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_message_tsv_gen
ON chat_message USING GIN (message_tsv_gen)
"""
)
)
# Drop old index and column
connection.execute(text("COMMIT"))
connection.execute(
text(
"""
DROP INDEX CONCURRENTLY IF EXISTS idx_chat_message_tsv;
"""
)
)
connection.execute(text("COMMIT"))
connection.execute(
text(
"""
ALTER TABLE chat_message DROP COLUMN message_tsv;
"""
)
)
# Rename new column to old name
connection.execute(
text(
"""
ALTER TABLE chat_message RENAME COLUMN message_tsv_gen TO message_tsv;
"""
)
)
# --- PART 2: chat_session table ---
# Step 1: Add nullable column (quick, minimal locking)
time.time()
connection.execute(
text(
"ALTER TABLE chat_session ADD COLUMN IF NOT EXISTS description_tsv tsvector"
)
)
# Step 2: Create function and trigger for new/updated rows - SPLIT INTO SEPARATE CALLS
connection.execute(
text(
"""
CREATE OR REPLACE FUNCTION update_chat_session_tsv()
RETURNS TRIGGER AS $$
BEGIN
NEW.description_tsv = to_tsvector('english', COALESCE(NEW.description, ''));
RETURN NEW;
END;
$$ LANGUAGE plpgsql
"""
)
)
# Create trigger in a separate execute call
connection.execute(
text(
"""
CREATE TRIGGER chat_session_tsv_trigger
BEFORE INSERT OR UPDATE ON chat_session
FOR EACH ROW EXECUTE FUNCTION update_chat_session_tsv()
"""
)
)
# Step 3: Update existing rows in batches using Python
time.time()
# Get the maximum ID to determine batch count
# Cast id to text for MAX function since it's a UUID
max_id_result = connection.execute(
text("SELECT COALESCE(MAX(id::text), '0') FROM chat_session")
).scalar()
max_id_result if max_id_result is not None else "0"
batch_size = 5000
batches = 0
# Get all IDs ordered to process in batches
rows = connection.execute(
text("SELECT id FROM chat_session ORDER BY id")
).fetchall()
total_rows = len(rows)
# Process in batches
for batch_num, batch_start in enumerate(range(0, total_rows, batch_size)):
batch_end = min(batch_start + batch_size, total_rows)
batch_ids = [row[0] for row in rows[batch_start:batch_end]]
if not batch_ids:
continue
# Use IN clause instead of BETWEEN for UUIDs
placeholders = ", ".join([f":id{i}" for i in range(len(batch_ids))])
params = {f"id{i}": id_val for i, id_val in enumerate(batch_ids)}
# Execute update for this batch
connection.execute(
text(
f"""
UPDATE chat_session
SET description_tsv = to_tsvector('english', COALESCE(description, ''))
WHERE id IN ({placeholders})
AND description_tsv IS NULL
"""
).bindparams(**params)
)
# Commit each batch
connection.execute(text("COMMIT"))
# Start a new transaction
connection.execute(text("BEGIN"))
batches += 1
# Final check for any remaining NULL values
connection.execute(
text(
"""
UPDATE chat_session SET description_tsv = to_tsvector('english', COALESCE(description, ''))
WHERE description_tsv IS NULL
"""
)
)
# Create GIN index concurrently
connection.execute(text("COMMIT"))
time.time()
connection.execute(
text(
"""
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_session_desc_tsv
ON chat_session USING GIN (description_tsv)
"""
)
)
# After Final check for chat_session
# First drop the trigger as it won't be needed anymore
connection.execute(
text(
"""
DROP TRIGGER IF EXISTS chat_session_tsv_trigger ON chat_session;
"""
)
)
connection.execute(
text(
"""
DROP FUNCTION IF EXISTS update_chat_session_tsv();
"""
)
)
# Add new generated column
time.time()
connection.execute(
text(
"""
ALTER TABLE chat_session
ADD COLUMN description_tsv_gen tsvector
GENERATED ALWAYS AS (to_tsvector('english', COALESCE(description, ''))) STORED;
"""
)
)
# Create new index on generated column
connection.execute(text("COMMIT"))
time.time()
connection.execute(
text(
"""
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_session_desc_tsv_gen
ON chat_session USING GIN (description_tsv_gen)
"""
)
)
# Drop old index and column
connection.execute(text("COMMIT"))
connection.execute(
text(
"""
DROP INDEX CONCURRENTLY IF EXISTS idx_chat_session_desc_tsv;
"""
)
)
connection.execute(text("COMMIT"))
connection.execute(
text(
"""
ALTER TABLE chat_session DROP COLUMN description_tsv;
"""
)
)
# Rename new column to old name
connection.execute(
text(
"""
ALTER TABLE chat_session RENAME COLUMN description_tsv_gen TO description_tsv;
"""
)
op.execute(
"""
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_session_desc_tsv
ON chat_session
USING GIN (description_tsv)
"""
)

View File

@@ -0,0 +1,51 @@
"""new column user tenant mapping
Revision ID: ac842f85f932
Revises: 34e3630c7f32
Create Date: 2025-03-03 13:30:14.802874
"""
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision = "ac842f85f932"
down_revision = "34e3630c7f32"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Add active column with default value of True
op.add_column(
"user_tenant_mapping",
sa.Column(
"active",
sa.Boolean(),
nullable=False,
server_default="true",
),
schema="public",
)
op.drop_constraint("uq_email", "user_tenant_mapping", schema="public")
# Create a unique index for active=true records
# This ensures a user can only be active in one tenant at a time
op.execute(
"CREATE UNIQUE INDEX uq_user_active_email_idx ON public.user_tenant_mapping (email) WHERE active = true"
)
def downgrade() -> None:
# Drop the unique index for active=true records
op.execute("DROP INDEX IF EXISTS uq_user_active_email_idx")
op.create_unique_constraint(
"uq_email", "user_tenant_mapping", ["email"], schema="public"
)
# Remove the active column
op.drop_column("user_tenant_mapping", "active", schema="public")

View File

@@ -27,6 +27,8 @@ def get_empty_chat_messages_entries__paginated(
first element is the most recent timestamp out of the sessions iterated
- this timestamp can be used to paginate forward in time
second element is a list of messages belonging to all the sessions iterated
Only messages of type USER are returned
"""
chat_sessions = fetch_chat_sessions_eagerly_by_time(
start=period[0],

View File

@@ -15,7 +15,7 @@ from ee.onyx.server.enterprise_settings.api import (
)
from ee.onyx.server.manage.standard_answer import router as standard_answer_router
from ee.onyx.server.middleware.tenant_tracking import add_tenant_id_middleware
from ee.onyx.server.oauth.api import router as oauth_router
from ee.onyx.server.oauth.api import router as ee_oauth_router
from ee.onyx.server.query_and_chat.chat_backend import (
router as chat_router,
)
@@ -128,7 +128,7 @@ def get_application() -> FastAPI:
include_router_with_global_prefix_prepended(application, query_router)
include_router_with_global_prefix_prepended(application, chat_router)
include_router_with_global_prefix_prepended(application, standard_answer_router)
include_router_with_global_prefix_prepended(application, oauth_router)
include_router_with_global_prefix_prepended(application, ee_oauth_router)
# Enterprise-only global settings
include_router_with_global_prefix_prepended(

View File

@@ -80,6 +80,7 @@ class ConfluenceCloudOAuth:
"search:confluence%20"
# granular scope
"read:attachment:confluence%20" # possibly unneeded unless calling v2 attachments api
"read:content-details:confluence%20" # for permission sync
"offline_access"
)

View File

@@ -1,10 +1,14 @@
import re
from typing import cast
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from ee.onyx.server.query_and_chat.models import AgentAnswer
from ee.onyx.server.query_and_chat.models import AgentSubQuery
from ee.onyx.server.query_and_chat.models import AgentSubQuestion
from ee.onyx.server.query_and_chat.models import BasicCreateChatMessageRequest
from ee.onyx.server.query_and_chat.models import (
BasicCreateChatMessageWithHistoryRequest,
@@ -14,13 +18,19 @@ from ee.onyx.server.query_and_chat.models import SimpleDoc
from onyx.auth.users import current_user
from onyx.chat.chat_utils import combine_message_thread
from onyx.chat.chat_utils import create_chat_chain
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import AllCitations
from onyx.chat.models import ExtendedToolResponse
from onyx.chat.models import FinalUsedContextDocsResponse
from onyx.chat.models import LlmDoc
from onyx.chat.models import LLMRelevanceFilterResponse
from onyx.chat.models import OnyxAnswerPiece
from onyx.chat.models import QADocsResponse
from onyx.chat.models import RefinedAnswerImprovement
from onyx.chat.models import StreamingError
from onyx.chat.models import SubQueryPiece
from onyx.chat.models import SubQuestionIdentifier
from onyx.chat.models import SubQuestionPiece
from onyx.chat.process_message import ChatPacketStream
from onyx.chat.process_message import stream_chat_message_objects
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
@@ -89,6 +99,12 @@ def _convert_packet_stream_to_response(
final_context_docs: list[LlmDoc] = []
answer = ""
# accumulate stream data with these dicts
agent_sub_questions: dict[tuple[int, int], AgentSubQuestion] = {}
agent_answers: dict[tuple[int, int], AgentAnswer] = {}
agent_sub_queries: dict[tuple[int, int, int], AgentSubQuery] = {}
for packet in packets:
if isinstance(packet, OnyxAnswerPiece) and packet.answer_piece:
answer += packet.answer_piece
@@ -97,6 +113,15 @@ def _convert_packet_stream_to_response(
# TODO: deprecate `simple_search_docs`
response.simple_search_docs = _translate_doc_response_to_simple_doc(packet)
# This is a no-op if agent_sub_questions hasn't already been filled
if packet.level is not None and packet.level_question_num is not None:
id = (packet.level, packet.level_question_num)
if id in agent_sub_questions:
agent_sub_questions[id].document_ids = [
saved_search_doc.document_id
for saved_search_doc in packet.top_documents
]
elif isinstance(packet, StreamingError):
response.error_msg = packet.error
elif isinstance(packet, ChatMessageDetail):
@@ -113,11 +138,104 @@ def _convert_packet_stream_to_response(
citation.citation_num: citation.document_id
for citation in packet.citations
}
# agentic packets
elif isinstance(packet, SubQuestionPiece):
if packet.level is not None and packet.level_question_num is not None:
id = (packet.level, packet.level_question_num)
if agent_sub_questions.get(id) is None:
agent_sub_questions[id] = AgentSubQuestion(
level=packet.level,
level_question_num=packet.level_question_num,
sub_question=packet.sub_question,
document_ids=[],
)
else:
agent_sub_questions[id].sub_question += packet.sub_question
elif isinstance(packet, AgentAnswerPiece):
if packet.level is not None and packet.level_question_num is not None:
id = (packet.level, packet.level_question_num)
if agent_answers.get(id) is None:
agent_answers[id] = AgentAnswer(
level=packet.level,
level_question_num=packet.level_question_num,
answer=packet.answer_piece,
answer_type=packet.answer_type,
)
else:
agent_answers[id].answer += packet.answer_piece
elif isinstance(packet, SubQueryPiece):
if packet.level is not None and packet.level_question_num is not None:
sub_query_id = (
packet.level,
packet.level_question_num,
packet.query_id,
)
if agent_sub_queries.get(sub_query_id) is None:
agent_sub_queries[sub_query_id] = AgentSubQuery(
level=packet.level,
level_question_num=packet.level_question_num,
sub_query=packet.sub_query,
query_id=packet.query_id,
)
else:
agent_sub_queries[sub_query_id].sub_query += packet.sub_query
elif isinstance(packet, ExtendedToolResponse):
# we shouldn't get this ... it gets intercepted and translated to QADocsResponse
logger.warning(
"_convert_packet_stream_to_response: Unexpected chat packet type ExtendedToolResponse!"
)
elif isinstance(packet, RefinedAnswerImprovement):
response.agent_refined_answer_improvement = (
packet.refined_answer_improvement
)
else:
logger.warning(
f"_convert_packet_stream_to_response - Unrecognized chat packet: type={type(packet)}"
)
response.final_context_doc_indices = _get_final_context_doc_indices(
final_context_docs, response.top_documents
)
# organize / sort agent metadata for output
if len(agent_sub_questions) > 0:
response.agent_sub_questions = cast(
dict[int, list[AgentSubQuestion]],
SubQuestionIdentifier.make_dict_by_level(agent_sub_questions),
)
if len(agent_answers) > 0:
# return the agent_level_answer from the first level or the last one depending
# on agent_refined_answer_improvement
response.agent_answers = cast(
dict[int, list[AgentAnswer]],
SubQuestionIdentifier.make_dict_by_level(agent_answers),
)
if response.agent_answers:
selected_answer_level = (
0
if not response.agent_refined_answer_improvement
else len(response.agent_answers) - 1
)
level_answers = response.agent_answers[selected_answer_level]
for level_answer in level_answers:
if level_answer.answer_type != "agent_level_answer":
continue
answer = level_answer.answer
break
if len(agent_sub_queries) > 0:
# subqueries are often emitted with trailing whitespace ... clean it up here
# perhaps fix at the source?
for v in agent_sub_queries.values():
v.sub_query = v.sub_query.strip()
response.agent_sub_queries = (
AgentSubQuery.make_dict_by_level_and_question_index(agent_sub_queries)
)
response.answer = answer
if answer:
response.answer_citationless = remove_answer_citations(answer)

View File

@@ -1,3 +1,5 @@
from collections import OrderedDict
from typing import Literal
from uuid import UUID
from pydantic import BaseModel
@@ -9,6 +11,7 @@ from onyx.chat.models import CitationInfo
from onyx.chat.models import OnyxContexts
from onyx.chat.models import PersonaOverrideConfig
from onyx.chat.models import QADocsResponse
from onyx.chat.models import SubQuestionIdentifier
from onyx.chat.models import ThreadMessage
from onyx.configs.constants import DocumentSource
from onyx.context.search.enums import LLMEvaluationType
@@ -88,6 +91,64 @@ class SimpleDoc(BaseModel):
metadata: dict | None
class AgentSubQuestion(SubQuestionIdentifier):
sub_question: str
document_ids: list[str]
class AgentAnswer(SubQuestionIdentifier):
answer: str
answer_type: Literal["agent_sub_answer", "agent_level_answer"]
class AgentSubQuery(SubQuestionIdentifier):
sub_query: str
query_id: int
@staticmethod
def make_dict_by_level_and_question_index(
original_dict: dict[tuple[int, int, int], "AgentSubQuery"]
) -> dict[int, dict[int, list["AgentSubQuery"]]]:
"""Takes a dict of tuple(level, question num, query_id) to sub queries.
returns a dict of level to dict[question num to list of query_id's]
Ordering is asc for readability.
"""
# In this function, when we sort int | None, we deliberately push None to the end
# map entries to the level_question_dict
level_question_dict: dict[int, dict[int, list["AgentSubQuery"]]] = {}
for k1, obj in original_dict.items():
level = k1[0]
question = k1[1]
if level not in level_question_dict:
level_question_dict[level] = {}
if question not in level_question_dict[level]:
level_question_dict[level][question] = []
level_question_dict[level][question].append(obj)
# sort each query_id list and question_index
for key1, obj1 in level_question_dict.items():
for key2, value2 in obj1.items():
# sort the query_id list of each question_index
level_question_dict[key1][key2] = sorted(
value2, key=lambda o: o.query_id
)
# sort the question_index dict of level
level_question_dict[key1] = OrderedDict(
sorted(level_question_dict[key1].items(), key=lambda x: (x is None, x))
)
# sort the top dict of levels
sorted_dict = OrderedDict(
sorted(level_question_dict.items(), key=lambda x: (x is None, x))
)
return sorted_dict
class ChatBasicResponse(BaseModel):
# This is built piece by piece, any of these can be None as the flow could break
answer: str | None = None
@@ -107,6 +168,12 @@ class ChatBasicResponse(BaseModel):
simple_search_docs: list[SimpleDoc] | None = None
llm_chunks_indices: list[int] | None = None
# agentic fields
agent_sub_questions: dict[int, list[AgentSubQuestion]] | None = None
agent_answers: dict[int, list[AgentAnswer]] | None = None
agent_sub_queries: dict[int, dict[int, list[AgentSubQuery]]] | None = None
agent_refined_answer_improvement: bool | None = None
class OneShotQARequest(ChunkContext):
# Supports simplier APIs that don't deal with chat histories or message edits

View File

@@ -48,10 +48,15 @@ def fetch_and_process_chat_session_history(
feedback_type: QAFeedbackType | None,
limit: int | None = 500,
) -> list[ChatSessionSnapshot]:
# observed to be slow a scale of 8192 sessions and 4 messages per session
# this is a little slow (5 seconds)
chat_sessions = fetch_chat_sessions_eagerly_by_time(
start=start, end=end, db_session=db_session, limit=limit
)
# this is VERY slow (80 seconds) due to create_chat_chain being called
# for each session. Needs optimizing.
chat_session_snapshots = [
snapshot_from_chat_session(chat_session=chat_session, db_session=db_session)
for chat_session in chat_sessions
@@ -246,6 +251,8 @@ def get_query_history_as_csv(
detail="Query history has been disabled by the administrator.",
)
# this call is very expensive and is timing out via endpoint
# TODO: optimize call and/or generate via background task
complete_chat_session_history = fetch_and_process_chat_session_history(
db_session=db_session,
start=start or datetime.fromtimestamp(0, tz=timezone.utc),

View File

@@ -0,0 +1,45 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Response
from ee.onyx.auth.users import current_cloud_superuser
from ee.onyx.server.tenants.models import ImpersonateRequest
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
from onyx.auth.users import auth_backend
from onyx.auth.users import get_redis_strategy
from onyx.auth.users import User
from onyx.db.engine import get_session_with_tenant
from onyx.db.users import get_user_by_email
from onyx.utils.logger import setup_logger
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.post("/impersonate")
async def impersonate_user(
impersonate_request: ImpersonateRequest,
_: User = Depends(current_cloud_superuser),
) -> Response:
"""Allows a cloud superuser to impersonate another user by generating an impersonation JWT token"""
tenant_id = get_tenant_id_for_email(impersonate_request.email)
with get_session_with_tenant(tenant_id=tenant_id) as tenant_session:
user_to_impersonate = get_user_by_email(
impersonate_request.email, tenant_session
)
if user_to_impersonate is None:
raise HTTPException(status_code=404, detail="User not found")
token = await get_redis_strategy().write_token(user_to_impersonate)
response = await auth_backend.transport.get_login_response(token)
response.set_cookie(
key="fastapiusersauth",
value=token,
httponly=True,
secure=True,
samesite="lax",
)
return response

View File

@@ -0,0 +1,98 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Response
from sqlalchemy.exc import IntegrityError
from ee.onyx.auth.users import generate_anonymous_user_jwt_token
from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME
from ee.onyx.server.tenants.anonymous_user_path import get_anonymous_user_path
from ee.onyx.server.tenants.anonymous_user_path import (
get_tenant_id_for_anonymous_user_path,
)
from ee.onyx.server.tenants.anonymous_user_path import modify_anonymous_user_path
from ee.onyx.server.tenants.anonymous_user_path import validate_anonymous_user_path
from ee.onyx.server.tenants.models import AnonymousUserPath
from onyx.auth.users import anonymous_user_enabled
from onyx.auth.users import current_admin_user
from onyx.auth.users import optional_user
from onyx.auth.users import User
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
from onyx.db.engine import get_session_with_shared_schema
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.get("/anonymous-user-path")
async def get_anonymous_user_path_api(
_: User | None = Depends(current_admin_user),
) -> AnonymousUserPath:
tenant_id = get_current_tenant_id()
if tenant_id is None:
raise HTTPException(status_code=404, detail="Tenant not found")
with get_session_with_shared_schema() as db_session:
current_path = get_anonymous_user_path(tenant_id, db_session)
return AnonymousUserPath(anonymous_user_path=current_path)
@router.post("/anonymous-user-path")
async def set_anonymous_user_path_api(
anonymous_user_path: str,
_: User | None = Depends(current_admin_user),
) -> None:
tenant_id = get_current_tenant_id()
try:
validate_anonymous_user_path(anonymous_user_path)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
with get_session_with_shared_schema() as db_session:
try:
modify_anonymous_user_path(tenant_id, anonymous_user_path, db_session)
except IntegrityError:
raise HTTPException(
status_code=409,
detail="The anonymous user path is already in use. Please choose a different path.",
)
except Exception as e:
logger.exception(f"Failed to modify anonymous user path: {str(e)}")
raise HTTPException(
status_code=500,
detail="An unexpected error occurred while modifying the anonymous user path",
)
@router.post("/anonymous-user")
async def login_as_anonymous_user(
anonymous_user_path: str,
_: User | None = Depends(optional_user),
) -> Response:
with get_session_with_shared_schema() as db_session:
tenant_id = get_tenant_id_for_anonymous_user_path(
anonymous_user_path, db_session
)
if not tenant_id:
raise HTTPException(status_code=404, detail="Tenant not found")
if not anonymous_user_enabled(tenant_id=tenant_id):
raise HTTPException(status_code=403, detail="Anonymous user is not enabled")
token = generate_anonymous_user_jwt_token(tenant_id)
response = Response()
response.delete_cookie(FASTAPI_USERS_AUTH_COOKIE_NAME)
response.set_cookie(
key=ANONYMOUS_USER_COOKIE_NAME,
value=token,
httponly=True,
secure=True,
samesite="strict",
)
return response

View File

@@ -1,269 +1,24 @@
import stripe
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Response
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from ee.onyx.auth.users import current_cloud_superuser
from ee.onyx.auth.users import generate_anonymous_user_jwt_token
from ee.onyx.configs.app_configs import ANONYMOUS_USER_COOKIE_NAME
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
from ee.onyx.server.tenants.access import control_plane_dep
from ee.onyx.server.tenants.anonymous_user_path import get_anonymous_user_path
from ee.onyx.server.tenants.anonymous_user_path import (
get_tenant_id_for_anonymous_user_path,
from ee.onyx.server.tenants.admin_api import router as admin_router
from ee.onyx.server.tenants.anonymous_users_api import router as anonymous_users_router
from ee.onyx.server.tenants.billing_api import router as billing_router
from ee.onyx.server.tenants.team_membership_api import router as team_membership_router
from ee.onyx.server.tenants.tenant_management_api import (
router as tenant_management_router,
)
from ee.onyx.server.tenants.user_invitations_api import (
router as user_invitations_router,
)
from ee.onyx.server.tenants.anonymous_user_path import modify_anonymous_user_path
from ee.onyx.server.tenants.anonymous_user_path import validate_anonymous_user_path
from ee.onyx.server.tenants.billing import fetch_billing_information
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
from ee.onyx.server.tenants.models import AnonymousUserPath
from ee.onyx.server.tenants.models import BillingInformation
from ee.onyx.server.tenants.models import ImpersonateRequest
from ee.onyx.server.tenants.models import ProductGatingRequest
from ee.onyx.server.tenants.models import ProductGatingResponse
from ee.onyx.server.tenants.models import SubscriptionSessionResponse
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
from ee.onyx.server.tenants.product_gating import store_product_gating
from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane
from ee.onyx.server.tenants.user_mapping import get_tenant_id_for_email
from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant
from onyx.auth.users import anonymous_user_enabled
from onyx.auth.users import auth_backend
from onyx.auth.users import current_admin_user
from onyx.auth.users import get_redis_strategy
from onyx.auth.users import optional_user
from onyx.auth.users import User
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import FASTAPI_USERS_AUTH_COOKIE_NAME
from onyx.db.auth import get_user_count
from onyx.db.engine import get_session
from onyx.db.engine import get_session_with_shared_schema
from onyx.db.engine import get_session_with_tenant
from onyx.db.users import delete_user_from_db
from onyx.db.users import get_user_by_email
from onyx.server.manage.models import UserByEmail
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
stripe.api_key = STRIPE_SECRET_KEY
logger = setup_logger()
router = APIRouter(prefix="/tenants")
# Create a main router to include all sub-routers
# Note: We don't add a prefix here as each router already has the /tenants prefix
router = APIRouter()
@router.get("/anonymous-user-path")
async def get_anonymous_user_path_api(
_: User | None = Depends(current_admin_user),
) -> AnonymousUserPath:
tenant_id = get_current_tenant_id()
if tenant_id is None:
raise HTTPException(status_code=404, detail="Tenant not found")
with get_session_with_shared_schema() as db_session:
current_path = get_anonymous_user_path(tenant_id, db_session)
return AnonymousUserPath(anonymous_user_path=current_path)
@router.post("/anonymous-user-path")
async def set_anonymous_user_path_api(
anonymous_user_path: str,
_: User | None = Depends(current_admin_user),
) -> None:
tenant_id = get_current_tenant_id()
try:
validate_anonymous_user_path(anonymous_user_path)
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
with get_session_with_shared_schema() as db_session:
try:
modify_anonymous_user_path(tenant_id, anonymous_user_path, db_session)
except IntegrityError:
raise HTTPException(
status_code=409,
detail="The anonymous user path is already in use. Please choose a different path.",
)
except Exception as e:
logger.exception(f"Failed to modify anonymous user path: {str(e)}")
raise HTTPException(
status_code=500,
detail="An unexpected error occurred while modifying the anonymous user path",
)
@router.post("/anonymous-user")
async def login_as_anonymous_user(
anonymous_user_path: str,
_: User | None = Depends(optional_user),
) -> Response:
with get_session_with_shared_schema() as db_session:
tenant_id = get_tenant_id_for_anonymous_user_path(
anonymous_user_path, db_session
)
if not tenant_id:
raise HTTPException(status_code=404, detail="Tenant not found")
if not anonymous_user_enabled(tenant_id=tenant_id):
raise HTTPException(status_code=403, detail="Anonymous user is not enabled")
token = generate_anonymous_user_jwt_token(tenant_id)
response = Response()
response.delete_cookie(FASTAPI_USERS_AUTH_COOKIE_NAME)
response.set_cookie(
key=ANONYMOUS_USER_COOKIE_NAME,
value=token,
httponly=True,
secure=True,
samesite="strict",
)
return response
@router.post("/product-gating")
def gate_product(
product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep)
) -> ProductGatingResponse:
"""
Gating the product means that the product is not available to the tenant.
They will be directed to the billing page.
We gate the product when their subscription has ended.
"""
try:
store_product_gating(
product_gating_request.tenant_id, product_gating_request.application_status
)
return ProductGatingResponse(updated=True, error=None)
except Exception as e:
logger.exception("Failed to gate product")
return ProductGatingResponse(updated=False, error=str(e))
@router.get("/billing-information")
async def billing_information(
_: User = Depends(current_admin_user),
) -> BillingInformation | SubscriptionStatusResponse:
logger.info("Fetching billing information")
tenant_id = get_current_tenant_id()
return fetch_billing_information(tenant_id)
@router.post("/create-customer-portal-session")
async def create_customer_portal_session(
_: User = Depends(current_admin_user),
) -> dict:
tenant_id = get_current_tenant_id()
try:
stripe_info = fetch_tenant_stripe_information(tenant_id)
stripe_customer_id = stripe_info.get("stripe_customer_id")
if not stripe_customer_id:
raise HTTPException(status_code=400, detail="Stripe customer ID not found")
logger.info(stripe_customer_id)
portal_session = stripe.billing_portal.Session.create(
customer=stripe_customer_id,
return_url=f"{WEB_DOMAIN}/admin/billing",
)
logger.info(portal_session)
return {"url": portal_session.url}
except Exception as e:
logger.exception("Failed to create customer portal session")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/create-subscription-session")
async def create_subscription_session(
_: User = Depends(current_admin_user),
) -> SubscriptionSessionResponse:
try:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if not tenant_id:
raise HTTPException(status_code=400, detail="Tenant ID not found")
session_id = fetch_stripe_checkout_session(tenant_id)
return SubscriptionSessionResponse(sessionId=session_id)
except Exception as e:
logger.exception("Failed to create resubscription session")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/impersonate")
async def impersonate_user(
impersonate_request: ImpersonateRequest,
_: User = Depends(current_cloud_superuser),
) -> Response:
"""Allows a cloud superuser to impersonate another user by generating an impersonation JWT token"""
tenant_id = get_tenant_id_for_email(impersonate_request.email)
with get_session_with_tenant(tenant_id=tenant_id) as tenant_session:
user_to_impersonate = get_user_by_email(
impersonate_request.email, tenant_session
)
if user_to_impersonate is None:
raise HTTPException(status_code=404, detail="User not found")
token = await get_redis_strategy().write_token(user_to_impersonate)
response = await auth_backend.transport.get_login_response(token)
response.set_cookie(
key="fastapiusersauth",
value=token,
httponly=True,
secure=True,
samesite="lax",
)
return response
@router.post("/leave-organization")
async def leave_organization(
user_email: UserByEmail,
current_user: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
tenant_id = get_current_tenant_id()
if current_user is None or current_user.email != user_email.user_email:
raise HTTPException(
status_code=403, detail="You can only leave the organization as yourself"
)
user_to_delete = get_user_by_email(user_email.user_email, db_session)
if user_to_delete is None:
raise HTTPException(status_code=404, detail="User not found")
num_admin_users = await get_user_count(only_admin_users=True)
should_delete_tenant = num_admin_users == 1
if should_delete_tenant:
logger.info(
"Last admin user is leaving the organization. Deleting tenant from control plane."
)
try:
await delete_user_from_control_plane(tenant_id, user_to_delete.email)
logger.debug("User deleted from control plane")
except Exception as e:
logger.exception(
f"Failed to delete user from control plane for tenant {tenant_id}: {e}"
)
raise HTTPException(
status_code=500,
detail=f"Failed to remove user from control plane: {str(e)}",
)
db_session.expunge(user_to_delete)
delete_user_from_db(user_to_delete, db_session)
if should_delete_tenant:
remove_all_users_from_tenant(tenant_id)
else:
remove_users_from_tenant([user_to_delete.email], tenant_id)
# Include all the individual routers
router.include_router(admin_router)
router.include_router(anonymous_users_router)
router.include_router(billing_router)
router.include_router(team_membership_router)
router.include_router(tenant_management_router)
router.include_router(user_invitations_router)

View File

@@ -0,0 +1,96 @@
import stripe
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from ee.onyx.auth.users import current_admin_user
from ee.onyx.configs.app_configs import STRIPE_SECRET_KEY
from ee.onyx.server.tenants.access import control_plane_dep
from ee.onyx.server.tenants.billing import fetch_billing_information
from ee.onyx.server.tenants.billing import fetch_stripe_checkout_session
from ee.onyx.server.tenants.billing import fetch_tenant_stripe_information
from ee.onyx.server.tenants.models import BillingInformation
from ee.onyx.server.tenants.models import ProductGatingRequest
from ee.onyx.server.tenants.models import ProductGatingResponse
from ee.onyx.server.tenants.models import SubscriptionSessionResponse
from ee.onyx.server.tenants.models import SubscriptionStatusResponse
from ee.onyx.server.tenants.product_gating import store_product_gating
from onyx.auth.users import User
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
from shared_configs.contextvars import get_current_tenant_id
stripe.api_key = STRIPE_SECRET_KEY
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.post("/product-gating")
def gate_product(
product_gating_request: ProductGatingRequest, _: None = Depends(control_plane_dep)
) -> ProductGatingResponse:
"""
Gating the product means that the product is not available to the tenant.
They will be directed to the billing page.
We gate the product when their subscription has ended.
"""
try:
store_product_gating(
product_gating_request.tenant_id, product_gating_request.application_status
)
return ProductGatingResponse(updated=True, error=None)
except Exception as e:
logger.exception("Failed to gate product")
return ProductGatingResponse(updated=False, error=str(e))
@router.get("/billing-information")
async def billing_information(
_: User = Depends(current_admin_user),
) -> BillingInformation | SubscriptionStatusResponse:
logger.info("Fetching billing information")
tenant_id = get_current_tenant_id()
return fetch_billing_information(tenant_id)
@router.post("/create-customer-portal-session")
async def create_customer_portal_session(
_: User = Depends(current_admin_user),
) -> dict:
tenant_id = get_current_tenant_id()
try:
stripe_info = fetch_tenant_stripe_information(tenant_id)
stripe_customer_id = stripe_info.get("stripe_customer_id")
if not stripe_customer_id:
raise HTTPException(status_code=400, detail="Stripe customer ID not found")
logger.info(stripe_customer_id)
portal_session = stripe.billing_portal.Session.create(
customer=stripe_customer_id,
return_url=f"{WEB_DOMAIN}/admin/billing",
)
logger.info(portal_session)
return {"url": portal_session.url}
except Exception as e:
logger.exception("Failed to create customer portal session")
raise HTTPException(status_code=500, detail=str(e))
@router.post("/create-subscription-session")
async def create_subscription_session(
_: User = Depends(current_admin_user),
) -> SubscriptionSessionResponse:
try:
tenant_id = CURRENT_TENANT_ID_CONTEXTVAR.get()
if not tenant_id:
raise HTTPException(status_code=400, detail="Tenant ID not found")
session_id = fetch_stripe_checkout_session(tenant_id)
return SubscriptionSessionResponse(sessionId=session_id)
except Exception as e:
logger.exception("Failed to create resubscription session")
raise HTTPException(status_code=500, detail=str(e))

View File

@@ -67,3 +67,30 @@ class ProductGatingResponse(BaseModel):
class SubscriptionSessionResponse(BaseModel):
sessionId: str
class TenantByDomainResponse(BaseModel):
tenant_id: str
number_of_users: int
creator_email: str
class TenantByDomainRequest(BaseModel):
email: str
class RequestInviteRequest(BaseModel):
tenant_id: str
class RequestInviteResponse(BaseModel):
success: bool
message: str
class PendingUserSnapshot(BaseModel):
email: str
class ApproveUserRequest(BaseModel):
email: str

View File

@@ -48,4 +48,5 @@ def store_product_gating(tenant_id: str, application_status: ApplicationStatus)
def get_gated_tenants() -> set[str]:
redis_client = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
return cast(set[str], redis_client.smembers(GATED_TENANTS_KEY))
gated_tenants_bytes = cast(set[bytes], redis_client.smembers(GATED_TENANTS_KEY))
return {tenant_id.decode("utf-8") for tenant_id in gated_tenants_bytes}

View File

@@ -4,6 +4,7 @@ import uuid
import aiohttp # Async HTTP client
import httpx
import requests
from fastapi import HTTPException
from fastapi import Request
from sqlalchemy import select
@@ -14,6 +15,7 @@ from ee.onyx.configs.app_configs import COHERE_DEFAULT_API_KEY
from ee.onyx.configs.app_configs import HUBSPOT_TRACKING_URL
from ee.onyx.configs.app_configs import OPENAI_DEFAULT_API_KEY
from ee.onyx.server.tenants.access import generate_data_plane_token
from ee.onyx.server.tenants.models import TenantByDomainResponse
from ee.onyx.server.tenants.models import TenantCreationPayload
from ee.onyx.server.tenants.models import TenantDeletionPayload
from ee.onyx.server.tenants.schema_management import create_schema_if_not_exists
@@ -55,7 +57,11 @@ logger = logging.getLogger(__name__)
async def get_or_provision_tenant(
email: str, referral_source: str | None = None, request: Request | None = None
) -> str:
"""Get existing tenant ID for an email or create a new tenant if none exists."""
"""
Get existing tenant ID for an email or create a new tenant if none exists.
This function should only be called after we have verified we want this user's tenant to exist.
It returns the tenant ID associated with the email, creating a new tenant if necessary.
"""
if not MULTI_TENANT:
return POSTGRES_DEFAULT_SCHEMA
@@ -349,3 +355,47 @@ async def delete_user_from_control_plane(tenant_id: str, email: str) -> None:
raise Exception(
f"Failed to delete tenant on control plane: {error_text}"
)
def get_tenant_by_domain_from_control_plane(
domain: str,
tenant_id: str,
) -> TenantByDomainResponse | None:
"""
Fetches tenant information from the control plane based on the email domain.
Args:
domain: The email domain to search for (e.g., "example.com")
Returns:
A dictionary containing tenant information if found, None otherwise
"""
token = generate_data_plane_token()
headers = {
"Authorization": f"Bearer {token}",
"Content-Type": "application/json",
}
try:
response = requests.get(
f"{CONTROL_PLANE_API_BASE_URL}/tenant-by-domain",
headers=headers,
json={"domain": domain, "tenant_id": tenant_id},
)
if response.status_code != 200:
logger.error(f"Control plane tenant lookup failed: {response.text}")
return None
response_data = response.json()
if not response_data:
return None
return TenantByDomainResponse(
tenant_id=response_data.get("tenant_id"),
number_of_users=response_data.get("number_of_users"),
creator_email=response_data.get("creator_email"),
)
except Exception as e:
logger.error(f"Error fetching tenant by domain: {str(e)}")
return None

View File

@@ -0,0 +1,67 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from sqlalchemy.orm import Session
from ee.onyx.server.tenants.provisioning import delete_user_from_control_plane
from ee.onyx.server.tenants.user_mapping import remove_all_users_from_tenant
from ee.onyx.server.tenants.user_mapping import remove_users_from_tenant
from onyx.auth.users import current_admin_user
from onyx.auth.users import User
from onyx.db.auth import get_user_count
from onyx.db.engine import get_session
from onyx.db.users import delete_user_from_db
from onyx.db.users import get_user_by_email
from onyx.server.manage.models import UserByEmail
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.post("/leave-team")
async def leave_organization(
user_email: UserByEmail,
current_user: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
tenant_id = get_current_tenant_id()
if current_user is None or current_user.email != user_email.user_email:
raise HTTPException(
status_code=403, detail="You can only leave the organization as yourself"
)
user_to_delete = get_user_by_email(user_email.user_email, db_session)
if user_to_delete is None:
raise HTTPException(status_code=404, detail="User not found")
num_admin_users = await get_user_count(only_admin_users=True)
should_delete_tenant = num_admin_users == 1
if should_delete_tenant:
logger.info(
"Last admin user is leaving the organization. Deleting tenant from control plane."
)
try:
await delete_user_from_control_plane(tenant_id, user_to_delete.email)
logger.debug("User deleted from control plane")
except Exception as e:
logger.exception(
f"Failed to delete user from control plane for tenant {tenant_id}: {e}"
)
raise HTTPException(
status_code=500,
detail=f"Failed to remove user from control plane: {str(e)}",
)
db_session.expunge(user_to_delete)
delete_user_from_db(user_to_delete, db_session)
if should_delete_tenant:
remove_all_users_from_tenant(tenant_id)
else:
remove_users_from_tenant([user_to_delete.email], tenant_id)

View File

@@ -0,0 +1,39 @@
from fastapi import APIRouter
from fastapi import Depends
from ee.onyx.server.tenants.models import TenantByDomainResponse
from ee.onyx.server.tenants.provisioning import get_tenant_by_domain_from_control_plane
from onyx.auth.users import current_user
from onyx.auth.users import User
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/tenants")
FORBIDDEN_COMMON_EMAIL_SUBSTRINGS = [
"gmail",
"outlook",
"yahoo",
"hotmail",
"icloud",
"msn",
"hotmail",
"hotmail.co.uk",
]
@router.get("/existing-team-by-domain")
def get_existing_tenant_by_domain(
user: User | None = Depends(current_user),
) -> TenantByDomainResponse | None:
if not user:
return None
domain = user.email.split("@")[1]
if any(substring in domain for substring in FORBIDDEN_COMMON_EMAIL_SUBSTRINGS):
return None
tenant_id = get_current_tenant_id()
return get_tenant_by_domain_from_control_plane(domain, tenant_id)

View File

@@ -0,0 +1,90 @@
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from ee.onyx.server.tenants.models import ApproveUserRequest
from ee.onyx.server.tenants.models import PendingUserSnapshot
from ee.onyx.server.tenants.models import RequestInviteRequest
from ee.onyx.server.tenants.user_mapping import accept_user_invite
from ee.onyx.server.tenants.user_mapping import approve_user_invite
from ee.onyx.server.tenants.user_mapping import deny_user_invite
from ee.onyx.server.tenants.user_mapping import invite_self_to_tenant
from onyx.auth.invited_users import get_pending_users
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_user
from onyx.auth.users import User
from onyx.utils.logger import setup_logger
from shared_configs.contextvars import get_current_tenant_id
logger = setup_logger()
router = APIRouter(prefix="/tenants")
@router.post("/users/invite/request")
async def request_invite(
invite_request: RequestInviteRequest,
user: User | None = Depends(current_admin_user),
) -> None:
if user is None:
raise HTTPException(status_code=401, detail="User not authenticated")
try:
invite_self_to_tenant(user.email, invite_request.tenant_id)
except Exception as e:
logger.exception(
f"Failed to invite self to tenant {invite_request.tenant_id}: {e}"
)
raise HTTPException(status_code=500, detail=str(e))
@router.get("/users/pending")
def list_pending_users(
_: User | None = Depends(current_admin_user),
) -> list[PendingUserSnapshot]:
pending_emails = get_pending_users()
return [PendingUserSnapshot(email=email) for email in pending_emails]
@router.post("/users/invite/approve")
async def approve_user(
approve_user_request: ApproveUserRequest,
_: User | None = Depends(current_admin_user),
) -> None:
tenant_id = get_current_tenant_id()
approve_user_invite(approve_user_request.email, tenant_id)
@router.post("/users/invite/accept")
async def accept_invite(
invite_request: RequestInviteRequest,
user: User | None = Depends(current_user),
) -> None:
"""
Accept an invitation to join a tenant.
"""
if not user:
raise HTTPException(status_code=401, detail="Not authenticated")
try:
accept_user_invite(user.email, invite_request.tenant_id)
except Exception as e:
logger.exception(f"Failed to accept invite: {str(e)}")
raise HTTPException(status_code=500, detail="Failed to accept invitation")
@router.post("/users/invite/deny")
async def deny_invite(
invite_request: RequestInviteRequest,
user: User | None = Depends(current_user),
) -> None:
"""
Deny an invitation to join a tenant.
"""
if not user:
raise HTTPException(status_code=401, detail="Not authenticated")
try:
deny_user_invite(user.email, invite_request.tenant_id)
except Exception as e:
logger.exception(f"Failed to deny invite: {str(e)}")
raise HTTPException(status_code=500, detail="Failed to deny invitation")

View File

@@ -1,27 +1,56 @@
import logging
from fastapi_users import exceptions
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.auth.invited_users import get_invited_users
from onyx.auth.invited_users import get_pending_users
from onyx.auth.invited_users import write_invited_users
from onyx.auth.invited_users import write_pending_users
from onyx.db.engine import get_session_with_shared_schema
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import get_sqlalchemy_engine
from onyx.db.models import UserTenantMapping
from onyx.server.manage.models import TenantSnapshot
from onyx.setup import setup_logger
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import POSTGRES_DEFAULT_SCHEMA
from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = logging.getLogger(__name__)
logger = setup_logger()
def get_tenant_id_for_email(email: str) -> str:
if not MULTI_TENANT:
return POSTGRES_DEFAULT_SCHEMA
# Implement logic to get tenant_id from the mapping table
with Session(get_sqlalchemy_engine()) as db_session:
result = db_session.execute(
select(UserTenantMapping.tenant_id).where(UserTenantMapping.email == email)
)
tenant_id = result.scalar_one_or_none()
try:
with get_session_with_shared_schema() as db_session:
# First try to get an active tenant
result = db_session.execute(
select(UserTenantMapping).where(
UserTenantMapping.email == email,
UserTenantMapping.active == True, # noqa: E712
)
)
mapping = result.scalar_one_or_none()
tenant_id = mapping.tenant_id if mapping else None
# If no active tenant found, try to get the first inactive one
if tenant_id is None:
result = db_session.execute(
select(UserTenantMapping).where(
UserTenantMapping.email == email,
UserTenantMapping.active == False, # noqa: E712
)
)
mapping = result.scalar_one_or_none()
if mapping:
# Mark this mapping as active
mapping.active = True
db_session.commit()
tenant_id = mapping.tenant_id
except Exception as e:
logger.exception(f"Error getting tenant id for email {email}: {e}")
raise exceptions.UserNotExists()
if tenant_id is None:
raise exceptions.UserNotExists()
return tenant_id
@@ -41,7 +70,9 @@ def add_users_to_tenant(emails: list[str], tenant_id: str) -> None:
with get_session_with_tenant(tenant_id=POSTGRES_DEFAULT_SCHEMA) as db_session:
try:
for email in emails:
db_session.add(UserTenantMapping(email=email, tenant_id=tenant_id))
db_session.add(
UserTenantMapping(email=email, tenant_id=tenant_id, active=False)
)
except Exception:
logger.exception(f"Failed to add users to tenant {tenant_id}")
db_session.commit()
@@ -76,3 +107,187 @@ def remove_all_users_from_tenant(tenant_id: str) -> None:
UserTenantMapping.tenant_id == tenant_id
).delete()
db_session.commit()
def invite_self_to_tenant(email: str, tenant_id: str) -> None:
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
try:
pending_users = get_pending_users()
if email in pending_users:
return
write_pending_users(pending_users + [email])
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def approve_user_invite(email: str, tenant_id: str) -> None:
"""
Approve a user invite to a tenant.
This will delete all existing records for this email and create a new mapping entry for the user in this tenant.
"""
with get_session_with_shared_schema() as db_session:
# Delete all existing records for this email
db_session.query(UserTenantMapping).filter(
UserTenantMapping.email == email
).delete()
# Create a new mapping entry for the user in this tenant
new_mapping = UserTenantMapping(email=email, tenant_id=tenant_id, active=True)
db_session.add(new_mapping)
db_session.commit()
# Also remove the user from pending users list
# Remove from pending users
pending_users = get_pending_users()
if email in pending_users:
pending_users.remove(email)
write_pending_users(pending_users)
# Add to invited users
invited_users = get_invited_users()
if email not in invited_users:
invited_users.append(email)
write_invited_users(invited_users)
def accept_user_invite(email: str, tenant_id: str) -> None:
"""
Accept an invitation to join a tenant.
This activates the user's mapping to the tenant.
"""
with get_session_with_shared_schema() as db_session:
try:
# First check if there's an active mapping for this user and tenant
active_mapping = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email == email,
UserTenantMapping.active == True, # noqa: E712
)
.first()
)
# If an active mapping exists, delete it
if active_mapping:
db_session.delete(active_mapping)
logger.info(
f"Deleted existing active mapping for user {email} in tenant {tenant_id}"
)
# Find the inactive mapping for this user and tenant
mapping = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email == email,
UserTenantMapping.tenant_id == tenant_id,
UserTenantMapping.active == False, # noqa: E712
)
.first()
)
if mapping:
# Set all other mappings for this user to inactive
db_session.query(UserTenantMapping).filter(
UserTenantMapping.email == email,
UserTenantMapping.active == True, # noqa: E712
).update({"active": False})
# Activate this mapping
mapping.active = True
db_session.commit()
logger.info(f"User {email} accepted invitation to tenant {tenant_id}")
else:
logger.warning(
f"No invitation found for user {email} in tenant {tenant_id}"
)
except Exception as e:
db_session.rollback()
logger.exception(
f"Failed to accept invitation for user {email} to tenant {tenant_id}: {str(e)}"
)
raise
def deny_user_invite(email: str, tenant_id: str) -> None:
"""
Deny an invitation to join a tenant.
This removes the user's mapping to the tenant.
"""
with get_session_with_shared_schema() as db_session:
# Delete the mapping for this user and tenant
result = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email == email,
UserTenantMapping.tenant_id == tenant_id,
UserTenantMapping.active == False, # noqa: E712
)
.delete()
)
db_session.commit()
if result:
logger.info(f"User {email} denied invitation to tenant {tenant_id}")
else:
logger.warning(
f"No invitation found for user {email} in tenant {tenant_id}"
)
token = CURRENT_TENANT_ID_CONTEXTVAR.set(tenant_id)
try:
pending_users = get_invited_users()
if email in pending_users:
pending_users.remove(email)
write_invited_users(pending_users)
finally:
CURRENT_TENANT_ID_CONTEXTVAR.reset(token)
def get_tenant_count(tenant_id: str) -> int:
"""
Get the number of active users for this tenant
"""
with get_session_with_shared_schema() as db_session:
# Count the number of active users for this tenant
user_count = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.tenant_id == tenant_id,
UserTenantMapping.active == True, # noqa: E712
)
.count()
)
return user_count
def get_tenant_invitation(email: str) -> TenantSnapshot | None:
"""
Get the first tenant invitation for this user
"""
with get_session_with_shared_schema() as db_session:
# Get the first tenant invitation for this user
invitation = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.email == email,
UserTenantMapping.active == False, # noqa: E712
)
.first()
)
if invitation:
# Get the user count for this tenant
user_count = (
db_session.query(UserTenantMapping)
.filter(
UserTenantMapping.tenant_id == invitation.tenant_id,
UserTenantMapping.active == True, # noqa: E712
)
.count()
)
return TenantSnapshot(
tenant_id=invitation.tenant_id, number_of_users=user_count
)
return None

View File

@@ -62,6 +62,60 @@ _OPENAI_MAX_INPUT_LEN = 2048
# Cohere allows up to 96 embeddings in a single embedding calling
_COHERE_MAX_INPUT_LEN = 96
# Authentication error string constants
_AUTH_ERROR_401 = "401"
_AUTH_ERROR_UNAUTHORIZED = "unauthorized"
_AUTH_ERROR_INVALID_API_KEY = "invalid api key"
_AUTH_ERROR_PERMISSION = "permission"
def is_authentication_error(error: Exception) -> bool:
"""Check if an exception is related to authentication issues.
Args:
error: The exception to check
Returns:
bool: True if the error appears to be authentication-related
"""
error_str = str(error).lower()
return (
_AUTH_ERROR_401 in error_str
or _AUTH_ERROR_UNAUTHORIZED in error_str
or _AUTH_ERROR_INVALID_API_KEY in error_str
or _AUTH_ERROR_PERMISSION in error_str
)
def format_embedding_error(
error: Exception,
service_name: str,
model: str | None,
provider: EmbeddingProvider,
status_code: int | None = None,
) -> str:
"""
Format a standardized error string for embedding errors.
"""
detail = f"Status {status_code}" if status_code else f"{type(error)}"
return (
f"{'HTTP error' if status_code else 'Exception'} embedding text with {service_name} - {detail}: "
f"Model: {model} "
f"Provider: {provider} "
f"Exception: {error}"
)
# Custom exception for authentication errors
class AuthenticationError(Exception):
"""Raised when authentication fails with a provider."""
def __init__(self, provider: str, message: str = "API key is invalid or expired"):
self.provider = provider
self.message = message
super().__init__(f"{provider} authentication failed: {message}")
class CloudEmbedding:
def __init__(
@@ -92,31 +146,17 @@ class CloudEmbedding:
)
final_embeddings: list[Embedding] = []
try:
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
response = await client.embeddings.create(
input=text_batch,
model=model,
dimensions=reduced_dimension or openai.NOT_GIVEN,
)
final_embeddings.extend(
[embedding.embedding for embedding in response.data]
)
return final_embeddings
except Exception as e:
error_string = (
f"Exception embedding text with OpenAI - {type(e)}: "
f"Model: {model} "
f"Provider: {self.provider} "
f"Exception: {e}"
for text_batch in batch_list(texts, _OPENAI_MAX_INPUT_LEN):
response = await client.embeddings.create(
input=text_batch,
model=model,
dimensions=reduced_dimension or openai.NOT_GIVEN,
)
logger.error(error_string)
# only log text when it's not an authentication error.
if not isinstance(e, openai.AuthenticationError):
logger.debug(f"Exception texts: {texts}")
raise RuntimeError(error_string)
final_embeddings.extend(
[embedding.embedding for embedding in response.data]
)
return final_embeddings
async def _embed_cohere(
self, texts: list[str], model: str | None, embedding_type: str
@@ -155,7 +195,6 @@ class CloudEmbedding:
input_type=embedding_type,
truncation=True,
)
return response.embeddings
async def _embed_azure(
@@ -239,22 +278,51 @@ class CloudEmbedding:
deployment_name: str | None = None,
reduced_dimension: int | None = None,
) -> list[Embedding]:
if self.provider == EmbeddingProvider.OPENAI:
return await self._embed_openai(texts, model_name, reduced_dimension)
elif self.provider == EmbeddingProvider.AZURE:
return await self._embed_azure(texts, f"azure/{deployment_name}")
elif self.provider == EmbeddingProvider.LITELLM:
return await self._embed_litellm_proxy(texts, model_name)
try:
if self.provider == EmbeddingProvider.OPENAI:
return await self._embed_openai(texts, model_name, reduced_dimension)
elif self.provider == EmbeddingProvider.AZURE:
return await self._embed_azure(texts, f"azure/{deployment_name}")
elif self.provider == EmbeddingProvider.LITELLM:
return await self._embed_litellm_proxy(texts, model_name)
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
if self.provider == EmbeddingProvider.COHERE:
return await self._embed_cohere(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.VOYAGE:
return await self._embed_voyage(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.GOOGLE:
return await self._embed_vertex(texts, model_name, embedding_type)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
embedding_type = EmbeddingModelTextType.get_type(self.provider, text_type)
if self.provider == EmbeddingProvider.COHERE:
return await self._embed_cohere(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.VOYAGE:
return await self._embed_voyage(texts, model_name, embedding_type)
elif self.provider == EmbeddingProvider.GOOGLE:
return await self._embed_vertex(texts, model_name, embedding_type)
else:
raise ValueError(f"Unsupported provider: {self.provider}")
except openai.AuthenticationError:
raise AuthenticationError(provider="OpenAI")
except httpx.HTTPStatusError as e:
if e.response.status_code == 401:
raise AuthenticationError(provider=str(self.provider))
error_string = format_embedding_error(
e,
str(self.provider),
model_name or deployment_name,
self.provider,
status_code=e.response.status_code,
)
logger.error(error_string)
logger.debug(f"Exception texts: {texts}")
raise RuntimeError(error_string)
except Exception as e:
if is_authentication_error(e):
raise AuthenticationError(provider=str(self.provider))
error_string = format_embedding_error(
e, str(self.provider), model_name or deployment_name, self.provider
)
logger.error(error_string)
logger.debug(f"Exception texts: {texts}")
raise RuntimeError(error_string)
@staticmethod
def create(
@@ -569,6 +637,13 @@ async def process_embed_request(
gpu_type=gpu_type,
)
return EmbedResponse(embeddings=embeddings)
except AuthenticationError as e:
# Handle authentication errors consistently
logger.error(f"Authentication error: {e.provider}")
raise HTTPException(
status_code=401,
detail=f"Authentication failed: {e.message}",
)
except RateLimitError as e:
raise HTTPException(
status_code=429,

View File

@@ -31,6 +31,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_VALIDATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_SUBANSWER_CHECK
from onyx.llm.chat_llm import LLMRateLimitError
@@ -92,6 +93,7 @@ def check_sub_answer(
fast_llm.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK,
max_tokens=AGENT_MAX_TOKENS_VALIDATION,
)
quality_str: str = cast(str, response.content)

View File

@@ -46,6 +46,7 @@ from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_SUBANSWER_GENERATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION
from onyx.llm.chat_llm import LLMRateLimitError
@@ -119,6 +120,7 @@ def generate_sub_answer(
for message in fast_llm.stream(
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION,
max_tokens=AGENT_MAX_TOKENS_SUBANSWER_GENERATION,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content

View File

@@ -43,6 +43,7 @@ from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrin
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_section_list,
)
from onyx.agents.agent_search.shared_graph_utils.utils import _should_restrict_tokens
from onyx.agents.agent_search.shared_graph_utils.utils import (
dispatch_main_answer_stop_info,
)
@@ -62,6 +63,7 @@ from onyx.chat.models import StreamingError
from onyx.configs.agent_configs import AGENT_ANSWER_GENERATION_BY_FAST_LLM
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_MAX_STREAMED_DOCS_FOR_INITIAL_ANSWER
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_ANSWER_GENERATION
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
@@ -153,8 +155,9 @@ def generate_initial_answer(
)
for tool_response in yield_search_responses(
query=question,
reranked_sections=answer_generation_documents.streaming_documents,
final_context_sections=answer_generation_documents.context_documents,
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
get_reranked_sections=lambda: answer_generation_documents.streaming_documents,
get_final_context_sections=lambda: answer_generation_documents.context_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,
search_tool=graph_config.tooling.search_tool,
@@ -278,6 +281,9 @@ def generate_initial_answer(
for message in model.stream(
msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
max_tokens=AGENT_MAX_TOKENS_ANSWER_GENERATION
if _should_restrict_tokens(model.config)
else None,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content

View File

@@ -34,6 +34,7 @@ from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.chat.models import SubQuestionPiece
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_SUBQUESTION_GENERATION
from onyx.configs.agent_configs import AGENT_NUM_DOCS_FOR_DECOMPOSITION
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION,
@@ -141,6 +142,7 @@ def decompose_orig_question(
model.stream(
msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION,
max_tokens=AGENT_MAX_TOKENS_SUBQUESTION_GENERATION,
),
dispatch_subquestion(0, writer),
sep_callback=dispatch_subquestion_sep(0, writer),

View File

@@ -33,6 +33,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import RefinedAnswerImprovement
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_VALIDATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_COMPARE_ANSWERS
from onyx.llm.chat_llm import LLMRateLimitError
@@ -112,6 +113,7 @@ def compare_answers(
model.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS,
max_tokens=AGENT_MAX_TOKENS_VALIDATION,
)
except (LLMTimeoutError, TimeoutError):

View File

@@ -43,6 +43,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import StreamingError
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_SUBQUESTION_GENERATION
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION,
)
@@ -144,6 +145,7 @@ def create_refined_sub_questions(
model.stream(
msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION,
max_tokens=AGENT_MAX_TOKENS_SUBQUESTION_GENERATION,
),
dispatch_subquestion(1, writer),
sep_callback=dispatch_subquestion_sep(1, writer),

View File

@@ -50,13 +50,7 @@ def decide_refinement_need(
)
]
if graph_config.behavior.allow_refinement:
return RequireRefinemenEvalUpdate(
require_refined_answer_eval=decision,
log_messages=log_messages,
)
else:
return RequireRefinemenEvalUpdate(
require_refined_answer_eval=False,
log_messages=log_messages,
)
return RequireRefinemenEvalUpdate(
require_refined_answer_eval=graph_config.behavior.allow_refinement and decision,
log_messages=log_messages,
)

View File

@@ -21,6 +21,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import format_docs
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_ENTITY_TERM_EXTRACTION
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION,
)
@@ -96,6 +97,7 @@ def extract_entities_terms(
fast_llm.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION,
max_tokens=AGENT_MAX_TOKENS_ENTITY_TERM_EXTRACTION,
)
cleaned_response = (

View File

@@ -46,6 +46,7 @@ from onyx.agents.agent_search.shared_graph_utils.models import RefinedAgentStats
from onyx.agents.agent_search.shared_graph_utils.operators import (
dedup_inference_section_list,
)
from onyx.agents.agent_search.shared_graph_utils.utils import _should_restrict_tokens
from onyx.agents.agent_search.shared_graph_utils.utils import (
dispatch_main_answer_stop_info,
)
@@ -68,6 +69,8 @@ from onyx.chat.models import StreamingError
from onyx.configs.agent_configs import AGENT_ANSWER_GENERATION_BY_FAST_LLM
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_MAX_STREAMED_DOCS_FOR_REFINED_ANSWER
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_ANSWER_GENERATION
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_VALIDATION
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION,
@@ -179,8 +182,9 @@ def generate_validate_refined_answer(
)
for tool_response in yield_search_responses(
query=question,
reranked_sections=answer_generation_documents.streaming_documents,
final_context_sections=answer_generation_documents.context_documents,
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
get_reranked_sections=lambda: answer_generation_documents.streaming_documents,
get_final_context_sections=lambda: answer_generation_documents.context_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,
search_tool=graph_config.tooling.search_tool,
@@ -302,7 +306,11 @@ def generate_validate_refined_answer(
def stream_refined_answer() -> list[str]:
for message in model.stream(
msg, timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION
msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION,
max_tokens=AGENT_MAX_TOKENS_ANSWER_GENERATION
if _should_restrict_tokens(model.config)
else None,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
@@ -409,6 +417,7 @@ def generate_validate_refined_answer(
validation_model.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION,
max_tokens=AGENT_MAX_TOKENS_VALIDATION,
)
refined_answer_quality = binary_string_test_after_answer_separator(
text=cast(str, validation_response.content),

View File

@@ -13,7 +13,6 @@ from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.chat.models import SubQuestionPiece
from onyx.context.search.models import IndexFilters
from onyx.tools.models import SearchQueryInfo
from onyx.utils.logger import setup_logger
@@ -144,8 +143,6 @@ def get_query_info(results: list[QueryRetrievalResult]) -> SearchQueryInfo:
if result.query_info is not None:
query_info = result.query_info
break
return query_info or SearchQueryInfo(
predicted_search=None,
final_filters=IndexFilters(access_control_list=None),
recency_bias_multiplier=1.0,
)
assert query_info is not None, "must have query info"
return query_info

View File

@@ -33,6 +33,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_SUBQUERY_GENERATION
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION,
)
@@ -96,6 +97,7 @@ def expand_queries(
model.stream(
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION,
max_tokens=AGENT_MAX_TOKENS_SUBQUERY_GENERATION,
),
dispatch_subquery(level, question_num, writer),
)

View File

@@ -56,8 +56,9 @@ def format_results(
relevance_list = relevance_from_docs(reranked_documents)
for tool_response in yield_search_responses(
query=state.question,
reranked_sections=state.retrieved_documents,
final_context_sections=reranked_documents,
get_retrieved_sections=lambda: reranked_documents,
get_reranked_sections=lambda: state.retrieved_documents,
get_final_context_sections=lambda: reranked_documents,
search_query_info=query_info,
get_section_relevance=lambda: relevance_list,
search_tool=graph_config.tooling.search_tool,

View File

@@ -91,7 +91,7 @@ def retrieve_documents(
retrieved_docs = retrieved_docs[:AGENT_MAX_QUERY_RETRIEVAL_RESULTS]
if AGENT_RETRIEVAL_STATS:
pre_rerank_docs = callback_container[0]
pre_rerank_docs = callback_container[0] if callback_container else []
fit_scores = get_fit_scores(
pre_rerank_docs,
retrieved_docs,

View File

@@ -25,6 +25,7 @@ from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrin
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_VALIDATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION
from onyx.llm.chat_llm import LLMRateLimitError
@@ -93,6 +94,7 @@ def verify_documents(
fast_llm.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION,
max_tokens=AGENT_MAX_TOKENS_VALIDATION,
)
assert isinstance(response.content, str)

View File

@@ -44,7 +44,9 @@ def call_tool(
tool = tool_choice.tool
tool_args = tool_choice.tool_args
tool_id = tool_choice.id
tool_runner = ToolRunner(tool, tool_args)
tool_runner = ToolRunner(
tool, tool_args, override_kwargs=tool_choice.search_tool_override_kwargs
)
tool_kickoff = tool_runner.kickoff()
emit_packet(tool_kickoff, writer)

View File

@@ -15,8 +15,17 @@ from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
from onyx.chat.tool_handling.tool_response_handler import (
get_tool_call_for_non_tool_calling_llm_impl,
)
from onyx.context.search.preprocessing.preprocessing import query_analysis
from onyx.context.search.retrieval.search_runner import get_query_embedding
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_in_background
from onyx.utils.threadpool_concurrency import TimeoutThread
from onyx.utils.threadpool_concurrency import wait_on_background
from onyx.utils.timing import log_function_time
from shared_configs.model_server_models import Embedding
logger = setup_logger()
@@ -25,6 +34,7 @@ logger = setup_logger()
# and a function that handles extracting the necessary fields
# from the state and config
# TODO: fan-out to multiple tool call nodes? Make this configurable?
@log_function_time(print_only=True)
def choose_tool(
state: ToolChoiceState,
config: RunnableConfig,
@@ -37,6 +47,31 @@ def choose_tool(
should_stream_answer = state.should_stream_answer
agent_config = cast(GraphConfig, config["metadata"]["config"])
force_use_tool = agent_config.tooling.force_use_tool
embedding_thread: TimeoutThread[Embedding] | None = None
keyword_thread: TimeoutThread[tuple[bool, list[str]]] | None = None
override_kwargs: SearchToolOverrideKwargs | None = None
if (
not agent_config.behavior.use_agentic_search
and agent_config.tooling.search_tool is not None
and (
not force_use_tool.force_use or force_use_tool.tool_name == SearchTool.name
)
):
override_kwargs = SearchToolOverrideKwargs()
# Run in a background thread to avoid blocking the main thread
embedding_thread = run_in_background(
get_query_embedding,
agent_config.inputs.search_request.query,
agent_config.persistence.db_session,
)
keyword_thread = run_in_background(
query_analysis,
agent_config.inputs.search_request.query,
)
using_tool_calling_llm = agent_config.tooling.using_tool_calling_llm
prompt_builder = state.prompt_snapshot or agent_config.inputs.prompt_builder
@@ -47,7 +82,6 @@ def choose_tool(
tools = [
tool for tool in (agent_config.tooling.tools or []) if tool.name in state.tools
]
force_use_tool = agent_config.tooling.force_use_tool
tool, tool_args = None, None
if force_use_tool.force_use and force_use_tool.args is not None:
@@ -71,11 +105,22 @@ def choose_tool(
# If we have a tool and tool args, we are ready to request a tool call.
# This only happens if the tool call was forced or we are using a non-tool calling LLM.
if tool and tool_args:
if embedding_thread and tool.name == SearchTool._NAME:
# Wait for the embedding thread to finish
embedding = wait_on_background(embedding_thread)
assert override_kwargs is not None, "must have override kwargs"
override_kwargs.precomputed_query_embedding = embedding
if keyword_thread and tool.name == SearchTool._NAME:
is_keyword, keywords = wait_on_background(keyword_thread)
assert override_kwargs is not None, "must have override kwargs"
override_kwargs.precomputed_is_keyword = is_keyword
override_kwargs.precomputed_keywords = keywords
return ToolChoiceUpdate(
tool_choice=ToolChoice(
tool=tool,
tool_args=tool_args,
id=str(uuid4()),
search_tool_override_kwargs=override_kwargs,
),
)
@@ -153,10 +198,22 @@ def choose_tool(
logger.debug(f"Selected tool: {selected_tool.name}")
logger.debug(f"Selected tool call request: {selected_tool_call_request}")
if embedding_thread and selected_tool.name == SearchTool._NAME:
# Wait for the embedding thread to finish
embedding = wait_on_background(embedding_thread)
assert override_kwargs is not None, "must have override kwargs"
override_kwargs.precomputed_query_embedding = embedding
if keyword_thread and selected_tool.name == SearchTool._NAME:
is_keyword, keywords = wait_on_background(keyword_thread)
assert override_kwargs is not None, "must have override kwargs"
override_kwargs.precomputed_is_keyword = is_keyword
override_kwargs.precomputed_keywords = keywords
return ToolChoiceUpdate(
tool_choice=ToolChoice(
tool=selected_tool,
tool_args=selected_tool_call_request["args"],
id=selected_tool_call_request["id"],
search_tool_override_kwargs=override_kwargs,
),
)

View File

@@ -9,18 +9,23 @@ from onyx.agents.agent_search.basic.states import BasicState
from onyx.agents.agent_search.basic.utils import process_llm_stream
from onyx.agents.agent_search.models import GraphConfig
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxContexts
from onyx.tools.tool_implementations.search.search_tool import (
SEARCH_DOC_CONTENT_ID,
SEARCH_RESPONSE_SUMMARY_ID,
)
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
from onyx.tools.tool_implementations.search.search_utils import (
context_from_inference_section,
)
from onyx.tools.tool_implementations.search_like_tool_utils import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
from onyx.utils.logger import setup_logger
from onyx.utils.timing import log_function_time
logger = setup_logger()
@log_function_time(print_only=True)
def basic_use_tool_response(
state: BasicState, config: RunnableConfig, writer: StreamWriter = lambda _: None
) -> BasicOutput:
@@ -50,11 +55,13 @@ def basic_use_tool_response(
for yield_item in tool_call_responses:
if yield_item.id == FINAL_CONTEXT_DOCUMENTS_ID:
final_search_results = cast(list[LlmDoc], yield_item.response)
elif yield_item.id == SEARCH_DOC_CONTENT_ID:
search_contexts = cast(OnyxContexts, yield_item.response).contexts
for doc in search_contexts:
if doc.document_id not in initial_search_results:
initial_search_results.append(doc)
elif yield_item.id == SEARCH_RESPONSE_SUMMARY_ID:
search_response_summary = cast(SearchResponseSummary, yield_item.response)
for section in search_response_summary.top_sections:
if section.center_chunk.document_id not in initial_search_results:
initial_search_results.append(
context_from_inference_section(section)
)
new_tool_call_chunk = AIMessageChunk(content="")
if not agent_config.behavior.skip_gen_ai_answer_generation:

View File

@@ -2,6 +2,7 @@ from pydantic import BaseModel
from onyx.chat.prompt_builder.answer_prompt_builder import PromptSnapshot
from onyx.tools.message import ToolCallSummary
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.models import ToolCallFinalResult
from onyx.tools.models import ToolCallKickoff
from onyx.tools.models import ToolResponse
@@ -35,6 +36,7 @@ class ToolChoice(BaseModel):
tool: Tool
tool_args: dict
id: str | None
search_tool_override_kwargs: SearchToolOverrideKwargs | None = None
class Config:
arbitrary_types_allowed = True

View File

@@ -13,6 +13,11 @@ AGENT_NEGATIVE_VALUE_STR = "no"
AGENT_ANSWER_SEPARATOR = "Answer:"
EMBEDDING_KEY = "embedding"
IS_KEYWORD_KEY = "is_keyword"
KEYWORDS_KEY = "keywords"
class AgentLLMErrorType(str, Enum):
TIMEOUT = "timeout"
RATE_LIMIT = "rate_limit"

View File

@@ -42,6 +42,7 @@ from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.configs.agent_configs import AGENT_MAX_TOKENS_HISTORY_SUMMARY
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION,
)
@@ -61,6 +62,7 @@ from onyx.db.persona import Persona
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.llm.interfaces import LLM
from onyx.llm.interfaces import LLMConfig
from onyx.prompts.agent_search import (
ASSISTANT_SYSTEM_PROMPT_DEFAULT,
)
@@ -402,6 +404,7 @@ def summarize_history(
llm.invoke,
history_context_prompt,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION,
max_tokens=AGENT_MAX_TOKENS_HISTORY_SUMMARY,
)
except (LLMTimeoutError, TimeoutError):
logger.error("LLM Timeout Error - summarize history")
@@ -505,3 +508,9 @@ def get_deduplicated_structured_subquestion_documents(
cited_documents=dedup_inference_section_list(cited_docs),
context_documents=dedup_inference_section_list(context_docs),
)
def _should_restrict_tokens(llm_config: LLMConfig) -> bool:
return not (
llm_config.model_provider == "openai" and llm_config.model_name.startswith("o")
)

View File

@@ -153,7 +153,8 @@ def send_email(
msg = MIMEMultipart("alternative")
msg["Subject"] = subject
msg["To"] = user_email
msg["From"] = mail_from
if mail_from:
msg["From"] = mail_from
msg["Date"] = formatdate(localtime=True)
msg["Message-ID"] = make_msgid(domain="onyx.app")

View File

@@ -1,5 +1,6 @@
from typing import cast
from onyx.configs.constants import KV_PENDING_USERS_KEY
from onyx.configs.constants import KV_USER_STORE_KEY
from onyx.key_value_store.factory import get_kv_store
from onyx.key_value_store.interface import KvKeyNotFoundError
@@ -18,3 +19,17 @@ def write_invited_users(emails: list[str]) -> int:
store = get_kv_store()
store.store(KV_USER_STORE_KEY, cast(JSON_ro, emails))
return len(emails)
def get_pending_users() -> list[str]:
try:
store = get_kv_store()
return cast(list, store.load(KV_PENDING_USERS_KEY))
except KvKeyNotFoundError:
return list()
def write_pending_users(emails: list[str]) -> int:
store = get_kv_store()
store.store(KV_PENDING_USERS_KEY, cast(JSON_ro, emails))
return len(emails)

View File

@@ -100,6 +100,7 @@ from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.telemetry import optional_telemetry
from onyx.utils.telemetry import RecordType
from onyx.utils.url import add_url_params
from onyx.utils.variable_functionality import fetch_ee_implementation_or_noop
from onyx.utils.variable_functionality import fetch_versioned_implementation
from shared_configs.configs import async_return_default_schema
@@ -587,14 +588,20 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
) -> Optional[User]:
email = credentials.username
# Get tenant_id from mapping table
tenant_id = await fetch_ee_implementation_or_noop(
"onyx.server.tenants.provisioning",
"get_or_provision_tenant",
async_return_default_schema,
)(
email=email,
)
tenant_id: str | None = None
try:
tenant_id = fetch_ee_implementation_or_noop(
"onyx.server.tenants.provisioning",
"get_tenant_id_for_email",
None,
)(
email=email,
)
except Exception as e:
logger.warning(
f"User attempted to login with invalid credentials: {str(e)}"
)
if not tenant_id:
# User not found in mapping
self.password_helper.hash(credentials.password)
@@ -888,7 +895,7 @@ async def current_limited_user(
return await double_check_user(user)
async def current_chat_accesssible_user(
async def current_chat_accessible_user(
user: User | None = Depends(optional_user),
) -> User | None:
tenant_id = get_current_tenant_id()
@@ -1089,6 +1096,12 @@ def get_oauth_router(
next_url = state_data.get("next_url", "/")
referral_source = state_data.get("referral_source", None)
try:
tenant_id = fetch_ee_implementation_or_noop(
"onyx.server.tenants.user_mapping", "get_tenant_id_for_email", None
)(account_email)
except exceptions.UserNotExists:
tenant_id = None
request.state.referral_source = referral_source
@@ -1120,9 +1133,14 @@ def get_oauth_router(
# Login user
response = await backend.login(strategy, user)
await user_manager.on_after_login(user, request, response)
# Prepare redirect response
redirect_response = RedirectResponse(next_url, status_code=302)
if tenant_id is None:
# Use URL utility to add parameters
redirect_url = add_url_params(next_url, {"new_team": "true"})
redirect_response = RedirectResponse(redirect_url, status_code=302)
else:
# No parameters to add
redirect_response = RedirectResponse(next_url, status_code=302)
# Copy headers and other attributes from 'response' to 'redirect_response'
for header_name, header_value in response.headers.items():
@@ -1134,6 +1152,7 @@ def get_oauth_router(
redirect_response.status_code = response.status_code
if hasattr(response, "media_type"):
redirect_response.media_type = response.media_type
return redirect_response
return router

View File

@@ -111,5 +111,6 @@ celery_app.autodiscover_tasks(
"onyx.background.celery.tasks.vespa",
"onyx.background.celery.tasks.connector_deletion",
"onyx.background.celery.tasks.doc_permission_syncing",
"onyx.background.celery.tasks.indexing",
]
)

View File

@@ -0,0 +1,73 @@
# backend/onyx/background/celery/memory_monitoring.py
import logging
import os
from logging.handlers import RotatingFileHandler
import psutil
from onyx.utils.logger import is_running_in_container
from onyx.utils.logger import setup_logger
# Regular application logger
logger = setup_logger()
# Only set up memory monitoring in container environment
if is_running_in_container():
# Set up a dedicated memory monitoring logger
MEMORY_LOG_DIR = "/var/log/persisted-logs/memory"
MEMORY_LOG_FILE = os.path.join(MEMORY_LOG_DIR, "memory_usage.log")
MEMORY_LOG_MAX_BYTES = 10 * 1024 * 1024 # 10MB
MEMORY_LOG_BACKUP_COUNT = 5 # Keep 5 backup files
# Ensure log directory exists
os.makedirs(MEMORY_LOG_DIR, exist_ok=True)
# Create a dedicated logger for memory monitoring
memory_logger = logging.getLogger("memory_monitoring")
memory_logger.setLevel(logging.INFO)
# Create a rotating file handler
memory_handler = RotatingFileHandler(
MEMORY_LOG_FILE,
maxBytes=MEMORY_LOG_MAX_BYTES,
backupCount=MEMORY_LOG_BACKUP_COUNT,
)
# Create a formatter that includes all relevant information
memory_formatter = logging.Formatter(
"%(asctime)s [%(levelname)s] %(message)s", datefmt="%Y-%m-%d %H:%M:%S"
)
memory_handler.setFormatter(memory_formatter)
memory_logger.addHandler(memory_handler)
else:
# Create a null logger when not in container
memory_logger = logging.getLogger("memory_monitoring")
memory_logger.addHandler(logging.NullHandler())
def emit_process_memory(
pid: int, process_name: str, additional_metadata: dict[str, str | int]
) -> None:
# Skip memory monitoring if not in container
if not is_running_in_container():
return
try:
process = psutil.Process(pid)
memory_info = process.memory_info()
cpu_percent = process.cpu_percent(interval=0.1)
# Build metadata string from additional_metadata dictionary
metadata_str = " ".join(
[f"{key}={value}" for key, value in additional_metadata.items()]
)
metadata_str = f" {metadata_str}" if metadata_str else ""
memory_logger.info(
f"PROCESS_MEMORY process_name={process_name} pid={pid} "
f"rss_mb={memory_info.rss / (1024 * 1024):.2f} "
f"vms_mb={memory_info.vms / (1024 * 1024):.2f} "
f"cpu={cpu_percent:.2f}{metadata_str}"
)
except Exception:
logger.exception("Error monitoring process memory.")

View File

@@ -23,6 +23,7 @@ from sqlalchemy.orm import Session
from onyx.background.celery.apps.app_base import task_logger
from onyx.background.celery.celery_utils import httpx_init_vespa_pool
from onyx.background.celery.memory_monitoring import emit_process_memory
from onyx.background.celery.tasks.indexing.utils import get_unfenced_index_attempt_ids
from onyx.background.celery.tasks.indexing.utils import IndexingCallback
from onyx.background.celery.tasks.indexing.utils import should_index
@@ -984,6 +985,9 @@ def connector_indexing_proxy_task(
redis_connector = RedisConnector(tenant_id, cc_pair_id)
redis_connector_index = redis_connector.new_index(search_settings_id)
# Track the last time memory info was emitted
last_memory_emit_time = 0.0
try:
with get_session_with_current_tenant() as db_session:
index_attempt = get_index_attempt(
@@ -1024,6 +1028,23 @@ def connector_indexing_proxy_task(
job.release()
break
# log the memory usage for tracking down memory leaks / connector-specific memory issues
pid = job.process.pid
if pid is not None:
# Only emit memory info once per minute (60 seconds)
current_time = time.monotonic()
if current_time - last_memory_emit_time >= 60.0:
emit_process_memory(
pid,
"indexing_worker",
{
"cc_pair_id": cc_pair_id,
"search_settings_id": search_settings_id,
"index_attempt_id": index_attempt_id,
},
)
last_memory_emit_time = current_time
# if a termination signal is detected, break (exit point will clean up)
if self.request.id and redis_connector_index.terminating(self.request.id):
task_logger.warning(
@@ -1170,6 +1191,7 @@ def connector_indexing_proxy_task(
return
# primary
@shared_task(
name=OnyxCeleryTask.CHECK_FOR_CHECKPOINT_CLEANUP,
soft_time_limit=300,
@@ -1217,6 +1239,7 @@ def check_for_checkpoint_cleanup(*, tenant_id: str) -> None:
)
# light worker
@shared_task(
name=OnyxCeleryTask.CLEANUP_CHECKPOINT,
bind=True,

View File

@@ -15,6 +15,8 @@ from onyx.chat.stream_processing.answer_response_handler import (
from onyx.chat.tool_handling.tool_response_handler import ToolResponseHandler
# This is Legacy code that is not used anymore.
# It is kept here for reference.
class LLMResponseHandlerManager:
"""
This class is responsible for postprocessing the LLM response stream.

View File

@@ -1,10 +1,13 @@
from collections import OrderedDict
from collections.abc import Callable
from collections.abc import Iterator
from collections.abc import Mapping
from datetime import datetime
from enum import Enum
from typing import Any
from typing import Literal
from typing import TYPE_CHECKING
from typing import Union
from pydantic import BaseModel
from pydantic import ConfigDict
@@ -44,9 +47,44 @@ class LlmDoc(BaseModel):
class SubQuestionIdentifier(BaseModel):
"""None represents references to objects in the original flow. To our understanding,
these will not be None in the packets returned from agent search.
"""
level: int | None = None
level_question_num: int | None = None
@staticmethod
def make_dict_by_level(
original_dict: Mapping[tuple[int, int], "SubQuestionIdentifier"]
) -> dict[int, list["SubQuestionIdentifier"]]:
"""returns a dict of level to object list (sorted by level_question_num)
Ordering is asc for readability.
"""
# organize by level, then sort ascending by question_index
level_dict: dict[int, list[SubQuestionIdentifier]] = {}
# group by level
for k, obj in original_dict.items():
level = k[0]
if level not in level_dict:
level_dict[level] = []
level_dict[level].append(obj)
# for each level, sort the group
for k2, value2 in level_dict.items():
# we need to handle the none case due to SubQuestionIdentifier typing
# level_question_num as int | None, even though it should never be None here.
level_dict[k2] = sorted(
value2,
key=lambda x: (x.level_question_num is None, x.level_question_num),
)
# sort by level
sorted_dict = OrderedDict(sorted(level_dict.items()))
return sorted_dict
# First chunk of info for streaming QA
class QADocsResponse(RetrievalDocs, SubQuestionIdentifier):
@@ -336,6 +374,8 @@ class AgentAnswerPiece(SubQuestionIdentifier):
class SubQuestionPiece(SubQuestionIdentifier):
"""Refined sub questions generated from the initial user question."""
sub_question: str
@@ -347,13 +387,13 @@ class RefinedAnswerImprovement(BaseModel):
refined_answer_improvement: bool
AgentSearchPacket = (
AgentSearchPacket = Union[
SubQuestionPiece
| AgentAnswerPiece
| SubQueryPiece
| ExtendedToolResponse
| RefinedAnswerImprovement
)
]
AnswerPacket = (
AnswerQuestionPossibleReturn | AgentSearchPacket | ToolCallKickoff | ToolResponse

View File

@@ -90,97 +90,97 @@ class CitationProcessor:
next(group for group in citation.groups() if group is not None)
)
if 1 <= numerical_value <= self.max_citation_num:
context_llm_doc = self.context_docs[numerical_value - 1]
final_citation_num = self.final_order_mapping[
if not (1 <= numerical_value <= self.max_citation_num):
continue
context_llm_doc = self.context_docs[numerical_value - 1]
final_citation_num = self.final_order_mapping[
context_llm_doc.document_id
]
if final_citation_num not in self.citation_order:
self.citation_order.append(final_citation_num)
citation_order_idx = self.citation_order.index(final_citation_num) + 1
# get the value that was displayed to user, should always
# be in the display_doc_order_dict. But check anyways
if context_llm_doc.document_id in self.display_order_mapping:
displayed_citation_num = self.display_order_mapping[
context_llm_doc.document_id
]
if final_citation_num not in self.citation_order:
self.citation_order.append(final_citation_num)
citation_order_idx = (
self.citation_order.index(final_citation_num) + 1
else:
displayed_citation_num = final_citation_num
logger.warning(
f"Doc {context_llm_doc.document_id} not in display_doc_order_dict. Used LLM citation number instead."
)
# get the value that was displayed to user, should always
# be in the display_doc_order_dict. But check anyways
if context_llm_doc.document_id in self.display_order_mapping:
displayed_citation_num = self.display_order_mapping[
context_llm_doc.document_id
]
else:
displayed_citation_num = final_citation_num
logger.warning(
f"Doc {context_llm_doc.document_id} not in display_doc_order_dict. Used LLM citation number instead."
)
# Skip consecutive citations of the same work
if final_citation_num in self.current_citations:
start, end = citation.span()
real_start = length_to_add + start
diff = end - start
self.curr_segment = (
self.curr_segment[: length_to_add + start]
+ self.curr_segment[real_start + diff :]
)
length_to_add -= diff
continue
# Handle edge case where LLM outputs citation itself
if self.curr_segment.startswith("[["):
match = re.match(r"\[\[(\d+)\]\]", self.curr_segment)
if match:
try:
doc_id = int(match.group(1))
context_llm_doc = self.context_docs[doc_id - 1]
yield CitationInfo(
# citation_num is now the number post initial ranking, i.e. as displayed to user
citation_num=displayed_citation_num,
document_id=context_llm_doc.document_id,
)
except Exception as e:
logger.warning(
f"Manual LLM citation didn't properly cite documents {e}"
)
else:
logger.warning(
"Manual LLM citation wasn't able to close brackets"
)
continue
link = context_llm_doc.link
self.past_cite_count = len(self.llm_out)
self.current_citations.append(final_citation_num)
if citation_order_idx not in self.cited_inds:
self.cited_inds.add(citation_order_idx)
yield CitationInfo(
# citation number is now the one that was displayed to user
citation_num=displayed_citation_num,
document_id=context_llm_doc.document_id,
)
# Skip consecutive citations of the same work
if final_citation_num in self.current_citations:
start, end = citation.span()
if link:
prev_length = len(self.curr_segment)
self.curr_segment = (
self.curr_segment[: start + length_to_add]
+ f"[[{displayed_citation_num}]]({link})" # use the value that was displayed to user
+ self.curr_segment[end + length_to_add :]
)
length_to_add += len(self.curr_segment) - prev_length
else:
prev_length = len(self.curr_segment)
self.curr_segment = (
self.curr_segment[: start + length_to_add]
+ f"[[{displayed_citation_num}]]()" # use the value that was displayed to user
+ self.curr_segment[end + length_to_add :]
)
length_to_add += len(self.curr_segment) - prev_length
real_start = length_to_add + start
diff = end - start
self.curr_segment = (
self.curr_segment[: length_to_add + start]
+ self.curr_segment[real_start + diff :]
)
length_to_add -= diff
continue
last_citation_end = end + length_to_add
# Handle edge case where LLM outputs citation itself
if self.curr_segment.startswith("[["):
match = re.match(r"\[\[(\d+)\]\]", self.curr_segment)
if match:
try:
doc_id = int(match.group(1))
context_llm_doc = self.context_docs[doc_id - 1]
yield CitationInfo(
# citation_num is now the number post initial ranking, i.e. as displayed to user
citation_num=displayed_citation_num,
document_id=context_llm_doc.document_id,
)
except Exception as e:
logger.warning(
f"Manual LLM citation didn't properly cite documents {e}"
)
else:
logger.warning(
"Manual LLM citation wasn't able to close brackets"
)
continue
link = context_llm_doc.link
self.past_cite_count = len(self.llm_out)
self.current_citations.append(final_citation_num)
if citation_order_idx not in self.cited_inds:
self.cited_inds.add(citation_order_idx)
yield CitationInfo(
# citation number is now the one that was displayed to user
citation_num=displayed_citation_num,
document_id=context_llm_doc.document_id,
)
start, end = citation.span()
if link:
prev_length = len(self.curr_segment)
self.curr_segment = (
self.curr_segment[: start + length_to_add]
+ f"[[{displayed_citation_num}]]({link})" # use the value that was displayed to user
+ self.curr_segment[end + length_to_add :]
)
length_to_add += len(self.curr_segment) - prev_length
else:
prev_length = len(self.curr_segment)
self.curr_segment = (
self.curr_segment[: start + length_to_add]
+ f"[[{displayed_citation_num}]]()" # use the value that was displayed to user
+ self.curr_segment[end + length_to_add :]
)
length_to_add += len(self.curr_segment) - prev_length
last_citation_end = end + length_to_add
if last_citation_end > 0:
result += self.curr_segment[:last_citation_end]

View File

@@ -217,20 +217,20 @@ AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION = int(
)
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION = 4 # in seconds
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION = 6 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION
)
AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION = 30 # in seconds
AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION = 40 # in seconds
AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION
)
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION = 5 # in seconds
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION = 10 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION
@@ -243,13 +243,13 @@ AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION = int(
)
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION = 5 # in seconds
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION = 15 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION
)
AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION = 30 # in seconds
AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION = 45 # in seconds
AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION
@@ -333,4 +333,45 @@ AGENT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION = int(
or AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION
)
AGENT_DEFAULT_MAX_TOKENS_VALIDATION = 4
AGENT_MAX_TOKENS_VALIDATION = int(
os.environ.get("AGENT_MAX_TOKENS_VALIDATION") or AGENT_DEFAULT_MAX_TOKENS_VALIDATION
)
AGENT_DEFAULT_MAX_TOKENS_SUBANSWER_GENERATION = 256
AGENT_MAX_TOKENS_SUBANSWER_GENERATION = int(
os.environ.get("AGENT_MAX_TOKENS_SUBANSWER_GENERATION")
or AGENT_DEFAULT_MAX_TOKENS_SUBANSWER_GENERATION
)
AGENT_DEFAULT_MAX_TOKENS_ANSWER_GENERATION = 1024
AGENT_MAX_TOKENS_ANSWER_GENERATION = int(
os.environ.get("AGENT_MAX_TOKENS_ANSWER_GENERATION")
or AGENT_DEFAULT_MAX_TOKENS_ANSWER_GENERATION
)
AGENT_DEFAULT_MAX_TOKENS_SUBQUESTION_GENERATION = 256
AGENT_MAX_TOKENS_SUBQUESTION_GENERATION = int(
os.environ.get("AGENT_MAX_TOKENS_SUBQUESTION_GENERATION")
or AGENT_DEFAULT_MAX_TOKENS_SUBQUESTION_GENERATION
)
AGENT_DEFAULT_MAX_TOKENS_ENTITY_TERM_EXTRACTION = 1024
AGENT_MAX_TOKENS_ENTITY_TERM_EXTRACTION = int(
os.environ.get("AGENT_MAX_TOKENS_ENTITY_TERM_EXTRACTION")
or AGENT_DEFAULT_MAX_TOKENS_ENTITY_TERM_EXTRACTION
)
AGENT_DEFAULT_MAX_TOKENS_SUBQUERY_GENERATION = 64
AGENT_MAX_TOKENS_SUBQUERY_GENERATION = int(
os.environ.get("AGENT_MAX_TOKENS_SUBQUERY_GENERATION")
or AGENT_DEFAULT_MAX_TOKENS_SUBQUERY_GENERATION
)
AGENT_DEFAULT_MAX_TOKENS_HISTORY_SUMMARY = 128
AGENT_MAX_TOKENS_HISTORY_SUMMARY = int(
os.environ.get("AGENT_MAX_TOKENS_HISTORY_SUMMARY")
or AGENT_DEFAULT_MAX_TOKENS_HISTORY_SUMMARY
)
GRAPH_VERSION_NAME: str = "a"

View File

@@ -76,6 +76,7 @@ KV_REINDEX_KEY = "needs_reindexing"
KV_SEARCH_SETTINGS = "search_settings"
KV_UNSTRUCTURED_API_KEY = "unstructured_api_key"
KV_USER_STORE_KEY = "INVITED_USERS"
KV_PENDING_USERS_KEY = "PENDING_USERS"
KV_NO_AUTH_USER_PREFERENCES_KEY = "no_auth_user_preferences"
KV_CRED_KEY = "credential_id_{}"
KV_GMAIL_CRED_KEY = "gmail_app_credential"

View File

@@ -66,9 +66,6 @@ _RESTRICTIONS_EXPANSION_FIELDS = [
_SLIM_DOC_BATCH_SIZE = 5000
_ATTACHMENT_EXTENSIONS_TO_FILTER_OUT = [
"png",
"jpg",
"jpeg",
"gif",
"mp4",
"mov",
@@ -240,7 +237,7 @@ class ConfluenceConnector(
# Extract basic page information
page_id = page["id"]
page_title = page["title"]
page_url = f"{self.wiki_base}/wiki{page['_links']['webui']}"
page_url = f"{self.wiki_base}{page['_links']['webui']}"
# Get the page content
page_content = extract_text_from_confluence_html(
@@ -305,7 +302,9 @@ class ConfluenceConnector(
# Create the document
return Document(
id=build_confluence_document_id(self.wiki_base, page_id, self.is_cloud),
id=build_confluence_document_id(
self.wiki_base, page["_links"]["webui"], self.is_cloud
),
sections=sections,
source=DocumentSource.CONFLUENCE,
semantic_identifier=page_title,
@@ -376,7 +375,7 @@ class ConfluenceConnector(
content_text, file_storage_name = response
object_url = build_confluence_document_id(
self.wiki_base, page["_links"]["webui"], self.is_cloud
self.wiki_base, attachment["_links"]["webui"], self.is_cloud
)
if content_text:

View File

@@ -228,10 +228,15 @@ class GitbookConnector(LoadConnector, PollConnector):
raise ConnectorMissingCredentialError("GitBook")
try:
content = self.client.get(f"/spaces/{self.space_id}/content")
content = self.client.get(f"/spaces/{self.space_id}/content/pages")
pages: list[dict[str, Any]] = content.get("pages", [])
current_batch: list[Document] = []
logger.info(f"Found {len(pages)} root pages.")
logger.info(
f"First 20 Page Ids: {[page.get('id', 'Unknown') for page in pages[:20]]}"
)
while pages:
page = pages.pop(0)

View File

@@ -124,14 +124,14 @@ class GithubConnector(LoadConnector, PollConnector):
def __init__(
self,
repo_owner: str,
repo_name: str | None = None,
repositories: str | None = None,
batch_size: int = INDEX_BATCH_SIZE,
state_filter: str = "all",
include_prs: bool = True,
include_issues: bool = False,
) -> None:
self.repo_owner = repo_owner
self.repo_name = repo_name
self.repositories = repositories
self.batch_size = batch_size
self.state_filter = state_filter
self.include_prs = include_prs
@@ -157,11 +157,42 @@ class GithubConnector(LoadConnector, PollConnector):
)
try:
return github_client.get_repo(f"{self.repo_owner}/{self.repo_name}")
return github_client.get_repo(f"{self.repo_owner}/{self.repositories}")
except RateLimitExceededException:
_sleep_after_rate_limit_exception(github_client)
return self._get_github_repo(github_client, attempt_num + 1)
def _get_github_repos(
self, github_client: Github, attempt_num: int = 0
) -> list[Repository.Repository]:
"""Get specific repositories based on comma-separated repo_name string."""
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
raise RuntimeError(
"Re-tried fetching repos too many times. Something is going wrong with fetching objects from Github"
)
try:
repos = []
# Split repo_name by comma and strip whitespace
repo_names = [
name.strip() for name in (cast(str, self.repositories)).split(",")
]
for repo_name in repo_names:
if repo_name: # Skip empty strings
try:
repo = github_client.get_repo(f"{self.repo_owner}/{repo_name}")
repos.append(repo)
except GithubException as e:
logger.warning(
f"Could not fetch repo {self.repo_owner}/{repo_name}: {e}"
)
return repos
except RateLimitExceededException:
_sleep_after_rate_limit_exception(github_client)
return self._get_github_repos(github_client, attempt_num + 1)
def _get_all_repos(
self, github_client: Github, attempt_num: int = 0
) -> list[Repository.Repository]:
@@ -189,11 +220,17 @@ class GithubConnector(LoadConnector, PollConnector):
if self.github_client is None:
raise ConnectorMissingCredentialError("GitHub")
repos = (
[self._get_github_repo(self.github_client)]
if self.repo_name
else self._get_all_repos(self.github_client)
)
repos = []
if self.repositories:
if "," in self.repositories:
# Multiple repositories specified
repos = self._get_github_repos(self.github_client)
else:
# Single repository (backward compatibility)
repos = [self._get_github_repo(self.github_client)]
else:
# All repositories
repos = self._get_all_repos(self.github_client)
for repo in repos:
if self.include_prs:
@@ -268,11 +305,48 @@ class GithubConnector(LoadConnector, PollConnector):
)
try:
if self.repo_name:
test_repo = self.github_client.get_repo(
f"{self.repo_owner}/{self.repo_name}"
)
test_repo.get_contents("")
if self.repositories:
if "," in self.repositories:
# Multiple repositories specified
repo_names = [name.strip() for name in self.repositories.split(",")]
if not repo_names:
raise ConnectorValidationError(
"Invalid connector settings: No valid repository names provided."
)
# Validate at least one repository exists and is accessible
valid_repos = False
validation_errors = []
for repo_name in repo_names:
if not repo_name:
continue
try:
test_repo = self.github_client.get_repo(
f"{self.repo_owner}/{repo_name}"
)
test_repo.get_contents("")
valid_repos = True
# If at least one repo is valid, we can proceed
break
except GithubException as e:
validation_errors.append(
f"Repository '{repo_name}': {e.data.get('message', str(e))}"
)
if not valid_repos:
error_msg = (
"None of the specified repositories could be accessed: "
)
error_msg += ", ".join(validation_errors)
raise ConnectorValidationError(error_msg)
else:
# Single repository (backward compatibility)
test_repo = self.github_client.get_repo(
f"{self.repo_owner}/{self.repositories}"
)
test_repo.get_contents("")
else:
# Try to get organization first
try:
@@ -298,10 +372,15 @@ class GithubConnector(LoadConnector, PollConnector):
"Your GitHub token does not have sufficient permissions for this repository (HTTP 403)."
)
elif e.status == 404:
if self.repo_name:
raise ConnectorValidationError(
f"GitHub repository not found with name: {self.repo_owner}/{self.repo_name}"
)
if self.repositories:
if "," in self.repositories:
raise ConnectorValidationError(
f"None of the specified GitHub repositories could be found for owner: {self.repo_owner}"
)
else:
raise ConnectorValidationError(
f"GitHub repository not found with name: {self.repo_owner}/{self.repositories}"
)
else:
raise ConnectorValidationError(
f"GitHub user or organization not found: {self.repo_owner}"
@@ -310,6 +389,7 @@ class GithubConnector(LoadConnector, PollConnector):
raise ConnectorValidationError(
f"Unexpected GitHub error (status={e.status}): {e.data}"
)
except Exception as exc:
raise Exception(
f"Unexpected error during GitHub settings validation: {exc}"
@@ -321,7 +401,7 @@ if __name__ == "__main__":
connector = GithubConnector(
repo_owner=os.environ["REPO_OWNER"],
repo_name=os.environ["REPO_NAME"],
repositories=os.environ["REPOSITORIES"],
)
connector.load_credentials(
{"github_access_token": os.environ["GITHUB_ACCESS_TOKEN"]}

View File

@@ -316,7 +316,9 @@ class GoogleDriveConnector(
# validate that the user has access to the drive APIs by performing a simple
# request and checking for a 401
try:
retry_builder()(get_root_folder_id)(drive_service)
# default is ~17mins of retries, don't do that here for cases so we don't
# waste 17mins everytime we run into a user without access to drive APIs
retry_builder(tries=3, delay=1)(get_root_folder_id)(drive_service)
except HttpError as e:
if e.status_code == 401:
# fail gracefully, let the other impersonations continue

View File

@@ -1,3 +1,4 @@
import json
from datetime import datetime
from enum import Enum
from typing import Any
@@ -204,6 +205,15 @@ class ConnectorCheckpoint(BaseModel):
def build_dummy_checkpoint(cls) -> "ConnectorCheckpoint":
return ConnectorCheckpoint(checkpoint_content={}, has_more=True)
def __str__(self) -> str:
"""String representation of the checkpoint, with truncation for large checkpoint content."""
MAX_CHECKPOINT_CONTENT_CHARS = 1000
content_str = json.dumps(self.checkpoint_content)
if len(content_str) > MAX_CHECKPOINT_CONTENT_CHARS:
content_str = content_str[: MAX_CHECKPOINT_CONTENT_CHARS - 3] + "..."
return f"ConnectorCheckpoint(checkpoint_content={content_str}, has_more={self.has_more})"
class DocumentFailure(BaseModel):
document_id: str

View File

@@ -1,4 +1,3 @@
import time
from collections.abc import Generator
from dataclasses import dataclass
from dataclasses import fields
@@ -32,6 +31,7 @@ from onyx.utils.logger import setup_logger
logger = setup_logger()
_NOTION_PAGE_SIZE = 100
_NOTION_CALL_TIMEOUT = 30 # 30 seconds
@@ -537,9 +537,9 @@ class NotionConnector(LoadConnector, PollConnector):
"""
filtered_pages: list[NotionPage] = []
for page in pages:
compare_time = time.mktime(
time.strptime(page[filter_field], "%Y-%m-%dT%H:%M:%S.000Z")
)
# Parse ISO 8601 timestamp and convert to UTC epoch time
timestamp = page[filter_field].replace(".000Z", "+00:00")
compare_time = datetime.fromisoformat(timestamp).timestamp()
if compare_time > start and compare_time <= end:
filtered_pages += [NotionPage(**page)]
return filtered_pages
@@ -578,7 +578,7 @@ class NotionConnector(LoadConnector, PollConnector):
query_dict = {
"filter": {"property": "object", "value": "page"},
"page_size": self.batch_size,
"page_size": _NOTION_PAGE_SIZE,
}
while True:
db_res = self._search_notion(query_dict)
@@ -604,7 +604,7 @@ class NotionConnector(LoadConnector, PollConnector):
return
query_dict = {
"page_size": self.batch_size,
"page_size": _NOTION_PAGE_SIZE,
"sort": {"timestamp": "last_edited_time", "direction": "descending"},
"filter": {"property": "object", "value": "page"},
}

View File

@@ -674,7 +674,7 @@ class SlackConnector(SlimConnector, CheckpointConnector):
"""
1. Verify the bot token is valid for the workspace (via auth_test).
2. Ensure the bot has enough scope to list channels.
3. Check that every channel specified in self.channels exists.
3. Check that every channel specified in self.channels exists (only when regex is not enabled).
"""
if self.client is None:
raise ConnectorMissingCredentialError("Slack credentials not loaded.")
@@ -706,8 +706,8 @@ class SlackConnector(SlimConnector, CheckpointConnector):
f"Slack API returned a failure: {error_msg}"
)
# 3) If channels are specified, verify each is accessible
if self.channels:
# 3) If channels are specified and regex is not enabled, verify each is accessible
if self.channels and not self.channel_regex_enabled:
accessible_channels = get_channels(
client=self.client,
exclude_archived=True,

View File

@@ -16,7 +16,7 @@ from onyx.db.models import SearchSettings
from onyx.indexing.models import BaseChunk
from onyx.indexing.models import IndexingSetting
from shared_configs.enums import RerankerProvider
from shared_configs.model_server_models import Embedding
MAX_METRICS_CONTENT = (
200 # Just need enough characters to identify where in the doc the chunk is
@@ -151,6 +151,10 @@ class SearchRequest(ChunkContext):
evaluation_type: LLMEvaluationType = LLMEvaluationType.UNSPECIFIED
model_config = ConfigDict(arbitrary_types_allowed=True)
precomputed_query_embedding: Embedding | None = None
precomputed_is_keyword: bool | None = None
precomputed_keywords: list[str] | None = None
class SearchQuery(ChunkContext):
"Processed Request that is directly passed to the SearchPipeline"
@@ -175,6 +179,8 @@ class SearchQuery(ChunkContext):
offset: int = 0
model_config = ConfigDict(frozen=True)
precomputed_query_embedding: Embedding | None = None
class RetrievalDetails(ChunkContext):
# Use LLM to determine whether to do a retrieval or only rely on existing history

View File

@@ -331,6 +331,14 @@ class SearchPipeline:
self._retrieved_sections = expanded_inference_sections
return expanded_inference_sections
@property
def retrieved_sections(self) -> list[InferenceSection]:
if self._retrieved_sections is not None:
return self._retrieved_sections
self._retrieved_sections = self._get_sections()
return self._retrieved_sections
@property
def reranked_sections(self) -> list[InferenceSection]:
"""Reranking is always done at the chunk level since section merging could create arbitrarily
@@ -343,7 +351,7 @@ class SearchPipeline:
if self._reranked_sections is not None:
return self._reranked_sections
retrieved_sections = self._get_sections()
retrieved_sections = self.retrieved_sections
if self.retrieved_sections_callback is not None:
self.retrieved_sections_callback(retrieved_sections)

View File

@@ -117,8 +117,12 @@ def retrieval_preprocessing(
else None
)
# Sometimes this is pre-computed in parallel with other heavy tasks to improve
# latency, and in that case we don't need to run the model again
run_query_analysis = (
None if skip_query_analysis else FunctionCall(query_analysis, (query,), {})
None
if (skip_query_analysis or search_request.precomputed_is_keyword is not None)
else FunctionCall(query_analysis, (query,), {})
)
functions_to_run = [
@@ -143,11 +147,12 @@ def retrieval_preprocessing(
# The extracted keywords right now are not very reliable, not using for now
# Can maybe use for highlighting
is_keyword, extracted_keywords = (
parallel_results[run_query_analysis.result_id]
if run_query_analysis
else (False, None)
)
is_keyword, _extracted_keywords = False, None
if search_request.precomputed_is_keyword is not None:
is_keyword = search_request.precomputed_is_keyword
_extracted_keywords = search_request.precomputed_keywords
elif run_query_analysis:
is_keyword, _extracted_keywords = parallel_results[run_query_analysis.result_id]
all_query_terms = query.split()
processed_keywords = (
@@ -247,4 +252,5 @@ def retrieval_preprocessing(
chunks_above=chunks_above,
chunks_below=chunks_below,
full_doc=search_request.full_doc,
precomputed_query_embedding=search_request.precomputed_query_embedding,
)

View File

@@ -31,7 +31,7 @@ from onyx.utils.timing import log_function_time
from shared_configs.configs import MODEL_SERVER_HOST
from shared_configs.configs import MODEL_SERVER_PORT
from shared_configs.enums import EmbedTextType
from shared_configs.model_server_models import Embedding
logger = setup_logger()
@@ -109,6 +109,20 @@ def combine_retrieval_results(
return sorted_chunks
def get_query_embedding(query: str, db_session: Session) -> Embedding:
search_settings = get_current_search_settings(db_session)
model = EmbeddingModel.from_db_model(
search_settings=search_settings,
# The below are globally set, this flow always uses the indexing one
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
)
query_embedding = model.encode([query], text_type=EmbedTextType.QUERY)[0]
return query_embedding
@log_function_time(print_only=True)
def doc_index_retrieval(
query: SearchQuery,
@@ -121,17 +135,10 @@ def doc_index_retrieval(
from the large chunks to the referenced chunks,
dedupes the chunks, and cleans the chunks.
"""
search_settings = get_current_search_settings(db_session)
model = EmbeddingModel.from_db_model(
search_settings=search_settings,
# The below are globally set, this flow always uses the indexing one
server_host=MODEL_SERVER_HOST,
server_port=MODEL_SERVER_PORT,
query_embedding = query.precomputed_query_embedding or get_query_embedding(
query.query, db_session
)
query_embedding = model.encode([query.query], text_type=EmbedTextType.QUERY)[0]
top_chunks = document_index.hybrid_retrieval(
query=query.query,
query_embedding=query_embedding,
@@ -249,7 +256,16 @@ def retrieve_chunks(
continue
simplified_queries.add(simplified_rephrase)
q_copy = query.copy(update={"query": rephrase}, deep=True)
q_copy = query.model_copy(
update={
"query": rephrase,
# need to recompute for each rephrase
# note that `SearchQuery` is a frozen model, so we can't update
# it below
"precomputed_query_embedding": None,
},
deep=True,
)
run_queries.append(
(
doc_index_retrieval,

View File

@@ -2295,15 +2295,14 @@ class PublicBase(DeclarativeBase):
__abstract__ = True
# Strictly keeps track of the tenant that a given user will authenticate to.
class UserTenantMapping(Base):
__tablename__ = "user_tenant_mapping"
__table_args__ = (
UniqueConstraint("email", "tenant_id", name="uq_user_tenant"),
{"schema": "public"},
)
__table_args__ = ({"schema": "public"},)
email: Mapped[str] = mapped_column(String, nullable=False, primary_key=True)
tenant_id: Mapped[str] = mapped_column(String, nullable=False)
tenant_id: Mapped[str] = mapped_column(String, nullable=False, primary_key=True)
active: Mapped[bool] = mapped_column(Boolean, nullable=False, default=True)
@validates("email")
def validate_email(self, key: str, value: str) -> str:

View File

@@ -1,6 +1,7 @@
import random
from datetime import datetime
from datetime import timedelta
from logging import getLogger
from onyx.configs.constants import MessageType
from onyx.db.chat import create_chat_session
@@ -9,6 +10,8 @@ from onyx.db.chat import get_or_create_root_message
from onyx.db.engine import get_session_with_current_tenant
from onyx.db.models import ChatSession
logger = getLogger(__name__)
def seed_chat_history(num_sessions: int, num_messages: int, days: int) -> None:
"""Utility function to seed chat history for testing.
@@ -19,12 +22,18 @@ def seed_chat_history(num_sessions: int, num_messages: int, days: int) -> None:
the times.
"""
with get_session_with_current_tenant() as db_session:
logger.info(f"Seeding {num_sessions} sessions.")
for y in range(0, num_sessions):
create_chat_session(db_session, f"pytest_session_{y}", None, None)
# randomize all session times
logger.info(f"Seeding {num_messages} messages per session.")
rows = db_session.query(ChatSession).all()
for row in rows:
for x in range(0, len(rows)):
if x % 1024 == 0:
logger.info(f"Seeded messages for {x} sessions so far.")
row = rows[x]
row.time_created = datetime.utcnow() - timedelta(
days=random.randint(0, days)
)
@@ -34,20 +43,37 @@ def seed_chat_history(num_sessions: int, num_messages: int, days: int) -> None:
root_message = get_or_create_root_message(row.id, db_session)
current_message_type = MessageType.USER
parent_message = root_message
for x in range(0, num_messages):
if current_message_type == MessageType.USER:
msg = f"pytest_message_user_{x}"
else:
msg = f"pytest_message_assistant_{x}"
chat_message = create_new_chat_message(
row.id,
root_message,
f"pytest_message_{x}",
parent_message,
msg,
None,
0,
MessageType.USER,
current_message_type,
db_session,
)
chat_message.time_sent = row.time_created + timedelta(
minutes=random.randint(0, 10)
)
db_session.commit()
db_session.commit()
current_message_type = (
MessageType.ASSISTANT
if current_message_type == MessageType.USER
else MessageType.USER
)
parent_message = chat_message
db_session.commit()
logger.info(f"Seeded messages for {len(rows)} sessions. Finished.")

View File

@@ -1,6 +1,5 @@
from sqlalchemy import and_
from sqlalchemy import delete
from sqlalchemy import func
from sqlalchemy import or_
from sqlalchemy import select
from sqlalchemy.orm import Session
@@ -149,11 +148,10 @@ def delete_document_tags_for_documents__no_commit(
stmt = delete(Document__Tag).where(Document__Tag.document_id.in_(document_ids))
db_session.execute(stmt)
orphan_tags_query = (
select(Tag.id)
.outerjoin(Document__Tag, Tag.id == Document__Tag.tag_id)
.group_by(Tag.id)
.having(func.count(Document__Tag.document_id) == 0)
orphan_tags_query = select(Tag.id).where(
~db_session.query(Document__Tag.tag_id)
.filter(Document__Tag.tag_id == Tag.id)
.exists()
)
orphan_tags = db_session.execute(orphan_tags_query).scalars().all()

View File

@@ -464,12 +464,29 @@ def index_doc_batch(
),
)
successful_doc_ids = {record.document_id for record in insertion_records}
if successful_doc_ids != set(updatable_ids):
all_returned_doc_ids = (
{record.document_id for record in insertion_records}
.union(
{
record.failed_document.document_id
for record in vector_db_write_failures
if record.failed_document
}
)
.union(
{
record.failed_document.document_id
for record in embedding_failures
if record.failed_document
}
)
)
if all_returned_doc_ids != set(updatable_ids):
raise RuntimeError(
f"Some documents were not successfully indexed. "
f"Updatable IDs: {updatable_ids}, "
f"Successful IDs: {successful_doc_ids}"
f"Returned IDs: {all_returned_doc_ids}. "
"This should never happen."
)
last_modified_ids = []

View File

@@ -167,7 +167,7 @@ def _convert_delta_to_message_chunk(
stop_reason: str | None = None,
) -> BaseMessageChunk:
"""Adapted from langchain_community.chat_models.litellm._convert_delta_to_message_chunk"""
role = _dict.get("role") or (_base_msg_to_role(curr_msg) if curr_msg else None)
role = _dict.get("role") or (_base_msg_to_role(curr_msg) if curr_msg else "unknown")
content = _dict.get("content") or ""
additional_kwargs = {}
if _dict.get("function_call"):
@@ -402,6 +402,7 @@ class DefaultMultiLLM(LLM):
stream: bool,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
max_tokens: int | None = None,
) -> litellm.ModelResponse | litellm.CustomStreamWrapper:
# litellm doesn't accept LangChain BaseMessage objects, so we need to convert them
# to a dict representation
@@ -429,6 +430,7 @@ class DefaultMultiLLM(LLM):
# model params
temperature=0,
timeout=timeout_override or self._timeout,
max_tokens=max_tokens,
# For now, we don't support parallel tool calls
# NOTE: we can't pass this in if tools are not specified
# or else OpenAI throws an error
@@ -484,6 +486,7 @@ class DefaultMultiLLM(LLM):
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
max_tokens: int | None = None,
) -> BaseMessage:
if LOG_DANSWER_MODEL_INTERACTIONS:
self.log_model_configs()
@@ -497,6 +500,7 @@ class DefaultMultiLLM(LLM):
stream=False,
structured_response_format=structured_response_format,
timeout_override=timeout_override,
max_tokens=max_tokens,
),
)
choice = response.choices[0]
@@ -515,6 +519,7 @@ class DefaultMultiLLM(LLM):
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
max_tokens: int | None = None,
) -> Iterator[BaseMessage]:
if LOG_DANSWER_MODEL_INTERACTIONS:
self.log_model_configs()
@@ -539,6 +544,7 @@ class DefaultMultiLLM(LLM):
stream=True,
structured_response_format=structured_response_format,
timeout_override=timeout_override,
max_tokens=max_tokens,
),
)
try:

View File

@@ -82,6 +82,7 @@ class CustomModelServer(LLM):
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
max_tokens: int | None = None,
) -> BaseMessage:
return self._execute(prompt)
@@ -92,5 +93,6 @@ class CustomModelServer(LLM):
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
max_tokens: int | None = None,
) -> Iterator[BaseMessage]:
yield self._execute(prompt)

View File

@@ -91,12 +91,18 @@ class LLM(abc.ABC):
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
max_tokens: int | None = None,
) -> BaseMessage:
self._precall(prompt)
# TODO add a postcall to log model outputs independent of concrete class
# implementation
return self._invoke_implementation(
prompt, tools, tool_choice, structured_response_format, timeout_override
prompt,
tools,
tool_choice,
structured_response_format,
timeout_override,
max_tokens,
)
@abc.abstractmethod
@@ -107,6 +113,7 @@ class LLM(abc.ABC):
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
max_tokens: int | None = None,
) -> BaseMessage:
raise NotImplementedError
@@ -117,12 +124,18 @@ class LLM(abc.ABC):
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
max_tokens: int | None = None,
) -> Iterator[BaseMessage]:
self._precall(prompt)
# TODO add a postcall to log model outputs independent of concrete class
# implementation
messages = self._stream_implementation(
prompt, tools, tool_choice, structured_response_format, timeout_override
prompt,
tools,
tool_choice,
structured_response_format,
timeout_override,
max_tokens,
)
tokens = []
@@ -142,5 +155,6 @@ class LLM(abc.ABC):
tool_choice: ToolChoiceOptions | None = None,
structured_response_format: dict | None = None,
timeout_override: int | None = None,
max_tokens: int | None = None,
) -> Iterator[BaseMessage]:
raise NotImplementedError

View File

@@ -51,6 +51,7 @@ from onyx.server.documents.cc_pair import router as cc_pair_router
from onyx.server.documents.connector import router as connector_router
from onyx.server.documents.credential import router as credential_router
from onyx.server.documents.document import router as document_router
from onyx.server.documents.standard_oauth import router as standard_oauth_router
from onyx.server.features.document_set.api import router as document_set_router
from onyx.server.features.folder.api import router as folder_router
from onyx.server.features.input_prompt.api import (
@@ -233,6 +234,8 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
yield
SqlEngine.reset_engine()
if AUTH_RATE_LIMITING_ENABLED:
await close_auth_limiter()
@@ -322,6 +325,7 @@ def get_application() -> FastAPI:
)
include_router_with_global_prefix_prepended(application, long_term_logs_router)
include_router_with_global_prefix_prepended(application, api_key_router)
include_router_with_global_prefix_prepended(application, standard_oauth_router)
if AUTH_TYPE == AuthType.DISABLED:
# Server logs this during auth setup verification step

View File

@@ -5,7 +5,7 @@ from fastapi.dependencies.models import Dependant
from starlette.routing import BaseRoute
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_chat_accesssible_user
from onyx.auth.users import current_chat_accessible_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_limited_user
from onyx.auth.users import current_user
@@ -112,7 +112,7 @@ def check_router_auth(
or depends_fn == current_curator_or_admin_user
or depends_fn == api_key_dep
or depends_fn == current_user_with_expired_token
or depends_fn == current_chat_accesssible_user
or depends_fn == current_chat_accessible_user
or depends_fn == control_plane_dep
or depends_fn == current_cloud_superuser
):

View File

@@ -17,7 +17,7 @@ from pydantic import BaseModel
from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_chat_accesssible_user
from onyx.auth.users import current_chat_accessible_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_user
from onyx.background.celery.versioned_apps.primary import app as primary_app
@@ -1247,7 +1247,7 @@ class BasicCCPairInfo(BaseModel):
@router.get("/connector-status")
def get_basic_connector_indexing_status(
user: User = Depends(current_chat_accesssible_user),
user: User = Depends(current_chat_accessible_user),
db_session: Session = Depends(get_session),
) -> list[BasicCCPairInfo]:
cc_pairs = get_connector_credential_pairs_for_user(

View File

@@ -11,7 +11,7 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_chat_accesssible_user
from onyx.auth.users import current_chat_accessible_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_limited_user
from onyx.auth.users import current_user
@@ -390,7 +390,7 @@ def get_image_generation_tool(
@basic_router.get("")
def list_personas(
user: User | None = Depends(current_chat_accesssible_user),
user: User | None = Depends(current_chat_accessible_user),
db_session: Session = Depends(get_session),
include_deleted: bool = False,
persona_ids: list[int] = Query(None),

View File

@@ -7,7 +7,7 @@ from fastapi import Query
from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_chat_accesssible_user
from onyx.auth.users import current_chat_accessible_user
from onyx.db.engine import get_session
from onyx.db.llm import fetch_existing_llm_providers
from onyx.db.llm import fetch_existing_llm_providers_for_user
@@ -191,7 +191,7 @@ def set_provider_as_default(
@basic_router.get("/provider")
def list_llm_provider_basics(
user: User | None = Depends(current_chat_accesssible_user),
user: User | None = Depends(current_chat_accessible_user),
db_session: Session = Depends(get_session),
) -> list[LLMProviderDescriptor]:
return [

View File

@@ -53,6 +53,16 @@ class UserPreferences(BaseModel):
temperature_override_enabled: bool | None = None
class TenantSnapshot(BaseModel):
tenant_id: str
number_of_users: int
class TenantInfo(BaseModel):
invitation: TenantSnapshot | None = None
new_tenant: TenantSnapshot | None = None
class UserInfo(BaseModel):
id: str
email: str
@@ -65,9 +75,10 @@ class UserInfo(BaseModel):
current_token_created_at: datetime | None = None
current_token_expiry_length: int | None = None
is_cloud_superuser: bool = False
organization_name: str | None = None
team_name: str | None = None
is_anonymous_user: bool | None = None
password_configured: bool | None = None
tenant_info: TenantInfo | None = None
@classmethod
def from_model(
@@ -76,8 +87,9 @@ class UserInfo(BaseModel):
current_token_created_at: datetime | None = None,
expiry_length: int | None = None,
is_cloud_superuser: bool = False,
organization_name: str | None = None,
team_name: str | None = None,
is_anonymous_user: bool | None = None,
tenant_info: TenantInfo | None = None,
) -> "UserInfo":
return cls(
id=str(user.id),
@@ -99,7 +111,7 @@ class UserInfo(BaseModel):
temperature_override_enabled=user.temperature_override_enabled,
)
),
organization_name=organization_name,
team_name=team_name,
# set to None if TRACK_EXTERNAL_IDP_EXPIRY is False so that we avoid cases
# where they previously had this set + used OIDC, and now they switched to
# basic auth are now constantly getting redirected back to the login page
@@ -109,6 +121,7 @@ class UserInfo(BaseModel):
current_token_expiry_length=expiry_length,
is_cloud_superuser=is_cloud_superuser,
is_anonymous_user=is_anonymous_user,
tenant_info=tenant_info,
)

View File

@@ -12,13 +12,11 @@ from fastapi import Depends
from fastapi import HTTPException
from fastapi import Query
from fastapi import Request
from psycopg2.errors import UniqueViolation
from pydantic import BaseModel
from sqlalchemy import Column
from sqlalchemy import desc
from sqlalchemy import select
from sqlalchemy import update
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from ee.onyx.configs.app_configs import SUPER_USERS
@@ -55,6 +53,8 @@ from onyx.key_value_store.factory import get_kv_store
from onyx.server.documents.models import PaginatedReturn
from onyx.server.manage.models import AllUsersResponse
from onyx.server.manage.models import AutoScrollRequest
from onyx.server.manage.models import TenantInfo
from onyx.server.manage.models import TenantSnapshot
from onyx.server.manage.models import UserByEmail
from onyx.server.manage.models import UserInfo
from onyx.server.manage.models import UserPreferences
@@ -296,13 +296,6 @@ def bulk_invite_users(
"onyx.server.tenants.provisioning", "add_users_to_tenant", None
)(new_invited_emails, tenant_id)
except IntegrityError as e:
if isinstance(e.orig, UniqueViolation):
raise HTTPException(
status_code=400,
detail="User has already been invited to a Onyx organization",
)
raise
except Exception as e:
logger.error(f"Failed to add users to tenant {tenant_id}: {str(e)}")
@@ -425,6 +418,10 @@ async def delete_user(
db_session.expunge(user_to_delete)
try:
tenant_id = get_current_tenant_id()
fetch_ee_implementation_or_noop(
"onyx.server.tenants.user_mapping", "remove_users_from_tenant", None
)([user_email.user_email], tenant_id)
delete_user_from_db(user_to_delete, db_session)
logger.info(f"Deleted user {user_to_delete.email}")
@@ -553,8 +550,8 @@ def verify_user_logged_in(
if anonymous_user_enabled(tenant_id=tenant_id):
store = get_kv_store()
return fetch_no_auth_user(store, anonymous_user_enabled=True)
raise BasicAuthenticationError(detail="User Not Authenticated")
if user.oidc_expiry and user.oidc_expiry < datetime.now(timezone.utc):
raise BasicAuthenticationError(
detail="Access denied. User's OIDC token has expired.",
@@ -563,16 +560,35 @@ def verify_user_logged_in(
token_created_at = (
None if MULTI_TENANT else get_current_token_creation(user, db_session)
)
organization_name = fetch_ee_implementation_or_noop(
team_name = fetch_ee_implementation_or_noop(
"onyx.server.tenants.user_mapping", "get_tenant_id_for_email", None
)(user.email)
new_tenant: TenantSnapshot | None = None
tenant_invitation: TenantSnapshot | None = None
if MULTI_TENANT:
if team_name != get_current_tenant_id():
user_count = fetch_ee_implementation_or_noop(
"onyx.server.tenants.user_mapping", "get_tenant_count", None
)(team_name)
new_tenant = TenantSnapshot(tenant_id=team_name, number_of_users=user_count)
tenant_invitation = fetch_ee_implementation_or_noop(
"onyx.server.tenants.user_mapping", "get_tenant_invitation", None
)(user.email)
user_info = UserInfo.from_model(
user,
current_token_created_at=token_created_at,
expiry_length=SESSION_EXPIRE_TIME_SECONDS,
is_cloud_superuser=user.email in SUPER_USERS,
organization_name=organization_name,
team_name=team_name,
tenant_info=TenantInfo(
new_tenant=new_tenant,
invitation=tenant_invitation,
),
)
return user_info

View File

@@ -49,9 +49,9 @@ class FullUserSnapshot(BaseModel):
)
class InvitedUserSnapshot(BaseModel):
email: str
class DisplayPriorityRequest(BaseModel):
display_priority_map: dict[int, int]
class InvitedUserSnapshot(BaseModel):
email: str

View File

@@ -20,7 +20,7 @@ from fastapi.responses import StreamingResponse
from pydantic import BaseModel
from sqlalchemy.orm import Session
from onyx.auth.users import current_chat_accesssible_user
from onyx.auth.users import current_chat_accessible_user
from onyx.auth.users import current_user
from onyx.chat.chat_utils import create_chat_chain
from onyx.chat.chat_utils import extract_headers
@@ -190,7 +190,7 @@ def update_chat_session_model(
def get_chat_session(
session_id: UUID,
is_shared: bool = False,
user: User | None = Depends(current_chat_accesssible_user),
user: User | None = Depends(current_chat_accessible_user),
db_session: Session = Depends(get_session),
) -> ChatSessionDetailResponse:
user_id = user.id if user is not None else None
@@ -246,7 +246,7 @@ def get_chat_session(
@router.post("/create-chat-session")
def create_new_chat_session(
chat_session_creation_request: ChatSessionCreationRequest,
user: User | None = Depends(current_chat_accesssible_user),
user: User | None = Depends(current_chat_accessible_user),
db_session: Session = Depends(get_session),
) -> CreateChatSessionID:
user_id = user.id if user is not None else None
@@ -381,7 +381,7 @@ async def is_connected(request: Request) -> Callable[[], bool]:
def handle_new_chat_message(
chat_message_req: CreateChatMessageRequest,
request: Request,
user: User | None = Depends(current_chat_accesssible_user),
user: User | None = Depends(current_chat_accessible_user),
_rate_limit_check: None = Depends(check_token_rate_limits),
is_connected_func: Callable[[], bool] = Depends(is_connected),
) -> StreamingResponse:
@@ -473,7 +473,7 @@ def set_message_as_latest(
@router.post("/create-chat-message-feedback")
def create_chat_feedback(
feedback: ChatFeedbackRequest,
user: User | None = Depends(current_chat_accesssible_user),
user: User | None = Depends(current_chat_accessible_user),
db_session: Session = Depends(get_session),
) -> None:
user_id = user.id if user else None

View File

@@ -11,7 +11,7 @@ from sqlalchemy import func
from sqlalchemy import select
from sqlalchemy.orm import Session
from onyx.auth.users import current_chat_accesssible_user
from onyx.auth.users import current_chat_accessible_user
from onyx.db.engine import get_session_context_manager
from onyx.db.models import ChatMessage
from onyx.db.models import ChatSession
@@ -29,7 +29,7 @@ TOKEN_BUDGET_UNIT = 1_000
def check_token_rate_limits(
user: User | None = Depends(current_chat_accesssible_user),
user: User | None = Depends(current_chat_accessible_user),
) -> None:
# short circuit if no rate limits are set up
# NOTE: result of `any_rate_limit_exists` is cached, so this call is fast 99% of the time

View File

@@ -32,15 +32,15 @@ class InCodeToolInfo(TypedDict):
BUILT_IN_TOOLS: list[InCodeToolInfo] = [
InCodeToolInfo(
cls=SearchTool,
description="The Search Tool allows the Assistant to search through connected knowledge to help build an answer.",
description="The Search Action allows the Assistant to search through connected knowledge to help build an answer.",
in_code_tool_id=SearchTool.__name__,
display_name=SearchTool._DISPLAY_NAME,
),
InCodeToolInfo(
cls=ImageGenerationTool,
description=(
"The Image Generation Tool allows the assistant to use DALL-E 3 to generate images. "
"The tool will be used when the user asks the assistant to generate an image."
"The Image Generation Action allows the assistant to use DALL-E 3 to generate images. "
"The action will be used when the user asks the assistant to generate an image."
),
in_code_tool_id=ImageGenerationTool.__name__,
display_name=ImageGenerationTool._DISPLAY_NAME,
@@ -51,7 +51,7 @@ BUILT_IN_TOOLS: list[InCodeToolInfo] = [
InCodeToolInfo(
cls=InternetSearchTool,
description=(
"The Internet Search Tool allows the assistant "
"The Internet Search Action allows the assistant "
"to perform internet searches for up-to-date information."
),
in_code_tool_id=InternetSearchTool.__name__,
@@ -98,7 +98,7 @@ def load_builtin_tools(db_session: Session) -> None:
for tool_id, tool in list(in_code_tool_id_to_tool.items()):
if tool_id not in built_in_ids:
db_session.delete(tool)
logger.notice(f"Removed tool no longer in built-in list: {tool.name}")
logger.notice(f"Removed action no longer in built-in list: {tool.name}")
db_session.commit()
logger.notice("All built-in tools are loaded/verified.")

View File

@@ -9,6 +9,7 @@ from sqlalchemy.orm import Session
from onyx.context.search.enums import SearchType
from onyx.context.search.models import IndexFilters
from onyx.context.search.models import InferenceSection
from shared_configs.model_server_models import Embedding
class ToolResponse(BaseModel):
@@ -60,11 +61,15 @@ class SearchQueryInfo(BaseModel):
recency_bias_multiplier: float
# None indicates that the default value should be used
class SearchToolOverrideKwargs(BaseModel):
force_no_rerank: bool
alternate_db_session: Session | None
retrieved_sections_callback: Callable[[list[InferenceSection]], None] | None
skip_query_analysis: bool
force_no_rerank: bool | None = None
alternate_db_session: Session | None = None
retrieved_sections_callback: Callable[[list[InferenceSection]], None] | None = None
skip_query_analysis: bool | None = None
precomputed_query_embedding: Embedding | None = None
precomputed_is_keyword: bool | None = None
precomputed_keywords: list[str] | None = None
class Config:
arbitrary_types_allowed = True

View File

@@ -3,6 +3,7 @@ from collections.abc import Callable
from collections.abc import Generator
from typing import Any
from typing import cast
from typing import TypeVar
from sqlalchemy.orm import Session
@@ -11,7 +12,6 @@ from onyx.chat.models import AnswerStyleConfig
from onyx.chat.models import ContextualPruningConfig
from onyx.chat.models import DocumentPruningConfig
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxContext
from onyx.chat.models import OnyxContexts
from onyx.chat.models import PromptConfig
from onyx.chat.models import SectionRelevancePiece
@@ -42,6 +42,9 @@ from onyx.tools.models import SearchQueryInfo
from onyx.tools.models import SearchToolOverrideKwargs
from onyx.tools.models import ToolResponse
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_utils import (
context_from_inference_section,
)
from onyx.tools.tool_implementations.search.search_utils import llm_doc_to_dict
from onyx.tools.tool_implementations.search_like_tool_utils import (
build_next_prompt_for_search_like_tool,
@@ -281,16 +284,23 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
self, override_kwargs: SearchToolOverrideKwargs | None = None, **llm_kwargs: Any
) -> Generator[ToolResponse, None, None]:
query = cast(str, llm_kwargs[QUERY_FIELD])
precomputed_query_embedding = None
precomputed_is_keyword = None
precomputed_keywords = None
force_no_rerank = False
alternate_db_session = None
retrieved_sections_callback = None
skip_query_analysis = False
if override_kwargs:
force_no_rerank = override_kwargs.force_no_rerank
force_no_rerank = use_alt_not_None(override_kwargs.force_no_rerank, False)
alternate_db_session = override_kwargs.alternate_db_session
retrieved_sections_callback = override_kwargs.retrieved_sections_callback
skip_query_analysis = override_kwargs.skip_query_analysis
skip_query_analysis = use_alt_not_None(
override_kwargs.skip_query_analysis, False
)
precomputed_query_embedding = override_kwargs.precomputed_query_embedding
precomputed_is_keyword = override_kwargs.precomputed_is_keyword
precomputed_keywords = override_kwargs.precomputed_keywords
if self.selected_sections:
yield from self._build_response_for_specified_sections(query)
return
@@ -327,6 +337,9 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
if self.retrieval_options
else None
),
precomputed_query_embedding=precomputed_query_embedding,
precomputed_is_keyword=precomputed_is_keyword,
precomputed_keywords=precomputed_keywords,
),
user=self.user,
llm=self.llm,
@@ -345,8 +358,9 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
)
yield from yield_search_responses(
query,
search_pipeline.reranked_sections,
search_pipeline.final_context_sections,
lambda: search_pipeline.retrieved_sections,
lambda: search_pipeline.reranked_sections,
lambda: search_pipeline.final_context_sections,
search_query_info,
lambda: search_pipeline.section_relevance,
self,
@@ -383,10 +397,16 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
# SearchTool passed in to allow for access to SearchTool properties.
# We can't just call SearchTool methods in the graph because we're operating on
# the retrieved docs (reranking, deduping, etc.) after the SearchTool has run.
#
# The various inference sections are passed in as functions to allow for lazy
# evaluation. The SearchPipeline object properties that they correspond to are
# actually functions defined with @property decorators, and passing them into
# this function causes them to get evaluated immediately which is undesirable.
def yield_search_responses(
query: str,
reranked_sections: list[InferenceSection],
final_context_sections: list[InferenceSection],
get_retrieved_sections: Callable[[], list[InferenceSection]],
get_reranked_sections: Callable[[], list[InferenceSection]],
get_final_context_sections: Callable[[], list[InferenceSection]],
search_query_info: SearchQueryInfo,
get_section_relevance: Callable[[], list[SectionRelevancePiece] | None],
search_tool: SearchTool,
@@ -395,7 +415,7 @@ def yield_search_responses(
id=SEARCH_RESPONSE_SUMMARY_ID,
response=SearchResponseSummary(
rephrased_query=query,
top_sections=final_context_sections,
top_sections=get_retrieved_sections(),
predicted_flow=QueryFlow.QUESTION_ANSWER,
predicted_search=search_query_info.predicted_search,
final_filters=search_query_info.final_filters,
@@ -407,13 +427,8 @@ def yield_search_responses(
id=SEARCH_DOC_CONTENT_ID,
response=OnyxContexts(
contexts=[
OnyxContext(
content=section.combined_content,
document_id=section.center_chunk.document_id,
semantic_identifier=section.center_chunk.semantic_identifier,
blurb=section.center_chunk.blurb,
)
for section in reranked_sections
context_from_inference_section(section)
for section in get_reranked_sections()
]
),
)
@@ -424,6 +439,7 @@ def yield_search_responses(
response=section_relevance,
)
final_context_sections = get_final_context_sections()
pruned_sections = prune_sections(
sections=final_context_sections,
section_relevance_list=section_relevance_list_impl(
@@ -438,3 +454,10 @@ def yield_search_responses(
llm_docs = [llm_doc_from_inference_section(section) for section in pruned_sections]
yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs)
T = TypeVar("T")
def use_alt_not_None(value: T | None, alt: T) -> T:
return value if value is not None else alt

View File

@@ -1,4 +1,5 @@
from onyx.chat.models import LlmDoc
from onyx.chat.models import OnyxContext
from onyx.context.search.models import InferenceSection
from onyx.prompts.prompt_utils import clean_up_source
@@ -29,3 +30,12 @@ def section_to_dict(section: InferenceSection, section_num: int) -> dict:
"%B %d, %Y %H:%M"
)
return doc_dict
def context_from_inference_section(section: InferenceSection) -> OnyxContext:
return OnyxContext(
content=section.combined_content,
document_id=section.center_chunk.document_id,
semantic_identifier=section.center_chunk.semantic_identifier,
blurb=section.center_chunk.blurb,
)

View File

@@ -1,6 +1,8 @@
from collections.abc import Callable
from collections.abc import Generator
from typing import Any
from typing import Generic
from typing import TypeVar
from onyx.llm.interfaces import LLM
from onyx.llm.models import PreviousMessage
@@ -11,10 +13,16 @@ from onyx.tools.tool import Tool
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
class ToolRunner:
def __init__(self, tool: Tool, args: dict[str, Any]):
R = TypeVar("R")
class ToolRunner(Generic[R]):
def __init__(
self, tool: Tool[R], args: dict[str, Any], override_kwargs: R | None = None
):
self.tool = tool
self.args = args
self.override_kwargs = override_kwargs
self._tool_responses: list[ToolResponse] | None = None
@@ -27,7 +35,9 @@ class ToolRunner:
return
tool_responses: list[ToolResponse] = []
for tool_response in self.tool.run(**self.args):
for tool_response in self.tool.run(
override_kwargs=self.override_kwargs, **self.args
):
yield tool_response
tool_responses.append(tool_response)

View File

@@ -118,7 +118,7 @@ def run_functions_in_parallel(
return results
class TimeoutThread(threading.Thread):
class TimeoutThread(threading.Thread, Generic[R]):
def __init__(
self, timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any
):
@@ -159,3 +159,34 @@ def run_with_timeout(
task.end()
return task.result
# NOTE: this function should really only be used when run_functions_tuples_in_parallel is
# difficult to use. It's up to the programmer to call wait_on_background on the thread after
# the code you want to run in parallel is finished. As with all python thread parallelism,
# this is only useful for I/O bound tasks.
def run_in_background(
func: Callable[..., R], *args: Any, **kwargs: Any
) -> TimeoutThread[R]:
"""
Runs a function in a background thread. Returns a TimeoutThread object that can be used
to wait for the function to finish with wait_on_background.
"""
context = contextvars.copy_context()
# Timeout not used in the non-blocking case
task = TimeoutThread(-1, context.run, func, *args, **kwargs)
task.start()
return task
def wait_on_background(task: TimeoutThread[R]) -> R:
"""
Used in conjunction with run_in_background. blocks until the task is finished,
then returns the result of the task.
"""
task.join()
if task.exception is not None:
raise task.exception
return task.result

43
backend/onyx/utils/url.py Normal file
View File

@@ -0,0 +1,43 @@
from urllib.parse import parse_qs
from urllib.parse import urlencode
from urllib.parse import urlparse
from urllib.parse import urlunparse
def add_url_params(url: str, params: dict) -> str:
"""
Add parameters to a URL, handling existing parameters properly.
Args:
url: The original URL
params: Dictionary of parameters to add
Returns:
URL with added parameters
"""
# Parse the URL
parsed_url = urlparse(url)
# Get existing query parameters
query_params = parse_qs(parsed_url.query)
# Update with new parameters
for key, value in params.items():
query_params[key] = [value]
# Build the new query string
new_query = urlencode(query_params, doseq=True)
# Reconstruct the URL with the new query string
new_url = urlunparse(
(
parsed_url.scheme,
parsed_url.netloc,
parsed_url.path,
parsed_url.params,
new_query,
parsed_url.fragment,
)
)
return new_url

View File

@@ -1,4 +1,4 @@
black==23.3.0
black==23.7.0
boto3-stubs[s3]==1.34.133
celery-types==0.19.0
cohere==5.6.1

View File

@@ -54,6 +54,7 @@ class OnyxRedisCommand(Enum):
purge_vespa_syncing = "purge_vespa_syncing"
get_user_token = "get_user_token"
delete_user_token = "delete_user_token"
add_invited_user = "add_invited_user"
def __str__(self) -> str:
return self.value
@@ -163,6 +164,21 @@ def onyx_redis(
return 0
else:
return 2
elif command == OnyxRedisCommand.add_invited_user:
if not user_email:
logger.error("You must specify --user-email with add_invited_user")
return 1
current_invited_users = get_invited_users()
if user_email not in current_invited_users:
current_invited_users.append(user_email)
if dry_run:
logger.info(f"(DRY-RUN) Would add {user_email} to invited users")
else:
write_invited_users(current_invited_users)
logger.info(f"Added {user_email} to invited users")
else:
logger.info(f"{user_email} is already in the invited users list")
return 0
else:
pass
@@ -441,23 +457,6 @@ if __name__ == "__main__":
if args.tenant_id:
CURRENT_TENANT_ID_CONTEXTVAR.set(args.tenant_id)
if args.command == "add_invited_user":
if not args.user_email:
print("Error: --user-email is required for add_invited_user command")
sys.exit(1)
current_invited_users = get_invited_users()
if args.user_email not in current_invited_users:
current_invited_users.append(args.user_email)
if args.dry_run:
print(f"(DRY-RUN) Would add {args.user_email} to invited users")
else:
write_invited_users(current_invited_users)
print(f"Added {args.user_email} to invited users")
else:
print(f"{args.user_email} is already in the invited users list")
sys.exit(0)
exitcode = onyx_redis(
command=args.command,
batch=args.batch,

View File

@@ -108,6 +108,7 @@ command=tail -qF
/var/log/celery_worker_light.log
/var/log/celery_worker_heavy.log
/var/log/celery_worker_indexing.log
/var/log/celery_worker_monitoring.log
/var/log/slack_bot.log
stdout_logfile=/dev/stdout
stdout_logfile_maxbytes = 0 # must be set to 0 when stdout_logfile=/dev/stdout

View File

@@ -36,6 +36,7 @@ def confluence_connector() -> ConfluenceConnector:
"onyx.file_processing.extract_file_text.get_unstructured_api_key",
return_value=None,
)
@pytest.mark.skip(reason="Skipping this test")
def test_confluence_connector_basic(
mock_get_api_key: MagicMock, confluence_connector: ConfluenceConnector
) -> None:
@@ -45,7 +46,7 @@ def test_confluence_connector_basic(
with pytest.raises(StopIteration):
next(doc_batch_generator)
assert len(doc_batch) == 3
assert len(doc_batch) == 2
page_within_a_page_doc: Document | None = None
page_doc: Document | None = None

View File

@@ -28,6 +28,7 @@ def confluence_connector() -> ConfluenceConnector:
# This should never fail because even if the docs in the cloud change,
# the full doc ids retrieved should always be a subset of the slim doc ids
@pytest.mark.skip(reason="Skipping this test")
def test_confluence_connector_permissions(
confluence_connector: ConfluenceConnector,
) -> None:
@@ -41,5 +42,10 @@ def test_confluence_connector_permissions(
for slim_doc_batch in confluence_connector.retrieve_all_slim_documents():
all_slim_doc_ids.update([doc.id for doc in slim_doc_batch])
# Find IDs that are in full but not in slim
difference = all_full_doc_ids - all_slim_doc_ids
# The set of full doc IDs should be always be a subset of the slim doc IDs
assert all_full_doc_ids.issubset(all_slim_doc_ids)
assert all_full_doc_ids.issubset(
all_slim_doc_ids
), f"Full doc IDs are not a subset of slim doc IDs. Found {len(difference)} IDs in full docs but not in slim docs."

View File

@@ -20,29 +20,32 @@ def gitbook_connector() -> GitbookConnector:
return connector
NUM_PAGES = 3
def test_gitbook_connector_basic(gitbook_connector: GitbookConnector) -> None:
doc_batch_generator = gitbook_connector.load_from_state()
# Get first batch of documents
doc_batch = next(doc_batch_generator)
assert len(doc_batch) > 0
assert len(doc_batch) == NUM_PAGES
# Verify first document structure
doc = doc_batch[0]
main_doc = doc_batch[0]
# Basic document properties
assert doc.id.startswith("gitbook-")
assert doc.semantic_identifier == "Acme Corp Internal Handbook"
assert doc.source == DocumentSource.GITBOOK
assert main_doc.id.startswith("gitbook-")
assert main_doc.semantic_identifier == "Acme Corp Internal Handbook"
assert main_doc.source == DocumentSource.GITBOOK
# Metadata checks
assert "path" in doc.metadata
assert "type" in doc.metadata
assert "kind" in doc.metadata
assert "path" in main_doc.metadata
assert "type" in main_doc.metadata
assert "kind" in main_doc.metadata
# Section checks
assert len(doc.sections) == 1
section = doc.sections[0]
assert len(main_doc.sections) == 1
section = main_doc.sections[0]
# Content specific checks
content = section.text
@@ -74,8 +77,23 @@ def test_gitbook_connector_basic(gitbook_connector: GitbookConnector) -> None:
assert section.link # Should have a URL
nested1 = doc_batch[1]
assert nested1.id.startswith("gitbook-")
assert nested1.semantic_identifier == "Nested1"
assert len(nested1.sections) == 1
# extra newlines at the end, remove them to make test easier
assert nested1.sections[0].text.strip() == "nested1"
assert nested1.source == DocumentSource.GITBOOK
nested2 = doc_batch[2]
assert nested2.id.startswith("gitbook-")
assert nested2.semantic_identifier == "Nested2"
assert len(nested2.sections) == 1
assert nested2.sections[0].text.strip() == "nested2"
assert nested2.source == DocumentSource.GITBOOK
# Time-based polling test
current_time = time.time()
poll_docs = gitbook_connector.poll_source(0, current_time)
poll_batch = next(poll_docs)
assert len(poll_batch) > 0
assert len(poll_batch) == NUM_PAGES

View File

@@ -0,0 +1,128 @@
import os
import time
import pytest
from onyx.configs.constants import DocumentSource
from onyx.connectors.notion.connector import NotionConnector
@pytest.fixture
def notion_connector() -> NotionConnector:
"""Create a NotionConnector with credentials from environment variables"""
connector = NotionConnector()
connector.load_credentials(
{
"notion_integration_token": os.environ["NOTION_INTEGRATION_TOKEN"],
}
)
return connector
def test_notion_connector_basic(notion_connector: NotionConnector) -> None:
"""Test the NotionConnector with a real Notion page.
Uses a Notion workspace under the onyx-test.com domain.
"""
doc_batch_generator = notion_connector.poll_source(0, time.time())
# Get first batch of documents
doc_batch = next(doc_batch_generator)
assert (
len(doc_batch) == 5
), "Expected exactly 5 documents (root, two children, table entry, and table entry child)"
# Find root and child documents by semantic identifier
root_doc = None
child1_doc = None
child2_doc = None
table_entry_doc = None
table_entry_child_doc = None
for doc in doc_batch:
if doc.semantic_identifier == "Root":
root_doc = doc
elif doc.semantic_identifier == "Child1":
child1_doc = doc
elif doc.semantic_identifier == "Child2":
child2_doc = doc
elif doc.semantic_identifier == "table-entry01":
table_entry_doc = doc
elif doc.semantic_identifier == "Child-table-entry01":
table_entry_child_doc = doc
assert root_doc is not None, "Root document not found"
assert child1_doc is not None, "Child1 document not found"
assert child2_doc is not None, "Child2 document not found"
assert table_entry_doc is not None, "Table entry document not found"
assert table_entry_child_doc is not None, "Table entry child document not found"
# Verify root document structure
assert root_doc.id is not None
assert root_doc.source == DocumentSource.NOTION
# Section checks for root
assert len(root_doc.sections) == 1
root_section = root_doc.sections[0]
# Content specific checks for root
assert root_section.text == "\nroot"
assert root_section.link is not None
assert root_section.link.startswith("https://www.notion.so/")
# Verify child1 document structure
assert child1_doc.id is not None
assert child1_doc.source == DocumentSource.NOTION
# Section checks for child1
assert len(child1_doc.sections) == 1
child1_section = child1_doc.sections[0]
# Content specific checks for child1
assert child1_section.text == "\nchild1"
assert child1_section.link is not None
assert child1_section.link.startswith("https://www.notion.so/")
# Verify child2 document structure (includes database)
assert child2_doc.id is not None
assert child2_doc.source == DocumentSource.NOTION
# Section checks for child2
assert len(child2_doc.sections) == 2 # One for content, one for database
child2_section = child2_doc.sections[0]
child2_db_section = child2_doc.sections[1]
# Content specific checks for child2
assert child2_section.text == "\nchild2"
assert child2_section.link is not None
assert child2_section.link.startswith("https://www.notion.so/")
# Database section checks for child2
assert child2_db_section.text.strip() != "" # Should contain some database content
assert child2_db_section.link is not None
assert child2_db_section.link.startswith("https://www.notion.so/")
# Verify table entry document structure
assert table_entry_doc.id is not None
assert table_entry_doc.source == DocumentSource.NOTION
# Section checks for table entry
assert len(table_entry_doc.sections) == 1
table_entry_section = table_entry_doc.sections[0]
# Content specific checks for table entry
assert table_entry_section.text == "\ntable-entry01"
assert table_entry_section.link is not None
assert table_entry_section.link.startswith("https://www.notion.so/")
# Verify table entry child document structure
assert table_entry_child_doc.id is not None
assert table_entry_child_doc.source == DocumentSource.NOTION
# Section checks for table entry child
assert len(table_entry_child_doc.sections) == 1
table_entry_child_section = table_entry_child_doc.sections[0]
# Content specific checks for table entry child
assert table_entry_child_section.text == "\nchild-table-entry01"
assert table_entry_child_section.link is not None
assert table_entry_child_section.link.startswith("https://www.notion.so/")

View File

@@ -25,7 +25,7 @@ from onyx.indexing.models import IndexingSetting
from onyx.setup import setup_postgres
from onyx.setup import setup_vespa
from onyx.utils.logger import setup_logger
from tests.integration.common_utils.timeout import run_with_timeout
from tests.integration.common_utils.timeout import run_with_timeout_multiproc
logger = setup_logger()
@@ -161,7 +161,7 @@ def reset_postgres(
for _ in range(NUM_TRIES):
logger.info(f"Downgrading Postgres... ({_ + 1}/{NUM_TRIES})")
try:
run_with_timeout(
run_with_timeout_multiproc(
downgrade_postgres,
TIMEOUT,
kwargs={

View File

@@ -6,7 +6,9 @@ from typing import TypeVar
T = TypeVar("T")
def run_with_timeout(task: Callable[..., T], timeout: int, kwargs: dict[str, Any]) -> T:
def run_with_timeout_multiproc(
task: Callable[..., T], timeout: int, kwargs: dict[str, Any]
) -> T:
# Use multiprocessing to prevent a thread from blocking the main thread
with multiprocessing.Pool(processes=1) as pool:
async_result = pool.apply_async(task, kwds=kwargs)

View File

@@ -2,7 +2,7 @@ FROM python:3.11.7-slim-bookworm
WORKDIR /app
RUN pip install fastapi uvicorn
RUN pip install "pydantic-core>=2.28.0" fastapi uvicorn
COPY ./main.py /app/main.py

Some files were not shown because too many files have changed in this diff Show More