Compare commits

...

17 Commits

Author SHA1 Message Date
Weves
4d1d5bdcfe tweak worker count 2025-02-16 14:51:10 -08:00
Weves
0d468b49a1 Final fixes 2025-02-16 14:50:51 -08:00
Weves
67b87ced39 fix paths 2025-02-16 14:23:51 -08:00
Weves
8b4e4a6c80 Fix paths 2025-02-16 14:22:12 -08:00
Weves
e26bcf5a05 Misc fixes 2025-02-16 14:02:16 -08:00
Weves
435959cf90 test 2025-02-16 14:02:16 -08:00
Weves
fcbe305dc0 test 2025-02-16 14:02:16 -08:00
Weves
6f13d44564 move 2025-02-16 14:02:16 -08:00
Weves
c1810a35cd Fix redis port 2025-02-16 14:02:16 -08:00
Weves
4003e7346a test 2025-02-16 14:02:16 -08:00
Weves
8057f1eb0d Add logging 2025-02-16 14:02:16 -08:00
Weves
7eebd3cff1 test not removing files 2025-02-16 14:02:16 -08:00
Weves
bac2aeb8b7 Fix 2025-02-16 14:02:16 -08:00
Weves
9831697acc Make migrations work 2025-02-16 14:02:15 -08:00
Weves
5da766dd3b testing 2025-02-16 14:01:50 -08:00
Weves
180608694a improvements 2025-02-16 14:01:49 -08:00
Weves
96b92edfdb Parallelize IT
Parallelization

Full draft of first pass

Adjsut test name

test

test

Fix

Update cmd

test

Fix

test

Test with all tests

Resource bump + limit num parallel runs

Add retries
2025-02-16 14:01:12 -08:00
19 changed files with 1501 additions and 185 deletions

View File

@@ -0,0 +1,153 @@
name: Run Integration Tests v3
concurrency:
group: Run-Integration-Tests-Parallel-${{ github.workflow }}-${{ github.head_ref || github.event.workflow_run.head_branch || github.run_id }}
cancel-in-progress: true
on:
merge_group:
pull_request:
branches:
- main
- "release/**"
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}
SLACK_BOT_TOKEN: ${{ secrets.SLACK_BOT_TOKEN }}
CONFLUENCE_TEST_SPACE_URL: ${{ secrets.CONFLUENCE_TEST_SPACE_URL }}
CONFLUENCE_USER_NAME: ${{ secrets.CONFLUENCE_USER_NAME }}
CONFLUENCE_ACCESS_TOKEN: ${{ secrets.CONFLUENCE_ACCESS_TOKEN }}
jobs:
integration-tests:
# See https://runs-on.com/runners/linux/
runs-on:
[runs-on, runner=32cpu-linux-x64, ram=64, "run-id=${{ github.run_id }}"]
steps:
- name: Checkout code
uses: actions/checkout@v4
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
- name: Login to Docker Hub
uses: docker/login-action@v3
with:
username: ${{ secrets.DOCKER_USERNAME }}
password: ${{ secrets.DOCKER_TOKEN }}
- name: Build integration test Docker image
uses: ./.github/actions/custom-build-and-push
with:
context: ./backend
file: ./backend/tests/integration/Dockerfile
platforms: linux/amd64
tags: danswer/danswer-integration:test
push: false
load: true
cache-from: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration-parallel/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }}
cache-to: type=s3,prefix=cache/${{ github.repository }}/integration-tests/integration-parallel/,region=${{ env.RUNS_ON_AWS_REGION }},bucket=${{ env.RUNS_ON_S3_BUCKET_CACHE }},mode=max
- name: Run Standard Integration Tests
run: |
# Print a message indicating that tests are starting
echo "Running integration tests..."
# Create a directory for test logs that will be mounted into the container
mkdir -p ${{ github.workspace }}/test_logs
chmod 777 ${{ github.workspace }}/test_logs
# Run the integration tests in a Docker container
# Mount the Docker socket to allow Docker-in-Docker (DinD)
# Mount the test_logs directory to capture logs
# Use host network for easier communication with other services
docker run \
-v /var/run/docker.sock:/var/run/docker.sock \
-v ${{ github.workspace }}/test_logs:/tmp \
--network host \
-e OPENAI_API_KEY=${OPENAI_API_KEY} \
-e SLACK_BOT_TOKEN=${SLACK_BOT_TOKEN} \
-e CONFLUENCE_TEST_SPACE_URL=${CONFLUENCE_TEST_SPACE_URL} \
-e CONFLUENCE_USER_NAME=${CONFLUENCE_USER_NAME} \
-e CONFLUENCE_ACCESS_TOKEN=${CONFLUENCE_ACCESS_TOKEN} \
-e TEST_WEB_HOSTNAME=test-runner \
-e MOCK_CONNECTOR_SERVER_HOST=mock_connector_server \
-e MOCK_CONNECTOR_SERVER_PORT=8001 \
danswer/danswer-integration:test \
python /app/tests/integration/run.py
continue-on-error: true
id: run_tests
- name: Check test results
run: |
if [ ${{ steps.run_tests.outcome }} == 'failure' ]; then
echo "Integration tests failed. Exiting with error."
exit 1
else
echo "All integration tests passed successfully."
fi
- name: Collect log files
if: success() || failure()
run: |
# Create a directory for logs
mkdir -p ${{ github.workspace }}/logs
mkdir -p ${{ github.workspace }}/logs/shared_services
# Copy all relevant log files from the mounted directory
cp ${{ github.workspace }}/test_logs/api_server_*.txt ${{ github.workspace }}/logs/ || true
cp ${{ github.workspace }}/test_logs/background_*.txt ${{ github.workspace }}/logs/ || true
cp ${{ github.workspace }}/test_logs/shared_model_server.txt ${{ github.workspace }}/logs/ || true
# Collect logs from shared services (Docker containers)
# Note: using a wildcard for the UUID part of the stack name
docker ps -a --filter "name=base-onyx-" --format "{{.Names}}" | while read container; do
echo "Collecting logs from $container"
docker logs $container > "${{ github.workspace }}/logs/shared_services/${container}.log" 2>&1 || true
done
# Also collect Redis container logs
docker ps -a --filter "name=redis-onyx-" --format "{{.Names}}" | while read container; do
echo "Collecting logs from $container"
docker logs $container > "${{ github.workspace }}/logs/shared_services/${container}.log" 2>&1 || true
done
# List collected logs
echo "Collected log files:"
ls -l ${{ github.workspace }}/logs/
echo "Collected shared services logs:"
ls -l ${{ github.workspace }}/logs/shared_services/
- name: Upload logs
if: success() || failure()
uses: actions/upload-artifact@v4
with:
name: integration-test-logs
path: |
${{ github.workspace }}/logs/
${{ github.workspace }}/logs/shared_services/
retention-days: 5
# save before stopping the containers so the logs can be captured
# - name: Save Docker logs
# if: success() || failure()
# run: |
# cd deployment/docker_compose
# docker compose -f docker-compose.dev.yml -p danswer-stack logs > docker-compose.log
# mv docker-compose.log ${{ github.workspace }}/docker-compose.log
# - name: Stop Docker containers
# run: |
# cd deployment/docker_compose
# docker compose -f docker-compose.dev.yml -p danswer-stack down -v
# - name: Upload logs
# if: success() || failure()
# uses: actions/upload-artifact@v4
# with:
# name: docker-logs
# path: ${{ github.workspace }}/docker-compose.log
# - name: Stop Docker containers
# run: |
# cd deployment/docker_compose
# docker compose -f docker-compose.dev.yml -p danswer-stack down -v

View File

@@ -5,10 +5,10 @@ concurrency:
on:
merge_group:
pull_request:
branches:
- main
- "release/**"
# pull_request:
# branches:
# - main
# - "release/**"
env:
OPENAI_API_KEY: ${{ secrets.OPENAI_API_KEY }}

View File

@@ -1,6 +1,6 @@
from typing import Any, Literal
from onyx.db.engine import get_iam_auth_token
from onyx.configs.app_configs import USE_IAM_AUTH
from onyx.db.engine import SYNC_DB_API, get_iam_auth_token
from onyx.configs.app_configs import POSTGRES_DB, USE_IAM_AUTH
from onyx.configs.app_configs import POSTGRES_HOST
from onyx.configs.app_configs import POSTGRES_PORT
from onyx.configs.app_configs import POSTGRES_USER
@@ -13,12 +13,11 @@ from sqlalchemy import text
from sqlalchemy.engine.base import Connection
import os
import ssl
import asyncio
import logging
from logging.config import fileConfig
from alembic import context
from sqlalchemy.ext.asyncio import create_async_engine
from sqlalchemy import create_engine
from sqlalchemy.sql.schema import SchemaItem
from onyx.configs.constants import SSL_CERT_FILE
from shared_configs.configs import MULTI_TENANT, POSTGRES_DEFAULT_SCHEMA
@@ -133,17 +132,32 @@ def provide_iam_token_for_alembic(
cparams["ssl"] = ssl_context
async def run_async_migrations() -> None:
def run_migrations() -> None:
schema_name, create_schema, upgrade_all_tenants = get_schema_options()
engine = create_async_engine(
build_connection_string(),
# Get any environment variables passed through alembic config
env_vars = context.config.attributes.get("env_vars", {})
# Use env vars if provided, otherwise fall back to defaults
postgres_host = env_vars.get("POSTGRES_HOST", POSTGRES_HOST)
postgres_port = env_vars.get("POSTGRES_PORT", POSTGRES_PORT)
postgres_user = env_vars.get("POSTGRES_USER", POSTGRES_USER)
postgres_db = env_vars.get("POSTGRES_DB", POSTGRES_DB)
engine = create_engine(
build_connection_string(
db=postgres_db,
user=postgres_user,
host=postgres_host,
port=postgres_port,
db_api=SYNC_DB_API,
),
poolclass=pool.NullPool,
)
if USE_IAM_AUTH:
@event.listens_for(engine.sync_engine, "do_connect")
@event.listens_for(engine, "do_connect")
def event_provide_iam_token_for_alembic(
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any
) -> None:
@@ -152,31 +166,26 @@ async def run_async_migrations() -> None:
if upgrade_all_tenants:
tenant_schemas = get_all_tenant_ids()
for schema in tenant_schemas:
if schema is None:
continue
try:
logger.info(f"Migrating schema: {schema}")
async with engine.connect() as connection:
await connection.run_sync(
do_run_migrations,
schema_name=schema,
create_schema=create_schema,
)
with engine.connect() as connection:
do_run_migrations(connection, schema, create_schema)
except Exception as e:
logger.error(f"Error migrating schema {schema}: {e}")
raise
else:
try:
logger.info(f"Migrating schema: {schema_name}")
async with engine.connect() as connection:
await connection.run_sync(
do_run_migrations,
schema_name=schema_name,
create_schema=create_schema,
)
with engine.connect() as connection:
do_run_migrations(connection, schema_name, create_schema)
except Exception as e:
logger.error(f"Error migrating schema {schema_name}: {e}")
raise
await engine.dispose()
engine.dispose()
def run_migrations_offline() -> None:
@@ -184,18 +193,18 @@ def run_migrations_offline() -> None:
url = build_connection_string()
if upgrade_all_tenants:
engine = create_async_engine(url)
engine = create_engine(url)
if USE_IAM_AUTH:
@event.listens_for(engine.sync_engine, "do_connect")
@event.listens_for(engine, "do_connect")
def event_provide_iam_token_for_alembic_offline(
dialect: Any, conn_rec: Any, cargs: Any, cparams: Any
) -> None:
provide_iam_token_for_alembic(dialect, conn_rec, cargs, cparams)
tenant_schemas = get_all_tenant_ids()
engine.sync_engine.dispose()
engine.dispose()
for schema in tenant_schemas:
logger.info(f"Migrating schema: {schema}")
@@ -230,7 +239,7 @@ def run_migrations_offline() -> None:
def run_migrations_online() -> None:
asyncio.run(run_async_migrations())
run_migrations()
if context.is_offline_mode():

View File

@@ -0,0 +1,6 @@
import os
SKIP_CONNECTION_POOL_WARM_UP = (
os.environ.get("SKIP_CONNECTION_POOL_WARM_UP", "").lower() == "true"
)

View File

@@ -191,9 +191,16 @@ class SqlEngine:
_app_name: str = POSTGRES_UNKNOWN_APP_NAME
@classmethod
def _init_engine(cls, **engine_kwargs: Any) -> Engine:
def _init_engine(
cls, host: str, port: str, db: str, **engine_kwargs: Any
) -> Engine:
connection_string = build_connection_string(
db_api=SYNC_DB_API, app_name=cls._app_name + "_sync", use_iam=USE_IAM_AUTH
db_api=SYNC_DB_API,
host=host,
port=port,
db=db,
app_name=cls._app_name + "_sync",
use_iam=USE_IAM_AUTH,
)
# Start with base kwargs that are valid for all pool types
@@ -231,15 +238,19 @@ class SqlEngine:
def init_engine(cls, **engine_kwargs: Any) -> None:
with cls._lock:
if not cls._engine:
cls._engine = cls._init_engine(**engine_kwargs)
cls._engine = cls._init_engine(
host=engine_kwargs.get("host", POSTGRES_HOST),
port=engine_kwargs.get("port", POSTGRES_PORT),
db=engine_kwargs.get("db", POSTGRES_DB),
**engine_kwargs,
)
@classmethod
def get_engine(cls) -> Engine:
if not cls._engine:
with cls._lock:
if not cls._engine:
cls._engine = cls._init_engine()
return cls._engine
cls.init_engine()
return cls._engine # type: ignore
@classmethod
def set_app_name(cls, app_name: str) -> None:

View File

@@ -6,6 +6,7 @@ from onyx.db.search_settings import get_current_search_settings
from onyx.document_index.interfaces import DocumentIndex
from onyx.document_index.vespa.index import VespaIndex
from shared_configs.configs import MULTI_TENANT
from shared_configs.configs import VECTOR_DB_INDEX_NAME_PREFIX__INTEGRATION_TEST_ONLY
def get_default_document_index(
@@ -23,14 +24,27 @@ def get_default_document_index(
secondary_index_name = secondary_search_settings.index_name
secondary_large_chunks_enabled = secondary_search_settings.large_chunks_enabled
# modify index names for integration tests so that we can run many tests
# using the same Vespa instance w/o having them collide
primary_index_name = search_settings.index_name
if VECTOR_DB_INDEX_NAME_PREFIX__INTEGRATION_TEST_ONLY:
primary_index_name = (
f"{VECTOR_DB_INDEX_NAME_PREFIX__INTEGRATION_TEST_ONLY}_{primary_index_name}"
)
if secondary_index_name:
secondary_index_name = f"{VECTOR_DB_INDEX_NAME_PREFIX__INTEGRATION_TEST_ONLY}_{secondary_index_name}"
# Currently only supporting Vespa
return VespaIndex(
index_name=search_settings.index_name,
index_name=primary_index_name,
secondary_index_name=secondary_index_name,
large_chunks_enabled=search_settings.large_chunks_enabled,
secondary_large_chunks_enabled=secondary_large_chunks_enabled,
multitenant=MULTI_TENANT,
httpx_client=httpx_client,
preserve_existing_indices=bool(
VECTOR_DB_INDEX_NAME_PREFIX__INTEGRATION_TEST_ONLY
),
)

View File

@@ -136,6 +136,7 @@ class VespaIndex(DocumentIndex):
secondary_large_chunks_enabled: bool | None,
multitenant: bool = False,
httpx_client: httpx.Client | None = None,
preserve_existing_indices: bool = False,
) -> None:
self.index_name = index_name
self.secondary_index_name = secondary_index_name
@@ -161,18 +162,18 @@ class VespaIndex(DocumentIndex):
secondary_index_name
] = secondary_large_chunks_enabled
def ensure_indices_exist(
self,
index_embedding_dim: int,
secondary_index_embedding_dim: int | None,
) -> None:
if MULTI_TENANT:
logger.info(
"Skipping Vespa index seup for multitenant (would wipe all indices)"
)
return None
self.preserve_existing_indices = preserve_existing_indices
deploy_url = f"{VESPA_APPLICATION_ENDPOINT}/tenant/default/prepareandactivate"
@classmethod
def create_indices(
cls,
indices: list[tuple[str, int, bool]],
application_endpoint: str = VESPA_APPLICATION_ENDPOINT,
) -> None:
"""
Create indices in Vespa based on the passed in configuration(s).
"""
deploy_url = f"{application_endpoint}/tenant/default/prepareandactivate"
logger.notice(f"Deploying Vespa application package to {deploy_url}")
vespa_schema_path = os.path.join(
@@ -185,7 +186,7 @@ class VespaIndex(DocumentIndex):
with open(services_file, "r") as services_f:
services_template = services_f.read()
schema_names = [self.index_name, self.secondary_index_name]
schema_names = [index_name for (index_name, _, _) in indices]
doc_lines = _create_document_xml_lines(schema_names)
services = services_template.replace(DOCUMENT_REPLACEMENT_PAT, doc_lines)
@@ -193,14 +194,6 @@ class VespaIndex(DocumentIndex):
SEARCH_THREAD_NUMBER_PAT, str(VESPA_SEARCHER_THREADS)
)
kv_store = get_kv_store()
needs_reindexing = False
try:
needs_reindexing = cast(bool, kv_store.load(KV_REINDEX_KEY))
except Exception:
logger.debug("Could not load the reindexing flag. Using ngrams")
with open(overrides_file, "r") as overrides_f:
overrides_template = overrides_f.read()
@@ -221,29 +214,63 @@ class VespaIndex(DocumentIndex):
schema_template = schema_f.read()
schema_template = schema_template.replace(TENANT_ID_PAT, "")
schema = schema_template.replace(
DANSWER_CHUNK_REPLACEMENT_PAT, self.index_name
).replace(VESPA_DIM_REPLACEMENT_PAT, str(index_embedding_dim))
for index_name, index_embedding_dim, needs_reindexing in indices:
schema = schema_template.replace(
DANSWER_CHUNK_REPLACEMENT_PAT, index_name
).replace(VESPA_DIM_REPLACEMENT_PAT, str(index_embedding_dim))
schema = add_ngrams_to_schema(schema) if needs_reindexing else schema
schema = schema.replace(TENANT_ID_PAT, "")
zip_dict[f"schemas/{schema_names[0]}.sd"] = schema.encode("utf-8")
if self.secondary_index_name:
upcoming_schema = schema_template.replace(
DANSWER_CHUNK_REPLACEMENT_PAT, self.secondary_index_name
).replace(VESPA_DIM_REPLACEMENT_PAT, str(secondary_index_embedding_dim))
zip_dict[f"schemas/{schema_names[1]}.sd"] = upcoming_schema.encode("utf-8")
schema = add_ngrams_to_schema(schema) if needs_reindexing else schema
schema = schema.replace(TENANT_ID_PAT, "")
logger.info(
f"Creating index: {index_name} with embedding "
f"dimension: {index_embedding_dim}. Schema:\n\n {schema}"
)
zip_dict[f"schemas/{index_name}.sd"] = schema.encode("utf-8")
zip_file = in_memory_zip_from_file_bytes(zip_dict)
headers = {"Content-Type": "application/zip"}
response = requests.post(deploy_url, headers=headers, data=zip_file)
if response.status_code != 200:
logger.error(f"Failed to create Vespa indices: {response.text}")
raise RuntimeError(
f"Failed to prepare Vespa Onyx Index. Response: {response.text}"
)
def ensure_indices_exist(
self,
index_embedding_dim: int,
secondary_index_embedding_dim: int | None,
) -> None:
if self.multitenant or MULTI_TENANT: # be extra safe here
logger.info(
"Skipping Vespa index setup for multitenant (would wipe all indices)"
)
return None
# Used in IT
# NOTE: this means that we can't switch embedding models
if self.preserve_existing_indices:
logger.info("Preserving existing indices")
return None
kv_store = get_kv_store()
primary_needs_reindexing = False
try:
primary_needs_reindexing = cast(bool, kv_store.load(KV_REINDEX_KEY))
except Exception:
logger.debug("Could not load the reindexing flag. Using ngrams")
indices = [
(self.index_name, index_embedding_dim, primary_needs_reindexing),
]
if self.secondary_index_name and secondary_index_embedding_dim:
indices.append(
(self.secondary_index_name, secondary_index_embedding_dim, False)
)
self.create_indices(indices)
@staticmethod
def register_multitenant_indices(
indices: list[str],

View File

@@ -43,6 +43,7 @@ from onyx.configs.app_configs import USER_AUTH_SECRET
from onyx.configs.app_configs import WEB_DOMAIN
from onyx.configs.constants import AuthType
from onyx.configs.constants import POSTGRES_WEB_APP_NAME
from onyx.configs.integration_test_configs import SKIP_CONNECTION_POOL_WARM_UP
from onyx.db.engine import SqlEngine
from onyx.db.engine import warm_up_connections
from onyx.server.api_key.api import router as api_key_router
@@ -208,8 +209,11 @@ async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
if DISABLE_GENERATIVE_AI:
logger.notice("Generative AI Q&A disabled")
# fill up Postgres connection pools
await warm_up_connections()
# only used for IT. Need to skip since it overloads postgres when we have 50+
# instances running
if not SKIP_CONNECTION_POOL_WARM_UP:
# fill up Postgres connection pools
await warm_up_connections()
if not MULTI_TENANT:
# We cache this at the beginning so there is no delay in the first telemetry

View File

@@ -146,10 +146,10 @@ class RedisPool:
cls._instance._init_pools()
return cls._instance
def _init_pools(self) -> None:
self._pool = RedisPool.create_pool(ssl=REDIS_SSL)
def _init_pools(self, redis_port: int = REDIS_PORT) -> None:
self._pool = RedisPool.create_pool(port=redis_port, ssl=REDIS_SSL)
self._replica_pool = RedisPool.create_pool(
host=REDIS_REPLICA_HOST, ssl=REDIS_SSL
host=REDIS_REPLICA_HOST, port=redis_port, ssl=REDIS_SSL
)
def get_client(self, tenant_id: str | None) -> Redis:

View File

@@ -251,7 +251,8 @@ def setup_vespa(
logger.notice("Vespa setup complete.")
return True
except Exception:
except Exception as e:
logger.debug(f"Error creating Vespa indices: {e}")
logger.notice(
f"Vespa setup did not succeed. The Vespa service may not be ready yet. Retrying in {WAIT_SECONDS} seconds."
)

View File

@@ -270,3 +270,10 @@ SUPPORTED_EMBEDDING_MODELS = [
index_name="danswer_chunk_intfloat_multilingual_e5_small",
),
]
"""
INTEGRATION TEST ONLY SETTINGS
"""
VECTOR_DB_INDEX_NAME_PREFIX__INTEGRATION_TEST_ONLY = os.getenv(
"VECTOR_DB_INDEX_NAME_PREFIX__INTEGRATION_TEST_ONLY"
)

View File

@@ -3,6 +3,19 @@ FROM python:3.11.7-slim-bookworm
# Currently needs all dependencies, since the ITs use some of the Onyx
# backend code.
# Add Docker's official GPG key and repository for Debian
RUN apt-get update && \
apt-get install -y ca-certificates curl && \
install -m 0755 -d /etc/apt/keyrings && \
curl -fsSL https://download.docker.com/linux/debian/gpg -o /etc/apt/keyrings/docker.asc && \
chmod a+r /etc/apt/keyrings/docker.asc && \
echo \
"deb [arch=$(dpkg --print-architecture) signed-by=/etc/apt/keyrings/docker.asc] https://download.docker.com/linux/debian \
$(. /etc/os-release && echo "$VERSION_CODENAME") stable" | \
tee /etc/apt/sources.list.d/docker.list > /dev/null && \
apt-get update
# Install system dependencies
# cmake needed for psycopg (postgres)
# libpq-dev needed for psycopg (postgres)
@@ -15,6 +28,9 @@ RUN apt-get update && \
curl \
zip \
ca-certificates \
postgresql-client \
# Install Docker for DinD
docker-ce docker-ce-cli containerd.io docker-buildx-plugin docker-compose-plugin \
libgnutls30=3.7.9-2+deb12u3 \
libblkid1=2.38.1-5+deb12u1 \
libmount1=2.38.1-5+deb12u1 \
@@ -29,37 +45,19 @@ RUN apt-get update && \
# Install Python dependencies
# Remove py which is pulled in by retry, py is not needed and is a CVE
COPY ./requirements/default.txt /tmp/requirements.txt
COPY ./requirements/model_server.txt /tmp/model_server-requirements.txt
COPY ./requirements/ee.txt /tmp/ee-requirements.txt
RUN pip install --no-cache-dir --upgrade \
--retries 5 \
--timeout 30 \
-r /tmp/requirements.txt \
-r /tmp/model_server-requirements.txt \
-r /tmp/ee-requirements.txt && \
pip uninstall -y py && \
playwright install chromium && \
playwright install-deps chromium && \
ln -s /usr/local/bin/supervisord /usr/bin/supervisord
# Cleanup for CVEs and size reduction
# https://github.com/tornadoweb/tornado/issues/3107
# xserver-common and xvfb included by playwright installation but not needed after
# perl-base is part of the base Python Debian image but not needed for Onyx functionality
# perl-base could only be removed with --allow-remove-essential
RUN apt-get update && \
apt-get remove -y --allow-remove-essential \
perl-base \
xserver-common \
xvfb \
cmake \
libldap-2.5-0 \
libxmlsec1-dev \
pkg-config \
gcc && \
apt-get install -y libxmlsec1-openssl && \
apt-get autoremove -y && \
rm -rf /var/lib/apt/lists/* && \
rm -f /usr/local/lib/python3.11/site-packages/tornado/test/test.key
# Set up application files
WORKDIR /app
@@ -76,6 +74,9 @@ COPY ./alembic.ini /app/alembic.ini
COPY ./pytest.ini /app/pytest.ini
COPY supervisord.conf /usr/etc/supervisord.conf
# need to copy over model server as well, since we're running it in the same container
COPY ./model_server /app/model_server
# Integration test stuff
COPY ./requirements/dev.txt /tmp/dev-requirements.txt
RUN pip install --no-cache-dir --upgrade \
@@ -84,5 +85,6 @@ COPY ./tests/integration /app/tests/integration
ENV PYTHONPATH=/app
ENTRYPOINT ["pytest", "-s"]
CMD ["/app/tests/integration", "--ignore=/app/tests/integration/multitenant_tests"]
ENTRYPOINT []
# let caller specify the command
CMD ["tail", "-f", "/dev/null"]

View File

@@ -1,6 +1,7 @@
import os
ADMIN_USER_NAME = "admin_user"
GUARANTEED_FRESH_SETUP = os.getenv("GUARANTEED_FRESH_SETUP") == "true"
API_SERVER_PROTOCOL = os.getenv("API_SERVER_PROTOCOL") or "http"
API_SERVER_HOST = os.getenv("API_SERVER_HOST") or "localhost"

View File

@@ -1,5 +1,9 @@
import contextlib
import io
import logging
import sys
import time
from logging import Logger
from types import SimpleNamespace
import psycopg2
@@ -11,10 +15,12 @@ from onyx.configs.app_configs import POSTGRES_HOST
from onyx.configs.app_configs import POSTGRES_PASSWORD
from onyx.configs.app_configs import POSTGRES_PORT
from onyx.configs.app_configs import POSTGRES_USER
from onyx.configs.app_configs import REDIS_PORT
from onyx.db.engine import build_connection_string
from onyx.db.engine import get_all_tenant_ids
from onyx.db.engine import get_session_context_manager
from onyx.db.engine import get_session_with_tenant
from onyx.db.engine import SqlEngine
from onyx.db.engine import SYNC_DB_API
from onyx.db.search_settings import get_current_search_settings
from onyx.db.swap_index import check_index_swap
@@ -22,17 +28,20 @@ from onyx.document_index.document_index_utils import get_multipass_config
from onyx.document_index.vespa.index import DOCUMENT_ID_ENDPOINT
from onyx.document_index.vespa.index import VespaIndex
from onyx.indexing.models import IndexingSetting
from onyx.redis.redis_pool import redis_pool
from onyx.setup import setup_postgres
from onyx.setup import setup_vespa
from onyx.utils.logger import setup_logger
from tests.integration.common_utils.timeout import run_with_timeout
logger = setup_logger()
def _run_migrations(
database_url: str,
database: str,
config_name: str,
postgres_host: str,
postgres_port: str,
redis_port: int,
direction: str = "upgrade",
revision: str = "head",
schema: str = "public",
@@ -46,9 +55,28 @@ def _run_migrations(
alembic_cfg.attributes["configure_logger"] = False
alembic_cfg.config_ini_section = config_name
# Add environment variables to the config attributes
alembic_cfg.attributes["env_vars"] = {
"POSTGRES_HOST": postgres_host,
"POSTGRES_PORT": postgres_port,
"POSTGRES_DB": database,
# some migrations call redis directly, so we need to pass the port
"REDIS_PORT": str(redis_port),
}
alembic_cfg.cmd_opts = SimpleNamespace() # type: ignore
alembic_cfg.cmd_opts.x = [f"schema={schema}"] # type: ignore
# Build the database URL
database_url = build_connection_string(
db=database,
user=POSTGRES_USER,
password=POSTGRES_PASSWORD,
host=postgres_host,
port=postgres_port,
db_api=SYNC_DB_API,
)
# Set the SQLAlchemy URL in the Alembic configuration
alembic_cfg.set_main_option("sqlalchemy.url", database_url)
@@ -71,6 +99,9 @@ def downgrade_postgres(
config_name: str = "alembic",
revision: str = "base",
clear_data: bool = False,
postgres_host: str = POSTGRES_HOST,
postgres_port: str = POSTGRES_PORT,
redis_port: int = REDIS_PORT,
) -> None:
"""Downgrade Postgres database to base state."""
if clear_data:
@@ -81,8 +112,8 @@ def downgrade_postgres(
dbname=database,
user=POSTGRES_USER,
password=POSTGRES_PASSWORD,
host=POSTGRES_HOST,
port=POSTGRES_PORT,
host=postgres_host,
port=postgres_port,
)
conn.autocommit = True # Need autocommit for dropping schema
cur = conn.cursor()
@@ -112,37 +143,32 @@ def downgrade_postgres(
return
# Downgrade to base
conn_str = build_connection_string(
db=database,
user=POSTGRES_USER,
password=POSTGRES_PASSWORD,
host=POSTGRES_HOST,
port=POSTGRES_PORT,
db_api=SYNC_DB_API,
)
_run_migrations(
conn_str,
config_name,
database=database,
config_name=config_name,
postgres_host=postgres_host,
postgres_port=postgres_port,
redis_port=redis_port,
direction="downgrade",
revision=revision,
)
def upgrade_postgres(
database: str = "postgres", config_name: str = "alembic", revision: str = "head"
database: str = "postgres",
config_name: str = "alembic",
revision: str = "head",
postgres_host: str = POSTGRES_HOST,
postgres_port: str = POSTGRES_PORT,
redis_port: int = REDIS_PORT,
) -> None:
"""Upgrade Postgres database to latest version."""
conn_str = build_connection_string(
db=database,
user=POSTGRES_USER,
password=POSTGRES_PASSWORD,
host=POSTGRES_HOST,
port=POSTGRES_PORT,
db_api=SYNC_DB_API,
)
_run_migrations(
conn_str,
config_name,
database=database,
config_name=config_name,
postgres_host=postgres_host,
postgres_port=postgres_port,
redis_port=redis_port,
direction="upgrade",
revision=revision,
)
@@ -152,46 +178,44 @@ def reset_postgres(
database: str = "postgres",
config_name: str = "alembic",
setup_onyx: bool = True,
postgres_host: str = POSTGRES_HOST,
postgres_port: str = POSTGRES_PORT,
redis_port: int = REDIS_PORT,
) -> None:
"""Reset the Postgres database."""
# this seems to hang due to locking issues, so run with a timeout with a few retries
NUM_TRIES = 10
TIMEOUT = 10
success = False
for _ in range(NUM_TRIES):
logger.info(f"Downgrading Postgres... ({_ + 1}/{NUM_TRIES})")
try:
run_with_timeout(
downgrade_postgres,
TIMEOUT,
kwargs={
"database": database,
"config_name": config_name,
"revision": "base",
"clear_data": True,
},
)
success = True
break
except TimeoutError:
logger.warning(
f"Postgres downgrade timed out, retrying... ({_ + 1}/{NUM_TRIES})"
)
if not success:
raise RuntimeError("Postgres downgrade failed after 10 timeouts.")
logger.info("Upgrading Postgres...")
upgrade_postgres(database=database, config_name=config_name, revision="head")
downgrade_postgres(
database=database,
config_name=config_name,
revision="base",
clear_data=True,
postgres_host=postgres_host,
postgres_port=postgres_port,
redis_port=redis_port,
)
upgrade_postgres(
database=database,
config_name=config_name,
revision="head",
postgres_host=postgres_host,
postgres_port=postgres_port,
redis_port=redis_port,
)
if setup_onyx:
logger.info("Setting up Postgres...")
with get_session_context_manager() as db_session:
setup_postgres(db_session)
def reset_vespa() -> None:
"""Wipe all data from the Vespa index."""
def reset_vespa(
skip_creating_indices: bool, document_id_endpoint: str = DOCUMENT_ID_ENDPOINT
) -> None:
"""Wipe all data from the Vespa index.
Args:
skip_creating_indices: If True, the indices will not be recreated.
This is useful if the indices already exist and you do not want to
recreate them (e.g. when running parallel tests).
"""
with get_session_context_manager() as db_session:
# swap to the correct default model
check_index_swap(db_session)
@@ -200,18 +224,21 @@ def reset_vespa() -> None:
multipass_config = get_multipass_config(search_settings)
index_name = search_settings.index_name
success = setup_vespa(
document_index=VespaIndex(
index_name=index_name,
secondary_index_name=None,
large_chunks_enabled=multipass_config.enable_large_chunks,
secondary_large_chunks_enabled=None,
),
index_setting=IndexingSetting.from_db_model(search_settings),
secondary_index_setting=None,
)
if not success:
raise RuntimeError("Could not connect to Vespa within the specified timeout.")
if not skip_creating_indices:
success = setup_vespa(
document_index=VespaIndex(
index_name=index_name,
secondary_index_name=None,
large_chunks_enabled=multipass_config.enable_large_chunks,
secondary_large_chunks_enabled=None,
),
index_setting=IndexingSetting.from_db_model(search_settings),
secondary_index_setting=None,
)
if not success:
raise RuntimeError(
"Could not connect to Vespa within the specified timeout."
)
for _ in range(5):
try:
@@ -222,7 +249,7 @@ def reset_vespa() -> None:
if continuation:
params = {**params, "continuation": continuation}
response = requests.delete(
DOCUMENT_ID_ENDPOINT.format(index_name=index_name), params=params
document_id_endpoint.format(index_name=index_name), params=params
)
response.raise_for_status()
@@ -336,11 +363,99 @@ def reset_vespa_multitenant() -> None:
time.sleep(5)
def reset_all() -> None:
def reset_all(
database: str = "postgres",
postgres_host: str = POSTGRES_HOST,
postgres_port: str = POSTGRES_PORT,
redis_port: int = REDIS_PORT,
silence_logs: bool = False,
skip_creating_indices: bool = False,
document_id_endpoint: str = DOCUMENT_ID_ENDPOINT,
) -> None:
if not silence_logs:
with contextlib.redirect_stdout(sys.stdout), contextlib.redirect_stderr(
sys.stderr
):
_do_reset(
database,
postgres_host,
postgres_port,
redis_port,
skip_creating_indices,
document_id_endpoint,
)
return
# Store original logging levels
loggers_to_silence: list[Logger] = [
logging.getLogger(), # Root logger
logging.getLogger("alembic"),
logger.logger, # Our custom logger
]
original_levels = [logger.level for logger in loggers_to_silence]
# Temporarily set all loggers to ERROR level
for log in loggers_to_silence:
log.setLevel(logging.ERROR)
stdout_redirect = io.StringIO()
stderr_redirect = io.StringIO()
try:
with contextlib.redirect_stdout(stdout_redirect), contextlib.redirect_stderr(
stderr_redirect
):
_do_reset(
database,
postgres_host,
postgres_port,
redis_port,
skip_creating_indices,
document_id_endpoint,
)
except Exception as e:
print(stdout_redirect.getvalue(), file=sys.stdout)
print(stderr_redirect.getvalue(), file=sys.stderr)
raise e
finally:
# Restore original logging levels
for logger_, level in zip(loggers_to_silence, original_levels):
logger_.setLevel(level)
def _do_reset(
database: str,
postgres_host: str,
postgres_port: str,
redis_port: int,
skip_creating_indices: bool,
document_id_endpoint: str,
) -> None:
"""NOTE: should only be be running in one worker/thread a time."""
# force re-create the engine to allow for the same worker to reset
# different databases
with SqlEngine._lock:
SqlEngine._engine = SqlEngine._init_engine(
host=postgres_host,
port=postgres_port,
db=database,
)
# same with redis
redis_pool._init_pools(redis_port=redis_port)
logger.info("Resetting Postgres...")
reset_postgres()
reset_postgres(
database=database,
postgres_host=postgres_host,
postgres_port=postgres_port,
redis_port=redis_port,
)
logger.info("Resetting Vespa...")
reset_vespa()
reset_vespa(
skip_creating_indices=skip_creating_indices,
document_id_endpoint=document_id_endpoint,
)
def reset_all_multitenant() -> None:

View File

@@ -7,6 +7,7 @@ from onyx.db.engine import get_session_context_manager
from onyx.db.search_settings import get_current_search_settings
from tests.integration.common_utils.constants import ADMIN_USER_NAME
from tests.integration.common_utils.constants import GENERAL_HEADERS
from tests.integration.common_utils.constants import GUARANTEED_FRESH_SETUP
from tests.integration.common_utils.managers.user import build_email
from tests.integration.common_utils.managers.user import DEFAULT_PASSWORD
from tests.integration.common_utils.managers.user import UserManager
@@ -14,25 +15,11 @@ from tests.integration.common_utils.reset import reset_all
from tests.integration.common_utils.reset import reset_all_multitenant
from tests.integration.common_utils.test_models import DATestUser
from tests.integration.common_utils.vespa import vespa_fixture
def load_env_vars(env_file: str = ".env") -> None:
current_dir = os.path.dirname(os.path.abspath(__file__))
env_path = os.path.join(current_dir, env_file)
try:
with open(env_path, "r") as f:
for line in f:
line = line.strip()
if line and not line.startswith("#"):
key, value = line.split("=", 1)
os.environ[key] = value.strip()
print("Successfully loaded environment variables")
except FileNotFoundError:
print(f"File {env_file} not found")
from tests.integration.introspection import load_env_vars
# Load environment variables at the module level
load_env_vars()
load_env_vars(os.environ.get("IT_ENV_FILE_PATH", ".env"))
"""NOTE: for some reason using this seems to lead to misc
@@ -57,6 +44,10 @@ def vespa_client() -> vespa_fixture:
@pytest.fixture
def reset() -> None:
if GUARANTEED_FRESH_SETUP:
print("GUARANTEED_FRESH_SETUP is true, skipping reset")
return None
reset_all()

View File

@@ -118,6 +118,7 @@ def test_google_permission_sync(
GoogleDriveService, str, DATestCCPair, DATestUser, DATestUser, DATestUser
],
) -> None:
print("Running test_google_permission_sync")
(
drive_service,
drive_id,

View File

@@ -0,0 +1,71 @@
import os
from pathlib import Path
import pytest
from _pytest.nodes import Item
def list_all_tests(directory: str | Path = ".") -> list[str]:
"""
List all pytest test functions under the specified directory.
Args:
directory: Directory path to search for tests (defaults to current directory)
Returns:
List of test function names with their module paths
"""
directory = Path(directory).absolute()
print(f"Searching for tests in: {directory}")
class TestCollector:
def __init__(self) -> None:
self.collected: list[str] = []
def pytest_collection_modifyitems(self, items: list[Item]) -> None:
for item in items:
if isinstance(item, Item):
# Get the relative path from the test file to the directory we're searching from
rel_path = Path(item.fspath).relative_to(directory)
# Remove the .py extension
module_path = str(rel_path.with_suffix(""))
# Replace directory separators with dots
module_path = module_path.replace("/", ".")
test_name = item.name
self.collected.append(f"{module_path}::{test_name}")
collector = TestCollector()
# Run pytest in collection-only mode
pytest.main(
[
str(directory),
"--collect-only",
"-q", # quiet mode
],
plugins=[collector],
)
return sorted(collector.collected)
def load_env_vars(env_file: str = ".env") -> None:
current_dir = os.path.dirname(os.path.abspath(__file__))
env_path = os.path.join(current_dir, env_file)
try:
with open(env_path, "r") as f:
for line in f:
line = line.strip()
if line and not line.startswith("#"):
key, value = line.split("=", 1)
os.environ[key] = value.strip()
print("Successfully loaded environment variables")
except FileNotFoundError:
print(f"File {env_file} not found")
if __name__ == "__main__":
tests = list_all_tests()
print("\nFound tests:")
for test in tests:
print(f"- {test}")

View File

@@ -0,0 +1,637 @@
#!/usr/bin/env python3
import atexit
import os
import random
import signal
import socket
import subprocess
import sys
import time
import uuid
from collections.abc import Callable
from multiprocessing.pool import ThreadPool
from pathlib import Path
from typing import NamedTuple
import requests
import yaml
from onyx.configs.constants import AuthType
from onyx.document_index.vespa.index import VespaIndex
BACKEND_DIR_PATH = Path(__file__).parent.parent.parent
COMPOSE_DIR_PATH = BACKEND_DIR_PATH.parent / "deployment/docker_compose"
DEFAULT_EMBEDDING_DIMENSION = 768
DEFAULT_SCHEMA_NAME = "danswer_chunk_nomic_ai_nomic_embed_text_v1"
class DeploymentConfig(NamedTuple):
instance_num: int
api_port: int
web_port: int
nginx_port: int
redis_port: int
postgres_db: str
class SharedServicesConfig(NamedTuple):
run_id: uuid.UUID
postgres_port: int
vespa_port: int
vespa_tenant_port: int
model_server_port: int
def get_random_port() -> int:
"""Find a random available port."""
while True:
port = random.randint(10000, 65535)
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as sock:
if sock.connect_ex(("localhost", port)) != 0:
return port
def cleanup_pid(pid: int) -> None:
"""Cleanup a specific PID."""
print(f"Killing process {pid}")
try:
os.kill(pid, signal.SIGTERM)
except ProcessLookupError:
print(f"Process {pid} not found")
def get_shared_services_stack_name(run_id: uuid.UUID) -> str:
return f"base-onyx-{run_id}"
def get_db_name(instance_num: int) -> str:
"""Get the database name for a given instance number."""
return f"onyx_{instance_num}"
def get_vector_db_prefix(instance_num: int) -> str:
"""Get the vector DB prefix for a given instance number."""
return f"test_instance_{instance_num}"
def setup_db(
instance_num: int,
postgres_port: int,
) -> str:
env = os.environ.copy()
# Wait for postgres to be ready
max_attempts = 10
for attempt in range(max_attempts):
try:
subprocess.run(
[
"psql",
"-h",
"localhost",
"-p",
str(postgres_port),
"-U",
"postgres",
"-c",
"SELECT 1",
],
env={**env, "PGPASSWORD": "password"},
check=True,
stdout=subprocess.DEVNULL,
stderr=subprocess.DEVNULL,
)
break
except subprocess.CalledProcessError:
if attempt == max_attempts - 1:
raise RuntimeError("Postgres failed to become ready within timeout")
time.sleep(1)
db_name = get_db_name(instance_num)
# Create the database first
subprocess.run(
[
"psql",
"-h",
"localhost",
"-p",
str(postgres_port),
"-U",
"postgres",
"-c",
f"CREATE DATABASE {db_name}",
],
env={**env, "PGPASSWORD": "password"},
check=True,
)
# NEW: Stamp this brand-new DB at 'base' so Alembic doesn't fail
subprocess.run(
[
"alembic",
"stamp",
"base",
],
env={
**env,
"PGPASSWORD": "password",
"POSTGRES_HOST": "localhost",
"POSTGRES_PORT": str(postgres_port),
"POSTGRES_DB": db_name,
},
check=True,
cwd=str(BACKEND_DIR_PATH),
)
# Run alembic upgrade to create tables
max_attempts = 3
for attempt in range(max_attempts):
try:
subprocess.run(
["alembic", "upgrade", "head"],
env={
**env,
"POSTGRES_HOST": "localhost",
"POSTGRES_PORT": str(postgres_port),
"POSTGRES_DB": db_name,
},
check=True,
cwd=str(BACKEND_DIR_PATH),
)
break
except subprocess.CalledProcessError:
if attempt == max_attempts - 1:
raise
print("Alembic upgrade failed, retrying in 5 seconds...")
time.sleep(5)
return db_name
def start_api_server(
instance_num: int,
model_server_port: int,
postgres_port: int,
vespa_port: int,
vespa_tenant_port: int,
redis_port: int,
register_process: Callable[[subprocess.Popen], None],
) -> int:
"""Start the API server.
NOTE: assumes that Postgres is all set up (database exists, migrations ran)
"""
print("Starting API server...")
db_name = get_db_name(instance_num)
vector_db_prefix = get_vector_db_prefix(instance_num)
env = os.environ.copy()
env.update(
{
"POSTGRES_HOST": "localhost",
"POSTGRES_PORT": str(postgres_port),
"POSTGRES_DB": db_name,
"REDIS_HOST": "localhost",
"REDIS_PORT": str(redis_port),
"VESPA_HOST": "localhost",
"VESPA_PORT": str(vespa_port),
"VESPA_TENANT_PORT": str(vespa_tenant_port),
"MODEL_SERVER_PORT": str(model_server_port),
"VECTOR_DB_INDEX_NAME_PREFIX__INTEGRATION_TEST_ONLY": vector_db_prefix,
"LOG_LEVEL": "debug",
"AUTH_TYPE": AuthType.BASIC,
}
)
port = get_random_port()
# Open log file for API server in /tmp
log_file = open(f"/tmp/api_server_{instance_num}.txt", "w")
process = subprocess.Popen(
[
"uvicorn",
"onyx.main:app",
"--host",
"localhost",
"--port",
str(port),
],
env=env,
cwd=str(BACKEND_DIR_PATH),
stdout=log_file,
stderr=subprocess.STDOUT,
)
register_process(process)
return port
def start_background(
instance_num: int,
postgres_port: int,
vespa_port: int,
vespa_tenant_port: int,
redis_port: int,
register_process: Callable[[subprocess.Popen], None],
) -> None:
"""Start the background process."""
print("Starting background process...")
env = os.environ.copy()
env.update(
{
"POSTGRES_HOST": "localhost",
"POSTGRES_PORT": str(postgres_port),
"POSTGRES_DB": get_db_name(instance_num),
"REDIS_HOST": "localhost",
"REDIS_PORT": str(redis_port),
"VESPA_HOST": "localhost",
"VESPA_PORT": str(vespa_port),
"VESPA_TENANT_PORT": str(vespa_tenant_port),
"VECTOR_DB_INDEX_NAME_PREFIX__INTEGRATION_TEST_ONLY": get_vector_db_prefix(
instance_num
),
"LOG_LEVEL": "debug",
}
)
str(Path(__file__).parent / "backend")
# Open log file for background process in /tmp
log_file = open(f"/tmp/background_{instance_num}.txt", "w")
process = subprocess.Popen(
["supervisord", "-n", "-c", "./supervisord.conf"],
env=env,
cwd=str(BACKEND_DIR_PATH),
stdout=log_file,
stderr=subprocess.STDOUT,
)
register_process(process)
def start_shared_services(run_id: uuid.UUID) -> SharedServicesConfig:
"""Start Postgres and Vespa using docker-compose.
Returns (postgres_port, vespa_port, vespa_tenant_port, model_server_port)
"""
print("Starting database services...")
postgres_port = get_random_port()
vespa_port = get_random_port()
vespa_tenant_port = get_random_port()
model_server_port = get_random_port()
minimal_compose = {
"services": {
"relational_db": {
"image": "postgres:15.2-alpine",
"command": "-c 'max_connections=1000'",
"environment": {
"POSTGRES_USER": os.getenv("POSTGRES_USER", "postgres"),
"POSTGRES_PASSWORD": os.getenv("POSTGRES_PASSWORD", "password"),
},
"ports": [f"{postgres_port}:5432"],
},
"index": {
"image": "vespaengine/vespa:8.277.17",
"ports": [
f"{vespa_port}:8081", # Main Vespa port
f"{vespa_tenant_port}:19071", # Tenant port
],
},
},
}
# Write the minimal compose file
temp_compose = Path("/tmp/docker-compose.minimal.yml")
with open(temp_compose, "w") as f:
yaml.dump(minimal_compose, f)
# Start the services
subprocess.run(
[
"docker",
"compose",
"-f",
str(temp_compose),
"-p",
get_shared_services_stack_name(run_id),
"up",
"-d",
],
check=True,
)
# Start the shared model server
env = os.environ.copy()
env.update(
{
"POSTGRES_HOST": "localhost",
"POSTGRES_PORT": str(postgres_port),
"VESPA_HOST": "localhost",
"VESPA_PORT": str(vespa_port),
"VESPA_TENANT_PORT": str(vespa_tenant_port),
"LOG_LEVEL": "debug",
}
)
# Open log file for shared model server in /tmp
log_file = open("/tmp/shared_model_server.txt", "w")
process = subprocess.Popen(
[
"uvicorn",
"model_server.main:app",
"--host",
"0.0.0.0",
"--port",
str(model_server_port),
],
env=env,
cwd=str(BACKEND_DIR_PATH),
stdout=log_file,
stderr=subprocess.STDOUT,
)
atexit.register(cleanup_pid, process.pid)
shared_services_config = SharedServicesConfig(
run_id, postgres_port, vespa_port, vespa_tenant_port, model_server_port
)
print(f"Shared services config: {shared_services_config}")
return shared_services_config
def prepare_vespa(instance_ids: list[int], vespa_tenant_port: int) -> None:
schema_names = [
(
f"{get_vector_db_prefix(instance_id)}_{DEFAULT_SCHEMA_NAME}",
DEFAULT_EMBEDDING_DIMENSION,
False,
)
for instance_id in instance_ids
]
print(f"Creating indices: {schema_names}")
for _ in range(7):
try:
VespaIndex.create_indices(
schema_names, f"http://localhost:{vespa_tenant_port}/application/v2"
)
return
except Exception as e:
print(f"Error creating indices: {e}. Trying again in 5 seconds...")
time.sleep(5)
raise RuntimeError("Failed to create indices in Vespa")
def start_redis(
instance_num: int,
register_process: Callable[[subprocess.Popen], None],
) -> int:
"""Start a Redis instance for a specific deployment."""
print(f"Starting Redis for instance {instance_num}...")
redis_port = get_random_port()
container_name = f"redis-onyx-{instance_num}"
# Start Redis using docker run
subprocess.run(
[
"docker",
"run",
"-d",
"--name",
container_name,
"-p",
f"{redis_port}:6379",
"redis:7.4-alpine",
"redis-server",
"--save",
'""',
"--appendonly",
"no",
],
check=True,
)
return redis_port
def launch_instance(
instance_num: int,
postgres_port: int,
vespa_port: int,
vespa_tenant_port: int,
model_server_port: int,
register_process: Callable[[subprocess.Popen], None],
) -> DeploymentConfig:
"""Launch a Docker Compose instance with custom ports."""
api_port = get_random_port()
web_port = get_random_port()
nginx_port = get_random_port()
# Start Redis for this instance
redis_port = start_redis(instance_num, register_process)
try:
db_name = setup_db(instance_num, postgres_port)
api_port = start_api_server(
instance_num,
model_server_port, # Use shared model server port
postgres_port,
vespa_port,
vespa_tenant_port,
redis_port,
register_process,
)
start_background(
instance_num,
postgres_port,
vespa_port,
vespa_tenant_port,
redis_port,
register_process,
)
except Exception as e:
print(f"Failed to start API server for instance {instance_num}: {e}")
raise
return DeploymentConfig(
instance_num, api_port, web_port, nginx_port, redis_port, db_name
)
def wait_for_instance(
ports: DeploymentConfig, max_attempts: int = 120, wait_seconds: int = 2
) -> None:
"""Wait for an instance to be healthy."""
print(f"Waiting for instance {ports.instance_num} to be ready...")
for attempt in range(1, max_attempts + 1):
try:
response = requests.get(f"http://localhost:{ports.api_port}/health")
if response.status_code == 200:
print(
f"Instance {ports.instance_num} is ready on port {ports.api_port}"
)
return
raise ConnectionError(
f"Health check returned status {response.status_code}"
)
except (requests.RequestException, ConnectionError):
if attempt == max_attempts:
raise TimeoutError(
f"Timeout waiting for instance {ports.instance_num} "
f"on port {ports.api_port}"
)
print(
f"Waiting for instance {ports.instance_num} on port "
f" {ports.api_port}... ({attempt}/{max_attempts})"
)
time.sleep(wait_seconds)
def cleanup_instance(instance_num: int) -> None:
"""Cleanup a specific instance."""
print(f"Cleaning up instance {instance_num}...")
temp_compose = Path(f"/tmp/docker-compose.dev.instance{instance_num}.yml")
try:
subprocess.run(
[
"docker",
"compose",
"-f",
str(temp_compose),
"-p",
f"onyx-stack-{instance_num}",
"down",
],
check=True,
)
print(f"Instance {instance_num} cleaned up successfully")
except subprocess.CalledProcessError:
print(f"Error cleaning up instance {instance_num}")
except FileNotFoundError:
print(f"No compose file found for instance {instance_num}")
finally:
# Clean up the temporary compose file if it exists
if temp_compose.exists():
temp_compose.unlink()
print(f"Removed temporary compose file for instance {instance_num}")
def run_x_instances(
num_instances: int,
) -> tuple[SharedServicesConfig, list[DeploymentConfig]]:
"""Start x instances of the application and return their configurations."""
run_id = uuid.uuid4()
instance_ids = list(range(1, num_instances + 1))
_pids: list[int] = []
def register_process(process: subprocess.Popen) -> None:
_pids.append(process.pid)
def cleanup_all_instances() -> None:
"""Cleanup all instances."""
print("Cleaning up all instances...")
# Stop the database services
subprocess.run(
[
"docker",
"compose",
"-p",
get_shared_services_stack_name(run_id),
"-f",
"/tmp/docker-compose.minimal.yml",
"down",
],
check=True,
)
# Stop and remove all Redis containers
for instance_id in range(1, num_instances + 1):
container_name = f"redis-onyx-{instance_id}"
try:
subprocess.run(["docker", "rm", "-f", container_name], check=True)
except subprocess.CalledProcessError:
print(f"Error cleaning up Redis container {container_name}")
for pid in _pids:
cleanup_pid(pid)
# Register cleanup handler
atexit.register(cleanup_all_instances)
# Start database services first
print("Starting shared services...")
shared_services_config = start_shared_services(run_id)
# create documents
print("Creating indices in Vespa...")
prepare_vespa(instance_ids, shared_services_config.vespa_tenant_port)
# Use ThreadPool to launch instances in parallel and collect results
# NOTE: only kick off 10 at a time to avoid overwhelming the system
print("Launching instances...")
with ThreadPool(processes=len(instance_ids)) as pool:
# Create list of arguments for each instance
launch_args = [
(
i,
shared_services_config.postgres_port,
shared_services_config.vespa_port,
shared_services_config.vespa_tenant_port,
shared_services_config.model_server_port,
register_process,
)
for i in instance_ids
]
# Launch instances and get results
port_configs = pool.starmap(launch_instance, launch_args)
# Wait for all instances to be healthy
print("Waiting for instances to be healthy...")
with ThreadPool(processes=len(port_configs)) as pool:
pool.map(wait_for_instance, port_configs)
print("All instances launched!")
print("Database Services:")
print(f"Postgres port: {shared_services_config.postgres_port}")
print(f"Vespa main port: {shared_services_config.vespa_port}")
print(f"Vespa tenant port: {shared_services_config.vespa_tenant_port}")
print("\nApplication Instances:")
for ports in port_configs:
print(
f"Instance {ports.instance_num}: "
f"API={ports.api_port}, Web={ports.web_port}, Nginx={ports.nginx_port}"
)
return shared_services_config, port_configs
def main() -> None:
shared_services_config, port_configs = run_x_instances(1)
# Run pytest with the API server port set
api_port = port_configs[0].api_port # Use first instance's API port
try:
subprocess.run(
["pytest", "tests/integration/openai_assistants_api"],
env={**os.environ, "API_SERVER_PORT": str(api_port)},
cwd=str(BACKEND_DIR_PATH),
check=True,
)
except subprocess.CalledProcessError as e:
print(f"Tests failed with exit code {e.returncode}")
sys.exit(e.returncode)
time.sleep(5)
if __name__ == "__main__":
main()

View File

@@ -0,0 +1,266 @@
import multiprocessing
import os
import queue
import subprocess
import sys
import threading
import time
from dataclasses import dataclass
from multiprocessing.synchronize import Lock as LockType
from pathlib import Path
from tests.integration.common_utils.reset import reset_all
from tests.integration.introspection import list_all_tests
from tests.integration.introspection import load_env_vars
from tests.integration.kickoff import BACKEND_DIR_PATH
from tests.integration.kickoff import DeploymentConfig
from tests.integration.kickoff import run_x_instances
from tests.integration.kickoff import SharedServicesConfig
@dataclass
class TestResult:
test_name: str
success: bool
output: str
error: str | None = None
def run_single_test(
test_name: str,
deployment_config: DeploymentConfig,
shared_services_config: SharedServicesConfig,
result_queue: multiprocessing.Queue,
) -> None:
"""Run a single test with the given API port."""
test_path, test_name = test_name.split("::")
processed_test_name = f"{test_path.replace('.', '/')}.py::{test_name}"
print(f"Running test: {processed_test_name}", flush=True)
try:
env = {
**os.environ,
"API_SERVER_PORT": str(deployment_config.api_port),
"PYTHONPATH": ".",
"GUARANTEED_FRESH_SETUP": "true",
"POSTGRES_PORT": str(shared_services_config.postgres_port),
"POSTGRES_DB": deployment_config.postgres_db,
"REDIS_PORT": str(deployment_config.redis_port),
"VESPA_PORT": str(shared_services_config.vespa_port),
"VESPA_TENANT_PORT": str(shared_services_config.vespa_tenant_port),
}
result = subprocess.run(
["pytest", processed_test_name, "-v"],
env=env,
cwd=str(BACKEND_DIR_PATH),
capture_output=True,
text=True,
)
result_queue.put(
TestResult(
test_name=test_name,
success=result.returncode == 0,
output=result.stdout,
error=result.stderr if result.returncode != 0 else None,
)
)
except Exception as e:
result_queue.put(
TestResult(
test_name=test_name,
success=False,
output="",
error=str(e),
)
)
def worker(
test_queue: queue.Queue[str],
instance_queue: queue.Queue[int],
result_queue: multiprocessing.Queue,
shared_services_config: SharedServicesConfig,
deployment_configs: list[DeploymentConfig],
reset_lock: LockType,
) -> None:
"""Worker process that runs tests on available instances."""
while True:
# Get the next test from the queue
try:
test = test_queue.get(block=False)
except queue.Empty:
break
# Get an available instance
instance_idx = instance_queue.get()
deployment_config = deployment_configs[
instance_idx - 1
] # Convert to 0-based index
try:
# Run the test
run_single_test(
test, deployment_config, shared_services_config, result_queue
)
# get instance ready for next test
print(
f"Resetting instance for next. DB: {deployment_config.postgres_db}, "
f"Port: {shared_services_config.postgres_port}"
)
# alembic is NOT thread-safe, so we need to make sure only one worker is resetting at a time
with reset_lock:
reset_all(
database=deployment_config.postgres_db,
postgres_port=str(shared_services_config.postgres_port),
redis_port=deployment_config.redis_port,
silence_logs=True,
# indices are created during the kickoff process, no need to recreate them
skip_creating_indices=True,
# use the special vespa port
document_id_endpoint=(
f"http://localhost:{shared_services_config.vespa_port}"
"/document/v1/default/{{index_name}}/docid"
),
)
except Exception as e:
# Log the error and put it in the result queue
error_msg = f"Critical error in worker thread for test {test}: {str(e)}"
print(error_msg, file=sys.stderr)
result_queue.put(
TestResult(
test_name=test,
success=False,
output="",
error=error_msg,
)
)
# Re-raise to stop the worker
raise
finally:
# Put the instance back in the queue
instance_queue.put(instance_idx)
test_queue.task_done()
def main() -> None:
NUM_INSTANCES = 7
# Get all tests
prefixes = ["tests", "connector_job_tests"]
tests = []
for prefix in prefixes:
tests += [
f"tests/integration/{prefix}/{test_path}"
for test_path in list_all_tests(Path(__file__).parent / prefix)
]
print(f"Found {len(tests)} tests to run")
# load env vars which will be passed into the tests
load_env_vars(os.environ.get("IT_ENV_FILE_PATH", ".env"))
# For debugging
# tests = [test for test in tests if "openai_assistants_api" in test]
# tests = tests[:2]
print(f"Running {len(tests)} tests")
# Start all instances at once
shared_services_config, deployment_configs = run_x_instances(NUM_INSTANCES)
# Create queues and lock
test_queue: queue.Queue[str] = queue.Queue()
instance_queue: queue.Queue[int] = queue.Queue()
result_queue: multiprocessing.Queue = multiprocessing.Queue()
reset_lock: LockType = multiprocessing.Lock()
# Fill the instance queue with available instance numbers
for i in range(1, NUM_INSTANCES + 1):
instance_queue.put(i)
# Fill the test queue with all tests
for test in tests:
test_queue.put(test)
# Start worker threads
workers = []
for _ in range(NUM_INSTANCES):
worker_thread = threading.Thread(
target=worker,
args=(
test_queue,
instance_queue,
result_queue,
shared_services_config,
deployment_configs,
reset_lock,
),
)
worker_thread.start()
workers.append(worker_thread)
# Monitor workers and fail fast if any die
try:
while any(w.is_alive() for w in workers):
# Check if all tests are done
if test_queue.empty() and all(not w.is_alive() for w in workers):
break
# Check for dead workers that died with unfinished tests
if not test_queue.empty() and any(not w.is_alive() for w in workers):
print(
"\nCritical: Worker thread(s) died with tests remaining!",
file=sys.stderr,
)
sys.exit(1)
time.sleep(0.1) # Avoid busy waiting
# Collect results
print("Collecting results")
results: list[TestResult] = []
while not result_queue.empty():
results.append(result_queue.get())
# Print results
print("\nTest Results:")
failed = False
failed_tests: list[str] = []
total_tests = len(results)
passed_tests = 0
for result in results:
status = "✅ PASSED" if result.success else "❌ FAILED"
print(f"{status} - {result.test_name}")
if result.success:
passed_tests += 1
else:
failed = True
failed_tests.append(result.test_name)
print("Error output:")
print(result.error)
print("Test output:")
print(result.output)
print("-" * 80)
# Print summary
print("\nTest Summary:")
print(f"Total Tests: {total_tests}")
print(f"Passed: {passed_tests}")
print(f"Failed: {len(failed_tests)}")
if failed_tests:
print("\nFailed Tests:")
for test_name in failed_tests:
print(f"{test_name}")
print()
if failed:
sys.exit(1)
except KeyboardInterrupt:
print("\nTest run interrupted by user", file=sys.stderr)
sys.exit(130) # Standard exit code for SIGINT
except Exception as e:
print(f"\nCritical error during result collection: {str(e)}", file=sys.stderr)
sys.exit(1)
if __name__ == "__main__":
main()