mirror of
https://github.com/onyx-dot-app/onyx.git
synced 2026-02-18 08:15:48 +00:00
Compare commits
17 Commits
nit_error
...
paralleliz
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4d1d5bdcfe | ||
|
|
0d468b49a1 | ||
|
|
67b87ced39 | ||
|
|
8b4e4a6c80 | ||
|
|
e26bcf5a05 | ||
|
|
435959cf90 | ||
|
|
fcbe305dc0 | ||
|
|
6f13d44564 | ||
|
|
c1810a35cd | ||
|
|
4003e7346a | ||
|
|
8057f1eb0d | ||
|
|
7eebd3cff1 | ||
|
|
bac2aeb8b7 | ||
|
|
9831697acc | ||
|
|
5da766dd3b | ||
|
|
180608694a | ||
|
|
96b92edfdb |
153
.github/workflows/pr-integration-tests-parallel.yml
vendored
Normal file
153
.github/workflows/pr-integration-tests-parallel.yml
vendored
Normal file
@@ -0,0 +1,153 @@
|
||||
name: Run Integration Tests v3
|
||||
concurrency:
|
||||
group: Run-Integration-Tests-Parallel-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
|
||||
cancel-in-progress: true
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- "release/**"
|
||||
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
|
||||
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
|
||||
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
|
||||
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
|
||||
|
||||
jobs:
|
||||
integration-tests:
|
||||
# See https://runs-on.com/runners/linux/
|
||||
runs-on:
|
||||
[runs-on, runner=32cpu-linux-x64, ram=64, "run-id=${{ github.run_id }}"]
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Set up Docker Buildx
|
||||
uses: docker/setup-buildx-action@v3
|
||||
|
||||
- name: Login to Docker Hub
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USERNAME }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: Build integration test Docker image
|
||||
uses: ./.github/actions/custom-build-and-push
|
||||
with:
|
||||
context: ./backend
|
||||
file: ./backend/tests/integration/Dockerfile
|
||||
platforms: linux/amd64
|
||||
tags: danswer/danswer-integration:test
|
||||
push: false
|
||||
load: true
|
||||
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration-parallel/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
|
||||
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration-parallel/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
|
||||
|
||||
- name: Run Standard Integration Tests
|
||||
run: |
|
||||
# Print a message indicating that tests are starting
|
||||
echo "Running integration tests..."
|
||||
|
||||
# Create a directory for test logs that will be mounted into the container
|
||||
mkdir -p ${{ github.workspace }}/test_logs
|
||||
chmod 777 ${{ github.workspace }}/test_logs
|
||||
|
||||
# Run the integration tests in a Docker container
|
||||
# Mount the Docker socket to allow Docker-in-Docker (DinD)
|
||||
# Mount the test_logs directory to capture logs
|
||||
# Use host network for easier communication with other services
|
||||
docker run \
|
||||
-v /var/run/docker.sock:/var/run/docker.sock \
|
||||
-v ${{ github.workspace }}/test_logs:/tmp \
|
||||
--network host \
|
||||
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
|
||||
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
|
||||
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
|
||||
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
|
||||
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
|
||||
-e TEST_WEB_HOSTNAME=test-runner \
|
||||
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
|
||||
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
|
||||
danswer/danswer-integration:test \
|
||||
python /app/tests/integration/run.py
|
||||
continue-on-error: true
|
||||
id: run_tests
|
||||
|
||||
- name: Check test results
|
||||
run: |
|
||||
if [ ${{ steps.run_tests.outcome }} == 'failure' ]; then
|
||||
echo "Integration tests failed. Exiting with error."
|
||||
exit 1
|
||||
else
|
||||
echo "All integration tests passed successfully."
|
||||
fi
|
||||
|
||||
- name: Collect log files
|
||||
if: success() || failure()
|
||||
run: |
|
||||
# Create a directory for logs
|
||||
mkdir -p ${{ github.workspace }}/logs
|
||||
mkdir -p ${{ github.workspace }}/logs/shared_services
|
||||
|
||||
# Copy all relevant log files from the mounted directory
|
||||
cp ${{ github.workspace }}/test_logs/api_server_*.txt ${{ github.workspace }}/logs/ || true
|
||||
cp ${{ github.workspace }}/test_logs/background_*.txt ${{ github.workspace }}/logs/ || true
|
||||
cp ${{ github.workspace }}/test_logs/shared_model_server.txt ${{ github.workspace }}/logs/ || true
|
||||
|
||||
# Collect logs from shared services (Docker containers)
|
||||
# Note: using a wildcard for the UUID part of the stack name
|
||||
docker ps -a --filter "name=base-onyx-" --format "{{.Names}}" | while read container; do
|
||||
echo "Collecting logs from $container"
|
||||
docker logs $container > "${{ github.workspace }}/logs/shared_services/${container}.log" 2>&1 || true
|
||||
done
|
||||
|
||||
# Also collect Redis container logs
|
||||
docker ps -a --filter "name=redis-onyx-" --format "{{.Names}}" | while read container; do
|
||||
echo "Collecting logs from $container"
|
||||
docker logs $container > "${{ github.workspace }}/logs/shared_services/${container}.log" 2>&1 || true
|
||||
done
|
||||
|
||||
# List collected logs
|
||||
echo "Collected log files:"
|
||||
ls -l ${{ github.workspace }}/logs/
|
||||
echo "Collected shared services logs:"
|
||||
ls -l ${{ github.workspace }}/logs/shared_services/
|
||||
|
||||
- name: Upload logs
|
||||
if: success() || failure()
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
name: integration-test-logs
|
||||
path: |
|
||||
${{ github.workspace }}/logs/
|
||||
${{ github.workspace }}/logs/shared_services/
|
||||
retention-days: 5
|
||||
|
||||
# save before stopping the containers so the logs can be captured
|
||||
# - name: Save Docker logs
|
||||
# if: success() || failure()
|
||||
# run: |
|
||||
# cd deployment/docker_compose
|
||||
# docker compose -f docker-compose.dev.yml -p danswer-stack logs > docker-compose.log
|
||||
# mv docker-compose.log ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
# - name: Stop Docker containers
|
||||
# run: |
|
||||
# cd deployment/docker_compose
|
||||
# docker compose -f docker-compose.dev.yml -p danswer-stack down -v
|
||||
|
||||
# - name: Upload logs
|
||||
# if: success() || failure()
|
||||
# uses: actions/upload-artifact@v4
|
||||
# with:
|
||||
# name: docker-logs
|
||||
# path: ${{ github.workspace }}/docker-compose.log
|
||||
|
||||
# - name: Stop Docker containers
|
||||
# run: |
|
||||
# cd deployment/docker_compose
|
||||
# docker compose -f docker-compose.dev.yml -p danswer-stack down -v
|
||||
8
.github/workflows/pr-integration-tests.yml
vendored
8
.github/workflows/pr-integration-tests.yml
vendored
@@ -5,10 +5,10 @@ concurrency:
|
||||
|
||||
on:
|
||||
merge_group:
|
||||
pull_request:
|
||||
branches:
|
||||
- main
|
||||
- "release/**"
|
||||
# pull_request:
|
||||
# branches:
|
||||
# - main
|
||||
# - "release/**"
|
||||
|
||||
env:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
|
||||
|
||||
@@ -28,11 +28,11 @@ RUN apt-get update && \
|
||||
curl \
|
||||
zip \
|
||||
ca-certificates \
|
||||
libgnutls30 \
|
||||
libblkid1 \
|
||||
libmount1 \
|
||||
libsmartcols1 \
|
||||
libuuid1 \
|
||||
libgnutls30=3.7.9-2+deb12u3 \
|
||||
libblkid1=2.38.1-5+deb12u1 \
|
||||
libmount1=2.38.1-5+deb12u1 \
|
||||
libsmartcols1=2.38.1-5+deb12u1 \
|
||||
libuuid1=2.38.1-5+deb12u1 \
|
||||
libxmlsec1-dev \
|
||||
pkg-config \
|
||||
gcc \
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from typing import Any, Literal
|
||||
from onyx.db.engine import get_iam_auth_token
|
||||
from onyx.configs.app_configs import USE_IAM_AUTH
|
||||
from onyx.db.engine import SYNC_DB_API, get_iam_auth_token
|
||||
from onyx.configs.app_configs import POSTGRES_DB, USE_IAM_AUTH
|
||||
from onyx.configs.app_configs import POSTGRES_HOST
|
||||
from onyx.configs.app_configs import POSTGRES_PORT
|
||||
from onyx.configs.app_configs import POSTGRES_USER
|
||||
@@ -13,12 +13,11 @@ from sqlalchemy import text
|
||||
from sqlalchemy.engine.base import Connection
|
||||
import os
|
||||
import ssl
|
||||
import asyncio
|
||||
import logging
|
||||
from logging.config import fileConfig
|
||||
|
||||
from alembic import context
|
||||
from sqlalchemy.ext.asyncio import create_async_engine
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.sql.schema import SchemaItem
|
||||
from onyx.configs.constants import SSL_CERT_FILE
|
||||
from shared_configs.configs import MULTI_TENANT, POSTGRES_DEFAULT_SCHEMA
|
||||
@@ -133,17 +132,32 @@ def provide_iam_token_for_alembic(
|
||||
cparams["ssl"] = ssl_context
|
||||
|
||||
|
||||
async def run_async_migrations() -> None:
|
||||
def run_migrations() -> None:
|
||||
schema_name, create_schema, upgrade_all_tenants = get_schema_options()
|
||||
|
||||
engine = create_async_engine(
|
||||
build_connection_string(),
|
||||
# Get any environment variables passed through alembic config
|
||||
env_vars = context.config.attributes.get("env_vars", {})
|
||||
|
||||
# Use env vars if provided, otherwise fall back to defaults
|
||||
postgres_host = env_vars.get("POSTGRES_HOST", POSTGRES_HOST)
|
||||
postgres_port = env_vars.get("POSTGRES_PORT", POSTGRES_PORT)
|
||||
postgres_user = env_vars.get("POSTGRES_USER", POSTGRES_USER)
|
||||
postgres_db = env_vars.get("POSTGRES_DB", POSTGRES_DB)
|
||||
|
||||
engine = create_engine(
|
||||
build_connection_string(
|
||||
db=postgres_db,
|
||||
user=postgres_user,
|
||||
host=postgres_host,
|
||||
port=postgres_port,
|
||||
db_api=SYNC_DB_API,
|
||||
),
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
if USE_IAM_AUTH:
|
||||
|
||||
@event.listens_for(engine.sync_engine, "do_connect")
|
||||
@event.listens_for(engine, "do_connect")
|
||||
def event_provide_iam_token_for_alembic(
|
||||
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any
|
||||
) -> None:
|
||||
@@ -152,31 +166,26 @@ async def run_async_migrations() -> None:
|
||||
if upgrade_all_tenants:
|
||||
tenant_schemas = get_all_tenant_ids()
|
||||
for schema in tenant_schemas:
|
||||
if schema is None:
|
||||
continue
|
||||
|
||||
try:
|
||||
logger.info(f"Migrating schema: {schema}")
|
||||
async with engine.connect() as connection:
|
||||
await connection.run_sync(
|
||||
do_run_migrations,
|
||||
schema_name=schema,
|
||||
create_schema=create_schema,
|
||||
)
|
||||
with engine.connect() as connection:
|
||||
do_run_migrations(connection, schema, create_schema)
|
||||
except Exception as e:
|
||||
logger.error(f"Error migrating schema {schema}: {e}")
|
||||
raise
|
||||
else:
|
||||
try:
|
||||
logger.info(f"Migrating schema: {schema_name}")
|
||||
async with engine.connect() as connection:
|
||||
await connection.run_sync(
|
||||
do_run_migrations,
|
||||
schema_name=schema_name,
|
||||
create_schema=create_schema,
|
||||
)
|
||||
with engine.connect() as connection:
|
||||
do_run_migrations(connection, schema_name, create_schema)
|
||||
except Exception as e:
|
||||
logger.error(f"Error migrating schema {schema_name}: {e}")
|
||||
raise
|
||||
|
||||
await engine.dispose()
|
||||
engine.dispose()
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
@@ -184,18 +193,18 @@ def run_migrations_offline() -> None:
|
||||
url = build_connection_string()
|
||||
|
||||
if upgrade_all_tenants:
|
||||
engine = create_async_engine(url)
|
||||
engine = create_engine(url)
|
||||
|
||||
if USE_IAM_AUTH:
|
||||
|
||||
@event.listens_for(engine.sync_engine, "do_connect")
|
||||
@event.listens_for(engine, "do_connect")
|
||||
def event_provide_iam_token_for_alembic_offline(
|
||||
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any
|
||||
) -> None:
|
||||
provide_iam_token_for_alembic(dialect, conn_rec, cargs, cparams)
|
||||
|
||||
tenant_schemas = get_all_tenant_ids()
|
||||
engine.sync_engine.dispose()
|
||||
engine.dispose()
|
||||
|
||||
for schema in tenant_schemas:
|
||||
logger.info(f"Migrating schema: {schema}")
|
||||
@@ -230,7 +239,7 @@ def run_migrations_offline() -> None:
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
asyncio.run(run_async_migrations())
|
||||
run_migrations()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
|
||||
@@ -5,14 +5,14 @@ from langgraph.graph import StateGraph
|
||||
from onyx.agents.agent_search.basic.states import BasicInput
|
||||
from onyx.agents.agent_search.basic.states import BasicOutput
|
||||
from onyx.agents.agent_search.basic.states import BasicState
|
||||
from onyx.agents.agent_search.orchestration.nodes.call_tool import call_tool
|
||||
from onyx.agents.agent_search.orchestration.nodes.choose_tool import choose_tool
|
||||
from onyx.agents.agent_search.orchestration.nodes.basic_use_tool_response import (
|
||||
basic_use_tool_response,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.llm_tool_choice import llm_tool_choice
|
||||
from onyx.agents.agent_search.orchestration.nodes.prepare_tool_input import (
|
||||
prepare_tool_input,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.use_tool_response import (
|
||||
basic_use_tool_response,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.tool_call import tool_call
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -33,13 +33,13 @@ def basic_graph_builder() -> StateGraph:
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="choose_tool",
|
||||
action=choose_tool,
|
||||
node="llm_tool_choice",
|
||||
action=llm_tool_choice,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
node="call_tool",
|
||||
action=call_tool,
|
||||
node="tool_call",
|
||||
action=tool_call,
|
||||
)
|
||||
|
||||
graph.add_node(
|
||||
@@ -51,12 +51,12 @@ def basic_graph_builder() -> StateGraph:
|
||||
|
||||
graph.add_edge(start_key=START, end_key="prepare_tool_input")
|
||||
|
||||
graph.add_edge(start_key="prepare_tool_input", end_key="choose_tool")
|
||||
graph.add_edge(start_key="prepare_tool_input", end_key="llm_tool_choice")
|
||||
|
||||
graph.add_conditional_edges("choose_tool", should_continue, ["call_tool", END])
|
||||
graph.add_conditional_edges("llm_tool_choice", should_continue, ["tool_call", END])
|
||||
|
||||
graph.add_edge(
|
||||
start_key="call_tool",
|
||||
start_key="tool_call",
|
||||
end_key="basic_use_tool_response",
|
||||
)
|
||||
|
||||
@@ -73,7 +73,7 @@ def should_continue(state: BasicState) -> str:
|
||||
# If there are no tool calls, basic graph already streamed the answer
|
||||
END
|
||||
if state.tool_choice is None
|
||||
else "call_tool"
|
||||
else "tool_call"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -31,14 +31,12 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_SUBANSWER_CHECK
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import SUB_ANSWER_CHECK_PROMPT
|
||||
from onyx.prompts.agent_search import UNKNOWN_ANSWER
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -87,11 +85,9 @@ def check_sub_answer(
|
||||
agent_error: AgentErrorLog | None = None
|
||||
response: BaseMessage | None = None
|
||||
try:
|
||||
response = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_SUBANSWER_CHECK,
|
||||
fast_llm.invoke,
|
||||
response = fast_llm.invoke(
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK,
|
||||
)
|
||||
|
||||
quality_str: str = cast(str, response.content)
|
||||
@@ -100,7 +96,7 @@ def check_sub_answer(
|
||||
)
|
||||
log_result = f"Answer quality: {quality_str}"
|
||||
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
except LLMTimeoutError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import merge_message_runs
|
||||
@@ -46,13 +47,11 @@ from onyx.chat.models import StreamStopInfo
|
||||
from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import StreamType
|
||||
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import NO_RECOVERED_DOCS
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -111,14 +110,15 @@ def generate_sub_answer(
|
||||
config=fast_llm.config,
|
||||
)
|
||||
|
||||
response: list[str | list[str | dict[str, Any]]] = []
|
||||
dispatch_timings: list[float] = []
|
||||
agent_error: AgentErrorLog | None = None
|
||||
response: list[str] = []
|
||||
|
||||
def stream_sub_answer() -> list[str]:
|
||||
agent_error: AgentErrorLog | None = None
|
||||
|
||||
try:
|
||||
for message in fast_llm.stream(
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION,
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
@@ -142,15 +142,8 @@ def generate_sub_answer(
|
||||
(end_stream_token - start_stream_token).microseconds
|
||||
)
|
||||
response.append(content)
|
||||
return response
|
||||
|
||||
try:
|
||||
response = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION,
|
||||
stream_sub_answer,
|
||||
)
|
||||
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
except LLMTimeoutError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
@@ -59,15 +60,11 @@ from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.configs.agent_configs import AGENT_ANSWER_GENERATION_BY_FAST_LLM
|
||||
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
|
||||
from onyx.configs.agent_configs import AGENT_MAX_STREAMED_DOCS_FOR_INITIAL_ANSWER
|
||||
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
|
||||
)
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION,
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION,
|
||||
)
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
@@ -80,7 +77,6 @@ from onyx.prompts.agent_search import (
|
||||
)
|
||||
from onyx.prompts.agent_search import UNKNOWN_ANSWER
|
||||
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
_llm_node_error_strings = LLMNodeErrorStrings(
|
||||
@@ -234,11 +230,7 @@ def generate_initial_answer(
|
||||
|
||||
sub_questions = all_sub_questions # Replace the original assignment
|
||||
|
||||
model = (
|
||||
graph_config.tooling.fast_llm
|
||||
if AGENT_ANSWER_GENERATION_BY_FAST_LLM
|
||||
else graph_config.tooling.primary_llm
|
||||
)
|
||||
model = graph_config.tooling.fast_llm
|
||||
|
||||
doc_context = format_docs(answer_generation_documents.context_documents)
|
||||
doc_context = trim_prompt_piece(
|
||||
@@ -268,16 +260,15 @@ def generate_initial_answer(
|
||||
)
|
||||
]
|
||||
|
||||
streamed_tokens: list[str] = [""]
|
||||
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
|
||||
dispatch_timings: list[float] = []
|
||||
|
||||
agent_error: AgentErrorLog | None = None
|
||||
|
||||
def stream_initial_answer() -> list[str]:
|
||||
response: list[str] = []
|
||||
try:
|
||||
for message in model.stream(
|
||||
msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION,
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
@@ -301,16 +292,9 @@ def generate_initial_answer(
|
||||
dispatch_timings.append(
|
||||
(end_stream_token - start_stream_token).microseconds
|
||||
)
|
||||
response.append(content)
|
||||
return response
|
||||
streamed_tokens.append(content)
|
||||
|
||||
try:
|
||||
streamed_tokens = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION,
|
||||
stream_initial_answer,
|
||||
)
|
||||
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
except LLMTimeoutError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
|
||||
@@ -36,10 +36,7 @@ from onyx.chat.models import StreamType
|
||||
from onyx.chat.models import SubQuestionPiece
|
||||
from onyx.configs.agent_configs import AGENT_NUM_DOCS_FOR_DECOMPOSITION
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION,
|
||||
)
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION,
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION,
|
||||
)
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
@@ -50,7 +47,6 @@ from onyx.prompts.agent_search import (
|
||||
INITIAL_QUESTION_DECOMPOSITION_PROMPT_ASSUMING_REFINEMENT,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -135,12 +131,10 @@ def decompose_orig_question(
|
||||
streamed_tokens: list[BaseMessage_Content] = []
|
||||
|
||||
try:
|
||||
streamed_tokens = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION,
|
||||
dispatch_separated,
|
||||
streamed_tokens = dispatch_separated(
|
||||
model.stream(
|
||||
msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION,
|
||||
),
|
||||
dispatch_subquestion(0, writer),
|
||||
sep_callback=dispatch_subquestion_sep(0, writer),
|
||||
@@ -160,7 +154,7 @@ def decompose_orig_question(
|
||||
)
|
||||
write_custom_event("stream_finished", stop_event, writer)
|
||||
|
||||
except (LLMTimeoutError, TimeoutError) as e:
|
||||
except LLMTimeoutError as e:
|
||||
logger.error("LLM Timeout Error - decompose orig question")
|
||||
raise e # fail loudly on this critical step
|
||||
except LLMRateLimitError as e:
|
||||
|
||||
@@ -25,7 +25,7 @@ logger = setup_logger()
|
||||
|
||||
def route_initial_tool_choice(
|
||||
state: MainState, config: RunnableConfig
|
||||
) -> Literal["call_tool", "start_agent_search", "logging_node"]:
|
||||
) -> Literal["tool_call", "start_agent_search", "logging_node"]:
|
||||
"""
|
||||
LangGraph edge to route to agent search.
|
||||
"""
|
||||
@@ -38,7 +38,7 @@ def route_initial_tool_choice(
|
||||
):
|
||||
return "start_agent_search"
|
||||
else:
|
||||
return "call_tool"
|
||||
return "tool_call"
|
||||
else:
|
||||
return "logging_node"
|
||||
|
||||
|
||||
@@ -43,14 +43,14 @@ from onyx.agents.agent_search.deep_search.main.states import MainState
|
||||
from onyx.agents.agent_search.deep_search.refinement.consolidate_sub_answers.graph_builder import (
|
||||
answer_refined_query_graph_builder,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.call_tool import call_tool
|
||||
from onyx.agents.agent_search.orchestration.nodes.choose_tool import choose_tool
|
||||
from onyx.agents.agent_search.orchestration.nodes.basic_use_tool_response import (
|
||||
basic_use_tool_response,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.llm_tool_choice import llm_tool_choice
|
||||
from onyx.agents.agent_search.orchestration.nodes.prepare_tool_input import (
|
||||
prepare_tool_input,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.use_tool_response import (
|
||||
basic_use_tool_response,
|
||||
)
|
||||
from onyx.agents.agent_search.orchestration.nodes.tool_call import tool_call
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import get_test_config
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
@@ -77,13 +77,13 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
|
||||
# Choose the initial tool
|
||||
graph.add_node(
|
||||
node="initial_tool_choice",
|
||||
action=choose_tool,
|
||||
action=llm_tool_choice,
|
||||
)
|
||||
|
||||
# Call the tool, if required
|
||||
graph.add_node(
|
||||
node="call_tool",
|
||||
action=call_tool,
|
||||
node="tool_call",
|
||||
action=tool_call,
|
||||
)
|
||||
|
||||
# Use the tool response
|
||||
@@ -168,11 +168,11 @@ def main_graph_builder(test_mode: bool = False) -> StateGraph:
|
||||
graph.add_conditional_edges(
|
||||
"initial_tool_choice",
|
||||
route_initial_tool_choice,
|
||||
["call_tool", "start_agent_search", "logging_node"],
|
||||
["tool_call", "start_agent_search", "logging_node"],
|
||||
)
|
||||
|
||||
graph.add_edge(
|
||||
start_key="call_tool",
|
||||
start_key="tool_call",
|
||||
end_key="basic_use_tool_response",
|
||||
)
|
||||
graph.add_edge(
|
||||
|
||||
@@ -33,15 +33,13 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import RefinedAnswerImprovement
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_COMPARE_ANSWERS
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import (
|
||||
INITIAL_REFINED_ANSWER_COMPARISON_PROMPT,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -107,14 +105,11 @@ def compare_answers(
|
||||
refined_answer_improvement: bool | None = None
|
||||
# no need to stream this
|
||||
try:
|
||||
resp = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_COMPARE_ANSWERS,
|
||||
model.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS,
|
||||
resp = model.invoke(
|
||||
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS
|
||||
)
|
||||
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
except LLMTimeoutError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
|
||||
@@ -44,10 +44,7 @@ from onyx.agents.agent_search.shared_graph_utils.utils import make_question_id
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION,
|
||||
)
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION,
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION,
|
||||
)
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
@@ -56,7 +53,6 @@ from onyx.prompts.agent_search import (
|
||||
)
|
||||
from onyx.tools.models import ToolCallKickoff
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -138,17 +134,15 @@ def create_refined_sub_questions(
|
||||
agent_error: AgentErrorLog | None = None
|
||||
streamed_tokens: list[BaseMessage_Content] = []
|
||||
try:
|
||||
streamed_tokens = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION,
|
||||
dispatch_separated,
|
||||
streamed_tokens = dispatch_separated(
|
||||
model.stream(
|
||||
msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION,
|
||||
),
|
||||
dispatch_subquestion(1, writer),
|
||||
sep_callback=dispatch_subquestion_sep(1, writer),
|
||||
)
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
except LLMTimeoutError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
|
||||
@@ -22,17 +22,11 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION,
|
||||
)
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION,
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION,
|
||||
)
|
||||
from onyx.configs.constants import NUM_EXPLORATORY_DOCS
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import ENTITY_TERM_EXTRACTION_PROMPT
|
||||
from onyx.prompts.agent_search import ENTITY_TERM_EXTRACTION_PROMPT_JSON_EXAMPLE
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
|
||||
@@ -90,42 +84,30 @@ def extract_entities_terms(
|
||||
]
|
||||
fast_llm = graph_config.tooling.fast_llm
|
||||
# Grader
|
||||
llm_response = fast_llm.invoke(
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION,
|
||||
)
|
||||
|
||||
cleaned_response = (
|
||||
str(llm_response.content).replace("```json\n", "").replace("\n```", "")
|
||||
)
|
||||
first_bracket = cleaned_response.find("{")
|
||||
last_bracket = cleaned_response.rfind("}")
|
||||
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
|
||||
|
||||
try:
|
||||
llm_response = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION,
|
||||
fast_llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION,
|
||||
entity_extraction_result = EntityExtractionResult.model_validate_json(
|
||||
cleaned_response
|
||||
)
|
||||
|
||||
cleaned_response = (
|
||||
str(llm_response.content).replace("```json\n", "").replace("\n```", "")
|
||||
)
|
||||
first_bracket = cleaned_response.find("{")
|
||||
last_bracket = cleaned_response.rfind("}")
|
||||
cleaned_response = cleaned_response[first_bracket : last_bracket + 1]
|
||||
|
||||
try:
|
||||
entity_extraction_result = EntityExtractionResult.model_validate_json(
|
||||
cleaned_response
|
||||
)
|
||||
except ValueError:
|
||||
logger.error(
|
||||
"Failed to parse LLM response as JSON in Entity-Term Extraction"
|
||||
)
|
||||
entity_extraction_result = EntityExtractionResult(
|
||||
retrieved_entities_relationships=EntityRelationshipTermExtraction(),
|
||||
)
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
logger.error("LLM Timeout Error - extract entities terms")
|
||||
except ValueError:
|
||||
logger.error("Failed to parse LLM response as JSON in Entity-Term Extraction")
|
||||
entity_extraction_result = EntityExtractionResult(
|
||||
retrieved_entities_relationships=EntityRelationshipTermExtraction(),
|
||||
)
|
||||
|
||||
except LLMRateLimitError:
|
||||
logger.error("LLM Rate Limit Error - extract entities terms")
|
||||
entity_extraction_result = EntityExtractionResult(
|
||||
retrieved_entities_relationships=EntityRelationshipTermExtraction(),
|
||||
retrieved_entities_relationships=EntityRelationshipTermExtraction(
|
||||
entities=[],
|
||||
relationships=[],
|
||||
terms=[],
|
||||
),
|
||||
)
|
||||
|
||||
return EntityTermExtractionUpdate(
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
from langchain_core.messages import HumanMessage
|
||||
@@ -65,21 +66,14 @@ from onyx.agents.agent_search.shared_graph_utils.utils import write_custom_event
|
||||
from onyx.chat.models import AgentAnswerPiece
|
||||
from onyx.chat.models import ExtendedToolResponse
|
||||
from onyx.chat.models import StreamingError
|
||||
from onyx.configs.agent_configs import AGENT_ANSWER_GENERATION_BY_FAST_LLM
|
||||
from onyx.configs.agent_configs import AGENT_MAX_ANSWER_CONTEXT_DOCS
|
||||
from onyx.configs.agent_configs import AGENT_MAX_STREAMED_DOCS_FOR_REFINED_ANSWER
|
||||
from onyx.configs.agent_configs import AGENT_MIN_ORIG_QUESTION_DOCS
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION,
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION,
|
||||
)
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION,
|
||||
)
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION,
|
||||
)
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION,
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION,
|
||||
)
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
@@ -98,7 +92,6 @@ from onyx.prompts.agent_search import (
|
||||
from onyx.prompts.agent_search import UNKNOWN_ANSWER
|
||||
from onyx.tools.tool_implementations.search.search_tool import yield_search_responses
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -260,12 +253,7 @@ def generate_validate_refined_answer(
|
||||
else REFINED_ANSWER_PROMPT_WO_SUB_QUESTIONS
|
||||
)
|
||||
|
||||
model = (
|
||||
graph_config.tooling.fast_llm
|
||||
if AGENT_ANSWER_GENERATION_BY_FAST_LLM
|
||||
else graph_config.tooling.primary_llm
|
||||
)
|
||||
|
||||
model = graph_config.tooling.fast_llm
|
||||
relevant_docs_str = format_docs(answer_generation_documents.context_documents)
|
||||
relevant_docs_str = trim_prompt_piece(
|
||||
model.config,
|
||||
@@ -296,13 +284,13 @@ def generate_validate_refined_answer(
|
||||
)
|
||||
]
|
||||
|
||||
streamed_tokens: list[str] = [""]
|
||||
streamed_tokens: list[str | list[str | dict[str, Any]]] = [""]
|
||||
dispatch_timings: list[float] = []
|
||||
agent_error: AgentErrorLog | None = None
|
||||
|
||||
def stream_refined_answer() -> list[str]:
|
||||
try:
|
||||
for message in model.stream(
|
||||
msg, timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION
|
||||
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION
|
||||
):
|
||||
# TODO: in principle, the answer here COULD contain images, but we don't support that yet
|
||||
content = message.content
|
||||
@@ -327,15 +315,8 @@ def generate_validate_refined_answer(
|
||||
(end_stream_token - start_stream_token).microseconds
|
||||
)
|
||||
streamed_tokens.append(content)
|
||||
return streamed_tokens
|
||||
|
||||
try:
|
||||
streamed_tokens = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION,
|
||||
stream_refined_answer,
|
||||
)
|
||||
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
except LLMTimeoutError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
@@ -402,20 +383,16 @@ def generate_validate_refined_answer(
|
||||
)
|
||||
]
|
||||
|
||||
validation_model = graph_config.tooling.fast_llm
|
||||
try:
|
||||
validation_response = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION,
|
||||
validation_model.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION,
|
||||
validation_response = model.invoke(
|
||||
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION
|
||||
)
|
||||
refined_answer_quality = binary_string_test_after_answer_separator(
|
||||
text=cast(str, validation_response.content),
|
||||
positive_value=AGENT_POSITIVE_VALUE_STR,
|
||||
separator=AGENT_ANSWER_SEPARATOR,
|
||||
)
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
except LLMTimeoutError:
|
||||
refined_answer_quality = True
|
||||
logger.error("LLM Timeout Error - validate refined answer")
|
||||
|
||||
|
||||
@@ -34,16 +34,14 @@ from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
)
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import parse_question_id
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION,
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION,
|
||||
)
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import (
|
||||
QUERY_REWRITING_PROMPT,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -71,7 +69,7 @@ def expand_queries(
|
||||
node_start_time = datetime.now()
|
||||
question = state.question
|
||||
|
||||
model = graph_config.tooling.fast_llm
|
||||
llm = graph_config.tooling.fast_llm
|
||||
sub_question_id = state.sub_question_id
|
||||
if sub_question_id is None:
|
||||
level, question_num = 0, 0
|
||||
@@ -90,12 +88,10 @@ def expand_queries(
|
||||
rewritten_queries = []
|
||||
|
||||
try:
|
||||
llm_response_list = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION,
|
||||
dispatch_separated,
|
||||
model.stream(
|
||||
llm_response_list = dispatch_separated(
|
||||
llm.stream(
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION,
|
||||
),
|
||||
dispatch_subquery(level, question_num, writer),
|
||||
)
|
||||
@@ -105,7 +101,7 @@ def expand_queries(
|
||||
rewritten_queries = llm_response.split("\n")
|
||||
log_result = f"Number of expanded queries: {len(rewritten_queries)}"
|
||||
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
except LLMTimeoutError:
|
||||
agent_error = AgentErrorLog(
|
||||
error_type=AgentLLMErrorType.TIMEOUT,
|
||||
error_message=AGENT_LLM_TIMEOUT_MESSAGE,
|
||||
|
||||
@@ -55,7 +55,6 @@ def rerank_documents(
|
||||
|
||||
# Note that these are passed in values from the API and are overrides which are typically None
|
||||
rerank_settings = graph_config.inputs.search_request.rerank_settings
|
||||
allow_agent_reranking = graph_config.behavior.allow_agent_reranking
|
||||
|
||||
if rerank_settings is None:
|
||||
with get_session_context_manager() as db_session:
|
||||
@@ -63,31 +62,23 @@ def rerank_documents(
|
||||
if not search_settings.disable_rerank_for_streaming:
|
||||
rerank_settings = RerankingDetails.from_db_model(search_settings)
|
||||
|
||||
# Initial default: no reranking. Will be overwritten below if reranking is warranted
|
||||
reranked_documents = verified_documents
|
||||
|
||||
if should_rerank(rerank_settings) and len(verified_documents) > 0:
|
||||
if len(verified_documents) > 1:
|
||||
if not allow_agent_reranking:
|
||||
logger.info("Use of local rerank model without GPU, skipping reranking")
|
||||
# No reranking, stay with verified_documents as default
|
||||
|
||||
else:
|
||||
# Reranking is warranted, use the rerank_sections functon
|
||||
reranked_documents = rerank_sections(
|
||||
query_str=question,
|
||||
# if runnable, then rerank_settings is not None
|
||||
rerank_settings=cast(RerankingDetails, rerank_settings),
|
||||
sections_to_rerank=verified_documents,
|
||||
)
|
||||
reranked_documents = rerank_sections(
|
||||
query_str=question,
|
||||
# if runnable, then rerank_settings is not None
|
||||
rerank_settings=cast(RerankingDetails, rerank_settings),
|
||||
sections_to_rerank=verified_documents,
|
||||
)
|
||||
else:
|
||||
logger.warning(
|
||||
f"{len(verified_documents)} verified document(s) found, skipping reranking"
|
||||
)
|
||||
# No reranking, stay with verified_documents as default
|
||||
reranked_documents = verified_documents
|
||||
else:
|
||||
logger.warning("No reranking settings found, using unranked documents")
|
||||
# No reranking, stay with verified_documents as default
|
||||
reranked_documents = verified_documents
|
||||
|
||||
if AGENT_RERANKING_STATS:
|
||||
fit_scores = get_fit_scores(verified_documents, reranked_documents)
|
||||
else:
|
||||
|
||||
@@ -25,15 +25,13 @@ from onyx.agents.agent_search.shared_graph_utils.models import LLMNodeErrorStrin
|
||||
from onyx.agents.agent_search.shared_graph_utils.utils import (
|
||||
get_langgraph_node_log_string,
|
||||
)
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION
|
||||
from onyx.llm.chat_llm import LLMRateLimitError
|
||||
from onyx.llm.chat_llm import LLMTimeoutError
|
||||
from onyx.prompts.agent_search import (
|
||||
DOCUMENT_VERIFICATION_PROMPT,
|
||||
)
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
from onyx.utils.timing import log_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -88,11 +86,8 @@ def verify_documents(
|
||||
] # default is to treat document as relevant
|
||||
|
||||
try:
|
||||
response = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION,
|
||||
fast_llm.invoke,
|
||||
prompt=msg,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION,
|
||||
response = fast_llm.invoke(
|
||||
msg, timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION
|
||||
)
|
||||
|
||||
assert isinstance(response.content, str)
|
||||
@@ -101,7 +96,7 @@ def verify_documents(
|
||||
):
|
||||
verified_documents = []
|
||||
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
except LLMTimeoutError:
|
||||
# In this case, we decide to continue and don't raise an error, as
|
||||
# little harm in letting some docs through that are less relevant.
|
||||
logger.error("LLM Timeout Error - verify documents")
|
||||
|
||||
@@ -67,7 +67,6 @@ class GraphSearchConfig(BaseModel):
|
||||
# Whether to allow creation of refinement questions (and entity extraction, etc.)
|
||||
allow_refinement: bool = True
|
||||
skip_gen_ai_answer_generation: bool = False
|
||||
allow_agent_reranking: bool = False
|
||||
|
||||
|
||||
class GraphConfig(BaseModel):
|
||||
|
||||
@@ -25,7 +25,7 @@ logger = setup_logger()
|
||||
# and a function that handles extracting the necessary fields
|
||||
# from the state and config
|
||||
# TODO: fan-out to multiple tool call nodes? Make this configurable?
|
||||
def choose_tool(
|
||||
def llm_tool_choice(
|
||||
state: ToolChoiceState,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
@@ -28,7 +28,7 @@ def emit_packet(packet: AnswerPacket, writer: StreamWriter) -> None:
|
||||
write_custom_event("basic_response", packet, writer)
|
||||
|
||||
|
||||
def call_tool(
|
||||
def tool_call(
|
||||
state: ToolChoiceUpdate,
|
||||
config: RunnableConfig,
|
||||
writer: StreamWriter = lambda _: None,
|
||||
@@ -43,9 +43,8 @@ from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.models import StreamType
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.configs.agent_configs import (
|
||||
AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION,
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION,
|
||||
)
|
||||
from onyx.configs.agent_configs import AGENT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION
|
||||
from onyx.configs.chat_configs import CHAT_TARGET_CHUNK_PERCENTAGE
|
||||
from onyx.configs.chat_configs import MAX_CHUNKS_FED_TO_CHAT
|
||||
from onyx.configs.constants import DEFAULT_PERSONA_ID
|
||||
@@ -81,7 +80,6 @@ from onyx.tools.tool_implementations.search.search_tool import SearchResponseSum
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.utils import explicit_tool_calling_supported
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@@ -397,13 +395,11 @@ def summarize_history(
|
||||
)
|
||||
|
||||
try:
|
||||
history_response = run_with_timeout(
|
||||
AGENT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION,
|
||||
llm.invoke,
|
||||
history_response = llm.invoke(
|
||||
history_context_prompt,
|
||||
timeout_override=AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION,
|
||||
timeout_override=AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION,
|
||||
)
|
||||
except (LLMTimeoutError, TimeoutError):
|
||||
except LLMTimeoutError:
|
||||
logger.error("LLM Timeout Error - summarize history")
|
||||
return (
|
||||
history # this is what is done at this point anyway, so we default to this
|
||||
|
||||
@@ -94,7 +94,6 @@ from onyx.db.models import User
|
||||
from onyx.db.users import get_user_by_email
|
||||
from onyx.redis.redis_pool import get_async_redis_connection
|
||||
from onyx.redis.redis_pool import get_redis_client
|
||||
from onyx.server.utils import BasicAuthenticationError
|
||||
from onyx.utils.logger import setup_logger
|
||||
from onyx.utils.telemetry import create_milestone_and_report
|
||||
from onyx.utils.telemetry import optional_telemetry
|
||||
@@ -108,6 +107,11 @@ from shared_configs.contextvars import CURRENT_TENANT_ID_CONTEXTVAR
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
class BasicAuthenticationError(HTTPException):
|
||||
def __init__(self, detail: str):
|
||||
super().__init__(status_code=status.HTTP_403_FORBIDDEN, detail=detail)
|
||||
|
||||
|
||||
def is_user_admin(user: User | None) -> bool:
|
||||
if AUTH_TYPE == AuthType.DISABLED:
|
||||
return True
|
||||
|
||||
@@ -27,10 +27,8 @@ from onyx.file_store.utils import InMemoryChatFile
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.tools.force import ForceUseTool
|
||||
from onyx.tools.tool import Tool
|
||||
from onyx.tools.tool_implementations.search.search_tool import QUERY_FIELD
|
||||
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
||||
from onyx.tools.utils import explicit_tool_calling_supported
|
||||
from onyx.utils.gpu_utils import gpu_status_request
|
||||
from onyx.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
@@ -82,26 +80,6 @@ class Answer:
|
||||
and not skip_explicit_tool_calling
|
||||
)
|
||||
|
||||
rerank_settings = search_request.rerank_settings
|
||||
|
||||
using_cloud_reranking = (
|
||||
rerank_settings is not None
|
||||
and rerank_settings.rerank_provider_type is not None
|
||||
)
|
||||
allow_agent_reranking = gpu_status_request() or using_cloud_reranking
|
||||
|
||||
# TODO: this is a hack to force the query to be used for the search tool
|
||||
# this should be removed once we fully unify graph inputs (i.e.
|
||||
# remove SearchQuery entirely)
|
||||
if (
|
||||
force_use_tool.force_use
|
||||
and search_tool
|
||||
and force_use_tool.args
|
||||
and force_use_tool.tool_name == search_tool.name
|
||||
and QUERY_FIELD in force_use_tool.args
|
||||
):
|
||||
search_request.query = force_use_tool.args[QUERY_FIELD]
|
||||
|
||||
self.graph_inputs = GraphInputs(
|
||||
search_request=search_request,
|
||||
prompt_builder=prompt_builder,
|
||||
@@ -116,6 +94,7 @@ class Answer:
|
||||
force_use_tool=force_use_tool,
|
||||
using_tool_calling_llm=using_tool_calling_llm,
|
||||
)
|
||||
assert db_session, "db_session must be provided for agentic persistence"
|
||||
self.graph_persistence = GraphPersistence(
|
||||
db_session=db_session,
|
||||
chat_session_id=chat_session_id,
|
||||
@@ -125,7 +104,6 @@ class Answer:
|
||||
use_agentic_search=use_agentic_search,
|
||||
skip_gen_ai_answer_generation=skip_gen_ai_answer_generation,
|
||||
allow_refinement=True,
|
||||
allow_agent_reranking=allow_agent_reranking,
|
||||
)
|
||||
self.graph_config = GraphConfig(
|
||||
inputs=self.graph_inputs,
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import cast
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.agents.agent_search.orchestration.nodes.call_tool import ToolCallException
|
||||
from onyx.agents.agent_search.orchestration.nodes.tool_call import ToolCallException
|
||||
from onyx.chat.answer import Answer
|
||||
from onyx.chat.chat_utils import create_chat_chain
|
||||
from onyx.chat.chat_utils import create_temporary_persona
|
||||
|
||||
@@ -31,9 +31,22 @@ AGENT_DEFAULT_MIN_ORIG_QUESTION_DOCS = 3
|
||||
AGENT_DEFAULT_MAX_ANSWER_CONTEXT_DOCS = 10
|
||||
AGENT_DEFAULT_MAX_STATIC_HISTORY_WORD_LENGTH = 2000
|
||||
|
||||
AGENT_ANSWER_GENERATION_BY_FAST_LLM = (
|
||||
os.environ.get("AGENT_ANSWER_GENERATION_BY_FAST_LLM", "").lower() == "true"
|
||||
)
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION = 30 # in seconds
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION = 10 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION = 25 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION = 4 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION = 1 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION = 3 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION = 12 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK = 8 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION = 25 # in seconds
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION = 6 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION = 25 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION = 8 # in seconds
|
||||
AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS = 8 # in seconds
|
||||
|
||||
|
||||
AGENT_RETRIEVAL_STATS = (
|
||||
not os.environ.get("AGENT_RETRIEVAL_STATS") == "False"
|
||||
@@ -165,172 +178,80 @@ AGENT_MAX_STATIC_HISTORY_WORD_LENGTH = int(
|
||||
) # 2000
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION = 10 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_ENTITY_TERM_EXTRACTION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION = 30 # in seconds
|
||||
AGENT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_ENTITY_TERM_EXTRACTION
|
||||
)
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_ENTITY_TERM_EXTRACTION
|
||||
) # 25
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION = 2 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_DOCUMENT_VERIFICATION
|
||||
)
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_DOCUMENT_VERIFICATION
|
||||
) # 3
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_DOCUMENT_VERIFICATION = 4 # in seconds
|
||||
AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_DOCUMENT_VERIFICATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_DOCUMENT_VERIFICATION
|
||||
)
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_GENERAL_GENERATION
|
||||
) # 30
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_GENERAL_GENERATION = 5 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_GENERAL_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_GENERAL_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_GENERAL_GENERATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_GENERAL_GENERATION = 30 # in seconds
|
||||
AGENT_TIMEOUT_LLM_GENERAL_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_GENERAL_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_GENERAL_GENERATION
|
||||
)
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBQUESTION_GENERATION
|
||||
) # 8
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION = 2 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBQUESTION_GENERATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_SUBQUESTION_GENERATION = 5 # in seconds
|
||||
AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_SUBQUESTION_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_SUBQUESTION_GENERATION
|
||||
)
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_GENERATION
|
||||
) # 12
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION = 3 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_GENERATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION = 30 # in seconds
|
||||
AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_SUBANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_GENERATION
|
||||
)
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_INITIAL_ANSWER_GENERATION
|
||||
) # 25
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION = 5 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_INITIAL_ANSWER_GENERATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION = 25 # in seconds
|
||||
AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_INITIAL_ANSWER_GENERATION
|
||||
)
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_GENERATION
|
||||
) # 25
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION = 5 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_GENERATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION = 30 # in seconds
|
||||
AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_GENERATION
|
||||
)
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_SUBANSWER_CHECK
|
||||
) # 8
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK = 2 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_SUBANSWER_CHECK
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_CHECK = 8 # in seconds
|
||||
AGENT_TIMEOUT_LLM_SUBANSWER_CHECK = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_SUBANSWER_CHECK")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_SUBANSWER_CHECK
|
||||
)
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_SUBQUESTION_GENERATION
|
||||
) # 6
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION = 3 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_SUBQUESTION_GENERATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION = 8 # in seconds
|
||||
AGENT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_REFINED_SUBQUESTION_GENERATION
|
||||
)
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_QUERY_REWRITING_GENERATION
|
||||
) # 1
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION = 1 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_QUERY_REWRITING_GENERATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION = 3 # in seconds
|
||||
AGENT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_QUERY_REWRITING_GENERATION
|
||||
)
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_HISTORY_SUMMARY_GENERATION
|
||||
) # 4
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION = 2 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_HISTORY_SUMMARY_GENERATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION = 5 # in seconds
|
||||
AGENT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_HISTORY_SUMMARY_GENERATION
|
||||
)
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_COMPARE_ANSWERS
|
||||
) # 8
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS = 2 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_COMPARE_ANSWERS
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_COMPARE_ANSWERS = 8 # in seconds
|
||||
AGENT_TIMEOUT_LLM_COMPARE_ANSWERS = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_COMPARE_ANSWERS")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_COMPARE_ANSWERS
|
||||
)
|
||||
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION = 2 # in seconds
|
||||
AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_CONNECT_LLM_REFINED_ANSWER_VALIDATION
|
||||
)
|
||||
|
||||
AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION = 8 # in seconds
|
||||
AGENT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_LLM_REFINED_ANSWER_VALIDATION
|
||||
)
|
||||
AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION = int(
|
||||
os.environ.get("AGENT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION")
|
||||
or AGENT_DEFAULT_TIMEOUT_OVERRIDE_LLM_REFINED_ANSWER_VALIDATION
|
||||
) # 8
|
||||
|
||||
GRAPH_VERSION_NAME: str = "a"
|
||||
|
||||
6
backend/onyx/configs/integration_test_configs.py
Normal file
6
backend/onyx/configs/integration_test_configs.py
Normal file
@@ -0,0 +1,6 @@
|
||||
import os
|
||||
|
||||
|
||||
SKIP_CONNECTION_POOL_WARM_UP = (
|
||||
os.environ.get("SKIP_CONNECTION_POOL_WARM_UP", "").lower() == "true"
|
||||
)
|
||||
@@ -628,7 +628,7 @@ def create_new_chat_message(
|
||||
commit: bool = True,
|
||||
reserved_message_id: int | None = None,
|
||||
overridden_model: str | None = None,
|
||||
refined_answer_improvement: bool | None = None,
|
||||
refined_answer_improvement: bool = True,
|
||||
) -> ChatMessage:
|
||||
if reserved_message_id is not None:
|
||||
# Edit existing message
|
||||
|
||||
@@ -191,9 +191,16 @@ class SqlEngine:
|
||||
_app_name: str = POSTGRES_UNKNOWN_APP_NAME
|
||||
|
||||
@classmethod
|
||||
def _init_engine(cls, **engine_kwargs: Any) -> Engine:
|
||||
def _init_engine(
|
||||
cls, host: str, port: str, db: str, **engine_kwargs: Any
|
||||
) -> Engine:
|
||||
connection_string = build_connection_string(
|
||||
db_api=SYNC_DB_API, app_name=cls._app_name + "_sync", use_iam=USE_IAM_AUTH
|
||||
db_api=SYNC_DB_API,
|
||||
host=host,
|
||||
port=port,
|
||||
db=db,
|
||||
app_name=cls._app_name + "_sync",
|
||||
use_iam=USE_IAM_AUTH,
|
||||
)
|
||||
|
||||
# Start with base kwargs that are valid for all pool types
|
||||
@@ -231,15 +238,19 @@ class SqlEngine:
|
||||
def init_engine(cls, **engine_kwargs: Any) -> None:
|
||||
with cls._lock:
|
||||
if not cls._engine:
|
||||
cls._engine = cls._init_engine(**engine_kwargs)
|
||||
cls._engine = cls._init_engine(
|
||||
host=engine_kwargs.get("host", POSTGRES_HOST),
|
||||
port=engine_kwargs.get("port", POSTGRES_PORT),
|
||||
db=engine_kwargs.get("db", POSTGRES_DB),
|
||||
**engine_kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_engine(cls) -> Engine:
|
||||
if not cls._engine:
|
||||
with cls._lock:
|
||||
if not cls._engine:
|
||||
cls._engine = cls._init_engine()
|
||||
return cls._engine
|
||||
cls.init_engine()
|
||||
|
||||
return cls._engine # type: ignore
|
||||
|
||||
@classmethod
|
||||
def set_app_name(cls, app_name: str) -> None:
|
||||
|
||||
@@ -6,6 +6,7 @@ from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.document_index.interfaces import DocumentIndex
|
||||
from onyx.document_index.vespa.index import VespaIndex
|
||||
from shared_configs.configs import MULTI_TENANT
|
||||
from shared_configs.configs import VECTOR_DB_INDEX_NAME_PREFIX__INTEGRATION_TEST_ONLY
|
||||
|
||||
|
||||
def get_default_document_index(
|
||||
@@ -23,14 +24,27 @@ def get_default_document_index(
|
||||
secondary_index_name = secondary_search_settings.index_name
|
||||
secondary_large_chunks_enabled = secondary_search_settings.large_chunks_enabled
|
||||
|
||||
# modify index names for integration tests so that we can run many tests
|
||||
# using the same Vespa instance w/o having them collide
|
||||
primary_index_name = search_settings.index_name
|
||||
if VECTOR_DB_INDEX_NAME_PREFIX__INTEGRATION_TEST_ONLY:
|
||||
primary_index_name = (
|
||||
f"{VECTOR_DB_INDEX_NAME_PREFIX__INTEGRATION_TEST_ONLY}_{primary_index_name}"
|
||||
)
|
||||
if secondary_index_name:
|
||||
secondary_index_name = f"{VECTOR_DB_INDEX_NAME_PREFIX__INTEGRATION_TEST_ONLY}_{secondary_index_name}"
|
||||
|
||||
# Currently only supporting Vespa
|
||||
return VespaIndex(
|
||||
index_name=search_settings.index_name,
|
||||
index_name=primary_index_name,
|
||||
secondary_index_name=secondary_index_name,
|
||||
large_chunks_enabled=search_settings.large_chunks_enabled,
|
||||
secondary_large_chunks_enabled=secondary_large_chunks_enabled,
|
||||
multitenant=MULTI_TENANT,
|
||||
httpx_client=httpx_client,
|
||||
preserve_existing_indices=bool(
|
||||
VECTOR_DB_INDEX_NAME_PREFIX__INTEGRATION_TEST_ONLY
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -136,6 +136,7 @@ class VespaIndex(DocumentIndex):
|
||||
secondary_large_chunks_enabled: bool | None,
|
||||
multitenant: bool = False,
|
||||
httpx_client: httpx.Client | None = None,
|
||||
preserve_existing_indices: bool = False,
|
||||
) -> None:
|
||||
self.index_name = index_name
|
||||
self.secondary_index_name = secondary_index_name
|
||||
@@ -161,18 +162,18 @@ class VespaIndex(DocumentIndex):
|
||||
secondary_index_name
|
||||
] = secondary_large_chunks_enabled
|
||||
|
||||
def ensure_indices_exist(
|
||||
self,
|
||||
index_embedding_dim: int,
|
||||
secondary_index_embedding_dim: int | None,
|
||||
) -> None:
|
||||
if MULTI_TENANT:
|
||||
logger.info(
|
||||
"Skipping Vespa index seup for multitenant (would wipe all indices)"
|
||||
)
|
||||
return None
|
||||
self.preserve_existing_indices = preserve_existing_indices
|
||||
|
||||
deploy_url = f"{VESPA_APPLICATION_ENDPOINT}/tenant/default/prepareandactivate"
|
||||
@classmethod
|
||||
def create_indices(
|
||||
cls,
|
||||
indices: list[tuple[str, int, bool]],
|
||||
application_endpoint: str = VESPA_APPLICATION_ENDPOINT,
|
||||
) -> None:
|
||||
"""
|
||||
Create indices in Vespa based on the passed in configuration(s).
|
||||
"""
|
||||
deploy_url = f"{application_endpoint}/tenant/default/prepareandactivate"
|
||||
logger.notice(f"Deploying Vespa application package to {deploy_url}")
|
||||
|
||||
vespa_schema_path = os.path.join(
|
||||
@@ -185,7 +186,7 @@ class VespaIndex(DocumentIndex):
|
||||
with open(services_file, "r") as services_f:
|
||||
services_template = services_f.read()
|
||||
|
||||
schema_names = [self.index_name, self.secondary_index_name]
|
||||
schema_names = [index_name for (index_name, _, _) in indices]
|
||||
|
||||
doc_lines = _create_document_xml_lines(schema_names)
|
||||
services = services_template.replace(DOCUMENT_REPLACEMENT_PAT, doc_lines)
|
||||
@@ -193,14 +194,6 @@ class VespaIndex(DocumentIndex):
|
||||
SEARCH_THREAD_NUMBER_PAT, str(VESPA_SEARCHER_THREADS)
|
||||
)
|
||||
|
||||
kv_store = get_kv_store()
|
||||
|
||||
needs_reindexing = False
|
||||
try:
|
||||
needs_reindexing = cast(bool, kv_store.load(KV_REINDEX_KEY))
|
||||
except Exception:
|
||||
logger.debug("Could not load the reindexing flag. Using ngrams")
|
||||
|
||||
with open(overrides_file, "r") as overrides_f:
|
||||
overrides_template = overrides_f.read()
|
||||
|
||||
@@ -221,29 +214,63 @@ class VespaIndex(DocumentIndex):
|
||||
schema_template = schema_f.read()
|
||||
schema_template = schema_template.replace(TENANT_ID_PAT, "")
|
||||
|
||||
schema = schema_template.replace(
|
||||
DANSWER_CHUNK_REPLACEMENT_PAT, self.index_name
|
||||
).replace(VESPA_DIM_REPLACEMENT_PAT, str(index_embedding_dim))
|
||||
for index_name, index_embedding_dim, needs_reindexing in indices:
|
||||
schema = schema_template.replace(
|
||||
DANSWER_CHUNK_REPLACEMENT_PAT, index_name
|
||||
).replace(VESPA_DIM_REPLACEMENT_PAT, str(index_embedding_dim))
|
||||
|
||||
schema = add_ngrams_to_schema(schema) if needs_reindexing else schema
|
||||
schema = schema.replace(TENANT_ID_PAT, "")
|
||||
zip_dict[f"schemas/{schema_names[0]}.sd"] = schema.encode("utf-8")
|
||||
|
||||
if self.secondary_index_name:
|
||||
upcoming_schema = schema_template.replace(
|
||||
DANSWER_CHUNK_REPLACEMENT_PAT, self.secondary_index_name
|
||||
).replace(VESPA_DIM_REPLACEMENT_PAT, str(secondary_index_embedding_dim))
|
||||
zip_dict[f"schemas/{schema_names[1]}.sd"] = upcoming_schema.encode("utf-8")
|
||||
schema = add_ngrams_to_schema(schema) if needs_reindexing else schema
|
||||
schema = schema.replace(TENANT_ID_PAT, "")
|
||||
logger.info(
|
||||
f"Creating index: {index_name} with embedding "
|
||||
f"dimension: {index_embedding_dim}. Schema:\n\n {schema}"
|
||||
)
|
||||
zip_dict[f"schemas/{index_name}.sd"] = schema.encode("utf-8")
|
||||
|
||||
zip_file = in_memory_zip_from_file_bytes(zip_dict)
|
||||
|
||||
headers = {"Content-Type": "application/zip"}
|
||||
response = requests.post(deploy_url, headers=headers, data=zip_file)
|
||||
if response.status_code != 200:
|
||||
logger.error(f"Failed to create Vespa indices: {response.text}")
|
||||
raise RuntimeError(
|
||||
f"Failed to prepare Vespa Onyx Index. Response: {response.text}"
|
||||
)
|
||||
|
||||
def ensure_indices_exist(
|
||||
self,
|
||||
index_embedding_dim: int,
|
||||
secondary_index_embedding_dim: int | None,
|
||||
) -> None:
|
||||
if self.multitenant or MULTI_TENANT: # be extra safe here
|
||||
logger.info(
|
||||
"Skipping Vespa index setup for multitenant (would wipe all indices)"
|
||||
)
|
||||
return None
|
||||
|
||||
# Used in IT
|
||||
# NOTE: this means that we can't switch embedding models
|
||||
if self.preserve_existing_indices:
|
||||
logger.info("Preserving existing indices")
|
||||
return None
|
||||
|
||||
kv_store = get_kv_store()
|
||||
primary_needs_reindexing = False
|
||||
try:
|
||||
primary_needs_reindexing = cast(bool, kv_store.load(KV_REINDEX_KEY))
|
||||
except Exception:
|
||||
logger.debug("Could not load the reindexing flag. Using ngrams")
|
||||
|
||||
indices = [
|
||||
(self.index_name, index_embedding_dim, primary_needs_reindexing),
|
||||
]
|
||||
if self.secondary_index_name and secondary_index_embedding_dim:
|
||||
indices.append(
|
||||
(self.secondary_index_name, secondary_index_embedding_dim, False)
|
||||
)
|
||||
|
||||
self.create_indices(indices)
|
||||
|
||||
@staticmethod
|
||||
def register_multitenant_indices(
|
||||
indices: list[str],
|
||||
|
||||
@@ -409,6 +409,10 @@ class DefaultMultiLLM(LLM):
|
||||
self._record_call(processed_prompt)
|
||||
|
||||
try:
|
||||
print(
|
||||
"model is",
|
||||
f"{self.config.model_provider}/{self.config.deployment_name or self.config.model_name}",
|
||||
)
|
||||
return litellm.completion(
|
||||
mock_response=MOCK_LLM_RESPONSE,
|
||||
# model choice
|
||||
|
||||
@@ -43,6 +43,7 @@ from onyx.configs.app_configs import USER_AUTH_SECRET
|
||||
from onyx.configs.app_configs import WEB_DOMAIN
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.configs.constants import POSTGRES_WEB_APP_NAME
|
||||
from onyx.configs.integration_test_configs import SKIP_CONNECTION_POOL_WARM_UP
|
||||
from onyx.db.engine import SqlEngine
|
||||
from onyx.db.engine import warm_up_connections
|
||||
from onyx.server.api_key.api import router as api_key_router
|
||||
@@ -208,8 +209,11 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
|
||||
if DISABLE_GENERATIVE_AI:
|
||||
logger.notice("Generative AI Q&A disabled")
|
||||
|
||||
# fill up Postgres connection pools
|
||||
await warm_up_connections()
|
||||
# only used for IT. Need to skip since it overloads postgres when we have 50+
|
||||
# instances running
|
||||
if not SKIP_CONNECTION_POOL_WARM_UP:
|
||||
# fill up Postgres connection pools
|
||||
await warm_up_connections()
|
||||
|
||||
if not MULTI_TENANT:
|
||||
# We cache this at the beginning so there is no delay in the first telemetry
|
||||
|
||||
@@ -146,10 +146,10 @@ class RedisPool:
|
||||
cls._instance._init_pools()
|
||||
return cls._instance
|
||||
|
||||
def _init_pools(self) -> None:
|
||||
self._pool = RedisPool.create_pool(ssl=REDIS_SSL)
|
||||
def _init_pools(self, redis_port: int = REDIS_PORT) -> None:
|
||||
self._pool = RedisPool.create_pool(port=redis_port, ssl=REDIS_SSL)
|
||||
self._replica_pool = RedisPool.create_pool(
|
||||
host=REDIS_REPLICA_HOST, ssl=REDIS_SSL
|
||||
host=REDIS_REPLICA_HOST, port=redis_port, ssl=REDIS_SSL
|
||||
)
|
||||
|
||||
def get_client(self, tenant_id: str | None) -> Redis:
|
||||
|
||||
@@ -213,6 +213,8 @@ def get_chat_session(
|
||||
# we need the tool call objs anyways, so just fetch them in a single call
|
||||
prefetch_tool_calls=True,
|
||||
)
|
||||
for message in session_messages:
|
||||
translate_db_message_to_chat_message_detail(message)
|
||||
|
||||
return ChatSessionDetailResponse(
|
||||
chat_session_id=session_id,
|
||||
|
||||
@@ -251,7 +251,8 @@ def setup_vespa(
|
||||
|
||||
logger.notice("Vespa setup complete.")
|
||||
return True
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.debug(f"Error creating Vespa indices: {e}")
|
||||
logger.notice(
|
||||
f"Vespa setup did not succeed. The Vespa service may not be ready yet. Retrying in {WAIT_SECONDS} seconds."
|
||||
)
|
||||
|
||||
@@ -58,7 +58,6 @@ SEARCH_RESPONSE_SUMMARY_ID = "search_response_summary"
|
||||
SEARCH_DOC_CONTENT_ID = "search_doc_content"
|
||||
SECTION_RELEVANCE_LIST_ID = "section_relevance_list"
|
||||
SEARCH_EVALUATION_ID = "llm_doc_eval"
|
||||
QUERY_FIELD = "query"
|
||||
|
||||
|
||||
class SearchResponseSummary(SearchQueryInfo):
|
||||
@@ -180,12 +179,12 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
QUERY_FIELD: {
|
||||
"query": {
|
||||
"type": "string",
|
||||
"description": "What to search for",
|
||||
},
|
||||
},
|
||||
"required": [QUERY_FIELD],
|
||||
"required": ["query"],
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -224,7 +223,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
rephrased_query = history_based_query_rephrase(
|
||||
query=query, history=history, llm=llm
|
||||
)
|
||||
return {QUERY_FIELD: rephrased_query}
|
||||
return {"query": rephrased_query}
|
||||
|
||||
"""Actual tool execution"""
|
||||
|
||||
@@ -280,7 +279,7 @@ class SearchTool(Tool[SearchToolOverrideKwargs]):
|
||||
def run(
|
||||
self, override_kwargs: SearchToolOverrideKwargs | None = None, **llm_kwargs: Any
|
||||
) -> Generator[ToolResponse, None, None]:
|
||||
query = cast(str, llm_kwargs[QUERY_FIELD])
|
||||
query = cast(str, llm_kwargs["query"])
|
||||
force_no_rerank = False
|
||||
alternate_db_session = None
|
||||
retrieved_sections_callback = None
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import as_completed
|
||||
@@ -14,10 +13,6 @@ logger = setup_logger()
|
||||
R = TypeVar("R")
|
||||
|
||||
|
||||
# WARNING: it is not currently well understood whether we lose access to contextvars when functions are
|
||||
# executed through this wrapper Do NOT try to acquire a db session in a function run through this unless
|
||||
# you have heavily tested that multi-tenancy is respected. If/when we know for sure that it is or
|
||||
# is not safe, update this comment.
|
||||
def run_functions_tuples_in_parallel(
|
||||
functions_with_args: list[tuple[Callable, tuple]],
|
||||
allow_failures: bool = False,
|
||||
@@ -83,10 +78,6 @@ class FunctionCall(Generic[R]):
|
||||
return self.func(*self.args, **self.kwargs)
|
||||
|
||||
|
||||
# WARNING: it is not currently well understood whether we lose access to contextvars when functions are
|
||||
# executed through this wrapper Do NOT try to acquire a db session in a function run through this unless
|
||||
# you have heavily tested that multi-tenancy is respected. If/when we know for sure that it is or
|
||||
# is not safe, update this comment.
|
||||
def run_functions_in_parallel(
|
||||
function_calls: list[FunctionCall],
|
||||
allow_failures: bool = False,
|
||||
@@ -118,49 +109,3 @@ def run_functions_in_parallel(
|
||||
raise
|
||||
|
||||
return results
|
||||
|
||||
|
||||
class TimeoutThread(threading.Thread):
|
||||
def __init__(
|
||||
self, timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any
|
||||
):
|
||||
super().__init__()
|
||||
self.timeout = timeout
|
||||
self.func = func
|
||||
self.args = args
|
||||
self.kwargs = kwargs
|
||||
self.exception: Exception | None = None
|
||||
|
||||
def run(self) -> None:
|
||||
try:
|
||||
self.result = self.func(*self.args, **self.kwargs)
|
||||
except Exception as e:
|
||||
self.exception = e
|
||||
|
||||
def end(self) -> None:
|
||||
raise TimeoutError(
|
||||
f"Function {self.func.__name__} timed out after {self.timeout} seconds"
|
||||
)
|
||||
|
||||
|
||||
# WARNING: it is not currently well understood whether we lose access to contextvars when functions are
|
||||
# executed through this wrapper Do NOT try to acquire a db session in a function run through this unless
|
||||
# you have heavily tested that multi-tenancy is respected. If/when we know for sure that it is or
|
||||
# is not safe, update this comment.
|
||||
def run_with_timeout(
|
||||
timeout: float, func: Callable[..., R], *args: Any, **kwargs: Any
|
||||
) -> R:
|
||||
"""
|
||||
Executes a function with a timeout. If the function doesn't complete within the specified
|
||||
timeout, raises TimeoutError.
|
||||
"""
|
||||
task = TimeoutThread(timeout, func, *args, **kwargs)
|
||||
task.start()
|
||||
task.join(timeout)
|
||||
|
||||
if task.exception is not None:
|
||||
raise task.exception
|
||||
if task.is_alive():
|
||||
task.end()
|
||||
|
||||
return task.result
|
||||
|
||||
@@ -270,3 +270,10 @@ SUPPORTED_EMBEDDING_MODELS = [
|
||||
index_name="danswer_chunk_intfloat_multilingual_e5_small",
|
||||
),
|
||||
]
|
||||
|
||||
"""
|
||||
INTEGRATION TEST ONLY SETTINGS
|
||||
"""
|
||||
VECTOR_DB_INDEX_NAME_PREFIX__INTEGRATION_TEST_ONLY = os.getenv(
|
||||
"VECTOR_DB_INDEX_NAME_PREFIX__INTEGRATION_TEST_ONLY"
|
||||
)
|
||||
|
||||
@@ -3,6 +3,19 @@ FROM python:3.11.7-slim-bookworm
|
||||
# Currently needs all dependencies, since the ITs use some of the Onyx
|
||||
# backend code.
|
||||
|
||||
# Add Docker's official GPG key and repository for Debian
|
||||
RUN apt-get update && \
|
||||
apt-get install -y ca-certificates curl && \
|
||||
install -m 0755 -d /etc/apt/keyrings && \
|
||||
curl -fsSL https://download.docker.com/linux/debian/gpg -o /etc/apt/keyrings/docker.asc && \
|
||||
chmod a+r /etc/apt/keyrings/docker.asc && \
|
||||
echo \
|
||||
"deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.asc] https://download.docker.com/linux/debian \
|
||||
$(. /etc/os-release && echo "$VERSION_CODENAME") stable" | \
|
||||
tee /etc/apt/sources.list.d/docker.list > /dev/null && \
|
||||
apt-get update
|
||||
|
||||
|
||||
# Install system dependencies
|
||||
# cmake needed for psycopg (postgres)
|
||||
# libpq-dev needed for psycopg (postgres)
|
||||
@@ -15,6 +28,9 @@ RUN apt-get update && \
|
||||
curl \
|
||||
zip \
|
||||
ca-certificates \
|
||||
postgresql-client \
|
||||
# Install Docker for DinD
|
||||
docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin \
|
||||
libgnutls30=3.7.9-2+deb12u3 \
|
||||
libblkid1=2.38.1-5+deb12u1 \
|
||||
libmount1=2.38.1-5+deb12u1 \
|
||||
@@ -29,37 +45,19 @@ RUN apt-get update && \
|
||||
# Install Python dependencies
|
||||
# Remove py which is pulled in by retry, py is not needed and is a CVE
|
||||
COPY ./requirements/default.txt /tmp/requirements.txt
|
||||
COPY ./requirements/model_server.txt /tmp/model_server-requirements.txt
|
||||
COPY ./requirements/ee.txt /tmp/ee-requirements.txt
|
||||
RUN pip install --no-cache-dir --upgrade \
|
||||
--retries 5 \
|
||||
--timeout 30 \
|
||||
-r /tmp/requirements.txt \
|
||||
-r /tmp/model_server-requirements.txt \
|
||||
-r /tmp/ee-requirements.txt && \
|
||||
pip uninstall -y py && \
|
||||
playwright install chromium && \
|
||||
playwright install-deps chromium && \
|
||||
ln -s /usr/local/bin/supervisord /usr/bin/supervisord
|
||||
|
||||
# Cleanup for CVEs and size reduction
|
||||
# https://github.com/tornadoweb/tornado/issues/3107
|
||||
# xserver-common and xvfb included by playwright installation but not needed after
|
||||
# perl-base is part of the base Python Debian image but not needed for Onyx functionality
|
||||
# perl-base could only be removed with --allow-remove-essential
|
||||
RUN apt-get update && \
|
||||
apt-get remove -y --allow-remove-essential \
|
||||
perl-base \
|
||||
xserver-common \
|
||||
xvfb \
|
||||
cmake \
|
||||
libldap-2.5-0 \
|
||||
libxmlsec1-dev \
|
||||
pkg-config \
|
||||
gcc && \
|
||||
apt-get install -y libxmlsec1-openssl && \
|
||||
apt-get autoremove -y && \
|
||||
rm -rf /var/lib/apt/lists/* && \
|
||||
rm -f /usr/local/lib/python3.11/site-packages/tornado/test/test.key
|
||||
|
||||
# Set up application files
|
||||
WORKDIR /app
|
||||
|
||||
@@ -76,6 +74,9 @@ COPY ./alembic.ini /app/alembic.ini
|
||||
COPY ./pytest.ini /app/pytest.ini
|
||||
COPY supervisord.conf /usr/etc/supervisord.conf
|
||||
|
||||
# need to copy over model server as well, since we're running it in the same container
|
||||
COPY ./model_server /app/model_server
|
||||
|
||||
# Integration test stuff
|
||||
COPY ./requirements/dev.txt /tmp/dev-requirements.txt
|
||||
RUN pip install --no-cache-dir --upgrade \
|
||||
@@ -84,5 +85,6 @@ COPY ./tests/integration /app/tests/integration
|
||||
|
||||
ENV PYTHONPATH=/app
|
||||
|
||||
ENTRYPOINT ["pytest", "-s"]
|
||||
CMD ["/app/tests/integration", "--ignore=/app/tests/integration/multitenant_tests"]
|
||||
ENTRYPOINT []
|
||||
# let caller specify the command
|
||||
CMD ["tail", "-f", "/dev/null"]
|
||||
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
|
||||
ADMIN_USER_NAME = "admin_user"
|
||||
GUARANTEED_FRESH_SETUP = os.getenv("GUARANTEED_FRESH_SETUP") == "true"
|
||||
|
||||
API_SERVER_PROTOCOL = os.getenv("API_SERVER_PROTOCOL") or "http"
|
||||
API_SERVER_HOST = os.getenv("API_SERVER_HOST") or "localhost"
|
||||
|
||||
@@ -1,5 +1,9 @@
|
||||
import contextlib
|
||||
import io
|
||||
import logging
|
||||
import sys
|
||||
import time
|
||||
from logging import Logger
|
||||
from types import SimpleNamespace
|
||||
|
||||
import psycopg2
|
||||
@@ -11,10 +15,12 @@ from onyx.configs.app_configs import POSTGRES_HOST
|
||||
from onyx.configs.app_configs import POSTGRES_PASSWORD
|
||||
from onyx.configs.app_configs import POSTGRES_PORT
|
||||
from onyx.configs.app_configs import POSTGRES_USER
|
||||
from onyx.configs.app_configs import REDIS_PORT
|
||||
from onyx.db.engine import build_connection_string
|
||||
from onyx.db.engine import get_all_tenant_ids
|
||||
from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.engine import get_session_with_tenant
|
||||
from onyx.db.engine import SqlEngine
|
||||
from onyx.db.engine import SYNC_DB_API
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from onyx.db.swap_index import check_index_swap
|
||||
@@ -22,17 +28,20 @@ from onyx.document_index.document_index_utils import get_multipass_config
|
||||
from onyx.document_index.vespa.index import DOCUMENT_ID_ENDPOINT
|
||||
from onyx.document_index.vespa.index import VespaIndex
|
||||
from onyx.indexing.models import IndexingSetting
|
||||
from onyx.redis.redis_pool import redis_pool
|
||||
from onyx.setup import setup_postgres
|
||||
from onyx.setup import setup_vespa
|
||||
from onyx.utils.logger import setup_logger
|
||||
from tests.integration.common_utils.timeout import run_with_timeout
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def _run_migrations(
|
||||
database_url: str,
|
||||
database: str,
|
||||
config_name: str,
|
||||
postgres_host: str,
|
||||
postgres_port: str,
|
||||
redis_port: int,
|
||||
direction: str = "upgrade",
|
||||
revision: str = "head",
|
||||
schema: str = "public",
|
||||
@@ -46,9 +55,28 @@ def _run_migrations(
|
||||
alembic_cfg.attributes["configure_logger"] = False
|
||||
alembic_cfg.config_ini_section = config_name
|
||||
|
||||
# Add environment variables to the config attributes
|
||||
alembic_cfg.attributes["env_vars"] = {
|
||||
"POSTGRES_HOST": postgres_host,
|
||||
"POSTGRES_PORT": postgres_port,
|
||||
"POSTGRES_DB": database,
|
||||
# some migrations call redis directly, so we need to pass the port
|
||||
"REDIS_PORT": str(redis_port),
|
||||
}
|
||||
|
||||
alembic_cfg.cmd_opts = SimpleNamespace() # type: ignore
|
||||
alembic_cfg.cmd_opts.x = [f"schema={schema}"] # type: ignore
|
||||
|
||||
# Build the database URL
|
||||
database_url = build_connection_string(
|
||||
db=database,
|
||||
user=POSTGRES_USER,
|
||||
password=POSTGRES_PASSWORD,
|
||||
host=postgres_host,
|
||||
port=postgres_port,
|
||||
db_api=SYNC_DB_API,
|
||||
)
|
||||
|
||||
# Set the SQLAlchemy URL in the Alembic configuration
|
||||
alembic_cfg.set_main_option("sqlalchemy.url", database_url)
|
||||
|
||||
@@ -71,6 +99,9 @@ def downgrade_postgres(
|
||||
config_name: str = "alembic",
|
||||
revision: str = "base",
|
||||
clear_data: bool = False,
|
||||
postgres_host: str = POSTGRES_HOST,
|
||||
postgres_port: str = POSTGRES_PORT,
|
||||
redis_port: int = REDIS_PORT,
|
||||
) -> None:
|
||||
"""Downgrade Postgres database to base state."""
|
||||
if clear_data:
|
||||
@@ -81,8 +112,8 @@ def downgrade_postgres(
|
||||
dbname=database,
|
||||
user=POSTGRES_USER,
|
||||
password=POSTGRES_PASSWORD,
|
||||
host=POSTGRES_HOST,
|
||||
port=POSTGRES_PORT,
|
||||
host=postgres_host,
|
||||
port=postgres_port,
|
||||
)
|
||||
conn.autocommit = True # Need autocommit for dropping schema
|
||||
cur = conn.cursor()
|
||||
@@ -112,37 +143,32 @@ def downgrade_postgres(
|
||||
return
|
||||
|
||||
# Downgrade to base
|
||||
conn_str = build_connection_string(
|
||||
db=database,
|
||||
user=POSTGRES_USER,
|
||||
password=POSTGRES_PASSWORD,
|
||||
host=POSTGRES_HOST,
|
||||
port=POSTGRES_PORT,
|
||||
db_api=SYNC_DB_API,
|
||||
)
|
||||
_run_migrations(
|
||||
conn_str,
|
||||
config_name,
|
||||
database=database,
|
||||
config_name=config_name,
|
||||
postgres_host=postgres_host,
|
||||
postgres_port=postgres_port,
|
||||
redis_port=redis_port,
|
||||
direction="downgrade",
|
||||
revision=revision,
|
||||
)
|
||||
|
||||
|
||||
def upgrade_postgres(
|
||||
database: str = "postgres", config_name: str = "alembic", revision: str = "head"
|
||||
database: str = "postgres",
|
||||
config_name: str = "alembic",
|
||||
revision: str = "head",
|
||||
postgres_host: str = POSTGRES_HOST,
|
||||
postgres_port: str = POSTGRES_PORT,
|
||||
redis_port: int = REDIS_PORT,
|
||||
) -> None:
|
||||
"""Upgrade Postgres database to latest version."""
|
||||
conn_str = build_connection_string(
|
||||
db=database,
|
||||
user=POSTGRES_USER,
|
||||
password=POSTGRES_PASSWORD,
|
||||
host=POSTGRES_HOST,
|
||||
port=POSTGRES_PORT,
|
||||
db_api=SYNC_DB_API,
|
||||
)
|
||||
_run_migrations(
|
||||
conn_str,
|
||||
config_name,
|
||||
database=database,
|
||||
config_name=config_name,
|
||||
postgres_host=postgres_host,
|
||||
postgres_port=postgres_port,
|
||||
redis_port=redis_port,
|
||||
direction="upgrade",
|
||||
revision=revision,
|
||||
)
|
||||
@@ -152,46 +178,44 @@ def reset_postgres(
|
||||
database: str = "postgres",
|
||||
config_name: str = "alembic",
|
||||
setup_onyx: bool = True,
|
||||
postgres_host: str = POSTGRES_HOST,
|
||||
postgres_port: str = POSTGRES_PORT,
|
||||
redis_port: int = REDIS_PORT,
|
||||
) -> None:
|
||||
"""Reset the Postgres database."""
|
||||
# this seems to hang due to locking issues, so run with a timeout with a few retries
|
||||
NUM_TRIES = 10
|
||||
TIMEOUT = 10
|
||||
success = False
|
||||
for _ in range(NUM_TRIES):
|
||||
logger.info(f"Downgrading Postgres... ({_ + 1}/{NUM_TRIES})")
|
||||
try:
|
||||
run_with_timeout(
|
||||
downgrade_postgres,
|
||||
TIMEOUT,
|
||||
kwargs={
|
||||
"database": database,
|
||||
"config_name": config_name,
|
||||
"revision": "base",
|
||||
"clear_data": True,
|
||||
},
|
||||
)
|
||||
success = True
|
||||
break
|
||||
except TimeoutError:
|
||||
logger.warning(
|
||||
f"Postgres downgrade timed out, retrying... ({_ + 1}/{NUM_TRIES})"
|
||||
)
|
||||
|
||||
if not success:
|
||||
raise RuntimeError("Postgres downgrade failed after 10 timeouts.")
|
||||
|
||||
logger.info("Upgrading Postgres...")
|
||||
upgrade_postgres(database=database, config_name=config_name, revision="head")
|
||||
downgrade_postgres(
|
||||
database=database,
|
||||
config_name=config_name,
|
||||
revision="base",
|
||||
clear_data=True,
|
||||
postgres_host=postgres_host,
|
||||
postgres_port=postgres_port,
|
||||
redis_port=redis_port,
|
||||
)
|
||||
upgrade_postgres(
|
||||
database=database,
|
||||
config_name=config_name,
|
||||
revision="head",
|
||||
postgres_host=postgres_host,
|
||||
postgres_port=postgres_port,
|
||||
redis_port=redis_port,
|
||||
)
|
||||
if setup_onyx:
|
||||
logger.info("Setting up Postgres...")
|
||||
with get_session_context_manager() as db_session:
|
||||
setup_postgres(db_session)
|
||||
|
||||
|
||||
def reset_vespa() -> None:
|
||||
"""Wipe all data from the Vespa index."""
|
||||
def reset_vespa(
|
||||
skip_creating_indices: bool, document_id_endpoint: str = DOCUMENT_ID_ENDPOINT
|
||||
) -> None:
|
||||
"""Wipe all data from the Vespa index.
|
||||
|
||||
Args:
|
||||
skip_creating_indices: If True, the indices will not be recreated.
|
||||
This is useful if the indices already exist and you do not want to
|
||||
recreate them (e.g. when running parallel tests).
|
||||
"""
|
||||
with get_session_context_manager() as db_session:
|
||||
# swap to the correct default model
|
||||
check_index_swap(db_session)
|
||||
@@ -200,18 +224,21 @@ def reset_vespa() -> None:
|
||||
multipass_config = get_multipass_config(search_settings)
|
||||
index_name = search_settings.index_name
|
||||
|
||||
success = setup_vespa(
|
||||
document_index=VespaIndex(
|
||||
index_name=index_name,
|
||||
secondary_index_name=None,
|
||||
large_chunks_enabled=multipass_config.enable_large_chunks,
|
||||
secondary_large_chunks_enabled=None,
|
||||
),
|
||||
index_setting=IndexingSetting.from_db_model(search_settings),
|
||||
secondary_index_setting=None,
|
||||
)
|
||||
if not success:
|
||||
raise RuntimeError("Could not connect to Vespa within the specified timeout.")
|
||||
if not skip_creating_indices:
|
||||
success = setup_vespa(
|
||||
document_index=VespaIndex(
|
||||
index_name=index_name,
|
||||
secondary_index_name=None,
|
||||
large_chunks_enabled=multipass_config.enable_large_chunks,
|
||||
secondary_large_chunks_enabled=None,
|
||||
),
|
||||
index_setting=IndexingSetting.from_db_model(search_settings),
|
||||
secondary_index_setting=None,
|
||||
)
|
||||
if not success:
|
||||
raise RuntimeError(
|
||||
"Could not connect to Vespa within the specified timeout."
|
||||
)
|
||||
|
||||
for _ in range(5):
|
||||
try:
|
||||
@@ -222,7 +249,7 @@ def reset_vespa() -> None:
|
||||
if continuation:
|
||||
params = {**params, "continuation": continuation}
|
||||
response = requests.delete(
|
||||
DOCUMENT_ID_ENDPOINT.format(index_name=index_name), params=params
|
||||
document_id_endpoint.format(index_name=index_name), params=params
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
@@ -336,11 +363,99 @@ def reset_vespa_multitenant() -> None:
|
||||
time.sleep(5)
|
||||
|
||||
|
||||
def reset_all() -> None:
|
||||
def reset_all(
|
||||
database: str = "postgres",
|
||||
postgres_host: str = POSTGRES_HOST,
|
||||
postgres_port: str = POSTGRES_PORT,
|
||||
redis_port: int = REDIS_PORT,
|
||||
silence_logs: bool = False,
|
||||
skip_creating_indices: bool = False,
|
||||
document_id_endpoint: str = DOCUMENT_ID_ENDPOINT,
|
||||
) -> None:
|
||||
if not silence_logs:
|
||||
with contextlib.redirect_stdout(sys.stdout), contextlib.redirect_stderr(
|
||||
sys.stderr
|
||||
):
|
||||
_do_reset(
|
||||
database,
|
||||
postgres_host,
|
||||
postgres_port,
|
||||
redis_port,
|
||||
skip_creating_indices,
|
||||
document_id_endpoint,
|
||||
)
|
||||
return
|
||||
|
||||
# Store original logging levels
|
||||
loggers_to_silence: list[Logger] = [
|
||||
logging.getLogger(), # Root logger
|
||||
logging.getLogger("alembic"),
|
||||
logger.logger, # Our custom logger
|
||||
]
|
||||
original_levels = [logger.level for logger in loggers_to_silence]
|
||||
|
||||
# Temporarily set all loggers to ERROR level
|
||||
for log in loggers_to_silence:
|
||||
log.setLevel(logging.ERROR)
|
||||
|
||||
stdout_redirect = io.StringIO()
|
||||
stderr_redirect = io.StringIO()
|
||||
try:
|
||||
with contextlib.redirect_stdout(stdout_redirect), contextlib.redirect_stderr(
|
||||
stderr_redirect
|
||||
):
|
||||
_do_reset(
|
||||
database,
|
||||
postgres_host,
|
||||
postgres_port,
|
||||
redis_port,
|
||||
skip_creating_indices,
|
||||
document_id_endpoint,
|
||||
)
|
||||
except Exception as e:
|
||||
print(stdout_redirect.getvalue(), file=sys.stdout)
|
||||
print(stderr_redirect.getvalue(), file=sys.stderr)
|
||||
raise e
|
||||
finally:
|
||||
# Restore original logging levels
|
||||
for logger_, level in zip(loggers_to_silence, original_levels):
|
||||
logger_.setLevel(level)
|
||||
|
||||
|
||||
def _do_reset(
|
||||
database: str,
|
||||
postgres_host: str,
|
||||
postgres_port: str,
|
||||
redis_port: int,
|
||||
skip_creating_indices: bool,
|
||||
document_id_endpoint: str,
|
||||
) -> None:
|
||||
"""NOTE: should only be be running in one worker/thread a time."""
|
||||
|
||||
# force re-create the engine to allow for the same worker to reset
|
||||
# different databases
|
||||
with SqlEngine._lock:
|
||||
SqlEngine._engine = SqlEngine._init_engine(
|
||||
host=postgres_host,
|
||||
port=postgres_port,
|
||||
db=database,
|
||||
)
|
||||
|
||||
# same with redis
|
||||
redis_pool._init_pools(redis_port=redis_port)
|
||||
|
||||
logger.info("Resetting Postgres...")
|
||||
reset_postgres()
|
||||
reset_postgres(
|
||||
database=database,
|
||||
postgres_host=postgres_host,
|
||||
postgres_port=postgres_port,
|
||||
redis_port=redis_port,
|
||||
)
|
||||
logger.info("Resetting Vespa...")
|
||||
reset_vespa()
|
||||
reset_vespa(
|
||||
skip_creating_indices=skip_creating_indices,
|
||||
document_id_endpoint=document_id_endpoint,
|
||||
)
|
||||
|
||||
|
||||
def reset_all_multitenant() -> None:
|
||||
|
||||
@@ -7,6 +7,7 @@ from onyx.db.engine import get_session_context_manager
|
||||
from onyx.db.search_settings import get_current_search_settings
|
||||
from tests.integration.common_utils.constants import ADMIN_USER_NAME
|
||||
from tests.integration.common_utils.constants import GENERAL_HEADERS
|
||||
from tests.integration.common_utils.constants import GUARANTEED_FRESH_SETUP
|
||||
from tests.integration.common_utils.managers.user import build_email
|
||||
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
|
||||
from tests.integration.common_utils.managers.user import UserManager
|
||||
@@ -14,25 +15,11 @@ from tests.integration.common_utils.reset import reset_all
|
||||
from tests.integration.common_utils.reset import reset_all_multitenant
|
||||
from tests.integration.common_utils.test_models import DATestUser
|
||||
from tests.integration.common_utils.vespa import vespa_fixture
|
||||
|
||||
|
||||
def load_env_vars(env_file: str = ".env") -> None:
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
env_path = os.path.join(current_dir, env_file)
|
||||
try:
|
||||
with open(env_path, "r") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line and not line.startswith("#"):
|
||||
key, value = line.split("=", 1)
|
||||
os.environ[key] = value.strip()
|
||||
print("Successfully loaded environment variables")
|
||||
except FileNotFoundError:
|
||||
print(f"File {env_file} not found")
|
||||
from tests.integration.introspection import load_env_vars
|
||||
|
||||
|
||||
# Load environment variables at the module level
|
||||
load_env_vars()
|
||||
load_env_vars(os.environ.get("IT_ENV_FILE_PATH", ".env"))
|
||||
|
||||
|
||||
"""NOTE: for some reason using this seems to lead to misc
|
||||
@@ -57,6 +44,10 @@ def vespa_client() -> vespa_fixture:
|
||||
|
||||
@pytest.fixture
|
||||
def reset() -> None:
|
||||
if GUARANTEED_FRESH_SETUP:
|
||||
print("GUARANTEED_FRESH_SETUP is true, skipping reset")
|
||||
return None
|
||||
|
||||
reset_all()
|
||||
|
||||
|
||||
|
||||
@@ -118,6 +118,7 @@ def test_google_permission_sync(
|
||||
GoogleDriveService, str, DATestCCPair, DATestUser, DATestUser, DATestUser
|
||||
],
|
||||
) -> None:
|
||||
print("Running test_google_permission_sync")
|
||||
(
|
||||
drive_service,
|
||||
drive_id,
|
||||
|
||||
71
backend/tests/integration/introspection.py
Normal file
71
backend/tests/integration/introspection.py
Normal file
@@ -0,0 +1,71 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from _pytest.nodes import Item
|
||||
|
||||
|
||||
def list_all_tests(directory: str | Path = ".") -> list[str]:
|
||||
"""
|
||||
List all pytest test functions under the specified directory.
|
||||
|
||||
Args:
|
||||
directory: Directory path to search for tests (defaults to current directory)
|
||||
|
||||
Returns:
|
||||
List of test function names with their module paths
|
||||
"""
|
||||
directory = Path(directory).absolute()
|
||||
print(f"Searching for tests in: {directory}")
|
||||
|
||||
class TestCollector:
|
||||
def __init__(self) -> None:
|
||||
self.collected: list[str] = []
|
||||
|
||||
def pytest_collection_modifyitems(self, items: list[Item]) -> None:
|
||||
for item in items:
|
||||
if isinstance(item, Item):
|
||||
# Get the relative path from the test file to the directory we're searching from
|
||||
rel_path = Path(item.fspath).relative_to(directory)
|
||||
# Remove the .py extension
|
||||
module_path = str(rel_path.with_suffix(""))
|
||||
# Replace directory separators with dots
|
||||
module_path = module_path.replace("/", ".")
|
||||
test_name = item.name
|
||||
self.collected.append(f"{module_path}::{test_name}")
|
||||
|
||||
collector = TestCollector()
|
||||
|
||||
# Run pytest in collection-only mode
|
||||
pytest.main(
|
||||
[
|
||||
str(directory),
|
||||
"--collect-only",
|
||||
"-q", # quiet mode
|
||||
],
|
||||
plugins=[collector],
|
||||
)
|
||||
|
||||
return sorted(collector.collected)
|
||||
|
||||
|
||||
def load_env_vars(env_file: str = ".env") -> None:
|
||||
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||
env_path = os.path.join(current_dir, env_file)
|
||||
try:
|
||||
with open(env_path, "r") as f:
|
||||
for line in f:
|
||||
line = line.strip()
|
||||
if line and not line.startswith("#"):
|
||||
key, value = line.split("=", 1)
|
||||
os.environ[key] = value.strip()
|
||||
print("Successfully loaded environment variables")
|
||||
except FileNotFoundError:
|
||||
print(f"File {env_file} not found")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tests = list_all_tests()
|
||||
print("\nFound tests:")
|
||||
for test in tests:
|
||||
print(f"- {test}")
|
||||
637
backend/tests/integration/kickoff.py
Normal file
637
backend/tests/integration/kickoff.py
Normal file
@@ -0,0 +1,637 @@
|
||||
#!/usr/bin/env python3
|
||||
import atexit
|
||||
import os
|
||||
import random
|
||||
import signal
|
||||
import socket
|
||||
import subprocess
|
||||
import sys
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Callable
|
||||
from multiprocessing.pool import ThreadPool
|
||||
from pathlib import Path
|
||||
from typing import NamedTuple
|
||||
|
||||
import requests
|
||||
import yaml
|
||||
|
||||
from onyx.configs.constants import AuthType
|
||||
from onyx.document_index.vespa.index import VespaIndex
|
||||
|
||||
|
||||
BACKEND_DIR_PATH = Path(__file__).parent.parent.parent
|
||||
COMPOSE_DIR_PATH = BACKEND_DIR_PATH.parent / "deployment/docker_compose"
|
||||
|
||||
DEFAULT_EMBEDDING_DIMENSION = 768
|
||||
DEFAULT_SCHEMA_NAME = "danswer_chunk_nomic_ai_nomic_embed_text_v1"
|
||||
|
||||
|
||||
class DeploymentConfig(NamedTuple):
|
||||
instance_num: int
|
||||
api_port: int
|
||||
web_port: int
|
||||
nginx_port: int
|
||||
redis_port: int
|
||||
postgres_db: str
|
||||
|
||||
|
||||
class SharedServicesConfig(NamedTuple):
|
||||
run_id: uuid.UUID
|
||||
postgres_port: int
|
||||
vespa_port: int
|
||||
vespa_tenant_port: int
|
||||
model_server_port: int
|
||||
|
||||
|
||||
def get_random_port() -> int:
|
||||
"""Find a random available port."""
|
||||
while True:
|
||||
port = random.randint(10000, 65535)
|
||||
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
|
||||
if sock.connect_ex(("localhost", port)) != 0:
|
||||
return port
|
||||
|
||||
|
||||
def cleanup_pid(pid: int) -> None:
|
||||
"""Cleanup a specific PID."""
|
||||
print(f"Killing process {pid}")
|
||||
try:
|
||||
os.kill(pid, signal.SIGTERM)
|
||||
except ProcessLookupError:
|
||||
print(f"Process {pid} not found")
|
||||
|
||||
|
||||
def get_shared_services_stack_name(run_id: uuid.UUID) -> str:
|
||||
return f"base-onyx-{run_id}"
|
||||
|
||||
|
||||
def get_db_name(instance_num: int) -> str:
|
||||
"""Get the database name for a given instance number."""
|
||||
return f"onyx_{instance_num}"
|
||||
|
||||
|
||||
def get_vector_db_prefix(instance_num: int) -> str:
|
||||
"""Get the vector DB prefix for a given instance number."""
|
||||
return f"test_instance_{instance_num}"
|
||||
|
||||
|
||||
def setup_db(
|
||||
instance_num: int,
|
||||
postgres_port: int,
|
||||
) -> str:
|
||||
env = os.environ.copy()
|
||||
|
||||
# Wait for postgres to be ready
|
||||
max_attempts = 10
|
||||
for attempt in range(max_attempts):
|
||||
try:
|
||||
subprocess.run(
|
||||
[
|
||||
"psql",
|
||||
"-h",
|
||||
"localhost",
|
||||
"-p",
|
||||
str(postgres_port),
|
||||
"-U",
|
||||
"postgres",
|
||||
"-c",
|
||||
"SELECT 1",
|
||||
],
|
||||
env={**env, "PGPASSWORD": "password"},
|
||||
check=True,
|
||||
stdout=subprocess.DEVNULL,
|
||||
stderr=subprocess.DEVNULL,
|
||||
)
|
||||
break
|
||||
except subprocess.CalledProcessError:
|
||||
if attempt == max_attempts - 1:
|
||||
raise RuntimeError("Postgres failed to become ready within timeout")
|
||||
time.sleep(1)
|
||||
|
||||
db_name = get_db_name(instance_num)
|
||||
# Create the database first
|
||||
subprocess.run(
|
||||
[
|
||||
"psql",
|
||||
"-h",
|
||||
"localhost",
|
||||
"-p",
|
||||
str(postgres_port),
|
||||
"-U",
|
||||
"postgres",
|
||||
"-c",
|
||||
f"CREATE DATABASE {db_name}",
|
||||
],
|
||||
env={**env, "PGPASSWORD": "password"},
|
||||
check=True,
|
||||
)
|
||||
|
||||
# NEW: Stamp this brand-new DB at 'base' so Alembic doesn't fail
|
||||
subprocess.run(
|
||||
[
|
||||
"alembic",
|
||||
"stamp",
|
||||
"base",
|
||||
],
|
||||
env={
|
||||
**env,
|
||||
"PGPASSWORD": "password",
|
||||
"POSTGRES_HOST": "localhost",
|
||||
"POSTGRES_PORT": str(postgres_port),
|
||||
"POSTGRES_DB": db_name,
|
||||
},
|
||||
check=True,
|
||||
cwd=str(BACKEND_DIR_PATH),
|
||||
)
|
||||
|
||||
# Run alembic upgrade to create tables
|
||||
max_attempts = 3
|
||||
for attempt in range(max_attempts):
|
||||
try:
|
||||
subprocess.run(
|
||||
["alembic", "upgrade", "head"],
|
||||
env={
|
||||
**env,
|
||||
"POSTGRES_HOST": "localhost",
|
||||
"POSTGRES_PORT": str(postgres_port),
|
||||
"POSTGRES_DB": db_name,
|
||||
},
|
||||
check=True,
|
||||
cwd=str(BACKEND_DIR_PATH),
|
||||
)
|
||||
break
|
||||
except subprocess.CalledProcessError:
|
||||
if attempt == max_attempts - 1:
|
||||
raise
|
||||
print("Alembic upgrade failed, retrying in 5 seconds...")
|
||||
time.sleep(5)
|
||||
|
||||
return db_name
|
||||
|
||||
|
||||
def start_api_server(
|
||||
instance_num: int,
|
||||
model_server_port: int,
|
||||
postgres_port: int,
|
||||
vespa_port: int,
|
||||
vespa_tenant_port: int,
|
||||
redis_port: int,
|
||||
register_process: Callable[[subprocess.Popen], None],
|
||||
) -> int:
|
||||
"""Start the API server.
|
||||
|
||||
NOTE: assumes that Postgres is all set up (database exists, migrations ran)
|
||||
"""
|
||||
print("Starting API server...")
|
||||
db_name = get_db_name(instance_num)
|
||||
vector_db_prefix = get_vector_db_prefix(instance_num)
|
||||
|
||||
env = os.environ.copy()
|
||||
env.update(
|
||||
{
|
||||
"POSTGRES_HOST": "localhost",
|
||||
"POSTGRES_PORT": str(postgres_port),
|
||||
"POSTGRES_DB": db_name,
|
||||
"REDIS_HOST": "localhost",
|
||||
"REDIS_PORT": str(redis_port),
|
||||
"VESPA_HOST": "localhost",
|
||||
"VESPA_PORT": str(vespa_port),
|
||||
"VESPA_TENANT_PORT": str(vespa_tenant_port),
|
||||
"MODEL_SERVER_PORT": str(model_server_port),
|
||||
"VECTOR_DB_INDEX_NAME_PREFIX__INTEGRATION_TEST_ONLY": vector_db_prefix,
|
||||
"LOG_LEVEL": "debug",
|
||||
"AUTH_TYPE": AuthType.BASIC,
|
||||
}
|
||||
)
|
||||
|
||||
port = get_random_port()
|
||||
|
||||
# Open log file for API server in /tmp
|
||||
log_file = open(f"/tmp/api_server_{instance_num}.txt", "w")
|
||||
|
||||
process = subprocess.Popen(
|
||||
[
|
||||
"uvicorn",
|
||||
"onyx.main:app",
|
||||
"--host",
|
||||
"localhost",
|
||||
"--port",
|
||||
str(port),
|
||||
],
|
||||
env=env,
|
||||
cwd=str(BACKEND_DIR_PATH),
|
||||
stdout=log_file,
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
register_process(process)
|
||||
|
||||
return port
|
||||
|
||||
|
||||
def start_background(
|
||||
instance_num: int,
|
||||
postgres_port: int,
|
||||
vespa_port: int,
|
||||
vespa_tenant_port: int,
|
||||
redis_port: int,
|
||||
register_process: Callable[[subprocess.Popen], None],
|
||||
) -> None:
|
||||
"""Start the background process."""
|
||||
print("Starting background process...")
|
||||
env = os.environ.copy()
|
||||
env.update(
|
||||
{
|
||||
"POSTGRES_HOST": "localhost",
|
||||
"POSTGRES_PORT": str(postgres_port),
|
||||
"POSTGRES_DB": get_db_name(instance_num),
|
||||
"REDIS_HOST": "localhost",
|
||||
"REDIS_PORT": str(redis_port),
|
||||
"VESPA_HOST": "localhost",
|
||||
"VESPA_PORT": str(vespa_port),
|
||||
"VESPA_TENANT_PORT": str(vespa_tenant_port),
|
||||
"VECTOR_DB_INDEX_NAME_PREFIX__INTEGRATION_TEST_ONLY": get_vector_db_prefix(
|
||||
instance_num
|
||||
),
|
||||
"LOG_LEVEL": "debug",
|
||||
}
|
||||
)
|
||||
|
||||
str(Path(__file__).parent / "backend")
|
||||
|
||||
# Open log file for background process in /tmp
|
||||
log_file = open(f"/tmp/background_{instance_num}.txt", "w")
|
||||
|
||||
process = subprocess.Popen(
|
||||
["supervisord", "-n", "-c", "./supervisord.conf"],
|
||||
env=env,
|
||||
cwd=str(BACKEND_DIR_PATH),
|
||||
stdout=log_file,
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
register_process(process)
|
||||
|
||||
|
||||
def start_shared_services(run_id: uuid.UUID) -> SharedServicesConfig:
|
||||
"""Start Postgres and Vespa using docker-compose.
|
||||
Returns (postgres_port, vespa_port, vespa_tenant_port, model_server_port)
|
||||
"""
|
||||
print("Starting database services...")
|
||||
|
||||
postgres_port = get_random_port()
|
||||
vespa_port = get_random_port()
|
||||
vespa_tenant_port = get_random_port()
|
||||
model_server_port = get_random_port()
|
||||
|
||||
minimal_compose = {
|
||||
"services": {
|
||||
"relational_db": {
|
||||
"image": "postgres:15.2-alpine",
|
||||
"command": "-c 'max_connections=1000'",
|
||||
"environment": {
|
||||
"POSTGRES_USER": os.getenv("POSTGRES_USER", "postgres"),
|
||||
"POSTGRES_PASSWORD": os.getenv("POSTGRES_PASSWORD", "password"),
|
||||
},
|
||||
"ports": [f"{postgres_port}:5432"],
|
||||
},
|
||||
"index": {
|
||||
"image": "vespaengine/vespa:8.277.17",
|
||||
"ports": [
|
||||
f"{vespa_port}:8081", # Main Vespa port
|
||||
f"{vespa_tenant_port}:19071", # Tenant port
|
||||
],
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
# Write the minimal compose file
|
||||
temp_compose = Path("/tmp/docker-compose.minimal.yml")
|
||||
with open(temp_compose, "w") as f:
|
||||
yaml.dump(minimal_compose, f)
|
||||
|
||||
# Start the services
|
||||
subprocess.run(
|
||||
[
|
||||
"docker",
|
||||
"compose",
|
||||
"-f",
|
||||
str(temp_compose),
|
||||
"-p",
|
||||
get_shared_services_stack_name(run_id),
|
||||
"up",
|
||||
"-d",
|
||||
],
|
||||
check=True,
|
||||
)
|
||||
|
||||
# Start the shared model server
|
||||
env = os.environ.copy()
|
||||
env.update(
|
||||
{
|
||||
"POSTGRES_HOST": "localhost",
|
||||
"POSTGRES_PORT": str(postgres_port),
|
||||
"VESPA_HOST": "localhost",
|
||||
"VESPA_PORT": str(vespa_port),
|
||||
"VESPA_TENANT_PORT": str(vespa_tenant_port),
|
||||
"LOG_LEVEL": "debug",
|
||||
}
|
||||
)
|
||||
|
||||
# Open log file for shared model server in /tmp
|
||||
log_file = open("/tmp/shared_model_server.txt", "w")
|
||||
|
||||
process = subprocess.Popen(
|
||||
[
|
||||
"uvicorn",
|
||||
"model_server.main:app",
|
||||
"--host",
|
||||
"0.0.0.0",
|
||||
"--port",
|
||||
str(model_server_port),
|
||||
],
|
||||
env=env,
|
||||
cwd=str(BACKEND_DIR_PATH),
|
||||
stdout=log_file,
|
||||
stderr=subprocess.STDOUT,
|
||||
)
|
||||
atexit.register(cleanup_pid, process.pid)
|
||||
|
||||
shared_services_config = SharedServicesConfig(
|
||||
run_id, postgres_port, vespa_port, vespa_tenant_port, model_server_port
|
||||
)
|
||||
print(f"Shared services config: {shared_services_config}")
|
||||
return shared_services_config
|
||||
|
||||
|
||||
def prepare_vespa(instance_ids: list[int], vespa_tenant_port: int) -> None:
|
||||
schema_names = [
|
||||
(
|
||||
f"{get_vector_db_prefix(instance_id)}_{DEFAULT_SCHEMA_NAME}",
|
||||
DEFAULT_EMBEDDING_DIMENSION,
|
||||
False,
|
||||
)
|
||||
for instance_id in instance_ids
|
||||
]
|
||||
print(f"Creating indices: {schema_names}")
|
||||
for _ in range(7):
|
||||
try:
|
||||
VespaIndex.create_indices(
|
||||
schema_names, f"http://localhost:{vespa_tenant_port}/application/v2"
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
print(f"Error creating indices: {e}. Trying again in 5 seconds...")
|
||||
time.sleep(5)
|
||||
|
||||
raise RuntimeError("Failed to create indices in Vespa")
|
||||
|
||||
|
||||
def start_redis(
|
||||
instance_num: int,
|
||||
register_process: Callable[[subprocess.Popen], None],
|
||||
) -> int:
|
||||
"""Start a Redis instance for a specific deployment."""
|
||||
print(f"Starting Redis for instance {instance_num}...")
|
||||
|
||||
redis_port = get_random_port()
|
||||
container_name = f"redis-onyx-{instance_num}"
|
||||
|
||||
# Start Redis using docker run
|
||||
subprocess.run(
|
||||
[
|
||||
"docker",
|
||||
"run",
|
||||
"-d",
|
||||
"--name",
|
||||
container_name,
|
||||
"-p",
|
||||
f"{redis_port}:6379",
|
||||
"redis:7.4-alpine",
|
||||
"redis-server",
|
||||
"--save",
|
||||
'""',
|
||||
"--appendonly",
|
||||
"no",
|
||||
],
|
||||
check=True,
|
||||
)
|
||||
|
||||
return redis_port
|
||||
|
||||
|
||||
def launch_instance(
|
||||
instance_num: int,
|
||||
postgres_port: int,
|
||||
vespa_port: int,
|
||||
vespa_tenant_port: int,
|
||||
model_server_port: int,
|
||||
register_process: Callable[[subprocess.Popen], None],
|
||||
) -> DeploymentConfig:
|
||||
"""Launch a Docker Compose instance with custom ports."""
|
||||
api_port = get_random_port()
|
||||
web_port = get_random_port()
|
||||
nginx_port = get_random_port()
|
||||
|
||||
# Start Redis for this instance
|
||||
redis_port = start_redis(instance_num, register_process)
|
||||
|
||||
try:
|
||||
db_name = setup_db(instance_num, postgres_port)
|
||||
api_port = start_api_server(
|
||||
instance_num,
|
||||
model_server_port, # Use shared model server port
|
||||
postgres_port,
|
||||
vespa_port,
|
||||
vespa_tenant_port,
|
||||
redis_port,
|
||||
register_process,
|
||||
)
|
||||
start_background(
|
||||
instance_num,
|
||||
postgres_port,
|
||||
vespa_port,
|
||||
vespa_tenant_port,
|
||||
redis_port,
|
||||
register_process,
|
||||
)
|
||||
except Exception as e:
|
||||
print(f"Failed to start API server for instance {instance_num}: {e}")
|
||||
raise
|
||||
|
||||
return DeploymentConfig(
|
||||
instance_num, api_port, web_port, nginx_port, redis_port, db_name
|
||||
)
|
||||
|
||||
|
||||
def wait_for_instance(
|
||||
ports: DeploymentConfig, max_attempts: int = 120, wait_seconds: int = 2
|
||||
) -> None:
|
||||
"""Wait for an instance to be healthy."""
|
||||
print(f"Waiting for instance {ports.instance_num} to be ready...")
|
||||
|
||||
for attempt in range(1, max_attempts + 1):
|
||||
try:
|
||||
response = requests.get(f"http://localhost:{ports.api_port}/health")
|
||||
if response.status_code == 200:
|
||||
print(
|
||||
f"Instance {ports.instance_num} is ready on port {ports.api_port}"
|
||||
)
|
||||
return
|
||||
raise ConnectionError(
|
||||
f"Health check returned status {response.status_code}"
|
||||
)
|
||||
except (requests.RequestException, ConnectionError):
|
||||
if attempt == max_attempts:
|
||||
raise TimeoutError(
|
||||
f"Timeout waiting for instance {ports.instance_num} "
|
||||
f"on port {ports.api_port}"
|
||||
)
|
||||
print(
|
||||
f"Waiting for instance {ports.instance_num} on port "
|
||||
f" {ports.api_port}... ({attempt}/{max_attempts})"
|
||||
)
|
||||
time.sleep(wait_seconds)
|
||||
|
||||
|
||||
def cleanup_instance(instance_num: int) -> None:
|
||||
"""Cleanup a specific instance."""
|
||||
print(f"Cleaning up instance {instance_num}...")
|
||||
temp_compose = Path(f"/tmp/docker-compose.dev.instance{instance_num}.yml")
|
||||
|
||||
try:
|
||||
subprocess.run(
|
||||
[
|
||||
"docker",
|
||||
"compose",
|
||||
"-f",
|
||||
str(temp_compose),
|
||||
"-p",
|
||||
f"onyx-stack-{instance_num}",
|
||||
"down",
|
||||
],
|
||||
check=True,
|
||||
)
|
||||
print(f"Instance {instance_num} cleaned up successfully")
|
||||
except subprocess.CalledProcessError:
|
||||
print(f"Error cleaning up instance {instance_num}")
|
||||
except FileNotFoundError:
|
||||
print(f"No compose file found for instance {instance_num}")
|
||||
finally:
|
||||
# Clean up the temporary compose file if it exists
|
||||
if temp_compose.exists():
|
||||
temp_compose.unlink()
|
||||
print(f"Removed temporary compose file for instance {instance_num}")
|
||||
|
||||
|
||||
def run_x_instances(
|
||||
num_instances: int,
|
||||
) -> tuple[SharedServicesConfig, list[DeploymentConfig]]:
|
||||
"""Start x instances of the application and return their configurations."""
|
||||
run_id = uuid.uuid4()
|
||||
instance_ids = list(range(1, num_instances + 1))
|
||||
_pids: list[int] = []
|
||||
|
||||
def register_process(process: subprocess.Popen) -> None:
|
||||
_pids.append(process.pid)
|
||||
|
||||
def cleanup_all_instances() -> None:
|
||||
"""Cleanup all instances."""
|
||||
print("Cleaning up all instances...")
|
||||
|
||||
# Stop the database services
|
||||
subprocess.run(
|
||||
[
|
||||
"docker",
|
||||
"compose",
|
||||
"-p",
|
||||
get_shared_services_stack_name(run_id),
|
||||
"-f",
|
||||
"/tmp/docker-compose.minimal.yml",
|
||||
"down",
|
||||
],
|
||||
check=True,
|
||||
)
|
||||
|
||||
# Stop and remove all Redis containers
|
||||
for instance_id in range(1, num_instances + 1):
|
||||
container_name = f"redis-onyx-{instance_id}"
|
||||
try:
|
||||
subprocess.run(["docker", "rm", "-f", container_name], check=True)
|
||||
except subprocess.CalledProcessError:
|
||||
print(f"Error cleaning up Redis container {container_name}")
|
||||
|
||||
for pid in _pids:
|
||||
cleanup_pid(pid)
|
||||
|
||||
# Register cleanup handler
|
||||
atexit.register(cleanup_all_instances)
|
||||
|
||||
# Start database services first
|
||||
print("Starting shared services...")
|
||||
shared_services_config = start_shared_services(run_id)
|
||||
|
||||
# create documents
|
||||
print("Creating indices in Vespa...")
|
||||
prepare_vespa(instance_ids, shared_services_config.vespa_tenant_port)
|
||||
|
||||
# Use ThreadPool to launch instances in parallel and collect results
|
||||
# NOTE: only kick off 10 at a time to avoid overwhelming the system
|
||||
print("Launching instances...")
|
||||
with ThreadPool(processes=len(instance_ids)) as pool:
|
||||
# Create list of arguments for each instance
|
||||
launch_args = [
|
||||
(
|
||||
i,
|
||||
shared_services_config.postgres_port,
|
||||
shared_services_config.vespa_port,
|
||||
shared_services_config.vespa_tenant_port,
|
||||
shared_services_config.model_server_port,
|
||||
register_process,
|
||||
)
|
||||
for i in instance_ids
|
||||
]
|
||||
|
||||
# Launch instances and get results
|
||||
port_configs = pool.starmap(launch_instance, launch_args)
|
||||
|
||||
# Wait for all instances to be healthy
|
||||
print("Waiting for instances to be healthy...")
|
||||
with ThreadPool(processes=len(port_configs)) as pool:
|
||||
pool.map(wait_for_instance, port_configs)
|
||||
|
||||
print("All instances launched!")
|
||||
print("Database Services:")
|
||||
print(f"Postgres port: {shared_services_config.postgres_port}")
|
||||
print(f"Vespa main port: {shared_services_config.vespa_port}")
|
||||
print(f"Vespa tenant port: {shared_services_config.vespa_tenant_port}")
|
||||
print("\nApplication Instances:")
|
||||
for ports in port_configs:
|
||||
print(
|
||||
f"Instance {ports.instance_num}: "
|
||||
f"API={ports.api_port}, Web={ports.web_port}, Nginx={ports.nginx_port}"
|
||||
)
|
||||
|
||||
return shared_services_config, port_configs
|
||||
|
||||
|
||||
def main() -> None:
|
||||
shared_services_config, port_configs = run_x_instances(1)
|
||||
|
||||
# Run pytest with the API server port set
|
||||
api_port = port_configs[0].api_port # Use first instance's API port
|
||||
try:
|
||||
subprocess.run(
|
||||
["pytest", "tests/integration/openai_assistants_api"],
|
||||
env={**os.environ, "API_SERVER_PORT": str(api_port)},
|
||||
cwd=str(BACKEND_DIR_PATH),
|
||||
check=True,
|
||||
)
|
||||
except subprocess.CalledProcessError as e:
|
||||
print(f"Tests failed with exit code {e.returncode}")
|
||||
sys.exit(e.returncode)
|
||||
|
||||
time.sleep(5)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
266
backend/tests/integration/run.py
Normal file
266
backend/tests/integration/run.py
Normal file
@@ -0,0 +1,266 @@
|
||||
import multiprocessing
|
||||
import os
|
||||
import queue
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from multiprocessing.synchronize import Lock as LockType
|
||||
from pathlib import Path
|
||||
|
||||
from tests.integration.common_utils.reset import reset_all
|
||||
from tests.integration.introspection import list_all_tests
|
||||
from tests.integration.introspection import load_env_vars
|
||||
from tests.integration.kickoff import BACKEND_DIR_PATH
|
||||
from tests.integration.kickoff import DeploymentConfig
|
||||
from tests.integration.kickoff import run_x_instances
|
||||
from tests.integration.kickoff import SharedServicesConfig
|
||||
|
||||
|
||||
@dataclass
|
||||
class TestResult:
|
||||
test_name: str
|
||||
success: bool
|
||||
output: str
|
||||
error: str | None = None
|
||||
|
||||
|
||||
def run_single_test(
|
||||
test_name: str,
|
||||
deployment_config: DeploymentConfig,
|
||||
shared_services_config: SharedServicesConfig,
|
||||
result_queue: multiprocessing.Queue,
|
||||
) -> None:
|
||||
"""Run a single test with the given API port."""
|
||||
test_path, test_name = test_name.split("::")
|
||||
processed_test_name = f"{test_path.replace('.', '/')}.py::{test_name}"
|
||||
print(f"Running test: {processed_test_name}", flush=True)
|
||||
try:
|
||||
env = {
|
||||
**os.environ,
|
||||
"API_SERVER_PORT": str(deployment_config.api_port),
|
||||
"PYTHONPATH": ".",
|
||||
"GUARANTEED_FRESH_SETUP": "true",
|
||||
"POSTGRES_PORT": str(shared_services_config.postgres_port),
|
||||
"POSTGRES_DB": deployment_config.postgres_db,
|
||||
"REDIS_PORT": str(deployment_config.redis_port),
|
||||
"VESPA_PORT": str(shared_services_config.vespa_port),
|
||||
"VESPA_TENANT_PORT": str(shared_services_config.vespa_tenant_port),
|
||||
}
|
||||
result = subprocess.run(
|
||||
["pytest", processed_test_name, "-v"],
|
||||
env=env,
|
||||
cwd=str(BACKEND_DIR_PATH),
|
||||
capture_output=True,
|
||||
text=True,
|
||||
)
|
||||
result_queue.put(
|
||||
TestResult(
|
||||
test_name=test_name,
|
||||
success=result.returncode == 0,
|
||||
output=result.stdout,
|
||||
error=result.stderr if result.returncode != 0 else None,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
result_queue.put(
|
||||
TestResult(
|
||||
test_name=test_name,
|
||||
success=False,
|
||||
output="",
|
||||
error=str(e),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def worker(
|
||||
test_queue: queue.Queue[str],
|
||||
instance_queue: queue.Queue[int],
|
||||
result_queue: multiprocessing.Queue,
|
||||
shared_services_config: SharedServicesConfig,
|
||||
deployment_configs: list[DeploymentConfig],
|
||||
reset_lock: LockType,
|
||||
) -> None:
|
||||
"""Worker process that runs tests on available instances."""
|
||||
while True:
|
||||
# Get the next test from the queue
|
||||
try:
|
||||
test = test_queue.get(block=False)
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
# Get an available instance
|
||||
instance_idx = instance_queue.get()
|
||||
deployment_config = deployment_configs[
|
||||
instance_idx - 1
|
||||
] # Convert to 0-based index
|
||||
|
||||
try:
|
||||
# Run the test
|
||||
run_single_test(
|
||||
test, deployment_config, shared_services_config, result_queue
|
||||
)
|
||||
# get instance ready for next test
|
||||
print(
|
||||
f"Resetting instance for next. DB: {deployment_config.postgres_db}, "
|
||||
f"Port: {shared_services_config.postgres_port}"
|
||||
)
|
||||
# alembic is NOT thread-safe, so we need to make sure only one worker is resetting at a time
|
||||
with reset_lock:
|
||||
reset_all(
|
||||
database=deployment_config.postgres_db,
|
||||
postgres_port=str(shared_services_config.postgres_port),
|
||||
redis_port=deployment_config.redis_port,
|
||||
silence_logs=True,
|
||||
# indices are created during the kickoff process, no need to recreate them
|
||||
skip_creating_indices=True,
|
||||
# use the special vespa port
|
||||
document_id_endpoint=(
|
||||
f"http://localhost:{shared_services_config.vespa_port}"
|
||||
"/document/v1/default/{{index_name}}/docid"
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
# Log the error and put it in the result queue
|
||||
error_msg = f"Critical error in worker thread for test {test}: {str(e)}"
|
||||
print(error_msg, file=sys.stderr)
|
||||
result_queue.put(
|
||||
TestResult(
|
||||
test_name=test,
|
||||
success=False,
|
||||
output="",
|
||||
error=error_msg,
|
||||
)
|
||||
)
|
||||
# Re-raise to stop the worker
|
||||
raise
|
||||
finally:
|
||||
# Put the instance back in the queue
|
||||
instance_queue.put(instance_idx)
|
||||
test_queue.task_done()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
NUM_INSTANCES = 7
|
||||
|
||||
# Get all tests
|
||||
prefixes = ["tests", "connector_job_tests"]
|
||||
tests = []
|
||||
for prefix in prefixes:
|
||||
tests += [
|
||||
f"tests/integration/{prefix}/{test_path}"
|
||||
for test_path in list_all_tests(Path(__file__).parent / prefix)
|
||||
]
|
||||
print(f"Found {len(tests)} tests to run")
|
||||
|
||||
# load env vars which will be passed into the tests
|
||||
load_env_vars(os.environ.get("IT_ENV_FILE_PATH", ".env"))
|
||||
|
||||
# For debugging
|
||||
# tests = [test for test in tests if "openai_assistants_api" in test]
|
||||
# tests = tests[:2]
|
||||
print(f"Running {len(tests)} tests")
|
||||
|
||||
# Start all instances at once
|
||||
shared_services_config, deployment_configs = run_x_instances(NUM_INSTANCES)
|
||||
|
||||
# Create queues and lock
|
||||
test_queue: queue.Queue[str] = queue.Queue()
|
||||
instance_queue: queue.Queue[int] = queue.Queue()
|
||||
result_queue: multiprocessing.Queue = multiprocessing.Queue()
|
||||
reset_lock: LockType = multiprocessing.Lock()
|
||||
|
||||
# Fill the instance queue with available instance numbers
|
||||
for i in range(1, NUM_INSTANCES + 1):
|
||||
instance_queue.put(i)
|
||||
|
||||
# Fill the test queue with all tests
|
||||
for test in tests:
|
||||
test_queue.put(test)
|
||||
# Start worker threads
|
||||
workers = []
|
||||
for _ in range(NUM_INSTANCES):
|
||||
worker_thread = threading.Thread(
|
||||
target=worker,
|
||||
args=(
|
||||
test_queue,
|
||||
instance_queue,
|
||||
result_queue,
|
||||
shared_services_config,
|
||||
deployment_configs,
|
||||
reset_lock,
|
||||
),
|
||||
)
|
||||
worker_thread.start()
|
||||
workers.append(worker_thread)
|
||||
|
||||
# Monitor workers and fail fast if any die
|
||||
try:
|
||||
while any(w.is_alive() for w in workers):
|
||||
# Check if all tests are done
|
||||
if test_queue.empty() and all(not w.is_alive() for w in workers):
|
||||
break
|
||||
|
||||
# Check for dead workers that died with unfinished tests
|
||||
if not test_queue.empty() and any(not w.is_alive() for w in workers):
|
||||
print(
|
||||
"\nCritical: Worker thread(s) died with tests remaining!",
|
||||
file=sys.stderr,
|
||||
)
|
||||
sys.exit(1)
|
||||
|
||||
time.sleep(0.1) # Avoid busy waiting
|
||||
|
||||
# Collect results
|
||||
print("Collecting results")
|
||||
results: list[TestResult] = []
|
||||
while not result_queue.empty():
|
||||
results.append(result_queue.get())
|
||||
|
||||
# Print results
|
||||
print("\nTest Results:")
|
||||
failed = False
|
||||
failed_tests: list[str] = []
|
||||
total_tests = len(results)
|
||||
passed_tests = 0
|
||||
|
||||
for result in results:
|
||||
status = "✅ PASSED" if result.success else "❌ FAILED"
|
||||
print(f"{status} - {result.test_name}")
|
||||
if result.success:
|
||||
passed_tests += 1
|
||||
else:
|
||||
failed = True
|
||||
failed_tests.append(result.test_name)
|
||||
print("Error output:")
|
||||
print(result.error)
|
||||
print("Test output:")
|
||||
print(result.output)
|
||||
print("-" * 80)
|
||||
|
||||
# Print summary
|
||||
print("\nTest Summary:")
|
||||
print(f"Total Tests: {total_tests}")
|
||||
print(f"Passed: {passed_tests}")
|
||||
print(f"Failed: {len(failed_tests)}")
|
||||
|
||||
if failed_tests:
|
||||
print("\nFailed Tests:")
|
||||
for test_name in failed_tests:
|
||||
print(f"❌ {test_name}")
|
||||
print()
|
||||
|
||||
if failed:
|
||||
sys.exit(1)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\nTest run interrupted by user", file=sys.stderr)
|
||||
sys.exit(130) # Standard exit code for SIGINT
|
||||
except Exception as e:
|
||||
print(f"\nCritical error during result collection: {str(e)}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -11,7 +11,6 @@ from langchain_core.messages import HumanMessage
|
||||
from langchain_core.messages import SystemMessage
|
||||
from langchain_core.messages import ToolCall
|
||||
from langchain_core.messages import ToolCallChunk
|
||||
from pytest_mock import MockerFixture
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from onyx.chat.answer import Answer
|
||||
@@ -26,7 +25,6 @@ from onyx.chat.models import StreamStopReason
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import AnswerPromptBuilder
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_system_message
|
||||
from onyx.chat.prompt_builder.answer_prompt_builder import default_build_user_message
|
||||
from onyx.context.search.models import RerankingDetails
|
||||
from onyx.context.search.models import SearchRequest
|
||||
from onyx.llm.interfaces import LLM
|
||||
from onyx.tools.force import ForceUseTool
|
||||
@@ -37,7 +35,6 @@ from onyx.tools.tool_implementations.search.search_tool import SEARCH_DOC_CONTEN
|
||||
from onyx.tools.tool_implementations.search_like_tool_utils import (
|
||||
FINAL_CONTEXT_DOCUMENTS_ID,
|
||||
)
|
||||
from shared_configs.enums import RerankerProvider
|
||||
from tests.unit.onyx.chat.conftest import DEFAULT_SEARCH_ARGS
|
||||
from tests.unit.onyx.chat.conftest import QUERY
|
||||
|
||||
@@ -47,20 +44,6 @@ def answer_instance(
|
||||
mock_llm: LLM,
|
||||
answer_style_config: AnswerStyleConfig,
|
||||
prompt_config: PromptConfig,
|
||||
mocker: MockerFixture,
|
||||
) -> Answer:
|
||||
mocker.patch(
|
||||
"onyx.chat.answer.gpu_status_request",
|
||||
return_value=True,
|
||||
)
|
||||
return _answer_fixture_impl(mock_llm, answer_style_config, prompt_config)
|
||||
|
||||
|
||||
def _answer_fixture_impl(
|
||||
mock_llm: LLM,
|
||||
answer_style_config: AnswerStyleConfig,
|
||||
prompt_config: PromptConfig,
|
||||
rerank_settings: RerankingDetails | None = None,
|
||||
) -> Answer:
|
||||
return Answer(
|
||||
prompt_builder=AnswerPromptBuilder(
|
||||
@@ -81,13 +64,13 @@ def _answer_fixture_impl(
|
||||
llm=mock_llm,
|
||||
fast_llm=mock_llm,
|
||||
force_use_tool=ForceUseTool(force_use=False, tool_name="", args=None),
|
||||
search_request=SearchRequest(query=QUERY, rerank_settings=rerank_settings),
|
||||
search_request=SearchRequest(query=QUERY),
|
||||
chat_session_id=UUID("123e4567-e89b-12d3-a456-426614174000"),
|
||||
current_agent_message_id=0,
|
||||
)
|
||||
|
||||
|
||||
def test_basic_answer(answer_instance: Answer, mocker: MockerFixture) -> None:
|
||||
def test_basic_answer(answer_instance: Answer) -> None:
|
||||
mock_llm = cast(Mock, answer_instance.graph_config.tooling.primary_llm)
|
||||
mock_llm.stream.return_value = [
|
||||
AIMessageChunk(content="This is a "),
|
||||
@@ -380,49 +363,3 @@ def test_is_cancelled(answer_instance: Answer) -> None:
|
||||
|
||||
# Verify LLM calls
|
||||
mock_llm.stream.assert_called_once()
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"gpu_enabled,is_local_model",
|
||||
[
|
||||
(True, False),
|
||||
(False, True),
|
||||
(True, True),
|
||||
(False, False),
|
||||
],
|
||||
)
|
||||
def test_no_slow_reranking(
|
||||
gpu_enabled: bool,
|
||||
is_local_model: bool,
|
||||
mock_llm: LLM,
|
||||
answer_style_config: AnswerStyleConfig,
|
||||
prompt_config: PromptConfig,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"onyx.chat.answer.gpu_status_request",
|
||||
return_value=gpu_enabled,
|
||||
)
|
||||
rerank_settings = (
|
||||
None
|
||||
if is_local_model
|
||||
else RerankingDetails(
|
||||
rerank_model_name="test_model",
|
||||
rerank_api_url="test_url",
|
||||
rerank_api_key="test_key",
|
||||
num_rerank=10,
|
||||
rerank_provider_type=RerankerProvider.COHERE,
|
||||
)
|
||||
)
|
||||
answer_instance = _answer_fixture_impl(
|
||||
mock_llm, answer_style_config, prompt_config, rerank_settings=rerank_settings
|
||||
)
|
||||
|
||||
assert (
|
||||
answer_instance.graph_config.inputs.search_request.rerank_settings
|
||||
== rerank_settings
|
||||
)
|
||||
assert (
|
||||
answer_instance.graph_config.behavior.allow_agent_reranking == gpu_enabled
|
||||
or not is_local_model
|
||||
)
|
||||
|
||||
@@ -36,12 +36,7 @@ def test_skip_gen_ai_answer_generation_flag(
|
||||
mock_search_tool: SearchTool,
|
||||
answer_style_config: AnswerStyleConfig,
|
||||
prompt_config: PromptConfig,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
mocker.patch(
|
||||
"onyx.chat.answer.gpu_status_request",
|
||||
return_value=True,
|
||||
)
|
||||
question = config["question"]
|
||||
skip_gen_ai_answer_generation = config["skip_gen_ai_answer_generation"]
|
||||
|
||||
|
||||
@@ -1,61 +0,0 @@
|
||||
import time
|
||||
|
||||
import pytest
|
||||
|
||||
from onyx.utils.threadpool_concurrency import run_with_timeout
|
||||
|
||||
|
||||
def test_run_with_timeout_completes() -> None:
|
||||
"""Test that a function that completes within timeout works correctly"""
|
||||
|
||||
def quick_function(x: int) -> int:
|
||||
return x * 2
|
||||
|
||||
result = run_with_timeout(1.0, quick_function, x=21)
|
||||
assert result == 42
|
||||
|
||||
|
||||
@pytest.mark.parametrize("slow,timeout", [(1, 0.1), (0.3, 0.2)])
|
||||
def test_run_with_timeout_raises_on_timeout(slow: float, timeout: float) -> None:
|
||||
"""Test that a function that exceeds timeout raises TimeoutError"""
|
||||
|
||||
def slow_function() -> None:
|
||||
time.sleep(slow) # Sleep for 2 seconds
|
||||
|
||||
with pytest.raises(TimeoutError) as exc_info:
|
||||
start = time.time()
|
||||
run_with_timeout(timeout, slow_function) # Set timeout to 0.1 seconds
|
||||
end = time.time()
|
||||
assert end - start >= timeout
|
||||
assert end - start < (slow + timeout) / 2
|
||||
assert f"timed out after {timeout} seconds" in str(exc_info.value)
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore::pytest.PytestUnhandledThreadExceptionWarning")
|
||||
def test_run_with_timeout_propagates_exceptions() -> None:
|
||||
"""Test that other exceptions from the function are propagated properly"""
|
||||
|
||||
def error_function() -> None:
|
||||
raise ValueError("Test error")
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
run_with_timeout(1.0, error_function)
|
||||
|
||||
assert "Test error" in str(exc_info.value)
|
||||
|
||||
|
||||
def test_run_with_timeout_with_args_and_kwargs() -> None:
|
||||
"""Test that args and kwargs are properly passed to the function"""
|
||||
|
||||
def complex_function(x: int, y: int, multiply: bool = False) -> int:
|
||||
if multiply:
|
||||
return x * y
|
||||
return x + y
|
||||
|
||||
# Test with just positional args
|
||||
result1 = run_with_timeout(1.0, complex_function, x=5, y=3)
|
||||
assert result1 == 8
|
||||
|
||||
# Test with positional and keyword args
|
||||
result2 = run_with_timeout(1.0, complex_function, x=5, y=3, multiply=True)
|
||||
assert result2 == 15
|
||||
@@ -1,11 +1,6 @@
|
||||
"use client";
|
||||
|
||||
import {
|
||||
redirect,
|
||||
usePathname,
|
||||
useRouter,
|
||||
useSearchParams,
|
||||
} from "next/navigation";
|
||||
import { redirect, useRouter, useSearchParams } from "next/navigation";
|
||||
import {
|
||||
BackendChatSession,
|
||||
BackendMessage,
|
||||
@@ -135,7 +130,6 @@ import {
|
||||
} from "@/lib/browserUtilities";
|
||||
import { Button } from "@/components/ui/button";
|
||||
import { ConfirmEntityModal } from "@/components/modals/ConfirmEntityModal";
|
||||
import { MessageChannel } from "node:worker_threads";
|
||||
|
||||
const TEMP_USER_MESSAGE_ID = -1;
|
||||
const TEMP_ASSISTANT_MESSAGE_ID = -2;
|
||||
@@ -1151,7 +1145,6 @@ export function ChatPage({
|
||||
regenerationRequest?: RegenerationRequest | null;
|
||||
overrideFileDescriptors?: FileDescriptor[];
|
||||
} = {}) => {
|
||||
navigatingAway.current = false;
|
||||
let frozenSessionId = currentSessionId();
|
||||
updateCanContinue(false, frozenSessionId);
|
||||
|
||||
@@ -1274,6 +1267,7 @@ export function ChatPage({
|
||||
let stackTrace: string | null = null;
|
||||
|
||||
let sub_questions: SubQuestionDetail[] = [];
|
||||
let second_level_sub_questions: SubQuestionDetail[] = [];
|
||||
let is_generating: boolean = false;
|
||||
let second_level_generating: boolean = false;
|
||||
let finalMessage: BackendMessage | null = null;
|
||||
@@ -1297,7 +1291,7 @@ export function ChatPage({
|
||||
|
||||
const stack = new CurrentMessageFIFO();
|
||||
updateCurrentMessageFIFO(stack, {
|
||||
signal: controller.signal,
|
||||
signal: controller.signal, // Add this line
|
||||
message: currMessage,
|
||||
alternateAssistantId: currentAssistantId,
|
||||
fileDescriptors: overrideFileDescriptors || currentMessageFiles,
|
||||
@@ -1718,10 +1712,7 @@ export function ChatPage({
|
||||
const newUrl = buildChatUrl(searchParams, currChatSessionId, null);
|
||||
// newUrl is like /chat?chatId=10
|
||||
// current page is like /chat
|
||||
|
||||
if (pathname == "/chat" && !navigatingAway.current) {
|
||||
router.push(newUrl, { scroll: false });
|
||||
}
|
||||
router.push(newUrl, { scroll: false });
|
||||
}
|
||||
}
|
||||
if (
|
||||
@@ -2095,31 +2086,6 @@ export function ChatPage({
|
||||
llmOverrideManager.updateImageFilesPresent(imageFileInMessageHistory);
|
||||
}, [imageFileInMessageHistory]);
|
||||
|
||||
const pathname = usePathname();
|
||||
useEffect(() => {
|
||||
return () => {
|
||||
// Cleanup which only runs when the component unmounts (i.e. when you navigate away).
|
||||
const currentSession = currentSessionId();
|
||||
const controller = abortControllersRef.current.get(currentSession);
|
||||
if (controller) {
|
||||
controller.abort();
|
||||
navigatingAway.current = true;
|
||||
setAbortControllers((prev) => {
|
||||
const newControllers = new Map(prev);
|
||||
newControllers.delete(currentSession);
|
||||
return newControllers;
|
||||
});
|
||||
}
|
||||
};
|
||||
}, [pathname]);
|
||||
|
||||
const navigatingAway = useRef(false);
|
||||
// Keep a ref to abortControllers to ensure we always have the latest value
|
||||
const abortControllersRef = useRef(abortControllers);
|
||||
useEffect(() => {
|
||||
abortControllersRef.current = abortControllers;
|
||||
}, [abortControllers]);
|
||||
|
||||
useSidebarShortcut(router, toggleSidebar);
|
||||
|
||||
const [sharedChatSession, setSharedChatSession] =
|
||||
@@ -2334,7 +2300,7 @@ export function ChatPage({
|
||||
fixed
|
||||
left-0
|
||||
z-40
|
||||
bg-neutral-200
|
||||
bg-background-100
|
||||
h-screen
|
||||
transition-all
|
||||
bg-opacity-80
|
||||
@@ -2591,21 +2557,12 @@ export function ChatPage({
|
||||
) {
|
||||
return <></>;
|
||||
}
|
||||
const nextMessage =
|
||||
messageHistory.length > i + 1
|
||||
? messageHistory[i + 1]
|
||||
: null;
|
||||
return (
|
||||
<div
|
||||
id={`message-${message.messageId}`}
|
||||
key={messageReactComponentKey}
|
||||
>
|
||||
<HumanMessage
|
||||
disableSwitchingForStreaming={
|
||||
(nextMessage &&
|
||||
nextMessage.is_generating) ||
|
||||
false
|
||||
}
|
||||
stopGenerating={stopGenerating}
|
||||
content={message.message}
|
||||
files={message.files}
|
||||
|
||||
@@ -94,7 +94,7 @@ export function AgenticToggle({
|
||||
Agent Search (BETA)
|
||||
</h3>
|
||||
</div>
|
||||
<p className="text-xs text-neutral-600 dark:text-neutral-700 mb-2">
|
||||
<p className="text-xs text-neutarl-600 dark:text-neutral-700 mb-2">
|
||||
Use AI agents to break down questions and run deep iterative
|
||||
research through promising pathways. Gives more thorough and
|
||||
accurate responses but takes slightly longer.
|
||||
|
||||
@@ -113,7 +113,7 @@ export default function LLMPopover({
|
||||
<Popover open={isOpen} onOpenChange={setIsOpen}>
|
||||
<PopoverTrigger asChild>
|
||||
<button
|
||||
className="dark:text-[#fff] text-[#000] focus:outline-none"
|
||||
className="focus:outline-none"
|
||||
data-testid="llm-popover-trigger"
|
||||
>
|
||||
<ChatInputOption
|
||||
|
||||
@@ -250,7 +250,7 @@ export async function* sendMessage({
|
||||
throw new Error(`HTTP error! status: ${response.status}`);
|
||||
}
|
||||
|
||||
yield* handleSSEStream<PacketType>(response, signal);
|
||||
yield* handleSSEStream<PacketType>(response);
|
||||
}
|
||||
|
||||
export async function nameChatSession(chatSessionId: string) {
|
||||
|
||||
@@ -9,12 +9,6 @@ import React, {
|
||||
useMemo,
|
||||
useState,
|
||||
} from "react";
|
||||
import {
|
||||
Tooltip,
|
||||
TooltipContent,
|
||||
TooltipProvider,
|
||||
TooltipTrigger,
|
||||
} from "@/components/ui/tooltip";
|
||||
import ReactMarkdown from "react-markdown";
|
||||
import { OnyxDocument, FilteredOnyxDocument } from "@/lib/search/interfaces";
|
||||
import remarkGfm from "remark-gfm";
|
||||
@@ -314,7 +308,7 @@ export const AgenticMessage = ({
|
||||
const renderedAlternativeMarkdown = useMemo(() => {
|
||||
return (
|
||||
<ReactMarkdown
|
||||
className="prose dark:prose-invert max-w-full text-base"
|
||||
className="prose max-w-full text-base"
|
||||
components={{
|
||||
...markdownComponents,
|
||||
code: ({ node, className, children }: any) => {
|
||||
@@ -341,7 +335,7 @@ export const AgenticMessage = ({
|
||||
const renderedMarkdown = useMemo(() => {
|
||||
return (
|
||||
<ReactMarkdown
|
||||
className="prose dark:prose-invert max-w-full text-base"
|
||||
className="prose max-w-full text-base"
|
||||
components={markdownComponents}
|
||||
remarkPlugins={[remarkGfm, remarkMath]}
|
||||
rehypePlugins={[[rehypePrism, { ignoreMissing: true }], rehypeKatex]}
|
||||
@@ -536,7 +530,6 @@ export const AgenticMessage = ({
|
||||
{includeMessageSwitcher && (
|
||||
<div className="-mx-1 mr-auto">
|
||||
<MessageSwitcher
|
||||
disableForStreaming={!isComplete}
|
||||
currentPage={currentMessageInd + 1}
|
||||
totalPages={otherMessagesCanSwitchTo.length}
|
||||
handlePrevious={() => {
|
||||
@@ -623,7 +616,6 @@ export const AgenticMessage = ({
|
||||
{includeMessageSwitcher && (
|
||||
<div className="-mx-1 mr-auto">
|
||||
<MessageSwitcher
|
||||
disableForStreaming={!isComplete}
|
||||
currentPage={currentMessageInd + 1}
|
||||
totalPages={otherMessagesCanSwitchTo.length}
|
||||
handlePrevious={() => {
|
||||
@@ -702,52 +694,27 @@ function MessageSwitcher({
|
||||
totalPages,
|
||||
handlePrevious,
|
||||
handleNext,
|
||||
disableForStreaming,
|
||||
}: {
|
||||
currentPage: number;
|
||||
totalPages: number;
|
||||
handlePrevious: () => void;
|
||||
handleNext: () => void;
|
||||
disableForStreaming?: boolean;
|
||||
}) {
|
||||
return (
|
||||
<div className="flex items-center text-sm space-x-0.5">
|
||||
<TooltipProvider>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<div>
|
||||
<Hoverable
|
||||
icon={FiChevronLeft}
|
||||
onClick={currentPage === 1 ? undefined : handlePrevious}
|
||||
/>
|
||||
</div>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>
|
||||
{disableForStreaming ? "Disabled" : "Previous"}
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
<Hoverable
|
||||
icon={FiChevronLeft}
|
||||
onClick={currentPage === 1 ? undefined : handlePrevious}
|
||||
/>
|
||||
|
||||
<span className="text-text-darker select-none">
|
||||
{currentPage} / {totalPages}
|
||||
{disableForStreaming ? "Complete" : "Generating"}
|
||||
</span>
|
||||
|
||||
<TooltipProvider>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<div>
|
||||
<Hoverable
|
||||
icon={FiChevronRight}
|
||||
onClick={currentPage === totalPages ? undefined : handleNext}
|
||||
/>
|
||||
</div>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>
|
||||
{disableForStreaming ? "Disabled" : "Next"}
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
<Hoverable
|
||||
icon={FiChevronRight}
|
||||
onClick={currentPage === totalPages ? undefined : handleNext}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
|
||||
@@ -383,7 +383,7 @@ export const AIMessage = ({
|
||||
dangerouslySetInnerHTML={{ __html: htmlContent }}
|
||||
/>
|
||||
<ReactMarkdown
|
||||
className="prose dark:prose-invert max-w-full text-base"
|
||||
className="prose max-w-full text-base"
|
||||
components={markdownComponents}
|
||||
remarkPlugins={[remarkGfm, remarkMath]}
|
||||
rehypePlugins={[[rehypePrism, { ignoreMissing: true }], rehypeKatex]}
|
||||
@@ -495,10 +495,7 @@ export const AIMessage = ({
|
||||
{docs && docs.length > 0 && (
|
||||
<div
|
||||
className={`mobile:hidden ${
|
||||
(query ||
|
||||
toolCall?.tool_name ===
|
||||
INTERNET_SEARCH_TOOL_NAME) &&
|
||||
"mt-2"
|
||||
query && "mt-2"
|
||||
} -mx-8 w-full mb-4 flex relative`}
|
||||
>
|
||||
<div className="w-full">
|
||||
@@ -798,67 +795,27 @@ function MessageSwitcher({
|
||||
totalPages,
|
||||
handlePrevious,
|
||||
handleNext,
|
||||
disableForStreaming,
|
||||
}: {
|
||||
currentPage: number;
|
||||
totalPages: number;
|
||||
handlePrevious: () => void;
|
||||
handleNext: () => void;
|
||||
disableForStreaming?: boolean;
|
||||
}) {
|
||||
return (
|
||||
<div className="flex items-center text-sm space-x-0.5">
|
||||
<TooltipProvider>
|
||||
<Tooltip>
|
||||
<TooltipTrigger asChild>
|
||||
<div>
|
||||
<Hoverable
|
||||
icon={FiChevronLeft}
|
||||
onClick={
|
||||
disableForStreaming
|
||||
? () => null
|
||||
: currentPage === 1
|
||||
? undefined
|
||||
: handlePrevious
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>
|
||||
{disableForStreaming
|
||||
? "Wait for agent message to complete"
|
||||
: "Previous"}
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
<Hoverable
|
||||
icon={FiChevronLeft}
|
||||
onClick={currentPage === 1 ? undefined : handlePrevious}
|
||||
/>
|
||||
|
||||
<span className="text-text-darker select-none">
|
||||
{currentPage} / {totalPages}
|
||||
</span>
|
||||
|
||||
<TooltipProvider>
|
||||
<Tooltip>
|
||||
<TooltipTrigger>
|
||||
<div>
|
||||
<Hoverable
|
||||
icon={FiChevronRight}
|
||||
onClick={
|
||||
disableForStreaming
|
||||
? () => null
|
||||
: currentPage === totalPages
|
||||
? undefined
|
||||
: handleNext
|
||||
}
|
||||
/>
|
||||
</div>
|
||||
</TooltipTrigger>
|
||||
<TooltipContent>
|
||||
{disableForStreaming
|
||||
? "Wait for agent message to complete"
|
||||
: "Next"}
|
||||
</TooltipContent>
|
||||
</Tooltip>
|
||||
</TooltipProvider>
|
||||
<Hoverable
|
||||
icon={FiChevronRight}
|
||||
onClick={currentPage === totalPages ? undefined : handleNext}
|
||||
/>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
@@ -872,7 +829,6 @@ export const HumanMessage = ({
|
||||
onMessageSelection,
|
||||
shared,
|
||||
stopGenerating = () => null,
|
||||
disableSwitchingForStreaming = false,
|
||||
}: {
|
||||
shared?: boolean;
|
||||
content: string;
|
||||
@@ -882,7 +838,6 @@ export const HumanMessage = ({
|
||||
onEdit?: (editedContent: string) => void;
|
||||
onMessageSelection?: (messageId: number) => void;
|
||||
stopGenerating?: () => void;
|
||||
disableSwitchingForStreaming?: boolean;
|
||||
}) => {
|
||||
const textareaRef = useRef<HTMLTextAreaElement>(null);
|
||||
|
||||
@@ -1112,7 +1067,6 @@ export const HumanMessage = ({
|
||||
otherMessagesCanSwitchTo.length > 1 && (
|
||||
<div className="ml-auto mr-3">
|
||||
<MessageSwitcher
|
||||
disableForStreaming={disableSwitchingForStreaming}
|
||||
currentPage={currentMessageInd + 1}
|
||||
totalPages={otherMessagesCanSwitchTo.length}
|
||||
handlePrevious={() => {
|
||||
|
||||
@@ -294,7 +294,7 @@ const SubQuestionDisplay: React.FC<{
|
||||
const renderedMarkdown = useMemo(() => {
|
||||
return (
|
||||
<ReactMarkdown
|
||||
className="prose dark:prose-invert max-w-full text-base"
|
||||
className="prose max-w-full text-base"
|
||||
components={markdownComponents}
|
||||
remarkPlugins={[remarkGfm, remarkMath]}
|
||||
rehypePlugins={[rehypeKatex]}
|
||||
@@ -340,7 +340,7 @@ const SubQuestionDisplay: React.FC<{
|
||||
{subQuestion?.question || temporaryDisplay?.question}
|
||||
</div>
|
||||
<ChevronDown
|
||||
className={`mt-0.5 flex-none text-text-darker transition-transform duration-500 ease-in-out ${
|
||||
className={`mt-0.5 text-text-darker transition-transform duration-500 ease-in-out ${
|
||||
toggled ? "" : "-rotate-90"
|
||||
}`}
|
||||
size={20}
|
||||
@@ -632,7 +632,9 @@ const SubQuestionsDisplay: React.FC<SubQuestionsDisplayProps> = ({
|
||||
}
|
||||
`}</style>
|
||||
<div className="relative">
|
||||
{/* {subQuestions.map((subQuestion, index) => ( */}
|
||||
{memoizedSubQuestions.map((subQuestion, index) => (
|
||||
// {dynamicSubQuestions.map((subQuestion, index) => (
|
||||
<SubQuestionDisplay
|
||||
currentlyOpen={
|
||||
currentlyOpenQuestion?.level === subQuestion.level &&
|
||||
|
||||
@@ -131,7 +131,7 @@ const StandardAnswersTableRow = ({
|
||||
/>,
|
||||
<ReactMarkdown
|
||||
key={`answer-${standardAnswer.id}`}
|
||||
className="prose dark:prose-invert"
|
||||
className="prose"
|
||||
remarkPlugins={[remarkGfm]}
|
||||
>
|
||||
{standardAnswer.answer}
|
||||
|
||||
@@ -562,7 +562,6 @@ body {
|
||||
.prose :where(pre):not(:where([class~="not-prose"], [class~="not-prose"] *)) {
|
||||
background-color: theme("colors.code-bg");
|
||||
font-size: theme("fontSize.code-sm");
|
||||
color: #fff;
|
||||
}
|
||||
|
||||
pre[class*="language-"],
|
||||
@@ -656,3 +655,16 @@ ul > li > p {
|
||||
display: inline;
|
||||
/* Make paragraphs inline to reduce vertical space */
|
||||
}
|
||||
|
||||
.dark strong {
|
||||
color: white;
|
||||
}
|
||||
|
||||
.prose.dark li,
|
||||
.prose.dark h1,
|
||||
.prose.dark h2,
|
||||
.prose.dark h3,
|
||||
.prose.dark h4,
|
||||
.prose.dark h5 {
|
||||
color: #e5e5e5;
|
||||
}
|
||||
|
||||
@@ -17,7 +17,7 @@ export const Hoverable: React.FC<{
|
||||
<div className="flex items-center">
|
||||
<Icon
|
||||
size={size}
|
||||
className="dark:text-[#B4B4B4] text-neutral-600 rounded h-fit cursor-pointer"
|
||||
className="hover:bg-background-chat-hover dark:text-[#B4B4B4] text-neutral-600 rounded h-fit cursor-pointer"
|
||||
/>
|
||||
{hoverText && (
|
||||
<div className="max-w-0 leading-none whitespace-nowrap overflow-hidden transition-all duration-300 ease-in-out group-hover:max-w-xs group-hover:ml-2">
|
||||
|
||||
@@ -50,7 +50,7 @@ export function SearchResultIcon({ url }: { url: string }) {
|
||||
return <SourceIcon sourceType={ValidSources.Web} iconSize={18} />;
|
||||
}
|
||||
if (url.includes("docs.onyx.app")) {
|
||||
return <OnyxIcon size={18} className="dark:text-[#fff] text-[#000]" />;
|
||||
return <OnyxIcon size={18} />;
|
||||
}
|
||||
|
||||
return (
|
||||
|
||||
@@ -23,7 +23,7 @@ export function WebResultIcon({
|
||||
return (
|
||||
<>
|
||||
{hostname == "docs.onyx.app" ? (
|
||||
<OnyxIcon size={size} className="dark:text-[#fff] text-[#000]" />
|
||||
<OnyxIcon size={size} />
|
||||
) : !error ? (
|
||||
<img
|
||||
className="my-0 rounded-full py-0"
|
||||
|
||||
@@ -432,10 +432,7 @@ export const MarkdownFormField = ({
|
||||
</div>
|
||||
{isPreviewOpen ? (
|
||||
<div className="p-4 border-t border-background-300">
|
||||
<ReactMarkdown
|
||||
className="prose dark:prose-invert"
|
||||
remarkPlugins={[remarkGfm]}
|
||||
>
|
||||
<ReactMarkdown className="prose" remarkPlugins={[remarkGfm]}>
|
||||
{field.value}
|
||||
</ReactMarkdown>
|
||||
</div>
|
||||
|
||||
@@ -9,7 +9,7 @@ export default function BlurBackground({
|
||||
<div
|
||||
onClick={onClick}
|
||||
className={`
|
||||
desktop:hidden w-full h-full fixed inset-0 bg-neutral-700 bg-opacity-50 backdrop-blur-sm z-30 transition-opacity duration-300 ease-in-out ${
|
||||
desktop:hidden w-full h-full fixed inset-0 bg-black bg-opacity-50 backdrop-blur-sm z-30 transition-opacity duration-300 ease-in-out ${
|
||||
visible
|
||||
? "opacity-100 pointer-events-auto"
|
||||
: "opacity-0 pointer-events-none"
|
||||
|
||||
@@ -35,7 +35,7 @@ export const MinimalMarkdown: React.FC<MinimalMarkdownProps> = ({
|
||||
|
||||
return (
|
||||
<ReactMarkdown
|
||||
className={`w-full text-wrap break-word prose dark:prose-invert ${className}`}
|
||||
className={`w-full text-wrap break-word ${className}`}
|
||||
components={markdownComponents}
|
||||
remarkPlugins={[remarkGfm]}
|
||||
>
|
||||
|
||||
@@ -78,7 +78,7 @@ export function getUniqueIcons(docs: OnyxDocument[]): JSX.Element[] {
|
||||
|
||||
for (const doc of docs) {
|
||||
// If it's a web source, we check domain uniqueness
|
||||
if ((doc.is_internet || doc.source_type === ValidSources.Web) && doc.link) {
|
||||
if (doc.source_type === ValidSources.Web && doc.link) {
|
||||
const domain = getDomainFromUrl(doc.link);
|
||||
if (domain && !seenDomains.has(domain)) {
|
||||
seenDomains.add(domain);
|
||||
|
||||
@@ -47,7 +47,7 @@ export default function LogoWithText({
|
||||
className="flex gap-x-2 items-center ml-0 cursor-pointer desktop:hidden "
|
||||
>
|
||||
{!toggled ? (
|
||||
<Logo className="desktop:hidden" height={24} width={24} />
|
||||
<Logo className="desktop:hidden -my-2" height={24} width={24} />
|
||||
) : (
|
||||
<LogoComponent
|
||||
show={toggled}
|
||||
|
||||
@@ -23,11 +23,8 @@ import { AllUsersResponse } from "./types";
|
||||
import { Credential } from "./connectors/credentials";
|
||||
import { SettingsContext } from "@/components/settings/SettingsProvider";
|
||||
import { Persona, PersonaLabel } from "@/app/admin/assistants/interfaces";
|
||||
import {
|
||||
isAnthropic,
|
||||
LLMProviderDescriptor,
|
||||
} from "@/app/admin/configuration/llm/interfaces";
|
||||
|
||||
import { LLMProviderDescriptor } from "@/app/admin/configuration/llm/interfaces";
|
||||
import { isAnthropic } from "@/app/admin/configuration/llm/interfaces";
|
||||
import { getSourceMetadata } from "./sources";
|
||||
import { AuthType, NEXT_PUBLIC_CLOUD_ENABLED } from "./constants";
|
||||
import { useUser } from "@/components/user/UserProvider";
|
||||
|
||||
@@ -79,18 +79,12 @@ export async function* handleStream<T extends NonEmptyObject>(
|
||||
}
|
||||
|
||||
export async function* handleSSEStream<T extends PacketType>(
|
||||
streamingResponse: Response,
|
||||
signal?: AbortSignal
|
||||
streamingResponse: Response
|
||||
): AsyncGenerator<T, void, unknown> {
|
||||
const reader = streamingResponse.body?.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
let buffer = "";
|
||||
if (signal) {
|
||||
signal.addEventListener("abort", () => {
|
||||
console.log("aborting");
|
||||
reader?.cancel();
|
||||
});
|
||||
}
|
||||
|
||||
while (true) {
|
||||
const rawChunk = await reader?.read();
|
||||
if (!rawChunk) {
|
||||
|
||||
@@ -21,6 +21,7 @@ module.exports = {
|
||||
transitionProperty: {
|
||||
spacing: "margin, padding",
|
||||
},
|
||||
|
||||
keyframes: {
|
||||
"subtle-pulse": {
|
||||
"0%, 100%": { opacity: 0.9 },
|
||||
@@ -147,6 +148,7 @@ module.exports = {
|
||||
"text-mobile-sidebar": "var(--text-800)",
|
||||
"background-search-filter": "var(--neutral-100-border-light)",
|
||||
"background-search-filter-dropdown": "var(--neutral-100-border-light)",
|
||||
"tw-prose-bold": "var(--text-800)",
|
||||
|
||||
"user-bubble": "var(--off-white)",
|
||||
|
||||
|
||||
Reference in New Issue
Block a user