mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-18 00:05:47 +00:00
Compare commits
18 Commits
batch_proc
...
connector_
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0593d33cbe | ||
|
|
56b108c313 | ||
|
|
e713a0c58a | ||
|
|
d5d124c5db | ||
|
|
3ec1d79034 | ||
|
|
bf4983e35a | ||
|
|
b7da91e3ae | ||
|
|
29382656fc | ||
|
|
7d6db8d500 | ||
|
|
a7a374dc81 | ||
|
|
facc8cc2fa | ||
|
|
2c0af0a0ca | ||
|
|
bfbc1cd954 | ||
|
|
626da583aa | ||
|
|
92faca139d | ||
|
|
cec05c5ee9 | ||
|
|
eaf054ef06 | ||
|
|
a7a1a24658 |
@@ -1,6 +1,7 @@
|
||||
name: Connector Tests
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches: [main]
|
||||
schedule:
|
||||
@@ -51,7 +52,7 @@ env:
|
||||
jobs:
|
||||
connectors-check:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on: [runs-on,runner=8cpu-linux-x64,"run-id=${{ github.run_id }}"]
|
||||
runs-on: [runs-on, runner=8cpu-linux-x64, "run-id=${{ github.run_id }}"]
|
||||
|
||||
env:
|
||||
PYTHONPATH: ./backend
|
||||
@@ -76,7 +77,7 @@ jobs:
|
||||
pip install --retries 5 --timeout 30 -r backend/requirements/dev.txt
|
||||
playwright install chromium
|
||||
playwright install-deps chromium
|
||||
|
||||
|
||||
- name: Run Tests
|
||||
shell: script -q -e -c "bash --noprofile --norc -eo pipefail {0}"
|
||||
run: py.test -o junit_family=xunit2 -xv --ff backend/tests/daily/connectors
|
||||
|
||||
@@ -0,0 +1,125 @@
|
||||
"""Update GitHub connector repo_name to repositories
|
||||
|
||||
Revision ID: 3934b1bc7b62
|
||||
Revises: b7c2b63c4a03
|
||||
Create Date: 2025-03-05 10:50:30.516962
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
import json
|
||||
import logging
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "3934b1bc7b62"
|
||||
down_revision = "b7c2b63c4a03"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
logger = logging.getLogger("alembic.runtime.migration")
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Get all GitHub connectors
|
||||
conn = op.get_bind()
|
||||
|
||||
# First get all GitHub connectors
|
||||
github_connectors = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id, connector_specific_config
|
||||
FROM connector
|
||||
WHERE source = 'GITHUB'
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
# Update each connector's config
|
||||
updated_count = 0
|
||||
for connector_id, config in github_connectors:
|
||||
try:
|
||||
if not config:
|
||||
logger.warning(f"Connector {connector_id} has no config, skipping")
|
||||
continue
|
||||
|
||||
# Parse the config if it's a string
|
||||
if isinstance(config, str):
|
||||
config = json.loads(config)
|
||||
|
||||
if "repo_name" not in config:
|
||||
continue
|
||||
|
||||
# Create new config with repositories instead of repo_name
|
||||
new_config = dict(config)
|
||||
repo_name_value = new_config.pop("repo_name")
|
||||
new_config["repositories"] = repo_name_value
|
||||
|
||||
# Update the connector with the new config
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE connector
|
||||
SET connector_specific_config = :new_config
|
||||
WHERE id = :connector_id
|
||||
"""
|
||||
),
|
||||
{"connector_id": connector_id, "new_config": json.dumps(new_config)},
|
||||
)
|
||||
updated_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Error updating connector {connector_id}: {str(e)}")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# Get all GitHub connectors
|
||||
conn = op.get_bind()
|
||||
|
||||
logger.debug(
|
||||
"Starting rollback of GitHub connectors from repositories to repo_name"
|
||||
)
|
||||
|
||||
github_connectors = conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
SELECT id, connector_specific_config
|
||||
FROM connector
|
||||
WHERE source = 'GITHUB'
|
||||
"""
|
||||
)
|
||||
).fetchall()
|
||||
|
||||
logger.debug(f"Found {len(github_connectors)} GitHub connectors to rollback")
|
||||
|
||||
# Revert each GitHub connector to use repo_name instead of repositories
|
||||
reverted_count = 0
|
||||
for connector_id, config in github_connectors:
|
||||
try:
|
||||
if not config:
|
||||
continue
|
||||
|
||||
# Parse the config if it's a string
|
||||
if isinstance(config, str):
|
||||
config = json.loads(config)
|
||||
|
||||
if "repositories" not in config:
|
||||
continue
|
||||
|
||||
# Create new config with repo_name instead of repositories
|
||||
new_config = dict(config)
|
||||
repositories_value = new_config.pop("repositories")
|
||||
new_config["repo_name"] = repositories_value
|
||||
|
||||
# Update the connector with the new config
|
||||
conn.execute(
|
||||
sa.text(
|
||||
"""
|
||||
UPDATE connector
|
||||
SET connector_specific_config = :new_config
|
||||
WHERE id = :connector_id
|
||||
"""
|
||||
),
|
||||
{"new_config": json.dumps(new_config), "connector_id": connector_id},
|
||||
)
|
||||
reverted_count += 1
|
||||
except Exception as e:
|
||||
logger.error(f"Error reverting connector {connector_id}: {str(e)}")
|
||||
@@ -6,8 +6,7 @@ Create Date: 2025-02-26 13:07:56.217791
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import time
|
||||
from sqlalchemy import text
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "3bd4c84fe72f"
|
||||
@@ -28,357 +27,45 @@ depends_on = None
|
||||
# 4. Adds indexes to both chat_message and chat_session tables for comprehensive search
|
||||
|
||||
|
||||
def upgrade():
|
||||
# --- PART 1: chat_message table ---
|
||||
# Step 1: Add nullable column (quick, minimal locking)
|
||||
# op.execute("ALTER TABLE chat_message DROP COLUMN IF EXISTS message_tsv")
|
||||
# op.execute("DROP TRIGGER IF EXISTS chat_message_tsv_trigger ON chat_message")
|
||||
# op.execute("DROP FUNCTION IF EXISTS update_chat_message_tsv()")
|
||||
# op.execute("ALTER TABLE chat_message DROP COLUMN IF EXISTS message_tsv")
|
||||
# # Drop chat_session tsv trigger if it exists
|
||||
# op.execute("DROP TRIGGER IF EXISTS chat_session_tsv_trigger ON chat_session")
|
||||
# op.execute("DROP FUNCTION IF EXISTS update_chat_session_tsv()")
|
||||
# op.execute("ALTER TABLE chat_session DROP COLUMN IF EXISTS title_tsv")
|
||||
# raise Exception("Stop here")
|
||||
time.time()
|
||||
op.execute("ALTER TABLE chat_message ADD COLUMN IF NOT EXISTS message_tsv tsvector")
|
||||
|
||||
# Step 2: Create function and trigger for new/updated rows
|
||||
def upgrade() -> None:
|
||||
# Create a GIN index for full-text search on chat_message.message
|
||||
op.execute(
|
||||
"""
|
||||
CREATE OR REPLACE FUNCTION update_chat_message_tsv()
|
||||
RETURNS TRIGGER AS $$
|
||||
BEGIN
|
||||
NEW.message_tsv = to_tsvector('english', NEW.message);
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql
|
||||
"""
|
||||
ALTER TABLE chat_message
|
||||
ADD COLUMN message_tsv tsvector
|
||||
GENERATED ALWAYS AS (to_tsvector('english', message)) STORED;
|
||||
"""
|
||||
)
|
||||
|
||||
# Create trigger in a separate execute call
|
||||
# Commit the current transaction before creating concurrent indexes
|
||||
op.execute("COMMIT")
|
||||
|
||||
op.execute(
|
||||
"""
|
||||
CREATE TRIGGER chat_message_tsv_trigger
|
||||
BEFORE INSERT OR UPDATE ON chat_message
|
||||
FOR EACH ROW EXECUTE FUNCTION update_chat_message_tsv()
|
||||
"""
|
||||
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_message_tsv
|
||||
ON chat_message
|
||||
USING GIN (message_tsv)
|
||||
"""
|
||||
)
|
||||
|
||||
# Step 3: Update existing rows in batches using Python
|
||||
time.time()
|
||||
|
||||
# Get connection and count total rows
|
||||
connection = op.get_bind()
|
||||
total_count_result = connection.execute(
|
||||
text("SELECT COUNT(*) FROM chat_message")
|
||||
).scalar()
|
||||
total_count = total_count_result if total_count_result is not None else 0
|
||||
batch_size = 5000
|
||||
batches = 0
|
||||
|
||||
# Calculate total batches needed
|
||||
total_batches = (
|
||||
(total_count + batch_size - 1) // batch_size if total_count > 0 else 0
|
||||
# Also add a stored tsvector column for chat_session.description
|
||||
op.execute(
|
||||
"""
|
||||
ALTER TABLE chat_session
|
||||
ADD COLUMN description_tsv tsvector
|
||||
GENERATED ALWAYS AS (to_tsvector('english', coalesce(description, ''))) STORED;
|
||||
"""
|
||||
)
|
||||
|
||||
# Process in batches - properly handling UUIDs by using OFFSET/LIMIT approach
|
||||
for batch_num in range(total_batches):
|
||||
offset = batch_num * batch_size
|
||||
# Commit again before creating the second concurrent index
|
||||
op.execute("COMMIT")
|
||||
|
||||
# Execute update for this batch using OFFSET/LIMIT which works with UUIDs
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE chat_message
|
||||
SET message_tsv = to_tsvector('english', message)
|
||||
WHERE id IN (
|
||||
SELECT id FROM chat_message
|
||||
WHERE message_tsv IS NULL
|
||||
ORDER BY id
|
||||
LIMIT :batch_size OFFSET :offset
|
||||
)
|
||||
"""
|
||||
).bindparams(batch_size=batch_size, offset=offset)
|
||||
)
|
||||
|
||||
# Commit each batch
|
||||
connection.execute(text("COMMIT"))
|
||||
# Start a new transaction
|
||||
connection.execute(text("BEGIN"))
|
||||
|
||||
batches += 1
|
||||
|
||||
# Final check for any remaining NULL values
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE chat_message SET message_tsv = to_tsvector('english', message)
|
||||
WHERE message_tsv IS NULL
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Create GIN index concurrently
|
||||
connection.execute(text("COMMIT"))
|
||||
|
||||
time.time()
|
||||
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_message_tsv
|
||||
ON chat_message USING GIN (message_tsv)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# First drop the trigger as it won't be needed anymore
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
DROP TRIGGER IF EXISTS chat_message_tsv_trigger ON chat_message;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
DROP FUNCTION IF EXISTS update_chat_message_tsv();
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Add new generated column
|
||||
time.time()
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
ALTER TABLE chat_message
|
||||
ADD COLUMN message_tsv_gen tsvector
|
||||
GENERATED ALWAYS AS (to_tsvector('english', message)) STORED;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
connection.execute(text("COMMIT"))
|
||||
|
||||
time.time()
|
||||
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_message_tsv_gen
|
||||
ON chat_message USING GIN (message_tsv_gen)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Drop old index and column
|
||||
connection.execute(text("COMMIT"))
|
||||
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
DROP INDEX CONCURRENTLY IF EXISTS idx_chat_message_tsv;
|
||||
"""
|
||||
)
|
||||
)
|
||||
connection.execute(text("COMMIT"))
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
ALTER TABLE chat_message DROP COLUMN message_tsv;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Rename new column to old name
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
ALTER TABLE chat_message RENAME COLUMN message_tsv_gen TO message_tsv;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# --- PART 2: chat_session table ---
|
||||
|
||||
# Step 1: Add nullable column (quick, minimal locking)
|
||||
time.time()
|
||||
connection.execute(
|
||||
text(
|
||||
"ALTER TABLE chat_session ADD COLUMN IF NOT EXISTS description_tsv tsvector"
|
||||
)
|
||||
)
|
||||
|
||||
# Step 2: Create function and trigger for new/updated rows - SPLIT INTO SEPARATE CALLS
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE OR REPLACE FUNCTION update_chat_session_tsv()
|
||||
RETURNS TRIGGER AS $$
|
||||
BEGIN
|
||||
NEW.description_tsv = to_tsvector('english', COALESCE(NEW.description, ''));
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Create trigger in a separate execute call
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE TRIGGER chat_session_tsv_trigger
|
||||
BEFORE INSERT OR UPDATE ON chat_session
|
||||
FOR EACH ROW EXECUTE FUNCTION update_chat_session_tsv()
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Step 3: Update existing rows in batches using Python
|
||||
time.time()
|
||||
|
||||
# Get the maximum ID to determine batch count
|
||||
# Cast id to text for MAX function since it's a UUID
|
||||
max_id_result = connection.execute(
|
||||
text("SELECT COALESCE(MAX(id::text), '0') FROM chat_session")
|
||||
).scalar()
|
||||
max_id_result if max_id_result is not None else "0"
|
||||
batch_size = 5000
|
||||
batches = 0
|
||||
|
||||
# Get all IDs ordered to process in batches
|
||||
rows = connection.execute(
|
||||
text("SELECT id FROM chat_session ORDER BY id")
|
||||
).fetchall()
|
||||
total_rows = len(rows)
|
||||
|
||||
# Process in batches
|
||||
for batch_num, batch_start in enumerate(range(0, total_rows, batch_size)):
|
||||
batch_end = min(batch_start + batch_size, total_rows)
|
||||
batch_ids = [row[0] for row in rows[batch_start:batch_end]]
|
||||
|
||||
if not batch_ids:
|
||||
continue
|
||||
|
||||
# Use IN clause instead of BETWEEN for UUIDs
|
||||
placeholders = ", ".join([f":id{i}" for i in range(len(batch_ids))])
|
||||
params = {f"id{i}": id_val for i, id_val in enumerate(batch_ids)}
|
||||
|
||||
# Execute update for this batch
|
||||
connection.execute(
|
||||
text(
|
||||
f"""
|
||||
UPDATE chat_session
|
||||
SET description_tsv = to_tsvector('english', COALESCE(description, ''))
|
||||
WHERE id IN ({placeholders})
|
||||
AND description_tsv IS NULL
|
||||
"""
|
||||
).bindparams(**params)
|
||||
)
|
||||
|
||||
# Commit each batch
|
||||
connection.execute(text("COMMIT"))
|
||||
# Start a new transaction
|
||||
connection.execute(text("BEGIN"))
|
||||
|
||||
batches += 1
|
||||
|
||||
# Final check for any remaining NULL values
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
UPDATE chat_session SET description_tsv = to_tsvector('english', COALESCE(description, ''))
|
||||
WHERE description_tsv IS NULL
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Create GIN index concurrently
|
||||
connection.execute(text("COMMIT"))
|
||||
|
||||
time.time()
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_session_desc_tsv
|
||||
ON chat_session USING GIN (description_tsv)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# After Final check for chat_session
|
||||
# First drop the trigger as it won't be needed anymore
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
DROP TRIGGER IF EXISTS chat_session_tsv_trigger ON chat_session;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
DROP FUNCTION IF EXISTS update_chat_session_tsv();
|
||||
"""
|
||||
)
|
||||
)
|
||||
# Add new generated column
|
||||
time.time()
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
ALTER TABLE chat_session
|
||||
ADD COLUMN description_tsv_gen tsvector
|
||||
GENERATED ALWAYS AS (to_tsvector('english', COALESCE(description, ''))) STORED;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Create new index on generated column
|
||||
connection.execute(text("COMMIT"))
|
||||
|
||||
time.time()
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_session_desc_tsv_gen
|
||||
ON chat_session USING GIN (description_tsv_gen)
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Drop old index and column
|
||||
connection.execute(text("COMMIT"))
|
||||
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
DROP INDEX CONCURRENTLY IF EXISTS idx_chat_session_desc_tsv;
|
||||
"""
|
||||
)
|
||||
)
|
||||
connection.execute(text("COMMIT"))
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
ALTER TABLE chat_session DROP COLUMN description_tsv;
|
||||
"""
|
||||
)
|
||||
)
|
||||
|
||||
# Rename new column to old name
|
||||
connection.execute(
|
||||
text(
|
||||
"""
|
||||
ALTER TABLE chat_session RENAME COLUMN description_tsv_gen TO description_tsv;
|
||||
"""
|
||||
)
|
||||
op.execute(
|
||||
"""
|
||||
CREATE INDEX CONCURRENTLY IF NOT EXISTS idx_chat_session_desc_tsv
|
||||
ON chat_session
|
||||
USING GIN (description_tsv)
|
||||
"""
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -15,7 +15,7 @@ from ee.onyx.server.enterprise_settings.api import (
|
||||
)
|
||||
from ee.onyx.server.manage.standard_answer import router as standard_answer_router
|
||||
from ee.onyx.server.middleware.tenant_tracking import add_tenant_id_middleware
|
||||
from ee.onyx.server.oauth.api import router as oauth_router
|
||||
from ee.onyx.server.oauth.api import router as ee_oauth_router
|
||||
from ee.onyx.server.query_and_chat.chat_backend import (
|
||||
router as chat_router,
|
||||
)
|
||||
@@ -128,7 +128,7 @@ def get_application() -> FastAPI:
|
||||
include_router_with_global_prefix_prepended(application, query_router)
|
||||
include_router_with_global_prefix_prepended(application, chat_router)
|
||||
include_router_with_global_prefix_prepended(application, standard_answer_router)
|
||||
include_router_with_global_prefix_prepended(application, oauth_router)
|
||||
include_router_with_global_prefix_prepended(application, ee_oauth_router)
|
||||
|
||||
# Enterprise-only global settings
|
||||
include_router_with_global_prefix_prepended(
|
||||
|
||||
@@ -80,6 +80,7 @@ class ConfluenceCloudOAuth:
|
||||
"search:confluence%20"
|
||||
# granular scope
|
||||
"read:attachment:confluence%20" # possibly unneeded unless calling v2 attachments api
|
||||
"read:content-details:confluence%20" # for permission sync
|
||||
"offline_access"
|
||||
)
|
||||
|
||||
|
||||
@@ -48,4 +48,5 @@ def store_product_gating(tenant_id: str, application_status: ApplicationStatus)
|
||||
|
||||
def get_gated_tenants() -> set[str]:
|
||||
redis_client = get_redis_replica_client(tenant_id=ONYX_CLOUD_TENANT_ID)
|
||||
return cast(set[str], redis_client.smembers(GATED_TENANTS_KEY))
|
||||
gated_tenants_bytes = cast(set[bytes], redis_client.smembers(GATED_TENANTS_KEY))
|
||||
return {tenant_id.decode("utf-8") for tenant_id in gated_tenants_bytes}
|
||||
|
||||
@@ -55,7 +55,11 @@ logger = logging.getLogger(__name__)
|
||||
async def get_or_provision_tenant(
|
||||
email: str, referral_source: str | None = None, request: Request | None = None
|
||||
) -> str:
|
||||
"""Get existing tenant ID for an email or create a new tenant if none exists."""
|
||||
"""
|
||||
Get existing tenant ID for an email or create a new tenant if none exists.
|
||||
This function should only be called after we have verified we want this user's tenant to exist.
|
||||
It returns the tenant ID associated with the email, creating a new tenant if necessary.
|
||||
"""
|
||||
if not MULTI_TENANT:
|
||||
return POSTGRES_DEFAULT_SCHEMA
|
||||
|
||||
|
||||
@@ -153,8 +153,9 @@ def generate_initial_answer(
|
||||
)
|
||||
for tool_response in yield_search_responses(
|
||||
query=question,
|
||||
reranked_sections=answer_generation_documents.streaming_documents,
|
||||
final_context_sections=answer_generation_documents.context_documents,
|
||||
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
|
||||
get_reranked_sections=lambda: answer_generation_documents.streaming_documents,
|
||||
get_final_context_sections=lambda: answer_generation_documents.context_documents,
|
||||
search_query_info=query_info,
|
||||
get_section_relevance=lambda: relevance_list,
|
||||
search_tool=graph_config.tooling.search_tool,
|
||||
|
||||
@@ -179,8 +179,9 @@ def generate_validate_refined_answer(
|
||||
)
|
||||
for tool_response in yield_search_responses(
|
||||
query=question,
|
||||
reranked_sections=answer_generation_documents.streaming_documents,
|
||||
final_context_sections=answer_generation_documents.context_documents,
|
||||
get_retrieved_sections=lambda: answer_generation_documents.context_documents,
|
||||
get_reranked_sections=lambda: answer_generation_documents.streaming_documents,
|
||||
get_final_context_sections=lambda: answer_generation_documents.context_documents,
|
||||
search_query_info=query_info,
|
||||
get_section_relevance=lambda: relevance_list,
|
||||
search_tool=graph_config.tooling.search_tool,
|
||||
|
||||
@@ -13,7 +13,6 @@ from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import StreamType
|
||||
from onyx.chat.models import SubQuestionPiece
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.tools.models import SearchQueryInfo
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -144,8 +143,6 @@ def get_query_info(results: list[QueryRetrievalResult]) -> SearchQueryInfo:
|
||||
if result.query_info is not None:
|
||||
query_info = result.query_info
|
||||
break
|
||||
return query_info or SearchQueryInfo(
|
||||
predicted_search=None,
|
||||
final_filters=IndexFilters(access_control_list=None),
|
||||
recency_bias_multiplier=1.0,
|
||||
)
|
||||
|
||||
assert query_info is not None, "must have query info"
|
||||
return query_info
|
||||
|
||||
@@ -56,8 +56,9 @@ def format_results(
|
||||
relevance_list = relevance_from_docs(reranked_documents)
|
||||
for tool_response in yield_search_responses(
|
||||
query=state.question,
|
||||
reranked_sections=state.retrieved_documents,
|
||||
final_context_sections=reranked_documents,
|
||||
get_retrieved_sections=lambda: reranked_documents,
|
||||
get_reranked_sections=lambda: state.retrieved_documents,
|
||||
get_final_context_sections=lambda: reranked_documents,
|
||||
search_query_info=query_info,
|
||||
get_section_relevance=lambda: relevance_list,
|
||||
search_tool=graph_config.tooling.search_tool,
|
||||
|
||||
@@ -91,7 +91,7 @@ def retrieve_documents(
|
||||
retrieved_docs = retrieved_docs[:AGENT_MAX_QUERY_RETRIEVAL_RESULTS]
|
||||
|
||||
if AGENT_RETRIEVAL_STATS:
|
||||
pre_rerank_docs = callback_container[0]
|
||||
pre_rerank_docs = callback_container[0] if callback_container else []
|
||||
fit_scores = get_fit_scores(
|
||||
pre_rerank_docs,
|
||||
retrieved_docs,
|
||||
|
||||
@@ -44,7 +44,9 @@ def call_tool(
|
||||
tool = tool_choice.tool
|
||||
tool_args = tool_choice.tool_args
|
||||
tool_id = tool_choice.id
|
||||
tool_runner = ToolRunner(tool, tool_args)
|
||||
tool_runner = ToolRunner(
|
||||
tool, tool_args, override_kwargs=tool_choice.search_tool_override_kwargs
|
||||
)
|
||||
tool_kickoff = tool_runner.kickoff()
|
||||
|
||||
emit_packet(tool_kickoff, writer)
|
||||
|
||||
@@ -15,8 +15,17 @@ from onyx.chat.tool_handling.tool_response_handler import get_tool_by_name
|
||||
from onyx.chat.tool_handling.tool_response_handler import (
|
||||
get_tool_call_for_non_tool_calling_llm_impl,
|
||||
)
|
||||
from onyx.context.search.preprocessing.preprocessing import query_analysis
|
||||
from onyx.context.search.retrieval.search_runner import get_query_embedding
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_in_background
|
||||
from onyx.utils.threadpool_concurrency import TimeoutThread
|
||||
from onyx.utils.threadpool_concurrency import wait_on_background
|
||||
from onyx.utils.timing import log_function_time
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -25,6 +34,7 @@ logger = setup_logger()
|
||||
# and a function that handles extracting the necessary fields
|
||||
# from the state and config
|
||||
# TODO: fan-out to multiple tool call nodes? Make this configurable?
|
||||
@log_function_time(print_only=True)
|
||||
def choose_tool(
|
||||
state: ToolChoiceState,
|
||||
config: RunnableConfig,
|
||||
@@ -37,6 +47,31 @@ def choose_tool(
|
||||
should_stream_answer = state.should_stream_answer
|
||||
|
||||
agent_config = cast(GraphConfig, config["metadata"]["config"])
|
||||
|
||||
force_use_tool = agent_config.tooling.force_use_tool
|
||||
|
||||
embedding_thread: TimeoutThread[Embedding] | None = None
|
||||
keyword_thread: TimeoutThread[tuple[bool, list[str]]] | None = None
|
||||
override_kwargs: SearchToolOverrideKwargs | None = None
|
||||
if (
|
||||
not agent_config.behavior.use_agentic_search
|
||||
and agent_config.tooling.search_tool is not None
|
||||
and (
|
||||
not force_use_tool.force_use or force_use_tool.tool_name == SearchTool.name
|
||||
)
|
||||
):
|
||||
override_kwargs = SearchToolOverrideKwargs()
|
||||
# Run in a background thread to avoid blocking the main thread
|
||||
embedding_thread = run_in_background(
|
||||
get_query_embedding,
|
||||
agent_config.inputs.search_request.query,
|
||||
agent_config.persistence.db_session,
|
||||
)
|
||||
keyword_thread = run_in_background(
|
||||
query_analysis,
|
||||
agent_config.inputs.search_request.query,
|
||||
)
|
||||
|
||||
using_tool_calling_llm = agent_config.tooling.using_tool_calling_llm
|
||||
prompt_builder = state.prompt_snapshot or agent_config.inputs.prompt_builder
|
||||
|
||||
@@ -47,7 +82,6 @@ def choose_tool(
|
||||
tools = [
|
||||
tool for tool in (agent_config.tooling.tools or []) if tool.name in state.tools
|
||||
]
|
||||
force_use_tool = agent_config.tooling.force_use_tool
|
||||
|
||||
tool, tool_args = None, None
|
||||
if force_use_tool.force_use and force_use_tool.args is not None:
|
||||
@@ -71,11 +105,22 @@ def choose_tool(
|
||||
# If we have a tool and tool args, we are ready to request a tool call.
|
||||
# This only happens if the tool call was forced or we are using a non-tool calling LLM.
|
||||
if tool and tool_args:
|
||||
if embedding_thread and tool.name == SearchTool._NAME:
|
||||
# Wait for the embedding thread to finish
|
||||
embedding = wait_on_background(embedding_thread)
|
||||
assert override_kwargs is not None, "must have override kwargs"
|
||||
override_kwargs.precomputed_query_embedding = embedding
|
||||
if keyword_thread and tool.name == SearchTool._NAME:
|
||||
is_keyword, keywords = wait_on_background(keyword_thread)
|
||||
assert override_kwargs is not None, "must have override kwargs"
|
||||
override_kwargs.precomputed_is_keyword = is_keyword
|
||||
override_kwargs.precomputed_keywords = keywords
|
||||
return ToolChoiceUpdate(
|
||||
tool_choice=ToolChoice(
|
||||
tool=tool,
|
||||
tool_args=tool_args,
|
||||
id=str(uuid4()),
|
||||
search_tool_override_kwargs=override_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -153,10 +198,22 @@ def choose_tool(
|
||||
logger.debug(f"Selected tool: {selected_tool.name}")
|
||||
logger.debug(f"Selected tool call request: {selected_tool_call_request}")
|
||||
|
||||
if embedding_thread and selected_tool.name == SearchTool._NAME:
|
||||
# Wait for the embedding thread to finish
|
||||
embedding = wait_on_background(embedding_thread)
|
||||
assert override_kwargs is not None, "must have override kwargs"
|
||||
override_kwargs.precomputed_query_embedding = embedding
|
||||
if keyword_thread and selected_tool.name == SearchTool._NAME:
|
||||
is_keyword, keywords = wait_on_background(keyword_thread)
|
||||
assert override_kwargs is not None, "must have override kwargs"
|
||||
override_kwargs.precomputed_is_keyword = is_keyword
|
||||
override_kwargs.precomputed_keywords = keywords
|
||||
|
||||
return ToolChoiceUpdate(
|
||||
tool_choice=ToolChoice(
|
||||
tool=selected_tool,
|
||||
tool_args=selected_tool_call_request["args"],
|
||||
id=selected_tool_call_request["id"],
|
||||
search_tool_override_kwargs=override_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -9,18 +9,23 @@ from onyx.agents.agent_search.basic.states import BasicState
|
||||
from onyx.agents.agent_search.basic.utils import process_llm_stream
|
||||
from onyx.agents.agent_search.models import GraphConfig
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import OnyxContexts
|
||||
from onyx.tools.tool_implementations.search.search_tool import (
|
||||
SEARCH_DOC_CONTENT_ID,
|
||||
SEARCH_RESPONSE_SUMMARY_ID,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchResponseSummary
|
||||
from onyx.tools.tool_implementations.search.search_utils import (
|
||||
context_from_inference_section,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def basic_use_tool_response(
|
||||
state: BasicState, config: RunnableConfig, writer: StreamWriter = lambda _: None
|
||||
) -> BasicOutput:
|
||||
@@ -50,11 +55,13 @@ def basic_use_tool_response(
|
||||
for yield_item in tool_call_responses:
|
||||
if yield_item.id == FINAL_CONTEXT_DOCUMENTS_ID:
|
||||
final_search_results = cast(list[LlmDoc], yield_item.response)
|
||||
elif yield_item.id == SEARCH_DOC_CONTENT_ID:
|
||||
search_contexts = cast(OnyxContexts, yield_item.response).contexts
|
||||
for doc in search_contexts:
|
||||
if doc.document_id not in initial_search_results:
|
||||
initial_search_results.append(doc)
|
||||
elif yield_item.id == SEARCH_RESPONSE_SUMMARY_ID:
|
||||
search_response_summary = cast(SearchResponseSummary, yield_item.response)
|
||||
for section in search_response_summary.top_sections:
|
||||
if section.center_chunk.document_id not in initial_search_results:
|
||||
initial_search_results.append(
|
||||
context_from_inference_section(section)
|
||||
)
|
||||
|
||||
new_tool_call_chunk = AIMessageChunk(content="")
|
||||
if not agent_config.behavior.skip_gen_ai_answer_generation:
|
||||
|
||||
@@ -2,6 +2,7 @@ from pydantic import BaseModel
|
||||
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import PromptSnapshot
|
||||
from onyx.tools.message import ToolCallSummary
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.models import ToolCallFinalResult
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.tools.models import ToolResponse
|
||||
@@ -35,6 +36,7 @@ class ToolChoice(BaseModel):
|
||||
tool: Tool
|
||||
tool_args: dict
|
||||
id: str | None
|
||||
search_tool_override_kwargs: SearchToolOverrideKwargs | None = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@@ -13,6 +13,11 @@ AGENT_NEGATIVE_VALUE_STR = "no"
|
||||
AGENT_ANSWER_SEPARATOR = "Answer:"
|
||||
|
||||
|
||||
EMBEDDING_KEY = "embedding"
|
||||
IS_KEYWORD_KEY = "is_keyword"
|
||||
KEYWORDS_KEY = "keywords"
|
||||
|
||||
|
||||
class AgentLLMErrorType(str, Enum):
|
||||
TIMEOUT = "timeout"
|
||||
RATE_LIMIT = "rate_limit"
|
||||
|
||||
@@ -587,14 +587,20 @@ class UserManager(UUIDIDMixin, BaseUserManager[User, uuid.UUID]):
|
||||
) -> Optional[User]:
|
||||
email = credentials.username
|
||||
|
||||
# Get tenant_id from mapping table
|
||||
tenant_id = await fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.provisioning",
|
||||
"get_or_provision_tenant",
|
||||
async_return_default_schema,
|
||||
)(
|
||||
email=email,
|
||||
)
|
||||
tenant_id: str | None = None
|
||||
try:
|
||||
tenant_id = fetch_ee_implementation_or_noop(
|
||||
"onyx.server.tenants.provisioning",
|
||||
"get_tenant_id_for_email",
|
||||
None,
|
||||
)(
|
||||
email=email,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"User attempted to login with invalid credentials: {str(e)}"
|
||||
)
|
||||
|
||||
if not tenant_id:
|
||||
# User not found in mapping
|
||||
self.password_helper.hash(credentials.password)
|
||||
|
||||
@@ -15,6 +15,8 @@ from onyx.chat.stream_processing.answer_response_handler import (
|
||||
from onyx.chat.tool_handling.tool_response_handler import ToolResponseHandler
|
||||
|
||||
|
||||
# This is Legacy code that is not used anymore.
|
||||
# It is kept here for reference.
|
||||
class LLMResponseHandlerManager:
|
||||
"""
|
||||
This class is responsible for postprocessing the LLM response stream.
|
||||
|
||||
@@ -90,97 +90,97 @@ class CitationProcessor:
|
||||
next(group for group in citation.groups() if group is not None)
|
||||
)
|
||||
|
||||
if 1 <= numerical_value <= self.max_citation_num:
|
||||
context_llm_doc = self.context_docs[numerical_value - 1]
|
||||
final_citation_num = self.final_order_mapping[
|
||||
if not (1 <= numerical_value <= self.max_citation_num):
|
||||
continue
|
||||
|
||||
context_llm_doc = self.context_docs[numerical_value - 1]
|
||||
final_citation_num = self.final_order_mapping[
|
||||
context_llm_doc.document_id
|
||||
]
|
||||
|
||||
if final_citation_num not in self.citation_order:
|
||||
self.citation_order.append(final_citation_num)
|
||||
|
||||
citation_order_idx = self.citation_order.index(final_citation_num) + 1
|
||||
|
||||
# get the value that was displayed to user, should always
|
||||
# be in the display_doc_order_dict. But check anyways
|
||||
if context_llm_doc.document_id in self.display_order_mapping:
|
||||
displayed_citation_num = self.display_order_mapping[
|
||||
context_llm_doc.document_id
|
||||
]
|
||||
|
||||
if final_citation_num not in self.citation_order:
|
||||
self.citation_order.append(final_citation_num)
|
||||
|
||||
citation_order_idx = (
|
||||
self.citation_order.index(final_citation_num) + 1
|
||||
else:
|
||||
displayed_citation_num = final_citation_num
|
||||
logger.warning(
|
||||
f"Doc {context_llm_doc.document_id} not in display_doc_order_dict. Used LLM citation number instead."
|
||||
)
|
||||
|
||||
# get the value that was displayed to user, should always
|
||||
# be in the display_doc_order_dict. But check anyways
|
||||
if context_llm_doc.document_id in self.display_order_mapping:
|
||||
displayed_citation_num = self.display_order_mapping[
|
||||
context_llm_doc.document_id
|
||||
]
|
||||
else:
|
||||
displayed_citation_num = final_citation_num
|
||||
logger.warning(
|
||||
f"Doc {context_llm_doc.document_id} not in display_doc_order_dict. Used LLM citation number instead."
|
||||
)
|
||||
|
||||
# Skip consecutive citations of the same work
|
||||
if final_citation_num in self.current_citations:
|
||||
start, end = citation.span()
|
||||
real_start = length_to_add + start
|
||||
diff = end - start
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: length_to_add + start]
|
||||
+ self.curr_segment[real_start + diff :]
|
||||
)
|
||||
length_to_add -= diff
|
||||
continue
|
||||
|
||||
# Handle edge case where LLM outputs citation itself
|
||||
if self.curr_segment.startswith("[["):
|
||||
match = re.match(r"\[\[(\d+)\]\]", self.curr_segment)
|
||||
if match:
|
||||
try:
|
||||
doc_id = int(match.group(1))
|
||||
context_llm_doc = self.context_docs[doc_id - 1]
|
||||
yield CitationInfo(
|
||||
# citation_num is now the number post initial ranking, i.e. as displayed to user
|
||||
citation_num=displayed_citation_num,
|
||||
document_id=context_llm_doc.document_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Manual LLM citation didn't properly cite documents {e}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Manual LLM citation wasn't able to close brackets"
|
||||
)
|
||||
continue
|
||||
|
||||
link = context_llm_doc.link
|
||||
|
||||
self.past_cite_count = len(self.llm_out)
|
||||
self.current_citations.append(final_citation_num)
|
||||
|
||||
if citation_order_idx not in self.cited_inds:
|
||||
self.cited_inds.add(citation_order_idx)
|
||||
yield CitationInfo(
|
||||
# citation number is now the one that was displayed to user
|
||||
citation_num=displayed_citation_num,
|
||||
document_id=context_llm_doc.document_id,
|
||||
)
|
||||
|
||||
# Skip consecutive citations of the same work
|
||||
if final_citation_num in self.current_citations:
|
||||
start, end = citation.span()
|
||||
if link:
|
||||
prev_length = len(self.curr_segment)
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: start + length_to_add]
|
||||
+ f"[[{displayed_citation_num}]]({link})" # use the value that was displayed to user
|
||||
+ self.curr_segment[end + length_to_add :]
|
||||
)
|
||||
length_to_add += len(self.curr_segment) - prev_length
|
||||
else:
|
||||
prev_length = len(self.curr_segment)
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: start + length_to_add]
|
||||
+ f"[[{displayed_citation_num}]]()" # use the value that was displayed to user
|
||||
+ self.curr_segment[end + length_to_add :]
|
||||
)
|
||||
length_to_add += len(self.curr_segment) - prev_length
|
||||
real_start = length_to_add + start
|
||||
diff = end - start
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: length_to_add + start]
|
||||
+ self.curr_segment[real_start + diff :]
|
||||
)
|
||||
length_to_add -= diff
|
||||
continue
|
||||
|
||||
last_citation_end = end + length_to_add
|
||||
# Handle edge case where LLM outputs citation itself
|
||||
if self.curr_segment.startswith("[["):
|
||||
match = re.match(r"\[\[(\d+)\]\]", self.curr_segment)
|
||||
if match:
|
||||
try:
|
||||
doc_id = int(match.group(1))
|
||||
context_llm_doc = self.context_docs[doc_id - 1]
|
||||
yield CitationInfo(
|
||||
# citation_num is now the number post initial ranking, i.e. as displayed to user
|
||||
citation_num=displayed_citation_num,
|
||||
document_id=context_llm_doc.document_id,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(
|
||||
f"Manual LLM citation didn't properly cite documents {e}"
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
"Manual LLM citation wasn't able to close brackets"
|
||||
)
|
||||
continue
|
||||
|
||||
link = context_llm_doc.link
|
||||
|
||||
self.past_cite_count = len(self.llm_out)
|
||||
self.current_citations.append(final_citation_num)
|
||||
|
||||
if citation_order_idx not in self.cited_inds:
|
||||
self.cited_inds.add(citation_order_idx)
|
||||
yield CitationInfo(
|
||||
# citation number is now the one that was displayed to user
|
||||
citation_num=displayed_citation_num,
|
||||
document_id=context_llm_doc.document_id,
|
||||
)
|
||||
|
||||
start, end = citation.span()
|
||||
if link:
|
||||
prev_length = len(self.curr_segment)
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: start + length_to_add]
|
||||
+ f"[[{displayed_citation_num}]]({link})" # use the value that was displayed to user
|
||||
+ self.curr_segment[end + length_to_add :]
|
||||
)
|
||||
length_to_add += len(self.curr_segment) - prev_length
|
||||
else:
|
||||
prev_length = len(self.curr_segment)
|
||||
self.curr_segment = (
|
||||
self.curr_segment[: start + length_to_add]
|
||||
+ f"[[{displayed_citation_num}]]()" # use the value that was displayed to user
|
||||
+ self.curr_segment[end + length_to_add :]
|
||||
)
|
||||
length_to_add += len(self.curr_segment) - prev_length
|
||||
|
||||
last_citation_end = end + length_to_add
|
||||
|
||||
if last_citation_end > 0:
|
||||
result += self.curr_segment[:last_citation_end]
|
||||
|
||||
@@ -240,7 +240,7 @@ class ConfluenceConnector(
|
||||
# Extract basic page information
|
||||
page_id = page["id"]
|
||||
page_title = page["title"]
|
||||
page_url = f"{self.wiki_base}/wiki{page['_links']['webui']}"
|
||||
page_url = f"{self.wiki_base}{page['_links']['webui']}"
|
||||
|
||||
# Get the page content
|
||||
page_content = extract_text_from_confluence_html(
|
||||
@@ -305,7 +305,7 @@ class ConfluenceConnector(
|
||||
|
||||
# Create the document
|
||||
return Document(
|
||||
id=build_confluence_document_id(self.wiki_base, page_id, self.is_cloud),
|
||||
id=build_confluence_document_id(self.wiki_base, page["_links"]["webui"], self.is_cloud),
|
||||
sections=sections,
|
||||
source=DocumentSource.CONFLUENCE,
|
||||
semantic_identifier=page_title,
|
||||
@@ -376,7 +376,7 @@ class ConfluenceConnector(
|
||||
content_text, file_storage_name = response
|
||||
|
||||
object_url = build_confluence_document_id(
|
||||
self.wiki_base, page["_links"]["webui"], self.is_cloud
|
||||
self.wiki_base, attachment["_links"]["webui"], self.is_cloud
|
||||
)
|
||||
|
||||
if content_text:
|
||||
|
||||
@@ -124,14 +124,14 @@ class GithubConnector(LoadConnector, PollConnector):
|
||||
def __init__(
|
||||
self,
|
||||
repo_owner: str,
|
||||
repo_name: str | None = None,
|
||||
repositories: str | None = None,
|
||||
batch_size: int = INDEX_BATCH_SIZE,
|
||||
state_filter: str = "all",
|
||||
include_prs: bool = True,
|
||||
include_issues: bool = False,
|
||||
) -> None:
|
||||
self.repo_owner = repo_owner
|
||||
self.repo_name = repo_name
|
||||
self.repositories = repositories
|
||||
self.batch_size = batch_size
|
||||
self.state_filter = state_filter
|
||||
self.include_prs = include_prs
|
||||
@@ -157,11 +157,42 @@ class GithubConnector(LoadConnector, PollConnector):
|
||||
)
|
||||
|
||||
try:
|
||||
return github_client.get_repo(f"{self.repo_owner}/{self.repo_name}")
|
||||
return github_client.get_repo(f"{self.repo_owner}/{self.repositories}")
|
||||
except RateLimitExceededException:
|
||||
_sleep_after_rate_limit_exception(github_client)
|
||||
return self._get_github_repo(github_client, attempt_num + 1)
|
||||
|
||||
def _get_github_repos(
|
||||
self, github_client: Github, attempt_num: int = 0
|
||||
) -> list[Repository.Repository]:
|
||||
"""Get specific repositories based on comma-separated repo_name string."""
|
||||
if attempt_num > _MAX_NUM_RATE_LIMIT_RETRIES:
|
||||
raise RuntimeError(
|
||||
"Re-tried fetching repos too many times. Something is going wrong with fetching objects from Github"
|
||||
)
|
||||
|
||||
try:
|
||||
repos = []
|
||||
# Split repo_name by comma and strip whitespace
|
||||
repo_names = [
|
||||
name.strip() for name in (cast(str, self.repositories)).split(",")
|
||||
]
|
||||
|
||||
for repo_name in repo_names:
|
||||
if repo_name: # Skip empty strings
|
||||
try:
|
||||
repo = github_client.get_repo(f"{self.repo_owner}/{repo_name}")
|
||||
repos.append(repo)
|
||||
except GithubException as e:
|
||||
logger.warning(
|
||||
f"Could not fetch repo {self.repo_owner}/{repo_name}: {e}"
|
||||
)
|
||||
|
||||
return repos
|
||||
except RateLimitExceededException:
|
||||
_sleep_after_rate_limit_exception(github_client)
|
||||
return self._get_github_repos(github_client, attempt_num + 1)
|
||||
|
||||
def _get_all_repos(
|
||||
self, github_client: Github, attempt_num: int = 0
|
||||
) -> list[Repository.Repository]:
|
||||
@@ -189,11 +220,17 @@ class GithubConnector(LoadConnector, PollConnector):
|
||||
if self.github_client is None:
|
||||
raise ConnectorMissingCredentialError("GitHub")
|
||||
|
||||
repos = (
|
||||
[self._get_github_repo(self.github_client)]
|
||||
if self.repo_name
|
||||
else self._get_all_repos(self.github_client)
|
||||
)
|
||||
repos = []
|
||||
if self.repositories:
|
||||
if "," in self.repositories:
|
||||
# Multiple repositories specified
|
||||
repos = self._get_github_repos(self.github_client)
|
||||
else:
|
||||
# Single repository (backward compatibility)
|
||||
repos = [self._get_github_repo(self.github_client)]
|
||||
else:
|
||||
# All repositories
|
||||
repos = self._get_all_repos(self.github_client)
|
||||
|
||||
for repo in repos:
|
||||
if self.include_prs:
|
||||
@@ -268,11 +305,48 @@ class GithubConnector(LoadConnector, PollConnector):
|
||||
)
|
||||
|
||||
try:
|
||||
if self.repo_name:
|
||||
test_repo = self.github_client.get_repo(
|
||||
f"{self.repo_owner}/{self.repo_name}"
|
||||
)
|
||||
test_repo.get_contents("")
|
||||
if self.repositories:
|
||||
if "," in self.repositories:
|
||||
# Multiple repositories specified
|
||||
repo_names = [name.strip() for name in self.repositories.split(",")]
|
||||
if not repo_names:
|
||||
raise ConnectorValidationError(
|
||||
"Invalid connector settings: No valid repository names provided."
|
||||
)
|
||||
|
||||
# Validate at least one repository exists and is accessible
|
||||
valid_repos = False
|
||||
validation_errors = []
|
||||
|
||||
for repo_name in repo_names:
|
||||
if not repo_name:
|
||||
continue
|
||||
|
||||
try:
|
||||
test_repo = self.github_client.get_repo(
|
||||
f"{self.repo_owner}/{repo_name}"
|
||||
)
|
||||
test_repo.get_contents("")
|
||||
valid_repos = True
|
||||
# If at least one repo is valid, we can proceed
|
||||
break
|
||||
except GithubException as e:
|
||||
validation_errors.append(
|
||||
f"Repository '{repo_name}': {e.data.get('message', str(e))}"
|
||||
)
|
||||
|
||||
if not valid_repos:
|
||||
error_msg = (
|
||||
"None of the specified repositories could be accessed: "
|
||||
)
|
||||
error_msg += ", ".join(validation_errors)
|
||||
raise ConnectorValidationError(error_msg)
|
||||
else:
|
||||
# Single repository (backward compatibility)
|
||||
test_repo = self.github_client.get_repo(
|
||||
f"{self.repo_owner}/{self.repositories}"
|
||||
)
|
||||
test_repo.get_contents("")
|
||||
else:
|
||||
# Try to get organization first
|
||||
try:
|
||||
@@ -298,10 +372,15 @@ class GithubConnector(LoadConnector, PollConnector):
|
||||
"Your GitHub token does not have sufficient permissions for this repository (HTTP 403)."
|
||||
)
|
||||
elif e.status == 404:
|
||||
if self.repo_name:
|
||||
raise ConnectorValidationError(
|
||||
f"GitHub repository not found with name: {self.repo_owner}/{self.repo_name}"
|
||||
)
|
||||
if self.repositories:
|
||||
if "," in self.repositories:
|
||||
raise ConnectorValidationError(
|
||||
f"None of the specified GitHub repositories could be found for owner: {self.repo_owner}"
|
||||
)
|
||||
else:
|
||||
raise ConnectorValidationError(
|
||||
f"GitHub repository not found with name: {self.repo_owner}/{self.repositories}"
|
||||
)
|
||||
else:
|
||||
raise ConnectorValidationError(
|
||||
f"GitHub user or organization not found: {self.repo_owner}"
|
||||
@@ -310,6 +389,7 @@ class GithubConnector(LoadConnector, PollConnector):
|
||||
raise ConnectorValidationError(
|
||||
f"Unexpected GitHub error (status={e.status}): {e.data}"
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
raise Exception(
|
||||
f"Unexpected error during GitHub settings validation: {exc}"
|
||||
@@ -321,7 +401,7 @@ if __name__ == "__main__":
|
||||
|
||||
connector = GithubConnector(
|
||||
repo_owner=os.environ["REPO_OWNER"],
|
||||
repo_name=os.environ["REPO_NAME"],
|
||||
repositories=os.environ["REPOSITORIES"],
|
||||
)
|
||||
connector.load_credentials(
|
||||
{"github_access_token": os.environ["GITHUB_ACCESS_TOKEN"]}
|
||||
|
||||
@@ -316,7 +316,9 @@ class GoogleDriveConnector(
|
||||
# validate that the user has access to the drive APIs by performing a simple
|
||||
# request and checking for a 401
|
||||
try:
|
||||
retry_builder()(get_root_folder_id)(drive_service)
|
||||
# default is ~17mins of retries, don't do that here for cases so we don't
|
||||
# waste 17mins everytime we run into a user without access to drive APIs
|
||||
retry_builder(tries=3, delay=1)(get_root_folder_id)(drive_service)
|
||||
except HttpError as e:
|
||||
if e.status_code == 401:
|
||||
# fail gracefully, let the other impersonations continue
|
||||
|
||||
@@ -16,7 +16,7 @@ from onyx.db.models import SearchSettings
|
||||
from onyx.indexing.models import BaseChunk
|
||||
from onyx.indexing.models import IndexingSetting
|
||||
from shared_configs.enums import RerankerProvider
|
||||
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
MAX_METRICS_CONTENT = (
|
||||
200 # Just need enough characters to identify where in the doc the chunk is
|
||||
@@ -151,6 +151,10 @@ class SearchRequest(ChunkContext):
|
||||
evaluation_type: LLMEvaluationType = LLMEvaluationType.UNSPECIFIED
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
precomputed_query_embedding: Embedding | None = None
|
||||
precomputed_is_keyword: bool | None = None
|
||||
precomputed_keywords: list[str] | None = None
|
||||
|
||||
|
||||
class SearchQuery(ChunkContext):
|
||||
"Processed Request that is directly passed to the SearchPipeline"
|
||||
@@ -175,6 +179,8 @@ class SearchQuery(ChunkContext):
|
||||
offset: int = 0
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
precomputed_query_embedding: Embedding | None = None
|
||||
|
||||
|
||||
class RetrievalDetails(ChunkContext):
|
||||
# Use LLM to determine whether to do a retrieval or only rely on existing history
|
||||
|
||||
@@ -331,6 +331,14 @@ class SearchPipeline:
|
||||
self._retrieved_sections = expanded_inference_sections
|
||||
return expanded_inference_sections
|
||||
|
||||
@property
|
||||
def retrieved_sections(self) -> list[InferenceSection]:
|
||||
if self._retrieved_sections is not None:
|
||||
return self._retrieved_sections
|
||||
|
||||
self._retrieved_sections = self._get_sections()
|
||||
return self._retrieved_sections
|
||||
|
||||
@property
|
||||
def reranked_sections(self) -> list[InferenceSection]:
|
||||
"""Reranking is always done at the chunk level since section merging could create arbitrarily
|
||||
@@ -343,7 +351,7 @@ class SearchPipeline:
|
||||
if self._reranked_sections is not None:
|
||||
return self._reranked_sections
|
||||
|
||||
retrieved_sections = self._get_sections()
|
||||
retrieved_sections = self.retrieved_sections
|
||||
if self.retrieved_sections_callback is not None:
|
||||
self.retrieved_sections_callback(retrieved_sections)
|
||||
|
||||
|
||||
@@ -117,8 +117,12 @@ def retrieval_preprocessing(
|
||||
else None
|
||||
)
|
||||
|
||||
# Sometimes this is pre-computed in parallel with other heavy tasks to improve
|
||||
# latency, and in that case we don't need to run the model again
|
||||
run_query_analysis = (
|
||||
None if skip_query_analysis else FunctionCall(query_analysis, (query,), {})
|
||||
None
|
||||
if (skip_query_analysis or search_request.precomputed_is_keyword is not None)
|
||||
else FunctionCall(query_analysis, (query,), {})
|
||||
)
|
||||
|
||||
functions_to_run = [
|
||||
@@ -143,11 +147,12 @@ def retrieval_preprocessing(
|
||||
|
||||
# The extracted keywords right now are not very reliable, not using for now
|
||||
# Can maybe use for highlighting
|
||||
is_keyword, extracted_keywords = (
|
||||
parallel_results[run_query_analysis.result_id]
|
||||
if run_query_analysis
|
||||
else (False, None)
|
||||
)
|
||||
is_keyword, _extracted_keywords = False, None
|
||||
if search_request.precomputed_is_keyword is not None:
|
||||
is_keyword = search_request.precomputed_is_keyword
|
||||
_extracted_keywords = search_request.precomputed_keywords
|
||||
elif run_query_analysis:
|
||||
is_keyword, _extracted_keywords = parallel_results[run_query_analysis.result_id]
|
||||
|
||||
all_query_terms = query.split()
|
||||
processed_keywords = (
|
||||
@@ -247,4 +252,5 @@ def retrieval_preprocessing(
|
||||
chunks_above=chunks_above,
|
||||
chunks_below=chunks_below,
|
||||
full_doc=search_request.full_doc,
|
||||
precomputed_query_embedding=search_request.precomputed_query_embedding,
|
||||
)
|
||||
|
||||
@@ -31,7 +31,7 @@ from onyx.utils.timing import log_function_time
|
||||
from shared_configs.configs import MODEL_SERVER_HOST
|
||||
from shared_configs.configs import MODEL_SERVER_PORT
|
||||
from shared_configs.enums import EmbedTextType
|
||||
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -109,6 +109,20 @@ def combine_retrieval_results(
|
||||
return sorted_chunks
|
||||
|
||||
|
||||
def get_query_embedding(query: str, db_session: Session) -> Embedding:
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
|
||||
model = EmbeddingModel.from_db_model(
|
||||
search_settings=search_settings,
|
||||
# The below are globally set, this flow always uses the indexing one
|
||||
server_host=MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
)
|
||||
|
||||
query_embedding = model.encode([query], text_type=EmbedTextType.QUERY)[0]
|
||||
return query_embedding
|
||||
|
||||
|
||||
@log_function_time(print_only=True)
|
||||
def doc_index_retrieval(
|
||||
query: SearchQuery,
|
||||
@@ -121,17 +135,10 @@ def doc_index_retrieval(
|
||||
from the large chunks to the referenced chunks,
|
||||
dedupes the chunks, and cleans the chunks.
|
||||
"""
|
||||
search_settings = get_current_search_settings(db_session)
|
||||
|
||||
model = EmbeddingModel.from_db_model(
|
||||
search_settings=search_settings,
|
||||
# The below are globally set, this flow always uses the indexing one
|
||||
server_host=MODEL_SERVER_HOST,
|
||||
server_port=MODEL_SERVER_PORT,
|
||||
query_embedding = query.precomputed_query_embedding or get_query_embedding(
|
||||
query.query, db_session
|
||||
)
|
||||
|
||||
query_embedding = model.encode([query.query], text_type=EmbedTextType.QUERY)[0]
|
||||
|
||||
top_chunks = document_index.hybrid_retrieval(
|
||||
query=query.query,
|
||||
query_embedding=query_embedding,
|
||||
@@ -250,6 +257,9 @@ def retrieve_chunks(
|
||||
simplified_queries.add(simplified_rephrase)
|
||||
|
||||
q_copy = query.copy(update={"query": rephrase}, deep=True)
|
||||
q_copy.precomputed_query_embedding = (
|
||||
None # need to recompute for each rephrase
|
||||
)
|
||||
run_queries.append(
|
||||
(
|
||||
doc_index_retrieval,
|
||||
|
||||
@@ -464,12 +464,29 @@ def index_doc_batch(
|
||||
),
|
||||
)
|
||||
|
||||
successful_doc_ids = {record.document_id for record in insertion_records}
|
||||
if successful_doc_ids != set(updatable_ids):
|
||||
all_returned_doc_ids = (
|
||||
{record.document_id for record in insertion_records}
|
||||
.union(
|
||||
{
|
||||
record.failed_document.document_id
|
||||
for record in vector_db_write_failures
|
||||
if record.failed_document
|
||||
}
|
||||
)
|
||||
.union(
|
||||
{
|
||||
record.failed_document.document_id
|
||||
for record in embedding_failures
|
||||
if record.failed_document
|
||||
}
|
||||
)
|
||||
)
|
||||
if all_returned_doc_ids != set(updatable_ids):
|
||||
raise RuntimeError(
|
||||
f"Some documents were not successfully indexed. "
|
||||
f"Updatable IDs: {updatable_ids}, "
|
||||
f"Successful IDs: {successful_doc_ids}"
|
||||
f"Returned IDs: {all_returned_doc_ids}. "
|
||||
"This should never happen."
|
||||
)
|
||||
|
||||
last_modified_ids = []
|
||||
|
||||
@@ -51,6 +51,7 @@ from onyx.server.documents.cc_pair import router as cc_pair_router
|
||||
from onyx.server.documents.connector import router as connector_router
|
||||
from onyx.server.documents.credential import router as credential_router
|
||||
from onyx.server.documents.document import router as document_router
|
||||
from onyx.server.documents.standard_oauth import router as standard_oauth_router
|
||||
from onyx.server.features.document_set.api import router as document_set_router
|
||||
from onyx.server.features.folder.api import router as folder_router
|
||||
from onyx.server.features.input_prompt.api import (
|
||||
@@ -322,6 +323,7 @@ def get_application() -> FastAPI:
|
||||
)
|
||||
include_router_with_global_prefix_prepended(application, long_term_logs_router)
|
||||
include_router_with_global_prefix_prepended(application, api_key_router)
|
||||
include_router_with_global_prefix_prepended(application, standard_oauth_router)
|
||||
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
# Server logs this during auth setup verification step
|
||||
|
||||
@@ -9,6 +9,7 @@ from sqlalchemy.orm import Session
|
||||
from onyx.context.search.enums import SearchType
|
||||
from onyx.context.search.models import IndexFilters
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from shared_configs.model_server_models import Embedding
|
||||
|
||||
|
||||
class ToolResponse(BaseModel):
|
||||
@@ -60,11 +61,15 @@ class SearchQueryInfo(BaseModel):
|
||||
recency_bias_multiplier: float
|
||||
|
||||
|
||||
# None indicates that the default value should be used
|
||||
class SearchToolOverrideKwargs(BaseModel):
|
||||
force_no_rerank: bool
|
||||
alternate_db_session: Session | None
|
||||
retrieved_sections_callback: Callable[[list[InferenceSection]], None] | None
|
||||
skip_query_analysis: bool
|
||||
force_no_rerank: bool | None = None
|
||||
alternate_db_session: Session | None = None
|
||||
retrieved_sections_callback: Callable[[list[InferenceSection]], None] | None = None
|
||||
skip_query_analysis: bool | None = None
|
||||
precomputed_query_embedding: Embedding | None = None
|
||||
precomputed_is_keyword: bool | None = None
|
||||
precomputed_keywords: list[str] | None = None
|
||||
|
||||
class Config:
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@@ -3,6 +3,7 @@ from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TypeVar
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -11,7 +12,6 @@ from onyx.chat.models import AnswerStyleConfig
|
||||
from onyx.chat.models import ContextualPruningConfig
|
||||
from onyx.chat.models import DocumentPruningConfig
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import OnyxContext
|
||||
from onyx.chat.models import OnyxContexts
|
||||
from onyx.chat.models import PromptConfig
|
||||
from onyx.chat.models import SectionRelevancePiece
|
||||
@@ -42,6 +42,9 @@ from onyx.tools.models import SearchQueryInfo
|
||||
from onyx.tools.models import SearchToolOverrideKwargs
|
||||
from onyx.tools.models import ToolResponse
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.search.search_utils import (
|
||||
context_from_inference_section,
|
||||
)
|
||||
from onyx.tools.tool_implementations.search.search_utils import llm_doc_to_dict
|
||||
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
||||
build_next_prompt_for_search_like_tool,
|
||||
@@ -281,16 +284,23 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
self, override_kwargs: SearchToolOverrideKwargs | None = None, **llm_kwargs: Any
|
||||
) -> Generator[ToolResponse, None, None]:
|
||||
query = cast(str, llm_kwargs[QUERY_FIELD])
|
||||
precomputed_query_embedding = None
|
||||
precomputed_is_keyword = None
|
||||
precomputed_keywords = None
|
||||
force_no_rerank = False
|
||||
alternate_db_session = None
|
||||
retrieved_sections_callback = None
|
||||
skip_query_analysis = False
|
||||
if override_kwargs:
|
||||
force_no_rerank = override_kwargs.force_no_rerank
|
||||
force_no_rerank = use_alt_not_None(override_kwargs.force_no_rerank, False)
|
||||
alternate_db_session = override_kwargs.alternate_db_session
|
||||
retrieved_sections_callback = override_kwargs.retrieved_sections_callback
|
||||
skip_query_analysis = override_kwargs.skip_query_analysis
|
||||
|
||||
skip_query_analysis = use_alt_not_None(
|
||||
override_kwargs.skip_query_analysis, False
|
||||
)
|
||||
precomputed_query_embedding = override_kwargs.precomputed_query_embedding
|
||||
precomputed_is_keyword = override_kwargs.precomputed_is_keyword
|
||||
precomputed_keywords = override_kwargs.precomputed_keywords
|
||||
if self.selected_sections:
|
||||
yield from self._build_response_for_specified_sections(query)
|
||||
return
|
||||
@@ -327,6 +337,9 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
if self.retrieval_options
|
||||
else None
|
||||
),
|
||||
precomputed_query_embedding=precomputed_query_embedding,
|
||||
precomputed_is_keyword=precomputed_is_keyword,
|
||||
precomputed_keywords=precomputed_keywords,
|
||||
),
|
||||
user=self.user,
|
||||
llm=self.llm,
|
||||
@@ -345,8 +358,9 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
)
|
||||
yield from yield_search_responses(
|
||||
query,
|
||||
search_pipeline.reranked_sections,
|
||||
search_pipeline.final_context_sections,
|
||||
lambda: search_pipeline.retrieved_sections,
|
||||
lambda: search_pipeline.reranked_sections,
|
||||
lambda: search_pipeline.final_context_sections,
|
||||
search_query_info,
|
||||
lambda: search_pipeline.section_relevance,
|
||||
self,
|
||||
@@ -383,10 +397,16 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
# SearchTool passed in to allow for access to SearchTool properties.
|
||||
# We can't just call SearchTool methods in the graph because we're operating on
|
||||
# the retrieved docs (reranking, deduping, etc.) after the SearchTool has run.
|
||||
#
|
||||
# The various inference sections are passed in as functions to allow for lazy
|
||||
# evaluation. The SearchPipeline object properties that they correspond to are
|
||||
# actually functions defined with @property decorators, and passing them into
|
||||
# this function causes them to get evaluated immediately which is undesirable.
|
||||
def yield_search_responses(
|
||||
query: str,
|
||||
reranked_sections: list[InferenceSection],
|
||||
final_context_sections: list[InferenceSection],
|
||||
get_retrieved_sections: Callable[[], list[InferenceSection]],
|
||||
get_reranked_sections: Callable[[], list[InferenceSection]],
|
||||
get_final_context_sections: Callable[[], list[InferenceSection]],
|
||||
search_query_info: SearchQueryInfo,
|
||||
get_section_relevance: Callable[[], list[SectionRelevancePiece] | None],
|
||||
search_tool: SearchTool,
|
||||
@@ -395,7 +415,7 @@ def yield_search_responses(
|
||||
id=SEARCH_RESPONSE_SUMMARY_ID,
|
||||
response=SearchResponseSummary(
|
||||
rephrased_query=query,
|
||||
top_sections=final_context_sections,
|
||||
top_sections=get_retrieved_sections(),
|
||||
predicted_flow=QueryFlow.QUESTION_ANSWER,
|
||||
predicted_search=search_query_info.predicted_search,
|
||||
final_filters=search_query_info.final_filters,
|
||||
@@ -407,13 +427,8 @@ def yield_search_responses(
|
||||
id=SEARCH_DOC_CONTENT_ID,
|
||||
response=OnyxContexts(
|
||||
contexts=[
|
||||
OnyxContext(
|
||||
content=section.combined_content,
|
||||
document_id=section.center_chunk.document_id,
|
||||
semantic_identifier=section.center_chunk.semantic_identifier,
|
||||
blurb=section.center_chunk.blurb,
|
||||
)
|
||||
for section in reranked_sections
|
||||
context_from_inference_section(section)
|
||||
for section in get_reranked_sections()
|
||||
]
|
||||
),
|
||||
)
|
||||
@@ -424,6 +439,7 @@ def yield_search_responses(
|
||||
response=section_relevance,
|
||||
)
|
||||
|
||||
final_context_sections = get_final_context_sections()
|
||||
pruned_sections = prune_sections(
|
||||
sections=final_context_sections,
|
||||
section_relevance_list=section_relevance_list_impl(
|
||||
@@ -438,3 +454,10 @@ def yield_search_responses(
|
||||
llm_docs = [llm_doc_from_inference_section(section) for section in pruned_sections]
|
||||
|
||||
yield ToolResponse(id=FINAL_CONTEXT_DOCUMENTS_ID, response=llm_docs)
|
||||
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def use_alt_not_None(value: T | None, alt: T) -> T:
|
||||
return value if value is not None else alt
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from onyx.chat.models import LlmDoc
|
||||
from onyx.chat.models import OnyxContext
|
||||
from onyx.context.search.models import InferenceSection
|
||||
from onyx.prompts.prompt_utils import clean_up_source
|
||||
|
||||
@@ -29,3 +30,12 @@ def section_to_dict(section: InferenceSection, section_num: int) -> dict:
|
||||
"%B %d, %Y %H:%M"
|
||||
)
|
||||
return doc_dict
|
||||
|
||||
|
||||
def context_from_inference_section(section: InferenceSection) -> OnyxContext:
|
||||
return OnyxContext(
|
||||
content=section.combined_content,
|
||||
document_id=section.center_chunk.document_id,
|
||||
semantic_identifier=section.center_chunk.semantic_identifier,
|
||||
blurb=section.center_chunk.blurb,
|
||||
)
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from typing import Generic
|
||||
from typing import TypeVar
|
||||
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.llm.models import PreviousMessage
|
||||
@@ -11,10 +13,16 @@ from onyx.tools.tool import Tool
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
|
||||
|
||||
class ToolRunner:
|
||||
def __init__(self, tool: Tool, args: dict[str, Any]):
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
class ToolRunner(Generic[R]):
|
||||
def __init__(
|
||||
self, tool: Tool[R], args: dict[str, Any], override_kwargs: R | None = None
|
||||
):
|
||||
self.tool = tool
|
||||
self.args = args
|
||||
self.override_kwargs = override_kwargs
|
||||
|
||||
self._tool_responses: list[ToolResponse] | None = None
|
||||
|
||||
@@ -27,7 +35,9 @@ class ToolRunner:
|
||||
return
|
||||
|
||||
tool_responses: list[ToolResponse] = []
|
||||
for tool_response in self.tool.run(**self.args):
|
||||
for tool_response in self.tool.run(
|
||||
override_kwargs=self.override_kwargs, **self.args
|
||||
):
|
||||
yield tool_response
|
||||
tool_responses.append(tool_response)
|
||||
|
||||
|
||||
@@ -118,7 +118,7 @@ def run_functions_in_parallel(
|
||||
return results
|
||||
|
||||
|
||||
class TimeoutThread(threading.Thread):
|
||||
class TimeoutThread(threading.Thread, Generic[R]):
|
||||
def __init__(
|
||||
self, timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any
|
||||
):
|
||||
@@ -159,3 +159,34 @@ def run_with_timeout(
|
||||
task.end()
|
||||
|
||||
return task.result
|
||||
|
||||
|
||||
# NOTE: this function should really only be used when run_functions_tuples_in_parallel is
|
||||
# difficult to use. It's up to the programmer to call wait_on_background on the thread after
|
||||
# the code you want to run in parallel is finished. As with all python thread parallelism,
|
||||
# this is only useful for I/O bound tasks.
|
||||
def run_in_background(
|
||||
func: Callable[..., R], *args: Any, **kwargs: Any
|
||||
) -> TimeoutThread[R]:
|
||||
"""
|
||||
Runs a function in a background thread. Returns a TimeoutThread object that can be used
|
||||
to wait for the function to finish with wait_on_background.
|
||||
"""
|
||||
context = contextvars.copy_context()
|
||||
# Timeout not used in the non-blocking case
|
||||
task = TimeoutThread(-1, context.run, func, *args, **kwargs)
|
||||
task.start()
|
||||
return task
|
||||
|
||||
|
||||
def wait_on_background(task: TimeoutThread[R]) -> R:
|
||||
"""
|
||||
Used in conjunction with run_in_background. blocks until the task is finished,
|
||||
then returns the result of the task.
|
||||
"""
|
||||
task.join()
|
||||
|
||||
if task.exception is not None:
|
||||
raise task.exception
|
||||
|
||||
return task.result
|
||||
|
||||
@@ -45,55 +45,43 @@ def test_confluence_connector_basic(
|
||||
with pytest.raises(StopIteration):
|
||||
next(doc_batch_generator)
|
||||
|
||||
assert len(doc_batch) == 3
|
||||
assert len(doc_batch) == 2
|
||||
|
||||
page_within_a_page_doc: Document | None = None
|
||||
page_doc: Document | None = None
|
||||
txt_doc: Document | None = None
|
||||
|
||||
for doc in doc_batch:
|
||||
if doc.semantic_identifier == "DailyConnectorTestSpace Home":
|
||||
page_doc = doc
|
||||
elif ".txt" in doc.semantic_identifier:
|
||||
txt_doc = doc
|
||||
elif doc.semantic_identifier == "Page Within A Page":
|
||||
page_within_a_page_doc = doc
|
||||
|
||||
assert page_within_a_page_doc is not None
|
||||
assert page_within_a_page_doc.semantic_identifier == "Page Within A Page"
|
||||
assert page_within_a_page_doc.primary_owners
|
||||
assert page_within_a_page_doc.primary_owners[0].email == "hagen@danswer.ai"
|
||||
# Updated to check for display_name instead of email
|
||||
assert page_within_a_page_doc.primary_owners[0].display_name == "Hagen O'Neill"
|
||||
assert page_within_a_page_doc.primary_owners[0].email is None
|
||||
assert len(page_within_a_page_doc.sections) == 1
|
||||
|
||||
page_within_a_page_section = page_within_a_page_doc.sections[0]
|
||||
page_within_a_page_text = "@Chris Weaver loves cherry pie"
|
||||
assert page_within_a_page_section.text == page_within_a_page_text
|
||||
# Updated link assertion
|
||||
assert (
|
||||
page_within_a_page_section.link
|
||||
== "https://danswerai.atlassian.net/wiki/spaces/DailyConne/pages/200769540/Page+Within+A+Page"
|
||||
page_within_a_page_section.link.endswith(
|
||||
"/wiki/spaces/DailyConne/pages/200769540/Page+Within+A+Page"
|
||||
)
|
||||
)
|
||||
|
||||
assert page_doc is not None
|
||||
assert page_doc.semantic_identifier == "DailyConnectorTestSpace Home"
|
||||
assert page_doc.metadata["labels"] == ["testlabel"]
|
||||
assert page_doc.primary_owners
|
||||
assert page_doc.primary_owners[0].email == "hagen@danswer.ai"
|
||||
assert len(page_doc.sections) == 1
|
||||
assert page_doc.primary_owners[0].display_name == "Hagen O'Neill"
|
||||
assert page_doc.primary_owners[0].email is None
|
||||
assert len(page_doc.sections) == 2
|
||||
|
||||
page_section = page_doc.sections[0]
|
||||
assert page_section.text == "test123 " + page_within_a_page_text
|
||||
assert (
|
||||
page_section.link
|
||||
== "https://danswerai.atlassian.net/wiki/spaces/DailyConne/overview"
|
||||
)
|
||||
|
||||
assert txt_doc is not None
|
||||
assert txt_doc.semantic_identifier == "small-file.txt"
|
||||
assert len(txt_doc.sections) == 1
|
||||
assert txt_doc.sections[0].text == "small"
|
||||
assert txt_doc.primary_owners
|
||||
assert txt_doc.primary_owners[0].email == "chris@onyx.app"
|
||||
assert (
|
||||
txt_doc.sections[0].link
|
||||
== "https://danswerai.atlassian.net/wiki/pages/viewpageattachments.action?pageId=52494430&preview=%2F52494430%2F52527123%2Fsmall-file.txt"
|
||||
)
|
||||
assert page_section.link.endswith("/wiki/spaces/DailyConne/overview")
|
||||
@@ -41,5 +41,9 @@ def test_confluence_connector_permissions(
|
||||
for slim_doc_batch in confluence_connector.retrieve_all_slim_documents():
|
||||
all_slim_doc_ids.update([doc.id for doc in slim_doc_batch])
|
||||
|
||||
# Find IDs that are in full but not in slim
|
||||
difference = all_full_doc_ids - all_slim_doc_ids
|
||||
|
||||
# The set of full doc IDs should be always be a subset of the slim doc IDs
|
||||
assert all_full_doc_ids.issubset(all_slim_doc_ids)
|
||||
assert all_full_doc_ids.issubset(all_slim_doc_ids), \
|
||||
f"Full doc IDs are not a subset of slim doc IDs. Found {len(difference)} IDs in full docs but not in slim docs."
|
||||
|
||||
@@ -1,8 +1,14 @@
|
||||
import contextvars
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.utils.threadpool_concurrency import run_in_background
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.threadpool_concurrency import wait_on_background
|
||||
|
||||
# Create a context variable for testing
|
||||
test_context_var = contextvars.ContextVar("test_var", default="default")
|
||||
|
||||
|
||||
def test_run_with_timeout_completes() -> None:
|
||||
@@ -59,3 +65,86 @@ def test_run_with_timeout_with_args_and_kwargs() -> None:
|
||||
# Test with positional and keyword args
|
||||
result2 = run_with_timeout(1.0, complex_function, x=5, y=3, multiply=True)
|
||||
assert result2 == 15
|
||||
|
||||
|
||||
def test_run_in_background_and_wait_success() -> None:
|
||||
"""Test that run_in_background and wait_on_background work correctly for successful execution"""
|
||||
|
||||
def background_function(x: int) -> int:
|
||||
time.sleep(0.1) # Small delay to ensure it's actually running in background
|
||||
return x * 2
|
||||
|
||||
# Start the background task
|
||||
task = run_in_background(background_function, 21)
|
||||
|
||||
# Verify we can do other work while task is running
|
||||
start_time = time.time()
|
||||
result = wait_on_background(task)
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
assert result == 42
|
||||
assert elapsed >= 0.1 # Verify we actually waited for the sleep
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore::pytest.PytestUnhandledThreadExceptionWarning")
|
||||
def test_run_in_background_propagates_exceptions() -> None:
|
||||
"""Test that exceptions in background tasks are properly propagated"""
|
||||
|
||||
def error_function() -> None:
|
||||
time.sleep(0.1) # Small delay to ensure it's actually running in background
|
||||
raise ValueError("Test background error")
|
||||
|
||||
task = run_in_background(error_function)
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
wait_on_background(task)
|
||||
|
||||
assert "Test background error" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_run_in_background_with_args_and_kwargs() -> None:
|
||||
"""Test that args and kwargs are properly passed to the background function"""
|
||||
|
||||
def complex_function(x: int, y: int, multiply: bool = False) -> int:
|
||||
time.sleep(0.1) # Small delay to ensure it's actually running in background
|
||||
if multiply:
|
||||
return x * y
|
||||
return x + y
|
||||
|
||||
# Test with args
|
||||
task1 = run_in_background(complex_function, 5, 3)
|
||||
result1 = wait_on_background(task1)
|
||||
assert result1 == 8
|
||||
|
||||
# Test with args and kwargs
|
||||
task2 = run_in_background(complex_function, 5, 3, multiply=True)
|
||||
result2 = wait_on_background(task2)
|
||||
assert result2 == 15
|
||||
|
||||
|
||||
def test_multiple_background_tasks() -> None:
|
||||
"""Test running multiple background tasks concurrently"""
|
||||
|
||||
def slow_add(x: int, y: int) -> int:
|
||||
time.sleep(0.2) # Make each task take some time
|
||||
return x + y
|
||||
|
||||
# Start multiple tasks
|
||||
start_time = time.time()
|
||||
task1 = run_in_background(slow_add, 1, 2)
|
||||
task2 = run_in_background(slow_add, 3, 4)
|
||||
task3 = run_in_background(slow_add, 5, 6)
|
||||
|
||||
# Wait for all results
|
||||
result1 = wait_on_background(task1)
|
||||
result2 = wait_on_background(task2)
|
||||
result3 = wait_on_background(task3)
|
||||
elapsed = time.time() - start_time
|
||||
|
||||
# Verify results
|
||||
assert result1 == 3
|
||||
assert result2 == 7
|
||||
assert result3 == 11
|
||||
|
||||
# Verify tasks ran in parallel (total time should be ~0.2s, not ~0.6s)
|
||||
assert 0.2 <= elapsed < 0.4 # Allow some buffer for test environment variations
|
||||
|
||||
@@ -4,7 +4,9 @@ import time
|
||||
from onyx.utils.threadpool_concurrency import FunctionCall
|
||||
from onyx.utils.threadpool_concurrency import run_functions_in_parallel
|
||||
from onyx.utils.threadpool_concurrency import run_functions_tuples_in_parallel
|
||||
from onyx.utils.threadpool_concurrency import run_in_background
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.threadpool_concurrency import wait_on_background
|
||||
|
||||
# Create a test contextvar
|
||||
test_var = contextvars.ContextVar("test_var", default="default")
|
||||
@@ -129,3 +131,39 @@ def test_contextvar_isolation_between_runs() -> None:
|
||||
|
||||
# Verify second run results
|
||||
assert all(result in ["thread3", "thread4"] for result in second_results)
|
||||
|
||||
|
||||
def test_run_in_background_preserves_contextvar() -> None:
|
||||
"""Test that run_in_background preserves contextvar values and modifications are isolated"""
|
||||
|
||||
def modify_and_sleep() -> tuple[str, str]:
|
||||
"""Modifies contextvar, sleeps, and returns original, modified, and final values"""
|
||||
original = test_var.get()
|
||||
test_var.set("modified_in_background")
|
||||
time.sleep(0.1) # Ensure we can check main thread during execution
|
||||
final = test_var.get()
|
||||
return original, final
|
||||
|
||||
# Set initial value in main thread
|
||||
token = test_var.set("initial_value")
|
||||
try:
|
||||
# Start background task
|
||||
task = run_in_background(modify_and_sleep)
|
||||
|
||||
# Verify main thread value remains unchanged while task runs
|
||||
assert test_var.get() == "initial_value"
|
||||
|
||||
# Get results from background thread
|
||||
original, modified = wait_on_background(task)
|
||||
|
||||
# Verify the background thread:
|
||||
# 1. Saw the initial value
|
||||
assert original == "initial_value"
|
||||
# 2. Successfully modified its own copy
|
||||
assert modified == "modified_in_background"
|
||||
|
||||
# Verify main thread value is still unchanged after task completion
|
||||
assert test_var.get() == "initial_value"
|
||||
finally:
|
||||
# Clean up
|
||||
test_var.reset(token)
|
||||
|
||||
@@ -80,3 +80,13 @@ prod cluster**
|
||||
- `kubectl delete -f .`
|
||||
- To not delete the persistent volumes (Document indexes and Users), specify the specific `.yaml` files instead of
|
||||
`.` without specifying delete on persistent-volumes.yaml.
|
||||
|
||||
### Using Helm to deploy to an existing cluster
|
||||
|
||||
Onyx has a helm chart that is convenient to install all services to an existing Kubernetes cluster. To install:
|
||||
|
||||
* Currently the helm chart is not published so to install, clone the repo.
|
||||
* Configure access to the cluster via kubectl. Ensure the kubectl context is set to the cluster that you want to use
|
||||
* The default secrets, environment variables and other service level configuration are stored in `deployment/helm/charts/onyx/values.yml`. You may create another `override.yml`
|
||||
* `cd deployment/helm/charts/onyx` and run `helm install onyx -n onyx -f override.yaml .`. This will install onyx on the cluster under the `onyx` namespace.
|
||||
* Check the status of the deploy using `kubectl get pods -n onyx`
|
||||
27
deployment/helm/charts/onyx/templates/ingress-api.yaml
Normal file
27
deployment/helm/charts/onyx/templates/ingress-api.yaml
Normal file
@@ -0,0 +1,27 @@
|
||||
{{- if .Values.ingress.enabled -}}
|
||||
apiVersion: networking.k8s.io/v1
|
||||
kind: Ingress
|
||||
metadata:
|
||||
name: {{ include "onyx-stack.fullname" . }}-ingress-api
|
||||
annotations:
|
||||
kubernetes.io/ingress.class: nginx
|
||||
nginx.ingress.kubernetes.io/rewrite-target: /$2
|
||||
nginx.ingress.kubernetes.io/use-regex: "true"
|
||||
cert-manager.io/cluster-issuer: {{ include "onyx-stack.fullname" . }}-letsencrypt
|
||||
spec:
|
||||
rules:
|
||||
- host: {{ .Values.ingress.api.host }}
|
||||
http:
|
||||
paths:
|
||||
- path: /api(/|$)(.*)
|
||||
pathType: Prefix
|
||||
backend:
|
||||
service:
|
||||
name: {{ include "onyx-stack.fullname" . }}-api-service
|
||||
port:
|
||||
number: {{ .Values.api.service.servicePort }}
|
||||
tls:
|
||||
- hosts:
|
||||
- {{ .Values.ingress.api.host }}
|
||||
secretName: {{ include "onyx-stack.fullname" . }}-ingress-api-tls
|
||||
{{- end }}
|
||||
26
deployment/helm/charts/onyx/templates/ingress-webserver.yaml
Normal file
26
deployment/helm/charts/onyx/templates/ingress-webserver.yaml
Normal file
@@ -0,0 +1,26 @@
|
||||
{{- if .Values.ingress.enabled -}}
|
||||
apiVersion: networking.k8s.io/v1
|
||||
kind: Ingress
|
||||
metadata:
|
||||
name: {{ include "onyx-stack.fullname" . }}-ingress-webserver
|
||||
annotations:
|
||||
kubernetes.io/ingress.class: nginx
|
||||
cert-manager.io/cluster-issuer: {{ include "onyx-stack.fullname" . }}-letsencrypt
|
||||
kubernetes.io/tls-acme: "true"
|
||||
spec:
|
||||
rules:
|
||||
- host: {{ .Values.ingress.webserver.host }}
|
||||
http:
|
||||
paths:
|
||||
- path: /
|
||||
pathType: Prefix
|
||||
backend:
|
||||
service:
|
||||
name: {{ include "onyx-stack.fullname" . }}-webserver
|
||||
port:
|
||||
number: {{ .Values.webserver.service.servicePort }}
|
||||
tls:
|
||||
- hosts:
|
||||
- {{ .Values.ingress.webserver.host }}
|
||||
secretName: {{ include "onyx-stack.fullname" . }}-ingress-webserver-tls
|
||||
{{- end }}
|
||||
20
deployment/helm/charts/onyx/templates/lets-encrypt.yaml
Normal file
20
deployment/helm/charts/onyx/templates/lets-encrypt.yaml
Normal file
@@ -0,0 +1,20 @@
|
||||
{{- if .Values.letsencrypt.enabled -}}
|
||||
apiVersion: cert-manager.io/v1
|
||||
kind: ClusterIssuer
|
||||
metadata:
|
||||
name: {{ include "onyx-stack.fullname" . }}-letsencrypt
|
||||
spec:
|
||||
acme:
|
||||
# The ACME server URL
|
||||
server: https://acme-v02.api.letsencrypt.org/directory
|
||||
# Email address used for ACME registration
|
||||
email: {{ .Values.letsencrypt.email }}
|
||||
# Name of a secret used to store the ACME account private key
|
||||
privateKeySecretRef:
|
||||
name: {{ include "onyx-stack.fullname" . }}-letsencrypt
|
||||
# Enable the HTTP-01 challenge provider
|
||||
solvers:
|
||||
- http01:
|
||||
ingress:
|
||||
class: nginx
|
||||
{{- end }}
|
||||
@@ -376,22 +376,17 @@ redis:
|
||||
existingSecret: onyx-secrets
|
||||
existingSecretPasswordKey: redis_password
|
||||
|
||||
# ingress:
|
||||
# enabled: false
|
||||
# className: ""
|
||||
# annotations: {}
|
||||
# # kubernetes.io/ingress.class: nginx
|
||||
# # kubernetes.io/tls-acme: "true"
|
||||
# hosts:
|
||||
# - host: chart-example.local
|
||||
# paths:
|
||||
# - path: /
|
||||
# pathType: ImplementationSpecific
|
||||
# tls: []
|
||||
# # - secretName: chart-example-tls
|
||||
# # hosts:
|
||||
# # - chart-example.local
|
||||
ingress:
|
||||
enabled: false
|
||||
className: ""
|
||||
api:
|
||||
host: onyx.local
|
||||
webserver:
|
||||
host: onyx.local
|
||||
|
||||
letsencrypt:
|
||||
enabled: false
|
||||
email: "abc@abc.com"
|
||||
|
||||
auth:
|
||||
# existingSecret onyx-secret for storing smtp, oauth, slack, and other secrets
|
||||
|
||||
@@ -1,17 +1,14 @@
|
||||
import {
|
||||
AnthropicIcon,
|
||||
AmazonIcon,
|
||||
AWSIcon,
|
||||
AzureIcon,
|
||||
CPUIcon,
|
||||
MicrosoftIconSVG,
|
||||
MistralIcon,
|
||||
MetaIcon,
|
||||
GeminiIcon,
|
||||
AnthropicSVG,
|
||||
IconProps,
|
||||
OpenAIISVG,
|
||||
DeepseekIcon,
|
||||
OpenAISVG,
|
||||
} from "@/components/icons/icons";
|
||||
|
||||
export interface CustomConfigKey {
|
||||
@@ -74,7 +71,7 @@ export interface LLMProviderDescriptor {
|
||||
}
|
||||
|
||||
export const getProviderIcon = (providerName: string, modelName?: string) => {
|
||||
const modelIconMap: Record<
|
||||
const iconMap: Record<
|
||||
string,
|
||||
({ size, className }: IconProps) => JSX.Element
|
||||
> = {
|
||||
@@ -86,34 +83,30 @@ export const getProviderIcon = (providerName: string, modelName?: string) => {
|
||||
gemini: GeminiIcon,
|
||||
deepseek: DeepseekIcon,
|
||||
claude: AnthropicIcon,
|
||||
anthropic: AnthropicIcon,
|
||||
openai: OpenAISVG,
|
||||
microsoft: MicrosoftIconSVG,
|
||||
meta: MetaIcon,
|
||||
google: GeminiIcon,
|
||||
};
|
||||
|
||||
const modelNameToIcon = (
|
||||
modelName: string,
|
||||
fallbackIcon: ({ size, className }: IconProps) => JSX.Element
|
||||
): (({ size, className }: IconProps) => JSX.Element) => {
|
||||
const lowerModelName = modelName?.toLowerCase();
|
||||
for (const [key, icon] of Object.entries(modelIconMap)) {
|
||||
if (lowerModelName?.includes(key)) {
|
||||
// First check if provider name directly matches an icon
|
||||
if (providerName.toLowerCase() in iconMap) {
|
||||
return iconMap[providerName.toLowerCase()];
|
||||
}
|
||||
|
||||
// Then check if model name contains any of the keys
|
||||
if (modelName) {
|
||||
const lowerModelName = modelName.toLowerCase();
|
||||
for (const [key, icon] of Object.entries(iconMap)) {
|
||||
if (lowerModelName.includes(key)) {
|
||||
return icon;
|
||||
}
|
||||
}
|
||||
return fallbackIcon;
|
||||
};
|
||||
|
||||
switch (providerName) {
|
||||
case "openai":
|
||||
// Special cases for openai based on modelName
|
||||
return modelNameToIcon(modelName || "", OpenAIISVG);
|
||||
case "anthropic":
|
||||
return AnthropicSVG;
|
||||
case "bedrock":
|
||||
return AWSIcon;
|
||||
case "azure":
|
||||
return AzureIcon;
|
||||
default:
|
||||
return modelNameToIcon(modelName || "", CPUIcon);
|
||||
}
|
||||
|
||||
// Fallback to CPU icon if no matches
|
||||
return CPUIcon;
|
||||
};
|
||||
|
||||
export const isAnthropic = (provider: string, modelName: string) =>
|
||||
|
||||
@@ -185,7 +185,10 @@ export const FilterComponent = forwardRef<
|
||||
hasActiveFilters ? "border-primary bg-primary/5" : ""
|
||||
}`}
|
||||
>
|
||||
<SortIcon size={20} className="text-neutral-800" />
|
||||
<SortIcon
|
||||
size={20}
|
||||
className="text-neutral-800 dark:text-neutral-200"
|
||||
/>
|
||||
</Button>
|
||||
</DropdownMenuTrigger>
|
||||
<DropdownMenuContent
|
||||
@@ -365,7 +368,7 @@ export const FilterComponent = forwardRef<
|
||||
|
||||
{hasActiveFilters && (
|
||||
<div className="absolute -top-1 -right-1">
|
||||
<Badge className="h-2 bg-red-400 border-red-400 w-2 p-0 border-2 flex items-center justify-center" />
|
||||
<Badge className="h-2 !bg-red-400 !border-red-400 w-2 p-0 border-2 flex items-center justify-center" />
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@@ -290,21 +290,24 @@ export function SettingsForm() {
|
||||
id="chatRetentionInput"
|
||||
placeholder="Infinite Retention"
|
||||
/>
|
||||
<Button
|
||||
onClick={handleSetChatRetention}
|
||||
variant="submit"
|
||||
size="sm"
|
||||
className="mr-3"
|
||||
>
|
||||
Set Retention Limit
|
||||
</Button>
|
||||
<Button
|
||||
onClick={handleClearChatRetention}
|
||||
variant="default"
|
||||
size="sm"
|
||||
>
|
||||
Retain All
|
||||
</Button>
|
||||
<div className="mr-auto flex gap-2">
|
||||
<Button
|
||||
onClick={handleSetChatRetention}
|
||||
variant="submit"
|
||||
size="sm"
|
||||
className="mr-auto"
|
||||
>
|
||||
Set Retention Limit
|
||||
</Button>
|
||||
<Button
|
||||
onClick={handleClearChatRetention}
|
||||
variant="default"
|
||||
size="sm"
|
||||
className="mr-auto"
|
||||
>
|
||||
Retain All
|
||||
</Button>
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
|
||||
|
||||
@@ -61,6 +61,7 @@ export function EmailPasswordForm({
|
||||
|
||||
if (!response.ok) {
|
||||
setIsWorking(false);
|
||||
|
||||
const errorDetail = (await response.json()).detail;
|
||||
let errorMsg = "Unknown error";
|
||||
if (typeof errorDetail === "object" && errorDetail.reason) {
|
||||
@@ -96,12 +97,13 @@ export function EmailPasswordForm({
|
||||
} else {
|
||||
setIsWorking(false);
|
||||
const errorDetail = (await loginResponse.json()).detail;
|
||||
|
||||
let errorMsg = "Unknown error";
|
||||
if (errorDetail === "LOGIN_BAD_CREDENTIALS") {
|
||||
errorMsg = "Invalid email or password";
|
||||
} else if (errorDetail === "NO_WEB_LOGIN_AND_HAS_NO_PASSWORD") {
|
||||
errorMsg = "Create an account to set a password";
|
||||
} else if (typeof errorDetail === "string") {
|
||||
errorMsg = errorDetail;
|
||||
}
|
||||
if (loginResponse.status === 429) {
|
||||
errorMsg = "Too many requests. Please try again later.";
|
||||
|
||||
@@ -191,6 +191,7 @@ export const FolderDropdown = forwardRef<HTMLDivElement, FolderDropdownProps>(
|
||||
onChange={(e) => setNewFolderName(e.target.value)}
|
||||
className="text-sm font-medium bg-transparent outline-none w-full pb-1 border-b border-background-500 transition-colors duration-200"
|
||||
onKeyDown={(e) => {
|
||||
e.stopPropagation();
|
||||
if (e.key === "Enter") {
|
||||
handleEdit();
|
||||
}
|
||||
|
||||
@@ -303,7 +303,6 @@ const FolderItem = ({
|
||||
key={chatSession.id}
|
||||
chatSession={chatSession}
|
||||
isSelected={chatSession.id === currentChatId}
|
||||
skipGradient={isDragOver}
|
||||
showShareModal={showShareModal}
|
||||
showDeleteModal={showDeleteModal}
|
||||
/>
|
||||
|
||||
@@ -32,21 +32,17 @@ export function ChatSessionDisplay({
|
||||
chatSession,
|
||||
search,
|
||||
isSelected,
|
||||
skipGradient,
|
||||
closeSidebar,
|
||||
showShareModal,
|
||||
showDeleteModal,
|
||||
foldersExisting,
|
||||
isDragging,
|
||||
}: {
|
||||
chatSession: ChatSession;
|
||||
isSelected: boolean;
|
||||
search?: boolean;
|
||||
skipGradient?: boolean;
|
||||
closeSidebar?: () => void;
|
||||
showShareModal?: (chatSession: ChatSession) => void;
|
||||
showDeleteModal?: (chatSession: ChatSession) => void;
|
||||
foldersExisting?: boolean;
|
||||
isDragging?: boolean;
|
||||
}) {
|
||||
const router = useRouter();
|
||||
@@ -238,8 +234,12 @@ export function ChatSessionDisplay({
|
||||
e.preventDefault();
|
||||
e.stopPropagation();
|
||||
}}
|
||||
onChange={(e) => setChatName(e.target.value)}
|
||||
onChange={(e) => {
|
||||
setChatName(e.target.value);
|
||||
}}
|
||||
onKeyDown={(event) => {
|
||||
event.stopPropagation();
|
||||
|
||||
if (event.key === "Enter") {
|
||||
onRename();
|
||||
event.preventDefault();
|
||||
|
||||
@@ -264,7 +264,6 @@ export function PagesTab({
|
||||
>
|
||||
<ChatSessionDisplay
|
||||
chatSession={chat}
|
||||
foldersExisting={foldersExisting}
|
||||
isSelected={currentChatId === chat.id}
|
||||
showShareModal={showShareModal}
|
||||
showDeleteModal={showDeleteModal}
|
||||
|
||||
@@ -40,8 +40,12 @@ export const ConnectorTitle = ({
|
||||
const typedConnector = connector as Connector<GithubConfig>;
|
||||
additionalMetadata.set(
|
||||
"Repo",
|
||||
typedConnector.connector_specific_config.repo_name
|
||||
? `${typedConnector.connector_specific_config.repo_owner}/${typedConnector.connector_specific_config.repo_name}`
|
||||
typedConnector.connector_specific_config.repositories
|
||||
? `${typedConnector.connector_specific_config.repo_owner}/${
|
||||
typedConnector.connector_specific_config.repositories.includes(",")
|
||||
? "multiple repos"
|
||||
: typedConnector.connector_specific_config.repositories
|
||||
}`
|
||||
: `${typedConnector.connector_specific_config.repo_owner}/*`
|
||||
);
|
||||
} else if (connector.source === "gitlab") {
|
||||
|
||||
@@ -3102,27 +3102,6 @@ export const OpenAISVG = ({
|
||||
);
|
||||
};
|
||||
|
||||
export const AnthropicSVG = ({
|
||||
size = 16,
|
||||
className = defaultTailwindCSS,
|
||||
}: IconProps) => {
|
||||
return (
|
||||
<svg
|
||||
style={{ width: `${size}px`, height: `${size}px` }}
|
||||
className={`w-[${size}px] h-[${size}px] ` + className}
|
||||
xmlns="http://www.w3.org/2000/svg"
|
||||
viewBox="0 0 92.2 65"
|
||||
xmlSpace="preserve"
|
||||
fill="currentColor"
|
||||
>
|
||||
<path
|
||||
fill="currentColor"
|
||||
d="M66.5,0H52.4l25.7,65h14.1L66.5,0z M25.7,0L0,65h14.4l5.3-13.6h26.9L51.8,65h14.4L40.5,0C40.5,0,25.7,0,25.7,0z M24.3,39.3l8.8-22.8l8.8,22.8H24.3z"
|
||||
/>
|
||||
</svg>
|
||||
);
|
||||
};
|
||||
|
||||
export const SourcesIcon = ({
|
||||
size = 16,
|
||||
className = defaultTailwindCSS,
|
||||
|
||||
@@ -190,10 +190,12 @@ export const connectorConfigs: Record<
|
||||
fields: [
|
||||
{
|
||||
type: "text",
|
||||
query: "Enter the repository name:",
|
||||
label: "Repository Name",
|
||||
name: "repo_name",
|
||||
query: "Enter the repository name(s):",
|
||||
label: "Repository Name(s)",
|
||||
name: "repositories",
|
||||
optional: false,
|
||||
description:
|
||||
"For multiple repositories, enter comma-separated names (e.g., repo1,repo2,repo3)",
|
||||
},
|
||||
],
|
||||
},
|
||||
@@ -1358,7 +1360,7 @@ export interface WebConfig {
|
||||
|
||||
export interface GithubConfig {
|
||||
repo_owner: string;
|
||||
repo_name: string;
|
||||
repositories: string; // Comma-separated list of repository names
|
||||
include_prs: boolean;
|
||||
include_issues: boolean;
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user