Compare commits

..

17 Commits

Author SHA1 Message Date
Weves
4d1d5bdcfe tweak worker count 2025-02-16 14:51:10 -08:00
Weves
0d468b49a1 Final fixes 2025-02-16 14:50:51 -08:00
Weves
67b87ced39 fix paths 2025-02-16 14:23:51 -08:00
Weves
8b4e4a6c80 Fix paths 2025-02-16 14:22:12 -08:00
Weves
e26bcf5a05 Misc fixes 2025-02-16 14:02:16 -08:00
Weves
435959cf90 test 2025-02-16 14:02:16 -08:00
Weves
fcbe305dc0 test 2025-02-16 14:02:16 -08:00
Weves
6f13d44564 move 2025-02-16 14:02:16 -08:00
Weves
c1810a35cd Fix redis port 2025-02-16 14:02:16 -08:00
Weves
4003e7346a test 2025-02-16 14:02:16 -08:00
Weves
8057f1eb0d Add logging 2025-02-16 14:02:16 -08:00
Weves
7eebd3cff1 test not removing files 2025-02-16 14:02:16 -08:00
Weves
bac2aeb8b7 Fix 2025-02-16 14:02:16 -08:00
Weves
9831697acc Make migrations work 2025-02-16 14:02:15 -08:00
Weves
5da766dd3b testing 2025-02-16 14:01:50 -08:00
Weves
180608694a improvements 2025-02-16 14:01:49 -08:00
Weves
96b92edfdb Parallelize IT
Parallelization

Full draft of first pass

Adjsut test name

test

test

Fix

Update cmd

test

Fix

test

Test with all tests

Resource bump + limit num parallel runs

Add retries
2025-02-16 14:01:12 -08:00
71 changed files with 1769 additions and 955 deletions

View File

@@ -0,0 +1,153 @@
name: Run Integration Tests v3
concurrency:
group: Run-Integration-Tests-Parallel-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
cancel-in-progress: true
on:
merge_group:
pull_request:
branches:
- main
- "release/**"
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
jobs:
integration-tests:
# See https://runs-on.com/runners/linux/
runs-on:
[runs-on, runner=32cpu-linux-x64, ram=64, "run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build integration test Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/tests/integration/Dockerfile
platforms: linux/amd64
tags: danswer/danswer-integration:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration-parallel/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration-parallel/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Run Standard Integration Tests
run: |
# Print a message indicating that tests are starting
echo "Running integration tests..."
# Create a directory for test logs that will be mounted into the container
mkdir -p ${{ github.workspace }}/test_logs
chmod 777 ${{ github.workspace }}/test_logs
# Run the integration tests in a Docker container
# Mount the Docker socket to allow Docker-in-Docker (DinD)
# Mount the test_logs directory to capture logs
# Use host network for easier communication with other services
docker run \
-v /var/run/docker.sock:/var/run/docker.sock \
-v ${{ github.workspace }}/test_logs:/tmp \
--network host \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
-e TEST_WEB_HOSTNAME=test-runner \
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
danswer/danswer-integration:test \
python /app/tests/integration/run.py
continue-on-error: true
id: run_tests
- name: Check test results
run: |
if [ ${{ steps.run_tests.outcome }} == 'failure' ]; then
echo "Integration tests failed. Exiting with error."
exit 1
else
echo "All integration tests passed successfully."
fi
- name: Collect log files
if: success() || failure()
run: |
# Create a directory for logs
mkdir -p ${{ github.workspace }}/logs
mkdir -p ${{ github.workspace }}/logs/shared_services
# Copy all relevant log files from the mounted directory
cp ${{ github.workspace }}/test_logs/api_server_*.txt ${{ github.workspace }}/logs/ || true
cp ${{ github.workspace }}/test_logs/background_*.txt ${{ github.workspace }}/logs/ || true
cp ${{ github.workspace }}/test_logs/shared_model_server.txt ${{ github.workspace }}/logs/ || true
# Collect logs from shared services (Docker containers)
# Note: using a wildcard for the UUID part of the stack name
docker ps -a --filter "name=base-onyx-" --format "{{.Names}}" | while read container; do
echo "Collecting logs from $container"
docker logs $container > "${{ github.workspace }}/logs/shared_services/${container}.log" 2>&1 || true
done
# Also collect Redis container logs
docker ps -a --filter "name=redis-onyx-" --format "{{.Names}}" | while read container; do
echo "Collecting logs from $container"
docker logs $container > "${{ github.workspace }}/logs/shared_services/${container}.log" 2>&1 || true
done
# List collected logs
echo "Collected log files:"
ls -l ${{ github.workspace }}/logs/
echo "Collected shared services logs:"
ls -l ${{ github.workspace }}/logs/shared_services/
- name: Upload logs
if: success() || failure()
uses: actions/upload-artifact@v4
with:
name: integration-test-logs
path: |
${{ github.workspace }}/logs/
${{ github.workspace }}/logs/shared_services/
retention-days: 5
# save before stopping the containers so the logs can be captured
# - name: Save Docker logs
# if: success() || failure()
# run: |
# cd deployment/docker_compose
# docker compose -f docker-compose.dev.yml -p danswer-stack logs > docker-compose.log
# mv docker-compose.log ${{ github.workspace }}/docker-compose.log
# - name: Stop Docker containers
# run: |
# cd deployment/docker_compose
# docker compose -f docker-compose.dev.yml -p danswer-stack down -v
# - name: Upload logs
# if: success() || failure()
# uses: actions/upload-artifact@v4
# with:
# name: docker-logs
# path: ${{ github.workspace }}/docker-compose.log
# - name: Stop Docker containers
# run: |
# cd deployment/docker_compose
# docker compose -f docker-compose.dev.yml -p danswer-stack down -v

View File

@@ -5,10 +5,10 @@ concurrency:
on:
merge_group:
pull_request:
branches:
- main
- "release/**"
# pull_request:
# branches:
# - main
# - "release/**"
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}

View File

@@ -28,11 +28,11 @@ RUN apt-get update && \
curl \
zip \
ca-certificates \
libgnutls30 \
libblkid1 \
libmount1 \
libsmartcols1 \
libuuid1 \
libgnutls30=3.7.9-2+deb12u3 \
libblkid1=2.38.1-5+deb12u1 \
libmount1=2.38.1-5+deb12u1 \
libsmartcols1=2.38.1-5+deb12u1 \
libuuid1=2.38.1-5+deb12u1 \
libxmlsec1-dev \
pkg-config \
gcc \

View File

@@ -1,6 +1,6 @@
from typing import Any, Literal
from onyx.db.engine import get_iam_auth_token
from onyx.configs.app_configs import USE_IAM_AUTH
from onyx.db.engine import SYNC_DB_API, get_iam_auth_token
from onyx.configs.app_configs import POSTGRES_DB, USE_IAM_AUTH
from onyx.configs.app_configs import POSTGRES_HOST
from onyx.configs.app_configs import POSTGRES_PORT
from onyx.configs.app_configs import POSTGRES_USER
@@ -13,12 +13,11 @@ from sqlalchemy import text
from sqlalchemy.engine.base import Connection
import os
import ssl
import asyncio
import logging
from logging.config import fileConfig
from alembic import context
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy import create_engine
from sqlalchemy.sql.schema import SchemaItem
from onyx.configs.constants import SSL_CERT_FILE
from shared_configs.configs import MULTI_TENANT, POSTGRES_DEFAULT_SCHEMA
@@ -133,17 +132,32 @@ def provide_iam_token_for_alembic(
cparams["ssl"] = ssl_context
async def run_async_migrations() -> None:
def run_migrations() -> None:
schema_name, create_schema, upgrade_all_tenants = get_schema_options()
engine = create_async_engine(
build_connection_string(),
# Get any environment variables passed through alembic config
env_vars = context.config.attributes.get("env_vars", {})
# Use env vars if provided, otherwise fall back to defaults
postgres_host = env_vars.get("POSTGRES_HOST", POSTGRES_HOST)
postgres_port = env_vars.get("POSTGRES_PORT", POSTGRES_PORT)
postgres_user = env_vars.get("POSTGRES_USER", POSTGRES_USER)
postgres_db = env_vars.get("POSTGRES_DB", POSTGRES_DB)
engine = create_engine(
build_connection_string(
db=postgres_db,
user=postgres_user,
host=postgres_host,
port=postgres_port,
db_api=SYNC_DB_API,
),
poolclass=pool.NullPool,
)
if USE_IAM_AUTH:
@event.listens_for(engine.sync_engine, "do_connect")
@event.listens_for(engine, "do_connect")
def event_provide_iam_token_for_alembic(
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any
) -> None:
@@ -152,31 +166,26 @@ async def run_async_migrations() -> None:
if upgrade_all_tenants:
tenant_schemas = get_all_tenant_ids()
for schema in tenant_schemas:
if schema is None:
continue
try:
logger.info(f"Migrating schema: {schema}")
async with engine.connect() as connection:
await connection.run_sync(
do_run_migrations,
schema_name=schema,
create_schema=create_schema,
)
with engine.connect() as connection:
do_run_migrations(connection, schema, create_schema)
except Exception as e:
logger.error(f"Error migrating schema {schema}: {e}")
raise
else:
try:
logger.info(f"Migrating schema: {schema_name}")
async with engine.connect() as connection:
await connection.run_sync(
do_run_migrations,
schema_name=schema_name,
create_schema=create_schema,
)
with engine.connect() as connection:
do_run_migrations(connection, schema_name, create_schema)
except Exception as e:
logger.error(f"Error migrating schema {schema_name}: {e}")
raise
await engine.dispose()
engine.dispose()
def run_migrations_offline() -> None:
@@ -184,18 +193,18 @@ def run_migrations_offline() -> None:
url = build_connection_string()
if upgrade_all_tenants:
engine = create_async_engine(url)
engine = create_engine(url)
if USE_IAM_AUTH:
@event.listens_for(engine.sync_engine, "do_connect")
@event.listens_for(engine, "do_connect")
def event_provide_iam_token_for_alembic_offline(
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any
) -> None:
provide_iam_token_for_alembic(dialect, conn_rec, cargs, cparams)
tenant_schemas = get_all_tenant_ids()
engine.sync_engine.dispose()
engine.dispose()
for schema in tenant_schemas:
logger.info(f"Migrating schema: {schema}")
@@ -230,7 +239,7 @@ def run_migrations_offline() -> None:
def run_migrations_online() -> None:
asyncio.run(run_async_migrations())
run_migrations()
if context.is_offline_mode():

View File

@@ -5,14 +5,14 @@ from langgraph.graph import StateGraph
from onyx.agents.agent_search.basic.states import BasicInput
from onyx.agents.agent_search.basic.states import BasicOutput
from onyx.agents.agent_search.basic.states import BasicState
from onyx.agents.agent_search.orchestration.nodes.call_tool import call_tool
from onyx.agents.agent_search.orchestration.nodes.choose_tool import choose_tool
from onyx.agents.agent_search.orchestration.nodes.basic_use_tool_response import (
basic_use_tool_response,
)
from onyx.agents.agent_search.orchestration.nodes.llm_tool_choice import llm_tool_choice
from onyx.agents.agent_search.orchestration.nodes.prepare_tool_input import (
prepare_tool_input,
)
from onyx.agents.agent_search.orchestration.nodes.use_tool_response import (
basic_use_tool_response,
)
from onyx.agents.agent_search.orchestration.nodes.tool_call import tool_call
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -33,13 +33,13 @@ def basic_graph_builder() -> StateGraph:
)
graph.add_node(
node="choose_tool",
action=choose_tool,
node="llm_tool_choice",
action=llm_tool_choice,
)
graph.add_node(
node="call_tool",
action=call_tool,
node="tool_call",
action=tool_call,
)
graph.add_node(
@@ -51,12 +51,12 @@ def basic_graph_builder() -> StateGraph:
graph.add_edge(start_key=START, end_key="prepare_tool_input")
graph.add_edge(start_key="prepare_tool_input", end_key="choose_tool")
graph.add_edge(start_key="prepare_tool_input", end_key="llm_tool_choice")
graph.add_conditional_edges("choose_tool", should_continue, ["call_tool", END])
graph.add_conditional_edges("llm_tool_choice", should_continue, ["tool_call", END])
graph.add_edge(
start_key="call_tool",
start_key="tool_call",
end_key="basic_use_tool_response",
)
@@ -73,7 +73,7 @@ def should_continue(state: BasicState) -> str:
# If there are no tool calls, basic graph already streamed the answer
END
if state.tool_choice is None
else "call_tool"
else "tool_call"
)

View File

@@ -31,14 +31,12 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_SUBANSWER_CHECK
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import SUB_ANSWER_CHECK_PROMPT
from onyx.prompts.agent_search import UNKNOWN_ANSWER
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.timing import log_function_time
logger = setup_logger()
@@ -87,11 +85,9 @@ def check_sub_answer(
agent_error: AgentErrorLog | None = None
response: BaseMessage | None = None
try:
response = run_with_timeout(
AGENT_TIMEOUT_LLM_SUBANSWER_CHECK,
fast_llm.invoke,
response = fast_llm.invoke(
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK,
)
quality_str: str = cast(str, response.content)
@@ -100,7 +96,7 @@ def check_sub_answer(
)
log_result = f"Answer quality: {quality_str}"
except (LLMTimeoutError, TimeoutError):
except LLMTimeoutError:
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,

View File

@@ -1,4 +1,5 @@
from datetime import datetime
from typing import Any
from typing import cast
from langchain_core.messages import merge_message_runs
@@ -46,13 +47,11 @@ from onyx.chat.models import StreamStopInfo
from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import NO_RECOVERED_DOCS
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.timing import log_function_time
logger = setup_logger()
@@ -111,14 +110,15 @@ def generate_sub_answer(
config=fast_llm.config,
)
response: list[str | list[str | dict[str, Any]]] = []
dispatch_timings: list[float] = []
agent_error: AgentErrorLog | None = None
response: list[str] = []
def stream_sub_answer() -> list[str]:
agent_error: AgentErrorLog | None = None
try:
for message in fast_llm.stream(
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
@@ -142,15 +142,8 @@ def generate_sub_answer(
(end_stream_token - start_stream_token).microseconds
)
response.append(content)
return response
try:
response = run_with_timeout(
AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION,
stream_sub_answer,
)
except (LLMTimeoutError, TimeoutError):
except LLMTimeoutError:
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,

View File

@@ -1,4 +1,5 @@
from datetime import datetime
from typing import Any
from typing import cast
from langchain_core.messages import HumanMessage
@@ -59,15 +60,11 @@ from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import ExtendedToolResponse
from onyx.chat.models import StreamingError
from onyx.configs.agent_configs import AGENT_ANSWER_GENERATION_BY_FAST_LLM
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_MAX_STREAMED_DOCS_FOR_INITIAL_ANSWER
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
)
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION,
AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION,
)
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
@@ -80,7 +77,6 @@ from onyx.prompts.agent_search import (
)
from onyx.prompts.agent_search import UNKNOWN_ANSWER
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.timing import log_function_time
_llm_node_error_strings = LLMNodeErrorStrings(
@@ -234,11 +230,7 @@ def generate_initial_answer(
sub_questions = all_sub_questions # Replace the original assignment
model = (
graph_config.tooling.fast_llm
if AGENT_ANSWER_GENERATION_BY_FAST_LLM
else graph_config.tooling.primary_llm
)
model = graph_config.tooling.fast_llm
doc_context = format_docs(answer_generation_documents.context_documents)
doc_context = trim_prompt_piece(
@@ -268,16 +260,15 @@ def generate_initial_answer(
)
]
streamed_tokens: list[str] = [""]
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
dispatch_timings: list[float] = []
agent_error: AgentErrorLog | None = None
def stream_initial_answer() -> list[str]:
response: list[str] = []
try:
for message in model.stream(
msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION,
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
@@ -301,16 +292,9 @@ def generate_initial_answer(
dispatch_timings.append(
(end_stream_token - start_stream_token).microseconds
)
response.append(content)
return response
streamed_tokens.append(content)
try:
streamed_tokens = run_with_timeout(
AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION,
stream_initial_answer,
)
except (LLMTimeoutError, TimeoutError):
except LLMTimeoutError:
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,

View File

@@ -36,10 +36,7 @@ from onyx.chat.models import StreamType
from onyx.chat.models import SubQuestionPiece
from onyx.configs.agent_configs import AGENT_NUM_DOCS_FOR_DECOMPOSITION
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION,
)
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION,
AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION,
)
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
@@ -50,7 +47,6 @@ from onyx.prompts.agent_search import (
INITIAL_QUESTION_DECOMPOSITION_PROMPT_ASSUMING_REFINEMENT,
)
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.timing import log_function_time
logger = setup_logger()
@@ -135,12 +131,10 @@ def decompose_orig_question(
streamed_tokens: list[BaseMessage_Content] = []
try:
streamed_tokens = run_with_timeout(
AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION,
dispatch_separated,
streamed_tokens = dispatch_separated(
model.stream(
msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION,
),
dispatch_subquestion(0, writer),
sep_callback=dispatch_subquestion_sep(0, writer),
@@ -160,7 +154,7 @@ def decompose_orig_question(
)
write_custom_event("stream_finished", stop_event, writer)
except (LLMTimeoutError, TimeoutError) as e:
except LLMTimeoutError as e:
logger.error("LLM Timeout Error - decompose orig question")
raise e # fail loudly on this critical step
except LLMRateLimitError as e:

View File

@@ -25,7 +25,7 @@ logger = setup_logger()
def route_initial_tool_choice(
state: MainState, config: RunnableConfig
) -> Literal["call_tool", "start_agent_search", "logging_node"]:
) -> Literal["tool_call", "start_agent_search", "logging_node"]:
"""
LangGraph edge to route to agent search.
"""
@@ -38,7 +38,7 @@ def route_initial_tool_choice(
):
return "start_agent_search"
else:
return "call_tool"
return "tool_call"
else:
return "logging_node"

View File

@@ -43,14 +43,14 @@ from onyx.agents.agent_search.deep_search.main.states import MainState
from onyx.agents.agent_search.deep_search.refinement.consolidate_sub_answers.graph_builder import (
answer_refined_query_graph_builder,
)
from onyx.agents.agent_search.orchestration.nodes.call_tool import call_tool
from onyx.agents.agent_search.orchestration.nodes.choose_tool import choose_tool
from onyx.agents.agent_search.orchestration.nodes.basic_use_tool_response import (
basic_use_tool_response,
)
from onyx.agents.agent_search.orchestration.nodes.llm_tool_choice import llm_tool_choice
from onyx.agents.agent_search.orchestration.nodes.prepare_tool_input import (
prepare_tool_input,
)
from onyx.agents.agent_search.orchestration.nodes.use_tool_response import (
basic_use_tool_response,
)
from onyx.agents.agent_search.orchestration.nodes.tool_call import tool_call
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
from onyx.utils.logger import setup_logger
@@ -77,13 +77,13 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
# Choose the initial tool
graph.add_node(
node="initial_tool_choice",
action=choose_tool,
action=llm_tool_choice,
)
# Call the tool, if required
graph.add_node(
node="call_tool",
action=call_tool,
node="tool_call",
action=tool_call,
)
# Use the tool response
@@ -168,11 +168,11 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
graph.add_conditional_edges(
"initial_tool_choice",
route_initial_tool_choice,
["call_tool", "start_agent_search", "logging_node"],
["tool_call", "start_agent_search", "logging_node"],
)
graph.add_edge(
start_key="call_tool",
start_key="tool_call",
end_key="basic_use_tool_response",
)
graph.add_edge(

View File

@@ -33,15 +33,13 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
)
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import RefinedAnswerImprovement
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_COMPARE_ANSWERS
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import (
INITIAL_REFINED_ANSWER_COMPARISON_PROMPT,
)
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.timing import log_function_time
logger = setup_logger()
@@ -107,14 +105,11 @@ def compare_answers(
refined_answer_improvement: bool | None = None
# no need to stream this
try:
resp = run_with_timeout(
AGENT_TIMEOUT_LLM_COMPARE_ANSWERS,
model.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS,
resp = model.invoke(
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS
)
except (LLMTimeoutError, TimeoutError):
except LLMTimeoutError:
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,

View File

@@ -44,10 +44,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import StreamingError
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION,
)
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION,
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION,
)
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
@@ -56,7 +53,6 @@ from onyx.prompts.agent_search import (
)
from onyx.tools.models import ToolCallKickoff
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.timing import log_function_time
logger = setup_logger()
@@ -138,17 +134,15 @@ def create_refined_sub_questions(
agent_error: AgentErrorLog | None = None
streamed_tokens: list[BaseMessage_Content] = []
try:
streamed_tokens = run_with_timeout(
AGENT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION,
dispatch_separated,
streamed_tokens = dispatch_separated(
model.stream(
msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION,
),
dispatch_subquestion(1, writer),
sep_callback=dispatch_subquestion_sep(1, writer),
)
except (LLMTimeoutError, TimeoutError):
except LLMTimeoutError:
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,

View File

@@ -22,17 +22,11 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION,
)
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION,
AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION,
)
from onyx.configs.constants import NUM_EXPLORATORY_DOCS
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import ENTITY_TERM_EXTRACTION_PROMPT
from onyx.prompts.agent_search import ENTITY_TERM_EXTRACTION_PROMPT_JSON_EXAMPLE
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.timing import log_function_time
@@ -90,42 +84,30 @@ def extract_entities_terms(
]
fast_llm = graph_config.tooling.fast_llm
# Grader
llm_response = fast_llm.invoke(
prompt=msg,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION,
)
cleaned_response = (
str(llm_response.content).replace("```json\n", "").replace("\n```", "")
)
first_bracket = cleaned_response.find("{")
last_bracket = cleaned_response.rfind("}")
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
try:
llm_response = run_with_timeout(
AGENT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION,
fast_llm.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION,
entity_extraction_result = EntityExtractionResult.model_validate_json(
cleaned_response
)
cleaned_response = (
str(llm_response.content).replace("```json\n", "").replace("\n```", "")
)
first_bracket = cleaned_response.find("{")
last_bracket = cleaned_response.rfind("}")
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
try:
entity_extraction_result = EntityExtractionResult.model_validate_json(
cleaned_response
)
except ValueError:
logger.error(
"Failed to parse LLM response as JSON in Entity-Term Extraction"
)
entity_extraction_result = EntityExtractionResult(
retrieved_entities_relationships=EntityRelationshipTermExtraction(),
)
except (LLMTimeoutError, TimeoutError):
logger.error("LLM Timeout Error - extract entities terms")
except ValueError:
logger.error("Failed to parse LLM response as JSON in Entity-Term Extraction")
entity_extraction_result = EntityExtractionResult(
retrieved_entities_relationships=EntityRelationshipTermExtraction(),
)
except LLMRateLimitError:
logger.error("LLM Rate Limit Error - extract entities terms")
entity_extraction_result = EntityExtractionResult(
retrieved_entities_relationships=EntityRelationshipTermExtraction(),
retrieved_entities_relationships=EntityRelationshipTermExtraction(
entities=[],
relationships=[],
terms=[],
),
)
return EntityTermExtractionUpdate(

View File

@@ -1,4 +1,5 @@
from datetime import datetime
from typing import Any
from typing import cast
from langchain_core.messages import HumanMessage
@@ -65,21 +66,14 @@ from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
from onyx.chat.models import AgentAnswerPiece
from onyx.chat.models import ExtendedToolResponse
from onyx.chat.models import StreamingError
from onyx.configs.agent_configs import AGENT_ANSWER_GENERATION_BY_FAST_LLM
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
from onyx.configs.agent_configs import AGENT_MAX_STREAMED_DOCS_FOR_REFINED_ANSWER
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION,
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION,
)
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION,
)
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION,
)
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION,
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION,
)
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
@@ -98,7 +92,6 @@ from onyx.prompts.agent_search import (
from onyx.prompts.agent_search import UNKNOWN_ANSWER
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.timing import log_function_time
logger = setup_logger()
@@ -260,12 +253,7 @@ def generate_validate_refined_answer(
else REFINED_ANSWER_PROMPT_WO_SUB_QUESTIONS
)
model = (
graph_config.tooling.fast_llm
if AGENT_ANSWER_GENERATION_BY_FAST_LLM
else graph_config.tooling.primary_llm
)
model = graph_config.tooling.fast_llm
relevant_docs_str = format_docs(answer_generation_documents.context_documents)
relevant_docs_str = trim_prompt_piece(
model.config,
@@ -296,13 +284,13 @@ def generate_validate_refined_answer(
)
]
streamed_tokens: list[str] = [""]
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
dispatch_timings: list[float] = []
agent_error: AgentErrorLog | None = None
def stream_refined_answer() -> list[str]:
try:
for message in model.stream(
msg, timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION
):
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
content = message.content
@@ -327,15 +315,8 @@ def generate_validate_refined_answer(
(end_stream_token - start_stream_token).microseconds
)
streamed_tokens.append(content)
return streamed_tokens
try:
streamed_tokens = run_with_timeout(
AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION,
stream_refined_answer,
)
except (LLMTimeoutError, TimeoutError):
except LLMTimeoutError:
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
@@ -402,20 +383,16 @@ def generate_validate_refined_answer(
)
]
validation_model = graph_config.tooling.fast_llm
try:
validation_response = run_with_timeout(
AGENT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION,
validation_model.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION,
validation_response = model.invoke(
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION
)
refined_answer_quality = binary_string_test_after_answer_separator(
text=cast(str, validation_response.content),
positive_value=AGENT_POSITIVE_VALUE_STR,
separator=AGENT_ANSWER_SEPARATOR,
)
except (LLMTimeoutError, TimeoutError):
except LLMTimeoutError:
refined_answer_quality = True
logger.error("LLM Timeout Error - validate refined answer")

View File

@@ -34,16 +34,14 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
)
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION,
AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION,
)
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import (
QUERY_REWRITING_PROMPT,
)
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.timing import log_function_time
logger = setup_logger()
@@ -71,7 +69,7 @@ def expand_queries(
node_start_time = datetime.now()
question = state.question
model = graph_config.tooling.fast_llm
llm = graph_config.tooling.fast_llm
sub_question_id = state.sub_question_id
if sub_question_id is None:
level, question_num = 0, 0
@@ -90,12 +88,10 @@ def expand_queries(
rewritten_queries = []
try:
llm_response_list = run_with_timeout(
AGENT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION,
dispatch_separated,
model.stream(
llm_response_list = dispatch_separated(
llm.stream(
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION,
),
dispatch_subquery(level, question_num, writer),
)
@@ -105,7 +101,7 @@ def expand_queries(
rewritten_queries = llm_response.split("\n")
log_result = f"Number of expanded queries: {len(rewritten_queries)}"
except (LLMTimeoutError, TimeoutError):
except LLMTimeoutError:
agent_error = AgentErrorLog(
error_type=AgentLLMErrorType.TIMEOUT,
error_message=AGENT_LLM_TIMEOUT_MESSAGE,

View File

@@ -55,7 +55,6 @@ def rerank_documents(
# Note that these are passed in values from the API and are overrides which are typically None
rerank_settings = graph_config.inputs.search_request.rerank_settings
allow_agent_reranking = graph_config.behavior.allow_agent_reranking
if rerank_settings is None:
with get_session_context_manager() as db_session:
@@ -63,31 +62,23 @@ def rerank_documents(
if not search_settings.disable_rerank_for_streaming:
rerank_settings = RerankingDetails.from_db_model(search_settings)
# Initial default: no reranking. Will be overwritten below if reranking is warranted
reranked_documents = verified_documents
if should_rerank(rerank_settings) and len(verified_documents) > 0:
if len(verified_documents) > 1:
if not allow_agent_reranking:
logger.info("Use of local rerank model without GPU, skipping reranking")
# No reranking, stay with verified_documents as default
else:
# Reranking is warranted, use the rerank_sections functon
reranked_documents = rerank_sections(
query_str=question,
# if runnable, then rerank_settings is not None
rerank_settings=cast(RerankingDetails, rerank_settings),
sections_to_rerank=verified_documents,
)
reranked_documents = rerank_sections(
query_str=question,
# if runnable, then rerank_settings is not None
rerank_settings=cast(RerankingDetails, rerank_settings),
sections_to_rerank=verified_documents,
)
else:
logger.warning(
f"{len(verified_documents)} verified document(s) found, skipping reranking"
)
# No reranking, stay with verified_documents as default
reranked_documents = verified_documents
else:
logger.warning("No reranking settings found, using unranked documents")
# No reranking, stay with verified_documents as default
reranked_documents = verified_documents
if AGENT_RERANKING_STATS:
fit_scores = get_fit_scores(verified_documents, reranked_documents)
else:

View File

@@ -25,15 +25,13 @@ from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrin
from onyx.agents.agent_search.shared_graph_utils.utils import (
get_langgraph_node_log_string,
)
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION
from onyx.llm.chat_llm import LLMRateLimitError
from onyx.llm.chat_llm import LLMTimeoutError
from onyx.prompts.agent_search import (
DOCUMENT_VERIFICATION_PROMPT,
)
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
from onyx.utils.timing import log_function_time
logger = setup_logger()
@@ -88,11 +86,8 @@ def verify_documents(
] # default is to treat document as relevant
try:
response = run_with_timeout(
AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION,
fast_llm.invoke,
prompt=msg,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION,
response = fast_llm.invoke(
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION
)
assert isinstance(response.content, str)
@@ -101,7 +96,7 @@ def verify_documents(
):
verified_documents = []
except (LLMTimeoutError, TimeoutError):
except LLMTimeoutError:
# In this case, we decide to continue and don't raise an error, as
# little harm in letting some docs through that are less relevant.
logger.error("LLM Timeout Error - verify documents")

View File

@@ -67,7 +67,6 @@ class GraphSearchConfig(BaseModel):
# Whether to allow creation of refinement questions (and entity extraction, etc.)
allow_refinement: bool = True
skip_gen_ai_answer_generation: bool = False
allow_agent_reranking: bool = False
class GraphConfig(BaseModel):

View File

@@ -25,7 +25,7 @@ logger = setup_logger()
# and a function that handles extracting the necessary fields
# from the state and config
# TODO: fan-out to multiple tool call nodes? Make this configurable?
def choose_tool(
def llm_tool_choice(
state: ToolChoiceState,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,

View File

@@ -28,7 +28,7 @@ def emit_packet(packet: AnswerPacket, writer: StreamWriter) -> None:
write_custom_event("basic_response", packet, writer)
def call_tool(
def tool_call(
state: ToolChoiceUpdate,
config: RunnableConfig,
writer: StreamWriter = lambda _: None,

View File

@@ -43,9 +43,8 @@ from onyx.chat.models import StreamStopReason
from onyx.chat.models import StreamType
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.configs.agent_configs import (
AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION,
AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION,
)
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
from onyx.configs.constants import DEFAULT_PERSONA_ID
@@ -81,7 +80,6 @@ from onyx.tools.tool_implementations.search.search_tool import SearchResponseSum
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.utils import explicit_tool_calling_supported
from onyx.utils.logger import setup_logger
from onyx.utils.threadpool_concurrency import run_with_timeout
logger = setup_logger()
@@ -397,13 +395,11 @@ def summarize_history(
)
try:
history_response = run_with_timeout(
AGENT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION,
llm.invoke,
history_response = llm.invoke(
history_context_prompt,
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION,
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION,
)
except (LLMTimeoutError, TimeoutError):
except LLMTimeoutError:
logger.error("LLM Timeout Error - summarize history")
return (
history # this is what is done at this point anyway, so we default to this

View File

@@ -94,7 +94,6 @@ from onyx.db.models import User
from onyx.db.users import get_user_by_email
from onyx.redis.redis_pool import get_async_redis_connection
from onyx.redis.redis_pool import get_redis_client
from onyx.server.utils import BasicAuthenticationError
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import create_milestone_and_report
from onyx.utils.telemetry import optional_telemetry
@@ -108,6 +107,11 @@ from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
logger = setup_logger()
class BasicAuthenticationError(HTTPException):
def __init__(self, detail: str):
super().__init__(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
def is_user_admin(user: User | None) -> bool:
if AUTH_TYPE == AuthType.DISABLED:
return True

View File

@@ -27,10 +27,8 @@ from onyx.file_store.utils import InMemoryChatFile
from onyx.llm.interfaces import LLM
from onyx.tools.force import ForceUseTool
from onyx.tools.tool import Tool
from onyx.tools.tool_implementations.search.search_tool import QUERY_FIELD
from onyx.tools.tool_implementations.search.search_tool import SearchTool
from onyx.tools.utils import explicit_tool_calling_supported
from onyx.utils.gpu_utils import gpu_status_request
from onyx.utils.logger import setup_logger
logger = setup_logger()
@@ -82,26 +80,6 @@ class Answer:
and not skip_explicit_tool_calling
)
rerank_settings = search_request.rerank_settings
using_cloud_reranking = (
rerank_settings is not None
and rerank_settings.rerank_provider_type is not None
)
allow_agent_reranking = gpu_status_request() or using_cloud_reranking
# TODO: this is a hack to force the query to be used for the search tool
# this should be removed once we fully unify graph inputs (i.e.
# remove SearchQuery entirely)
if (
force_use_tool.force_use
and search_tool
and force_use_tool.args
and force_use_tool.tool_name == search_tool.name
and QUERY_FIELD in force_use_tool.args
):
search_request.query = force_use_tool.args[QUERY_FIELD]
self.graph_inputs = GraphInputs(
search_request=search_request,
prompt_builder=prompt_builder,
@@ -116,6 +94,7 @@ class Answer:
force_use_tool=force_use_tool,
using_tool_calling_llm=using_tool_calling_llm,
)
assert db_session, "db_session must be provided for agentic persistence"
self.graph_persistence = GraphPersistence(
db_session=db_session,
chat_session_id=chat_session_id,
@@ -125,7 +104,6 @@ class Answer:
use_agentic_search=use_agentic_search,
skip_gen_ai_answer_generation=skip_gen_ai_answer_generation,
allow_refinement=True,
allow_agent_reranking=allow_agent_reranking,
)
self.graph_config = GraphConfig(
inputs=self.graph_inputs,

View File

@@ -7,7 +7,7 @@ from typing import cast
from sqlalchemy.orm import Session
from onyx.agents.agent_search.orchestration.nodes.call_tool import ToolCallException
from onyx.agents.agent_search.orchestration.nodes.tool_call import ToolCallException
from onyx.chat.answer import Answer
from onyx.chat.chat_utils import create_chat_chain
from onyx.chat.chat_utils import create_temporary_persona

View File

@@ -31,9 +31,22 @@ AGENT_DEFAULT_MIN_ORIG_QUESTION_DOCS = 3
AGENT_DEFAULT_MAX_ANSWER_CONTEXT_DOCS = 10
AGENT_DEFAULT_MAX_STATIC_HISTORY_WORD_LENGTH = 2000
AGENT_ANSWER_GENERATION_BY_FAST_LLM = (
os.environ.get("AGENT_ANSWER_GENERATION_BY_FAST_LLM", "").lower() == "true"
)
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION = 30 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION = 10 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION = 25 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION = 4 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION = 1 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION = 3 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION = 12 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK = 8 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION = 25 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION = 6 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION = 25 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION = 8 # in seconds
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS = 8 # in seconds
AGENT_RETRIEVAL_STATS = (
not os.environ.get("AGENT_RETRIEVAL_STATS") == "False"
@@ -165,172 +178,80 @@ AGENT_MAX_STATIC_HISTORY_WORD_LENGTH = int(
) # 2000
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION = 10 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION
)
AGENT_DEFAULT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION = 30 # in seconds
AGENT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION = int(
os.environ.get("AGENT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION")
or AGENT_DEFAULT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION
)
AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION
) # 25
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION = 2 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION
)
AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION
) # 3
AGENT_DEFAULT_TIMEOUT_LLM_DOCUMENT_VERIFICATION = 4 # in seconds
AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION = int(
os.environ.get("AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION")
or AGENT_DEFAULT_TIMEOUT_LLM_DOCUMENT_VERIFICATION
)
AGENT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION
) # 30
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_GENERAL_GENERATION = 5 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_GENERAL_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_GENERAL_GENERATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_GENERAL_GENERATION
)
AGENT_DEFAULT_TIMEOUT_LLM_GENERAL_GENERATION = 30 # in seconds
AGENT_TIMEOUT_LLM_GENERAL_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_LLM_GENERAL_GENERATION")
or AGENT_DEFAULT_TIMEOUT_LLM_GENERAL_GENERATION
)
AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION
) # 8
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION = 2 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION
)
AGENT_DEFAULT_TIMEOUT_LLM_SUBQUESTION_GENERATION = 5 # in seconds
AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION")
or AGENT_DEFAULT_TIMEOUT_LLM_SUBQUESTION_GENERATION
)
AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION
) # 12
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION = 3 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION
)
AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION = 30 # in seconds
AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION
)
AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION
) # 25
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION = 5 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION
)
AGENT_DEFAULT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION = 25 # in seconds
AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION
)
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION
) # 25
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION = 5 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION
)
AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION = 30 # in seconds
AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION")
or AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION
)
AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK
) # 8
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK = 2 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK
)
AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_CHECK = 8 # in seconds
AGENT_TIMEOUT_LLM_SUBANSWER_CHECK = int(
os.environ.get("AGENT_TIMEOUT_LLM_SUBANSWER_CHECK")
or AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_CHECK
)
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION
) # 6
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION = 3 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION
)
AGENT_DEFAULT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION = 8 # in seconds
AGENT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION")
or AGENT_DEFAULT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION
)
AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION
) # 1
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION = 1 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION
)
AGENT_DEFAULT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION = 3 # in seconds
AGENT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION")
or AGENT_DEFAULT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION
)
AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION
) # 4
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION = 2 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION
)
AGENT_DEFAULT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION = 5 # in seconds
AGENT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION = int(
os.environ.get("AGENT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION")
or AGENT_DEFAULT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION
)
AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS
) # 8
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS = 2 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS
)
AGENT_DEFAULT_TIMEOUT_LLM_COMPARE_ANSWERS = 8 # in seconds
AGENT_TIMEOUT_LLM_COMPARE_ANSWERS = int(
os.environ.get("AGENT_TIMEOUT_LLM_COMPARE_ANSWERS")
or AGENT_DEFAULT_TIMEOUT_LLM_COMPARE_ANSWERS
)
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION = 2 # in seconds
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION = int(
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION")
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION
)
AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION = 8 # in seconds
AGENT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION = int(
os.environ.get("AGENT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION")
or AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION
)
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION = int(
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION")
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION
) # 8
GRAPH_VERSION_NAME: str = "a"

View File

@@ -0,0 +1,6 @@
import os
SKIP_CONNECTION_POOL_WARM_UP = (
os.environ.get("SKIP_CONNECTION_POOL_WARM_UP", "").lower() == "true"
)

View File

@@ -628,7 +628,7 @@ def create_new_chat_message(
commit: bool = True,
reserved_message_id: int | None = None,
overridden_model: str | None = None,
refined_answer_improvement: bool | None = None,
refined_answer_improvement: bool = True,
) -> ChatMessage:
if reserved_message_id is not None:
# Edit existing message

View File

@@ -191,9 +191,16 @@ class SqlEngine:
_app_name: str = POSTGRES_UNKNOWN_APP_NAME
@classmethod
def _init_engine(cls, **engine_kwargs: Any) -> Engine:
def _init_engine(
cls, host: str, port: str, db: str, **engine_kwargs: Any
) -> Engine:
connection_string = build_connection_string(
db_api=SYNC_DB_API, app_name=cls._app_name + "_sync", use_iam=USE_IAM_AUTH
db_api=SYNC_DB_API,
host=host,
port=port,
db=db,
app_name=cls._app_name + "_sync",
use_iam=USE_IAM_AUTH,
)
# Start with base kwargs that are valid for all pool types
@@ -231,15 +238,19 @@ class SqlEngine:
def init_engine(cls, **engine_kwargs: Any) -> None:
with cls._lock:
if not cls._engine:
cls._engine = cls._init_engine(**engine_kwargs)
cls._engine = cls._init_engine(
host=engine_kwargs.get("host", POSTGRES_HOST),
port=engine_kwargs.get("port", POSTGRES_PORT),
db=engine_kwargs.get("db", POSTGRES_DB),
**engine_kwargs,
)
@classmethod
def get_engine(cls) -> Engine:
if not cls._engine:
with cls._lock:
if not cls._engine:
cls._engine = cls._init_engine()
return cls._engine
cls.init_engine()
return cls._engine # type: ignore
@classmethod
def set_app_name(cls, app_name: str) -> None:

View File

@@ -6,6 +6,7 @@ from onyx.db.search_settings import get_current_search_settings
from onyx.document_index.interfaces import DocumentIndex
from onyx.document_index.vespa.index import VespaIndex
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import VECTOR_DB_INDEX_NAME_PREFIX__INTEGRATION_TEST_ONLY
def get_default_document_index(
@@ -23,14 +24,27 @@ def get_default_document_index(
secondary_index_name = secondary_search_settings.index_name
secondary_large_chunks_enabled = secondary_search_settings.large_chunks_enabled
# modify index names for integration tests so that we can run many tests
# using the same Vespa instance w/o having them collide
primary_index_name = search_settings.index_name
if VECTOR_DB_INDEX_NAME_PREFIX__INTEGRATION_TEST_ONLY:
primary_index_name = (
f"{VECTOR_DB_INDEX_NAME_PREFIX__INTEGRATION_TEST_ONLY}_{primary_index_name}"
)
if secondary_index_name:
secondary_index_name = f"{VECTOR_DB_INDEX_NAME_PREFIX__INTEGRATION_TEST_ONLY}_{secondary_index_name}"
# Currently only supporting Vespa
return VespaIndex(
index_name=search_settings.index_name,
index_name=primary_index_name,
secondary_index_name=secondary_index_name,
large_chunks_enabled=search_settings.large_chunks_enabled,
secondary_large_chunks_enabled=secondary_large_chunks_enabled,
multitenant=MULTI_TENANT,
httpx_client=httpx_client,
preserve_existing_indices=bool(
VECTOR_DB_INDEX_NAME_PREFIX__INTEGRATION_TEST_ONLY
),
)

View File

@@ -136,6 +136,7 @@ class VespaIndex(DocumentIndex):
secondary_large_chunks_enabled: bool | None,
multitenant: bool = False,
httpx_client: httpx.Client | None = None,
preserve_existing_indices: bool = False,
) -> None:
self.index_name = index_name
self.secondary_index_name = secondary_index_name
@@ -161,18 +162,18 @@ class VespaIndex(DocumentIndex):
secondary_index_name
] = secondary_large_chunks_enabled
def ensure_indices_exist(
self,
index_embedding_dim: int,
secondary_index_embedding_dim: int | None,
) -> None:
if MULTI_TENANT:
logger.info(
"Skipping Vespa index seup for multitenant (would wipe all indices)"
)
return None
self.preserve_existing_indices = preserve_existing_indices
deploy_url = f"{VESPA_APPLICATION_ENDPOINT}/tenant/default/prepareandactivate"
@classmethod
def create_indices(
cls,
indices: list[tuple[str, int, bool]],
application_endpoint: str = VESPA_APPLICATION_ENDPOINT,
) -> None:
"""
Create indices in Vespa based on the passed in configuration(s).
"""
deploy_url = f"{application_endpoint}/tenant/default/prepareandactivate"
logger.notice(f"Deploying Vespa application package to {deploy_url}")
vespa_schema_path = os.path.join(
@@ -185,7 +186,7 @@ class VespaIndex(DocumentIndex):
with open(services_file, "r") as services_f:
services_template = services_f.read()
schema_names = [self.index_name, self.secondary_index_name]
schema_names = [index_name for (index_name, _, _) in indices]
doc_lines = _create_document_xml_lines(schema_names)
services = services_template.replace(DOCUMENT_REPLACEMENT_PAT, doc_lines)
@@ -193,14 +194,6 @@ class VespaIndex(DocumentIndex):
SEARCH_THREAD_NUMBER_PAT, str(VESPA_SEARCHER_THREADS)
)
kv_store = get_kv_store()
needs_reindexing = False
try:
needs_reindexing = cast(bool, kv_store.load(KV_REINDEX_KEY))
except Exception:
logger.debug("Could not load the reindexing flag. Using ngrams")
with open(overrides_file, "r") as overrides_f:
overrides_template = overrides_f.read()
@@ -221,29 +214,63 @@ class VespaIndex(DocumentIndex):
schema_template = schema_f.read()
schema_template = schema_template.replace(TENANT_ID_PAT, "")
schema = schema_template.replace(
DANSWER_CHUNK_REPLACEMENT_PAT, self.index_name
).replace(VESPA_DIM_REPLACEMENT_PAT, str(index_embedding_dim))
for index_name, index_embedding_dim, needs_reindexing in indices:
schema = schema_template.replace(
DANSWER_CHUNK_REPLACEMENT_PAT, index_name
).replace(VESPA_DIM_REPLACEMENT_PAT, str(index_embedding_dim))
schema = add_ngrams_to_schema(schema) if needs_reindexing else schema
schema = schema.replace(TENANT_ID_PAT, "")
zip_dict[f"schemas/{schema_names[0]}.sd"] = schema.encode("utf-8")
if self.secondary_index_name:
upcoming_schema = schema_template.replace(
DANSWER_CHUNK_REPLACEMENT_PAT, self.secondary_index_name
).replace(VESPA_DIM_REPLACEMENT_PAT, str(secondary_index_embedding_dim))
zip_dict[f"schemas/{schema_names[1]}.sd"] = upcoming_schema.encode("utf-8")
schema = add_ngrams_to_schema(schema) if needs_reindexing else schema
schema = schema.replace(TENANT_ID_PAT, "")
logger.info(
f"Creating index: {index_name} with embedding "
f"dimension: {index_embedding_dim}. Schema:\n\n {schema}"
)
zip_dict[f"schemas/{index_name}.sd"] = schema.encode("utf-8")
zip_file = in_memory_zip_from_file_bytes(zip_dict)
headers = {"Content-Type": "application/zip"}
response = requests.post(deploy_url, headers=headers, data=zip_file)
if response.status_code != 200:
logger.error(f"Failed to create Vespa indices: {response.text}")
raise RuntimeError(
f"Failed to prepare Vespa Onyx Index. Response: {response.text}"
)
def ensure_indices_exist(
self,
index_embedding_dim: int,
secondary_index_embedding_dim: int | None,
) -> None:
if self.multitenant or MULTI_TENANT: # be extra safe here
logger.info(
"Skipping Vespa index setup for multitenant (would wipe all indices)"
)
return None
# Used in IT
# NOTE: this means that we can't switch embedding models
if self.preserve_existing_indices:
logger.info("Preserving existing indices")
return None
kv_store = get_kv_store()
primary_needs_reindexing = False
try:
primary_needs_reindexing = cast(bool, kv_store.load(KV_REINDEX_KEY))
except Exception:
logger.debug("Could not load the reindexing flag. Using ngrams")
indices = [
(self.index_name, index_embedding_dim, primary_needs_reindexing),
]
if self.secondary_index_name and secondary_index_embedding_dim:
indices.append(
(self.secondary_index_name, secondary_index_embedding_dim, False)
)
self.create_indices(indices)
@staticmethod
def register_multitenant_indices(
indices: list[str],

View File

@@ -409,6 +409,10 @@ class DefaultMultiLLM(LLM):
self._record_call(processed_prompt)
try:
print(
"model is",
f"{self.config.model_provider}/{self.config.deployment_name or self.config.model_name}",
)
return litellm.completion(
mock_response=MOCK_LLM_RESPONSE,
# model choice

View File

@@ -43,6 +43,7 @@ from onyx.configs.app_configs import USER_AUTH_SECRET
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import AuthType
from onyx.configs.constants import POSTGRES_WEB_APP_NAME
from onyx.configs.integration_test_configs import SKIP_CONNECTION_POOL_WARM_UP
from onyx.db.engine import SqlEngine
from onyx.db.engine import warm_up_connections
from onyx.server.api_key.api import router as api_key_router
@@ -208,8 +209,11 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
if DISABLE_GENERATIVE_AI:
logger.notice("Generative AI Q&A disabled")
# fill up Postgres connection pools
await warm_up_connections()
# only used for IT. Need to skip since it overloads postgres when we have 50+
# instances running
if not SKIP_CONNECTION_POOL_WARM_UP:
# fill up Postgres connection pools
await warm_up_connections()
if not MULTI_TENANT:
# We cache this at the beginning so there is no delay in the first telemetry

View File

@@ -146,10 +146,10 @@ class RedisPool:
cls._instance._init_pools()
return cls._instance
def _init_pools(self) -> None:
self._pool = RedisPool.create_pool(ssl=REDIS_SSL)
def _init_pools(self, redis_port: int = REDIS_PORT) -> None:
self._pool = RedisPool.create_pool(port=redis_port, ssl=REDIS_SSL)
self._replica_pool = RedisPool.create_pool(
host=REDIS_REPLICA_HOST, ssl=REDIS_SSL
host=REDIS_REPLICA_HOST, port=redis_port, ssl=REDIS_SSL
)
def get_client(self, tenant_id: str | None) -> Redis:

View File

@@ -213,6 +213,8 @@ def get_chat_session(
# we need the tool call objs anyways, so just fetch them in a single call
prefetch_tool_calls=True,
)
for message in session_messages:
translate_db_message_to_chat_message_detail(message)
return ChatSessionDetailResponse(
chat_session_id=session_id,

View File

@@ -251,7 +251,8 @@ def setup_vespa(
logger.notice("Vespa setup complete.")
return True
except Exception:
except Exception as e:
logger.debug(f"Error creating Vespa indices: {e}")
logger.notice(
f"Vespa setup did not succeed. The Vespa service may not be ready yet. Retrying in {WAIT_SECONDS} seconds."
)

View File

@@ -58,7 +58,6 @@ SEARCH_RESPONSE_SUMMARY_ID = "search_response_summary"
SEARCH_DOC_CONTENT_ID = "search_doc_content"
SECTION_RELEVANCE_LIST_ID = "section_relevance_list"
SEARCH_EVALUATION_ID = "llm_doc_eval"
QUERY_FIELD = "query"
class SearchResponseSummary(SearchQueryInfo):
@@ -180,12 +179,12 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
"parameters": {
"type": "object",
"properties": {
QUERY_FIELD: {
"query": {
"type": "string",
"description": "What to search for",
},
},
"required": [QUERY_FIELD],
"required": ["query"],
},
},
}
@@ -224,7 +223,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
rephrased_query = history_based_query_rephrase(
query=query, history=history, llm=llm
)
return {QUERY_FIELD: rephrased_query}
return {"query": rephrased_query}
"""Actual tool execution"""
@@ -280,7 +279,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
def run(
self, override_kwargs: SearchToolOverrideKwargs | None = None, **llm_kwargs: Any
) -> Generator[ToolResponse, None, None]:
query = cast(str, llm_kwargs[QUERY_FIELD])
query = cast(str, llm_kwargs["query"])
force_no_rerank = False
alternate_db_session = None
retrieved_sections_callback = None

View File

@@ -1,4 +1,3 @@
import threading
import uuid
from collections.abc import Callable
from concurrent.futures import as_completed
@@ -14,10 +13,6 @@ logger = setup_logger()
R = TypeVar("R")
# WARNING: it is not currently well understood whether we lose access to contextvars when functions are
# executed through this wrapper Do NOT try to acquire a db session in a function run through this unless
# you have heavily tested that multi-tenancy is respected. If/when we know for sure that it is or
# is not safe, update this comment.
def run_functions_tuples_in_parallel(
functions_with_args: list[tuple[Callable, tuple]],
allow_failures: bool = False,
@@ -83,10 +78,6 @@ class FunctionCall(Generic[R]):
return self.func(*self.args, **self.kwargs)
# WARNING: it is not currently well understood whether we lose access to contextvars when functions are
# executed through this wrapper Do NOT try to acquire a db session in a function run through this unless
# you have heavily tested that multi-tenancy is respected. If/when we know for sure that it is or
# is not safe, update this comment.
def run_functions_in_parallel(
function_calls: list[FunctionCall],
allow_failures: bool = False,
@@ -118,49 +109,3 @@ def run_functions_in_parallel(
raise
return results
class TimeoutThread(threading.Thread):
def __init__(
self, timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any
):
super().__init__()
self.timeout = timeout
self.func = func
self.args = args
self.kwargs = kwargs
self.exception: Exception | None = None
def run(self) -> None:
try:
self.result = self.func(*self.args, **self.kwargs)
except Exception as e:
self.exception = e
def end(self) -> None:
raise TimeoutError(
f"Function {self.func.__name__} timed out after {self.timeout} seconds"
)
# WARNING: it is not currently well understood whether we lose access to contextvars when functions are
# executed through this wrapper Do NOT try to acquire a db session in a function run through this unless
# you have heavily tested that multi-tenancy is respected. If/when we know for sure that it is or
# is not safe, update this comment.
def run_with_timeout(
timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any
) -> R:
"""
Executes a function with a timeout. If the function doesn't complete within the specified
timeout, raises TimeoutError.
"""
task = TimeoutThread(timeout, func, *args, **kwargs)
task.start()
task.join(timeout)
if task.exception is not None:
raise task.exception
if task.is_alive():
task.end()
return task.result

View File

@@ -270,3 +270,10 @@ SUPPORTED_EMBEDDING_MODELS = [
index_name="danswer_chunk_intfloat_multilingual_e5_small",
),
]
"""
INTEGRATION TEST ONLY SETTINGS
"""
VECTOR_DB_INDEX_NAME_PREFIX__INTEGRATION_TEST_ONLY = os.getenv(
"VECTOR_DB_INDEX_NAME_PREFIX__INTEGRATION_TEST_ONLY"
)

View File

@@ -3,6 +3,19 @@ FROM python:3.11.7-slim-bookworm
# Currently needs all dependencies, since the ITs use some of the Onyx
# backend code.
# Add Docker's official GPG key and repository for Debian
RUN apt-get update && \
apt-get install -y ca-certificates curl && \
install -m 0755 -d /etc/apt/keyrings && \
curl -fsSL https://download.docker.com/linux/debian/gpg -o /etc/apt/keyrings/docker.asc && \
chmod a+r /etc/apt/keyrings/docker.asc && \
echo \
"deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.asc] https://download.docker.com/linux/debian \
$(. /etc/os-release && echo "$VERSION_CODENAME") stable" | \
tee /etc/apt/sources.list.d/docker.list > /dev/null && \
apt-get update
# Install system dependencies
# cmake needed for psycopg (postgres)
# libpq-dev needed for psycopg (postgres)
@@ -15,6 +28,9 @@ RUN apt-get update && \
curl \
zip \
ca-certificates \
postgresql-client \
# Install Docker for DinD
docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin \
libgnutls30=3.7.9-2+deb12u3 \
libblkid1=2.38.1-5+deb12u1 \
libmount1=2.38.1-5+deb12u1 \
@@ -29,37 +45,19 @@ RUN apt-get update && \
# Install Python dependencies
# Remove py which is pulled in by retry, py is not needed and is a CVE
COPY ./requirements/default.txt /tmp/requirements.txt
COPY ./requirements/model_server.txt /tmp/model_server-requirements.txt
COPY ./requirements/ee.txt /tmp/ee-requirements.txt
RUN pip install --no-cache-dir --upgrade \
--retries 5 \
--timeout 30 \
-r /tmp/requirements.txt \
-r /tmp/model_server-requirements.txt \
-r /tmp/ee-requirements.txt && \
pip uninstall -y py && \
playwright install chromium && \
playwright install-deps chromium && \
ln -s /usr/local/bin/supervisord /usr/bin/supervisord
# Cleanup for CVEs and size reduction
# https://github.com/tornadoweb/tornado/issues/3107
# xserver-common and xvfb included by playwright installation but not needed after
# perl-base is part of the base Python Debian image but not needed for Onyx functionality
# perl-base could only be removed with --allow-remove-essential
RUN apt-get update && \
apt-get remove -y --allow-remove-essential \
perl-base \
xserver-common \
xvfb \
cmake \
libldap-2.5-0 \
libxmlsec1-dev \
pkg-config \
gcc && \
apt-get install -y libxmlsec1-openssl && \
apt-get autoremove -y && \
rm -rf /var/lib/apt/lists/* && \
rm -f /usr/local/lib/python3.11/site-packages/tornado/test/test.key
# Set up application files
WORKDIR /app
@@ -76,6 +74,9 @@ COPY ./alembic.ini /app/alembic.ini
COPY ./pytest.ini /app/pytest.ini
COPY supervisord.conf /usr/etc/supervisord.conf
# need to copy over model server as well, since we're running it in the same container
COPY ./model_server /app/model_server
# Integration test stuff
COPY ./requirements/dev.txt /tmp/dev-requirements.txt
RUN pip install --no-cache-dir --upgrade \
@@ -84,5 +85,6 @@ COPY ./tests/integration /app/tests/integration
ENV PYTHONPATH=/app
ENTRYPOINT ["pytest", "-s"]
CMD ["/app/tests/integration", "--ignore=/app/tests/integration/multitenant_tests"]
ENTRYPOINT []
# let caller specify the command
CMD ["tail", "-f", "/dev/null"]

View File

@@ -1,6 +1,7 @@
import os
ADMIN_USER_NAME = "admin_user"
GUARANTEED_FRESH_SETUP = os.getenv("GUARANTEED_FRESH_SETUP") == "true"
API_SERVER_PROTOCOL = os.getenv("API_SERVER_PROTOCOL") or "http"
API_SERVER_HOST = os.getenv("API_SERVER_HOST") or "localhost"

View File

@@ -1,5 +1,9 @@
import contextlib
import io
import logging
import sys
import time
from logging import Logger
from types import SimpleNamespace
import psycopg2
@@ -11,10 +15,12 @@ from onyx.configs.app_configs import POSTGRES_HOST
from onyx.configs.app_configs import POSTGRES_PASSWORD
from onyx.configs.app_configs import POSTGRES_PORT
from onyx.configs.app_configs import POSTGRES_USER
from onyx.configs.app_configs import REDIS_PORT
from onyx.db.engine import build_connection_string
from onyx.db.engine import get_all_tenant_ids
from onyx.db.engine import get_session_context_manager
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import SqlEngine
from onyx.db.engine import SYNC_DB_API
from onyx.db.search_settings import get_current_search_settings
from onyx.db.swap_index import check_index_swap
@@ -22,17 +28,20 @@ from onyx.document_index.document_index_utils import get_multipass_config
from onyx.document_index.vespa.index import DOCUMENT_ID_ENDPOINT
from onyx.document_index.vespa.index import VespaIndex
from onyx.indexing.models import IndexingSetting
from onyx.redis.redis_pool import redis_pool
from onyx.setup import setup_postgres
from onyx.setup import setup_vespa
from onyx.utils.logger import setup_logger
from tests.integration.common_utils.timeout import run_with_timeout
logger = setup_logger()
def _run_migrations(
database_url: str,
database: str,
config_name: str,
postgres_host: str,
postgres_port: str,
redis_port: int,
direction: str = "upgrade",
revision: str = "head",
schema: str = "public",
@@ -46,9 +55,28 @@ def _run_migrations(
alembic_cfg.attributes["configure_logger"] = False
alembic_cfg.config_ini_section = config_name
# Add environment variables to the config attributes
alembic_cfg.attributes["env_vars"] = {
"POSTGRES_HOST": postgres_host,
"POSTGRES_PORT": postgres_port,
"POSTGRES_DB": database,
# some migrations call redis directly, so we need to pass the port
"REDIS_PORT": str(redis_port),
}
alembic_cfg.cmd_opts = SimpleNamespace() # type: ignore
alembic_cfg.cmd_opts.x = [f"schema={schema}"] # type: ignore
# Build the database URL
database_url = build_connection_string(
db=database,
user=POSTGRES_USER,
password=POSTGRES_PASSWORD,
host=postgres_host,
port=postgres_port,
db_api=SYNC_DB_API,
)
# Set the SQLAlchemy URL in the Alembic configuration
alembic_cfg.set_main_option("sqlalchemy.url", database_url)
@@ -71,6 +99,9 @@ def downgrade_postgres(
config_name: str = "alembic",
revision: str = "base",
clear_data: bool = False,
postgres_host: str = POSTGRES_HOST,
postgres_port: str = POSTGRES_PORT,
redis_port: int = REDIS_PORT,
) -> None:
"""Downgrade Postgres database to base state."""
if clear_data:
@@ -81,8 +112,8 @@ def downgrade_postgres(
dbname=database,
user=POSTGRES_USER,
password=POSTGRES_PASSWORD,
host=POSTGRES_HOST,
port=POSTGRES_PORT,
host=postgres_host,
port=postgres_port,
)
conn.autocommit = True # Need autocommit for dropping schema
cur = conn.cursor()
@@ -112,37 +143,32 @@ def downgrade_postgres(
return
# Downgrade to base
conn_str = build_connection_string(
db=database,
user=POSTGRES_USER,
password=POSTGRES_PASSWORD,
host=POSTGRES_HOST,
port=POSTGRES_PORT,
db_api=SYNC_DB_API,
)
_run_migrations(
conn_str,
config_name,
database=database,
config_name=config_name,
postgres_host=postgres_host,
postgres_port=postgres_port,
redis_port=redis_port,
direction="downgrade",
revision=revision,
)
def upgrade_postgres(
database: str = "postgres", config_name: str = "alembic", revision: str = "head"
database: str = "postgres",
config_name: str = "alembic",
revision: str = "head",
postgres_host: str = POSTGRES_HOST,
postgres_port: str = POSTGRES_PORT,
redis_port: int = REDIS_PORT,
) -> None:
"""Upgrade Postgres database to latest version."""
conn_str = build_connection_string(
db=database,
user=POSTGRES_USER,
password=POSTGRES_PASSWORD,
host=POSTGRES_HOST,
port=POSTGRES_PORT,
db_api=SYNC_DB_API,
)
_run_migrations(
conn_str,
config_name,
database=database,
config_name=config_name,
postgres_host=postgres_host,
postgres_port=postgres_port,
redis_port=redis_port,
direction="upgrade",
revision=revision,
)
@@ -152,46 +178,44 @@ def reset_postgres(
database: str = "postgres",
config_name: str = "alembic",
setup_onyx: bool = True,
postgres_host: str = POSTGRES_HOST,
postgres_port: str = POSTGRES_PORT,
redis_port: int = REDIS_PORT,
) -> None:
"""Reset the Postgres database."""
# this seems to hang due to locking issues, so run with a timeout with a few retries
NUM_TRIES = 10
TIMEOUT = 10
success = False
for _ in range(NUM_TRIES):
logger.info(f"Downgrading Postgres... ({_ + 1}/{NUM_TRIES})")
try:
run_with_timeout(
downgrade_postgres,
TIMEOUT,
kwargs={
"database": database,
"config_name": config_name,
"revision": "base",
"clear_data": True,
},
)
success = True
break
except TimeoutError:
logger.warning(
f"Postgres downgrade timed out, retrying... ({_ + 1}/{NUM_TRIES})"
)
if not success:
raise RuntimeError("Postgres downgrade failed after 10 timeouts.")
logger.info("Upgrading Postgres...")
upgrade_postgres(database=database, config_name=config_name, revision="head")
downgrade_postgres(
database=database,
config_name=config_name,
revision="base",
clear_data=True,
postgres_host=postgres_host,
postgres_port=postgres_port,
redis_port=redis_port,
)
upgrade_postgres(
database=database,
config_name=config_name,
revision="head",
postgres_host=postgres_host,
postgres_port=postgres_port,
redis_port=redis_port,
)
if setup_onyx:
logger.info("Setting up Postgres...")
with get_session_context_manager() as db_session:
setup_postgres(db_session)
def reset_vespa() -> None:
"""Wipe all data from the Vespa index."""
def reset_vespa(
skip_creating_indices: bool, document_id_endpoint: str = DOCUMENT_ID_ENDPOINT
) -> None:
"""Wipe all data from the Vespa index.
Args:
skip_creating_indices: If True, the indices will not be recreated.
This is useful if the indices already exist and you do not want to
recreate them (e.g. when running parallel tests).
"""
with get_session_context_manager() as db_session:
# swap to the correct default model
check_index_swap(db_session)
@@ -200,18 +224,21 @@ def reset_vespa() -> None:
multipass_config = get_multipass_config(search_settings)
index_name = search_settings.index_name
success = setup_vespa(
document_index=VespaIndex(
index_name=index_name,
secondary_index_name=None,
large_chunks_enabled=multipass_config.enable_large_chunks,
secondary_large_chunks_enabled=None,
),
index_setting=IndexingSetting.from_db_model(search_settings),
secondary_index_setting=None,
)
if not success:
raise RuntimeError("Could not connect to Vespa within the specified timeout.")
if not skip_creating_indices:
success = setup_vespa(
document_index=VespaIndex(
index_name=index_name,
secondary_index_name=None,
large_chunks_enabled=multipass_config.enable_large_chunks,
secondary_large_chunks_enabled=None,
),
index_setting=IndexingSetting.from_db_model(search_settings),
secondary_index_setting=None,
)
if not success:
raise RuntimeError(
"Could not connect to Vespa within the specified timeout."
)
for _ in range(5):
try:
@@ -222,7 +249,7 @@ def reset_vespa() -> None:
if continuation:
params = {**params, "continuation": continuation}
response = requests.delete(
DOCUMENT_ID_ENDPOINT.format(index_name=index_name), params=params
document_id_endpoint.format(index_name=index_name), params=params
)
response.raise_for_status()
@@ -336,11 +363,99 @@ def reset_vespa_multitenant() -> None:
time.sleep(5)
def reset_all() -> None:
def reset_all(
database: str = "postgres",
postgres_host: str = POSTGRES_HOST,
postgres_port: str = POSTGRES_PORT,
redis_port: int = REDIS_PORT,
silence_logs: bool = False,
skip_creating_indices: bool = False,
document_id_endpoint: str = DOCUMENT_ID_ENDPOINT,
) -> None:
if not silence_logs:
with contextlib.redirect_stdout(sys.stdout), contextlib.redirect_stderr(
sys.stderr
):
_do_reset(
database,
postgres_host,
postgres_port,
redis_port,
skip_creating_indices,
document_id_endpoint,
)
return
# Store original logging levels
loggers_to_silence: list[Logger] = [
logging.getLogger(), # Root logger
logging.getLogger("alembic"),
logger.logger, # Our custom logger
]
original_levels = [logger.level for logger in loggers_to_silence]
# Temporarily set all loggers to ERROR level
for log in loggers_to_silence:
log.setLevel(logging.ERROR)
stdout_redirect = io.StringIO()
stderr_redirect = io.StringIO()
try:
with contextlib.redirect_stdout(stdout_redirect), contextlib.redirect_stderr(
stderr_redirect
):
_do_reset(
database,
postgres_host,
postgres_port,
redis_port,
skip_creating_indices,
document_id_endpoint,
)
except Exception as e:
print(stdout_redirect.getvalue(), file=sys.stdout)
print(stderr_redirect.getvalue(), file=sys.stderr)
raise e
finally:
# Restore original logging levels
for logger_, level in zip(loggers_to_silence, original_levels):
logger_.setLevel(level)
def _do_reset(
database: str,
postgres_host: str,
postgres_port: str,
redis_port: int,
skip_creating_indices: bool,
document_id_endpoint: str,
) -> None:
"""NOTE: should only be be running in one worker/thread a time."""
# force re-create the engine to allow for the same worker to reset
# different databases
with SqlEngine._lock:
SqlEngine._engine = SqlEngine._init_engine(
host=postgres_host,
port=postgres_port,
db=database,
)
# same with redis
redis_pool._init_pools(redis_port=redis_port)
logger.info("Resetting Postgres...")
reset_postgres()
reset_postgres(
database=database,
postgres_host=postgres_host,
postgres_port=postgres_port,
redis_port=redis_port,
)
logger.info("Resetting Vespa...")
reset_vespa()
reset_vespa(
skip_creating_indices=skip_creating_indices,
document_id_endpoint=document_id_endpoint,
)
def reset_all_multitenant() -> None:

View File

@@ -7,6 +7,7 @@ from onyx.db.engine import get_session_context_manager
from onyx.db.search_settings import get_current_search_settings
from tests.integration.common_utils.constants import ADMIN_USER_NAME
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.constants import GUARANTEED_FRESH_SETUP
from tests.integration.common_utils.managers.user import build_email
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
from tests.integration.common_utils.managers.user import UserManager
@@ -14,25 +15,11 @@ from tests.integration.common_utils.reset import reset_all
from tests.integration.common_utils.reset import reset_all_multitenant
from tests.integration.common_utils.test_models import DATestUser
from tests.integration.common_utils.vespa import vespa_fixture
def load_env_vars(env_file: str = ".env") -> None:
current_dir = os.path.dirname(os.path.abspath(__file__))
env_path = os.path.join(current_dir, env_file)
try:
with open(env_path, "r") as f:
for line in f:
line = line.strip()
if line and not line.startswith("#"):
key, value = line.split("=", 1)
os.environ[key] = value.strip()
print("Successfully loaded environment variables")
except FileNotFoundError:
print(f"File {env_file} not found")
from tests.integration.introspection import load_env_vars
# Load environment variables at the module level
load_env_vars()
load_env_vars(os.environ.get("IT_ENV_FILE_PATH", ".env"))
"""NOTE: for some reason using this seems to lead to misc
@@ -57,6 +44,10 @@ def vespa_client() -> vespa_fixture:
@pytest.fixture
def reset() -> None:
if GUARANTEED_FRESH_SETUP:
print("GUARANTEED_FRESH_SETUP is true, skipping reset")
return None
reset_all()

View File

@@ -118,6 +118,7 @@ def test_google_permission_sync(
GoogleDriveService, str, DATestCCPair, DATestUser, DATestUser, DATestUser
],
) -> None:
print("Running test_google_permission_sync")
(
drive_service,
drive_id,

View File

@@ -0,0 +1,71 @@
import os
from pathlib import Path
import pytest
from _pytest.nodes import Item
def list_all_tests(directory: str | Path = ".") -> list[str]:
"""
List all pytest test functions under the specified directory.
Args:
directory: Directory path to search for tests (defaults to current directory)
Returns:
List of test function names with their module paths
"""
directory = Path(directory).absolute()
print(f"Searching for tests in: {directory}")
class TestCollector:
def __init__(self) -> None:
self.collected: list[str] = []
def pytest_collection_modifyitems(self, items: list[Item]) -> None:
for item in items:
if isinstance(item, Item):
# Get the relative path from the test file to the directory we're searching from
rel_path = Path(item.fspath).relative_to(directory)
# Remove the .py extension
module_path = str(rel_path.with_suffix(""))
# Replace directory separators with dots
module_path = module_path.replace("/", ".")
test_name = item.name
self.collected.append(f"{module_path}::{test_name}")
collector = TestCollector()
# Run pytest in collection-only mode
pytest.main(
[
str(directory),
"--collect-only",
"-q", # quiet mode
],
plugins=[collector],
)
return sorted(collector.collected)
def load_env_vars(env_file: str = ".env") -> None:
current_dir = os.path.dirname(os.path.abspath(__file__))
env_path = os.path.join(current_dir, env_file)
try:
with open(env_path, "r") as f:
for line in f:
line = line.strip()
if line and not line.startswith("#"):
key, value = line.split("=", 1)
os.environ[key] = value.strip()
print("Successfully loaded environment variables")
except FileNotFoundError:
print(f"File {env_file} not found")
if __name__ == "__main__":
tests = list_all_tests()
print("\nFound tests:")
for test in tests:
print(f"- {test}")

View File

@@ -0,0 +1,637 @@
#!/usr/bin/env python3
import atexit
import os
import random
import signal
import socket
import subprocess
import sys
import time
import uuid
from collections.abc import Callable
from multiprocessing.pool import ThreadPool
from pathlib import Path
from typing import NamedTuple
import requests
import yaml
from onyx.configs.constants import AuthType
from onyx.document_index.vespa.index import VespaIndex
BACKEND_DIR_PATH = Path(__file__).parent.parent.parent
COMPOSE_DIR_PATH = BACKEND_DIR_PATH.parent / "deployment/docker_compose"
DEFAULT_EMBEDDING_DIMENSION = 768
DEFAULT_SCHEMA_NAME = "danswer_chunk_nomic_ai_nomic_embed_text_v1"
class DeploymentConfig(NamedTuple):
instance_num: int
api_port: int
web_port: int
nginx_port: int
redis_port: int
postgres_db: str
class SharedServicesConfig(NamedTuple):
run_id: uuid.UUID
postgres_port: int
vespa_port: int
vespa_tenant_port: int
model_server_port: int
def get_random_port() -> int:
"""Find a random available port."""
while True:
port = random.randint(10000, 65535)
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
if sock.connect_ex(("localhost", port)) != 0:
return port
def cleanup_pid(pid: int) -> None:
"""Cleanup a specific PID."""
print(f"Killing process {pid}")
try:
os.kill(pid, signal.SIGTERM)
except ProcessLookupError:
print(f"Process {pid} not found")
def get_shared_services_stack_name(run_id: uuid.UUID) -> str:
return f"base-onyx-{run_id}"
def get_db_name(instance_num: int) -> str:
"""Get the database name for a given instance number."""
return f"onyx_{instance_num}"
def get_vector_db_prefix(instance_num: int) -> str:
"""Get the vector DB prefix for a given instance number."""
return f"test_instance_{instance_num}"
def setup_db(
instance_num: int,
postgres_port: int,
) -> str:
env = os.environ.copy()
# Wait for postgres to be ready
max_attempts = 10
for attempt in range(max_attempts):
try:
subprocess.run(
[
"psql",
"-h",
"localhost",
"-p",
str(postgres_port),
"-U",
"postgres",
"-c",
"SELECT 1",
],
env={**env, "PGPASSWORD": "password"},
check=True,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
break
except subprocess.CalledProcessError:
if attempt == max_attempts - 1:
raise RuntimeError("Postgres failed to become ready within timeout")
time.sleep(1)
db_name = get_db_name(instance_num)
# Create the database first
subprocess.run(
[
"psql",
"-h",
"localhost",
"-p",
str(postgres_port),
"-U",
"postgres",
"-c",
f"CREATE DATABASE {db_name}",
],
env={**env, "PGPASSWORD": "password"},
check=True,
)
# NEW: Stamp this brand-new DB at 'base' so Alembic doesn't fail
subprocess.run(
[
"alembic",
"stamp",
"base",
],
env={
**env,
"PGPASSWORD": "password",
"POSTGRES_HOST": "localhost",
"POSTGRES_PORT": str(postgres_port),
"POSTGRES_DB": db_name,
},
check=True,
cwd=str(BACKEND_DIR_PATH),
)
# Run alembic upgrade to create tables
max_attempts = 3
for attempt in range(max_attempts):
try:
subprocess.run(
["alembic", "upgrade", "head"],
env={
**env,
"POSTGRES_HOST": "localhost",
"POSTGRES_PORT": str(postgres_port),
"POSTGRES_DB": db_name,
},
check=True,
cwd=str(BACKEND_DIR_PATH),
)
break
except subprocess.CalledProcessError:
if attempt == max_attempts - 1:
raise
print("Alembic upgrade failed, retrying in 5 seconds...")
time.sleep(5)
return db_name
def start_api_server(
instance_num: int,
model_server_port: int,
postgres_port: int,
vespa_port: int,
vespa_tenant_port: int,
redis_port: int,
register_process: Callable[[subprocess.Popen], None],
) -> int:
"""Start the API server.
NOTE: assumes that Postgres is all set up (database exists, migrations ran)
"""
print("Starting API server...")
db_name = get_db_name(instance_num)
vector_db_prefix = get_vector_db_prefix(instance_num)
env = os.environ.copy()
env.update(
{
"POSTGRES_HOST": "localhost",
"POSTGRES_PORT": str(postgres_port),
"POSTGRES_DB": db_name,
"REDIS_HOST": "localhost",
"REDIS_PORT": str(redis_port),
"VESPA_HOST": "localhost",
"VESPA_PORT": str(vespa_port),
"VESPA_TENANT_PORT": str(vespa_tenant_port),
"MODEL_SERVER_PORT": str(model_server_port),
"VECTOR_DB_INDEX_NAME_PREFIX__INTEGRATION_TEST_ONLY": vector_db_prefix,
"LOG_LEVEL": "debug",
"AUTH_TYPE": AuthType.BASIC,
}
)
port = get_random_port()
# Open log file for API server in /tmp
log_file = open(f"/tmp/api_server_{instance_num}.txt", "w")
process = subprocess.Popen(
[
"uvicorn",
"onyx.main:app",
"--host",
"localhost",
"--port",
str(port),
],
env=env,
cwd=str(BACKEND_DIR_PATH),
stdout=log_file,
stderr=subprocess.STDOUT,
)
register_process(process)
return port
def start_background(
instance_num: int,
postgres_port: int,
vespa_port: int,
vespa_tenant_port: int,
redis_port: int,
register_process: Callable[[subprocess.Popen], None],
) -> None:
"""Start the background process."""
print("Starting background process...")
env = os.environ.copy()
env.update(
{
"POSTGRES_HOST": "localhost",
"POSTGRES_PORT": str(postgres_port),
"POSTGRES_DB": get_db_name(instance_num),
"REDIS_HOST": "localhost",
"REDIS_PORT": str(redis_port),
"VESPA_HOST": "localhost",
"VESPA_PORT": str(vespa_port),
"VESPA_TENANT_PORT": str(vespa_tenant_port),
"VECTOR_DB_INDEX_NAME_PREFIX__INTEGRATION_TEST_ONLY": get_vector_db_prefix(
instance_num
),
"LOG_LEVEL": "debug",
}
)
str(Path(__file__).parent / "backend")
# Open log file for background process in /tmp
log_file = open(f"/tmp/background_{instance_num}.txt", "w")
process = subprocess.Popen(
["supervisord", "-n", "-c", "./supervisord.conf"],
env=env,
cwd=str(BACKEND_DIR_PATH),
stdout=log_file,
stderr=subprocess.STDOUT,
)
register_process(process)
def start_shared_services(run_id: uuid.UUID) -> SharedServicesConfig:
"""Start Postgres and Vespa using docker-compose.
Returns (postgres_port, vespa_port, vespa_tenant_port, model_server_port)
"""
print("Starting database services...")
postgres_port = get_random_port()
vespa_port = get_random_port()
vespa_tenant_port = get_random_port()
model_server_port = get_random_port()
minimal_compose = {
"services": {
"relational_db": {
"image": "postgres:15.2-alpine",
"command": "-c 'max_connections=1000'",
"environment": {
"POSTGRES_USER": os.getenv("POSTGRES_USER", "postgres"),
"POSTGRES_PASSWORD": os.getenv("POSTGRES_PASSWORD", "password"),
},
"ports": [f"{postgres_port}:5432"],
},
"index": {
"image": "vespaengine/vespa:8.277.17",
"ports": [
f"{vespa_port}:8081", # Main Vespa port
f"{vespa_tenant_port}:19071", # Tenant port
],
},
},
}
# Write the minimal compose file
temp_compose = Path("/tmp/docker-compose.minimal.yml")
with open(temp_compose, "w") as f:
yaml.dump(minimal_compose, f)
# Start the services
subprocess.run(
[
"docker",
"compose",
"-f",
str(temp_compose),
"-p",
get_shared_services_stack_name(run_id),
"up",
"-d",
],
check=True,
)
# Start the shared model server
env = os.environ.copy()
env.update(
{
"POSTGRES_HOST": "localhost",
"POSTGRES_PORT": str(postgres_port),
"VESPA_HOST": "localhost",
"VESPA_PORT": str(vespa_port),
"VESPA_TENANT_PORT": str(vespa_tenant_port),
"LOG_LEVEL": "debug",
}
)
# Open log file for shared model server in /tmp
log_file = open("/tmp/shared_model_server.txt", "w")
process = subprocess.Popen(
[
"uvicorn",
"model_server.main:app",
"--host",
"0.0.0.0",
"--port",
str(model_server_port),
],
env=env,
cwd=str(BACKEND_DIR_PATH),
stdout=log_file,
stderr=subprocess.STDOUT,
)
atexit.register(cleanup_pid, process.pid)
shared_services_config = SharedServicesConfig(
run_id, postgres_port, vespa_port, vespa_tenant_port, model_server_port
)
print(f"Shared services config: {shared_services_config}")
return shared_services_config
def prepare_vespa(instance_ids: list[int], vespa_tenant_port: int) -> None:
schema_names = [
(
f"{get_vector_db_prefix(instance_id)}_{DEFAULT_SCHEMA_NAME}",
DEFAULT_EMBEDDING_DIMENSION,
False,
)
for instance_id in instance_ids
]
print(f"Creating indices: {schema_names}")
for _ in range(7):
try:
VespaIndex.create_indices(
schema_names, f"http://localhost:{vespa_tenant_port}/application/v2"
)
return
except Exception as e:
print(f"Error creating indices: {e}. Trying again in 5 seconds...")
time.sleep(5)
raise RuntimeError("Failed to create indices in Vespa")
def start_redis(
instance_num: int,
register_process: Callable[[subprocess.Popen], None],
) -> int:
"""Start a Redis instance for a specific deployment."""
print(f"Starting Redis for instance {instance_num}...")
redis_port = get_random_port()
container_name = f"redis-onyx-{instance_num}"
# Start Redis using docker run
subprocess.run(
[
"docker",
"run",
"-d",
"--name",
container_name,
"-p",
f"{redis_port}:6379",
"redis:7.4-alpine",
"redis-server",
"--save",
'""',
"--appendonly",
"no",
],
check=True,
)
return redis_port
def launch_instance(
instance_num: int,
postgres_port: int,
vespa_port: int,
vespa_tenant_port: int,
model_server_port: int,
register_process: Callable[[subprocess.Popen], None],
) -> DeploymentConfig:
"""Launch a Docker Compose instance with custom ports."""
api_port = get_random_port()
web_port = get_random_port()
nginx_port = get_random_port()
# Start Redis for this instance
redis_port = start_redis(instance_num, register_process)
try:
db_name = setup_db(instance_num, postgres_port)
api_port = start_api_server(
instance_num,
model_server_port, # Use shared model server port
postgres_port,
vespa_port,
vespa_tenant_port,
redis_port,
register_process,
)
start_background(
instance_num,
postgres_port,
vespa_port,
vespa_tenant_port,
redis_port,
register_process,
)
except Exception as e:
print(f"Failed to start API server for instance {instance_num}: {e}")
raise
return DeploymentConfig(
instance_num, api_port, web_port, nginx_port, redis_port, db_name
)
def wait_for_instance(
ports: DeploymentConfig, max_attempts: int = 120, wait_seconds: int = 2
) -> None:
"""Wait for an instance to be healthy."""
print(f"Waiting for instance {ports.instance_num} to be ready...")
for attempt in range(1, max_attempts + 1):
try:
response = requests.get(f"http://localhost:{ports.api_port}/health")
if response.status_code == 200:
print(
f"Instance {ports.instance_num} is ready on port {ports.api_port}"
)
return
raise ConnectionError(
f"Health check returned status {response.status_code}"
)
except (requests.RequestException, ConnectionError):
if attempt == max_attempts:
raise TimeoutError(
f"Timeout waiting for instance {ports.instance_num} "
f"on port {ports.api_port}"
)
print(
f"Waiting for instance {ports.instance_num} on port "
f" {ports.api_port}... ({attempt}/{max_attempts})"
)
time.sleep(wait_seconds)
def cleanup_instance(instance_num: int) -> None:
"""Cleanup a specific instance."""
print(f"Cleaning up instance {instance_num}...")
temp_compose = Path(f"/tmp/docker-compose.dev.instance{instance_num}.yml")
try:
subprocess.run(
[
"docker",
"compose",
"-f",
str(temp_compose),
"-p",
f"onyx-stack-{instance_num}",
"down",
],
check=True,
)
print(f"Instance {instance_num} cleaned up successfully")
except subprocess.CalledProcessError:
print(f"Error cleaning up instance {instance_num}")
except FileNotFoundError:
print(f"No compose file found for instance {instance_num}")
finally:
# Clean up the temporary compose file if it exists
if temp_compose.exists():
temp_compose.unlink()
print(f"Removed temporary compose file for instance {instance_num}")
def run_x_instances(
num_instances: int,
) -> tuple[SharedServicesConfig, list[DeploymentConfig]]:
"""Start x instances of the application and return their configurations."""
run_id = uuid.uuid4()
instance_ids = list(range(1, num_instances + 1))
_pids: list[int] = []
def register_process(process: subprocess.Popen) -> None:
_pids.append(process.pid)
def cleanup_all_instances() -> None:
"""Cleanup all instances."""
print("Cleaning up all instances...")
# Stop the database services
subprocess.run(
[
"docker",
"compose",
"-p",
get_shared_services_stack_name(run_id),
"-f",
"/tmp/docker-compose.minimal.yml",
"down",
],
check=True,
)
# Stop and remove all Redis containers
for instance_id in range(1, num_instances + 1):
container_name = f"redis-onyx-{instance_id}"
try:
subprocess.run(["docker", "rm", "-f", container_name], check=True)
except subprocess.CalledProcessError:
print(f"Error cleaning up Redis container {container_name}")
for pid in _pids:
cleanup_pid(pid)
# Register cleanup handler
atexit.register(cleanup_all_instances)
# Start database services first
print("Starting shared services...")
shared_services_config = start_shared_services(run_id)
# create documents
print("Creating indices in Vespa...")
prepare_vespa(instance_ids, shared_services_config.vespa_tenant_port)
# Use ThreadPool to launch instances in parallel and collect results
# NOTE: only kick off 10 at a time to avoid overwhelming the system
print("Launching instances...")
with ThreadPool(processes=len(instance_ids)) as pool:
# Create list of arguments for each instance
launch_args = [
(
i,
shared_services_config.postgres_port,
shared_services_config.vespa_port,
shared_services_config.vespa_tenant_port,
shared_services_config.model_server_port,
register_process,
)
for i in instance_ids
]
# Launch instances and get results
port_configs = pool.starmap(launch_instance, launch_args)
# Wait for all instances to be healthy
print("Waiting for instances to be healthy...")
with ThreadPool(processes=len(port_configs)) as pool:
pool.map(wait_for_instance, port_configs)
print("All instances launched!")
print("Database Services:")
print(f"Postgres port: {shared_services_config.postgres_port}")
print(f"Vespa main port: {shared_services_config.vespa_port}")
print(f"Vespa tenant port: {shared_services_config.vespa_tenant_port}")
print("\nApplication Instances:")
for ports in port_configs:
print(
f"Instance {ports.instance_num}: "
f"API={ports.api_port}, Web={ports.web_port}, Nginx={ports.nginx_port}"
)
return shared_services_config, port_configs
def main() -> None:
shared_services_config, port_configs = run_x_instances(1)
# Run pytest with the API server port set
api_port = port_configs[0].api_port # Use first instance's API port
try:
subprocess.run(
["pytest", "tests/integration/openai_assistants_api"],
env={**os.environ, "API_SERVER_PORT": str(api_port)},
cwd=str(BACKEND_DIR_PATH),
check=True,
)
except subprocess.CalledProcessError as e:
print(f"Tests failed with exit code {e.returncode}")
sys.exit(e.returncode)
time.sleep(5)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,266 @@
import multiprocessing
import os
import queue
import subprocess
import sys
import threading
import time
from dataclasses import dataclass
from multiprocessing.synchronize import Lock as LockType
from pathlib import Path
from tests.integration.common_utils.reset import reset_all
from tests.integration.introspection import list_all_tests
from tests.integration.introspection import load_env_vars
from tests.integration.kickoff import BACKEND_DIR_PATH
from tests.integration.kickoff import DeploymentConfig
from tests.integration.kickoff import run_x_instances
from tests.integration.kickoff import SharedServicesConfig
@dataclass
class TestResult:
test_name: str
success: bool
output: str
error: str | None = None
def run_single_test(
test_name: str,
deployment_config: DeploymentConfig,
shared_services_config: SharedServicesConfig,
result_queue: multiprocessing.Queue,
) -> None:
"""Run a single test with the given API port."""
test_path, test_name = test_name.split("::")
processed_test_name = f"{test_path.replace('.', '/')}.py::{test_name}"
print(f"Running test: {processed_test_name}", flush=True)
try:
env = {
**os.environ,
"API_SERVER_PORT": str(deployment_config.api_port),
"PYTHONPATH": ".",
"GUARANTEED_FRESH_SETUP": "true",
"POSTGRES_PORT": str(shared_services_config.postgres_port),
"POSTGRES_DB": deployment_config.postgres_db,
"REDIS_PORT": str(deployment_config.redis_port),
"VESPA_PORT": str(shared_services_config.vespa_port),
"VESPA_TENANT_PORT": str(shared_services_config.vespa_tenant_port),
}
result = subprocess.run(
["pytest", processed_test_name, "-v"],
env=env,
cwd=str(BACKEND_DIR_PATH),
capture_output=True,
text=True,
)
result_queue.put(
TestResult(
test_name=test_name,
success=result.returncode == 0,
output=result.stdout,
error=result.stderr if result.returncode != 0 else None,
)
)
except Exception as e:
result_queue.put(
TestResult(
test_name=test_name,
success=False,
output="",
error=str(e),
)
)
def worker(
test_queue: queue.Queue[str],
instance_queue: queue.Queue[int],
result_queue: multiprocessing.Queue,
shared_services_config: SharedServicesConfig,
deployment_configs: list[DeploymentConfig],
reset_lock: LockType,
) -> None:
"""Worker process that runs tests on available instances."""
while True:
# Get the next test from the queue
try:
test = test_queue.get(block=False)
except queue.Empty:
break
# Get an available instance
instance_idx = instance_queue.get()
deployment_config = deployment_configs[
instance_idx - 1
] # Convert to 0-based index
try:
# Run the test
run_single_test(
test, deployment_config, shared_services_config, result_queue
)
# get instance ready for next test
print(
f"Resetting instance for next. DB: {deployment_config.postgres_db}, "
f"Port: {shared_services_config.postgres_port}"
)
# alembic is NOT thread-safe, so we need to make sure only one worker is resetting at a time
with reset_lock:
reset_all(
database=deployment_config.postgres_db,
postgres_port=str(shared_services_config.postgres_port),
redis_port=deployment_config.redis_port,
silence_logs=True,
# indices are created during the kickoff process, no need to recreate them
skip_creating_indices=True,
# use the special vespa port
document_id_endpoint=(
f"http://localhost:{shared_services_config.vespa_port}"
"/document/v1/default/{{index_name}}/docid"
),
)
except Exception as e:
# Log the error and put it in the result queue
error_msg = f"Critical error in worker thread for test {test}: {str(e)}"
print(error_msg, file=sys.stderr)
result_queue.put(
TestResult(
test_name=test,
success=False,
output="",
error=error_msg,
)
)
# Re-raise to stop the worker
raise
finally:
# Put the instance back in the queue
instance_queue.put(instance_idx)
test_queue.task_done()
def main() -> None:
NUM_INSTANCES = 7
# Get all tests
prefixes = ["tests", "connector_job_tests"]
tests = []
for prefix in prefixes:
tests += [
f"tests/integration/{prefix}/{test_path}"
for test_path in list_all_tests(Path(__file__).parent / prefix)
]
print(f"Found {len(tests)} tests to run")
# load env vars which will be passed into the tests
load_env_vars(os.environ.get("IT_ENV_FILE_PATH", ".env"))
# For debugging
# tests = [test for test in tests if "openai_assistants_api" in test]
# tests = tests[:2]
print(f"Running {len(tests)} tests")
# Start all instances at once
shared_services_config, deployment_configs = run_x_instances(NUM_INSTANCES)
# Create queues and lock
test_queue: queue.Queue[str] = queue.Queue()
instance_queue: queue.Queue[int] = queue.Queue()
result_queue: multiprocessing.Queue = multiprocessing.Queue()
reset_lock: LockType = multiprocessing.Lock()
# Fill the instance queue with available instance numbers
for i in range(1, NUM_INSTANCES + 1):
instance_queue.put(i)
# Fill the test queue with all tests
for test in tests:
test_queue.put(test)
# Start worker threads
workers = []
for _ in range(NUM_INSTANCES):
worker_thread = threading.Thread(
target=worker,
args=(
test_queue,
instance_queue,
result_queue,
shared_services_config,
deployment_configs,
reset_lock,
),
)
worker_thread.start()
workers.append(worker_thread)
# Monitor workers and fail fast if any die
try:
while any(w.is_alive() for w in workers):
# Check if all tests are done
if test_queue.empty() and all(not w.is_alive() for w in workers):
break
# Check for dead workers that died with unfinished tests
if not test_queue.empty() and any(not w.is_alive() for w in workers):
print(
"\nCritical: Worker thread(s) died with tests remaining!",
file=sys.stderr,
)
sys.exit(1)
time.sleep(0.1) # Avoid busy waiting
# Collect results
print("Collecting results")
results: list[TestResult] = []
while not result_queue.empty():
results.append(result_queue.get())
# Print results
print("\nTest Results:")
failed = False
failed_tests: list[str] = []
total_tests = len(results)
passed_tests = 0
for result in results:
status = "✅ PASSED" if result.success else "❌ FAILED"
print(f"{status} - {result.test_name}")
if result.success:
passed_tests += 1
else:
failed = True
failed_tests.append(result.test_name)
print("Error output:")
print(result.error)
print("Test output:")
print(result.output)
print("-" * 80)
# Print summary
print("\nTest Summary:")
print(f"Total Tests: {total_tests}")
print(f"Passed: {passed_tests}")
print(f"Failed: {len(failed_tests)}")
if failed_tests:
print("\nFailed Tests:")
for test_name in failed_tests:
print(f"{test_name}")
print()
if failed:
sys.exit(1)
except KeyboardInterrupt:
print("\nTest run interrupted by user", file=sys.stderr)
sys.exit(130) # Standard exit code for SIGINT
except Exception as e:
print(f"\nCritical error during result collection: {str(e)}", file=sys.stderr)
sys.exit(1)
if __name__ == "__main__":
main()

View File

@@ -11,7 +11,6 @@ from langchain_core.messages import HumanMessage
from langchain_core.messages import SystemMessage
from langchain_core.messages import ToolCall
from langchain_core.messages import ToolCallChunk
from pytest_mock import MockerFixture
from sqlalchemy.orm import Session
from onyx.chat.answer import Answer
@@ -26,7 +25,6 @@ from onyx.chat.models import StreamStopReason
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
from onyx.context.search.models import RerankingDetails
from onyx.context.search.models import SearchRequest
from onyx.llm.interfaces import LLM
from onyx.tools.force import ForceUseTool
@@ -37,7 +35,6 @@ from onyx.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTEN
from onyx.tools.tool_implementations.search_like_tool_utils import (
FINAL_CONTEXT_DOCUMENTS_ID,
)
from shared_configs.enums import RerankerProvider
from tests.unit.onyx.chat.conftest import DEFAULT_SEARCH_ARGS
from tests.unit.onyx.chat.conftest import QUERY
@@ -47,20 +44,6 @@ def answer_instance(
mock_llm: LLM,
answer_style_config: AnswerStyleConfig,
prompt_config: PromptConfig,
mocker: MockerFixture,
) -> Answer:
mocker.patch(
"onyx.chat.answer.gpu_status_request",
return_value=True,
)
return _answer_fixture_impl(mock_llm, answer_style_config, prompt_config)
def _answer_fixture_impl(
mock_llm: LLM,
answer_style_config: AnswerStyleConfig,
prompt_config: PromptConfig,
rerank_settings: RerankingDetails | None = None,
) -> Answer:
return Answer(
prompt_builder=AnswerPromptBuilder(
@@ -81,13 +64,13 @@ def _answer_fixture_impl(
llm=mock_llm,
fast_llm=mock_llm,
force_use_tool=ForceUseTool(force_use=False, tool_name="", args=None),
search_request=SearchRequest(query=QUERY, rerank_settings=rerank_settings),
search_request=SearchRequest(query=QUERY),
chat_session_id=UUID("123e4567-e89b-12d3-a456-426614174000"),
current_agent_message_id=0,
)
def test_basic_answer(answer_instance: Answer, mocker: MockerFixture) -> None:
def test_basic_answer(answer_instance: Answer) -> None:
mock_llm = cast(Mock, answer_instance.graph_config.tooling.primary_llm)
mock_llm.stream.return_value = [
AIMessageChunk(content="This is a "),
@@ -380,49 +363,3 @@ def test_is_cancelled(answer_instance: Answer) -> None:
# Verify LLM calls
mock_llm.stream.assert_called_once()
@pytest.mark.parametrize(
"gpu_enabled,is_local_model",
[
(True, False),
(False, True),
(True, True),
(False, False),
],
)
def test_no_slow_reranking(
gpu_enabled: bool,
is_local_model: bool,
mock_llm: LLM,
answer_style_config: AnswerStyleConfig,
prompt_config: PromptConfig,
mocker: MockerFixture,
) -> None:
mocker.patch(
"onyx.chat.answer.gpu_status_request",
return_value=gpu_enabled,
)
rerank_settings = (
None
if is_local_model
else RerankingDetails(
rerank_model_name="test_model",
rerank_api_url="test_url",
rerank_api_key="test_key",
num_rerank=10,
rerank_provider_type=RerankerProvider.COHERE,
)
)
answer_instance = _answer_fixture_impl(
mock_llm, answer_style_config, prompt_config, rerank_settings=rerank_settings
)
assert (
answer_instance.graph_config.inputs.search_request.rerank_settings
== rerank_settings
)
assert (
answer_instance.graph_config.behavior.allow_agent_reranking == gpu_enabled
or not is_local_model
)

View File

@@ -36,12 +36,7 @@ def test_skip_gen_ai_answer_generation_flag(
mock_search_tool: SearchTool,
answer_style_config: AnswerStyleConfig,
prompt_config: PromptConfig,
mocker: MockerFixture,
) -> None:
mocker.patch(
"onyx.chat.answer.gpu_status_request",
return_value=True,
)
question = config["question"]
skip_gen_ai_answer_generation = config["skip_gen_ai_answer_generation"]

View File

@@ -1,61 +0,0 @@
import time
import pytest
from onyx.utils.threadpool_concurrency import run_with_timeout
def test_run_with_timeout_completes() -> None:
"""Test that a function that completes within timeout works correctly"""
def quick_function(x: int) -> int:
return x * 2
result = run_with_timeout(1.0, quick_function, x=21)
assert result == 42
@pytest.mark.parametrize("slow,timeout", [(1, 0.1), (0.3, 0.2)])
def test_run_with_timeout_raises_on_timeout(slow: float, timeout: float) -> None:
"""Test that a function that exceeds timeout raises TimeoutError"""
def slow_function() -> None:
time.sleep(slow) # Sleep for 2 seconds
with pytest.raises(TimeoutError) as exc_info:
start = time.time()
run_with_timeout(timeout, slow_function) # Set timeout to 0.1 seconds
end = time.time()
assert end - start >= timeout
assert end - start < (slow + timeout) / 2
assert f"timed out after {timeout} seconds" in str(exc_info.value)
@pytest.mark.filterwarnings("ignore::pytest.PytestUnhandledThreadExceptionWarning")
def test_run_with_timeout_propagates_exceptions() -> None:
"""Test that other exceptions from the function are propagated properly"""
def error_function() -> None:
raise ValueError("Test error")
with pytest.raises(ValueError) as exc_info:
run_with_timeout(1.0, error_function)
assert "Test error" in str(exc_info.value)
def test_run_with_timeout_with_args_and_kwargs() -> None:
"""Test that args and kwargs are properly passed to the function"""
def complex_function(x: int, y: int, multiply: bool = False) -> int:
if multiply:
return x * y
return x + y
# Test with just positional args
result1 = run_with_timeout(1.0, complex_function, x=5, y=3)
assert result1 == 8
# Test with positional and keyword args
result2 = run_with_timeout(1.0, complex_function, x=5, y=3, multiply=True)
assert result2 == 15

View File

@@ -1,11 +1,6 @@
"use client";
import {
redirect,
usePathname,
useRouter,
useSearchParams,
} from "next/navigation";
import { redirect, useRouter, useSearchParams } from "next/navigation";
import {
BackendChatSession,
BackendMessage,
@@ -135,7 +130,6 @@ import {
} from "@/lib/browserUtilities";
import { Button } from "@/components/ui/button";
import { ConfirmEntityModal } from "@/components/modals/ConfirmEntityModal";
import { MessageChannel } from "node:worker_threads";
const TEMP_USER_MESSAGE_ID = -1;
const TEMP_ASSISTANT_MESSAGE_ID = -2;
@@ -1151,7 +1145,6 @@ export function ChatPage({
regenerationRequest?: RegenerationRequest | null;
overrideFileDescriptors?: FileDescriptor[];
} = {}) => {
navigatingAway.current = false;
let frozenSessionId = currentSessionId();
updateCanContinue(false, frozenSessionId);
@@ -1274,6 +1267,7 @@ export function ChatPage({
let stackTrace: string | null = null;
let sub_questions: SubQuestionDetail[] = [];
let second_level_sub_questions: SubQuestionDetail[] = [];
let is_generating: boolean = false;
let second_level_generating: boolean = false;
let finalMessage: BackendMessage | null = null;
@@ -1297,7 +1291,7 @@ export function ChatPage({
const stack = new CurrentMessageFIFO();
updateCurrentMessageFIFO(stack, {
signal: controller.signal,
signal: controller.signal, // Add this line
message: currMessage,
alternateAssistantId: currentAssistantId,
fileDescriptors: overrideFileDescriptors || currentMessageFiles,
@@ -1718,10 +1712,7 @@ export function ChatPage({
const newUrl = buildChatUrl(searchParams, currChatSessionId, null);
// newUrl is like /chat?chatId=10
// current page is like /chat
if (pathname == "/chat" && !navigatingAway.current) {
router.push(newUrl, { scroll: false });
}
router.push(newUrl, { scroll: false });
}
}
if (
@@ -2095,31 +2086,6 @@ export function ChatPage({
llmOverrideManager.updateImageFilesPresent(imageFileInMessageHistory);
}, [imageFileInMessageHistory]);
const pathname = usePathname();
useEffect(() => {
return () => {
// Cleanup which only runs when the component unmounts (i.e. when you navigate away).
const currentSession = currentSessionId();
const controller = abortControllersRef.current.get(currentSession);
if (controller) {
controller.abort();
navigatingAway.current = true;
setAbortControllers((prev) => {
const newControllers = new Map(prev);
newControllers.delete(currentSession);
return newControllers;
});
}
};
}, [pathname]);
const navigatingAway = useRef(false);
// Keep a ref to abortControllers to ensure we always have the latest value
const abortControllersRef = useRef(abortControllers);
useEffect(() => {
abortControllersRef.current = abortControllers;
}, [abortControllers]);
useSidebarShortcut(router, toggleSidebar);
const [sharedChatSession, setSharedChatSession] =
@@ -2334,7 +2300,7 @@ export function ChatPage({
fixed
left-0
z-40
bg-neutral-200
bg-background-100
h-screen
transition-all
bg-opacity-80
@@ -2591,21 +2557,12 @@ export function ChatPage({
) {
return <></>;
}
const nextMessage =
messageHistory.length > i + 1
? messageHistory[i + 1]
: null;
return (
<div
id={`message-${message.messageId}`}
key={messageReactComponentKey}
>
<HumanMessage
disableSwitchingForStreaming={
(nextMessage &&
nextMessage.is_generating) ||
false
}
stopGenerating={stopGenerating}
content={message.message}
files={message.files}

View File

@@ -94,7 +94,7 @@ export function AgenticToggle({
Agent Search (BETA)
</h3>
</div>
<p className="text-xs text-neutral-600 dark:text-neutral-700 mb-2">
<p className="text-xs text-neutarl-600 dark:text-neutral-700 mb-2">
Use AI agents to break down questions and run deep iterative
research through promising pathways. Gives more thorough and
accurate responses but takes slightly longer.

View File

@@ -113,7 +113,7 @@ export default function LLMPopover({
<Popover open={isOpen} onOpenChange={setIsOpen}>
<PopoverTrigger asChild>
<button
className="dark:text-[#fff] text-[#000] focus:outline-none"
className="focus:outline-none"
data-testid="llm-popover-trigger"
>
<ChatInputOption

View File

@@ -250,7 +250,7 @@ export async function* sendMessage({
throw new Error(`HTTP error! status: ${response.status}`);
}
yield* handleSSEStream<PacketType>(response, signal);
yield* handleSSEStream<PacketType>(response);
}
export async function nameChatSession(chatSessionId: string) {

View File

@@ -9,12 +9,6 @@ import React, {
useMemo,
useState,
} from "react";
import {
Tooltip,
TooltipContent,
TooltipProvider,
TooltipTrigger,
} from "@/components/ui/tooltip";
import ReactMarkdown from "react-markdown";
import { OnyxDocument, FilteredOnyxDocument } from "@/lib/search/interfaces";
import remarkGfm from "remark-gfm";
@@ -314,7 +308,7 @@ export const AgenticMessage = ({
const renderedAlternativeMarkdown = useMemo(() => {
return (
<ReactMarkdown
className="prose dark:prose-invert max-w-full text-base"
className="prose max-w-full text-base"
components={{
...markdownComponents,
code: ({ node, className, children }: any) => {
@@ -341,7 +335,7 @@ export const AgenticMessage = ({
const renderedMarkdown = useMemo(() => {
return (
<ReactMarkdown
className="prose dark:prose-invert max-w-full text-base"
className="prose max-w-full text-base"
components={markdownComponents}
remarkPlugins={[remarkGfm, remarkMath]}
rehypePlugins={[[rehypePrism, { ignoreMissing: true }], rehypeKatex]}
@@ -536,7 +530,6 @@ export const AgenticMessage = ({
{includeMessageSwitcher && (
<div className="-mx-1 mr-auto">
<MessageSwitcher
disableForStreaming={!isComplete}
currentPage={currentMessageInd + 1}
totalPages={otherMessagesCanSwitchTo.length}
handlePrevious={() => {
@@ -623,7 +616,6 @@ export const AgenticMessage = ({
{includeMessageSwitcher && (
<div className="-mx-1 mr-auto">
<MessageSwitcher
disableForStreaming={!isComplete}
currentPage={currentMessageInd + 1}
totalPages={otherMessagesCanSwitchTo.length}
handlePrevious={() => {
@@ -702,52 +694,27 @@ function MessageSwitcher({
totalPages,
handlePrevious,
handleNext,
disableForStreaming,
}: {
currentPage: number;
totalPages: number;
handlePrevious: () => void;
handleNext: () => void;
disableForStreaming?: boolean;
}) {
return (
<div className="flex items-center text-sm space-x-0.5">
<TooltipProvider>
<Tooltip>
<TooltipTrigger asChild>
<div>
<Hoverable
icon={FiChevronLeft}
onClick={currentPage === 1 ? undefined : handlePrevious}
/>
</div>
</TooltipTrigger>
<TooltipContent>
{disableForStreaming ? "Disabled" : "Previous"}
</TooltipContent>
</Tooltip>
</TooltipProvider>
<Hoverable
icon={FiChevronLeft}
onClick={currentPage === 1 ? undefined : handlePrevious}
/>
<span className="text-text-darker select-none">
{currentPage} / {totalPages}
{disableForStreaming ? "Complete" : "Generating"}
</span>
<TooltipProvider>
<Tooltip>
<TooltipTrigger asChild>
<div>
<Hoverable
icon={FiChevronRight}
onClick={currentPage === totalPages ? undefined : handleNext}
/>
</div>
</TooltipTrigger>
<TooltipContent>
{disableForStreaming ? "Disabled" : "Next"}
</TooltipContent>
</Tooltip>
</TooltipProvider>
<Hoverable
icon={FiChevronRight}
onClick={currentPage === totalPages ? undefined : handleNext}
/>
</div>
);
}

View File

@@ -383,7 +383,7 @@ export const AIMessage = ({
dangerouslySetInnerHTML={{ __html: htmlContent }}
/>
<ReactMarkdown
className="prose dark:prose-invert max-w-full text-base"
className="prose max-w-full text-base"
components={markdownComponents}
remarkPlugins={[remarkGfm, remarkMath]}
rehypePlugins={[[rehypePrism, { ignoreMissing: true }], rehypeKatex]}
@@ -495,10 +495,7 @@ export const AIMessage = ({
{docs && docs.length > 0 && (
<div
className={`mobile:hidden ${
(query ||
toolCall?.tool_name ===
INTERNET_SEARCH_TOOL_NAME) &&
"mt-2"
query && "mt-2"
} -mx-8 w-full mb-4 flex relative`}
>
<div className="w-full">
@@ -798,67 +795,27 @@ function MessageSwitcher({
totalPages,
handlePrevious,
handleNext,
disableForStreaming,
}: {
currentPage: number;
totalPages: number;
handlePrevious: () => void;
handleNext: () => void;
disableForStreaming?: boolean;
}) {
return (
<div className="flex items-center text-sm space-x-0.5">
<TooltipProvider>
<Tooltip>
<TooltipTrigger asChild>
<div>
<Hoverable
icon={FiChevronLeft}
onClick={
disableForStreaming
? () => null
: currentPage === 1
? undefined
: handlePrevious
}
/>
</div>
</TooltipTrigger>
<TooltipContent>
{disableForStreaming
? "Wait for agent message to complete"
: "Previous"}
</TooltipContent>
</Tooltip>
</TooltipProvider>
<Hoverable
icon={FiChevronLeft}
onClick={currentPage === 1 ? undefined : handlePrevious}
/>
<span className="text-text-darker select-none">
{currentPage} / {totalPages}
</span>
<TooltipProvider>
<Tooltip>
<TooltipTrigger>
<div>
<Hoverable
icon={FiChevronRight}
onClick={
disableForStreaming
? () => null
: currentPage === totalPages
? undefined
: handleNext
}
/>
</div>
</TooltipTrigger>
<TooltipContent>
{disableForStreaming
? "Wait for agent message to complete"
: "Next"}
</TooltipContent>
</Tooltip>
</TooltipProvider>
<Hoverable
icon={FiChevronRight}
onClick={currentPage === totalPages ? undefined : handleNext}
/>
</div>
);
}
@@ -872,7 +829,6 @@ export const HumanMessage = ({
onMessageSelection,
shared,
stopGenerating = () => null,
disableSwitchingForStreaming = false,
}: {
shared?: boolean;
content: string;
@@ -882,7 +838,6 @@ export const HumanMessage = ({
onEdit?: (editedContent: string) => void;
onMessageSelection?: (messageId: number) => void;
stopGenerating?: () => void;
disableSwitchingForStreaming?: boolean;
}) => {
const textareaRef = useRef<HTMLTextAreaElement>(null);
@@ -1112,7 +1067,6 @@ export const HumanMessage = ({
otherMessagesCanSwitchTo.length > 1 && (
<div className="ml-auto mr-3">
<MessageSwitcher
disableForStreaming={disableSwitchingForStreaming}
currentPage={currentMessageInd + 1}
totalPages={otherMessagesCanSwitchTo.length}
handlePrevious={() => {

View File

@@ -294,7 +294,7 @@ const SubQuestionDisplay: React.FC<{
const renderedMarkdown = useMemo(() => {
return (
<ReactMarkdown
className="prose dark:prose-invert max-w-full text-base"
className="prose max-w-full text-base"
components={markdownComponents}
remarkPlugins={[remarkGfm, remarkMath]}
rehypePlugins={[rehypeKatex]}
@@ -340,7 +340,7 @@ const SubQuestionDisplay: React.FC<{
{subQuestion?.question || temporaryDisplay?.question}
</div>
<ChevronDown
className={`mt-0.5 flex-none text-text-darker transition-transform duration-500 ease-in-out ${
className={`mt-0.5 text-text-darker transition-transform duration-500 ease-in-out ${
toggled ? "" : "-rotate-90"
}`}
size={20}
@@ -632,7 +632,9 @@ const SubQuestionsDisplay: React.FC<SubQuestionsDisplayProps> = ({
}
`}</style>
<div className="relative">
{/* {subQuestions.map((subQuestion, index) => ( */}
{memoizedSubQuestions.map((subQuestion, index) => (
// {dynamicSubQuestions.map((subQuestion, index) => (
<SubQuestionDisplay
currentlyOpen={
currentlyOpenQuestion?.level === subQuestion.level &&

View File

@@ -131,7 +131,7 @@ const StandardAnswersTableRow = ({
/>,
<ReactMarkdown
key={`answer-${standardAnswer.id}`}
className="prose dark:prose-invert"
className="prose"
remarkPlugins={[remarkGfm]}
>
{standardAnswer.answer}

View File

@@ -562,7 +562,6 @@ body {
.prose :where(pre):not(:where([class~="not-prose"], [class~="not-prose"] *)) {
background-color: theme("colors.code-bg");
font-size: theme("fontSize.code-sm");
color: #fff;
}
pre[class*="language-"],
@@ -656,3 +655,16 @@ ul > li > p {
display: inline;
/* Make paragraphs inline to reduce vertical space */
}
.dark strong {
color: white;
}
.prose.dark li,
.prose.dark h1,
.prose.dark h2,
.prose.dark h3,
.prose.dark h4,
.prose.dark h5 {
color: #e5e5e5;
}

View File

@@ -17,7 +17,7 @@ export const Hoverable: React.FC<{
<div className="flex items-center">
<Icon
size={size}
className="dark:text-[#B4B4B4] text-neutral-600 rounded h-fit cursor-pointer"
className="hover:bg-background-chat-hover dark:text-[#B4B4B4] text-neutral-600 rounded h-fit cursor-pointer"
/>
{hoverText && (
<div className="max-w-0 leading-none whitespace-nowrap overflow-hidden transition-all duration-300 ease-in-out group-hover:max-w-xs group-hover:ml-2">

View File

@@ -50,7 +50,7 @@ export function SearchResultIcon({ url }: { url: string }) {
return <SourceIcon sourceType={ValidSources.Web} iconSize={18} />;
}
if (url.includes("docs.onyx.app")) {
return <OnyxIcon size={18} className="dark:text-[#fff] text-[#000]" />;
return <OnyxIcon size={18} />;
}
return (

View File

@@ -23,7 +23,7 @@ export function WebResultIcon({
return (
<>
{hostname == "docs.onyx.app" ? (
<OnyxIcon size={size} className="dark:text-[#fff] text-[#000]" />
<OnyxIcon size={size} />
) : !error ? (
<img
className="my-0 rounded-full py-0"

View File

@@ -432,10 +432,7 @@ export const MarkdownFormField = ({
</div>
{isPreviewOpen ? (
<div className="p-4 border-t border-background-300">
<ReactMarkdown
className="prose dark:prose-invert"
remarkPlugins={[remarkGfm]}
>
<ReactMarkdown className="prose" remarkPlugins={[remarkGfm]}>
{field.value}
</ReactMarkdown>
</div>

View File

@@ -9,7 +9,7 @@ export default function BlurBackground({
<div
onClick={onClick}
className={`
desktop:hidden w-full h-full fixed inset-0 bg-neutral-700 bg-opacity-50 backdrop-blur-sm z-30 transition-opacity duration-300 ease-in-out ${
desktop:hidden w-full h-full fixed inset-0 bg-black bg-opacity-50 backdrop-blur-sm z-30 transition-opacity duration-300 ease-in-out ${
visible
? "opacity-100 pointer-events-auto"
: "opacity-0 pointer-events-none"

View File

@@ -35,7 +35,7 @@ export const MinimalMarkdown: React.FC<MinimalMarkdownProps> = ({
return (
<ReactMarkdown
className={`w-full text-wrap break-word prose dark:prose-invert ${className}`}
className={`w-full text-wrap break-word ${className}`}
components={markdownComponents}
remarkPlugins={[remarkGfm]}
>

View File

@@ -78,7 +78,7 @@ export function getUniqueIcons(docs: OnyxDocument[]): JSX.Element[] {
for (const doc of docs) {
// If it's a web source, we check domain uniqueness
if ((doc.is_internet || doc.source_type === ValidSources.Web) && doc.link) {
if (doc.source_type === ValidSources.Web && doc.link) {
const domain = getDomainFromUrl(doc.link);
if (domain && !seenDomains.has(domain)) {
seenDomains.add(domain);

View File

@@ -47,7 +47,7 @@ export default function LogoWithText({
className="flex gap-x-2 items-center ml-0 cursor-pointer desktop:hidden "
>
{!toggled ? (
<Logo className="desktop:hidden" height={24} width={24} />
<Logo className="desktop:hidden -my-2" height={24} width={24} />
) : (
<LogoComponent
show={toggled}

View File

@@ -23,11 +23,8 @@ import { AllUsersResponse } from "./types";
import { Credential } from "./connectors/credentials";
import { SettingsContext } from "@/components/settings/SettingsProvider";
import { Persona, PersonaLabel } from "@/app/admin/assistants/interfaces";
import {
isAnthropic,
LLMProviderDescriptor,
} from "@/app/admin/configuration/llm/interfaces";
import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
import { isAnthropic } from "@/app/admin/configuration/llm/interfaces";
import { getSourceMetadata } from "./sources";
import { AuthType, NEXT_PUBLIC_CLOUD_ENABLED } from "./constants";
import { useUser } from "@/components/user/UserProvider";

View File

@@ -79,18 +79,12 @@ export async function* handleStream<T extends NonEmptyObject>(
}
export async function* handleSSEStream<T extends PacketType>(
streamingResponse: Response,
signal?: AbortSignal
streamingResponse: Response
): AsyncGenerator<T, void, unknown> {
const reader = streamingResponse.body?.getReader();
const decoder = new TextDecoder();
let buffer = "";
if (signal) {
signal.addEventListener("abort", () => {
console.log("aborting");
reader?.cancel();
});
}
while (true) {
const rawChunk = await reader?.read();
if (!rawChunk) {

View File

@@ -21,6 +21,7 @@ module.exports = {
transitionProperty: {
spacing: "margin, padding",
},
keyframes: {
"subtle-pulse": {
"0%, 100%": { opacity: 0.9 },
@@ -147,6 +148,7 @@ module.exports = {
"text-mobile-sidebar": "var(--text-800)",
"background-search-filter": "var(--neutral-100-border-light)",
"background-search-filter-dropdown": "var(--neutral-100-border-light)",
"tw-prose-bold": "var(--text-800)",
"user-bubble": "var(--off-white)",